From 32cf3c4785848945096fcf7ca5fc70a7db9afdf3 Mon Sep 17 00:00:00 2001 From: Sumukh Chaluvaraju Date: Thu, 25 Jun 2026 02:50:58 +0100 Subject: [PATCH] fix(nnx): raise ValueError instead of assert for q/k/v shape checks in attention dot_product_attention and dot_product_attention_weights validated q/k/v rank and shape invariants with plain `assert` statements. Python strips `assert` under `-O` (or PYTHONOPTIMIZE=1), which is common in production serving. On the slow path -- taken whenever dropout_rate > 0 or a module is passed for sowing attention weights -- a rank or shape mismatch then silently falls through to the einsum instead of raising, and returns a wrong-shaped output with no error at all. The fast path (dropout_rate == 0, module is None) delegates to jax.nn.dot_product_attention, which does its own validation regardless of -O, so it was never affected. Replaced the asserts with explicit `if ...: raise ValueError(...)` so the same clear error is raised with or without -O. Fixes #5496 --- flax/nnx/nn/attention.py | 29 +++++++++++++++++++++-------- tests/nnx/nn/attention_test.py | 17 +++++++++++++++++ 2 files changed, 38 insertions(+), 8 deletions(-) diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..0ac5b4f50 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -107,9 +107,15 @@ def dot_product_attention_weights( query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype - assert query.ndim == key.ndim, 'q, k must have same rank.' - assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + # Plain `assert` is dropped under `python -O`, which would let a rank or + # shape mismatch silently fall through to the einsum below and return a + # wrong-shaped result instead of raising. + if query.ndim != key.ndim: + raise ValueError('q, k must have same rank.') + if query.shape[:-3] != key.shape[:-3]: + raise ValueError('q, k batch dims must match.') + if query.shape[-1] != key.shape[-1]: + raise ValueError('q, k depths must match.') # check if we need to broadcast Key heads to match Query heads is_gqa = False @@ -255,11 +261,18 @@ def dot_product_attention( query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert ( - query.shape[:-3] == key.shape[:-3] == value.shape[:-3] - ), 'q, k, v batch dims must match.' - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + # Plain `assert` is dropped under `python -O`, which would let a rank or + # shape mismatch silently fall through to the einsums below and return a + # wrong-shaped result instead of raising (the fast path a few lines down + # is backstopped by jax.nn.dot_product_attention's own checks, but this + # slow path -- used whenever dropout_rate > 0 or module is not None -- is + # not). + if not (key.ndim == query.ndim == value.ndim): + raise ValueError('q, k, v must have same rank.') + if not (query.shape[:-3] == key.shape[:-3] == value.shape[:-3]): + raise ValueError('q, k, v batch dims must match.') + if key.shape[-3] != value.shape[-3]: + raise ValueError('k, v lengths must match.') # Criteria that invoke the more optimized dot product attention if dropout_rate == 0.0 and module is None: diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 167fbf8cf..ed7df331d 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -325,6 +325,23 @@ def test_gqa_invalid_heads(self): with self.assertRaisesRegex(ValueError, "must be a multiple"): nnx.dot_product_attention(query, key, value) + def test_rank_mismatch_raises_under_dropout_path(self): + # Regression test for https://github.com/google/flax/issues/5496 + # The dropout_rate > 0 path used `assert` for q/k/v rank and shape + # invariants. `assert` is compiled out under `python -O`, which let a + # rank mismatch fall through to the einsum and silently return a + # wrong-shaped output instead of raising. The fast path (dropout_rate==0, + # module=None) isn't affected since jax.nn.dot_product_attention does + # its own validation, so this exercises the slow path specifically. + query = jax.random.normal(jax.random.key(0), (2, 16, 4, 8)) # rank 4 + key = jax.random.normal(jax.random.key(1), (16, 4, 8)) # rank 3 + value = jax.random.normal(jax.random.key(2), (16, 4, 8)) # rank 3 + + with self.assertRaisesRegex(ValueError, "must have same rank"): + nnx.dot_product_attention( + query, key, value, dropout_rate=0.1, dropout_rng=jax.random.key(3) + ) + def test_gqa_multihead_attention(self): in_feat = 128 n_heads = 32