Skip to content

fix(nnx): raise ValueError instead of assert for q/k/v shape checks in attention#5514

Open
Sumu004 wants to merge 1 commit into
google:mainfrom
Sumu004:fix/attention-assert-under-O
Open

fix(nnx): raise ValueError instead of assert for q/k/v shape checks in attention#5514
Sumu004 wants to merge 1 commit into
google:mainfrom
Sumu004:fix/attention-assert-under-O

Conversation

@Sumu004

@Sumu004 Sumu004 commented Jun 25, 2026

Copy link
Copy Markdown

What

flax.nnx.nn.attention.dot_product_attention (and dot_product_attention_weights) validate q/k/v rank and shape invariants with plain assert statements. Python strips assert under -O / PYTHONOPTIMIZE=1, which is common in production serving. On the slow path -- taken whenever dropout_rate > 0 or a module is passed for sowing attention weights -- a rank or shape mismatch then silently falls through to the einsum instead of raising, and returns a wrong-shaped output with no error at all.

import jax
import flax.nnx as nnx

attn = nnx.nn.attention.dot_product_attention
q = jax.random.normal(jax.random.PRNGKey(0), (2, 16, 4, 8))  # rank 4
k = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 8))     # rank 3 (mismatch)
v = jax.random.normal(jax.random.PRNGKey(0), (16, 4, 8))     # rank 3 (mismatch)
out = attn(q, k, v, dropout_rate=0.1, dropout_rng=jax.random.PRNGKey(1))
print("result shape:", out.shape)
# python    repro.py -> AssertionError: q, k, v must have same rank.
# python -O repro.py -> result shape: (2, 16, 4, 8)   (silently wrong, no error)

Root cause

The fast path (dropout_rate == 0, module is None) delegates to jax.nn.dot_product_attention, which does its own validation regardless of -O, so it was never affected. The slow path's own checks (lines ~110-112, ~258-262 before this change) are bare asserts, which Python compiles out entirely under -O -- there's no runtime check left at all, not even a degraded one.

Fix

Replaced the asserts in both dot_product_attention_weights and dot_product_attention with explicit if ...: raise ValueError(...), so the same clear error is raised with or without -O. No behavior change in the normal (non--O) case -- same exceptions, same messages, just ValueError instead of AssertionError (matching the existing ValueError used a few lines below for the GQA head-count check, which already used explicit raises rather than assert).

Checklist

  • New test_rank_mismatch_raises_under_dropout_path regression test, asserting a ValueError on the dropout-path rank mismatch from the issue repro
  • Verified independently (plain Python, no JAX) that assert is compiled out under -O while if: raise is not
  • Existing test_gqa_invalid_heads (and other attention tests) unaffected -- this only swaps the exception type and is unconditional regardless of -O

Fixes #5496

…n attention

dot_product_attention and dot_product_attention_weights validated q/k/v
rank and shape invariants with plain `assert` statements. Python strips
`assert` under `-O` (or PYTHONOPTIMIZE=1), which is common in production
serving. On the slow path -- taken whenever dropout_rate > 0 or a module
is passed for sowing attention weights -- a rank or shape mismatch then
silently falls through to the einsum instead of raising, and returns a
wrong-shaped output with no error at all.

The fast path (dropout_rate == 0, module is None) delegates to
jax.nn.dot_product_attention, which does its own validation regardless of
-O, so it was never affected.

Replaced the asserts with explicit `if ...: raise ValueError(...)` so the
same clear error is raised with or without -O.

Fixes google#5496
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

flax.nnx.dot_product_attention silently returns wrong-shaped output under python -O (rank asserts dropped)

1 participant