fix(nnx): raise ValueError instead of assert for q/k/v shape checks in attention#5514
Open
Sumu004 wants to merge 1 commit into
Open
fix(nnx): raise ValueError instead of assert for q/k/v shape checks in attention#5514Sumu004 wants to merge 1 commit into
Sumu004 wants to merge 1 commit into
Conversation
…n attention dot_product_attention and dot_product_attention_weights validated q/k/v rank and shape invariants with plain `assert` statements. Python strips `assert` under `-O` (or PYTHONOPTIMIZE=1), which is common in production serving. On the slow path -- taken whenever dropout_rate > 0 or a module is passed for sowing attention weights -- a rank or shape mismatch then silently falls through to the einsum instead of raising, and returns a wrong-shaped output with no error at all. The fast path (dropout_rate == 0, module is None) delegates to jax.nn.dot_product_attention, which does its own validation regardless of -O, so it was never affected. Replaced the asserts with explicit `if ...: raise ValueError(...)` so the same clear error is raised with or without -O. Fixes google#5496
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
flax.nnx.nn.attention.dot_product_attention(anddot_product_attention_weights) validate q/k/v rank and shape invariants with plainassertstatements. Python stripsassertunder-O/PYTHONOPTIMIZE=1, which is common in production serving. On the slow path -- taken wheneverdropout_rate > 0or amoduleis passed for sowing attention weights -- a rank or shape mismatch then silently falls through to the einsum instead of raising, and returns a wrong-shaped output with no error at all.Root cause
The fast path (
dropout_rate == 0,module is None) delegates tojax.nn.dot_product_attention, which does its own validation regardless of-O, so it was never affected. The slow path's own checks (lines ~110-112, ~258-262 before this change) are bareasserts, which Python compiles out entirely under-O-- there's no runtime check left at all, not even a degraded one.Fix
Replaced the asserts in both
dot_product_attention_weightsanddot_product_attentionwith explicitif ...: raise ValueError(...), so the same clear error is raised with or without-O. No behavior change in the normal (non--O) case -- same exceptions, same messages, justValueErrorinstead ofAssertionError(matching the existingValueErrorused a few lines below for the GQA head-count check, which already used explicit raises rather than assert).Checklist
test_rank_mismatch_raises_under_dropout_pathregression test, asserting aValueErroron the dropout-path rank mismatch from the issue reproassertis compiled out under-Owhileif: raiseis nottest_gqa_invalid_heads(and other attention tests) unaffected -- this only swaps the exception type and is unconditional regardless of-OFixes #5496