diff --git a/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/sm120_mxf4nvf4_native_tma_microtile.py b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/sm120_mxf4nvf4_native_tma_microtile.py new file mode 100644 index 0000000000..6f73e964cd --- /dev/null +++ b/examples/python/CuTeDSL/cute/blackwell/kernel/blockscaled_gemm/sm120_mxf4nvf4_native_tma_microtile.py @@ -0,0 +1,281 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +"""Minimal SM120 MXF4/NVFP4 native-TMA microtile smoke example. + +This module exposes a callable JIT entry point. The corresponding pytest smoke +test demonstrates invocation and checks the fixed instruction mix. +""" + +from typing import Optional + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass._mlir import ir +from cutlass.cute.nvgpu import cpasync, warp +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.utils.gemm import sm120 +from cutlass.utils.gemm.sm120.constants import ( + MXF4NVF4_SCALE_TMA_BYTES, + mxf4nvf4_full_tma_bytes, +) +from cutlass.utils.smem_allocator import SmemAllocator + +_MXF4NVF4_SCALE_INTERLEAVED_K64_BYTES = MXF4NVF4_SCALE_TMA_BYTES // 2 +_SM120_MXF4NVF4_MICROTILE_SMEM_BYTES = 49152 + + +@dsl_user_op +def _issue_tma_load( + tma_atom: cute.CopyAtom, + tma_tensor: cute.Tensor, + smem_tensor: cute.Tensor, + tma_bar_ptr: cute.Pointer, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + tS, tG = cpasync.tma_partition( + tma_atom, + 0, + cute.make_layout(1, loc=loc, ip=ip), + cute.group_modes(smem_tensor, 0, cute.rank(smem_tensor) - 1, loc=loc, ip=ip), + cute.group_modes(tma_tensor, 0, cute.rank(tma_tensor), loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + cute.copy( + tma_atom, + tG, + tS[(None, 0)], + tma_bar_ptr=tma_bar_ptr, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _load_ab_k_block( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + sB: cute.Tensor, + a_frag: cute.Tensor, + b_frag: cute.Tensor, + tidx: cutlass.Int32, + k_block: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + copy_atom = cute.make_copy_atom( + warp.LdMatrix8x8x16bOp(transpose=False, num_matrices=4), + cutlass.Float4E2M1FN, + loc=loc, + ip=ip, + ) + tiled_copy_a = cute.make_tiled_copy_A(copy_atom, tiled_mma, loc=loc, ip=ip) + tiled_copy_b = cute.make_tiled_copy_B(copy_atom, tiled_mma, loc=loc, ip=ip) + thr_copy_a = tiled_copy_a.get_slice(tidx) + thr_copy_b = tiled_copy_b.get_slice(tidx) + sA_src = cute.as_position_independent_swizzle_tensor(sA, loc=loc, ip=ip) + sB_src = cute.as_position_independent_swizzle_tensor(sB, loc=loc, ip=ip) + tCsA = thr_copy_a.partition_S(sA_src, loc=loc, ip=ip) + tCsB = thr_copy_b.partition_S(sB_src, loc=loc, ip=ip) + tCrA = thr_copy_a.retile(a_frag, loc=loc, ip=ip) + tCrB = thr_copy_b.retile(b_frag, loc=loc, ip=ip) + tCsA_stage = tCsA[(None, None, None, 0)] + tCsB_stage = tCsB[(None, None, None, 0)] + cute.copy( + tiled_copy_a, + tCsA_stage[(None, None, k_block)], + tCrA[(None, None, k_block)], + loc=loc, + ip=ip, + ) + cute.copy( + tiled_copy_b, + tCsB_stage[(None, None, k_block)], + tCrB[(None, None, k_block)], + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _load_uniform_scale_fragment_from_first_scale_column( + scale_smem: cute.Tensor, + is_sfa: cutlass.Constexpr[bool], + k_block: int, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Tensor: + """Load a scale fragment for this uniform-scale microtile only. + + This intentionally reads the first compact FP8 scale column for each K64 + half. It is not a general SM120 scale-fragment partitioner. + """ + if is_sfa: + scale_frag = warp.make_mxf4nvf4_sfa_fragment(loc=loc, ip=ip) + else: + scale_frag = warp.make_mxf4nvf4_sfb_fragment(loc=loc, ip=ip) + scale_src = cute.recast_tensor( + cute.make_tensor( + scale_smem.iterator + k_block * _MXF4NVF4_SCALE_INTERLEAVED_K64_BYTES, + cute.make_layout(4, loc=loc, ip=ip), + ), + cutlass.Float8E4M3FN, + loc=loc, + ip=ip, + ) + cute.filter_zeros(scale_frag, loc=loc, ip=ip).store( + scale_src.load(loc=loc, ip=ip), loc=loc, ip=ip + ) + return scale_frag + + +@dsl_user_op +def _store_accumulator( + thr_mma: cute.ThrMma, + acc: cute.Tensor, + out: cute.Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + tDgD = thr_mma.partition_C(out, loc=loc, ip=ip) + rD = cute.make_rmem_tensor(acc.layout, out.element_type, loc=loc, ip=ip) + rD.store(acc.load(loc=loc, ip=ip).to(out.element_type), loc=loc, ip=ip) + copy_atom = cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), out.element_type, loc=loc, ip=ip + ) + cute.copy(copy_atom, rD, tDgD, loc=loc, ip=ip) + + +@cute.jit +def sm120_mxf4nvf4_native_tma_microtile( + a: cute.Tensor, + b: cute.Tensor, + sfa: cute.Tensor, + sfb: cute.Tensor, + d: cute.Tensor, + stream: cuda.CUstream = cuda.CUstream(0), +): + """Compute one 16x8 output microtile from a native-TMA SM120 K128 tile.""" + gSFA = cute.make_tensor( + sfa.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + gSFB = cute.make_tensor( + sfb.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + ( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + ) = sm120.make_mxf4nvf4_native_tma_atoms( + a, + b, + gSFA, + gSFB, + ab_smem_format="packed", + ab_tile_coord=(0, 0, 0), + scale_tile_coord=(0, 0, 0, 0), + ) + + _sm120_mxf4nvf4_native_tma_microtile_kernel( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + d, + ).launch( + grid=[1, 1, 1], + block=[32, 1, 1], + stream=stream, + smem=_SM120_MXF4NVF4_MICROTILE_SMEM_BYTES, + ) + + +@cute.kernel +def _sm120_mxf4nvf4_native_tma_microtile_kernel( + tma_atom_a: cute.CopyAtom, + tma_tensor_a: cute.Tensor, + tma_atom_b: cute.CopyAtom, + tma_tensor_b: cute.Tensor, + tma_atom_sfa: cute.CopyAtom, + tma_tensor_sfa: cute.Tensor, + tma_atom_sfb: cute.CopyAtom, + tma_tensor_sfb: cute.Tensor, + d: cute.Tensor, +): + tidx, _, _ = cute.arch.thread_idx() + + @cute.struct + class SharedStorage: + tma_barrier: cute.struct.MemRange[cutlass.Int64, 1] + + smem = SmemAllocator() + storage = smem.allocate(SharedStorage) + sA, sB, sSFA, sSFB = sm120.make_mxf4nvf4_native_tma_smem_views( + smem, + ab_smem_format="packed", + ) + tma_bar_ptr = storage.tma_barrier.data_ptr() + + with cute.arch.elect_one(): + cute.arch.mbarrier_init(tma_bar_ptr, 1) + cute.arch.mbarrier_expect_tx( + tma_bar_ptr, mxf4nvf4_full_tma_bytes("packed") + ) + cute.arch.mbarrier_init_fence() + cute.arch.barrier() + + _issue_tma_load(tma_atom_a, tma_tensor_a, sA, tma_bar_ptr) + _issue_tma_load(tma_atom_b, tma_tensor_b, sB, tma_bar_ptr) + _issue_tma_load(tma_atom_sfa, tma_tensor_sfa, sSFA, tma_bar_ptr) + _issue_tma_load(tma_atom_sfb, tma_tensor_sfb, sSFB, tma_bar_ptr) + with cute.arch.elect_one(): + cute.arch.mbarrier_arrive(tma_bar_ptr) + + cute.arch.mbarrier_wait(tma_bar_ptr, 0) + cute.arch.barrier() + + tiled_mma = sm120.make_mxf4nvf4_tiled_mma(atom_layout_mnk=(1, 1, 1)) + thr_mma = tiled_mma.get_slice(tidx) + tCsA_mma = thr_mma.partition_A(sA) + tCsB_mma = thr_mma.partition_B(sB) + a_frag = tiled_mma.make_fragment_A(tCsA_mma[None, None, None, 0]) + b_frag = tiled_mma.make_fragment_B(tCsB_mma[None, None, None, 0]) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + acc.fill(0.0) + + for k_block in range(2): + _load_ab_k_block(tiled_mma, sA, sB, a_frag, b_frag, tidx, k_block) + sfa_frag = _load_uniform_scale_fragment_from_first_scale_column( + sSFA, True, k_block + ) + sfb_frag = _load_uniform_scale_fragment_from_first_scale_column( + sSFB, False, k_block + ) + cute.gemm( + tiled_mma, + acc, + (a_frag[(None, 0, k_block)], sfa_frag), + (b_frag[(None, 0, k_block)], sfb_frag), + acc, + ) + + _store_accumulator(thr_mma, acc, d) diff --git a/python/CuTeDSL/cutlass/cute/__init__.py b/python/CuTeDSL/cutlass/cute/__init__.py index aaa7787e63..2527d1e8b2 100644 --- a/python/CuTeDSL/cutlass/cute/__init__.py +++ b/python/CuTeDSL/cutlass/cute/__init__.py @@ -149,6 +149,7 @@ TensorSSA, ReductionOp, make_tensor, + as_position_independent_swizzle_tensor, make_identity_tensor, make_fragment, make_fragment_like, @@ -286,6 +287,7 @@ # Tensor functions "make_ptr", "make_tensor", + "as_position_independent_swizzle_tensor", "make_identity_tensor", "make_fragment", "make_fragment_like", diff --git a/python/CuTeDSL/cutlass/cute/algorithm.py b/python/CuTeDSL/cutlass/cute/algorithm.py index a6a99ea0c8..b8f4931b3b 100644 --- a/python/CuTeDSL/cutlass/cute/algorithm.py +++ b/python/CuTeDSL/cutlass/cute/algorithm.py @@ -115,6 +115,15 @@ def gemm( a_list = _normalize_variadic_tensor_operand(a, "a") b_list = _normalize_variadic_tensor_operand(b, "b") + if len(a_list) == 2 and len(b_list) == 2: + from .nvgpu.warp.mma import MmaSM120BlockScaledOp, mma_mxf4nvf4 + + if ( + isinstance(atom.op, MmaSM120BlockScaledOp) + and atom.op.is_mxf4nvf4_sm120() + ): + return mma_mxf4nvf4(atom, d, a_list, b_list, c, loc=loc, ip=ip) + # Rank validations based on the primary A/B tensors (guaranteed non-empty) a_rank = rank(a_list[0].shape) b_rank = rank(b_list[0].shape) diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py index 2c0c32be67..67549cd40a 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/__init__.py @@ -9,6 +9,8 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. +# ruff: noqa: F403, F405 + from .copy import * from .mma import * @@ -24,6 +26,11 @@ "MmaMXF8Op", "MmaMXF8F6F4Op", "MXF8F6F4_SUPPORTED_PAIRS", + "make_mxf4nvf4_sfa_fragment", + "make_mxf4nvf4_sfa_layout", + "make_mxf4nvf4_sfb_fragment", + "make_mxf4nvf4_sfb_layout", + "mma_mxf4nvf4", # copy.py "LdMatrix8x8x16bOp", "LdMatrix16x8x8bOp", diff --git a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py index dbdc296a4b..4010714be7 100644 --- a/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py +++ b/python/CuTeDSL/cutlass/cute/nvgpu/warp/mma.py @@ -10,11 +10,11 @@ # is strictly prohibited. from dataclasses import dataclass -from typing import Any, Optional, Type +from typing import Any, List, Optional, Tuple, Type, Union import enum from cutlass.base_dsl.arch import Arch -from cutlass.cutlass_dsl import BaseDSL +from cutlass.cutlass_dsl import BaseDSL, dsl_user_op from ..common import OpError @@ -27,15 +27,22 @@ Float16, BFloat16, Float32, + Int8, + Int32, Numeric, Pointer, + Uint8, + AddressSpace, + Layout, ) -from ...core import _pack_shape +from ...core import _pack_shape, cosize, filter_zeros, make_layout, rank, size +from ...tensor import TensorSSA, make_rmem_tensor, recast_tensor from ...typing import Tensor -from ...atom import MmaOp, Trait, make_atom +from ...atom import MmaAtom, MmaOp, Trait, make_atom from cutlass._mlir import ir import cutlass._mlir.dialects.cute_nvgpu as _cute_nvgpu_ir +import cutlass._mlir.dialects.vector as vector #################################################################################################### @@ -318,6 +325,15 @@ def _verify_fragment_B( ) -> None: pass + def is_mxf4nvf4_sm120(self) -> bool: + return ( + self.shape_mnk == (16, 8, 64) + and self.ab_dtype == Float4E2M1FN + and self.acc_dtype == Float32 + and self.sf_type == Float8E4M3FN + and self.sf_vec_size == 16 + ) + class Field(enum.Enum): """ @@ -493,6 +509,468 @@ class MmaMXF4NVF4Trait(MmaBlockScaledTrait): pass +def _normalize_mxf4nvf4_operand( + operand: Union[List[Tensor], Tuple[Tensor, ...]], + name: str, +) -> Tuple[Tensor, Tensor]: + if not isinstance(operand, (list, tuple)) or len(operand) != 2: + raise TypeError(f"`{name}` must be a two-tensor sequence `(fragment, scale)`") + fragment, scale = operand + if not isinstance(fragment, Tensor) or not isinstance(scale, Tensor): + raise TypeError(f"`{name}` must contain only Tensor operands") + return fragment, scale + + +def _require_rmem(tensor: Tensor, name: str) -> None: + if tensor.memspace != AddressSpace.rmem: + raise ValueError(f"`{name}` must be register-resident") + + +def _require_static_size(actual: Any, expected: int, name: str) -> None: + if actual != expected: + raise ValueError(f"`{name}` must have static size {expected}, but got {actual}") + + +def _validate_mxf4nvf4_packed_fragment_layout( + fragment: Tensor, + name: str, + *, + expected_logical_size: int, + expected_i32_size: int, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + _require_rmem(fragment, name) + if fragment.element_type == Float4E2M1FN: + _require_static_size( + size(fragment, loc=loc, ip=ip), expected_logical_size, name + ) + compact = filter_zeros(fragment, loc=loc, ip=ip) + elif fragment.element_type in (Int8, Uint8): + _require_static_size( + size(fragment, loc=loc, ip=ip), + expected_i32_size * 4, + name, + ) + compact = filter_zeros(fragment, loc=loc, ip=ip) + else: + raise TypeError( + f"`{name}` must have element type {Float4E2M1FN}, {Int8}, or {Uint8}" + ) + + i32_fragment = recast_tensor(compact, Int32, loc=loc, ip=ip) + _require_static_size( + size(i32_fragment, loc=loc, ip=ip), + expected_i32_size, + f"packed `{name}`", + ) + return i32_fragment + + +def _mxf4nvf4_fragment_atom_size(fragment: Tensor, fp4_size: int) -> int: + return fp4_size // 2 if fragment.element_type in (Int8, Uint8) else fp4_size + + +def _validate_mxf4nvf4_accumulator_layout( + accumulator: Tensor, + name: str, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + _require_rmem(accumulator, name) + if accumulator.element_type != Float32: + raise TypeError(f"`{name}` must have element type {Float32}") + _require_static_size(size(accumulator, loc=loc, ip=ip), 4, name) + + +def _validate_mxf4nvf4_scale_fragment_layout( + scale: Tensor, + name: str, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + _require_rmem(scale, name) + if scale.element_type != Float8E4M3FN: + raise TypeError(f"`{name}` must have element type {Float8E4M3FN}") + _require_static_size(size(scale, loc=loc, ip=ip), 64, name) + + compact = filter_zeros(scale, loc=loc, ip=ip) + _require_static_size( + cosize(compact.layout, loc=loc, ip=ip), + 4, + f"compact `{name}`", + ) + return compact + + +def _validate_mxf4nvf4_atom(atom: MmaAtom) -> None: + op = getattr(atom, "op", None) + if not isinstance(op, MmaSM120BlockScaledOp): + raise TypeError("`mma_mxf4nvf4` expects an SM120 warp blockscaled MMA atom") + if not op.is_mxf4nvf4_sm120(): + raise ValueError( + "SM120 NVFP4 MMA requires Float4E2M1FN A/B, Float32 accumulators, " + "Float8E4M3FN scales, shape (16, 8, 64), and scale_vec::4X" + ) + + +def _is_mxf4nvf4_full_k_tiled_fragment( + d: Tensor, + a_fragment: Tensor, + b_fragment: Tensor, + c: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> bool: + if ( + rank(a_fragment) < 3 + or rank(b_fragment) < 3 + or rank(d) < 3 + or rank(c) < 3 + ): + return False + a_atom_size = _mxf4nvf4_fragment_atom_size(a_fragment, 32) + b_atom_size = _mxf4nvf4_fragment_atom_size(b_fragment, 16) + if size(a_fragment, mode=[0], loc=loc, ip=ip) != a_atom_size: + return False + if size(b_fragment, mode=[0], loc=loc, ip=ip) != b_atom_size: + return False + if size(d, mode=[0], loc=loc, ip=ip) != 4: + return False + if size(c, mode=[0], loc=loc, ip=ip) != 4: + return False + m_tiles = size(a_fragment, mode=[1], loc=loc, ip=ip) + n_tiles = size(b_fragment, mode=[1], loc=loc, ip=ip) + k_tiles = size(a_fragment, mode=[2], loc=loc, ip=ip) + if size(b_fragment, mode=[2], loc=loc, ip=ip) != k_tiles: + return False + return ( + size(d, mode=[1], loc=loc, ip=ip) == m_tiles + and size(d, mode=[2], loc=loc, ip=ip) == n_tiles + and size(c, mode=[1], loc=loc, ip=ip) == m_tiles + and size(c, mode=[2], loc=loc, ip=ip) == n_tiles + ) + + +def _is_mxf4nvf4_tiled_fragment( + d: Tensor, + a_fragment: Tensor, + b_fragment: Tensor, + c: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> bool: + a_size = size(a_fragment, loc=loc, ip=ip) + b_size = size(b_fragment, loc=loc, ip=ip) + d_size = size(d, loc=loc, ip=ip) + c_size = size(c, loc=loc, ip=ip) + a_atom_size = _mxf4nvf4_fragment_atom_size(a_fragment, 32) + b_atom_size = _mxf4nvf4_fragment_atom_size(b_fragment, 16) + if ( + a_size == a_atom_size + and b_size == b_atom_size + and d_size == 4 + and c_size == 4 + ): + return False + if _is_mxf4nvf4_full_k_tiled_fragment( + d, a_fragment, b_fragment, c, loc=loc, ip=ip + ): + return True + if a_size % a_atom_size != 0 or b_size % b_atom_size != 0: + return False + a_tiles = a_size // a_atom_size + b_tiles = b_size // b_atom_size + return d_size == 4 * a_tiles * b_tiles and c_size == d_size + + +def _select_mxf4nvf4_scale_fragment( + scale: Tensor, + tile_idx: int, + tile_count: int, + name: str, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + scale_size = size(scale, loc=loc, ip=ip) + if scale_size == 64: + return scale + if scale_size == 64 * tile_count and rank(scale) >= 2: + return scale[(None, tile_idx)] + raise ValueError( + f"`{name}` must be a canonical scale fragment or one scale fragment per tile" + ) + + +def _select_mxf4nvf4_full_k_scale_fragment( + scale: Tensor, + major_idx: int, + k_idx: int, + major_tiles: int, + k_tiles: int, + name: str, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + scale_size = size(scale, loc=loc, ip=ip) + if scale_size == 64: + return scale + scale_rank = rank(scale) + mode_0_size = size(scale, mode=[0], loc=loc, ip=ip) if scale_rank >= 1 else 0 + mode_1_size = size(scale, mode=[1], loc=loc, ip=ip) if scale_rank >= 2 else 0 + mode_2_size = size(scale, mode=[2], loc=loc, ip=ip) if scale_rank >= 3 else 0 + if ( + scale_rank >= 3 + and mode_0_size == 64 + and mode_1_size == major_tiles + and mode_2_size == k_tiles + ): + return scale[(None, major_idx, k_idx)] + raise ValueError( + f"`{name}` must be a canonical scale fragment or one scale fragment per " + "major/K tile; got " + f"rank={scale_rank}, size={scale_size}, " + f"modes=({mode_0_size}, {mode_1_size}, {mode_2_size}), " + f"expected=(64, {major_tiles}, {k_tiles})" + ) + + +def _mxf4nvf4_tiled_gemm_indices(m_tiles: int, n_tiles: int): + for m_idx in range(m_tiles): + for n_idx in range(n_tiles): + yield m_idx, n_idx + + +def _mma_mxf4nvf4_full_k_tiled( + atom: MmaAtom, + d: Tensor, + a_fragment: Tensor, + sfa: Tensor, + b_fragment: Tensor, + sfb: Tensor, + c: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + m_tiles = size(a_fragment, mode=[1], loc=loc, ip=ip) + n_tiles = size(b_fragment, mode=[1], loc=loc, ip=ip) + k_tiles = size(a_fragment, mode=[2], loc=loc, ip=ip) + for k_idx in range(k_tiles): + for m_idx, n_idx in _mxf4nvf4_tiled_gemm_indices(m_tiles, n_tiles): + sfa_tile = _select_mxf4nvf4_full_k_scale_fragment( + sfa, m_idx, k_idx, m_tiles, k_tiles, "sfa", loc=loc, ip=ip + ) + sfb_tile = _select_mxf4nvf4_full_k_scale_fragment( + sfb, n_idx, k_idx, n_tiles, k_tiles, "sfb", loc=loc, ip=ip + ) + c_tile = ( + c[(None, m_idx, n_idx)] + if k_idx == 0 + else d[(None, m_idx, n_idx)] + ) + mma_mxf4nvf4( + atom, + d[(None, m_idx, n_idx)], + (a_fragment[(None, m_idx, k_idx)], sfa_tile), + (b_fragment[(None, n_idx, k_idx)], sfb_tile), + c_tile, + loc=loc, + ip=ip, + ) + + +def _mma_mxf4nvf4_tiled( + atom: MmaAtom, + d: Tensor, + a_fragment: Tensor, + sfa: Tensor, + b_fragment: Tensor, + sfb: Tensor, + c: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + if _is_mxf4nvf4_full_k_tiled_fragment( + d, a_fragment, b_fragment, c, loc=loc, ip=ip + ): + return _mma_mxf4nvf4_full_k_tiled( + atom, d, a_fragment, sfa, b_fragment, sfb, c, loc=loc, ip=ip + ) + a_atom_size = _mxf4nvf4_fragment_atom_size(a_fragment, 32) + b_atom_size = _mxf4nvf4_fragment_atom_size(b_fragment, 16) + a_tiles = size(a_fragment, loc=loc, ip=ip) // a_atom_size + b_tiles = size(b_fragment, loc=loc, ip=ip) // b_atom_size + if rank(a_fragment) < 2 or rank(b_fragment) < 2: + raise ValueError("tiled SM120 MXF4NVF4 fragments must expose tile modes") + if rank(d) < 3 or rank(c) < 3: + raise ValueError("tiled SM120 MXF4NVF4 accumulators must expose M/N tile modes") + for m_idx, n_idx in _mxf4nvf4_tiled_gemm_indices(a_tiles, b_tiles): + sfa_tile = _select_mxf4nvf4_scale_fragment( + sfa, m_idx, a_tiles, "sfa", loc=loc, ip=ip + ) + sfb_tile = _select_mxf4nvf4_scale_fragment( + sfb, n_idx, b_tiles, "sfb", loc=loc, ip=ip + ) + mma_mxf4nvf4( + atom, + d[(None, m_idx, n_idx)], + (a_fragment[(None, m_idx)], sfa_tile), + (b_fragment[(None, n_idx)], sfb_tile), + c[(None, m_idx, n_idx)], + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_sfa_layout( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: + """Return the SM120 MXF4NVF4 SFA register scale-fragment layout.""" + return make_layout(((16, 4),), stride=((0, 1),), loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_sfb_layout( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Layout: + """Return the SM120 MXF4NVF4 SFB register scale-fragment layout.""" + return make_layout(((16, 4),), stride=((0, 1),), loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_sfa_fragment( + dtype: Type[Numeric] = Float8E4M3FN, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + """Return an SM120 MXF4NVF4 SFA register scale fragment.""" + if dtype != Float8E4M3FN: + raise TypeError("SM120 MXF4NVF4 SFA fragments require Float8E4M3FN") + return make_rmem_tensor( + make_mxf4nvf4_sfa_layout(loc=loc, ip=ip), dtype, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_mxf4nvf4_sfb_fragment( + dtype: Type[Numeric] = Float8E4M3FN, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + """Return an SM120 MXF4NVF4 SFB register scale fragment.""" + if dtype != Float8E4M3FN: + raise TypeError("SM120 MXF4NVF4 SFB fragments require Float8E4M3FN") + return make_rmem_tensor( + make_mxf4nvf4_sfb_layout(loc=loc, ip=ip), dtype, loc=loc, ip=ip + ) + + +@dsl_user_op +def mma_mxf4nvf4( + atom: MmaAtom, + d: Tensor, + a: Union[List[Tensor], Tuple[Tensor, ...]], + b: Union[List[Tensor], Tuple[Tensor, ...]], + c: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> None: + """Issue SM120 MXF4NVF4 warp MMA with bundled FP4 and E4M3 scale fragments. + + This helper consumes already-partitioned register fragments. SMEM scale + partitioning remains in the Blackwell block-scaled helper path + (``partition_fragment_SFA`` / ``partition_fragment_SFB``). + """ + _validate_mxf4nvf4_atom(atom) + a_fragment, sfa = _normalize_mxf4nvf4_operand(a, "a") + b_fragment, sfb = _normalize_mxf4nvf4_operand(b, "b") + + if _is_mxf4nvf4_tiled_fragment( + d, a_fragment, b_fragment, c, loc=loc, ip=ip + ): + return _mma_mxf4nvf4_tiled( + atom, d, a_fragment, sfa, b_fragment, sfb, c, loc=loc, ip=ip + ) + + a_i32 = _validate_mxf4nvf4_packed_fragment_layout( + a_fragment, + "a fragment", + expected_logical_size=32, + expected_i32_size=4, + loc=loc, + ip=ip, + ) + b_i32 = _validate_mxf4nvf4_packed_fragment_layout( + b_fragment, + "b fragment", + expected_logical_size=16, + expected_i32_size=2, + loc=loc, + ip=ip, + ) + _validate_mxf4nvf4_accumulator_layout(d, "d", loc=loc, ip=ip) + _validate_mxf4nvf4_accumulator_layout(c, "c", loc=loc, ip=ip) + + compact_sfa = _validate_mxf4nvf4_scale_fragment_layout( + sfa, "sfa", loc=loc, ip=ip + ) + compact_sfb = _validate_mxf4nvf4_scale_fragment_layout( + sfb, "sfb", loc=loc, ip=ip + ) + + sfa_i32 = recast_tensor(compact_sfa, Int32, loc=loc, ip=ip) + sfb_i32 = recast_tensor(compact_sfb, Int32, loc=loc, ip=ip) + + a_vec = a_i32.load(loc=loc, ip=ip) + b_vec = b_i32.load(loc=loc, ip=ip) + c_vec = c.load(loc=loc, ip=ip) + a_regs = [a_vec[i].ir_value(loc=loc, ip=ip) for i in range(4)] + b_regs = [b_vec[i].ir_value(loc=loc, ip=ip) for i in range(2)] + c_regs = [c_vec[i].ir_value(loc=loc, ip=ip) for i in range(4)] + shape_mnk = _pack_shape((16, 8, 64), loc=loc, ip=ip) + result = _cute_nvgpu_ir.arch_mma_SM120_block_scaled( + [Float32.mlir_type] * 4, + shape_mnk.type.attribute, + 16, + ir.TypeAttr.get(Float4E2M1FN.mlir_type), + ir.TypeAttr.get(Float4E2M1FN.mlir_type), + ir.TypeAttr.get(Float8E4M3FN.mlir_type), + a_regs, + b_regs, + c_regs, + Int32(sfa_i32[0]).ir_value(loc=loc, ip=ip), + Int32(sfb_i32[0]).ir_value(loc=loc, ip=ip), + thread_id_a=0, + thread_id_b=0, + loc=loc, + ip=ip, + ) + result_vec = vector.from_elements( + ir.VectorType.get([4], Float32.mlir_type), result, loc=loc, ip=ip + ) + d.store( + TensorSSA(result_vec, d.shape, Float32, loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + # # MXF8 MMA # diff --git a/python/CuTeDSL/cutlass/cute/tensor.py b/python/CuTeDSL/cutlass/cute/tensor.py index 0627fc8700..b36b5fa51a 100644 --- a/python/CuTeDSL/cutlass/cute/tensor.py +++ b/python/CuTeDSL/cutlass/cute/tensor.py @@ -38,7 +38,6 @@ Int8, Int32, BFloat16, - Float32, IntTuple, Coord, Shape, @@ -68,6 +67,8 @@ append, depth, flatten, + get_nonswizzle_portion, + get_swizzle_portion, has_underscore, make_layout, select, @@ -95,6 +96,7 @@ "make_fragment_like", "make_rmem_tensor_like", "make_rmem_tensor", + "as_position_independent_swizzle_tensor", "recast_tensor", "domain_offset", "print_tensor", @@ -907,6 +909,31 @@ def make_fragment( return make_rmem_tensor(layout_or_shape, dtype, loc=loc, ip=ip) +@dsl_user_op +def as_position_independent_swizzle_tensor( + src: Tensor, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> Tensor: + """Return a shared-memory tensor with layout swizzle moved onto the pointer.""" + if not isinstance(src, Tensor): + raise TypeError(f"expects a Tensor, but got {type(src)}") + if src.memspace != AddressSpace.smem: + raise TypeError("expects a shared-memory tensor") + + swizzle = get_swizzle_portion(src.layout, loc=loc, ip=ip) + layout = get_nonswizzle_portion(src.layout, loc=loc, ip=ip) + ptr = recast_ptr( + src.iterator, + swizzle_=swizzle, + dtype=src.element_type, + loc=loc, + ip=ip, + ) + return make_tensor(ptr, layout, loc=loc, ip=ip) + + @dsl_user_op def make_rmem_tensor_like( src: Union[Layout, ComposedLayout, Tensor, "TensorSSA"], diff --git a/python/CuTeDSL/cutlass/pipeline/sm90.py b/python/CuTeDSL/cutlass/pipeline/sm90.py index 1818660882..83baa47b93 100644 --- a/python/CuTeDSL/cutlass/pipeline/sm90.py +++ b/python/CuTeDSL/cutlass/pipeline/sm90.py @@ -601,6 +601,39 @@ def producer_acquire( ) self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip) # type: ignore[call-arg] + @dsl_user_op + def producer_acquire_already_elected( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, + ) -> None: + """ + Acquire a TMA load stage from inside an existing ``elect_one`` block. + + This is equivalent to ``producer_acquire`` for TMA load pipelines, except + the transaction-barrier arrive does not perform its own thread election. + Use it only when the caller has already elected a single producer thread. + Calling this outside such a region is incorrect because every calling + thread would arrive on the transaction barrier. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait( # type: ignore[call-arg] + state.index, state.phase, loc=loc, ip=ip + ), + loc=loc, + ip=ip, + ) + cute.arch.mbarrier_arrive_and_expect_tx( + self.producer_get_barrier(state, loc=loc, ip=ip), + self.sync_object_full.tx_count, # type: ignore[attr-defined] + loc=loc, + ip=ip, + ) + @dsl_user_op def producer_commit( self, diff --git a/python/CuTeDSL/cutlass/utils/gemm/__init__.py b/python/CuTeDSL/cutlass/utils/gemm/__init__.py index 2357e392fa..18a2a36817 100644 --- a/python/CuTeDSL/cutlass/utils/gemm/__init__.py +++ b/python/CuTeDSL/cutlass/utils/gemm/__init__.py @@ -9,8 +9,9 @@ # and related documentation outside the scope permitted by the EULA # is strictly prohibited. -from . import sm100 +from . import sm100, sm120 __all__ = [ "sm100", + "sm120", ] diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm120/__init__.py b/python/CuTeDSL/cutlass/utils/gemm/sm120/__init__.py new file mode 100644 index 0000000000..5ee754db1e --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/sm120/__init__.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""Narrow SM120 MXF4/NVFP4 GEMM helper API.""" + +from .constants import ( + MXF4NVF4_CTA_SHAPE_MNK, + MXF4NVF4_MMA_SHAPE_MNK, + MXF4NVF4_SCALE_TMA_BYTES, + MXF4NVF4_SCALE_VEC_SIZE, + mxf4nvf4_ab_tma_bytes, + mxf4nvf4_full_tma_bytes, +) +from .layouts import ( + make_mxf4nvf4_a_gmem_layout, + make_mxf4nvf4_ab_tma_physical_layout_staged, + make_mxf4nvf4_b_gmem_layout, + make_mxf4nvf4_native_tma_smem_views, + make_mxf4nvf4_scale_interleaved_gmem_layout, + make_mxf4nvf4_scale_interleaved_tma_layout_staged, + make_mxf4nvf4_tiled_mma, +) +from .tma import make_mxf4nvf4_native_tma_atoms +from .validation import mxf4nvf4_can_implement, validate_mxf4nvf4_gemm_config + +__all__ = [ + "MXF4NVF4_CTA_SHAPE_MNK", + "MXF4NVF4_MMA_SHAPE_MNK", + "MXF4NVF4_SCALE_TMA_BYTES", + "MXF4NVF4_SCALE_VEC_SIZE", + "make_mxf4nvf4_native_tma_atoms", + "make_mxf4nvf4_native_tma_smem_views", + "make_mxf4nvf4_a_gmem_layout", + "make_mxf4nvf4_ab_tma_physical_layout_staged", + "make_mxf4nvf4_b_gmem_layout", + "make_mxf4nvf4_scale_interleaved_gmem_layout", + "make_mxf4nvf4_scale_interleaved_tma_layout_staged", + "make_mxf4nvf4_tiled_mma", + "mxf4nvf4_ab_tma_bytes", + "mxf4nvf4_can_implement", + "mxf4nvf4_full_tma_bytes", + "validate_mxf4nvf4_gemm_config", +] diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm120/constants.py b/python/CuTeDSL/cutlass/utils/gemm/sm120/constants.py new file mode 100644 index 0000000000..31ee6aa38a --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/sm120/constants.py @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""Constants for SM120 MXF4/NVFP4 GEMM helpers.""" + +from typing import Literal + +AbSmemFormat = Literal["packed", "unpack"] + +MXF4NVF4_CTA_SHAPE_MNK = (128, 128, 128) +MXF4NVF4_MMA_SHAPE_MNK = (16, 8, 64) +MXF4NVF4_SCALE_VEC_SIZE = 16 +MXF4NVF4_SCALE_K = MXF4NVF4_CTA_SHAPE_MNK[2] // MXF4NVF4_SCALE_VEC_SIZE +MXF4NVF4_SCALE_TMA_MIN_L = 2 + +# Packed-path fixed CTA transaction sizes. Use mxf4nvf4_*_tma_bytes() +# when the A/B SMEM format is configurable. +MXF4NVF4_AB_TMA_BYTES = ( + MXF4NVF4_CTA_SHAPE_MNK[0] * MXF4NVF4_CTA_SHAPE_MNK[2] // 2 +) +MXF4NVF4_SCALE_TMA_BYTES = MXF4NVF4_CTA_SHAPE_MNK[0] * MXF4NVF4_SCALE_K +MXF4NVF4_FULL_TMA_BYTES = 2 * MXF4NVF4_AB_TMA_BYTES + 2 * MXF4NVF4_SCALE_TMA_BYTES + + +def mxf4nvf4_ab_tma_bytes(ab_smem_format: AbSmemFormat = "packed") -> int: + """Return one SM120 MXF4/NVFP4 A or B TMA transaction size in bytes.""" + major_extent = MXF4NVF4_CTA_SHAPE_MNK[0] + tile_k = MXF4NVF4_CTA_SHAPE_MNK[2] + if ab_smem_format == "packed": + return major_extent * tile_k // 2 + if ab_smem_format == "unpack": + return major_extent * tile_k + raise ValueError(f"unsupported ab_smem_format: {ab_smem_format!r}") + + +def mxf4nvf4_full_tma_bytes(ab_smem_format: AbSmemFormat = "packed") -> int: + """Return the total fixed microtile TMA transaction size in bytes.""" + return 2 * mxf4nvf4_ab_tma_bytes(ab_smem_format) + 2 * MXF4NVF4_SCALE_TMA_BYTES + + +__all__ = [ + "AbSmemFormat", + "MXF4NVF4_CTA_SHAPE_MNK", + "MXF4NVF4_FULL_TMA_BYTES", + "MXF4NVF4_MMA_SHAPE_MNK", + "MXF4NVF4_SCALE_K", + "MXF4NVF4_SCALE_TMA_MIN_L", + "MXF4NVF4_SCALE_TMA_BYTES", + "MXF4NVF4_SCALE_VEC_SIZE", + "mxf4nvf4_ab_tma_bytes", + "mxf4nvf4_full_tma_bytes", +] diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm120/layouts.py b/python/CuTeDSL/cutlass/utils/gemm/sm120/layouts.py new file mode 100644 index 0000000000..6b1a72743e --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/sm120/layouts.py @@ -0,0 +1,368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""Layouts for SM120 MXF4/NVFP4 GEMM helpers.""" + +from typing import Any, Optional + +import cutlass +import cutlass.cute as cute +from cutlass._mlir import ir +from cutlass.cute.nvgpu import warp +from cutlass.cutlass_dsl import dsl_user_op +from cutlass.utils.smem_allocator import SmemAllocator + +from .constants import ( + AbSmemFormat, + MXF4NVF4_CTA_SHAPE_MNK, + MXF4NVF4_SCALE_TMA_MIN_L, + MXF4NVF4_SCALE_VEC_SIZE, +) +from .validation import _check_default_tile, _check_positive + + +@dsl_user_op +def make_mxf4nvf4_tiled_mma( + atom_layout_mnk: Any = None, + permutation_mnk: Any = None, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.TiledMma: + """Create the tested SM120 128x128x128 MXF4/NVFP4 tiled MMA.""" + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, + cutlass.Float32, + cutlass.Float8E4M3FN, + ) + if atom_layout_mnk is None: + atom_layout_mnk = cute.make_layout((4, 2, 1), stride=(1, 4, 0), loc=loc, ip=ip) + if permutation_mnk is None: + permutation_mnk = ( + 128, + cute.make_layout((8, 2, 2), stride=(1, 16, 8), loc=loc, ip=ip), + 64, + ) + return cute.make_tiled_mma( + mma_op, + atom_layout_mnk=atom_layout_mnk, + permutation_mnk=permutation_mnk, + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_a_gmem_layout( + m: int = 128, + k: int = 128, + l_extent: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the logical K-major A GMEM layout.""" + _check_positive("m", m) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + return cute.make_layout((m, k, l_extent), stride=(k, 1, m * k), loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_b_gmem_layout( + n: int = 128, + k: int = 128, + l_extent: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the logical K-major B GMEM layout.""" + _check_positive("n", n) + _check_positive("k", k) + _check_positive("l_extent", l_extent) + return cute.make_layout((n, k, l_extent), stride=(k, 1, n * k), loc=loc, ip=ip) + + +@dsl_user_op +def make_mxf4nvf4_scale_interleaved_gmem_layout( + major_extent: int = 128, + logical_k_extent: int = 128, + l_extent: int = 1, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the compact interleaved FP8 GMEM scale layout consumed by TMA.""" + _check_positive("major_extent", major_extent) + _check_positive("logical_k_extent", logical_k_extent) + _check_positive("l_extent", l_extent) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 scale layout requires sf_vec_size=16") + if major_extent % 128 != 0: + raise ValueError("SM120 scale interleaved layout requires major_extent % 128 == 0") + if logical_k_extent % sf_vec_size != 0: + raise ValueError( + "SM120 scale interleaved layout requires " + "logical_k_extent % sf_vec_size == 0" + ) + logical_scale_k = cute.ceil_div(logical_k_extent, sf_vec_size) + if logical_scale_k % 4 != 0: + raise ValueError("SM120 scale interleaved layout requires scale_k % 4 == 0") + major_tiles = major_extent // 128 + scale_tiles = logical_scale_k // 4 + l_stride = major_tiles * scale_tiles * 512 + return cute.make_layout( + (((32, 4), major_tiles), 4, scale_tiles, l_extent), + stride=(((16, 4), 512), 1, major_tiles * 512, l_stride), + loc=loc, + ip=ip, + ) + + +def mxf4nvf4_padded_scale_k_extent(logical_scale_k_extent: int) -> int: + """Return the padded physical scale-K extent for SM120 scale TMA.""" + _check_positive("logical_scale_k_extent", logical_scale_k_extent) + granularity = MXF4NVF4_CTA_SHAPE_MNK[2] // MXF4NVF4_SCALE_VEC_SIZE * 2 + return ( + (logical_scale_k_extent + granularity - 1) // granularity + ) * granularity + + +def mxf4nvf4_scale_tma_physical_k_extent( + k: int, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> int: + """Return the physical scale-K extent needed to back a logical K extent.""" + _check_positive("k", k) + _check_positive("sf_vec_size", sf_vec_size) + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 scale TMA requires sf_vec_size=16") + if k % sf_vec_size != 0: + raise ValueError("SM120 MXF4NVF4 K extent must be divisible by sf_vec_size") + return mxf4nvf4_padded_scale_k_extent(k // sf_vec_size) + + +def mxf4nvf4_scale_tma_physical_l_extent(logical_l_extent: int) -> int: + """Return the physical scale-L extent used by native SM120 scale TMA.""" + _check_positive("logical_l_extent", logical_l_extent) + return max(logical_l_extent, MXF4NVF4_SCALE_TMA_MIN_L) + + +@dsl_user_op +def make_mxf4nvf4_ab_tma_physical_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the A/B physical SMEM byte layout populated by TMA.""" + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.make_layout( + (major_extent, tile_k, num_stages), + stride=(tile_k, 1, major_extent * tile_k), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_ab_consumer_layout_atom( + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + return cute.make_composed_layout( + cute.make_swizzle(2, 4, 3, loc=loc, ip=ip), + 0, + cute.make_layout((8, 128), stride=(128, 1), loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_a_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + _check_default_tile(major_extent, tile_k, MXF4NVF4_SCALE_VEC_SIZE) + _check_positive("num_stages", num_stages) + return cute.tile_to_shape( + _make_mxf4nvf4_ab_consumer_layout_atom(loc=loc, ip=ip), + (major_extent, tile_k, num_stages), + (0, 1, 2), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_b_consumer_smem_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.ComposedLayout: + return _make_mxf4nvf4_a_consumer_smem_layout_staged( + major_extent, tile_k, num_stages, loc=loc, ip=ip + ) + + +@dsl_user_op +def make_mxf4nvf4_scale_interleaved_tma_layout_staged( + major_extent: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +) -> cute.Layout: + """Return the compact interleaved FP8 scale TMA SMEM layout.""" + _check_default_tile(major_extent, tile_k, sf_vec_size) + _check_positive("num_stages", num_stages) + if major_extent % 128 != 0: + raise ValueError( + "SM120 scale interleaved SMEM layout requires major_extent % 128 == 0" + ) + scale_k = tile_k // sf_vec_size + if scale_k % 4 != 0: + raise ValueError("SM120 scale interleaved SMEM layout requires scale_k % 4 == 0") + major_tiles = major_extent // 128 + scale_tiles = scale_k // 4 + stage_stride = major_tiles * scale_tiles * 512 + return cute.make_layout( + (((32, 4), major_tiles), 4, scale_tiles, num_stages), + stride=(((16, 4), 512), 1, major_tiles * 512, stage_stride), + loc=loc, + ip=ip, + ) + + +def _mxf4nvf4_ab_smem_dtype(ab_smem_format: AbSmemFormat): + if ab_smem_format == "unpack": + return cutlass.Uint8 + if ab_smem_format == "packed": + return cutlass.Float4E2M1FN + raise ValueError( + f"`ab_smem_format` must be 'packed' or 'unpack', but got {ab_smem_format!r}" + ) + + +def _allocate_mxf4nvf4_ab_smem_views( + smem: SmemAllocator, + *, + ab_smem_format: AbSmemFormat = "packed", + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, +) -> tuple[cute.Tensor, cute.Tensor]: + ab_smem_dtype = _mxf4nvf4_ab_smem_dtype(ab_smem_format) + return ( + smem.allocate_tensor( + ab_smem_dtype, + _make_mxf4nvf4_a_consumer_smem_layout_staged(tile_m, tile_k, num_stages), + byte_alignment=128, + ), + smem.allocate_tensor( + ab_smem_dtype, + _make_mxf4nvf4_b_consumer_smem_layout_staged(tile_n, tile_k, num_stages), + byte_alignment=128, + ), + ) + + +def _allocate_mxf4nvf4_scale_interleaved_smem_views( + smem: SmemAllocator, + *, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + num_stages: int = 1, +) -> tuple[cute.Tensor, cute.Tensor]: + return ( + smem.allocate_tensor( + cutlass.Float8E4M3FN, + make_mxf4nvf4_scale_interleaved_tma_layout_staged( + tile_m, tile_k, sf_vec_size, num_stages + ), + byte_alignment=128, + ), + smem.allocate_tensor( + cutlass.Float8E4M3FN, + make_mxf4nvf4_scale_interleaved_tma_layout_staged( + tile_n, tile_k, sf_vec_size, num_stages + ), + byte_alignment=128, + ), + ) + + +def make_mxf4nvf4_native_tma_smem_views( + smem: SmemAllocator, + *, + ab_smem_format: AbSmemFormat = "packed", + num_stages: int = 1, + tile_m: int = 128, + tile_n: int = 128, + tile_k: int = 128, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, +) -> tuple[cute.Tensor, cute.Tensor, cute.Tensor, cute.Tensor]: + """Allocate A/B/SFA/SFB SMEM views for native SM120 TMA atoms. + + The scale path is intentionally interleaved-only; this is the compact fast + path matching the SM120 native FP8 scale tensor-map usage. + """ + sA, sB = _allocate_mxf4nvf4_ab_smem_views( + smem, + ab_smem_format=ab_smem_format, + num_stages=num_stages, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + ) + sSFA, sSFB = _allocate_mxf4nvf4_scale_interleaved_smem_views( + smem, + tile_m=tile_m, + tile_n=tile_n, + tile_k=tile_k, + sf_vec_size=sf_vec_size, + num_stages=num_stages, + ) + return (sA, sB, sSFA, sSFB) + + +__all__ = [ + "make_mxf4nvf4_a_gmem_layout", + "make_mxf4nvf4_ab_tma_physical_layout_staged", + "make_mxf4nvf4_b_gmem_layout", + "make_mxf4nvf4_native_tma_smem_views", + "make_mxf4nvf4_scale_interleaved_gmem_layout", + "make_mxf4nvf4_scale_interleaved_tma_layout_staged", + "make_mxf4nvf4_tiled_mma", + "mxf4nvf4_padded_scale_k_extent", + "mxf4nvf4_scale_tma_physical_k_extent", + "mxf4nvf4_scale_tma_physical_l_extent", +] diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm120/tma.py b/python/CuTeDSL/cutlass/utils/gemm/sm120/tma.py new file mode 100644 index 0000000000..a5b0a93ac4 --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/sm120/tma.py @@ -0,0 +1,328 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""TMA atom helpers for SM120 MXF4/NVFP4 GEMM.""" + +from typing import Optional, Type + +import cutlass +import cutlass.cute as cute +from cutlass import const_expr +from cutlass._mlir import ir +from cutlass.base_dsl.typing import Numeric +from cutlass.cute.nvgpu import cpasync +from cutlass.cutlass_dsl import dsl_user_op + +from .constants import ( + AbSmemFormat, + MXF4NVF4_CTA_SHAPE_MNK, + MXF4NVF4_SCALE_TMA_MIN_L, + MXF4NVF4_SCALE_VEC_SIZE, +) +from .layouts import ( + _make_mxf4nvf4_a_consumer_smem_layout_staged, + _make_mxf4nvf4_b_consumer_smem_layout_staged, + make_mxf4nvf4_scale_interleaved_tma_layout_staged, +) + + +def _normalize_mxf4nvf4_ab_smem_format(smem_format: AbSmemFormat) -> AbSmemFormat: + if smem_format in ("packed", "unpack"): + return smem_format + raise ValueError( + f"`smem_format` must be 'packed' or 'unpack', but got {smem_format!r}" + ) + + +def _mxf4nvf4_ab_tma_internal_type( + smem_format: AbSmemFormat, +) -> Optional[Type[Numeric]]: + if _normalize_mxf4nvf4_ab_smem_format(smem_format) == "unpack": + return cutlass.Uint8 + return None + + +def _preserve_mxf4nvf4_ab_tma_l_mode(gmem_tensor: cute.Tensor) -> cute.Tensor: + """Keep A/B tensor maps rank-3 even for logical L=1. + + The SM120 native TMA path is closest to the C++ 79a path when A/B tensor + maps keep the L coordinate in the instruction stream. For logical L=1, use + a physical L extent of at least two while preserving the original strides. + """ + if const_expr(cute.size(gmem_tensor, mode=[2]) != 1): + return gmem_tensor + return cute.make_tensor( + gmem_tensor.iterator, + cute.make_layout( + ( + gmem_tensor.shape[0], + gmem_tensor.shape[1], + MXF4NVF4_SCALE_TMA_MIN_L, + ), + stride=gmem_tensor.layout.stride, + ), + ) + + +def _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: cute.Layout, + cta_tiler: cute.Tile, + *, + internal_type: Optional[Type[Numeric]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + return cpasync.make_tiled_tma_atom( + cpasync.CopyBulkTensorTileG2SOp(), + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=internal_type, + loc=loc, + ip=ip, + ) + + +def _validate_mxf4nvf4_tma_tensor( + name: str, + tensor: cute.Tensor, + dtype: Type[Numeric], + rank: int, +) -> None: + if tensor.element_type != dtype: + raise TypeError( + f"`{name}` must have element type {dtype}, but got {tensor.element_type}" + ) + if cute.rank(tensor) != rank: + raise ValueError( + f"`{name}` must have rank {rank}, but got rank {cute.rank(tensor)}" + ) + if tensor.memspace != cute.AddressSpace.gmem: + raise ValueError(f"`{name}` must be a global-memory tensor") + + +@dsl_user_op +def _make_mxf4nvf4_tiled_tma_atom_a( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + smem_format: AbSmemFormat = "packed", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + smem_format = _normalize_mxf4nvf4_ab_smem_format(smem_format) + if const_expr(smem_layout is None): + smem_layout = _make_mxf4nvf4_a_consumer_smem_layout_staged(loc=loc, ip=ip) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=_mxf4nvf4_ab_tma_internal_type(smem_format), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_tiled_tma_atom_b( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 128, 1), + smem_format: AbSmemFormat = "packed", + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + smem_format = _normalize_mxf4nvf4_ab_smem_format(smem_format) + if const_expr(smem_layout is None): + smem_layout = _make_mxf4nvf4_b_consumer_smem_layout_staged(loc=loc, ip=ip) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + internal_type=_mxf4nvf4_ab_tma_internal_type(smem_format), + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_sfa_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 4, 2, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + if const_expr(smem_layout is None): + smem_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + loc=loc, ip=ip + ) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + # Do not pass an internal type here: unlike A/B unpack-SMEM TMA, + # the scale tensor-map format must remain the native FP8 GMEM type. + # In CuTe DSL, internal_type controls tensor-map data format, so a + # Uint16 internal type would change the descriptor away from FP8. + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def _make_mxf4nvf4_sfb_tiled_tma_atom( + gmem_tensor: cute.Tensor, + smem_layout: Optional[cute.Layout] = None, + cta_tiler: cute.Tile = (128, 4, 2, 1), + *, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + if const_expr(smem_layout is None): + smem_layout = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + loc=loc, ip=ip + ) + return _make_mxf4nvf4_tiled_tma_atom( + gmem_tensor, + smem_layout, + cta_tiler, + # Keep the scale tensor-map format native FP8. See SFA helper above. + loc=loc, + ip=ip, + ) + + +@dsl_user_op +def make_mxf4nvf4_native_tma_atoms( + gA: cute.Tensor, + gB: cute.Tensor, + gSFA: cute.Tensor, + gSFB: cute.Tensor, + *, + ab_smem_format: AbSmemFormat = "packed", + ab_cta_tiler: cute.Tile = (128, 128, 1), + ab_tile_coord: Optional[tuple[int, int, int]] = None, + ab_tile_coord_a: Optional[tuple[int, int, int]] = None, + ab_tile_coord_b: Optional[tuple[int, int, int]] = None, + scale_tile_coord: Optional[tuple[int, int, int, int]] = None, + scale_tile_coord_sfa: Optional[tuple[int, int, int, int]] = None, + scale_tile_coord_sfb: Optional[tuple[int, int, int, int]] = None, + loc: Optional[ir.Location] = None, + ip: Optional[ir.InsertionPoint] = None, +): + """Create A/B/SFA/SFB native TMA atoms for the SM120 NVFP4 path. + + A/B tensors remain logical FP4. Passing ``ab_smem_format="unpack"`` uses an + 8-bit internal TMA type, triggering the FP4 unpack-SMEM tensor-map format. + Scale tensors remain native FP8 and use the compact interleaved scale TMA + layout. Tile coordinates default to ``None`` for both A/B and scale tensors; + callers that need local tiles can request them consistently. + """ + _validate_mxf4nvf4_tma_tensor("gA", gA, cutlass.Float4E2M1FN, 3) + _validate_mxf4nvf4_tma_tensor("gB", gB, cutlass.Float4E2M1FN, 3) + _validate_mxf4nvf4_tma_tensor("gSFA", gSFA, cutlass.Float8E4M3FN, 4) + _validate_mxf4nvf4_tma_tensor("gSFB", gSFB, cutlass.Float8E4M3FN, 4) + scale_cta_tiler = (128, 4, 2, 1) + gA = _preserve_mxf4nvf4_ab_tma_l_mode(gA) + gB = _preserve_mxf4nvf4_ab_tma_l_mode(gB) + tma_atom_a, tma_tensor_a = _make_mxf4nvf4_tiled_tma_atom_a( + gA, cta_tiler=ab_cta_tiler, smem_format=ab_smem_format, loc=loc, ip=ip + ) + tma_atom_b, tma_tensor_b = _make_mxf4nvf4_tiled_tma_atom_b( + gB, cta_tiler=ab_cta_tiler, smem_format=ab_smem_format, loc=loc, ip=ip + ) + + if ab_tile_coord_a is None: + ab_tile_coord_a = ab_tile_coord + if ab_tile_coord_b is None: + ab_tile_coord_b = ab_tile_coord + if ab_tile_coord_a is not None: + tma_tensor_a = cute.local_tile( + tma_tensor_a, ab_cta_tiler, ab_tile_coord_a, loc=loc, ip=ip + ) + if ab_tile_coord_b is not None: + tma_tensor_b = cute.local_tile( + tma_tensor_b, ab_cta_tiler, ab_tile_coord_b, loc=loc, ip=ip + ) + + scale_smem_layout_a = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + MXF4NVF4_CTA_SHAPE_MNK[0], + MXF4NVF4_CTA_SHAPE_MNK[2], + MXF4NVF4_SCALE_VEC_SIZE, + 1, + loc=loc, + ip=ip, + ) + scale_smem_layout_b = make_mxf4nvf4_scale_interleaved_tma_layout_staged( + MXF4NVF4_CTA_SHAPE_MNK[1], + MXF4NVF4_CTA_SHAPE_MNK[2], + MXF4NVF4_SCALE_VEC_SIZE, + 1, + loc=loc, + ip=ip, + ) + tma_atom_sfa, tma_tensor_sfa = _make_mxf4nvf4_sfa_tiled_tma_atom( + gSFA, + smem_layout=scale_smem_layout_a, + cta_tiler=scale_cta_tiler, + loc=loc, + ip=ip, + ) + tma_atom_sfb, tma_tensor_sfb = _make_mxf4nvf4_sfb_tiled_tma_atom( + gSFB, + smem_layout=scale_smem_layout_b, + cta_tiler=scale_cta_tiler, + loc=loc, + ip=ip, + ) + + if scale_tile_coord_sfa is None: + scale_tile_coord_sfa = scale_tile_coord + if scale_tile_coord_sfb is None: + scale_tile_coord_sfb = scale_tile_coord + if scale_tile_coord_sfa is not None: + tma_tensor_sfa = cute.local_tile( + tma_tensor_sfa, + scale_cta_tiler, + scale_tile_coord_sfa, + loc=loc, + ip=ip, + ) + if scale_tile_coord_sfb is not None: + tma_tensor_sfb = cute.local_tile( + tma_tensor_sfb, + scale_cta_tiler, + scale_tile_coord_sfb, + loc=loc, + ip=ip, + ) + + return ( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + ) + + +__all__ = [ + "AbSmemFormat", + "make_mxf4nvf4_native_tma_atoms", +] diff --git a/python/CuTeDSL/cutlass/utils/gemm/sm120/validation.py b/python/CuTeDSL/cutlass/utils/gemm/sm120/validation.py new file mode 100644 index 0000000000..305783e027 --- /dev/null +++ b/python/CuTeDSL/cutlass/utils/gemm/sm120/validation.py @@ -0,0 +1,207 @@ +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-NvidiaProprietary +# +# Use of this software is governed by the terms and conditions of the +# NVIDIA End User License Agreement (EULA), available at: +# https://docs.nvidia.com/cutlass/latest/media/docs/pythonDSL/license.html +# +# Any use, reproduction, disclosure, or distribution of this software +# and related documentation outside the scope permitted by the EULA +# is strictly prohibited. + +"""Validation helpers for SM120 MXF4/NVFP4 GEMM support.""" + +from typing import Type + +import cutlass +from cutlass.base_dsl.typing import Numeric + +from .constants import ( + MXF4NVF4_CTA_SHAPE_MNK, + MXF4NVF4_SCALE_VEC_SIZE, +) + + +def _check_positive(name: str, value: int) -> None: + if value <= 0: + raise ValueError(f"`{name}` must be positive, but got {value}") + + +def _check_default_tile(tile_mn: int, tile_k: int, sf_vec_size: int) -> None: + _check_positive("tile_mn", tile_mn) + _check_positive("tile_k", tile_k) + _check_positive("sf_vec_size", sf_vec_size) + if tile_k != MXF4NVF4_CTA_SHAPE_MNK[2]: + raise ValueError("SM120 MXF4NVF4 helpers currently support tile_k=128") + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + raise ValueError("SM120 MXF4NVF4 helpers currently support sf_vec_size=16") + + +def _check_tuple(name: str, value: tuple[int, ...], rank: int) -> None: + if len(value) != rank: + raise ValueError(f"`{name}` must have rank {rank}, but got {value}") + + +def _contiguous_alignment(dtype: Type[Numeric]) -> int: + return 16 * 8 // dtype.width + + +def _mxf4nvf4_gemm_config_errors( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", +) -> list[str]: + errors: list[str] = [] + for name, value in (("m", m), ("n", n), ("k", k), ("l_extent", l_extent)): + if value <= 0: + errors.append(f"`{name}` must be positive") + + try: + _check_tuple("tile_shape_mnk", tile_shape_mnk, 3) + except ValueError as exc: + errors.append(str(exc)) + try: + _check_tuple("cluster_shape_mnk", cluster_shape_mnk, 3) + except ValueError as exc: + errors.append(str(exc)) + + if a_dtype != cutlass.Float4E2M1FN: + errors.append("A dtype must be Float4E2M1FN") + if b_dtype != cutlass.Float4E2M1FN: + errors.append("B dtype must be Float4E2M1FN") + if sf_dtype != cutlass.Float8E4M3FN: + errors.append("scale dtype must be Float8E4M3FN") + if sf_vec_size != MXF4NVF4_SCALE_VEC_SIZE: + errors.append(f"sf_vec_size must be {MXF4NVF4_SCALE_VEC_SIZE}") + if acc_dtype != cutlass.Float32: + errors.append("accumulator dtype must be Float32") + if c_dtype not in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + errors.append("output dtype must be Float32, Float16, or BFloat16") + + if tile_shape_mnk != MXF4NVF4_CTA_SHAPE_MNK: + errors.append(f"tile_shape_mnk must be {MXF4NVF4_CTA_SHAPE_MNK}") + if cluster_shape_mnk != (1, 1, 1): + errors.append("cluster_shape_mnk must be (1, 1, 1)") + if a_major != "k": + errors.append("A layout must be K-major") + if b_major != "k": + errors.append("B layout must be K-major") + if c_major not in {"n", "m"}: + errors.append("output layout must be N-major or M-major") + + if len(tile_shape_mnk) == 3: + tile_m, tile_n, tile_k = tile_shape_mnk + if m % tile_m != 0: + errors.append("m must be divisible by tile_shape_mnk[0]") + if n % tile_n != 0: + errors.append("n must be divisible by tile_shape_mnk[1]") + if k % tile_k != 0: + errors.append("k must be divisible by tile_shape_mnk[2]") + + if a_dtype == cutlass.Float4E2M1FN and k % _contiguous_alignment(a_dtype): + errors.append("K-major A requires k to be 16-byte aligned") + if b_dtype == cutlass.Float4E2M1FN and k % _contiguous_alignment(b_dtype): + errors.append("K-major B requires k to be 16-byte aligned") + if c_dtype in {cutlass.Float32, cutlass.Float16, cutlass.BFloat16}: + c_contiguous_extent = m if c_major == "m" else n + if c_contiguous_extent % _contiguous_alignment(c_dtype): + errors.append("output contiguous dimension must be 16-byte aligned") + + return errors + + +def mxf4nvf4_can_implement( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", +) -> bool: + """Return whether the narrow SM120 MXF4/NVFP4 GEMM helper supports a config.""" + return not _mxf4nvf4_gemm_config_errors( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ) + + +def validate_mxf4nvf4_gemm_config( + *, + m: int = 128, + n: int = 128, + k: int = 128, + l_extent: int = 1, + a_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + b_dtype: Type[Numeric] = cutlass.Float4E2M1FN, + sf_dtype: Type[Numeric] = cutlass.Float8E4M3FN, + sf_vec_size: int = MXF4NVF4_SCALE_VEC_SIZE, + c_dtype: Type[Numeric] = cutlass.BFloat16, + acc_dtype: Type[Numeric] = cutlass.Float32, + tile_shape_mnk: tuple[int, int, int] = MXF4NVF4_CTA_SHAPE_MNK, + cluster_shape_mnk: tuple[int, int, int] = (1, 1, 1), + a_major: str = "k", + b_major: str = "k", + c_major: str = "n", +) -> None: + """Raise ``ValueError`` if a config is outside the narrow SM120 NVFP4 path.""" + errors = _mxf4nvf4_gemm_config_errors( + m=m, + n=n, + k=k, + l_extent=l_extent, + a_dtype=a_dtype, + b_dtype=b_dtype, + sf_dtype=sf_dtype, + sf_vec_size=sf_vec_size, + c_dtype=c_dtype, + acc_dtype=acc_dtype, + tile_shape_mnk=tile_shape_mnk, + cluster_shape_mnk=cluster_shape_mnk, + a_major=a_major, + b_major=b_major, + c_major=c_major, + ) + if errors: + raise ValueError("unsupported SM120 MXF4NVF4 GEMM config: " + "; ".join(errors)) + + +__all__ = [ + "mxf4nvf4_can_implement", + "validate_mxf4nvf4_gemm_config", +] diff --git a/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_native_tma_microtile.py b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_native_tma_microtile.py new file mode 100644 index 0000000000..4989e2134e --- /dev/null +++ b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_native_tma_microtile.py @@ -0,0 +1,90 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from pathlib import Path + +import pytest + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import from_dlpack + + +pytestmark = [pytest.mark.arch(["120a"])] + +_MXF4NVF4_MMA_PTX = ( + "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale." + "scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3" +) + +_EXAMPLES_DIR = Path(__file__).parents[4] / "examples/python/CuTeDSL" + + +def _make_cute_fp4_tensor(torch, logical_major: int, logical_k: int): + storage = torch.empty( + (1, logical_major, logical_k // 2), + device="cuda", + dtype=torch.float4_e2m1fn_x2, + ) + storage.view(torch.uint8).fill_(0x22) + tensor = storage.permute(1, 2, 0) + # Torch exposes packed FP4 storage as x2 bytes. Disable TVM FFI conversion + # and restore the logical CuTe element type so TMA sees Float4E2M1FN. + cute_tensor = from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=False) + cute_tensor.element_type = cutlass.Float4E2M1FN + assert cute_tensor.element_type is cutlass.Float4E2M1FN + assert cute_tensor.shape == (logical_major, logical_k, 1) + return cute_tensor + + +def _make_cute_tensor(tensor): + return from_dlpack(tensor, assumed_align=16, enable_tvm_ffi=True) + + +def test_sm120_mxf4nvf4_native_tma_microtile_example(monkeypatch): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + if torch.cuda.get_device_capability()[0] < 12: + pytest.skip("SM120 CUDA device required") + if not hasattr(torch, "float4_e2m1fn_x2"): + pytest.skip("torch Float4E2M1FN storage dtype unavailable") + if not hasattr(torch, "float8_e4m3fn"): + pytest.skip("torch Float8E4M3FN dtype unavailable") + monkeypatch.syspath_prepend(str(_EXAMPLES_DIR)) + from cute.blackwell.kernel.blockscaled_gemm.sm120_mxf4nvf4_native_tma_microtile import ( + sm120_mxf4nvf4_native_tma_microtile, + ) + + a = _make_cute_fp4_tensor(torch, 128, 128) + b = _make_cute_fp4_tensor(torch, 128, 128) + sfa_storage = torch.ones((1024,), device="cuda", dtype=torch.float8_e4m3fn) + sfa_storage[512:].fill_(2.0) + sfb_storage = torch.ones((1024,), device="cuda", dtype=torch.float8_e4m3fn) + d_storage = torch.empty((16, 8), device="cuda", dtype=torch.bfloat16) + sfa = _make_cute_tensor(sfa_storage) + sfb = _make_cute_tensor(sfb_storage) + d = _make_cute_tensor(d_storage) + + compiled = cute.compile( + sm120_mxf4nvf4_native_tma_microtile, + a, + b, + sfa, + sfb, + d, + options="--keep-ptx", + ) + # Exact instruction counts are intentional for this fixed 16x8 K128 + # microtile: two K64 MMAs, two A/B TMA loads, and two scale TMA loads. + assert compiled.__ptx__.count(_MXF4NVF4_MMA_PTX) == 2 + assert compiled.__ptx__.count("cp.async.bulk.tensor.3d") == 2 + assert compiled.__ptx__.count("cp.async.bulk.tensor.2d") == 2 + + d_storage.zero_() + compiled(a, b, sfa, sfb, d) + torch.cuda.synchronize() + + torch.testing.assert_close( + d_storage.float(), torch.full_like(d_storage.float(), 192.0), rtol=0, atol=0 + ) diff --git a/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_scale_mapping.py b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_scale_mapping.py new file mode 100644 index 0000000000..be397d7193 --- /dev/null +++ b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_scale_mapping.py @@ -0,0 +1,360 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import warp +from cutlass.cute.runtime import from_dlpack + + +pytestmark = [pytest.mark.arch(["120a"])] + +_MXF4NVF4_MMA_PTX = ( + "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale." + "scale_vec::4X" +) + + +@cute.jit +def _make_full_k_tiled_mma(): + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + return cute.make_tiled_mma( + mma_op, + atom_layout_mnk=cute.make_layout((4, 2, 1), stride=(1, 4, 0)), + permutation_mnk=( + 128, + cute.make_layout((8, 2, 2), stride=(1, 16, 8)), + 64, + ), + ) + + +@cute.jit +def _make_sfa_fragment(): + return cute.make_rmem_tensor( + cute.make_layout(((16, 4), 2, 2), stride=((0, 1), 4, 8)), + cutlass.Float8E4M3FN, + ) + + +@cute.jit +def _make_sfb_fragment(): + return cute.make_rmem_tensor( + cute.make_layout(((16, 4), 8, 2), stride=((0, 1), 4, 32)), + cutlass.Float8E4M3FN, + ) + + +@cute.jit +def _store_accumulator(thr_mma: cute.ThrMma, acc: cute.Tensor, out: cute.Tensor): + tDgD = thr_mma.partition_C(out) + rD = cute.make_rmem_tensor(acc.layout, out.element_type) + rD.store(acc.load().to(out.element_type)) + copy_atom = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), out.element_type) + cute.copy(copy_atom, rD, tDgD) + + +@cute.jit +def _fill_sfa_k64_split(sfa: cute.Tensor): + sfa.fill(1.0) + for major in cutlass.range_constexpr(2): + for scale_col in cutlass.range_constexpr(4): + for row in cutlass.range_constexpr(16): + sfa[((row, scale_col), major, 1)] = cutlass.Float8E4M3FN(2.0) + + +@cute.jit +def _fill_row_col_k64_scales(sfa: cute.Tensor, sfb: cute.Tensor): + for major in cutlass.range_constexpr(2): + for k_tile in cutlass.range_constexpr(2): + sfa_value = 1 << (major + k_tile) + for scale_col in cutlass.range_constexpr(4): + for row in cutlass.range_constexpr(16): + sfa[((row, scale_col), major, k_tile)] = cutlass.Float8E4M3FN( + sfa_value + ) + + for major in cutlass.range_constexpr(8): + for k_tile in cutlass.range_constexpr(2): + sfb_value = 1 << k_tile + if major >= 4: + sfb_value = 4 << k_tile + for scale_col in cutlass.range_constexpr(4): + for row in cutlass.range_constexpr(16): + sfb[((row, scale_col), major, k_tile)] = cutlass.Float8E4M3FN( + sfb_value + ) + + +@cute.jit +def _fill_one_logical_scale_column( + sfa: cute.Tensor, + sfb: cute.Tensor, + selected_scale_col: cutlass.Constexpr[int], +): + sfa.fill(0.0) + sfb.fill(0.0) + selected_k_tile = selected_scale_col // 4 + selected_col = selected_scale_col % 4 + sfa_base = 1 << selected_scale_col + sfb_base = 1 << ((selected_scale_col + 1) % 4) + + for major in cutlass.range_constexpr(2): + sfa_value = sfa_base + if major != 0: + sfa_value = 2 * sfa_base + for row in cutlass.range_constexpr(16): + sfa[((row, selected_col), major, selected_k_tile)] = ( + cutlass.Float8E4M3FN(sfa_value) + ) + + for major in cutlass.range_constexpr(8): + sfb_value = sfb_base + if major >= 4: + sfb_value = 4 * sfb_base + for row in cutlass.range_constexpr(16): + sfb[((row, selected_col), major, selected_k_tile)] = ( + cutlass.Float8E4M3FN(sfb_value) + ) + + +@cute.kernel +def _full_k_distinct_c_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + tiled_mma = _make_full_k_tiled_mma() + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + c = cute.make_rmem_tensor(tiled_mma.partition_shape_C((128, 128)), cutlass.Float32) + sfa = _make_sfa_fragment() + sfb = _make_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + c.fill(7.0) + sfa.fill(1.0) + sfb.fill(1.0) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), c) + + thr_mma = tiled_mma.get_slice(tidx) + _store_accumulator(thr_mma, acc, out) + + +@cute.kernel +def _two_k64_scale_columns_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + tiled_mma = _make_full_k_tiled_mma() + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + sfa = _make_sfa_fragment() + sfb = _make_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + _fill_sfa_k64_split(sfa) + sfb.fill(1.0) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + thr_mma = tiled_mma.get_slice(tidx) + _store_accumulator(thr_mma, acc, out) + + +@cute.kernel +def _row_col_k64_scale_mapping_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + tiled_mma = _make_full_k_tiled_mma() + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + sfa = _make_sfa_fragment() + sfb = _make_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + _fill_row_col_k64_scales(sfa, sfb) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + thr_mma = tiled_mma.get_slice(tidx) + _store_accumulator(thr_mma, acc, out) + + +@cute.kernel +def _per_scale_column_mapping_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + tiled_mma = _make_full_k_tiled_mma() + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + sfa = _make_sfa_fragment() + sfb = _make_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + thr_mma = tiled_mma.get_slice(tidx) + + for selected_scale_col in cutlass.range_constexpr(8): + acc.fill(0.0) + _fill_one_logical_scale_column(sfa, sfb, selected_scale_col) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + _store_accumulator(thr_mma, acc, out[(selected_scale_col, None, None)]) + cute.arch.sync_threads() + + +@cute.kernel +def _row_varying_sfa_mapping_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + tiled_mma = _make_full_k_tiled_mma() + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + sfa = _make_sfa_fragment() + sfb = _make_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + sfa.fill(1.0) + row_mod = (tidx // 4) % 4 + if row_mod == 1: + sfa.fill(2.0) + if row_mod == 2: + sfa.fill(3.0) + if row_mod == 3: + sfa.fill(4.0) + sfb.fill(1.0) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + thr_mma = tiled_mma.get_slice(tidx) + _store_accumulator(thr_mma, acc, out) + + +@cute.jit +def _launch_full_k_distinct_c(out: cute.Tensor): + _full_k_distinct_c_kernel(out).launch(grid=[1, 1, 1], block=[256, 1, 1]) + + +@cute.jit +def _launch_two_k64_scale_columns(out: cute.Tensor): + _two_k64_scale_columns_kernel(out).launch(grid=[1, 1, 1], block=[256, 1, 1]) + + +@cute.jit +def _launch_row_col_k64_scale_mapping(out: cute.Tensor): + _row_col_k64_scale_mapping_kernel(out).launch( + grid=[1, 1, 1], block=[256, 1, 1] + ) + + +@cute.jit +def _launch_per_scale_column_mapping(out: cute.Tensor): + _per_scale_column_mapping_kernel(out).launch(grid=[1, 1, 1], block=[256, 1, 1]) + + +@cute.jit +def _launch_row_varying_sfa_mapping(out: cute.Tensor): + _row_varying_sfa_mapping_kernel(out).launch(grid=[1, 1, 1], block=[256, 1, 1]) + + +def _cuda_out(*shape): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + out = torch.empty(shape, device="cuda", dtype=torch.float32) + return torch, out, from_dlpack(out, enable_tvm_ffi=True) + + +def _compile(fn, *args): + compiled = cute.compile(fn, *args, options="--keep-ptx") + assert compiled.__ptx__.count(_MXF4NVF4_MMA_PTX) >= 2 + return compiled + + +def test_sm120_mxf4nvf4_scale_mapping_uses_both_k64_halves(): + torch, out, cute_out = _cuda_out(128, 128) + _compile(_launch_two_k64_scale_columns, cute_out)(cute_out) + torch.cuda.synchronize() + torch.testing.assert_close(out, torch.full_like(out, 192.0), rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_full_k_distinct_d_and_c_accumulates_across_k(): + torch, out, cute_out = _cuda_out(128, 128) + _compile(_launch_full_k_distinct_c, cute_out)(cute_out) + torch.cuda.synchronize() + torch.testing.assert_close(out, torch.full_like(out, 135.0), rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_scale_mapping_covers_row_col_and_k64(): + torch, out, cute_out = _cuda_out(128, 128) + _compile(_launch_row_col_k64_scale_mapping, cute_out)(cute_out) + torch.cuda.synchronize() + + expected = torch.empty_like(out) + expected[:64, :64] = 320.0 + expected[:64, 64:] = 1280.0 + expected[64:, :64] = 640.0 + expected[64:, 64:] = 2560.0 + torch.testing.assert_close(out, expected, rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_scale_mapping_covers_each_logical_scale_column(): + torch, out, cute_out = _cuda_out(8, 128, 128) + _compile(_launch_per_scale_column_mapping, cute_out)(cute_out) + torch.cuda.synchronize() + + expected = torch.empty_like(out) + for scale_col in range(8): + sfa_value = 1 << scale_col + sfb_value = 1 << ((scale_col + 1) % 4) + expected[scale_col, :64, :64] = 16.0 * sfa_value * sfb_value + expected[scale_col, :64, 64:] = 64.0 * sfa_value * sfb_value + expected[scale_col, 64:, :64] = 32.0 * sfa_value * sfb_value + expected[scale_col, 64:, 64:] = 128.0 * sfa_value * sfb_value + torch.testing.assert_close(out, expected, rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_scale_mapping_covers_row_varying_sfa(): + torch, out, cute_out = _cuda_out(128, 128) + _compile(_launch_row_varying_sfa_mapping, cute_out)(cute_out) + torch.cuda.synchronize() + + row_values = (torch.arange(128, device="cuda") % 4 + 1).to(torch.float32) * 128.0 + expected = row_values[:, None].expand_as(out) + torch.testing.assert_close(out, expected, rtol=0, atol=0) diff --git a/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_warp_mma.py b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_warp_mma.py new file mode 100644 index 0000000000..347b0bac55 --- /dev/null +++ b/test/examples/CuTeDSL/sm_120a/test_sm120_mxf4nvf4_warp_mma.py @@ -0,0 +1,326 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import warp +from cutlass.cute.runtime import from_dlpack, make_fake_compact_tensor + + +pytestmark = [pytest.mark.arch(["120a"])] + +_MXF4NVF4_UE4M3_MMA_PTX = ( + "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale." + "scale_vec::4X.f32.e2m1.e2m1.f32.ue4m3" +) + + +@cute.kernel +def _mma_mxf4nvf4_kernel(out: cute.Tensor, scale_a2: cutlass.Constexpr[bool]): + tidx, _, _ = cute.arch.thread_idx() + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((16, 64)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((8, 64)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + sfa = warp.make_mxf4nvf4_sfa_fragment() + sfb = warp.make_mxf4nvf4_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + sfa.fill(2.0 if scale_a2 else 1.0) + sfb.fill(1.0) + + warp.mma_mxf4nvf4(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + acc_size = cute.size(acc) + for i in cutlass.range(acc_size, unroll_full=True): + out[tidx * acc_size + i] = acc[i] + + +@cute.kernel +def _gemm_bundle_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((16, 64)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((8, 64)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + sfa = warp.make_mxf4nvf4_sfa_fragment() + sfb = warp.make_mxf4nvf4_sfb_fragment() + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + sfa.fill(1.0) + sfb.fill(1.0) + + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + acc_size = cute.size(acc) + for i in cutlass.range(acc_size, unroll_full=True): + out[tidx * acc_size + i] = acc[i] + + +@cute.jit +def _launch_mma_mxf4nvf4(out: cute.Tensor): + _mma_mxf4nvf4_kernel(out, False).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.jit +def _launch_mma_mxf4nvf4_scale_a2(out: cute.Tensor): + _mma_mxf4nvf4_kernel(out, True).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.jit +def _launch_gemm_bundle(out: cute.Tensor): + _gemm_bundle_kernel(out).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.kernel +def _full_fragment_gemm_bundle_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma( + mma_op, + atom_layout_mnk=cute.make_layout((4, 2, 1), stride=(1, 4, 0)), + permutation_mnk=( + 128, + cute.make_layout((8, 2, 2), stride=(1, 16, 8)), + 64, + ), + ) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((128, 128)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((128, 128)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor( + tiled_mma.partition_shape_C((128, 128)), cutlass.Float32 + ) + sfa = cute.make_rmem_tensor( + cute.make_layout(((16, 4), 2, 2), stride=((0, 1), 4, 8)), + cutlass.Float8E4M3FN, + ) + sfb = cute.make_rmem_tensor( + cute.make_layout(((16, 4), 8, 2), stride=((0, 1), 4, 32)), + cutlass.Float8E4M3FN, + ) + + cute.recast_tensor(a, cutlass.Int32).fill(0x22222222) + cute.recast_tensor(b, cutlass.Int32).fill(0x22222222) + acc.fill(0.0) + sfa.fill(1.0) + sfb.fill(1.0) + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + acc_size = cute.size(acc) + for i in cutlass.range(acc_size, unroll_full=True): + out[tidx * acc_size + i] = acc[i] + + +@cute.jit +def _launch_full_fragment_gemm_bundle(out: cute.Tensor): + _full_fragment_gemm_bundle_kernel(out).launch( + grid=[1, 1, 1], block=[256, 1, 1] + ) + + +@cute.kernel +def _wrong_scale_dtype_kernel(out: cute.Tensor): + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((16, 64)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((8, 64)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + sfa = cute.make_rmem_tensor(warp.make_mxf4nvf4_sfa_layout(), cutlass.Float8E5M2) + sfb = warp.make_mxf4nvf4_sfb_fragment() + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + +@cute.jit +def _launch_wrong_scale_dtype(out: cute.Tensor): + _wrong_scale_dtype_kernel(out).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.kernel +def _wrong_accumulator_dtype_kernel(out: cute.Tensor): + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((16, 64)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((8, 64)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float16) + sfa = warp.make_mxf4nvf4_sfa_fragment() + sfb = warp.make_mxf4nvf4_sfb_fragment() + cute.gemm(tiled_mma, acc, (a, sfa), (b, sfb), acc) + + +@cute.jit +def _launch_wrong_accumulator_dtype(out: cute.Tensor): + _wrong_accumulator_dtype_kernel(out).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.kernel +def _non_rmem_scale_kernel(out: cute.Tensor, scale: cute.Tensor): + mma_op = warp.MmaMXF4NVF4Op( + cutlass.Float4E2M1FN, cutlass.Float32, cutlass.Float8E4M3FN + ) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor( + tiled_mma.partition_shape_A((16, 64)), cutlass.Float4E2M1FN + ) + b = cute.make_rmem_tensor( + tiled_mma.partition_shape_B((8, 64)), cutlass.Float4E2M1FN + ) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + sfb = warp.make_mxf4nvf4_sfb_fragment() + cute.gemm(tiled_mma, acc, (a, scale), (b, sfb), acc) + + +@cute.jit +def _launch_non_rmem_scale(out: cute.Tensor, scale: cute.Tensor): + _non_rmem_scale_kernel(out, scale).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +@cute.kernel +def _plain_f16bf16_gemm_kernel(out: cute.Tensor): + mma_op = warp.MmaF16BF16Op(cutlass.Float16, cutlass.Float32, (16, 8, 16)) + tiled_mma = cute.make_tiled_mma(mma_op) + a = cute.make_rmem_tensor(tiled_mma.partition_shape_A((16, 16)), cutlass.Float16) + b = cute.make_rmem_tensor(tiled_mma.partition_shape_B((8, 16)), cutlass.Float16) + acc = cute.make_rmem_tensor(tiled_mma.partition_shape_C((16, 8)), cutlass.Float32) + a.fill(1.0) + b.fill(1.0) + acc.fill(0.0) + cute.gemm(tiled_mma, acc, a, b, acc) + + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + out[0] = acc[0] + + +@cute.jit +def _launch_plain_f16bf16_gemm(out: cute.Tensor): + _plain_f16bf16_gemm_kernel(out).launch(grid=[1, 1, 1], block=[32, 1, 1]) + + +def _cuda_out(): + torch = pytest.importorskip("torch") + if not torch.cuda.is_available(): + pytest.skip("CUDA device unavailable") + out = torch.empty((32 * 4,), device="cuda", dtype=torch.float32) + return torch, out, from_dlpack(out, enable_tvm_ffi=True) + + +def _compile_mma_runtime(fn, *args): + compiled = cute.compile(fn, *args, options="--keep-ptx --keep-cubin") + assert _MXF4NVF4_UE4M3_MMA_PTX in compiled.__ptx__ + assert "_mma.block_scale" not in compiled.__ptx__ + if not isinstance(compiled.artifacts.CUBIN, bytes): + pytest.skip("CuTe DSL CUBIN artifact unavailable for SM120 MMA backend check") + return compiled + + +def _fake_float32_out(): + return make_fake_compact_tensor( + cutlass.Float32, + (32 * 4,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + +def test_sm120_mxf4nvf4_direct_helper_all_ones(): + torch, out, cute_out = _cuda_out() + compiled = _compile_mma_runtime(_launch_mma_mxf4nvf4, cute_out) + compiled(cute_out) + torch.cuda.synchronize() + torch.testing.assert_close(out, torch.full_like(out, 64.0), rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_gemm_bundle_matches_direct_helper(): + torch, out_direct, cute_direct = _cuda_out() + _, out_gemm, cute_gemm = _cuda_out() + _compile_mma_runtime(_launch_mma_mxf4nvf4, cute_direct)(cute_direct) + _compile_mma_runtime(_launch_gemm_bundle, cute_gemm)(cute_gemm) + torch.cuda.synchronize() + torch.testing.assert_close(out_gemm, out_direct, rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_direct_helper_scale_a2(): + torch, out, cute_out = _cuda_out() + compiled = _compile_mma_runtime(_launch_mma_mxf4nvf4_scale_a2, cute_out) + compiled(cute_out) + torch.cuda.synchronize() + torch.testing.assert_close(out, torch.full_like(out, 128.0), rtol=0, atol=0) + + +def test_sm120_mxf4nvf4_full_fragment_gemm_bundle_compile(): + out = make_fake_compact_tensor( + cutlass.Float32, + (256 * 64,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + compiled = cute.compile(_launch_full_fragment_gemm_bundle, out, options="--keep-ptx") + assert ( + compiled.__ptx__.count( + "mma.sync.aligned.m16n8k64.row.col.kind::mxf4nvf4.block_scale.scale_vec::4X" + ) + == 2 + ) + + +def test_sm120_mxf4nvf4_wrong_scale_dtype_rejects(): + with pytest.raises(Exception, match="sfa.*Float8E4M3FN"): + cute.compile(_launch_wrong_scale_dtype, _fake_float32_out()) + + +def test_sm120_mxf4nvf4_wrong_accumulator_dtype_rejects(): + with pytest.raises(Exception, match="d.*Float32"): + cute.compile(_launch_wrong_accumulator_dtype, _fake_float32_out()) + + +def test_sm120_mxf4nvf4_non_rmem_scale_rejects(): + scale = make_fake_compact_tensor( + cutlass.Float8E4M3FN, + (64,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + with pytest.raises(Exception, match="sfa.*register-resident"): + cute.compile(_launch_non_rmem_scale, _fake_float32_out(), scale) + + +def test_sm120_mxf4nvf4_plain_f16bf16_gemm_still_compiles(): + cute.compile(_launch_plain_f16bf16_gemm, _fake_float32_out()) diff --git a/test/python/CuTeDSL/test_pipeline_tma_already_elected.py b/test/python/CuTeDSL/test_pipeline_tma_already_elected.py new file mode 100644 index 0000000000..b702a7e99c --- /dev/null +++ b/test/python/CuTeDSL/test_pipeline_tma_already_elected.py @@ -0,0 +1,96 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.cute.runtime import make_fake_compact_tensor + + +@cute.kernel +def _already_elected_tma_acquire_with_token_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + barrier_storage = cute.arch.alloc_smem(cutlass.Int64, 2, alignment=8) + pipe = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=16, + barrier_storage=barrier_storage, + tidx=tidx, + defer_sync=True, + ) + state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + + token = pipe.producer_try_acquire(state) + with cute.arch.elect_one(): + pipe.producer_acquire_already_elected(state, token) + + if tidx == 0: + out[0] = cutlass.Int32(0) + + +@cute.jit +def _launch_already_elected_tma_acquire_with_token(out: cute.Tensor): + _already_elected_tma_acquire_with_token_kernel(out).launch( + grid=[1, 1, 1], block=[32, 1, 1] + ) + + +@cute.kernel +def _already_elected_tma_acquire_kernel(out: cute.Tensor): + tidx, _, _ = cute.arch.thread_idx() + barrier_storage = cute.arch.alloc_smem(cutlass.Int64, 2, alignment=8) + pipe = pipeline.PipelineTmaAsync.create( + num_stages=1, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread, 1), + tx_count=16, + barrier_storage=barrier_storage, + tidx=tidx, + defer_sync=True, + ) + state = pipeline.make_pipeline_state(pipeline.PipelineUserType.Producer, 1) + + with cute.arch.elect_one(): + pipe.producer_acquire_already_elected(state) + + if tidx == 0: + out[0] = cutlass.Int32(0) + + +@cute.jit +def _launch_already_elected_tma_acquire(out: cute.Tensor): + _already_elected_tma_acquire_kernel(out).launch( + grid=[1, 1, 1], block=[32, 1, 1] + ) + + +def test_pipeline_tma_producer_acquire_already_elected_compiles(): + out = make_fake_compact_tensor( + cutlass.Int32, + (1,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + compiled = cute.compile(_launch_already_elected_tma_acquire, out, options="--keep-ptx") + + assert "mbarrier.try_wait" in compiled.__ptx__ + assert "mbarrier.arrive.expect_tx" in compiled.__ptx__ + + +def test_pipeline_tma_producer_acquire_already_elected_with_token_compiles(): + out = make_fake_compact_tensor( + cutlass.Int32, + (1,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + compiled = cute.compile( + _launch_already_elected_tma_acquire_with_token, out, options="--keep-ptx" + ) + + assert "mbarrier.try_wait" in compiled.__ptx__ + assert "mbarrier.arrive.expect_tx" in compiled.__ptx__ diff --git a/test/python/CuTeDSL/test_sm120_mxf4nvf4_tma_layouts.py b/test/python/CuTeDSL/test_sm120_mxf4nvf4_tma_layouts.py new file mode 100644 index 0000000000..3f0fe7f154 --- /dev/null +++ b/test/python/CuTeDSL/test_sm120_mxf4nvf4_tma_layouts.py @@ -0,0 +1,291 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest + +import cutlass +import cutlass.cute as cute +from cutlass.cute.runtime import make_fake_compact_tensor +from cutlass.utils.gemm import sm120 +from cutlass.utils.gemm.sm120.constants import ( + MXF4NVF4_AB_TMA_BYTES, + MXF4NVF4_FULL_TMA_BYTES, + mxf4nvf4_ab_tma_bytes, + mxf4nvf4_full_tma_bytes, +) +from cutlass.utils.gemm.sm120.layouts import ( + mxf4nvf4_padded_scale_k_extent, + mxf4nvf4_scale_tma_physical_k_extent, + mxf4nvf4_scale_tma_physical_l_extent, +) +from cutlass.utils.smem_allocator import SmemAllocator + + +def test_sm120_mxf4nvf4_public_api_is_narrow(): + assert set(sm120.__all__) == { + "MXF4NVF4_CTA_SHAPE_MNK", + "MXF4NVF4_MMA_SHAPE_MNK", + "MXF4NVF4_SCALE_TMA_BYTES", + "MXF4NVF4_SCALE_VEC_SIZE", + "make_mxf4nvf4_a_gmem_layout", + "make_mxf4nvf4_ab_tma_physical_layout_staged", + "make_mxf4nvf4_b_gmem_layout", + "make_mxf4nvf4_native_tma_atoms", + "make_mxf4nvf4_native_tma_smem_views", + "make_mxf4nvf4_scale_interleaved_gmem_layout", + "make_mxf4nvf4_scale_interleaved_tma_layout_staged", + "make_mxf4nvf4_tiled_mma", + "mxf4nvf4_ab_tma_bytes", + "mxf4nvf4_can_implement", + "mxf4nvf4_full_tma_bytes", + "validate_mxf4nvf4_gemm_config", + } + + +def test_sm120_mxf4nvf4_config_and_extent_validation(): + assert sm120.MXF4NVF4_CTA_SHAPE_MNK == (128, 128, 128) + assert sm120.MXF4NVF4_MMA_SHAPE_MNK == (16, 8, 64) + assert sm120.MXF4NVF4_SCALE_VEC_SIZE == 16 + assert sm120.mxf4nvf4_can_implement() + assert not sm120.mxf4nvf4_can_implement(k=64) + + sm120.validate_mxf4nvf4_gemm_config() + with pytest.raises(ValueError, match="tile_shape_mnk"): + sm120.validate_mxf4nvf4_gemm_config(tile_shape_mnk=(64, 128, 128)) + + assert mxf4nvf4_padded_scale_k_extent(8) == 16 + assert mxf4nvf4_padded_scale_k_extent(17) == 32 + assert mxf4nvf4_scale_tma_physical_k_extent(384) == 32 + assert mxf4nvf4_scale_tma_physical_l_extent(1) == 2 + with pytest.raises(ValueError, match="major_extent % 128"): + sm120.make_mxf4nvf4_scale_interleaved_tma_layout_staged(64, 128, 16, 1) + with pytest.raises(ValueError, match="logical_k_extent % sf_vec_size"): + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 127, 1) + assert mxf4nvf4_ab_tma_bytes("packed") == MXF4NVF4_AB_TMA_BYTES + assert mxf4nvf4_ab_tma_bytes("unpack") == 2 * MXF4NVF4_AB_TMA_BYTES + assert mxf4nvf4_full_tma_bytes("packed") == MXF4NVF4_FULL_TMA_BYTES + assert mxf4nvf4_full_tma_bytes("unpack") > MXF4NVF4_FULL_TMA_BYTES + + +@cute.jit +def _check_scale_interleaved_layouts_exact(): + ab_tma_layout = sm120.make_mxf4nvf4_ab_tma_physical_layout_staged(128, 128, 1) + gmem_layout = sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1) + smem_layout = sm120.make_mxf4nvf4_scale_interleaved_tma_layout_staged( + 128, 128, 16, 1 + ) + assert ab_tma_layout.shape == (128, 128, 1) + assert ab_tma_layout.stride == (128, 1, 16384) + expected_shape = (((32, 4), 1), 4, 2, 1) + expected_stride = (((16, 4), 512), 1, 512, 1024) + assert gmem_layout.shape == expected_shape + assert gmem_layout.stride == expected_stride + assert smem_layout.shape == expected_shape + assert smem_layout.stride == expected_stride + + +@cute.jit +def _build_native_tma_atoms( + gA: cute.Tensor, + gB: cute.Tensor, + gSFA_storage: cute.Tensor, + gSFB_storage: cute.Tensor, +): + gSFA = cute.make_tensor( + gSFA_storage.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + gSFB = cute.make_tensor( + gSFB_storage.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + ( + tma_atom_a, + tma_tensor_a, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + ) = sm120.make_mxf4nvf4_native_tma_atoms(gA, gB, gSFA, gSFB) + assert tma_atom_a is not None + assert tma_atom_b is not None + assert tma_atom_sfa is not None + assert tma_atom_sfb is not None + assert cute.rank(tma_tensor_a) == 3 + assert cute.rank(tma_tensor_b) == 3 + assert cute.size(tma_tensor_a, mode=[2]) == 2 + assert cute.size(tma_tensor_b, mode=[2]) == 2 + assert cute.rank(tma_tensor_sfa) == 4 + assert cute.rank(tma_tensor_sfb) == 4 + + +@cute.jit +def _build_native_tma_atoms_unpack( + gA: cute.Tensor, + gB: cute.Tensor, + gSFA_storage: cute.Tensor, + gSFB_storage: cute.Tensor, +): + gSFA = cute.make_tensor( + gSFA_storage.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + gSFB = cute.make_tensor( + gSFB_storage.iterator, + sm120.make_mxf4nvf4_scale_interleaved_gmem_layout(128, 128, 1), + ) + sm120.make_mxf4nvf4_native_tma_atoms( + gA, + gB, + gSFA, + gSFB, + ab_smem_format="unpack", + ) + + +@cute.jit +def _check_default_tiled_mma(): + tiled_mma = sm120.make_mxf4nvf4_tiled_mma() + assert tiled_mma.size == 256 + assert tiled_mma.get_tile_size(0) == 128 + assert tiled_mma.get_tile_size(2) == 64 + + +@cute.kernel +def _check_native_tma_smem_view_dtypes_kernel(ab_smem_format: cutlass.Constexpr[str]): + smem = SmemAllocator() + sA, sB, sSFA, sSFB = sm120.make_mxf4nvf4_native_tma_smem_views( + smem, + ab_smem_format=ab_smem_format, + ) + assert sA.element_type == cutlass.Float4E2M1FN + assert sB.element_type == cutlass.Float4E2M1FN + assert sSFA.element_type == cutlass.Float8E4M3FN + assert sSFB.element_type == cutlass.Float8E4M3FN + + +@cute.kernel +def _check_native_tma_unpack_smem_view_dtypes_kernel(): + smem = SmemAllocator() + sA, sB, sSFA, sSFB = sm120.make_mxf4nvf4_native_tma_smem_views( + smem, + ab_smem_format="unpack", + ) + assert sA.element_type == cutlass.Uint8 + assert sB.element_type == cutlass.Uint8 + assert sSFA.element_type == cutlass.Float8E4M3FN + assert sSFB.element_type == cutlass.Float8E4M3FN + + +@cute.jit +def _launch_check_native_tma_smem_view_dtypes(): + _check_native_tma_smem_view_dtypes_kernel("packed").launch( + grid=[1, 1, 1], block=[1, 1, 1] + ) + + +@cute.jit +def _launch_check_native_tma_unpack_smem_view_dtypes(): + _check_native_tma_unpack_smem_view_dtypes_kernel().launch( + grid=[1, 1, 1], block=[1, 1, 1] + ) + + +def _fake_fp4_tensor(memspace: cute.AddressSpace = cute.AddressSpace.gmem): + return make_fake_compact_tensor( + cutlass.Float4E2M1FN, + (128, 128, 1), + stride_order=(1, 0, 2), + memspace=memspace, + assumed_align=16, + ) + + +def _fake_fp8_scale_storage(): + return make_fake_compact_tensor( + cutlass.Float8E4M3FN, + (1024,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + +def _fake_uint8_scale_storage(): + return make_fake_compact_tensor( + cutlass.Uint8, + (1024,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + +def _fake_f16_tensor(): + return make_fake_compact_tensor( + cutlass.Float16, + (128, 128, 1), + stride_order=(1, 0, 2), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + +def test_sm120_mxf4nvf4_scale_interleaved_layouts_are_exact(): + _check_scale_interleaved_layouts_exact() + + +def test_sm120_mxf4nvf4_default_tiled_mma_matches_full_tile_config(): + _check_default_tiled_mma() + + +def test_sm120_mxf4nvf4_native_tma_smem_view_dtypes_compile(): + cute.compile(_launch_check_native_tma_smem_view_dtypes) + cute.compile(_launch_check_native_tma_unpack_smem_view_dtypes) + + +def test_sm120_mxf4nvf4_native_tma_atoms_compile(): + cute.compile( + _build_native_tma_atoms, + _fake_fp4_tensor(), + _fake_fp4_tensor(), + _fake_fp8_scale_storage(), + _fake_fp8_scale_storage(), + ) + + +def test_sm120_mxf4nvf4_native_tma_atoms_unpack_compile(): + cute.compile( + _build_native_tma_atoms_unpack, + _fake_fp4_tensor(), + _fake_fp4_tensor(), + _fake_fp8_scale_storage(), + _fake_fp8_scale_storage(), + ) + + +def test_sm120_mxf4nvf4_native_tma_atoms_reject_wrong_dtypes(): + with pytest.raises(Exception, match="gA.*Float4E2M1FN"): + cute.compile( + _build_native_tma_atoms, + _fake_f16_tensor(), + _fake_fp4_tensor(), + _fake_fp8_scale_storage(), + _fake_fp8_scale_storage(), + ) + with pytest.raises(Exception, match="gSFA.*Float8E4M3FN"): + cute.compile( + _build_native_tma_atoms, + _fake_fp4_tensor(), + _fake_fp4_tensor(), + _fake_uint8_scale_storage(), + _fake_fp8_scale_storage(), + ) + with pytest.raises(Exception, match="gA.*global-memory"): + cute.compile( + _build_native_tma_atoms, + _fake_fp4_tensor(cute.AddressSpace.smem), + _fake_fp4_tensor(), + _fake_fp8_scale_storage(), + _fake_fp8_scale_storage(), + ) diff --git a/test/python/CuTeDSL/test_tensor_swizzle_helpers.py b/test/python/CuTeDSL/test_tensor_swizzle_helpers.py new file mode 100644 index 0000000000..2962b3b010 --- /dev/null +++ b/test/python/CuTeDSL/test_tensor_swizzle_helpers.py @@ -0,0 +1,194 @@ +# Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from dataclasses import dataclass + +import pytest + +import cutlass +import cutlass.cute as cute +import cutlass.cute.tensor as cute_tensor +from cutlass.cute.runtime import make_fake_compact_tensor +from cutlass.cute.typing import Tensor + + +@dataclass +class _DummyTensor(Tensor): + _memspace: cute.AddressSpace + _iterator: object + _layout: object + _element_type: object = cutlass.Float32 + + def __str__(self) -> str: + return "_DummyTensor" + + def __getitem__(self, idx): + raise NotImplementedError + + def __setitem__(self, idx, value) -> None: + raise NotImplementedError + + @property + def element_type(self): + return self._element_type + + @property + def memspace(self): + return self._memspace + + @property + def iterator(self): + return self._iterator + + @property + def leading_dim(self): + return None + + @property + def layout(self): + return self._layout + + @property + def shape(self): + return (1,) + + @property + def stride(self): + return (1,) + + def fill(self, value) -> None: + raise NotImplementedError + + +def test_as_position_independent_swizzle_tensor_rejects_non_tensor(): + with pytest.raises(TypeError, match="expects a Tensor"): + cute.as_position_independent_swizzle_tensor(object()) + + +def test_as_position_independent_swizzle_tensor_rejects_non_smem_tensor(): + tensor = make_fake_compact_tensor( + cutlass.Float32, + (4,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + with pytest.raises(TypeError, match="shared-memory tensor"): + cute.as_position_independent_swizzle_tensor(tensor) + + +def test_as_position_independent_swizzle_tensor_moves_swizzle_to_pointer( + monkeypatch, +): + source = _DummyTensor( + cute.AddressSpace.smem, + _iterator=object(), + _layout=object(), + ) + swizzle = object() + nonswizzle_layout = object() + recast_pointer = object() + result_tensor = object() + + def fake_get_swizzle_portion(layout, *, loc=None, ip=None): + assert layout is source.layout + return swizzle + + def fake_get_nonswizzle_portion(layout, *, loc=None, ip=None): + assert layout is source.layout + return nonswizzle_layout + + def fake_recast_ptr(ptr, swizzle_=None, dtype=None, loc=None, ip=None): + assert ptr is source.iterator + assert swizzle_ is swizzle + assert dtype is cutlass.Float32 + return recast_pointer + + def fake_make_tensor(ptr, layout, *, loc=None, ip=None): + assert ptr is recast_pointer + assert layout is nonswizzle_layout + return result_tensor + + monkeypatch.setattr(cute_tensor, "get_swizzle_portion", fake_get_swizzle_portion) + monkeypatch.setattr( + cute_tensor, "get_nonswizzle_portion", fake_get_nonswizzle_portion + ) + monkeypatch.setattr(cute_tensor, "recast_ptr", fake_recast_ptr) + monkeypatch.setattr(cute_tensor, "make_tensor", fake_make_tensor) + + assert cute.as_position_independent_swizzle_tensor(source) is result_tensor + + +@cute.kernel +def _copy_position_independent_plain_smem_tensor_kernel(out: cute.Tensor): + smem_ptr = cute.arch.alloc_smem(cutlass.Float32, 64, alignment=16) + layout = cute.make_layout((8, 8), stride=(8, 1)) + smem_tensor = cute.make_tensor(smem_ptr, layout) + smem_view = cute.as_position_independent_swizzle_tensor(smem_tensor) + rmem_tensor = cute.make_rmem_tensor(layout, cutlass.Float32) + + rmem_tensor.fill(1.0) + cute.basic_copy(rmem_tensor, smem_view) + rmem_tensor.fill(0.0) + cute.basic_copy(smem_view, rmem_tensor) + + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + out[0] = rmem_tensor[0, 0] + + +@cute.jit +def _launch_copy_position_independent_plain_smem_tensor(out: cute.Tensor): + _copy_position_independent_plain_smem_tensor_kernel(out).launch( + grid=[1, 1, 1], block=[32, 1, 1] + ) + + +@cute.kernel +def _copy_position_independent_swizzle_tensor_kernel(out: cute.Tensor): + smem_ptr = cute.arch.alloc_smem(cutlass.Float32, 64, alignment=16) + base_layout = cute.make_layout((8, 8), stride=(8, 1)) + swizzled_layout = cute.make_composed_layout( + cute.make_swizzle(1, 4, 3), 0, base_layout + ) + smem_tensor = cute.make_tensor(smem_ptr, swizzled_layout) + smem_view = cute.as_position_independent_swizzle_tensor(smem_tensor) + rmem_tensor = cute.make_rmem_tensor(base_layout, cutlass.Float32) + + rmem_tensor.fill(1.0) + cute.basic_copy(rmem_tensor, smem_view) + rmem_tensor.fill(0.0) + cute.basic_copy(smem_view, rmem_tensor) + + tidx, _, _ = cute.arch.thread_idx() + if tidx == 0: + out[0] = rmem_tensor[0, 0] + + +@cute.jit +def _launch_copy_position_independent_swizzle_tensor(out: cute.Tensor): + _copy_position_independent_swizzle_tensor_kernel(out).launch( + grid=[1, 1, 1], block=[32, 1, 1] + ) + + +def test_as_position_independent_swizzle_tensor_copy_path_compiles(): + out = make_fake_compact_tensor( + cutlass.Float32, + (1,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + cute.compile(_launch_copy_position_independent_swizzle_tensor, out) + + +def test_as_position_independent_swizzle_tensor_plain_smem_copy_path_compiles(): + out = make_fake_compact_tensor( + cutlass.Float32, + (1,), + memspace=cute.AddressSpace.gmem, + assumed_align=16, + ) + + cute.compile(_launch_copy_position_independent_plain_smem_tensor, out)