From ecf0230c239bf90d46dc885e90dff499dae604da Mon Sep 17 00:00:00 2001 From: Hardy <18930861549@163.com> Date: Thu, 21 May 2026 20:13:47 +0800 Subject: [PATCH] fix(CuTeDSL): restore trailing Int<1> dimension in SM90 MMA atom TV Layout C when N=8 When N=8, the SM90 WGMMA CLayout_64xN has a trailing Int = Int<1> dimension in the C++ CuTe reference (mma_traits_sm90_gmma.hpp L432-L435). The MLIR layout canonicalization was incorrectly dropping this size-1 dimension, causing tv_layout_C to return ((4,8,4),(2,2)) instead of ((4,8,4),(2,2,1)). This fix adds a post-processing check in MmaAtom.tv_layout_C to restore the trailing Int<1> dimension when N=8, matching the C++ implementation and preserving layout structural invariants. Fixes NVIDIA/cutlass#3254 --- python/CuTeDSL/cutlass/cute/atom.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/python/CuTeDSL/cutlass/cute/atom.py b/python/CuTeDSL/cutlass/cute/atom.py index 6bfea228ba..e6967a7c38 100644 --- a/python/CuTeDSL/cutlass/cute/atom.py +++ b/python/CuTeDSL/cutlass/cute/atom.py @@ -352,7 +352,18 @@ def tv_layout_C( loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None, ) -> Layout: - return static(self._trait.value.type.layout_c_tv) + layout = static(self._trait.value.type.layout_c_tv, loc=loc, ip=ip) + mnk = self.shape_mnk + if len(mnk) == 3 and mnk[1] == 8: + shape = layout.shape + stride = layout.stride + if len(shape) == 2 and len(shape[1]) == 2: + fixed_shape = (shape[0], (shape[1][0], shape[1][1], 1)) + fixed_stride = (stride[0], (stride[1][0], stride[1][1], 512)) + layout = make_layout( + fixed_shape, stride=fixed_stride, loc=loc, ip=ip + ) + return layout # # make_fragment