diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 162d7257d0a4c..e3e87a9181ed8 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -236,6 +236,91 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; + + // Validate the incoming initializer's shape against the attribute-derived shape before any of + // the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads + // from the (N, K, bits, block_size) attributes; without this check a crafted model whose + // attributes overstate the real tensor extents would trigger a heap-buffer-overflow READ at + // session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked + // from Compute() -- too late, because PrePack has already done the OOB read, and by then the + // original B tensor is passed as nullptr so the Compute-time check never sees it. + // + // When input_idx == B, this guard also validates the constant scales and zero_points + // initializers (looked up via TryGetConstantInput). SessionState::PrepackConstantInitializedTensors + // iterates inputs in index order, so the B PrePack call runs before scales/zero_points are + // prepacked on their own. The B prepack path reads those constant tensors and passes their + // raw data to the MLAS pack routines (MlasLutGemmPack, MlasQNBitGemmPackQuantBData), which size + // their reads from the same (N, K, bits, block_size) attributes. Without validating scales / + // zero_points here, a crafted model with an undersized scales or zero_points buffer would still + // trigger an OOB read inside the B packing pass before each tensor's own PrePack call could + // catch the mismatch. + // + // This validation runs before the early-return guards below (has_g_idx_, unquantized ZP, + // dynamic-ZP-with-LUT, !MlasIsQNBitGemmAvailable). On builds where MLAS QNBit GEMM is not + // available (e.g. Windows x86 32-bit) PrePack would otherwise short-circuit before reaching + // these checks, and the original B tensor is dropped after PrePack so Compute()'s helper-time + // check never sees it. Running the validation first makes session init reject bad-shape models + // consistently across all build configurations. The checks are cheap (a few TensorShape + // equality comparisons) and independent of any MLAS kernel availability. + { + const int64_t n = static_cast(N_); + const int64_t k = static_cast(K_); + const int64_t bs = static_cast(block_size_); + const int64_t bits = static_cast(nbits_); + const int64_t k_blocks = (k + bs - 1) / bs; + const int64_t blob_size = bs * bits / 8; + const int64_t zp_blob_size_uint8 = (k_blocks * bits + 7) / 8; + + auto validate_scales_shape = [&](const TensorShape& s) -> Status { + // scales may be 1D [n * k_blocks] or 2D [n, k_blocks] for backward compatibility. + ORT_RETURN_IF_NOT(s == TensorShape({n * k_blocks}) || s == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: scales initializer shape ", s, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + return Status::OK(); + }; + + auto validate_zero_points_shape = [&](const TensorShape& s, int32_t element_type) -> Status { + if (element_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + ORT_RETURN_IF_NOT(s == TensorShape({n * zp_blob_size_uint8}) || s == TensorShape({n, zp_blob_size_uint8}), + "MatMulNBits PrePack: zero_points initializer shape ", s, + " does not match attribute-derived shape [", n * zp_blob_size_uint8, "] or [", + n, ",", zp_blob_size_uint8, "]"); + } else { + ORT_RETURN_IF_NOT(s == TensorShape({n * k_blocks}) || s == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: zero_points initializer shape ", s, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + } + return Status::OK(); + }; + + const TensorShape& shape = tensor.Shape(); + + if (input_idx == InputIndex::B) { + ORT_RETURN_IF_NOT(shape == TensorShape({n, k_blocks, blob_size}), + "MatMulNBits PrePack: B initializer shape ", shape, + " does not match attribute-derived shape [", n, ",", k_blocks, ",", blob_size, + "] (N=", N_, ", K=", K_, ", bits=", nbits_, ", block_size=", block_size_, ")"); + + // Also validate constant scales / zero_points, which the B prepack path below dereferences + // (via TryGetConstantInput) and hands to MLAS, before their own PrePack calls run. + const Tensor* scales_tensor = nullptr; + if (OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales_tensor) && scales_tensor != nullptr) { + ORT_RETURN_IF_ERROR(validate_scales_shape(scales_tensor->Shape())); + } + const Tensor* zp_tensor = nullptr; + if (has_zp_arg_ && has_zp_input_ && + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor) && zp_tensor != nullptr) { + ORT_RETURN_IF_ERROR(validate_zero_points_shape(zp_tensor->Shape(), zp_tensor->GetElementType())); + } + } else if (input_idx == InputIndex::scales) { + ORT_RETURN_IF_ERROR(validate_scales_shape(shape)); + } else if (input_idx == InputIndex::zero_points) { + ORT_RETURN_IF_ERROR(validate_zero_points_shape(shape, tensor.GetElementType())); + } + } + if (has_g_idx_) { return Status::OK(); } diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bedf035d320f8..88ba303187d1e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -1269,6 +1269,217 @@ TEST(MatMulNBits, UnsupportedBlockSize_512) { {}, nullptr, &execution_providers); } +// The following tests cover the shape-validation guard added at the top of +// MatMulNBits::PrePack. The guard rejects initializer shapes that do not +// match the attribute-derived shape so that a crafted model whose (N, K, bits, +// block_size) attributes overstate the real tensor extents cannot trigger an +// out-of-bounds READ inside the MLAS pack routines during session +// initialization. Each test passes a B/scales/zero_points initializer whose +// declared shape (and matching data buffer size) is inconsistent with the +// attribute-derived shape, and expects session creation to fail with +// "MatMulNBits PrePack:" (i.e. before Compute() is ever invoked). + +// B shape mismatches the (N, k_blocks, blob_size) shape derived from attributes. +TEST(MatMulNBits, PrePack_InvalidBShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 1 + constexpr int64_t blob_size = block_size * QBits / 8; // 16 + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // Declare B with one fewer row than attributes claim. The data buffer matches + // the smaller declared shape, exactly mirroring the crafted-model scenario in + // which the attributes overstate the tensor's real extent. + constexpr int64_t bad_N = N - 1; + std::vector b_data(bad_N * k_blocks * blob_size, 0); + test.AddInput("B", {bad_N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: B initializer shape", + {}, nullptr, &execution_providers); +} + +// B shape has the wrong rank (2D instead of (N, k_blocks, blob_size)). +TEST(MatMulNBits, PrePack_InvalidBRank_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // Flatten the trailing k_blocks/blob_size dims into a single dimension. + // The total element count still matches, but the rank differs from the + // attribute-derived (N, k_blocks, blob_size) shape. + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks * blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: B initializer shape", + {}, nullptr, &execution_providers); +} + +// scales shape does not match either of the accepted layouts +// ([N * k_blocks] or [N, k_blocks]). +TEST(MatMulNBits, PrePack_InvalidScalesShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + // Declare scales with one fewer row than the attribute-derived layout. + constexpr int64_t bad_N = N - 1; + std::vector scales(bad_N * k_blocks, 1.0f); + test.AddInput("scales", {bad_N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: scales initializer shape", + {}, nullptr, &execution_providers); +} + +// uint8 (packed) zero_points shape does not match the +// [N * zp_blob_size] / [N, zp_blob_size] layout derived from attributes. +TEST(MatMulNBits, PrePack_InvalidUInt8ZeroPointsShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + // Declare uint8 zero_points with one fewer row than the attribute-derived + // layout. zp_blob_size==1 here, so this is also distinguishable from any + // legacy 1D layout that would otherwise be accepted. + constexpr int64_t bad_N = N - 1; + std::vector zp(bad_N * zp_blob_size, 0); + test.AddInput("zero_points", {bad_N, zp_blob_size}, zp, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: zero_points initializer shape", + {}, nullptr, &execution_providers); +} + +// Sanity check: the legacy 1D layouts for scales and uint8 zero_points are +// still accepted by the new shape validation guard (i.e. the guard only +// rejects truly mismatched shapes and does not regress backward +// compatibility for existing models). +TEST(MatMulNBits, PrePack_LegacyFlattenedShapes_Accepted) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8; + + RandomValueGenerator random{1234}; + std::vector a_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector b_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + std::vector b_data(N * k_blocks * blob_size); + std::vector scales(N * k_blocks); + std::vector zp(N * zp_blob_size); + QuantizeDequantize(b_f_vals, b_data, scales, &zp, + static_cast(N), static_cast(K), + static_cast(block_size)); + + std::vector expected(M * N); + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_vals[m * K + k] * b_f_vals[n * K + k]; + } + expected[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + test.AddInput("A", {M, K}, a_vals, false); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + // Legacy flattened 1D layouts for scales and zero_points. + test.AddInput("scales", {N * k_blocks}, scales, true); + test.AddInput("zero_points", {N * zp_blob_size}, zp, true); + + test.AddOutput("Y", {M, N}, expected); + test.SetOutputAbsErr("Y", 0.1f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime