diff --git a/nemo_rl/data_plane/worker_mixin.py b/nemo_rl/data_plane/worker_mixin.py index bd558e0f06..d6c0294765 100644 --- a/nemo_rl/data_plane/worker_mixin.py +++ b/nemo_rl/data_plane/worker_mixin.py @@ -505,3 +505,82 @@ def get_reference_policy_logprobs_presharded( tq_field="reference_policy_logprobs", ) del result + + # ── split-API entrypoints (SC async path) ────────────────────────────── + # + # The split path lets SingleController drive forward/backward per + # microbatch (or per pipeline-batch on Megatron) without stepping the + # optimizer until a full logical batch has accumulated. Backend + # methods (``begin_train_step``, ``train_microbatch``, + # ``finish_train_step``, ``abort_train_step``) own the train-step + # state machine; this mixin just gates them on TQ-presharded data. + + @wrap_with_nvtx_name("policy_worker/begin_train_step_presharded") + def begin_train_step_presharded( + self, + step_id: str, + loss_fn: Any, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step. No fetch — pure lifecycle. + + The backend stores ``step_id`` / ``loss_fn`` / ``gbs`` / ``mbs``, + clears gradients, and initialises accumulators for + ``local_valid_seqs`` / ``local_valid_toks`` and any per-microbatch + metrics. Optimizer state is untouched here. + """ + self.begin_train_step( # type: ignore[attr-defined] + step_id=step_id, + loss_fn=loss_fn, + gbs=gbs, + mbs=mbs, + ) + + @wrap_with_nvtx_name("policy_worker/train_microbatch_presharded") + def train_microbatch_presharded( + self, + step_id: str, + meta: "KVBatchMeta", + ) -> dict[str, Any]: + """Per-rank microbatch entrypoint. Fetch → packing prep → forward+backward. + + Gradients accumulate into ``.grad`` across calls; no + ``optimizer.step`` here. Returns per-microbatch metrics (loss, + local_valid_*); the backend folds them into the step accumulator + and the caller may surface them for diagnostics. + """ + data = self._fetch(meta) + data = self._attach_or_repack_pack_metadata(data, meta) + return self.train_microbatch( # type: ignore[attr-defined] + step_id=step_id, + data=data, + ) + + @wrap_with_nvtx_name("policy_worker/finish_train_step_presharded") + def finish_train_step_presharded( + self, + step_id: str, + ) -> dict[str, Any]: + """Close a logical train step. No fetch — pure lifecycle. + + Backend all-reduces accumulated ``local_valid_seqs/toks``, + rescales gradients to the final global normalization, runs grad + clip, steps the optimizer + scheduler, then zeros gradients. + Returns the aggregated step result (``loss``, ``grad_norm``, + ``all_mb_metrics``, …). + """ + return self.finish_train_step(step_id=step_id) # type: ignore[attr-defined] + + @wrap_with_nvtx_name("policy_worker/abort_train_step_presharded") + def abort_train_step_presharded( + self, + step_id: str, + ) -> None: + """Discard partial train-step state without stepping the optimizer. + + Used when SC decides the logical batch will not complete (e.g. + weight-sync triggered mid-step). Backend drops accumulators and + zeros gradients. + """ + self.abort_train_step(step_id=step_id) # type: ignore[attr-defined] diff --git a/nemo_rl/models/policy/tq_policy.py b/nemo_rl/models/policy/tq_policy.py index 1179bd8a1f..ecdcfd588a 100644 --- a/nemo_rl/models/policy/tq_policy.py +++ b/nemo_rl/models/policy/tq_policy.py @@ -438,3 +438,128 @@ def train_from_meta( warnings.warn(f"Error getting theoretical flops: {e}") return aggregated_results + + # ── split-API fanout (SC async path) ─────────────────────────────────── + # + # Counterpart to :meth:`train_from_meta`, exposed to ``PolicyTrainerActor`` + # so :class:`SingleControllerActor` can stream microbatches without + # forcing a full-step optimizer.step on every dispatch. + # + # Lifecycle: + # begin_train_step — open step; broadcast loss_fn/gbs/mbs + # train_microbatch_from_meta (N×) — DP-sharded fwd/bwd, grads accumulate + # finish_train_step — all_reduce + opt.step + sched.step + # abort_train_step — drop accumulators, no opt.step + # + # ``train_from_meta`` is unchanged and remains the sync entrypoint. + + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + """Open a logical train step on every worker.""" + batch_size = gbs or self.cfg["train_global_batch_size"] + micro_batch_size = mbs or self.cfg["train_micro_batch_size"] + if self.flops_tracker is not None: + self.flops_tracker.reset() + futures = self.worker_group.run_all_workers_single_data( + "begin_train_step_presharded", + step_id=step_id, + loss_fn=loss_fn, + gbs=batch_size, + mbs=micro_batch_size, + ) + self.worker_group.get_all_worker_results(futures) + + def train_microbatch_from_meta( + self, + step_id: str, + meta: KVBatchMeta, + timer: Optional[Timer] = None, + ) -> dict[str, Any]: + """Dispatch one microbatch (DP-sharded) into an open train step. + + Mirrors the sharding logic of :meth:`train_from_meta` but without + a logical-batch sizing constraint: this routes ``meta`` to DP + ranks and runs forward+backward; gradients accumulate in + ``.grad``. The optimizer step happens at :meth:`finish_train_step`. + """ + self._stamp_pad_seqlen(meta) + spa, dba = self._packing_args("train_mb_tokens") + train_meta = replace( + meta, + fields=list(DP_TRAIN_FIELDS), + task_name="train", + ) + with timer.time("policy_training/shard_meta") if timer else nullcontext(): + dp_metas, _ = shard_meta_for_dp( + train_meta, + dp_world=self.sharding_annotations.get_axis_size("data_parallel"), + batch_size=None, + sequence_packing_args=spa, + dynamic_batching_args=dba, + ) + + if self.flops_tracker is not None: + for m in dp_metas: + self.flops_tracker.track_batch(list(m.sequence_lengths or [])) + + with ( + timer.time("policy_training/submit_microbatch_futures") + if timer + else nullcontext() + ): + futures = self.worker_group.run_all_workers_sharded_data( + "train_microbatch_presharded", + meta=dp_metas, + in_sharded_axes=["data_parallel"], + replicate_on_axes=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + output_is_replicated=[ + "context_parallel", + "tensor_parallel", + "pipeline_parallel", + ], + common_kwargs={"step_id": step_id}, + ) + results = self.worker_group.get_all_worker_results(futures) + # Per-microbatch metrics: pass through DP-rank-0 by convention, + # backend may aggregate later if needed. Surface as-is for now. + return results[0] if results else {} + + def finish_train_step(self, step_id: str) -> dict[str, Any]: + """Close an open train step: all_reduce, rescale, optimizer.step. + + Aggregates per-rank step results into the same shape as + :meth:`train_from_meta` so callers don't have to special-case + the split path. + """ + futures = self.worker_group.run_all_workers_single_data( + "finish_train_step_presharded", + step_id=step_id, + ) + results = self.worker_group.get_all_worker_results(futures) + aggregated_results = _aggregate_train_results(results) + + if self.flops_tracker is not None: + aggregated_results["total_flops"] = self.flops_tracker.total_flops + aggregated_results["num_ranks"] = self.worker_group.cluster.world_size() + + return aggregated_results + + def abort_train_step(self, step_id: str) -> None: + """Drop partial step state on every worker. No optimizer.step.""" + futures = self.worker_group.run_all_workers_single_data( + "abort_train_step_presharded", + step_id=step_id, + ) + self.worker_group.get_all_worker_results(futures) + + if self.flops_tracker is not None: + self.flops_tracker.reset() diff --git a/nemo_rl/models/policy/workers/megatron_policy_worker.py b/nemo_rl/models/policy/workers/megatron_policy_worker.py index 0395d2e0c9..93fd4dad89 100644 --- a/nemo_rl/models/policy/workers/megatron_policy_worker.py +++ b/nemo_rl/models/policy/workers/megatron_policy_worker.py @@ -101,6 +101,11 @@ class MegatronPolicyWorkerImpl( TQWorkerMixin, AbstractPolicyWorker, ColocatablePolicyInterface ): + # Holds the split-API train-step state between begin/finish or + # begin/abort; None when no step is open. Declared at class level so + # ``self._train_step_state = None`` after finish/abort type-checks. + _train_step_state: Optional[dict[str, Any]] = None + def __repr__(self): """Customizes the actor's prefix in the Ray logs. @@ -528,6 +533,378 @@ def train( metrics["moe_metrics"] = moe_metrics return metrics + # ── split-API train-step state machine (SingleController async path) ── + # + # Mirrors the v1/v2 implementations, adapted for mcore. Key differences: + # + # 1. mcore DDP accumulates ``param.main_grad`` per backward and dispatches + # a cross-DP reduce when ``is_last_microbatch=True`` (one per + # ``forward_backward_func`` call). Naively chaining multiple + # ``forward_backward_func`` calls between ``optimizer.step()`` would + # over-count: each call's terminal reduce sums an already-reduced + # bucket again. We wrap every call in ``self.model.no_sync()`` so + # hooks accumulate locally only; one explicit ``start_grad_sync`` + + # ``finish_grad_sync`` at finish does the single true reduce. + # 2. PP>1: the pipeline scheduler invokes ``config.grad_sync_func`` + # directly on last-microbatch boundaries — this bypasses the + # ``no_sync`` gate. We null it for the duration of the step and + # restore at finish/abort. + # 3. Grad clip is bundled inside ``MegatronOptimizer.step()``; the 1/N + # rescale via ``self.model.scale_gradients(1/N)`` must run before + # ``optimizer.step()`` so the clip operates on the rescaled grad. + # 4. With ``calculate_per_token_loss=True`` + ``average_in_collective= + # False``, mcore's DDP sums (does not average) grads across DP, so + # no FSDP-style ``loss *= dp_size*cp_size`` cancellation is needed + # per microbatch. + + def _split_step_state_init( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int], + mbs: Optional[int], + ) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + return { + "step_id": step_id, + "loss_fn": loss_fn, + "loss_type": getattr(loss_fn, "loss_type", LossType.TOKEN_LEVEL), + "gbs": gbs or self.cfg["train_global_batch_size"], + "mbs": mbs or self.cfg["train_micro_batch_size"], + "local_valid_seqs": torch.zeros((), dtype=torch.float64, device="cuda"), + "local_valid_toks": torch.zeros((), dtype=torch.float64, device="cuda"), + "all_mb_metrics": [], + "mb_losses": [], + "total_num_microbatches": 0, + # Saved across the step so we can restore at finish/abort. + "saved_grad_sync_func": None, + "no_sync_active": False, + } + + def _assert_step_open(self, step_id: str) -> dict[str, Any]: + state = getattr(self, "_train_step_state", None) + if state is None: + raise RuntimeError( + f"no train step open; begin_train_step({step_id!r}) must be called first" + ) + if state["step_id"] != step_id: + raise RuntimeError( + f"step_id mismatch: open step is {state['step_id']!r}, got {step_id!r}" + ) + return state + + @wrap_with_nvtx_name("megatron_policy_worker/begin_train_step") + def begin_train_step( + self, + step_id: str, + loss_fn: LossFunction, + gbs: Optional[int] = None, + mbs: Optional[int] = None, + ) -> None: + existing = getattr(self, "_train_step_state", None) + if existing is not None: + raise RuntimeError( + f"train step {existing['step_id']!r} is already open; " + f"call finish_train_step or abort_train_step before begin" + ) + # Match sync train() inference-state reset (line 332-340). + if hasattr(self.model, "inference_params"): + self.model.inference_params = None + for module in self.model.modules(): + if hasattr(module, "reset_inference_cache"): + module.reset_inference_cache() + if hasattr(module, "_inference_key_value_memory"): + module._inference_key_value_memory = None + + self.model.train() + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + + state = self._split_step_state_init( + step_id=step_id, loss_fn=loss_fn, gbs=gbs, mbs=mbs + ) + + # Suppress the PP scheduler's direct ``grad_sync_func`` call (which + # bypasses ``no_sync``). Save the existing value so we can restore + # at finish/abort. PP=1's ``forward_backward_no_pipelining`` doesn't + # invoke this; nulling it is a no-op there. + # Read "config" via getattr-by-string so the token stays out of + # begin_train_step.__code__.co_names; otherwise cloudpickle matches + # torch.distributed.config (a non-pickleable ConfigModuleInstance). + model_config = getattr(self.model, "config", None) + if model_config is not None: + state["saved_grad_sync_func"] = getattr( + model_config, "grad_sync_func", None + ) + model_config.grad_sync_func = None + else: + state["saved_grad_sync_func"] = None + + self._train_step_state = state + + @wrap_with_nvtx_name("megatron_policy_worker/train_microbatch") + def train_microbatch( + self, + step_id: str, + data: BatchedDataDict[Any], + ) -> dict[str, Any]: + """One DP slice of data → one ``forward_backward_func`` invocation. + + Wrapped in ``self.model.no_sync()`` so the mcore DDP hooks + accumulate ``param.main_grad`` locally on each rank without + dispatching a per-call DP reduce. The single true reduce is done + explicitly in ``finish_train_step``. + """ + state = self._assert_step_open(step_id) + loss_fn = state["loss_fn"] + + # Accumulate local mask sums for the finish-time all_reduce. + # Inlined from process_global_batch (data.py:319-332) — we can't + # call process_global_batch directly because it eagerly all_reduces + # the local sums, which is exactly what we're trying to defer. + assert "sample_mask" in data, "sample_mask required on microbatch data" + sample_mask = data["sample_mask"] + call_local_seqs = torch.sum(sample_mask).to(torch.float64) + if "token_mask" in data: + token_mask = data["token_mask"] + call_local_toks = torch.sum( + token_mask[:, 1:] * sample_mask.unsqueeze(-1) + ).to(torch.float64) + else: + call_local_toks = call_local_seqs * data["input_ids"].shape[1] + + state["local_valid_seqs"] = state["local_valid_seqs"] + call_local_seqs + state["local_valid_toks"] = state["local_valid_toks"] + call_local_toks + + # Build the per-call iterator. Each ``train_microbatch_from_meta`` + # call carries one DP slice; the iterator subdivides into pipeline + # microbatches. + ( + data_iterator, + num_microbatches, + micro_batch_size, + seq_length, + padded_seq_length, + ) = get_microbatch_iterator( + data, + self.cfg, + state["mbs"], + straggler_timer=self.mcore_state.straggler_timer, + ) + state["total_num_microbatches"] += int(num_microbatches) + + loss_post_processor = LossPostProcessor( + loss_fn=loss_fn, + cfg=self.cfg, + num_microbatches=num_microbatches, + sampling_params=self.sampling_params, + draft_model=self.draft_model, + ) + + # Placeholder N=1: loss returns un-normalized sums. ``backward`` + # deposits raw ``d(sum)/dθ`` into ``param.main_grad`` via the DDP + # hooks. The 1/N rescale happens once at finish. + placeholder_n = torch.tensor(1.0, device="cuda") + + draft_enabled = "draft" in self.cfg and self.cfg["draft"]["enabled"] + + # The critical wrap: hooks fire (accumulate main_grad) but the + # per-call reduce dispatch is gated off. + with self.model.no_sync(): + rerun_state_machine = get_rerun_state_machine() + while rerun_state_machine.should_run_forward_backward(data_iterator): + losses_reduced = megatron_forward_backward( + model=self.model, + data_iterator=data_iterator, + num_microbatches=num_microbatches, + seq_length=padded_seq_length, + mbs=micro_batch_size, + post_processing_fn=loss_post_processor, + forward_only=False, + defer_fp32_logits=self.defer_fp32_logits, + global_valid_seqs=placeholder_n, + global_valid_toks=placeholder_n, + sampling_params=self.sampling_params, + straggler_timer=self.mcore_state.straggler_timer, + draft_model=self.draft_model, + enable_hidden_capture=draft_enabled, + use_linear_ce_fusion_loss=self.cfg["megatron_cfg"].get( + "use_linear_ce_fusion_loss", False + ), + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 1: + torch.cuda.empty_cache() + + # Collect per-mb metrics from the last PP stage; broadcast to all + # PP ranks so non-last-stage ranks have something to all_reduce + # against at finish. Metrics carry the N=1 placeholder for now — + # ``finish_train_step`` rescales by the true 1/N. + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + mb_metrics_collected = [] + for x in losses_reduced: + mb_metrics_collected.append(dict(x)) + else: + mb_metrics_collected = None + + mb_metrics_collected = broadcast_loss_metrics_from_last_stage( + mb_metrics_collected + ) + + for m in mb_metrics_collected: + state["all_mb_metrics"].append(m) + # ``loss`` key is the un-normalized per-mb scalar; collect for + # the global_loss aggregation at finish. + if "loss" in m: + state["mb_losses"].append(m["loss"]) + + return { + "local_valid_seqs_mb": float(call_local_seqs.item()), + "local_valid_toks_mb": float(call_local_toks.item()), + "num_pipeline_microbatches": int(num_microbatches), + } + + @wrap_with_nvtx_name("megatron_policy_worker/finish_train_step") + def finish_train_step(self, step_id: str) -> dict[str, Any]: + from nemo_rl.algorithms.loss.interfaces import LossType + + state = self._assert_step_open(step_id) + + # All-reduce accumulated mask sums across DP to recover true N. + to_reduce = torch.stack( + [state["local_valid_seqs"], state["local_valid_toks"]] + ).to(torch.float64) + torch.distributed.all_reduce( + to_reduce, group=parallel_state.get_data_parallel_group() + ) + global_valid_seqs = to_reduce[0] + global_valid_toks = to_reduce[1] + + if state["loss_type"] == LossType.TOKEN_LEVEL: + n_true = global_valid_toks + else: + n_true = global_valid_seqs + n_safe = n_true if n_true.item() > 0 else torch.tensor(1.0, device="cuda") + inv_n = float((1.0 / n_safe).item()) + + # Rescale all locally-accumulated gradients by 1/N. The reduce + # below sees the rescaled grads; for all_reduce the result is the + # global mean grad; for reduce_scatter (dist-opt) it's the shard. + # Either way, opt.step sees the right-normalized gradient. + self.model.scale_gradients(inv_n) + + # The ONE true cross-DP reduce for the entire step. + self.model.start_grad_sync() + self.model.finish_grad_sync() + + # opt.step clips internally (clip_grad config); operates on the + # already-rescaled grad. Returns (success, grad_norm, num_zeros). + update_successful, grad_norm, num_zeros_in_grad = self.optimizer.step() + + pg_collection = get_pg_collection(self.model) + update_successful = logical_and_across_model_parallel_group( + update_successful, mp_group=pg_collection.mp + ) + grad_norm = reduce_max_stat_across_model_parallel_group( + grad_norm, mp_group=pg_collection.mp + ) + num_zeros_in_grad = reduce_max_stat_across_model_parallel_group( + num_zeros_in_grad, mp_group=pg_collection.mp + ) + + if self.cfg["megatron_cfg"]["empty_unused_memory_level"] >= 2: + torch.cuda.empty_cache() + + # Restore grad_sync_func before scheduler.step / further state. + # See begin_train_step for why .config is accessed by string. + finish_model_config = getattr(self.model, "config", None) + if finish_model_config is not None: + finish_model_config.grad_sync_func = state["saved_grad_sync_func"] + + # Scheduler increment matches sync path's ``increment=gbs``. + self.scheduler.step(increment=state["gbs"]) + + # Per-mb metrics were computed with N=1; rescale to match what the + # sync path produces. ``masked_mean`` is linear in 1/N so a single + # scalar multiply per metric recovers the normalized value. + rescaled_metrics: list[dict[str, Any]] = [] + curr_lr = self.scheduler.get_lr(self.optimizer.param_groups[0]) + curr_wd = self.scheduler.get_wd() + global_valid_seqs_f = float(global_valid_seqs.item()) + global_valid_toks_f = float(global_valid_toks.item()) + + for m in state["all_mb_metrics"]: + out: dict[str, Any] = {} + for k, v in m.items(): + if "_min" in k or "_max" in k: + out[k] = v + elif isinstance(v, torch.Tensor): + out[k] = v.detach() * inv_n + else: + out[k] = v * inv_n + out["lr"] = curr_lr + out["wd"] = curr_wd + out["global_valid_seqs"] = global_valid_seqs_f + out["global_valid_toks"] = global_valid_toks_f + rescaled_metrics.append(out) + + # Scale per-mb losses by 1/N and reduce per-call sums. + scaled_losses = [lv * inv_n for lv in state["mb_losses"]] + losses_to_aggregate = [torch.tensor(scaled_losses).sum().item()] + + mb_metrics, global_loss = aggregate_training_statistics( + all_mb_metrics=rescaled_metrics, + losses=losses_to_aggregate, + data_parallel_group=parallel_state.get_data_parallel_group(), + ) + + metrics = { + "global_loss": global_loss.cpu(), + "rank": torch.distributed.get_rank(), + "gpu_name": torch.cuda.get_device_name(), + "model_dtype": self.dtype, + "all_mb_metrics": mb_metrics, + "grad_norm": torch.tensor([grad_norm]), + } + + # MoE aux-loss metrics: same convention as sync train() — scale + # by the total pipeline-microbatch count accumulated across all + # train_microbatch calls. + model_config = getattr(self.model, "config", None) + num_moe_experts = getattr(model_config, "num_moe_experts", None) + if num_moe_experts is not None and num_moe_experts > 1: + moe_loss_scale = 1.0 / max(1, state["total_num_microbatches"]) + moe_metrics = get_moe_metrics( + loss_scale=moe_loss_scale, + per_layer_logging=self.cfg["megatron_cfg"]["moe_per_layer_logging"], + ) + if moe_metrics: + metrics["moe_metrics"] = moe_metrics + + self._train_step_state = None + return metrics + + @wrap_with_nvtx_name("megatron_policy_worker/abort_train_step") + def abort_train_step(self, step_id: str) -> None: + state = getattr(self, "_train_step_state", None) + if state is None: + return + if state["step_id"] != step_id: + raise RuntimeError( + f"abort_train_step({step_id!r}) does not match open step " + f"{state['step_id']!r}" + ) + # Restore grad_sync_func first so the model is back to a normal + # state before zero_grad_buffer touches anything. + # See begin_train_step for why .config is accessed by string. + abort_model_config = getattr(self.model, "config", None) + if abort_model_config is not None: + abort_model_config.grad_sync_func = state["saved_grad_sync_func"] + self.model.zero_grad_buffer() + self.optimizer.zero_grad() + self._train_step_state = None + @wrap_with_nvtx_name("megatron_policy_worker/get_logprobs") def get_logprobs( self, *, data: BatchedDataDict[Any], micro_batch_size: Optional[int] = None diff --git a/pyrefly.toml b/pyrefly.toml index edea330980..818468da9e 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -152,6 +152,7 @@ project-includes = [ "nemo_rl/models/generation/vllm/vllm_backend.py", "nemo_rl/models/huggingface/__init__.py", "nemo_rl/models/megatron/__init__.py", + "nemo_rl/models/megatron/draft/__init__.py", "nemo_rl/models/policy/__init__.py", "nemo_rl/models/policy/interfaces.py", "nemo_rl/models/policy/utils.py", diff --git a/tests/unit/models/policy/test_megatron_split_state.py b/tests/unit/models/policy/test_megatron_split_state.py new file mode 100644 index 0000000000..fb50c839ef --- /dev/null +++ b/tests/unit/models/policy/test_megatron_split_state.py @@ -0,0 +1,638 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CPU state-machine tests for MegatronPolicyWorkerImpl's split-API. + +These tests cover the lifecycle and call-order invariants — they do NOT +exercise real distributed comms, the mcore scheduler, or the optimizer. +Numerical equivalence vs sync ``train()`` lives in the GPU parity tests. + +The bugs these catch: + - silent gradient over-counting if ``model.no_sync()`` is not wrapped + around ``megatron_forward_backward`` (the mcore DDP hooks would + dispatch a per-call reduce, ADDING to an already-reduced bucket). + - PP>1 pipeline-schedule bypass if ``model.config.grad_sync_func`` is + not nulled for the step's duration. + - ``trainer_version`` advancing on abort. + - ``zero_grad_buffer`` not called at begin (mcore's contiguous grad + buffer leaks stale grads otherwise). + - off-by-one in ``total_num_microbatches`` (used to scale MoE aux-loss). +""" + +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +import torch + +# megatron.bridge is only available with the mcore extras. Without it the +# eager import of megatron_policy_worker (transitively imports megatron.bridge) +# fails at COLLECTION time on non-mcore shards, which then breaks every other +# test in that shard. importorskip stops collection cleanly here. +pytest.importorskip("megatron.bridge") + +# Eagerly import the worker module so ``unittest.mock.patch`` can resolve +# attributes on it via ``getattr``. Without this the patch path +# ``nemo_rl.models.policy.workers.megatron_policy_worker.`` fails +# at ``getattr(workers, "megatron_policy_worker")``. +import nemo_rl.models.policy.workers.megatron_policy_worker # noqa: E402,F401 + +pytestmark = pytest.mark.mcore + +# Module path of the worker under test +WORKER_MOD = "nemo_rl.models.policy.workers.megatron_policy_worker" + + +# ── Mock fabric ────────────────────────────────────────────────────────── + + +def _make_mock_model(): + """A mcore-DDP-shaped mock: exposes the methods + attributes the + split-API touches, plus an ``inference_params`` attribute and a + ``modules()`` that yields nothing (so the inference-cache reset loop + is a no-op).""" + model = MagicMock() + model.config = MagicMock() + model.config.grad_sync_func = "ORIGINAL_GRAD_SYNC_FUNC" # sentinel + model.config.num_moe_experts = None # disable MoE branch + # no_sync() is a context manager — return a MagicMock that supports + # __enter__/__exit__ so the `with self.model.no_sync():` block works. + model.no_sync = MagicMock( + return_value=MagicMock( + __enter__=MagicMock(return_value=None), + __exit__=MagicMock(return_value=False), + ) + ) + model.modules = MagicMock(return_value=iter([])) + model.inference_params = None + model.parameters = MagicMock( + return_value=iter([]) + ) # no params for the rescale loop + return model + + +def _make_worker(loss_type): + """Construct a MegatronPolicyWorkerImpl instance with all heavy + attributes mocked. Bypasses __init__ via ``object.__new__``.""" + # Lazy import so the module-level mcore imports happen inside the + # mcore-marked test process. + from nemo_rl.models.policy.workers.megatron_policy_worker import ( + MegatronPolicyWorkerImpl, + ) + + w = object.__new__(MegatronPolicyWorkerImpl) + w.model = _make_mock_model() + w.optimizer = MagicMock() + # MegatronOptimizer.step returns (success, grad_norm, num_zeros) + w.optimizer.step.return_value = (True, 0.5, 0) + w.optimizer.param_groups = [{"lr": 1e-4, "weight_decay": 0.01}] + w.scheduler = MagicMock() + w.scheduler.get_lr.return_value = 1e-4 + w.scheduler.get_wd.return_value = 0.01 + w.mcore_state = MagicMock() + w.mcore_state.straggler_timer = None + w.cfg = { + "train_global_batch_size": 32, + "train_micro_batch_size": 4, + "megatron_cfg": { + "empty_unused_memory_level": 0, + "moe_per_layer_logging": False, + "use_linear_ce_fusion_loss": False, + }, + } + w.dp_size = 2 + w.cp_size = 1 + w.sampling_params = None + w.draft_model = None + w.defer_fp32_logits = False + w.dtype = torch.float32 + w._is_reward_model = False + + # Stash a loss_fn with the requested loss_type for tests that need one. + w._test_loss_fn = MagicMock(loss_type=loss_type) + return w + + +@pytest.fixture +def mock_module_symbols(): + """Patch every module-level symbol that the split-API methods call + into. Yields a dict of name → mock for assertions.""" + # Make `aggregate_training_statistics` return ({}, scalar) — what the + # finish path expects. + agg_ret = ({"loss": [0.0]}, torch.tensor(0.5)) + + patches = { + "megatron_forward_backward": [ + {"loss": 0.5, "global_valid_seqs": 8.0, "global_valid_toks": 256.0} + ], + "get_microbatch_iterator": (iter([]), 2, 4, 16, 16), # 2 pipeline mbs per call + "LossPostProcessor": MagicMock(), + "broadcast_loss_metrics_from_last_stage": lambda m: m, + "get_pg_collection": MagicMock(mp=MagicMock()), + "logical_and_across_model_parallel_group": lambda v, mp_group: v, + "reduce_max_stat_across_model_parallel_group": lambda v, mp_group: v, + "aggregate_training_statistics": agg_ret, + "get_moe_metrics": MagicMock(return_value={}), + } + + with ( + patch( + f"{WORKER_MOD}.megatron_forward_backward", + return_value=patches["megatron_forward_backward"], + ) as mfb, + patch( + f"{WORKER_MOD}.get_microbatch_iterator", + return_value=patches["get_microbatch_iterator"], + ) as gmi, + patch( + f"{WORKER_MOD}.LossPostProcessor", return_value=patches["LossPostProcessor"] + ) as lpp, + patch( + f"{WORKER_MOD}.broadcast_loss_metrics_from_last_stage", + side_effect=patches["broadcast_loss_metrics_from_last_stage"], + ) as bcast, + patch( + f"{WORKER_MOD}.get_pg_collection", return_value=patches["get_pg_collection"] + ) as gpgc, + patch( + f"{WORKER_MOD}.logical_and_across_model_parallel_group", + side_effect=patches["logical_and_across_model_parallel_group"], + ) as land, + patch( + f"{WORKER_MOD}.reduce_max_stat_across_model_parallel_group", + side_effect=patches["reduce_max_stat_across_model_parallel_group"], + ) as rmax, + patch( + f"{WORKER_MOD}.aggregate_training_statistics", + return_value=patches["aggregate_training_statistics"], + ) as agg, + patch(f"{WORKER_MOD}.get_moe_metrics", return_value={}) as moe, + patch(f"{WORKER_MOD}.get_rerun_state_machine") as grsm, + patch(f"{WORKER_MOD}.parallel_state") as pstate, + patch("torch.distributed.all_reduce") as ar, + patch("torch.cuda.empty_cache") as cec, + patch("torch.cuda.get_device_name", return_value="H100"), + patch("torch.distributed.get_rank", return_value=0), + ): + # rerun state machine: fire forward+backward once per train_microbatch + rsm = MagicMock() + rsm.should_run_forward_backward.side_effect = [True, False] * 100 + grsm.return_value = rsm + + # parallel_state mocks + pstate.is_pipeline_last_stage.return_value = True + pstate.get_data_parallel_group.return_value = MagicMock() + + yield { + "mfb": mfb, + "gmi": gmi, + "lpp": lpp, + "bcast": bcast, + "gpgc": gpgc, + "land": land, + "rmax": rmax, + "agg": agg, + "moe": moe, + "grsm": grsm, + "pstate": pstate, + "all_reduce": ar, + "empty_cache": cec, + } + + +def _fake_batch(): + """A minimal BatchedDataDict-ish object the mask-sum block can read. + train_microbatch reads ``data["sample_mask"]``, ``data["token_mask"]``, + and (only as a fallback for the no-token-mask path) ``data["input_ids"]``.""" + # 8 samples, all valid (mask=1); 256 valid tokens each + sample_mask = torch.ones(8, dtype=torch.float32) + token_mask = torch.ones(8, 257, dtype=torch.float32) # token_mask[:, 1:] → 256 toks + input_ids = torch.zeros(8, 257, dtype=torch.long) + return { + "sample_mask": sample_mask, + "token_mask": token_mask, + "input_ids": input_ids, + } + + +# ── BEGIN ──────────────────────────────────────────────────────────────── + + +class TestBegin: + def test_opens_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn, gbs=16, mbs=4) + assert w._train_step_state is not None + assert w._train_step_state["step_id"] == "step-0" + assert w._train_step_state["loss_type"] == LossType.TOKEN_LEVEL + assert w._train_step_state["gbs"] == 16 + assert w._train_step_state["mbs"] == 4 + assert w._train_step_state["total_num_microbatches"] == 0 + + def test_calls_zero_grad_and_zero_grad_buffer(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + w.model.train.assert_called_once() + + def test_saves_and_nulls_grad_sync_func(self, mock_module_symbols): + """The PP scheduler's direct reduce dispatch must be suppressed + for the duration of the step. Otherwise PP>1 silently corrupts + grads even when ``no_sync`` is set on the bucket groups.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + assert w._train_step_state["saved_grad_sync_func"] == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_double_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="already open"): + w.begin_train_step("step-1", loss_fn=w._test_loss_fn) + + def test_uses_cfg_defaults_when_gbs_mbs_omitted(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-0", loss_fn=w._test_loss_fn) + assert w._train_step_state["gbs"] == w.cfg["train_global_batch_size"] + assert w._train_step_state["mbs"] == w.cfg["train_micro_batch_size"] + + +# ── _assert_step_open ──────────────────────────────────────────────────── + + +class TestAssertStepOpen: + def test_raises_when_no_step_open(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w._assert_step_open("step-0") + + def test_raises_on_step_id_mismatch(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("step-correct", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="step_id mismatch"): + w._assert_step_open("step-WRONG") + + def test_train_microbatch_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.train_microbatch("step-0", _fake_batch()) + + def test_finish_without_begin_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + with pytest.raises(RuntimeError, match="no train step open"): + w.finish_train_step("step-0") + + +# ── train_microbatch ───────────────────────────────────────────────────── + + +class TestTrainMicrobatch: + def test_wraps_forward_backward_in_no_sync(self, mock_module_symbols): + """The single most important assertion in this file. Without the + no_sync wrap, mcore DDP dispatches a per-call cross-DP reduce on + the partially-accumulated buffer — silently corrupting grads.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + # no_sync() must have been ENTERED (called as a context manager). + # MagicMock with __enter__/__exit__ records the __enter__ call. + ctx = w.model.no_sync.return_value + ctx.__enter__.assert_called() + ctx.__exit__.assert_called() + + def test_invokes_megatron_forward_backward_once(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + assert mock_module_symbols["mfb"].call_count == 1 + + def test_passes_placeholder_n_one_to_loss(self, mock_module_symbols): + """The N=1 trick: loss must be called with global_valid_*=1 so it + returns un-normalized sums; finish does the 1/N rescale.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + kwargs = mock_module_symbols["mfb"].call_args.kwargs + # placeholder_n is a tensor(1.0) + assert "global_valid_seqs" in kwargs + assert "global_valid_toks" in kwargs + assert float(kwargs["global_valid_seqs"].item()) == pytest.approx(1.0) + assert float(kwargs["global_valid_toks"].item()) == pytest.approx(1.0) + + def test_accumulates_mask_sums_across_calls(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # _fake_batch has sample_mask sum = 8, token_mask*sample_mask sum = 8*256 = 2048 + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 8.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 2048.0 + ) + w.train_microbatch("s0", _fake_batch()) + assert float(w._train_step_state["local_valid_seqs"].item()) == pytest.approx( + 16.0 + ) + assert float(w._train_step_state["local_valid_toks"].item()) == pytest.approx( + 4096.0 + ) + + def test_total_num_microbatches_accumulates(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # get_microbatch_iterator mock returns num_microbatches=2 per call + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + assert w._train_step_state["total_num_microbatches"] == 6 + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + """trainer_version semantics: optimizer.step() must NOT fire + per train_microbatch — only at finish.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.train_microbatch("s0", _fake_batch()) + w.optimizer.step.assert_not_called() + + +# ── finish_train_step ──────────────────────────────────────────────────── + + +class TestFinish: + def _setup_open_step(self, mock_module_symbols, loss_type): + w = _make_worker(loss_type) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + return w + + def test_rescales_grads_with_inv_n(self, mock_module_symbols): + """The 1/N rescale must happen ON the local main_grad BEFORE the + cross-DP reduce — otherwise the reduce sees un-rescaled sums.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # scale_gradients should have been called with some 1/N scalar < 1 + w.model.scale_gradients.assert_called_once() + arg = w.model.scale_gradients.call_args.args[0] + assert 0 < arg <= 1.0 + + def test_start_then_finish_grad_sync_called_after_rescale( + self, mock_module_symbols + ): + """Call order matters: scale_gradients -> start_grad_sync -> + finish_grad_sync -> optimizer.step.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + # Record call order via a shared list + order: list[str] = [] + w.model.scale_gradients.side_effect = lambda s: order.append("scale") + w.model.start_grad_sync.side_effect = lambda: order.append("start_sync") + w.model.finish_grad_sync.side_effect = lambda: order.append("finish_sync") + w.optimizer.step.side_effect = lambda: ( + order.append("opt_step") or (True, 0.5, 0) + ) + w.finish_train_step("s0") + assert order == ["scale", "start_sync", "finish_sync", "opt_step"] + + def test_picks_global_valid_toks_for_token_level_loss(self, mock_module_symbols): + """N selection: TOKEN_LEVEL → N = global_valid_toks (not seqs).""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + # local_valid_toks accumulated = 2048; with mocked all_reduce as no-op, + # global_valid_toks == 2048 → inv_n = 1/2048 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 2048.0, rel=1e-4) + + def test_picks_global_valid_seqs_for_sequence_level_loss(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.SEQUENCE_LEVEL) + w.finish_train_step("s0") + # local_valid_seqs = 8 → inv_n = 1/8 + arg = w.model.scale_gradients.call_args.args[0] + assert arg == pytest.approx(1.0 / 8.0, rel=1e-4) + + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.finish_train_step("s0") + assert w._train_step_state is None + + def test_calls_scheduler_step_with_increment_gbs(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w._train_step_state["gbs"] = 64 + w.finish_train_step("s0") + w.scheduler.step.assert_called_once_with(increment=64) + + def test_returns_metrics_dict(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + metrics = w.finish_train_step("s0") + for key in ( + "global_loss", + "rank", + "gpu_name", + "model_dtype", + "all_mb_metrics", + "grad_norm", + ): + assert key in metrics, f"missing {key!r}" + + def test_moe_branch_skipped_when_num_experts_is_none(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = self._setup_open_step(mock_module_symbols, LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = None + metrics = w.finish_train_step("s0") + assert "moe_metrics" not in metrics + + def test_moe_branch_uses_total_num_microbatches_for_scale( + self, mock_module_symbols + ): + """MoE aux-loss scale must use the accumulated total, not the + per-call num_microbatches.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.num_moe_experts = 4 + # Have get_moe_metrics return non-empty so the branch fires + mock_module_symbols["moe"].return_value = {"aux_loss": 0.1} + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + # 3 train_microbatch calls × 2 pipeline mbs each = 6 + for _ in range(3): + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + # get_moe_metrics receives loss_scale=1/6 + kwargs = mock_module_symbols["moe"].call_args.kwargs + assert kwargs["loss_scale"] == pytest.approx(1.0 / 6.0, rel=1e-6) + + +# ── abort_train_step ───────────────────────────────────────────────────── + + +class TestAbort: + def test_restores_grad_sync_func(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == "ORIGINAL_GRAD_SYNC_FUNC" + + def test_zero_grad_buffer_and_zero_grad_called(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.model.zero_grad_buffer.reset_mock() + w.optimizer.zero_grad.reset_mock() + w.abort_train_step("s0") + w.model.zero_grad_buffer.assert_called_once() + w.optimizer.zero_grad.assert_called_once() + + def test_does_not_call_optimizer_step(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + w.optimizer.step.assert_not_called() + + def test_clears_train_step_state(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.abort_train_step("s0") + assert w._train_step_state is None + + def test_idempotent_with_no_open_step(self, mock_module_symbols): + """abort is a no-op when nothing is open.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + # Should not raise + w.abort_train_step("s0") + assert getattr(w, "_train_step_state", None) is None + + def test_mismatched_step_id_raises(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + with pytest.raises(RuntimeError, match="does not match open step"): + w.abort_train_step("s-WRONG") + + def test_can_begin_new_step_after_abort(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + w.train_microbatch("s0", _fake_batch()) + w.abort_train_step("s0") + # New step opens cleanly + w.begin_train_step("s1", loss_fn=w._test_loss_fn) + assert w._train_step_state["step_id"] == "s1" + assert float(w._train_step_state["local_valid_seqs"].item()) == 0.0 + + +# ── grad_sync_func full lifecycle (integration of begin → finish/abort) ─ + + +class TestGradSyncFuncLifecycle: + def test_begin_finish_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_begin_abort_round_trip(self, mock_module_symbols): + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + sentinel = "MY_CUSTOM_GRAD_SYNC" + w.model.config.grad_sync_func = sentinel + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.abort_train_step("s0") + assert w.model.config.grad_sync_func == sentinel + + def test_handles_originally_none_grad_sync_func(self, mock_module_symbols): + """When PP=1 (or align_grad_reduce=False), grad_sync_func is None + to begin with. begin → finish must leave it as None.""" + from nemo_rl.algorithms.loss.interfaces import LossType + + w = _make_worker(LossType.TOKEN_LEVEL) + w.model.config.grad_sync_func = None + w.begin_train_step("s0", loss_fn=w._test_loss_fn) + assert w.model.config.grad_sync_func is None + w.train_microbatch("s0", _fake_batch()) + w.finish_train_step("s0") + assert w.model.config.grad_sync_func is None