Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,30 @@ Status CheckInputs(const T* input,
if (rotary_embedding_dim == 0) {
int cache_width = 0;
ORT_RETURN_IF_ERROR(detail::NarrowNonNegativeToInt32(cos_cache_dims[1], "cache_width", cache_width));

int effective_rotary_dim = 0;
ORT_RETURN_IF_ERROR(detail::CheckedMulToInt32(cache_width, 2, "effective_rotary_dim", effective_rotary_dim));

if (head_size == 0) {
ORT_RETURN_IF_ERROR(detail::CheckedMulToInt32(cache_width, 2, "head_size", head_size));
head_size = effective_rotary_dim;
}

// Rotary embedding is applied per head, so the inferred dimension must not exceed head_size.
if (head_size > 0 && effective_rotary_dim > head_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"RotaryEmbedding: cos_cache dimension (", cache_width,
" * 2 = ", effective_rotary_dim,
") exceeds head_size (", head_size,
") when rotary_embedding_dim is 0");
}

// Also guard against exceeding the full hidden_size (covers num_heads==0 / rank-3 without num_heads).
if (hidden_size > 0 && effective_rotary_dim > hidden_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"RotaryEmbedding: cos_cache dimension (", cache_width,
" * 2 = ", effective_rotary_dim,
") exceeds input hidden_size (", hidden_size,
") when rotary_embedding_dim is 0");
}
} else {
if (!transposed) {
Expand Down
62 changes: 60 additions & 2 deletions onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1034,7 +1034,7 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank3MalformedCacheWidth
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got 8",
"exceeds head_size",
{}, nullptr, &execution_providers);
}

Expand All @@ -1051,7 +1051,7 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank4MalformedCacheWidth
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got 8",
"exceeds head_size",
{}, nullptr, &execution_providers);
}

Expand Down Expand Up @@ -1172,5 +1172,63 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative_WebGPU_Pas
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that cos_cache dimension exceeding head_size is rejected when rotary_embedding_dim=0.
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsCosCacheExceedsHeadSize) {
int batch_size = 1;
int sequence_length = 1;
int hidden_size = 64;
int half_rotary_dim = 64; // makes cos_cache_dims[1]*2 = 128 > hidden_size
int max_sequence_length = 2;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));
test.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(1));
Comment thread
apsonawane marked this conversation as resolved.

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 42.0f));
test.AddInput<int64_t>("position_ids", {1}, {0});
test.AddInput<float>("cos_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 0.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 1.0f));
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"exceeds head_size", {}, nullptr, &execution_providers);
}

// Test that cos_cache dimension exceeding hidden_size is rejected when rotary_embedding_dim=0
// and head_size is inferred from the cache (rank-3 input without num_heads). This exercises the
// `effective_rotary_dim > hidden_size` guard rather than the `exceeds head_size` guard.
TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsCosCacheExceedsHiddenSize_NoNumHeads) {
int batch_size = 1;
int sequence_length = 1;
int hidden_size = 64;
int half_rotary_dim = 64; // cos_cache_dims[1]*2 = 128 > hidden_size; head_size inferred to 128
int max_sequence_length = 2;

OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));
// num_heads intentionally NOT set so head_size stays 0 on entry and is inferred from cos_cache.

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 42.0f));
test.AddInput<int64_t>("position_ids", {1}, {0});
test.AddInput<float>("cos_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 0.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 1.0f));
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"exceeds input hidden_size", {}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
30 changes: 30 additions & 0 deletions onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1412,5 +1412,35 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_OOB_InBatch_WebGPU_Passthr
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test that cos_cache dimension exceeding hidden_size is rejected when rotary_embedding_dim=0.
TEST(RotaryEmbeddingTest, RotaryEmbedding_RejectsCosCacheExceedsHiddenSize) {
// hidden_size = 64, cos_cache dim1 = 64 => effective rotary dim = 128 > 64
int batch_size = 1;
int sequence_length = 1;
int hidden_size = 64;
int half_rotary_dim = 64; // makes cos_cache_dims[1]*2 = 128 > hidden_size
int max_sequence_length = 2;

OpTester test("RotaryEmbedding", 23, onnxruntime::kOnnxDomain);
test.AddAttribute<int64_t>("interleaved", static_cast<int64_t>(0));
test.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(1));

test.AddInput<float>("input", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 42.0f));
test.AddInput<float>("cos_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 0.0f));
test.AddInput<float>("sin_cache", {max_sequence_length, half_rotary_dim},
std::vector<float>(max_sequence_length * half_rotary_dim, 1.0f));
test.AddInput<int64_t>("position_ids", {batch_size, sequence_length}, {0});
test.AddOutput<float>("output", {batch_size, sequence_length, hidden_size},
std::vector<float>(hidden_size, 0.0f));

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectFailure,
"Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2",
{}, nullptr, &execution_providers);
}

} // namespace test
} // namespace onnxruntime
Loading