Skip to content

feat(megatron): split-API train-step state machine on MegatronPolicyWorker#2683

Open
mehraakash wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_mcore
Open

feat(megatron): split-API train-step state machine on MegatronPolicyWorker#2683
mehraakash wants to merge 5 commits into
NVIDIA-NeMo:mainfrom
mehraakash:asyncrl/split_train_mcore

Conversation

@mehraakash
Copy link
Copy Markdown

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:

  • begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn / gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and null model.config.grad_sync_func (saved for restore) so the PP scheduler's direct reduce dispatch cannot bypass no_sync.
  • train_microbatch(data): wrap one megatron_forward_backward invocation in with self.model.no_sync(): so mcore DDP hooks accumulate param.main_grad locally without dispatching the cross-DP reduce. Pass global_valid_seqs/toks=tensor(1.0) so the loss returns un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local mask sums + per-mb metrics + the total pipeline-microbatch count (for finish-time MoE aux-loss scaling).
  • finish_train_step: all_reduce mask sums to get true N (toks for TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call self.model.scale_gradients(1/N), then the one true cross-DP reduce via start_grad_sync + finish_grad_sync, optimizer.step (clips internally), restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
  • abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad, drop state. trainer_version unchanged.

Sync train() is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py covering the lifecycle and call-order invariants (no_sync wrap, grad_sync_func save/restore, mask-sum accumulation, N selection by loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so they run only in mcore-enabled CI containers.

What does this PR do ?

Add a one line overview of what this PR aims to accomplish.

Issues

List issues that this PR closes (syntax):

Usage

  • You can potentially add a usage example below
# Add a code snippet demonstrating how to use this

Before your PR is "Ready for review"

Pre checks:

  • Make sure you read and followed Contributor guidelines
  • Did you write any new necessary tests?
  • Did you run the unit tests and functional tests locally? Visit our Testing Guide for how to run tests
  • Did you add or update any necessary documentation? Visit our Document Development Guide for how to write, build and test the docs.

Additional Information

  • ...

mehraakash and others added 2 commits June 3, 2026 22:56
…orker

Adds begin_train_step / train_microbatch / finish_train_step / abort_train_step
on MegatronPolicyWorkerImpl, mirroring the DTensor v1/v2 implementations but
adapted for mcore's contiguous grad bucket + pipeline-schedule reduce path.

Mechanism:
- begin_train_step: zero_grad_buffer + optimizer.zero_grad, store loss_fn /
  gbs / mbs / local_valid_seqs/toks accumulators on _train_step_state, and
  null model.config.grad_sync_func (saved for restore) so the PP scheduler's
  direct reduce dispatch cannot bypass no_sync.
- train_microbatch(data): wrap one ``megatron_forward_backward`` invocation
  in ``with self.model.no_sync():`` so mcore DDP hooks accumulate
  ``param.main_grad`` locally without dispatching the cross-DP reduce.
  Pass ``global_valid_seqs/toks=tensor(1.0)`` so the loss returns
  un-normalized sums; backward deposits raw d(sum)/dθ. Accumulate local
  mask sums + per-mb metrics + the total pipeline-microbatch count
  (for finish-time MoE aux-loss scaling).
- finish_train_step: all_reduce mask sums to get true N (toks for
  TOKEN_LEVEL loss, seqs for SEQUENCE_LEVEL), call
  self.model.scale_gradients(1/N), then the one true cross-DP reduce via
  start_grad_sync + finish_grad_sync, optimizer.step (clips internally),
  restore grad_sync_func, scheduler.step(increment=gbs). Rescale per-mb
  metrics by 1/N (linear-in-1/N math), aggregate, surface global counts.
- abort_train_step: restore grad_sync_func, zero_grad_buffer + zero_grad,
  drop state. ``trainer_version`` unchanged.

Sync ``train()`` is left untouched.

Includes CPU unit tests at tests/unit/models/policy/test_megatron_split_state.py
covering the lifecycle and call-order invariants (no_sync wrap,
grad_sync_func save/restore, mask-sum accumulation, N selection by
loss_type, abort idempotence, MoE scaling). Marked pytest.mark.mcore so
they run only in mcore-enabled CI containers.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash mehraakash requested review from a team as code owners June 4, 2026 06:09
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Jun 4, 2026

Auto-sync is disabled for ready for review pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@mehraakash
Copy link
Copy Markdown
Author

/ok to test

Pre-existing zero-error file from NVIDIA-NeMo#2078 (Eagle3) that was never added
to the project-includes whitelist. Carrying the fix forward in this
PR to unblock the lint job.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test 3d24224

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test cb65400

The file is introduced by NVIDIA-NeMo#2692 (DTensor PR), not by this branch.
Whitelisting it here causes pyrefly to fail with 'No Python files
matched pattern' since the file does not exist on mcore.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Signed-off-by: Akash Mehra <akamehra@nvidia.com>
@mehraakash
Copy link
Copy Markdown
Author

/ok to test 108ec17

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CI:L0 Run doctests and unit tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant