[release/2.12] Add support for gfx1250#3327
Conversation
…odule + triton pins)
* CK - gfx1250 support (#5) * Enable ROCM_CK_SDPA build * [submodule] composable_kernel and aiter update (pytorch#172592) Summary: update ck to commit ROCm/composable_kernel@fcc9372 update aiter to commit ROCm/aiter@9a469a6 changes of caffe2/aten/src/ATen/CMakeLists.txt and caffe2/caffe2/CMakeLists.txt are adopted from pytorch#161759 updated caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp to match the ck version in https://github.com/ROCm/composable_kernel/blob/292df2719f28cd01464d5d059820684790c101da/include/ck_tile/host/kernel_launch.hpp update aiter fav3 bwd codegen according to changes in ROCm/aiter#1573 update caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck mha fwd/bwd kernels according to the interfaces in https://github.com/ROCm/composable_kernel/tree/292df2719f28cd01464d5d059820684790c101da/example/ck_tile/01_fmha Differential Revision: D88991877 Pull Request resolved: pytorch#172592 Approved by: https://github.com/alugorey, https://github.com/izaitsevfb * Added MI450 supports and packages * Fix misalinged ck api * Replace aiter with ck for bwd * [ROCm] Bump AOTriton to 0.11.2b (pytorch#174105) Notable new features: * AOTriton 0.11.2b adds gfx1151/1152/1153 support. * Add precompiled AOTriton runtime for ROCM 7.2 * Match the sliding window attention behavior of `_flash_attention_forward/backward` with CUTLASS backend. Bug fixes: * Fixes pytorch#173204. Now all tests in `test/test_varlen_attention.py` are enabled on ROCm Notes: This replaces PR pytorch#173820 and pytorch#173469 Pull Request resolved: pytorch#174105 Approved by: https://github.com/jeffdaily * Fix philox data types for this version of ck * Update CK to use new gfx1250_pytorch branch * Add new gfx1250 compile flags for CK * add --targets to generate and a couple new compile flags * Remove default USE_ROCM_CK_SDPA --------- Co-authored-by: blorange-amd <bo.li2@amd.com> Co-authored-by: Yu Guo <yuguo@meta.com> Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com> * Updated aiter module * Fixed merged error * Fixed additional merged error * Reset USE_ROCM_CK_SDPA config --------- Co-authored-by: LugoReyes, Andy <Andy.LugoReyes@amd.com> Co-authored-by: Yu Guo <yuguo@meta.com> Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
Fix `torch.arange` (and the other range factories sharing this kernel) for very large outputs on ROCm. `torch.arange(N)` with `N >= 2^32` fails on ROCm because `hipLaunchKernel` does not support `gridDim.x * blockDim.x >= 2^32` for the per-thread kernel `aten/src/ATen/native/cuda/RangeFactories.cu` previously used. Depending on the ROCm version the launch returns `hipErrorInvalidConfiguration` or is accepted silently with the kernel never executing, leaving zero-initialized output. Concrete repro: `torch.arange(2 ** 32 + 1, device="cuda", dtype=torch.int32)`. The fix replaces the per-thread launch on the ROCm path with a grid-stride loop that fixes the grid at `sm_count * 4` blocks, so the launch limit is no longer load-bearing for correctness regardless of `N`. The non-ROCm path is untouched. On MI250X the grid-stride kernel matches the per-thread kernel within noise at `N=1024` and is 24-60% faster from `N=1M` up across `int32`, `int64`, and `float32`. On MI300X the grid-stride kernel matches within noise at `N=1024` and `N=1M`, and is 2-5x faster from `N=64M` up across `int32`, `int64`, and `float32`. The 64-bit-indexing test is extended to also cover `N = 2^32 + 1` and `N = 2^33 + 1` on ROCm when memory permits. Pull Request resolved: pytorch#182657 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
* TDM on release/2.11 for bring-up based on careful selection * Triton commit: Upstream fe0c38b5262c0447fed6df0d37e02cb8ea75deb4 -> AMD-ROCm-Internal Triton 250bb5d5b821377f49dc2d83d87ded75b952f0f7; Consequence: Triton TDM support may miss. * Refinement according to reviewers' comments * Added/modified UT cases; NUM_STAGES issue of ineffectiveness * A couple of changes to related UTs * Got rid of configs like `waves_per_cu=2`
- Need to turn MSLK on for mi300 and mi350 - Need to turn CK off for gfx1250 ## Motivation <!-- Explain the purpose of this PR and the goals it aims to achieve. --> ## Technical Details <!-- Explain the changes along with any relevant GitHub links. --> ## Test Plan <!-- Explain any relevant testing done to verify this PR. --> ## Test Result <!-- Briefly summarize test outcomes. --> ## Submission Checklist - [ ] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
# Bump to AOTriton 0.12.50tp Notable new features: Enable gfx1250 ## Features from AOTriton 0.12b Notable new features: * **BREAKING** Varlen LSE tensor shape changes to (H, Total_seqlen) * Support head_dim != head_dim_v * Support `use_deterministic_algorithims` * Support seqused_k in test/test_varlen_attention.py * gfx1100 and gfx1151 promoted out of experimental * Partial FAv3 support on gfx950 Bug Fixes: * GQA kernel failed to read bias tensor with the right offset. Known Issues * gfx950's Triton kernel has problem handling hdim=16's fwd, in addition to hdim=48/80's bwd. * Disables gfx90a's CK SDPA support due to GPU Segfault. --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Prachi Gupta <pracgupt@amd.com>
|
Jenkins build for aeb64a7497d08b2da400801c9340834bd6bde3f1 commit finished as FAILURE Detected error during Pytorch building: |
Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
|
Jenkins build for bbecc5657577e15f0c0aa057daf34b4e4be41c31 commit finished as NOT_BUILT |
|
Jenkins build for 04b54055439beb8a156f244c3fd3cdb9e31a1d3b commit finished as NOT_BUILT |
This PR is to address the reviewed comments on PR #3307
|
Jenkins build for 43a62ac6cc17a57487826c1b7b0f6e7cf96a43c1 commit finished as NOT_BUILT |
|
Jenkins build for 43a62ac6cc17a57487826c1b7b0f6e7cf96a43c1 commit finished as FAILURE |
| # generate a list of kernels, but not actually emit files at config stage | ||
| execute_process( | ||
| COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py | ||
| COMMAND python3 ${CMAKE_SOURCE_DIR}/third_party/composable_kernel/example/ck_tile/01_fmha/generate.py --targets gfx1250 |
There was a problem hiding this comment.
Here, --targets gfx1250 is appended to every generate.py invocation, unconditionally.
This restricts CK FMHA blob generation to gfx1250 for all builds, including pure gfx942/gfx950 builds. A PYTORCH_ROCM_ARCH=gfx942 build will now emit only gfx1250 kernels and lose its own FMHA code objects.
Suggested fix: derive --targets from PYTORCH_ROCM_ARCH (filtered to CK-supported archs), or drop the flag entirely and keep the generator's default multi-target behavior.
There was a problem hiding this comment.
CK is turned off for gfx1250, not a priority at the moment. We can probably just drop this.
| constexpr size_t kSmallSize = 1048576; | ||
| // allocations between 1 and 10 MiB may use kLargeBuffer | ||
| constexpr size_t kMinLargeAlloc = 10485760; | ||
| #if defined(USE_ROCM) && defined(__gfx1250__) |
There was a problem hiding this comment.
This is host-side allocator code. __gfx1250__ is a device-compilation predefine. As a result, this #if block never compiles?
| auto persistent_counter = mk_atomictensor(is_causal ? atomic_counter.data_ptr<int32_t>() : nullptr); | ||
| hipError_t err; // TODO: Error handling | ||
| if constexpr (AOTRITON_ALWAYS_V3_API) { // Better readability than nesting ifdef | ||
| #if AOTRITON_V3_API // if constexpr does not stop errors from undefined functions |
There was a problem hiding this comment.
This removal of code block below made a AOTriton v3 hard-switch and removed the v2 fallback without a build guard. Would this bring portability risk?
There was a problem hiding this comment.
Hi @xinyazhang, could you please help me address this review w.r.t cherry-pick of aeb64a7 ?
There was a problem hiding this comment.
ultimately is 0.12.50tp is 0.12b with gfx1250 support (yes it's ABI compatible). The related PR is also 0.12b's PR+version bump 0.12b->0.12.50tp.
| #if ROCM_VERSION >= 70000 | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950"}), | ||
| "Block-wise scaling for Float8_e8m0fnu is only supported on gfx950"); | ||
| TORCH_CHECK_NOT_IMPLEMENTED(at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx1250"}), |
There was a problem hiding this comment.
Above _scaled_mm_allowed_device() (line ~82) gates gfx1250 at >= 70200. So So on ROCm 7.0/7.1 the device is rejected by _scaled_mm_allowed_device yet these inner checks would have admitted it.
How about nesting #if ROCM_VERSION >= 70200 inside each isGPUArch({...})?
| try { | ||
| if (at::cuda::device_count() > 0) { | ||
| g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942"}, 0); | ||
| g_hipSparseLtSupported = at::detail::getCUDAHooks().isGPUArch({"gfx950", "gfx942", "gfx1250"}, 0); |
There was a problem hiding this comment.
Can we confirm whether hipSparseLT requires ROCm 7.2+?
gfx1250 is advertised unconditionally here, which might fail deeper.
There was a problem hiding this comment.
Yes, hipsparselt actually requires ROCm >=7.12. PR is in progress pytorch#178737
| CK_USE_GFX94 | ||
| #CK_USE_FNUZ_FP8 | ||
| #CK_USE_GFX94 | ||
| CK_USE_GFX1250 |
There was a problem hiding this comment.
Here and below change the CK SDPA compile definitions globally.
Are CK/AITER artifacts for GFX1250 actually ready and validated?
|
|
||
| # composable_kernel has no gfx1250 support, so its CK GEMM/SDPA kernels fail | ||
| # to compile for that arch. | ||
| if("gfx1250" IN_LIST PYTORCH_ROCM_ARCH) |
There was a problem hiding this comment.
This disables both USE_ROCM_CK_GEMM and USE_ROCM_CK_SDPA whenever PYTORCH_ROCM_ARCH contains gfx1250.
Would this break mixed-arch builds such as gfx942;gfx950;gfx1250?
There was a problem hiding this comment.
Yes, this is highly problematic for multi-arch builds. We could follow the same approach like we do for other sub-components of PyTorch build by using HIP_CLANG_FLAGS temporary override?
There was a problem hiding this comment.
This was solved on release/2.11 in #3346 (merged). That PR moves the logic out of Dependencies.cmake and into aten/src/ATen/CMakeLists.txt:
pytorch/aten/src/ATen/CMakeLists.txt
Lines 203 to 216 in 712584b
| // ifdef USE_ROCM_CK_GEMM is required since ROCm systems w/o CK should not call ck path. | ||
| #if defined(USE_ROCM_CK_GEMM) | ||
| if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a"})) { | ||
| if (at::globalContext().rocmAllowGroupGemmCk() && at::detail::getCUDAHooks().isGPUArch({"gfx942", "gfx950", "gfx90a", "gfx1250"})) { |
There was a problem hiding this comment.
Is it true that the existing CK grouped GEMM path is the Wave64/MFMA/XDL path used gfx90a/gfx942/gfx950? If so, because gfx1250 is Wave32 and WMMA/SWMMAC-oriented, it may not be routed into this path by arch-name allowlisting.
| auto *from = reinterpret_cast<const vec_t *>(base_ptr); | ||
| #if defined(USE_ROCM) && defined(__gfx942__) | ||
| // Extend the non-temporal load optimization to GFX1250. | ||
| #if defined(USE_ROCM) && (defined(__gfx942__) || defined(__gfx1250__)) |
There was a problem hiding this comment.
Simply extending to gfx1250, would this be another Wave64-tuned path being applied to Wave32 hardware?
| # CDNA4 (gfx950) 160KB, and CDNA5 (gfx1250) 320KB. | ||
| if device_props.gcnArchName == "gfx950": | ||
| max_shared_mem = 160 * 1024 | ||
| elif device_props.gcnArchName == "gfx1250": |
There was a problem hiding this comment.
gcnArchName can include feature suffixes such as gfx1250:sramecc+:xnack-. Would this exact comparison lead to unexpected fallback?
There was a problem hiding this comment.
These changes are unnecessary unless we know for certain any build workflows that would use it. TheRock build workflows don't.
There was a problem hiding this comment.
These changes are unnecessary unless we know for certain any build workflows that would use it. TheRock build workflows don't.
| __device__ inline __hip_bfloat162 preview_unsafeAtomicAdd(__hip_bfloat162* address, __hip_bfloat162 value) { | ||
| #if (defined(__gfx942__)) && \ | ||
| // `__gfx1250__`-specific `s_wait_loadcnt(0)` path for committed store already there | ||
| #if (defined(__gfx942__) || defined(__gfx1250__)) && \ |
There was a problem hiding this comment.
Does this change matter now, if the outer condition is #if ROCM_VERSION < 60400?
Add support for gfx1250
TheRock Validation: https://github.com/ROCm/TheRock/actions/runs/27717422954
Build is passing. Testing is in progress.