[releease/2.12] Support for gfx1250#3322
Closed
rraminen wants to merge 5 commits into
Closed
Conversation
…odule + triton pins)
* CK - gfx1250 support (pytorch#5) * Enable ROCM_CK_SDPA build * [submodule] composable_kernel and aiter update (pytorch#172592) Summary: update ck to commit ROCm/composable_kernel@fcc9372 update aiter to commit ROCm/aiter@9a469a6 changes of caffe2/aten/src/ATen/CMakeLists.txt and caffe2/caffe2/CMakeLists.txt are adopted from pytorch#161759 updated caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck/launch_kernel_pt.hpp to match the ck version in https://github.com/ROCm/composable_kernel/blob/292df2719f28cd01464d5d059820684790c101da/include/ck_tile/host/kernel_launch.hpp update aiter fav3 bwd codegen according to changes in ROCm/aiter#1573 update caffe2/aten/src/ATen/native/transformers/hip/flash_attn/ck mha fwd/bwd kernels according to the interfaces in https://github.com/ROCm/composable_kernel/tree/292df2719f28cd01464d5d059820684790c101da/example/ck_tile/01_fmha Differential Revision: D88991877 Pull Request resolved: pytorch#172592 Approved by: https://github.com/alugorey, https://github.com/izaitsevfb * Added MI450 supports and packages * Fix misalinged ck api * Replace aiter with ck for bwd * [ROCm] Bump AOTriton to 0.11.2b (pytorch#174105) Notable new features: * AOTriton 0.11.2b adds gfx1151/1152/1153 support. * Add precompiled AOTriton runtime for ROCM 7.2 * Match the sliding window attention behavior of `_flash_attention_forward/backward` with CUTLASS backend. Bug fixes: * Fixes pytorch#173204. Now all tests in `test/test_varlen_attention.py` are enabled on ROCm Notes: This replaces PR pytorch#173820 and pytorch#173469 Pull Request resolved: pytorch#174105 Approved by: https://github.com/jeffdaily * Fix philox data types for this version of ck * Update CK to use new gfx1250_pytorch branch * Add new gfx1250 compile flags for CK * add --targets to generate and a couple new compile flags * Remove default USE_ROCM_CK_SDPA --------- Co-authored-by: blorange-amd <bo.li2@amd.com> Co-authored-by: Yu Guo <yuguo@meta.com> Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com> * Updated aiter module * Fixed merged error * Fixed additional merged error * Reset USE_ROCM_CK_SDPA config --------- Co-authored-by: LugoReyes, Andy <Andy.LugoReyes@amd.com> Co-authored-by: Yu Guo <yuguo@meta.com> Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
Fix `torch.arange` (and the other range factories sharing this kernel) for very large outputs on ROCm. `torch.arange(N)` with `N >= 2^32` fails on ROCm because `hipLaunchKernel` does not support `gridDim.x * blockDim.x >= 2^32` for the per-thread kernel `aten/src/ATen/native/cuda/RangeFactories.cu` previously used. Depending on the ROCm version the launch returns `hipErrorInvalidConfiguration` or is accepted silently with the kernel never executing, leaving zero-initialized output. Concrete repro: `torch.arange(2 ** 32 + 1, device="cuda", dtype=torch.int32)`. The fix replaces the per-thread launch on the ROCm path with a grid-stride loop that fixes the grid at `sm_count * 4` blocks, so the launch limit is no longer load-bearing for correctness regardless of `N`. The non-ROCm path is untouched. On MI250X the grid-stride kernel matches the per-thread kernel within noise at `N=1024` and is 24-60% faster from `N=1M` up across `int32`, `int64`, and `float32`. On MI300X the grid-stride kernel matches within noise at `N=1024` and `N=1M`, and is 2-5x faster from `N=64M` up across `int32`, `int64`, and `float32`. The 64-bit-indexing test is extended to also cover `N = 2^32 + 1` and `N = 2^33 + 1` on ROCm when memory permits. Pull Request resolved: pytorch#182657 Approved by: https://github.com/jeffdaily Co-authored-by: Jeff Daily <jeff.daily@amd.com>
* TDM on release/2.11 for bring-up based on careful selection * Triton commit: Upstream fe0c38b5262c0447fed6df0d37e02cb8ea75deb4 -> AMD-ROCm-Internal Triton 250bb5d5b821377f49dc2d83d87ded75b952f0f7; Consequence: Triton TDM support may miss. * Refinement according to reviewers' comments * Added/modified UT cases; NUM_STAGES issue of ineffectiveness * A couple of changes to related UTs * Got rid of configs like `waves_per_cu=2`
|
Jenkins build for 3ba54943493ab56f7724b8cf79014121ac1af81f commit finished as FAILURE |
Collaborator
|
#3327 is more comprehensive, closing this one |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Support for gfx1250