Fix signed-int overflow in SamplingState::Init to prevent heap-buffer-overflow#29443
Fix signed-int overflow in SamplingState::Init to prevent heap-buffer-overflow#29443apsonawane wants to merge 5 commits into
Conversation
…erflow SamplingState<T>::Init computed int total_count = batch_size * vocab_size as a bare int*int multiply with model-controlled operands, then wrapped the already-overflowed result in SafeInt<size_t>. SafeInt rejected the negative-wrap case but silently accepted positive-wrap (e.g. 4 * 0x40000001 wraps to 4), under-sizing sorted_scores / cumulative_probs. The companion next_token_scores buffer sizes the same product correctly via SafeInt<size_t>(batch_size) * vocab_size, so the later memcpy in SamplingCpuHelper::Sample copies the large size into the small buffer -- a heap-buffer-overflow WRITE triggerable by a hostile .onnx model with a com.microsoft::Sampling node. Fix: compute the product in SafeInt's checked domain by casting an operand first, matching the pattern already used for next_token_scores. Apply the same operand-first pattern to the batch_size * max_iter site and to SafeInt<size_t>(batch_size + 1) (which itself could wrap in int).
There was a problem hiding this comment.
Pull request overview
This PR hardens SamplingState::Init in the generation/transformers greedy-search implementation by moving buffer element-count computations into SafeInt<size_t> so integer overflow can’t lead to under-allocation and downstream memory errors.
Changes:
- Compute
batch_size * vocab_sizeusingSafeInt<size_t>to prevent overflow before buffer allocation. - Reuse the checked
total_countacross CPU/CUDA allocations inSamplingState.
tianleiwu
left a comment
There was a problem hiding this comment.
The production change in SamplingState::Init is correct and a genuine hardening win. int total_count = batch_size * vocab_size; is replaced with a checked SafeInt<size_t> product before any value can positive-wrap, and the two remaining in-int sub-expressions (batch_size + 1, batch_size * max_iter) are fixed the same way. Since batch_size (BatchBeamSize()) and vocab_size are model-controlled ints, this is a reachable defensive fix.
One concern on the test file (inline): the added tests re-implement the SafeInt expression inside the test itself and never call the production code, so they cannot fail if someone reverts SamplingState::Init to int * int — which contradicts the test's own comment. They only prove that SafeInt<size_t> throws on negative operands, a property of SafeInt rather than of this PR.
Minor (optional): safeint.h already exposes SafeMul<size_t>(batch_size, vocab_size) which does exactly this multiply; not required since the explicit form matches the existing next_token_scores pattern.
Verdict: COMMENT — production fix looks good; please make the regression test actually exercise the production path (or correct the misleading comment).
| // const SafeInt<size_t> total_count = | ||
| // SafeInt<size_t>(batch_size) * SafeInt<size_t>(vocab_size); | ||
| size_t ComputeSamplingTotalCount(int batch_size, int vocab_size) { | ||
| return SafeInt<size_t>(batch_size) * SafeInt<size_t>(vocab_size); |
There was a problem hiding this comment.
This test is tautological: ComputeSamplingTotalCount is a copy of the SafeInt expression defined here in the test, and it never includes or calls greedy_search_impl_base.h. Reverting the production code to int * int or static_cast<size_t>(...) would leave this test green, which directly contradicts the header comment above ("They will fail if anyone reverts the production code..."). As written, these tests only verify that SafeInt<size_t> multiplication throws on negative operands — a property of SafeInt, not of this change.
Suggestion: extract the size computation into a small function in the production header (e.g. inline SafeInt<size_t> SamplingBufferElementCount(int batch_size, int vocab_size)) and have both Init and this test call it, so a revert genuinely breaks the test. If you keep the mirrored expression, please at least fix the comment so it no longer claims to catch a production revert.
This pull request improves the safety of buffer size calculations in the
SamplingStateinitialization logic by ensuring that all multiplications involvingbatch_sizeandvocab_sizeare safely performed usingSafeInt<size_t>. This prevents potential integer overflow bugs that could lead to under-allocated buffers and memory errors.Buffer allocation safety improvements:
batch_sizeandvocab_sizenow useSafeInt<size_t>to ensure checked arithmetic, preventing silent integer overflows that could cause heap-buffer-overflow issues. This includes allocations for both CPU and CUDA buffers inSamplingState. [1] [2]h_sampled_allnow also safely castsmax_itertosize_tbefore multiplication, further protecting against overflow.These changes make the code more robust and secure, especially when handling large or model-controlled input sizes.