Skip to content

Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445

Open
apsonawane wants to merge 4 commits into
mainfrom
asonawane/edge-3
Open

Validate B/scales/zero_points shape in MatMulNBits::PrePack#29445
apsonawane wants to merge 4 commits into
mainfrom
asonawane/edge-3

Conversation

@apsonawane

Copy link
Copy Markdown
Contributor

MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required).

The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it.

Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape:

  • B -> (N, k_blocks, blob_size)
  • scales -> (N * k_blocks) or (N, k_blocks)
  • zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
    (N * k_blocks) or (N, k_blocks)

A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer.

MatMulNBits::PrePack ran at session initialization and called the MLAS
pack routines using byte counts derived from the node attributes
(N, K, bits, block_size) without ever comparing those attributes to
the actual tensor Shape(). A crafted .onnx whose attributes overstate
the real B (or scales / zero_points) extent triggered a
heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData /
MlasLutGemmPack during OrtApis::CreateSession (no Run() required).

The canonical shape check already lives in
matmul_nbits_helper::CheckInputs, but is invoked only from Compute()
-- after PrePack has already done the OOB read, and by then the
original B tensor is replaced with nullptr in the kernel context so
the Compute-time check never re-validates it.

Fix: at the top of PrePack, after the existing early-return guards
and before any tensor.DataRaw() read, validate the incoming
initializer's Shape() against the attribute-derived shape:

  - B           -> (N, k_blocks, blob_size)
  - scales      -> (N * k_blocks) or (N, k_blocks)
  - zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else
                   (N * k_blocks) or (N, k_blocks)

A mismatch returns INVALID_ARGUMENT so the session fails to load
rather than reading past the buffer.

Copilot AI left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR hardens the CPU MatMulNBits contrib op against malformed models by adding early shape validation in MatMulNBits<T1>::PrePack() so that session initialization rejects inconsistent initializers before any MLAS packing routine can read past the provided buffers.

Changes:

  • Add attribute-derived initializer shape checks for B, scales, and zero_points at the top of MatMulNBits<T1>::PrePack().
  • Add new unit tests that expect session creation to fail (pre-Compute()) for mismatched initializer shapes, plus a compatibility test for legacy flattened scales/zero_points layouts.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 1 comment.

File Description
onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Adds new PrePack-time shape validation intended to prevent OOB reads during weight packing at session init.
onnxruntime/test/contrib_ops/matmul_4bits_test.cc Adds tests that exercise PrePack-time rejection for malformed initializer shapes and verifies legacy flattened layouts remain accepted.

Comment thread onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc Outdated
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants