Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/group_query_attention.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
// Validate seqlens_k values before they are used as GEMM dimensions to prevent OOB access.
{
const int32_t* seqlens_k_data = seqlens_k->Data<int32_t>();
const int past_kv_seqlen = parameters.seqlen_past_kv_cache;
for (int b = 0; b < batch_size; b++) {
if (seqlens_k_data[b] < 0 || seqlens_k_data[b] >= present_kv_seqlen) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
Expand All @@ -156,6 +157,25 @@ Status GroupQueryAttention<T>::Compute(OpKernelContext* context) const {
"seqlens_k[", b, "] = ", seqlens_k_data[b],
" is too small for sequence_length ", sequence_length);
}
// Bound the number of past KV rows copied out of the past key/value buffers during
// token generation (decode). ConcatStateChunkGQA copies (seqlens_k + 1 - sequence_length)
// rows from the past buffer (sized past_kv_seqlen). The present-buffer check above does
// not bound this past-side read, because the present buffer can be larger than the past
// buffer when total_sequence_length exceeds the past sequence length. A large seqlens_k
// combined with a small past buffer would otherwise read past the end of the past tensors.
// Shared KV (kv_sequence_length == 0) appends no new KV and its past read is already
// bounded by the present-buffer check together with the total_sequence_length <=
// seqlen_past_kv_cache enforcement in the apply-attention paths, so it needs no check here.
if (past_key != nullptr && past_value != nullptr && parameters.kv_sequence_length != 0 &&
!parameters.is_first_prompt) {
const int64_t past_rows = static_cast<int64_t>(seqlens_k_data[b]) + 1 - sequence_length;
if (past_rows > past_kv_seqlen) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"seqlens_k[", b, "] = ", seqlens_k_data[b], " requires ", past_rows,
" past KV rows, which exceeds the past buffer sequence length ",
past_kv_seqlen, ".");
}
}
}
}
int q_hidden_size = parameters.hidden_size;
Expand Down
29 changes: 29 additions & 0 deletions onnxruntime/test/contrib_ops/group_query_attention_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,35 @@ TEST(GroupQueryAttentionTest, NonPromptSeqlensKUnderflow_OOB) {
/*past_seq_len=*/4);
}

// Regression: present buffer large enough (total_seq_len passes the present-buffer check),
// but the past buffer is much smaller. ConcatStateChunkGQA would copy
// (seqlens_k + 1 - sequence_length) rows out of the small past buffer, reading past its end.
TEST(GroupQueryAttentionTest, SeqlensKExceedsPastBuffer_OOBRead) {
// present_kv_seqlen = max(total_seq_len=100, past_seq_len=2) = 100, so seqlens_k=50 passes the
// present-buffer check, but past_seqlen = 51 - 1 = 50 rows >> past buffer (2 rows) => OOB read.
RunGQASeqlensKTest(
/*seqlens_k_data=*/{50},
/*total_seq_len=*/100,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"exceeds the past buffer sequence length",
/*provide_past=*/true,
/*past_seq_len=*/2);
}

TEST(GroupQueryAttentionTest, SeqlensKExceedsEmptyPastBuffer_OOBRead) {
RunGQASeqlensKTest(
/*seqlens_k_data=*/{50},
/*total_seq_len=*/100,
/*batch_size=*/1,
/*sequence_length=*/1,
OpTester::ExpectResult::kExpectFailure,
"exceeds the past buffer sequence length",
/*provide_past=*/true,
/*past_seq_len=*/0);
}

// INT32_MAX seqlens_k: rejected by the >= present_kv_seqlen check.
TEST(GroupQueryAttentionTest, Int32MaxSeqlensK_OOB) {
RunGQASeqlensKTest(
Expand Down
Loading