diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 8dfcb50b916e7..be07e0663db2d 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -96,14 +96,14 @@ void PrepareMask(const int32_t* mask_index, // mask_index is 1D: (B) or (2B) => (Bx)T // Handle right-side padding: mask value at or after the end position will be mask_filter_value - int end_position = std::max(0, std::min(static_cast(mask_index[b_i]), all_sequence_length)); + int end_position = std::clamp(static_cast(mask_index[b_i]), 0, all_sequence_length); for (int m_i = end_position; m_i < all_sequence_length; m_i++) { p_mask[m_i] = static_cast(mask_filter_value); } // Handle left-side padding: mask value before the start position will be mask_filter_value if (has_mask_start_position) { - int start_position = std::max(0, std::min(static_cast(mask_index[b_i + batch_size]), all_sequence_length)); + int start_position = std::clamp(static_cast(mask_index[b_i + batch_size]), 0, all_sequence_length); for (int m_i = 0; m_i < start_position; m_i++) { p_mask[m_i] = static_cast(mask_filter_value); } diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index a4d059ced5d42..d0bee280d9972 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -1425,6 +1425,49 @@ TEST(ContribOpAttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { AttentionMaskType::MASK_1D_END_START); } +// Verifies that out-of-range 1D mask_index values are clamped rather than +// causing out-of-bounds memory access. end_position > all_sequence_length +// must clamp to all_sequence_length (no right-side masking), and +// start_position < 0 must clamp to 0 (no left-side masking). Both cases +// should produce the same output as a fully-unmasked sequence. +TEST(ContribOpAttentionTest, AttentionMaskIndex1DClampOOB) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // end_position=999 is well above sequence_length=2, so it must clamp to 2 + // (no right-side masking). Expected output equals the fully-unmasked case. + std::vector mask_index_data_end_oob = {999}; + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data_end_oob, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); + + // start_position=-5 is below zero, so it must clamp to 0 (no left-side + // masking). end_position=2 keeps all tokens unmasked on the right. + // Expected output is again the fully-unmasked case. + std::vector mask_index_data_start_neg = {2, -5}; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data_start_neg, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, false, false, false, 0, + nullptr, nullptr, AttentionMaskType::MASK_1D_END_START); +} + TEST(ContribOpAttentionTest, Attention3DMask) { int batch_size = 2; int sequence_length = 2;