diff --git a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h index 984e8f2490a73..91b4c9a736316 100644 --- a/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/rotary_embedding_helper.h @@ -115,8 +115,30 @@ Status CheckInputs(const T* input, if (rotary_embedding_dim == 0) { int cache_width = 0; ORT_RETURN_IF_ERROR(detail::NarrowNonNegativeToInt32(cos_cache_dims[1], "cache_width", cache_width)); + + int effective_rotary_dim = 0; + ORT_RETURN_IF_ERROR(detail::CheckedMulToInt32(cache_width, 2, "effective_rotary_dim", effective_rotary_dim)); + if (head_size == 0) { - ORT_RETURN_IF_ERROR(detail::CheckedMulToInt32(cache_width, 2, "head_size", head_size)); + head_size = effective_rotary_dim; + } + + // Rotary embedding is applied per head, so the inferred dimension must not exceed head_size. + if (head_size > 0 && effective_rotary_dim > head_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "RotaryEmbedding: cos_cache dimension (", cache_width, + " * 2 = ", effective_rotary_dim, + ") exceeds head_size (", head_size, + ") when rotary_embedding_dim is 0"); + } + + // Also guard against exceeding the full hidden_size (covers num_heads==0 / rank-3 without num_heads). + if (hidden_size > 0 && effective_rotary_dim > hidden_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "RotaryEmbedding: cos_cache dimension (", cache_width, + " * 2 = ", effective_rotary_dim, + ") exceeds input hidden_size (", hidden_size, + ") when rotary_embedding_dim is 0"); } } else { if (!transposed) { diff --git a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc index 880c10137f3fe..1b6235fca88fa 100644 --- a/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc +++ b/onnxruntime/test/contrib_ops/rotary_embedding_op_test.cc @@ -1034,7 +1034,7 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank3MalformedCacheWidth std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectFailure, - "Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got 8", + "exceeds head_size", {}, nullptr, &execution_providers); } @@ -1051,7 +1051,7 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsRank4MalformedCacheWidth std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectFailure, - "Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2, got 8", + "exceeds head_size", {}, nullptr, &execution_providers); } @@ -1172,5 +1172,63 @@ TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_PositionIds_Negative_WebGPU_Pas test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// Test that cos_cache dimension exceeding head_size is rejected when rotary_embedding_dim=0. +TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsCosCacheExceedsHeadSize) { + int batch_size = 1; + int sequence_length = 1; + int hidden_size = 64; + int half_rotary_dim = 64; // makes cos_cache_dims[1]*2 = 128 > hidden_size + int max_sequence_length = 2; + + OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", static_cast(0)); + test.AddAttribute("num_heads", static_cast(1)); + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 42.0f)); + test.AddInput("position_ids", {1}, {0}); + test.AddInput("cos_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 0.0f)); + test.AddInput("sin_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 1.0f)); + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "exceeds head_size", {}, nullptr, &execution_providers); +} + +// Test that cos_cache dimension exceeding hidden_size is rejected when rotary_embedding_dim=0 +// and head_size is inferred from the cache (rank-3 input without num_heads). This exercises the +// `effective_rotary_dim > hidden_size` guard rather than the `exceeds head_size` guard. +TEST(RotaryEmbeddingTest, ContribRotaryEmbedding_RejectsCosCacheExceedsHiddenSize_NoNumHeads) { + int batch_size = 1; + int sequence_length = 1; + int hidden_size = 64; + int half_rotary_dim = 64; // cos_cache_dims[1]*2 = 128 > hidden_size; head_size inferred to 128 + int max_sequence_length = 2; + + OpTester test("RotaryEmbedding", 1, onnxruntime::kMSDomain); + test.AddAttribute("interleaved", static_cast(0)); + // num_heads intentionally NOT set so head_size stays 0 on entry and is inferred from cos_cache. + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 42.0f)); + test.AddInput("position_ids", {1}, {0}); + test.AddInput("cos_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 0.0f)); + test.AddInput("sin_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 1.0f)); + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "exceeds input hidden_size", {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime diff --git a/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc index 2f51b8a7a5690..bd14fcfe1ec06 100644 --- a/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc +++ b/onnxruntime/test/providers/cpu/llm/rotary_embedding_op_test.cc @@ -1412,5 +1412,35 @@ TEST(RotaryEmbeddingTest, RotaryEmbedding_PositionIds_OOB_InBatch_WebGPU_Passthr test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } +// Test that cos_cache dimension exceeding hidden_size is rejected when rotary_embedding_dim=0. +TEST(RotaryEmbeddingTest, RotaryEmbedding_RejectsCosCacheExceedsHiddenSize) { + // hidden_size = 64, cos_cache dim1 = 64 => effective rotary dim = 128 > 64 + int batch_size = 1; + int sequence_length = 1; + int hidden_size = 64; + int half_rotary_dim = 64; // makes cos_cache_dims[1]*2 = 128 > hidden_size + int max_sequence_length = 2; + + OpTester test("RotaryEmbedding", 23, onnxruntime::kOnnxDomain); + test.AddAttribute("interleaved", static_cast(0)); + test.AddAttribute("num_heads", static_cast(1)); + + test.AddInput("input", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 42.0f)); + test.AddInput("cos_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 0.0f)); + test.AddInput("sin_cache", {max_sequence_length, half_rotary_dim}, + std::vector(max_sequence_length * half_rotary_dim, 1.0f)); + test.AddInput("position_ids", {batch_size, sequence_length}, {0}); + test.AddOutput("output", {batch_size, sequence_length, hidden_size}, + std::vector(hidden_size, 0.0f)); + + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectFailure, + "Input 'cos_cache' dimension 1 should be same as head_size / 2 or rotary_embedding_dim / 2", + {}, nullptr, &execution_providers); +} + } // namespace test } // namespace onnxruntime