Guard MlasConvPrepare working-buffer products with SafeInt#29444
Open
apsonawane wants to merge 3 commits into
Open
Guard MlasConvPrepare working-buffer products with SafeInt#29444apsonawane wants to merge 3 commits into
apsonawane wants to merge 3 commits into
Conversation
MlasConvPrepare computed the im2col working-buffer size as raw size_t multiplies of attacker-controlled tensor shape factors (OutputSize * K, and the thread-partition product TargetThreadCount * MLAS_CONV_WORKING_BUFFER_SIZE_PER_THREAD). The MSVC C26451 arithmetic-overflow lint that flags exactly this class was suppressed by #pragma warning(disable : 26451) rather than the arithmetic being guarded. On wrap, the CPU EP allocates a small working buffer and MlasConvIm2Col then writes the full unwrapped count of floats past it -- heap-buffer-overflow WRITE. Per-tensor element counts are already SafeInt-checked, but the product across tensors was not. Fix: accumulate dim products via SafeInt<size_t> and compute the final working-buffer size with SafeInt-guarded multiplies. SafeInt throws on overflow; the caller propagates as Status failure. Also guards the adjacent BatchCount * GroupCount product. Removes the C26451 suppression block. Adds #include "core/common/safeint.h" (already used elsewhere in core/mlas/lib).
Contributor
There was a problem hiding this comment.
Pull request overview
This PR hardens MLAS convolution preparation (MlasConvPrepare) against size_t overflow by moving several shape-dependent products into SafeInt<size_t> computations, and adds unit tests that validate overflow cases throw rather than silently producing wrapped sizes.
Changes:
- Switched
InputSize,OutputSize, andKaccumulation inMlasConvPreparetoSafeInt<size_t>and materialized them assize_tonly after checked multiplication. - Guarded working-buffer element-count products (
OutputSize * K,TargetThreadCount * per_thread) usingSafeInt<size_t>. - Added a new unit test file that constructs overflow-inducing shapes and validates exception behavior.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 4 comments.
| File | Description |
|---|---|
onnxruntime/core/mlas/lib/convolve.cpp |
Uses SafeInt<size_t> to validate shape-derived products and working-buffer size calculations in MlasConvPrepare. |
onnxruntime/test/mlas/unittest/test_conv_prepare_safeint.cpp |
Adds gtest coverage to ensure overflow cases throw (instead of wrapping) and that small “happy path” inputs still work. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request strengthens the safety of arithmetic operations in the convolution implementation by using
SafeIntto prevent integer overflows, especially in cases where tensor shapes may be attacker-controlled. It also removes some compiler-specific warning pragmas that are no longer needed due to these changes.Enhanced integer safety in convolution calculations:
size_tmultiplications withSafeInt<size_t>for calculating input/output sizes and kernel dimensions, ensuring that any arithmetic overflow is caught and handled appropriately. This is particularly important for preventing security vulnerabilities from attacker-controlled tensor shapes. [1] [2]SafeInt<size_t>for all working buffer size calculations, including those involving thread counts and batch/group products, to prevent buffer overflows and ensure correct memory allocation. [1] [2]Code cleanup and maintenance:
SafeIntnow guards against arithmetic overflow, making these warnings obsolete. [1] [2]core/common/safeint.hinclude to provide access to theSafeIntfunctionality.