diff --git a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py index 85bd4fc26f..4e32479802 100644 --- a/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py +++ b/examples/cfd/external_aerodynamics/unified_external_aero_recipe/src/utils.py @@ -33,7 +33,7 @@ from torch.amp import autocast from physicsnemo.mesh import DomainMesh, Mesh -from physicsnemo.optim import CombinedOptimizer +from physicsnemo.optim import CombinedOptimizer, Muon ### Recipe-wide type aliases. Re-exported for use in loss.py, metrics.py, ### output_normalize.py, forward_kwargs.py, collate.py, train.py, infer.py, @@ -112,7 +112,7 @@ def build_muon_optimizer( if muon_params and other_params: return CombinedOptimizer( [ - torch.optim.Muon( + Muon( muon_params, lr=lr, weight_decay=weight_decay, @@ -129,7 +129,7 @@ def build_muon_optimizer( torch_compile_kwargs=compile_kwargs, ) elif muon_params: - opt = torch.optim.Muon( + opt = Muon( muon_params, lr=lr, weight_decay=weight_decay, diff --git a/physicsnemo/optim/__init__.py b/physicsnemo/optim/__init__.py index f71db6db97..be7733c7c2 100644 --- a/physicsnemo/optim/__init__.py +++ b/physicsnemo/optim/__init__.py @@ -17,5 +17,6 @@ """Optimizer utilities for PhysicsNeMo.""" from physicsnemo.optim.combined_optimizer import CombinedOptimizer +from physicsnemo.optim.muon import Muon -__all__ = ["CombinedOptimizer"] +__all__ = ["CombinedOptimizer", "Muon"] diff --git a/physicsnemo/optim/muon.py b/physicsnemo/optim/muon.py new file mode 100644 index 0000000000..859a574f39 --- /dev/null +++ b/physicsnemo/optim/muon.py @@ -0,0 +1,293 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Fused Muon optimizer. + +PyTorch's :class:`torch.optim.Muon` runs the Newton-Schulz orthogonalization +one parameter at a time -- its functional ``muon()`` explicitly raises for the +``foreach`` path -- so a model with many 2-D weight matrices issues hundreds of +tiny, serial, launch-bound matmuls per step. The Newton-Schulz iteration of one +parameter is independent of every other, so parameters that share a shape can be +stacked and orthogonalized together with batched matmuls (``bmm`` / ``baddbmm``). + +``torch.bmm`` computes each batch element independently, so the batched result is +the same as the per-parameter loop; only the number of kernel launches changes, +from ``O(num_params * ns_steps)`` to ``O(num_shape_groups * ns_steps)``. + +This module provides :class:`Muon`, a subclass of :class:`torch.optim.Muon` that +overrides only the per-step orthogonalization with a batched implementation. The +constructor signature, hyperparameter semantics, validation, and ``momentum_buffer`` +optimizer-state key are all inherited from :class:`torch.optim.Muon`, so checkpoints +remain interchangeable. +""" + +from __future__ import annotations + +from collections import defaultdict +from typing import Callable + +import torch +from torch import Tensor +from torch.optim import Muon as _TorchMuon + +from physicsnemo.core.version_check import OptionalImport + +# Prevent import errors against internal API changes: +_torch_muon_internal = OptionalImport("torch.optim._muon") + +__all__ = ["Muon"] + + +def _batched_newton_schulz( + updates: Tensor, + ns_coefficients: tuple[float, float, float], + ns_steps: int, + eps: float, +) -> Tensor: + """Batched Newton-Schulz orthogonalization of a stack of 2-D matrices. + + Performs, for every matrix in the batch independently, the same quintic + Newton-Schulz iteration as :func:`torch.optim._muon._zeropower_via_newtonschulz`, + but with batched matmuls so a whole group of equally-shaped parameters is + orthogonalized in a handful of kernel launches. + + Parameters + ---------- + updates : torch.Tensor + Stack of update matrices of shape ``(G, M, N)``. + ns_coefficients : tuple[float, float, float] + Quintic polynomial coefficients ``(a, b, c)``. + ns_steps : int + Number of Newton-Schulz iterations. + eps : float + Numerical-stability floor for the spectral-norm normalization. + + Returns + ------- + torch.Tensor + Orthogonalized stack of shape ``(G, M, N)`` in ``bfloat16`` (cast back + to the parameter dtype by the caller). + """ + if ns_steps >= 100: + # This is a decision that exactly mirrors upstream pytorch. + raise ValueError( + "Number of steps must be less than 100 for computational efficiency" + ) + if updates.ndim != 3: + raise ValueError("Batched Newton-Schulz expects a 3D (G, M, N) tensor") + if len(ns_coefficients) != 3: + raise ValueError("Coefficients must be a tuple of exactly 3 values") + + a, b, c = ns_coefficients + ortho = updates.bfloat16() + + # Orient so rows <= cols (the Gram matrix is then the smaller M x M), the + # same orientation rule torch uses per matrix. All matrices in the batch + # share a shape, so the decision is uniform. + transpose = ortho.size(1) > ortho.size(2) + if transpose: + ortho = ortho.transpose(1, 2) + + # Ensure each matrix's spectral norm is at most 1 (Frobenius is + # transpose-invariant, so doing it after the orient above is fine). + norm = ortho.norm(dim=(1, 2), keepdim=True).clamp(min=eps) + ortho = ortho / norm + + for _ in range(ns_steps): + gram = torch.bmm(ortho, ortho.transpose(1, 2)) + # b * gram + c * (gram @ gram) + gram_update = torch.baddbmm(gram, gram, gram, beta=b, alpha=c) + # a * ortho + (gram_update @ ortho) + ortho = torch.baddbmm(ortho, gram_update, ortho, beta=a, alpha=1.0) + + if transpose: + ortho = ortho.transpose(1, 2) + return ortho + + +class Muon(_TorchMuon): + r"""Fused Muon optimizer for 2-D parameters. + + Subclass of :class:`torch.optim.Muon` that batches the Newton-Schulz + orthogonalization across parameters of the same shape using ``torch.bmm`` / + ``torch.baddbmm``, and applies the momentum and weight-decay updates with the + ``torch._foreach_*`` fused kernels. Construction, validation, hyperparameter + defaults, and the ``momentum_buffer`` state key are inherited unchanged from + :class:`torch.optim.Muon`, so it is numerically equivalent (batched matmuls + compute each matrix independently) and checkpoint-compatible. + + Muon only optimizes 2-D parameters (linear / attention weight matrices). Use + a standard optimizer such as AdamW for biases, norms, and embeddings -- for + example via :class:`physicsnemo.optim.CombinedOptimizer`. + + Parameters + ---------- + params : iterable + Iterable of 2-D parameters or parameter-group dicts. + lr : float, optional + Learning rate. Default 1e-3. + weight_decay : float, optional + Decoupled weight decay. Default 0.1. + momentum : float, optional + Momentum factor. Default 0.95. + nesterov : bool, optional + Enable Nesterov momentum. Default True. + ns_coefficients : tuple[float, float, float], optional + Newton-Schulz quintic coefficients ``(a, b, c)``. + eps : float, optional + Numerical-stability term for the spectral-norm normalization. + ns_steps : int, optional + Number of Newton-Schulz iterations. Default 5. + adjust_lr_fn : str, optional + One of ``"original"`` or ``"match_rms_adamw"``. Default None + (treated as ``"original"``). + + Forward + ------- + Call :meth:`step` after ``loss.backward()`` to apply one optimization step. + + Outputs + ------- + The optional closure loss returned by :meth:`step`, or ``None``. + + .. important:: + The fused path stacks equally-shaped matrices and orthogonalizes them + with batched matmuls, which is only correct for **replicated 2-D + parameters**. Single-GPU and **DDP** are fully supported and numerically + equal to :class:`torch.optim.Muon` (DDP gradients are dense, replicated + tensors, and :meth:`step` runs after ``backward()`` returns, by which + point DDP's bucketed all-reduce has already been synchronized). + + Notes + ----- + See :class:`torch.optim.Muon` for the full algorithm description. + + Examples + -------- + >>> import torch + >>> from physicsnemo.optim import Muon + >>> weights = [torch.nn.Parameter(torch.randn(8, 8)) for _ in range(3)] + >>> opt = Muon(weights, lr=0.02) + >>> for w in weights: + ... w.grad = torch.randn_like(w) + >>> _ = opt.step() + """ + + @staticmethod + def _group_params_by_shape(params: list[Tensor]) -> dict[tuple, list[int]]: + """Bucket parameter indices by ``(shape, dtype, device)``. + + Parameters that share all three can be stacked and orthogonalized in a + single batched Newton-Schulz call. Insertion order is preserved so the + batched result maps back to the original parameter order. + + Parameters + ---------- + params : list[torch.Tensor] + Parameters (or per-parameter update tensors) to group. + + Returns + ------- + dict[tuple, list[int]] + Mapping from ``(tuple(shape), dtype, device)`` to the list of + indices into *params* that belong to that group. + """ + groups: dict[tuple, list[int]] = defaultdict(list) + for i, p in enumerate(params): + groups[(tuple(p.shape), p.dtype, p.device)].append(i) + return groups + + @torch.no_grad() + def step(self, closure: Callable[[], float] | None = None) -> float | None: + """Perform a single optimization step. + + Parameters + ---------- + closure : Callable[[], float], optional + Optional callable that reevaluates the model and returns the loss. + Default None. + + Returns + ------- + float or None + The loss returned by *closure*, or ``None`` if no closure was given. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + if isinstance(lr, Tensor): + lr = lr.item() + weight_decay = group["weight_decay"] + momentum = group["momentum"] + nesterov = group["nesterov"] + ns_coefficients = group["ns_coefficients"] + eps = group["eps"] + ns_steps = group["ns_steps"] + adjust_lr_fn = group["adjust_lr_fn"] + + params_with_grad: list[Tensor] = [] + grads: list[Tensor] = [] + momentum_bufs: list[Tensor] = [] + + # Reuse the upstream collector: it appends params/grads, rejects + # complex/sparse, and lazily initializes the momentum_buffer state. + self._init_group(group, params_with_grad, grads, momentum_bufs) + + if not params_with_grad: + continue + + for g in grads: + if g.ndim != 2: + raise ValueError("Param gradient must be a 2D matrix") + + # Momentum (fused across all shapes): buf = momentum*buf + (1-momentum)*grad + torch._foreach_lerp_(momentum_bufs, grads, 1 - momentum) + if nesterov: + # update = grad + momentum*(buf - grad) + updates = torch._foreach_lerp(grads, momentum_bufs, momentum) + else: + updates = list(momentum_bufs) + + # Decoupled weight decay (fused across all shapes). + torch._foreach_mul_(params_with_grad, 1 - lr * weight_decay) + + # Group equally-shaped updates and orthogonalize each group with one + # batched Newton-Schulz, then apply the (per-group, shape-dependent) + # learning rate. + groups = self._group_params_by_shape(params_with_grad) + + for (shape, _dtype, _device), idxs in groups.items(): + stacked = torch.stack([updates[i] for i in idxs], dim=0) + ortho = _batched_newton_schulz(stacked, ns_coefficients, ns_steps, eps) + adjusted_lr = _torch_muon_internal._adjust_lr( + lr, adjust_lr_fn, torch.Size(shape) + ) + + group_params = [params_with_grad[i] for i in idxs] + # Cast back to the parameter dtype (NS runs in bf16). + ortho_list = [ + ortho[j].to(group_params[j].dtype) for j in range(len(idxs)) + ] + torch._foreach_add_(group_params, ortho_list, alpha=-adjusted_lr) + + return loss + + def __repr__(self) -> str: + return f"{self.__class__.__name__} (fused Newton-Schulz)" diff --git a/test/optim/test_muon.py b/test/optim/test_muon.py new file mode 100644 index 0000000000..d84aa0633c --- /dev/null +++ b/test/optim/test_muon.py @@ -0,0 +1,148 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2026 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import copy + +import pytest +import torch + +from physicsnemo.optim import Muon + +# torch.optim.Muon is the per-parameter reference implementation we batch. +_TORCH_MUON = getattr(torch.optim, "Muon", None) +_HAS_TORCH_MUON = _TORCH_MUON is not None + + +def _make_params(shapes, device, seed=0): + """Create a list of 2-D parameters with reproducible random values.""" + gen = torch.Generator(device="cpu").manual_seed(seed) + return [ + torch.nn.Parameter(torch.randn(*s, generator=gen).to(device)) for s in shapes + ] + + +def _make_grads(shapes, device, seed): + """Create a list of gradients aligned with ``shapes``.""" + gen = torch.Generator(device="cpu").manual_seed(seed) + return [torch.randn(*s, generator=gen).to(device) for s in shapes] + + +@pytest.mark.skipif(not _HAS_TORCH_MUON, reason="torch.optim.Muon unavailable") +@pytest.mark.parametrize("nesterov", [True, False]) +@pytest.mark.parametrize("adjust_lr_fn", ["original", "match_rms_adamw"]) +def test_matches_torch_muon(device, nesterov, adjust_lr_fn): + """Fused Muon matches torch.optim.Muon step-for-step within tolerance. + + Shapes include square, wide, tall, and a repeated shape so the batched + Newton-Schulz path (group size > 1) is exercised. + """ + shapes = [(8, 8), (8, 8), (8, 16), (16, 8)] + + ref_params = _make_params(shapes, device, seed=1) + fused_params = [torch.nn.Parameter(p.detach().clone()) for p in ref_params] + + kwargs = dict( + lr=0.02, + weight_decay=0.01, + momentum=0.95, + nesterov=nesterov, + adjust_lr_fn=adjust_lr_fn, + ) + ref_opt = _TORCH_MUON(ref_params, **kwargs) + fused_opt = Muon(fused_params, **kwargs) + + for step in range(5): + grads = _make_grads(shapes, device, seed=100 + step) + for p, g in zip(ref_params, grads): + p.grad = g.clone() + for p, g in zip(fused_params, grads): + p.grad = g.clone() + ref_opt.step() + fused_opt.step() + + for ref_p, fused_p in zip(ref_params, fused_params): + torch.testing.assert_close(fused_p, ref_p, atol=1e-3, rtol=1e-3) + + +def test_lazy_adjust_lr_proxy_resolves(): + """The private torch._adjust_lr is reachable via the lazy OptionalImport. + + physicsnemo.optim.muon imports torch.optim._muon lazily so a future + rename/removal fails at step() runtime rather than at module import time. + This asserts the proxy resolves the symbol on the installed torch. + """ + from physicsnemo.optim.muon import _torch_muon_internal + + assert callable(_torch_muon_internal._adjust_lr) + + +def test_group_params_by_shape(device): + """Equally-shaped params bucket together; distinct shapes stay separate.""" + shapes = [(8, 8), (8, 8), (8, 16), (16, 8), (8, 8)] + params = _make_params(shapes, device, seed=2) + + groups = Muon._group_params_by_shape(params) + + sizes = sorted(len(idxs) for idxs in groups.values()) + # (8,8) x3, (8,16) x1, (16,8) x1 + assert sizes == [1, 1, 3] + # The repeated (8,8) shape collapses to one group with the right indices. + eight = [idxs for key, idxs in groups.items() if key[0] == (8, 8)] + assert eight == [[0, 1, 4]] + + +def test_state_dict_roundtrip(device): + """Saving and restoring state (incl. momentum buffers) resumes identically.""" + shapes = [(8, 8), (8, 16)] + params = _make_params(shapes, device, seed=3) + opt = Muon(params, lr=0.02, weight_decay=0.01, adjust_lr_fn="match_rms_adamw") + + # Two steps to populate momentum buffers. + for step in range(2): + grads = _make_grads(shapes, device, seed=200 + step) + for p, g in zip(params, grads): + p.grad = g.clone() + opt.step() + + saved_state = copy.deepcopy(opt.state_dict()) + snapshot = [p.detach().clone() for p in params] + + # Continue one more step on the original optimizer (the reference). + final_grads = _make_grads(shapes, device, seed=999) + for p, g in zip(params, final_grads): + p.grad = g.clone() + opt.step() + reference_final = [p.detach().clone() for p in params] + + # Fresh optimizer restored from the snapshot + saved state must match. + restored_params = [torch.nn.Parameter(s.clone()) for s in snapshot] + restored_opt = Muon( + restored_params, lr=0.02, weight_decay=0.01, adjust_lr_fn="match_rms_adamw" + ) + restored_opt.load_state_dict(saved_state) + for p, g in zip(restored_params, final_grads): + p.grad = g.clone() + restored_opt.step() + + for ref_p, restored_p in zip(reference_final, restored_params): + torch.testing.assert_close(restored_p, ref_p, atol=1e-6, rtol=1e-6) + + +def test_rejects_non_2d_params(device): + """Muon only supports 2-D parameters.""" + param_1d = torch.nn.Parameter(torch.randn(8).to(device)) + with pytest.raises(ValueError, match="2D"): + Muon([param_1d], lr=0.02)