From ae4bd77eb7054539ef3dbca099672a97a6961de4 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 30 Jun 2026 19:45:32 +0000 Subject: [PATCH 1/3] Fix GQA oob read in past buffer --- .../cpu/bert/group_query_attention.cc | 22 +++++++++++++++++++ .../group_query_attention_op_test.cc | 17 ++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index debda282eb4f1..d8d18c451d80f 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -145,6 +145,7 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Validate seqlens_k values before they are used as GEMM dimensions to prevent OOB access. { const int32_t* seqlens_k_data = seqlens_k->Data(); + const int past_kv_seqlen = parameters.seqlen_past_kv_cache; for (int b = 0; b < batch_size; b++) { if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= present_kv_seqlen) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -156,6 +157,27 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "seqlens_k[", b, "] = ", seqlens_k_data[b], " is too small for sequence_length ", sequence_length); } + // Bound the number of past KV rows copied out of the past key/value buffers. + // ConcatStateChunkGQA copies `past_seqlen` rows from the past buffer (sized + // past_kv_seqlen). The present buffer (validated above) can be larger than the past + // buffer when total_sequence_length exceeds the past sequence length, so the check + // above does not bound the past-side read. A large seqlens_k combined with a small + // past buffer would otherwise read past the end of the past key/value tensors. + if (past_kv_seqlen > 0) { + const int64_t total_seqlen_b = static_cast(seqlens_k_data[b]) + 1; + int64_t past_rows = 0; + if (parameters.kv_sequence_length == 0) { + past_rows = total_seqlen_b; // shared KV: the entire past cache is copied + } else if (!parameters.is_first_prompt) { + past_rows = total_seqlen_b - sequence_length; + } + if (past_rows > past_kv_seqlen) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "seqlens_k[", b, "] = ", seqlens_k_data[b], " requires ", past_rows, + " past KV rows, which exceeds the past buffer sequence length ", + past_kv_seqlen, "."); + } + } } } int q_hidden_size = parameters.hidden_size; diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index ef7a04bc67d4f..a5121affcea21 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -369,6 +369,23 @@ TEST(GroupQueryAttentionTest, NonPromptSeqlensKUnderflow_OOB) { /*past_seq_len=*/4); } +// Regression: present buffer large enough (total_seq_len passes the present-buffer check), +// but the past buffer is much smaller. ConcatStateChunkGQA would copy +// (seqlens_k + 1 - sequence_length) rows out of the small past buffer, reading past its end. +TEST(GroupQueryAttentionTest, SeqlensKExceedsPastBuffer_OOBRead) { + // present_kv_seqlen = max(total_seq_len=100, past_seq_len=2) = 100, so seqlens_k=50 passes the + // present-buffer check, but past_seqlen = 51 - 1 = 50 rows >> past buffer (2 rows) => OOB read. + RunGQASeqlensKTest( + /*seqlens_k_data=*/{50}, + /*total_seq_len=*/100, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "exceeds the past buffer sequence length", + /*provide_past=*/true, + /*past_seq_len=*/2); +} + // INT32_MAX seqlens_k: rejected by the >= present_kv_seqlen check. TEST(GroupQueryAttentionTest, Int32MaxSeqlensK_OOB) { RunGQASeqlensKTest( From b0e7209dce1d54abd75a61c63fde3b013b169c77 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Tue, 30 Jun 2026 20:45:09 +0000 Subject: [PATCH 2/3] refine check --- .../cpu/bert/group_query_attention.cc | 25 ++++++++----------- 1 file changed, 11 insertions(+), 14 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index d8d18c451d80f..ea8be5d28ae78 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -157,20 +157,17 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { "seqlens_k[", b, "] = ", seqlens_k_data[b], " is too small for sequence_length ", sequence_length); } - // Bound the number of past KV rows copied out of the past key/value buffers. - // ConcatStateChunkGQA copies `past_seqlen` rows from the past buffer (sized - // past_kv_seqlen). The present buffer (validated above) can be larger than the past - // buffer when total_sequence_length exceeds the past sequence length, so the check - // above does not bound the past-side read. A large seqlens_k combined with a small - // past buffer would otherwise read past the end of the past key/value tensors. - if (past_kv_seqlen > 0) { - const int64_t total_seqlen_b = static_cast(seqlens_k_data[b]) + 1; - int64_t past_rows = 0; - if (parameters.kv_sequence_length == 0) { - past_rows = total_seqlen_b; // shared KV: the entire past cache is copied - } else if (!parameters.is_first_prompt) { - past_rows = total_seqlen_b - sequence_length; - } + // Bound the number of past KV rows copied out of the past key/value buffers during + // token generation (decode). ConcatStateChunkGQA copies (seqlens_k + 1 - sequence_length) + // rows from the past buffer (sized past_kv_seqlen). The present-buffer check above does + // not bound this past-side read, because the present buffer can be larger than the past + // buffer when total_sequence_length exceeds the past sequence length. A large seqlens_k + // combined with a small past buffer would otherwise read past the end of the past tensors. + // Shared KV (kv_sequence_length == 0) appends no new KV and its past read is already + // bounded by the present-buffer check together with the total_sequence_length <= + // seqlen_past_kv_cache enforcement in the apply-attention paths, so it needs no check here. + if (past_kv_seqlen > 0 && parameters.kv_sequence_length != 0 && !parameters.is_first_prompt) { + const int64_t past_rows = static_cast(seqlens_k_data[b]) + 1 - sequence_length; if (past_rows > past_kv_seqlen) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "seqlens_k[", b, "] = ", seqlens_k_data[b], " requires ", past_rows, From e633ff2971277171b29dca12e702632eb7a0679b Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 1 Jul 2026 01:21:13 +0000 Subject: [PATCH 3/3] address feedbacks --- .../contrib_ops/cpu/bert/group_query_attention.cc | 3 ++- .../contrib_ops/group_query_attention_op_test.cc | 12 ++++++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc index ea8be5d28ae78..1080498831474 100644 --- a/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc @@ -166,7 +166,8 @@ Status GroupQueryAttention::Compute(OpKernelContext* context) const { // Shared KV (kv_sequence_length == 0) appends no new KV and its past read is already // bounded by the present-buffer check together with the total_sequence_length <= // seqlen_past_kv_cache enforcement in the apply-attention paths, so it needs no check here. - if (past_kv_seqlen > 0 && parameters.kv_sequence_length != 0 && !parameters.is_first_prompt) { + if (past_key != nullptr && past_value != nullptr && parameters.kv_sequence_length != 0 && + !parameters.is_first_prompt) { const int64_t past_rows = static_cast(seqlens_k_data[b]) + 1 - sequence_length; if (past_rows > past_kv_seqlen) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, diff --git a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc index a5121affcea21..e17aff159e6cc 100644 --- a/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/group_query_attention_op_test.cc @@ -386,6 +386,18 @@ TEST(GroupQueryAttentionTest, SeqlensKExceedsPastBuffer_OOBRead) { /*past_seq_len=*/2); } +TEST(GroupQueryAttentionTest, SeqlensKExceedsEmptyPastBuffer_OOBRead) { + RunGQASeqlensKTest( + /*seqlens_k_data=*/{50}, + /*total_seq_len=*/100, + /*batch_size=*/1, + /*sequence_length=*/1, + OpTester::ExpectResult::kExpectFailure, + "exceeds the past buffer sequence length", + /*provide_past=*/true, + /*past_seq_len=*/0); +} + // INT32_MAX seqlens_k: rejected by the >= present_kv_seqlen check. TEST(GroupQueryAttentionTest, Int32MaxSeqlensK_OOB) { RunGQASeqlensKTest(