Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 24 additions & 61 deletions python/CuTeDSL/cutlass/base_dsl/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,59 +10,12 @@
# is strictly prohibited.

from collections.abc import Callable
from enum import Enum, EnumMeta
import re
from typing import Any


class ArchMeta(EnumMeta):
"""
Custom metaclass for Arch enum that supports dynamic aliases based on CUDA version.

- If cuda_version >= 13.0: sm_101/sm_101a/sm_101f are aliases of sm_110/sm_110a/sm_110f, use sm_110 as the canonical name
- Otherwise: sm_110/sm_110a/sm_110f are aliases of sm_101/sm_101a/sm_101f, use sm_101 as the canonical name
"""

_arch_aliases: dict[str, str] = {}

def __new__(
mcs, name: str, bases: tuple[type, ...], namespace: dict[str, Any]
) -> "ArchMeta":
cls = super().__new__(mcs, name, bases, namespace) # type: ignore[arg-type]
from .version_info import CUDA_VERSION

if CUDA_VERSION.major >= 13:
# sm_101 -> sm_110, use sm_110 as the canonical name
mcs._arch_aliases = {
"sm_101": "sm_110",
"sm_101a": "sm_110a",
"sm_101f": "sm_110f",
}
else:
# sm_110 -> sm_101, use sm_101 as the canonical name
mcs._arch_aliases = {
"sm_110": "sm_101",
"sm_110a": "sm_101a",
"sm_110f": "sm_101f",
}
return cls

def __getattribute__(cls, name: str) -> Any:
# Use type.__getattribute__ to avoid recursion when accessing _arch_aliases
aliases = type.__getattribute__(cls, "_arch_aliases")
if name in aliases:
# Redirect to the target member
return type.__getattribute__(cls, aliases[name])
return super().__getattribute__(name)

def __getitem__(cls, name: str) -> "Arch": # type: ignore[override]
# Support Arch["sm_101"] style access
if name in cls._arch_aliases:
return super().__getitem__(cls._arch_aliases[name])
return super().__getitem__(name)


class Arch(Enum, metaclass=ArchMeta):
from enum import Enum

from .version_info import CUDA_VERSION


class Arch(Enum):
# sm_arch = (major, minor, suffix)
# Ampere
sm_80 = (8, 0, "")
Expand All @@ -77,21 +30,30 @@ class Arch(Enum, metaclass=ArchMeta):
sm_100 = (10, 0, "")
sm_100a = (10, 0, "a")
sm_100f = (10, 0, "f")
sm_101 = (10, 1, "")
sm_101a = (10, 1, "a")
sm_101f = (10, 1, "f")
if CUDA_VERSION.major >= 13:
sm_110 = (11, 0, "")
sm_110a = (11, 0, "a")
sm_110f = (11, 0, "f")
sm_101 = sm_110
sm_101a = sm_110a
sm_101f = sm_110f
else:
sm_101 = (10, 1, "")
sm_101a = (10, 1, "a")
sm_101f = (10, 1, "f")
sm_110 = sm_101
sm_110a = sm_101a
sm_110f = sm_101f
sm_103 = (10, 3, "")
sm_103a = (10, 3, "a")
sm_103f = (10, 3, "f")
sm_110 = (11, 0, "")
sm_110a = (11, 0, "a")
sm_110f = (11, 0, "f")
sm_120 = (12, 0, "")
sm_120a = (12, 0, "a")
sm_120f = (12, 0, "f")
sm_121 = (12, 1, "")
sm_121a = (12, 1, "a")
sm_121f = (12, 1, "f")

def __init__(self, major: int, minor: int, suffix: str) -> None:
self.major = major
self.minor = minor
Expand All @@ -112,7 +74,7 @@ def HopperArchs(cls) -> tuple["Arch", ...]:

@classmethod
def BlackwellArchs(cls) -> tuple["Arch", ...]:
return (
archs = (
Arch.sm_100,
Arch.sm_100a,
Arch.sm_100f,
Expand All @@ -132,6 +94,7 @@ def BlackwellArchs(cls) -> tuple["Arch", ...]:
Arch.sm_121a,
Arch.sm_121f,
)
return tuple(dict.fromkeys(archs))

def __str__(self) -> str:
return self.name
Expand Down Expand Up @@ -168,7 +131,7 @@ def is_family_of(self, arch: "Arch") -> bool:
"""
# sm_101 is renamed to sm_110, sm_101f is family of sm_110f, but is not family of sm_100f
if self in [Arch.sm_101a, Arch.sm_101f]:
return arch.major == 11 and arch.minor >= 0
return arch in [Arch.sm_101, Arch.sm_101a, Arch.sm_101f]

return (
self.major == arch.major
Expand Down