Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
9d8de03
Add initial support for gfx1250
rraminen Jun 15, 2026
082dbc9
[ROCm] Fix large ROCm arange launch (#182657) (#16)
jammm May 7, 2026
471cdf1
Temp 2.11 1250 tdm (#20)
glen-amd May 26, 2026
4694286
Bump triton version to 3.7.0
rraminen Jun 15, 2026
2d6992a
[release/2.11_gfx1250] Bump AOTriton to 0.12.50tp (#184288) (#3312)
xinyazhang Jun 19, 2026
ff88759
[release/2.11] Fix all arch build (#3324)
pragupta Jun 17, 2026
58eacd9
Fixes previous PR: [ROCm] Bump AOTriton to 0.12.50tp (#3328)
pragupta Jun 17, 2026
536aac8
Turn off MSLK's CK Kernels for gfx1250 (#3329)
pragupta Jun 17, 2026
dcea57c
[release/2.11_gfx1250] Support for gfx1250 - bug fixes (#3330)
rraminen Jun 17, 2026
6b6cf88
[release/2.11_gfx1250] Fix typo (#3331)
rraminen Jun 17, 2026
564e7c1
[release/2.11] Clean up gfx1250 specific code (#3347)
rraminen Jun 19, 2026
d3a0630
[release/2.11_gfx1250] Address PR #3346 review comments
pragupta Jun 23, 2026
62cced8
[ROCm] Build CK GEMM/SDPA per-arch and disable them for gfx1250
pragupta Jun 23, 2026
2999693
Advance Triton pin to 3.7.1 with module fix
naromero77amd Jun 23, 2026
d39ac76
[Inductor] Fix constants handling for Triton constexpr (triton#8248) …
pytorchbot Dec 23, 2025
331e9d6
Fix Triton HOP clone for strided mutated views (#184050)
jansel May 17, 2026
adb92d2
[ROCm] Fix multi-arch AOT Inductor compilation with newer Triton LLVM…
chinmaydk99 Jun 23, 2026
8d83604
bump to 0.12.50tp2 to fix gfx1250 gpu detection
xinyazhang Jun 23, 2026
36c2d98
GFX1250 warp-/wave-size related fixes
glen-amd Jun 23, 2026
75f45e1
Refined a piece of comments
glen-amd Jun 23, 2026
1e5a66b
Reverted an earlier commit to eliminate redundant CK related dependen…
glen-amd Jun 24, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ case "$tag" in
KATEX=yes
UCX_COMMIT=${_UCX_COMMIT}
UCC_COMMIT=${_UCC_COMMIT}
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100"
PYTORCH_ROCM_ARCH="gfx90a;gfx942;gfx950;gfx1100;gfx1250"
if [[ $tag =~ "benchmarks" ]]; then
INDUCTOR_BENCHMARKS=yes
fi
Expand Down
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
4ed888920c5a0871957f1cf912e557bc79fbe56c
110cd8e2ddf80d46fcc935d46dfcae7130d13b24
2 changes: 1 addition & 1 deletion .ci/docker/triton_version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.6.0
3.7.1
6 changes: 5 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -945,8 +945,12 @@ cmake_dependent_option(
OFF)


# TODO:
# MSLK related parts are missing that already exists upstream.
# gfx1250 for MSLK needs to be involved as well.

IF(USE_ROCM AND ("gfx942" IN_LIST PYTORCH_ROCM_ARCH OR "gfx950" IN_LIST PYTORCH_ROCM_ARCH))
message(WARNING "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build")
message(STATUS "Setting USE_MSLK for gfx942/gfx950 to ON by default, doing ROCM build")
set(USE_MSLK_DEFAULT ON)
elseif(USE_CUDA AND "$ENV{TORCH_CUDA_ARCH_LIST}" MATCHES "10.0" AND CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 12.8 AND NOT WIN32)
message(STATUS "Setting USE_MSLK to ON by default , doing CUDA build for SM100a")
Expand Down
62 changes: 59 additions & 3 deletions aten/src/ATen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,21 @@ file(GLOB native_flash_attn_api_cpp "native/transformers/cuda/flash_attn/flash_a

# flash_attention hip sources
file(GLOB flash_attention_hip_hip "native/transformers/hip/flash_attn/*.hip")
# composable_kernel lacks gfx1250 support. CK GEMM/SDPA are otherwise built for
# every arch except gfx1250 (the --offload-arch filtering below). If gfx1250 is
# the ONLY arch there is no supported arch left to build CK for, so disable both
# entirely here. caffe2_update_option writes the cache, so this is honored by the
# conditional CK GEMM/SDPA defines and links in caffe2/CMakeLists.txt.
if(USE_ROCM AND "gfx1250" IN_LIST PYTORCH_ROCM_ARCH)
set(_ck_supported_archs ${PYTORCH_ROCM_ARCH})
list(REMOVE_ITEM _ck_supported_archs gfx1250)
if("${_ck_supported_archs}" STREQUAL "")
message(WARNING "gfx1250 is the only arch in PYTORCH_ROCM_ARCH: disabling USE_ROCM_CK_GEMM and USE_ROCM_CK_SDPA (composable_kernel lacks gfx1250 support)")
caffe2_update_option(USE_ROCM_CK_GEMM OFF)
caffe2_update_option(USE_ROCM_CK_SDPA OFF)
endif()
endif()

# if USE_FLASH_ATTENTION is set, ensure CK instances get generated
if(USE_FLASH_ATTENTION)
if("$ENV{USE_CK_FLASH_ATTENTION}" STREQUAL "1")
Expand Down Expand Up @@ -289,9 +304,20 @@ if(USE_FLASH_ATTENTION)
"native/transformers/hip/flash_attn/ck/fav_v3/*.hip")

set_source_files_properties(${ck_sdpa_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
# composable_kernel lacks gfx1250 support, so build CK SDPA for every arch
# except gfx1250 by filtering the --offload-arch flags for this target only.
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
foreach(ARCH ${PYTORCH_ROCM_ARCH})
if(NOT ARCH STREQUAL "gfx1250")
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH})
endif()
endforeach()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(ck_sdpa STATIC
${ck_sdpa_sources_hip}
HIPCC_OPTIONS ${HIP_HCC_FLAGS} ${CK_SDPA_EXTRA_HIPCC_FLAGS})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(ck_sdpa PROPERTIES POSITION_INDEPENDENT_CODE ON)
target_compile_definitions(ck_sdpa PUBLIC ${CK_SDPA_EXTRA_HIPCC_OPTIONS})
target_compile_definitions(ck_sdpa PRIVATE AITER_EMBEDDED_HSA_HEADER="aiter_embedded_hsa.h")
Expand Down Expand Up @@ -430,7 +456,7 @@ IF(USE_MSLK)
list(PREPEND MSLK_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1)
endif()

# Only compile for gfx942 and gfx950.
# Only compile for gfx942 and gfx950 (composable_kernel lacks gfx1250 support).
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
foreach(ARCH gfx942 gfx950)
Comment thread
pragupta marked this conversation as resolved.
Expand Down Expand Up @@ -627,11 +653,41 @@ if(USE_ROCM)
${native_quantized_hip_hip}
${native_transformers_hip_hip} ${native_transformers_src_hip_hip}
)
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
if(NOT USE_ROCM_CK_GEMM)
file(GLOB native_hip_bgemm "native/hip/bgemm_kernels/*.hip")
file(GLOB native_hip_ck "native/hip/ck*.hip")
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_hip_bgemm} ${native_hip_ck})
else()
# composable_kernel lacks gfx1250 support, so the CK GEMM kernels are removed
# from the main HIP sources and compiled into a dedicated ck_gemm library for
# every arch except gfx1250 (the rest of torch_hip still builds for all archs).
# The --offload-arch filtering mirrors the mslk pattern above.
exclude(ATen_HIP_SRCS "${ATen_HIP_SRCS}"
${native_hip_bgemm} ${native_hip_ck})
set(ck_gemm_sources_hip ${native_hip_bgemm} ${native_hip_ck})
set_source_files_properties(${ck_gemm_sources_hip} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS})
string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}")
foreach(ARCH ${PYTORCH_ROCM_ARCH})
if(NOT ARCH STREQUAL "gfx1250")
list(APPEND FILTERED_HIP_CLANG_FLAGS --offload-arch=${ARCH})
endif()
endforeach()
set(HIP_CLANG_FLAGS ${FILTERED_HIP_CLANG_FLAGS})
hip_add_library(ck_gemm STATIC ${ck_gemm_sources_hip})
set(HIP_CLANG_FLAGS ${HIP_CLANG_FLAGS_ORIGINAL})
set_target_properties(ck_gemm PROPERTIES POSITION_INDEPENDENT_CODE ON)
# The define is no longer added globally in cmake/Dependencies.cmake; the
# ck_gemm_*.hip sources are guarded by #if defined(USE_ROCM_CK_GEMM).
target_compile_definitions(ck_gemm PRIVATE USE_ROCM_CK_GEMM)
target_include_directories(ck_gemm PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/include
${CMAKE_CURRENT_SOURCE_DIR}/../../../third_party/composable_kernel/library/include
${CMAKE_CURRENT_BINARY_DIR}/composable_kernel)
if(TARGET torch_cpu)
add_dependencies(ck_gemm torch_cpu)
endif()
endif()

# TODO: Codegen separate files for HIP and use those (s/cuda_generated_sources/hip_generated_sources)
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ at::BlasBackend Context::blasPreferredBackend() {
bool Context::ckSupported() {
#ifdef USE_ROCM
static const std::vector<std::string> supported_archs = {
"gfx90a", "gfx942", "gfx950"
"gfx90a", "gfx942", "gfx950", "gfx1250",
};
for (auto index : c10::irange(detail::getCUDAHooks().deviceCount())) {
if(!detail::getCUDAHooks().isGPUArch(supported_archs, index)) {
Expand Down
8 changes: 8 additions & 0 deletions aten/src/ATen/cuda/CUDABlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2020,6 +2020,14 @@ void scaled_gemm(
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
#if ROCM_VERSION >= 70200
if (at::detail::getCUDAHooks().isGPUArch({"gfx1250"})) {
// TODO: add constraints based on hipblaslt internals
TORCH_CHECK((m % 16 == 0) && (n % 16 == 0) && (k % 128 == 0),
"M, N must be multiples of 16 and K should be multiple of 128 for MX format. "
"Got m=", m, ", n=", n, ", k=", k);
}
#endif
}
#elif (CUDA_VERSION < 12090) && !defined(USE_ROCM)
// hipblaslt supported row-wise before cublas, and did so their own way (via
Expand Down
3 changes: 3 additions & 0 deletions aten/src/ATen/cuda/CUDAScaledBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#endif
#if ROCM_VERSION >= 70200
, "gfx1250"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs);
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,8 +141,8 @@ size_t parseChosenWorkspaceSize() {
val = c10::utils::get_env("ROCBLAS_WORKSPACE_CONFIG");
}
/* 32MiB default, 128MiB for gfx94x/gfx95x */
const bool gfx94_95 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95"});
const size_t default_size = gfx94_95 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
const bool gfx94_95_125 = at::detail::getCUDAHooks().isGPUArch({"gfx94", "gfx95", "gfx125"});
Comment thread
pragupta marked this conversation as resolved.
const size_t default_size = gfx94_95_125 ? 1024 * 128 * 1024 : 1024 * 32 * 1024;
#else
/* :4096:2:16:8 default, 32MiB for Hopper */
cudaDeviceProp* properties = at::cuda::getCurrentDeviceProperties();
Expand Down
5 changes: 4 additions & 1 deletion aten/src/ATen/cuda/detail/CUDAHooks.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,10 @@ const std::vector<std::string>& CUDAHooks::getHipblasltPreferredArchs() const {
"gfx1200", "gfx1201",
#endif
#if ROCM_VERSION >= 70000
"gfx950"
"gfx950",
#endif
#if ROCM_VERSION >= 70200
"gfx1250"
#endif
};
return archs;
Expand Down
6 changes: 5 additions & 1 deletion aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -228,8 +228,12 @@ C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void vectorized_elementwise_kernel(int N, func_t f, array_t data) {
using traits = function_traits<func_t>;
constexpr auto io_size = calc_io_size<func_t>();
// tws=16 is tuned for gfx942 (CDNA3, Wave64) and is NOT portable to Wave32
// archs. gfx1250 (GFX12.5, Wave32) falls through to the default `tws` below;
// re-enabling this fast path for gfx1250 requires retuning the constant
// against the 2x wave-width ratio. Keep this in sync with the host-side `tws`
// computation in launch_vectorized_kernel() (also gfx942-only).
#if defined(USE_ROCM) && defined(__gfx942__)
// Similar check in launch_vectorized_kernel() as well. Both should be in sync.
constexpr int tws = 16;
#else
constexpr int tws = elems_per_thread<io_size>();
Expand Down
5 changes: 5 additions & 0 deletions aten/src/ATen/native/cuda/GroupedBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,9 @@ std::optional<c10::ScalarType> out_dtype) {
// To enable CK path, use env variable ROCM_ALLOW_GROUP_GEMM_CK=1.
bool use_fast_path = false;
// ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path.
// NOTE: gfx1250 is intentionally excluded. The CK grouped GEMM path dispatches
// Wave64/MFMA-style XDL templates; gfx1250 is Wave32 and needs a WMMA/SWMMAC
// path, so it must stay on the fallback until a gfx1250-safe CK path exists.
#if defined(USE_ROCM_CK_GEMM)
if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) {
use_fast_path = true;
Expand All @@ -699,7 +702,9 @@ std::optional<c10::ScalarType> out_dtype) {
const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype);
Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_);
if (use_fast_path) {
#if defined(USE_ROCM_CK_GEMM)
at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out);
#endif //USE_ROCM_CK_GEMM
} else {
_grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out);
}
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ template <int vec_size, typename scalar_t>
__device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) {
using vec_t = aligned_vector<scalar_t, vec_size>;
auto *from = reinterpret_cast<const vec_t *>(base_ptr);
// This nontemporal vectorized load is tuned for gfx942 (CDNA3, Wave64). Do NOT
// extend to gfx1250 (GFX12.5, Wave32) without first validating intrinsic
// codegen and re-benchmarking -- the register layout and wave width differ.
// gfx1250 falls through to the generic path below, which is correctness-safe.
#if defined(USE_ROCM) && defined(__gfx942__)
using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int;
if constexpr (sizeof(vec_t) == sizeof(int)) {
Expand Down
41 changes: 41 additions & 0 deletions aten/src/ATen/native/cuda/RangeFactories.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
#include <ATen/native/RangeUtils.h>
#include <cmath>
#include <limits>
#if defined(USE_ROCM)
#include <algorithm>
#endif

#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
Expand Down Expand Up @@ -48,12 +51,49 @@ __global__ void elementwise_kernel_with_index(index_t N, func_t f, typename func
}
}

#if defined(USE_ROCM)
// HIP does not support launches with gridDim.x * blockDim.x >= 2^32:
// depending on the ROCm version the launch returns
// hipErrorInvalidConfiguration or is accepted silently with the kernel
// never executing, leaving zero-initialized output. A grid-stride kernel
// with a fixed grid sized to device occupancy avoids the limit.
template<typename index_t, typename func_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void elementwise_kernel_with_index_grid_stride(
index_t N, func_t f,
typename function_traits<func_t>::result_type *data) {
index_t idx = static_cast<index_t>(blockIdx.x) * blockDim.x + threadIdx.x;
const index_t stride = static_cast<index_t>(gridDim.x) * blockDim.x;
for (; idx < N; idx += stride) {
data[idx] = f(idx);
}
}
#endif

template<typename func_t>
void gpu_kernel_with_index(at::Tensor &output, func_t f) {
int64_t N = output.numel();
if (N == 0) {
return;
}
#if defined(USE_ROCM)
constexpr int blocks_per_sm = 4;
const int sm_count =
at::cuda::getCurrentDeviceProperties()->multiProcessorCount;
const int64_t orig_grid = (N + block_work_size - 1) / block_work_size;
int64_t grid = std::min<int64_t>(
orig_grid, static_cast<int64_t>(sm_count) * blocks_per_sm);
grid = std::max<int64_t>(grid, 1);
auto stream = at::cuda::getCurrentCUDAStream();
using scalar_t = typename function_traits<func_t>::result_type;
if (N <= std::numeric_limits<int>::max()) {
elementwise_kernel_with_index_grid_stride<int><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
elementwise_kernel_with_index_grid_stride<int64_t><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#else
int64_t grid = (N + block_work_size - 1) / block_work_size;
auto stream = at::cuda::getCurrentCUDAStream();
using scalar_t = typename function_traits<func_t>::result_type;
Expand All @@ -64,6 +104,7 @@ void gpu_kernel_with_index(at::Tensor &output, func_t f) {
elementwise_kernel_with_index<int64_t><<<grid, num_threads(), 0, stream>>>(N, f, output.mutable_data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif
}

} // namespace
Expand Down
20 changes: 19 additions & 1 deletion aten/src/ATen/native/cuda/ScaledBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ static bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=fals
#endif
#if ROCM_VERSION >= 60500
"gfx950"
#endif
#if ROCM_VERSION >= 70200
, "gfx1250"
#endif
};
return at::detail::getCUDAHooks().isGPUArch(archs);
Expand Down Expand Up @@ -622,9 +625,14 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2,
}
else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) {
#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
#if ROCM_VERSION >= 70000
#if ROCM_VERSION >= 70200
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250");
#else
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
#endif

int packed_factor = 1;
if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) {
Expand Down Expand Up @@ -1067,8 +1075,13 @@ _scaled_mxfp8_mxfp8(

#ifdef USE_ROCM
#if ROCM_VERSION >= 70000
#if ROCM_VERSION >= 70200
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250");
#else
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
#endif

TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
Expand Down Expand Up @@ -1153,8 +1166,13 @@ _scaled_mxfp4_mxfp4(
auto scaling_choice_b = ScalingType::BlockWise1x32;

#if ROCM_VERSION >= 70000
#if ROCM_VERSION >= 70200
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250");
#else
TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}),
"Block-wise scaling for Float8_e8m0fnu is only supported on gfx950");
#endif

TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 &&
mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0,
Expand Down
19 changes: 18 additions & 1 deletion aten/src/ATen/native/cuda/int4mm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ inline __host__ __device__ uint32_t getAlignmentRoundUp(const void* p) {
return diff == 0 ? 0 : uint32_t(Align) - diff;
}

#if defined (__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
// CDNA2+ arch with MFMA and Wave64 support
#if defined(__gfx90a__) || defined(__gfx942__) || defined(__gfx950__)
#define CDNA2_OR_LATER 1
#else
#define CDNA2_OR_LATER 0
Expand All @@ -146,6 +147,12 @@ static bool isCDNA2orLater(int index) {
return at::detail::getCUDAHooks().isGPUArch({"gfx90a", "gfx942", "gfx950"}, index);
}

// Conceptual for now and subject to change
// gfx1250 (CDNA5 / CDNA-next / UDNA)
static bool isCDNA5orLater(int index) {
return at::detail::getCUDAHooks().isGPUArch({"gfx1250"}, index);
}

#else
constexpr int32_t kWarpSize = 32;
#endif
Expand Down Expand Up @@ -1098,6 +1105,11 @@ at::Tensor _weight_int4pack_mm_cuda(
A.device() == B.device() && A.device() == qScaleAndZeros.device());

#if defined(USE_ROCM)
if (isCDNA5orLater(A.device().index())) {
TORCH_CHECK(false,
"_weight_int4pack_mm_cuda is not yet supported on gfx1250. "
"A WMMA-based implementation is required for gfx1250.");
}
if (!isCDNA2orLater(A.device().index())) {
TORCH_CHECK(false, "_weight_int4pack_mm_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
}
Expand Down Expand Up @@ -1293,6 +1305,11 @@ at::Tensor _convert_weight_to_int4pack_cuda(
TORCH_CHECK(innerKTiles == 2 || innerKTiles == 4 || innerKTiles == 8);

#if defined(USE_ROCM)
if (isCDNA5orLater(in.device().index())) {
TORCH_CHECK(false,
"_convert_weight_to_int4pack_cuda is not yet supported on gfx1250. "
"A WMMA-based implementation is required for gfx1250.");
}
if (!isCDNA2orLater(in.device().index())) {
TORCH_CHECK(false, "_convert_weight_to_int4pack_cuda is only supported on AMD gpu arch greater than or equal to CDNA2");
}
Expand Down
Loading