Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,32 @@ struct SamplingState : public ISamplingState<T> {
int seed,
bool is_cuda,
Stream* stream) {
int total_count = batch_size * vocab_size;
// Compute the product entirely in SafeInt's checked domain; an `int * int` multiply
// with model-controlled operands can silently positive-wrap before any SafeInt cast
// sees it, leading to under-allocated buffers (heap-buffer-overflow on the downstream
// memcpy in SamplingCpuHelper::Sample). Using SafeInt<size_t> on both operands also
// preserves sign/overflow checking for negative inputs (e.g. unvalidated vocab_size),
// which a static_cast<size_t> would silently turn into a huge value. Matches the
// pattern used for next_token_scores in GreedySearchState::Init below.
const SafeInt<size_t> total_count = SafeInt<size_t>(batch_size) * SafeInt<size_t>(vocab_size);

this->h_softmaxed_score = AllocateBuffer<float>(cpu_allocator, h_softmaxed_score_buffer_, SafeInt<size_t>(total_count), stream);
this->h_softmaxed_score = AllocateBuffer<float>(cpu_allocator, h_softmaxed_score_buffer_, total_count, stream);

this->generator = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};

if (is_cuda) {
this->d_index_in = AllocateBuffer<int>(allocator, d_index_in_buffer_, SafeInt<size_t>(total_count), stream);
this->d_index_out = AllocateBuffer<int>(allocator, d_index_out_buffer_, SafeInt<size_t>(total_count), stream);
this->d_offset = AllocateBuffer<int>(allocator, d_offset_buffer_, SafeInt<size_t>(batch_size + 1), stream);
this->d_sorted_score = AllocateBuffer<T>(allocator, d_sorted_score_buffer_, SafeInt<size_t>(total_count), stream);
this->d_sorted_softmaxed_score = AllocateBuffer<float>(allocator, d_sorted_softmaxed_score_buffer_, SafeInt<size_t>(total_count), stream);
this->d_softmaxed_score = AllocateBuffer<float>(allocator, d_softmaxed_score_buffer_, SafeInt<size_t>(total_count), stream);
this->d_index_in = AllocateBuffer<int>(allocator, d_index_in_buffer_, total_count, stream);
this->d_index_out = AllocateBuffer<int>(allocator, d_index_out_buffer_, total_count, stream);
this->d_offset = AllocateBuffer<int>(allocator, d_offset_buffer_, SafeInt<size_t>(batch_size) + 1, stream);
this->d_sorted_score = AllocateBuffer<T>(allocator, d_sorted_score_buffer_, total_count, stream);
this->d_sorted_softmaxed_score = AllocateBuffer<float>(allocator, d_sorted_softmaxed_score_buffer_, total_count, stream);
this->d_softmaxed_score = AllocateBuffer<float>(allocator, d_softmaxed_score_buffer_, total_count, stream);
this->d_sampled = AllocateBuffer<float>(allocator, d_sampled_buffer_, SafeInt<size_t>(batch_size), stream);
this->h_sampled_all = AllocateBuffer<float>(cpu_allocator, h_sampled_all_buffer_, SafeInt<size_t>(batch_size * max_iter), stream);
this->h_sampled_all = AllocateBuffer<float>(cpu_allocator, h_sampled_all_buffer_, SafeInt<size_t>(batch_size) * SafeInt<size_t>(max_iter), stream);
this->d_indices = AllocateBuffer<int32_t>(allocator, d_indices_buffer_, SafeInt<size_t>(batch_size), stream);
this->temp_storage_bytes = 0;
// TODO: Do not allocate this buffer if there's no presence_mask
this->d_presence_mask = AllocateBuffer<int>(allocator, d_presence_mask_buffer_, SafeInt<size_t>(total_count), stream);
this->d_presence_mask = AllocateBuffer<int>(allocator, d_presence_mask_buffer_, total_count, stream);

std::uniform_real_distribution<float> distribution(0.0, 1.0);
static_cast<void>(distribution(this->generator));
Expand All @@ -49,8 +56,8 @@ struct SamplingState : public ISamplingState<T> {
}
} else {
// TODO: Some buffer can be reused for CPU
this->sorted_scores = AllocateBuffer<T>(cpu_allocator, sorted_scores_buffer_, SafeInt<size_t>(total_count), stream);
this->cumulative_probs = AllocateBuffer<T>(cpu_allocator, cumulative_probs_buffer_, SafeInt<size_t>(total_count), stream);
this->sorted_scores = AllocateBuffer<T>(cpu_allocator, sorted_scores_buffer_, total_count, stream);
this->cumulative_probs = AllocateBuffer<T>(cpu_allocator, cumulative_probs_buffer_, total_count, stream);
}
}

Expand Down
72 changes: 72 additions & 0 deletions onnxruntime/test/contrib_ops/sampling_state_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// Regression tests for the buffer-size arithmetic in
// onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
// (SamplingState::Init). The bug being guarded was that a plain `int * int`
// multiply of `batch_size * vocab_size`, or a `static_cast<size_t>(...)` of a
// negative/unvalidated operand, could silently wrap and lead to under-allocated
// buffers (heap-buffer-overflow on the downstream memcpy in
// SamplingCpuHelper::Sample).
//
// `greedy_search_impl_base.h` is not a self-contained public header (it
// transitively requires internal framework types such as OpKernelContextInternal
// that are unavailable to test code), so these tests reproduce the exact
// SafeInt<size_t> expression that the fix introduced rather than constructing
// a SamplingState directly. They will fail if anyone reverts the production
// code to use `int * int` or `static_cast<size_t>` on operands that may be
// negative.

#include "gtest/gtest.h"

#include "core/common/common.h"
#include "core/common/safeint.h"

namespace onnxruntime {
namespace test {

namespace {

// Mirrors the production computation in SamplingState::Init:
// const SafeInt<size_t> total_count =
// SafeInt<size_t>(batch_size) * SafeInt<size_t>(vocab_size);
size_t ComputeSamplingTotalCount(int batch_size, int vocab_size) {
return SafeInt<size_t>(batch_size) * SafeInt<size_t>(vocab_size);
}

// Mirrors the production computation for `h_sampled_all`:
// SafeInt<size_t>(batch_size) * SafeInt<size_t>(max_iter)
size_t ComputeSampledAllCount(int batch_size, int max_iter) {
return SafeInt<size_t>(batch_size) * SafeInt<size_t>(max_iter);
}

} // namespace

// Sanity check: well-formed inputs produce the expected element count.
TEST(SamplingStateArithmeticTest, ProducesExpectedTotalCountForValidInputs) {
EXPECT_EQ(ComputeSamplingTotalCount(4, 32), static_cast<size_t>(4) * 32u);
EXPECT_EQ(ComputeSamplingTotalCount(1, 50257), static_cast<size_t>(50257));
EXPECT_EQ(ComputeSampledAllCount(8, 16), static_cast<size_t>(8) * 16u);
}

// A negative `vocab_size` (e.g. an unvalidated default of -1) used to be turned
// into SIZE_MAX by `static_cast<size_t>(vocab_size)`, yielding a multiplication
// result that either silently wrapped or requested an absurdly large buffer.
// SafeInt<size_t> rejects the negative-to-unsigned conversion up front.
TEST(SamplingStateArithmeticTest, ThrowsOnNegativeVocabSize) {
EXPECT_THROW(ComputeSamplingTotalCount(4, -1), OnnxRuntimeException);
}

// Symmetric check for a negative `batch_size`.
TEST(SamplingStateArithmeticTest, ThrowsOnNegativeBatchSize) {
EXPECT_THROW(ComputeSamplingTotalCount(-1, 32), OnnxRuntimeException);
}

// `max_iter` flows through the same SafeInt<size_t> path for the `h_sampled_all`
// allocation, so a negative value must also be rejected.
TEST(SamplingStateArithmeticTest, ThrowsOnNegativeMaxIter) {
EXPECT_THROW(ComputeSampledAllCount(4, -1), OnnxRuntimeException);
}

} // namespace test
} // namespace onnxruntime
Loading