Skip to content

Add vmap reset() regression test for nnx.metrics (#5483)#5491

Open
chenkuanliao wants to merge 1 commit into
google:mainfrom
chenkuanliao:fix-issue-5483
Open

Add vmap reset() regression test for nnx.metrics (#5483)#5491
chenkuanliao wants to merge 1 commit into
google:mainfrom
chenkuanliao:fix-issue-5483

Conversation

@chenkuanliao

@chenkuanliao chenkuanliao commented Jun 11, 2026

Copy link
Copy Markdown

What does this PR do?

This adds a regression test for the nnx.metrics behavior reported in #5483.

nnx.metrics.Average.reset() (and Welford.reset()) re-zero their state by assigning a scalar (jnp.array(0, ...)). When a metric is built under nnx.vmap, its state carries a leading batch axis of shape (N,). On the reporter's versions (flax 0.12.0 / jax 0.7.2) the scalar assignment replaced the whole array, collapsing (N,) to (), so a later vmapped update failed with:

ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())

While investigating I found the crash no longer reproduces on main: the Variable state layer has since changed so a full-index assignment broadcasts a scalar into the existing array in place instead of replacing it, so reset() preserves the batch axis today. I confirmed it still reproduces on flax 0.12.0 / jax 0.7.2 and is gone on main independent of any single commit.

Since there is no longer a live crash to fix on main, this PR is scoped as a regression test only, to lock in the current vmap reset() behavior so it can't silently break again.

What's in this PR

  • tests/nnx/metrics_test.pytest_vmap_reset_preserves_shape: construct a MultiMetric(loss=Average('loss')) under nnx.vmap (state shape (N,)), call reset(), assert the state stays shape (N,), then run a vmapped update and assert compute()['loss'] returns the expected per-batch values.

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
  • This change is discussed in a Github issue/discussion (please add a link). — metrics does not work well with vmap #5483
  • The documentation and docstrings adhere to the documentation guidelines.
  • This change includes necessary high-coverage tests. (No quality testing = no merge!)

Construct a MultiMetric(loss=Average) under nnx.vmap so its state
carries a leading batch axis, call reset(), then run a vmapped update
and assert compute() returns the expected per-batch means. Locks in the
shape-preserving reset() behavior tracked in google#5483.
@chenkuanliao chenkuanliao changed the title Make metric reset() shape-preserving under vmap (#5483) Add vmap reset() regression test for nnx.metrics (#5483) Jun 24, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant