Skip to content

feat(policy): split-API train-step state machine on DTensor v1/v2#2692

Open
mehraakash wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_dtensor
Open

feat(policy): split-API train-step state machine on DTensor v1/v2#2692
mehraakash wants to merge 6 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_dtensor

Conversation

@mehraakash
Copy link
Copy Markdown

Summary

Adds the DTensor v1 + v2 backend implementations of the split-API train-step
methods (begin / train_microbatch / finish / abort), plus the
PolicyTrainerActor Ray wrapper that exposes the API to SingleController via
.remote(...). Sync train() is left untouched on both backends.

This is PR 2 of 3 in the SingleController split-API series:

  • PR 1 — `asyncrl/split_train_mcore` (Megatron + shared infra: `worker_mixin`, `tq_policy`). Open this branch's PR against `main` first — the dtensor PR diff includes those commits until PR 1 lands.
  • PR 2 — this PR (DTensor v1 + v2 + `PolicyTrainerActor` + GPU parity tests).
  • PR 3 — coming next (SC `_train_pump` rewrite to drive the split API end-to-end).

What's new

  • DTensor v1 state machine using the `N=1` placeholder trick. Per microbatch the loss is called with `global_valid_*=tensor(1.0)` so backward deposits un-normalized gradients; `finish_train_step` `all_reduce`s local mask sums to recover the true `N` (toks for TOKEN_LEVEL, seqs for SEQUENCE_LEVEL), rescales `p.grad` by `1/N`, runs grad_norm/clip/opt.step. Bin iteration with DP-rank dummy-bin padding handles seq-packing / dynamic-batching uneven splits without desyncing NCCL.
  • DTensor v2 same shape, built on the v2 helpers (`LossPostProcessor`, `get_microbatch_iterator`, `automodel_forward_backward`, `scale_grads_and_clip_grad_norm`). The clip helper has MoE/EP awareness; the manual `1/N` rescale runs before it.
  • `PolicyTrainerActor` at `nemo_rl/models/policy/policy_trainer_actor.py` — `@ray.remote(num_cpus=1, num_gpus=0)` wrapper that owns a `TQPolicy` instance and exposes `train_from_meta` (sync proxy), `prepare_logprobs_from_meta`, and the split API. `trainer_version` advances on sync `train_from_meta` or on `finish_train_step` (never on abort).
  • GPU parity tests at `tests/unit/models/policy/test_split_train_step.py`:
    • numerical equivalence vs sync `train()` for token-level and seq-level losses;
    • multi-bin-per-call under seq-packing;
    • split-state-machine lifecycle (double begin, abort idempotence).
      Parameterised over v1/v2.

Test plan

  • `pytest tests/unit/models/policy/test_split_train_step.py -v` on a multi-GPU CI runner.
  • Sync `train()` path unchanged — existing test_dtensor_worker* tests still pass.
  • No production caller of the split API yet (PR 3 will wire SC up), so no recipe regression possible from this PR.

mehraakash and others added 3 commits June 3, 2026 22:56
…orker

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step
on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but
adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:
- begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn /
  gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and
  null model.config.grad_sync_func (saved for restore) so the PP scheduler's
  direct reduce dispatch cannot bypass no_sync.
- train_microbatch(data): wrap one ``megatron_forward_backward`` invocation
  in ``with self.model.no_sync():`` so mcore DDP hooks accumulate
  ``param.main_grad`` locally without dispatching the cross-DP reduce.
  Pass ``global_valid_seqs/toks=tensor(1.0)`` so the loss returns
  un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local
  mask sums + per-mb metrics + the total pipeline-microbatch count
  (for finish-time MoE aux-loss scaling).
- finish_train_step: all_reduce mask sums to get true N (toks for
  TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call
  self.model.scale_gradients(1/N), then the one true cross-DP reduce via
  start_grad_sync + finish_grad_sync, optimizer.step (clips internally),
  restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb
  metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
- abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad,
  drop state. ``trainer_version`` unchanged.

Sync ``train()`` is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py
covering the lifecycle and call-order invariants (no_sync wrap,
grad_sync_func save/restore, mask-sum accumulation, N selection by
loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so
they run only in mcore-enabled CI containers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Adds the DTensor v1 + v2 backend implementations of the split-API methods
(begin_train_step / train_microbatch / finish_train_step / abort_train_step)
introduced in the mcore PR, plus the PolicyTrainerActor Ray wrapper that
exposes the API to SingleController via ``.remote(...)``.

- DTensor v1: state machine using the N=1 placeholder trick. Per microbatch
  the loss is called with ``global_valid_*=tensor(1.0)`` so backward
  deposits un-normalized gradients; ``finish_train_step`` all_reduces local
  mask sums to recover the true N (toks for TOKEN_LEVEL, seqs for
  SEQUENCE_LEVEL), rescales ``p.grad`` by 1/N, runs grad_norm/clip,
  optimizer.step, scheduler.step. Bin iteration with DP-rank dummy-bin
  padding handles seq-packing / dynamic-batching uneven splits without
  desyncing NCCL.

- DTensor v2: same shape, built on the v2 helpers (LossPostProcessor,
  get_microbatch_iterator, automodel_forward_backward,
  scale_grads_and_clip_grad_norm). The clip helper has MoE/EP awareness;
  the manual 1/N rescale runs before it.

- PolicyTrainerActor: ``@ray.remote(num_cpus=1, num_gpus=0)`` wrapper that
  owns a TQPolicy instance and exposes ``train_from_meta`` (sync proxy),
  ``prepare_logprobs_from_meta``, and the split API. ``trainer_version``
  advances on sync train_from_meta or on ``finish_train_step`` (never on
  abort).

- GPU parity tests at tests/unit/models/policy/test_split_train_step.py:
  numerical equivalence vs sync ``train()`` for token-level and seq-level
  losses, multi-bin-per-call under seq-packing, and split-state-machine
  lifecycle (double begin, abort idempotence). Parameterised over v1/v2.

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash mehraakash requested review from a team as code owners June 4, 2026 22:33
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 4, 2026

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@mehraakash
Copy link
Copy Markdown
Author

/ok to test 21d91efcf

Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test

Pre-existing zero-error file from NVIDIA-NeMo#2078 (Eagle3) that was never added
to the project-includes whitelist. Carrying the fix forward in this
PR to unblock the lint job.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test d45f58d

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test 3748982

mehraakash added a commit to mehraakash/RL that referenced this pull request Jun 5, 2026
The file is introduced by NVIDIA-NeMo#2692 (DTensor PR), not by this branch.
Whitelisting it here causes pyrefly to fail with 'No Python files
matched pattern' since the file does not exist on mcore.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
mehraakash added a commit to mehraakash/RL that referenced this pull request Jun 5, 2026
The file is introduced by NVIDIA-NeMo#2692 (DTensor PR), not by this branch.
Whitelisting it here causes pyrefly to fail with 'No Python files
matched pattern' since the file does not exist on this branch.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant