From be3e5633536978e58267acf850caa869c6d8c45e Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Wed, 3 Jun 2026 22:47:10 -0700 Subject: [PATCH 1/7] feat(megatron): split-API train-step state machine on MegatronPolicyWorker MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path. Mechanism: - begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn / gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and null model.config.grad_sync_func (saved for restore) so the PP scheduler's direct reduce dispatch cannot bypass no_sync. - train_microbatch(data): wrap one ``megatron_forward_backward`` invocation in ``with self.model.no_sync():`` so mcore DDP hooks accumulate ``param.main_grad`` locally without dispatching the cross-DP reduce. Pass ``global_valid_seqs/toks=tensor(1.0)`` so the loss returns un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local mask sums + per-mb metrics + the total pipeline-microbatch count (for finish-time MoE aux-loss scaling). - finish_train_step: all_reduce mask sums to get true N (toks for TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call self.model.scale_gradients(1/N), then the one true cross-DP reduce via start_grad_sync + finish_grad_sync, optimizer.step (clips internally), restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb metrics by 1/N (linear-in-1/N math), aggregate, surface global counts. - abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad, drop state. ``trainer_version`` unchanged. Sync ``train()`` is left untouched. Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py covering the lifecycle and call-order invariants (no_sync wrap, grad_sync_func save/restore, mask-sum accumulation, N selection by loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so they run only in mcore-enabled CI containers. Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- nemo_rl/data_plane/worker_mixin.py | 79 +++ nemo_rl/models/policy/tq_policy.py | 125 ++++ .../policy/workers/megatron_policy_worker.py | 358 ++++++++++++ .../policy/test_megatron_split_state.py | 541 ++++++++++++++++++ 4 files changed, 1103 insertions(+) create mode 100644 tests/unit/models/policy/test_megatron_split_state.py diff --git a/nemo_rl/data_plane/worker_mixin.py b/nemo_rl/data_plane/worker_mixin.py index bd558e0f06..d6c0294765 100644 --- a/nemo_rl/data_plane/worker_mixin.py +++ b/nemo_rl/data_plane/worker_mixin.py @@ -505,3 +505,82 @@ def get_reference_policy_logprobs_presharded( tq_field="reference_policy_logprobs", ) del result + + # ── split-API entrypoints (SC async path) ────────────────────────────── + # + # The split path lets SingleController drive forward/backward per + # microbatch (or per pipeline-batch on Megatron) without stepping the + # optimizer until a full logical batch has accumulated. Backend + # methods (``begin_train_step``, ``train_microbatch``, + # ``finish_train_step``, ``abort_train_step``) own the train-step + # state machine; this mixin just gates them on TQ-presharded data. + + @wrap_with_nvtx_name("policy_worker/begin_train_step_presharded") + def begin_train_step_presharded( + self, + step_id: str, + loss_fn: Any, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step. No fetch — pure lifecycle. + + The backend stores ``step_id`` / ``loss_fn`` / ``gbs`` / ``mbs``, + clears gradients, and initialises accumulators for + ``local_valid_seqs`` / ``local_valid_toks`` and any per-microbatch + metrics. Optimizer state is untouched here. + """ + self.begin_train_step( # type: ignore[attr-defined] + step_id=step_id, + loss_fn=loss_fn, + gbs=gbs, + mbs=mbs, + ) + + @wrap_with_nvtx_name("policy_worker/train_microbatch_presharded") + def train_microbatch_presharded( + self, + step_id: str, + meta: "KVBatchMeta", + ) -> dict[str, Any]: + """Per-rank microbatch entrypoint. Fetch → packing prep → forward+backward. + + Gradients accumulate into ``.grad`` across calls; no + ``optimizer.step`` here. Returns per-microbatch metrics (loss, + local_valid_*); the backend folds them into the step accumulator + and the caller may surface them for diagnostics. + """ + data = self._fetch(meta) + data = self._attach_or_repack_pack_metadata(data, meta) + return self.train_microbatch( # type: ignore[attr-defined] + step_id=step_id, + data=data, + ) + + @wrap_with_nvtx_name("policy_worker/finish_train_step_presharded") + def finish_train_step_presharded( + self, + step_id: str, + ) -> dict[str, Any]: + """Close a logical train step. No fetch — pure lifecycle. + + Backend all-reduces accumulated ``local_valid_seqs/toks``, + rescales gradients to the final global normalization, runs grad + clip, steps the optimizer + scheduler, then zeros gradients. + Returns the aggregated step result (``loss``, ``grad_norm``, + ``all_mb_metrics``, …). + """ + return self.finish_train_step(step_id=step_id) # type: ignore[attr-defined] + + @wrap_with_nvtx_name("policy_worker/abort_train_step_presharded") + def abort_train_step_presharded( + self, + step_id: str, + ) -> None: + """Discard partial train-step state without stepping the optimizer. + + Used when SC decides the logical batch will not complete (e.g. + weight-sync triggered mid-step). Backend drops accumulators and + zeros gradients. + """ + self.abort_train_step(step_id=step_id) # type: ignore[attr-defined] diff --git a/nemo_rl/models/policy/tq_policy.py b/nemo_rl/models/policy/tq_policy.py index 1179bd8a1f..ecdcfd588a 100644 --- a/nemo_rl/models/policy/tq_policy.py +++ b/nemo_rl/models/policy/tq_policy.py @@ -438,3 +438,128 @@ def train_from_meta( warnings.warn(f"Error getting theoretical flops: {e}") return aggregated_results + + # ── split-API fanout (SC async path) ─────────────────────────────────── + # + # Counterpart to :meth:`train_from_meta`, exposed to ``PolicyTrainerActor`` + # so :class:`SingleControllerActor` can stream microbatches without + # forcing a full-step optimizer.step on every dispatch. + # + # Lifecycle: + # begin_train_step — open step; broadcast loss_fn/gbs/mbs + # train_microbatch_from_meta (N×) — DP-sharded fwd/bwd, grads accumulate + # finish_train_step — all_reduce + opt.step + sched.step + # abort_train_step — drop accumulators, no opt.step + # + # ``train_from_meta`` is unchanged and remains the sync entrypoint. + + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step on every worker.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + if self.flops_tracker is not None: + self.flops_tracker.reset() + futures = self.worker_group.run_all_workers_single_data( + "begin_train_step_presharded", + step_id=step_id, + loss_fn=loss_fn, + gbs=batch_size, + mbs=micro_batch_size, + ) + self.worker_group.get_all_worker_results(futures) + + def train_microbatch_from_meta( + self, + step_id: str, + meta: KVBatchMeta, + timer: Optional[Timer] = None, + ) -> dict[str, Any]: + """Dispatch one microbatch (DP-sharded) into an open train step. + + Mirrors the sharding logic of :meth:`train_from_meta` but without + a logical-batch sizing constraint: this routes ``meta`` to DP + ranks and runs forward+backward; gradients accumulate in + ``.grad``. The optimizer step happens at :meth:`finish_train_step`. + """ + self._stamp_pad_seqlen(meta) + spa, dba = self._packing_args("train_mb_tokens") + train_meta = replace( + meta, + fields=list(DP_TRAIN_FIELDS), + task_name="train", + ) + with timer.time("policy_training/shard_meta") if timer else nullcontext(): + dp_metas, _ = shard_meta_for_dp( + train_meta, + dp_world=self.sharding_annotations.get_axis_size("data_parallel"), + batch_size=None, + sequence_packing_args=spa, + dynamic_batching_args=dba, + ) + + if self.flops_tracker is not None: + for m in dp_metas: + self.flops_tracker.track_batch(list(m.sequence_lengths or [])) + + with ( + timer.time("policy_training/submit_microbatch_futures") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "train_microbatch_presharded", + meta=dp_metas, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"step_id": step_id}, + ) + results = self.worker_group.get_all_worker_results(futures) + # Per-microbatch metrics: pass through DP-rank-0 by convention, + # backend may aggregate later if needed. Surface as-is for now. + return results[0] if results else {} + + def finish_train_step(self, step_id: str) -> dict[str, Any]: + """Close an open train step: all_reduce, rescale, optimizer.step. + + Aggregates per-rank step results into the same shape as + :meth:`train_from_meta` so callers don't have to special-case + the split path. + """ + futures = self.worker_group.run_all_workers_single_data( + "finish_train_step_presharded", + step_id=step_id, + ) + results = self.worker_group.get_all_worker_results(futures) + aggregated_results = _aggregate_train_results(results) + + if self.flops_tracker is not None: + aggregated_results["total_flops"] = self.flops_tracker.total_flops + aggregated_results["num_ranks"] = self.worker_group.cluster.world_size() + + return aggregated_results + + def abort_train_step(self, step_id: str) -> None: + """Drop partial step state on every worker. No optimizer.step.""" + futures = self.worker_group.run_all_workers_single_data( + "abort_train_step_presharded", + step_id=step_id, + ) + self.worker_group.get_all_worker_results(futures) + + if self.flops_tracker is not None: + self.flops_tracker.reset() diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 0395d2e0c9..cdf3fe6212 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -528,6 +528,364 @@ def train( metrics["moe_metrics"] = moe_metrics return metrics + # ── split-API train-step state machine (SingleController async path) ── + # + # Mirrors the v1/v2 implementations, adapted for mcore. Key differences: + # + # 1. mcore DDP accumulates ``param.main_grad`` per backward and dispatches + # a cross-DP reduce when ``is_last_microbatch=True`` (one per + # ``forward_backward_func`` call). Naively chaining multiple + # ``forward_backward_func`` calls between ``optimizer.step()`` would + # over-count: each call's terminal reduce sums an already-reduced + # bucket again. We wrap every call in ``self.model.no_sync()`` so + # hooks accumulate locally only; one explicit ``start_grad_sync`` + + # ``finish_grad_sync`` at finish does the single true reduce. + # 2. PP>1: the pipeline scheduler invokes ``config.grad_sync_func`` + # directly on last-microbatch boundaries — this bypasses the + # ``no_sync`` gate. We null it for the duration of the step and + # restore at finish/abort. + # 3. Grad clip is bundled inside ``MegatronOptimizer.step()``; the 1/N + # rescale via ``self.model.scale_gradients(1/N)`` must run before + # ``optimizer.step()`` so the clip operates on the rescaled grad. + # 4. With ``calculate_per_token_loss=True`` + ``average_in_collective= + # False``, mcore's DDP sums (does not average) grads across DP, so + # no FSDP-style ``loss *= dp_size*cp_size`` cancellation is needed + # per microbatch. + + def _split_step_state_init( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int], + mbs: Optional[int], + ) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + return { + "step_id": step_id, + "loss_fn": loss_fn, + "loss_type": getattr(loss_fn, "loss_type", LossType.TOKEN_LEVEL), + "gbs": gbs or self.cfg["train_global_batch_size"], + "mbs": mbs or self.cfg["train_micro_batch_size"], + "local_valid_seqs": torch.zeros((), dtype=torch.float64, device="cuda"), + "local_valid_toks": torch.zeros((), dtype=torch.float64, device="cuda"), + "all_mb_metrics": [], + "mb_losses": [], + "total_num_microbatches": 0, + # Saved across the step so we can restore at finish/abort. + "saved_grad_sync_func": None, + "no_sync_active": False, + } + + def _assert_step_open(self, step_id: str) -> dict[str, Any]: + state = getattr(self, "_train_step_state", None) + if state is None: + raise RuntimeError( + f"no train step open; begin_train_step({step_id!r}) must be called first" + ) + if state["step_id"] != step_id: + raise RuntimeError( + f"step_id mismatch: open step is {state['step_id']!r}, got {step_id!r}" + ) + return state + + @wrap_with_nvtx_name("megatron_policy_worker/begin_train_step") + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + existing = getattr(self, "_train_step_state", None) + if existing is not None: + raise RuntimeError( + f"train step {existing['step_id']!r} is already open; " + f"call finish_train_step or abort_train_step before begin" + ) + # Match sync train() inference-state reset (line 332-340). + if hasattr(self.model, "inference_params"): + self.model.inference_params = None + for module in self.model.modules(): + if hasattr(module, "reset_inference_cache"): + module.reset_inference_cache() + if hasattr(module, "_inference_key_value_memory"): + module._inference_key_value_memory = None + + self.model.train() + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + + state = self._split_step_state_init( + step_id=step_id, loss_fn=loss_fn, gbs=gbs, mbs=mbs + ) + + # Suppress the PP scheduler's direct ``grad_sync_func`` call (which + # bypasses ``no_sync``). Save the existing value so we can restore + # at finish/abort. PP=1's ``forward_backward_no_pipelining`` doesn't + # invoke this; nulling it is a no-op there. + model_config = self.model.config + state["saved_grad_sync_func"] = getattr(model_config, "grad_sync_func", None) + model_config.grad_sync_func = None + + self._train_step_state = state + + @wrap_with_nvtx_name("megatron_policy_worker/train_microbatch") + def train_microbatch( + self, + step_id: str, + data: BatchedDataDict[Any], + ) -> dict[str, Any]: + """One DP slice of data → one ``forward_backward_func`` invocation. + + Wrapped in ``self.model.no_sync()`` so the mcore DDP hooks + accumulate ``param.main_grad`` locally on each rank without + dispatching a per-call DP reduce. The single true reduce is done + explicitly in ``finish_train_step``. + """ + state = self._assert_step_open(step_id) + loss_fn = state["loss_fn"] + + # Accumulate local mask sums for the finish-time all_reduce. + # Inlined from process_global_batch (data.py:319-332) — we can't + # call process_global_batch directly because it eagerly all_reduces + # the local sums, which is exactly what we're trying to defer. + assert "sample_mask" in data, "sample_mask required on microbatch data" + sample_mask = data["sample_mask"] + call_local_seqs = torch.sum(sample_mask).to(torch.float64) + if "token_mask" in data: + token_mask = data["token_mask"] + call_local_toks = torch.sum( + token_mask[:, 1:] * sample_mask.unsqueeze(-1) + ).to(torch.float64) + else: + call_local_toks = call_local_seqs * data["input_ids"].shape[1] + + state["local_valid_seqs"] = state["local_valid_seqs"] + call_local_seqs + state["local_valid_toks"] = state["local_valid_toks"] + call_local_toks + + # Build the per-call iterator. Each ``train_microbatch_from_meta`` + # call carries one DP slice; the iterator subdivides into pipeline + # microbatches. + ( + data_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator( + data, + self.cfg, + state["mbs"], + straggler_timer=self.mcore_state.straggler_timer, + ) + state["total_num_microbatches"] += int(num_microbatches) + + loss_post_processor = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + num_microbatches=num_microbatches, + sampling_params=self.sampling_params, + draft_model=self.draft_model, + ) + + # Placeholder N=1: loss returns un-normalized sums. ``backward`` + # deposits raw ``d(sum)/dθ`` into ``param.main_grad`` via the DDP + # hooks. The 1/N rescale happens once at finish. + placeholder_n = torch.tensor(1.0, device="cuda") + + draft_enabled = "draft" in self.cfg and self.cfg["draft"]["enabled"] + + # The critical wrap: hooks fire (accumulate main_grad) but the + # per-call reduce dispatch is gated off. + with self.model.no_sync(): + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + losses_reduced = megatron_forward_backward( + model=self.model, + data_iterator=data_iterator, + num_microbatches=num_microbatches, + seq_length=padded_seq_length, + mbs=micro_batch_size, + post_processing_fn=loss_post_processor, + forward_only=False, + defer_fp32_logits=self.defer_fp32_logits, + global_valid_seqs=placeholder_n, + global_valid_toks=placeholder_n, + sampling_params=self.sampling_params, + straggler_timer=self.mcore_state.straggler_timer, + draft_model=self.draft_model, + enable_hidden_capture=draft_enabled, + use_linear_ce_fusion_loss=self.cfg["megatron_cfg"].get( + "use_linear_ce_fusion_loss", False + ), + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + torch.cuda.empty_cache() + + # Collect per-mb metrics from the last PP stage; broadcast to all + # PP ranks so non-last-stage ranks have something to all_reduce + # against at finish. Metrics carry the N=1 placeholder for now — + # ``finish_train_step`` rescales by the true 1/N. + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + mb_metrics_collected = [] + for x in losses_reduced: + mb_metrics_collected.append(dict(x)) + else: + mb_metrics_collected = None + + mb_metrics_collected = broadcast_loss_metrics_from_last_stage( + mb_metrics_collected + ) + + for m in mb_metrics_collected: + state["all_mb_metrics"].append(m) + # ``loss`` key is the un-normalized per-mb scalar; collect for + # the global_loss aggregation at finish. + if "loss" in m: + state["mb_losses"].append(m["loss"]) + + return { + "local_valid_seqs_mb": float(call_local_seqs.item()), + "local_valid_toks_mb": float(call_local_toks.item()), + "num_pipeline_microbatches": int(num_microbatches), + } + + @wrap_with_nvtx_name("megatron_policy_worker/finish_train_step") + def finish_train_step(self, step_id: str) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + state = self._assert_step_open(step_id) + + # All-reduce accumulated mask sums across DP to recover true N. + to_reduce = torch.stack( + [state["local_valid_seqs"], state["local_valid_toks"]] + ).to(torch.float64) + torch.distributed.all_reduce( + to_reduce, group=parallel_state.get_data_parallel_group() + ) + global_valid_seqs = to_reduce[0] + global_valid_toks = to_reduce[1] + + if state["loss_type"] == LossType.TOKEN_LEVEL: + n_true = global_valid_toks + else: + n_true = global_valid_seqs + n_safe = n_true if n_true.item() > 0 else torch.tensor(1.0, device="cuda") + inv_n = float((1.0 / n_safe).item()) + + # Rescale all locally-accumulated gradients by 1/N. The reduce + # below sees the rescaled grads; for all_reduce the result is the + # global mean grad; for reduce_scatter (dist-opt) it's the shard. + # Either way, opt.step sees the right-normalized gradient. + self.model.scale_gradients(inv_n) + + # The ONE true cross-DP reduce for the entire step. + self.model.start_grad_sync() + self.model.finish_grad_sync() + + # opt.step clips internally (clip_grad config); operates on the + # already-rescaled grad. Returns (success, grad_norm, num_zeros). + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() + + pg_collection = get_pg_collection(self.model) + update_successful = logical_and_across_model_parallel_group( + update_successful, mp_group=pg_collection.mp + ) + grad_norm = reduce_max_stat_across_model_parallel_group( + grad_norm, mp_group=pg_collection.mp + ) + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group( + num_zeros_in_grad, mp_group=pg_collection.mp + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 2: + torch.cuda.empty_cache() + + # Restore grad_sync_func before scheduler.step / further state. + self.model.config.grad_sync_func = state["saved_grad_sync_func"] + + # Scheduler increment matches sync path's ``increment=gbs``. + self.scheduler.step(increment=state["gbs"]) + + # Per-mb metrics were computed with N=1; rescale to match what the + # sync path produces. ``masked_mean`` is linear in 1/N so a single + # scalar multiply per metric recovers the normalized value. + rescaled_metrics: list[dict[str, Any]] = [] + curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) + curr_wd = self.scheduler.get_wd() + global_valid_seqs_f = float(global_valid_seqs.item()) + global_valid_toks_f = float(global_valid_toks.item()) + + for m in state["all_mb_metrics"]: + out: dict[str, Any] = {} + for k, v in m.items(): + if "_min" in k or "_max" in k: + out[k] = v + elif isinstance(v, torch.Tensor): + out[k] = v.detach() * inv_n + else: + out[k] = v * inv_n + out["lr"] = curr_lr + out["wd"] = curr_wd + out["global_valid_seqs"] = global_valid_seqs_f + out["global_valid_toks"] = global_valid_toks_f + rescaled_metrics.append(out) + + # Scale per-mb losses by 1/N and reduce per-call sums. + scaled_losses = [lv * inv_n for lv in state["mb_losses"]] + losses_to_aggregate = [torch.tensor(scaled_losses).sum().item()] + + mb_metrics, global_loss = aggregate_training_statistics( + all_mb_metrics=rescaled_metrics, + losses=losses_to_aggregate, + data_parallel_group=parallel_state.get_data_parallel_group(), + ) + + metrics = { + "global_loss": global_loss.cpu(), + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": mb_metrics, + "grad_norm": torch.tensor([grad_norm]), + } + + # MoE aux-loss metrics: same convention as sync train() — scale + # by the total pipeline-microbatch count accumulated across all + # train_microbatch calls. + model_config = getattr(self.model, "config", None) + num_moe_experts = getattr(model_config, "num_moe_experts", None) + if num_moe_experts is not None and num_moe_experts > 1: + moe_loss_scale = 1.0 / max(1, state["total_num_microbatches"]) + moe_metrics = get_moe_metrics( + loss_scale=moe_loss_scale, + per_layer_logging=self.cfg["megatron_cfg"]["moe_per_layer_logging"], + ) + if moe_metrics: + metrics["moe_metrics"] = moe_metrics + + self._train_step_state = None + return metrics + + @wrap_with_nvtx_name("megatron_policy_worker/abort_train_step") + def abort_train_step(self, step_id: str) -> None: + state = getattr(self, "_train_step_state", None) + if state is None: + return + if state["step_id"] != step_id: + raise RuntimeError( + f"abort_train_step({step_id!r}) does not match open step " + f"{state['step_id']!r}" + ) + # Restore grad_sync_func first so the model is back to a normal + # state before zero_grad_buffer touches anything. + self.model.config.grad_sync_func = state["saved_grad_sync_func"] + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + self._train_step_state = None + @wrap_with_nvtx_name("megatron_policy_worker/get_logprobs") def get_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None diff --git a/tests/unit/models/policy/test_megatron_split_state.py b/tests/unit/models/policy/test_megatron_split_state.py new file mode 100644 index 0000000000..faf266f88e --- /dev/null +++ b/tests/unit/models/policy/test_megatron_split_state.py @@ -0,0 +1,541 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU state-machine tests for MegatronPolicyWorkerImpl's split-API. + +These tests cover the lifecycle and call-order invariants — they do NOT +exercise real distributed comms, the mcore scheduler, or the optimizer. +Numerical equivalence vs sync ``train()`` lives in the GPU parity tests. + +The bugs these catch: + - silent gradient over-counting if ``model.no_sync()`` is not wrapped + around ``megatron_forward_backward`` (the mcore DDP hooks would + dispatch a per-call reduce, ADDING to an already-reduced bucket). + - PP>1 pipeline-schedule bypass if ``model.config.grad_sync_func`` is + not nulled for the step's duration. + - ``trainer_version`` advancing on abort. + - ``zero_grad_buffer`` not called at begin (mcore's contiguous grad + buffer leaks stale grads otherwise). + - off-by-one in ``total_num_microbatches`` (used to scale MoE aux-loss). +""" + +from __future__ import annotations + +from contextlib import nullcontext +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# Eagerly import the worker module so ``unittest.mock.patch`` can resolve +# attributes on it via ``getattr``. Without this the patch path +# ``nemo_rl.models.policy.workers.megatron_policy_worker.`` fails +# at ``getattr(workers, "megatron_policy_worker")``. +import nemo_rl.models.policy.workers.megatron_policy_worker # noqa: F401 + +pytestmark = pytest.mark.mcore + +# Module path of the worker under test +WORKER_MOD = "nemo_rl.models.policy.workers.megatron_policy_worker" + + +# ── Mock fabric ────────────────────────────────────────────────────────── + +def _make_mock_model(): + """A mcore-DDP-shaped mock: exposes the methods + attributes the + split-API touches, plus an ``inference_params`` attribute and a + ``modules()`` that yields nothing (so the inference-cache reset loop + is a no-op).""" + model = MagicMock() + model.config = MagicMock() + model.config.grad_sync_func = "ORIGINAL_GRAD_SYNC_FUNC" # sentinel + model.config.num_moe_experts = None # disable MoE branch + # no_sync() is a context manager — return a MagicMock that supports + # __enter__/__exit__ so the `with self.model.no_sync():` block works. + model.no_sync = MagicMock(return_value=MagicMock( + __enter__=MagicMock(return_value=None), + __exit__=MagicMock(return_value=False), + )) + model.modules = MagicMock(return_value=iter([])) + model.inference_params = None + model.parameters = MagicMock(return_value=iter([])) # no params for the rescale loop + return model + + +def _make_worker(loss_type): + """Construct a MegatronPolicyWorkerImpl instance with all heavy + attributes mocked. Bypasses __init__ via ``object.__new__``.""" + # Lazy import so the module-level mcore imports happen inside the + # mcore-marked test process. + from nemo_rl.models.policy.workers.megatron_policy_worker import ( + MegatronPolicyWorkerImpl, + ) + + w = object.__new__(MegatronPolicyWorkerImpl) + w.model = _make_mock_model() + w.optimizer = MagicMock() + # MegatronOptimizer.step returns (success, grad_norm, num_zeros) + w.optimizer.step.return_value = (True, 0.5, 0) + w.optimizer.param_groups = [{"lr": 1e-4, "weight_decay": 0.01}] + w.scheduler = MagicMock() + w.scheduler.get_lr.return_value = 1e-4 + w.scheduler.get_wd.return_value = 0.01 + w.mcore_state = MagicMock() + w.mcore_state.straggler_timer = None + w.cfg = { + "train_global_batch_size": 32, + "train_micro_batch_size": 4, + "megatron_cfg": { + "empty_unused_memory_level": 0, + "moe_per_layer_logging": False, + "use_linear_ce_fusion_loss": False, + }, + } + w.dp_size = 2 + w.cp_size = 1 + w.sampling_params = None + w.draft_model = None + w.defer_fp32_logits = False + w.dtype = torch.float32 + w._is_reward_model = False + + # Stash a loss_fn with the requested loss_type for tests that need one. + w._test_loss_fn = MagicMock(loss_type=loss_type) + return w + + +@pytest.fixture +def mock_module_symbols(): + """Patch every module-level symbol that the split-API methods call + into. Yields a dict of name → mock for assertions.""" + # Make `aggregate_training_statistics` return ({}, scalar) — what the + # finish path expects. + agg_ret = ({"loss": [0.0]}, torch.tensor(0.5)) + + patches = { + "megatron_forward_backward": [ + {"loss": 0.5, "global_valid_seqs": 8.0, "global_valid_toks": 256.0} + ], + "get_microbatch_iterator": (iter([]), 2, 4, 16, 16), # 2 pipeline mbs per call + "LossPostProcessor": MagicMock(), + "broadcast_loss_metrics_from_last_stage": lambda m: m, + "get_pg_collection": MagicMock(mp=MagicMock()), + "logical_and_across_model_parallel_group": lambda v, mp_group: v, + "reduce_max_stat_across_model_parallel_group": lambda v, mp_group: v, + "aggregate_training_statistics": agg_ret, + "get_moe_metrics": MagicMock(return_value={}), + } + + with patch(f"{WORKER_MOD}.megatron_forward_backward", + return_value=patches["megatron_forward_backward"]) as mfb, \ + patch(f"{WORKER_MOD}.get_microbatch_iterator", + return_value=patches["get_microbatch_iterator"]) as gmi, \ + patch(f"{WORKER_MOD}.LossPostProcessor", + return_value=patches["LossPostProcessor"]) as lpp, \ + patch(f"{WORKER_MOD}.broadcast_loss_metrics_from_last_stage", + side_effect=patches["broadcast_loss_metrics_from_last_stage"]) as bcast, \ + patch(f"{WORKER_MOD}.get_pg_collection", + return_value=patches["get_pg_collection"]) as gpgc, \ + patch(f"{WORKER_MOD}.logical_and_across_model_parallel_group", + side_effect=patches["logical_and_across_model_parallel_group"]) as land, \ + patch(f"{WORKER_MOD}.reduce_max_stat_across_model_parallel_group", + side_effect=patches["reduce_max_stat_across_model_parallel_group"]) as rmax, \ + patch(f"{WORKER_MOD}.aggregate_training_statistics", + return_value=patches["aggregate_training_statistics"]) as agg, \ + patch(f"{WORKER_MOD}.get_moe_metrics", + return_value={}) as moe, \ + patch(f"{WORKER_MOD}.get_rerun_state_machine") as grsm, \ + patch(f"{WORKER_MOD}.parallel_state") as pstate, \ + patch("torch.distributed.all_reduce") as ar, \ + patch("torch.cuda.empty_cache") as cec, \ + patch("torch.cuda.get_device_name", return_value="H100"), \ + patch("torch.distributed.get_rank", return_value=0): + + # rerun state machine: fire forward+backward once per train_microbatch + rsm = MagicMock() + rsm.should_run_forward_backward.side_effect = [True, False] * 100 + grsm.return_value = rsm + + # parallel_state mocks + pstate.is_pipeline_last_stage.return_value = True + pstate.get_data_parallel_group.return_value = MagicMock() + + yield { + "mfb": mfb, "gmi": gmi, "lpp": lpp, "bcast": bcast, + "gpgc": gpgc, "land": land, "rmax": rmax, "agg": agg, + "moe": moe, "grsm": grsm, "pstate": pstate, + "all_reduce": ar, "empty_cache": cec, + } + + +def _fake_batch(): + """A minimal BatchedDataDict-ish object the mask-sum block can read. + train_microbatch reads ``data["sample_mask"]``, ``data["token_mask"]``, + and (only as a fallback for the no-token-mask path) ``data["input_ids"]``.""" + # 8 samples, all valid (mask=1); 256 valid tokens each + sample_mask = torch.ones(8, dtype=torch.float32) + token_mask = torch.ones(8, 257, dtype=torch.float32) # token_mask[:, 1:] → 256 toks + input_ids = torch.zeros(8, 257, dtype=torch.long) + return {"sample_mask": sample_mask, "token_mask": token_mask, "input_ids": input_ids} + + +# ── BEGIN ──────────────────────────────────────────────────────────────── + +class TestBegin: + def test_opens_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn, gbs=16, mbs=4) + assert w._train_step_state is not None + assert w._train_step_state["step_id"] == "step-0" + assert w._train_step_state["loss_type"] == LossType.TOKEN_LEVEL + assert w._train_step_state["gbs"] == 16 + assert w._train_step_state["mbs"] == 4 + assert w._train_step_state["total_num_microbatches"] == 0 + + def test_calls_zero_grad_and_zero_grad_buffer(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + w.model.train.assert_called_once() + + def test_saves_and_nulls_grad_sync_func(self, mock_module_symbols): + """The PP scheduler's direct reduce dispatch must be suppressed + for the duration of the step. Otherwise PP>1 silently corrupts + grads even when ``no_sync`` is set on the bucket groups.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + assert w._train_step_state["saved_grad_sync_func"] == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_double_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="already open"): + w.begin_train_step("step-1", loss_fn=w._test_loss_fn) + + def test_uses_cfg_defaults_when_gbs_mbs_omitted(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w._train_step_state["gbs"] == w.cfg["train_global_batch_size"] + assert w._train_step_state["mbs"] == w.cfg["train_micro_batch_size"] + + +# ── _assert_step_open ──────────────────────────────────────────────────── + +class TestAssertStepOpen: + def test_raises_when_no_step_open(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w._assert_step_open("step-0") + + def test_raises_on_step_id_mismatch(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-correct", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="step_id mismatch"): + w._assert_step_open("step-WRONG") + + def test_train_microbatch_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.train_microbatch("step-0", _fake_batch()) + + def test_finish_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.finish_train_step("step-0") + + +# ── train_microbatch ───────────────────────────────────────────────────── + +class TestTrainMicrobatch: + def test_wraps_forward_backward_in_no_sync(self, mock_module_symbols): + """The single most important assertion in this file. Without the + no_sync wrap, mcore DDP dispatches a per-call cross-DP reduce on + the partially-accumulated buffer — silently corrupting grads.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + # no_sync() must have been ENTERED (called as a context manager). + # MagicMock with __enter__/__exit__ records the __enter__ call. + ctx = w.model.no_sync.return_value + ctx.__enter__.assert_called() + ctx.__exit__.assert_called() + + def test_invokes_megatron_forward_backward_once(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + assert mock_module_symbols["mfb"].call_count == 1 + + def test_passes_placeholder_n_one_to_loss(self, mock_module_symbols): + """The N=1 trick: loss must be called with global_valid_*=1 so it + returns un-normalized sums; finish does the 1/N rescale.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + kwargs = mock_module_symbols["mfb"].call_args.kwargs + # placeholder_n is a tensor(1.0) + assert "global_valid_seqs" in kwargs + assert "global_valid_toks" in kwargs + assert float(kwargs["global_valid_seqs"].item()) == pytest.approx(1.0) + assert float(kwargs["global_valid_toks"].item()) == pytest.approx(1.0) + + def test_accumulates_mask_sums_across_calls(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # _fake_batch has sample_mask sum = 8, token_mask*sample_mask sum = 8*256 = 2048 + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx(8.0) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx(2048.0) + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx(16.0) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx(4096.0) + + def test_total_num_microbatches_accumulates(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # get_microbatch_iterator mock returns num_microbatches=2 per call + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + assert w._train_step_state["total_num_microbatches"] == 6 + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + """trainer_version semantics: optimizer.step() must NOT fire + per train_microbatch — only at finish.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.optimizer.step.assert_not_called() + + +# ── finish_train_step ──────────────────────────────────────────────────── + +class TestFinish: + def _setup_open_step(self, mock_module_symbols, loss_type): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(loss_type) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + return w + + def test_rescales_grads_with_inv_n(self, mock_module_symbols): + """The 1/N rescale must happen ON the local main_grad BEFORE the + cross-DP reduce — otherwise the reduce sees un-rescaled sums.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # scale_gradients should have been called with some 1/N scalar < 1 + w.model.scale_gradients.assert_called_once() + arg = w.model.scale_gradients.call_args.args[0] + assert 0 < arg <= 1.0 + + def test_start_then_finish_grad_sync_called_after_rescale(self, mock_module_symbols): + """Call order matters: scale_gradients -> start_grad_sync -> + finish_grad_sync -> optimizer.step.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + # Record call order via a shared list + order: list[str] = [] + w.model.scale_gradients.side_effect = lambda s: order.append("scale") + w.model.start_grad_sync.side_effect = lambda: order.append("start_sync") + w.model.finish_grad_sync.side_effect = lambda: order.append("finish_sync") + w.optimizer.step.side_effect = lambda: (order.append("opt_step") or (True, 0.5, 0)) + w.finish_train_step("s0") + assert order == ["scale", "start_sync", "finish_sync", "opt_step"] + + def test_picks_global_valid_toks_for_token_level_loss(self, mock_module_symbols): + """N selection: TOKEN_LEVEL → N = global_valid_toks (not seqs).""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # local_valid_toks accumulated = 2048; with mocked all_reduce as no-op, + # global_valid_toks == 2048 → inv_n = 1/2048 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 2048.0, rel=1e-4) + + def test_picks_global_valid_seqs_for_sequence_level_loss(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.SEQUENCE_LEVEL) + w.finish_train_step("s0") + # local_valid_seqs = 8 → inv_n = 1/8 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 8.0, rel=1e-4) + + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w._train_step_state is None + + def test_calls_scheduler_step_with_increment_gbs(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w._train_step_state["gbs"] = 64 + w.finish_train_step("s0") + w.scheduler.step.assert_called_once_with(increment=64) + + def test_returns_metrics_dict(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + metrics = w.finish_train_step("s0") + for key in ("global_loss", "rank", "gpu_name", "model_dtype", + "all_mb_metrics", "grad_norm"): + assert key in metrics, f"missing {key!r}" + + def test_moe_branch_skipped_when_num_experts_is_none(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = None + metrics = w.finish_train_step("s0") + assert "moe_metrics" not in metrics + + def test_moe_branch_uses_total_num_microbatches_for_scale(self, mock_module_symbols): + """MoE aux-loss scale must use the accumulated total, not the + per-call num_microbatches.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = 4 + # Have get_moe_metrics return non-empty so the branch fires + mock_module_symbols["moe"].return_value = {"aux_loss": 0.1} + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # 3 train_microbatch calls × 2 pipeline mbs each = 6 + for _ in range(3): + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + # get_moe_metrics receives loss_scale=1/6 + kwargs = mock_module_symbols["moe"].call_args.kwargs + assert kwargs["loss_scale"] == pytest.approx(1.0 / 6.0, rel=1e-6) + + +# ── abort_train_step ───────────────────────────────────────────────────── + +class TestAbort: + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_zero_grad_buffer_and_zero_grad_called(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.reset_mock() + w.optimizer.zero_grad.reset_mock() + w.abort_train_step("s0") + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + w.optimizer.step.assert_not_called() + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w._train_step_state is None + + def test_idempotent_with_no_open_step(self, mock_module_symbols): + """abort is a no-op when nothing is open.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + # Should not raise + w.abort_train_step("s0") + assert getattr(w, "_train_step_state", None) is None + + def test_mismatched_step_id_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="does not match open step"): + w.abort_train_step("s-WRONG") + + def test_can_begin_new_step_after_abort(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + # New step opens cleanly + w.begin_train_step("s1", loss_fn=w._test_loss_fn) + assert w._train_step_state["step_id"] == "s1" + assert float(w._train_step_state["local_valid_seqs"].item()) == 0.0 + + +# ── grad_sync_func full lifecycle (integration of begin → finish/abort) ─ + +class TestGradSyncFuncLifecycle: + def test_begin_finish_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_begin_abort_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_handles_originally_none_grad_sync_func(self, mock_module_symbols): + """When PP=1 (or align_grad_reduce=False), grad_sync_func is None + to begin with. begin → finish must leave it as None.""" + from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.grad_sync_func = None + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func is None From 07e6f751ea46a3fc05cbf3137277d1388137018e Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Wed, 3 Jun 2026 23:02:15 -0700 Subject: [PATCH 2/7] style: lint + type fixes for split-API megatron worker Signed-off-by: Akash Mehra --- .../policy/workers/megatron_policy_worker.py | 5 + .../policy/test_megatron_split_state.py | 183 +++++++++++++----- 2 files changed, 142 insertions(+), 46 deletions(-) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index cdf3fe6212..c87a65b09e 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -101,6 +101,11 @@ class MegatronPolicyWorkerImpl( TQWorkerMixin, AbstractPolicyWorker, ColocatablePolicyInterface ): + # Holds the split-API train-step state between begin/finish or + # begin/abort; None when no step is open. Declared at class level so + # ``self._train_step_state = None`` after finish/abort type-checks. + _train_step_state: Optional[dict[str, Any]] = None + def __repr__(self): """Customizes the actor's prefix in the Ray logs. diff --git a/tests/unit/models/policy/test_megatron_split_state.py b/tests/unit/models/policy/test_megatron_split_state.py index faf266f88e..b62ba9b6b4 100644 --- a/tests/unit/models/policy/test_megatron_split_state.py +++ b/tests/unit/models/policy/test_megatron_split_state.py @@ -31,7 +31,6 @@ from __future__ import annotations -from contextlib import nullcontext from unittest.mock import MagicMock, patch import pytest @@ -51,6 +50,7 @@ # ── Mock fabric ────────────────────────────────────────────────────────── + def _make_mock_model(): """A mcore-DDP-shaped mock: exposes the methods + attributes the split-API touches, plus an ``inference_params`` attribute and a @@ -62,13 +62,17 @@ def _make_mock_model(): model.config.num_moe_experts = None # disable MoE branch # no_sync() is a context manager — return a MagicMock that supports # __enter__/__exit__ so the `with self.model.no_sync():` block works. - model.no_sync = MagicMock(return_value=MagicMock( - __enter__=MagicMock(return_value=None), - __exit__=MagicMock(return_value=False), - )) + model.no_sync = MagicMock( + return_value=MagicMock( + __enter__=MagicMock(return_value=None), + __exit__=MagicMock(return_value=False), + ) + ) model.modules = MagicMock(return_value=iter([])) model.inference_params = None - model.parameters = MagicMock(return_value=iter([])) # no params for the rescale loop + model.parameters = MagicMock( + return_value=iter([]) + ) # no params for the rescale loop return model @@ -136,31 +140,45 @@ def mock_module_symbols(): "get_moe_metrics": MagicMock(return_value={}), } - with patch(f"{WORKER_MOD}.megatron_forward_backward", - return_value=patches["megatron_forward_backward"]) as mfb, \ - patch(f"{WORKER_MOD}.get_microbatch_iterator", - return_value=patches["get_microbatch_iterator"]) as gmi, \ - patch(f"{WORKER_MOD}.LossPostProcessor", - return_value=patches["LossPostProcessor"]) as lpp, \ - patch(f"{WORKER_MOD}.broadcast_loss_metrics_from_last_stage", - side_effect=patches["broadcast_loss_metrics_from_last_stage"]) as bcast, \ - patch(f"{WORKER_MOD}.get_pg_collection", - return_value=patches["get_pg_collection"]) as gpgc, \ - patch(f"{WORKER_MOD}.logical_and_across_model_parallel_group", - side_effect=patches["logical_and_across_model_parallel_group"]) as land, \ - patch(f"{WORKER_MOD}.reduce_max_stat_across_model_parallel_group", - side_effect=patches["reduce_max_stat_across_model_parallel_group"]) as rmax, \ - patch(f"{WORKER_MOD}.aggregate_training_statistics", - return_value=patches["aggregate_training_statistics"]) as agg, \ - patch(f"{WORKER_MOD}.get_moe_metrics", - return_value={}) as moe, \ - patch(f"{WORKER_MOD}.get_rerun_state_machine") as grsm, \ - patch(f"{WORKER_MOD}.parallel_state") as pstate, \ - patch("torch.distributed.all_reduce") as ar, \ - patch("torch.cuda.empty_cache") as cec, \ - patch("torch.cuda.get_device_name", return_value="H100"), \ - patch("torch.distributed.get_rank", return_value=0): - + with ( + patch( + f"{WORKER_MOD}.megatron_forward_backward", + return_value=patches["megatron_forward_backward"], + ) as mfb, + patch( + f"{WORKER_MOD}.get_microbatch_iterator", + return_value=patches["get_microbatch_iterator"], + ) as gmi, + patch( + f"{WORKER_MOD}.LossPostProcessor", return_value=patches["LossPostProcessor"] + ) as lpp, + patch( + f"{WORKER_MOD}.broadcast_loss_metrics_from_last_stage", + side_effect=patches["broadcast_loss_metrics_from_last_stage"], + ) as bcast, + patch( + f"{WORKER_MOD}.get_pg_collection", return_value=patches["get_pg_collection"] + ) as gpgc, + patch( + f"{WORKER_MOD}.logical_and_across_model_parallel_group", + side_effect=patches["logical_and_across_model_parallel_group"], + ) as land, + patch( + f"{WORKER_MOD}.reduce_max_stat_across_model_parallel_group", + side_effect=patches["reduce_max_stat_across_model_parallel_group"], + ) as rmax, + patch( + f"{WORKER_MOD}.aggregate_training_statistics", + return_value=patches["aggregate_training_statistics"], + ) as agg, + patch(f"{WORKER_MOD}.get_moe_metrics", return_value={}) as moe, + patch(f"{WORKER_MOD}.get_rerun_state_machine") as grsm, + patch(f"{WORKER_MOD}.parallel_state") as pstate, + patch("torch.distributed.all_reduce") as ar, + patch("torch.cuda.empty_cache") as cec, + patch("torch.cuda.get_device_name", return_value="H100"), + patch("torch.distributed.get_rank", return_value=0), + ): # rerun state machine: fire forward+backward once per train_microbatch rsm = MagicMock() rsm.should_run_forward_backward.side_effect = [True, False] * 100 @@ -171,10 +189,19 @@ def mock_module_symbols(): pstate.get_data_parallel_group.return_value = MagicMock() yield { - "mfb": mfb, "gmi": gmi, "lpp": lpp, "bcast": bcast, - "gpgc": gpgc, "land": land, "rmax": rmax, "agg": agg, - "moe": moe, "grsm": grsm, "pstate": pstate, - "all_reduce": ar, "empty_cache": cec, + "mfb": mfb, + "gmi": gmi, + "lpp": lpp, + "bcast": bcast, + "gpgc": gpgc, + "land": land, + "rmax": rmax, + "agg": agg, + "moe": moe, + "grsm": grsm, + "pstate": pstate, + "all_reduce": ar, + "empty_cache": cec, } @@ -186,14 +213,20 @@ def _fake_batch(): sample_mask = torch.ones(8, dtype=torch.float32) token_mask = torch.ones(8, 257, dtype=torch.float32) # token_mask[:, 1:] → 256 toks input_ids = torch.zeros(8, 257, dtype=torch.long) - return {"sample_mask": sample_mask, "token_mask": token_mask, "input_ids": input_ids} + return { + "sample_mask": sample_mask, + "token_mask": token_mask, + "input_ids": input_ids, + } # ── BEGIN ──────────────────────────────────────────────────────────────── + class TestBegin: def test_opens_state(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("step-0", loss_fn=w._test_loss_fn, gbs=16, mbs=4) assert w._train_step_state is not None @@ -205,6 +238,7 @@ def test_opens_state(self, mock_module_symbols): def test_calls_zero_grad_and_zero_grad_buffer(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("step-0", loss_fn=w._test_loss_fn) w.model.zero_grad_buffer.assert_called_once() @@ -216,6 +250,7 @@ def test_saves_and_nulls_grad_sync_func(self, mock_module_symbols): for the duration of the step. Otherwise PP>1 silently corrupts grads even when ``no_sync`` is set on the bucket groups.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" w.begin_train_step("step-0", loss_fn=w._test_loss_fn) @@ -224,6 +259,7 @@ def test_saves_and_nulls_grad_sync_func(self, mock_module_symbols): def test_double_begin_raises(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("step-0", loss_fn=w._test_loss_fn) with pytest.raises(RuntimeError, match="already open"): @@ -231,6 +267,7 @@ def test_double_begin_raises(self, mock_module_symbols): def test_uses_cfg_defaults_when_gbs_mbs_omitted(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("step-0", loss_fn=w._test_loss_fn) assert w._train_step_state["gbs"] == w.cfg["train_global_batch_size"] @@ -239,15 +276,18 @@ def test_uses_cfg_defaults_when_gbs_mbs_omitted(self, mock_module_symbols): # ── _assert_step_open ──────────────────────────────────────────────────── + class TestAssertStepOpen: def test_raises_when_no_step_open(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) with pytest.raises(RuntimeError, match="no train step open"): w._assert_step_open("step-0") def test_raises_on_step_id_mismatch(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("step-correct", loss_fn=w._test_loss_fn) with pytest.raises(RuntimeError, match="step_id mismatch"): @@ -255,12 +295,14 @@ def test_raises_on_step_id_mismatch(self, mock_module_symbols): def test_train_microbatch_without_begin_raises(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) with pytest.raises(RuntimeError, match="no train step open"): w.train_microbatch("step-0", _fake_batch()) def test_finish_without_begin_raises(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) with pytest.raises(RuntimeError, match="no train step open"): w.finish_train_step("step-0") @@ -268,12 +310,14 @@ def test_finish_without_begin_raises(self, mock_module_symbols): # ── train_microbatch ───────────────────────────────────────────────────── + class TestTrainMicrobatch: def test_wraps_forward_backward_in_no_sync(self, mock_module_symbols): """The single most important assertion in this file. Without the no_sync wrap, mcore DDP dispatches a per-call cross-DP reduce on the partially-accumulated buffer — silently corrupting grads.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -285,6 +329,7 @@ def test_wraps_forward_backward_in_no_sync(self, mock_module_symbols): def test_invokes_megatron_forward_backward_once(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -294,6 +339,7 @@ def test_passes_placeholder_n_one_to_loss(self, mock_module_symbols): """The N=1 trick: loss must be called with global_valid_*=1 so it returns un-normalized sums; finish does the 1/N rescale.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -306,18 +352,28 @@ def test_passes_placeholder_n_one_to_loss(self, mock_module_symbols): def test_accumulates_mask_sums_across_calls(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) # _fake_batch has sample_mask sum = 8, token_mask*sample_mask sum = 8*256 = 2048 w.train_microbatch("s0", _fake_batch()) - assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx(8.0) - assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx(2048.0) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 8.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 2048.0 + ) w.train_microbatch("s0", _fake_batch()) - assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx(16.0) - assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx(4096.0) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 16.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 4096.0 + ) def test_total_num_microbatches_accumulates(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) # get_microbatch_iterator mock returns num_microbatches=2 per call @@ -330,6 +386,7 @@ def test_does_not_call_optimizer_step(self, mock_module_symbols): """trainer_version semantics: optimizer.step() must NOT fire per train_microbatch — only at finish.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -339,9 +396,9 @@ def test_does_not_call_optimizer_step(self, mock_module_symbols): # ── finish_train_step ──────────────────────────────────────────────────── + class TestFinish: def _setup_open_step(self, mock_module_symbols, loss_type): - from nemo_rl.algorithms.loss.interfaces import LossType w = _make_worker(loss_type) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -351,6 +408,7 @@ def test_rescales_grads_with_inv_n(self, mock_module_symbols): """The 1/N rescale must happen ON the local main_grad BEFORE the cross-DP reduce — otherwise the reduce sees un-rescaled sums.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w.finish_train_step("s0") # scale_gradients should have been called with some 1/N scalar < 1 @@ -358,23 +416,29 @@ def test_rescales_grads_with_inv_n(self, mock_module_symbols): arg = w.model.scale_gradients.call_args.args[0] assert 0 < arg <= 1.0 - def test_start_then_finish_grad_sync_called_after_rescale(self, mock_module_symbols): + def test_start_then_finish_grad_sync_called_after_rescale( + self, mock_module_symbols + ): """Call order matters: scale_gradients -> start_grad_sync -> finish_grad_sync -> optimizer.step.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) # Record call order via a shared list order: list[str] = [] w.model.scale_gradients.side_effect = lambda s: order.append("scale") w.model.start_grad_sync.side_effect = lambda: order.append("start_sync") w.model.finish_grad_sync.side_effect = lambda: order.append("finish_sync") - w.optimizer.step.side_effect = lambda: (order.append("opt_step") or (True, 0.5, 0)) + w.optimizer.step.side_effect = lambda: ( + order.append("opt_step") or (True, 0.5, 0) + ) w.finish_train_step("s0") assert order == ["scale", "start_sync", "finish_sync", "opt_step"] def test_picks_global_valid_toks_for_token_level_loss(self, mock_module_symbols): """N selection: TOKEN_LEVEL → N = global_valid_toks (not seqs).""" from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w.finish_train_step("s0") # local_valid_toks accumulated = 2048; with mocked all_reduce as no-op, @@ -384,6 +448,7 @@ def test_picks_global_valid_toks_for_token_level_loss(self, mock_module_symbols) def test_picks_global_valid_seqs_for_sequence_level_loss(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.SEQUENCE_LEVEL) w.finish_train_step("s0") # local_valid_seqs = 8 → inv_n = 1/8 @@ -392,18 +457,21 @@ def test_picks_global_valid_seqs_for_sequence_level_loss(self, mock_module_symbo def test_restores_grad_sync_func(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w.finish_train_step("s0") assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" def test_clears_train_step_state(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w.finish_train_step("s0") assert w._train_step_state is None def test_calls_scheduler_step_with_increment_gbs(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w._train_step_state["gbs"] = 64 w.finish_train_step("s0") @@ -411,23 +479,34 @@ def test_calls_scheduler_step_with_increment_gbs(self, mock_module_symbols): def test_returns_metrics_dict(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) metrics = w.finish_train_step("s0") - for key in ("global_loss", "rank", "gpu_name", "model_dtype", - "all_mb_metrics", "grad_norm"): + for key in ( + "global_loss", + "rank", + "gpu_name", + "model_dtype", + "all_mb_metrics", + "grad_norm", + ): assert key in metrics, f"missing {key!r}" def test_moe_branch_skipped_when_num_experts_is_none(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) w.model.config.num_moe_experts = None metrics = w.finish_train_step("s0") assert "moe_metrics" not in metrics - def test_moe_branch_uses_total_num_microbatches_for_scale(self, mock_module_symbols): + def test_moe_branch_uses_total_num_microbatches_for_scale( + self, mock_module_symbols + ): """MoE aux-loss scale must use the accumulated total, not the per-call num_microbatches.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.model.config.num_moe_experts = 4 # Have get_moe_metrics return non-empty so the branch fires @@ -444,9 +523,11 @@ def test_moe_branch_uses_total_num_microbatches_for_scale(self, mock_module_symb # ── abort_train_step ───────────────────────────────────────────────────── + class TestAbort: def test_restores_grad_sync_func(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.abort_train_step("s0") @@ -454,6 +535,7 @@ def test_restores_grad_sync_func(self, mock_module_symbols): def test_zero_grad_buffer_and_zero_grad_called(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.model.zero_grad_buffer.reset_mock() @@ -464,6 +546,7 @@ def test_zero_grad_buffer_and_zero_grad_called(self, mock_module_symbols): def test_does_not_call_optimizer_step(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -472,6 +555,7 @@ def test_does_not_call_optimizer_step(self, mock_module_symbols): def test_clears_train_step_state(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.abort_train_step("s0") @@ -480,6 +564,7 @@ def test_clears_train_step_state(self, mock_module_symbols): def test_idempotent_with_no_open_step(self, mock_module_symbols): """abort is a no-op when nothing is open.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) # Should not raise w.abort_train_step("s0") @@ -487,6 +572,7 @@ def test_idempotent_with_no_open_step(self, mock_module_symbols): def test_mismatched_step_id_raises(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) with pytest.raises(RuntimeError, match="does not match open step"): @@ -494,6 +580,7 @@ def test_mismatched_step_id_raises(self, mock_module_symbols): def test_can_begin_new_step_after_abort(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.begin_train_step("s0", loss_fn=w._test_loss_fn) w.train_microbatch("s0", _fake_batch()) @@ -506,9 +593,11 @@ def test_can_begin_new_step_after_abort(self, mock_module_symbols): # ── grad_sync_func full lifecycle (integration of begin → finish/abort) ─ + class TestGradSyncFuncLifecycle: def test_begin_finish_round_trip(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) sentinel = "MY_CUSTOM_GRAD_SYNC" w.model.config.grad_sync_func = sentinel @@ -520,6 +609,7 @@ def test_begin_finish_round_trip(self, mock_module_symbols): def test_begin_abort_round_trip(self, mock_module_symbols): from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) sentinel = "MY_CUSTOM_GRAD_SYNC" w.model.config.grad_sync_func = sentinel @@ -532,6 +622,7 @@ def test_handles_originally_none_grad_sync_func(self, mock_module_symbols): """When PP=1 (or align_grad_reduce=False), grad_sync_func is None to begin with. begin → finish must leave it as None.""" from nemo_rl.algorithms.loss.interfaces import LossType + w = _make_worker(LossType.TOKEN_LEVEL) w.model.config.grad_sync_func = None w.begin_train_step("s0", loss_fn=w._test_loss_fn) From 0a4b1e05bfce3a6620ded702ee11ba2094621e43 Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Thu, 4 Jun 2026 19:06:43 -0700 Subject: [PATCH 3/7] ci: whitelist megatron/draft/__init__.py in pyrefly Pre-existing zero-error file from #2078 (Eagle3) that was never added to the project-includes whitelist. Carrying the fix forward in this PR to unblock the lint job. Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- pyrefly.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrefly.toml b/pyrefly.toml index edea330980..818468da9e 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -152,6 +152,7 @@ project-includes = [ "nemo_rl/models/generation/vllm/vllm_backend.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", + "nemo_rl/models/megatron/draft/__init__.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", From 3cec5519299a86b3b46999c7dbd2617b2a00aa42 Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Thu, 4 Jun 2026 19:21:41 -0700 Subject: [PATCH 4/7] ci: whitelist policy_trainer_actor.py in pyrefly Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- pyrefly.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyrefly.toml b/pyrefly.toml index 818468da9e..7dc5e059a3 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -155,6 +155,7 @@ project-includes = [ "nemo_rl/models/megatron/draft/__init__.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", + "nemo_rl/models/policy/policy_trainer_actor.py", "nemo_rl/models/policy/utils.py", "nemo_rl/models/policy/workers/__init__.py", "nemo_rl/models/policy/workers/patches.py", From 92d49cc1b64b2f46012d76beec16991a88bf851e Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Thu, 4 Jun 2026 20:05:37 -0700 Subject: [PATCH 5/7] ci: drop policy_trainer_actor.py whitelist (not on this branch) The file is introduced by #2692 (DTensor PR), not by this branch. Whitelisting it here causes pyrefly to fail with 'No Python files matched pattern' since the file does not exist on mcore. Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- pyrefly.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrefly.toml b/pyrefly.toml index 7dc5e059a3..818468da9e 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -155,7 +155,6 @@ project-includes = [ "nemo_rl/models/megatron/draft/__init__.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", - "nemo_rl/models/policy/policy_trainer_actor.py", "nemo_rl/models/policy/utils.py", "nemo_rl/models/policy/workers/__init__.py", "nemo_rl/models/policy/workers/patches.py", From 3bc32442a431b837c27af9add360b0374180f0f7 Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Mon, 8 Jun 2026 15:17:04 -0700 Subject: [PATCH 6/7] fix(megatron): avoid 'config' in split-API method co_names cloudpickle traverses globals/closures of each method when serializing the Ray actor class. With torch 2.11, 'config' in __code__.co_names matches torch.distributed.config (a non-pickleable ConfigModuleInstance), breaking actor creation with: TypeError: cannot pickle 'ConfigModuleInstance' object Could not serialize the actor class ...MegatronPolicyWorker Same workaround as the existing sync train(): read 'config' via getattr-by-string in begin/finish/abort_train_step. Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- .../policy/workers/megatron_policy_worker.py | 24 +++++++++++++++---- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index c87a65b09e..93fd4dad89 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -629,9 +629,17 @@ def begin_train_step( # bypasses ``no_sync``). Save the existing value so we can restore # at finish/abort. PP=1's ``forward_backward_no_pipelining`` doesn't # invoke this; nulling it is a no-op there. - model_config = self.model.config - state["saved_grad_sync_func"] = getattr(model_config, "grad_sync_func", None) - model_config.grad_sync_func = None + # Read "config" via getattr-by-string so the token stays out of + # begin_train_step.__code__.co_names; otherwise cloudpickle matches + # torch.distributed.config (a non-pickleable ConfigModuleInstance). + model_config = getattr(self.model, "config", None) + if model_config is not None: + state["saved_grad_sync_func"] = getattr( + model_config, "grad_sync_func", None + ) + model_config.grad_sync_func = None + else: + state["saved_grad_sync_func"] = None self._train_step_state = state @@ -809,7 +817,10 @@ def finish_train_step(self, step_id: str) -> dict[str, Any]: torch.cuda.empty_cache() # Restore grad_sync_func before scheduler.step / further state. - self.model.config.grad_sync_func = state["saved_grad_sync_func"] + # See begin_train_step for why .config is accessed by string. + finish_model_config = getattr(self.model, "config", None) + if finish_model_config is not None: + finish_model_config.grad_sync_func = state["saved_grad_sync_func"] # Scheduler increment matches sync path's ``increment=gbs``. self.scheduler.step(increment=state["gbs"]) @@ -886,7 +897,10 @@ def abort_train_step(self, step_id: str) -> None: ) # Restore grad_sync_func first so the model is back to a normal # state before zero_grad_buffer touches anything. - self.model.config.grad_sync_func = state["saved_grad_sync_func"] + # See begin_train_step for why .config is accessed by string. + abort_model_config = getattr(self.model, "config", None) + if abort_model_config is not None: + abort_model_config.grad_sync_func = state["saved_grad_sync_func"] self.model.zero_grad_buffer() self.optimizer.zero_grad() self._train_step_state = None From cd456b62eb7af6b3352590ba293250db0c14f51e Mon Sep 17 00:00:00 2001 From: Akash Mehra Date: Mon, 8 Jun 2026 17:43:44 -0700 Subject: [PATCH 7/7] test(megatron): skip split_state collection when megatron.bridge missing test_megatron_split_state.py eagerly imports megatron_policy_worker which transitively imports megatron.bridge. In non-mcore shards (Models, Vllm, Sglang, Automodel_Policy), megatron.bridge isn't installed so collection of this file fails, killing every other test in the shard. pytest.importorskip stops collection cleanly when megatron.bridge is not available. The pytest.mark.mcore filter still ensures these tests only run in mcore shards. Co-Authored-By: Claude Opus 4.7 Signed-off-by: Akash Mehra --- tests/unit/models/policy/test_megatron_split_state.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unit/models/policy/test_megatron_split_state.py b/tests/unit/models/policy/test_megatron_split_state.py index b62ba9b6b4..fb50c839ef 100644 --- a/tests/unit/models/policy/test_megatron_split_state.py +++ b/tests/unit/models/policy/test_megatron_split_state.py @@ -36,11 +36,17 @@ import pytest import torch +# megatron.bridge is only available with the mcore extras. Without it the +# eager import of megatron_policy_worker (transitively imports megatron.bridge) +# fails at COLLECTION time on non-mcore shards, which then breaks every other +# test in that shard. importorskip stops collection cleanly here. +pytest.importorskip("megatron.bridge") + # Eagerly import the worker module so ``unittest.mock.patch`` can resolve # attributes on it via ``getattr``. Without this the patch path # ``nemo_rl.models.policy.workers.megatron_policy_worker.`` fails # at ``getattr(workers, "megatron_policy_worker")``. -import nemo_rl.models.policy.workers.megatron_policy_worker # noqa: F401 +import nemo_rl.models.policy.workers.megatron_policy_worker # noqa: E402,F401 pytestmark = pytest.mark.mcore