Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nemo_rl/models/policy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,11 @@ class MegatronConfig(TypedDict):
linear_ce_fusion_chunk_size: NotRequired[int]
# When mtp_num_layers=0, Multi-Token Prediction is disabled.
mtp_num_layers: NotRequired[int]
# When True, clear the RotaryEmbedding LRU cache and MoE token dispatcher
# routing tensors after offload_after_refit. Useful when training and logprob
# runs use different sequence lengths (rope cache) or for MoE models with
# activation recompute (dispatcher reference cycles).
clear_memory_caches_after_refit: NotRequired[bool]


class DraftConfigDisabled(TypedDict):
Expand Down
92 changes: 92 additions & 0 deletions nemo_rl/models/policy/workers/megatron_policy_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1234,6 +1234,41 @@ def prepare_for_training(self, *args, **kwargs):
if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1:
torch.cuda.empty_cache()

def _clear_fp8_caches(self):
"""Clear FP8 workspace caches and release fragmented GPU memory.

The main memory issue in the train→offload→refit→generate cycle is CUDA
allocator fragmentation, not leaked FP8 tensors. This method:
1. Clears TE workspace buffers (which hold references to scratch memory)
2. Runs gc.collect() + empty_cache() to return freed blocks to CUDA

For anti-fragmentation, configure PYTORCH_CUDA_ALLOC_CONF in the recipe YAML:
- "max_split_size_mb:512" — fast, prevents large-block splitting
- "expandable_segments:True" — most effective but ~5x slower weight transfer
"""
# 1. Clear Transformer Engine workspaces
workspace_count = 0
for module in self.model.modules():
if hasattr(module, "_fp8_workspaces"):
module._fp8_workspaces.clear()
workspace_count += 1

# 2. Clear global TE workspace
try:
import transformer_engine.pytorch as te

if hasattr(te, "module") and hasattr(te.module.base, "clear_workspace"):
te.module.base.clear_workspace()
except ImportError:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, @ashors1 might be my issue. I can't find this in TE and I don't think this piece code is helpful to the memory saving, maybe we can simply remove it?


print(
f"[_clear_fp8_caches] Cleared {workspace_count} workspace modules on rank {self.rank}"
)

gc.collect()
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be more efficient to run gc.collect() and torch.cuda.empty_cache() just once after all references are cleared? Calling them frequently might introduce unnecessary overhead to the training pipeline.


@wrap_with_nvtx_name("megatron_policy_worker/offload_before_refit")
def offload_before_refit(self):
"""Offload the optimizer and buffers to the CPU."""
Expand All @@ -1247,6 +1282,11 @@ def offload_before_refit(self):
self.model = self.move_model(
self.model, "cpu", move_params=False, move_grads=True
) # get rid of grad buffers

# Clear FP8 caches (uint8/int16 tensors) to CPU
if self.fp8_cfg.get("force_clear_fp8_caches", False):
self._clear_fp8_caches()

torch.randn(1).cuda() # wake up torch allocator
if (
hasattr(self, "optimizer")
Expand Down Expand Up @@ -1276,6 +1316,58 @@ def offload_after_refit(self):
torch.randn(1).cuda() # wake up torch allocator
self.offload_before_refit() # rerun the old offload function

if self.cfg["megatron_cfg"].get("clear_memory_caches_after_refit", False):
# Clear RotaryEmbedding's @lru_cache(maxsize=32). The cache accumulates one
# entry per unique (max_seq_len, offset, packed_seq) seen, and each entry is
# a GPU tensor (the concatenated sin/cos embedding). With training + logprob
# runs at different sequence lengths, the cache fills quickly and the tensors
# anchor large CUDA segments.
try:
from megatron.core.models.common.embeddings.rotary_pos_embedding import RotaryEmbedding
RotaryEmbedding.forward.cache_clear()
except Exception:
pass

# Clear MoE token dispatcher persistent routing tensors.
#
# MoETokenDispatcher is a plain Python class (NOT an nn.Module), so iterating
# self.model.modules() never yields it. We must access it via the token_dispatcher
# attribute on MoELayer nn.Module objects.
#
# When recompute_mlp=True and fp8=True,
# transformer_layer._forward_mlp wraps self.mlp (the MoE layer) with te_checkpoint.
# te_checkpoint._CheckpointFunction.backward recomputes the forward with
# torch.enable_grad(), which causes dispatch_preprocess to store
# dispatcher.probs = routing_probs (with grad_fn, under enable_grad)
# This creates a reference cycle:
# _CheckpointFunctionBackward → ctx → ctx.run_function=mlp
# → mlp.token_dispatcher.probs → probs.grad_fn → ... → _CheckpointFunctionBackward
#
# Breaking this cycle by nulling dispatcher.probs frees BOTH:
# - the routing tensors
# - the te_checkpoint ctx saved tensors
try:
for module in self.model.modules():
if not hasattr(module, "token_dispatcher"):
continue
dispatcher = module.token_dispatcher
if dispatcher is None:
continue
for attr in (
"probs", # AllToAll + AllGather
"routing_map", # AllToAll
"reversed_local_input_permutation_mapping", # AllToAll
"local_probs", # AllGather
"local_map", # AllGather
):
if isinstance(getattr(dispatcher, attr, None), torch.Tensor):
setattr(dispatcher, attr, None)
except Exception:
pass

gc.collect()
torch.cuda.empty_cache()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same above:
Would it be more efficient to run gc.collect() and torch.cuda.empty_cache() just once after all references are cleared? Calling them frequently might introduce unnecessary overhead to the training pipeline.


allocated = torch.cuda.memory_allocated() / (1024**3) # Convert to GB
reserved = torch.cuda.memory_reserved() / (1024**3) # Convert to GB
print(
Expand Down
Loading