-
Notifications
You must be signed in to change notification settings - Fork 416
fix: Fix fp8 memory fragmentation #2670
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
ea5beff
8a7da38
37294c4
10469c3
97efa3a
9e8f799
468ee20
ec41a13
378c101
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1234,6 +1234,41 @@ def prepare_for_training(self, *args, **kwargs): | |
| if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: | ||
| torch.cuda.empty_cache() | ||
|
|
||
| def _clear_fp8_caches(self): | ||
| """Clear FP8 workspace caches and release fragmented GPU memory. | ||
|
|
||
| The main memory issue in the train→offload→refit→generate cycle is CUDA | ||
| allocator fragmentation, not leaked FP8 tensors. This method: | ||
| 1. Clears TE workspace buffers (which hold references to scratch memory) | ||
| 2. Runs gc.collect() + empty_cache() to return freed blocks to CUDA | ||
|
|
||
| For anti-fragmentation, configure PYTORCH_CUDA_ALLOC_CONF in the recipe YAML: | ||
| - "max_split_size_mb:512" — fast, prevents large-block splitting | ||
| - "expandable_segments:True" — most effective but ~5x slower weight transfer | ||
| """ | ||
| # 1. Clear Transformer Engine workspaces | ||
| workspace_count = 0 | ||
| for module in self.model.modules(): | ||
| if hasattr(module, "_fp8_workspaces"): | ||
| module._fp8_workspaces.clear() | ||
| workspace_count += 1 | ||
|
|
||
| # 2. Clear global TE workspace | ||
| try: | ||
| import transformer_engine.pytorch as te | ||
|
|
||
| if hasattr(te, "module") and hasattr(te.module.base, "clear_workspace"): | ||
| te.module.base.clear_workspace() | ||
| except ImportError: | ||
| pass | ||
|
|
||
| print( | ||
| f"[_clear_fp8_caches] Cleared {workspace_count} workspace modules on rank {self.rank}" | ||
| ) | ||
|
|
||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it be more efficient to run |
||
|
|
||
| @wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit") | ||
| def offload_before_refit(self): | ||
| """Offload the optimizer and buffers to the CPU.""" | ||
|
|
@@ -1247,6 +1282,11 @@ def offload_before_refit(self): | |
| self.model = self.move_model( | ||
| self.model, "cpu", move_params=False, move_grads=True | ||
| ) # get rid of grad buffers | ||
|
|
||
| # Clear FP8 caches (uint8/int16 tensors) to CPU | ||
| if self.fp8_cfg.get("force_clear_fp8_caches", False): | ||
| self._clear_fp8_caches() | ||
|
|
||
| torch.randn(1).cuda() # wake up torch allocator | ||
| if ( | ||
| hasattr(self, "optimizer") | ||
|
|
@@ -1276,6 +1316,58 @@ def offload_after_refit(self): | |
| torch.randn(1).cuda() # wake up torch allocator | ||
| self.offload_before_refit() # rerun the old offload function | ||
|
|
||
| if self.cfg["megatron_cfg"].get("clear_memory_caches_after_refit", False): | ||
| # Clear RotaryEmbedding's @lru_cache(maxsize=32). The cache accumulates one | ||
| # entry per unique (max_seq_len, offset, packed_seq) seen, and each entry is | ||
| # a GPU tensor (the concatenated sin/cos embedding). With training + logprob | ||
| # runs at different sequence lengths, the cache fills quickly and the tensors | ||
| # anchor large CUDA segments. | ||
| try: | ||
| from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding | ||
| RotaryEmbedding.forward.cache_clear() | ||
| except Exception: | ||
| pass | ||
|
|
||
| # Clear MoE token dispatcher persistent routing tensors. | ||
| # | ||
| # MoETokenDispatcher is a plain Python class (NOT an nn.Module), so iterating | ||
| # self.model.modules() never yields it. We must access it via the token_dispatcher | ||
| # attribute on MoELayer nn.Module objects. | ||
| # | ||
| # When recompute_mlp=True and fp8=True, | ||
| # transformer_layer._forward_mlp wraps self.mlp (the MoE layer) with te_checkpoint. | ||
| # te_checkpoint._CheckpointFunction.backward recomputes the forward with | ||
| # torch.enable_grad(), which causes dispatch_preprocess to store | ||
| # dispatcher.probs = routing_probs (with grad_fn, under enable_grad) | ||
| # This creates a reference cycle: | ||
| # _CheckpointFunctionBackward → ctx → ctx.run_function=mlp | ||
| # → mlp.token_dispatcher.probs → probs.grad_fn → ... → _CheckpointFunctionBackward | ||
| # | ||
| # Breaking this cycle by nulling dispatcher.probs frees BOTH: | ||
| # - the routing tensors | ||
| # - the te_checkpoint ctx saved tensors | ||
| try: | ||
| for module in self.model.modules(): | ||
| if not hasattr(module, "token_dispatcher"): | ||
| continue | ||
| dispatcher = module.token_dispatcher | ||
| if dispatcher is None: | ||
| continue | ||
| for attr in ( | ||
| "probs", # AllToAll + AllGather | ||
| "routing_map", # AllToAll | ||
| "reversed_local_input_permutation_mapping", # AllToAll | ||
| "local_probs", # AllGather | ||
| "local_map", # AllGather | ||
| ): | ||
| if isinstance(getattr(dispatcher, attr, None), torch.Tensor): | ||
| setattr(dispatcher, attr, None) | ||
| except Exception: | ||
| pass | ||
|
|
||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same above: |
||
|
|
||
| allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB | ||
| reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB | ||
| print( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, @ashors1 might be my issue. I can't find this in TE and I don't think this piece code is helpful to the memory saving, maybe we can simply remove it?