Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/embed_layer_norm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <limits>

#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/embed_layer_norm_helper.h"
#include "embed_layer_norm.h"
Expand Down Expand Up @@ -61,6 +63,15 @@ Status EmbedLayerNorm<T>::ComputeInternal(OpKernelContext* context) const {
int sequence_length = static_cast<int>(input_dims[1]);
size_t element_size = sizeof(T);

// Element offsets into the output are 32-bit on device; reject shapes whose element count would
// exceed the 32-bit indexable range. The maximum output write index is
// batch_size * sequence_length * hidden_size - 1, so this guard covers every device write site.
const int64_t output_element_count =
static_cast<int64_t>(batch_size) * sequence_length * hidden_size;
ORT_RETURN_IF_NOT(output_element_count <= static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
"EmbedLayerNormalization: output element count (", output_element_count,
") exceeds the supported 32-bit indexing range.");

const bool broadcast_position_ids = (nullptr != position_ids && position_ids->Shape()[0] == 1);

return LaunchEmbedLayerNormKernel(
Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/cuda/bert/skip_layer_norm.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include <limits>

#include "core/providers/cuda/cuda_common.h"
#include "core/common/narrow.h"
#include "skip_layer_norm.h"
Expand Down Expand Up @@ -77,6 +79,15 @@ Status SkipLayerNorm<T, Simplified>::ComputeInternal(OpKernelContext* ctx) const
return Status::OK();
}

// Element offsets into the output are 32-bit on device; reject shapes whose element count would
// exceed the 32-bit indexable range. The output shares the input shape, so input->Shape().Size()
// is the output element count. The maximum output write index is
// row_count * hidden_size - 1 == output element count - 1, so this guard covers every device write site.
const int64_t output_element_count = input->Shape().Size();
ORT_RETURN_IF_NOT(output_element_count <= static_cast<int64_t>(std::numeric_limits<int32_t>::max()),
"SkipLayerNormalization: output element count (", output_element_count,
") exceeds the supported 32-bit indexing range.");

typedef typename ToCudaType<T>::MappedType CudaT;

const int skip_size = onnxruntime::narrow<int>(skip->Shape().Size());
Expand Down
Loading