From a4b483d73ef79afb5df9f4055d506bd6fb554423 Mon Sep 17 00:00:00 2001 From: titaiwangms Date: Thu, 25 Jun 2026 22:08:54 +0000 Subject: [PATCH] Reject CUDA BERT EmbedLayerNorm/SkipLayerNorm shapes exceeding 32-bit output indexing The CUDA EmbedLayerNormalization and SkipLayerNormalization kernels compute output write offsets (row_index * hidden_size) using 32-bit arithmetic. For very large output tensors the element count can exceed INT32_MAX and the offset would no longer be representable in 32 bits. Every output write index in these kernels is a pure function of the launch grid and hidden_size (no data-dependent write indexing), so the maximum index is exactly output_element_count - 1, which the host knows from the input shapes before launch. Add a host-side guard in each ComputeInternal that computes the output element count in 64-bit arithmetic and returns a clear error when it exceeds the supported 32-bit indexing range, instead of silently relying on the int32 kernels for shapes they cannot index. Kernels are unchanged (int32 baseline); no numeric behavior change for supported shapes. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc | 11 +++++++++++ onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc | 11 +++++++++++ 2 files changed, 22 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc index 864e2d1623923..20fafb1139dbb 100644 --- a/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cpu/bert/embed_layer_norm_helper.h" #include "embed_layer_norm.h" @@ -61,6 +63,15 @@ Status EmbedLayerNorm::ComputeInternal(OpKernelContext* context) const { int sequence_length = static_cast(input_dims[1]); size_t element_size = sizeof(T); + // Element offsets into the output are 32-bit on device; reject shapes whose element count would + // exceed the 32-bit indexable range. The maximum output write index is + // batch_size * sequence_length * hidden_size - 1, so this guard covers every device write site. + const int64_t output_element_count = + static_cast(batch_size) * sequence_length * hidden_size; + ORT_RETURN_IF_NOT(output_element_count <= static_cast(std::numeric_limits::max()), + "EmbedLayerNormalization: output element count (", output_element_count, + ") exceeds the supported 32-bit indexing range."); + const bool broadcast_position_ids = (nullptr != position_ids && position_ids->Shape()[0] == 1); return LaunchEmbedLayerNormKernel( diff --git a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc index 8557e326e5b15..865202801b2ce 100644 --- a/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc +++ b/onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc @@ -1,6 +1,8 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. +#include + #include "core/providers/cuda/cuda_common.h" #include "core/common/narrow.h" #include "skip_layer_norm.h" @@ -77,6 +79,15 @@ Status SkipLayerNorm::ComputeInternal(OpKernelContext* ctx) const return Status::OK(); } + // Element offsets into the output are 32-bit on device; reject shapes whose element count would + // exceed the 32-bit indexable range. The output shares the input shape, so input->Shape().Size() + // is the output element count. The maximum output write index is + // row_count * hidden_size - 1 == output element count - 1, so this guard covers every device write site. + const int64_t output_element_count = input->Shape().Size(); + ORT_RETURN_IF_NOT(output_element_count <= static_cast(std::numeric_limits::max()), + "SkipLayerNormalization: output element count (", output_element_count, + ") exceeds the supported 32-bit indexing range."); + typedef typename ToCudaType::MappedType CudaT; const int skip_size = onnxruntime::narrow(skip->Shape().Size());