diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc index 864e2d1623923..20fafb1139dbb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "embed_layer_norm.h" @@ -61,6 +63,15 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { int sequence_length = static_cast(input_dims[1]); size_t element_size = sizeof(T); + // Element offsets into the output are 32-bit on device; reject shapes whose element count would + // exceed the 32-bit indexable range. The maximum output write index is + // batch_size * sequence_length * hidden_size - 1, so this guard covers every device write site. + const int64_t output_element_count = + static_cast(batch_size) * sequence_length * hidden_size; + ORT_RETURN_IF_NOT(output_element_count <= static_cast(std::numeric_limits::max()), + "EmbedLayerNormalization: output element count (", output_element_count, + ") exceeds the supported 32-bit indexing range."); + const bool broadcast_position_ids = (nullptr != position_ids && position_ids->Shape()[0] == 1); return LaunchEmbedLayerNormKernel( diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 8557e326e5b15..865202801b2ce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/cuda/cuda_common.h" #include "core/common/narrow.h" #include "skip_layer_norm.h" @@ -77,6 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const return Status::OK(); } + // Element offsets into the output are 32-bit on device; reject shapes whose element count would + // exceed the 32-bit indexable range. The output shares the input shape, so input->Shape().Size() + // is the output element count. The maximum output write index is + // row_count * hidden_size - 1 == output element count - 1, so this guard covers every device write site. + const int64_t output_element_count = input->Shape().Size(); + ORT_RETURN_IF_NOT(output_element_count <= static_cast(std::numeric_limits::max()), + "SkipLayerNormalization: output element count (", output_element_count, + ") exceeds the supported 32-bit indexing range."); + typedef typename ToCudaType::MappedType CudaT; const int skip_size = onnxruntime::narrow(skip->Shape().Size());