Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions python/CuTeDSL/cutlass/cute/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6342,31 +6342,36 @@ def _divisor(self, value: ir.Value) -> None:

def __extract_mlir_values__(self) -> List[ir.Value]:
"""Extract MLIR values for Host->Device transfer."""
# CRITICAL: Extract the FastDivmodDivisor MLIR value directly.
# Extract the encoded FastDivmod SSA ("_divisor_mlir") plus the scalar divisor SSA
# used when building it.
#
# This allows GridInvariantCodeMotionPass to:
# 1. Recognize FastDivmodCreateDivisorOp in the IR
# 2. Hoist it to the host side before kernel launch
# 3. Pass the pre-computed divisor as a kernel argument
#
# We only extract the _divisor_mlir to maintain compatibility with
# other code that assumes each FastDivmodDivisor has exactly 1 MLIR value.
# The _original_divisor is preserved in the object structure.
return [self._divisor_mlir]
# The scalar divisor must travel as its own SSA value so that after
# ``__new_from_mlir_values__`` inside an isolated kernel region,
# ``.divisor`` refers to SSA defined in that region (see GitHub issue #3243).
divisor_for_pack = self._original_divisor
if isinstance(divisor_for_pack, ir.Value):
divisor_ir = divisor_for_pack
else:
divisor_ir = Int32(divisor_for_pack).ir_value()
return [self._divisor_mlir, divisor_ir]

def __new_from_mlir_values__(self, values: List[ir.Value]) -> "FastDivmodDivisor":
"""Reconstruct FastDivmodDivisor from MLIR values."""
# Directly use the passed FastDivmodDivisor value without recreating it.
# This is critical to avoid generating new create_divisor ops on device side,
# which would bypass GridInvariantCodeMotionPass optimization.
assert len(values) >= 2, (
"FastDivmodDivisor reconstructs from the encoded divisor plus the transit "
f"scalar divisor IR; expected 2 values, got {len(values)}."
)
new_obj = object.__new__(FastDivmodDivisor)
new_obj._divisor_mlir = values[0]

# Preserve the original divisor to support the public divisor property.
# Note: After host-device transfer, _original_divisor will reference
# the same value as before transfer for constants, or the reconstructed
# value for dynamic expressions.
new_obj._original_divisor = self._original_divisor
new_obj._original_divisor = IntValue(values[1])

return new_obj

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,20 @@ def __new_from_mlir_values__(

if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0:
# Override the FastDivmod divisors created by __init__ with reconstructed ones
for j, original_index in enumerate(self._fastdivmod_indices):
fdd_tail_offset = 0
for original_index in self._fastdivmod_indices:
fdd_name = fdd_names[original_index]
# Get the original FastDivmodDivisor object
original_fdd = getattr(self, fdd_name)
if original_fdd is not None and j < len(values_copy):
# Each FastDivmodDivisor has 1 MLIR value
if original_fdd is None:
continue
n_fdd = len(extract_mlir_values(original_fdd))
end = fdd_tail_offset + n_fdd
if end <= len(values_copy):
reconstructed_fdd = new_from_mlir_values(
original_fdd, [values_copy[j]]
original_fdd, values_copy[fdd_tail_offset:end]
)
setattr(new_params, fdd_name, reconstructed_fdd)
fdd_tail_offset = end

return new_params

Expand Down
14 changes: 9 additions & 5 deletions python/CuTeDSL/cutlass/utils/static_persistent_tile_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,20 @@ def __new_from_mlir_values__(

if hasattr(self, "_fastdivmod_indices") and len(self._fastdivmod_indices) > 0:
# Override the FastDivmod divisors created by __init__ with reconstructed ones
for j, original_index in enumerate(self._fastdivmod_indices):
fdd_tail_offset = 0
for original_index in self._fastdivmod_indices:
fdd_name = fdd_names[original_index]
# Get the original FastDivmodDivisor object
original_fdd = getattr(self, fdd_name)
if original_fdd is not None and j < len(values_copy):
# Each FastDivmodDivisor has 1 MLIR value
if original_fdd is None:
continue
n_fdd = len(extract_mlir_values(original_fdd))
end = fdd_tail_offset + n_fdd
if end <= len(values_copy):
reconstructed_fdd = new_from_mlir_values(
original_fdd, [values_copy[j]]
original_fdd, values_copy[fdd_tail_offset:end]
)
setattr(new_params, fdd_name, reconstructed_fdd)
fdd_tail_offset = end

return new_params

Expand Down
41 changes: 41 additions & 0 deletions test/examples/CuTeDSL/test_fast_divmod_divisor_param_transport.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright (c) 2025 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

# Regression for https://github.com/NVIDIA/cutlass/issues/3243 — FastDivmodDivisor
# must carry the scalar divisor SSA alongside the encoded divisor so that
# ``params.fdd.divisor`` is legal inside an isolated kernel region.

from dataclasses import dataclass

import cutlass
import cutlass.cute as cute
from cutlass import Int32
from cutlass.cute.runtime import make_fake_tensor


@dataclass
class Params:
fdd: cute.FastDivmodDivisor


@cute.jit
def make_params(divisor: Int32) -> Params:
return Params(cute.FastDivmodDivisor(divisor))


@cute.kernel
def write_divisor(out: cute.Tensor, params: Params):
tidx, _, _ = cute.arch.thread_idx()
if tidx == 0:
out[0] = params.fdd.divisor


@cute.jit
def entry(out: cute.Tensor, divisor: Int32):
params = make_params(divisor)
write_divisor(out, params).launch(grid=(1, 1, 1), block=(32, 1, 1))


def test_fast_divmod_divisor_inside_params_compiles():
out_fake = make_fake_tensor(cutlass.Int32, (1,), stride=(1,), assumed_align=4)
cute.compile(entry, out_fake, Int32(0))