diff --git a/python/CuTeDSL/cutlass/base_dsl/arch.py b/python/CuTeDSL/cutlass/base_dsl/arch.py index 1740486340..36fadf39d3 100644 --- a/python/CuTeDSL/cutlass/base_dsl/arch.py +++ b/python/CuTeDSL/cutlass/base_dsl/arch.py @@ -10,59 +10,12 @@ # is strictly prohibited. from collections.abc import Callable -from enum import Enum, EnumMeta -import re -from typing import Any - - -class ArchMeta(EnumMeta): - """ - Custom metaclass for Arch enum that supports dynamic aliases based on CUDA version. - - - If cuda_version >= 13.0: sm_101/sm_101a/sm_101f are aliases of sm_110/sm_110a/sm_110f, use sm_110 as the canonical name - - Otherwise: sm_110/sm_110a/sm_110f are aliases of sm_101/sm_101a/sm_101f, use sm_101 as the canonical name - """ - - _arch_aliases: dict[str, str] = {} - - def __new__( - mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any] - ) -> "ArchMeta": - cls = super().__new__(mcs, name, bases, namespace) # type: ignore[arg-type] - from .version_info import CUDA_VERSION - - if CUDA_VERSION.major >= 13: - # sm_101 -> sm_110, use sm_110 as the canonical name - mcs._arch_aliases = { - "sm_101": "sm_110", - "sm_101a": "sm_110a", - "sm_101f": "sm_110f", - } - else: - # sm_110 -> sm_101, use sm_101 as the canonical name - mcs._arch_aliases = { - "sm_110": "sm_101", - "sm_110a": "sm_101a", - "sm_110f": "sm_101f", - } - return cls - - def __getattribute__(cls, name: str) -> Any: - # Use type.__getattribute__ to avoid recursion when accessing _arch_aliases - aliases = type.__getattribute__(cls, "_arch_aliases") - if name in aliases: - # Redirect to the target member - return type.__getattribute__(cls, aliases[name]) - return super().__getattribute__(name) - - def __getitem__(cls, name: str) -> "Arch": # type: ignore[override] - # Support Arch["sm_101"] style access - if name in cls._arch_aliases: - return super().__getitem__(cls._arch_aliases[name]) - return super().__getitem__(name) - - -class Arch(Enum, metaclass=ArchMeta): +from enum import Enum + +from .version_info import CUDA_VERSION + + +class Arch(Enum): # sm_arch = (major, minor, suffix) # Ampere sm_80 = (8, 0, "") @@ -77,21 +30,30 @@ class Arch(Enum, metaclass=ArchMeta): sm_100 = (10, 0, "") sm_100a = (10, 0, "a") sm_100f = (10, 0, "f") - sm_101 = (10, 1, "") - sm_101a = (10, 1, "a") - sm_101f = (10, 1, "f") + if CUDA_VERSION.major >= 13: + sm_110 = (11, 0, "") + sm_110a = (11, 0, "a") + sm_110f = (11, 0, "f") + sm_101 = sm_110 + sm_101a = sm_110a + sm_101f = sm_110f + else: + sm_101 = (10, 1, "") + sm_101a = (10, 1, "a") + sm_101f = (10, 1, "f") + sm_110 = sm_101 + sm_110a = sm_101a + sm_110f = sm_101f sm_103 = (10, 3, "") sm_103a = (10, 3, "a") sm_103f = (10, 3, "f") - sm_110 = (11, 0, "") - sm_110a = (11, 0, "a") - sm_110f = (11, 0, "f") sm_120 = (12, 0, "") sm_120a = (12, 0, "a") sm_120f = (12, 0, "f") sm_121 = (12, 1, "") sm_121a = (12, 1, "a") sm_121f = (12, 1, "f") + def __init__(self, major: int, minor: int, suffix: str) -> None: self.major = major self.minor = minor @@ -112,7 +74,7 @@ def HopperArchs(cls) -> tuple["Arch", ...]: @classmethod def BlackwellArchs(cls) -> tuple["Arch", ...]: - return ( + archs = ( Arch.sm_100, Arch.sm_100a, Arch.sm_100f, @@ -132,6 +94,7 @@ def BlackwellArchs(cls) -> tuple["Arch", ...]: Arch.sm_121a, Arch.sm_121f, ) + return tuple(dict.fromkeys(archs)) def __str__(self) -> str: return self.name @@ -168,7 +131,7 @@ def is_family_of(self, arch: "Arch") -> bool: """ # sm_101 is renamed to sm_110, sm_101f is family of sm_110f, but is not family of sm_100f if self in [Arch.sm_101a, Arch.sm_101f]: - return arch.major == 11 and arch.minor >= 0 + return arch in [Arch.sm_101, Arch.sm_101a, Arch.sm_101f] return ( self.major == arch.major