From 807390233b809c529bb78a3888e59a319e0c26ff Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 10:49:02 -0700 Subject: [PATCH 1/4] Validate B/scales/zero_points shape in MatMulNBits::PrePack MatMulNBits::PrePack ran at session initialization and called the MLAS pack routines using byte counts derived from the node attributes (N, K, bits, block_size) without ever comparing those attributes to the actual tensor Shape(). A crafted .onnx whose attributes overstate the real B (or scales / zero_points) extent triggered a heap-buffer-overflow READ inside MlasQNBitGemmPackQuantBData / MlasLutGemmPack during OrtApis::CreateSession (no Run() required). The canonical shape check already lives in matmul_nbits_helper::CheckInputs, but is invoked only from Compute() -- after PrePack has already done the OOB read, and by then the original B tensor is replaced with nullptr in the kernel context so the Compute-time check never re-validates it. Fix: at the top of PrePack, after the existing early-return guards and before any tensor.DataRaw() read, validate the incoming initializer's Shape() against the attribute-derived shape: - B -> (N, k_blocks, blob_size) - scales -> (N * k_blocks) or (N, k_blocks) - zero_points -> uint8: (N * zp_blob) or (N, zp_blob); else (N * k_blocks) or (N, k_blocks) A mismatch returns INVALID_ARGUMENT so the session fails to load rather than reading past the buffer. --- .../cpu/quantization/matmul_nbits.cc | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 162d7257d0a4c..67f88bdf4605b 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -254,6 +254,49 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All return Status::OK(); } + // Validate the incoming initializer's shape against the attribute-derived shape before any of + // the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads + // from the (N, K, bits, block_size) attributes; without this check a crafted model whose + // attributes overstate the real tensor extents would trigger a heap-buffer-overflow READ at + // session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked + // from Compute() -- too late, because PrePack has already done the OOB read, and by then the + // original B tensor is passed as nullptr so the Compute-time check never sees it. + { + const int64_t n = static_cast(N_); + const int64_t k = static_cast(K_); + const int64_t bs = static_cast(block_size_); + const int64_t bits = static_cast(nbits_); + const int64_t k_blocks = (k + bs - 1) / bs; + const int64_t blob_size = bs * bits / 8; + const TensorShape& shape = tensor.Shape(); + + if (input_idx == InputIndex::B) { + ORT_RETURN_IF_NOT(shape == TensorShape({n, k_blocks, blob_size}), + "MatMulNBits PrePack: B initializer shape ", shape, + " does not match attribute-derived shape [", n, ",", k_blocks, ",", blob_size, + "] (N=", N_, ", K=", K_, ", bits=", nbits_, ", block_size=", block_size_, ")"); + } else if (input_idx == InputIndex::scales) { + // scales may be 1D [n * k_blocks] or 2D [n, k_blocks] for backward compatibility. + ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: scales initializer shape ", shape, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + } else if (input_idx == InputIndex::zero_points) { + if (tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + const int64_t zp_blob_size = (k_blocks * bits + 7) / 8; + ORT_RETURN_IF_NOT(shape == TensorShape({n * zp_blob_size}) || shape == TensorShape({n, zp_blob_size}), + "MatMulNBits PrePack: zero_points initializer shape ", shape, + " does not match attribute-derived shape [", n * zp_blob_size, "] or [", + n, ",", zp_blob_size, "]"); + } else { + ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: zero_points initializer shape ", shape, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + } + } + } + // Create a temporary threadpool for parallel packing // This is used during model load time to speed up weight prepacking std::unique_ptr temp_threadpool; From 495fbd2a2e2d9b5d6b696b26520e19f61528940d Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 11:06:45 -0700 Subject: [PATCH 2/4] add unit tests --- .../test/contrib_ops/matmul_4bits_test.cc | 211 ++++++++++++++++++ 1 file changed, 211 insertions(+) diff --git a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc index bedf035d320f8..88ba303187d1e 100644 --- a/onnxruntime/test/contrib_ops/matmul_4bits_test.cc +++ b/onnxruntime/test/contrib_ops/matmul_4bits_test.cc @@ -1269,6 +1269,217 @@ TEST(MatMulNBits, UnsupportedBlockSize_512) { {}, nullptr, &execution_providers); } +// The following tests cover the shape-validation guard added at the top of +// MatMulNBits::PrePack. The guard rejects initializer shapes that do not +// match the attribute-derived shape so that a crafted model whose (N, K, bits, +// block_size) attributes overstate the real tensor extents cannot trigger an +// out-of-bounds READ inside the MLAS pack routines during session +// initialization. Each test passes a B/scales/zero_points initializer whose +// declared shape (and matching data buffer size) is inconsistent with the +// attribute-derived shape, and expects session creation to fail with +// "MatMulNBits PrePack:" (i.e. before Compute() is ever invoked). + +// B shape mismatches the (N, k_blocks, blob_size) shape derived from attributes. +TEST(MatMulNBits, PrePack_InvalidBShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; // 1 + constexpr int64_t blob_size = block_size * QBits / 8; // 16 + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // Declare B with one fewer row than attributes claim. The data buffer matches + // the smaller declared shape, exactly mirroring the crafted-model scenario in + // which the attributes overstate the tensor's real extent. + constexpr int64_t bad_N = N - 1; + std::vector b_data(bad_N * k_blocks * blob_size, 0); + test.AddInput("B", {bad_N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: B initializer shape", + {}, nullptr, &execution_providers); +} + +// B shape has the wrong rank (2D instead of (N, k_blocks, blob_size)). +TEST(MatMulNBits, PrePack_InvalidBRank_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + // Flatten the trailing k_blocks/blob_size dims into a single dimension. + // The total element count still matches, but the rank differs from the + // attribute-derived (N, k_blocks, blob_size) shape. + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks * blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: B initializer shape", + {}, nullptr, &execution_providers); +} + +// scales shape does not match either of the accepted layouts +// ([N * k_blocks] or [N, k_blocks]). +TEST(MatMulNBits, PrePack_InvalidScalesShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + // Declare scales with one fewer row than the attribute-derived layout. + constexpr int64_t bad_N = N - 1; + std::vector scales(bad_N * k_blocks, 1.0f); + test.AddInput("scales", {bad_N, k_blocks}, scales, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: scales initializer shape", + {}, nullptr, &execution_providers); +} + +// uint8 (packed) zero_points shape does not match the +// [N * zp_blob_size] / [N, zp_blob_size] layout derived from attributes. +TEST(MatMulNBits, PrePack_InvalidUInt8ZeroPointsShape_RejectsAtSessionInit) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8; + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + std::vector a_data(M * K, 1.0f); + test.AddInput("A", {M, K}, a_data, false); + + std::vector b_data(N * k_blocks * blob_size, 0); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + + std::vector scales(N * k_blocks, 1.0f); + test.AddInput("scales", {N, k_blocks}, scales, true); + + // Declare uint8 zero_points with one fewer row than the attribute-derived + // layout. zp_blob_size==1 here, so this is also distinguishable from any + // legacy 1D layout that would otherwise be accepted. + constexpr int64_t bad_N = N - 1; + std::vector zp(bad_N * zp_blob_size, 0); + test.AddInput("zero_points", {bad_N, zp_blob_size}, zp, true); + + std::vector y_data(M * N, 0.0f); + test.AddOutput("Y", {M, N}, y_data); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "MatMulNBits PrePack: zero_points initializer shape", + {}, nullptr, &execution_providers); +} + +// Sanity check: the legacy 1D layouts for scales and uint8 zero_points are +// still accepted by the new shape validation guard (i.e. the guard only +// rejects truly mismatched shapes and does not regress backward +// compatibility for existing models). +TEST(MatMulNBits, PrePack_LegacyFlattenedShapes_Accepted) { + constexpr int64_t M = 1, N = 4, K = 32, block_size = 32; + constexpr int64_t k_blocks = (K + block_size - 1) / block_size; + constexpr int64_t blob_size = block_size * QBits / 8; + constexpr int64_t zp_blob_size = (k_blocks * QBits + 7) / 8; + + RandomValueGenerator random{1234}; + std::vector a_vals(random.Gaussian(AsSpan({M, K}), 0.0f, 0.25f)); + std::vector b_f_vals(random.Gaussian(AsSpan({K, N}), 0.0f, 0.25f)); + + std::vector b_data(N * k_blocks * blob_size); + std::vector scales(N * k_blocks); + std::vector zp(N * zp_blob_size); + QuantizeDequantize(b_f_vals, b_data, scales, &zp, + static_cast(N), static_cast(K), + static_cast(block_size)); + + std::vector expected(M * N); + for (int64_t m = 0; m < M; ++m) { + for (int64_t n = 0; n < N; ++n) { + float sum = 0.0f; + for (int64_t k = 0; k < K; ++k) { + sum += a_vals[m * K + k] * b_f_vals[n * K + k]; + } + expected[m * N + n] = sum; + } + } + + OpTester test("MatMulNBits", 1, kMSDomain); + test.AddAttribute("K", K); + test.AddAttribute("N", N); + test.AddAttribute("block_size", block_size); + test.AddAttribute("bits", QBits); + test.AddAttribute("accuracy_level", int64_t{0}); + + test.AddInput("A", {M, K}, a_vals, false); + test.AddInput("B", {N, k_blocks, blob_size}, b_data, true); + // Legacy flattened 1D layouts for scales and zero_points. + test.AddInput("scales", {N * k_blocks}, scales, true); + test.AddInput("zero_points", {N * zp_blob_size}, zp, true); + + test.AddOutput("Y", {M, N}, expected); + test.SetOutputAbsErr("Y", 0.1f); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime From bcba24f720e5908f978e793bed063552bdae77e4 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 11:58:27 -0700 Subject: [PATCH 3/4] Address comments --- .../cpu/quantization/matmul_nbits.cc | 67 ++++++++++++++----- 1 file changed, 50 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 67f88bdf4605b..5ebd78771c9a9 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -261,6 +261,16 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // session initialization. The matching guard in matmul_nbits_helper::CheckInputs is invoked // from Compute() -- too late, because PrePack has already done the OOB read, and by then the // original B tensor is passed as nullptr so the Compute-time check never sees it. + // + // When input_idx == B, this guard also validates the constant scales and zero_points + // initializers (looked up via TryGetConstantInput). SessionState::PrepackConstantInitializedTensors + // iterates inputs in index order, so the B PrePack call runs before scales/zero_points are + // prepacked on their own. The B prepack path reads those constant tensors and passes their + // raw data to the MLAS pack routines (MlasLutGemmPack, MlasQNBitGemmPackQuantBData), which size + // their reads from the same (N, K, bits, block_size) attributes. Without validating scales / + // zero_points here, a crafted model with an undersized scales or zero_points buffer would still + // trigger an OOB read inside the B packing pass before each tensor's own PrePack call could + // catch the mismatch. { const int64_t n = static_cast(N_); const int64_t k = static_cast(K_); @@ -268,6 +278,32 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All const int64_t bits = static_cast(nbits_); const int64_t k_blocks = (k + bs - 1) / bs; const int64_t blob_size = bs * bits / 8; + const int64_t zp_blob_size_uint8 = (k_blocks * bits + 7) / 8; + + auto validate_scales_shape = [&](const TensorShape& s) -> Status { + // scales may be 1D [n * k_blocks] or 2D [n, k_blocks] for backward compatibility. + ORT_RETURN_IF_NOT(s == TensorShape({n * k_blocks}) || s == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: scales initializer shape ", s, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + return Status::OK(); + }; + + auto validate_zero_points_shape = [&](const TensorShape& s, int32_t element_type) -> Status { + if (element_type == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { + ORT_RETURN_IF_NOT(s == TensorShape({n * zp_blob_size_uint8}) || s == TensorShape({n, zp_blob_size_uint8}), + "MatMulNBits PrePack: zero_points initializer shape ", s, + " does not match attribute-derived shape [", n * zp_blob_size_uint8, "] or [", + n, ",", zp_blob_size_uint8, "]"); + } else { + ORT_RETURN_IF_NOT(s == TensorShape({n * k_blocks}) || s == TensorShape({n, k_blocks}), + "MatMulNBits PrePack: zero_points initializer shape ", s, + " does not match attribute-derived shape [", n * k_blocks, "] or [", + n, ",", k_blocks, "]"); + } + return Status::OK(); + }; + const TensorShape& shape = tensor.Shape(); if (input_idx == InputIndex::B) { @@ -275,25 +311,22 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All "MatMulNBits PrePack: B initializer shape ", shape, " does not match attribute-derived shape [", n, ",", k_blocks, ",", blob_size, "] (N=", N_, ", K=", K_, ", bits=", nbits_, ", block_size=", block_size_, ")"); + + // Also validate constant scales / zero_points, which the B prepack path below dereferences + // (via TryGetConstantInput) and hands to MLAS, before their own PrePack calls run. + const Tensor* scales_tensor = nullptr; + if (OpKernel::Info().TryGetConstantInput(InputIndex::scales, &scales_tensor) && scales_tensor != nullptr) { + ORT_RETURN_IF_ERROR(validate_scales_shape(scales_tensor->Shape())); + } + const Tensor* zp_tensor = nullptr; + if (has_zp_arg_ && has_zp_input_ && + OpKernel::Info().TryGetConstantInput(InputIndex::zero_points, &zp_tensor) && zp_tensor != nullptr) { + ORT_RETURN_IF_ERROR(validate_zero_points_shape(zp_tensor->Shape(), zp_tensor->GetElementType())); + } } else if (input_idx == InputIndex::scales) { - // scales may be 1D [n * k_blocks] or 2D [n, k_blocks] for backward compatibility. - ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}), - "MatMulNBits PrePack: scales initializer shape ", shape, - " does not match attribute-derived shape [", n * k_blocks, "] or [", - n, ",", k_blocks, "]"); + ORT_RETURN_IF_ERROR(validate_scales_shape(shape)); } else if (input_idx == InputIndex::zero_points) { - if (tensor.GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_UINT8) { - const int64_t zp_blob_size = (k_blocks * bits + 7) / 8; - ORT_RETURN_IF_NOT(shape == TensorShape({n * zp_blob_size}) || shape == TensorShape({n, zp_blob_size}), - "MatMulNBits PrePack: zero_points initializer shape ", shape, - " does not match attribute-derived shape [", n * zp_blob_size, "] or [", - n, ",", zp_blob_size, "]"); - } else { - ORT_RETURN_IF_NOT(shape == TensorShape({n * k_blocks}) || shape == TensorShape({n, k_blocks}), - "MatMulNBits PrePack: zero_points initializer shape ", shape, - " does not match attribute-derived shape [", n * k_blocks, "] or [", - n, ",", k_blocks, "]"); - } + ORT_RETURN_IF_ERROR(validate_zero_points_shape(shape, tensor.GetElementType())); } } From 43c1b16b8a174d2eb8a5f9b543e31e48c5b2f2ac Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 14:49:55 -0700 Subject: [PATCH 4/4] Fix pipeline --- .../cpu/quantization/matmul_nbits.cc | 43 +++++++++++-------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc index 5ebd78771c9a9..e3e87a9181ed8 100644 --- a/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc +++ b/onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc @@ -236,23 +236,6 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All /*out*/ bool& is_packed, /*out*/ PrePackedWeights* prepacked_weights) { is_packed = false; - if (has_g_idx_) { - return Status::OK(); - } - if (has_unquantized_zero_point_ && !prefer_lut_gemm_) { - return Status::OK(); - } - - // LUT GEMM requires ZP to be a constant initializer for prepacking. If the node - // has a ZP input but it's dynamic, skip LUT packing and fall through to the - // unpacked dequant path at compute time (similar to KleidiAI's dynamic ZP fallback). - if (prefer_lut_gemm_ && has_zp_arg_ && !has_zp_input_) { - return Status::OK(); - } - - if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { - return Status::OK(); - } // Validate the incoming initializer's shape against the attribute-derived shape before any of // the pack routines below dereference tensor.DataRaw(). The MLAS pack routines size their reads @@ -271,6 +254,14 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All // zero_points here, a crafted model with an undersized scales or zero_points buffer would still // trigger an OOB read inside the B packing pass before each tensor's own PrePack call could // catch the mismatch. + // + // This validation runs before the early-return guards below (has_g_idx_, unquantized ZP, + // dynamic-ZP-with-LUT, !MlasIsQNBitGemmAvailable). On builds where MLAS QNBit GEMM is not + // available (e.g. Windows x86 32-bit) PrePack would otherwise short-circuit before reaching + // these checks, and the original B tensor is dropped after PrePack so Compute()'s helper-time + // check never sees it. Running the validation first makes session init reject bad-shape models + // consistently across all build configurations. The checks are cheap (a few TensorShape + // equality comparisons) and independent of any MLAS kernel availability. { const int64_t n = static_cast(N_); const int64_t k = static_cast(K_); @@ -330,6 +321,24 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ All } } + if (has_g_idx_) { + return Status::OK(); + } + if (has_unquantized_zero_point_ && !prefer_lut_gemm_) { + return Status::OK(); + } + + // LUT GEMM requires ZP to be a constant initializer for prepacking. If the node + // has a ZP input but it's dynamic, skip LUT packing and fall through to the + // unpacked dequant path at compute time (similar to KleidiAI's dynamic ZP fallback). + if (prefer_lut_gemm_ && has_zp_arg_ && !has_zp_input_) { + return Status::OK(); + } + + if (!MlasIsQNBitGemmAvailable(nbits_, block_size_, compute_type_) && !prefer_lut_gemm_) { + return Status::OK(); + } + // Create a temporary threadpool for parallel packing // This is used during model load time to speed up weight prepacking std::unique_ptr temp_threadpool;