diff --git a/nemo_rl/models/policy/__init__.py b/nemo_rl/models/policy/__init__.py index d3e82457bc..7feecdf7f6 100644 --- a/nemo_rl/models/policy/__init__.py +++ b/nemo_rl/models/policy/__init__.py @@ -194,6 +194,25 @@ class MegatronDDPConfig(TypedDict): data_parallel_sharding_strategy: str +class Fp8Config(TypedDict): + # Master switch for FP8 training. When False, all other fields are ignored. + enabled: bool + # FP8 format used for the GEMMs (e.g. "e4m3"). + fp8: NotRequired[str] + # FP8 scaling recipe (e.g. "blockwise"). + fp8_recipe: NotRequired[str] + # When True, keep parameters in FP8. Can cause NaN token_mult_prob_error; + # use with caution (see https://github.com/NVIDIA-NeMo/RL/issues/1164). + fp8_param: NotRequired[bool] + # When True, clear Transformer Engine's per-module _fp8_workspaces scratch + # buffers in offload_before_refit (before weight transfer to the inference + # engine). These FP8 workspace tensors anchor large CUDA segments and + # aggravate allocator fragmentation across the train->offload->refit->generate + # cycle. Useful for FP8 training runs that observe growing reserved GPU memory + # after offload. + force_clear_fp8_caches: NotRequired[bool] + + # Type exists to be lax if not specified class MegatronConfigDisabled(TypedDict): enabled: Literal[False] @@ -277,6 +296,14 @@ class MegatronConfig(TypedDict): linear_ce_fusion_chunk_size: NotRequired[int] # When mtp_num_layers=0, Multi-Token Prediction is disabled. mtp_num_layers: NotRequired[int] + # When True, clear the RotaryEmbedding LRU cache and MoE token dispatcher + # routing tensors in offload_before_refit (before weight transfer to the + # inference engine). Useful when training and logprob runs use different + # sequence lengths (rope cache) or for MoE models with activation recompute + # (dispatcher reference cycles). + clear_memory_caches_before_refit: NotRequired[bool] + # FP8 quantization settings for the Megatron training backend. + fp8_cfg: NotRequired[Fp8Config] class DraftConfigDisabled(TypedDict): diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 0395d2e0c9..fb83c524e2 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -1237,6 +1237,31 @@ def prepare_for_training(self, *args, **kwargs): if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: torch.cuda.empty_cache() + def _clear_fp8_caches(self): + """Clear FP8 workspace caches and release fragmented GPU memory. + + The main memory issue in the train→offload→refit→generate cycle is CUDA + allocator fragmentation, not leaked FP8 tensors. This method clears + per-module _fp8_workspaces buffers (scratch memory references). The + caller is responsible for running gc.collect() + empty_cache() once + all references have been dropped. + + For anti-fragmentation, configure PYTORCH_CUDA_ALLOC_CONF in the recipe YAML: + - "max_split_size_mb:512" — fast, prevents large-block splitting + - "expandable_segments:True" — most effective but ~5x slower weight transfer + """ + # 1. Clear Transformer Engine workspaces + workspace_count = 0 + for module in self.model.modules(): + if hasattr(module, "_fp8_workspaces"): + module._fp8_workspaces.clear() + workspace_count += 1 + + print( + f"[_clear_fp8_caches] Cleared {workspace_count} workspace modules on rank {self.rank}" + ) + + @wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit") def offload_before_refit(self): """Offload the optimizer and buffers to the CPU.""" @@ -1250,6 +1275,62 @@ def offload_before_refit(self): self.model = self.move_model( self.model, "cpu", move_params=False, move_grads=True ) # get rid of grad buffers + + # When True, clear Transformer Engine's per-module _fp8_workspaces scratch + # buffers in offload_before_refit (before weight transfer to the inference + # engine). + if self.fp8_cfg and self.fp8_cfg.get("force_clear_fp8_caches", False): + self._clear_fp8_caches() + + if self.cfg["megatron_cfg"].get("clear_memory_caches_before_refit", False): + # Clear RotaryEmbedding's @lru_cache(maxsize=32). The cache accumulates one + # entry per unique (max_seq_len, offset, packed_seq) seen, and each entry is + # a GPU tensor (the concatenated sin/cos embedding). With training + logprob + # runs at different sequence lengths, the cache fills quickly and the tensors + # anchor large CUDA segments. + try: + from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding + RotaryEmbedding.forward.cache_clear() + except Exception: + pass + + # Clear MoE token dispatcher persistent routing tensors. + # + # MoETokenDispatcher is a plain Python class (NOT an nn.Module), so iterating + # self.model.modules() never yields it. We must access it via the token_dispatcher + # attribute on MoELayer nn.Module objects. + # + # When recompute_mlp=True and fp8=True, + # transformer_layer._forward_mlp wraps self.mlp (the MoE layer) with te_checkpoint. + # te_checkpoint._CheckpointFunction.backward recomputes the forward with + # torch.enable_grad(), which causes dispatch_preprocess to store + # dispatcher.probs = routing_probs (with grad_fn, under enable_grad) + # This creates a reference cycle: + # _CheckpointFunctionBackward → ctx → ctx.run_function=mlp + # → mlp.token_dispatcher.probs → probs.grad_fn → ... → _CheckpointFunctionBackward + # + # Breaking this cycle by nulling dispatcher.probs frees BOTH: + # - the routing tensors + # - the te_checkpoint ctx saved tensors + try: + for module in self.model.modules(): + if not hasattr(module, "token_dispatcher"): + continue + dispatcher = module.token_dispatcher + if dispatcher is None: + continue + for attr in ( + "probs", # AllToAll + AllGather + "routing_map", # AllToAll + "reversed_local_input_permutation_mapping", # AllToAll + "local_probs", # AllGather + "local_map", # AllGather + ): + if isinstance(getattr(dispatcher, attr, None), torch.Tensor): + setattr(dispatcher, attr, None) + except Exception: + pass + torch.randn(1).cuda() # wake up torch allocator if ( hasattr(self, "optimizer")