Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/build.sh

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BLOrange-AMD @glen-amd TheRock CI flows do not use this file, so these changes would be irrelevant for it. Do you know if this file was invoked by the PyTorch NPI build flows?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only change is setting PYTORCH_ROCM_ARCH, I believe we should still keep this change for any legacy build.
In the ROCk this is controlled in the environment variable.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even for our rocm-ci legacy builds, we used to set PYTORCH_ROCM_ARCH env variable directly from the CI job parameters IIRC. So I'm not sure why/where the build-environment-based method was being used.

Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ case "$tag" in
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100"
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100;gfx1250"
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
Expand Down
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4ed888920c5a0871957f1cf912e557bc79fbe56c
9c610c781cb810a11bfcc9accba094550b189a5e
2 changes: 1 addition & 1 deletion .ci/docker/triton_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.0
3.7.0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jataylo / @iupaikov-amd to signoff on this (and the corresponding triton.txt change), since this is moving to a newer major version of triton, so just need confidence on compatibility for non-gfx1250 archs.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh.. I somehow mis-read yesterday morning. I thought we were already using 3.7.0 in release/2.11. 3.6 is too old to be compatible with gfx1250.

@jataylo / @iupaikov-amd - can you please guide on which UTs we can run as a sanity testing for this triton bump?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pragupta I have in my notes that the minimum UT suite would be test/inductor/test_torchinductor.py and test/inductor/test_max_autotune.py.

@naromero77amd naromero77amd Jun 16, 2026

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pragupta FWIW, we normally document PyTorch-Triton compability here:
https://amd.atlassian.net/wiki/spaces/MLSE/pages/1032521014/PyTorch+-+Triton+-+Team+responsibilities

but looks like we haven't kept the page up to date.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

concerned about which triton code we are pulling here.
commit - 9c610c781cb810a11bfcc9accba094550b189a5e belongs to which repo? Is it public?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, it is tip of https://github.com/ROCm/triton/commits/release/internal/3.7.x/ then if the UTs pass, then ok.

20 changes: 17 additions & 3 deletions .ci/pytorch/build.sh

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BLOrange-AMD @glen-amd TheRock CI flows do not use this file, so these changes would be irrelevant for it. Do you know if this file was invoked by the PyTorch NPI build flows?

Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ if [[ "$BUILD_ENVIRONMENT" == *vulkan* ]]; then
source /var/lib/jenkins/vulkansdk/setup-env.sh
fi

# Example BUILD_ENVIRONMENT: linux-noble-rocm-py3.12-gfx1250
if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
# hcc used to run out of memory, silently exiting without stopping
# the build process, leaving undefined symbols in the shared lib,
Expand All @@ -159,10 +160,23 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then
export MAX_JOBS=$(($(nproc) - 1))
fi

# Logic for multiple architectures based on the discriminator BUILD_ENVIRONMENT
# that is set by the workflow YAML and follows a consistent naming pattern.
if [[ -n "$CI" && -z "$PYTORCH_ROCM_ARCH" ]]; then
# Set ROCM_ARCH to gfx906 for CI builds, if user doesn't override.
echo "Limiting PYTORCH_ROCM_ARCH to gfx906 for CI builds"
export PYTORCH_ROCM_ARCH="gfx906"
if [[ "$BUILD_ENVIRONMENT" == *gfx1250* ]]; then
echo "Setting PYTORCH_ROCM_ARCH to gfx1250 for CI builds"
export PYTORCH_ROCM_ARCH="gfx1250"
elif [[ "$BUILD_ENVIRONMENT" == *mi355* ]] || [[ "$BUILD_ENVIRONMENT" == *gfx950* ]]; then
echo "Setting PYTORCH_ROCM_ARCH to gfx950 for CI builds"
export PYTORCH_ROCM_ARCH="gfx950"
elif [[ "$BUILD_ENVIRONMENT" == *mi300* ]] || [[ "$BUILD_ENVIRONMENT" == *gfx942* ]]; then
echo "Setting PYTORCH_ROCM_ARCH to gfx942 for CI builds"
export PYTORCH_ROCM_ARCH="gfx942"
else
# Set ROCM_ARCH to gfx906 for CI builds, if user doesn't override.
echo "Limiting PYTORCH_ROCM_ARCH to gfx906 for CI builds"
export PYTORCH_ROCM_ARCH="gfx906"
fi
fi

# hipify sources
Expand Down
2 changes: 1 addition & 1 deletion .circleci/scripts/binary_populate_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS:
# TODO: We don't need this anymore IIUC
export TORCH_PACKAGE_NAME='torch'

export USE_FBGEMM=1
export USE_FBGEMM=0

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rraminen @pragupta I don't think this file is used by theRock either. Do we want to do this only for gfx1250? If yes, we probably need to do this elsewhere, and only for gfx1250, since USE_FBGEMM is set to ON for release/2.11 builds in theRock today.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we use a similar strategy to cb4e545?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jithunnair-amd ,
should this setting of USE_FBGEMM=0 be in the pytorch build workflow (github wf) of ROCk.

export PIP_UPLOAD_FOLDER="$PIP_UPLOAD_FOLDER"
export DOCKER_IMAGE="$DOCKER_IMAGE"

Expand Down
4 changes: 4 additions & 0 deletions .github/actionlint.yaml

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed either?

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For those files (not only this one), if we are sure they are not used, feel free to disregard or handle them in whatever way you think is best.

Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,13 @@ self-hosted-runner:
# gfx942 runners
- linux.rocm.gpu.gfx942.1
- linux.rocm.gpu.gfx942.4
- linux.rocm.gfx942.docker-cache
# gfx950 runners
- linux.rocm.gpu.gfx950.1
- linux.rocm.gpu.gfx950.4
# gfx1250 runners
- linux.rocm.gpu.gfx1250.1
- linux.rocm.gpu.gfx1250.4
# Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors)
- macos-m1-stable
- macos-m1-14
Expand Down
77 changes: 77 additions & 0 deletions .github/workflows/inductor-rocm-gfx1250.yml

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not needed at all?

Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# The name of this file is subject to change to stay consistent with other .yml files.
#
# The MI355 workflow (.github/workflows/inductor-rocm-mi355.yml) uses:
# - _linux-build.yml and _rocm-test.yml reusable workflows
# - Build environment linux-noble-rocm-py3.12-mi355
# - Runner label linux.rocm.gpu.gfx950.1
# - Docker image ci-image:pytorch-linux-noble-rocm-n-py3
# - 2-shard test matrix for the inductor config
#
# The GFX1250 equivalent is following this exact pattern.

name: inductor-rocm-gfx1250

on:
push:
branches:
- main
- release/*
tags:
- ciflow/inductor-rocm-gfx1250/*
workflow_dispatch:

concurrency:
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }}
cancel-in-progress: true

permissions:
id-token: write
contents: read
actions: read

jobs:
target-determination:
if: github.repository_owner == 'pytorch'
name: before-test
uses: ./.github/workflows/target_determination.yml

get-label-type:
name: get-label-type
uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.11
if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }}
with:
triggering_actor: ${{ github.triggering_actor }}
issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }}
curr_branch: ${{ github.head_ref || github.ref_name }}
curr_ref_type: ${{ github.ref_type }}
opt_out_experiments: lf

linux-noble-rocm-py3_12-inductor-build:
name: linux-noble-rocm-py3.12-gfx1250
uses: ./.github/workflows/_linux-build.yml
needs: get-label-type
with:
runner_prefix: "${{ needs.get-label-type.outputs.label-type }}"
build-environment: linux-noble-rocm-py3.12-gfx1250
# Docker image stays the same as MI355 because ROCm image supports multiple arches.
docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3
# Set PYTORCH_ROCM_ARCH directly in the workflow YAML as an env variable,
# so build.sh never needs to parse BUILD_ENVIRONMENT.
#env-var-script: |
# export PYTORCH_ROCM_ARCH=gfx1250
test-matrix: |
{ include: [
{ config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1250.1" }, # It requires provisioning hardware.
{ config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1250.1" },
]}
secrets: inherit

linux-noble-rocm-py3_12-inductor-test:
name: linux-noble-rocm-py3.12-gfx1250
uses: ./.github/workflows/_rocm-test.yml
needs: linux-noble-rocm-py3_12-inductor-build
with:
build-environment: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.build-environment }}
docker-image: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.docker-image }}
test-matrix: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.test-matrix }}
secrets: inherit
10 changes: 7 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF)
cmake_dependent_option(USE_CUDSS "Use cuDSS" ON "USE_CUDA" OFF)
# USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here
cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON)
option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" OFF)

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pragupta @rraminen We probably don't want to turn this off wholesale for 2.11, and do it only for gfx1250 instead, if needed.

option(USE_KINETO "Use Kineto profiling library" ON)
option(USE_CUPTI_SO "Use CUPTI as a shared library" ON)
option(USE_GFLAGS "Use GFLAGS" OFF)
Expand Down Expand Up @@ -945,9 +945,13 @@ cmake_dependent_option(
OFF)


# TODO:
# MSLK related parts are missing that already exists upstream.
# gfx1250 for MSLK needs to be involved as well.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we please open an issue to track this? Because 2.11 branch uses MSLK by default, we are changing that here. We can probably just turn it off and use mslk until it's ready?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that is better to turn it off.
We need to have a ROCm fork branch of MSLK for 2.11.
FBGEMM (Li Li) was checking on this we should reach out to him.

IF(USE_ROCM AND ("gfx942" IN_LIST PYTORCH_ROCM_ARCH OR "gfx950" IN_LIST PYTORCH_ROCM_ARCH))
message(WARNING "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build")
set(USE_MSLK_DEFAULT ON)
message(WARNING "Setting USE_FBGEMM_GENAI for gfx942/gfx950 to ON by default, doing ROCM build")
set(USE_FBGEMM_GENAI_DEFAULT ON)
elseif(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32)
message(STATUS "Setting USE_MSLK to ON by default , doing CUDA build for SM100a")
set(USE_MSLK_DEFAULT ON)
Expand Down
15 changes: 11 additions & 4 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -268,8 +268,15 @@ if(USE_FLASH_ATTENTION)
CK_ENABLE_FP64
CK_ENABLE_FP8
CK_ENABLE_INT8
CK_USE_FNUZ_FP8
CK_USE_GFX94
#CK_USE_FNUZ_FP8
#CK_USE_GFX94
CK_USE_GFX1250
CK_USE_NATIVE_MX_SUPPORT
CK_GFX1250_SUPPORT
CK_GFX12_SUPPORT
CK_USE_OCP_FP8
CK_USE_WMMA
CK_USE_WMMA_FP8
CK_USE_XDL
__HIP_PLATFORM_AMD__=1
__HIP_PLATFORM_HCC__=1
Expand Down Expand Up @@ -430,10 +437,10 @@ IF(USE_MSLK)
list(PREPEND MSLK_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
endif()

# Only compile for gfx942 and gfx950.
# Only compile for gfx942, gfx950, and gfx1250.
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
foreach(ARCH gfx942 gfx950)
foreach(ARCH gfx942 gfx950 gfx1250)
if(${ARCH} IN_LIST PYTORCH_ROCM_ARCH)
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH})
endif()
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ at::BlasBackend Context::blasPreferredBackend() {
bool Context::ckSupported() {
#ifdef USE_ROCM
static const std::vector<std::string> supported_archs = {
"gfx90a", "gfx942", "gfx950"
"gfx90a", "gfx942", "gfx950", "gfx1250",
};
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,14 @@ void scaled_gemm(
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
#if ROCM_VERSION >= 70200
if (at::detail::getCUDAHooks().isGPUArch({"gfx1250"})) {
// TODO: add constraints based on hipblaslt internals
TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0),
"M, N must be multiples of 16 and K should be multiple of 128 for MX format. "
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
}
#elif (CUDA_VERSION < 12090) && !defined(USE_ROCM)
// hipblaslt supported row-wise before cublas, and did so their own way (via
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cuda/CUDAScaledBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#endif
#if ROCM_VERSION >= 70200
, "gfx1250"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ size_t parseChosenWorkspaceSize() {
val = c10::utils::get_env("ROCBLAS_WORKSPACE_CONFIG");
}
/* 32MiB default, 128MiB for gfx94x/gfx95x */
const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"});
const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
const bool gfx94_95_125 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95", "gfx125"});
const size_t default_size = gfx94_95_125 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
#else
/* :4096:2:16:8 default, 32MiB for Hopper */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ const std::vector<std::string>& CUDAHooks::getHipblasltPreferredArchs() const {
"gfx1200", "gfx1201",
#endif
#if ROCM_VERSION >= 70000
"gfx950"
"gfx950",
#endif
#if ROCM_VERSION >= 70200
"gfx1250"
#endif
};
return archs;
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
#if defined(USE_ROCM) && defined(__gfx942__)
// Extend the TWS (16) to GFX1250.
#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__))
// Similar check in launch_vectorized_kernel() as well. Both should be in sync.
constexpr int tws = 16;
#else
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/GroupedBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,7 +692,7 @@ std::optional<c10::ScalarType> out_dtype) {
bool use_fast_path = false;
// ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path.
#if defined(USE_ROCM_CK_GEMM)
if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) {
if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a", "gfx1250"})) {
use_fast_path = true;
}
#endif //USE_ROCM_CK_GEMM
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cuda/KernelUtils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

#if ROCM_VERSION < 60400
__device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) {
#if (defined(__gfx942__)) && \
// `__gfx1250__`-specific `s_wait_loadcnt(0)` path for committed store already there
#if (defined(__gfx942__) || defined(__gfx1250__)) && \
__has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16)
typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2;
static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw));
Expand Down
3 changes: 2 additions & 1 deletion aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,8 @@ template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
#if defined(USE_ROCM) && defined(__gfx942__)
// Extend the non-temporal load optimization to GFX1250.
#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__))
using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int;
if constexpr (sizeof(vec_t) == sizeof(int)) {
union {
Expand Down
41 changes: 41 additions & 0 deletions aten/src/ATen/native/cuda/RangeFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <ATen/native/RangeUtils.h>
#include <cmath>
#include <limits>
#if defined(USE_ROCM)
#include <algorithm>
#endif

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -48,12 +51,49 @@ __global__ void elementwise_kernel_with_index(index_t N, func_t f, typename func
}
}

#if defined(USE_ROCM)
// HIP does not support launches with gridDim.x * blockDim.x >= 2^32:
// depending on the ROCm version the launch returns
// hipErrorInvalidConfiguration or is accepted silently with the kernel
// never executing, leaving zero-initialized output. A grid-stride kernel
// with a fixed grid sized to device occupancy avoids the limit.
template<typename index_t, typename func_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void elementwise_kernel_with_index_grid_stride(
index_t N, func_t f,
typename function_traits<func_t>::result_type *data) {
index_t idx = static_cast<index_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const index_t stride = static_cast<index_t>(gridDim.x) * blockDim.x;
for (; idx < N; idx += stride) {
data[idx] = f(idx);
}
}
#endif

template<typename func_t>
void gpu_kernel_with_index(at::Tensor &output, func_t f) {
int64_t N = output.numel();
if (N == 0) {
return;
}
#if defined(USE_ROCM)
constexpr int blocks_per_sm = 4;
const int sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
const int64_t orig_grid = (N + block_work_size - 1) / block_work_size;
int64_t grid = std::min<int64_t>(
orig_grid, static_cast<int64_t>(sm_count) * blocks_per_sm);
grid = std::max<int64_t>(grid, 1);
auto stream = at::cuda::getCurrentCUDAStream();
using scalar_t = typename function_traits<func_t>::result_type;
if (N <= std::numeric_limits<int>::max()) {
elementwise_kernel_with_index_grid_stride<int><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
elementwise_kernel_with_index_grid_stride<int64_t><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#else
int64_t grid = (N + block_work_size - 1) / block_work_size;
auto stream = at::cuda::getCurrentCUDAStream();
using scalar_t = typename function_traits<func_t>::result_type;
Expand All @@ -64,6 +104,7 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) {
elementwise_kernel_with_index<int64_t><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}

} // namespace
Expand Down
Loading