Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445
Open
apsonawane wants to merge 4 commits into
Open
Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445apsonawane wants to merge 4 commits into
apsonawane wants to merge 4 commits into
Conversation
MatMulNBits::PrePack ran at session initialization and called the MLAS
pack routines using byte counts derived from the node attributes
(N, K, bits, block_size) without ever comparing those attributes to
the actual tensor Shape(). A crafted .onnx whose attributes overstate
the real B (or scales / zero_points) extent triggered a
heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData /
MlasLutGemmPack during OrtApis::CreateSession (no Run() required).
The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute()
-- after PrePack has already done the OOB read, and by then the
original B tensor is replaced with nullptr in the kernel context so
the Compute-time check never re-validates it.
Fix: at the top of PrePack, after the existing early-return guards
and before any tensor.DataRaw() read, validate the incoming
initializer's Shape() against the attribute-derived shape:
- B -> (N, k_blocks, blob_size)
- scales -> (N * k_blocks) or (N, k_blocks)
- zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
(N * k_blocks) or (N, k_blocks)
A mismatch returns INVALID_ARGUMENT so the session fails to load
rather than reading past the buffer.
Contributor
There was a problem hiding this comment.
Pull request overview
This PR hardens the CPU MatMulNBits contrib op against malformed models by adding early shape validation in MatMulNBits<T1>::PrePack() so that session initialization rejects inconsistent initializers before any MLAS packing routine can read past the provided buffers.
Changes:
- Add attribute-derived initializer shape checks for
B,scales, andzero_pointsat the top ofMatMulNBits<T1>::PrePack(). - Add new unit tests that expect session creation to fail (pre-
Compute()) for mismatched initializer shapes, plus a compatibility test for legacy flattenedscales/zero_pointslayouts.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.
| File | Description |
|---|---|
| onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc | Adds new PrePack-time shape validation intended to prevent OOB reads during weight packing at session init. |
| onnxruntime/test/contrib_ops/matmul_4bits_test.cc | Adds tests that exercise PrePack-time rejection for malformed initializer shapes and verifies legacy flattened layouts remain accepted. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required).
The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it.
Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape:
(N * k_blocks) or (N, k_blocks)
A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer.