Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxruntime/contrib_ops/cpu/bert/attention_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,14 @@ void PrepareMask(const int32_t* mask_index,
// mask_index is 1D: (B) or (2B) => (Bx)T

// Handle right-side padding: mask value at or after the end position will be mask_filter_value
int end_position = std::max(0, std::min(static_cast<int>(mask_index[b_i]), all_sequence_length));
int end_position = std::clamp(static_cast<int>(mask_index[b_i]), 0, all_sequence_length);
for (int m_i = end_position; m_i < all_sequence_length; m_i++) {
p_mask[m_i] = static_cast<T>(mask_filter_value);
}

// Handle left-side padding: mask value before the start position will be mask_filter_value
if (has_mask_start_position) {
int start_position = std::max(0, std::min(static_cast<int>(mask_index[b_i + batch_size]), all_sequence_length));
int start_position = std::clamp(static_cast<int>(mask_index[b_i + batch_size]), 0, all_sequence_length);
for (int m_i = 0; m_i < start_position; m_i++) {
p_mask[m_i] = static_cast<T>(mask_filter_value);
}
Expand Down
43 changes: 43 additions & 0 deletions onnxruntime/test/contrib_ops/attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1425,6 +1425,49 @@ TEST(ContribOpAttentionTest, AttentionBatch2LeftPaddingMaskIndex2) {
AttentionMaskType::MASK_1D_END_START);
}

// Verifies that out-of-range 1D mask_index values are clamped rather than
// causing out-of-bounds memory access. end_position > all_sequence_length
// must clamp to all_sequence_length (no right-side masking), and
// start_position < 0 must clamp to 0 (no left-side masking). Both cases
// should produce the same output as a fully-unmasked sequence.
TEST(ContribOpAttentionTest, AttentionMaskIndex1DClampOOB) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;

std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};

std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};

std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};

// end_position=999 is well above sequence_length=2, so it must clamp to 2
// (no right-side masking). Expected output equals the fully-unmasked case.
std::vector<int32_t> mask_index_data_end_oob = {999};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};

RunAttentionTest(input_data, weight_data, bias_data, mask_index_data_end_oob, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);

// start_position=-5 is below zero, so it must clamp to 0 (no left-side
// masking). end_position=2 keeps all tokens unmasked on the right.
// Expected output is again the fully-unmasked case.
std::vector<int32_t> mask_index_data_start_neg = {2, -5};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data_start_neg, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0,
nullptr, nullptr, AttentionMaskType::MASK_1D_END_START);
}

TEST(ContribOpAttentionTest, Attention3DMask) {
int batch_size = 2;
int sequence_length = 2;
Expand Down
Loading