Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions nemo_rl/data_plane/worker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,82 @@ def get_reference_policy_logprobs_presharded(
tq_field="reference_policy_logprobs",
)
del result

# ── split-API entrypoints (SC async path) ──────────────────────────────
#
# The split path lets SingleController drive forward/backward per
# microbatch (or per pipeline-batch on Megatron) without stepping the
# optimizer until a full logical batch has accumulated. Backend
# methods (``begin_train_step``, ``train_microbatch``,
# ``finish_train_step``, ``abort_train_step``) own the train-step
# state machine; this mixin just gates them on TQ-presharded data.

@wrap_with_nvtx_name("policy_worker/begin_train_step_presharded")
def begin_train_step_presharded(
self,
step_id: str,
loss_fn: Any,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) -> None:
"""Open a logical train step. No fetch — pure lifecycle.

The backend stores ``step_id`` / ``loss_fn`` / ``gbs`` / ``mbs``,
clears gradients, and initialises accumulators for
``local_valid_seqs`` / ``local_valid_toks`` and any per-microbatch
metrics. Optimizer state is untouched here.
"""
self.begin_train_step( # type: ignore[attr-defined]
step_id=step_id,
loss_fn=loss_fn,
gbs=gbs,
mbs=mbs,
)

@wrap_with_nvtx_name("policy_worker/train_microbatch_presharded")
def train_microbatch_presharded(
self,
step_id: str,
meta: "KVBatchMeta",
) -> dict[str, Any]:
"""Per-rank microbatch entrypoint. Fetch → packing prep → forward+backward.

Gradients accumulate into ``.grad`` across calls; no
``optimizer.step`` here. Returns per-microbatch metrics (loss,
local_valid_*); the backend folds them into the step accumulator
and the caller may surface them for diagnostics.
"""
data = self._fetch(meta)
data = self._attach_or_repack_pack_metadata(data, meta)
return self.train_microbatch( # type: ignore[attr-defined]
step_id=step_id,
data=data,
)

@wrap_with_nvtx_name("policy_worker/finish_train_step_presharded")
def finish_train_step_presharded(
self,
step_id: str,
) -> dict[str, Any]:
"""Close a logical train step. No fetch — pure lifecycle.

Backend all-reduces accumulated ``local_valid_seqs/toks``,
rescales gradients to the final global normalization, runs grad
clip, steps the optimizer + scheduler, then zeros gradients.
Returns the aggregated step result (``loss``, ``grad_norm``,
``all_mb_metrics``, …).
"""
return self.finish_train_step(step_id=step_id) # type: ignore[attr-defined]

@wrap_with_nvtx_name("policy_worker/abort_train_step_presharded")
def abort_train_step_presharded(
self,
step_id: str,
) -> None:
"""Discard partial train-step state without stepping the optimizer.

Used when SC decides the logical batch will not complete (e.g.
weight-sync triggered mid-step). Backend drops accumulators and
zeros gradients.
"""
self.abort_train_step(step_id=step_id) # type: ignore[attr-defined]
125 changes: 125 additions & 0 deletions nemo_rl/models/policy/tq_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,3 +438,128 @@ def train_from_meta(
warnings.warn(f"Error getting theoretical flops: {e}")

return aggregated_results

# ── split-API fanout (SC async path) ───────────────────────────────────
#
# Counterpart to :meth:`train_from_meta`, exposed to ``PolicyTrainerActor``
# so :class:`SingleControllerActor` can stream microbatches without
# forcing a full-step optimizer.step on every dispatch.
#
# Lifecycle:
# begin_train_step — open step; broadcast loss_fn/gbs/mbs
# train_microbatch_from_meta (N×) — DP-sharded fwd/bwd, grads accumulate
# finish_train_step — all_reduce + opt.step + sched.step
# abort_train_step — drop accumulators, no opt.step
#
# ``train_from_meta`` is unchanged and remains the sync entrypoint.

def begin_train_step(
self,
step_id: str,
loss_fn: LossFunction,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) -> None:
"""Open a logical train step on every worker."""
batch_size = gbs or self.cfg["train_global_batch_size"]
micro_batch_size = mbs or self.cfg["train_micro_batch_size"]
if self.flops_tracker is not None:
self.flops_tracker.reset()
futures = self.worker_group.run_all_workers_single_data(
"begin_train_step_presharded",
step_id=step_id,
loss_fn=loss_fn,
gbs=batch_size,
mbs=micro_batch_size,
)
self.worker_group.get_all_worker_results(futures)

def train_microbatch_from_meta(
self,
step_id: str,
meta: KVBatchMeta,
timer: Optional[Timer] = None,
) -> dict[str, Any]:
"""Dispatch one microbatch (DP-sharded) into an open train step.

Mirrors the sharding logic of :meth:`train_from_meta` but without
a logical-batch sizing constraint: this routes ``meta`` to DP
ranks and runs forward+backward; gradients accumulate in
``.grad``. The optimizer step happens at :meth:`finish_train_step`.
"""
self._stamp_pad_seqlen(meta)
spa, dba = self._packing_args("train_mb_tokens")
train_meta = replace(
meta,
fields=list(DP_TRAIN_FIELDS),
task_name="train",
)
with timer.time("policy_training/shard_meta") if timer else nullcontext():
dp_metas, _ = shard_meta_for_dp(
train_meta,
dp_world=self.sharding_annotations.get_axis_size("data_parallel"),
batch_size=None,
sequence_packing_args=spa,
dynamic_batching_args=dba,
)

if self.flops_tracker is not None:
for m in dp_metas:
self.flops_tracker.track_batch(list(m.sequence_lengths or []))

with (
timer.time("policy_training/submit_microbatch_futures")
if timer
else nullcontext()
):
futures = self.worker_group.run_all_workers_sharded_data(
"train_microbatch_presharded",
meta=dp_metas,
in_sharded_axes=["data_parallel"],
replicate_on_axes=[
"context_parallel",
"tensor_parallel",
"pipeline_parallel",
],
output_is_replicated=[
"context_parallel",
"tensor_parallel",
"pipeline_parallel",
],
common_kwargs={"step_id": step_id},
)
results = self.worker_group.get_all_worker_results(futures)
# Per-microbatch metrics: pass through DP-rank-0 by convention,
# backend may aggregate later if needed. Surface as-is for now.
return results[0] if results else {}

def finish_train_step(self, step_id: str) -> dict[str, Any]:
"""Close an open train step: all_reduce, rescale, optimizer.step.

Aggregates per-rank step results into the same shape as
:meth:`train_from_meta` so callers don't have to special-case
the split path.
"""
futures = self.worker_group.run_all_workers_single_data(
"finish_train_step_presharded",
step_id=step_id,
)
results = self.worker_group.get_all_worker_results(futures)
aggregated_results = _aggregate_train_results(results)

if self.flops_tracker is not None:
aggregated_results["total_flops"] = self.flops_tracker.total_flops
aggregated_results["num_ranks"] = self.worker_group.cluster.world_size()

return aggregated_results

def abort_train_step(self, step_id: str) -> None:
"""Drop partial step state on every worker. No optimizer.step."""
futures = self.worker_group.run_all_workers_single_data(
"abort_train_step_presharded",
step_id=step_id,
)
self.worker_group.get_all_worker_results(futures)

if self.flops_tracker is not None:
self.flops_tracker.reset()
Loading
Loading