From a400a902740031918abfe8e934073cc4b6ee8775 Mon Sep 17 00:00:00 2001 From: rraminen_amdeng Date: Mon, 15 Jun 2026 19:03:15 +0000 Subject: [PATCH 1/5] Cherry-pick bdbcbea8dbf09fb95685d499cd6b1de1e04fe4b0 (exclude CK submodule + triton pins) --- .ci/docker/build.sh | 2 +- .circleci/scripts/binary_populate_env.sh | 2 +- .github/actionlint.yaml | 4 ++++ CMakeLists.txt | 10 +++++++--- aten/src/ATen/CMakeLists.txt | 4 ++-- aten/src/ATen/Context.cpp | 2 +- aten/src/ATen/cuda/CUDABlas.cpp | 8 ++++++++ aten/src/ATen/cuda/CUDAScaledBlas.h | 3 +++ aten/src/ATen/cuda/CublasHandlePool.cpp | 4 ++-- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 5 ++++- aten/src/ATen/native/cuda/CUDALoops.cuh | 3 ++- aten/src/ATen/native/cuda/GroupedBlas.cpp | 2 +- aten/src/ATen/native/cuda/KernelUtils.cuh | 3 ++- aten/src/ATen/native/cuda/MemoryAccess.cuh | 3 ++- aten/src/ATen/native/cuda/ScaledBlas.cpp | 15 +++++++++------ aten/src/ATen/native/cuda/int4mm.cu | 19 ++++++++++++++++++- .../ATen/native/sparse/cuda/cuSPARSELtOps.cpp | 4 ++-- .../native/transformers/cuda/sdp_utils.cpp | 1 + .../hip/flash_attn/ck/launch_kernel_pt.hpp | 2 +- c10/core/AllocatorConfig.h | 8 +++++++- cmake/External/aotriton.cmake | 5 +++++ test/test_cuda.py | 8 +++++--- test/test_linalg.py | 6 ++++++ torch/_inductor/config.py | 3 ++- torch/cuda/_utils.py | 13 ++++++++----- torch/testing/_internal/common_cuda.py | 15 +++++++++++---- torch/testing/_internal/common_distributed.py | 2 +- 27 files changed, 116 insertions(+), 40 deletions(-) diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 7df6453c22da9..8e216eb1ea15a 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -190,7 +190,7 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100" + PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100;gfx1250" if [[ $tag =~ "benchmarks" ]]; then INDUCTOR_BENCHMARKS=yes fi diff --git a/.circleci/scripts/binary_populate_env.sh b/.circleci/scripts/binary_populate_env.sh index c25e351768607..425c52a08d062 100755 --- a/.circleci/scripts/binary_populate_env.sh +++ b/.circleci/scripts/binary_populate_env.sh @@ -152,7 +152,7 @@ export PYTORCH_EXTRA_INSTALL_REQUIREMENTS="${PYTORCH_EXTRA_INSTALL_REQUIREMENTS: # TODO: We don't need this anymore IIUC export TORCH_PACKAGE_NAME='torch' -export USE_FBGEMM=1 +export USE_FBGEMM=0 export PIP_UPLOAD_FOLDER="$PIP_UPLOAD_FOLDER" export DOCKER_IMAGE="$DOCKER_IMAGE" diff --git a/.github/actionlint.yaml b/.github/actionlint.yaml index 8ef1ee2240a2e..e27610c780cb6 100644 --- a/.github/actionlint.yaml +++ b/.github/actionlint.yaml @@ -61,9 +61,13 @@ self-hosted-runner: # gfx942 runners - linux.rocm.gpu.gfx942.1 - linux.rocm.gpu.gfx942.4 + - linux.rocm.gfx942.docker-cache # gfx950 runners - linux.rocm.gpu.gfx950.1 - linux.rocm.gpu.gfx950.4 + # gfx1250 runners + - linux.rocm.gpu.gfx1250.1 + - linux.rocm.gpu.gfx1250.4 # Org wise AWS `mac2.metal` runners (2020 Mac mini hardware powered by Apple silicon M1 processors) - macos-m1-stable - macos-m1-14 diff --git a/CMakeLists.txt b/CMakeLists.txt index 188f08bb272f5..fcfcb8dae33e5 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -263,7 +263,7 @@ cmake_dependent_option(USE_CUSPARSELT "Use cuSPARSELt" ON "USE_CUDA" OFF) cmake_dependent_option(USE_CUDSS "Use cuDSS" ON "USE_CUDA" OFF) # USE_ROCM is guarded against in Dependencies.cmake because USE_ROCM is not properly defined here cmake_dependent_option(USE_CUFILE "Use cuFile" ON "USE_CUDA AND NOT WIN32" OFF) -option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" ON) +option(USE_FBGEMM "Use FBGEMM (quantized 8-bit server operators)" OFF) option(USE_KINETO "Use Kineto profiling library" ON) option(USE_CUPTI_SO "Use CUPTI as a shared library" ON) option(USE_GFLAGS "Use GFLAGS" OFF) @@ -945,9 +945,13 @@ cmake_dependent_option( OFF) +# TODO: +# MSLK related parts are missing that already exists upstream. +# gfx1250 for MSLK needs to be involved as well. + IF(USE_ROCM AND ("gfx942" IN_LIST PYTORCH_ROCM_ARCH OR "gfx950" IN_LIST PYTORCH_ROCM_ARCH)) - message(WARNING "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build") - set(USE_MSLK_DEFAULT ON) + message(WARNING "Setting USE_FBGEMM_GENAI for gfx942/gfx950 to ON by default, doing ROCM build") + set(USE_FBGEMM_GENAI_DEFAULT ON) elseif(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32) message(STATUS "Setting USE_MSLK to ON by default , doing CUDA build for SM100a") set(USE_MSLK_DEFAULT ON) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index f1ac6246f4d56..38f6eb201713a 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -430,10 +430,10 @@ IF(USE_MSLK) list(PREPEND MSLK_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) endif() - # Only compile for gfx942 and gfx950. + # Only compile for gfx942, gfx950, and gfx1250. set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") - foreach(ARCH gfx942 gfx950) + foreach(ARCH gfx942 gfx950 gfx1250) if(${ARCH} IN_LIST PYTORCH_ROCM_ARCH) list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH}) endif() diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index c342590b58c42..ba63e73e1017c 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -505,7 +505,7 @@ at::BlasBackend Context::blasPreferredBackend() { bool Context::ckSupported() { #ifdef USE_ROCM static const std::vector supported_archs = { - "gfx90a", "gfx942", "gfx950" + "gfx90a", "gfx942", "gfx950", "gfx1250", }; for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) { if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) { diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index b9e66fea5ebdb..0d113df3d2e40 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -2020,6 +2020,14 @@ void scaled_gemm( "Got m=", m, ", n=", n, ", k=", k); } #endif + #if ROCM_VERSION >= 70200 + if (at::detail::getCUDAHooks().isGPUArch({"gfx1250"})) { + // TODO: add constraints based on hipblaslt internals + TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0), + "M, N must be multiples of 16 and K should be multiple of 128 for MX format. " + "Got m=", m, ", n=", n, ", k=", k); + } + #endif } #elif (CUDA_VERSION < 12090) && !defined(USE_ROCM) // hipblaslt supported row-wise before cublas, and did so their own way (via diff --git a/aten/src/ATen/cuda/CUDAScaledBlas.h b/aten/src/ATen/cuda/CUDAScaledBlas.h index 3fd0e2a6a3aae..90dde0384b87a 100644 --- a/aten/src/ATen/cuda/CUDAScaledBlas.h +++ b/aten/src/ATen/cuda/CUDAScaledBlas.h @@ -67,6 +67,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals #endif #if ROCM_VERSION >= 60500 "gfx950" +#endif +#if ROCM_VERSION >= 70200 + , "gfx1250" #endif }; return at::detail::getCUDAHooks().isGPUArch(archs); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index b5008bad832d0..75d9ae9ae586d 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -141,8 +141,8 @@ size_t parseChosenWorkspaceSize() { val = c10::utils::get_env("ROCBLAS_WORKSPACE_CONFIG"); } /* 32MiB default, 128MiB for gfx94x/gfx95x */ - const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"}); - const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; + const bool gfx94_95_125 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95", "gfx125"}); + const size_t default_size = gfx94_95_125 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 03a3a97525a43..4dfc3003e0cdb 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -558,7 +558,10 @@ const std::vector& CUDAHooks::getHipblasltPreferredArchs() const { "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 70000 - "gfx950" + "gfx950", +#endif +#if ROCM_VERSION >= 70200 + "gfx1250" #endif }; return archs; diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index b1031a95a3d0e..2d4846d1030bd 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -228,7 +228,8 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; constexpr auto io_size = calc_io_size(); -#if defined(USE_ROCM) && defined(__gfx942__) + // Extend the TWS (16) to GFX1250. +#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__)) // Similar check in launch_vectorized_kernel() as well. Both should be in sync. constexpr int tws = 16; #else diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index 70c33e27aa0a3..5a3ddb8ef94bf 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -692,7 +692,7 @@ std::optional out_dtype) { bool use_fast_path = false; // ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path. #if defined(USE_ROCM_CK_GEMM) - if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) { + if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a", "gfx1250"})) { use_fast_path = true; } #endif //USE_ROCM_CK_GEMM diff --git a/aten/src/ATen/native/cuda/KernelUtils.cuh b/aten/src/ATen/native/cuda/KernelUtils.cuh index 12feeb6d63af3..cf36105fb6e84 100644 --- a/aten/src/ATen/native/cuda/KernelUtils.cuh +++ b/aten/src/ATen/native/cuda/KernelUtils.cuh @@ -13,7 +13,8 @@ #if ROCM_VERSION < 60400 __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { -#if (defined(__gfx942__)) && \ +// `__gfx1250__`-specific `s_wait_loadcnt(0)` path for committed store already there +#if (defined(__gfx942__) || defined(__gfx1250__)) && \ __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 373b44cca7901..e96876c8ca149 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -187,7 +187,8 @@ template __device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { using vec_t = aligned_vector; auto *from = reinterpret_cast(base_ptr); -#if defined(USE_ROCM) && defined(__gfx942__) + // Extend the non-temporal load optimization to GFX1250. +#if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__)) using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; if constexpr (sizeof(vec_t) == sizeof(int)) { union { diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 223f10c53a318..e6e0554a35b7d 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -78,6 +78,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals #endif #if ROCM_VERSION >= 60500 "gfx950" +#endif +#if ROCM_VERSION >= 70200 + , "gfx1250" #endif }; return at::detail::getCUDAHooks().isGPUArch(archs); @@ -623,8 +626,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) { #ifdef USE_ROCM #if ROCM_VERSION >= 70000 - TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), - "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); int packed_factor = 1; if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) { @@ -1067,8 +1070,8 @@ _scaled_mxfp8_mxfp8( #ifdef USE_ROCM #if ROCM_VERSION >= 70000 - TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), - "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, @@ -1153,8 +1156,8 @@ _scaled_mxfp4_mxfp4( auto scaling_choice_b = ScalingType::BlockWise1x32; #if ROCM_VERSION >= 70000 - TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), - "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 8765bed83345a..07f3b9443ed40 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -127,7 +127,8 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { return diff == 0 ? 0 : uint32_t(Align) - diff; } -#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +// CDNA arch with MFMA and Warp-32 support +#if defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) #define CDNA2_OR_LATER 1 #else #define CDNA2_OR_LATER 0 @@ -146,6 +147,12 @@ static bool isCDNA2orLater(int index) { return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index); } +// Conceptual for now and subject to change +// gfx1250 (CDNA5 / CDNA-next / UDNA) +static bool isCDNA5orLater(int index) { + return at::detail::getCUDAHooks().isGPUArch({"gfx1250"}, index); +} + #else constexpr int32_t kWarpSize = 32; #endif @@ -1098,6 +1105,11 @@ at::Tensor _weight_int4pack_mm_cuda( A.device() == B.device() && A.device() == qScaleAndZeros.device()); #if defined(USE_ROCM) + if (isCDNA5orLater(A.device().index())) { + TORCH_CHECK(false, + "_weight_int4pack_mm_cuda is not yet supported on gfx1250. " + "A WMMA-based implementation is required for gfx1250.") + } if (!isCDNA2orLater(A.device().index())) { TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); } @@ -1293,6 +1305,11 @@ at::Tensor _convert_weight_to_int4pack_cuda( TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); #if defined(USE_ROCM) + if (isCDNA5orLater(in.device().index())) { + TORCH_CHECK(false, + "_convert_weight_to_int4pack_cuda is not yet supported on gfx1250. " + "A WMMA-based implementation is required for gfx1250.") + } if (!isCDNA2orLater(in.device().index())) { TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); } diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 9d735ac0f2c88..934c6a7c91403 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -30,7 +30,7 @@ static void initHipSparseLtSupport() { // Check only the first available device try { if (at::cuda::device_count() > 0) { - g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); + g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942", "gfx1250"}, 0); } } catch (const std::exception&) { // If an exception occurs during device property check, we assume hipSparseLt is not supported @@ -49,7 +49,7 @@ static bool isHipSparseLtSupported() { TORCH_CHECK( false, "hipSparseLt not supported on this device, supported architectures: " - "gfx950, gfx942. " + "gfx1250, gfx950, gfx942. " "required ROCM version: 6.4.0 or later."); } return g_hipSparseLtSupported; diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 67a5a296e2afe..ad20704e47c78 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -301,6 +301,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm121 = SMVersion<12, 1>; #if USE_ROCM #if USE_ROCM_ATTENTION +// TODO: gfx1250 if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { // User explicitly set CK as the flash attention backend. Return true for now // TODO: Flesh out sanity checks diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp index 5f7a16cffa1c7..3838b3e2edf2b 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp @@ -27,7 +27,7 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif __global__ void kentry_pt(Args... args) { -#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) +#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1250__)) Kernel{}(args...); #else CUDA_KERNEL_ASSERT(false && "Fatal! Attempting to call a CK SDPA kernel on unsupported hardware"); diff --git a/c10/core/AllocatorConfig.h b/c10/core/AllocatorConfig.h index d314c93f0494f..ef4e798939c5e 100644 --- a/c10/core/AllocatorConfig.h +++ b/c10/core/AllocatorConfig.h @@ -19,7 +19,13 @@ constexpr size_t kMinBlockSize = 512; // largest "small" allocation is 1 MiB constexpr size_t kSmallSize = 1048576; // allocations between 1 and 10 MiB may use kLargeBuffer -constexpr size_t kMinLargeAlloc = 10485760; +#if defined(USE_ROCM) && defined(__gfx1250__) +// Increase the buffer threshold for gfx1250 +// to avoid fragmentation on 432GB devices. +constexpr size_t kMinLargeAlloc = 20 * 1024 * 1024; // 20 MiB +#else +constexpr size_t kMinLargeAlloc = 10485760; // 10 MiB +#endif // round up large allocations to 2 MiB constexpr size_t kRoundLarge = 2097152; diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 93f852f014887..7b70d0a2107d6 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -41,6 +41,8 @@ if(NOT __AOTRITON_INCLUDED) "amd-gfx950" "amd-gfx11xx" "amd-gfx120x" + # TODO: Update on AOTriton integration + #"amd-gfx1250" ) set(__AOTRITON_IMAGE_SHA256_LIST "fe9f04b66bf52ac27cd025e1d89cfd04974dd3fb3ae076192f783641a4d80fdf" # amd-gfx90a @@ -48,6 +50,9 @@ if(NOT __AOTRITON_INCLUDED) "c1ba3bfe84217fd67df3dd1f8b67c80a7f7b33d0ad4d74b41d6567036e032ace" # amd-gfx950 "839299637fccb13fbe3e7823d57d1b2dcd0e0bed78abbcb7005ea5f4fd82b928" # amd-gfx11xx "0a4ff324bffdac0c2fde87a8a7f70563d3c84a80ad4e8f31345f2b40a1384e95" # amd-gfx120x + # TODO: Update when AOTriton publishes gfx1250 images. + # Until then, may need to set AOTRITON_INSTALL_FROM_SOURCE=1 to build from source. + #"0000000000000000000000000000000000000000000000000000000000000000" # amd-gfx1250 ) set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore set(__AOTRITON_Z "gz") diff --git a/test/test_cuda.py b/test/test_cuda.py index 5dd2a7346c79b..c068c44f65656 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -684,7 +684,7 @@ def _check_default(): gcn_arch = str( torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] ) - if gcn_arch in ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"]: + if gcn_arch in ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201", "gfx1250"]: self.assertTrue(default == torch._C._BlasBackend.Cublaslt) else: self.assertTrue(default == torch._C._BlasBackend.Cublas) @@ -754,7 +754,7 @@ def test_cublas_workspace_explicit_allocation(self): gcn_arch = str( torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] ) - if "gfx94" in gcn_arch or "gfx95" in gcn_arch: + if "gfx94" in gcn_arch or "gfx95" in gcn_arch or "gfx1250" in gcn_arch: default_workspace_size = 1024 * 128 * 1024 # :1024:128 else: default_workspace_size = ( @@ -8012,7 +8012,9 @@ def test_compile_kernel_large_shared_memory(self): # Test error handling with more than supported shared memory size if torch.version.hip: max_smem = ( - 65536 if get_device_properties().gcnArchName != "gfx950" else 160 * 1024 + 65536 + if get_device_properties().gcnArchName not in ["gfx950", "gfx1250"] + else 160 * 1024 ) else: max_smem = get_device_properties().shared_memory_per_block_optin diff --git a/test/test_linalg.py b/test/test_linalg.py index 25a157343db15..403af3890cd31 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -65,6 +65,12 @@ def blaslt_supported_device(): archs.extend(['gfx110', 'gfx120']) if ROCM_VERSION >= (6, 5): archs.append('gfx95') + # We extend this in a way not assuming the exact ROCm version + # where gfx1250 lands, so we don't pin a ROCm_VERSION gate here + # because the exact landing version may shift. + # Prefix checks rather than exact matches treat MI-series / gfx1250 + # as BLASLt-capabile. + archs.append('gfx125') for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e2fee26f45cc1..da4a60574c37c 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -2201,10 +2201,11 @@ class rocm: # Enable the CK backend for CDNA2 and CDNA3 only (for now) # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors - ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [ + ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950", "gfx1250"]] = [ "gfx90a", "gfx942", "gfx950", + "gfx1250", ] # Optimization level, use to balance compilation speed and runtime performance. diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index d842e8b56ef41..8797c1e1ff178 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -435,11 +435,14 @@ def set_shared_memory_config(self, shared_mem_bytes: int) -> None: device_props = torch.cuda.get_device_properties() # HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here if torch.version.hip: - # navi, CDNA1-CDNA3 allows a max of 64KB shared memory - # CDNA4 allows a max of 160KB shared memory - max_shared_mem = ( - 65536 if device_props.gcnArchName != "gfx950" else 160 * 1024 - ) + # navi, CDNA1-CDNA3 allows a max of 64KB shared memory, + # CDNA4 (gfx950) 160KB, and CDNA5 (gfx1250) 320KB. + if device_props.gcnArchName == "gfx950": + max_shared_mem = 160 * 1024 + elif device_props.gcnArchName == "gfx1250": + max_shared_mem = 320 * 1024 + else: + max_shared_mem = 65536 else: max_shared_mem = getattr( device_props, "shared_memory_per_block_optin", 49152 diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 3f539d586e8bc..840d69c1eb8b0 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -53,6 +53,10 @@ def evaluate_gfx_arch_within(arch_list): # Hence the matching should be done reversely return any(arch in effective_arch for arch in arch_list) +# CDNA 5 (CDNA-next / UDAN) arch helper +def CDNA5OrLater(): + return evaluate_gfx_arch_within(["gfx1250"]) + def CDNA3OrLater(): return evaluate_gfx_arch_within(["gfx942", "gfx950"]) @@ -61,7 +65,7 @@ def CDNA2OrLater(): def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950", "gfx1250"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": arch_list += ["gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) @@ -71,7 +75,7 @@ def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950", "gfx1250"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": arch_list += ["gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) @@ -139,6 +143,8 @@ def evaluate_platform_supports_fp8(): archs.extend(['gfx120']) if ROCM_VERSION >= (6, 5): archs.append('gfx95') + if ROCM_VERSION >= (7, 2): + archs.append('gfx1250') for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True @@ -151,7 +157,7 @@ def evaluate_platform_supports_fp8_grouped_gemm(): if torch.version.hip: if "USE_MSLK" not in torch.__config__.show(): return False - archs = ['gfx942', 'gfx950'] + archs = ['gfx942', 'gfx950', 'gfx1250'] for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True @@ -163,7 +169,8 @@ def evaluate_platform_supports_mx_gemm(): if torch.cuda.is_available(): if torch.version.hip: if ROCM_VERSION >= (7, 0): - return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName + gcn_name = torch.cuda.get_device_properties(0).gcnArchName + return 'gfx950' in gcn_name or ('gfx1250' in gcn_name and ROCM_VERSION >= (7, 2)) else: return SM100OrLater return False diff --git a/torch/testing/_internal/common_distributed.py b/torch/testing/_internal/common_distributed.py index 1705cd1398e56..06b07fa654f88 100644 --- a/torch/testing/_internal/common_distributed.py +++ b/torch/testing/_internal/common_distributed.py @@ -493,7 +493,7 @@ def requires_multicast_support(): def evaluate_platform_supports_symm_mem(): if TEST_CUDA: if TEST_WITH_ROCM: - arch_list = ["gfx942", "gfx950"] + arch_list = ["gfx942", "gfx950", "gfx1250"] for arch in arch_list: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True From 611da0622aa73e2c0f1bd0da9babc33835d0fd66 Mon Sep 17 00:00:00 2001 From: "Li, Bo" Date: Thu, 30 Apr 2026 12:03:31 -0500 Subject: [PATCH 2/5] CK - gfx1250 supports (#14) * CK - gfx1250 support (#5) * Enable ROCM_CK_SDPA build * [submodule] composable_kernel and aiter update (#172592) Summary: update ck to commit https://github.com/ROCm/composable_kernel/commit/fcc9372c009c8e0a23fece77b582da83b04a654f update aiter to commit https://github.com/ROCm/aiter/commit/9a469a608b2c10b7157df573a38d31e5bf4038b4 changes of caffe2/aten/src/ATen/CMakeLists.txt and caffe2/caffe2/CMakeLists.txt are adopted from https://github.com/pytorch/pytorch/pull/161759 updated caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp to match the ck version in https://github.com/ROCm/composable_kernel/blob/292df2719f28cd01464d5d059820684790c101da/include/ck_tile/host/kernel_launch.hpp update aiter fav3 bwd codegen according to changes in https://github.com/ROCm/aiter/pull/1573 update caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck mha fwd/bwd kernels according to the interfaces in https://github.com/ROCm/composable_kernel/tree/292df2719f28cd01464d5d059820684790c101da/example/ck_tile/01_fmha Differential Revision: D88991877 Pull Request resolved: https://github.com/pytorch/pytorch/pull/172592 Approved by: https://github.com/alugorey, https://github.com/izaitsevfb * Added MI450 supports and packages * Fix misalinged ck api * Replace aiter with ck for bwd * [ROCm] Bump AOTriton to 0.11.2b (#174105) Notable new features: * AOTriton 0.11.2b adds gfx1151/1152/1153 support. * Add precompiled AOTriton runtime for ROCM 7.2 * Match the sliding window attention behavior of `_flash_attention_forward/backward` with CUTLASS backend. Bug fixes: * Fixes #173204. Now all tests in `test/test_varlen_attention.py` are enabled on ROCm Notes: This replaces PR #173820 and #173469 Pull Request resolved: https://github.com/pytorch/pytorch/pull/174105 Approved by: https://github.com/jeffdaily * Fix philox data types for this version of ck * Update CK to use new gfx1250_pytorch branch * Add new gfx1250 compile flags for CK * add --targets to generate and a couple new compile flags * Remove default USE_ROCM_CK_SDPA --------- Co-authored-by: blorange-amd Co-authored-by: Yu Guo Co-authored-by: Xinya Zhang * Updated aiter module * Fixed merged error * Fixed additional merged error * Reset USE_ROCM_CK_SDPA config --------- Co-authored-by: LugoReyes, Andy Co-authored-by: Yu Guo Co-authored-by: Xinya Zhang --- aten/src/ATen/CMakeLists.txt | 11 ++- .../hip/flash_attn/ck/CMakeLists.txt | 16 ++-- .../hip/flash_attn/ck/fav_v3/CMakeLists.txt | 2 +- .../hip/flash_attn/ck/launch_kernel_pt.hpp | 2 +- .../hip/flash_attn/ck/mha_bwd_ck.hip | 79 +++++++++++-------- .../hip/flash_attn/ck/mha_fwd_ck.hip | 30 +++++-- .../hip/flash_attn/ck/mha_varlen_fwd_ck.hip | 23 +++++- requirements.txt | 1 + third_party/aiter | 2 +- 9 files changed, 111 insertions(+), 55 deletions(-) diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index 38f6eb201713a..b46ccff1ccb7a 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -268,8 +268,15 @@ if(USE_FLASH_ATTENTION) CK_ENABLE_FP64 CK_ENABLE_FP8 CK_ENABLE_INT8 - CK_USE_FNUZ_FP8 - CK_USE_GFX94 + #CK_USE_FNUZ_FP8 + #CK_USE_GFX94 + CK_USE_GFX1250 + CK_USE_NATIVE_MX_SUPPORT + CK_GFX1250_SUPPORT + CK_GFX12_SUPPORT + CK_USE_OCP_FP8 + CK_USE_WMMA + CK_USE_WMMA_FP8 CK_USE_XDL __HIP_PLATFORM_AMD__=1 __HIP_PLATFORM_HCC__=1 diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt index 819880cf3bc5c..3fb754cb8bf14 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/CMakeLists.txt @@ -1,6 +1,6 @@ # generate a list of kernels, but not actually emit files at config stage execute_process( - COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_blob_list.txt RESULT_VARIABLE ret ) @@ -10,7 +10,7 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd_splitkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_splitkv_blob_list.txt RESULT_VARIABLE ret ) @@ -20,7 +20,7 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd_appendkv --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/fwd_appendkv_blob_list.txt RESULT_VARIABLE ret ) @@ -30,7 +30,7 @@ if(ret AND NOT ret EQUAL 0) endif() execute_process( - COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py + COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api bwd --receipt 4 --list_blobs ${CMAKE_CURRENT_LIST_DIR}/bwd_blob_list.txt RESULT_VARIABLE ret ) @@ -40,28 +40,28 @@ if(ret AND NOT ret EQUAL 0) endif() # Generate the files for both fwd, fwd_splitkv, fwd_appendkv, and bwd -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd_splitkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_SPLITKV kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api fwd_appendkv --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} ) if(ret AND NOT ret EQUAL 0) message( FATAL_ERROR "CK Tile FMHA FAILED to generate FWD_APPENDKV kernels.") endif() -execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} +execute_process(COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 --api bwd --receipt 4 --output_dir ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret ) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt index e0eb28652c1e6..2bed9e565d200 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/fav_v3/CMakeLists.txt @@ -2,7 +2,7 @@ include(CMakePrintHelpers) # Generate AITER/CK Asm code execute_process( - COMMAND ${CMAKE_COMMAND} -E env "AITER_GPU_ARCHS=gfx942;gfx950" + COMMAND ${CMAKE_COMMAND} -E env "AITER_GPU_ARCHS=gfx942;gfx950;gfx1250" python3 ${CMAKE_SOURCE_DIR}/third_party/aiter/hsa/codegen.py -m fmha_v3_bwd --output_dir ${CMAKE_CURRENT_LIST_DIR} RESULT_VARIABLE ret ) diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp index 3838b3e2edf2b..6bbcf09d5c92d 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp @@ -14,7 +14,7 @@ __launch_bounds__(Kernel::kBlockSize, MinBlockPerCu) #endif __global__ void kentry_pt(Args... args) { -#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)) +#if (defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) || defined(__gfx1250__)) Kernel{}(args...); #else CUDA_KERNEL_ASSERT(false && "Fatal! Attempting to call a CK SDPA kernel on unsupported hardware"); diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip index 2d3692d9f98df..73b48585d5371 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_bwd_ck.hip @@ -3,13 +3,32 @@ ******************************************************************************/ #include -#include #include #include namespace pytorch_flash { -aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, +fmha_bwd_traits get_ck_fmha_bwd_traits(const mask_info &mask, + std::string dtype, + int head_size, + bool has_dropout, + bool enable_bias, + bool deterministic, + bool bias_requires_grad) +{ + return fmha_bwd_traits{head_size, + head_size, + dtype, + false, // is_group_mode (batch mode) + mask.type, + enable_bias ? bias_enum::elementwise_bias : bias_enum::no_bias, + bias_requires_grad, + has_dropout, + false, // is_store_randval + deterministic}; +} + +fmha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, std::string dtype, bool has_dropout, bool enable_bias, @@ -124,27 +143,7 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, float p_undrop = 1.0 - p_dropout; - return aiter::mha_bwd_args{ - // aiter args - static_cast(mask.type), - true, // use_asm_v3 - true, // v3_atomic_fp32 - 1, // v3_bf16_cvt - false, // v3_api_check - - // From ck fmha_bwd_traits - hdim, // hdim_q - hdim, // hdim_v - dtype, // data_type - false, // is_group_mode - static_cast(mask.type), // ck_mask_type - enable_bias ? static_cast(bias_enum::elementwise_bias) : static_cast(bias_enum::no_bias), - bias_requires_grad, // has_dbias - has_dropout, - false, // is_store_randval - deterministic, // is_deterministic - - // From ck fmha_bwd_args + return fmha_bwd_args{ q.data_ptr(), k.data_ptr(), v.data_ptr(), @@ -167,16 +166,18 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, nullptr, // cu_seqlen_k_ptr seqlen_q, seqlen_k, - b, // batch + b, seqlen_q, // max_seqlen_q seqlen_k, // max_seqlen_k - h, // nhead_q - h_k, // nhead_k - softmax_scale, // scale + hdim, // hdim_q + hdim, // hdim_v + h, // nhead_q + h_k, // nhead_k + softmax_scale, stride_q, stride_k, stride_v, - stride_attn_bias, // stride_bias + stride_attn_bias, stride_o, 0, // stride_randval stride_do, @@ -212,9 +213,10 @@ aiter::mha_bwd_args get_ck_fmha_bwd_args(const mask_info &mask, batch_stride_dv, batch_stride_dbias, split_stride_dq_acc, - mask.left, // window_size_left - mask.right, // window_size_right - p_dropout, // p_drop + mask.left, + mask.right, + static_cast(mask.type), + p_dropout, p_undrop, drop_seed_offset }; @@ -376,8 +378,8 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x } uint64_t drop_seed = 1, drop_offset = 0; - drop_seed = *philox_seed.data_ptr(); - drop_offset = *philox_offset.data_ptr(); + drop_seed = *philox_seed.data_ptr(); + drop_offset = *philox_offset.data_ptr(); auto drop_seed_offset = std::make_pair(&drop_seed, &drop_offset); @@ -385,6 +387,15 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x ck_tile::stream_config stream_config{stream}; dq.zero_(); // ck use atomic operation on dq + auto traits = + get_ck_fmha_bwd_traits(mask, + q_dtype_str, + head_size_8x, + is_dropout, + attn_bias_.has_value(), + deterministic, + bias_requires_grad); + auto args = get_ck_fmha_bwd_args( mask, @@ -416,7 +427,7 @@ mha_bwd_ck(const at::Tensor &dout, // batch_size x seqlen_q x p_dropout, drop_seed_offset); - float t = aiter::mha_bwd(args, stream_config); + float t = fmha_bwd(traits, args, stream_config); TORCH_CHECK(t >= 0, "invalid argument for fmha_bwd"); } else { diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip index 441589e70d763..a584591d6b18b 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_fwd_ck.hip @@ -87,6 +87,15 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, ck_tile::index_t batch_stride_bias = 0; ck_tile::index_t nhead_stride_bias = 0; + ck_tile::index_t nhead_stride_q_descale = 0; + ck_tile::index_t nhead_stride_k_descale = 0; + ck_tile::index_t nhead_stride_v_descale = 0; + ck_tile::index_t batch_stride_q_descale = 0; + ck_tile::index_t batch_stride_k_descale = 0; + ck_tile::index_t batch_stride_v_descale = 0; + ck_tile::index_t block_scale_size_q = 0; + ck_tile::index_t block_scale_size_kv = 0; + if (attn_bias_.has_value()) { auto a_b = attn_bias_.value(); CHECK_DEVICE(a_b); @@ -112,7 +121,8 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr - nullptr, // sink_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr seqlen_q, seqlen_k, b, @@ -136,6 +146,9 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -143,6 +156,9 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, mask.left, mask.right, 0, // sink_size @@ -150,7 +166,9 @@ fmha_fwd_args get_ck_fmha_fwd_args(bool has_lse, -1, // min_seqlen_q p_dropout, has_dropout_randval, - drop_seed_offset}; + drop_seed_offset, + block_scale_size_q, + block_scale_size_kv}; } std::tuple @@ -306,13 +324,13 @@ mha_fwd_ck(const at::Tensor &q, // batch_size x seqlen_q x hipLaunchKernelGGL( flash::ParsePhiloxCudaState, dim3(1), dim3(64), 0, at::cuda::getCurrentCUDAStream(), philox_args, rng_state_ptr); - seed_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[0])), at::dtype(at::kLong)); - offset_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[1])), at::dtype(at::kLong)); + seed_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[0])), at::dtype(at::kUInt64)); + offset_t = at::scalar_tensor(at::Scalar(static_cast(rng_state_ptr[1])), at::dtype(at::kUInt64)); } else { - seed_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); - offset_t = at::empty({}, at::dtype(at::kLong).device(at::kCUDA)); + seed_t = at::empty({}, at::dtype(at::kUInt64).device(at::kCUDA)); + offset_t = at::empty({}, at::dtype(at::kUInt64).device(at::kCUDA)); } std::optional attn_bias; diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip index 2a9d4899e8236..a7f7fe101aee6 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/ck/mha_varlen_fwd_ck.hip @@ -84,6 +84,16 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, ck_tile::index_t batch_stride_lse = has_lse ? softmax_lse.stride(0) : 0; ck_tile::index_t batch_stride_randval = 0; + + ck_tile::index_t nhead_stride_q_descale = 0; + ck_tile::index_t nhead_stride_k_descale = 0; + ck_tile::index_t nhead_stride_v_descale = 0; + ck_tile::index_t batch_stride_q_descale = 0; + ck_tile::index_t batch_stride_k_descale = 0; + ck_tile::index_t batch_stride_v_descale = 0; + ck_tile::index_t block_scale_size_q = 0; + ck_tile::index_t block_scale_size_kv = 0; + void *attn_bias_ptr = nullptr; ck_tile::index_t stride_attn_bias = 0; @@ -113,7 +123,8 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, nullptr, // seqlen_k_ptr nullptr, // cu_seqlen_q_ptr nullptr, // cu_seqlen_k_ptr - nullptr, // sink_ptr + nullptr, // block_scale_seqstart_q_ptr + nullptr, // block_scale_seqstart_k_ptr total_q, total_k, b, @@ -137,6 +148,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, nhead_stride_randval, nhead_stride_lse, nhead_stride_o, + nhead_stride_q_descale, + nhead_stride_k_descale, + nhead_stride_v_descale, batch_stride_q, batch_stride_k, batch_stride_v, @@ -144,6 +158,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, batch_stride_randval, batch_stride_lse, batch_stride_o, + batch_stride_q_descale, + batch_stride_k_descale, + batch_stride_v_descale, mask.left, mask.right, 0, // sink_size @@ -151,7 +168,9 @@ fmha_fwd_args get_ck_fmha_varlen_fwd_args(bool has_lse, -1, // min_seqlen_q p_dropout, has_dropout_randval, - drop_seed_offset}; + drop_seed_offset, + block_scale_size_q, + block_scale_size_kv}; } std::tuple diff --git a/requirements.txt b/requirements.txt index ceb41d722e320..4ae3e31656fe1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -18,6 +18,7 @@ numpy==2.1.2 ; python_version > "3.9" and python_version < "3.14" numpy==2.4.3 ; python_version >= "3.14" optree==0.13.0 ; python_version < "3.14" optree==0.17.0 ; python_version >= "3.14" +pandas psutil==7.2.2 spin==0.17 sympy==1.13.3 diff --git a/third_party/aiter b/third_party/aiter index 9a469a608b2c1..bfa86a0b25652 160000 --- a/third_party/aiter +++ b/third_party/aiter @@ -1 +1 @@ -Subproject commit 9a469a608b2c10b7157df573a38d31e5bf4038b4 +Subproject commit bfa86a0b25652b9b0da5c3e3692136789dc4f984 From aae31f5f8136236e6e7a1b2b55a0524c2996e8ca Mon Sep 17 00:00:00 2001 From: "Vasishta, Aaryaman (Jam)" Date: Fri, 8 May 2026 03:26:34 +0900 Subject: [PATCH 3/5] [ROCm] Fix large ROCm arange launch (#182657) (#16) Fix `torch.arange` (and the other range factories sharing this kernel) for very large outputs on ROCm. `torch.arange(N)` with `N >= 2^32` fails on ROCm because `hipLaunchKernel` does not support `gridDim.x * blockDim.x >= 2^32` for the per-thread kernel `aten/src/ATen/native/cuda/RangeFactories.cu` previously used. Depending on the ROCm version the launch returns `hipErrorInvalidConfiguration` or is accepted silently with the kernel never executing, leaving zero-initialized output. Concrete repro: `torch.arange(2 ** 32 + 1, device="cuda", dtype=torch.int32)`. The fix replaces the per-thread launch on the ROCm path with a grid-stride loop that fixes the grid at `sm_count * 4` blocks, so the launch limit is no longer load-bearing for correctness regardless of `N`. The non-ROCm path is untouched. On MI250X the grid-stride kernel matches the per-thread kernel within noise at `N=1024` and is 24-60% faster from `N=1M` up across `int32`, `int64`, and `float32`. On MI300X the grid-stride kernel matches within noise at `N=1024` and `N=1M`, and is 2-5x faster from `N=64M` up across `int32`, `int64`, and `float32`. The 64-bit-indexing test is extended to also cover `N = 2^32 + 1` and `N = 2^33 + 1` on ROCm when memory permits. Pull Request resolved: https://github.com/pytorch/pytorch/pull/182657 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily --- aten/src/ATen/native/cuda/RangeFactories.cu | 41 +++++++++++++++++++++ test/test_tensor_creation_ops.py | 18 +++++++++ 2 files changed, 59 insertions(+) diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 9d7ead7e49892..bc8decb5469ec 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -8,6 +8,9 @@ #include #include #include +#if defined(USE_ROCM) +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -48,12 +51,49 @@ __global__ void elementwise_kernel_with_index(index_t N, func_t f, typename func } } +#if defined(USE_ROCM) +// HIP does not support launches with gridDim.x * blockDim.x >= 2^32: +// depending on the ROCm version the launch returns +// hipErrorInvalidConfiguration or is accepted silently with the kernel +// never executing, leaving zero-initialized output. A grid-stride kernel +// with a fixed grid sized to device occupancy avoids the limit. +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void elementwise_kernel_with_index_grid_stride( + index_t N, func_t f, + typename function_traits::result_type *data) { + index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const index_t stride = static_cast(gridDim.x) * blockDim.x; + for (; idx < N; idx += stride) { + data[idx] = f(idx); + } +} +#endif + template void gpu_kernel_with_index(at::Tensor &output, func_t f) { int64_t N = output.numel(); if (N == 0) { return; } +#if defined(USE_ROCM) + constexpr int blocks_per_sm = 4; + const int sm_count = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int64_t orig_grid = (N + block_work_size - 1) / block_work_size; + int64_t grid = std::min( + orig_grid, static_cast(sm_count) * blocks_per_sm); + grid = std::max(grid, 1); + auto stream = at::cuda::getCurrentCUDAStream(); + using scalar_t = typename function_traits::result_type; + if (N <= std::numeric_limits::max()) { + elementwise_kernel_with_index_grid_stride<<>>(N, f, output.mutable_data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + elementwise_kernel_with_index_grid_stride<<>>(N, f, output.mutable_data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +#else int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); using scalar_t = typename function_traits::result_type; @@ -64,6 +104,7 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) { elementwise_kernel_with_index<<>>(N, f, output.mutable_data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } +#endif } } // namespace diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 287ee3cb7e421..87830f8b0750f 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -2789,6 +2789,24 @@ def test_range_factories_64bit_indexing(self, device): self.assertEqual(t[-1].item(), 2) del t + # On ROCm, launches with gridDim.x * blockDim.x >= 2^32 are not + # supported and either return hipErrorInvalidConfiguration or fail + # silently. Exercise the just-over case (~16 GB at int32) and a + # far-above case (~33 GB) when memory permits. arange computes + # int64 values then casts down, so for int32 the trailing + # 2^32 - 2, 2^32 - 1, 2^32 wrap to -2, -1, 0. + if TEST_WITH_ROCM: + for bigint in (2 ** 32 + 1, 2 ** 33 + 1): + free, _ = torch.cuda.mem_get_info(device) + if free < bigint * 4 + (3 << 30): + continue + t = torch.arange(bigint, dtype=torch.int32, device=device) + self.assertEqual(t.numel(), bigint) + self.assertEqual( + t[-3:].cpu(), torch.tensor([-2, -1, 0], dtype=torch.int32) + ) + del t + @expectedFailureMeta # RuntimeError: The tensor has a non-zero number of elements @onlyNativeDeviceTypes def test_tensor_ctor_device_inference(self, device): From 1bc755ac30d6a8b5c183eafc823b8fe307a77417 Mon Sep 17 00:00:00 2001 From: "Cao, Glen" Date: Tue, 26 May 2026 13:35:23 -0700 Subject: [PATCH 4/5] Temp 2.11 1250 tdm (#20) * TDM on release/2.11 for bring-up based on careful selection * Triton commit: Upstream fe0c38b5262c0447fed6df0d37e02cb8ea75deb4 -> AMD-ROCm-Internal Triton 250bb5d5b821377f49dc2d83d87ded75b952f0f7; Consequence: Triton TDM support may miss. * Refinement according to reviewers' comments * Added/modified UT cases; NUM_STAGES issue of ineffectiveness * A couple of changes to related UTs * Got rid of configs like `waves_per_cu=2` --- .ci/pytorch/build.sh | 20 +- .github/workflows/inductor-rocm-gfx1250.yml | 77 ++++++++ test/inductor/test_flex_attention.py | 14 ++ test/inductor/test_max_autotune.py | 118 ++++++++++++ .../_inductor/codegen/rocm/compile_command.py | 4 + torch/_inductor/config.py | 23 +++ torch/_inductor/kernel/flex/common.py | 7 +- torch/_inductor/kernel/flex/flex_attention.py | 18 +- torch/_inductor/kernel/flex/flex_decoding.py | 12 +- torch/_inductor/kernel/mm.py | 72 +++++++- .../templates/triton_persistent_mm.py.jinja | 75 ++++++++ torch/_inductor/template_heuristics/triton.py | 172 +++++++++++++++++- torch/_inductor/utils.py | 112 ++++++++++++ 13 files changed, 709 insertions(+), 15 deletions(-) create mode 100644 .github/workflows/inductor-rocm-gfx1250.yml create mode 100644 torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja diff --git a/.ci/pytorch/build.sh b/.ci/pytorch/build.sh index eb3529a0c43f3..fce87199941fc 100755 --- a/.ci/pytorch/build.sh +++ b/.ci/pytorch/build.sh @@ -150,6 +150,7 @@ if [[ "$BUILD_ENVIRONMENT" == *vulkan* ]]; then source /var/lib/jenkins/vulkansdk/setup-env.sh fi +# Example BUILD_ENVIRONMENT: linux-noble-rocm-py3.12-gfx1250 if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then # hcc used to run out of memory, silently exiting without stopping # the build process, leaving undefined symbols in the shared lib, @@ -159,10 +160,23 @@ if [[ "$BUILD_ENVIRONMENT" == *rocm* ]]; then export MAX_JOBS=$(($(nproc) - 1)) fi + # Logic for multiple architectures based on the discriminator BUILD_ENVIRONMENT + # that is set by the workflow YAML and follows a consistent naming pattern. if [[ -n "$CI" && -z "$PYTORCH_ROCM_ARCH" ]]; then - # Set ROCM_ARCH to gfx906 for CI builds, if user doesn't override. - echo "Limiting PYTORCH_ROCM_ARCH to gfx906 for CI builds" - export PYTORCH_ROCM_ARCH="gfx906" + if [[ "$BUILD_ENVIRONMENT" == *gfx1250* ]]; then + echo "Setting PYTORCH_ROCM_ARCH to gfx1250 for CI builds" + export PYTORCH_ROCM_ARCH="gfx1250" + elif [[ "$BUILD_ENVIRONMENT" == *mi355* ]] || [[ "$BUILD_ENVIRONMENT" == *gfx950* ]]; then + echo "Setting PYTORCH_ROCM_ARCH to gfx950 for CI builds" + export PYTORCH_ROCM_ARCH="gfx950" + elif [[ "$BUILD_ENVIRONMENT" == *mi300* ]] || [[ "$BUILD_ENVIRONMENT" == *gfx942* ]]; then + echo "Setting PYTORCH_ROCM_ARCH to gfx942 for CI builds" + export PYTORCH_ROCM_ARCH="gfx942" + else + # Set ROCM_ARCH to gfx906 for CI builds, if user doesn't override. + echo "Limiting PYTORCH_ROCM_ARCH to gfx906 for CI builds" + export PYTORCH_ROCM_ARCH="gfx906" + fi fi # hipify sources diff --git a/.github/workflows/inductor-rocm-gfx1250.yml b/.github/workflows/inductor-rocm-gfx1250.yml new file mode 100644 index 0000000000000..33449a7d52bf1 --- /dev/null +++ b/.github/workflows/inductor-rocm-gfx1250.yml @@ -0,0 +1,77 @@ +# The name of this file is subject to change to stay consistent with other .yml files. +# +# The MI355 workflow (.github/workflows/inductor-rocm-mi355.yml) uses: +# - _linux-build.yml and _rocm-test.yml reusable workflows +# - Build environment linux-noble-rocm-py3.12-mi355 +# - Runner label linux.rocm.gpu.gfx950.1 +# - Docker image ci-image:pytorch-linux-noble-rocm-n-py3 +# - 2-shard test matrix for the inductor config +# +# The GFX1250 equivalent is following this exact pattern. + +name: inductor-rocm-gfx1250 + +on: + push: + branches: + - main + - release/* + tags: + - ciflow/inductor-rocm-gfx1250/* + workflow_dispatch: + +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }}-${{ github.event_name == 'workflow_dispatch' }}-${{ github.event_name == 'schedule' }} + cancel-in-progress: true + +permissions: + id-token: write + contents: read + actions: read + +jobs: + target-determination: + if: github.repository_owner == 'pytorch' + name: before-test + uses: ./.github/workflows/target_determination.yml + + get-label-type: + name: get-label-type + uses: pytorch/pytorch/.github/workflows/_runner-determinator.yml@release/2.11 + if: ${{ (github.event_name != 'schedule' || github.repository == 'pytorch/pytorch') && github.repository_owner == 'pytorch' }} + with: + triggering_actor: ${{ github.triggering_actor }} + issue_owner: ${{ github.event.pull_request.user.login || github.event.issue.user.login }} + curr_branch: ${{ github.head_ref || github.ref_name }} + curr_ref_type: ${{ github.ref_type }} + opt_out_experiments: lf + + linux-noble-rocm-py3_12-inductor-build: + name: linux-noble-rocm-py3.12-gfx1250 + uses: ./.github/workflows/_linux-build.yml + needs: get-label-type + with: + runner_prefix: "${{ needs.get-label-type.outputs.label-type }}" + build-environment: linux-noble-rocm-py3.12-gfx1250 + # Docker image stays the same as MI355 because ROCm image supports multiple arches. + docker-image-name: ci-image:pytorch-linux-noble-rocm-n-py3 + # Set PYTORCH_ROCM_ARCH directly in the workflow YAML as an env variable, + # so build.sh never needs to parse BUILD_ENVIRONMENT. + #env-var-script: | + # export PYTORCH_ROCM_ARCH=gfx1250 + test-matrix: | + { include: [ + { config: "inductor", shard: 1, num_shards: 2, runner: "linux.rocm.gpu.gfx1250.1" }, # It requires provisioning hardware. + { config: "inductor", shard: 2, num_shards: 2, runner: "linux.rocm.gpu.gfx1250.1" }, + ]} + secrets: inherit + + linux-noble-rocm-py3_12-inductor-test: + name: linux-noble-rocm-py3.12-gfx1250 + uses: ./.github/workflows/_rocm-test.yml + needs: linux-noble-rocm-py3_12-inductor-build + with: + build-environment: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.build-environment }} + docker-image: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.docker-image }} + test-matrix: ${{ needs.linux-noble-rocm-py3_12-inductor-build.outputs.test-matrix }} + secrets: inherit diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 972e74c34dfa6..88d0f30f733bd 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -447,6 +447,20 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): ) +class TestFlexAttentionTDMOptions(InductorTestCase): + def test_apply_tdm_num_stages_uses_triton_launch_option(self): + from torch._inductor.kernel.flex.common import apply_tdm_num_stages + + kernel_options = {"num_stages": 1, "NUM_STAGES": 4} + + apply_tdm_num_stages(kernel_options) + + self.assertEqual( + kernel_options["num_stages"], config.tdm.max_outstanding_per_wave + ) + self.assertNotIn("NUM_STAGES", kernel_options) + + @large_tensor_test_class("2GB", device=test_device[0]) class TestFlexAttention(InductorTestCase): def setUp(self): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 5d3855aa73e93..fe45772a9fce4 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -60,6 +60,7 @@ CUDAPersistentTMATemplateConfigHeuristic, GemmConfig, get_shared_memory_checker_opts, + ROCmMMTemplateConfigHeuristic, XPUMMTemplateConfigHeuristic, XPUPersistentTMATemplateConfigHeuristic, ) @@ -327,6 +328,123 @@ def mock_get_tma_workspace_arg(*args, **kwargs): mm_tma_heuristic.mm_configs = original_tma_configs mm_heuristic.mm_configs = original_mm_configs + def test_tdm_arch_gate_accepts_only_gfx1250(self): + from torch._inductor.utils import is_gfx1250_arch + + self.assertTrue(is_gfx1250_arch("gfx1250")) + self.assertTrue(is_gfx1250_arch("gfx1250:sramecc+:xnack-")) + self.assertFalse(is_gfx1250_arch("gfx1251")) + self.assertFalse(is_gfx1250_arch("gfx1260")) + self.assertFalse(is_gfx1250_arch("amd-gfx1250")) + + def test_tdm_persistent_template_precedes_rocm_tma_fallback(self): + from torch._inductor.kernel import mm as mm_kernel + + for persistent_tma_enabled in (False, True): + templates = [] + with ( + config.patch( + { + "triton.enable_persistent_tma_matmul": persistent_tma_enabled + } + ), + mock.patch.object( + mm_kernel, + "use_triton_blackwell_tma_template", + return_value=False, + ), + mock.patch.object( + mm_kernel, + "use_triton_tdm_template", + return_value=True, + ), + mock.patch.object( + mm_kernel, + "use_triton_tma_template", + side_effect=AssertionError("TMA fallback should not run"), + ), + ): + selected = mm_kernel._append_persistent_mm_template( + templates, mock.Mock(), mock.Mock(), mock.Mock() + ) + + self.assertEqual(selected, "tdm") + self.assertEqual(templates, [mm_kernel.persistent_mm_template]) + + def test_tdm_template_add_guards_checks_compile_time_device(self): + from torch._inductor.utils import use_triton_tdm_template + + class FakeMatrix: + def __init__(self, device): + self.device = device + + def get_device(self): + return self.device + + def get_dtype(self): + return torch.float16 + + def get_stride(self): + return (128, 1) + + mat = FakeMatrix(torch.device("cuda", 0)) + with ( + config.patch({"enable_tdm_configs": True}), + mock.patch.object(torch.version, "hip", "7.2.0"), + mock.patch.object( + torch.cuda, + "get_device_properties", + return_value=mock.Mock(gcnArchName="gfx1250"), + ), + ): + self.assertTrue( + use_triton_tdm_template( + mat, + output_layout=mock.Mock(device=torch.device("cuda", 0)), + add_guards=True, + ) + ) + self.assertFalse( + use_triton_tdm_template( + mat, + output_layout=mock.Mock(device=torch.device("cuda", 1)), + add_guards=True, + ) + ) + + def test_tdm_block_k_filter_is_dtype_size_aware(self): + from torch._inductor.template_heuristics.triton import ( + _filter_tdm_block_k_configs, + ROCmGemmConfig, + ) + + configs = [ + ROCmGemmConfig(128, 64, 64, 4, 4, group_m=8), + ROCmGemmConfig(128, 64, 128, 4, 4, group_m=8), + ] + + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 2)], + [64, 128], + ) + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 1)], + [128], + ) + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 4)], + [64, 128], + ) + + def test_shared_memory_estimation_counts_num_stages_once(self): + heuristic = ROCmMMTemplateConfigHeuristic() + gemm_config = GemmConfig(128, 64, 64, 4, 4, group_m=8) + + self.assertEqual( + heuristic.get_shared_memory_estimation(gemm_config, 2, False, 0), + (128 * 64 + 64 * 64) * 2 * 4 + 128, + ) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) diff --git a/torch/_inductor/codegen/rocm/compile_command.py b/torch/_inductor/codegen/rocm/compile_command.py index aa935b14af23c..1ff830089cfa4 100644 --- a/torch/_inductor/codegen/rocm/compile_command.py +++ b/torch/_inductor/codegen/rocm/compile_command.py @@ -75,6 +75,10 @@ def _rocm_lib_options(dst_file_ext: str) -> list[str]: def _rocm_compiler_options() -> list[str]: + # `config.rocm.arch`is populated from either: + # - The `PYTORCH_ROCM_ARCH` environment variable, or + # - Runtime device detection. + # The string "native" tells `hipcc` to compile for the current GPU. arch_list = config.rocm.arch or ["native"] gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] opts = [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index da4a60574c37c..407ded38adc12 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -734,6 +734,18 @@ def use_autoheuristic(name: str) -> bool: force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" +# AMD TDM config flag for TDM support for gfx1250, is a configuration gate +# rather than something that causes TDM instructions to be emitted directly. +# It's set to control if gfx1250-specific autotuning configs +# (with larger tile sizes, more pipeline stages) are included +# in the candidate set during `max_autotune`. +# When enabled and running on a gfx1250 device, Inductor may emit Triton kernels +# that leverage TDM for asynchronous global->LDS tensor tile copies. +# Requires Triton >= 3.6.0 with AMD TDM backend support. +# Disabled by default on ROCm; TDM is further gated by device arch (gfx1250) +# in `use_triton_tdm_template()` at codegen time. +# Set to "0" to disable TDM even on gfx1250 hardware. +enable_tdm_configs = os.environ.get("TORCHINDUCTOR_ENABLE_TDM_CONFIGS", "0") == "1" # Whether to keep the output strides the same as eager after layout optimization. keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" @@ -2268,6 +2280,17 @@ class rocm: contiguous_threshold: int = 16 +class tdm: + # Maximum outstanding TDM address translations per wave is 4. + # For small tiles (e.g., 128x64 FP16), this is the binding constraint. + # For larger tiles with 2 waves/SIMD, the SIMD limit of 6 applies. + max_outstanding_per_wave: int = 4 + max_outstanding_per_simd: int = 6 + # TDM requires 128B/256B aligned contiguous regions + # in both global memory and LDS. + alignment_bytes: int = 128 + + # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental) cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp" diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index bf4006de8399a..764d7cd386691 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -43,12 +43,17 @@ to_dtype, ) from ...select_algorithm import realize_inputs -from ...utils import load_template +from ...utils import config, load_template SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] +def apply_tdm_num_stages(kernel_options: dict[str, Any]) -> None: + kernel_options["num_stages"] = config.tdm.max_outstanding_per_wave + kernel_options.pop("NUM_STAGES", None) + + def zeros_and_scatter_lowering(shape: list[int], indices, values): """To support backwards on captured buffers we register a specific lowering for our specific custom up""" # Always accumulate into fp32 then cast diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index c6017fadb0387..6277d26f43c4a 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -23,8 +23,9 @@ SymbolicGridFn, TritonTemplate, ) -from ...utils import can_use_tma +from ...utils import can_use_tma, use_triton_tdm_template from .common import ( + apply_tdm_num_stages, build_subgraph_buffer, create_indices_fake, create_num_blocks_fake_generator, @@ -410,6 +411,15 @@ def flex_attention( if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # For gfx1250 TDM, ensure the non-TMA path uses enough pipeline stages + # to trigger TDM async copies in Triton's AMD backend (TTGIR pipelining pass). + # Standard attention block sizes (64, 128, 256) with FP16/BF16 + # produce 128B aligned tiles and are compatible. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) cur_kernel_options.setdefault("BLOCK_N", conf.block_n) # Blocksparse options @@ -927,6 +937,12 @@ def flex_attention_backward(*args, **kwargs): if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # See the comments at the corresponding place in function `flex_attention` above. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m1) cur_kernel_options.setdefault("BLOCK_N1", conf.block_n1) cur_kernel_options.setdefault("BLOCK_M2", conf.block_m2) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 4111915c26082..8a24b71e398f4 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -18,8 +18,9 @@ SymbolicGridFn, TritonTemplate, ) -from ...utils import can_use_tma +from ...utils import can_use_tma, use_triton_tdm_template from .common import ( + apply_tdm_num_stages, create_indices_fake, create_num_blocks_fake_generator, freeze_irnodes, @@ -354,6 +355,15 @@ def create_flex_decoding_kernel(*args, **kwargs): if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # For gfx1250 TDM, ensure the non-TMA path uses enough pipeline stages + # to trigger TDM async copies in Triton's AMD backend (TTGIR pipelining pass). + # Standard attention block sizes (64, 128, 256) with FP16/BF16 + # produce 128B aligned tiles and are compatible. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + # Add ROCm-specific parameters if they exist in the config for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: if hasattr(conf, attrib): diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 112f201d52f1d..51d922deec82f 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -54,6 +54,7 @@ use_triton_scaling_template, use_triton_template, use_triton_tma_template, + use_triton_tdm_template, ) from .mm_common import ( _is_static_problem, @@ -100,6 +101,14 @@ source=load_kernel_template("triton_persistent_tma_mm"), ) +# Non-TMA Triton template for persistent MM +# used on AMD +persistent_mm_template = TritonTemplate( + name="mm_persistent", + grid=persistent_mm_grid, + source=load_kernel_template("triton_persistent_mm"), +) + scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( name="scaled_mm_device_tma_epilogue_scaling", @@ -121,6 +130,59 @@ ) +def _append_persistent_mm_template( + templates_to_use: list[ExternKernelChoice | KernelTemplate], + mat1: Buffer, + mat2: Buffer, + layout: Layout, +) -> str | None: + if use_triton_blackwell_tma_template( + mat1, mat2, output_layout=layout, add_guards=True + ): + templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) + return "blackwell_tma" + if use_triton_tdm_template(mat1, mat2, output_layout=layout, add_guards=True): + # GFX1250 TDM: use the non-TMA persistent template. Triton's AMD backend + # automatically inserts TDM instructions when compiling with num_stages > 1. + # + # TDM hardware constraints (gfx1250): + # Documentation-only for now because Triton's pipeliner handles the + # wave assignment internally. If Triton exposes wave-specialization + # knobs in the future, these comments identify where to hook them. + # - 1 TDM unit per SIMD pair (2 SIMDs share 1 TDM unit) + # - Max 4 outstanding TDM address translations per wave + # - Max 6 outstanding TDM address translations per SIMD + # - Recommended: waves specialized to load A or B + # - 8 waves/WG (num_warps=8): 4 waves for A, 4 for B + # - 4 waves/WG (num_warps=4): 2 waves for A, 2 for B + # - All waves can load both A & B. This is preferable for uniform + # unrolling across CUs and helps multicast load latency. + # - Do not interleave different TDMs within a wave. Complete one TDM + # instruction's unrolling before switching. + # + # The persistent template heuristic supplies gfx1250 TDM configs. + # Key constraints: + # - TDM requests target 128B or 256B aligned contiguous regions in both + # global memory and LDS. + # - For FP16: BLOCK_K must be a multiple of 64 (64 * 2B = 128B). + # - For FP8/FP4: BLOCK_K must be a multiple of 128 (128 * 1B = 128B). + # - For FP32: BLOCK_K must be a multiple of 32 (32 * 4B = 128B). + # - BLOCK_M and BLOCK_N should also produce 128B-aligned LDS rows. + # TDM is most beneficial for small tiles such as 128x64/64x128 FP16 + # and 128x128/128x256 FP8 with num_stages=4 (matching the + # 4-outstanding-per-wave TDM address translation limit) and + # num_warps=4 or 8 (for wave-specialized A/B loading). + templates_to_use.append(persistent_mm_template) + return "tdm" + if use_triton_tma_template(mat1, mat2, output_layout=layout, add_guards=True): + if torch.version.hip is None: + templates_to_use.append(persistent_tma_mm_template) + else: + templates_to_use.append(persistent_mm_template) + return "tma" + return None + + # prevent duplication registration of extern functions @functools.cache def lazy_register_extern_choice(fn): @@ -419,10 +481,7 @@ def _to_dtype(x): if is_exhaustive or not use_decompose_k_choice(m, n, k, threshold_multiple=2): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + _append_persistent_mm_template(templates_to_use, mat1, mat2, layout) if ( inductor_config.is_fbcode() @@ -664,10 +723,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if is_nonzero and use_triton_template(layout, check_max_autotune=False): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + _append_persistent_mm_template(templates_to_use, mat1, mat2, layout) templates_to_use.append(addmm_contiguous_subgraph_template) diff --git a/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja new file mode 100644 index 0000000000000..7450533be5adf --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja @@ -0,0 +1,75 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # persistent kernel: each CTA processes multiple tiles + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + num_tiles = grid_m * grid_n + width = GROUP_M * grid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): + + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=12, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=12, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=8, val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index e15ff07ee5a4f..fb6740e48b37f 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -22,6 +22,7 @@ get_scaling_options, get_tile_size, mm_template, + persistent_mm_template, persistent_tma_mm_template, scaled_mm_device_tma_epilogue_scaling_template, scaled_mm_device_tma_main_loop_scaling_template, @@ -32,6 +33,7 @@ get_backend_num_stages, get_num_sms, get_tma_workspace_arg, + is_gfx1250_arch, TMA_DESCRIPTOR_SIZE, using_b200, ) @@ -49,6 +51,42 @@ from torch._inductor.runtime.triton_compat import Config as TritonConfig +def _is_gfx1250_device() -> bool: + """Detect whether the current device is AMD gfx1250. + + gfx1250 has 320 KB LDS per CU, 512 B/clk/CU LDS bandwidth (R/W combined), + and TDM for async descriptor-based global->LDS copies. When num_stages > 1, + Triton's AMD backend automatically converts tl.load into TDM instructions. + + TDM hardware constraints: + - 1 TDM unit per SIMD pair (2 SIMDs share 1 TDM unit) + - Max 4 outstanding TDM address translations per wave + - Max 6 outstanding TDM address translations per SIMD + - Requests target 128B or 256B aligned contiguous regions + in both global memory and LDS + """ + if not torch.version.hip: + return False + if not inductor_config.enable_tdm_configs: + return False + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = getattr(props, "gcnArchName", "") + return is_gfx1250_arch(arch) + except Exception: + return False + + +def _filter_tdm_block_k_configs( + configs: list[BaseConfig], + dtype_size: int, +) -> list[BaseConfig]: + if dtype_size <= 0: + return configs + block_k_multiple = max(inductor_config.tdm.alignment_bytes // dtype_size, 1) + return [c for c in configs if c.block_k % block_k_multiple == 0] + + # Gemm Configs @dataclasses.dataclass class BaseConfig: @@ -1381,6 +1419,7 @@ def __init__(self) -> None: super().__init__() self.default_num_stages = get_backend_num_stages() + self.uses_tdm_configs = False self.mm_configs: list[BaseConfig] = [ ROCmGemmConfig( @@ -1442,6 +1481,42 @@ def __init__(self) -> None: ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), ] + # TDM-optimized persistent MM configs for gfx1250. + # + # Hardware constraints on gfx1250: + # - 320 KB LDS per CU, 512 B/clk/CU bandwidth + # - Max 4 outstanding TDM translations per wave (bounds num_stages) + # - Max 6 outstanding TDM translations per SIMD + # - TDM requests target 128B/256B aligned contiguous regions + # - 1 TDM unit per SIMD pair -> recommended 1 wave/SIMD pair issuing + # TDM instructions (num_warps=4 gives 1 wave/SIMD on a 4-SIMD CU) + # + # BLOCK_K alignment requirements (128B / dtype_bytes): + # FP16/BF16 (2B): BLOCK_K multiple of 64 + # FP8/INT8 (1B): BLOCK_K multiple of 128 + # FP32 (4B): BLOCK_K multiple of 32 + # + # LDS usage estimate: (BM*BK + BK*BN)*dtype_size*num_stages + # Must stay under 320 KB. + self.tdm_persistent_mm_configs: list[BaseConfig] = [ + # Small tiles: TDM most beneficial here (LDS-BW-limited to 1 wave/SIMD). + # These are where TDM eliminates cluster load issue overhead. + # `waves_per_eu=0` means Triton compiler should figure out the best value. + ROCmGemmConfig(128, 64, 64, 4, 4, group_m=8, waves_per_eu=0), + ROCmGemmConfig( 64, 128, 64, 4, 4, group_m=8, waves_per_eu=0), + ROCmGemmConfig(128, 64, 128, 4, 4, group_m=8), + ROCmGemmConfig( 64, 128, 128, 4, 4, group_m=8), + # Medium tiles: benefit from TDM prologue overhead reduction. + ROCmGemmConfig(128, 128, 64, 4, 4, group_m=8), + ROCmGemmConfig(128, 128, 64, 4, 8, group_m=16), + ROCmGemmConfig(128, 128, 128, 4, 4, group_m=8), + ROCmGemmConfig(128, 128, 128, 3, 8, group_m=16), + # Larger tiles: exploit the 320 KB LDS with deep pipelines. + ROCmGemmConfig(256, 128, 64, 4, 8, group_m=16), + ROCmGemmConfig(128, 256, 64, 4, 8, group_m=16), + ROCmGemmConfig(256, 128, 128, 3, 8, group_m=16), + ] + # Exhaustive search for mm configs self.exhaustive_configs: list[BaseConfig] = [ ROCmGemmConfig( @@ -1563,12 +1638,57 @@ def _prune_exhaustive_configs( ] return pruned_configs + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", + **kwargs, + ) -> Generator[TritonConfig, None, None]: + if self.uses_tdm_configs: + configs = _filter_tdm_block_k_configs(configs, dtype_size) + return super().preprocess_mm_configs( + m, + n, + k, + configs, + has_int8_tensor, + scale, + exclude, + dtype_size, + op_name, + **kwargs, + ) + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ ROCm specific filtering + + Normally we force num_stages to `default_num_stages` because AMD's + Triton backend historically didn't benefit from multi-stage pipelining. + However, on gfx1250 with TDM enabled, num_stages > 1 is exactly + what triggers TDM async copies in the StreamPipeliner pass, so we + must preserve the requested num_stages for TDM configs. """ + tdm_enabled = _is_gfx1250_device() for c in configs: - c.num_stages = self.default_num_stages + # On gfx1250, preserve num_stages as provided (but cap at 4 to + # respect TDM's per-wave outstanding limit). Otherwise, fall + # back to the legacy behavior of forcing default_num_stages. + if tdm_enabled: + c.num_stages = ( + min(c.num_stages, 4) if c.num_stages > 1 else c.num_stages + ) + else: + c.num_stages = self.default_num_stages return super()._filter_configs(configs) def _finalize_mm_configs( @@ -2914,6 +3034,56 @@ def __init__(self) -> None: self.exhaustive_configs = self.mm_plus_mm_configs +# ROCm persistent MM template heuristic (non-TMA, standard pointer loads) + + +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class PersistentMMTemplateConfigHeuristic( + MMTemplateConfigMixin, + ROCmConfigHeuristic, # type: ignore[misc] +): + """Persistent MM template heuristic (no TMA, standard pointer loads)""" + + def __init__(self) -> None: + super().__init__() + if _is_gfx1250_device(): + # On gfx1250, use TDM-optimized configs that exploit: + # - 320 KB LDS (larger tiles possible) + # - TDM async global->LDS copies (num_stages > 1) + # - Wave specialization across SIMD pairs + self.mm_configs = self.tdm_persistent_mm_configs + self.uses_tdm_configs = True + else: + self.mm_configs = self.persistent_mm_configs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + **kwargs, + ) -> Generator[dict[str, Any], None, None]: + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, op_name, **kwargs + ): + yield {**template_kwargs, "NUM_SMS": get_num_sms()} + + +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="addmm", +) +class ROCmAddMMPersistentTemplateConfigHeuristic( + AddMMConfigMixin, PersistentMMTemplateConfigHeuristic +): + """Addmm specific mixin for persistent MM on ROCm""" + + # MTIA template-specific classes diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index c9c60aa71eb0e..5184b16001a35 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1977,6 +1977,118 @@ def use_triton_tma_template( ) +def is_gfx1250_arch(arch: str) -> bool: + """Return True only for the gfx1250 target, including feature-suffixed names.""" + return arch.split(":", 1)[0] == "gfx1250" + + +def use_triton_tdm_template(*matrices: IRNode, output_layout=None, add_guards=False): + """Coarsely check whether AMD TDM-optimized persistent template should be used. + Despite the fact that TDM has some restrictions on direct LDS loads ( + when unrolling requests, the requests are to 128B or 256B aligned and + continuous regions in global memory, and writing back to 128B or 256B + aligned and continuous regions in LDS), this gate doesn't check per-element + alignment, instead delegating to Triton compiler's AMD backend. + That said, for GeMM specifically, it checks the leading dimensions produce + 128B-aligned rows, which is the common case where TDM is beneficial. + + Args: + - output_layout: Output layout (used for device detection) + - add_guards: If True, add shape guards to the graph + Returns True when: + - Running on ROCm (torch.version.hip is not None) + - Device is gfx1250 + - config.enable_tdm_configs (defined elsewhere) is True + - Triton version supports gfx1250 TDM (>= 3.6.0) + """ + if not torch.version.hip: + return False + if not config.enable_tdm_configs: + return False + device = matrices[0].get_device() + if device.type != "cuda": + return False + try: + # props = torch.cuda.get_device_properties(torch.cuda.current_device()) + props = torch.cuda.get_device_properties(device) + arch = getattr(props, "gcnArchName", "") + if not is_gfx1250_arch(arch): + return False + except Exception: + return False + + # Check Triton version supports TDM (3.6.0+) + try: + from torch.torch_version import TorchVersion + import triton + if TorchVersion(triton.__version__) < "3.6.0": + return False + except (ImportError, Exception): + return False + + # TDM can lower the same matrix element types as standard Triton GeMM. + # The persistent-template config heuristic still filters BLOCK_K by dtype + # size so each TDM request covers an aligned 128B region. + # FIXME: torch.float4 + triton_tdm_dtypes = { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + torch.int8, + } + # TDM unrolls requests to 128B or 256B aligned contiguous regions + # in both global memory and LDS. For GeMM, this means the leading + # dimension (stride) of each matrix should produce rows that start + # at 128B-aligned addresses. If strides are not 128B-aligned, TDM + # may still work (Triton's backend can fall back per-load), but + # performance benefit is reduced. We don't hard-gate on this because: + # 1. The autotuner will benchmark and reject TDM configs if they're slower. + # 2. Triton's pipeliner handles misaligned cases gracefully. + # 3. Most real GeMM workloads have contiguous row-major or col-major layouts. + # For now, only log for debugging purposes. May be upgraded to a hard gate later. + for mat in matrices: + mat_dtype = mat.get_dtype() + if mat_dtype not in triton_tdm_dtypes: + return False + strides = mat.get_stride() + dtype_bytes = mat_dtype.itemsize + # Check if the innermost stride is 1 (contiguous) and + # the outer stride produces 128B-aligned rows. + if len(strides) >= 2: + inner_stride = strides[-1] + outer_stride = strides[-2] + if hasattr(inner_stride, '__int__'): + inner_val = int(inner_stride) + if inner_val != 1: + # TODO: Should be a hard rejection, not just a warning. + log.debug( + "TDM: matrix has non-unit inner stride %d, " + "TDM requests may not be contiguous", + inner_val, + ) + if hasattr(outer_stride, '__int__'): + row_bytes = int(outer_stride) * dtype_bytes + if row_bytes % config.tdm.alignment_bytes != 0: + log.debug( + "TDM: matrix row stride %d bytes is not %dB-aligned, " + "TDM performance may be suboptimal", + row_bytes, + config.tdm.alignment_bytes, + ) + + if add_guards and output_layout is not None: + # `SizeVarAllocator` does not guard device properties. Keep this as a + # static consistency check for the layout/device pair used at compile time. + if (output_layout.device.index or 0) != (device.index or 0): + return False + + return True + + def use_triton_blackwell_tma_template( *matrices: IRNode, output_layout: Layout, add_guards: bool = False ) -> bool: From 9c8cbb3cecb9f73c443f28830a60226c2bf4e260 Mon Sep 17 00:00:00 2001 From: rraminen_amdeng Date: Mon, 15 Jun 2026 23:19:31 +0000 Subject: [PATCH 5/5] Bump triton version to 3.7.0 --- .ci/docker/ci_commit_pins/triton.txt | 2 +- .ci/docker/triton_version.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index f0849cc7d8f63..9c78a57cca3ad 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -4ed888920c5a0871957f1cf912e557bc79fbe56c +9c610c781cb810a11bfcc9accba094550b189a5e diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 40c341bdcdbe8..7c69a55dbb185 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.6.0 +3.7.0