diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt index ab27e757c6e57..9c78a57cca3ad 100644 --- a/.ci/docker/ci_commit_pins/triton.txt +++ b/.ci/docker/ci_commit_pins/triton.txt @@ -1 +1 @@ -b4e20bbe55617cc798b986c2555a2bc7b303c737 +9c610c781cb810a11bfcc9accba094550b189a5e diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh index 1b68e3c247839..c68bef68b473c 100755 --- a/.ci/docker/common/install_triton.sh +++ b/.ci/docker/common/install_triton.sh @@ -21,12 +21,16 @@ elif [ -n "${TRITON_CPU}" ]; then TRITON_REPO="https://github.com/triton-lang/triton-cpu" TRITON_TEXT_FILE="triton-cpu" else - TRITON_REPO="https://github.com/triton-lang/triton" + TRITON_REPO="https://github.com/ROCm/triton" TRITON_TEXT_FILE="triton" fi # The logic here is copied from .ci/pytorch/common_utils.sh -TRITON_PINNED_COMMIT=$(get_pinned_commit ${TRITON_TEXT_FILE}) +if [ "${TRITON_TEXT_FILE}" = "triton" ]; then + TRITON_REF="release/internal/3.7.x" +else + TRITON_REF=$(get_pinned_commit ${TRITON_TEXT_FILE}) +fi if [ -n "${UBUNTU_VERSION}" ];then apt update @@ -49,7 +53,7 @@ pushd /var/lib/jenkins/ as_jenkins git clone --recursive ${TRITON_REPO} triton cd triton -as_jenkins git checkout ${TRITON_PINNED_COMMIT} +as_jenkins git checkout ${TRITON_REF} as_jenkins git submodule update --init --recursive # Old versions of python have setup.py in ./python; newer versions have it in ./ diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py index de0fa7739edc5..59f17b28bde38 100644 --- a/.github/scripts/build_triton_wheel.py +++ b/.github/scripts/build_triton_wheel.py @@ -70,7 +70,7 @@ def build_triton( triton_basedir = Path(tmpdir) / "triton" triton_pythondir = triton_basedir / "python" - triton_repo = "https://github.com/openai/triton" + triton_repo = "https://github.com/ROCm/triton" if device == "rocm": triton_pkg_name = "triton-rocm" elif device == "xpu": @@ -79,21 +79,22 @@ def build_triton( else: triton_pkg_name = "triton" check_call(["git", "clone", triton_repo, "triton"], cwd=tmpdir) - if release: + if device == "xpu" and release: ver, rev, patch = version.split(".") - if device == "xpu": - # XPU uses the patch version in the release branch name - check_call( - ["git", "checkout", f"release/{ver}.{rev}.{patch}"], - cwd=triton_basedir, - ) - else: - check_call( - ["git", "checkout", f"release/{ver}.{rev}.x"], cwd=triton_basedir - ) - else: + # XPU uses the patch version in the release branch name + check_call( + ["git", "checkout", f"release/{ver}.{rev}.{patch}"], + cwd=triton_basedir, + ) + elif device == "xpu": check_call(["git", "fetch", "origin", commit_hash], cwd=triton_basedir) check_call(["git", "checkout", commit_hash], cwd=triton_basedir) + else: + ver, rev, _patch = version.split(".") + check_call( + ["git", "checkout", f"release/internal/{ver}.{rev}.x"], + cwd=triton_basedir, + ) # change built wheel name and version env["TRITON_WHEEL_NAME"] = triton_pkg_name