diff --git a/python/CuTeDSL/cutlass/cute/atom.py b/python/CuTeDSL/cutlass/cute/atom.py index 6bfea228ba..e6967a7c38 100644 --- a/python/CuTeDSL/cutlass/cute/atom.py +++ b/python/CuTeDSL/cutlass/cute/atom.py @@ -352,7 +352,18 @@ def tv_layout_C( loc: Optional[ir.Location] = None, ip: Optional[ir.InsertionPoint] = None, ) -> Layout: - return static(self._trait.value.type.layout_c_tv) + layout = static(self._trait.value.type.layout_c_tv, loc=loc, ip=ip) + mnk = self.shape_mnk + if len(mnk) == 3 and mnk[1] == 8: + shape = layout.shape + stride = layout.stride + if len(shape) == 2 and len(shape[1]) == 2: + fixed_shape = (shape[0], (shape[1][0], shape[1][1], 1)) + fixed_stride = (stride[0], (stride[1][0], stride[1][1], 512)) + layout = make_layout( + fixed_shape, stride=fixed_stride, loc=loc, ip=ip + ) + return layout # # make_fragment