Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .ci/docker/ci_commit_pins/triton.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
b4e20bbe55617cc798b986c2555a2bc7b303c737
9c610c781cb810a11bfcc9accba094550b189a5e
10 changes: 7 additions & 3 deletions .ci/docker/common/install_triton.sh
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,16 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
TRITON_REPO="https://github.com/triton-lang/triton"
TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi

# The logic here is copied from .ci/pytorch/common_utils.sh
TRITON_PINNED_COMMIT=$(get_pinned_commit ${TRITON_TEXT_FILE})
if [ "${TRITON_TEXT_FILE}" = "triton" ]; then
TRITON_REF="release/internal/3.7.x"
else
TRITON_REF=$(get_pinned_commit ${TRITON_TEXT_FILE})
fi

if [ -n "${UBUNTU_VERSION}" ];then
apt update
Expand All @@ -49,7 +53,7 @@ pushd /var/lib/jenkins/

as_jenkins git clone --recursive ${TRITON_REPO} triton
cd triton
as_jenkins git checkout ${TRITON_PINNED_COMMIT}
as_jenkins git checkout ${TRITON_REF}
as_jenkins git submodule update --init --recursive

# Old versions of python have setup.py in ./python; newer versions have it in ./
Expand Down
27 changes: 14 additions & 13 deletions .github/scripts/build_triton_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def build_triton(
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"

triton_repo = "https://github.com/openai/triton"
triton_repo = "https://github.com/ROCm/triton"
if device == "rocm":
triton_pkg_name = "triton-rocm"
elif device == "xpu":
Expand All @@ -79,21 +79,22 @@ def build_triton(
else:
triton_pkg_name = "triton"
check_call(["git", "clone", triton_repo, "triton"], cwd=tmpdir)
if release:
if device == "xpu" and release:
ver, rev, patch = version.split(".")
if device == "xpu":
# XPU uses the patch version in the release branch name
check_call(
["git", "checkout", f"release/{ver}.{rev}.{patch}"],
cwd=triton_basedir,
)
else:
check_call(
["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir
)
else:
# XPU uses the patch version in the release branch name
check_call(
["git", "checkout", f"release/{ver}.{rev}.{patch}"],
cwd=triton_basedir,
)
elif device == "xpu":
check_call(["git", "fetch", "origin", commit_hash], cwd=triton_basedir)
check_call(["git", "checkout", commit_hash], cwd=triton_basedir)
else:
ver, rev, _patch = version.split(".")
check_call(
["git", "checkout", f"release/internal/{ver}.{rev}.x"],
cwd=triton_basedir,
)

# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
Expand Down