Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ class MegatronDDPConfig(TypedDict):
data_parallel_sharding_strategy: str


class Fp8Config(TypedDict):
# Master switch for FP8 training. When False, all other fields are ignored.
enabled: bool
# FP8 format used for the GEMMs (e.g. "e4m3").
fp8: NotRequired[str]
# FP8 scaling recipe (e.g. "blockwise").
fp8_recipe: NotRequired[str]
# When True, keep parameters in FP8. Can cause NaN token_mult_prob_error;
# use with caution (see https://github.com/NVIDIA-NeMo/RL/issues/1164).
fp8_param: NotRequired[bool]
# When True, clear Transformer Engine's per-module _fp8_workspaces scratch
# buffers in offload_before_refit (before weight transfer to the inference
# engine). These FP8 workspace tensors anchor large CUDA segments and
# aggravate allocator fragmentation across the train->offload->refit->generate
# cycle. Useful for FP8 training runs that observe growing reserved GPU memory
# after offload.
force_clear_fp8_caches: NotRequired[bool]


# Type exists to be lax if not specified
class MegatronConfigDisabled(TypedDict):
enabled: Literal[False]
Expand Down Expand Up @@ -277,6 +296,14 @@ class MegatronConfig(TypedDict):
linear_ce_fusion_chunk_size: NotRequired[int]
# When mtp_num_layers=0, Multi-Token Prediction is disabled.
mtp_num_layers: NotRequired[int]
# When True, clear the RotaryEmbedding LRU cache and MoE token dispatcher
# routing tensors in offload_before_refit (before weight transfer to the
# inference engine). Useful when training and logprob runs use different
# sequence lengths (rope cache) or for MoE models with activation recompute
# (dispatcher reference cycles).
clear_memory_caches_before_refit: NotRequired[bool]
# FP8 quantization settings for the Megatron training backend.
fp8_cfg: NotRequired[Fp8Config]


class DraftConfigDisabled(TypedDict):
Expand Down
81 changes: 81 additions & 0 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,6 +1237,31 @@ def prepare_for_training(self, *args, **kwargs):
if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1:
torch.cuda.empty_cache()

def _clear_fp8_caches(self):
"""Clear FP8 workspace caches and release fragmented GPU memory.

The main memory issue in the train→offload→refit→generate cycle is CUDA
allocator fragmentation, not leaked FP8 tensors. This method clears
per-module _fp8_workspaces buffers (scratch memory references). The
caller is responsible for running gc.collect() + empty_cache() once
all references have been dropped.

For anti-fragmentation, configure PYTORCH_CUDA_ALLOC_CONF in the recipe YAML:
- "max_split_size_mb:512" — fast, prevents large-block splitting
- "expandable_segments:True" — most effective but ~5x slower weight transfer
"""
# 1. Clear Transformer Engine workspaces
workspace_count = 0
for module in self.model.modules():
if hasattr(module, "_fp8_workspaces"):
module._fp8_workspaces.clear()
workspace_count += 1

print(
f"[_clear_fp8_caches] Cleared {workspace_count} workspace modules on rank {self.rank}"
)


@wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit")
def offload_before_refit(self):
"""Offload the optimizer and buffers to the CPU."""
Expand All @@ -1250,6 +1275,62 @@ def offload_before_refit(self):
self.model = self.move_model(
self.model, "cpu", move_params=False, move_grads=True
) # get rid of grad buffers

# When True, clear Transformer Engine's per-module _fp8_workspaces scratch
# buffers in offload_before_refit (before weight transfer to the inference
# engine).
if self.fp8_cfg and self.fp8_cfg.get("force_clear_fp8_caches", False):
self._clear_fp8_caches()

if self.cfg["megatron_cfg"].get("clear_memory_caches_before_refit", False):
# Clear RotaryEmbedding's @lru_cache(maxsize=32). The cache accumulates one
# entry per unique (max_seq_len, offset, packed_seq) seen, and each entry is
# a GPU tensor (the concatenated sin/cos embedding). With training + logprob
# runs at different sequence lengths, the cache fills quickly and the tensors
# anchor large CUDA segments.
try:
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
RotaryEmbedding.forward.cache_clear()
except Exception:
pass

# Clear MoE token dispatcher persistent routing tensors.
#
# MoETokenDispatcher is a plain Python class (NOT an nn.Module), so iterating
# self.model.modules() never yields it. We must access it via the token_dispatcher
# attribute on MoELayer nn.Module objects.
#
# When recompute_mlp=True and fp8=True,
# transformer_layer._forward_mlp wraps self.mlp (the MoE layer) with te_checkpoint.
# te_checkpoint._CheckpointFunction.backward recomputes the forward with
# torch.enable_grad(), which causes dispatch_preprocess to store
# dispatcher.probs = routing_probs (with grad_fn, under enable_grad)
# This creates a reference cycle:
# _CheckpointFunctionBackward → ctx → ctx.run_function=mlp
# → mlp.token_dispatcher.probs → probs.grad_fn → ... → _CheckpointFunctionBackward
#
# Breaking this cycle by nulling dispatcher.probs frees BOTH:
# - the routing tensors
# - the te_checkpoint ctx saved tensors
try:
for module in self.model.modules():
if not hasattr(module, "token_dispatcher"):
continue
dispatcher = module.token_dispatcher
if dispatcher is None:
continue
for attr in (
"probs", # AllToAll + AllGather
"routing_map", # AllToAll
"reversed_local_input_permutation_mapping", # AllToAll
"local_probs", # AllGather
"local_map", # AllGather
):
if isinstance(getattr(dispatcher, attr, None), torch.Tensor):
setattr(dispatcher, attr, None)
except Exception:
pass

torch.randn(1).cuda() # wake up torch allocator
if (
hasattr(self, "optimizer")
Expand Down
Loading