Skip to content

GeoT optimization 3/4: Add fused batched Muon optimizer#1743

Open
coreyjadams wants to merge 6 commits into
mainfrom
geoT-opt-muon-opt-fusion
Open

GeoT optimization 3/4: Add fused batched Muon optimizer#1743
coreyjadams wants to merge 6 commits into
mainfrom
geoT-opt-muon-opt-fusion

Conversation

@coreyjadams

Copy link
Copy Markdown
Collaborator

PhysicsNeMo Pull Request

Cursor made this implementation and I want to clean it up to be a tighter integration against torch before we merge. The key is that the overhead of looping over params is actually pretty significant for models like GeoT. So this is a first draft at taht fusion.

We won't merge it in this state, but I wanted a branch as a placeholder for putting all the pieces together.

Description

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Add physicsnemo.optim.Muon, a fused/batched drop-in replacement for
torch.optim.Muon that groups 2-D parameters by (shape, dtype, device)
and runs batched Newton-Schulz via torch.bmm/baddbmm with
torch._foreach_* momentum/weight-decay updates. Matches torch.optim.Muon
hyperparameters, momentum_buffer state, and LR-adjustment modes.

Export it from physicsnemo.optim and switch the unified external aero
recipe's build_muon_optimizer to use it via CombinedOptimizer.
@copy-pr-bot

copy-pr-bot Bot commented Jun 22, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@copy-pr-bot

copy-pr-bot Bot commented Jun 29, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@coreyjadams coreyjadams marked this pull request as ready for review June 29, 2026 19:53
@greptile-apps

greptile-apps Bot commented Jun 29, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR introduces physicsnemo.optim.Muon, a subclass of torch.optim.Muon that batches the Newton-Schulz orthogonalization across same-shaped parameters using torch.bmm/torch.baddbmm, reducing kernel launches from O(num_params × ns_steps) to O(num_shape_groups × ns_steps). The author explicitly marks this as a WIP placeholder that won't be merged as-is.

  • physicsnemo/optim/muon.py: New fused optimizer with batched NS iteration, DTensor/FSDP2 rejection guards, and a full docstring. Relies on torch.optim._muon._adjust_lr (a private PyTorch internal) and the Nesterov formula uses lerp(grad, buf, momentum) whose equivalence to torch.optim.Muon needs confirming on a PyTorch 2.10 build.
  • test/optim/test_muon.py: Good test coverage for grouping, state-dict roundtrip, and DTensor rejection; the key numerical-equivalence test is gated on torch.optim.Muon availability and may not have run yet.
  • examples/.../utils.py and physicsnemo/optim/__init__.py: Minimal wiring changes to surface the new optimizer.

Important Files Changed

Filename Overview
physicsnemo/optim/muon.py New fused Muon optimizer subclass: two correctness concerns (private _adjust_lr import that will hard-fail if PyTorch refactors the internal module, and an unverified Nesterov formula that may diverge from the upstream by momentum*grad).
test/optim/test_muon.py Good test coverage for grouping, state-dict roundtrip, and DTensor rejection; the critical numerical-equivalence test (test_matches_torch_muon) is guarded by skipif and may not have been run against a real torch.optim.Muon build yet.
physicsnemo/optim/init.py Exports new Muon class; change is minimal and correct.
examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py Switches call sites from torch.optim.Muon to physicsnemo.optim.Muon; straightforward drop-in replacement with no other changes.

Reviews (1): Last reviewed commit: "Merge branch 'main' into geoT-opt-muon-o..." | Re-trigger Greptile

Comment thread physicsnemo/optim/muon.py Outdated
Comment thread physicsnemo/optim/muon.py
Comment thread physicsnemo/optim/muon.py
Comment thread physicsnemo/optim/muon.py Outdated

@peterdsharpe peterdsharpe left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job with this!

The bucketing-logic ("grouping") is a particularly nice touch that I think will really benefit kernel-launch-bound training. TBH, this might be worth upstreaming to PyTorch's Muon impl too (unless they do it first).

Comment thread physicsnemo/optim/muon.py Outdated
Comment thread physicsnemo/optim/muon.py Outdated
Comment thread physicsnemo/optim/muon.py
Comment thread physicsnemo/optim/muon.py
Comment thread physicsnemo/optim/muon.py
@coreyjadams

Copy link
Copy Markdown
Collaborator Author

/ok to test fdc0b5a

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants