-
Notifications
You must be signed in to change notification settings - Fork 81
[release/2.12] Add support for gfx1250 #3327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: release/2.12
Are you sure you want to change the base?
Changes from all commits
632d652
a3dd7ed
aec8828
1e71988
07ec88f
873d4c5
aeb64a7
55719a5
bbecc56
04b5405
43a62ac
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These changes are unnecessary unless we know for certain any build workflows that would use it. TheRock build workflows don't. |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -268,8 +268,15 @@ if(USE_FLASH_ATTENTION) | |
| CK_ENABLE_FP64 | ||
| CK_ENABLE_FP8 | ||
| CK_ENABLE_INT8 | ||
| CK_USE_FNUZ_FP8 | ||
| CK_USE_GFX94 | ||
| #CK_USE_FNUZ_FP8 | ||
| #CK_USE_GFX94 | ||
| CK_USE_GFX1250 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here and below change the CK SDPA compile definitions globally. |
||
| CK_USE_NATIVE_MX_SUPPORT | ||
| CK_GFX1250_SUPPORT | ||
| CK_GFX12_SUPPORT | ||
| CK_USE_OCP_FP8 | ||
| CK_USE_WMMA | ||
| CK_USE_WMMA_FP8 | ||
| CK_USE_XDL | ||
| __HIP_PLATFORM_AMD__=1 | ||
| __HIP_PLATFORM_HCC__=1 | ||
|
|
@@ -430,7 +437,7 @@ IF(USE_MSLK) | |
| list(PREPEND MSLK_EXTRA_HIPCC_FLAGS -mllvm -amdgpu-coerce-illegal-types=1) | ||
| endif() | ||
|
|
||
| # Only compile for gfx942 and gfx950. | ||
| # Only compile for gfx942 and gfx950 (composable_kernel lacks gfx1250 support). | ||
| set(HIP_CLANG_FLAGS_ORIGINAL ${HIP_CLANG_FLAGS}) | ||
| string(REGEX REPLACE "--offload-arch=[^ ]*" "" FILTERED_HIP_CLANG_FLAGS "${HIP_CLANG_FLAGS}") | ||
| foreach(ARCH gfx942 gfx950) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -714,14 +714,16 @@ std::optional<c10::ScalarType> out_dtype) { | |
| bool use_fast_path = false; | ||
| // ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path. | ||
| #if defined(USE_ROCM_CK_GEMM) | ||
| if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) { | ||
| if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a", "gfx1250"})) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it true that the existing CK grouped GEMM path is the Wave64/MFMA/XDL path used gfx90a/gfx942/gfx950? If so, because gfx1250 is Wave32 and WMMA/SWMMAC-oriented, it may not be routed into this path by arch-name allowlisting. |
||
| use_fast_path = true; | ||
| } | ||
| #endif //USE_ROCM_CK_GEMM | ||
| const auto out_dtype_ = _resolve_grouped_mm_out_dtype(mat_a, mat_b, out_dtype); | ||
| Tensor out = create_grouped_gemm_output_tensor(mat_a, mat_b, offs, out_dtype_); | ||
| if (use_fast_path) { | ||
| #if defined(USE_ROCM_CK_GEMM) | ||
| at::hip::detail::group_gemm_ck(mat_a, mat_b, offs, bias, out); | ||
| #endif //USE_ROCM_CK_GEMM | ||
| } else { | ||
| _grouped_mm_fallback(mat_a, mat_b, offs, bias, out_dtype, out); | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,7 +13,8 @@ | |
|
|
||
| #if ROCM_VERSION < 60400 | ||
| __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { | ||
| #if (defined(__gfx942__)) && \ | ||
| // `__gfx1250__`-specific `s_wait_loadcnt(0)` path for committed store already there | ||
| #if (defined(__gfx942__) || defined(__gfx1250__)) && \ | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this change matter now, if the outer condition is
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Addressed in #3347 |
||
| __has_builtin(__builtin_amdgcn_flat_atomic_fadd_v2bf16) | ||
| typedef unsigned short __attribute__((ext_vector_type(2))) vec_short2; | ||
| static_assert(sizeof(vec_short2) == sizeof(__hip_bfloat162_raw)); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -187,7 +187,8 @@ template <int vec_size, typename scalar_t> | |
| __device__ aligned_vector<scalar_t, vec_size> load_vector(const scalar_t *base_ptr, uint32_t offset) { | ||
| using vec_t = aligned_vector<scalar_t, vec_size>; | ||
| auto *from = reinterpret_cast<const vec_t *>(base_ptr); | ||
| #if defined(USE_ROCM) && defined(__gfx942__) | ||
| // Extend the non-temporal load optimization to GFX1250. | ||
| #if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Simply extending to gfx1250, would this be another Wave64-tuned path being applied to Wave32 hardware? |
||
| using longx2 = __attribute__((__vector_size__(4*sizeof(int)))) int; | ||
| if constexpr (sizeof(vec_t) == sizeof(int)) { | ||
| union { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -79,6 +79,9 @@ bool _scaled_mm_allowed_device(bool sm90_only=false, bool sm100_only=false) { | |
| #endif | ||
| #if ROCM_VERSION >= 60500 | ||
| "gfx950" | ||
| #endif | ||
| #if ROCM_VERSION >= 70200 | ||
| , "gfx1250" | ||
| #endif | ||
| }; | ||
| return at::detail::getCUDAHooks().isGPUArch(archs); | ||
|
|
@@ -622,8 +625,8 @@ _scaled_mm_out_cuda(const Tensor& mat1, const Tensor& mat2, | |
| else if (scaling_choice_a == ScalingType::BlockWise1x32 && scaling_choice_b == ScalingType::BlockWise1x32) { | ||
| #ifdef USE_ROCM | ||
| #if ROCM_VERSION >= 70000 | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); | ||
|
|
||
| int packed_factor = 1; | ||
| if (mat1.scalar_type() == ScalarType::Float4_e2m1fn_x2) { | ||
|
|
@@ -1064,8 +1067,8 @@ _scaled_mxfp8_mxfp8( | |
|
|
||
| #ifdef USE_ROCM | ||
| #if ROCM_VERSION >= 70000 | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Above How about nesting |
||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); | ||
|
|
||
| TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && | ||
| mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, | ||
|
|
@@ -1150,8 +1153,8 @@ _scaled_mxfp4_mxfp4( | |
| auto scaling_choice_b = ScalingType::BlockWise1x32; | ||
|
|
||
| #if ROCM_VERSION >= 70000 | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950/gfx1250"); | ||
|
|
||
| TORCH_CHECK_VALUE(mat_a.size(0) % 32 == 0 && mat_a.size(1) % 32 == 0 && | ||
| mat_b.size(0) % 32 == 0 && mat_b.size(1) % 32 == 0, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,7 +30,7 @@ static void initHipSparseLtSupport() { | |
| // Check only the first available device | ||
| try { | ||
| if (at::cuda::device_count() > 0) { | ||
| g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); | ||
| g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942", "gfx1250"}, 0); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we confirm whether hipSparseLT requires ROCm 7.2+?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, hipsparselt actually requires ROCm >=7.12. PR is in progress pytorch#178737 |
||
| } | ||
| } catch (const std::exception&) { | ||
| // If an exception occurs during device property check, we assume hipSparseLt is not supported | ||
|
|
@@ -49,7 +49,7 @@ static bool isHipSparseLtSupported() { | |
| TORCH_CHECK( | ||
| false, | ||
| "hipSparseLt not supported on this device, supported architectures: " | ||
| "gfx950, gfx942. " | ||
| "gfx1250, gfx950, gfx942. " | ||
| "required ROCM version: 6.4.0 or later."); | ||
| } | ||
| return g_hipSparseLtSupported; | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes are unnecessary unless we know for certain any build workflows that would use it. TheRock build workflows don't.