Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions flax/nnx/nn/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,15 @@ def dot_product_attention_weights(
query, key = promote_dtype((query, key), dtype=dtype) # type: ignore[bad-unpacking]
dtype = query.dtype

assert query.ndim == key.ndim, 'q, k must have same rank.'
assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.'
assert query.shape[-1] == key.shape[-1], 'q, k depths must match.'
# Plain `assert` is dropped under `python -O`, which would let a rank or
# shape mismatch silently fall through to the einsum below and return a
# wrong-shaped result instead of raising.
if query.ndim != key.ndim:
raise ValueError('q, k must have same rank.')
if query.shape[:-3] != key.shape[:-3]:
raise ValueError('q, k batch dims must match.')
if query.shape[-1] != key.shape[-1]:
raise ValueError('q, k depths must match.')

# check if we need to broadcast Key heads to match Query heads
is_gqa = False
Expand Down Expand Up @@ -255,11 +261,18 @@ def dot_product_attention(
query, key, value = promote_dtype((query, key, value), dtype=dtype) # type: ignore[bad-unpacking]
dtype = query.dtype

assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.'
assert (
query.shape[:-3] == key.shape[:-3] == value.shape[:-3]
), 'q, k, v batch dims must match.'
assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.'
# Plain `assert` is dropped under `python -O`, which would let a rank or
# shape mismatch silently fall through to the einsums below and return a
# wrong-shaped result instead of raising (the fast path a few lines down
# is backstopped by jax.nn.dot_product_attention's own checks, but this
# slow path -- used whenever dropout_rate > 0 or module is not None -- is
# not).
if not (key.ndim == query.ndim == value.ndim):
raise ValueError('q, k, v must have same rank.')
if not (query.shape[:-3] == key.shape[:-3] == value.shape[:-3]):
raise ValueError('q, k, v batch dims must match.')
if key.shape[-3] != value.shape[-3]:
raise ValueError('k, v lengths must match.')

# Criteria that invoke the more optimized dot product attention
if dropout_rate == 0.0 and module is None:
Expand Down
17 changes: 17 additions & 0 deletions tests/nnx/nn/attention_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,23 @@ def test_gqa_invalid_heads(self):
with self.assertRaisesRegex(ValueError, "must be a multiple"):
nnx.dot_product_attention(query, key, value)

def test_rank_mismatch_raises_under_dropout_path(self):
# Regression test for https://github.com/google/flax/issues/5496
# The dropout_rate > 0 path used `assert` for q/k/v rank and shape
# invariants. `assert` is compiled out under `python -O`, which let a
# rank mismatch fall through to the einsum and silently return a
# wrong-shaped output instead of raising. The fast path (dropout_rate==0,
# module=None) isn't affected since jax.nn.dot_product_attention does
# its own validation, so this exercises the slow path specifically.
query = jax.random.normal(jax.random.key(0), (2, 16, 4, 8)) # rank 4
key = jax.random.normal(jax.random.key(1), (16, 4, 8)) # rank 3
value = jax.random.normal(jax.random.key(2), (16, 4, 8)) # rank 3

with self.assertRaisesRegex(ValueError, "must have same rank"):
nnx.dot_product_attention(
query, key, value, dropout_rate=0.1, dropout_rng=jax.random.key(3)
)

def test_gqa_multihead_attention(self):
in_feat = 128
n_heads = 32
Expand Down