diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py index 0cdb8dd087..fb88e7a488 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py @@ -786,7 +786,8 @@ class _S2TCopyBase(CopyOp): :param cta_group: Cooperative Thread Array (CTA) group configuration :type cta_group: CtaGroup - :raises OpError: If the current architecture is not SM100f family or if invalid parameters are provided + :raises OpError: If the current architecture is not SM100f or SM110f + family or if invalid parameters are provided """ cta_group: CtaGroup @@ -794,8 +795,11 @@ class _S2TCopyBase(CopyOp): def __post_init__(self) -> None: # Arch verification arch = BaseDSL._get_dsl().get_arch_enum() - if not arch.is_family_of(Arch.sm_100f): - supported = Arch.filter(lambda a: a.is_family_of(Arch.sm_100f)) + # S2T tcgen05 copy encodings are valid on both SM100 and Thor SM110. + if not (arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f)): + supported = Arch.filter( + lambda a: a.is_family_of(Arch.sm_100f) or a.is_family_of(Arch.sm_110f) + ) raise OpError( self, f"expects arch to be one of {supported}, but got {arch}", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py index 2c9902424c..e76fc32d8e 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py @@ -387,6 +387,7 @@ class BlockScaledMmaOp(Tcgen05MmaOp): admissible_archs = [ Arch.sm_100a, Arch.sm_103a, + Arch.sm_110a, ] def __post_init__(self) -> None: diff --git a/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py b/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py index e6f5b8eac0..105769fc38 100644 --- a/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py +++ b/test/examples/CuTeDSL/sm_100a/test_dense_blockscaled_gemm_persistent_prefetch.py @@ -51,6 +51,54 @@ pytestmark = [pytest.mark.arch(["100a"])] +@pytest.mark.L0 +@pytest.mark.arch(["110a"]) +def test_dense_blockscaled_gemm_prefetch_sm110a_compile(): + mnkl = (128, 128, 64, 1) + ab_dtype = cutlass.Float4E2M1FN + sf_dtype = cutlass.Float8E8M0FNU + sf_vec_size = 16 + c_dtype = cutlass.Float16 + a_major = "k" + b_major = "k" + c_major = "n" + mma_tiler_mn = (128, 128) + cluster_shape_mn = (1, 1) + + if not Sm100BlockScaledPersistentDenseGemmKernel.can_implement( + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + mma_tiler_mn, + cluster_shape_mn, + mnkl[0], + mnkl[1], + mnkl[2], + mnkl[3], + a_major, + b_major, + c_major, + ): + pytest.skip("Configuration not supported on SM110a") + + run( + mnkl, + ab_dtype, + sf_dtype, + sf_vec_size, + c_dtype, + a_major, + b_major, + c_major, + mma_tiler_mn, + cluster_shape_mn, + warmup_iterations=0, + iterations=0, + skip_ref_check=True, + ) + + @pytest.mark.invalid_case( lambda: ( not Sm100BlockScaledPersistentDenseGemmKernel.can_implement(