diff --git a/flax/nnx/nn/attention.py b/flax/nnx/nn/attention.py index 4f0c4f0cd..0ac5b4f50 100644 --- a/flax/nnx/nn/attention.py +++ b/flax/nnx/nn/attention.py @@ -107,9 +107,15 @@ def dot_product_attention_weights( query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype - assert query.ndim == key.ndim, 'q, k must have same rank.' - assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' - assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' + # Plain `assert` is dropped under `python -O`, which would let a rank or + # shape mismatch silently fall through to the einsum below and return a + # wrong-shaped result instead of raising. + if query.ndim != key.ndim: + raise ValueError('q, k must have same rank.') + if query.shape[:-3] != key.shape[:-3]: + raise ValueError('q, k batch dims must match.') + if query.shape[-1] != key.shape[-1]: + raise ValueError('q, k depths must match.') # check if we need to broadcast Key heads to match Query heads is_gqa = False @@ -255,11 +261,18 @@ def dot_product_attention( query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking] dtype = query.dtype - assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert ( - query.shape[:-3] == key.shape[:-3] == value.shape[:-3] - ), 'q, k, v batch dims must match.' - assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' + # Plain `assert` is dropped under `python -O`, which would let a rank or + # shape mismatch silently fall through to the einsums below and return a + # wrong-shaped result instead of raising (the fast path a few lines down + # is backstopped by jax.nn.dot_product_attention's own checks, but this + # slow path -- used whenever dropout_rate > 0 or module is not None -- is + # not). + if not (key.ndim == query.ndim == value.ndim): + raise ValueError('q, k, v must have same rank.') + if not (query.shape[:-3] == key.shape[:-3] == value.shape[:-3]): + raise ValueError('q, k, v batch dims must match.') + if key.shape[-3] != value.shape[-3]: + raise ValueError('k, v lengths must match.') # Criteria that invoke the more optimized dot product attention if dropout_rate == 0.0 and module is None: diff --git a/tests/nnx/nn/attention_test.py b/tests/nnx/nn/attention_test.py index 167fbf8cf..ed7df331d 100644 --- a/tests/nnx/nn/attention_test.py +++ b/tests/nnx/nn/attention_test.py @@ -325,6 +325,23 @@ def test_gqa_invalid_heads(self): with self.assertRaisesRegex(ValueError, "must be a multiple"): nnx.dot_product_attention(query, key, value) + def test_rank_mismatch_raises_under_dropout_path(self): + # Regression test for https://github.com/google/flax/issues/5496 + # The dropout_rate > 0 path used `assert` for q/k/v rank and shape + # invariants. `assert` is compiled out under `python -O`, which let a + # rank mismatch fall through to the einsum and silently return a + # wrong-shaped output instead of raising. The fast path (dropout_rate==0, + # module=None) isn't affected since jax.nn.dot_product_attention does + # its own validation, so this exercises the slow path specifically. + query = jax.random.normal(jax.random.key(0), (2, 16, 4, 8)) # rank 4 + key = jax.random.normal(jax.random.key(1), (16, 4, 8)) # rank 3 + value = jax.random.normal(jax.random.key(2), (16, 4, 8)) # rank 3 + + with self.assertRaisesRegex(ValueError, "must have same rank"): + nnx.dot_product_attention( + query, key, value, dropout_rate=0.1, dropout_rng=jax.random.key(3) + ) + def test_gqa_multihead_attention(self): in_feat = 128 n_heads = 32