diff --git a/.ci/docker/build.sh b/.ci/docker/build.sh index 7df6453c22da9..8e216eb1ea15a 100755 --- a/.ci/docker/build.sh +++ b/.ci/docker/build.sh @@ -190,7 +190,7 @@ case "$tag" in KATEX=yes UCX_COMMIT=${_UCX_COMMIT} UCC_COMMIT=${_UCC_COMMIT} - PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100" + PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100;gfx1250" if [[ $tag =~ "benchmarks" ]]; then INDUCTOR_BENCHMARKS=yes fi diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index f0849cc7d8f63..2eedac500e0ec 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -4ed888920c5a0871957f1cf912e557bc79fbe56c +110cd8e2ddf80d46fcc935d46dfcae7130d13b24 diff --git a/.ci/docker/triton_version.txt b/.ci/docker/triton_version.txt index 40c341bdcdbe8..a76ccff2a6e0d 100644 --- a/.ci/docker/triton_version.txt +++ b/.ci/docker/triton_version.txt @@ -1 +1 @@ -3.6.0 +3.7.1 diff --git a/CMakeLists.txt b/CMakeLists.txt index 188f08bb272f5..5d44e05486016 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -945,8 +945,12 @@ cmake_dependent_option( OFF) +# TODO: +# MSLK related parts are missing that already exists upstream. +# gfx1250 for MSLK needs to be involved as well. + IF(USE_ROCM AND ("gfx942" IN_LIST PYTORCH_ROCM_ARCH OR "gfx950" IN_LIST PYTORCH_ROCM_ARCH)) - message(WARNING "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build") + message(STATUS "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build") set(USE_MSLK_DEFAULT ON) elseif(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32) message(STATUS "Setting USE_MSLK to ON by default , doing CUDA build for SM100a") diff --git a/aten/src/ATen/CMakeLists.txt b/aten/src/ATen/CMakeLists.txt index f1ac6246f4d56..58910ae7f106d 100644 --- a/aten/src/ATen/CMakeLists.txt +++ b/aten/src/ATen/CMakeLists.txt @@ -200,6 +200,21 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a # flash_attention hip sources file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip") +# composable_kernel lacks gfx1250 support. CK GEMM/SDPA are otherwise built for +# every arch except gfx1250 (the --offload-arch filtering below). If gfx1250 is +# the ONLY arch there is no supported arch left to build CK for, so disable both +# entirely here. caffe2_update_option writes the cache, so this is honored by the +# conditional CK GEMM/SDPA defines and links in caffe2/CMakeLists.txt. +if(USE_ROCM AND "gfx1250" IN_LIST PYTORCH_ROCM_ARCH) + set(_ck_supported_archs ${PYTORCH_ROCM_ARCH}) + list(REMOVE_ITEM _ck_supported_archs gfx1250) + if("${_ck_supported_archs}" STREQUAL "") + message(WARNING "gfx1250 is the only arch in PYTORCH_ROCM_ARCH: disabling USE_ROCM_CK_GEMM and USE_ROCM_CK_SDPA (composable_kernel lacks gfx1250 support)") + caffe2_update_option(USE_ROCM_CK_GEMM OFF) + caffe2_update_option(USE_ROCM_CK_SDPA OFF) + endif() +endif() + # if USE_FLASH_ATTENTION is set, ensure CK instances get generated if(USE_FLASH_ATTENTION) if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1") @@ -289,9 +304,20 @@ if(USE_FLASH_ATTENTION) "native/transformers/hip/flash_attn/ck/fav_v3/*.hip") set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + # composable_kernel lacks gfx1250 support, so build CK SDPA for every arch + # except gfx1250 by filtering the --offload-arch flags for this target only. + set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) + string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") + foreach(ARCH ${PYTORCH_ROCM_ARCH}) + if(NOT ARCH STREQUAL "gfx1250") + list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH}) + endif() + endforeach() + set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS}) hip_add_library(ck_sdpa STATIC ${ck_sdpa_sources_hip} HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS}) + set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL}) set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON) target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS}) target_compile_definitions(ck_sdpa PRIVATE AITER_EMBEDDED_HSA_HEADER="aiter_embedded_hsa.h") @@ -430,7 +456,7 @@ IF(USE_MSLK) list(PREPEND MSLK_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) endif() - # Only compile for gfx942 and gfx950. + # Only compile for gfx942 and gfx950 (composable_kernel lacks gfx1250 support). set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") foreach(ARCH gfx942 gfx950) @@ -627,11 +653,41 @@ if(USE_ROCM) ${native_quantized_hip_hip} ${native_transformers_hip_hip} ${native_transformers_src_hip_hip} ) + file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") + file(GLOB native_hip_ck "native/hip/ck*.hip") if(NOT USE_ROCM_CK_GEMM) - file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip") - file(GLOB native_hip_ck "native/hip/ck*.hip") exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" ${native_hip_bgemm} ${native_hip_ck}) + else() + # composable_kernel lacks gfx1250 support, so the CK GEMM kernels are removed + # from the main HIP sources and compiled into a dedicated ck_gemm library for + # every arch except gfx1250 (the rest of torch_hip still builds for all archs). + # The --offload-arch filtering mirrors the mslk pattern above. + exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}" + ${native_hip_bgemm} ${native_hip_ck}) + set(ck_gemm_sources_hip ${native_hip_bgemm} ${native_hip_ck}) + set_source_files_properties(${ck_gemm_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) + string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") + foreach(ARCH ${PYTORCH_ROCM_ARCH}) + if(NOT ARCH STREQUAL "gfx1250") + list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH}) + endif() + endforeach() + set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS}) + hip_add_library(ck_gemm STATIC ${ck_gemm_sources_hip}) + set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL}) + set_target_properties(ck_gemm PROPERTIES POSITION_INDEPENDENT_CODE ON) + # The define is no longer added globally in cmake/Dependencies.cmake; the + # ck_gemm_*.hip sources are guarded by #if defined(USE_ROCM_CK_GEMM). + target_compile_definitions(ck_gemm PRIVATE USE_ROCM_CK_GEMM) + target_include_directories(ck_gemm PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include + ${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include + ${CMAKE_CURRENT_BINARY_DIR}/composable_kernel) + if(TARGET torch_cpu) + add_dependencies(ck_gemm torch_cpu) + endif() endif() # TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources) diff --git a/aten/src/ATen/Context.cpp b/aten/src/ATen/Context.cpp index c342590b58c42..ba63e73e1017c 100644 --- a/aten/src/ATen/Context.cpp +++ b/aten/src/ATen/Context.cpp @@ -505,7 +505,7 @@ at::BlasBackend Context::blasPreferredBackend() { bool Context::ckSupported() { #ifdef USE_ROCM static const std::vector supported_archs = { - "gfx90a", "gfx942", "gfx950" + "gfx90a", "gfx942", "gfx950", "gfx1250", }; for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) { if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) { diff --git a/aten/src/ATen/cuda/CUDABlas.cpp b/aten/src/ATen/cuda/CUDABlas.cpp index b9e66fea5ebdb..0d113df3d2e40 100644 --- a/aten/src/ATen/cuda/CUDABlas.cpp +++ b/aten/src/ATen/cuda/CUDABlas.cpp @@ -2020,6 +2020,14 @@ void scaled_gemm( "Got m=", m, ", n=", n, ", k=", k); } #endif + #if ROCM_VERSION >= 70200 + if (at::detail::getCUDAHooks().isGPUArch({"gfx1250"})) { + // TODO: add constraints based on hipblaslt internals + TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0), + "M, N must be multiples of 16 and K should be multiple of 128 for MX format. " + "Got m=", m, ", n=", n, ", k=", k); + } + #endif } #elif (CUDA_VERSION < 12090) && !defined(USE_ROCM) // hipblaslt supported row-wise before cublas, and did so their own way (via diff --git a/aten/src/ATen/cuda/CUDAScaledBlas.h b/aten/src/ATen/cuda/CUDAScaledBlas.h index 3fd0e2a6a3aae..90dde0384b87a 100644 --- a/aten/src/ATen/cuda/CUDAScaledBlas.h +++ b/aten/src/ATen/cuda/CUDAScaledBlas.h @@ -67,6 +67,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals #endif #if ROCM_VERSION >= 60500 "gfx950" +#endif +#if ROCM_VERSION >= 70200 + , "gfx1250" #endif }; return at::detail::getCUDAHooks().isGPUArch(archs); diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index b5008bad832d0..75d9ae9ae586d 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -141,8 +141,8 @@ size_t parseChosenWorkspaceSize() { val = c10::utils::get_env("ROCBLAS_WORKSPACE_CONFIG"); } /* 32MiB default, 128MiB for gfx94x/gfx95x */ - const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"}); - const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; + const bool gfx94_95_125 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95", "gfx125"}); + const size_t default_size = gfx94_95_125 ? 1024 * 128 * 1024 : 1024 * 32 * 1024; #else /* :4096:2:16:8 default, 32MiB for Hopper */ cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties(); diff --git a/aten/src/ATen/cuda/detail/CUDAHooks.cpp b/aten/src/ATen/cuda/detail/CUDAHooks.cpp index 03a3a97525a43..4dfc3003e0cdb 100644 --- a/aten/src/ATen/cuda/detail/CUDAHooks.cpp +++ b/aten/src/ATen/cuda/detail/CUDAHooks.cpp @@ -558,7 +558,10 @@ const std::vector& CUDAHooks::getHipblasltPreferredArchs() const { "gfx1200", "gfx1201", #endif #if ROCM_VERSION >= 70000 - "gfx950" + "gfx950", +#endif +#if ROCM_VERSION >= 70200 + "gfx1250" #endif }; return archs; diff --git a/aten/src/ATen/native/cuda/CUDALoops.cuh b/aten/src/ATen/native/cuda/CUDALoops.cuh index b1031a95a3d0e..aa0878566052c 100644 --- a/aten/src/ATen/native/cuda/CUDALoops.cuh +++ b/aten/src/ATen/native/cuda/CUDALoops.cuh @@ -228,8 +228,12 @@ C10_LAUNCH_BOUNDS_1(num_threads()) __global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) { using traits = function_traits; constexpr auto io_size = calc_io_size(); + // tws=16 is tuned for gfx942 (CDNA3, Wave64) and is NOT portable to Wave32 + // archs. gfx1250 (GFX12.5, Wave32) falls through to the default `tws` below; + // re-enabling this fast path for gfx1250 requires retuning the constant + // against the 2x wave-width ratio. Keep this in sync with the host-side `tws` + // computation in launch_vectorized_kernel() (also gfx942-only). #if defined(USE_ROCM) && defined(__gfx942__) - // Similar check in launch_vectorized_kernel() as well. Both should be in sync. constexpr int tws = 16; #else constexpr int tws = elems_per_thread(); diff --git a/aten/src/ATen/native/cuda/GroupedBlas.cpp b/aten/src/ATen/native/cuda/GroupedBlas.cpp index 70c33e27aa0a3..90cbf0a435423 100644 --- a/aten/src/ATen/native/cuda/GroupedBlas.cpp +++ b/aten/src/ATen/native/cuda/GroupedBlas.cpp @@ -691,6 +691,9 @@ std::optional out_dtype) { // To enable CK path, use env variable ROCM_ALLOW_GROUP_GEMM_CK=1. bool use_fast_path = false; // ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path. + // NOTE: gfx1250 is intentionally excluded. The CK grouped GEMM path dispatches + // Wave64/MFMA-style XDL templates; gfx1250 is Wave32 and needs a WMMA/SWMMAC + // path, so it must stay on the fallback until a gfx1250-safe CK path exists. #if defined(USE_ROCM_CK_GEMM) if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) { use_fast_path = true; @@ -699,7 +702,9 @@ std::optional out_dtype) { const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); if (use_fast_path) { +#if defined(USE_ROCM_CK_GEMM) at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); +#endif //USE_ROCM_CK_GEMM } else { _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); } diff --git a/aten/src/ATen/native/cuda/MemoryAccess.cuh b/aten/src/ATen/native/cuda/MemoryAccess.cuh index 373b44cca7901..e5506b02d267c 100644 --- a/aten/src/ATen/native/cuda/MemoryAccess.cuh +++ b/aten/src/ATen/native/cuda/MemoryAccess.cuh @@ -187,6 +187,10 @@ template __device__ aligned_vector load_vector(const scalar_t *base_ptr, uint32_t offset) { using vec_t = aligned_vector; auto *from = reinterpret_cast(base_ptr); + // This nontemporal vectorized load is tuned for gfx942 (CDNA3, Wave64). Do NOT + // extend to gfx1250 (GFX12.5, Wave32) without first validating intrinsic + // codegen and re-benchmarking -- the register layout and wave width differ. + // gfx1250 falls through to the generic path below, which is correctness-safe. #if defined(USE_ROCM) && defined(__gfx942__) using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; if constexpr (sizeof(vec_t) == sizeof(int)) { diff --git a/aten/src/ATen/native/cuda/RangeFactories.cu b/aten/src/ATen/native/cuda/RangeFactories.cu index 9d7ead7e49892..bc8decb5469ec 100644 --- a/aten/src/ATen/native/cuda/RangeFactories.cu +++ b/aten/src/ATen/native/cuda/RangeFactories.cu @@ -8,6 +8,9 @@ #include #include #include +#if defined(USE_ROCM) +#include +#endif #ifndef AT_PER_OPERATOR_HEADERS #include @@ -48,12 +51,49 @@ __global__ void elementwise_kernel_with_index(index_t N, func_t f, typename func } } +#if defined(USE_ROCM) +// HIP does not support launches with gridDim.x * blockDim.x >= 2^32: +// depending on the ROCm version the launch returns +// hipErrorInvalidConfiguration or is accepted silently with the kernel +// never executing, leaving zero-initialized output. A grid-stride kernel +// with a fixed grid sized to device occupancy avoids the limit. +template +C10_LAUNCH_BOUNDS_1(num_threads()) +__global__ void elementwise_kernel_with_index_grid_stride( + index_t N, func_t f, + typename function_traits::result_type *data) { + index_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; + const index_t stride = static_cast(gridDim.x) * blockDim.x; + for (; idx < N; idx += stride) { + data[idx] = f(idx); + } +} +#endif + template void gpu_kernel_with_index(at::Tensor &output, func_t f) { int64_t N = output.numel(); if (N == 0) { return; } +#if defined(USE_ROCM) + constexpr int blocks_per_sm = 4; + const int sm_count = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + const int64_t orig_grid = (N + block_work_size - 1) / block_work_size; + int64_t grid = std::min( + orig_grid, static_cast(sm_count) * blocks_per_sm); + grid = std::max(grid, 1); + auto stream = at::cuda::getCurrentCUDAStream(); + using scalar_t = typename function_traits::result_type; + if (N <= std::numeric_limits::max()) { + elementwise_kernel_with_index_grid_stride<<>>(N, f, output.mutable_data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } else { + elementwise_kernel_with_index_grid_stride<<>>(N, f, output.mutable_data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + } +#else int64_t grid = (N + block_work_size - 1) / block_work_size; auto stream = at::cuda::getCurrentCUDAStream(); using scalar_t = typename function_traits::result_type; @@ -64,6 +104,7 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) { elementwise_kernel_with_index<<>>(N, f, output.mutable_data_ptr()); C10_CUDA_KERNEL_LAUNCH_CHECK(); } +#endif } } // namespace diff --git a/aten/src/ATen/native/cuda/ScaledBlas.cpp b/aten/src/ATen/native/cuda/ScaledBlas.cpp index 223f10c53a318..0522d0f4a601c 100644 --- a/aten/src/ATen/native/cuda/ScaledBlas.cpp +++ b/aten/src/ATen/native/cuda/ScaledBlas.cpp @@ -78,6 +78,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals #endif #if ROCM_VERSION >= 60500 "gfx950" +#endif +#if ROCM_VERSION >= 70200 + , "gfx1250" #endif }; return at::detail::getCUDAHooks().isGPUArch(archs); @@ -622,9 +625,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, } else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) { #ifdef USE_ROCM - #if ROCM_VERSION >= 70000 +#if ROCM_VERSION >= 70000 +#if ROCM_VERSION >= 70200 + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); +#else TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); +#endif int packed_factor = 1; if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) { @@ -1067,8 +1075,13 @@ _scaled_mxfp8_mxfp8( #ifdef USE_ROCM #if ROCM_VERSION >= 70000 +#if ROCM_VERSION >= 70200 + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); +#else TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); +#endif TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, @@ -1153,8 +1166,13 @@ _scaled_mxfp4_mxfp4( auto scaling_choice_b = ScalingType::BlockWise1x32; #if ROCM_VERSION >= 70000 +#if ROCM_VERSION >= 70200 + TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), + "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); +#else TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); +#endif TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, diff --git a/aten/src/ATen/native/cuda/int4mm.cu b/aten/src/ATen/native/cuda/int4mm.cu index 8765bed83345a..b0d2126beac26 100644 --- a/aten/src/ATen/native/cuda/int4mm.cu +++ b/aten/src/ATen/native/cuda/int4mm.cu @@ -127,7 +127,8 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) { return diff == 0 ? 0 : uint32_t(Align) - diff; } -#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) +// CDNA2+ arch with MFMA and Wave64 support +#if defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__) #define CDNA2_OR_LATER 1 #else #define CDNA2_OR_LATER 0 @@ -146,6 +147,12 @@ static bool isCDNA2orLater(int index) { return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index); } +// Conceptual for now and subject to change +// gfx1250 (CDNA5 / CDNA-next / UDNA) +static bool isCDNA5orLater(int index) { + return at::detail::getCUDAHooks().isGPUArch({"gfx1250"}, index); +} + #else constexpr int32_t kWarpSize = 32; #endif @@ -1098,6 +1105,11 @@ at::Tensor _weight_int4pack_mm_cuda( A.device() == B.device() && A.device() == qScaleAndZeros.device()); #if defined(USE_ROCM) + if (isCDNA5orLater(A.device().index())) { + TORCH_CHECK(false, + "_weight_int4pack_mm_cuda is not yet supported on gfx1250. " + "A WMMA-based implementation is required for gfx1250."); + } if (!isCDNA2orLater(A.device().index())) { TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); } @@ -1293,6 +1305,11 @@ at::Tensor _convert_weight_to_int4pack_cuda( TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8); #if defined(USE_ROCM) + if (isCDNA5orLater(in.device().index())) { + TORCH_CHECK(false, + "_convert_weight_to_int4pack_cuda is not yet supported on gfx1250. " + "A WMMA-based implementation is required for gfx1250."); + } if (!isCDNA2orLater(in.device().index())) { TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2"); } diff --git a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp index 9d735ac0f2c88..b3ea27f1d357e 100644 --- a/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp +++ b/aten/src/ATen/native/sparse/cuda/cuSPARSELtOps.cpp @@ -30,7 +30,16 @@ static void initHipSparseLtSupport() { // Check only the first available device try { if (at::cuda::device_count() > 0) { - g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); + // hipSparseLt gfx1250 support requires ROCm 7.2+. Older ROCm builds + // should not treat gfx1250 as supported here to avoid failing deeper + // in hipSparseLt after we pass this gate. +#if ROCM_VERSION >= 70200 + g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch( + {"gfx950", "gfx942", "gfx1250"}, 0); +#else + g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch( + {"gfx950", "gfx942"}, 0); +#endif } } catch (const std::exception&) { // If an exception occurs during device property check, we assume hipSparseLt is not supported @@ -48,9 +57,13 @@ static bool isHipSparseLtSupported() { if (!g_hipSparseLtSupported) { TORCH_CHECK( false, - "hipSparseLt not supported on this device, supported architectures: " - "gfx950, gfx942. " - "required ROCM version: 6.4.0 or later."); + "hipSparseLt not supported on this device. Supported architectures: " +#if ROCM_VERSION >= 70200 + "gfx1250, gfx950, gfx942. " +#else + "gfx950, gfx942 (gfx1250 requires a PyTorch build against ROCm 7.2 or newer). " +#endif + "hipSparseLt on ROCm requires ROCm 7.12 or newer."); } return g_hipSparseLtSupported; } diff --git a/aten/src/ATen/native/transformers/cuda/attention.cu b/aten/src/ATen/native/transformers/cuda/attention.cu index f2df246521fb5..7fee27adf9fea 100644 --- a/aten/src/ATen/native/transformers/cuda/attention.cu +++ b/aten/src/ATen/native/transformers/cuda/attention.cu @@ -1493,6 +1493,10 @@ std::tuple _efficient_ // compute_logsumexp is false constexpr int kAlignLSE = 1; res = at::empty({B, M, num_heads, Kv}, query.options()); + // TODO: Use Compact Varlen LSE + // The current memory allocation is strictly larger than necessary + // (total_q <= max_seqlen_q * B) + // The problem is total_q is not available here. at::Tensor softmax_lse; logsumexp = at::empty( { B, num_heads, compute_logsumexp ? max_seqlen_q : 0}, @@ -1521,8 +1525,6 @@ std::tuple _efficient_ atomic_counter = at::zeros({1}, query.options().dtype(at::kInt)); } - using aotriton::v2::flash::attn_fwd; - using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -1538,92 +1540,47 @@ std::tuple _efficient_ auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); hipError_t err; // TODO: Error handling - if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef -#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - using aotriton::v3::flash::WindowValue; - aotriton::v3::flash::attn_fwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.Sm_scale = softmax_scale; - params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2; - params.Out = mk_aotensor(output_t, "Out"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = dropout_p; - params.philox_seed_ptr = seed; - params.philox_offset1 = offset1; - params.philox_offset2 = offset2; - params.philox_seed_output = seed_output; - params.philox_offset_output = offset_output; - params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); - params.persistent_atomic_counter = persistent_counter; - params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - params.window_left = WindowValue::TopLeftAligned; - params.window_right = WindowValue::TopLeftAligned; - } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { - params.window_left = WindowValue::BottomRightAligned; - params.window_right = WindowValue::BottomRightAligned; - } - if (bias.has_value()) { - params.B = mk_aotensor(bias.value(), "bias"); - } - if (seqstart_q.has_value()) { - params.varlen_type = VarlenType::CompactVarlen; - params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"); - } else { - params.varlen_type = VarlenType::None; - } - err = aotriton::v3::flash::attn_fwd(params, - aotriton::v3::flash::attn_fwd_params::kVersion, - stream); -#endif // AOTRITON_V3_API - } else if (seqstart_q.has_value()) { - // varlen aka nested tensor - err = attn_fwd_compact_varlen(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"), - mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - softmax_scale, - compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.L = compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2; + params.Out = mk_aotensor(output_t, "Out"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = dropout_p; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } + if (bias.has_value()) { + params.B = mk_aotensor(bias.value(), "bias"); + } + if (seqstart_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(seqstart_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(seqstart_k.value(), "cu_seqlens_k"); } else { - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias"): empty_t4, - softmax_scale, - compute_logsumexp ? mk_aotensor<2>(softmax_lse, "M") : empty_t2, - mk_aotensor(output_t, "Out"), - dropout_p, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); + params.varlen_type = VarlenType::None; } + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); #else TORCH_CHECK(false, "Attempting to use AOTriton mem_eff_forward backend in a build that has not built AOTriton"); #endif diff --git a/aten/src/ATen/native/transformers/cuda/attention_backward.cu b/aten/src/ATen/native/transformers/cuda/attention_backward.cu index 183f99e975cda..b7e52617a0ec6 100644 --- a/aten/src/ATen/native/transformers/cuda/attention_backward.cu +++ b/aten/src/ATen/native/transformers/cuda/attention_backward.cu @@ -547,6 +547,17 @@ _efficient_attention_backward( "[AOTriton] Accelerated SDPA only supports MI200/MI300X/7900XTX/9070XT GPUs" " (gfx90a/gfx942/gfx1100/gfx1201)") } + bool deterministic{false}; + auto& ctx = at::globalContext(); + if (ctx.deterministicAlgorithms()) { + if (ctx.deterministicAlgorithmsWarnOnly()) { + TORCH_WARN_ONCE( + "Memory Efficient attention defaults to a non-deterministic algorithm. ", + "To explicitly enable determinism call torch.use_deterministic_algorithms(True, warn_only=False)."); + } else { + deterministic = true; + } + } const auto softmax_scale = sdp::calculate_scale(query, scale).expect_float(); bool is_causal; if (static_cast(sdp::CustomMaskType::NoCustomMask) == custom_mask_type) { @@ -569,139 +580,60 @@ _efficient_attention_backward( at::Tensor dout_t = grad_out.permute({0,2,1,3}); at::Tensor softmax_lse = logsumexp.view({B * nH, max_seqlen_q}); hipError_t err; - using aotriton::v2::flash::attn_bwd; - using aotriton::v2::flash::attn_bwd_fused; - using aotriton::v2::flash::attn_bwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::cast_dtype; aotriton::TensorView<4> empty_t4(0, {0, 0, 0, 0}, {0, 0, 0, 0}, cast_dtype(query.dtype())); - if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef -#if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - using aotriton::v3::flash::WindowValue; - aotriton::v3::flash::attn_bwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4; - params.Sm_scale = softmax_scale; - params.Out = mk_aotensor(out_t, "out"); - params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dk_t, "dk"); - params.DV = mk_aotensor(dv_t, "dv"); - params.DQ = mk_aotensor(dq_t, "dq"); - params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4; - params.L = mk_aotensor<2>(softmax_lse, "L"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = float(dropout_p); - params.philox_seed_ptr = mk_aoscalartensor(philox_seed); - params.philox_offset1 = mk_aoscalartensor(philox_offset); - params.philox_offset2 = 0; - params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; - if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { - params.window_left = WindowValue::TopLeftAligned; - params.window_right = WindowValue::TopLeftAligned; - } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { - params.window_left = WindowValue::BottomRightAligned; - params.window_right = WindowValue::BottomRightAligned; - } -#if AOTRITON_ALWAYS_V3_API - using sdp::aotriton_adapter::mklazy_empty_like; - using sdp::aotriton_adapter::mklazy_fp32zeros; - using sdp::aotriton_adapter::LazyTensorContext; - LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" }; - LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; - params.D = mklazy_empty_like<2>(&lazy_delta); - params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); -#else - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - params.D = mk_aotensor<2>(delta, "delta"); -#endif - if (cu_seqlens_q.has_value()) { - params.varlen_type = VarlenType::CompactVarlen; - params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"); - } else { - params.varlen_type = VarlenType::None; - } - err = aotriton::v3::flash::attn_bwd(params, - aotriton::v3::flash::attn_bwd_params::kVersion, - stream); -#endif // AOTRITON_V3_API - } else if (cu_seqlens_q.has_value()) { - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - // varlen aka Nested tensor - err = attn_bwd_compact_varlen(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } else { // cu_seqlens.has_value - auto d_head = Kv; - bool use_fused_bwd = d_head <= 192 && d_head * max_seqlen_q < 64 * 512; - if (use_fused_bwd) { - err = attn_bwd_fused(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } else { - at::Tensor delta = at::empty_like(softmax_lse).contiguous(); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4, - mk_aotensor<2>(softmax_lse, "L"), - mk_aotensor<2>(delta, "delta"), - float(dropout_p), - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } //used_fused_bwd - } // cuseqlen.has_value + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + using aotriton::v3::flash::WindowValue; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.B = bias.has_value() ? mk_aotensor(bias.value(), "bias") : empty_t4; + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DB = bias_requires_grad ? mk_aotensor(grad_bias, "db") : empty_t4; + params.L = mk_aotensor<2>(softmax_lse, "L"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = float(dropout_p); + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + if (static_cast(sdp::CustomMaskType::CausalFromTopLeft) == custom_mask_type) { + params.window_left = WindowValue::TopLeftAligned; + params.window_right = WindowValue::TopLeftAligned; + } else if (static_cast(sdp::CustomMaskType::CausalFromBottomRight) == custom_mask_type) { + params.window_left = WindowValue::BottomRightAligned; + params.window_right = WindowValue::BottomRightAligned; + } + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); + if (cu_seqlens_q.has_value()) { + params.varlen_type = VarlenType::CompactVarlen; + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q.value(), "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k.value(), "cu_seqlens_k"); + } else { + params.varlen_type = VarlenType::None; + } + aotriton::v3::flash::attn_options opts; + opts.deterministic = deterministic; + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream, + &opts); #else // DISABLE_AOTRITON TORCH_CHECK(false, "Attempting to use aotriton mem_eff_backward backend in a build that has not built AOTriton"); #endif diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp index 67a5a296e2afe..72bc84e2d10a2 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.cpp @@ -140,19 +140,11 @@ int64_t minimum_gemm_alignment(sdp_params const& params) { return matmul_alignment_mn; } -// On ROCM, ME and FA share the backend, and hence they share the checking -// function for fundamental limitations by the GPU kernel -// caller_is_meff is added to make the TORCH_WARN message showing the correct result -template -bool check_head_dim_size_flash(sdp_params const& params, bool debug) { #if USE_ROCM_ATTENTION - if (at::cuda::device_count() == 0) { - return false; - } - // AOTriton 0.9+ supports head_dim up to 512 - const static auto max_hdim = []() { +inline int aotriton_max_hdim() { + static const int max_hdim = []() { #if AOTRITON_VERSION_CURRENT == AOTRITON_VERSION_INT(0, 11) - // gfx11xx only support hdim <= 256 on AOTriton 0.11 + // gfx11xx only support hdim <= 256 on AOTriton 0.11/0.12 auto dprops = at::cuda::getCurrentDeviceProperties(); const c10::basic_string_view arch(dprops->gcnArchName); if (arch.starts_with("gfx11")) { @@ -165,7 +157,28 @@ bool check_head_dim_size_flash(sdp_params const& params, bool debug) { return 256; #endif }(); - const auto max_size = c10::SymInt(max_hdim); + return max_hdim; +} +#endif // USE_ROCM_ATTENTION + +// For AOTriton <= 0.11: +// On ROCM, ME and FA share the backend, and hence they share the checking +// function for fundamental limitations by the GPU kernel +// caller_is_meff is added to make the TORCH_WARN message showing the correct result +// +// FIXME: revert this reuse when removing AOTriton <= 0.11 support +// +// AOTriton 0.12 supports hdim_qk != hdim_vo, but we cannot enable this in +// check_head_dim_size_flash because it changes the backend selection logic for +// FA, which can break certain workloads that rely on the behavior of rejecting +// FA for hdim_qk != hdim_vo +template +bool check_head_dim_size_flash(sdp_params const& params, bool debug) { +#if USE_ROCM_ATTENTION + if (at::cuda::device_count() == 0) { + return false; + } + const auto max_size = c10::SymInt(aotriton_max_hdim()); #else // All head_dim sizes must be equal and less than 256 const auto max_size = c10::SymInt(256); @@ -245,9 +258,20 @@ bool check_head_dim_size_flash_nested(sdp_params const& params, bool debug) { } bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) { +#if USE_ROCM_ATTENTION +#if AOTRITON_VERSION_CURRENT < AOTRITON_VERSION_INT(0, 12) + return check_head_dim_size_flash_nested(params, debug); +#endif +#endif const auto query_size_last = params.query.sym_size(-1); const auto value_size_last = params.value.sym_size(-1); +#ifdef USE_ROCM + bool is_half = (params.query.dtype() == at::kHalf) || + (params.query.dtype() == at::kBFloat16); + const int64_t alignment = is_half ? 8 : 4; +#else const int64_t alignment = minimum_gemm_alignment(params); +#endif if (!(query_size_last == params.key.sym_size(-1) && query_size_last % alignment == 0 && query_size_last > 0 && value_size_last % alignment == 0 && value_size_last > 0)) { @@ -266,6 +290,27 @@ bool check_head_dim_size_mem_efficient(sdp_params const& params, bool debug) { } return false; } +#if USE_ROCM_ATTENTION +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) + const auto max_size = c10::SymInt(aotriton_max_hdim()); + if (!(query_size_last <= max_size && value_size_last <= max_size)) { + if (debug) { + TORCH_WARN( + "Mem efficient attention on ROCM requires last dimension of inputs to less or equal than ", + max_size, + ". ", + "Got Query.size(-1): ", + query_size_last, + ", Key.size(-1): ", + params.key.sym_size(-1), + ", Value.size(-1): ", + params.value.sym_size(-1), + " instead. (Note this limit differs among architectures)"); + } + return false; + } +#endif // AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) +#endif // USE_ROCM_ATTENTION return true; } @@ -301,6 +346,7 @@ bool check_flash_attention_hardware_support(sdp_params const& params, bool debug using sm121 = SMVersion<12, 1>; #if USE_ROCM #if USE_ROCM_ATTENTION +// TODO: gfx1250 if(at::globalContext().getROCmFAPreferredBackend() == at::ROCmFABackend::Ck) { // User explicitly set CK as the flash attention backend. Return true for now // TODO: Flesh out sanity checks @@ -867,11 +913,7 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { check_all_tensors_on_device, check_mem_efficient_hardware_support, check_tensor_shapes, -#ifdef USE_ROCM - check_head_dim_size_flash -#else check_head_dim_size_mem_efficient -#endif ); for (auto& constraint : general_constraints) { if (!constraint(params, debug)) { @@ -881,11 +923,6 @@ bool can_use_mem_efficient_attention(sdp_params const& params, bool debug) { if (has_for_nested_inputs(params)) { constexpr auto nested_constraints = c10::array_of( -#ifndef USE_ROCM // ME and FA shares backend on ROCM and thus supports training - check_requires_grad_and_nested, -#else // Meanwhile ME on ROCM share the limits of FA about head dimensions - check_head_dim_size_flash_nested, -#endif check_batch_size_nested, check_for_seq_len_0_nested_tensor); for (auto& constraint : nested_constraints) { diff --git a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h index e40376ae0c3a7..c0cac2a76a2a5 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_adapter.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_adapter.h @@ -11,6 +11,10 @@ #include #include +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) +#define AOTRITON_V2_API_FLASH_ATTN_H // Suppress the include of deprecated flash/v2.h +#endif + //////////////////////////////////////////////////////////////////////////////// // Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h //////////////////////////////////////////////////////////////////////////////// @@ -127,8 +131,17 @@ struct LazyTensorContext { template struct LazyTensorFunctions : public LazyTensorContext { - static aotriton::TensorView acquire(void* cookie) { - auto ctx = (LazyTensorContext*)cookie; +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) + using HolderType = aotriton::LazyTensor; +#else + using HolderType = void; +#endif + static aotriton::TensorView acquire(HolderType* self) { +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) + auto ctx = (LazyTensorContext*)self->cookie; +#else + auto ctx = (LazyTensorContext*)self; +#endif if (!ctx->tensor.defined()) { auto q = ctx->like_tensor; if constexpr (kRequireZeros) { @@ -141,7 +154,7 @@ struct LazyTensorFunctions : public LazyTensorContext { return mk_aotensor(ctx->tensor, ctx->tensor_name); } - static void dispose(void* cookie) { + static void dispose(HolderType* cookie) { } }; diff --git a/aten/src/ATen/native/transformers/hip/aotriton_versions.h b/aten/src/ATen/native/transformers/hip/aotriton_versions.h index 2f5d3f0e12228..8b06cc1acc789 100644 --- a/aten/src/ATen/native/transformers/hip/aotriton_versions.h +++ b/aten/src/ATen/native/transformers/hip/aotriton_versions.h @@ -17,4 +17,10 @@ #define AOTRITON_V3_API 0 #endif +#if AOTRITON_VERSION_CURRENT >= AOTRITON_VERSION_INT(0, 12) +#define AOTRITON_COMPACT_VARLEN_LSE 1 +#else +#define AOTRITON_COMPACT_VARLEN_LSE 0 +#endif + #endif diff --git a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip index e809f23e61def..c16f7d1aad233 100644 --- a/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip +++ b/aten/src/ATen/native/transformers/hip/flash_attn/aot/mha_all_aot.hip @@ -257,7 +257,6 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x } hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -270,56 +269,42 @@ mha_fwd_aot(const at::Tensor &q, // batch_size x seqlen_q x num_heads x auto seed_output = mk_philoxtensor(use_philox_state ? seed_t.data_ptr() : nullptr); auto offset_output = mk_philoxtensor(use_philox_state ? offset_t.data_ptr() : nullptr); auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - if (uses_swa || AOTRITON_ALWAYS_V3_API) { -#if AOTRITON_V3_API - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_fwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.Sm_scale = softmax_scale; - params.L = mk_aotensor<2>(M, "M"); - params.Out = mk_aotensor(output_t, "Out"); - params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = seed; - params.philox_offset1 = offset1; - params.philox_offset2 = offset2; - params.philox_seed_output = seed_output; - params.philox_offset_output = offset_output; - params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); - params.persistent_atomic_counter = persistent_counter; - params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; - params.varlen_type = VarlenType::None; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_fwd(params, - aotriton::v3::flash::attn_fwd_params::kVersion, - stream); -#endif - } else { - err = attn_fwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(output_t, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); - } - - return {out, q_padded, k_padded, v_padded, M.view({batch_size, num_heads, seqlen_q}), seed_t, offset_t, softmax_fa_t}; + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.L = mk_aotensor<2>(M, "M"); + params.Out = mk_aotensor(output_t, "Out"); + params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + params.varlen_type = VarlenType::None; + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); + // Note: These are propagated up to the return of mha_fwd(). comments + // represent the assignments at that level + return {out, // output + q_padded, // q_padded + k_padded, // k_padded + v_padded, // v_padded + M.view({batch_size, num_heads, seqlen_q}), // logsumexp + seed_t, // philox_seed + offset_t, // philox_offset + softmax_fa_t};// debug_attn_mask } std::tuple @@ -342,7 +327,7 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot std::optional window_size_right, const bool return_softmax, const std::optional& gen_) { - TORCH_CHECK(!seqused_k.has_value(), "[ROCm] mha_varlen_fwd: seqused_k must be nullopt"); + bool strided_varlen = seqused_k.has_value(); const bool paged_KV = block_table_.has_value(); TORCH_CHECK(!paged_KV, "[ROCm] mha_varlen_fwd: block_table_ must be nullopt"); TORCH_CHECK(!alibi_slopes_.has_value(), "[ROCm] mha_varlen_fwd: alibi_slopes_ must be nullopt"); @@ -417,8 +402,13 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto opts = q.options(); +#if AOTRITON_COMPACT_VARLEN_LSE + auto softmax_lse = at::empty({num_heads, total_q}, opts.dtype(at::kFloat)); + at::Tensor M = softmax_lse; +#else auto softmax_lse = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); at::Tensor M = softmax_lse.view({batch_size * num_heads, max_seqlen_q}); +#endif at::Tensor softmax_fa_t; // Only return softmax if there's dropout to reduce compilation time if (return_softmax) { @@ -457,7 +447,6 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot if (max_seqlen_k > 0) { hipError_t err; // TODO: Error handling - using aotriton::v2::flash::attn_fwd_compact_varlen; using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; using sdp::aotriton_adapter::mk_philoxtensor; @@ -475,60 +464,48 @@ mha_varlen_fwd_aot(const at::Tensor &q, // total_q x num_heads x head_size, tot auto seed_output = use_philox_state ? mk_philoxtensor(seed_t.data_ptr()) : nullscalar; auto offset_output = use_philox_state ? mk_philoxtensor(offset_t.data_ptr()) : nullscalar; auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr() : nullptr); - if (uses_swa || AOTRITON_ALWAYS_V3_API) { -#if AOTRITON_V3_API - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_fwd_params params; - params.Q = mk_aotensor(q_padded, "q"); - params.K = mk_aotensor(k_padded, "k"); - params.V = mk_aotensor(v_padded, "v"); - params.Sm_scale = softmax_scale; - params.L = mk_aotensor<2>(M, "M"); - params.Out = mk_aotensor(out_padded, "Out"); + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_fwd_params params; + params.Q = mk_aotensor(q_padded, "q"); + params.K = mk_aotensor(k_padded, "k"); + params.V = mk_aotensor(v_padded, "v"); + params.Sm_scale = softmax_scale; + params.L = mk_aotensor<2>(M, "logsumexp"); + params.Out = mk_aotensor(out_padded, "Out"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = seed; + params.philox_offset1 = offset1; + params.philox_offset2 = offset2; + params.philox_seed_output = seed_output; + params.philox_offset_output = offset_output; + params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); + params.persistent_atomic_counter = persistent_counter; + params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; + params.varlen_type = strided_varlen ? VarlenType::StridedVarlen : VarlenType::CompactVarlen; + if (strided_varlen) { + // seqused_k holds per-batch actual kv lengths; the kernel expects cumulative + // offsets for cu_seqlens_k so it can compute seqlen_k via differencing. + // seq_strides_k carries the real memory offsets into the KV cache. + const auto& seqused = seqused_k.value(); + const int num_seqs = seqused.size(0); + at::Tensor cu_seqlens_k_from_seqused = at::zeros({num_seqs + 1}, seqused.options()); + cu_seqlens_k_from_seqused.slice(0, 1).copy_(seqused.cumsum(0)); params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = seed; - params.philox_offset1 = offset1; - params.philox_offset2 = offset2; - params.philox_seed_output = seed_output; - params.philox_offset_output = offset_output; - params.encoded_softmax = mk_aotensor(softmax_fa_t, "encoded_softmax"); - params.persistent_atomic_counter = persistent_counter; - params.causal_type = is_causal ? CausalType::WindowedAttention : CausalType::None; - params.varlen_type = VarlenType::CompactVarlen; - params.window_left = window_left; - params.window_right = window_right; - err = aotriton::v3::flash::attn_fwd(params, - aotriton::v3::flash::attn_fwd_params::kVersion, - stream); -#endif + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k_from_seqused, "cu_seqlens_k"); + params.seq_strides_q = mk_aotensor<1>(cu_seqlens_q, "seq_strides_q"); + params.seq_strides_k = mk_aotensor<1>(cu_seqlens_k, "seq_strides_k"); } else { - err = attn_fwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - empty_bias, - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - softmax_scale, - mk_aotensor<2>(M, "M"), - mk_aotensor(out_padded, "Out"), - p_dropout, - seed, - offset1, - offset2, - seed_output, - offset_output, - mk_aotensor(softmax_fa_t, "encoded_softmax"), - is_causal, - persistent_counter, - stream); + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); } + params.window_left = window_left; + params.window_right = window_right; + err = aotriton::v3::flash::attn_fwd(params, + aotriton::v3::flash::attn_fwd_params::kVersion, + stream); } else { // If seqlen_k == 0, then we have an empty tensor. We need to set the output to 0. out.zero_(); @@ -670,96 +647,44 @@ mha_bwd_aot(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x hea hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa || AOTRITON_ALWAYS_V3_API) { -#if AOTRITON_V3_API - // Fused BWD does not support SWA - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_bwd_params params; - params.Q = mk_aotensor(q_t, "q"); - params.K = mk_aotensor(k_t, "k"); - params.V = mk_aotensor(v_t, "v"); - params.Sm_scale = softmax_scale; - params.Out = mk_aotensor(out_t, "out"); - params.DO = mk_aotensor(dout_t, "dout"); - params.DQ = mk_aotensor(dq_t, "dq"); - params.DK = mk_aotensor(dk_t, "dk"); - params.DV = mk_aotensor(dv_t, "dv"); - params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = mk_aoscalartensor(philox_seed); - params.philox_offset1 = mk_aoscalartensor(philox_offset); - params.philox_offset2 = 0; - // SWA in AOTriton Kernels is treated as "Generalized Causal masks" - params.causal_type = is_causal || uses_swa ? CausalType::WindowedAttention : CausalType::None; - params.window_left = window_left; - params.window_right = window_right; - params.varlen_type = VarlenType::None; -#if AOTRITON_ALWAYS_V3_API - using sdp::aotriton_adapter::mklazy_empty_like; - using sdp::aotriton_adapter::mklazy_fp32zeros; - using sdp::aotriton_adapter::LazyTensorContext; - LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; - LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; - params.D = mklazy_empty_like<2>(&lazy_delta); - params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); -#else - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); - params.D = mk_aotensor<2>(delta, "delta"); -#endif - err = aotriton::v3::flash::attn_bwd(params, - aotriton::v3::flash::attn_bwd_params::kVersion, - stream); -#endif - } else if (use_fused_bwd) { - using aotriton::v2::flash::attn_bwd_fused; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd_fused(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - empty_bias, // dbb - mk_aotensor<2>(softmax_lse_cont, "L"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } else { - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); - using aotriton::v2::flash::attn_bwd; - using sdp::aotriton_adapter::cast_dtype; - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd(mk_aotensor(q_t, "q"), - mk_aotensor(k_t, "k"), - mk_aotensor(v_t, "v"), - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_t, "dq"), - mk_aotensor(dk_t, "dk"), - mk_aotensor(dv_t, "dv"), - empty_bias, // db - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } + // Fused BWD does not support SWA + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_t, "q"); + params.K = mk_aotensor(k_t, "k"); + params.V = mk_aotensor(v_t, "v"); + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DQ = mk_aotensor(dq_t, "dq"); + params.DK = mk_aotensor(dk_t, "dk"); + params.DV = mk_aotensor(dv_t, "dv"); + params.L = mk_aotensor<2>(softmax_lse_cont, "L"); + params.Max_seqlen_q = seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + params.causal_type = is_causal || uses_swa ? CausalType::WindowedAttention : CausalType::None; + params.window_left = window_left; + params.window_right = window_right; + params.varlen_type = VarlenType::None; + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_t, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); + aotriton::v3::flash::attn_options options; + options.deterministic = deterministic; + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream, + &options); return { dq, dk, dv, softmax_d }; } @@ -842,7 +767,11 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size CHECK_SHAPE(cu_seqlens_q, batch_size + 1); CHECK_SHAPE(cu_seqlens_k, batch_size + 1); +#if AOTRITON_COMPACT_VARLEN_LSE + at::Tensor softmax_lse_cont = softmax_lse.view({num_heads, total_q}).contiguous(); +#else at::Tensor softmax_lse_cont = softmax_lse.view({batch_size * num_heads, max_seqlen_q}).contiguous(); +#endif at::Tensor q_padded, k_padded, v_padded; q_padded = q.unsqueeze(0).transpose(1, 2); @@ -923,79 +852,45 @@ mha_varlen_bwd_aot(const at::Tensor &dout, // total_q x num_heads, x head_size hipError_t err; // TODO: Error handling using sdp::aotriton_adapter::mk_aotensor; using sdp::aotriton_adapter::mk_aoscalartensor; - if (uses_swa || AOTRITON_ALWAYS_V3_API) { -#if AOTRITON_V3_API - using aotriton::v3::flash::CausalType; - using aotriton::v3::flash::VarlenType; - aotriton::v3::flash::attn_bwd_params params; - params.Q = mk_aotensor(q_padded, "q"); - params.K = mk_aotensor(k_padded, "k"); - params.V = mk_aotensor(v_padded, "v"); - params.Sm_scale = softmax_scale; - params.Out = mk_aotensor(out_t, "out"); - params.DO = mk_aotensor(dout_t, "dout"); - params.DK = mk_aotensor(dk_padded, "dk"); - params.DV = mk_aotensor(dv_padded, "dv"); - params.DQ = mk_aotensor(dq_padded, "dq"); - params.L = mk_aotensor<2>(softmax_lse_cont, "L"); - params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); - params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); - params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty - params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty - params.dropout_p = p_dropout; - params.philox_seed_ptr = mk_aoscalartensor(philox_seed); - params.philox_offset1 = mk_aoscalartensor(philox_offset); - params.philox_offset2 = 0; - // SWA in AOTriton Kernels is treated as "Generalized Causal masks" - params.causal_type = is_causal || uses_swa ? CausalType::WindowedAttention : CausalType::None; - params.varlen_type = VarlenType::CompactVarlen; - params.window_left = window_left; - params.window_right = window_right; -#if AOTRITON_ALWAYS_V3_API - using sdp::aotriton_adapter::mklazy_empty_like; - using sdp::aotriton_adapter::mklazy_fp32zeros; - using sdp::aotriton_adapter::LazyTensorContext; - LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; - LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" }; - params.D = mklazy_empty_like<2>(&lazy_delta); - params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); -#else - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); - params.D = mk_aotensor<2>(delta, "delta"); -#endif - err = aotriton::v3::flash::attn_bwd(params, - aotriton::v3::flash::attn_bwd_params::kVersion, - stream); -#endif // AOTRITON_ALWAYS_V3_API - } else { - using aotriton::v2::flash::attn_bwd_compact_varlen; - using sdp::aotriton_adapter::cast_dtype; - at::Tensor delta = at::empty_like(softmax_lse_cont).contiguous(); - aotriton::TensorView<4> empty_bias(0, {0,0,0,0}, {0,0,0,0}, cast_dtype(q.dtype())); - err = attn_bwd_compact_varlen(mk_aotensor(q_padded, "q"), - mk_aotensor(k_padded, "k"), - mk_aotensor(v_padded, "v"), - mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"), - mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"), - max_seqlen_q, - max_seqlen_k, - empty_bias, - softmax_scale, - mk_aotensor(out_t, "out"), - mk_aotensor(dout_t, "dout"), - mk_aotensor(dq_padded, "dq"), - mk_aotensor(dk_padded, "dk"), - mk_aotensor(dv_padded, "dv"), - empty_bias, - mk_aotensor<2>(softmax_lse_cont, "L"), - mk_aotensor<2>(delta, "delta"), - p_dropout, - mk_aoscalartensor(philox_seed), - mk_aoscalartensor(philox_offset), - 0, - is_causal, - stream); - } + using aotriton::v3::flash::CausalType; + using aotriton::v3::flash::VarlenType; + aotriton::v3::flash::attn_bwd_params params; + params.Q = mk_aotensor(q_padded, "q"); + params.K = mk_aotensor(k_padded, "k"); + params.V = mk_aotensor(v_padded, "v"); + params.Sm_scale = softmax_scale; + params.Out = mk_aotensor(out_t, "out"); + params.DO = mk_aotensor(dout_t, "dout"); + params.DK = mk_aotensor(dk_padded, "dk"); + params.DV = mk_aotensor(dv_padded, "dv"); + params.DQ = mk_aotensor(dq_padded, "dq"); + params.L = mk_aotensor<2>(softmax_lse_cont, "L"); + params.cu_seqlens_q = mk_aotensor<1>(cu_seqlens_q, "cu_seqlens_q"); + params.cu_seqlens_k = mk_aotensor<1>(cu_seqlens_k, "cu_seqlens_k"); + params.Max_seqlen_q = max_seqlen_q; // Unused if cu_seqlens_q is empty + params.Max_seqlen_k = max_seqlen_k; // Unused if cu_seqlens_k is empty + params.dropout_p = p_dropout; + params.philox_seed_ptr = mk_aoscalartensor(philox_seed); + params.philox_offset1 = mk_aoscalartensor(philox_offset); + params.philox_offset2 = 0; + // SWA in AOTriton Kernels is treated as "Generalized Causal masks" + params.causal_type = is_causal || uses_swa ? CausalType::WindowedAttention : CausalType::None; + params.varlen_type = VarlenType::CompactVarlen; + params.window_left = window_left; + params.window_right = window_right; + using sdp::aotriton_adapter::mklazy_empty_like; + using sdp::aotriton_adapter::mklazy_fp32zeros; + using sdp::aotriton_adapter::LazyTensorContext; + LazyTensorContext lazy_delta { .like_tensor = softmax_lse_cont, .tensor_name = "delta" }; + LazyTensorContext lazy_dq_acc { .like_tensor = dq_padded, .tensor_name = "dq_acc" }; + params.D = mklazy_empty_like<2>(&lazy_delta); + params.DQ_ACC = mklazy_fp32zeros<4>(&lazy_dq_acc); + aotriton::v3::flash::attn_options options; + options.deterministic = deterministic; + err = aotriton::v3::flash::attn_bwd(params, + aotriton::v3::flash::attn_bwd_params::kVersion, + stream, + &options); } else { // If seqlen_q == 0, then we have an empty tensor. We need to set the output to 0. dq.zero_(); diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt index 0a419e46a5bce..ff519ce57b663 100644 --- a/caffe2/CMakeLists.txt +++ b/caffe2/CMakeLists.txt @@ -1461,6 +1461,9 @@ if(USE_ROCM) if(USE_ROCM_CK_SDPA) target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_SDPA) endif() + if(USE_ROCM_CK_GEMM) + target_compile_definitions(torch_hip PRIVATE USE_ROCM_CK_GEMM) + endif() endif() if(BUILD_LITE_INTERPRETER) @@ -1792,6 +1795,10 @@ if(USE_ROCM) target_link_libraries(torch_hip PRIVATE ck_sdpa) endif() + if(USE_ROCM_CK_GEMM) + target_link_libraries(torch_hip PRIVATE ck_gemm) + endif() + if(USE_MSLK) if(USE_ROCM) target_link_libraries(torch_hip PRIVATE mslk) diff --git a/cmake/Dependencies.cmake b/cmake/Dependencies.cmake index 203cdc7c029db..815ddc2d814e6 100644 --- a/cmake/Dependencies.cmake +++ b/cmake/Dependencies.cmake @@ -1019,9 +1019,6 @@ if(USE_ROCM) if(HIPBLASLT_VEC_EXT) list(APPEND HIP_CXX_FLAGS -DHIPBLASLT_VEC_EXT) endif() - if(USE_ROCM_CK_GEMM) - list(APPEND HIP_CXX_FLAGS -DUSE_ROCM_CK_GEMM) - endif() list(APPEND HIP_HIPCC_FLAGS --offload-compress) list(APPEND HIP_HIPCC_FLAGS -std=c++17) # Pass device library path for theRock nightly builds diff --git a/cmake/External/aotriton.cmake b/cmake/External/aotriton.cmake index 93f852f014887..91f9daf18565b 100644 --- a/cmake/External/aotriton.cmake +++ b/cmake/External/aotriton.cmake @@ -9,47 +9,55 @@ if(NOT __AOTRITON_INCLUDED) # Replaces .ci/docker/aotriton_version.txt # Note packages information may have versions skipped (due to no ABI breaks) # But they must be listed from lower version to higher version - set(__AOTRITON_VER "0.11.2b") + set(__AOTRITON_VER "0.12.50tp2") set(__AOTRITON_MANYLINUX_LIST - "manylinux_2_28" # rocm6.2 - "manylinux_2_28" # rocm6.3 "manylinux_2_28" # rocm6.4 "manylinux_2_28" # rocm7.0 "manylinux_2_28" # rocm7.1 "manylinux_2_28" # rocm7.2 + "manylinux_2_28" # rocm7.14 ) set(__AOTRITON_ROCM_LIST - "rocm6.2" - "rocm6.3" "rocm6.4" "rocm7.0" "rocm7.1" "rocm7.2" + "rocm7.14" ) - set(__AOTRITON_CI_COMMIT "dd1b68b604b5258ee7a9f7b66ad95e7a82c18065") + if(DEFINED ENV{PYTORCH_AOTRITON_COMMIT}) + set(__AOTRITON_CI_COMMIT "$ENV{PYTORCH_AOTRITON_COMMIT}") + else() + set(__AOTRITON_CI_COMMIT "05f26c2d8fbdcf8dd6d95b9544d04f7a39fc9920") + endif() set(__AOTRITON_SHA256_LIST - "d784314849ba1911181dfc80cd845064ff6f0cdad10e2f4c53eb84a8b89245b9" # rocm6.2 - "f4b14dc111c334e967b28a1cf9ed4c63264c634dbdccbb5849aa9490022992f7" # rocm6.3 - "6b51d8479c85b902334e4f5518f404a8f5d563fd8d4732cb8b621ed4b45c2876" # rocm6.4 - "5501a0a3b300890001b6625f2a3539a7bad60f386f0a061ebe7d4ed5ca0fafb9" # rocm7.0 - "fee36beb3ea484ce18155bbafe026c577fd6705e4469e59405b260bd74b8cc10" # rocm7.1 - "cd8abf27bbb63cec45c94135e9b28745966074263a6b0555e5878ae1cb6a2349" # rocm7.2 + "70020f5938d84e2dbb7cd98a4ca9cb46c4c0a61b683611c61dac20e99c1ee377" # rocm6.4 + "2a334e29f316e5a40766dbcf0b3075f7c73cce57f5044b39b7c3a7b33a7826bd" # rocm7.0 + "cc5894fed1dd44464a6187cafb82e61d9e4034b3035257c0fdfa93d4f0e906cc" # rocm7.1 + "af4ca6c5889c7ba302659115051ea1bfe6a09fb89f311dae214a6f942e93f605" # rocm7.2 + "21895066db0c0e2079f5483d07fbfde659f8de906c589bbfbc774bb9b4565e11" # rocm7.14 ) set(__AOTRITON_IMAGE_LIST "amd-gfx90a" "amd-gfx942" "amd-gfx950" - "amd-gfx11xx" + "amd-gfx110x" + "amd-gfx115x" "amd-gfx120x" + "amd-gfx1250" ) set(__AOTRITON_IMAGE_SHA256_LIST - "fe9f04b66bf52ac27cd025e1d89cfd04974dd3fb3ae076192f783641a4d80fdf" # amd-gfx90a - "0a7bcee19d3bb6d548732248c3234f7b92736c2ab7a7aae65294b87a7fd64c06" # amd-gfx942 - "c1ba3bfe84217fd67df3dd1f8b67c80a7f7b33d0ad4d74b41d6567036e032ace" # amd-gfx950 - "839299637fccb13fbe3e7823d57d1b2dcd0e0bed78abbcb7005ea5f4fd82b928" # amd-gfx11xx - "0a4ff324bffdac0c2fde87a8a7f70563d3c84a80ad4e8f31345f2b40a1384e95" # amd-gfx120x + "bb8bf2237b77fc503bc2967ea0d99d6ca419126c479e951ea42b712737128086" # amd-gfx90a + "f08edacf83c9ccf1c4bdcb51f1cab052d1680abea31c9e035f3f9fadb2f13ba4" # amd-gfx942 + "307a37d729cda3a2120449909e5192cd71c2badccbd37f0222786098e69c7a91" # amd-gfx950 + "c9cac7cf6f277168e1659ac2f04706f8823580b7c7e3e895f5a5503ed6bdd55f" # amd-gfx110x + "3177387a15c678b30057f4584d1fc1b8f8db56163890cb5c98f27450209f5a7b" # amd-gfx115x + "68572511ce6487a83f9014bd255bd69c8943f87d0c93bd57b2daac5fbc6c79c1" # amd-gfx120x + "c6ed084f1dce1c963c17055e38bcbc41d8e0fc48390d8fff6e11782454b26dbc" # amd-gfx1250 ) - set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore + set(__AOTRITON_BASE_URL "$ENV{PYTORCH_AOTRITON_BASE_URL}") + if(NOT __AOTRITON_BASE_URL) + set(__AOTRITON_BASE_URL "https://github.com/ROCm/aotriton/releases/download/") # @lint-ignore + endif() set(__AOTRITON_Z "gz") # Set the default __AOTRITON_LIB path if(NOT WIN32) diff --git a/test/inductor/test_flex_attention.py b/test/inductor/test_flex_attention.py index 972e74c34dfa6..88d0f30f733bd 100644 --- a/test/inductor/test_flex_attention.py +++ b/test/inductor/test_flex_attention.py @@ -447,6 +447,20 @@ def batch_reserve(paged_attention: PagedAttention, target_seq_len: Tensor): ) +class TestFlexAttentionTDMOptions(InductorTestCase): + def test_apply_tdm_num_stages_uses_triton_launch_option(self): + from torch._inductor.kernel.flex.common import apply_tdm_num_stages + + kernel_options = {"num_stages": 1, "NUM_STAGES": 4} + + apply_tdm_num_stages(kernel_options) + + self.assertEqual( + kernel_options["num_stages"], config.tdm.max_outstanding_per_wave + ) + self.assertNotIn("NUM_STAGES", kernel_options) + + @large_tensor_test_class("2GB", device=test_device[0]) class TestFlexAttention(InductorTestCase): def setUp(self): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 5d3855aa73e93..2eccaf1d592d5 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -60,6 +60,7 @@ CUDAPersistentTMATemplateConfigHeuristic, GemmConfig, get_shared_memory_checker_opts, + ROCmMMTemplateConfigHeuristic, XPUMMTemplateConfigHeuristic, XPUPersistentTMATemplateConfigHeuristic, ) @@ -327,6 +328,122 @@ def mock_get_tma_workspace_arg(*args, **kwargs): mm_tma_heuristic.mm_configs = original_tma_configs mm_heuristic.mm_configs = original_mm_configs + def test_tdm_arch_gate_accepts_only_gfx1250(self): + from torch._inductor.utils import is_gfx1250_arch + + self.assertTrue(is_gfx1250_arch("gfx1250")) + self.assertTrue(is_gfx1250_arch("gfx1250:sramecc+:xnack-")) + self.assertFalse(is_gfx1250_arch("gfx1251")) + self.assertFalse(is_gfx1250_arch("amd-gfx1250")) + + def test_tdm_persistent_template_precedes_rocm_tma_fallback(self): + from torch._inductor.kernel import mm as mm_kernel + + for persistent_tma_enabled in (False, True): + templates = [] + with ( + config.patch( + { + "triton.enable_persistent_tma_matmul": persistent_tma_enabled + } + ), + mock.patch.object( + mm_kernel, + "use_triton_blackwell_tma_template", + return_value=False, + ), + mock.patch.object( + mm_kernel, + "use_triton_tdm_template", + return_value=True, + ), + mock.patch.object( + mm_kernel, + "use_triton_tma_template", + side_effect=AssertionError("TMA fallback should not run"), + ), + ): + selected = mm_kernel._append_persistent_mm_template( + templates, mock.Mock(), mock.Mock(), mock.Mock() + ) + + self.assertEqual(selected, "tdm") + self.assertEqual(templates, [mm_kernel.persistent_mm_template]) + + def test_tdm_template_add_guards_checks_compile_time_device(self): + from torch._inductor.utils import use_triton_tdm_template + + class FakeMatrix: + def __init__(self, device): + self.device = device + + def get_device(self): + return self.device + + def get_dtype(self): + return torch.float16 + + def get_stride(self): + return (128, 1) + + mat = FakeMatrix(torch.device("cuda", 0)) + with ( + config.patch({"enable_tdm_configs": True}), + mock.patch.object(torch.version, "hip", "7.2.0"), + mock.patch.object( + torch.cuda, + "get_device_properties", + return_value=mock.Mock(gcnArchName="gfx1250"), + ), + ): + self.assertTrue( + use_triton_tdm_template( + mat, + output_layout=mock.Mock(device=torch.device("cuda", 0)), + add_guards=True, + ) + ) + self.assertFalse( + use_triton_tdm_template( + mat, + output_layout=mock.Mock(device=torch.device("cuda", 1)), + add_guards=True, + ) + ) + + def test_tdm_block_k_filter_is_dtype_size_aware(self): + from torch._inductor.template_heuristics.triton import ( + _filter_tdm_block_k_configs, + ROCmGemmConfig, + ) + + configs = [ + ROCmGemmConfig(128, 64, 64, 4, 4, group_m=8), + ROCmGemmConfig(128, 64, 128, 4, 4, group_m=8), + ] + + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 2)], + [64, 128], + ) + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 1)], + [128], + ) + self.assertEqual( + [c.block_k for c in _filter_tdm_block_k_configs(configs, 4)], + [64, 128], + ) + + def test_shared_memory_estimation_counts_num_stages_once(self): + heuristic = ROCmMMTemplateConfigHeuristic() + gemm_config = GemmConfig(128, 64, 64, 4, 4, group_m=8) + + self.assertEqual( + heuristic.get_shared_memory_estimation(gemm_config, 2, False, 0), + (128 * 64 + 64 * 64) * 2 * 4 + 128, + ) + @unittest.skipIf( not has_triton_tma_device(), "Need device-side TMA support in Triton" ) diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 0d1d2427855f8..4c67c431319a1 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -655,7 +655,7 @@ def test_embedding_backward_dynamic_shapes_large_grid(self, device): max_grid_x = 2147483647 if torch.version.hip: - warp_size = 64 # TODO: query warp size once #129663 is merged + warp_size = torch.cuda.get_device_properties(device).warp_size # ROCm limits total threads (num_blocks * num_warps * warp_size) self.assertLessEqual( result_num_blocks * num_warps * warp_size, diff --git a/test/inductor/test_triton_kernels.py b/test/inductor/test_triton_kernels.py index e264c56add2ae..aa042f3b798fe 100644 --- a/test/inductor/test_triton_kernels.py +++ b/test/inductor/test_triton_kernels.py @@ -402,8 +402,8 @@ def can_fuse(self, scheduler, node1, node2, shared_data_score) -> bool: output_code = log_stream.getvalue() - FileCheck().check("del buf3").check( - "dual_output_kernel_with_inline_asm_0.run(x_1, buf0," + FileCheck().check_regex(r"del buf[0-9]*").check_regex( + r"dual_output_kernel_with_inline_asm_0\.run\(x_1, buf[0-9]*, buf[0-9]*," ).run(output_code) eager_result = f(t.clone())[0] @@ -1208,6 +1208,82 @@ def f(inp): compiled_out = torch.compile(f)(inp) self.assertEqual(compiled_out, eager_out) + @requires_gpu + def test_triton_kernel_mutates_strided_intermediate(self): + @triton.jit + def add_one_strided_kernel( + in_ptr, + out_ptr, + n_cols: tl.constexpr, + stride_b: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + row = tl.program_id(0) + offsets = tl.arange(0, BLOCK_N) + mask = offsets < n_cols + values = tl.load(in_ptr + row * stride_b + offsets, mask=mask, other=0.0) + tl.store(out_ptr + row * stride_b + offsets, values + 1.0, mask=mask) + + def triton_update(x): + out = x + add_one_strided_kernel[(x.size(0),)]( + x, + out, + x.size(1), + x.stride(0), + BLOCK_N=4096, + ) + return out + + def f(base): + x = (base + 1)[:, :4096] + return triton_update(x) + + base = torch.randn(64, 8192, device=GPU_TYPE) + eager_out = f(base.clone()) + compiled_out = torch.compile(f, fullgraph=True)(base.clone()) + + self.assertEqual(compiled_out, eager_out) + self.assertEqual(compiled_out.stride(), eager_out.stride()) + + @requires_gpu + def test_triton_kernel_mutates_expanded_intermediate_errors(self): + @triton.jit + def add_one_expanded_kernel( + in_ptr, + out_ptr, + n_elements: tl.constexpr, + stride_n: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + offsets = tl.arange(0, BLOCK_N) + mask = offsets < n_elements + value = tl.load(in_ptr) + tl.store(out_ptr + offsets * stride_n, value + 1.0, mask=mask) + + def triton_update(x): + out = x + add_one_expanded_kernel[(1,)]( + x, + out, + x.numel(), + x.stride(0), + BLOCK_N=16, + ) + return out + + def f(base): + x = (base + 1).expand(4) + return triton_update(x) + + base = torch.randn(1, device=GPU_TYPE) + with self.assertRaisesRegex( + RuntimeError, + "Cannot safely clone an internally overlapping mutated Triton kernel " + "argument.", + ): + torch.compile(f, fullgraph=True)(base.clone()) + @inductor_config.patch( triton_kernel_default_layout_constraint="needs_fixed_stride_order" ) @@ -2769,6 +2845,43 @@ def fn(a, b): ) from e raise + @requires_gpu + def test_constexpr_handling(self): + @triton.jit + def copy_kernel( + src_ptr, + dst_ptr, + n_elements, + stride, + BLOCK_SIZE: tl.constexpr, + ): + pid = tl.program_id(0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < n_elements + x = tl.load(src_ptr + offs * stride, mask=mask) + tl.store(dst_ptr + offs * stride, x, mask=mask) + + t = torch.randn(1024, device=GPU_TYPE) + out = torch.empty(1024, device=GPU_TYPE) + + kwargs = { + "src_ptr": t, + "dst_ptr": out, + "n_elements": 1024, + "stride": 1, + "BLOCK_SIZE": 256, + } + + ttir_module, _ = generate_ttir(copy_kernel, kwargs, tma_descriptor_metadata={}) + ttir_str = str(ttir_module) + + # `constexpr` values get inlined, and do not appear as function parameters. + self.assertIn("src_ptr", ttir_str) + self.assertIn("dst_ptr", ttir_str) + self.assertIn("n_elements", ttir_str) + self.assertIn("stride", ttir_str) + self.assertNotIn("BLOCK_SIZE", ttir_str) + def make_mutation_test(fn): @requires_gpu @@ -3152,7 +3265,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -3222,7 +3335,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) @@ -3258,7 +3371,7 @@ def add_4_times_kernel( in_ptr0, in_ptr1, out_ptr, - n_elements, + n_elements: "tl.constexpr", BLOCK_SIZE: "tl.constexpr", ): pid = tl.program_id(axis=0) diff --git a/test/test_cuda.py b/test/test_cuda.py index 5dd2a7346c79b..bbcd369aa2ca0 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -684,7 +684,7 @@ def _check_default(): gcn_arch = str( torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] ) - if gcn_arch in ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201"]: + if gcn_arch in ["gfx90a", "gfx942", "gfx950", "gfx1200", "gfx1201", "gfx1250"]: self.assertTrue(default == torch._C._BlasBackend.Cublaslt) else: self.assertTrue(default == torch._C._BlasBackend.Cublas) @@ -754,7 +754,7 @@ def test_cublas_workspace_explicit_allocation(self): gcn_arch = str( torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0] ) - if "gfx94" in gcn_arch or "gfx95" in gcn_arch: + if "gfx94" in gcn_arch or "gfx95" in gcn_arch or "gfx1250" in gcn_arch: default_workspace_size = 1024 * 128 * 1024 # :1024:128 else: default_workspace_size = ( @@ -8011,9 +8011,13 @@ def test_compile_kernel_large_shared_memory(self): # Test error handling with more than supported shared memory size if torch.version.hip: - max_smem = ( - 65536 if get_device_properties().gcnArchName != "gfx950" else 160 * 1024 - ) + arch_name = get_device_properties().gcnArchName + if "gfx1250" in arch_name: + max_smem = 320 * 1024 + elif "gfx950" in arch_name: + max_smem = 160 * 1024 + else: + max_smem = 65536 else: max_smem = get_device_properties().shared_memory_per_block_optin excessive_shared_mem = max_smem * 2 diff --git a/test/test_linalg.py b/test/test_linalg.py index 25a157343db15..fc2e91c9db3cd 100644 --- a/test/test_linalg.py +++ b/test/test_linalg.py @@ -65,6 +65,10 @@ def blaslt_supported_device(): archs.extend(['gfx110', 'gfx120']) if ROCM_VERSION >= (6, 5): archs.append('gfx95') + # We extend this in a way not assuming the exact ROCm version + # where gfx1250 lands, so we don't pin a ROCm_VERSION gate here + # because the exact landing version may shift. + archs.append('gfx1250') for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py index 287ee3cb7e421..87830f8b0750f 100644 --- a/test/test_tensor_creation_ops.py +++ b/test/test_tensor_creation_ops.py @@ -2789,6 +2789,24 @@ def test_range_factories_64bit_indexing(self, device): self.assertEqual(t[-1].item(), 2) del t + # On ROCm, launches with gridDim.x * blockDim.x >= 2^32 are not + # supported and either return hipErrorInvalidConfiguration or fail + # silently. Exercise the just-over case (~16 GB at int32) and a + # far-above case (~33 GB) when memory permits. arange computes + # int64 values then casts down, so for int32 the trailing + # 2^32 - 2, 2^32 - 1, 2^32 wrap to -2, -1, 0. + if TEST_WITH_ROCM: + for bigint in (2 ** 32 + 1, 2 ** 33 + 1): + free, _ = torch.cuda.mem_get_info(device) + if free < bigint * 4 + (3 << 30): + continue + t = torch.arange(bigint, dtype=torch.int32, device=device) + self.assertEqual(t.numel(), bigint) + self.assertEqual( + t[-3:].cpu(), torch.tensor([-2, -1, 0], dtype=torch.int32) + ) + del t + @expectedFailureMeta # RuntimeError: The tensor has a non-zero number of elements @onlyNativeDeviceTypes def test_tensor_ctor_device_inference(self, device): diff --git a/test/test_transformers.py b/test/test_transformers.py index 284eb5ad64704..fde79dcaead95 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3131,8 +3131,6 @@ def test_scaled_dot_product_attention_fused_kernels_packed(self, device, type: s @parametrize("type", ["nested"]) @parametrize("is_contiguous", [True, False]) def test_scaled_dot_product_attention_cudnn_nested(self, device, type: str, is_contiguous: bool): - if TEST_WITH_ROCM and type == 'nested': - self.skipTest("ROCM does not support efficient attention on nested tensors, for now") make_tensor = partial(rand_sdpa_tensor, type=type, device=device, dtype=torch.float16, packed=True) batch_size, seq_len, num_heads, head_dim = 8, 64, 16, 64 @@ -3430,7 +3428,6 @@ def compiled_func(order): reset_order = torch._C._get_sdp_priority_order() self.assertEqual(default_order, reset_order, "expected SDPA context manager to reset priority order.") - @skipIfRocm # Missing deterministic algo @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_ATTENTION, "Fused SDPA was not built for this system") @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) @parametrize("warn_only", [True, False]) @@ -4147,9 +4144,6 @@ def test_fused_kernels_nested_broadcasting( rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=dtype) batch, num_heads, head_dim = 32, 8, 64 head_dim_v = 32 if is_efficient else head_dim - if TEST_WITH_ROCM and head_dim != head_dim_v: - self.skipTest("head_dim != head_dim_v unsupported on ROCm for now") - return seq_lens_q = (torch.randint(low=1, high=5, size=(1,)).item() if expand_q_batch else torch.randint(low=1, high=32, size=(batch,)).tolist()) @@ -4213,7 +4207,6 @@ def _broadcast(t, batch_broadcasted, num_heads_broadcasted): self.assertEqual(actual.contiguous(), math_ref.contiguous().to(dtype), atol=1.5e-3, rtol=1e-2) - @skipIfRocm(msg="Efficient Attention on ROCM does not support head_dim != head_dim_v for now.") @unittest.skipIf(not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Fused SDPA was not built for this system") def test_fused_kernels_nested_broadcasting_query_dense(self, device): rand_nested_tensor = partial(rand_sdpa_tensor, type="nested", device=device, dtype=torch.float32) diff --git a/torch/_higher_order_ops/triton_kernel_wrap.py b/torch/_higher_order_ops/triton_kernel_wrap.py index ae1e13bd1faa2..b3480ac818756 100644 --- a/torch/_higher_order_ops/triton_kernel_wrap.py +++ b/torch/_higher_order_ops/triton_kernel_wrap.py @@ -9,18 +9,19 @@ import threading from collections import defaultdict from collections.abc import Callable, Sequence -from typing import Any, Optional, TYPE_CHECKING, Union +from typing import Any, cast, Optional, TYPE_CHECKING, Union from typing_extensions import Never import sympy +import torch import torch.fx as fx import torch.utils._pytree as pytree from torch import SymInt, Tensor from torch._C import DispatchKey from torch._higher_order_ops.utils import redirect_to_mode from torch._ops import HigherOrderOperator -from torch._prims_common import clone_preserve_strides +from torch._prims_common import compute_required_storage_length from torch._subclasses.fake_tensor import FakeTensorMode from torch.fx.experimental.proxy_tensor import ( disable_proxy_modes_tracing, @@ -196,6 +197,23 @@ def reset_table(self) -> None: kernel_side_table = KernelSideTable() +def clone_preserve_strides_for_triton_kernel_wrapper(x: Tensor) -> Tensor: + storage_offset = cast(int, x.storage_offset()) + needed_size = compute_required_storage_length(x.size(), x.stride(), storage_offset) + if torch._debug_has_internal_overlap(x) == 1: + raise RuntimeError( + "Cannot safely clone an internally overlapping mutated Triton kernel " + "argument." + ) + buffer = torch.empty_strided((needed_size,), (1,), dtype=x.dtype, device=x.device) + out = torch.as_strided(buffer, x.size(), x.stride(), storage_offset) + # Copy logical elements only. Inductor may use a compact internal + # materialization for strided views, so flattening the source storage span + # can read past the realized buffer. + out.copy_(x) + return out + + ############################################################################### # Mutation Tracker @@ -330,11 +348,6 @@ def is_stable_tensor_descriptor_arg(arg: Any) -> bool: return True return False - def is_tensor_like_arg(arg: Any) -> bool: - if isinstance(arg, Tensor) or is_stable_tensor_descriptor_arg(arg): - return True - return False - # Note: one would expect that each input to the triton kernel maps to # one input parameter in the TTIR. This is _not_ true for TMA descriptors: # one TMA descriptor gets converted into: @@ -343,9 +356,13 @@ def is_tensor_like_arg(arg: Any) -> bool: # * N sizes, for a rank-N tensor # To account for this, we inject some fake arg names as placeholders for # the stride and size parameters. - def get_tensor_names(name: str, arg: Any) -> list[str]: - if isinstance(arg, Tensor): - return [name] + # + # Tensors and scalars both become single TTIR parameters, whereas + # `constexpr` are inlined. This matters for "odd" ordering + # (eg. [tensor, scalar, tensor]). + def get_arg_names(name: str, arg: Any, is_constexpr) -> list[str]: + if is_constexpr or arg is None: + return [] if is_stable_tensor_descriptor_arg(arg): stable_meta = maybe_unpack_tma_stable_metadata( tma_descriptor_metadata[name] @@ -358,11 +375,12 @@ def get_tensor_names(name: str, arg: Any) -> list[str]: names.extend(name + f" STRIDE PLACEHOLDER {i}" for i in range(tensor_rank)) names.extend(name + f" SIZE PLACEHOLDER {i}" for i in range(tensor_rank)) return names - return [] + return [name] - ordered_tensor_names = list( + ordered_arg_names = list( itertools.chain.from_iterable( - get_tensor_names(name, arg) for name, arg in ordered_args.items() + get_arg_names(name, arg, param.is_constexpr) + for (name, arg), param in zip(ordered_args.items(), kernel.params) ) ) @@ -452,8 +470,14 @@ def _native_specialize_impl( return attrs specialization = _get_specialization(ordered_args.values()) + + # Triton explicitly interprets ASTSource.constants entries as constexpr + # (triton-lang/triton#8248). Thus, only arguments marked `is_constexpr` + # should be treated as such, not just non-tensor-like arguments. constants = { - name: arg for name, arg in ordered_args.items() if not is_tensor_like_arg(arg) + (i,): arg + for i, ((_, arg), param) in enumerate(zip(ordered_args.items(), kernel.params)) + if param.is_constexpr } if (mangle_type := getattr(triton.runtime.jit, "mangle_type", None)) is not None: @@ -518,7 +542,7 @@ def get_signature_value(idx: int, arg: Any) -> str: if not ttir_module.verify(): raise RuntimeError("Verification for TTIR module has failed") - return ttir_module, ordered_tensor_names + return ttir_module, ordered_arg_names def ttir_to_functions( @@ -987,11 +1011,12 @@ def identify_mutated_tensors( 2) Parses the TTIR and creates a control flow graph 3) Analyzes the graph to detect all input tensor mutations """ + from torch._inductor.ir import TensorBox ttir_module = None functions = None try: - ttir_module, ordered_tensor_names = generate_ttir( + ttir_module, ordered_arg_names = generate_ttir( kernel, kwargs, tma_descriptor_metadata ) @@ -1014,11 +1039,13 @@ def identify_mutated_tensors( analyze_kernel_mutations.reset() get_tma_stores.reset() mutations = analyze_kernel_mutations( - functions, kernel_name, len(ordered_tensor_names) + functions, kernel_name, len(ordered_arg_names) ) return [ - ordered_tensor_names[i] for i, mutated in enumerate(mutations) if mutated + ordered_arg_names[i] + for i, mutated in enumerate(mutations) + if mutated and isinstance(kwargs[ordered_arg_names[i]], (Tensor, TensorBox)) ] except Exception: import torch._inductor.ir @@ -1339,11 +1366,15 @@ def triton_kernel_wrapper_functional_dense( tensors_to_clone: list[str], ) -> dict[str, Any]: # TODO(oulgen): For performance reasons, we want to ensure that these - # `clone_preserve_strides` calls are never executed at runtime + # strided clone calls are never executed at runtime # (inductor should always optimize them away). # Requires https://github.com/pytorch/pytorch/issues/109240 kwargs = { - key: (clone_preserve_strides(val) if key in tensors_to_clone else val) + key: ( + clone_preserve_strides_for_triton_kernel_wrapper(val) + if key in tensors_to_clone + else val + ) for key, val in kwargs.items() } triton_kernel_wrapper_mutation( @@ -1368,12 +1399,12 @@ def triton_kernel_wrapper_functional_fake_tensor_mode( tensors_to_clone: list[str], ) -> dict[str, Any]: # TODO(oulgen): For performance reasons, we want to ensure that these - # `clone_preserve_strides` calls are never executed at runtime + # strided clone calls are never executed at runtime # (inductor should always optimize them away). # Requires https://github.com/pytorch/pytorch/issues/109240 with mode: return { - key: clone_preserve_strides(val) + key: clone_preserve_strides_for_triton_kernel_wrapper(val) for key, val in kwargs.items() if key in tensors_to_clone } diff --git a/torch/_inductor/codegen/rocm/compile_command.py b/torch/_inductor/codegen/rocm/compile_command.py index aa935b14af23c..1ff830089cfa4 100644 --- a/torch/_inductor/codegen/rocm/compile_command.py +++ b/torch/_inductor/codegen/rocm/compile_command.py @@ -75,6 +75,10 @@ def _rocm_lib_options(dst_file_ext: str) -> list[str]: def _rocm_compiler_options() -> list[str]: + # `config.rocm.arch`is populated from either: + # - The `PYTORCH_ROCM_ARCH` environment variable, or + # - Runtime device detection. + # The string "native" tells `hipcc` to compile for the current GPU. arch_list = config.rocm.arch or ["native"] gpu_arch_flags = [f"--offload-arch={arch}" for arch in arch_list] opts = [ diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index e2fee26f45cc1..d379050857403 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -734,6 +734,18 @@ def use_autoheuristic(name: str) -> bool: force_layout_optimization = os.environ.get("TORCHINDUCTOR_FORCE_LAYOUT_OPT", "0") == "1" +# AMD TDM config flag for TDM support for gfx1250, is a configuration gate +# rather than something that causes TDM instructions to be emitted directly. +# It's set to control if gfx1250-specific autotuning configs +# (with larger tile sizes, more pipeline stages) are included +# in the candidate set during `max_autotune`. +# When enabled and running on a gfx1250 device, Inductor may emit Triton kernels +# that leverage TDM for asynchronous global->LDS tensor tile copies. +# Requires Triton >= 3.6.0 with AMD TDM backend support. +# Disabled by default on ROCm; TDM is further gated by device arch (gfx1250) +# in `use_triton_tdm_template()` at codegen time. +# Set to "0" to disable TDM even on gfx1250 hardware. +enable_tdm_configs = os.environ.get("TORCHINDUCTOR_ENABLE_TDM_CONFIGS", "0") == "1" # Whether to keep the output strides the same as eager after layout optimization. keep_output_stride = os.environ.get("TORCHINDUCTOR_KEEP_OUTPUT_STRIDE", "1") == "1" @@ -2199,7 +2211,9 @@ class rocm: # If empty, the `native` arch is used arch: list[str] = [] - # Enable the CK backend for CDNA2 and CDNA3 only (for now) + # Enable the CK backend for supported CDNA archs only (for now). + # gfx1250 is Wave32 and must stay off the current CK/XDL paths until + # composable_kernel has a gfx1250-safe WMMA/SWMMAC implementation. # Processor name reference: https://llvm.org/docs/AMDGPUUsage.html#processors ck_supported_arch: list[Literal["gfx90a", "gfx942", "gfx950"]] = [ "gfx90a", @@ -2267,6 +2281,17 @@ class rocm: contiguous_threshold: int = 16 +class tdm: + # Maximum outstanding TDM address translations per wave is 4. + # For small tiles (e.g., 128x64 FP16), this is the binding constraint. + # For larger tiles with 2 waves/SIMD, the SIMD limit of 6 applies. + max_outstanding_per_wave: int = 4 + max_outstanding_per_simd: int = 6 + # TDM requires 128B/256B aligned contiguous regions + # in both global memory and LDS. + alignment_bytes: int = 128 + + # Backend to use for CPU codegen either "cpp" or "triton" (experimental) or "halide" (experimental) or "pallas" (experimental) cpu_backend: Literal["cpp", "triton", "halide", "pallas"] = "cpp" diff --git a/torch/_inductor/kernel/flex/common.py b/torch/_inductor/kernel/flex/common.py index bf4006de8399a..764d7cd386691 100644 --- a/torch/_inductor/kernel/flex/common.py +++ b/torch/_inductor/kernel/flex/common.py @@ -43,12 +43,17 @@ to_dtype, ) from ...select_algorithm import realize_inputs -from ...utils import load_template +from ...utils import config, load_template SubgraphResults = Union[list[Optional[ComputedBuffer]], Optional[ComputedBuffer]] +def apply_tdm_num_stages(kernel_options: dict[str, Any]) -> None: + kernel_options["num_stages"] = config.tdm.max_outstanding_per_wave + kernel_options.pop("NUM_STAGES", None) + + def zeros_and_scatter_lowering(shape: list[int], indices, values): """To support backwards on captured buffers we register a specific lowering for our specific custom up""" # Always accumulate into fp32 then cast diff --git a/torch/_inductor/kernel/flex/flex_attention.py b/torch/_inductor/kernel/flex/flex_attention.py index c6017fadb0387..6277d26f43c4a 100644 --- a/torch/_inductor/kernel/flex/flex_attention.py +++ b/torch/_inductor/kernel/flex/flex_attention.py @@ -23,8 +23,9 @@ SymbolicGridFn, TritonTemplate, ) -from ...utils import can_use_tma +from ...utils import can_use_tma, use_triton_tdm_template from .common import ( + apply_tdm_num_stages, build_subgraph_buffer, create_indices_fake, create_num_blocks_fake_generator, @@ -410,6 +411,15 @@ def flex_attention( if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # For gfx1250 TDM, ensure the non-TMA path uses enough pipeline stages + # to trigger TDM async copies in Triton's AMD backend (TTGIR pipelining pass). + # Standard attention block sizes (64, 128, 256) with FP16/BF16 + # produce 128B aligned tiles and are compatible. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + cur_kernel_options.setdefault("BLOCK_M", conf.block_m) cur_kernel_options.setdefault("BLOCK_N", conf.block_n) # Blocksparse options @@ -927,6 +937,12 @@ def flex_attention_backward(*args, **kwargs): if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # See the comments at the corresponding place in function `flex_attention` above. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + cur_kernel_options.setdefault("BLOCK_M1", conf.block_m1) cur_kernel_options.setdefault("BLOCK_N1", conf.block_n1) cur_kernel_options.setdefault("BLOCK_M2", conf.block_m2) diff --git a/torch/_inductor/kernel/flex/flex_decoding.py b/torch/_inductor/kernel/flex/flex_decoding.py index 4111915c26082..8a24b71e398f4 100644 --- a/torch/_inductor/kernel/flex/flex_decoding.py +++ b/torch/_inductor/kernel/flex/flex_decoding.py @@ -18,8 +18,9 @@ SymbolicGridFn, TritonTemplate, ) -from ...utils import can_use_tma +from ...utils import can_use_tma, use_triton_tdm_template from .common import ( + apply_tdm_num_stages, create_indices_fake, create_num_blocks_fake_generator, freeze_irnodes, @@ -354,6 +355,15 @@ def create_flex_decoding_kernel(*args, **kwargs): if cur_kernel_options["USE_TMA"] and not can_use_tma(query, key, value): cur_kernel_options["USE_TMA"] = False + # For gfx1250 TDM, ensure the non-TMA path uses enough pipeline stages + # to trigger TDM async copies in Triton's AMD backend (TTGIR pipelining pass). + # Standard attention block sizes (64, 128, 256) with FP16/BF16 + # produce 128B aligned tiles and are compatible. + if not cur_kernel_options["USE_TMA"] and use_triton_tdm_template( + query, key, value + ): + apply_tdm_num_stages(cur_kernel_options) + # Add ROCm-specific parameters if they exist in the config for attrib in ["kpack", "matrix_instr_nonkdim", "waves_per_eu"]: if hasattr(conf, attrib): diff --git a/torch/_inductor/kernel/mm.py b/torch/_inductor/kernel/mm.py index 112f201d52f1d..51d922deec82f 100644 --- a/torch/_inductor/kernel/mm.py +++ b/torch/_inductor/kernel/mm.py @@ -54,6 +54,7 @@ use_triton_scaling_template, use_triton_template, use_triton_tma_template, + use_triton_tdm_template, ) from .mm_common import ( _is_static_problem, @@ -100,6 +101,14 @@ source=load_kernel_template("triton_persistent_tma_mm"), ) +# Non-TMA Triton template for persistent MM +# used on AMD +persistent_mm_template = TritonTemplate( + name="mm_persistent", + grid=persistent_mm_grid, + source=load_kernel_template("triton_persistent_mm"), +) + scaled_mm_device_tma_epilogue_scaling_template = TritonTemplate( name="scaled_mm_device_tma_epilogue_scaling", @@ -121,6 +130,59 @@ ) +def _append_persistent_mm_template( + templates_to_use: list[ExternKernelChoice | KernelTemplate], + mat1: Buffer, + mat2: Buffer, + layout: Layout, +) -> str | None: + if use_triton_blackwell_tma_template( + mat1, mat2, output_layout=layout, add_guards=True + ): + templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) + return "blackwell_tma" + if use_triton_tdm_template(mat1, mat2, output_layout=layout, add_guards=True): + # GFX1250 TDM: use the non-TMA persistent template. Triton's AMD backend + # automatically inserts TDM instructions when compiling with num_stages > 1. + # + # TDM hardware constraints (gfx1250): + # Documentation-only for now because Triton's pipeliner handles the + # wave assignment internally. If Triton exposes wave-specialization + # knobs in the future, these comments identify where to hook them. + # - 1 TDM unit per SIMD pair (2 SIMDs share 1 TDM unit) + # - Max 4 outstanding TDM address translations per wave + # - Max 6 outstanding TDM address translations per SIMD + # - Recommended: waves specialized to load A or B + # - 8 waves/WG (num_warps=8): 4 waves for A, 4 for B + # - 4 waves/WG (num_warps=4): 2 waves for A, 2 for B + # - All waves can load both A & B. This is preferable for uniform + # unrolling across CUs and helps multicast load latency. + # - Do not interleave different TDMs within a wave. Complete one TDM + # instruction's unrolling before switching. + # + # The persistent template heuristic supplies gfx1250 TDM configs. + # Key constraints: + # - TDM requests target 128B or 256B aligned contiguous regions in both + # global memory and LDS. + # - For FP16: BLOCK_K must be a multiple of 64 (64 * 2B = 128B). + # - For FP8/FP4: BLOCK_K must be a multiple of 128 (128 * 1B = 128B). + # - For FP32: BLOCK_K must be a multiple of 32 (32 * 4B = 128B). + # - BLOCK_M and BLOCK_N should also produce 128B-aligned LDS rows. + # TDM is most beneficial for small tiles such as 128x64/64x128 FP16 + # and 128x128/128x256 FP8 with num_stages=4 (matching the + # 4-outstanding-per-wave TDM address translation limit) and + # num_warps=4 or 8 (for wave-specialized A/B loading). + templates_to_use.append(persistent_mm_template) + return "tdm" + if use_triton_tma_template(mat1, mat2, output_layout=layout, add_guards=True): + if torch.version.hip is None: + templates_to_use.append(persistent_tma_mm_template) + else: + templates_to_use.append(persistent_mm_template) + return "tma" + return None + + # prevent duplication registration of extern functions @functools.cache def lazy_register_extern_choice(fn): @@ -419,10 +481,7 @@ def _to_dtype(x): if is_exhaustive or not use_decompose_k_choice(m, n, k, threshold_multiple=2): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + _append_persistent_mm_template(templates_to_use, mat1, mat2, layout) if ( inductor_config.is_fbcode() @@ -664,10 +723,7 @@ def tuned_addmm(inp, mat1, mat2, *, alpha=1, beta=1, layout=None): if is_nonzero and use_triton_template(layout, check_max_autotune=False): templates_to_use.append(mm_template) - if use_triton_blackwell_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(blackwell_ws_persistent_device_tma_mm_template) - elif use_triton_tma_template(mat1, mat2, output_layout=layout): - templates_to_use.append(persistent_tma_mm_template) + _append_persistent_mm_template(templates_to_use, mat1, mat2, layout) templates_to_use.append(addmm_contiguous_subgraph_template) diff --git a/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja new file mode 100644 index 0000000000000..7450533be5adf --- /dev/null +++ b/torch/_inductor/kernel/templates/triton_persistent_mm.py.jinja @@ -0,0 +1,75 @@ +{{def_kernel("A", "B")}} + M = {{size("A", 0)}} + N = {{size("B", 1)}} + K = {{size("A", 1)}} + if M * N == 0: + # early exit due to zero-size input(s) + return + stride_am = {{stride("A", 0)}} + stride_ak = {{stride("A", 1)}} + stride_bk = {{stride("B", 0)}} + stride_bn = {{stride("B", 1)}} + + # persistent kernel: each CTA processes multiple tiles + start_pid = tl.program_id(0).to(INDEX_DTYPE) + grid_m = tl.cdiv(M, BLOCK_M) + grid_n = tl.cdiv(N, BLOCK_N) + num_tiles = grid_m * grid_n + width = GROUP_M * grid_n + + for tile_id in tl.range(start_pid, num_tiles, NUM_SMS): + + # re-order program ID for better L2 performance + group_id = tile_id // width + group_size = min(grid_m - group_id * GROUP_M, GROUP_M) + pid_m = group_id * GROUP_M + (tile_id % group_size) + pid_n = (tile_id % width) // (group_size) + tl.assume(pid_m >= 0) + tl.assume(pid_n >= 0) + + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + if ((stride_am == 1 and stride_ak == M) or (stride_am == K and stride_ak == 1)) and (M >= BLOCK_M and K > 1): + offs_a_m = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M) + else: + offs_a_m = rm % M + if ((stride_bk == 1 and stride_bn == K) or (stride_bk == N and stride_bn == 1)) and (N >= BLOCK_N and K > 1): + offs_b_n = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N) + else: + offs_b_n = rn % N + offs_k = tl.arange(0, BLOCK_K) + acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE) + + for k_idx in range(0, tl.cdiv(K, BLOCK_K)): + {% if not EVEN_K %} + a_mask = offs_k[None, :] < (K - k_idx * BLOCK_K) + b_mask = offs_k[:, None] < (K - k_idx * BLOCK_K) + {% endif %} + a_k_idx_vals = offs_k[None, :] + (k_idx * BLOCK_K) + b_k_idx_vals = offs_k[:, None] + (k_idx * BLOCK_K) + + idx_m = offs_a_m[:, None] + idx_n = a_k_idx_vals + {{load_input("A", "a", ("idx_m", "idx_n"), mask=None if EVEN_K else "a_mask", + indent_width=12, index_shape=("BLOCK_M", "BLOCK_K"))}} + + idx_m = b_k_idx_vals + idx_n = offs_b_n[None, :] + {{load_input("B", "b", ("idx_m", "idx_n"), mask=None if EVEN_K else "b_mask", + indent_width=12, index_shape=("BLOCK_K", "BLOCK_N"))}} + + {% if USE_FAST_ACCUM %} + acc = tl.dot(a, b, acc, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% else %} + acc += tl.dot(a, b, allow_tf32=ALLOW_TF32, out_dtype=ACC_TYPE) + {% endif %} + + # rematerialize rm and rn to save registers + rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M) + rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N) + idx_m = rm[:, None] + idx_n = rn[None, :] + mask = (idx_m < M) & (idx_n < N) + + # inductor generates a suffix + {{store_output(("idx_m", "idx_n"), "acc", "mask", indent_width=8, val_shape=("BLOCK_M", "BLOCK_N"))}} diff --git a/torch/_inductor/rocm_multiarch_utils.py b/torch/_inductor/rocm_multiarch_utils.py index a1a6103e10915..1f4c4b1f928dd 100644 --- a/torch/_inductor/rocm_multiarch_utils.py +++ b/torch/_inductor/rocm_multiarch_utils.py @@ -3,7 +3,9 @@ Compile LLVM IR to multi-arch bundles that HIP can load automatically. """ +import logging import os +import re import subprocess from typing import Optional @@ -11,6 +13,9 @@ from torch.utils.cpp_extension import _join_rocm_home, ROCM_HOME +log = logging.getLogger(__name__) + + def get_rocm_compiler() -> str: """ Get path to ROCm's clang compiler. @@ -69,19 +74,21 @@ def get_rocm_bundler() -> str: def get_rocm_target_archs() -> list[str]: - """ - Get target architectures from environment or config. - Returns: List of architecture strings (e.g., ['gfx90a', 'gfx942']) - """ - # Check PYTORCH_ROCM_ARCH environment variable env_archs = os.environ.get("PYTORCH_ROCM_ARCH", "").strip() if env_archs: archs = [arch.strip() for arch in env_archs.replace(";", ",").split(",")] archs = [arch for arch in archs if arch] if archs: + # Ensure current device arch is included + if torch.cuda.is_available(): + for dev_idx in range(torch.cuda.device_count()): + current_arch = torch.cuda.get_device_properties( + dev_idx + ).gcnArchName.split(":")[0] + if current_arch not in archs: + archs.append(current_arch) return archs - # Try to get from inductor config try: from torch._inductor import config @@ -96,6 +103,43 @@ def get_rocm_target_archs() -> list[str]: return torch.cuda.get_arch_list() +def _sanitize_llvm_ir_for_rocm(llvm_ir_path: str) -> str: + """ + Sanitize LLVM IR to be compatible with ROCm's clang. + + Triton's LLVM (upstream) may emit attributes and metadata that ROCm's + older clang does not yet support. Only strips attributes confirmed to + cause parse errors — preserves all others to maintain correct codegen. + + Currently strips: + - nocreateundeforpoison: function attribute (upstream LLVM, not in ROCm) + - dwarfAddressSpace: debug metadata field (upstream LLVM, not in ROCm) + + Returns: + Path to sanitized .ll file, or original path if no changes needed. + """ + with open(llvm_ir_path) as f: + content = f.read() + + sanitized = content + sanitized = re.sub(r"\bnocreateundeforpoison\b\s*", "", sanitized) + sanitized = re.sub(r",\s*dwarfAddressSpace:\s*\d+", "", sanitized) + + if sanitized == content: + return llvm_ir_path + + sanitized_path = llvm_ir_path + ".sanitized.ll" + with open(sanitized_path, "w") as f: + f.write(sanitized) + + log.debug( + "Sanitized LLVM IR for ROCm clang compatibility: %s -> %s", + llvm_ir_path, + sanitized_path, + ) + return sanitized_path + + def compile_llvm_ir_to_code_object( llvm_ir_path: str, output_path: str, target_arch: str ) -> bool: @@ -120,6 +164,9 @@ def compile_llvm_ir_to_code_object( except RuntimeError: return False + # Sanitize LLVM IR to remove attributes unsupported by ROCm's clang + llvm_ir_path = _sanitize_llvm_ir_for_rocm(llvm_ir_path) + # Using clang and not hipcc since we are not compiling source code # Instead we use the LLVM IR (.ll) provided by triton cmd = [ diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 16bc308839b99..d7e98b7d5bf96 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -2349,25 +2349,59 @@ def check_max_block(cfg: dict[str, int]): ) -def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): - # On AMD GPU each warp has 64 lanes which is double the size on NV GPU, - # therefore using half the number of warps here correspondingly. +def _device_warp_size_for_heuristics() -> int: + # Wave size is not uniform across AMD GPUs: gfx90a/942/950 are Wave64 while + # gfx1250 are Wave32. Query the device so the warp-count heuristics below + # scale correctly instead of assuming Wave64. + # + # NOTE: This intentionally queries the *current* device (no index argument). + # The triton_config* helpers that consume this build configs at lowering time + # and are not parameterized by a target device, so there is no device index to + # thread through here. Inductor bakes the warp size into the config at + # construction time against the active device; on a hypothetical host mixing + # Wave32 and Wave64 GPUs, configs built under one device should not be reused + # verbatim under another. + if not torch.version.hip: + return _NUM_THREADS_PER_WARP + + try: + warp_size = torch.cuda.get_device_properties().warp_size + except (AssertionError, AttributeError, RuntimeError): + # Device properties are not available during config construction. + # Fall back to 64, which preserves the historical ROCm (Wave64) + # heuristic. Wave32 parts like gfx1250 report a real warp size of 32 + # when properties are queryable, so they take the branch above. + return 64 + + return warp_size or 64 + + +def _num_warps( + num_warps, + max_num_warps=8, + min_num_warps=2, + register_intensive=False, + warp_size=None, +): if torch.version.hip: - max_num_warps = (max_num_warps + 1) // 2 - min_num_warps = (min_num_warps + 1) // 2 + # On Wave64 AMD GPUs each warp has 64 lanes, double the 32-lane NV/Wave32 + # warp, so the warp budget is scaled down by the wave-width ratio. Wave32 + # parts (e.g. gfx1250) have warp_scale == 1 and keep the full warp count. + warp_size = warp_size or _device_warp_size_for_heuristics() + warp_scale = max(warp_size // _NUM_THREADS_PER_WARP, 1) + max_num_warps = (max_num_warps + warp_scale - 1) // warp_scale + min_num_warps = (min_num_warps + warp_scale - 1) // warp_scale # persistent reduction is register intensive if register_intensive: max_num_warps = max_num_warps // 2 return next_power_of_2(min(max(num_warps, min_num_warps), max_num_warps)) -def _check_max_grid_x(size_hints, x, num_warps): +def _check_max_grid_x(size_hints, x, num_warps, warp_size=None): # Check if maxGridSize is exceeded - if so then must scale XBLOCK further max_grid_x = 2147483647 max_block_x = TRITON_MAX_BLOCK["X"] - warp_size = ( - 64 if torch.version.hip else 32 - ) # TODO: query warp size once #129663 is merged + warp_size = warp_size or _device_warp_size_for_heuristics() num_blocks = (size_hints["x"] + x - 1) // x if torch.version.hip: @@ -2459,10 +2493,14 @@ def triton_config( ): z *= 2 + warp_size = _device_warp_size_for_heuristics() + # Calculate num_warps if they are not hard passed to config if num_warps is None: num_warps = _num_warps( - conditional_product(x, y, z) // num_elements_per_warp, min_num_warps=1 + conditional_product(x, y, z) // num_elements_per_warp, + min_num_warps=1, + warp_size=warp_size, ) # we are going to arrive at 2 warps only if bs was too small due to # numel being too small. However to workaround some ptx bugs we still @@ -2477,13 +2515,18 @@ def triton_config( znumel = size_hints.get("z") # Increase x to satisfy min_elem_per_thread requirements. + # NOTE: Keep this expressed in 32-lane units (_NUM_THREADS_PER_WARP), not the + # device warp size. Using the real warp size would double this min block size + # on Wave64 archs (gfx942/gfx950) relative to historical behavior, a silent + # perf change for already-shipping parts. gfx1250 is Wave32, so its real warp + # size is already 32 and it is unaffected by keeping the constant here. block_size = max( conditional_product(x, y, z), min_elem_per_thread * _NUM_THREADS_PER_WARP * num_warps, ) x *= math.ceil(block_size / conditional_product(x, y, z)) - x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps, warp_size) x = min(x, size_hints["x"]) cfg = {"XBLOCK": x} @@ -2579,6 +2622,8 @@ def total_numel() -> int: while rnumels[prefix] < size_hints[prefix] and total_numel() < target: rnumels[prefix] *= 2 + warp_size = _device_warp_size_for_heuristics() + if num_warps is None: if reduction_hint == ReductionHint.INNER: # r is contiguous, ensure at least 8 elements per thread @@ -2594,10 +2639,13 @@ def total_numel() -> int: _num_warps_func = _num_warps num_warps = _num_warps_func( - num_warps, max_num_warps=max_num_warps, register_intensive=register_intensive + num_warps, + max_num_warps=max_num_warps, + register_intensive=register_intensive, + warp_size=warp_size, ) - x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps) + x, _num_blocks = _check_max_grid_x(size_hints, x, num_warps, warp_size) for prefix in sorted(rnumels): while total_numel() > target: @@ -2757,9 +2805,13 @@ def total_numel() -> int: y *= 2 cfg = _get_config({"x": x, "y": y, **rnumels}) - num_warps = _num_warps(total_numel() // 256, min_num_warps=1) + warp_size = _device_warp_size_for_heuristics() + num_warps = _num_warps(total_numel() // 256, min_num_warps=1, warp_size=warp_size) num_warps = _num_warps( - num_warps, max_num_warps=16, register_intensive=register_intensive + num_warps, + max_num_warps=16, + register_intensive=register_intensive, + warp_size=warp_size, ) check_config(cfg, xnumel=size_hints["x"], ynumel=size_hints["y"]) check_max_block(cfg) diff --git a/torch/_inductor/template_heuristics/triton.py b/torch/_inductor/template_heuristics/triton.py index e15ff07ee5a4f..fb6740e48b37f 100644 --- a/torch/_inductor/template_heuristics/triton.py +++ b/torch/_inductor/template_heuristics/triton.py @@ -22,6 +22,7 @@ get_scaling_options, get_tile_size, mm_template, + persistent_mm_template, persistent_tma_mm_template, scaled_mm_device_tma_epilogue_scaling_template, scaled_mm_device_tma_main_loop_scaling_template, @@ -32,6 +33,7 @@ get_backend_num_stages, get_num_sms, get_tma_workspace_arg, + is_gfx1250_arch, TMA_DESCRIPTOR_SIZE, using_b200, ) @@ -49,6 +51,42 @@ from torch._inductor.runtime.triton_compat import Config as TritonConfig +def _is_gfx1250_device() -> bool: + """Detect whether the current device is AMD gfx1250. + + gfx1250 has 320 KB LDS per CU, 512 B/clk/CU LDS bandwidth (R/W combined), + and TDM for async descriptor-based global->LDS copies. When num_stages > 1, + Triton's AMD backend automatically converts tl.load into TDM instructions. + + TDM hardware constraints: + - 1 TDM unit per SIMD pair (2 SIMDs share 1 TDM unit) + - Max 4 outstanding TDM address translations per wave + - Max 6 outstanding TDM address translations per SIMD + - Requests target 128B or 256B aligned contiguous regions + in both global memory and LDS + """ + if not torch.version.hip: + return False + if not inductor_config.enable_tdm_configs: + return False + try: + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + arch = getattr(props, "gcnArchName", "") + return is_gfx1250_arch(arch) + except Exception: + return False + + +def _filter_tdm_block_k_configs( + configs: list[BaseConfig], + dtype_size: int, +) -> list[BaseConfig]: + if dtype_size <= 0: + return configs + block_k_multiple = max(inductor_config.tdm.alignment_bytes // dtype_size, 1) + return [c for c in configs if c.block_k % block_k_multiple == 0] + + # Gemm Configs @dataclasses.dataclass class BaseConfig: @@ -1381,6 +1419,7 @@ def __init__(self) -> None: super().__init__() self.default_num_stages = get_backend_num_stages() + self.uses_tdm_configs = False self.mm_configs: list[BaseConfig] = [ ROCmGemmConfig( @@ -1442,6 +1481,42 @@ def __init__(self) -> None: ROCmGemmConfig(256, 256, 64, self.default_num_stages, 8, group_m=4), ] + # TDM-optimized persistent MM configs for gfx1250. + # + # Hardware constraints on gfx1250: + # - 320 KB LDS per CU, 512 B/clk/CU bandwidth + # - Max 4 outstanding TDM translations per wave (bounds num_stages) + # - Max 6 outstanding TDM translations per SIMD + # - TDM requests target 128B/256B aligned contiguous regions + # - 1 TDM unit per SIMD pair -> recommended 1 wave/SIMD pair issuing + # TDM instructions (num_warps=4 gives 1 wave/SIMD on a 4-SIMD CU) + # + # BLOCK_K alignment requirements (128B / dtype_bytes): + # FP16/BF16 (2B): BLOCK_K multiple of 64 + # FP8/INT8 (1B): BLOCK_K multiple of 128 + # FP32 (4B): BLOCK_K multiple of 32 + # + # LDS usage estimate: (BM*BK + BK*BN)*dtype_size*num_stages + # Must stay under 320 KB. + self.tdm_persistent_mm_configs: list[BaseConfig] = [ + # Small tiles: TDM most beneficial here (LDS-BW-limited to 1 wave/SIMD). + # These are where TDM eliminates cluster load issue overhead. + # `waves_per_eu=0` means Triton compiler should figure out the best value. + ROCmGemmConfig(128, 64, 64, 4, 4, group_m=8, waves_per_eu=0), + ROCmGemmConfig( 64, 128, 64, 4, 4, group_m=8, waves_per_eu=0), + ROCmGemmConfig(128, 64, 128, 4, 4, group_m=8), + ROCmGemmConfig( 64, 128, 128, 4, 4, group_m=8), + # Medium tiles: benefit from TDM prologue overhead reduction. + ROCmGemmConfig(128, 128, 64, 4, 4, group_m=8), + ROCmGemmConfig(128, 128, 64, 4, 8, group_m=16), + ROCmGemmConfig(128, 128, 128, 4, 4, group_m=8), + ROCmGemmConfig(128, 128, 128, 3, 8, group_m=16), + # Larger tiles: exploit the 320 KB LDS with deep pipelines. + ROCmGemmConfig(256, 128, 64, 4, 8, group_m=16), + ROCmGemmConfig(128, 256, 64, 4, 8, group_m=16), + ROCmGemmConfig(256, 128, 128, 3, 8, group_m=16), + ] + # Exhaustive search for mm configs self.exhaustive_configs: list[BaseConfig] = [ ROCmGemmConfig( @@ -1563,12 +1638,57 @@ def _prune_exhaustive_configs( ] return pruned_configs + def preprocess_mm_configs( + self, + m: int, + n: int, + k: int, + configs: list[BaseConfig], + has_int8_tensor: bool = False, + scale: float = 1.0, + exclude: Callable[ + [sympy.Integer, sympy.Integer, sympy.Integer], bool + ] = lambda m, n, k: False, + dtype_size: int = 0, + op_name: str = "mm", + **kwargs, + ) -> Generator[TritonConfig, None, None]: + if self.uses_tdm_configs: + configs = _filter_tdm_block_k_configs(configs, dtype_size) + return super().preprocess_mm_configs( + m, + n, + k, + configs, + has_int8_tensor, + scale, + exclude, + dtype_size, + op_name, + **kwargs, + ) + def _filter_configs(self, configs: list[BaseConfig]) -> list[BaseConfig]: """ ROCm specific filtering + + Normally we force num_stages to `default_num_stages` because AMD's + Triton backend historically didn't benefit from multi-stage pipelining. + However, on gfx1250 with TDM enabled, num_stages > 1 is exactly + what triggers TDM async copies in the StreamPipeliner pass, so we + must preserve the requested num_stages for TDM configs. """ + tdm_enabled = _is_gfx1250_device() for c in configs: - c.num_stages = self.default_num_stages + # On gfx1250, preserve num_stages as provided (but cap at 4 to + # respect TDM's per-wave outstanding limit). Otherwise, fall + # back to the legacy behavior of forcing default_num_stages. + if tdm_enabled: + c.num_stages = ( + min(c.num_stages, 4) if c.num_stages > 1 else c.num_stages + ) + else: + c.num_stages = self.default_num_stages return super()._filter_configs(configs) def _finalize_mm_configs( @@ -2914,6 +3034,56 @@ def __init__(self) -> None: self.exhaustive_configs = self.mm_plus_mm_configs +# ROCm persistent MM template heuristic (non-TMA, standard pointer loads) + + +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, +) +class PersistentMMTemplateConfigHeuristic( + MMTemplateConfigMixin, + ROCmConfigHeuristic, # type: ignore[misc] +): + """Persistent MM template heuristic (no TMA, standard pointer loads)""" + + def __init__(self) -> None: + super().__init__() + if _is_gfx1250_device(): + # On gfx1250, use TDM-optimized configs that exploit: + # - 320 KB LDS (larger tiles possible) + # - TDM async global->LDS copies (num_stages > 1) + # - Wave specialization across SIMD pairs + self.mm_configs = self.tdm_persistent_mm_configs + self.uses_tdm_configs = True + else: + self.mm_configs = self.persistent_mm_configs + + def _get_template_configs_impl( + self, + kernel_inputs: KernelInputs, + op_name: str, + **kwargs, + ) -> Generator[dict[str, Any], None, None]: + for template_kwargs in super()._get_template_configs_impl( + kernel_inputs, op_name, **kwargs + ): + yield {**template_kwargs, "NUM_SMS": get_num_sms()} + + +@register_template_heuristic( + persistent_mm_template.uid, + "cuda", + register=torch.version.hip is not None, + op_name="addmm", +) +class ROCmAddMMPersistentTemplateConfigHeuristic( + AddMMConfigMixin, PersistentMMTemplateConfigHeuristic +): + """Addmm specific mixin for persistent MM on ROCm""" + + # MTIA template-specific classes diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index c9c60aa71eb0e..5184b16001a35 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -1977,6 +1977,118 @@ def use_triton_tma_template( ) +def is_gfx1250_arch(arch: str) -> bool: + """Return True only for the gfx1250 target, including feature-suffixed names.""" + return arch.split(":", 1)[0] == "gfx1250" + + +def use_triton_tdm_template(*matrices: IRNode, output_layout=None, add_guards=False): + """Coarsely check whether AMD TDM-optimized persistent template should be used. + Despite the fact that TDM has some restrictions on direct LDS loads ( + when unrolling requests, the requests are to 128B or 256B aligned and + continuous regions in global memory, and writing back to 128B or 256B + aligned and continuous regions in LDS), this gate doesn't check per-element + alignment, instead delegating to Triton compiler's AMD backend. + That said, for GeMM specifically, it checks the leading dimensions produce + 128B-aligned rows, which is the common case where TDM is beneficial. + + Args: + - output_layout: Output layout (used for device detection) + - add_guards: If True, add shape guards to the graph + Returns True when: + - Running on ROCm (torch.version.hip is not None) + - Device is gfx1250 + - config.enable_tdm_configs (defined elsewhere) is True + - Triton version supports gfx1250 TDM (>= 3.6.0) + """ + if not torch.version.hip: + return False + if not config.enable_tdm_configs: + return False + device = matrices[0].get_device() + if device.type != "cuda": + return False + try: + # props = torch.cuda.get_device_properties(torch.cuda.current_device()) + props = torch.cuda.get_device_properties(device) + arch = getattr(props, "gcnArchName", "") + if not is_gfx1250_arch(arch): + return False + except Exception: + return False + + # Check Triton version supports TDM (3.6.0+) + try: + from torch.torch_version import TorchVersion + import triton + if TorchVersion(triton.__version__) < "3.6.0": + return False + except (ImportError, Exception): + return False + + # TDM can lower the same matrix element types as standard Triton GeMM. + # The persistent-template config heuristic still filters BLOCK_K by dtype + # size so each TDM request covers an aligned 128B region. + # FIXME: torch.float4 + triton_tdm_dtypes = { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float8_e4m3fn, + torch.float8_e5m2, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + torch.int8, + } + # TDM unrolls requests to 128B or 256B aligned contiguous regions + # in both global memory and LDS. For GeMM, this means the leading + # dimension (stride) of each matrix should produce rows that start + # at 128B-aligned addresses. If strides are not 128B-aligned, TDM + # may still work (Triton's backend can fall back per-load), but + # performance benefit is reduced. We don't hard-gate on this because: + # 1. The autotuner will benchmark and reject TDM configs if they're slower. + # 2. Triton's pipeliner handles misaligned cases gracefully. + # 3. Most real GeMM workloads have contiguous row-major or col-major layouts. + # For now, only log for debugging purposes. May be upgraded to a hard gate later. + for mat in matrices: + mat_dtype = mat.get_dtype() + if mat_dtype not in triton_tdm_dtypes: + return False + strides = mat.get_stride() + dtype_bytes = mat_dtype.itemsize + # Check if the innermost stride is 1 (contiguous) and + # the outer stride produces 128B-aligned rows. + if len(strides) >= 2: + inner_stride = strides[-1] + outer_stride = strides[-2] + if hasattr(inner_stride, '__int__'): + inner_val = int(inner_stride) + if inner_val != 1: + # TODO: Should be a hard rejection, not just a warning. + log.debug( + "TDM: matrix has non-unit inner stride %d, " + "TDM requests may not be contiguous", + inner_val, + ) + if hasattr(outer_stride, '__int__'): + row_bytes = int(outer_stride) * dtype_bytes + if row_bytes % config.tdm.alignment_bytes != 0: + log.debug( + "TDM: matrix row stride %d bytes is not %dB-aligned, " + "TDM performance may be suboptimal", + row_bytes, + config.tdm.alignment_bytes, + ) + + if add_guards and output_layout is not None: + # `SizeVarAllocator` does not guard device properties. Keep this as a + # static consistency check for the layout/device pair used at compile time. + if (output_layout.device.index or 0) != (device.index or 0): + return False + + return True + + def use_triton_blackwell_tma_template( *matrices: IRNode, output_layout: Layout, add_guards: bool = False ) -> bool: diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index d842e8b56ef41..e64e36192fa62 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -435,11 +435,15 @@ def set_shared_memory_config(self, shared_mem_bytes: int) -> None: device_props = torch.cuda.get_device_properties() # HIP doesn't have shared_memory_per_block_optin in device properties, so we hard-code it here if torch.version.hip: - # navi, CDNA1-CDNA3 allows a max of 64KB shared memory - # CDNA4 allows a max of 160KB shared memory - max_shared_mem = ( - 65536 if device_props.gcnArchName != "gfx950" else 160 * 1024 - ) + # navi, CDNA1-CDNA3 allows a max of 64KB shared memory, + # CDNA4 (gfx950) 160KB, and CDNA5 (gfx1250) 320KB. + gcn_arch = device_props.gcnArchName.split(":", 1)[0] + if gcn_arch == "gfx950": + max_shared_mem = 160 * 1024 + elif gcn_arch == "gfx1250": + max_shared_mem = 320 * 1024 + else: + max_shared_mem = 65536 else: max_shared_mem = getattr( device_props, "shared_memory_per_block_optin", 49152 diff --git a/torch/nn/attention/varlen.py b/torch/nn/attention/varlen.py index 5be971e1bd96c..3f281c4675453 100644 --- a/torch/nn/attention/varlen.py +++ b/torch/nn/attention/varlen.py @@ -140,16 +140,9 @@ def _varlen_attn_fake( # For varlen path: logsumexp shape is (num_heads, total_q) total_q = query.size(0) num_heads = query.size(1) - if torch.version.hip: - # ROCm uses batched format: [batch_size, num_heads, max_q] - batch_size = cu_seq_q.size(0) - 1 - logsumexp = torch.empty( - (batch_size, num_heads, max_q), dtype=torch.float, device=query.device - ) - else: - logsumexp = torch.empty( - (num_heads, total_q), dtype=torch.float, device=query.device - ) + logsumexp = torch.empty( + (num_heads, total_q), dtype=torch.float, device=query.device + ) rng_state = torch.empty((2,), dtype=torch.uint64, device=query.device) diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py index 3f539d586e8bc..3a4c400fa9131 100644 --- a/torch/testing/_internal/common_cuda.py +++ b/torch/testing/_internal/common_cuda.py @@ -53,15 +53,26 @@ def evaluate_gfx_arch_within(arch_list): # Hence the matching should be done reversely return any(arch in effective_arch for arch in arch_list) +# Per-generation gfx targets, ordered oldest -> newest. Each "OrLater" helper +# below unions its own generation with every newer one, so the predicates are +# nested by construction: CDNA5OrLater => CDNA3OrLater => CDNA2OrLater. +_CDNA2_ARCHS = ["gfx90a"] +_CDNA3_ARCHS = ["gfx942", "gfx950"] +# CDNA 5 (CDNA-next / UDAN) +_CDNA5_ARCHS = ["gfx1250"] + +def CDNA5OrLater(): + return evaluate_gfx_arch_within(_CDNA5_ARCHS) + def CDNA3OrLater(): - return evaluate_gfx_arch_within(["gfx942", "gfx950"]) + return evaluate_gfx_arch_within(_CDNA3_ARCHS + _CDNA5_ARCHS) def CDNA2OrLater(): - return evaluate_gfx_arch_within(["gfx90a", "gfx942"]) + return evaluate_gfx_arch_within(_CDNA2_ARCHS + _CDNA3_ARCHS + _CDNA5_ARCHS) def evaluate_platform_supports_flash_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950", "gfx1250"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": arch_list += ["gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) @@ -71,7 +82,7 @@ def evaluate_platform_supports_flash_attention(): def evaluate_platform_supports_efficient_attention(): if TEST_WITH_ROCM: - arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950"] + arch_list = ["gfx90a", "gfx942", "gfx1100", "gfx1201", "gfx950", "gfx1250"] if os.environ.get("TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL", "0") != "0": arch_list += ["gfx1101", "gfx1102", "gfx1150", "gfx1151", "gfx1200"] return evaluate_gfx_arch_within(arch_list) @@ -139,6 +150,8 @@ def evaluate_platform_supports_fp8(): archs.extend(['gfx120']) if ROCM_VERSION >= (6, 5): archs.append('gfx95') + if ROCM_VERSION >= (7, 2): + archs.append('gfx1250') for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True @@ -151,7 +164,7 @@ def evaluate_platform_supports_fp8_grouped_gemm(): if torch.version.hip: if "USE_MSLK" not in torch.__config__.show(): return False - archs = ['gfx942', 'gfx950'] + archs = ['gfx942', 'gfx950', 'gfx1250'] for arch in archs: if arch in torch.cuda.get_device_properties(0).gcnArchName: return True @@ -163,7 +176,8 @@ def evaluate_platform_supports_mx_gemm(): if torch.cuda.is_available(): if torch.version.hip: if ROCM_VERSION >= (7, 0): - return 'gfx950' in torch.cuda.get_device_properties(0).gcnArchName + gcn_name = torch.cuda.get_device_properties(0).gcnArchName + return 'gfx950' in gcn_name or ('gfx1250' in gcn_name and ROCM_VERSION >= (7, 2)) else: return SM100OrLater return False