Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,49 @@ Status MatMulNBits<T1>::PrePack(const Tensor& tensor, int input_idx, /*out*/ All
return Status::OK();
}

// Validate the incoming initializer's shape against the attribute-derived shape before any of
// the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads
// from the (N, K, bits, block_size) attributes; without this check a crafted model whose
// attributes overstate the real tensor extents would trigger a heap-buffer-overflow READ at
// session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked
// from Compute() -- too late, because PrePack has already done the OOB read, and by then the
// original B tensor is passed as nullptr so the Compute-time check never sees it.
{
const int64_t n = static_cast<int64_t>(N_);
const int64_t k = static_cast<int64_t>(K_);
const int64_t bs = static_cast<int64_t>(block_size_);
const int64_t bits = static_cast<int64_t>(nbits_);
const int64_t k_blocks = (k + bs - 1) / bs;
const int64_t blob_size = bs * bits / 8;
const TensorShape& shape = tensor.Shape();

if (input_idx == InputIndex::B) {
ORT_RETURN_IF_NOT(shape == TensorShape({n, k_blocks, blob_size}),
"MatMulNBits PrePack: B initializer shape ", shape,
" does not match attribute-derived shape [", n, ",", k_blocks, ",", blob_size,
"] (N=", N_, ", K=", K_, ", bits=", nbits_, ", block_size=", block_size_, ")");
} else if (input_idx == InputIndex::scales) {
// scales may be 1D [n * k_blocks] or 2D [n, k_blocks] for backward compatibility.
ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}),
"MatMulNBits PrePack: scales initializer shape ", shape,
" does not match attribute-derived shape [", n * k_blocks, "] or [",
n, ",", k_blocks, "]");
Comment thread
apsonawane marked this conversation as resolved.
Outdated
} else if (input_idx == InputIndex::zero_points) {
if (tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) {
const int64_t zp_blob_size = (k_blocks * bits + 7) / 8;
ORT_RETURN_IF_NOT(shape == TensorShape({n * zp_blob_size}) || shape == TensorShape({n, zp_blob_size}),
"MatMulNBits PrePack: zero_points initializer shape ", shape,
" does not match attribute-derived shape [", n * zp_blob_size, "] or [",
n, ",", zp_blob_size, "]");
} else {
ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}),
"MatMulNBits PrePack: zero_points initializer shape ", shape,
" does not match attribute-derived shape [", n * k_blocks, "] or [",
n, ",", k_blocks, "]");
}
}
}

// Create a temporary threadpool for parallel packing
// This is used during model load time to speed up weight prepacking
std::unique_ptr<concurrency::ThreadPool> temp_threadpool;
Expand Down
211 changes: 211 additions & 0 deletions onnxruntime/test/contrib_ops/matmul_4bits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1269,6 +1269,217 @@ TEST(MatMulNBits, UnsupportedBlockSize_512) {
{}, nullptr, &execution_providers);
}

// The following tests cover the shape-validation guard added at the top of
// MatMulNBits<T1>::PrePack. The guard rejects initializer shapes that do not
// match the attribute-derived shape so that a crafted model whose (N, K, bits,
// block_size) attributes overstate the real tensor extents cannot trigger an
// out-of-bounds READ inside the MLAS pack routines during session
// initialization. Each test passes a B/scales/zero_points initializer whose
// declared shape (and matching data buffer size) is inconsistent with the
// attribute-derived shape, and expects session creation to fail with
// "MatMulNBits PrePack:" (i.e. before Compute() is ever invoked).

// B shape mismatches the (N, k_blocks, blob_size) shape derived from attributes.
TEST(MatMulNBits, PrePack_InvalidBShape_RejectsAtSessionInit) {
constexpr int64_t M = 1, N = 4, K = 32, block_size = 32;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 1
constexpr int64_t blob_size = block_size * QBits / 8; // 16

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

// Declare B with one fewer row than attributes claim. The data buffer matches
// the smaller declared shape, exactly mirroring the crafted-model scenario in
// which the attributes overstate the tensor's real extent.
constexpr int64_t bad_N = N - 1;
std::vector<uint8_t> b_data(bad_N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {bad_N, k_blocks, blob_size}, b_data, true);

std::vector<float> scales(N * k_blocks, 1.0f);
test.AddInput<float>("scales", {N, k_blocks}, scales, true);

std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"MatMulNBits PrePack: B initializer shape",
{}, nullptr, &execution_providers);
}

// B shape has the wrong rank (2D instead of (N, k_blocks, blob_size)).
TEST(MatMulNBits, PrePack_InvalidBRank_RejectsAtSessionInit) {
constexpr int64_t M = 1, N = 4, K = 32, block_size = 32;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
constexpr int64_t blob_size = block_size * QBits / 8;

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

// Flatten the trailing k_blocks/blob_size dims into a single dimension.
// The total element count still matches, but the rank differs from the
// attribute-derived (N, k_blocks, blob_size) shape.
std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {N, k_blocks * blob_size}, b_data, true);

std::vector<float> scales(N * k_blocks, 1.0f);
test.AddInput<float>("scales", {N, k_blocks}, scales, true);

std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"MatMulNBits PrePack: B initializer shape",
{}, nullptr, &execution_providers);
}

// scales shape does not match either of the accepted layouts
// ([N * k_blocks] or [N, k_blocks]).
TEST(MatMulNBits, PrePack_InvalidScalesShape_RejectsAtSessionInit) {
constexpr int64_t M = 1, N = 4, K = 32, block_size = 32;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
constexpr int64_t blob_size = block_size * QBits / 8;

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);

// Declare scales with one fewer row than the attribute-derived layout.
constexpr int64_t bad_N = N - 1;
std::vector<float> scales(bad_N * k_blocks, 1.0f);
test.AddInput<float>("scales", {bad_N, k_blocks}, scales, true);

std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"MatMulNBits PrePack: scales initializer shape",
{}, nullptr, &execution_providers);
}

// uint8 (packed) zero_points shape does not match the
// [N * zp_blob_size] / [N, zp_blob_size] layout derived from attributes.
TEST(MatMulNBits, PrePack_InvalidUInt8ZeroPointsShape_RejectsAtSessionInit) {
constexpr int64_t M = 1, N = 4, K = 32, block_size = 32;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
constexpr int64_t blob_size = block_size * QBits / 8;
constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8;

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

std::vector<float> a_data(M * K, 1.0f);
test.AddInput<float>("A", {M, K}, a_data, false);

std::vector<uint8_t> b_data(N * k_blocks * blob_size, 0);
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);

std::vector<float> scales(N * k_blocks, 1.0f);
test.AddInput<float>("scales", {N, k_blocks}, scales, true);

// Declare uint8 zero_points with one fewer row than the attribute-derived
// layout. zp_blob_size==1 here, so this is also distinguishable from any
// legacy 1D layout that would otherwise be accepted.
constexpr int64_t bad_N = N - 1;
std::vector<uint8_t> zp(bad_N * zp_blob_size, 0);
test.AddInput<uint8_t>("zero_points", {bad_N, zp_blob_size}, zp, true);

std::vector<float> y_data(M * N, 0.0f);
test.AddOutput<float>("Y", {M, N}, y_data);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"MatMulNBits PrePack: zero_points initializer shape",
{}, nullptr, &execution_providers);
}

// Sanity check: the legacy 1D layouts for scales and uint8 zero_points are
// still accepted by the new shape validation guard (i.e. the guard only
// rejects truly mismatched shapes and does not regress backward
// compatibility for existing models).
TEST(MatMulNBits, PrePack_LegacyFlattenedShapes_Accepted) {
constexpr int64_t M = 1, N = 4, K = 32, block_size = 32;
constexpr int64_t k_blocks = (K + block_size - 1) / block_size;
constexpr int64_t blob_size = block_size * QBits / 8;
constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8;

RandomValueGenerator random{1234};
std::vector<float> a_vals(random.Gaussian<float>(AsSpan({M, K}), 0.0f, 0.25f));
std::vector<float> b_f_vals(random.Gaussian<float>(AsSpan({K, N}), 0.0f, 0.25f));

std::vector<uint8_t> b_data(N * k_blocks * blob_size);
std::vector<float> scales(N * k_blocks);
std::vector<uint8_t> zp(N * zp_blob_size);
QuantizeDequantize(b_f_vals, b_data, scales, &zp,
static_cast<int32_t>(N), static_cast<int32_t>(K),
static_cast<int32_t>(block_size));

std::vector<float> expected(M * N);
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
float sum = 0.0f;
for (int64_t k = 0; k < K; ++k) {
sum += a_vals[m * K + k] * b_f_vals[n * K + k];
}
expected[m * N + n] = sum;
}
}

OpTester test("MatMulNBits", 1, kMSDomain);
test.AddAttribute<int64_t>("K", K);
test.AddAttribute<int64_t>("N", N);
test.AddAttribute<int64_t>("block_size", block_size);
test.AddAttribute<int64_t>("bits", QBits);
test.AddAttribute<int64_t>("accuracy_level", int64_t{0});

test.AddInput<float>("A", {M, K}, a_vals, false);
test.AddInput<uint8_t>("B", {N, k_blocks, blob_size}, b_data, true);
// Legacy flattened 1D layouts for scales and zero_points.
test.AddInput<float>("scales", {N * k_blocks}, scales, true);
test.AddInput<uint8_t>("zero_points", {N * zp_blob_size}, zp, true);

test.AddOutput<float>("Y", {M, N}, expected);
test.SetOutputAbsErr("Y", 0.1f);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "",
{}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime

Expand Down
Loading