From d1e07be4349a3c098b10788431b13235d6706059 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 10:44:43 -0700 Subject: [PATCH 1/3] Guard MlasConvPrepare working-buffer products with SafeInt MlasConvPrepare computed the im2col working-buffer size as raw size_t multiplies of attacker-controlled tensor shape factors (OutputSize * K, and the thread-partition product TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD). The MSVC C26451 arithmetic-overflow lint that flags exactly this class was suppressed by #pragma warning(disable : 26451) rather than the arithmetic being guarded. On wrap, the CPU EP allocates a small working buffer and MlasConvIm2Col then writes the full unwrapped count of floats past it -- heap-buffer-overflow WRITE. Per-tensor element counts are already SafeInt-checked, but the product across tensors was not. Fix: accumulate dim products via SafeInt and compute the final working-buffer size with SafeInt-guarded multiplies. SafeInt throws on overflow; the caller propagates as Status failure. Also guards the adjacent BatchCount * GroupCount product. Removes the C26451 suppression block. Adds #include "core/common/safeint.h" (already used elsewhere in core/mlas/lib). --- onnxruntime/core/mlas/lib/convolve.cpp | 46 +++++++++++++++----------- 1 file changed, 27 insertions(+), 19 deletions(-) diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 9bff72b29d8fb..04d7d570aa8c7 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -15,6 +15,7 @@ Module Name: --*/ #include "mlasi.h" +#include "core/common/safeint.h" // // Define the number of working buffer elements required per thread. @@ -1328,11 +1329,6 @@ Return Value: } } } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(push) -// Chance of arithmetic overflow could be reduced -#pragma warning(disable : 26451) -#endif namespace { @@ -1560,9 +1556,14 @@ Return Value: Parameters->FilterCount = FilterCount; Parameters->Beta = Beta; - size_t InputSize = 1; - size_t OutputSize = 1; - size_t K = InputChannels; + // Accumulate dimension products in SafeInt's checked domain. The raw size_t multiplies + // here are reachable from attacker-controlled tensor shapes, and the cross-tensor product + // is not otherwise guarded (per-tensor element counts are SafeInt-checked elsewhere, but + // the product across tensors is not). On overflow SafeInt throws and the caller propagates + // it as a Status failure. + SafeInt SafeInputSize = 1; + SafeInt SafeOutputSize = 1; + SafeInt SafeK = SafeInt(InputChannels); bool AllStridesAreOne = true; bool AllDilationsAreOne = true; @@ -1578,15 +1579,19 @@ Return Value: Parameters->Padding[dim + Dimensions] = size_t(Padding[dim + Dimensions]); Parameters->StrideShape[dim] = size_t(StrideShape[dim]); - InputSize *= Parameters->InputShape[dim]; - OutputSize *= Parameters->OutputShape[dim]; - K *= Parameters->KernelShape[dim]; + SafeInputSize *= Parameters->InputShape[dim]; + SafeOutputSize *= Parameters->OutputShape[dim]; + SafeK *= Parameters->KernelShape[dim]; AllStridesAreOne &= (Parameters->StrideShape[dim] == 1); AllDilationsAreOne &= (Parameters->DilationShape[dim] == 1); AllPaddingIsZero &= (Parameters->Padding[dim] == 0 && Parameters->Padding[dim + Dimensions] == 0); } + const size_t InputSize = SafeInputSize; + const size_t OutputSize = SafeOutputSize; + const size_t K = SafeK; + Parameters->InputSize = InputSize; Parameters->OutputSize = OutputSize; Parameters->K = K; @@ -1675,7 +1680,10 @@ Return Value: Parameters->Algorithm = MlasConvAlgorithmExpandThenGemm; - *WorkingBufferSize = OutputSize * K; + // SafeInt guards against wrap of the cross-tensor product; the raw size_t multiply + // is reachable from attacker-controlled shapes and would otherwise under-size the + // im2col working buffer (heap-buffer-overflow on the downstream MlasConvIm2Col write). + *WorkingBufferSize = SafeInt(OutputSize) * K; } else { @@ -1774,18 +1782,18 @@ Return Value: Parameters->Algorithm = MlasConvAlgorithmExpandThenGemmSegmented; Parameters->u.ExpandThenGemmSegmented.ThreadStrideN = StrideN; - *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + // SafeInt-guarded products: TargetThreadCount and the BatchCount*GroupCount product + // are both reachable from attacker-controlled inputs. + *WorkingBufferSize = SafeInt(TargetThreadCount) * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; if (Parameters->BatchCount > 1 || Parameters->GroupCount > 1) { TargetThreadCount = MaximumThreadCount; - if (static_cast(TargetThreadCount) >= Parameters->BatchCount * Parameters->GroupCount) { - TargetThreadCount = static_cast(Parameters->BatchCount * Parameters->GroupCount); + const size_t BatchGroupProduct = SafeInt(Parameters->BatchCount) * Parameters->GroupCount; + if (static_cast(TargetThreadCount) >= BatchGroupProduct) { + TargetThreadCount = static_cast(BatchGroupProduct); } - *WorkingBufferSize = TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; + *WorkingBufferSize = SafeInt(TargetThreadCount) * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD; } } } -#if defined(_MSC_VER) && !defined(__clang__) -#pragma warning(pop) -#endif From 5f4b82e97d9c6ebb54ea1cf6c5a1be7c37a9e70f Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 11:16:56 -0700 Subject: [PATCH 2/3] Add unit tests --- .../unittest/test_conv_prepare_safeint.cpp | 195 ++++++++++++++++++ 1 file changed, 195 insertions(+) create mode 100644 onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp diff --git a/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp b/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp new file mode 100644 index 0000000000000..07224522a9285 --- /dev/null +++ b/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp @@ -0,0 +1,195 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// Tests for the SafeInt-guarded working-buffer/size products in +// MlasConvPrepare (see PR #29444). These cases construct attacker-controlled +// shape/count combinations whose products overflow size_t and verify that +// MlasConvPrepare throws onnxruntime::OnnxRuntimeException (raised by +// SafeInt's overflow handler) rather than silently producing an under-sized +// working buffer or wrapped tensor size. + +#include "gtest/gtest.h" + +#include "core/common/exceptions.h" +#include "mlas.h" + +#include +#include +#include + +#ifndef ORT_NO_EXCEPTIONS + +namespace { + +// A value whose square overflows size_t on every supported pointer width +// (64-bit: 2^33, 32-bit: 2^17). Each individual value still fits in size_t +// and in int64_t shape entries, so the loop accumulators reach the SafeInt +// multiplication before the per-tensor value itself is invalid. +constexpr size_t kHalfShift = static_cast(1) << ((sizeof(size_t) * 8 / 2) + 1); + +// Identity activation, shared across all tests because MlasConvPrepare +// dereferences but does not invoke it. +MLAS_ACTIVATION MakeIdentityActivation() { + MLAS_ACTIVATION activation{}; + activation.ActivationKind = MlasIdentityActivation; + return activation; +} + +// A baseline-valid 3D conv setup. Callers mutate one field per test to drive +// the specific SafeInt-guarded product into overflow. 3D is chosen so the +// KleidiAI MlasConvPrepareOverride (which only handles 2D) is bypassed on +// every platform. +struct ConvPrepareInputs { + size_t Dimensions = 3; + size_t BatchCount = 1; + size_t GroupCount = 1; + size_t InputChannels = 1; + size_t FilterCount = 1; + int64_t InputShape[3] = {1, 1, 1}; + int64_t KernelShape[3] = {1, 1, 1}; + int64_t DilationShape[3] = {1, 1, 1}; + // Stride[1] = 2 (rather than all-ones) keeps execution out of the pointwise + // "direct GEMM" early-return so the SafeInt-guarded code paths are reached. + int64_t StrideShape[3] = {1, 1, 2}; + int64_t Padding[6] = {0, 0, 0, 0, 0, 0}; + int64_t OutputShape[3] = {1, 1, 1}; + float Beta = 0.0f; + bool ChannelsLast = false; +}; + +void RunConvPrepare(const ConvPrepareInputs& in, size_t* working_buffer_size_out = nullptr) { + MLAS_CONV_PARAMETERS parameters{}; + MLAS_ACTIVATION activation = MakeIdentityActivation(); + size_t working_buffer_size = 0; + + MlasConvPrepare(¶meters, + in.Dimensions, + in.BatchCount, + in.GroupCount, + in.InputChannels, + in.InputShape, + in.KernelShape, + in.DilationShape, + in.Padding, + in.StrideShape, + in.OutputShape, + in.FilterCount, + &activation, + &working_buffer_size, + in.ChannelsLast, + in.Beta, + /*ThreadPool=*/nullptr); + + if (working_buffer_size_out != nullptr) { + *working_buffer_size_out = working_buffer_size; + } +} + +} // namespace + +// Sanity check: the SafeInt instrumentation must not regress the happy path. +TEST(MlasConvPrepareSafeIntTest, SmallShapeDoesNotThrow) { + ConvPrepareInputs in; + in.InputShape[0] = 4; + in.InputShape[1] = 4; + in.InputShape[2] = 4; + in.KernelShape[0] = 1; + in.KernelShape[1] = 1; + in.KernelShape[2] = 1; + in.OutputShape[0] = 2; + in.OutputShape[1] = 2; + in.OutputShape[2] = 2; + in.InputChannels = 2; + in.FilterCount = 2; + + size_t working_buffer_size = 0; + EXPECT_NO_THROW(RunConvPrepare(in, &working_buffer_size)); +} + +// SafeInputSize *= Parameters->InputShape[dim] must trip on overflow. +TEST(MlasConvPrepareSafeIntTest, InputSizeProductOverflowsThrows) { + ConvPrepareInputs in; + in.InputShape[0] = static_cast(kHalfShift); + in.InputShape[1] = static_cast(kHalfShift); + in.InputShape[2] = 1; + + EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); +} + +// SafeOutputSize *= Parameters->OutputShape[dim] must trip on overflow. +TEST(MlasConvPrepareSafeIntTest, OutputSizeProductOverflowsThrows) { + ConvPrepareInputs in; + in.OutputShape[0] = static_cast(kHalfShift); + in.OutputShape[1] = static_cast(kHalfShift); + in.OutputShape[2] = 1; + + EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); +} + +// SafeK is seeded with InputChannels and then folded against the kernel shape; +// growing the kernel dimensions until the running product overflows must throw. +TEST(MlasConvPrepareSafeIntTest, KernelProductOverflowsThrows) { + ConvPrepareInputs in; + in.InputChannels = kHalfShift; + in.KernelShape[0] = static_cast(kHalfShift); + in.KernelShape[1] = 1; + in.KernelShape[2] = 1; + + EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); +} + +// In the ExpandThenGemm path *WorkingBufferSize = SafeInt(OutputSize) * K. +// Individual values fit, but the cross-tensor product overflows and must throw +// rather than under-sizing the im2col buffer. +TEST(MlasConvPrepareSafeIntTest, ExpandThenGemmWorkingBufferOverflowsThrows) { + ConvPrepareInputs in; + // OutputSize = kHalfShift (only one non-unit output dim so the running + // SafeOutputSize accumulation stays in range). + in.OutputShape[0] = static_cast(kHalfShift); + in.OutputShape[1] = 1; + in.OutputShape[2] = 1; + // K = InputChannels * prod(KernelShape) = kHalfShift, again in range. + in.InputChannels = kHalfShift; + in.KernelShape[0] = 1; + in.KernelShape[1] = 1; + in.KernelShape[2] = 1; + // FilterCount > OutputSize selects MlasConvAlgorithmExpandThenGemm. + in.FilterCount = kHalfShift + 1; + // Non-trivial stride keeps us out of the pointwise GemmDirect early return + // even though AllPaddingIsZero remains true. + in.StrideShape[0] = 1; + in.StrideShape[1] = 1; + in.StrideShape[2] = 2; + + EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); +} + +// The MlasConvAlgorithmExpandThenGemmSegmented path multiplies BatchCount and +// GroupCount inside SafeInt before clamping TargetThreadCount; that product +// must throw on overflow rather than wrapping silently. +TEST(MlasConvPrepareSafeIntTest, BatchGroupProductOverflowsThrows) { + ConvPrepareInputs in; + in.BatchCount = kHalfShift; + in.GroupCount = kHalfShift; + // Keep every per-tensor accumulator small so the only SafeInt product that + // can fail is the BatchCount * GroupCount one inside MlasConvPrepare. + in.InputChannels = 1; + in.FilterCount = 1; // FilterCount <= OutputSize -> reaches the segmented branch. + in.InputShape[0] = 1; + in.InputShape[1] = 1; + in.InputShape[2] = 1; + in.KernelShape[0] = 1; + in.KernelShape[1] = 1; + in.KernelShape[2] = 1; + in.OutputShape[0] = 1; + in.OutputShape[1] = 1; + in.OutputShape[2] = 1; + // Non-trivial stride keeps us out of the pointwise GemmDirect early return. + in.StrideShape[0] = 1; + in.StrideShape[1] = 1; + in.StrideShape[2] = 2; + + EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); +} + +#endif // ORT_NO_EXCEPTIONS From 7720b9495d5be1460fdbafd9ee66c5751c4c1f23 Mon Sep 17 00:00:00 2001 From: Akshay Sonawane Date: Tue, 30 Jun 2026 12:12:19 -0700 Subject: [PATCH 3/3] Address review comments: gate SafeInt include and tests on BUILD_MLAS_NO_ONNXRUNTIME --- onnxruntime/core/mlas/lib/convolve.cpp | 7 ++++++ .../unittest/test_conv_prepare_safeint.cpp | 25 +++++++++++-------- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/onnxruntime/core/mlas/lib/convolve.cpp b/onnxruntime/core/mlas/lib/convolve.cpp index 04d7d570aa8c7..5d33d7dd7b734 100644 --- a/onnxruntime/core/mlas/lib/convolve.cpp +++ b/onnxruntime/core/mlas/lib/convolve.cpp @@ -15,7 +15,14 @@ Module Name: --*/ #include "mlasi.h" +#if defined(BUILD_MLAS_NO_ONNXRUNTIME) +// Standalone MLAS builds don't have access to the ORT-internal SafeInt +// wrapper; fall back to the SafeInt.hpp header directly (its default +// exception handler still throws on overflow). +#include "SafeInt.hpp" +#else #include "core/common/safeint.h" +#endif // // Define the number of working buffer elements required per thread. diff --git a/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp b/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp index 07224522a9285..f0555704caffe 100644 --- a/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp +++ b/onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp @@ -1,23 +1,26 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -// Tests for the SafeInt-guarded working-buffer/size products in -// MlasConvPrepare (see PR #29444). These cases construct attacker-controlled -// shape/count combinations whose products overflow size_t and verify that -// MlasConvPrepare throws onnxruntime::OnnxRuntimeException (raised by -// SafeInt's overflow handler) rather than silently producing an under-sized -// working buffer or wrapped tensor size. +// Tests for the SafeInt-guarded working-buffer/size products in MlasConvPrepare. +// These cases construct attacker-controlled shape/count combinations whose +// products overflow size_t and verify that MlasConvPrepare throws +// onnxruntime::OnnxRuntimeException (raised by SafeInt's overflow handler) +// rather than silently producing an under-sized working buffer or wrapped +// tensor size. #include "gtest/gtest.h" -#include "core/common/exceptions.h" #include "mlas.h" +// These tests rely on the ORT-internal SafeInt overflow handler +// (onnxruntime::OnnxRuntimeException) and on ORT exception support, so they +// are skipped in standalone MLAS builds and in no-exception configurations. +#if !defined(ORT_NO_EXCEPTIONS) && !defined(BUILD_MLAS_NO_ONNXRUNTIME) + +#include "core/common/exceptions.h" + #include #include -#include - -#ifndef ORT_NO_EXCEPTIONS namespace { @@ -192,4 +195,4 @@ TEST(MlasConvPrepareSafeIntTest, BatchGroupProductOverflowsThrows) { EXPECT_THROW(RunConvPrepare(in), onnxruntime::OnnxRuntimeException); } -#endif // ORT_NO_EXCEPTIONS +#endif // !defined(ORT_NO_EXCEPTIONS) && !defined(BUILD_MLAS_NO_ONNXRUNTIME)