diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index debda282eb4f1..1080498831474 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -145,6 +145,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Validate seqlens_k values before they are used as GEMM dimensions to prevent OOB access. { const int32_t* seqlens_k_data = seqlens_k->Data(); + const int past_kv_seqlen = parameters.seqlen_past_kv_cache; for (int b = 0; b < batch_size; b++) { if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= present_kv_seqlen) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -156,6 +157,25 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "seqlens_k[", b, "] = ", seqlens_k_data[b], " is too small for sequence_length ", sequence_length); } + // Bound the number of past KV rows copied out of the past key/value buffers during + // token generation (decode). ConcatStateChunkGQA copies (seqlens_k + 1 - sequence_length) + // rows from the past buffer (sized past_kv_seqlen). The present-buffer check above does + // not bound this past-side read, because the present buffer can be larger than the past + // buffer when total_sequence_length exceeds the past sequence length. A large seqlens_k + // combined with a small past buffer would otherwise read past the end of the past tensors. + // Shared KV (kv_sequence_length == 0) appends no new KV and its past read is already + // bounded by the present-buffer check together with the total_sequence_length <= + // seqlen_past_kv_cache enforcement in the apply-attention paths, so it needs no check here. + if (past_key != nullptr && past_value != nullptr && parameters.kv_sequence_length != 0 && + !parameters.is_first_prompt) { + const int64_t past_rows = static_cast(seqlens_k_data[b]) + 1 - sequence_length; + if (past_rows > past_kv_seqlen) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k[", b, "] = ", seqlens_k_data[b], " requires ", past_rows, + " past KV rows, which exceeds the past buffer sequence length ", + past_kv_seqlen, "."); + } + } } } int q_hidden_size = parameters.hidden_size; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index ef7a04bc67d4f..e17aff159e6cc 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -369,6 +369,35 @@ TEST(GroupQueryAttentionTest, NonPromptSeqlensKUnderflow_OOB) { /*past_seq_len=*/4); } +// Regression: present buffer large enough (total_seq_len passes the present-buffer check), +// but the past buffer is much smaller. ConcatStateChunkGQA would copy +// (seqlens_k + 1 - sequence_length) rows out of the small past buffer, reading past its end. +TEST(GroupQueryAttentionTest, SeqlensKExceedsPastBuffer_OOBRead) { + // present_kv_seqlen = max(total_seq_len=100, past_seq_len=2) = 100, so seqlens_k=50 passes the + // present-buffer check, but past_seqlen = 51 - 1 = 50 rows >> past buffer (2 rows) => OOB read. + RunGQASeqlensKTest( + /*seqlens_k_data=*/{50}, + /*total_seq_len=*/100, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "exceeds the past buffer sequence length", + /*provide_past=*/true, + /*past_seq_len=*/2); +} + +TEST(GroupQueryAttentionTest, SeqlensKExceedsEmptyPastBuffer_OOBRead) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{50}, + /*total_seq_len=*/100, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "exceeds the past buffer sequence length", + /*provide_past=*/true, + /*past_seq_len=*/0); +} + // INT32_MAX seqlens_k: rejected by the >= present_kv_seqlen check. TEST(GroupQueryAttentionTest, Int32MaxSeqlensK_OOB) { RunGQASeqlensKTest(