diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py
new file mode 100644
index 000000000000..7db2e1ce9233
--- /dev/null
+++ b/.automation_scripts/parse_xml_results.py
@@ -0,0 +1,178 @@
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ work_flow_name: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ print(f"Parsing {tag}s for test report: {report}")
+
+ job_id = get_job_id(report)
+ print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["work_flow_name"] = work_flow_name
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case
+ elif tag == 'testsuite':
+ case["work_flow_name"] = work_flow_name
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case
+
+ return test_cases
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
\ No newline at end of file
diff --git a/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py b/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py
new file mode 100644
index 000000000000..cf948495ec04
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py
@@ -0,0 +1,1027 @@
+#!/usr/bin/env python3
+"""
+Auto-classify skip reasons for ROCm parity CSV tests.
+
+Takes a parity CSV (output of summarize_xml_testreports.py) and automatically
+assigns skip_reason categories to tests where ROCm=SKIPPED/MISSED and CUDA=PASSED
+based on patterns in:
+ - The skip message (message_rocm column)
+ - The test file name
+ - The test class name
+ - The test name
+
+Rules are ordered by specificity: combined match rules first, then message-based,
+then file+class combos, then file-only fallbacks. First matching rule wins.
+
+Usage:
+ python auto_classify_skip_reasons.py -i input.csv -o output.csv [--report]
+ python auto_classify_skip_reasons.py -i input.csv -o output.csv --tsv-out updated_skip_reasons.tsv
+ python auto_classify_skip_reasons.py -i input.csv --dry-run --report
+"""
+
+import argparse
+import ast
+import csv
+import re
+import sys
+from collections import Counter, defaultdict
+
+
+# ---------------------------------------------------------------------------
+# Rules are evaluated top-to-bottom; first match wins.
+# Each rule is a dict with:
+# reason: the skip_reason category string
+# msg: (optional) regex to match against the skip message
+# file: (optional) regex to match against test_file
+# cls: (optional) regex to match against test_class
+# name: (optional) regex to match against test_name
+# workflow: (optional) one of "default", "distributed", "inductor"
+#
+# All provided fields must match (AND logic). Omitted fields match anything.
+# msg="" matches empty messages; omitting msg matches anything.
+# ---------------------------------------------------------------------------
+
+RULES = [
+ # ==================================================================
+ # TIER 1: High-specificity combined rules (message + file/class)
+ # ==================================================================
+
+ # --- bfloat16_SDPA_ME: dropout mask in test_transformers with bfloat16 in TEST NAME ---
+ # Must be before generic SDPA_ME rule
+ {"reason": "bfloat16_SDPA_ME",
+ "msg": r"_fill_mem_eff_dropout_mask",
+ "file": r"^test_transformers$",
+ "name": r"(?i)bfloat16|bf16"},
+
+ # --- GEMMS: test_mm_bmm in test_matmul_cuda with accuracy regression ---
+ # Must be before generic hipblas rule
+ {"reason": "GEMMS",
+ "msg": r"accuracy regression in hipblas",
+ "file": r"^test_matmul_cuda$",
+ "name": r"test_mm_bmm"},
+
+ # --- hipblas hipblaslt: test_addmm/test_cublas/other in test_matmul_cuda ---
+ {"reason": "hipblas hipblaslt",
+ "msg": r"accuracy regression in hipblas",
+ "file": r"^test_matmul_cuda$"},
+ {"reason": "hipblas hipblaslt",
+ "msg": r"skipIfRocm.*doesn't currently work",
+ "file": r"^test_matmul_cuda$"},
+ {"reason": "hipblas hipblaslt",
+ "file": r"^test_matmul_cuda$",
+ "msg": r"Green contexts are not supported"},
+
+ # --- Expected to work: skipCUDAIfRocm in test_meta for ldl_solve ops ---
+ {"reason": "Expected to work",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$",
+ "name": r"(?i)ldl_solve"},
+
+ # --- Linalg: skipCUDAIfRocm in test_meta for other linalg ops ---
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$"},
+
+ # --- Linalg: skipCUDAIfRocm in test_ops/test_linalg/test_meta/test_ops_fwd_gradients/test_ops_gradients ---
+ # These are ops like linalg.svd, linalg.eigh, etc.
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_linalg$"},
+ {"reason": "Linalg",
+ "msg": r"_convert_weight_to_int4pack_cuda.*(supported only for|is supported only for) CDNA"},
+ {"reason": "Linalg",
+ "msg": r"bfloat16 NCHW train failed"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|matrix_rank|cross|norm|cond|householder|ormqr|geqrf|triangular|vecdot|multi_dot"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_fwd_gradients$"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_gradients$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|householder|ormqr|geqrf|triangular"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|householder|ormqr|geqrf|triangular"},
+ {"reason": "Linalg",
+ "file": r"^test_nn$",
+ "msg": r"skipIfRocm.*doesn't currently work"},
+
+ # --- hipSolver/Magma: skipCUDAIfRocm in test_ops for ldl_solve, scaled_dot_product, conv_transpose3d ---
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)ldl_solve|scaled_dot_product|conv_transpose3d"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_jit$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_decomp$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_schema_check$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_testing$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"Skipped for ROCm!"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"test_cow_input does not work with efficient attention on ROCM"},
+
+ # --- Compiler issue: "Skipped!" in test_ops for specific compiler-related tests ---
+ {"reason": "Compiler issue",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$",
+ "name": r"(?i)special_hermite_polynomial_h|special_laguerre"},
+
+ # --- non-standard bool: "Skipped!" in test_ops for bool-related tests ---
+ {"reason": "non-standard bool",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$",
+ "name": r"(?i)bool"},
+
+ # --- pow: "Skipped!" in test_ops/test_decomp for pow tests ---
+ {"reason": "pow",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$|^test_decomp$",
+ "name": r"(?i)^pow$|_pow_|float_power"},
+
+ # --- fft: "Skipped!" or "Skipped on ROCm" in test_ops for fft tests ---
+ {"reason": "fft",
+ "msg": r"^Skipped(!| on ROCm)$",
+ "file": r"^test_ops$",
+ "name": r"(?i)fft"},
+
+ # --- NHWC: "Skipped!" in test_modules for NHWC tests ---
+ {"reason": "NHWC",
+ "msg": r"^Skipped!$",
+ "file": r"^test_modules$"},
+
+ # (FakeTensor removed — "Requires CUDA" messages are explicit NVIDIA test per policy)
+
+ # --- hermite_polynomial_h: custom_mask_type in test_ops for hermite ---
+ {"reason": "hermite_polynomial_h",
+ "msg": r"Efficient attention on ROCM doesn't support custom_mask_type",
+ "file": r"^test_ops$",
+ "name": r"(?i)hermite"},
+
+ # --- fake_crossref: skipCUDAIfRocm in test_ops for crossref tests ---
+ {"reason": "fake_crossref",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)crossref|fake_crossref"},
+
+ # --- Jit: Tensor-likes not close in test_jit_fuser ---
+ {"reason": "Jit",
+ "msg": r"Tensor-likes are not close",
+ "file": r"test_jit_fuser"},
+
+ # --- Memory allocation: TestBlockStateAbsorption in test_cuda ---
+ {"reason": "Memory allocation",
+ "file": r"^test_cuda$",
+ "cls": r"^TestBlockStateAbsorption$"},
+
+ # --- cuda allocator: TestCudaAllocator in test_cuda ---
+ {"reason": "cuda allocator",
+ "file": r"^test_cuda$",
+ "cls": r"^TestCudaAllocator$"},
+
+ # --- hipGraph/cudaGraph: CudaGraph-related classes in test_cuda ---
+ {"reason": "hipGraph/cudaGraph",
+ "file": r"^test_cuda$",
+ "cls": r"CachingHostAllocatorCudaGraph|GreenContext"},
+
+ # --- Memory allocation: TestMemPool in test_cuda ---
+ {"reason": "Memory allocation",
+ "file": r"^test_cuda$",
+ "cls": r"^TestMemPool$"},
+
+ # --- Profiler: TestFXMemoryProfiler in test_cuda ---
+ {"reason": "Profiler",
+ "file": r"^test_cuda$",
+ "cls": r"FXMemoryProfiler"},
+
+ # --- compiled optimizer: ROCm numerical behavior in inductor.test_compiled_optimizers ---
+ {"reason": "compiled optimizer",
+ "msg": r"ROCm may have different numerical behavior",
+ "file": r"inductor\.test_compiled_optimizers"},
+
+ # --- functorch: FuncTorch classes in inductor.test_compiled_autograd ---
+ {"reason": "functorch",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"FuncTorch"},
+
+ # --- PT2.0 - Distributed: DTensor classes in inductor.test_compiled_autograd ---
+ {"reason": "PT2.0 - Distributed",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"DTensor"},
+
+ # --- hipdnn: cudnn Attention messages ---
+ {"reason": "hipdnn",
+ "msg": r"[Cc]u[Dd][Nn][Nn] Attention is not supported"},
+ {"reason": "hipdnn",
+ "msg": r"Efficient or cuDNN Attention was not built"},
+
+ # --- Will not be supported on ROCm: test_transformers with (no message) ---
+ {"reason": "Will not be supported on ROCm",
+ "file": r"^test_transformers$",
+ "cls": r"SDPA.*CUDA",
+ "msg": r"^$"},
+
+ # --- transformers: test_transformers / test_flop_counter with misc messages ---
+ {"reason": "transformers",
+ "file": r"^test_transformers$",
+ "msg": r"Does not support all SDPA backends"},
+ {"reason": "transformers",
+ "file": r"^test_flop_counter$"},
+
+ # --- bfloat16: test_sparse_csr with (no message) ---
+ {"reason": "bfloat16",
+ "file": r"^test_sparse_csr$",
+ "cls": r"[Bb]float16|bf16"},
+ {"reason": "bfloat16",
+ "file": r"^test_sparse$",
+ "cls": r"[Bb]float16|bf16"},
+ {"reason": "bfloat16",
+ "file": r"^test_matmul_cuda$",
+ "msg": r"ROCm doesn't support CUTLASS"},
+
+ # --- explicit NVIDIA test: test_sparse_semi_structured with cutlass in NAME ---
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_sparse_semi_structured$",
+ "name": r"(?i)cutlass"},
+
+ # --- cusparselt: everything else in test_sparse_semi_structured ---
+ {"reason": "cusparselt",
+ "file": r"^test_sparse_semi_structured$"},
+
+ # --- Quantization: distributed quantization tests ---
+ {"reason": "Quantization",
+ "msg": r"Test skipped for ROCm",
+ "file": r"distributed\.algorithms\.quantization"},
+
+ # --- Process Group: distributed spawn/c10d with "Test skipped for ROCm" ---
+ {"reason": "Process Group",
+ "msg": r"Test skipped for ROCm",
+ "file": r"distributed\.test_distributed_spawn.*nccl"},
+
+ # ==================================================================
+ # TIER 2: Message-based rules (strong signal from skip message)
+ # ==================================================================
+
+ # SDPA_ME
+ {"reason": "SDPA_ME",
+ "msg": r"_fill_mem_eff_dropout_mask"},
+ {"reason": "SDPA_ME",
+ "msg": r"Efficient attention on ROCM doesn't support custom_mask_type"},
+ {"reason": "SDPA_ME",
+ "msg": r"Efficient Attention on ROCM does not support head_dim"},
+
+ # SDPA_FA
+ {"reason": "SDPA_FA",
+ "msg": r"Large numerical errors on ROCM"},
+ {"reason": "SDPA_FA",
+ "msg": r"flash attention not supported"},
+
+ # Will not be supported on ROCm
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"head_dim != head_dim_v unsupported on ROCm"},
+
+ # Triton 3.7 bump
+ {"reason": "triton 3.7 bump",
+ "msg": r"skipIfRocm.*Fails with Triton 3\.7"},
+
+ # MIOpen
+ {"reason": "MIOpen Convolutions",
+ "msg": r"Marked as skipped for MIOpen"},
+
+ # Static CUDA launcher
+ {"reason": "static cuda launcher",
+ "msg": r"Static cuda launcher doesn't work with ROCM"},
+
+ # NUMBA
+ {"reason": "NUMBA",
+ "msg": r"No numba\.cuda"},
+
+ # int4
+ {"reason": "int4",
+ "msg": r"_int4_mm is supported only for CDNA"},
+
+ # FP8
+ {"reason": "FP8",
+ "msg": r"cuBLAS blockwise scaling"},
+
+ # variable length attention
+ {"reason": "variable length attention",
+ "msg": r"ROCm does not support seqused_k"},
+
+ # CUDA IPC
+ {"reason": "Pass with unskip or minor mod",
+ "msg": r"CUDA IPC not available"},
+
+ # Python version
+ {"reason": "Python version",
+ "msg": r"Not supported in Python 3\.1[0-9]+"},
+
+ # cpp_test / CUDA not found
+ {"reason": "cpp_test",
+ "msg": r"CUDA not found"},
+ {"reason": "cpp_test",
+ "msg": r"CUDA_HOME not set"},
+
+ # Foreach
+ {"reason": "Foreach",
+ "msg": r"failed starting on ROCm"},
+
+ # CUTLASS
+ {"reason": "cutlass",
+ "msg": r"ROCm doesn't support CUTLASS|CUTLASS backend is not supported on HIP|ROCm and Windows doesn't support CUTLASS"},
+
+ # Transformers dependency
+ {"reason": "transformers",
+ "msg": r"No transformers"},
+
+ # hipGraph / cudaGraph (but NOT in functorch files -- those stay functorch)
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"Green contexts are not supported"},
+ {"reason": "functorch",
+ "msg": r"CUDA 12\.4 or greater is required for CUDA Graphs",
+ "file": r"^functorch\."},
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"CUDA 12\.4 or greater is required for CUDA Graphs"},
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"ROCM >= 5\.3 required for graphs.*cuda-bindings"},
+
+ # TMA / Blackwell
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Need.*TMA support"},
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Need Blackwell"},
+
+ # CUDA SM requirements
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires CUDA SM >= [0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires CUDA with SM >= [0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Test is only supported on CUDA 1[0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires NCCL version greater than"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Excluded from CUDA tests"},
+
+ # FP8 — MI300+ / H100+ only
+ {"reason": "FP8",
+ "msg": r"FP8 is only supported on H100\+|FP8 is not supported on this platform|FP8 requires H100\+"},
+ {"reason": "FP8",
+ "msg": r"requires gpu with fp8 support"},
+
+ # Symmetric memory
+ {"reason": "Symmetric memory",
+ "msg": r"SymmMem is not supported on this ROCm arch"},
+
+ # Python version / 3.12+
+ {"reason": "Python version",
+ "msg": r"Failing on python 3\.12\+|torch\.compile is not supported on python 3\.12\+|complex flaky in 3\.12"},
+
+ # Greater than 4 GPU (distributed)
+ {"reason": "Greater than 4 GPU",
+ "msg": r"Need at least 4 CUDA devices"},
+ {"reason": "Greater than 4 GPU",
+ "msg": r"Test requires.*world size of 4"},
+ {"reason": "Greater than 4 GPU",
+ "msg": r"requires [34] GPUs, found [12]"},
+
+ # tensor_parallel — architecture-specific skip
+ {"reason": "tensor_parallel",
+ "msg": r"test only runs on \('gfx942'"},
+
+ # Process Group: subprocess level skip
+ {"reason": "Process Group",
+ "msg": r"Test skipped at subprocess level"},
+
+ # Sharded Tensor: subprocess level skip in _shard
+ {"reason": "Sharded Tensor",
+ "msg": r"Test skipped at subprocess level",
+ "file": r"distributed\._shard"},
+
+ # Process Group: NCCL version / device assert
+ {"reason": "Process Group",
+ "msg": r"NCCL test requires 2\+ GPUs"},
+
+ # Misc: ROCm preserves subnormals
+ {"reason": "Misc",
+ "msg": r"ROCm preserves subnormals"},
+
+ # Misc: GCC codegen
+ {"reason": "Misc",
+ "msg": r"Fails under GCC 1[0-9] due to vector codegen"},
+
+ # Misc: Skipped on ROCm due to hang
+ {"reason": "Misc",
+ "msg": r"Skipped on ROCm due to hang"},
+
+ # Misc: Test skipped for ROCm (generic distributed)
+ {"reason": "Misc",
+ "msg": r"Test skipped for ROCm"},
+
+ # Misc: architecture-specific skips
+ {"reason": "Misc",
+ "msg": r"test skipped on \('gfx"},
+
+ # cuFFT-specific
+ {"reason": "Misc",
+ "msg": r"cuFFT-specific"},
+
+ # ROCTracer profiler
+ {"reason": "Memory allocation",
+ "msg": r"ROCTracer does not capture"},
+
+ # expandable_segments-related messages
+ {"reason": "expandable_segments",
+ "msg": r"expandable_segments mode is not supported on ROCm"},
+ {"reason": "expandable_segments",
+ "msg": r"CUDA >= 11\.0 required for external events in cuda graphs.*rocm"},
+
+ # not enabled by default on rocm
+ {"reason": "expandable_segments",
+ "msg": r"not enabled by default on rocm"},
+
+ # HIP runtime context
+ {"reason": "Misc",
+ "msg": r"HIP runtime doesn't create context"},
+
+ # ==================================================================
+ # TIER 3: File + class based rules (for empty/generic messages)
+ # ==================================================================
+
+ # --- test_cuda class-based disambiguation ---
+ {"reason": "Misc",
+ "file": r"^test_cuda$",
+ "cls": r"^TestCuda$"},
+ {"reason": "compiled optimizer",
+ "file": r"^test_cuda$",
+ "cls": r"TestCudaOptims"},
+ {"reason": "Misc",
+ "file": r"^test_cuda$",
+ "cls": r"TestCudaAutocast"},
+ {"reason": "cpp_test",
+ "file": r"^test_cuda$",
+ "cls": r"TestCompileKernel"},
+
+ # --- test_nn (MI200-specific skips, no message) ---
+ {"reason": "Misc",
+ "file": r"^test_nn$"},
+
+ # --- inductor.test_fp8 ---
+ {"reason": "FP8",
+ "file": r"^inductor\.test_fp8$"},
+
+ # --- test_scaled_matmul_cuda ---
+ {"reason": "FP8",
+ "file": r"^test_scaled_matmul_cuda$"},
+
+ # --- inductor.test_torchinductor_strided_blocks ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_strided_blocks$"},
+
+ # --- inductor.test_flex_decoding ---
+ {"reason": "flex_decoding",
+ "file": r"^inductor\.test_flex_decoding$"},
+
+ # --- inductor.test_loop_ordering ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_loop_ordering$"},
+
+ # --- torch_np / numpy tests ---
+ {"reason": "NumPy",
+ "file": r"^torch_np\."},
+
+ # --- test_binary_ufuncs ---
+ {"reason": "Misc",
+ "file": r"^test_binary_ufuncs$"},
+
+ # --- test_fx ---
+ {"reason": "FX",
+ "file": r"^test_fx$"},
+
+ # --- profiler.test_execution_trace ---
+ {"reason": "Profiler",
+ "file": r"^profiler\.test_execution_trace$"},
+
+ # --- test_cpp_api_parity ---
+ {"reason": "cpp_test",
+ "file": r"^test_cpp_api_parity$"},
+
+ # --- test_expanded_weights ---
+ {"reason": "Misc",
+ "file": r"^test_expanded_weights$"},
+
+ # --- test_linalg (arch-specific skips) ---
+ {"reason": "Linalg",
+ "file": r"^test_linalg$"},
+
+ # --- test_torch (arch-specific skips) ---
+ {"reason": "Misc",
+ "file": r"^test_torch$"},
+
+ # --- nn.test_convolution (arch-specific) ---
+ {"reason": "MIOpen Convolutions",
+ "file": r"^nn\.test_convolution$"},
+
+ # --- inductor.test_aot_inductor_arrayref ---
+ {"reason": "PT2.0 - AOTInductor",
+ "file": r"^inductor\.test_aot_inductor_arrayref$"},
+
+ # --- distributed.test_symmetric_memory ---
+ {"reason": "Symmetric memory",
+ "file": r"^distributed\.test_symmetric_memory$"},
+
+ # --- inductor.test_compiled_autograd HigherOrderOp (MI300 has more classes) ---
+ {"reason": "functorch",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"HigherOrderOp"},
+
+ # --- explicit NVIDIA test in various files ---
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_cuda_nvml_based_avail$"},
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_cpp_extensions_aot"},
+
+ # --- hipGraph/cudaGraph: only test_graph_* (NOT test_cuda_graph_*) in test_cuda_expandable_segments ---
+ {"reason": "hipGraph/cudaGraph",
+ "file": r"^test_cuda_expandable_segments$",
+ "name": r"^test_graph_"},
+
+ # --- expandable_segments (everything else in test_cuda_expandable_segments) ---
+ {"reason": "expandable_segments",
+ "file": r"^test_cuda_expandable_segments$"},
+
+ # --- Profiler ---
+ {"reason": "Profiler",
+ "file": r"^profiler\.test_profiler$"},
+
+ # --- serialization ---
+ {"reason": "serialization",
+ "file": r"^test_serialization$"},
+
+ # --- dataloader ---
+ {"reason": "dataloader",
+ "file": r"^test_dataloader$"},
+
+ # --- Multi-Processing ---
+ {"reason": "Multi-Processing",
+ "file": r"^test_multiprocessing_spawn$"},
+ {"reason": "Multi-Processing",
+ "file": r"^test_multiprocessing$"},
+
+ # --- hipSparse ---
+ {"reason": "hipSparse",
+ "file": r"^test_sparse_csr$"},
+ {"reason": "hipSparse",
+ "file": r"^test_sparse$",
+ "msg": r"^$"},
+
+ # --- nested tensor ---
+ {"reason": "nested tensor",
+ "file": r"^test_nestedtensor$"},
+
+ # --- asm_elementwise ---
+ {"reason": "asm_elementwise",
+ "file": r"higher_order_ops\.test_inline_asm_elementwise"},
+
+ # --- torchinductor_opinfo_properties ---
+ {"reason": "torchinductor_opinfo_properties",
+ "file": r"^inductor\.test_torchinductor_opinfo_properties$"},
+
+ # --- flex_attention ---
+ {"reason": "flex_attention",
+ "file": r"^inductor\.test_flex_attention$"},
+
+ # --- compiled optimizer ---
+ {"reason": "compiled optimizer",
+ "file": r"^inductor\.test_compiled_optimizers$"},
+
+ # --- inductor combo_kernels ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_combo_kernels$"},
+
+ # --- inductor compiled_autograd (remaining after FuncTorch/DTensor class rules) ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compiled_autograd$"},
+
+ # --- Foreach (inductor) ---
+ {"reason": "Foreach",
+ "file": r"^inductor\.test_foreach$"},
+
+ # --- inductor codecache / cudacodecache ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_codecache$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cudacodecache$"},
+
+ # --- inductor GPU cpp wrapper ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_gpu_cpp_wrapper$"},
+
+ # --- inductor torchinductor variants ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_dynamic_shapes$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_codegen_dynamic_shapes$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_opinfo$"},
+
+ # --- inductor compile subprocess ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compile_subprocess$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compile_worker$"},
+
+ # --- inductor cpu/cuda repro ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cpu_repro$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cuda_repro$"},
+
+ # --- inductor custom lowering / minifier ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_custom_lowering$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_minifier"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_mix_order"},
+
+ # --- inductor aot_inductor ---
+ {"reason": "PT2.0 - AOTInductor",
+ "file": r"^inductor\.test_aot_inductor"},
+
+ # --- functorch ---
+ {"reason": "functorch",
+ "file": r"^functorch\."},
+
+ # --- dynamo ---
+ {"reason": "PT2.0 - Dynamo",
+ "file": r"^dynamo\."},
+
+ # --- export ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^export\."},
+
+ # --- tf32: test_nn with "Test is disabled" ---
+ {"reason": "tf32",
+ "file": r"^test_nn$",
+ "msg": r"Test is disabled"},
+
+ # --- MIOpen Convolutions ---
+ {"reason": "MIOpen Convolutions",
+ "file": r"^nn\.test_convolution$"},
+
+ # --- test_stateless ---
+ {"reason": "Misc",
+ "file": r"^test_stateless$"},
+
+ # --- test_cuda_primary_ctx ---
+ {"reason": "Misc",
+ "file": r"^test_cuda_primary_ctx$"},
+
+ # --- test_torchfuzz ---
+ {"reason": "Misc",
+ "file": r"^test_torchfuzz"},
+
+ # ==================================================================
+ # TIER 4: Distributed file-based rules
+ # ==================================================================
+
+ # Sharded Tensor
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._shard\."},
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._composable\.fsdp\.test_fully_shard_training$"},
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._composable\.fsdp\.test_fully_shard_clip_grad"},
+
+ # tensor_parallel
+ {"reason": "tensor_parallel",
+ "file": r"^distributed\.tensor\.parallel\."},
+
+ # pipeline_parallel
+ {"reason": "pipeline_parallel",
+ "file": r"^distributed\.pipelining\."},
+
+ # FSDP
+ {"reason": "FSDP",
+ "file": r"^distributed\.fsdp\."},
+ {"reason": "FSDP",
+ "file": r"^distributed\._composable\.fsdp\."},
+
+ # 2D FSDP / composability
+ {"reason": "2D FSDP",
+ "file": r"^distributed\._composable\.test_composability"},
+
+ # DDP / replicate
+ {"reason": "DDP",
+ "file": r"^distributed\._composable\.test_replicate"},
+
+ # Process Group / c10d
+ {"reason": "Process Group",
+ "file": r"^distributed\.test_c10d_"},
+
+ # PT2.0 - Distributed (dynamo_distributed)
+ {"reason": "PT2.0 - Distributed",
+ "file": r"^distributed\.test_dynamo_distributed$"},
+
+ # Collectives (tensor ops, composability, nccl)
+ {"reason": "Collectives",
+ "file": r"^distributed\.tensor\.test_"},
+ {"reason": "Collectives",
+ "file": r"^distributed\.test_composability$"},
+ {"reason": "Collectives",
+ "file": r"^distributed\.test_nccl$"},
+
+ # Distributed tools
+ {"reason": "Misc",
+ "file": r"^distributed\._tools\."},
+
+ # Distributed elastic
+ {"reason": "elastic",
+ "file": r"^distributed\.elastic\."},
+
+ # Distributed quantization
+ {"reason": "Quantization",
+ "file": r"^distributed\.algorithms\.quantization"},
+
+ # Distributed rpc
+ {"reason": "Misc",
+ "file": r"^distributed\.rpc\."},
+
+ # Distributed spawn
+ {"reason": "Misc",
+ "file": r"^distributed\.test_distributed_spawn"},
+
+ # Distributed (generic catch-all)
+ {"reason": "Misc",
+ "file": r"^distributed\."},
+
+ # ==================================================================
+ # TIER 5: Generic message fallbacks
+ # ==================================================================
+
+ # "Test is disabled" messages
+ {"reason": "Misc",
+ "msg": r"Test is disabled because an issue exists disabling it"},
+
+ # Generic skipIfRocm / skipCUDAIfRocm
+ {"reason": "Misc",
+ "msg": r"skipIfRocm.*doesn't currently work on the ROCm stack"},
+ {"reason": "Misc",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work on the ROCm stack"},
+
+ # "Skipped!" / "Skipped"
+ {"reason": "Misc",
+ "msg": r"^Skipped!?$"},
+
+ # "Skipped on ROCm"
+ {"reason": "Misc",
+ "msg": r"^Skipped on ROCm$"},
+
+ # Not supported on ROCm (generic)
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Not supported on ROCm"},
+
+ # ==================================================================
+ # TIER 6: Catch-all for remaining test_cuda (no message, generic class)
+ # ==================================================================
+ {"reason": "Misc",
+ "file": r"^test_cuda$"},
+]
+
+
+def extract_message(raw_msg: str) -> str:
+ """Extract a clean message string from the raw CSV message_rocm value."""
+ if not raw_msg or raw_msg.strip() == '':
+ return ''
+ try:
+ d = ast.literal_eval(raw_msg)
+ if isinstance(d, dict):
+ return d.get('message', str(d))
+ except (ValueError, SyntaxError):
+ pass
+ return raw_msg.strip()
+
+
+def classify_test(msg: str, test_file: str, test_class: str, test_name: str,
+ workflow: str = '') -> str | None:
+ """Return the skip_reason for a test, or None if no rule matches."""
+ for rule in RULES:
+ match = True
+ if 'msg' in rule:
+ if not re.search(rule['msg'], msg, re.IGNORECASE):
+ match = False
+ if 'file' in rule and match:
+ if not re.search(rule['file'], test_file):
+ match = False
+ if 'cls' in rule and match:
+ if not re.search(rule['cls'], test_class):
+ match = False
+ if 'name' in rule and match:
+ if not re.search(rule['name'], test_name):
+ match = False
+ if 'workflow' in rule and match:
+ if workflow and workflow != rule['workflow']:
+ match = False
+ if match:
+ return rule['reason']
+ return None
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Auto-classify skip reasons for ROCm parity CSVs')
+ parser.add_argument('-i', '--input', required=True,
+ help='Input parity CSV file')
+ parser.add_argument('-o', '--output',
+ help='Output CSV with auto-classified skip_reason column')
+ parser.add_argument('--tsv-out',
+ help='Also write a TSV file in skip_reasons format '
+ '(compatible with --skip_reasons in summarize_xml_testreports.py)')
+ parser.add_argument('--only-unclassified', action='store_true',
+ help='Only classify tests that have no skip_reason (default)')
+ parser.add_argument('--reclassify-all', action='store_true',
+ help='Re-classify all tests, overwriting existing skip_reason')
+ parser.add_argument('--report', action='store_true',
+ help='Print classification report to stderr')
+ parser.add_argument('--dry-run', action='store_true',
+ help='Print report but do not write output files')
+ return parser.parse_args()
+
+
+def detect_columns(fieldnames):
+ """Detect whether CSV uses status_rocm/status_cuda or status_set1/status_set2."""
+ if 'status_rocm' in fieldnames:
+ return 'status_rocm', 'status_cuda', 'message_rocm'
+ elif 'status_set1' in fieldnames:
+ return 'status_set1', 'status_set2', 'message_set1'
+ else:
+ raise ValueError(f"Cannot detect status columns. Available: {fieldnames}")
+
+
+def main():
+ args = parse_args()
+
+ rows = []
+ with open(args.input, newline='') as f:
+ reader = csv.DictReader(f)
+ fieldnames = list(reader.fieldnames)
+ for row in reader:
+ rows.append(row)
+
+ col_rocm, col_cuda, col_msg = detect_columns(fieldnames)
+
+ for col in ('skip_reason', 'assignee', 'comments'):
+ if col not in fieldnames:
+ fieldnames.append(col)
+
+ classified_count = 0
+ already_had_count = 0
+ unclassified_count = 0
+ overwritten_count = 0
+ auto_reasons = Counter()
+ unclassified_msgs = Counter()
+ unclassified_files = Counter()
+ unclassified_details = []
+
+ tsv_entries = []
+
+ for row in rows:
+ status_rocm = row.get(col_rocm, '')
+ status_cuda = row.get(col_cuda, '')
+ existing_reason = row.get('skip_reason', '').strip()
+
+ needs_reason = (
+ status_rocm in ('SKIPPED', 'MISSED')
+ and status_cuda == 'PASSED'
+ )
+
+ if not needs_reason:
+ continue
+
+ raw_msg = row.get(col_msg, '')
+ msg = extract_message(raw_msg)
+ test_file = row.get('test_file', '')
+ test_class = row.get('test_class', '')
+ test_name = row.get('test_name', '')
+ workflow = row.get('test_config', '')
+
+ if existing_reason and not args.reclassify_all:
+ already_had_count += 1
+ tsv_entries.append({
+ 'test_file': test_file,
+ 'test_name': test_name,
+ 'test_class': test_class,
+ 'skip_reason': existing_reason,
+ 'assignee': row.get('assignee', ' '),
+ 'comments': row.get('comments', ' '),
+ })
+ continue
+
+ reason = classify_test(msg, test_file, test_class, test_name, workflow)
+
+ if reason:
+ if existing_reason and existing_reason != reason:
+ overwritten_count += 1
+ row['skip_reason'] = reason
+ row.setdefault('assignee', '')
+ row.setdefault('comments', 'auto-classified')
+ classified_count += 1
+ auto_reasons[reason] += 1
+ tsv_entries.append({
+ 'test_file': test_file,
+ 'test_name': test_name,
+ 'test_class': test_class,
+ 'skip_reason': reason,
+ 'assignee': row.get('assignee', ' ') if not args.reclassify_all else ' ',
+ 'comments': 'auto-classified',
+ })
+ else:
+ unclassified_count += 1
+ display_msg = msg[:100] if msg else '(no message)'
+ unclassified_msgs[display_msg] += 1
+ unclassified_files[test_file] += 1
+ unclassified_details.append(
+ f" {test_file:55s} {test_class:45s} {test_name[:40]:42s} {display_msg[:50]}")
+
+ if args.report or args.dry_run:
+ total = already_had_count + classified_count + unclassified_count
+ print(f"\n{'='*60}", file=sys.stderr)
+ print(f"AUTO-CLASSIFICATION REPORT", file=sys.stderr)
+ print(f"{'='*60}", file=sys.stderr)
+ print(f"Already had skip_reason: {already_had_count}", file=sys.stderr)
+ print(f"Auto-classified: {classified_count}", file=sys.stderr)
+ if overwritten_count:
+ print(f" (overwritten existing: {overwritten_count})", file=sys.stderr)
+ print(f"Still unclassified: {unclassified_count}", file=sys.stderr)
+ if total:
+ pct = (already_had_count + classified_count) / total * 100
+ print(f"Coverage: {pct:.1f}%", file=sys.stderr)
+ print(f"Total target tests: {total}", file=sys.stderr)
+
+ if auto_reasons:
+ print(f"\nAuto-classified by category:", file=sys.stderr)
+ for reason, cnt in auto_reasons.most_common():
+ print(f" {cnt:5d} {reason}", file=sys.stderr)
+
+ if unclassified_msgs:
+ print(f"\nUnclassified — top messages:", file=sys.stderr)
+ for msg_key, cnt in unclassified_msgs.most_common(15):
+ print(f" {cnt:5d} {msg_key}", file=sys.stderr)
+
+ if unclassified_files:
+ print(f"\nUnclassified — top files:", file=sys.stderr)
+ for f, cnt in unclassified_files.most_common(15):
+ print(f" {cnt:5d} {f}", file=sys.stderr)
+
+ if unclassified_details and len(unclassified_details) <= 50:
+ print(f"\nUnclassified tests:", file=sys.stderr)
+ for d in unclassified_details:
+ print(d, file=sys.stderr)
+
+ if args.dry_run:
+ return
+
+ if not args.output:
+ print("No --output specified; use --dry-run for report-only mode.",
+ file=sys.stderr)
+ sys.exit(1)
+
+ with open(args.output, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+
+ if args.tsv_out and tsv_entries:
+ with open(args.tsv_out, 'w', newline='') as f:
+ writer = csv.DictWriter(
+ f,
+ fieldnames=['test_file', 'test_name', 'test_class',
+ 'skip_reason', 'assignee', 'comments'],
+ delimiter='\t',
+ )
+ writer.writeheader()
+ for entry in tsv_entries:
+ writer.writerow(entry)
+ print(f"\nWrote TSV with {len(tsv_entries)} entries to {args.tsv_out}",
+ file=sys.stderr)
+
+ print(f"Wrote {len(rows)} rows to {args.output}", file=sys.stderr)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py b/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py
new file mode 100755
index 000000000000..b563fcf74bc1
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py
@@ -0,0 +1,570 @@
+#!/usr/bin/env python3
+"""Scan CI log files (.txt) for test failures not captured in XML reports.
+
+Tests that timeout (exit code 124), crash (SIGIOT, SIGSEGV, Fatal Python error),
+or are killed (SIGKILL, OOM) never produce JUnit XML output. This script detects
+those failures from the raw log files and outputs a CSV/summary.
+
+Usage:
+ python detect_log_failures.py --logs-dir [--output ]
+"""
+
+import argparse
+import csv
+import os
+import re
+import sys
+from collections import defaultdict
+from datetime import datetime
+from pathlib import Path
+
+
+RE_RUNNING = re.compile(
+ r"Running (?P\S+) (?P\d+)/(?P\d+) \.\.\."
+)
+RE_SUCCESS = re.compile(
+ r"(?P\S+) (?P\d+)/(?P\d+) was successful"
+)
+RE_FAILED = re.compile(
+ r"(?P\S+) (?P\d+)/(?P\d+) failed!(?P.*)"
+)
+RE_EXIT_CODE = re.compile(r"Got exit code (?P\d+)")
+RE_TIMEOUT = re.compile(r"Command took >(\d+)min, returning 124")
+RE_FAILED_CONSISTENTLY = re.compile(
+ r"FAILED CONSISTENTLY: (?P\S+)"
+)
+RE_STEPCURRENT = re.compile(
+ r"stepcurrent:.*Running only (?:test/)?(?P\S+)"
+)
+RE_INDIVIDUAL_TEST = re.compile(
+ r"(?P\S+\.py::(?P\w+)::(?P\w+))"
+)
+RE_INDIV_PASSED = re.compile(
+ r"(?:test/)?(?P\S+\.py)::(?P\w+)::(?P\S+?)\s+PASSED"
+)
+RE_NEW_PROCESS_SUCCESS = re.compile(r"Test succeeded in new process")
+
+CRASH_PATTERNS = [
+ (re.compile(r"Segmentation fault", re.IGNORECASE), "SEGFAULT"),
+ (re.compile(r"SIGSEGV"), "SIGSEGV"),
+ (re.compile(r"SIGIOT"), "SIGIOT"),
+ (re.compile(r"SIGABRT"), "SIGABRT"),
+ (re.compile(r"SIGKILL"), "SIGKILL"),
+ (re.compile(r"Fatal Python error", re.IGNORECASE), "FATAL_PYTHON"),
+ (re.compile(r"core dumped", re.IGNORECASE), "CORE_DUMP"),
+ (re.compile(r"Aborted \(core dumped\)", re.IGNORECASE), "ABORTED"),
+ (re.compile(r"torch\.cuda\.OutOfMemoryError"), "CUDA_OOM"),
+ (re.compile(r"std::bad_alloc"), "BAD_ALLOC"),
+]
+
+LOG_FILE_MAP = {
+ "rocm": ("rocm", "default"),
+ "rocm_dist": ("rocm", "distributed"),
+ "rocm_inductor": ("rocm", "inductor"),
+ "cuda": ("cuda", "default"),
+ "cuda_dist": ("cuda", "distributed"),
+ "cuda_inductor": ("cuda", "inductor"),
+ "baseline": ("baseline", "default"),
+}
+
+
+def classify_log_file(filename):
+ """Return (platform, test_config, shard_num) from a log filename like rocm3.txt."""
+ stem = Path(filename).stem
+ for prefix, (platform, test_config) in sorted(LOG_FILE_MAP.items(), key=lambda x: -len(x[0])):
+ if stem.startswith(prefix):
+ remainder = stem[len(prefix):]
+ if remainder.isdigit():
+ return platform, test_config, int(remainder)
+ return None, None, None
+
+
+RE_TIMESTAMP = re.compile(r"^\d{4}-\d{2}-\d{2}T[\d:.]+Z\s*")
+RE_TS_CAPTURE = re.compile(r"^(\d{4}-\d{2}-\d{2}T[\d:.]+)Z")
+
+
+def _parse_ts(s):
+ """Parse a CI log ISO-8601 timestamp (without the trailing Z) to datetime.
+
+ Returns None if it doesn't match the expected shapes (with or without
+ fractional seconds), so callers can skip un-timestamped lines."""
+ for fmt in ("%Y-%m-%dT%H:%M:%S.%f", "%Y-%m-%dT%H:%M:%S"):
+ try:
+ return datetime.strptime(s, fmt)
+ except ValueError:
+ continue
+ return None
+
+
+def parse_log_file(filepath):
+ """Parse a single log file and return test file results, consistent failures,
+ and flaky tests.
+
+ A flaky test is one that failed in its normal-process run but PASSED when the
+ CI harness re-ran it alone in a new subprocess (indicated by a PASSED line
+ for the specific test::class::method, followed by 'Test succeeded in new
+ process, continuing with the rest of the tests').
+ """
+ results = {}
+ current_test = None
+ last_failed_test = None
+ consistent_failures = []
+ flaky_tests = []
+ last_passed_individual = None
+ first_ts = None
+ last_ts = None
+
+ with open(filepath, "r", errors="replace") as f:
+ for line in f:
+ # First/last ISO timestamps give the job's wall-clock (the run time
+ # for log-based failures, which have no XML). Must run before the
+ # line filter below, since skipped lines still advance the clock.
+ m_ts = RE_TS_CAPTURE.match(line)
+ if m_ts:
+ ts = _parse_ts(m_ts.group(1))
+ if ts is not None:
+ if first_ts is None:
+ first_ts = ts
+ last_ts = ts
+
+ # Lightweight tracking of individual pytest test lines.
+ # These are very frequent (~37% of lines) so we extract the
+ # test name directly without timestamp stripping.
+ if ".py::" in line:
+ m_ind = RE_INDIVIDUAL_TEST.search(line)
+ if m_ind:
+ active = current_test or last_failed_test
+ if active and active in results:
+ # Only update if the pytest path belongs to this shard's test file,
+ # otherwise rerun output from earlier shards contaminates later ones.
+ shard_file = results[active]["test_file"]
+ if shard_file + ".py" in m_ind.group("test_path"):
+ results[active]["last_test"] = f"{m_ind.group('cls')}::{m_ind.group('method')}"
+
+ if " ... [" not in line and "was successful" not in line \
+ and "failed!" not in line and "Got exit code" not in line \
+ and "returning 124" not in line and "FAILED CONSISTENTLY" not in line \
+ and "Retrying" not in line \
+ and "Segmentation fault" not in line and "SIGIOT" not in line \
+ and "SIGSEGV" not in line and "SIGABRT" not in line \
+ and "SIGKILL" not in line \
+ and "Fatal Python error" not in line and "core dumped" not in line \
+ and "Aborted (core dumped)" not in line \
+ and "OutOfMemoryError" not in line \
+ and "bad_alloc" not in line \
+ and "stepcurrent" not in line \
+ and "PASSED" not in line \
+ and "new process" not in line:
+ continue
+
+ stripped = RE_TIMESTAMP.sub("", line).rstrip()
+
+ m = RE_RUNNING.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ current_test = key
+ if key not in results:
+ results[key] = {
+ "test_file": m.group("test_file"),
+ "shard": int(m.group("shard")),
+ "total": int(m.group("total")),
+ "status": "RUNNING",
+ "reason": "",
+ "exit_codes": [],
+ "crashes": [],
+ "crash_tests": [],
+ "last_test": "",
+ }
+ continue
+
+ m = RE_SUCCESS.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ if key in results:
+ results[key]["status"] = "PASSED"
+ current_test = None
+ last_failed_test = None
+ continue
+
+ m = RE_FAILED.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ reason = m.group("reason").strip()
+ if key in results:
+ results[key]["status"] = "FAILED"
+ if reason:
+ results[key]["reason"] = reason
+ last_failed_test = key
+ current_test = key
+ continue
+
+ active = current_test or last_failed_test
+
+ # Track stepcurrent rerun lines — identifies crash-causing test
+ m = RE_STEPCURRENT.search(stripped)
+ if m:
+ test_path = m.group("test_path")
+ parts = test_path.split("::")
+ if len(parts) >= 3:
+ crash_id = f"{parts[1]}::{parts[2]}"
+ elif len(parts) == 2:
+ crash_id = parts[1]
+ else:
+ crash_id = None
+ if crash_id and active and active in results:
+ shard_file = results[active]["test_file"]
+ if shard_file in test_path:
+ if crash_id not in results[active]["crash_tests"]:
+ results[active]["crash_tests"].append(crash_id)
+ continue
+
+ # Track individual pytest test lines for last-running-test context
+ m_ind = RE_INDIVIDUAL_TEST.search(stripped)
+ if m_ind and active and active in results:
+ cls = m_ind.group("cls")
+ method = m_ind.group("method")
+ results[active]["last_test"] = f"{cls}::{method}"
+
+ m = RE_EXIT_CODE.search(stripped)
+ if m:
+ code = int(m.group("code"))
+ if active and active in results:
+ results[active]["exit_codes"].append(code)
+
+ m = RE_TIMEOUT.search(stripped)
+ if m and active and active in results:
+ if "TIMEOUT" not in results[active]["crashes"]:
+ results[active]["crashes"].append("TIMEOUT")
+
+ m = RE_FAILED_CONSISTENTLY.search(stripped)
+ if m:
+ shard_str = ""
+ if active and active in results:
+ info = results[active]
+ shard_str = f"{info['shard']}/{info['total']}"
+ consistent_failures.append((m.group("test_path"), shard_str))
+
+ # Detect individual PASSED lines for flaky-rerun tracking.
+ m = RE_INDIV_PASSED.search(stripped)
+ if m:
+ last_passed_individual = {
+ "file": m.group("file"),
+ "cls": m.group("cls"),
+ "method": m.group("method"),
+ "active": active,
+ }
+
+ # When we see 'Test succeeded in new process' after a PASSED
+ # individual test, that test was originally failing in the main
+ # process (CI only falls back to rerun-in-new-process for tests
+ # that crashed or failed) but passed on retry -> flaky.
+ if RE_NEW_PROCESS_SUCCESS.search(stripped) and last_passed_individual:
+ lp = last_passed_individual
+ lp_active = lp.get("active")
+ test_shard = ""
+ if lp_active and lp_active in results:
+ info = results[lp_active]
+ test_shard = f"{info['shard']}/{info['total']}"
+ flaky_tests.append({
+ "file": lp["file"],
+ "cls": lp["cls"],
+ "method": lp["method"],
+ "test_shard": test_shard,
+ })
+ last_passed_individual = None
+
+ if active and active in results:
+ for pattern, label in CRASH_PATTERNS:
+ if pattern.search(stripped):
+ if label not in results[active]["crashes"]:
+ results[active]["crashes"].append(label)
+
+ job_run_time = ""
+ if first_ts is not None and last_ts is not None:
+ job_run_time = round(max(0.0, (last_ts - first_ts).total_seconds()), 2)
+
+ return results, consistent_failures, flaky_tests, job_run_time
+
+
+def scan_logs(logs_dir):
+ """Scan all log files and return non-passing test file results plus a
+ test-level shard inventory.
+
+ Returns (all_failures, shard_inventory) where shard_inventory is a list
+ of dicts with one entry per (platform, test_config, job_shard, test_file)
+ combination seen in the logs, plus a sorted comma-separated list of the
+ test-level shards observed (e.g. "1/1" or "1/15,2/15,...,15/15"). This
+ lets downstream consumers look up the test-level shard for any XML-based
+ failure whose only shard info is the job-level shard."""
+ all_failures = []
+ all_flaky = []
+ shard_map = defaultdict(set)
+
+ # Pre-compute job-level shard totals per (platform, test_config) by
+ # counting how many log files belong to each group. Log files are
+ # 1-indexed (e.g. rocm1.txt..rocm6.txt for a 6-way sharded job), so
+ # the count == total shards for that CI job.
+ shard_totals = defaultdict(int)
+ for fname in os.listdir(logs_dir):
+ if not fname.endswith(".txt"):
+ continue
+ platform, test_config, shard_num = classify_log_file(fname)
+ if platform is None:
+ continue
+ shard_totals[(platform, test_config)] += 1
+
+ for fname in sorted(os.listdir(logs_dir)):
+ if not fname.endswith(".txt"):
+ continue
+
+ platform, test_config, shard_num = classify_log_file(fname)
+ if platform is None:
+ continue
+
+ job_total = shard_totals.get((platform, test_config), 0)
+ job_shard_str = f"{shard_num}/{job_total}" if job_total else str(shard_num)
+
+ # If download_testlogs left a ".job_url" file next to this log,
+ # it contains the URL of the upstream pytorch CI job that produced
+ # the log. We surface it in the LOG-BASED FAILURES table as a link
+ # to that job's page. Empty for older runs that predate this.
+ job_url_file = os.path.join(logs_dir, fname + ".job_url")
+ job_url = ""
+ if os.path.isfile(job_url_file):
+ with open(job_url_file) as f:
+ job_url = f.read().strip()
+
+ filepath = os.path.join(logs_dir, fname)
+ results, consistent_failures, flaky_tests, job_run_time = parse_log_file(filepath)
+
+ for ft in flaky_tests:
+ file_part = ft["file"].replace("test/", "").replace(".py", "")
+ all_flaky.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": file_part,
+ "test_class": ft["cls"],
+ "test_name": ft["method"],
+ "job_shard": job_shard_str,
+ "test_shard": ft["test_shard"],
+ "run_time": job_run_time,
+ "job_url": job_url,
+ })
+
+ # Record every (test_file, test_shard) observed in this log file,
+ # including PASSED ones, so the inventory covers the full run.
+ for info in results.values():
+ shard_map[(platform, test_config, job_shard_str, info["test_file"])].add(
+ f"{info['shard']}/{info['total']}"
+ )
+
+ for key, info in results.items():
+ if info["status"] == "PASSED":
+ continue
+
+ categories = []
+ if 124 in info["exit_codes"] or "TIMEOUT" in info["crashes"]:
+ categories.append("TIMEOUT")
+ for c in info["crashes"]:
+ if c != "TIMEOUT":
+ categories.append(c)
+ if info["status"] == "FAILED" and not categories:
+ categories.append("FAILED")
+ if info["status"] == "RUNNING" and not categories:
+ categories.append("INCOMPLETE")
+
+ if not categories:
+ continue
+ # Skip tests stuck in RUNNING with no evidence of failure —
+ # these are typically from multi-shard logs where a different
+ # shard's "Running ..." line appeared but the result was elsewhere.
+ if info["status"] == "RUNNING" and categories == ["INCOMPLETE"]:
+ continue
+
+ reason = info["reason"]
+ # Populate reason with identified crash/timeout test name
+ crash_tests = info.get("crash_tests", [])
+ last_test = info.get("last_test", "")
+ identified_test = ""
+ if crash_tests:
+ identified_test = crash_tests[0]
+ elif last_test:
+ identified_test = last_test
+
+ if identified_test and "::" in identified_test:
+ if not reason:
+ reason = identified_test
+ elif "::" not in reason:
+ reason = f"{identified_test} | {reason}"
+
+ all_failures.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": info["test_file"],
+ "job_shard": job_shard_str,
+ "test_shard": f"{info['shard']}/{info['total']}",
+ "status": info["status"],
+ "category": "+".join(categories),
+ "reason": reason,
+ "exit_codes": ",".join(str(c) for c in info["exit_codes"]),
+ "run_time": job_run_time,
+ "job_url": job_url,
+ })
+
+ for test_path, shard_str in consistent_failures:
+ parts = test_path.split("::")
+ file_part = parts[0].replace("test/", "").replace(".py", "")
+ test_class = parts[1] if len(parts) > 1 else ""
+ test_name = parts[2] if len(parts) > 2 else ""
+
+ all_failures.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": file_part,
+ "job_shard": job_shard_str,
+ "test_shard": shard_str,
+ "status": "FAILED_CONSISTENTLY",
+ "category": "CONSISTENT_FAILURE",
+ "reason": f"{test_class}::{test_name}" if test_class else "",
+ "exit_codes": "",
+ "run_time": job_run_time,
+ "job_url": job_url,
+ })
+
+ def _sort_shards(vals):
+ def key(v):
+ try:
+ a, b = v.split("/", 1)
+ return (int(b), int(a))
+ except (ValueError, AttributeError):
+ return (0, 0)
+ return sorted(vals, key=key)
+
+ shard_inventory = [
+ {
+ "platform": platform,
+ "test_config": test_config,
+ "job_shard": job_shard_str,
+ "test_file": test_file,
+ "test_shards": ",".join(_sort_shards(shards)),
+ }
+ for (platform, test_config, job_shard_str, test_file), shards in shard_map.items()
+ ]
+ shard_inventory.sort(key=lambda r: (r["platform"], r["test_config"],
+ r["job_shard"], r["test_file"]))
+
+ all_flaky.sort(key=lambda r: (r["platform"], r["test_config"],
+ r["job_shard"], r["test_file"],
+ r["test_class"], r["test_name"]))
+
+ return all_failures, shard_inventory, all_flaky
+
+
+def write_csv_report(failures, output_path):
+ fieldnames = [
+ "log_file", "platform", "test_config", "test_file",
+ "job_shard", "test_shard",
+ "status", "category", "reason", "exit_codes",
+ "run_time",
+ "job_url",
+ ]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(failures)
+ print(f"Log failure report: {output_path} ({len(failures)} entries)")
+
+
+def write_shards_report(inventory, output_path):
+ fieldnames = ["platform", "test_config", "job_shard", "test_file", "test_shards"]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(inventory)
+ print(f"Log shard inventory: {output_path} ({len(inventory)} entries)")
+
+
+def write_flaky_report(flaky, output_path):
+ fieldnames = [
+ "log_file", "platform", "test_config", "test_file",
+ "test_class", "test_name", "job_shard", "test_shard",
+ "run_time",
+ "job_url",
+ ]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(flaky)
+ print(f"Flaky test report: {output_path} ({len(flaky)} entries)")
+
+
+def _derive_sibling_path(output_path, new_prefix):
+ """Given an output path like '.../log_failures_mi355.csv' and
+ new_prefix='log_shards', return '.../log_shards_mi355.csv'. Falls back to
+ appending '.{new_prefix}.csv' if the expected prefix isn't present."""
+ d, base = os.path.split(output_path)
+ if base.startswith("log_failures"):
+ return os.path.join(d, new_prefix + base[len("log_failures"):])
+ stem, ext = os.path.splitext(base)
+ return os.path.join(d, f"{stem}.{new_prefix}{ext or '.csv'}")
+
+
+def _derive_shards_path(output_path):
+ return _derive_sibling_path(output_path, "log_shards")
+
+
+def _derive_flaky_path(output_path):
+ return _derive_sibling_path(output_path, "flaky_tests")
+
+
+def print_summary(failures):
+ if not failures:
+ print("No log-based failures detected.")
+ return
+
+ by_category = defaultdict(list)
+ for f in failures:
+ by_category[f["category"]].append(f)
+
+ print(f"\n{'='*60}")
+ print("LOG FAILURE DETECTION SUMMARY")
+ print(f"{'='*60}")
+ print(f"Total failures detected: {len(failures)}")
+ print()
+
+ for cat, items in sorted(by_category.items()):
+ print(f" {cat}: {len(items)}")
+ for item in items:
+ print(f" - {item['test_file']} ({item['platform']}/{item['test_config']}) [{item['log_file']}]")
+ if item["reason"]:
+ print(f" Reason: {item['reason'][:120]}")
+ print()
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Detect test failures from CI log files not captured in XML reports"
+ )
+ parser.add_argument(
+ "--logs-dir", required=True,
+ help="Directory containing .txt log files"
+ )
+ parser.add_argument(
+ "--output", default="log_failures.csv",
+ help="Output CSV path (default: log_failures.csv)"
+ )
+ args = parser.parse_args()
+
+ failures, shard_inventory, flaky_tests = scan_logs(args.logs_dir)
+ print_summary(failures)
+ write_csv_report(failures, args.output)
+ write_shards_report(shard_inventory, _derive_shards_path(args.output))
+ write_flaky_report(flaky_tests, _derive_flaky_path(args.output))
+ return 0 if not failures else 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.automation_scripts/pytorch-unit-test-scripts/download_testlogs b/.automation_scripts/pytorch-unit-test-scripts/download_testlogs
new file mode 100755
index 000000000000..59eb9fbba81e
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/download_testlogs
@@ -0,0 +1,1201 @@
+#!/usr/bin/env python3
+
+
+try:
+ import os
+ import json
+ import time
+ import argparse
+ import requests
+ import re
+ import sys
+ from upload_stats_lib import unzip
+ from upload_test_stats import download_gha_artifacts, download_s3_artifacts
+except ImportError:
+ import subprocess
+ result = subprocess.run(["pip3", "install", "-U", "-r", "requirements.txt"], capture_output=True, text=True)
+ print(result.stdout)
+ print("Please rerun the download_testlogs script")
+ sys.exit(1)
+
+
+# Check if environment variables are set
+required_env_vars = ['GITHUB_TOKEN', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY']
+
+missing_vars = [var for var in required_env_vars if not os.getenv(var)]
+if missing_vars:
+ print(f"ERROR: Please set these environment variables: {', '.join(missing_vars)}")
+ sys.exit(1)
+
+
+# global variables
+error_msgs = []
+
+# Job-name matching config (workflow names, job-name prefixes, shard counts and
+# check-run regexes) lives in a single JSON file next to this script so
+# download_testlogs and parity-auto.yml share one source of truth. See
+# parity_job_config.json for the schema.
+PARITY_CONFIG_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "parity_job_config.json")
+with open(PARITY_CONFIG_PATH) as _f:
+ PARITY_CONFIG = json.load(_f)
+
+# Every ROCm runner label follows the rocm.gpu convention, so the S3 artifact
+# downloader always filters on it to avoid pulling foreign platforms' reports.
+ROCM_ARTIFACT_SUBSTRINGS = ["rocm.gpu"]
+
+# Test configs whose sources are listed in the per-arch config (in order).
+_TEST_CONFIGS = ("default", "distributed", "inductor")
+
+
+def parity_config_views(cfg):
+ """Unpack the per-config ordered [{workflow, job_prefix}, ...] lists into the
+ flat dicts the matching logic expects: (workflows, job_prefixes, fallbacks).
+
+ The first list entry per test config is the primary source; the second (if
+ present) is its fallback. Test configs absent from cfg are simply skipped.
+ """
+ workflows, job_prefixes, fallbacks = {}, {}, {}
+ for tc in _TEST_CONFIGS:
+ sources = cfg.get(tc)
+ if not sources:
+ continue
+ workflows[tc] = sources[0]["workflow"]
+ job_prefixes[tc] = sources[0]["job_prefix"]
+ if len(sources) > 1:
+ fallbacks[tc] = {"workflow": sources[1]["workflow"],
+ "job_prefix": sources[1]["job_prefix"]}
+ return workflows, job_prefixes, fallbacks
+
+
+# Workflow names mapped to TEST_CONFIG values in PyTorch CI.
+# ROCmWorkflowNames is set per --arch in main(); CUDA is arch-independent.
+ROCmWorkflowNames = {}
+CUDAWorkflowNames, CUDA_JOB_PREFIXES, _ = parity_config_views(PARITY_CONFIG["cuda"])
+
+authentication_headers = None
+
+def get_commit_hashes(pr_id, token):
+ owner = "pytorch"
+ repo = "pytorch"
+ commits_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_id}/commits"
+ headers = {
+ "Authorization": f"token {token}",
+ "Accept": "application/vnd.github.v3+json"
+ }
+ page = 1
+ commits = []
+ while True:
+ response = requests.get(commits_url, headers=headers, params={'page': page})
+ if response.status_code == 200:
+ new_commits = response.json()
+ if not new_commits:
+ break
+ commits.extend(new_commits)
+ page += 1
+ else:
+ print(f"Failed to fetch commits: {response.status_code}")
+ break
+ return commits
+
+def get_latest_commit_sha(pr_id, token):
+ commits = get_commit_hashes(pr_id, token)
+ if commits:
+ return commits[-1]['sha']
+ else:
+ print("No commits found for the given pull request.")
+ sys.exit(1)
+
+def write_test_log_to_file(filename, test_key, jobs, sha):
+ js = [j for j in jobs if test_key in j['name']]
+ if len(js) > 0:
+ if len(js) > 1:
+ print(f"WARNING: Found multiple jobs with key: '{test_key}', selecting first one")
+ for j in js:
+ print(j['name'])
+ test_id = js[0]['id']
+ print(f"key: {test_key}, job Name: {js[0]['name']}, job ID: {test_id}, Downloading to {filename}")
+ else:
+ # Not being able to download logs is not a fatal error since we primarily depend on xml artifacts
+ # so log error and continue
+ error_msg = f"Error: TEST KEY: {test_key} DOES NOT EXIST IN JOBS.\nCheck url - https://hud.pytorch.org/hud/pytorch/pytorch/{sha}/1?per_page=50 - for job name"
+ print(error_msg)
+ error_msgs.append(error_msg)
+ return
+ response = requests.get( "https://ossci-raw-job-status.s3.amazonaws.com/log/" + str(test_id) )
+ with open(filename, "w", encoding="utf-8") as f:
+ f.write(response.text)
+
+ # Save the upstream pytorch CI job's page URL next to the log so
+ # detect_log_failures.py can later surface it as a link in the
+ # LOG-BASED FAILURES table of the parity summary.
+ job_url = js[0].get('html_url', '')
+ if job_url:
+ with open(filename + ".job_url", "w", encoding="utf-8") as f:
+ f.write(job_url)
+
+def _fetch_jobs_page(jobs_url, params, max_retries=5):
+ """GET one page of the jobs API, retrying transient failures.
+
+ When many parity runs fire at once they hammer the jobs API and GitHub
+ responds with 403 (secondary rate limit) or 5xx. Those bodies have no
+ 'jobs' key, which previously surfaced as a bare KeyError. Retry with
+ exponential backoff (honouring Retry-After when present) and only give
+ up with a clear error after exhausting retries.
+ """
+ for attempt in range(max_retries):
+ response = requests.get(jobs_url, headers=authentication_headers, params=params)
+ if response.status_code == 200:
+ response_json = response.json()
+ if "jobs" in response_json:
+ return response_json
+ if attempt == max_retries - 1:
+ break
+ retry_after = response.headers.get("Retry-After")
+ delay = int(retry_after) if retry_after and retry_after.isdigit() else 2 ** attempt
+ print(f" WARNING: jobs API returned HTTP {response.status_code} for {jobs_url} "
+ f"(attempt {attempt + 1}/{max_retries}); retrying in {delay}s")
+ time.sleep(delay)
+ raise Exception(
+ f"jobs API did not return a 'jobs' payload after {max_retries} attempts "
+ f"(last HTTP {response.status_code}) for {jobs_url}"
+ )
+
+def get_workflow_jobs(wf, all_attempts=False):
+ """Get all jobs for a workflow run.
+
+ By default GitHub returns only the latest attempt's jobs. Pass
+ all_attempts=True to include jobs from every run attempt: when a run is
+ re-run, shards that already passed are carried over (not re-executed) and
+ their test-report artifacts keep the job ID from the attempt that produced
+ them, so artifact matching by job ID must consider all attempts.
+ """
+ if wf is None:
+ raise Exception("wf is None!")
+ page_size = 100 #max allowed by Github API
+ base_params = {'per_page': page_size}
+ if all_attempts:
+ base_params['filter'] = 'all'
+ response_json = _fetch_jobs_page(wf['jobs_url'], base_params)
+ jobs = response_json["jobs"]
+
+ if response_json['total_count'] > page_size:
+ import math
+ for i in range(2, math.ceil(response_json['total_count']/page_size) + 1):
+ page_json = _fetch_jobs_page(wf['jobs_url'], {**base_params, 'page': i})
+ jobs += page_json["jobs"]
+ return jobs
+
+def get_check_runs_for_commit(sha, prefix):
+ """Get check runs for a commit filtered by name prefix.
+
+ The workflow jobs API does not return jobs from reusable workflows
+ (workflow_call). The check-runs API returns all jobs regardless of
+ workflow nesting, so we use it as a fallback.
+ """
+ check_runs = []
+ page = 1
+ while True:
+ response = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/commits/{sha}/check-runs",
+ headers=authentication_headers,
+ params={'per_page': 100, 'page': page},
+ )
+ data = response.json()
+ runs = data.get('check_runs', [])
+ check_runs.extend([cr for cr in runs if prefix in cr.get('name', '')])
+ if len(runs) < 100:
+ break
+ page += 1
+ return check_runs
+
+def get_job_ids_by_prefix(wf, prefix):
+ """Get job IDs (as strings) for jobs whose name contains the given prefix."""
+ jobs = get_workflow_jobs(wf)
+ return [str(j['id']) for j in jobs if prefix in j['name']]
+
+def matches_job_prefix(job_name, prefix):
+ """Match the exact CUDA job family without also matching -debug/-sm86 jobs."""
+ return job_name.startswith(f"{prefix} / ")
+
+# CUDA test jobs are named "test-osdc" on main but plain "test" on PRs, so we
+# probe both spellings (in priority order) from the shared config.
+CUDA_TEST_KINDS = tuple(PARITY_CONFIG["cuda"]["test_kinds"])
+
+def get_cuda_test_jobs(jobs, cuda_job_prefix):
+ """Return the CUDA test kind and jobs for either main or PR CI layouts."""
+ for test_kind in CUDA_TEST_KINDS:
+ test_jobs = [
+ j for j in jobs
+ if matches_job_prefix(j['name'], cuda_job_prefix)
+ and f"/ {test_kind} (" in j['name']
+ ]
+ if test_jobs:
+ return test_kind, test_jobs
+ return CUDA_TEST_KINDS[0], []
+
+def get_cuda_inductor_test_jobs(jobs):
+ """Return the CUDA inductor test kind and jobs for either main or PR CI layouts."""
+ inductor_prefix = CUDA_JOB_PREFIXES["inductor"]
+ for test_kind in CUDA_TEST_KINDS:
+ test_jobs = [
+ j for j in jobs
+ if f"{inductor_prefix} / {test_kind} (inductor," in j['name']
+ ]
+ if test_jobs:
+ return test_kind, test_jobs
+ return CUDA_TEST_KINDS[0], []
+
+def download_logs(wf, test_log_list, test_folder, jobs=None):
+ if wf is None:
+ raise Exception("wf is None!")
+
+ if jobs is None:
+ jobs = get_workflow_jobs(wf)
+
+ for test_log in test_log_list:
+ write_out_file = test_folder + "/" + test_log[0]
+ write_test_log_to_file(write_out_file, test_log[1], jobs, wf['head_sha'])
+
+def download_gha_artifacts_filtered(workflow_run_id, workflow_run_attempt, prefixes=[], allowed_substrings=None):
+ """Download GHA artifacts matching prefixes and optional substring filters.
+
+ GHA artifact names include run attempt info, e.g.:
+ test-reports-runattempt1-test-default-3-6-linux.rocm.gpu.gfx942.1_68425162477.zip
+ while S3 prefixes look like:
+ test-reports-test-default-3-6
+ We strip the runattemptN- portion before matching prefixes.
+
+ When a shard is re-run, only the latest attempt's artifact exists for that
+ shard, while other shards keep their original attempt. We collect all
+ matching artifacts and prefer the highest run attempt per shard key.
+ """
+ from pathlib import Path
+ from collections import defaultdict
+ artifact_paths = []
+ response = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/runs/{workflow_run_id}/artifacts?per_page=100",
+ headers=authentication_headers,
+ )
+ artifacts = response.json().get("artifacts", [])
+ while "next" in response.links:
+ response = requests.get(response.links["next"]["url"], headers=authentication_headers)
+ artifacts.extend(response.json().get("artifacts", []))
+
+ # Group matching artifacts by shard key, keeping highest run attempt
+ # shard key = normalized name without runattemptN- and without runner/jobid suffix
+ best_per_shard = {}
+ for artifact in artifacts:
+ name = artifact["name"]
+ if not name.startswith("test-reports-"):
+ continue
+ if "rerun_disabled" in name:
+ continue
+ normalized = re.sub(r'runattempt\d+-', '', name)
+ if not any(normalized.startswith(pfx) for pfx in prefixes):
+ continue
+ if allowed_substrings and not any(sub in name for sub in allowed_substrings):
+ continue
+ # Extract run attempt number
+ attempt_match = re.search(r'runattempt(\d+)', name)
+ attempt_num = int(attempt_match.group(1)) if attempt_match else 0
+ # Use the shard portion as key (e.g., test-reports-test-default-3-6)
+ shard_key = re.sub(r'-[a-z]+\..*$', '', normalized)
+ if shard_key not in best_per_shard or attempt_num > best_per_shard[shard_key][0]:
+ best_per_shard[shard_key] = (attempt_num, name, artifact["archive_download_url"])
+
+ for shard_key, (attempt_num, name, url) in best_per_shard.items():
+ print(f"Downloading GHA artifact: {name}")
+ dl_response = requests.get(url, headers=authentication_headers)
+ if dl_response.status_code != 200:
+ print(f" WARNING: Failed to download (HTTP {dl_response.status_code})")
+ continue
+ p = Path(name if name.endswith(".zip") else name + ".zip")
+ with open(p, "wb") as f:
+ f.write(dl_response.content)
+ artifact_paths.append(p)
+
+ return artifact_paths
+
+def _shorten_unzipped_dirs():
+ """Rename unzipped-* directories to short names for Windows MAX_PATH compatibility.
+
+ Converts names like:
+ unzipped-test-reports-runattempt1-test-default-1-6-linux.rocm.gpu.gfx942.1_68613413431.zip
+ unzipped-test-reports-runattempt1-test-osdc-default-1-5-mt-l-x86aavx2-29-113-l4_73385044118.zip
+ to:
+ test-default-1-6_68613413431
+ test-default-1-5_73385044118
+
+ Preserves the 'test-' prefix so that summarize_xml_testreports.py
+ can still detect workflow type via substring matching. The trailing
+ '_' is the upstream pytorch CI job id, used to link to the
+ failing job from the parity summary.
+ """
+ from pathlib import Path
+ for d in sorted(Path(".").glob("unzipped-*")):
+ if not d.is_dir():
+ continue
+ m = re.search(r'test-(?:osdc-)?(default|distributed|inductor)-(\d+)-(\d+)', d.name)
+ if m:
+ short_name = f"test-{m.group(1)}-{m.group(2)}-{m.group(3)}"
+ # The original artifact name ends with "_.zip" where
+ # is the upstream pytorch CI job id (e.g.
+ # ..._68613413431.zip). Carry it onto short_name so
+ # summarize_xml_testreports.py can link to that job.
+ job_id_match = re.search(r'_(\d{6,})\.zip$', d.name)
+ if job_id_match:
+ short_name += f"_{job_id_match.group(1)}"
+ if not Path(short_name).exists():
+ d.rename(short_name)
+ print(f" Renamed {d.name} -> {short_name}")
+ else:
+ print(f" WARNING: {short_name} already exists, keeping {d.name}")
+
+def download_xml_files(workflow_run_id, workflow_run_attempts, prefixes=[], allowed_substrings=None):
+ # Get from S3 artifacts.
+ #
+ # When an upstream run is re-run, GitHub only re-executes the jobs that
+ # failed; jobs that already succeeded are not re-run and keep their
+ # artifacts under the attempt in which they ran. So a run whose current
+ # attempt is N can still have test reports sitting under an earlier
+ # attempt's S3 path (e.g. CUDA test-osdc shards uploaded on attempt 1
+ # while a retried ROCm shard lives under attempt 2). Searching only the
+ # latest attempt misses those carried-over artifacts. For each prefix
+ # (one shard) try attempts from the latest down to 1 and take the first
+ # (highest) attempt that has it, mirroring the GHA fallback's
+ # "prefer highest attempt per shard" behaviour.
+ artifact_paths = []
+ try:
+ max_attempt = int(workflow_run_attempts)
+ except (TypeError, ValueError):
+ max_attempt = 1
+ if max_attempt < 1:
+ max_attempt = 1
+ for prefix in prefixes:
+ for attempt in range(max_attempt, 0, -1):
+ print("Trying to download S3 artifacts for workflow_run_attempt {} with prefix {}".format(attempt, prefix))
+ found = download_s3_artifacts(
+ prefix,
+ workflow_run_id,
+ attempt,
+ allowed_substrings=allowed_substrings,
+ )
+ if found:
+ artifact_paths += found
+ break
+
+ # Filter out rerun_disabled_tests artifacts (same prefix, different job)
+ before = len(artifact_paths)
+ artifact_paths = [p for p in artifact_paths if "rerun_disabled" not in p.name]
+ if before != len(artifact_paths):
+ print(f" Filtered out {before - len(artifact_paths)} rerun_disabled artifacts")
+
+ # Fall back to GHA artifacts if S3 returned nothing
+ if len(artifact_paths) == 0:
+ print(f"No S3 artifacts found, trying GHA artifacts as fallback...")
+ artifact_paths = download_gha_artifacts_filtered(
+ workflow_run_id,
+ workflow_run_attempts,
+ prefixes=prefixes,
+ allowed_substrings=allowed_substrings,
+ )
+
+ if len(artifact_paths) == 0:
+ error_msg = f"WARNING: workflow run id: {workflow_run_id} - no artifacts found (S3 or GHA) for prefixes: {prefixes}"
+ print(error_msg)
+ error_msgs.append(error_msg)
+ return
+
+ for path in artifact_paths:
+ unzip(path)
+
+ _shorten_unzipped_dirs()
+
+ with open("_wf_run_id", "w") as f:
+ f.write(str(workflow_run_id))
+
+ # Delete raw zip files now that contents are extracted
+ for path in artifact_paths:
+ try:
+ path.unlink()
+ print(f" Deleted {path}")
+ except Exception:
+ pass
+
+def download_artifacts(wf, prefixes=[], test_folder=".", allowed_substrings=None):
+ os.chdir(test_folder)
+ #download the xml files
+ download_xml_files(
+ wf['id'],
+ wf.get('run_attempt',1),
+ prefixes,
+ allowed_substrings=allowed_substrings,
+ )
+ os.chdir("..")
+# for older runs, add 'created':'<=YYYY-MM-DD'. see https://docs.github.com/en/search-github/getting-started-with-searching-on-github/understanding-the-search-syntax#query-for-dates
+def download_workflow_run(created=None, max_pages=10, workflow=None, sha=None, ignore_status=False, status='success', error_msg='Error downloading workflow runs'):
+ if not workflow:
+ raise Exception("Workflow must be specified")
+ for page in range(max_pages):
+ params = {'per_page': 30, 'page': page}
+ if not ignore_status:
+ if status:
+ params['status'] = status
+ if created:
+ params['created'] = created
+ if sha:
+ params['head_sha'] = sha
+ else:
+ params['branch'] = "main"
+ print(".")
+
+ # Uncomment below for additional debug info
+ # print(f"authentication_headers: {authentication_headers}")
+ # print(f"params: {params}")
+ # print("https://api.github.com/repos/pytorch/pytorch/actions/workflows/{}.yml/runs".format(workflow))
+ response = requests.get("https://api.github.com/repos/pytorch/pytorch/actions/workflows/{}.yml/runs".format(workflow), headers=authentication_headers, params=params)
+ #print(response.json())
+ workflow_runs = None
+ try:
+ workflow_runs = response.json()['workflow_runs']
+ except:
+ raise Exception(response.text)
+ if not workflow_runs:
+ continue
+ # Prefer completed runs over in-progress ones. When multiple
+ # runs exist for the same SHA, the most recent may still be
+ # running and have no artifacts yet.
+ completed = [wf for wf in workflow_runs if wf.get('status') == 'completed']
+ if completed:
+ return completed[0]
+ return workflow_runs[0]
+
+ # Should not reach here ideally
+ raise Exception(error_msg)
+
+def create_test_folder(wf):
+ if wf is None:
+ raise Exception("wf is None!")
+ #return
+ test_folder = re.sub('T.*Z', '', wf['created_at'].replace(":", "").replace("-", "")) + "_" + wf['head_sha']
+ if not os.path.exists(test_folder):
+ os.mkdir(test_folder)
+
+ cuda_xml_folder = test_folder + "/cuda_xml"
+ if not os.path.exists(cuda_xml_folder):
+ os.mkdir(cuda_xml_folder)
+
+ rocm_xml_folder = test_folder + "/rocm_xml"
+ if not os.path.exists(rocm_xml_folder):
+ os.mkdir(rocm_xml_folder)
+ return [test_folder, cuda_xml_folder, rocm_xml_folder]
+
+_first_folder = None
+
+def get_or_create_test_folder(wf):
+ """Reuse the first folder created so all artifacts land in one place.
+
+ Different upstream workflows for the same SHA can have different created_at
+ dates (e.g. spanning midnight), which would cause create_test_folder to
+ create separate directories. This wrapper ensures every call returns the
+ same folder that was established by the very first invocation.
+ """
+ global _first_folder
+ if _first_folder is not None:
+ test_folder = _first_folder
+ cuda_xml = test_folder + "/cuda_xml"
+ rocm_xml = test_folder + "/rocm_xml"
+ os.makedirs(cuda_xml, exist_ok=True)
+ os.makedirs(rocm_xml, exist_ok=True)
+ return [test_folder, cuda_xml, rocm_xml]
+ result = create_test_folder(wf)
+ _first_folder = result[0]
+ return result
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Download pytorch unit test logs')
+ parser.add_argument('--created', const=None, help='eg., \'<=YYYY-MM-DD\'. See https://docs.github.com/en/search-github/getting-started-with-searching-on-github/understanding-the-search-syntax#query-for-dates')
+ parser.add_argument('--max_pages', type=int, default=10, help='eg., 100')
+ parser.add_argument('--sha1', const=None, help='eg., 3dcd67a1b374faea01f4d2e17beb6bb1fff76d76')
+ parser.add_argument('--exclude_distributed', action='store_true')
+ parser.add_argument('--exclude_inductor', action='store_true')
+ parser.add_argument('--exclude_default', action='store_true')
+ parser.add_argument('--ignore_status', action='store_true')
+ parser.add_argument('--artifacts_only', action='store_true')
+ parser.add_argument('--no_rocm', action='store_true')
+ parser.add_argument('--no_cuda', action='store_true')
+ parser.add_argument('--pr_id', type=int, help='The pull request ID')
+ parser.add_argument('--arch', type=str, choices=['mi200', 'mi300', 'mi350', 'navi31', 'nightly'], default='mi350', help='ROCm GPU architecture (mi200, mi300, mi350, navi31, or nightly, default: mi350)')
+ parser.add_argument('--include_inductor_periodic', action='store_true', help='Also download inductor-periodic benchmark artifacts (into a separate directory, not included in parity CSV)')
+ parser.add_argument('--baseline_sha', type=str, help='Baseline commit SHA to compare against. Downloads the same ROCm workflows for this commit into baseline_xml/.')
+ return parser.parse_args()
+
+# Rate-limit issues
+# Authenticated users get 5000 requests/day
+# Check rate-limit without penalty: curl -H "Authorization: token $GITHUB_TOKEN" -I https://api.github.com/users/octocat
+
+def resolve_full_trunk_run(sha, cuda_job_prefix, status='success', ignore_status=False):
+ """Resolve the canonical trunk.yml run for a SHA.
+
+ Upstream can produce more than one trunk.yml run for the same commit (for
+ example a reduced re-trigger that omits CUDA and only re-runs a subset of
+ the ROCm shards). The parity report has to read ROCm and CUDA from the
+ SAME full trunk push, otherwise it pairs a partial test set on one side
+ with the full set on the other and most columns come back empty. Only the
+ full push runs the CUDA test jobs, so anchor on the run that contains them
+ (falling back to the newest completed run when none can be found).
+
+ Returns (trunk_wf, cuda_test_jobs, cuda_test_kind, all_cuda_jobs).
+ """
+ params = {'per_page': 10, 'head_sha': sha}
+ if not ignore_status and status:
+ params['status'] = status
+ resp = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/workflows/{CUDAWorkflowNames['default']}.yml/runs",
+ headers=authentication_headers, params=params,
+ )
+ trunk_runs = resp.json().get('workflow_runs', [])
+ # Prefer the normal push trunk run over scheduled trunk runs. The scheduled
+ # trunk.yml runs are the periodic rerun_disabled_tests / mem_leak_check
+ # variants: their test jobs share the same names but run in a degenerate
+ # mode and produce near-empty reports. A stable sort keeps the API's
+ # newest-first order within each group, so we pick the newest push run.
+ trunk_runs.sort(key=lambda r: r.get('event') != 'push')
+
+ trunk_wf = None
+ cuda_test_jobs = []
+ cuda_test_kind = CUDA_TEST_KINDS[0]
+ all_cuda_jobs = []
+
+ for run in trunk_runs:
+ jobs = get_workflow_jobs(run)
+ kind, test_jobs = get_cuda_test_jobs(jobs, cuda_job_prefix)
+ if test_jobs:
+ trunk_wf = run
+ all_cuda_jobs = jobs
+ cuda_test_jobs = test_jobs
+ cuda_test_kind = kind
+ print(f"Found CUDA test jobs in trunk run {run['id']}")
+ break
+
+ if not cuda_test_jobs and trunk_runs:
+ # CUDA jobs may live in a run the jobs API does not surface (e.g. a
+ # reusable workflow); the check-runs API resolves the actual run.
+ print("No CUDA test jobs in any trunk run's jobs API, trying check-runs API...")
+ check_runs = get_check_runs_for_commit(sha, cuda_job_prefix)
+ cuda_test_kind, cuda_test_jobs = get_cuda_test_jobs(check_runs, cuda_job_prefix)
+ if cuda_test_jobs:
+ run_match = re.search(r'/runs/(\d+)/', cuda_test_jobs[0].get('details_url', ''))
+ if run_match:
+ actual_run_id = int(run_match.group(1))
+ trunk_wf = next((r for r in trunk_runs if r['id'] == actual_run_id), None)
+ if trunk_wf is None:
+ resp = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/runs/{actual_run_id}",
+ headers=authentication_headers,
+ )
+ trunk_wf = resp.json()
+ print(f"CUDA test jobs are in trunk run {trunk_wf['id']} (found via check-runs)")
+ all_cuda_jobs = list(cuda_test_jobs)
+
+ if trunk_wf is None and trunk_runs:
+ trunk_wf = trunk_runs[0]
+
+ return trunk_wf, cuda_test_jobs, cuda_test_kind, all_cuda_jobs
+
+
+def main():
+ global args
+ args = parse_args()
+ if args.max_pages < 1:
+ args.max_pages=1
+
+ # All arch-specific matching (which upstream workflow each test type lives
+ # in, job-name prefixes, shard counts and fallbacks) comes from
+ # parity_job_config.json - see PARITY_CONFIG at the top of this file.
+ global ROCmWorkflowNames
+ arch = args.arch # 'mi200', 'mi300', 'mi350', 'navi31', or 'nightly'
+ arch_config = PARITY_CONFIG["rocm"][arch]
+
+ ROCmWorkflowNames, rocm_job_prefix, arch_fallbacks = parity_config_views(arch_config)
+ rocm_shards = dict(arch_config["shard_counts"])
+ rocm_artifact_substrings = ROCM_ARTIFACT_SUBSTRINGS
+
+ # Fallback (workflow, job_prefix) to try per test type when the primary
+ # workflow run is missing for a SHA. Kept as tuples keyed by arch so the
+ # existing fallback logic below is unchanged.
+ def _fallbacks_for(test_type):
+ result = {}
+ for other_arch, other_config in PARITY_CONFIG["rocm"].items():
+ fallback = parity_config_views(other_config)[2].get(test_type)
+ if fallback:
+ result[other_arch] = (fallback["workflow"], fallback["job_prefix"])
+ return result
+ # navi31 only has default tests (no distributed/inductor workflows)
+ if arch in ("navi31",):
+ if not args.exclude_distributed:
+ print(f"NOTE: {arch} has no distributed workflow, auto-excluding distributed")
+ args.exclude_distributed = True
+ if not args.exclude_inductor:
+ print(f"NOTE: {arch} has no inductor workflow, auto-excluding inductor")
+ args.exclude_inductor = True
+ if args.baseline_sha and not args.no_cuda:
+ print("NOTE: baseline_sha provided, auto-skipping CUDA (commit-vs-commit comparison)")
+ args.no_cuda = True
+
+ print(f"Using ROCm architecture: {arch}")
+ print(f"Using ROCm workflows: {ROCmWorkflowNames}")
+ if not args.no_cuda:
+ print(f"Using CUDA workflows: {CUDAWorkflowNames}")
+ print(f"Using ROCm job prefixes: {rocm_job_prefix}")
+ print(f"Using initial ROCm shard counts (may be updated based on actual workflow used): {rocm_shards}")
+
+ token = os.getenv('GITHUB_TOKEN', '...')
+ global authentication_headers
+ authentication_headers = {'Authorization': f'token {token}'}
+ if (args.pr_id and args.sha1) or (not args.pr_id and not args.sha1):
+ error_msg = "Error: Please provide either pr_id or sha!"
+ print(error_msg)
+ sys.exit(1)
+ if args.pr_id:
+ pr_id = args.pr_id
+ sha = get_latest_commit_sha(pr_id, token)
+ else:
+ sha = args.sha1
+ pr_id = None
+ status = "success"
+ print(f"sha: {sha}")
+
+ # When comparing two commits, prefix log filenames with short SHAs
+ if args.baseline_sha:
+ current_prefix = sha[:8] + "_"
+ baseline_prefix = args.baseline_sha[:8] + "_"
+ else:
+ current_prefix = ""
+ baseline_prefix = "baseline_"
+
+ # Resolve the canonical trunk run for this SHA exactly once, so every
+ # section that reads from trunk (ROCm default, the ROCm distributed ->
+ # trunk fallback, and CUDA) shares the SAME full trunk push. Without this,
+ # ROCm picked the newest completed trunk run while CUDA picked the run
+ # carrying CUDA jobs; when upstream had a reduced re-trigger for the SHA
+ # those were different runs, pairing a partial set on one side with the
+ # full set on the other (mostly-empty columns).
+ arch_uses_trunk = "trunk" in ROCmWorkflowNames.values()
+ fallback_uses_trunk = any(
+ fb.get("workflow") == "trunk"
+ for fb in arch_fallbacks.values()
+ )
+ trunk_full_wf = None
+ trunk_cuda_test_jobs = []
+ trunk_cuda_test_kind = CUDA_TEST_KINDS[0]
+ trunk_all_cuda_jobs = []
+ if (not args.no_cuda) or arch_uses_trunk or fallback_uses_trunk:
+ print("==============================================")
+ print(f"Resolving canonical trunk run for sha: {sha}")
+ print("==============================================")
+ trunk_full_wf, trunk_cuda_test_jobs, trunk_cuda_test_kind, trunk_all_cuda_jobs = \
+ resolve_full_trunk_run(sha, CUDA_JOB_PREFIXES["default"], status, args.ignore_status)
+ if trunk_full_wf is not None:
+ print(f"Using trunk run {trunk_full_wf['id']} as the canonical run for this SHA "
+ f"({len(trunk_cuda_test_jobs)} CUDA test jobs); ROCm-trunk and CUDA share it")
+
+ if not args.exclude_distributed and not args.no_rocm:
+ periodic_sha = sha
+ print("==============================================")
+ print(f"Finding ROCm distributed tests in workflow '{ROCmWorkflowNames['distributed']}' by sha: {sha}")
+ print("==============================================")
+ # find distributed test in periodic workflow with success status
+ error_msg="Error: Periodic workflow not found in scanned workflow runs."
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ periodic_fallback_used = False
+ try:
+ periodic_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["distributed"], sha=periodic_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ periodic_wf = None
+ periodic_fallbacks = _fallbacks_for("distributed")
+ if periodic_wf is None and arch in periodic_fallbacks:
+ fallback_wf, fallback_prefix = periodic_fallbacks[arch]
+ print(f"Distributed not found in {ROCmWorkflowNames['distributed']}, falling back to {fallback_wf}")
+ if fallback_wf == "trunk" and trunk_full_wf is not None:
+ # Reuse the canonical trunk run so distributed reads from the
+ # same full push as CUDA (see resolve_full_trunk_run).
+ periodic_wf = trunk_full_wf
+ else:
+ periodic_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=periodic_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ periodic_fallback_used = True
+ if periodic_wf is None:
+ raise Exception(error_msg)
+ dist_wf_name = ROCmWorkflowNames['distributed'] if not periodic_fallback_used else periodic_fallbacks[arch][0]
+ print(f"Using workflow '{dist_wf_name}' with id:{periodic_wf['id']} for ROCm distributed")
+
+ if periodic_fallback_used and arch in periodic_fallbacks:
+ dist_job_prefix = periodic_fallbacks[arch][1]
+ else:
+ dist_job_prefix = rocm_job_prefix['distributed']
+
+ folder_list = get_or_create_test_folder(periodic_wf)
+
+ # Download logs
+ # If the ROCm distributed logs aren't found you might want to check the HUD for the correct tags
+ # HUD link: https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=rocm
+ # Make sure "Hide unstable jobs" is unselected, in case ROCm jobs are marked as unstable
+
+ if arch == "mi350":
+ dist_shards = 3 if not periodic_fallback_used else rocm_shards["distributed"]
+ else:
+ dist_shards = rocm_shards["distributed"]
+ print(f"Using final ROCm shard count {dist_shards} for distributed")
+
+ if not args.artifacts_only:
+ test_log_list_rocm_distributed = [
+ [f"{current_prefix}rocm_dist{i}.txt", f"{dist_job_prefix} / test (distributed, {i}, {dist_shards}"]
+ for i in range(1, dist_shards + 1)
+ ]
+ download_logs(periodic_wf, test_log_list_rocm_distributed, folder_list[0])
+
+ # Download artifacts
+ test_artifacts_list_rocm_distributed = [
+ f"test-reports-test-distributed-{i}-{dist_shards}"
+ for i in range(1, dist_shards + 1)
+ ]
+ download_artifacts(
+ periodic_wf,
+ test_artifacts_list_rocm_distributed,
+ folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # Download ROCm default rocm_wf when ROCm is enabled
+ if not args.no_rocm and not args.exclude_default:
+ rocm_sha = sha
+ print("===========================================")
+ print(f"Finding ROCm default tests in workflow '{ROCmWorkflowNames['default']}' by sha: {rocm_sha}")
+ print("===========================================")
+ # find tests in rocm workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ error_msg="Error: rocm workflow not found in scanned workflow runs. Try increasing max_pages."
+ default_fallback_used = False
+ if ROCmWorkflowNames["default"] == "trunk" and trunk_full_wf is not None:
+ # Read ROCm default from the canonical trunk run rather than the
+ # newest completed one, so it matches the CUDA side (see
+ # resolve_full_trunk_run).
+ rocm_wf = trunk_full_wf
+ else:
+ try:
+ rocm_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["default"], sha=rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ rocm_wf = None
+ default_fallbacks = _fallbacks_for("default")
+ if rocm_wf is None and arch in default_fallbacks:
+ fallback_wf, fallback_prefix = default_fallbacks[arch]
+ print(f"Default not found in {ROCmWorkflowNames['default']}, falling back to {fallback_wf}")
+ rocm_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ default_fallback_used = True
+ rocm_job_prefix['default'] = fallback_prefix
+ if rocm_wf is None:
+ raise Exception(error_msg)
+ default_wf_name = ROCmWorkflowNames['default'] if not default_fallback_used else default_fallbacks[arch][0]
+ print(f"Using workflow '{default_wf_name}' with id:{rocm_wf['id']} for ROCm default{' (fallback)' if default_fallback_used else ''}")
+
+ folder_list = get_or_create_test_folder(rocm_wf)
+
+ # Download logs
+ # If logs aren't found you might want to check the HUD for the correct tags
+ # HUD link: https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=rocm
+ if arch == "mi350":
+ default_shards = 6 if default_fallback_used else rocm_shards["default"]
+ else:
+ default_shards = rocm_shards["default"]
+ print(f"Using final ROCm shard count {default_shards} for default")
+
+ if not args.artifacts_only:
+ test_log_list_rocm_default = [
+ [f"{current_prefix}rocm{i}.txt", f"{rocm_job_prefix['default']} / test (default, {i}, {default_shards}"]
+ for i in range(1, default_shards + 1)
+ ]
+ download_logs(rocm_wf, test_log_list_rocm_default, folder_list[0])
+
+ # Download artifacts
+ test_artifacts_list_rocm_default = [
+ f"test-reports-test-default-{i}-{default_shards}"
+ for i in range(1, default_shards + 1)
+ ]
+ if not args.exclude_default:
+ download_artifacts(
+ rocm_wf,
+ test_artifacts_list_rocm_default,
+ test_folder=folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # add new inductor workflow downloading for ROCm
+ if not args.no_rocm and not args.exclude_inductor:
+ inductor_rocm_sha = sha
+ # find tests in inductor workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ print("===========================================")
+ print(f"Finding ROCm inductor tests in workflow '{ROCmWorkflowNames['inductor']}' by sha: {inductor_rocm_sha}")
+ print("===========================================")
+ error_msg="Error: inductor workflow not found in scanned workflow runs. Try increasing max_pages."
+ inductor_fallback_used = False
+ try:
+ inductor_wf_rocm = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["inductor"], sha=inductor_rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ inductor_wf_rocm = None
+ inductor_fallbacks = _fallbacks_for("inductor")
+ if inductor_wf_rocm is None and arch in inductor_fallbacks:
+ fallback_wf, fallback_prefix = inductor_fallbacks[arch]
+ print(f"Inductor not found in {ROCmWorkflowNames['inductor']}, falling back to {fallback_wf}")
+ inductor_wf_rocm = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=inductor_rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ inductor_fallback_used = True
+ if inductor_wf_rocm is None:
+ # Inductor is optional and does not run for every SHA. Skip it
+ # instead of aborting the whole run, which would also drop the
+ # CUDA download below and publish an empty report.
+ print(f"WARNING: {error_msg} Skipping ROCm inductor for this SHA.")
+ else:
+ inductor_wf_name = ROCmWorkflowNames['inductor'] if not inductor_fallback_used else inductor_fallbacks[arch][0]
+ print(f"Using workflow '{inductor_wf_name}' with id:{inductor_wf_rocm['id']} for ROCm inductor")
+
+ folder_list = get_or_create_test_folder(inductor_wf_rocm)
+
+ inductor_shards = rocm_shards["inductor"]
+ print(f"Using final ROCm shard count {inductor_shards} for inductor")
+ if inductor_fallback_used and arch in inductor_fallbacks:
+ inductor_job_prefix = inductor_fallbacks[arch][1]
+ else:
+ inductor_job_prefix = rocm_job_prefix['inductor']
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_rocm_inductor = [
+ [f"{current_prefix}rocm_inductor{i}.txt", f"{inductor_job_prefix} / test (inductor, {i}, {inductor_shards}"]
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_logs(inductor_wf_rocm, test_log_list_rocm_inductor, folder_list[0])
+
+ #Download artifacts
+ test_artifacts_list_rocm_inductor = [
+ f"test-reports-test-inductor-{i}-{inductor_shards}"
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_artifacts(
+ inductor_wf_rocm,
+ test_artifacts_list_rocm_inductor,
+ test_folder=folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ if not args.no_cuda:
+ cuda_config = PARITY_CONFIG["cuda"]
+ cuda_job_prefix = CUDA_JOB_PREFIXES["default"]
+ cuda_shards = cuda_config["shard_counts"]
+ print("==========================================")
+ print(f"Finding CUDA tests in workflow '{CUDAWorkflowNames['default']}' by sha: {sha}")
+ print("==========================================")
+
+ # Reuse the canonical trunk run resolved once at the top of main() so
+ # CUDA reads from the same full trunk push as the ROCm-trunk sections
+ # (see resolve_full_trunk_run).
+ trunk_wf = trunk_full_wf
+ cuda_test_jobs = trunk_cuda_test_jobs
+ cuda_test_job_kind = trunk_cuda_test_kind
+ all_cuda_jobs = trunk_all_cuda_jobs
+ if trunk_wf is None:
+ raise Exception("Error: No trunk workflow run found for CUDA tests")
+
+ print(f"Using workflow '{CUDAWorkflowNames['default']}' with id:{trunk_wf['id']} for CUDA default")
+
+ # Match artifacts by job ID across ALL run attempts. If trunk was
+ # re-run, CUDA shards that passed earlier are carried over and their
+ # S3 reports keep the original attempt's job ID; filtering only by the
+ # latest attempt's IDs would silently drop them (missing CUDA columns).
+ all_attempt_cuda_jobs = get_workflow_jobs(trunk_wf, all_attempts=True)
+ _, all_attempt_cuda_test_jobs = get_cuda_test_jobs(all_attempt_cuda_jobs, cuda_job_prefix)
+ cuda_job_ids = sorted({str(j['id']) for j in (cuda_test_jobs + all_attempt_cuda_test_jobs)})
+ cuda_artifact_substrings = [f"_{jid}" for jid in cuda_job_ids] if cuda_job_ids else ["nvidia.gpu"]
+ print(f"Using CUDA job prefix: {cuda_job_prefix}")
+ print(f"Using CUDA test job kind: {cuda_test_job_kind}")
+ print(f"Found {len(cuda_test_jobs)} CUDA test jobs matching prefix")
+
+ folder_list = get_or_create_test_folder(trunk_wf)
+
+ cuda_default_shards = cuda_shards["default"]
+ cuda_dist_shards = cuda_shards["distributed"]
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_cuda = [
+ [f"cuda{i}.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, {i}, {cuda_default_shards}"]
+ for i in range(1, cuda_default_shards + 1)
+ ]
+ if not args.exclude_distributed:
+ test_log_list_cuda += [
+ [f"cuda_dist{i}.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (distributed, {i}, {cuda_dist_shards}"]
+ for i in range(1, cuda_dist_shards + 1)
+ ]
+
+ download_logs(trunk_wf, test_log_list_cuda, folder_list[0], jobs=all_cuda_jobs)
+
+ # Download artifacts
+ test_artifacts_list_cuda = []
+ if not args.exclude_default:
+ test_artifacts_list_cuda += [
+ f"test-reports-{cuda_test_job_kind}-default-{i}-{cuda_default_shards}"
+ for i in range(1, cuda_default_shards + 1)
+ ]
+
+ if not args.exclude_distributed:
+ test_artifacts_list_cuda += [
+ f"test-reports-{cuda_test_job_kind}-distributed-{i}-{cuda_dist_shards}"
+ for i in range(1, cuda_dist_shards + 1)
+ ]
+
+ if test_artifacts_list_cuda:
+ download_artifacts(
+ trunk_wf,
+ test_artifacts_list_cuda,
+ test_folder=folder_list[1],
+ allowed_substrings=cuda_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # add new inductor workflow downloading for CUDA
+ if not args.exclude_inductor:
+ inductor_sha = sha
+ print("==========================================")
+ print(f"Finding CUDA inductor tests in workflow '{CUDAWorkflowNames['inductor']}' by sha: {inductor_sha}")
+ print("==========================================")
+ # find tests in inductor workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ error_msg="Error: inductor workflow not found in scanned workflow runs. Try increasing max_pages."
+ inductor_wf_cuda = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=CUDAWorkflowNames["inductor"], sha=inductor_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ print(f"Using workflow '{CUDAWorkflowNames['inductor']}' with id:{inductor_wf_cuda['id']} for CUDA inductor")
+
+ inductor_cuda_jobs = get_workflow_jobs(inductor_wf_cuda)
+ cuda_inductor_test_job_kind, cuda_inductor_test_jobs = get_cuda_inductor_test_jobs(inductor_cuda_jobs)
+ # Same carried-over-on-rerun reasoning as CUDA default above: match
+ # artifacts by job ID across all attempts of the inductor run.
+ all_attempt_inductor_jobs = get_workflow_jobs(inductor_wf_cuda, all_attempts=True)
+ _, all_attempt_inductor_test_jobs = get_cuda_inductor_test_jobs(all_attempt_inductor_jobs)
+ cuda_inductor_job_ids = sorted({str(j['id']) for j in (cuda_inductor_test_jobs + all_attempt_inductor_test_jobs)})
+ cuda_inductor_artifact_substrings = (
+ [f"_{jid}" for jid in cuda_inductor_job_ids]
+ if cuda_inductor_job_ids
+ else None
+ )
+ print(f"Using CUDA inductor test job kind: {cuda_inductor_test_job_kind}")
+ print(f"Found {len(cuda_inductor_test_jobs)} CUDA inductor test jobs")
+
+ folder_list = get_or_create_test_folder(inductor_wf_cuda)
+
+ cuda_inductor_prefix = CUDA_JOB_PREFIXES["inductor"]
+ cuda_inductor_shards = cuda_shards["inductor"]
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_cuda_inductor = [
+ [f"cuda_inductor{i}.txt", f"{cuda_inductor_prefix} / {cuda_inductor_test_job_kind} (inductor, {i}, {cuda_inductor_shards}"]
+ for i in range(1, cuda_inductor_shards + 1)
+ ]
+ download_logs(inductor_wf_cuda, test_log_list_cuda_inductor, folder_list[0], jobs=inductor_cuda_jobs)
+
+ test_artifacts_list_cuda_inductor = [
+ f"test-reports-{cuda_inductor_test_job_kind}-inductor-{i}-{cuda_inductor_shards}"
+ for i in range(1, cuda_inductor_shards + 1)
+ ]
+ download_artifacts(
+ inductor_wf_cuda,
+ test_artifacts_list_cuda_inductor,
+ test_folder=folder_list[1],
+ allowed_substrings=cuda_inductor_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # Download baseline commit artifacts for commit-vs-commit comparison
+ if args.baseline_sha and not args.no_rocm:
+ baseline_sha = args.baseline_sha
+ print("==============================================")
+ print(f"Downloading BASELINE ROCm artifacts for sha: {baseline_sha}")
+ print("==============================================")
+
+ import glob
+ existing_folders = sorted(glob.glob("[0-9]*_[0-9a-f]*"), key=os.path.getmtime, reverse=True)
+ if existing_folders:
+ test_folder = existing_folders[0]
+ else:
+ raise Exception("No output folder found from primary downloads")
+
+ baseline_xml_dir = os.path.join(test_folder, "baseline_xml")
+ os.makedirs(baseline_xml_dir, exist_ok=True)
+
+ if not args.exclude_default:
+ try:
+ baseline_default_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["default"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline default workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline default workflow '{ROCmWorkflowNames['default']}' id: {baseline_default_wf['id']}")
+ default_shards = rocm_shards["default"]
+
+ if not args.artifacts_only:
+ baseline_default_logs = [
+ [f"{baseline_prefix}rocm{i}.txt", f"{rocm_job_prefix['default']} / test (default, {i}, {default_shards}"]
+ for i in range(1, default_shards + 1)
+ ]
+ download_logs(baseline_default_wf, baseline_default_logs, test_folder)
+
+ baseline_default_prefixes = [
+ f"test-reports-test-default-{i}-{default_shards}"
+ for i in range(1, default_shards + 1)
+ ]
+ download_artifacts(
+ baseline_default_wf,
+ baseline_default_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline default artifacts: {e}")
+
+ if not args.exclude_distributed and "distributed" in ROCmWorkflowNames:
+ try:
+ baseline_dist_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["distributed"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline distributed workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline distributed workflow '{ROCmWorkflowNames['distributed']}' id: {baseline_dist_wf['id']}")
+ dist_shards = rocm_shards["distributed"]
+
+ if not args.artifacts_only:
+ baseline_dist_logs = [
+ [f"{baseline_prefix}rocm_dist{i}.txt", f"{rocm_job_prefix['distributed']} / test (distributed, {i}, {dist_shards}"]
+ for i in range(1, dist_shards + 1)
+ ]
+ download_logs(baseline_dist_wf, baseline_dist_logs, test_folder)
+
+ baseline_dist_prefixes = [
+ f"test-reports-test-distributed-{i}-{dist_shards}"
+ for i in range(1, dist_shards + 1)
+ ]
+ download_artifacts(
+ baseline_dist_wf,
+ baseline_dist_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline distributed artifacts: {e}")
+
+ if not args.exclude_inductor and "inductor" in ROCmWorkflowNames:
+ try:
+ baseline_inductor_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["inductor"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline inductor workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline inductor workflow '{ROCmWorkflowNames['inductor']}' id: {baseline_inductor_wf['id']}")
+ inductor_shards = rocm_shards["inductor"]
+
+ if not args.artifacts_only:
+ baseline_inductor_logs = [
+ [f"{baseline_prefix}rocm_inductor{i}.txt", f"{rocm_job_prefix['inductor']} / test (inductor, {i}, {inductor_shards}"]
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_logs(baseline_inductor_wf, baseline_inductor_logs, test_folder)
+
+ baseline_inductor_prefixes = [
+ f"test-reports-test-inductor-{i}-{inductor_shards}"
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_artifacts(
+ baseline_inductor_wf,
+ baseline_inductor_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline inductor artifacts: {e}")
+
+ print(f"Baseline artifacts saved to: {baseline_xml_dir}")
+
+ # Download inductor-periodic benchmark artifacts (separate from parity CSV)
+ if args.include_inductor_periodic:
+ print("==============================================")
+ print(f"Finding inductor-periodic tests in workflow 'inductor-periodic' by sha: {sha}")
+ print("==============================================")
+ error_msg = "Error: inductor-periodic workflow not found for this SHA. It may not have run on this commit."
+ try:
+ inductor_periodic_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow="inductor-periodic", sha=sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=error_msg,
+ )
+ except (IndexError, Exception) as e:
+ print(f"WARNING: {e}")
+ inductor_periodic_wf = None
+
+ if inductor_periodic_wf:
+ print(f"Using workflow 'inductor-periodic' with id:{inductor_periodic_wf['id']} for inductor-periodic")
+
+ folder_list = get_or_create_test_folder(inductor_periodic_wf)
+ test_folder = folder_list[0]
+
+ rocm_periodic_dir = os.path.join(test_folder, "inductor_periodic_rocm_dir")
+ cuda_periodic_dir = os.path.join(test_folder, "inductor_periodic_cuda_dir")
+ os.makedirs(rocm_periodic_dir, exist_ok=True)
+ os.makedirs(cuda_periodic_dir, exist_ok=True)
+
+ if not args.no_rocm:
+ print("Downloading inductor-periodic ROCm artifacts...")
+ download_artifacts(
+ inductor_periodic_wf,
+ ["test-reports-"],
+ test_folder=rocm_periodic_dir,
+ allowed_substrings=["rocm.gpu"],
+ )
+ os.chdir("..")
+
+ if not args.no_cuda:
+ print("Downloading inductor-periodic CUDA artifacts...")
+ cuda_periodic_job_ids = get_job_ids_by_prefix(inductor_periodic_wf, "linux.g5")
+ cuda_periodic_substrings = (
+ [f"_{jid}" for jid in cuda_periodic_job_ids]
+ if cuda_periodic_job_ids
+ else ["nvidia.gpu"]
+ )
+ download_artifacts(
+ inductor_periodic_wf,
+ ["test-reports-"],
+ test_folder=cuda_periodic_dir,
+ allowed_substrings=cuda_periodic_substrings,
+ )
+ os.chdir("..")
+
+ print(f"Inductor-periodic artifacts saved to:")
+ print(f" ROCm: {rocm_periodic_dir}")
+ print(f" CUDA: {cuda_periodic_dir}")
+ else:
+ print("Skipping inductor-periodic download (workflow run not found)")
+
+ return
+
+if __name__ == "__main__":
+ main()
+ if error_msgs:
+ for msg in error_msgs:
+ print(msg)
+ exit(1)
diff --git a/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py b/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py
new file mode 100644
index 000000000000..67398bbf4cf1
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py
@@ -0,0 +1,995 @@
+#!/usr/bin/env python3
+
+import argparse
+import ast
+import csv
+import os
+import re
+import sys
+
+
+TEST_CONFIGS = ['default', 'distributed', 'inductor']
+TEST_CONFIG_DISPLAY = {
+ 'default': 'TEST DEFAULT',
+ 'distributed': 'TEST DISTRIBUTED',
+ 'inductor': 'TEST INDUCTOR',
+}
+MAX_DIAGNOSTIC_FIELD_CHARS = 20_000
+DIAGNOSTIC_FIELDS = {
+ 'comments',
+ 'message_cuda',
+ 'message_rocm',
+ 'message_set1',
+ 'message_set2',
+ 'reason',
+ 'skip_reason',
+}
+
+
+def _configure_csv_field_limit():
+ limit = sys.maxsize
+ while True:
+ try:
+ csv.field_size_limit(limit)
+ return
+ except OverflowError:
+ limit //= 10
+
+
+_configure_csv_field_limit()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Generate a parity summary from per-architecture test status CSVs'
+ )
+ parser.add_argument(
+ '--csv', nargs='+', required=True,
+ help='CSV file(s) to summarize (one per architecture, same order as --arch)'
+ )
+ parser.add_argument(
+ '--arch', nargs='+', required=True,
+ help='Architecture labels matching --csv order (e.g. mi200 mi300 mi355)'
+ )
+ parser.add_argument('--sha', type=str, default='', help='Commit SHA')
+ parser.add_argument('--pr_id', type=str, default='', help='Pull request ID')
+ parser.add_argument(
+ '--set1_name', type=str, default='set1',
+ help='Name used for set1 in CSV column headers (default: set1)'
+ )
+ parser.add_argument(
+ '--set2_name', type=str, default='set2',
+ help='Name used for set2 in CSV column headers (default: set2)'
+ )
+ parser.add_argument(
+ '--output', type=str, default='parity_summary',
+ help='Output path prefix (produces .csv and .md)'
+ )
+ parser.add_argument(
+ '--log-failures', nargs='*', default=[],
+ help='CSV file(s) from detect_log_failures.py to include in summary'
+ )
+ return parser.parse_args()
+
+
+def load_csv(filepath):
+ with open(filepath, newline='') as f:
+ return [_truncate_diagnostic_fields(row) for row in csv.DictReader(f)]
+
+
+def _truncate_diagnostic_fields(row):
+ for field in DIAGNOSTIC_FIELDS:
+ value = row.get(field, '')
+ if len(value) > MAX_DIAGNOSTIC_FIELD_CHARS:
+ omitted = len(value) - MAX_DIAGNOSTIC_FIELD_CHARS
+ row[field] = (
+ value[:MAX_DIAGNOSTIC_FIELD_CHARS]
+ + f'\n...[truncated {omitted:,} chars by generate_summary.py]'
+ )
+ return row
+
+
+def detect_columns(headers, set1_name, set2_name):
+ s1_status = f'status_{set1_name}'
+ s2_status = f'status_{set2_name}'
+ s1_time = f'running_time_{set1_name}'
+ s2_time = f'running_time_{set2_name}'
+ if s1_status not in headers:
+ s1_status = 'status_set1'
+ s2_status = 'status_set2'
+ s1_time = 'running_time_set1'
+ s2_time = 'running_time_set2'
+ return s1_status, s2_status, s1_time, s2_time
+
+
+def test_config_stats_keys(s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+ if not has_set2:
+ return [
+ f'PASSED ({s1_name})',
+ f'SKIPPED ({s1_name})',
+ f'FAILED ({s1_name})',
+ f'MISSED ({s1_name})',
+ f'TOTAL {s1}',
+ ]
+ return [
+ f'SKIPPED (on {s1_name}, but not on {s2_name})',
+ f'SKIPPED (on {s1_name})',
+ f'SKIPPED (on {s2_name})',
+ f'MISSED (MISSED on {s1_name}, NOT SKIPPED on {s2_name})',
+ f'{s1}ONLY (PASSED on {s1}, NOT PASSED on {s2})',
+ s2,
+ s1,
+ 'SKIPPED + MISSED',
+ f'{s2} - (SKIPPED + MISSED)',
+ f'DISAGREE [(SKIPPED+MISSED)/{s2}] %',
+ ]
+
+
+def compute_test_config_stats(rows, s1_col, s2_col, s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+
+ if not has_set2:
+ vals = {}
+ keys = test_config_stats_keys(s1_name, s2_name, has_set2=False)
+ vals[keys[0]] = sum(1 for r in rows if r[s1_col] == 'PASSED')
+ vals[keys[1]] = sum(1 for r in rows if r[s1_col] == 'SKIPPED')
+ vals[keys[2]] = sum(1 for r in rows if r[s1_col] == 'FAILED')
+ vals[keys[3]] = sum(1 for r in rows if r[s1_col] == 'MISSED')
+ vals[keys[4]] = sum(1 for r in rows if r[s1_col].strip())
+ return vals
+
+ s1_skip_not_s2 = sum(
+ 1 for r in rows
+ if r[s1_col] == 'SKIPPED' and r[s2_col] != 'SKIPPED'
+ )
+ s1_skip = sum(1 for r in rows if r[s1_col] == 'SKIPPED')
+ s2_skip = sum(1 for r in rows if r[s2_col] == 'SKIPPED')
+ s1_miss_not_s2_skip = sum(
+ 1 for r in rows
+ if r[s1_col] == 'MISSED' and r[s2_col] != 'SKIPPED'
+ )
+ only_s1 = sum(
+ 1 for r in rows
+ if r[s1_col] == 'PASSED' and r[s2_col] != 'PASSED'
+ )
+ total_s2 = sum(1 for r in rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+ total_s1 = sum(1 for r in rows if r[s1_col].strip() and r[s1_col].strip() != 'MISSED')
+
+ skip_miss = s1_skip_not_s2 + s1_miss_not_s2_skip
+ s2_minus = total_s2 - skip_miss
+ pct = (skip_miss / total_s2 * 100) if total_s2 else 0
+
+ vals = {}
+ keys = test_config_stats_keys(s1_name, s2_name)
+ vals[keys[0]] = s1_skip_not_s2
+ vals[keys[1]] = s1_skip
+ vals[keys[2]] = s2_skip
+ vals[keys[3]] = s1_miss_not_s2_skip
+ vals[keys[4]] = only_s1
+ vals[keys[5]] = total_s2
+ vals[keys[6]] = total_s1
+ vals[keys[7]] = skip_miss
+ vals[keys[8]] = s2_minus
+ vals[keys[9]] = f'{pct:.2f}%'
+ return vals
+
+
+def overall_stats_keys(s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+ if not has_set2:
+ keys = []
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ keys.append(f'{status}({s1_name})')
+ keys += [
+ f'TOTAL {s1}',
+ f'TOTAL {s1} RUNNING TIME',
+ ]
+ return keys
+ keys = [
+ 'Overall DISAGREE%',
+ 'Overall AGREE%',
+ ]
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ keys.append(f'{status}({s1_name})')
+ keys.append(f'{status}({s2_name})')
+ keys += [
+ f'TOTAL {s2}',
+ f'TOTAL {s1}',
+ f'TOTAL {s1} RUNNING TIME',
+ f'TOTAL {s2} RUNNING TIME',
+ ]
+ return keys
+
+
+def compute_overall_stats(rows, s1_col, s2_col, s1_time_col, s2_time_col, s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+
+ def safe_float(v):
+ try:
+ return float(v)
+ except (ValueError, TypeError):
+ return 0.0
+
+ if not has_set2:
+ vals = {}
+ keys = overall_stats_keys(s1_name, s2_name, has_set2=False)
+ idx = 0
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col] == status)
+ idx += 1
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col].strip())
+ idx += 1
+ vals[keys[idx]] = f'{sum(safe_float(r[s1_time_col]) for r in rows):.2f}'
+ return vals
+
+ total_disagree = 0
+ total_s2 = 0
+ for wf in TEST_CONFIGS:
+ wf_rows = [r for r in rows if r['test_config'] == wf]
+ s1_skip_not_s2 = sum(
+ 1 for r in wf_rows
+ if r[s1_col] == 'SKIPPED' and r[s2_col] != 'SKIPPED'
+ )
+ s1_miss_not_s2_skip = sum(
+ 1 for r in wf_rows
+ if r[s1_col] == 'MISSED' and r[s2_col] != 'SKIPPED'
+ )
+ total_disagree += s1_skip_not_s2 + s1_miss_not_s2_skip
+ total_s2 += sum(1 for r in wf_rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+
+ disagree_pct = (total_disagree / total_s2 * 100) if total_s2 else 0
+ agree_pct = 100 - disagree_pct
+
+ vals = {}
+ keys = overall_stats_keys(s1_name, s2_name)
+ vals[keys[0]] = f'{disagree_pct:.2f}%'
+ vals[keys[1]] = f'{agree_pct:.2f}%'
+
+ idx = 2
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col] == status)
+ vals[keys[idx + 1]] = sum(1 for r in rows if r[s2_col] == status)
+ idx += 2
+
+ vals[keys[idx]] = sum(1 for r in rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+ idx += 1
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col].strip() and r[s1_col].strip() != 'MISSED')
+ idx += 1
+
+ vals[keys[idx]] = f'{sum(safe_float(r[s1_time_col]) for r in rows):.2f}'
+ idx += 1
+ vals[keys[idx]] = f'{sum(safe_float(r[s2_time_col]) for r in rows):.2f}'
+ return vals
+
+
+def _extract_message(raw):
+ """Turn a stored XML failure/error value into readable text.
+
+ summarize_xml_testreports.py writes the parsed XML failure node (a dict like
+ {'message': 'AssertionError: ...', 'text': ''}) into the
+ message_ CSV column via str(), so cells arrive as a dict repr. Prefer
+ the full traceback ('text'), fall back to the short 'message', and return
+ the raw string unchanged if it is not a dict repr.
+ """
+ if not raw:
+ return ''
+ raw = raw.strip()
+ if raw.startswith('{') and raw.endswith('}'):
+ try:
+ d = ast.literal_eval(raw)
+ if isinstance(d, dict):
+ return (d.get('text') or d.get('message') or '').strip()
+ except (ValueError, SyntaxError):
+ pass
+ return raw
+
+
+def _html_escape(s):
+ return (s or '').replace('&', '&').replace('<', '<').replace('>', '>')
+
+
+# The whole .md is piped into $GITHUB_STEP_SUMMARY, and GitHub restricts each
+# step summary to 1 MiB - past that the upload fails (and content can be dropped
+# silently as it nears the limit), so a single run with many large tracebacks
+# would lose the entire summary. We keep the rendered summary under this budget
+# (1 MiB minus headroom for the surrounding tables/markdown) and only clip the
+# longest failure messages when a run would otherwise exceed it; the full text
+# always stays in the 'Error Message' column of the CSV artifact.
+# https://docs.github.com/en/actions/reference/workflows-and-actions/workflow-commands#step-isolation-and-limits
+STEP_SUMMARY_BYTE_BUDGET = 950_000
+
+# Per-message hard cap. Generous on purpose: real tracebacks fit well within it,
+# so they render in full; it only bounds pathological multi-MB messages. The
+# global budget above is what protects the summary when there are many failures.
+PER_MESSAGE_CHAR_LIMIT = 50_000
+
+
+def _truncate_message(msg, limit=PER_MESSAGE_CHAR_LIMIT):
+ """Cap a failure message for the markdown summary.
+
+ Keeps the tail (the exception/assertion lives at the end of a traceback) and
+ points readers to the CSV 'Error Message' column, which keeps the full text
+ for dashboards.
+ """
+ if not msg or len(msg) <= limit:
+ return msg
+ return ('...(truncated; full message in the Error Message column of the CSV artifact)\n'
+ + msg[-limit:])
+
+
+def _message_cell(msg, char_limit=PER_MESSAGE_CHAR_LIMIT):
+ """Render a failure message as a collapsible cell for a markdown table.
+
+ Table cells can't contain raw newlines or unescaped pipes, so the message
+ is HTML-escaped, pipes are escaped, and newlines become
inside a
+
(GitHub renders these as line breaks). The whole thing is wrapped in
+ a so each row shows only a small 'view' toggle by default.
+ `char_limit` is set per-run by the budget pass below; 0 drops the body.
+ """
+ if not msg:
+ return ''
+ if char_limit <= 0:
+ return ('message omitted to keep the summary under GitHub\u2019s 1 MiB '
+ 'limit; see the Error Message column of the CSV artifact')
+ msg = _truncate_message(msg, char_limit)
+ body = (_html_escape(msg)
+ .replace('\r', '')
+ .replace('\n', '
')
+ .replace('|', '\\|'))
+ return f'view
{body}
'
+
+
+def _fit_message_cap(messages, budget):
+ """Largest uniform per-message char cap whose rendered cells fit `budget`
+ bytes. Messages shorter than the cap are untouched, so only the longest
+ (outlier) messages get clipped, and only when a run would exceed the budget.
+ """
+ if budget <= 0:
+ return 0
+
+ def total_bytes(cap):
+ return sum(len(_message_cell(m, cap).encode('utf-8')) for m in messages)
+
+ if total_bytes(PER_MESSAGE_CHAR_LIMIT) <= budget:
+ return PER_MESSAGE_CHAR_LIMIT
+ lo, hi = 0, PER_MESSAGE_CHAR_LIMIT
+ while lo < hi:
+ mid = (lo + hi + 1) // 2
+ if total_bytes(mid) <= budget:
+ lo = mid
+ else:
+ hi = mid - 1
+ return lo
+
+
+def _fill_failure_messages(lines, fail_rows):
+ """Fill the deferred 'Error Message' cells of the FAILED TESTS table.
+
+ `fail_rows` is a list of (line_index, row_prefix, raw_message). We first
+ measure every other byte in the summary (rendering these cells empty), then
+ spend whatever remains of STEP_SUMMARY_BYTE_BUDGET on the messages via a
+ single uniform cap, so typical runs show full tracebacks and only oversized
+ runs clip their longest messages.
+ """
+ if not fail_rows:
+ return
+ for idx, prefix, _ in fail_rows:
+ lines[idx] = f"{prefix} | |"
+ base_bytes = len('\n'.join('' if l is None else l for l in lines).encode('utf-8'))
+ cap = _fit_message_cap([m for _, _, m in fail_rows], STEP_SUMMARY_BYTE_BUDGET - base_bytes)
+ for idx, prefix, msg in fail_rows:
+ lines[idx] = f"{prefix} | {_message_cell(msg, cap)} |"
+
+
+def collect_failed_tests(arch_data, archs, s1_name, s2_name):
+ """Return a list of failed test rows across all architectures.
+
+ Only collects tests where s1 (ROCm) is FAILED. Each entry records shards
+ for both s1 and s2 so the reviewer can look up the failure in either CI
+ job. 'also_failing_in' is populated later once log failures are known so
+ CUDA log-only failures can be included.
+ """
+ failed = []
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, s1_time, _ = d['cols']
+ # message_ columns mirror the status_ naming (incl. the
+ # status_set1/set2 fallback), so derive them from the resolved cols.
+ s1_msg_col = s1_col.replace('status_', 'message_', 1)
+ s2_msg_col = s2_col.replace('status_', 'message_', 1)
+ has_set2 = d.get('has_set2', True)
+ for r in d['rows']:
+ s1 = r[s1_col].strip()
+ s2 = r[s2_col].strip() if has_set2 else ''
+ if s1 == 'FAILED':
+ entry = {
+ 'arch': arch,
+ 'test_file': r.get('test_file', ''),
+ 'test_class': r.get('test_class', ''),
+ 'test_name': r.get('test_name', ''),
+ 'test_config': r.get('test_config', ''),
+ 'run_time': r.get(s1_time, ''),
+ f'shard_{s1_name}': r.get(f'shard_{s1_name}', ''),
+ f'job_url_{s1_name}': r.get(f'job_url_{s1_name}', ''),
+ f'status_{s1_name}': s1,
+ 'error_message': _extract_message(r.get(s1_msg_col, '')),
+ }
+ if has_set2:
+ entry[f'shard_{s2_name}'] = r.get(f'shard_{s2_name}', '')
+ entry[f'job_url_{s2_name}'] = r.get(f'job_url_{s2_name}', '')
+ entry[f'status_{s2_name}'] = s2
+ entry['error_message_set2'] = _extract_message(r.get(s2_msg_col, ''))
+ failed.append(entry)
+
+ return failed
+
+
+def _add_cross_arch_info(failed_tests, log_failures, s2_name):
+ """Populate 'also_failing_in' for each entry.
+
+ Matches across other ROCm architectures (from XML-based failures) and also
+ includes s2 (CUDA) if a log failure is recorded for the same test tuple.
+ """
+ from collections import defaultdict
+ by_tuple = defaultdict(set)
+ for t in failed_tests:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ by_tuple[key].add(t['arch'])
+
+ cuda_log_tuples = set()
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s2_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ cuda_log_tuples.add((lf.get('test_file', ''), test_class, test_name))
+
+ for t in failed_tests:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ others = sorted(a for a in by_tuple[key] if a != t['arch'])
+ if key in cuda_log_tuples and s2_name not in others:
+ others.append(s2_name)
+ t['also_failing_in'] = ', '.join(others)
+
+
+def _add_log_failure_cross_arch(log_failures, failed_tests, s1_name, s2_name):
+ """Populate 'also_failing_in' for each log failure entry.
+
+ Cross-references: other archs that have the same test failing (either as
+ a log failure or as an XML-based failure), plus s2 (CUDA) if it appears
+ in log failures for the same test tuple.
+ """
+ from collections import defaultdict
+ by_tuple_archs = defaultdict(set)
+
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s1_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('test_file', ''), test_class, test_name)
+ by_tuple_archs[key].add(lf.get('arch', ''))
+ for t in failed_tests or []:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ by_tuple_archs[key].add(t['arch'])
+
+ cuda_log_tuples = set()
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s2_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ cuda_log_tuples.add((lf.get('test_file', ''), test_class, test_name))
+
+ for lf in log_failures or []:
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('test_file', ''), test_class, test_name)
+ arch = lf.get('arch', '')
+ others = sorted(a for a in by_tuple_archs[key] if a and a != arch)
+ if key in cuda_log_tuples and s2_name not in others:
+ others.append(s2_name)
+ lf['also_failing_in'] = ', '.join(others)
+
+
+def load_log_failures(filepaths):
+ """Load log failure CSVs from detect_log_failures.py.
+
+ Extracts the architecture from the filename (e.g. log_failures_mi355.csv -> mi355).
+ """
+ entries = []
+ for fp in filepaths:
+ if not os.path.isfile(fp):
+ continue
+ basename = os.path.basename(fp)
+ arch = ''
+ if basename.startswith('log_failures_') and basename.endswith('.csv'):
+ arch = basename[len('log_failures_'):-len('.csv')]
+ with open(fp, newline='') as f:
+ for row in csv.DictReader(f):
+ row['arch'] = arch
+ entries.append(row)
+ return entries
+
+
+def load_flaky_tests_as_log_failures(filepaths):
+ """Load flaky_tests_.csv and return entries shaped like log-failure rows.
+
+ Each returned dict has the same schema as the entries produced by
+ load_log_failures, with category='FLAKY' and reason='::',
+ so they can be appended to the log_failures list and surfaced in the
+ LOG-BASED FAILURES table alongside crashes/timeouts/etc.
+ """
+ entries = []
+ for fp in filepaths or []:
+ if not fp:
+ continue
+ basename = os.path.basename(fp)
+ if not (basename.startswith('log_failures_') and basename.endswith('.csv')):
+ continue
+ arch = basename[len('log_failures_'):-len('.csv')]
+ flaky_path = os.path.join(
+ os.path.dirname(fp),
+ 'flaky_tests_' + basename[len('log_failures_'):],
+ )
+ if not os.path.isfile(flaky_path):
+ continue
+ with open(flaky_path, newline='') as f:
+ for row in csv.DictReader(f):
+ test_class = row.get('test_class', '')
+ test_name = row.get('test_name', '')
+ entries.append({
+ 'arch': arch,
+ 'log_file': row.get('log_file', ''),
+ 'platform': row.get('platform', ''),
+ 'test_config': row.get('test_config', ''),
+ 'test_file': row.get('test_file', ''),
+ 'job_shard': row.get('job_shard', ''),
+ 'test_shard': row.get('test_shard', ''),
+ 'status': 'FLAKY',
+ 'category': 'FLAKY',
+ 'reason': f'{test_class}::{test_name}' if test_class else test_name,
+ 'exit_codes': '',
+ 'run_time': row.get('run_time', ''),
+ 'job_url': row.get('job_url', ''),
+ })
+ return entries
+
+
+def load_log_shards(filepaths):
+ """Load log shard inventory CSVs written alongside log_failures files.
+
+ For each log_failures_.csv, looks for a sibling log_shards_.csv
+ and returns a lookup dict:
+ (arch, platform, test_config, job_shard, normalized_test_file) -> test_shards_str
+
+ The CSV is produced by detect_log_failures.py and records every
+ (test_file, test_shard) pair observed per job-level shard. If an XML-based
+ failure's key matches, we can back-fill the test-level shard value.
+ """
+ lookup = {}
+ for fp in filepaths:
+ if not fp:
+ continue
+ basename = os.path.basename(fp)
+ arch = ''
+ if basename.startswith('log_failures_') and basename.endswith('.csv'):
+ arch = basename[len('log_failures_'):-len('.csv')]
+ shards_path = os.path.join(
+ os.path.dirname(fp),
+ 'log_shards_' + basename[len('log_failures_'):],
+ )
+ else:
+ continue
+ if not os.path.isfile(shards_path):
+ continue
+ with open(shards_path, newline='') as f:
+ for row in csv.DictReader(f):
+ key = (arch, row.get('platform', ''), row.get('test_config', ''),
+ row.get('job_shard', ''),
+ _norm_test_file(row.get('test_file', '')))
+ lookup[key] = row.get('test_shards', '')
+ return lookup
+
+
+def _format_test_shards(shards_str):
+ """Collapse a test_shards inventory string into a compact display value.
+
+ - '' -> ''
+ - '1/1' -> '1/1'
+ - '3/14' -> '3/14'
+ - '1/14,6/14,12/14' -> '1,6,12/14' (multiple test-level shards observed)
+ - mixed totals fall back to the raw string."""
+ if not shards_str:
+ return ''
+ parts = [p for p in shards_str.split(',') if p]
+ if len(parts) == 1:
+ return parts[0]
+ totals = set()
+ nums = []
+ for p in parts:
+ if '/' not in p:
+ return shards_str
+ a, b = p.split('/', 1)
+ totals.add(b)
+ nums.append(a)
+ if len(totals) == 1:
+ return f"{','.join(nums)}/{totals.pop()}"
+ return shards_str
+
+
+def fmt_val(v):
+ if isinstance(v, int):
+ return f'{v:,}'
+ return str(v)
+
+
+def fmt_run_time(v):
+ """Format a per-test run time (seconds, from the XML 'time' attribute).
+
+ Returns '' for blank/non-numeric values so log-based rows stay empty."""
+ try:
+ return f'{float(v):.2f}'
+ except (ValueError, TypeError):
+ return ''
+
+
+def build_rows(args, archs, arch_data):
+ """Return a list of (label, val_per_arch...) tuples and section markers."""
+ out = []
+ any_has_set2 = any(d.get('has_set2', True) for d in arch_data.values())
+
+ if args.sha:
+ out.append(('__header__', f'Commit SHA: {args.sha}'))
+ # Link straight to the upstream HUD page filtered (regex) to the trunk
+ # CUDA / inductor / rocm test jobs this report is built from, so a
+ # reviewer can jump from the summary to the matching CI jobs. Parens and
+ # pipes are percent-encoded to keep the markdown link valid.
+ hud_url = (
+ f'https://hud.pytorch.org/hud/pytorch/pytorch/{args.sha}/1'
+ '?per_page=50'
+ '&name_filter=%28trunk.*cuda%7Cinductor%7Crocm%29.*test.*'
+ '%28default%7Cdistributed%7Cinductor%29%2C'
+ '&useRegexFilter=true'
+ )
+ out.append(('__header__', f'HUD: [parity jobs for this commit]({hud_url})'))
+ if args.pr_id:
+ out.append(('__header__', f'PR ID: {args.pr_id}'))
+
+ wf_keys = test_config_stats_keys(args.set1_name, args.set2_name, has_set2=any_has_set2)
+ for wf in TEST_CONFIGS:
+ out.append(('__section__', TEST_CONFIG_DISPLAY[wf]))
+ arch_stats = {}
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, _, _ = d['cols']
+ has_set2 = d.get('has_set2', True)
+ wf_rows = [r for r in d['rows'] if r['test_config'] == wf]
+ arch_stats[arch] = compute_test_config_stats(
+ wf_rows, s1_col, s2_col, args.set1_name, args.set2_name,
+ has_set2=has_set2,
+ )
+ for key in wf_keys:
+ out.append((key, [arch_stats[a].get(key, 0) for a in archs]))
+
+ out.append(('__section__', 'OVERALL'))
+ ov_keys = overall_stats_keys(args.set1_name, args.set2_name, has_set2=any_has_set2)
+ arch_overall = {}
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, s1_time, s2_time = d['cols']
+ has_set2 = d.get('has_set2', True)
+ arch_overall[arch] = compute_overall_stats(
+ d['rows'], s1_col, s2_col, s1_time, s2_time,
+ args.set1_name, args.set2_name, has_set2=has_set2,
+ )
+ for key in ov_keys:
+ out.append((key, [arch_overall[a].get(key, 0) for a in archs]))
+ return out
+
+
+def _norm_test_file(path):
+ """Normalize a test_file string so XML-sourced ('a.b.c') and log-sourced
+ ('a/b/c') forms compare equal. Also strips a trailing .py if present."""
+ if not path:
+ return ''
+ s = path.replace('/', '.')
+ if s.endswith('.py'):
+ s = s[:-3]
+ return s
+
+
+def _parse_log_failure_names(lf):
+ """Extract test_class and test_name from a log failure's reason field.
+
+ Handles formats like 'TestClass::test_method' and
+ 'TestClass::test_method | extra reason text'.
+ """
+ reason = lf.get('reason', '')
+ if '::' not in reason:
+ return '', ''
+ test_part = reason.split(' | ', 1)[0] if ' | ' in reason else reason
+ parts = test_part.split('::', 1)
+ return parts[0], parts[1]
+
+
+def write_csv(rows, archs, output_path, failed_tests=None, s1_name='set1', s2_name='set2', has_set2=True, log_failures=None, shard_lookup=None):
+ csv_rows = []
+ csv_rows.append([''] + list(archs))
+ for label, vals in rows:
+ if label == '__header__':
+ csv_rows.append([vals])
+ elif label == '__section__':
+ csv_rows.append([])
+ csv_rows.append([vals])
+ else:
+ csv_rows.append([label] + list(vals))
+ csv_rows.append([])
+
+ s1_failed = [t for t in (failed_tests or []) if t.get(f'status_{s1_name}') == 'FAILED']
+
+ shard_lookup = shard_lookup or {}
+
+ def _xml_test_shard(t, platform):
+ key = (t.get('arch', ''), platform, t.get('test_config', ''),
+ t.get(f'shard_{platform}', ''),
+ _norm_test_file(t.get('test_file', '')))
+ return _format_test_shards(shard_lookup.get(key, ''))
+
+ if s1_failed:
+ csv_rows.append(['FAILED TESTS'])
+ header = ['Arch', 'Test Config', 'Test File', 'Test Class',
+ 'Test Name', 'Run Time (s)',
+ f'Job-Level Shard ({s1_name})',
+ f'Test-Level Shard ({s1_name})']
+ if has_set2:
+ header.append(f'Job-Level Shard ({s2_name})')
+ header.append(f'Test-Level Shard ({s2_name})')
+ header.append(f'Status ({s1_name})')
+ if has_set2:
+ header.append(f'Status ({s2_name})')
+ header.append('Also Failing In')
+ header.append('Error Message')
+ csv_rows.append(header)
+ for t in s1_failed:
+ row = [t['arch'], t['test_config'], t['test_file'],
+ t['test_class'], t['test_name'],
+ fmt_run_time(t.get('run_time', '')),
+ t.get(f'shard_{s1_name}', ''),
+ _xml_test_shard(t, s1_name)]
+ if has_set2:
+ row.append(t.get(f'shard_{s2_name}', ''))
+ row.append(_xml_test_shard(t, s2_name))
+ row.append(t[f'status_{s1_name}'])
+ if has_set2:
+ row.append(t.get(f'status_{s2_name}', ''))
+ row.append(t.get('also_failing_in', ''))
+ row.append(t.get('error_message', ''))
+ csv_rows.append(row)
+ csv_rows.append([])
+
+ if log_failures:
+ xml_failed_keys = {
+ (t['arch'], _norm_test_file(t['test_file']), t['test_class'], t['test_name'])
+ for t in (failed_tests or [])
+ }
+ rocm_log_failures = []
+ for lf in log_failures:
+ if lf.get('platform', '') != s1_name:
+ continue
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('arch', ''), _norm_test_file(lf.get('test_file', '')),
+ test_class, test_name)
+ # Skip entries already present in the XML-based FAILED TESTS table
+ # to avoid double-counting the same failure, except for FLAKY
+ # entries which represent an independent signal (a rerun passed).
+ if key in xml_failed_keys and lf.get('category', '') != 'FLAKY':
+ continue
+ rocm_log_failures.append(lf)
+ if rocm_log_failures:
+ csv_rows.append(['LOG-BASED FAILURES (not in XML)'])
+ csv_rows.append(['Arch', 'Platform', 'Test Config', 'Test File', 'Test Class',
+ 'Test Name', 'Run Time (s)', 'Job-Level Shard', 'Test-Level Shard',
+ 'Category', 'Also Failing In', 'Log File'])
+ for lf in rocm_log_failures:
+ test_class, test_name = _parse_log_failure_names(lf)
+ csv_rows.append([
+ lf.get('arch', ''), lf.get('platform', ''), lf.get('test_config', ''),
+ lf.get('test_file', ''), test_class, test_name,
+ fmt_run_time(lf.get('run_time', '')),
+ lf.get('job_shard', ''),
+ lf.get('test_shard', lf.get('shard', '')),
+ lf.get('category', ''),
+ lf.get('also_failing_in', ''),
+ lf.get('log_file', ''),
+ ])
+ csv_rows.append([])
+
+ with open(output_path, 'w', newline='') as f:
+ csv.writer(f).writerows(csv_rows)
+ print(f'CSV written to {output_path}')
+
+
+def write_markdown(rows, archs, output_path, failed_tests=None, s1_name='set1', s2_name='set2', has_set2=True, log_failures=None, shard_lookup=None):
+ lines = []
+ current_section = []
+
+ def flush_table():
+ if not current_section:
+ return
+ header = '| Metric | ' + ' | '.join(archs) + ' |'
+ sep = '| :--- | ' + ' | '.join(['---:'] * len(archs)) + ' |'
+ lines.append(header)
+ lines.append(sep)
+ for label, vals in current_section:
+ formatted = [fmt_val(v) for v in vals]
+ lines.append(f'| {label} | ' + ' | '.join(formatted) + ' |')
+ lines.append('')
+ current_section.clear()
+
+ for label, vals in rows:
+ if label == '__header__':
+ flush_table()
+ lines.append(f'**{vals}**')
+ lines.append('')
+ elif label == '__section__':
+ flush_table()
+ lines.append(f'### {vals}')
+ lines.append('')
+ else:
+ current_section.append((label, vals))
+
+ flush_table()
+
+ s1_failed = [t for t in (failed_tests or []) if t.get(f'status_{s1_name}') == 'FAILED']
+
+ shard_lookup = shard_lookup or {}
+
+ def _xml_test_shard(t, platform):
+ key = (t.get('arch', ''), platform, t.get('test_config', ''),
+ t.get(f'shard_{platform}', ''),
+ _norm_test_file(t.get('test_file', '')))
+ return _format_test_shards(shard_lookup.get(key, ''))
+
+ def _job_id_link(url):
+ if not url:
+ return ''
+ # Use the job id (digits after "/job/" in the URL) as the visible
+ # link label so the cell reads e.g. [76905282313](...).
+ m = re.search(r'/job/(\d+)', url)
+ if not m:
+ return ''
+ return f'[{m.group(1)}]({url})'
+
+ cols = ['Arch', 'Test Config', 'Test File', 'Test Class', 'Test Name',
+ 'Run Time (s)',
+ f'Job-Level Shard ({s1_name})',
+ f'Test-Level Shard ({s1_name})']
+ if has_set2:
+ cols.append(f'Job-Level Shard ({s2_name})')
+ cols.append(f'Test-Level Shard ({s2_name})')
+ cols.append(f'Status ({s1_name})')
+ if has_set2:
+ cols.append(f'Status ({s2_name})')
+ cols.append('Also Failing In')
+ cols.append(f'Job ID ({s1_name})')
+ if has_set2:
+ cols.append(f'Job ID ({s2_name})')
+ cols.append('Error Message')
+
+ # Filled in by _fill_failure_messages once the rest of the summary is sized,
+ # so the message column can use whatever byte budget is left.
+ fail_rows = []
+ if s1_failed:
+ lines.append(f'### FAILED TESTS ({len(s1_failed)})')
+ lines.append('')
+ lines.append('| ' + ' | '.join(cols) + ' |')
+ lines.append('| ' + ' | '.join(['---'] * len(cols)) + ' |')
+ for t in s1_failed:
+ line = (f"| {t['arch']} | {t['test_config']} | {t['test_file']} "
+ f"| {t['test_class']} | {t['test_name']} "
+ f"| {fmt_run_time(t.get('run_time', ''))} "
+ f"| {t.get(f'shard_{s1_name}', '')} "
+ f"| {_xml_test_shard(t, s1_name)}")
+ if has_set2:
+ line += f" | {t.get(f'shard_{s2_name}', '')}"
+ line += f" | {_xml_test_shard(t, s2_name)}"
+ line += f" | {t[f'status_{s1_name}']}"
+ if has_set2:
+ line += f" | {t.get(f'status_{s2_name}', '')}"
+ line += f" | {t.get('also_failing_in', '')}"
+ line += f" | {_job_id_link(t.get(f'job_url_{s1_name}', ''))}"
+ if has_set2:
+ line += f" | {_job_id_link(t.get(f'job_url_{s2_name}', ''))}"
+ lines.append(None) # placeholder; message cell filled in below
+ fail_rows.append((len(lines) - 1, line, t.get('error_message', '')))
+ lines.append('')
+ else:
+ lines.append('### FAILED TESTS')
+ lines.append('')
+ lines.append('No failed tests found.')
+ lines.append('')
+
+ if log_failures:
+ xml_failed_keys = {
+ (t['arch'], _norm_test_file(t['test_file']), t['test_class'], t['test_name'])
+ for t in (failed_tests or [])
+ }
+ rocm_log_failures = []
+ for lf in log_failures:
+ if lf.get('platform', '') != s1_name:
+ continue
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('arch', ''), _norm_test_file(lf.get('test_file', '')),
+ test_class, test_name)
+ if key in xml_failed_keys and lf.get('category', '') != 'FLAKY':
+ continue
+ rocm_log_failures.append(lf)
+ if rocm_log_failures:
+ lines.append(f'### LOG-BASED FAILURES (not in XML) ({len(rocm_log_failures)})')
+ lines.append('')
+ lines.append('These test failures were detected from CI log files but have no XML report')
+ lines.append('(typically due to timeouts, crashes, or process kills).')
+ lines.append('')
+ lines.append('| Arch | Platform | Test Config | Test File | Test Class | Test Name | Run Time (s) | Job-Level Shard | Test-Level Shard | Category | Also Failing In | Job ID |')
+ lines.append('| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |')
+ for lf in rocm_log_failures:
+ test_class, test_name = _parse_log_failure_names(lf)
+ lines.append(
+ f"| {lf.get('arch', '')} | {lf.get('platform', '')} | {lf.get('test_config', '')} "
+ f"| {lf.get('test_file', '')} | {test_class} "
+ f"| {test_name} "
+ f"| {fmt_run_time(lf.get('run_time', ''))} "
+ f"| {lf.get('job_shard', '')} "
+ f"| {lf.get('test_shard', lf.get('shard', ''))} "
+ f"| {lf.get('category', '')} "
+ f"| {lf.get('also_failing_in', '')} "
+ f"| {_job_id_link(lf.get('job_url', ''))} |"
+ )
+ lines.append('')
+
+ _fill_failure_messages(lines, fail_rows)
+
+ md = '\n'.join(lines)
+ with open(output_path, 'w') as f:
+ f.write(md)
+ print(f'Markdown written to {output_path}')
+ return md
+
+
+def main():
+ args = parse_args()
+
+ if len(args.csv) != len(args.arch):
+ print('Error: --csv and --arch must have the same number of values')
+ sys.exit(1)
+
+ archs = args.arch
+ arch_data = {}
+ for csv_path, arch in zip(args.csv, archs):
+ rows = load_csv(csv_path)
+ headers = set(rows[0].keys()) if rows else set()
+ cols = detect_columns(headers, args.set1_name, args.set2_name)
+ s2_col = cols[1]
+ has_set2 = any(r.get(s2_col, '').strip() for r in rows)
+ arch_data[arch] = {'rows': rows, 'cols': cols, 'has_set2': has_set2}
+
+ data_rows = build_rows(args, archs, arch_data)
+ failed = collect_failed_tests(arch_data, archs, args.set1_name, args.set2_name)
+ any_has_set2 = any(d.get('has_set2', True) for d in arch_data.values())
+ log_failures = load_log_failures(args.log_failures) if args.log_failures else []
+ if args.log_failures:
+ log_failures.extend(load_flaky_tests_as_log_failures(args.log_failures))
+ shard_lookup = load_log_shards(args.log_failures) if args.log_failures else {}
+
+ _add_cross_arch_info(failed, log_failures, args.set2_name)
+ _add_log_failure_cross_arch(log_failures, failed, args.set1_name, args.set2_name)
+
+ output_base = args.output
+ if output_base.endswith('.csv') or output_base.endswith('.md'):
+ output_base = output_base.rsplit('.', 1)[0]
+
+ write_csv(data_rows, archs, f'{output_base}.csv', failed, args.set1_name, args.set2_name, has_set2=any_has_set2, log_failures=log_failures, shard_lookup=shard_lookup)
+ write_markdown(data_rows, archs, f'{output_base}.md', failed, args.set1_name, args.set2_name, has_set2=any_has_set2, log_failures=log_failures, shard_lookup=shard_lookup)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.automation_scripts/pytorch-unit-test-scripts/parity_job_config.json b/.automation_scripts/pytorch-unit-test-scripts/parity_job_config.json
new file mode 100644
index 000000000000..976fb2e8ae27
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/parity_job_config.json
@@ -0,0 +1,75 @@
+{
+ "_comment": "Single source of truth for pytorch/pytorch job-name matching used by the parity tooling. Consumed by download_testlogs (Python, S3 artifact + log matching) and parity-auto.yml (bash/jq, upstream check-run gating before dispatching parity.yml). See _subfields for the per-arch schema.",
+ "_subfields": {
+ "default/distributed/inductor": "Ordered list of {workflow, job_prefix} sources for that test config. The first entry is the primary upstream workflow + check-run/job-name prefix; any later entries are fallbacks tried (in order) when the primary workflow run is missing for a SHA.",
+ "shard_counts": "Number of shards per test config.",
+ "test_kinds": "CUDA only: upstream test-job kinds to match (e.g. test-osdc, test).",
+ "checkrun_regex": "PCRE matching this arch's ROCm test check-run names (parity-auto gating). CUDA uses it for the trunk CUDA test jobs."
+ },
+ "_notes": {
+ "artifact_substrings": "The S3 artifact downloader always filters on the ROCm-GPU runner tag (rocm.gpu), since every ROCm runner label follows that convention; no per-arch config is needed.",
+ "workflow_regex": "parity-auto derives the per-arch upstream-workflow-path regex from the union of the workflow values across this arch's configs, so it is not stored here."
+ },
+ "cuda": {
+ "default": [{ "workflow": "trunk", "job_prefix": "linux-jammy-cuda13.0-py3.10-gcc11" }],
+ "distributed": [{ "workflow": "trunk", "job_prefix": "linux-jammy-cuda13.0-py3.10-gcc11" }],
+ "inductor": [{ "workflow": "inductor", "job_prefix": "unit-test / inductor-test" }],
+ "test_kinds": ["test-osdc", "test"],
+ "shard_counts": { "default": 5, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "(linux-jammy-cuda13[.]0-py3[.]10-gcc11 / (test-osdc|test) [(](default|distributed),|unit-test / inductor-test / (test-osdc|test) [(]inductor,)"
+ },
+ "rocm": {
+ "mi350": {
+ "default": [
+ { "workflow": "trunk", "job_prefix": "linux-jammy-rocm-py3.10-mi350" },
+ { "workflow": "rocm-mi350", "job_prefix": "linux-noble-rocm-py3.12-mi350" }
+ ],
+ "distributed": [
+ { "workflow": "periodic-rocm-mi350", "job_prefix": "linux-noble-rocm-py3.12-mi350" },
+ { "workflow": "trunk", "job_prefix": "linux-jammy-rocm-py3.10-mi350" }
+ ],
+ "inductor": [
+ { "workflow": "trunk", "job_prefix": "linux-jammy-rocm-py3.10-mi350" }
+ ],
+ "shard_counts": { "default": 8, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "rocm.*mi350.*/ test [(](default|distributed|inductor),"
+ },
+ "mi300": {
+ "default": [{ "workflow": "rocm-mi300", "job_prefix": "linux-noble-rocm-py3.12-mi300" }],
+ "distributed": [{ "workflow": "periodic-rocm-mi300", "job_prefix": "linux-noble-rocm-py3.12-mi300" }],
+ "inductor": [{ "workflow": "inductor-rocm-mi300", "job_prefix": "linux-noble-rocm-py3.12-mi300" }],
+ "shard_counts": { "default": 6, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "rocm.*mi300.*/ test [(](default|distributed|inductor),"
+ },
+ "mi200": {
+ "default": [
+ { "workflow": "rocm-mi200", "job_prefix": "linux-jammy-rocm-py3.10-mi200" },
+ { "workflow": "trunk-rocm-sandbox", "job_prefix": "linux-jammy-rocm-py3.10" }
+ ],
+ "distributed": [
+ { "workflow": "periodic-rocm-mi200", "job_prefix": "linux-jammy-rocm-py3.10-mi200" },
+ { "workflow": "trunk-rocm-sandbox", "job_prefix": "linux-jammy-rocm-py3.10" }
+ ],
+ "inductor": [
+ { "workflow": "inductor-rocm-mi200", "job_prefix": "linux-jammy-rocm-py3.10-mi200" },
+ { "workflow": "trunk-rocm-sandbox", "job_prefix": "linux-jammy-rocm-py3.10" }
+ ],
+ "shard_counts": { "default": 6, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "(rocm.*(mi200|mi210).*/ test [(](default|distributed|inductor),|linux-jammy-rocm-py3[.]10 / test [(](default|distributed|inductor),)"
+ },
+ "navi31": {
+ "default": [{ "workflow": "rocm-navi31", "job_prefix": "linux-jammy-rocm-py3.10-navi31" }],
+ "distributed": [{ "workflow": "periodic-rocm-navi31", "job_prefix": "linux-jammy-rocm-py3.10-navi31" }],
+ "inductor": [{ "workflow": "inductor-rocm-navi31", "job_prefix": "linux-jammy-rocm-py3.10-navi31" }],
+ "shard_counts": { "default": 2, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "rocm.*navi31.*/ test [(]default,"
+ },
+ "nightly": {
+ "default": [{ "workflow": "rocm-nightly", "job_prefix": "linux-noble-rocm-nightly-py3.12-gfx942" }],
+ "distributed": [{ "workflow": "rocm-nightly", "job_prefix": "linux-noble-rocm-nightly-py3.12-gfx942" }],
+ "inductor": [{ "workflow": "rocm-nightly", "job_prefix": "linux-noble-rocm-nightly-py3.12-gfx942" }],
+ "shard_counts": { "default": 6, "distributed": 3, "inductor": 2 },
+ "checkrun_regex": "rocm-nightly.*/ test [(](default|distributed|inductor),"
+ }
+ }
+}
diff --git a/.automation_scripts/pytorch-unit-test-scripts/requirements.txt b/.automation_scripts/pytorch-unit-test-scripts/requirements.txt
new file mode 100644
index 000000000000..9ee33b404d9c
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/requirements.txt
@@ -0,0 +1,4 @@
+pandas
+rockset
+boto3
+requests
diff --git a/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py b/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py
new file mode 100755
index 000000000000..4349a90f21f8
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py
@@ -0,0 +1,760 @@
+#!/usr/bin/env python3
+
+import argparse
+import csv
+import os
+import re
+import pandas as pd
+from enum import Enum
+from itertools import chain
+from pathlib import Path
+from upload_test_stats import (
+ parse_xml_report,
+ get_pytest_parallel_times,
+ summarize_test_cases,
+)
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+# excluded test suites for comparison
+EXCLUDED_TEST_SUITES = [
+ "_nvfuser.test_dynamo",
+ "_nvfuser.test_python_frontend",
+ "_nvfuser.test_torchscript",
+ "test_nvfuser_dynamo",
+ "test_nvfuser_frontend",
+ # Blocklisted on ROCm upstream, so it never runs there and would otherwise
+ # show up as falsely MISSED against the CUDA baseline.
+ "test_jit_cuda_fuser",
+ "inductor.test_cpu_repro"
+]
+
+
+EXCLUDED_TEST_CLASSES = [
+ "nvfuser_tests",
+ "TensorPipeCudaDdpComparisonTest",
+ "TensorPipeCudaDistAutogradTest",
+ "TensorPipeCudaRemoteModuleTest",
+ "TensorPipeCudaRpcTest",
+ "TensorPipeTensorPipeAgentCudaRpcTest",
+ "TensorPipeTensorPipeCudaDistAutogradTest",
+ "test_cpp_rpc"
+]
+EXCLUDED_TESTS = [
+]
+
+
+# Test config names
+TestConfigName = Enum('TestConfigName', ['default', 'distributed', 'inductor'])
+
+def _status_priority(test_case):
+ """Return a numeric priority for deduplication of retried tests.
+ PASSED/XFAILED are preferred over FAILED/ERROR/SKIPPED since a
+ passing retry means the test is considered passing (flaky) in CI."""
+ status = get_test_status(test_case)
+ return {"PASSED": 4, "XFAILED": 3, "SKIPPED": 2, "FAILED": 1, "ERROR": 1, "MISSED": 0}.get(status, 0)
+
+def _extract_shard(dirname):
+ """Extract shard number from directory names like 'test-default-3-6'."""
+ m = re.match(r'test-\w+-(\d+)-(\d+)', dirname)
+ if m:
+ return f"{m.group(1)}/{m.group(2)}"
+ return ""
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, path="."):
+ test_config = ""
+ test_cases = {}
+
+ # download_testlogs writes the upstream pytorch CI workflow run id
+ # into "_wf_run_id" alongside the shard dirs. We combine it with each
+ # shard dir's trailing "_" to form the URL
+ # https://github.com/pytorch/pytorch/actions/runs//job/
+ # surfaced as the "Job ID" column in the FAILED TESTS table.
+ wf_run_id = ""
+ wf_id_file = os.path.join(path, "_wf_run_id")
+ if os.path.isfile(wf_id_file):
+ with open(wf_id_file) as f:
+ wf_run_id = f.read().strip()
+
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ if "test-default" in new_dir:
+ test_config = TestConfigName.default.name
+ elif "test-distributed" in new_dir:
+ test_config = TestConfigName.distributed.name
+ elif "test-inductor" in new_dir:
+ test_config = TestConfigName.inductor.name
+ shard = _extract_shard(dir)
+ jid = re.search(r'_(\d+)$', dir)
+ job_url = (
+ f"https://github.com/pytorch/pytorch/actions/runs/{wf_run_id}/job/{jid.group(1)}"
+ if wf_run_id and jid else ""
+ )
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ try:
+ new_cases = parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ test_config
+ )
+ except Exception as e:
+ print(f"WARNING: Skipping malformed XML {xml_report}: {e}")
+ continue
+ for key, case in new_cases.items():
+ case["shard"] = shard
+ case["job_url"] = job_url
+ existing = test_cases.get(key)
+ if existing is None or _status_priority(case) > _status_priority(existing):
+ test_cases[key] = case
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if not test_case:
+ return "MISSED"
+ elif "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_running_time(test_case):
+ status = get_test_status(test_case)
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def check_time_valid(time):
+ if time == "":
+ return False
+ return True
+
+def summarize_xml_files(args):
+ # TODO: Add arguments and parse accordingly
+ set1_path = args.set1 if args.set1 else "."
+ set2_path = args.set2
+ set1_name = args.set1_name
+ set2_name = args.set2_name
+
+ # statistics
+ SKIPPED_DEFAULT = 0
+ MISSED_DEFAULT = 0
+ CUDA_DEFAULT = 0
+ ROCM_DEFAULT = 0
+ ROCMONLY_DEFAULT = 0
+
+ SKIPPED_DISTRIBUTED = 0
+ MISSED_DISTRIBUTED = 0
+ CUDA_DISTRIBUTED = 0
+ ROCM_DISTRIBUTED = 0
+ ROCMONLY_DISTRIBUTED = 0
+
+ SKIPPED_INDUCTOR = 0
+ MISSED_INDUCTOR = 0
+ CUDA_INDUCTOR = 0
+ ROCM_INDUCTOR = 0
+ ROCMONLY_INDUCTOR = 0
+
+ TOTAL_CUDA_RUNNING_TIME = 0.0
+ TOTAL_ROCM_RUNNING_TIME = 0.0
+
+ # filter example: --filter SKIPPED-PASSED-MISSED-PASSED (tuples: set1 status1 - set2 status1, set1 status2 - set2 status2)
+ ut_status_filter = args.filter if args.filter else "."
+ list_of_status = ut_status_filter.split('-') if args.filter else []
+ # assertion: should be an even number length
+ assert len(list_of_status) % 2 == 0
+ list_status_set1 = []
+ list_status_set2 = []
+
+ index = 0
+ while index < len(list_of_status):
+ # special handling for status-NOT_status scenario
+ if "NOT" in list_of_status[index] or "NOT" in list_of_status[index+1]:
+ if "NOT" in list_of_status[index]:
+ items = list_of_status[index].split('_')
+ not_item = items[1]
+ for ind in range(len(UT_STATUS_LIST)):
+ if UT_STATUS_LIST[ind] != not_item:
+ list_status_set1.append(UT_STATUS_LIST[ind])
+ list_status_set2.append(list_of_status[index+1])
+ else:
+ items = list_of_status[index+1].split('_')
+ not_item = items[1]
+ for ind in range(len(UT_STATUS_LIST)):
+ if UT_STATUS_LIST[ind] != not_item:
+ list_status_set2.append(UT_STATUS_LIST[ind])
+ list_status_set1.append(list_of_status[index])
+ index += 2
+ else:
+ list_status_set1.append(list_of_status[index])
+ index += 1
+ list_status_set2.append(list_of_status[index])
+ index += 1
+
+ assert len(list_status_set1) == len(list_status_set2), \
+ "status_list not specified correctly, should be in pairs of two"
+ len_status_filter = len(list_status_set1)
+
+ # define column list
+ column_list = ['set1', 'set2', 'skip_reason', 'assignee', 'comments']
+
+ # function location pattern
+ pattern = "at 0x"
+
+ #parse the xml files
+ test_cases_set1_running_time = parse_xml_reports_as_dict(-1, -1, 'testsuite', set1_path)
+ # TODO: Does it matter what the workflow_run_attempt is set to below??
+ # test_cases is dict of dicts, with keys as tuple of test_file, test_class, test_name and test_config
+ test_cases_set1 = parse_xml_reports_as_dict(-1, -1, 'testcase', set1_path)
+ for (k,v) in list(test_cases_set1.items()):
+ if v['test_config'] == TestConfigName.default.name:
+ ROCM_DEFAULT += 1
+ elif v['test_config'] == TestConfigName.distributed.name:
+ ROCM_DISTRIBUTED += 1
+ elif v['test_config'] == TestConfigName.inductor.name:
+ ROCM_INDUCTOR += 1
+
+ # start with creating empty dicts for set2 for each test tuple
+ # for rocm/cuda comparison(with valid set2_path), sometimes parity sheet has inaccurate resutls due to different function string but with same test names,
+ # such as test_np_argmin_argmax_keepdims_size_(1, 2, 3, 4)_axis_-4_method_
+ test_cases_set1_new: Dict[Tuple[str], Dict[str, Any]] = {}
+ if set2_path:
+ for (k,v) in list(test_cases_set1.items()):
+ if pattern in k[2]:
+ values = list(k)
+ index = k[2].find(pattern)
+ values[2] = k[2][0 : index]
+ k_new = tuple(values)
+ test_cases_set1_new[k_new] = v
+ del test_cases_set1[k]
+ #combine two dict
+ test_cases_set1_combined = {**test_cases_set1, **test_cases_set1_new}
+ test_cases = { k:[v, {}] for (k,v) in test_cases_set1_combined.items() }
+ else:
+ test_cases = { k:[v, {}] for (k,v) in test_cases_set1.items() }
+
+ test_cases_set2_running_time = {}
+ if set2_path:
+ assert set2_path != set1_path, \
+ "set2 path not specified correctly, should be different from set1 path"
+ test_cases_set2_running_time = parse_xml_reports_as_dict(-1, -1, 'testsuite', set2_path)
+ test_cases_set2 = parse_xml_reports_as_dict(-1, -1, 'testcase', set2_path)
+ for (k,v) in list(test_cases_set2.items()):
+ if v['test_config'] == TestConfigName.default.name:
+ CUDA_DEFAULT += 1
+ elif v['test_config'] == TestConfigName.distributed.name:
+ CUDA_DISTRIBUTED += 1
+ elif v['test_config'] == TestConfigName.inductor.name:
+ CUDA_INDUCTOR += 1
+
+ # for rocm/cuda comparison, sometimes parity sheet has inaccurate resutls due to different function string but with same test names,
+ # such as test_np_argmin_argmax_keepdims_size_(1, 2, 3, 4)_axis_-4_method_
+ test_cases_set2_new: Dict[Tuple[str], Dict[str, Any]] = {}
+ for (k,v) in list(test_cases_set2.items()):
+ if pattern in k[2]:
+ values = list(k)
+ index = k[2].find(pattern)
+ values[2] = k[2][0 : index]
+ k_new = tuple(values)
+ test_cases_set2_new[k_new] = v
+ del test_cases_set2[k]
+ #combine two dict
+ test_cases_set2_combined = {**test_cases_set2, **test_cases_set2_new}
+
+ # repopulate set2 dicts for test_tuples from test_cases_set2,
+ # creating empty dicts for set1 if test_tuple doesn't exist in test_cases
+ for test_case in test_cases_set2_combined:
+ test_cases[test_case] = [test_cases_set1_combined[test_case] if test_case in test_cases_set1_combined else {}, test_cases_set2_combined[test_case]]
+
+ # expand with skip_reason, assignee and comments
+ for (k,v) in list(test_cases.items()):
+ # set1, set2, skip_reason, assignee and comments
+ while len(v) < len(column_list):
+ v.append('')
+
+ # get running time statistics before any exclusion and filter since they are only for comparison
+ # total running time: ROCm and CUDA
+ for (k,v) in list(test_cases_set1_running_time.items()):
+ TOTAL_ROCM_RUNNING_TIME += v["running_time_xml"]
+ for (k,v) in list(test_cases_set2_running_time.items()):
+ TOTAL_CUDA_RUNNING_TIME += v["running_time_xml"]
+
+ # test file level running time: ROCm and CUDA
+ test_file_level_ROCm: Dict[Tuple[str], float] = {}
+ test_file_level_CUDA: Dict[Tuple[str], float] = {}
+ for (k,v) in list(test_cases_set1_running_time.items()):
+ test_file_name = k[0]
+ test_config_name = k[2]
+ tar_tup_rocm = (test_file_name, test_config_name,)
+ if test_file_level_ROCm.get(tar_tup_rocm) == None:
+ test_file_level_ROCm[ ( test_file_name, test_config_name ) ] = v["running_time_xml"]
+ else:
+ test_file_level_ROCm[ ( test_file_name, test_config_name ) ] += v["running_time_xml"]
+ for (k,v) in list(test_cases_set2_running_time.items()):
+ test_file_name = k[0]
+ test_config_name = k[2]
+ tar_tup_cuda = (test_file_name, test_config_name)
+ if test_file_level_CUDA.get(tar_tup_cuda) == None:
+ test_file_level_CUDA[ ( test_file_name, test_config_name ) ] = v["running_time_xml"]
+ else:
+ test_file_level_CUDA[ ( test_file_name, test_config_name ) ] += v["running_time_xml"]
+
+ # test file level counts: ROCm tests run, passed, skipped, missed; CUDA tests run
+ test_file_counts_ROCm: Dict[Tuple[str], Dict[str, int]] = {}
+ test_file_counts_CUDA: Dict[Tuple[str], int] = {}
+ for (k,v) in list(test_cases_set1.items()):
+ test_file_name = k[0]
+ test_config_name = v['test_config']
+ tar_tup = (test_file_name, test_config_name)
+ if tar_tup not in test_file_counts_ROCm:
+ test_file_counts_ROCm[tar_tup] = {'tests_run': 0, 'passed': 0, 'skipped': 0, 'missed': 0}
+ test_file_counts_ROCm[tar_tup]['tests_run'] += 1
+ status = get_test_status(v)
+ if status == "PASSED":
+ test_file_counts_ROCm[tar_tup]['passed'] += 1
+ elif status == "SKIPPED":
+ test_file_counts_ROCm[tar_tup]['skipped'] += 1
+ elif status == "MISSED":
+ test_file_counts_ROCm[tar_tup]['missed'] += 1
+ for (k,v) in list(test_cases_set2.items()) if set2_path else []:
+ test_file_name = k[0]
+ test_config_name = v['test_config']
+ tar_tup = (test_file_name, test_config_name)
+ if tar_tup not in test_file_counts_CUDA:
+ test_file_counts_CUDA[tar_tup] = 0
+ test_file_counts_CUDA[tar_tup] += 1
+
+ # exclude certain tests for comparison
+ if set2_path:
+ for (k,v) in list(test_cases.items()):
+ if k[0] in EXCLUDED_TEST_SUITES:
+ test_cases.pop(k)
+ elif k[1] in EXCLUDED_TEST_CLASSES:
+ test_cases.pop(k)
+ elif (k[0], k[1], k[2]) in EXCLUDED_TESTS:
+ test_cases.pop(k)
+
+ # remove unmatched items if user specified ut status filters
+ if len_status_filter > 0:
+ case_matched = True
+ for (k,v) in list(test_cases.items()):
+ case_matched = False
+ status_set_1 = get_test_status(v[0])
+ status_set_2 = get_test_status(v[1]) if set2_path else ""
+ for index in range(len_status_filter):
+ if status_set_1 == list_status_set1[index] and status_set_2 == list_status_set2[index]:
+ case_matched = True
+ break
+
+ if not case_matched:
+ test_cases.pop(k)
+
+ # insert skip_reason, assignee and comments info for the cases that: rocm-missed+cuda-passed OR rocm-skipped+cuda-passed
+ # To do: assume set1 is ROCm currently. Should insert another arg for ROCm and CUDA order?
+ skip_reasons_stat_default = dict()
+ skip_reasons_stat_distributed = dict()
+ skip_reasons_stat_inductor = dict()
+ if args.skip_reasons:
+ # read skip reasons csv file
+ known_skips = pd.read_csv(args.skip_reasons, sep='\t')
+ known_skips = known_skips.to_dict(orient="records")
+
+ # Load previous week's CSV to check if tests existed and get skip reasons
+ prev_week_tests = set()
+ prev_week_skip_reasons = {} # Maps (test_file, test_class, test_name) -> (skip_reason, assignee, comments)
+ if args.prev_week_csv:
+ prev_week_df = pd.read_csv(args.prev_week_csv)
+ for _, row in prev_week_df.iterrows():
+ test_key = (row['test_file'], row['test_class'], row['test_name'])
+ prev_week_tests.add(test_key)
+ # Also extract skip_reason, assignee, comments if they exist
+ skip_reason = row.get('skip_reason', '') if 'skip_reason' in row and not pd.isna(row.get('skip_reason', '')) else ''
+ assignee = row.get('assignee', '') if 'assignee' in row and not pd.isna(row.get('assignee', '')) else ''
+ comments = row.get('comments', '') if 'comments' in row and not pd.isna(row.get('comments', '')) else ''
+ if skip_reason or assignee or comments:
+ prev_week_skip_reasons[test_key] = (skip_reason, assignee, comments)
+
+ for (k,v) in list(test_cases.items()):
+ status_set_1 = get_test_status(v[0])
+ status_set_2 = get_test_status(v[1]) if set2_path else ""
+ test_file_name = k[0]
+ test_info = v[0]
+ test_info_set2 = []
+ if status_set_1 == "SKIPPED" and status_set_2 != "SKIPPED":
+ if test_info['test_config'] == TestConfigName.default.name:
+ SKIPPED_DEFAULT += 1
+ elif test_info['test_config'] == TestConfigName.distributed.name:
+ SKIPPED_DISTRIBUTED += 1
+ elif test_info['test_config'] == TestConfigName.inductor.name:
+ SKIPPED_INDUCTOR += 1
+ elif set2_path:
+ test_info_set2 = v[1]
+ if status_set_1 == "MISSED" and status_set_2 != "MISSED":
+ if test_info_set2['test_config'] == TestConfigName.default.name:
+ MISSED_DEFAULT += 1
+ elif test_info_set2['test_config'] == TestConfigName.distributed.name:
+ MISSED_DISTRIBUTED += 1
+ elif test_info_set2['test_config'] == TestConfigName.inductor.name:
+ MISSED_INDUCTOR += 1
+
+
+ if args.skip_reasons:
+ if (status_set_1 == "SKIPPED" and status_set_2 != "SKIPPED") or status_set_1 == "MISSED":
+ for known_skip in known_skips:
+ if test_file_name == known_skip['test_file'] and k[1] == known_skip['test_class'] and k[2] == known_skip['test_name']:
+ v[2] = known_skip['skip_reason'] if known_skip.__contains__('skip_reason') and not pd.isna(known_skip['skip_reason']) else ' '
+ if (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.default.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.default.name):
+ if not skip_reasons_stat_default.__contains__(v[2]):
+ skip_reasons_stat_default[v[2]] = 1
+ else:
+ skip_reasons_stat_default[v[2]] += 1
+ elif (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.distributed.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.distributed.name):
+ if not skip_reasons_stat_distributed.__contains__(v[2]):
+ skip_reasons_stat_distributed[v[2]] = 1
+ else:
+ skip_reasons_stat_distributed[v[2]] += 1
+ elif (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.inductor.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.inductor.name):
+ if not skip_reasons_stat_inductor.__contains__(v[2]):
+ skip_reasons_stat_inductor[v[2]] = 1
+ else:
+ skip_reasons_stat_inductor[v[2]] += 1
+ v[3] = known_skip['assignee'] if known_skip.__contains__('assignee') and not pd.isna(known_skip['assignee']) else ' '
+ v[4] = known_skip['comments'] if known_skip.__contains__('comments') and not pd.isna(known_skip['comments']) else ' '
+ break
+
+ if status_set_1 == "PASSED" and status_set_2 != "PASSED" and set2_path:
+ if test_info['test_config'] == TestConfigName.default.name:
+ ROCMONLY_DEFAULT += 1
+ elif test_info['test_config'] == TestConfigName.distributed.name:
+ ROCMONLY_DISTRIBUTED += 1
+ elif test_info['test_config'] == TestConfigName.inductor.name:
+ ROCMONLY_INDUCTOR += 1
+
+ skip_reasons_stat_default.pop(' ', None)
+ skip_reasons_stat_distributed.pop(' ', None)
+
+ test_cases_for_csv = {}
+ # k is test_tuple, v is list of rocm and cuda info for that test_tuple
+ skip_reason_file_specified = False
+ if args.skip_reasons:
+ skip_reason_file_specified = True
+ for (k,v) in test_cases.items():
+ item_values = {}
+ item_values["test_file"] = k[0]
+ item_values["test_class"] = k[1]
+ item_values["test_name"] = k[2]
+ item_values[f"status_{set1_name}"] = get_test_status(v[0])
+ item_values[f"status_{set2_name}"] = get_test_status(v[1]) if set2_path else ""
+ # get test config info
+ v_values = v[0]
+ v1_values = v[1] if set2_path else []
+ config_name = ""
+ item_values["test_config"] = ""
+ if item_values[f"status_{set1_name}"] != "MISSED":
+ config_name = v_values['test_config']
+ elif item_values[f"status_{set2_name}"] != "MISSED" and item_values[f"status_{set2_name}"] != "":
+ config_name = v1_values['test_config']
+ item_values["test_config"] = config_name
+ item_values[f"shard_{set1_name}"] = v_values.get('shard', '') if v_values else ''
+ item_values[f"shard_{set2_name}"] = v1_values.get('shard', '') if v1_values else ''
+ item_values[f"job_url_{set1_name}"] = v_values.get('job_url', '') if v_values else ''
+ item_values[f"job_url_{set2_name}"] = v1_values.get('job_url', '') if v1_values else ''
+ # get test related info
+ item_values[f"message_{set1_name}"] = get_test_message(v[0])
+ item_values[f"message_{set2_name}"] = get_test_message(v[1]) if set2_path else ""
+ # Get skip_reason, assignee, comments from --skip_reasons file if specified
+ if skip_reason_file_specified:
+ item_values["skip_reason"] = v[2]
+ item_values["assignee"] = v[3]
+ item_values["comments"] = v[4]
+ # Check if test existed in previous week's CSV and get skip reasons from there
+ if args.prev_week_csv:
+ test_key = (k[0], k[1], k[2]) # (test_file, test_class, test_name)
+ item_values["existed_last_week"] = "yes" if test_key in prev_week_tests else "no"
+ # If skip_reason not set by --skip_reasons, try to get from prev_week_csv
+ if not skip_reason_file_specified:
+ if test_key in prev_week_skip_reasons:
+ prev_skip_reason, prev_assignee, prev_comments = prev_week_skip_reasons[test_key]
+ item_values["skip_reason"] = prev_skip_reason
+ item_values["assignee"] = prev_assignee
+ item_values["comments"] = prev_comments
+ else:
+ item_values["skip_reason"] = ""
+ item_values["assignee"] = ""
+ item_values["comments"] = ""
+ if not skip_reason_file_specified and not args.prev_week_csv:
+ item_values["skip_reason"] = ""
+ item_values["assignee"] = ""
+ item_values["comments"] = ""
+ running_time1 = get_running_time(v[0])
+ item_values[f"running_time_{set1_name}"] = running_time1
+ running_time2 = get_running_time(v[1])
+ item_values[f"running_time_{set2_name}"] = running_time2
+ item_values["abs_time_diff"] = ""
+ item_values["relative_time_diff"] = ""
+ if check_time_valid(running_time1) and check_time_valid(running_time2):
+ item_values["abs_time_diff"] = running_time1 - running_time2
+ if get_running_time(v[1]) != 0.0:
+ item_values["relative_time_diff"] = 100 * (running_time1 - running_time2) / running_time2
+ test_cases_for_csv[k] = item_values
+
+ test_cases_for_csv = dict(sorted(test_cases_for_csv.items()))
+
+ #store test_cases in csv
+ tests_from_xml_filename = args.output_csv
+ keys_list = list(set(chain.from_iterable(sub.keys() for sub in test_cases_for_csv.values())))
+
+ def sorting_key(e):
+ if e == "invoking_file":
+ return 0
+ elif e == "test_file":
+ return 1
+ elif e == "test_class":
+ return 2
+ elif e == "test_name":
+ return 3
+ elif e == "test_config":
+ return 4
+ elif e == "skip_reason":
+ return 5
+ elif e == "assignee":
+ return 6
+ elif e == "comments":
+ return 7
+ elif e == f"status_{set1_name}":
+ return 8
+ elif e == f"message_{set1_name}":
+ return 9
+ elif e == f"running_time_{set1_name}":
+ return 10
+ elif e == f"status_{set2_name}":
+ return 11
+ elif e == f"message_{set2_name}":
+ return 12
+ elif e == f"running_time_{set2_name}":
+ return 13
+ elif e == "abs_time_diff":
+ return 14
+ elif e == "relative_time_diff":
+ return 15
+ elif e == "skipped":
+ return 16
+ elif e == "failure":
+ return 17
+ elif e == "error":
+ return 18
+ elif e == "system-out":
+ return 19
+ elif e == "existed_last_week":
+ return 20
+ elif e == f"shard_{set1_name}":
+ return 21
+ elif e == f"shard_{set2_name}":
+ return 22
+ elif e == f"job_url_{set1_name}":
+ return 23
+ elif e == f"job_url_{set2_name}":
+ return 24
+ elif e == "workflow_run_attempt" or e == "job_id":
+ return 1000
+ else:
+ return 100
+
+ keys_list.sort(key=sorting_key)
+
+ with open(tests_from_xml_filename, "w") as outfile:
+ writer = csv.DictWriter(outfile, fieldnames = keys_list)
+ writer.writeheader()
+ writer.writerows(test_cases_for_csv.values())
+ ## TODO - usage yet to be identified
+ #pytest_parallel_times = get_pytest_parallel_times()
+ ##extract test cases summary and save them in csv file
+ #test_cases_summary = summarize_test_cases(test_cases)
+ #testcases_summary_filename = "testcases_summary.csv"
+ #keys_list = list(set(chain.from_iterable(sub.keys() for sub in test_cases_summary)))
+ #with open(testcases_summary_filename, "w") as outfile:
+ # writer = csv.DictWriter(outfile, fieldnames = keys_list)
+ # writer.writeheader()
+ # writer.writerows(test_cases_summary)
+
+ # write test file running time to file
+ test_file_running_time_for_csv = {}
+ for key_rocm in test_file_level_ROCm.keys():
+ item_values = {}
+ item_values["test_file"] = key_rocm[0]
+ item_values["test_config"] = key_rocm[1]
+ item_values["rocm_running_time"] = test_file_level_ROCm[key_rocm]
+ item_values["cuda_running_time"] = 0.0
+ if key_rocm in test_file_level_CUDA.keys():
+ item_values["cuda_running_time"] = test_file_level_CUDA[key_rocm]
+ item_values["abs_time_diff"] = item_values["rocm_running_time"] - item_values["cuda_running_time"]
+ item_values["relative_time_diff"] = 0.0
+ if item_values["cuda_running_time"] != 0.0:
+ item_values["relative_time_diff"] = 100 * (item_values["rocm_running_time"] - item_values["cuda_running_time"]) / item_values["cuda_running_time"]
+ # Add test counts
+ item_values["rocm_tests_run"] = test_file_counts_ROCm.get(key_rocm, {}).get('tests_run', 0)
+ item_values["cuda_tests_run"] = test_file_counts_CUDA.get(key_rocm, 0)
+ item_values["rocm_passed"] = test_file_counts_ROCm.get(key_rocm, {}).get('passed', 0)
+ item_values["rocm_skipped"] = test_file_counts_ROCm.get(key_rocm, {}).get('skipped', 0)
+ item_values["rocm_missed"] = test_file_counts_ROCm.get(key_rocm, {}).get('missed', 0)
+ test_file_running_time_for_csv[key_rocm] = item_values
+
+ for key_cuda in test_file_level_CUDA.keys():
+ if not key_cuda in test_file_level_ROCm.keys():
+ item_values = {}
+ item_values["test_file"] = key_cuda[0]
+ item_values["test_config"] = key_cuda[1]
+ item_values["rocm_running_time"] = 0.0
+ item_values["cuda_running_time"] = test_file_level_CUDA[key_cuda]
+ item_values["abs_time_diff"] = item_values["rocm_running_time"] - item_values["cuda_running_time"]
+ item_values["relative_time_diff"] = 0.0
+ if item_values["cuda_running_time"] != 0.0:
+ item_values["relative_time_diff"] = 100 * (item_values["rocm_running_time"] - item_values["cuda_running_time"]) / item_values["cuda_running_time"]
+ # Add test counts
+ item_values["rocm_tests_run"] = test_file_counts_ROCm.get(key_cuda, {}).get('tests_run', 0)
+ item_values["cuda_tests_run"] = test_file_counts_CUDA.get(key_cuda, 0)
+ item_values["rocm_passed"] = test_file_counts_ROCm.get(key_cuda, {}).get('passed', 0)
+ item_values["rocm_skipped"] = test_file_counts_ROCm.get(key_cuda, {}).get('skipped', 0)
+ item_values["rocm_missed"] = test_file_counts_ROCm.get(key_cuda, {}).get('missed', 0)
+ test_file_running_time_for_csv[key_cuda] = item_values
+
+ test_file_running_time_for_csv = dict(sorted(test_file_running_time_for_csv.items()))
+ keys_list_running_time = list(set(chain.from_iterable(sub.keys() for sub in test_file_running_time_for_csv.values())))
+ def sorting_key_running_time(e):
+ if e == "test_file":
+ return 0
+ elif e == "test_config":
+ return 1
+ elif e == "rocm_running_time":
+ return 2
+ elif e == "cuda_running_time":
+ return 3
+ elif e == "abs_time_diff":
+ return 4
+ elif e == "relative_time_diff":
+ return 5
+ elif e == "rocm_tests_run":
+ return 6
+ elif e == "cuda_tests_run":
+ return 7
+ elif e == "rocm_passed":
+ return 8
+ elif e == "rocm_skipped":
+ return 9
+ elif e == "rocm_missed":
+ return 10
+ else:
+ return 100
+
+ keys_list_running_time.sort(key=sorting_key_running_time)
+ tests_from_xml_file_running_time = args.test_file_running_time_output_csv
+ with open(tests_from_xml_file_running_time, "w") as outfile:
+ writer = csv.DictWriter(outfile, fieldnames = keys_list_running_time)
+ writer.writeheader()
+ writer.writerows(test_file_running_time_for_csv.values())
+
+ # print summary
+ print( " " )
+ print( "_____________________________________" )
+ print( "Test-results" )
+ print( " " )
+ print( "=====Single GPU Number=====" )
+ print( "SKIPPED_DEFAULT, MISSED_DEFAULT, ROCMONLY_DEFAULT, CUDA_DEFAULT, ROCM_DEFAULT" )
+ print( str(SKIPPED_DEFAULT) + ", " + str(MISSED_DEFAULT) + ", " + str(ROCMONLY_DEFAULT) + ", " + str(CUDA_DEFAULT) + ", " + str(ROCM_DEFAULT) )
+ print( " " )
+ print( "=====Distributed GPU Number=====" )
+ print( "SKIPPED_DISTRIBUTED, MISSED_DISTRIBUTED, ROCMONLY_DISTRIBUTED, CUDA_DISTRIBUTED, ROCM_DISTRIBUTED" )
+ print( str(SKIPPED_DISTRIBUTED) + ", " + str(MISSED_DISTRIBUTED) + ", " + str(ROCMONLY_DISTRIBUTED) + ", " + str(CUDA_DISTRIBUTED) + ", " + str(ROCM_DISTRIBUTED) )
+ print( " " )
+ print( "=====Inductor GPU Number=====" )
+ print( "SKIPPED_INDUCTOR, MISSED_INDUCTOR, ROCMONLY_INDUCTOR, CUDA_INDUCTOR, ROCM_INDUCTOR" )
+ print( str(SKIPPED_INDUCTOR) + ", " + str(MISSED_INDUCTOR) + ", " + str(ROCMONLY_INDUCTOR) + ", " + str(CUDA_INDUCTOR) + ", " + str(ROCM_INDUCTOR) )
+ print( " " )
+ print( "SELECTED CAUSES SUMMARY" )
+ print( " " )
+ print( "=====================" )
+ print( "Single GPU test" )
+ sorted_skip_reasons_statistics_default = sorted(skip_reasons_stat_default.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_statistics_default:
+ print( skip_reason_entry, ": ", skip_reasons_stat_default[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Distributed test" )
+ sorted_skip_reasons_distributed_statistics = sorted(skip_reasons_stat_distributed.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_distributed_statistics:
+ print( skip_reason_entry, ": ", skip_reasons_stat_distributed[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Inductor test" )
+ sorted_skip_reasons_statistics_inductor = sorted(skip_reasons_stat_inductor.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_statistics_inductor:
+ print( skip_reason_entry, ": ", skip_reasons_stat_inductor[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Time statistics" )
+ print( "ROCM_RUNNING_TIME, CUDA_RUNNING_TIME" )
+ print( str(TOTAL_ROCM_RUNNING_TIME) + ", " + str(TOTAL_CUDA_RUNNING_TIME) )
+ #print( "ROCm test file level time statistics" )
+ #for (k,v) in list(test_file_level_ROCm.items()):
+ #print( k[0] + ", " + k[1] + ", " + k[2] + ", " + str(v) )
+ #print( "CUDA test file level time statistics" )
+ #for (k,v) in list(test_file_level_CUDA.items()):
+ #print( k[0] + ", " + k[1] + ", " + k[2] + ", " + str(v) )
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Parse xml test-reports')
+ parser.add_argument("--set1", required=False, type=str, help="absolute or relative path to first test-reports dir")
+ parser.add_argument("--set2", required=False, type=str, help="absolute or relative path to second test-reports dir")
+ parser.add_argument("--set1_name", required=False, type=str, default="set1", help="display name for set1 in CSV column headers (default: set1)")
+ parser.add_argument("--set2_name", required=False, type=str, default="set2", help="display name for set2 in CSV column headers (default: set2)")
+ parser.add_argument("--output_csv", required=False, type=str, help="output csv filename", default="tests_from_xml.csv")
+ parser.add_argument("--filter", required=False, type=str, help="ut status filter flag")
+ parser.add_argument("--skip_reasons", required=False, type=str, help='skip reasons file')
+ parser.add_argument("--test_file_running_time_output_csv", required=False, type=str, help="file running time output csv filename", default="file_running_time_output.csv")
+ parser.add_argument("--prev_week_csv", required=False, type=str, help="previous week's all tests status CSV file to check if tests existed")
+ return parser.parse_args()
+
+def main():
+ global args
+ args = parse_args()
+ summarize_xml_files(args)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py b/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py
new file mode 100644
index 000000000000..218e35768ef2
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py
@@ -0,0 +1,187 @@
+import gzip
+import io
+import json
+import os
+import xml.etree.ElementTree as ET
+import zipfile
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import boto3 # type: ignore[import]
+import requests
+import rockset # type: ignore[import]
+
+PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
+S3_RESOURCE = boto3.resource("s3")
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+
+def _get_request_headers() -> Dict[str, str]:
+ return {
+ "Accept": "application/vnd.github.v3+json",
+ "Authorization": "token " + os.environ["GITHUB_TOKEN"],
+ }
+
+
+def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
+ """Get all workflow artifacts with 'test-report' in the name."""
+ response = requests.get(
+ f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
+ )
+ artifacts = response.json()["artifacts"]
+ while "next" in response.links.keys():
+ response = requests.get(
+ response.links["next"]["url"], headers=_get_request_headers()
+ )
+ artifacts.extend(response.json()["artifacts"])
+
+ artifact_urls = {}
+ for artifact in artifacts:
+ if artifact["name"].startswith(prefix):
+ artifact_urls[Path(artifact["name"])] = artifact["archive_download_url"]
+ return artifact_urls
+
+
+def _download_artifact(
+ artifact_name: Path, artifact_url: str, workflow_run_attempt: int
+) -> Path:
+ # [Artifact run attempt]
+ # All artifacts on a workflow share a single namespace. However, we can
+ # re-run a workflow and produce a new set of artifacts. To avoid name
+ # collisions, we add `-runattempt1-` somewhere in the artifact name.
+ #
+ # This code parses out the run attempt number from the artifact name. If it
+ # doesn't match the one specified on the command line, skip it.
+ atoms = str(artifact_name).split("-")
+ for atom in atoms:
+ if atom.startswith("runattempt"):
+ found_run_attempt = int(atom[len("runattempt") :])
+ if workflow_run_attempt != found_run_attempt:
+ print(
+ f"Skipping {artifact_name} as it is an invalid run attempt. "
+ f"Expected {workflow_run_attempt}, found {found_run_attempt}."
+ )
+
+ print(f"Downloading {artifact_name}")
+
+ response = requests.get(artifact_url, headers=_get_request_headers())
+ with open(artifact_name, "wb") as f:
+ f.write(response.content)
+ return artifact_name
+
+
+def download_s3_artifacts(
+ prefix: str,
+ workflow_run_id: int,
+ workflow_run_attempt: int,
+ allowed_substrings: Optional[List[str]] = None,
+) -> List[Path]:
+ bucket = S3_RESOURCE.Bucket("gha-artifacts")
+ objs = bucket.objects.filter(
+ Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
+ )
+
+ found_one = False
+ paths = []
+ for obj in objs:
+ p = Path(Path(obj.key).name)
+ if allowed_substrings and not any(sub in p.name for sub in allowed_substrings):
+ continue
+ found_one = True
+ print(f"Downloading {p}")
+ with open(p, "wb") as f:
+ f.write(obj.get()["Body"].read())
+ paths.append(p)
+
+ if not found_one:
+ print(
+ "::warning title=s3 artifacts not found::"
+ "Didn't find any test reports in s3, there might be a bug!"
+ )
+ return paths
+
+
+def download_gha_artifacts(
+ prefix: str, workflow_run_id: int, workflow_run_attempt: int
+) -> List[Path]:
+ artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
+ paths = []
+ for name, url in artifact_urls.items():
+ paths.append(_download_artifact(Path(name), url, workflow_run_attempt))
+ return paths
+
+
+def upload_to_rockset(collection: str, docs: List[Any]) -> None:
+ print(f"Writing {len(docs)} documents to Rockset")
+ client = rockset.Client(
+ api_server="api.rs2.usw2.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
+ )
+ client.Collection.retrieve(collection).add_docs(docs)
+ print("Done!")
+
+
+def upload_to_s3(
+ workflow_run_id: int,
+ workflow_run_attempt: int,
+ collection: str,
+ docs: List[Dict[str, Any]],
+) -> None:
+ print(f"Writing {len(docs)} documents to S3")
+ body = io.StringIO()
+ for doc in docs:
+ json.dump(doc, body)
+ body.write("\n")
+
+ S3_RESOURCE.Object(
+ "ossci-raw-job-status",
+ f"{collection}/{workflow_run_id}/{workflow_run_attempt}",
+ ).put(
+ Body=gzip.compress(body.getvalue().encode()),
+ ContentEncoding="gzip",
+ ContentType="application/json",
+ )
+ print("Done!")
+
+
+def upload_file_to_s3(
+ file_name: str,
+ bucket: str,
+ key: str,
+) -> None:
+ """
+ Upload a local file to S3
+ """
+ print(f"Upload {file_name} to s3://{bucket}/{key}")
+ boto3.client("s3").upload_file(
+ file_name,
+ bucket,
+ key,
+ )
+
+
+def unzip(p: Path) -> None:
+ """Unzip the provided zipfile to a similarly-named directory.
+
+ Returns None if `p` is not a zipfile.
+
+ Looks like: /tmp/test-reports.zip -> /tmp/unzipped-test-reports/
+ """
+ assert p.is_file()
+ unzipped_dir = p.with_name("unzipped-" + p.stem)
+ print(f"Extracting {p} to {unzipped_dir}")
+
+ with zipfile.ZipFile(p, "r") as zip:
+ zip.extractall(unzipped_dir)
+
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
diff --git a/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py b/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py
new file mode 100644
index 000000000000..29384d3bd0b4
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py
@@ -0,0 +1,394 @@
+import argparse
+import os
+import sys
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Any, Dict, List, Tuple
+
+from upload_stats_lib import (
+ download_gha_artifacts,
+ download_s3_artifacts,
+ is_rerun_disabled_tests,
+ unzip,
+ upload_to_s3,
+)
+
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ test_config: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ #print(f"Parsing {tag}s for test report: {report}")
+ print(".", end="", flush=True)
+
+ job_id = get_job_id(report)
+ #print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["test_config"] = test_config
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for part in report.parts:
+ for backend in BACKENDS_LIST:
+ if backend in part:
+ case_name = case_name + "_" + part
+ break
+ else:
+ continue
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["test_config"] ) ] = case
+ elif tag == 'testsuite':
+ case["test_config"] = test_config
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for part in report.parts:
+ for backend in BACKENDS_LIST:
+ if backend in part:
+ case_name = case_name + "_" + part
+ break
+ else:
+ continue
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["test_config"] ) ] = case
+
+ return test_cases
+
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
+
+
+def get_pytest_parallel_times() -> Dict[Any, Any]:
+ pytest_parallel_times: Dict[Any, Any] = {}
+ for report in Path(".").glob("**/python-pytest/**/*.xml"):
+ invoking_file = report.parent.name
+
+ root = ET.parse(report)
+ # TODO: Skip test reports from rerun disabled tests, same reason as mentioned
+ # above
+ if is_rerun_disabled_tests(root):
+ continue
+
+ assert len(list(root.iter("testsuite"))) == 1
+ for test_suite in root.iter("testsuite"):
+ pytest_parallel_times[
+ (invoking_file, get_job_id(report))
+ ] = test_suite.attrib["time"]
+ return pytest_parallel_times
+
+
+def get_tests(
+ workflow_run_id: int, workflow_run_attempt: int
+) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
+ with TemporaryDirectory() as temp_dir:
+ print("Using temporary directory:", temp_dir)
+ os.chdir(temp_dir)
+
+ # Download and extract all the reports (both GHA and S3)
+ s3_paths = download_s3_artifacts(
+ "test-report", workflow_run_id, workflow_run_attempt
+ )
+ for path in s3_paths:
+ unzip(path)
+
+ artifact_paths = download_gha_artifacts(
+ "test-report", workflow_run_id, workflow_run_attempt
+ )
+ for path in artifact_paths:
+ unzip(path)
+
+ # Parse the reports and transform them to JSON
+ test_cases = []
+ for xml_report in Path(".").glob("**/*.xml"):
+ test_cases.extend(
+ parse_xml_report(
+ "testcase",
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ )
+ )
+
+ pytest_parallel_times = get_pytest_parallel_times()
+
+ return test_cases, pytest_parallel_times
+
+
+def get_tests_for_circleci(
+ workflow_run_id: int, workflow_run_attempt: int
+) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
+ # Parse the reports and transform them to JSON
+ test_cases = []
+ for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
+ test_cases.extend(
+ parse_xml_report(
+ "testcase", xml_report, workflow_run_id, workflow_run_attempt
+ )
+ )
+
+ pytest_parallel_times = get_pytest_parallel_times()
+
+ return test_cases, pytest_parallel_times
+
+
+def get_invoking_file_times(
+ test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
+) -> List[Dict[str, Any]]:
+ def get_key(summary: Dict[str, Any]) -> Any:
+ return (
+ summary["invoking_file"],
+ summary["job_id"],
+ )
+
+ def init_value(summary: Dict[str, Any]) -> Any:
+ return {
+ "job_id": summary["job_id"],
+ "workflow_id": summary["workflow_id"],
+ "workflow_run_attempt": summary["workflow_run_attempt"],
+ "invoking_file": summary["invoking_file"],
+ "time": 0.0,
+ }
+
+ ret = {}
+ for summary in test_case_summaries:
+ key = get_key(summary)
+ if key not in ret:
+ ret[key] = init_value(summary)
+ ret[key]["time"] += summary["time"]
+
+ for key, val in ret.items():
+ # when running in parallel in pytest, adding the test times will not give the correct
+ # time used to run the file, which will make the sharding incorrect, so if the test is
+ # run in parallel, we take the time reported by the testsuite
+ if key in pytest_parallel_times:
+ val["time"] = pytest_parallel_times[key]
+
+ return list(ret.values())
+
+
+def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Group test cases by classname, file, and job_id. We perform the aggregation
+ manually instead of using the `test-suite` XML tag because xmlrunner does
+ not produce reliable output for it.
+ """
+
+ def get_key(test_case: Dict[str, Any]) -> Any:
+ return (
+ test_case.get("file"),
+ test_case.get("classname"),
+ test_case["job_id"],
+ test_case["workflow_id"],
+ test_case["workflow_run_attempt"],
+ # [see: invoking file]
+ test_case["invoking_file"],
+ )
+
+ def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "file": test_case.get("file"),
+ "classname": test_case.get("classname"),
+ "job_id": test_case["job_id"],
+ "workflow_id": test_case["workflow_id"],
+ "workflow_run_attempt": test_case["workflow_run_attempt"],
+ # [see: invoking file]
+ "invoking_file": test_case["invoking_file"],
+ "tests": 0,
+ "failures": 0,
+ "errors": 0,
+ "skipped": 0,
+ "successes": 0,
+ "time": 0.0,
+ }
+
+ ret = {}
+ for test_case in test_cases:
+ key = get_key(test_case)
+ if key not in ret:
+ ret[key] = init_value(test_case)
+
+ ret[key]["tests"] += 1
+
+ if "failure" in test_case:
+ ret[key]["failures"] += 1
+ elif "error" in test_case:
+ ret[key]["errors"] += 1
+ elif "skipped" in test_case:
+ ret[key]["skipped"] += 1
+ else:
+ ret[key]["successes"] += 1
+
+ ret[key]["time"] += test_case["time"]
+ return list(ret.values())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
+ parser.add_argument(
+ "--workflow-run-id",
+ required=True,
+ help="id of the workflow to get artifacts from",
+ )
+ parser.add_argument(
+ "--workflow-run-attempt",
+ type=int,
+ required=True,
+ help="which retry of the workflow this is",
+ )
+ parser.add_argument(
+ "--head-branch",
+ required=True,
+ help="Head branch of the workflow",
+ )
+ parser.add_argument(
+ "--circleci",
+ action="store_true",
+ help="If this is being run through circleci",
+ )
+ args = parser.parse_args()
+
+ print(f"Workflow id is: {args.workflow_run_id}")
+
+ if args.circleci:
+ test_cases, pytest_parallel_times = get_tests_for_circleci(
+ args.workflow_run_id, args.workflow_run_attempt
+ )
+ else:
+ test_cases, pytest_parallel_times = get_tests(
+ args.workflow_run_id, args.workflow_run_attempt
+ )
+
+ # Flush stdout so that any errors in rockset upload show up last in the logs.
+ sys.stdout.flush()
+
+ # For PRs, only upload a summary of test_runs. This helps lower the
+ # volume of writes we do to Rockset.
+ test_case_summary = summarize_test_cases(test_cases)
+ invoking_file_times = get_invoking_file_times(
+ test_case_summary, pytest_parallel_times
+ )
+
+ upload_to_s3(
+ args.workflow_run_id,
+ args.workflow_run_attempt,
+ "test_run_summary",
+ test_case_summary,
+ )
+
+ upload_to_s3(
+ args.workflow_run_id,
+ args.workflow_run_attempt,
+ "invoking_file_times",
+ invoking_file_times,
+ )
+
+ if args.head_branch == "master":
+ # For master jobs, upload everytihng.
+ upload_to_s3(
+ args.workflow_run_id, args.workflow_run_attempt, "test_run", test_cases
+ )
diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py
new file mode 100644
index 000000000000..514afd19624c
--- /dev/null
+++ b/.automation_scripts/run_pytorch_unit_tests.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+from subprocess import STDOUT, CalledProcessError
+
+from collections import namedtuple
+from datetime import datetime
+from pathlib import Path
+from parse_xml_results import (
+ parse_xml_report
+)
+from pprint import pprint
+from typing import Any, Dict, List
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+DEFAULT_CORE_TESTS = [
+ "test_nn",
+ "test_torch",
+ "test_cuda",
+ "test_ops",
+ "test_unary_ufuncs",
+ "test_autograd",
+ "inductor/test_torchinductor"
+]
+
+DISTRIBUTED_CORE_TESTS = [
+ "distributed/test_c10d_common",
+ "distributed/test_c10d_nccl",
+ "distributed/test_distributed_spawn"
+]
+
+CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log"
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."):
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ test_cases.update(
+ parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ workflow_name
+ )
+ )
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_test_file_running_time(test_suite):
+ if test_suite.__contains__('time'):
+ return test_suite["time"]
+ return 0
+
+def get_test_running_time(test_case):
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def summarize_xml_files(path, workflow_name):
+ # statistics
+ TOTAL_TEST_NUM = 0
+ TOTAL_PASSED_NUM = 0
+ TOTAL_SKIPPED_NUM = 0
+ TOTAL_XFAIL_NUM = 0
+ TOTAL_FAILED_NUM = 0
+ TOTAL_ERROR_NUM = 0
+ TOTAL_EXECUTION_TIME = 0
+
+ #parse the xml files
+ test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
+ test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
+ test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
+ # results dict
+ res = {}
+ res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ]
+ test_file_items = set()
+ for (k,v) in list(test_suites.items()):
+ file_name = k[0]
+ if not file_name in test_file_items:
+ test_file_items.add(file_name)
+ # initialization
+ for item in res_item_list:
+ temp_item = test_file_and_status(file_name, item)
+ res[temp_item] = {}
+ temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
+ res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
+ test_running_time = get_test_file_running_time(v)
+ res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+ else:
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ test_running_time = get_test_file_running_time(v)
+ res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+
+ for (k,v) in list(test_cases.items()):
+ file_name = k[0]
+ class_name = k[1]
+ test_name = k[2]
+ combined_name = file_name + "::" + class_name + "::" + test_name
+ test_status = get_test_status(v)
+ test_running_time = get_test_running_time(v)
+ test_message = get_test_message(v, test_status)
+ test_info_value = ""
+ test_tuple_key_status = test_file_and_status(file_name, test_status)
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ TOTAL_TEST_NUM += 1
+ res[test_tuple_key_statistics]["TOTAL"] += 1
+ if test_status == "PASSED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["PASSED"] += 1
+ TOTAL_PASSED_NUM += 1
+ elif test_status == "SKIPPED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["SKIPPED"] += 1
+ TOTAL_SKIPPED_NUM += 1
+ elif test_status == "XFAILED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["XFAILED"] += 1
+ TOTAL_XFAIL_NUM += 1
+ elif test_status == "FAILED":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["FAILED"] += 1
+ TOTAL_FAILED_NUM += 1
+ elif test_status == "ERROR":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["ERROR"] += 1
+ TOTAL_ERROR_NUM += 1
+
+ # generate statistics_dict
+ statistics_dict = {}
+ statistics_dict["TOTAL"] = TOTAL_TEST_NUM
+ statistics_dict["PASSED"] = TOTAL_PASSED_NUM
+ statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM
+ statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
+ statistics_dict["FAILED"] = TOTAL_FAILED_NUM
+ statistics_dict["ERROR"] = TOTAL_ERROR_NUM
+ statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
+ aggregate_item = workflow_name + "_aggregate"
+ total_item = test_file_and_status(aggregate_item, "STATISTICS")
+ res[total_item] = statistics_dict
+
+ return res
+
+def run_command_and_capture_output(cmd):
+ try:
+ print(f"Running command '{cmd}'")
+ with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:
+ print(f"========================================", file=output_file, flush=True)
+ print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well
+ print(f"========================================", file=output_file, flush=True)
+ p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True)
+ except CalledProcessError as e:
+ print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!")
+
+def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/"
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/"
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/"
+ # use test.sh for tests execution
+ run_command_and_capture_output(test_shell_path)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+ return entire_results_dict
+
+def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return priority_results_dict
+
+def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ default_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ distributed_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/"
+ inductor_selected_test_suites = ""
+ non_inductor_selected_test_suites = ""
+ for item in selected_list:
+ if "inductor/" in item:
+ inductor_selected_test_suites += item
+ inductor_selected_test_suites += " "
+ else:
+ non_inductor_selected_test_suites += item
+ non_inductor_selected_test_suites += " "
+ if inductor_selected_test_suites != "":
+ inductor_selected_test_suites = inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ if non_inductor_selected_test_suites != "":
+ non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return selected_results_dict
+
+def run_test_and_summarize_results(
+ pytorch_root_dir: str,
+ priority_tests: bool,
+ test_config: List[str],
+ default_list: List[str],
+ distributed_list: List[str],
+ inductor_list: List[str],
+ skip_rerun: bool) -> Dict[str, Any]:
+
+ # copy current environment variables
+ _environ = dict(os.environ)
+
+ # modify path
+ test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh"
+ test_run_test_path = pytorch_root_dir + "/test/run_test.py"
+ repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/"
+ test_reports_src = pytorch_root_dir + "/test/test-reports/"
+ run_test_python_file = pytorch_root_dir + "/test/run_test.py"
+
+ # change directory to pytorch root
+ os.chdir(pytorch_root_dir)
+
+ # all test results dict
+ res_all_tests_dict = {}
+
+ # patterns
+ search_text = "--reruns=2"
+ replace_text = "--reruns=0"
+
+ # create logs folder
+ if not os.path.exists(repo_test_log_folder_path):
+ os.mkdir(repo_test_log_folder_path)
+
+ # Set common environment variables for all scenarios
+ os.environ['CI'] = '1'
+ os.environ['PYTORCH_TEST_WITH_ROCM'] = '1'
+ os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1'
+ os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda'
+ os.environ['CONTINUE_THROUGH_ERROR'] = 'True'
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(search_text, replace_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ # Time stamp
+ current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S")
+ print("Current date & time : ", current_datetime)
+ # performed as Job ID
+ str_current_datetime = str(current_datetime)
+ overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/"
+ os.mkdir(overall_logs_path_current_run)
+
+ global CONSOLIDATED_LOG_FILE_PATH
+ CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
+
+ # Check multi gpu availability if distributed tests are enabled
+ if ("distributed" in test_config) or len(distributed_list) != 0:
+ check_num_gpus_for_distributed()
+
+ # Install test requirements
+ command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
+ run_command_and_capture_output(command)
+
+ # Run entire tests for each workflow
+ if not priority_tests and not default_list and not distributed_list and not inductor_list:
+ # run entire tests for default, distributed and inductor workflows → use test.sh
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ # distributed test process
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ # inductor test process
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ if "distributed" in workflow_list:
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ if "inductor" in workflow_list:
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ # Run priority test for each workflow
+ elif priority_tests and not default_list and not distributed_list and not inductor_list:
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ # distributed test process
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ # will not run inductor priority tests
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ if "distributed" in workflow_list:
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ if "inductor" in workflow_list:
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ # Run specified tests for each workflow
+ elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests:
+ if default_list:
+ default_workflow_list = []
+ for item in default_list:
+ default_workflow_list.append(item)
+ res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list)
+ res_all_tests_dict["default"] = res_default_selected
+ if distributed_list:
+ distributed_workflow_list = []
+ for item in distributed_list:
+ distributed_workflow_list.append(item)
+ res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list)
+ res_all_tests_dict["distributed"] = res_distributed_selected
+ if inductor_list:
+ inductor_workflow_list = []
+ for item in inductor_list:
+ inductor_workflow_list.append(item)
+ res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list)
+ res_all_tests_dict["inductor"] = res_inductor_selected
+ else:
+ raise Exception("Invalid test configurations!")
+
+ # restore environment variables
+ os.environ.clear()
+ os.environ.update(_environ)
+
+ # restore files
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(replace_text, search_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ return res_all_tests_dict
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'")
+ parser.add_argument('--priority_tests', action='store_true', help="run priority tests only")
+ parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'")
+ parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'")
+ parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'")
+ parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory")
+ parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process")
+ parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n"
+ " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n"
+ "}}\n")
+ parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n"
+ "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n"
+ "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor")
+ return parser.parse_args()
+
+def check_num_gpus_for_distributed():
+ p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
+ num_gpus_visible = int(p.stdout)
+ assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
+
+def main():
+ args = parse_args()
+ all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
+ pprint(dict(all_tests_results))
+
+if __name__ == "__main__":
+ main()
diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt
index 23407b4d540c..3d17e9c0de64 100644
--- a/.ci/docker/ci_commit_pins/triton.txt
+++ b/.ci/docker/ci_commit_pins/triton.txt
@@ -1 +1 @@
-9844da955a9db14ec69c9aac828ee9803085e288
+ba5c1517e6f5906761cf5783036efb587026208d
diff --git a/.ci/docker/common/install_cache.sh b/.ci/docker/common/install_cache.sh
index 040a31fc379d..9bb80a4e80ec 100644
--- a/.ci/docker/common/install_cache.sh
+++ b/.ci/docker/common/install_cache.sh
@@ -38,7 +38,12 @@ sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment
export PATH="/opt/cache/bin:$PATH"
# Setup compiler cache
-install_ubuntu
+if [ -n "$ROCM_VERSION" ]; then
+ curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache
+else
+ install_ubuntu
+fi
+
chmod a+x /opt/cache/bin/sccache
function write_sccache_stub() {
diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh
index 1b68e3c24783..b2fdebdcc474 100755
--- a/.ci/docker/common/install_triton.sh
+++ b/.ci/docker/common/install_triton.sh
@@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
- TRITON_REPO="https://github.com/triton-lang/triton"
+ TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 14b8ff59fcfb..cf79a13b4e44 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -120,7 +120,7 @@ ninja==1.11.1.4
numba==0.57.1 ; python_version == "3.10" and platform_machine != "s390x"
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#Description: Just-In-Time Compiler for Numerical Functions
-#Pinned versions: 0.55.2, 0.60.0
+#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
@@ -141,6 +141,7 @@ numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
numpy==2.1.2; python_version >= "3.13" and python_version < "3.14"
numpy==2.3.4; python_version >= "3.14"
+
pandas==2.0.3; python_version < "3.12"
pandas==2.2.3; python_version >= "3.12" and python_version < "3.14"
pandas==2.3.3; python_version >= "3.14"
@@ -254,8 +255,7 @@ scikit-image==0.22.0
#Pinned versions: 0.20.3
#test that import:
-scipy==1.10.1 ; python_version <= "3.11"
-scipy==1.14.1 ; python_version > "3.11" and python_version < "3.14"
+scipy==1.14.1 ; python_version > "3.9" and python_version < "3.14"
scipy==1.16.2 ; python_version >= "3.14"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
@@ -316,8 +316,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x"
#Pinned versions:
#test that import:
-tensorboard==2.13.0 ; python_version < "3.13"
-tensorboard==2.18.0 ; python_version >= "3.13"
+tensorboard==2.18.0
#Description: Also included in .ci/docker/requirements-docs.txt
#Pinned versions:
#test that import: test_tensorboard
diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh
index b5c0a5e43dea..88e587ab5ff7 100644
--- a/.ci/pytorch/common_utils.sh
+++ b/.ci/pytorch/common_utils.sh
@@ -67,13 +67,13 @@ function pip_install_whl() {
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
fi
}
diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py
index e5ac9c5937df..0979c6f3f436 100644
--- a/.github/scripts/build_triton_wheel.py
+++ b/.github/scripts/build_triton_wheel.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import os
+import re
import shutil
import sys
from pathlib import Path
@@ -51,6 +52,31 @@ def patch_init_py(
with open(path, "w") as f:
f.write(orig)
+def get_rocm_version() -> str:
+ rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
+ rocm_version = "0.0.0"
+ rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
+ if not os.path.isfile(rocm_version_h):
+ rocm_version_h = f"{rocm_path}/include/rocm_version.h"
+
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
+ if os.path.isfile(rocm_version_h):
+ RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
+ RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
+ RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
+ major, minor, patch = 0, 0, 0
+ for line in open(rocm_version_h):
+ match = RE_MAJOR.search(line)
+ if match:
+ major = int(match.group(1))
+ match = RE_MINOR.search(line)
+ if match:
+ minor = int(match.group(1))
+ match = RE_PATCH.search(line)
+ if match:
+ patch = int(match.group(1))
+ rocm_version = str(major)+"."+str(minor)+"."+str(patch)
+ return rocm_version
def build_triton(
*,
@@ -66,13 +92,22 @@ def build_triton(
max_jobs = os.cpu_count() or 1
env["MAX_JOBS"] = str(max_jobs)
+ version_suffix = ""
+ if not release:
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
+ # while release build should only include the version, i.e. 2.1.0
+ rocm_version = get_rocm_version()
+ version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
+ version += version_suffix
+
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
- triton_pkg_name = "triton-rocm"
+ triton_pkg_name = "triton"
+ triton_repo = "https://github.com/ROCm/triton"
elif device == "xpu":
triton_pkg_name = "triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
@@ -90,6 +125,7 @@ def build_triton(
# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
+ env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
@@ -128,6 +164,13 @@ def build_triton(
cwd=triton_basedir,
)
+ # For gpt-oss models, triton requires this extra triton_kernels wheel
+ # triton_kernels came after pytorch release/2.8
+ triton_kernels_dir = Path(f"{triton_basedir}/python/triton_kernels")
+ check_call([sys.executable, "-m", "build", "--wheel"], cwd=triton_kernels_dir, env=env)
+ kernels_whl_path = next(iter((triton_kernels_dir / "dist").glob("*.whl")))
+ shutil.copy(kernels_whl_path, Path.cwd())
+
return Path.cwd() / whl_path.name
diff --git a/.github/scripts/install_pytorch_wheels.py b/.github/scripts/install_pytorch_wheels.py
new file mode 100644
index 000000000000..cf8dc5eccc0c
--- /dev/null
+++ b/.github/scripts/install_pytorch_wheels.py
@@ -0,0 +1,306 @@
+#!/usr/bin/env python3
+"""
+install_pytorch_wheels.py
+
+Installs PyTorch wheels from a pip index URL.
+
+Usage (from repo root):
+ python .github/scripts/install_pytorch_wheels.py --index-url --amdgpu-family [OPTIONS]
+
+Examples:
+ # Install latest versions
+ python .github/scripts/install_pytorch_wheels.py \
+ --index-url /whl \
+ --amdgpu-family gfx1250
+
+ # Install specific versions (matching ROCm builds)
+ python .github/scripts/install_pytorch_wheels.py \
+ --index-url /whl \
+ --amdgpu-family gfx1250 \
+ --torch-version "2.10.0+devrocm7.12.0.dev0.849eec43b..." \
+ --torchaudio-version "2.11.0a0+devrocm7.12.0.dev0.849eec43b..." \
+ --torchvision-version "0.25.0a0+devrocm7.12.0.dev0.849eec43b..."
+"""
+
+import argparse
+import re
+import subprocess
+import sys
+import urllib.parse
+import urllib.request
+
+
+# Package configuration: (name, always_install)
+PACKAGES = {
+ "torch": True,
+ "torchaudio": True,
+ "torchvision": True,
+ "triton": False,
+ "rocm[devel]": True,
+}
+PYTORCH_PKGS = ["torch", "torchaudio", "torchvision", "triton"]
+
+
+def print_banner(title: str) -> None:
+ """Print a formatted banner."""
+ print("=" * 50)
+ print(title)
+ print("=" * 50)
+
+
+def build_package_spec(name: str, version: str | None) -> str:
+ """Build a pip package spec (e.g., 'torch==2.10.0' or 'torch')."""
+ return f"{name}=={version}" if version else name
+
+def get_latest_package_version_for_rocm(
+ index_url: str, package_name: str, rocm_version: str, required: bool = True,
+ version_prefix: str | None = None,
+) -> str | None:
+ """Return latest package version containing rocm_version by parsing the index HTML.
+
+ If version_prefix is set (e.g. "2.9"), only versions whose base part starts
+ with that prefix are considered.
+ """
+
+ # Build the URL for this package's index page (e.g. .../gfx1250/torch/).
+ rocm_tag = f"rocm{rocm_version}"
+ url = f"{index_url.rstrip('/')}/{package_name}/"
+ # Fetch the package index page; on failure (e.g. 404, timeout) fail if always_install, else return None.
+ try:
+ with urllib.request.urlopen(url, timeout=30) as resp:
+ html = resp.read().decode("utf-8", errors="ignore")
+ except Exception as e:
+ print(f"Error: failed to fetch index for {package_name}: {e}", file=sys.stderr)
+ sys.exit(1)
+ # Parse wheel links: format is package-VERSION-...whl (e.g. torch-0.26.0a0+rocm7.12...-cp312-....whl).
+ # Version can contain dots and + (URL-encoded as %2B), so we capture everything up to .whl.
+ pattern = re.compile(
+ re.escape(package_name) + r"-(.+?)\.whl",
+ re.IGNORECASE,
+ )
+ all_suffixes = [m.group(1).strip() for m in pattern.finditer(html)]
+ # Keep only wheels whose version string contains the requested ROCm tag (e.g. rocm7.12.0a20260224).
+ # Version is the first segment before "-" in the suffix; decode %2B to + for comparison.
+ matching = []
+ for s in all_suffixes:
+ ver = s.split("-")[0]
+ if rocm_tag in ver:
+ matching.append(urllib.parse.unquote(ver))
+ # Filter by version prefix (e.g. "2.9" matches "2.9.0+...", "2.9.1+...").
+ if version_prefix and matching:
+ matching = [v for v in matching if v.split("+")[0].startswith(version_prefix)]
+ # No matching wheels: if required (always_install), fail; otherwise return None (package will be skipped).
+ if not matching:
+ if required:
+ msg = f"Error: no wheel found for {package_name} with ROCm {rocm_version}"
+ if version_prefix:
+ msg += f" and version prefix {version_prefix}"
+ print(msg, file=sys.stderr)
+ sys.exit(1)
+ return None
+ # Pick the latest version by comparing all numeric parts including the ROCm date.
+ def _key(v: str) -> tuple[int, ...]:
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+ return max(matching, key=_key)
+
+
+def run_pip_install(
+ index_url: str, packages: list[str], break_system_packages: bool = True
+) -> None:
+ """Run pip install with the given packages."""
+ cmd = [sys.executable, "-m", "pip", "install", "--index-url", index_url]
+
+ if break_system_packages:
+ cmd.append("--break-system-packages")
+
+ cmd.extend(packages)
+
+ print(f"Running: {' '.join(cmd)}")
+ result = subprocess.run(cmd, check=False)
+
+ if result.returncode != 0:
+ print(f"Error: pip install failed with return code {result.returncode}")
+ sys.exit(result.returncode)
+
+
+def check_package(name: str) -> tuple[bool, str | None]:
+ """Check if a package is installed and return (installed, version)."""
+ try:
+ module = __import__(name)
+ return True, getattr(module, "__version__", "unknown")
+ except ImportError:
+ return False, None
+
+
+def verify_installation() -> bool:
+ """Verify PyTorch installation and print version info."""
+ print_banner("Verifying Installation")
+
+ # Check torch separately for ROCm info
+ try:
+ import torch as _torch
+
+ version = getattr(_torch, "__version__", "unknown")
+ except ImportError as e:
+ print(f"Error: torch import failed ({e!r}). If wheels are installed, run rocm-sdk init first.")
+ return False
+
+ print(f"torch: {version}")
+
+ hip_version = _torch.version.hip
+ print(f"ROCm/HIP: {hip_version or 'not available'}")
+ print(f"Built with ROCm: {hip_version is not None}")
+
+ # Check other packages
+ for name in ["torchaudio", "torchvision", "triton", "rocm"]:
+ installed, version = check_package(name)
+ status = version if installed else "not installed"
+ print(f"{name}: {status}")
+
+ return True
+
+
+def list_installed_packages() -> None:
+ """List installed torch-related packages."""
+ print("\nInstalled PyTorch packages:")
+ result = subprocess.run(
+ [sys.executable, "-m", "pip", "list"],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ if result.returncode == 0:
+ keywords = ["torch", "triton", "rocm"]
+ for line in result.stdout.splitlines():
+ if any(kw in line.lower() for kw in keywords):
+ print(f" {line}")
+
+
+def main() -> int:
+ """Main entry point."""
+ parser = argparse.ArgumentParser(
+ description="Install PyTorch wheels from a pip index URL",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+
+ parser.add_argument(
+ "--index-url", required=True, help="Base URL for PyTorch wheels index"
+ )
+ parser.add_argument(
+ "--amdgpu-family", required=True, help="AMD GPU family (e.g., gfx1250)"
+ )
+ parser.add_argument(
+ "--rocm-version",
+ help="Optional. ROCm version (e.g. 7.12.0a20260126). When set without --torch-version: discovers and installs latest torch/torchaudio/torchvision/triton built for this ROCm. ",
+ )
+ parser.add_argument(
+ "--torch-version", help="Specific torch version (default: latest)"
+ )
+ parser.add_argument(
+ "--torch-version-prefix",
+ help="Torch version prefix for discovery (e.g. '2.9' matches 2.9.x). "
+ "Only used in auto-discovery mode (--rocm-version without --torch-version).",
+ )
+ parser.add_argument(
+ "--torchaudio-version", help="Specific torchaudio version (default: latest)"
+ )
+ parser.add_argument(
+ "--torchvision-version", help="Specific torchvision version (default: latest)"
+ )
+ parser.add_argument(
+ "--triton-version",
+ help="Specific triton version (default: from torch dependency)",
+ )
+ parser.add_argument(
+ "--no-break-system-packages",
+ action="store_true",
+ help="Don't use --break-system-packages",
+ )
+ parser.add_argument(
+ "--skip-verify", action="store_true", help="Skip verification step"
+ )
+
+ args = parser.parse_args()
+
+ # Build the full index URL
+ index_url = f"{args.index_url.rstrip('/')}/{args.amdgpu_family}/"
+
+ rocm = args.rocm_version
+ rocm_only = bool(rocm and not args.torch_version)
+ torch_prefix = args.torch_version_prefix if rocm_only else None
+ break_sys = not args.no_break_system_packages
+
+ if rocm_only:
+ # Two-pass install:
+ # Pass 1: torch (pinned) + rocm[devel] (pinned)
+ # Pass 2: torchaudio, torchvision, triton (unpinned — pip resolves compatibility)
+ torch_version = get_latest_package_version_for_rocm(
+ index_url, "torch", rocm, required=True, version_prefix=torch_prefix,
+ )
+
+ print_banner("PyTorch Wheels Installation")
+ print(f"Index URL: {index_url}")
+ print(f"AMDGPU Family: {args.amdgpu_family}")
+ print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
+ print(f"torch: {torch_version}")
+ print(f"rocm[devel]: {rocm}")
+ print(f"torchaudio: (pip resolves)")
+ print(f"torchvision: (pip resolves)")
+ print(f"triton: (torch dependency)")
+ print("=" * 50)
+
+ # Pass 1: install torch + rocm[devel] with exact versions.
+ # torch's declared dependency on triton pulls in the correct build.
+ primary = [
+ build_package_spec("torch", torch_version),
+ build_package_spec("rocm[devel]", rocm),
+ ]
+ print_banner("Pass 1: torch + rocm[devel]")
+ print(f"Installing: {', '.join(primary)}")
+ run_pip_install(index_url, primary, break_sys)
+
+ # Pass 2: install torchaudio/torchvision without pinning — pip picks
+ # versions compatible with the torch that's already installed
+ companions = ["torchaudio", "torchvision"]
+ print_banner("Pass 2: torchaudio, torchvision (unpinned)")
+ print(f"Installing: {', '.join(companions)}")
+ run_pip_install(index_url, companions, break_sys)
+ else:
+ # Explicit versions mode — install everything in one shot
+ arg_attrs = ["torch_version", "torchaudio_version", "torchvision_version", "triton_version"]
+ versions = {p: getattr(args, a) for p, a in zip(PYTORCH_PKGS, arg_attrs)}
+ versions["rocm[devel]"] = rocm if rocm else None
+
+ print_banner("PyTorch Wheels Installation")
+ print(f"Index URL: {index_url}")
+ print(f"AMDGPU Family: {args.amdgpu_family}")
+ print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
+ for name, version in versions.items():
+ print(f"{name:14}: {version or 'latest'}")
+ print("=" * 50)
+
+ packages = []
+ for name, always_install in PACKAGES.items():
+ version = versions.get(name)
+ if always_install or version:
+ packages.append(build_package_spec(name, version))
+
+ print(f"Installing: {', '.join(packages)}")
+ run_pip_install(index_url, packages, break_sys)
+
+ # Verify
+ if not args.skip_verify and not verify_installation():
+ return 1
+
+ list_installed_packages()
+ print_banner("Installation complete")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/scripts/install_rocm_deps.sh b/.github/scripts/install_rocm_deps.sh
new file mode 100644
index 000000000000..e4c0fd91a106
--- /dev/null
+++ b/.github/scripts/install_rocm_deps.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+# install_rocm_deps.sh
+#
+# Installs runtime dependencies for ROCm on various Linux distributions.
+# Automatically detects the distribution and uses the appropriate package manager.
+#
+# Supported distributions:
+# - Ubuntu 22.04, 24.04 (apt)
+# - AlmaLinux 8 (dnf)
+# - Azure Linux 3 (tdnf)
+
+set -e
+
+# Detect distribution type from /etc/os-release
+detect_distro() {
+ if [ -f /etc/os-release ]; then
+ . /etc/os-release
+ echo "$ID"
+ else
+ echo "unknown"
+ fi
+}
+
+DISTRO=$(detect_distro)
+echo "Detected distribution: $DISTRO"
+
+case "$DISTRO" in
+ ubuntu)
+ echo "Installing dependencies using apt..."
+ apt-get update
+ apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ build-essential \
+ libelf1 \
+ libnuma1 \
+ libunwind8 \
+ libncurses6 \
+ perl \
+ file \
+ nano \
+ git \
+ python3 \
+ python3-dev \
+ python3-pip \
+ python3-venv \
+ kmod \
+ pkg-config \
+ liblzma-dev \
+ libdrm-dev
+ # libdw: libdw1t64 for Ubuntu 24.04+, libdw1 for older versions
+ apt-get install -y --no-install-recommends libdw1t64 2>/dev/null || \
+ apt-get install -y --no-install-recommends libdw1 || true
+ # libssl: libssl3 for Ubuntu 22.04+, libssl1.1 for older versions
+ apt-get install -y --no-install-recommends libssl3 2>/dev/null || \
+ apt-get install -y --no-install-recommends libssl1.1 || true
+ rm -rf /var/lib/apt/lists/*
+ ;;
+
+ almalinux)
+ echo "Installing dependencies using dnf..."
+ # Fix AlmaLinux repo to use direct baseurl instead of mirrorlist
+ if [ -f /etc/yum.repos.d/almalinux.repo ]; then
+ sed -i 's/^mirrorlist=/#mirrorlist=/g' /etc/yum.repos.d/almalinux.repo
+ sed -i 's/^# baseurl=/baseurl=/g' /etc/yum.repos.d/almalinux.repo
+ fi
+ dnf install -y --setopt=install_weak_deps=False \
+ ca-certificates \
+ curl \
+ libatomic \
+ elfutils-libelf \
+ elfutils-libs \
+ numactl-libs \
+ ncurses-libs \
+ openssl-libs \
+ perl \
+ file \
+ python3 \
+ python3-devel \
+ python3-pip \
+ kmod
+ dnf clean all
+ ;;
+
+ azurelinux)
+ echo "Installing dependencies using tdnf..."
+ tdnf install -y \
+ ca-certificates \
+ curl \
+ tar \
+ libatomic \
+ elfutils-libelf \
+ elfutils-libs \
+ numactl-libs \
+ libunwind \
+ ncurses-libs \
+ openssl-libs \
+ perl \
+ file \
+ python3 \
+ python3-devel \
+ python3-pip \
+ kmod
+ tdnf clean all
+ ;;
+
+ *)
+ echo "Error: Unsupported distribution: $DISTRO"
+ echo "Supported distributions: ubuntu, almalinux, azurelinux"
+ exit 1
+ ;;
+esac
+
+echo "Dependencies installed successfully for $DISTRO"
diff --git a/.github/workflows/build_portable_linux_pytorch_dockers.yml b/.github/workflows/build_portable_linux_pytorch_dockers.yml
new file mode 100644
index 000000000000..d5c9a94c3b1a
--- /dev/null
+++ b/.github/workflows/build_portable_linux_pytorch_dockers.yml
@@ -0,0 +1,427 @@
+name: Build Portable Linux PyTorch Dockers
+
+on:
+ schedule:
+ - cron: "0 6 * * *" # daily at 06:00 UTC
+ workflow_dispatch:
+ inputs:
+ pytorch_repo:
+ description: "GitHub repo to clone into the image (e.g. 'pytorch/pytorch' or 'ROCm/pytorch')"
+ type: string
+ default: "pytorch/pytorch"
+ pytorch_branch:
+ description: "Branch to clone. Default 'nightly' matches theRock wheel builds. For releases use ROCm/pytorch with 'release/2.11', 'release/2.10', etc."
+ type: string
+ default: "nightly"
+ python_version:
+ type: choice
+ options:
+ - "3.12"
+ - "3.10"
+ - "3.11"
+ - "3.13"
+ - "3.14"
+ default: "3.12"
+ amdgpu_family:
+ type: choice
+ options:
+ - gfx950-dcgpu
+ - gfx94X-dcgpu
+ - gfx90X-dcgpu
+ - gfx120X-all
+ - gfx110X-all
+ - gfx110X-dgpu
+ - gfx103X-dgpu
+ - gfx101X-dgpu
+ default: gfx94X-dcgpu
+ rocm_version:
+ description: "ROCm version (e.g. '7.13.0a20260413'). Leave empty to auto-discover from the latest available torch wheel."
+ type: string
+ index_url:
+ description: Base URL for PyTorch wheels index
+ type: string
+ default: "https://rocm.nightlies.amd.com/v2-staging"
+
+permissions:
+ contents: read
+
+run-name: >-
+ ${{ github.event_name == 'schedule' && 'Nightly Docker builds' ||
+ format('Build PyTorch Docker ({0}, {1}/{2}, ROCm {3})',
+ inputs.amdgpu_family || 'gfx950-dcgpu',
+ inputs.pytorch_repo || 'pytorch/pytorch',
+ inputs.pytorch_branch || 'nightly',
+ inputs.rocm_version || 'auto') }}
+
+env:
+ REGISTRY: docker.io
+ IMAGE_NAME: rocm/pytorch-private
+ DEFAULT_AMDGPU_FAMILY: gfx950-dcgpu
+ DEFAULT_PYTHON_VERSION: "3.12"
+ DEFAULT_INDEX_URL: "https://rocm.nightlies.amd.com/v2-staging"
+ DEFAULT_BASE_IMAGE: "ubuntu:24.04"
+
+jobs:
+ # ── Nightly matrix build (schedule only) ─────────────────────────────────
+ nightly-matrix:
+ if: github.event_name == 'schedule'
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - pytorch_repo: pytorch/pytorch
+ pytorch_branch: nightly
+ label: nightly
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.11
+ label: "2.11"
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.10
+ label: "2.10"
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.9
+ label: "2.9"
+ name: "Nightly | torch ${{ matrix.label }} | MI355"
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout workflow files
+ uses: actions/checkout@v4
+
+ - name: Checkout PyTorch source
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ matrix.pytorch_repo }}
+ ref: ${{ matrix.pytorch_branch }}
+ path: pytorch-src
+ fetch-depth: 1
+
+ - name: Derive torch version prefix from branch
+ id: prefix
+ run: |
+ BRANCH="${{ matrix.pytorch_branch }}"
+ if [[ "$BRANCH" =~ ^release/([0-9]+\.[0-9]+) ]]; then
+ echo "value=${BASH_REMATCH[1]}" >> $GITHUB_OUTPUT
+ echo "Derived torch prefix: ${BASH_REMATCH[1]}"
+ else
+ echo "value=" >> $GITHUB_OUTPUT
+ echo "No prefix (nightly/main branch)"
+ fi
+
+ - name: Discover ROCm version from index
+ id: discover
+ run: |
+ python3 - "${{ env.DEFAULT_INDEX_URL }}" "${{ env.DEFAULT_AMDGPU_FAMILY }}" "${{ steps.prefix.outputs.value }}" <<'PYEOF'
+ import re, sys, urllib.request, urllib.parse
+
+ index_url, gpu_family = sys.argv[1], sys.argv[2]
+ prefix = sys.argv[3] if len(sys.argv) > 3 else ""
+
+ url = f"{index_url.rstrip('/')}/{gpu_family}/torch/"
+ print(f"Fetching torch index: {url}")
+ html = urllib.request.urlopen(url, timeout=60).read().decode()
+
+ pattern = re.compile(r"torch-(.+?)\.whl", re.IGNORECASE)
+ versions = []
+ for m in pattern.finditer(html):
+ ver = urllib.parse.unquote(m.group(1).split("-")[0])
+ if "+rocm" in ver:
+ versions.append(ver)
+
+ if prefix:
+ versions = [v for v in versions if v.split("+")[0].startswith(prefix)]
+
+ if not versions:
+ print(f"::error::No torch wheels found (prefix={prefix!r})")
+ sys.exit(1)
+
+ def key(v):
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+
+ latest = max(versions, key=key)
+ rocm_ver = re.search(r"\+rocm(.+)", latest).group(1)
+
+ print(f"Latest torch wheel: {latest}")
+ print(f"Discovered ROCm version: {rocm_ver}")
+
+ import os
+ with open(os.environ["GITHUB_OUTPUT"], "a") as f:
+ f.write(f"rocm_version={rocm_ver}\n")
+ f.write(f"torch_wheel_version={latest}\n")
+ PYEOF
+
+ - name: Resolve config
+ id: cfg
+ run: |
+ echo "amdgpu_family=${{ env.DEFAULT_AMDGPU_FAMILY }}" >> $GITHUB_OUTPUT
+ echo "python_version=${{ env.DEFAULT_PYTHON_VERSION }}" >> $GITHUB_OUTPUT
+ echo "rocm_version=${{ steps.discover.outputs.rocm_version }}" >> $GITHUB_OUTPUT
+ echo "index_url=${{ env.DEFAULT_INDEX_URL }}" >> $GITHUB_OUTPUT
+ echo "base_image=${{ env.DEFAULT_BASE_IMAGE }}" >> $GITHUB_OUTPUT
+ echo "torch_prefix=${{ steps.prefix.outputs.value }}" >> $GITHUB_OUTPUT
+ echo "pytorch_repo=${{ matrix.pytorch_repo }}" >> $GITHUB_OUTPUT
+ echo "pytorch_branch=${{ matrix.pytorch_branch }}" >> $GITHUB_OUTPUT
+
+ COMMIT="$(cd pytorch-src && git rev-parse --short=8 HEAD)"
+ echo "pytorch_commit=${COMMIT}" >> $GITHUB_OUTPUT
+
+ - name: Generate Docker image tag
+ id: docker-tag
+ run: |
+ BRANCH="${{ matrix.pytorch_branch }}"
+ BRANCH_SAFE="${BRANCH//\//-}"
+ COMMIT="${{ steps.cfg.outputs.pytorch_commit }}"
+ ROCM_VERSION="${{ steps.cfg.outputs.rocm_version }}"
+ PYTHON_VERSION="${{ steps.cfg.outputs.python_version }}"
+ GFX="${{ steps.cfg.outputs.amdgpu_family }}"
+ BASE_IMAGE="${{ steps.cfg.outputs.base_image }}"
+ OS=$(echo "${BASE_IMAGE}" | tr -d ':' | tr '/' '-')
+
+ IMAGE_TAG="pytorch-${BRANCH_SAFE}-${COMMIT}-rocm${ROCM_VERSION}-${OS}-py${PYTHON_VERSION}-${GFX}"
+ IMAGE_TAG="${IMAGE_TAG//+/-}"
+ echo "tag=${IMAGE_TAG}" >> $GITHUB_OUTPUT
+ echo "Generated image tag: ${IMAGE_TAG}"
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERUSERNAME }}
+ password: ${{ secrets.DOCKERTOKEN }}
+
+ - name: Prepare build context
+ run: |
+ cp dockerfiles/Dockerfile pytorch-src/
+ mkdir -p pytorch-src/.github/scripts
+ cp .github/scripts/install_rocm_deps.sh pytorch-src/.github/scripts/
+ cp .github/scripts/install_pytorch_wheels.py pytorch-src/.github/scripts/
+
+ - name: Build Docker image
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+
+ docker build \
+ --file pytorch-src/Dockerfile \
+ --tag "${IMAGE}" \
+ --label "pytorch.repo=${{ matrix.pytorch_repo }}" \
+ --label "pytorch.branch=${{ matrix.pytorch_branch }}" \
+ --label "pytorch.commit=${{ steps.cfg.outputs.pytorch_commit }}" \
+ --build-arg "BASE_IMAGE=${{ steps.cfg.outputs.base_image }}" \
+ --build-arg "ROCM_VERSION=${{ steps.cfg.outputs.rocm_version }}" \
+ --build-arg "AMDGPU_FAMILY=${{ steps.cfg.outputs.amdgpu_family }}" \
+ --build-arg "PYTHON_VERSION=${{ steps.cfg.outputs.python_version }}" \
+ --build-arg "INDEX_URL=${{ steps.cfg.outputs.index_url }}" \
+ --build-arg "TORCH_VERSION_PREFIX=${{ steps.prefix.outputs.value }}" \
+ pytorch-src
+
+ echo "Docker image built successfully: ${IMAGE}"
+
+ - name: Get ROCm packages info
+ id: rocm-packages
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ ROCM_PACKAGES=$(docker run --rm "${IMAGE}" pip freeze | grep -i rocm || echo "No ROCm packages found")
+ echo "rocm_packages<> $GITHUB_OUTPUT
+ echo "${ROCM_PACKAGES}" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+ echo "ROCm packages:"
+ echo "${ROCM_PACKAGES}"
+
+ - name: Push Docker image
+ run: |
+ docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}
+ echo "Docker image pushed successfully"
+
+ - name: Post-build summary
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ echo "## PyTorch Docker Build Summary — ${{ matrix.label }}" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "| Parameter | Value |" >> $GITHUB_STEP_SUMMARY
+ echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY
+ echo "| Image | \`${IMAGE}\` |" >> $GITHUB_STEP_SUMMARY
+ echo "| Torch Wheel | ${{ steps.discover.outputs.torch_wheel_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Repo | ${{ matrix.pytorch_repo }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Branch | ${{ matrix.pytorch_branch }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Commit | ${{ steps.cfg.outputs.pytorch_commit }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| AMDGPU Family | ${{ steps.cfg.outputs.amdgpu_family }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Python | ${{ steps.cfg.outputs.python_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| ROCm (discovered) | ${{ steps.cfg.outputs.rocm_version }} |" >> $GITHUB_STEP_SUMMARY
+
+ # ── Single image build (manual dispatch) ──────────────────────────────────
+ build-docker:
+ if: github.event_name == 'workflow_dispatch'
+ name: "Build | ${{ inputs.amdgpu_family }} | ${{ inputs.pytorch_repo || 'pytorch/pytorch' }}@${{ inputs.pytorch_branch || 'nightly' }}"
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout workflow files
+ uses: actions/checkout@v4
+
+ - name: Checkout PyTorch source
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ inputs.pytorch_repo || 'pytorch/pytorch' }}
+ ref: ${{ inputs.pytorch_branch || 'nightly' }}
+ path: pytorch-src
+ fetch-depth: 1
+
+ - name: Derive torch version prefix from branch
+ id: prefix
+ run: |
+ BRANCH="${{ inputs.pytorch_branch || 'nightly' }}"
+ if [[ "$BRANCH" =~ ^release/([0-9]+\.[0-9]+) ]]; then
+ echo "value=${BASH_REMATCH[1]}" >> $GITHUB_OUTPUT
+ echo "Derived torch prefix: ${BASH_REMATCH[1]}"
+ else
+ echo "value=" >> $GITHUB_OUTPUT
+ echo "No prefix (nightly/main branch)"
+ fi
+
+ - name: Discover ROCm version from index
+ id: discover
+ if: ${{ !inputs.rocm_version }}
+ run: |
+ python3 - "${{ inputs.index_url || env.DEFAULT_INDEX_URL }}" "${{ inputs.amdgpu_family || env.DEFAULT_AMDGPU_FAMILY }}" "${{ steps.prefix.outputs.value }}" <<'PYEOF'
+ import re, sys, urllib.request, urllib.parse
+
+ index_url, gpu_family = sys.argv[1], sys.argv[2]
+ prefix = sys.argv[3] if len(sys.argv) > 3 else ""
+
+ url = f"{index_url.rstrip('/')}/{gpu_family}/torch/"
+ print(f"Fetching torch index: {url}")
+ html = urllib.request.urlopen(url, timeout=60).read().decode()
+
+ pattern = re.compile(r"torch-(.+?)\.whl", re.IGNORECASE)
+ versions = []
+ for m in pattern.finditer(html):
+ ver = urllib.parse.unquote(m.group(1).split("-")[0])
+ if "+rocm" in ver:
+ versions.append(ver)
+
+ if prefix:
+ versions = [v for v in versions if v.split("+")[0].startswith(prefix)]
+
+ if not versions:
+ print(f"::error::No torch wheels found (prefix={prefix!r})")
+ sys.exit(1)
+
+ def key(v):
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+
+ latest = max(versions, key=key)
+ rocm_ver = re.search(r"\+rocm(.+)", latest).group(1)
+
+ print(f"Latest torch wheel: {latest}")
+ print(f"Discovered ROCm version: {rocm_ver}")
+
+ import os
+ with open(os.environ["GITHUB_OUTPUT"], "a") as f:
+ f.write(f"rocm_version={rocm_ver}\n")
+ f.write(f"torch_wheel_version={latest}\n")
+ PYEOF
+
+ - name: Resolve inputs with defaults
+ id: cfg
+ run: |
+ echo "amdgpu_family=${{ inputs.amdgpu_family || env.DEFAULT_AMDGPU_FAMILY }}" >> $GITHUB_OUTPUT
+ echo "python_version=${{ inputs.python_version || env.DEFAULT_PYTHON_VERSION }}" >> $GITHUB_OUTPUT
+
+ # Use explicit rocm_version if provided, otherwise use discovered version
+ ROCM="${{ inputs.rocm_version || steps.discover.outputs.rocm_version }}"
+ echo "rocm_version=${ROCM}" >> $GITHUB_OUTPUT
+
+ echo "index_url=${{ inputs.index_url || env.DEFAULT_INDEX_URL }}" >> $GITHUB_OUTPUT
+ echo "base_image=${{ env.DEFAULT_BASE_IMAGE }}" >> $GITHUB_OUTPUT
+ echo "torch_prefix=${{ steps.prefix.outputs.value }}" >> $GITHUB_OUTPUT
+ echo "pytorch_repo=${{ inputs.pytorch_repo || 'pytorch/pytorch' }}" >> $GITHUB_OUTPUT
+ echo "pytorch_branch=${{ inputs.pytorch_branch || 'nightly' }}" >> $GITHUB_OUTPUT
+
+ COMMIT="$(cd pytorch-src && git rev-parse --short=8 HEAD)"
+ echo "pytorch_commit=${COMMIT}" >> $GITHUB_OUTPUT
+
+ - name: Generate Docker image tag
+ id: docker-tag
+ run: |
+ BRANCH="${{ steps.cfg.outputs.pytorch_branch }}"
+ BRANCH_SAFE="${BRANCH//\//-}"
+ COMMIT="${{ steps.cfg.outputs.pytorch_commit }}"
+ ROCM_VERSION="${{ steps.cfg.outputs.rocm_version }}"
+ PYTHON_VERSION="${{ steps.cfg.outputs.python_version }}"
+ GFX="${{ steps.cfg.outputs.amdgpu_family }}"
+ BASE_IMAGE="${{ steps.cfg.outputs.base_image }}"
+ OS=$(echo "${BASE_IMAGE}" | tr -d ':' | tr '/' '-')
+
+ IMAGE_TAG="pytorch-${BRANCH_SAFE}-${COMMIT}-rocm${ROCM_VERSION}-${OS}-py${PYTHON_VERSION}-${GFX}"
+ IMAGE_TAG="${IMAGE_TAG//+/-}"
+ echo "tag=${IMAGE_TAG}" >> $GITHUB_OUTPUT
+ echo "Generated image tag: ${IMAGE_TAG}"
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERUSERNAME }}
+ password: ${{ secrets.DOCKERTOKEN }}
+
+ - name: Prepare build context
+ run: |
+ cp dockerfiles/Dockerfile pytorch-src/
+ mkdir -p pytorch-src/.github/scripts
+ cp .github/scripts/install_rocm_deps.sh pytorch-src/.github/scripts/
+ cp .github/scripts/install_pytorch_wheels.py pytorch-src/.github/scripts/
+
+ - name: Build Docker image
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+
+ docker build \
+ --file pytorch-src/Dockerfile \
+ --tag "${IMAGE}" \
+ --label "pytorch.repo=${{ steps.cfg.outputs.pytorch_repo }}" \
+ --label "pytorch.branch=${{ steps.cfg.outputs.pytorch_branch }}" \
+ --label "pytorch.commit=${{ steps.cfg.outputs.pytorch_commit }}" \
+ --build-arg "BASE_IMAGE=${{ steps.cfg.outputs.base_image }}" \
+ --build-arg "ROCM_VERSION=${{ steps.cfg.outputs.rocm_version }}" \
+ --build-arg "AMDGPU_FAMILY=${{ steps.cfg.outputs.amdgpu_family }}" \
+ --build-arg "PYTHON_VERSION=${{ steps.cfg.outputs.python_version }}" \
+ --build-arg "INDEX_URL=${{ steps.cfg.outputs.index_url }}" \
+ --build-arg "TORCH_VERSION_PREFIX=${{ steps.cfg.outputs.torch_prefix }}" \
+ pytorch-src
+
+ echo "Docker image built successfully: ${IMAGE}"
+
+ - name: Get ROCm packages info
+ id: rocm-packages
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ ROCM_PACKAGES=$(docker run --rm "${IMAGE}" pip freeze | grep -i rocm || echo "No ROCm packages found")
+ echo "rocm_packages<> $GITHUB_OUTPUT
+ echo "${ROCM_PACKAGES}" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+ echo "ROCm packages:"
+ echo "${ROCM_PACKAGES}"
+
+ - name: Push Docker image
+ run: |
+ docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}
+ echo "Docker image pushed successfully"
+
+ - name: Post-build summary
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ echo "## PyTorch Docker Build Summary" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "| Parameter | Value |" >> $GITHUB_STEP_SUMMARY
+ echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY
+ echo "| Image | \`${IMAGE}\` |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Repo | ${{ steps.cfg.outputs.pytorch_repo }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Branch | ${{ steps.cfg.outputs.pytorch_branch }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Commit | ${{ steps.cfg.outputs.pytorch_commit }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| AMDGPU Family | ${{ steps.cfg.outputs.amdgpu_family }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Python | ${{ steps.cfg.outputs.python_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| ROCm | ${{ steps.cfg.outputs.rocm_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Torch Version Prefix | ${{ steps.cfg.outputs.torch_prefix || 'latest' }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Index URL | ${{ steps.cfg.outputs.index_url }} |" >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/create_ifu_issues.yml b/.github/workflows/create_ifu_issues.yml
new file mode 100644
index 000000000000..8e2e7da07ab4
--- /dev/null
+++ b/.github/workflows/create_ifu_issues.yml
@@ -0,0 +1,352 @@
+name: Create issues for ROCm commits
+
+on:
+ # Manual trigger for testing
+ workflow_dispatch:
+ inputs:
+ prev_post_tag:
+ description: "Issue range start ref (previous IFU post tag or cold-start SHA)"
+ required: true
+ type: string
+ curr_pre_tag:
+ description: "Current IFU pre tag"
+ required: true
+ type: string
+ target_repo:
+ description: "Target repo for issue creation"
+ required: false
+ default: "ROCm/pytorch"
+ type: string
+ project_number:
+ description: "GitHub Project number"
+ required: false
+ default: "114"
+ type: string
+ project_owner:
+ description: "Project owner"
+ required: false
+ default: "ROCm"
+ type: string
+
+ # Called by create_ifu_tag.yml after tagging
+ workflow_call:
+ inputs:
+ prev_post_tag:
+ description: "Issue range start ref (previous IFU post tag or cold-start SHA)"
+ required: true
+ type: string
+ curr_pre_tag:
+ description: "Current IFU pre tag"
+ required: true
+ type: string
+ target_repo:
+ description: "Target repo for issue creation"
+ required: false
+ default: "ROCm/pytorch"
+ type: string
+ project_number:
+ description: "GitHub Project number"
+ required: false
+ default: "114"
+ type: string
+ project_owner:
+ description: "Project owner"
+ required: false
+ default: "ROCm"
+ type: string
+ secrets:
+ IFU_GITHUB_TOKEN:
+ required: true
+
+permissions:
+ contents: read
+ issues: write
+
+jobs:
+ create-issues:
+ runs-on: ubuntu-latest
+ env:
+ # Use passed secret for workflow_call, direct secret for workflow_dispatch
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Fetch tags
+ run: git fetch origin --tags --force
+
+ - name: Extract branch from tag
+ id: parse
+ env:
+ CURR_PRE_TAG: ${{ inputs.curr_pre_tag }}
+ run: |
+ branch="${CURR_PRE_TAG%_IFU_*}"
+ echo "Branch: $branch"
+ echo "branch=$branch" >> $GITHUB_OUTPUT
+
+ - name: Fetch upstream
+ run: |
+ git remote add upstream https://github.com/pytorch/pytorch.git 2>/dev/null || true
+ git fetch upstream main --force
+
+ - name: List commits in range
+ run: |
+ echo "ROCm-only commits between ${{ inputs.prev_post_tag }} and ${{ inputs.curr_pre_tag }}:"
+ git log ${{ inputs.prev_post_tag }}..${{ inputs.curr_pre_tag }} --oneline --no-merges --not upstream/main
+
+ - name: Get or create project fields
+ id: project_fields
+ if: ${{ inputs.project_number != '' }}
+ env:
+ PROJECT_NUMBER: ${{ inputs.project_number }}
+ PROJECT_OWNER: ${{ inputs.project_owner }}
+ run: |
+ echo "Getting project information..."
+
+ # Try user-owned project first.
+ project_data=$(gh api graphql -f query='
+ query($owner: String!, $number: Int!) {
+ user(login: $owner) {
+ projectV2(number: $number) {
+ id
+ fields(first: 50) {
+ nodes {
+ ... on ProjectV2Field {
+ id
+ name
+ dataType
+ }
+ ... on ProjectV2SingleSelectField {
+ id
+ name
+ dataType
+ }
+ }
+ }
+ }
+ }
+ }' -f owner="${PROJECT_OWNER}" -F number="${PROJECT_NUMBER}" 2>/dev/null || true)
+
+ project_id=$(echo "$project_data" | jq -r '.data.user.projectV2.id // empty' 2>/dev/null || true)
+ echo "User project ID: ${project_id:-'(none)'}"
+
+ if [[ -z "$project_id" ]]; then
+ echo "User project not found (or owner is an org). Trying organization query..."
+ project_data=$(gh api graphql -f query='
+ query($owner: String!, $number: Int!) {
+ organization(login: $owner) {
+ projectV2(number: $number) {
+ id
+ fields(first: 50) {
+ nodes {
+ ... on ProjectV2Field {
+ id
+ name
+ dataType
+ }
+ ... on ProjectV2SingleSelectField {
+ id
+ name
+ dataType
+ }
+ }
+ }
+ }
+ }
+ }' -f owner="${PROJECT_OWNER}" -F number="${PROJECT_NUMBER}" 2>/dev/null || true)
+
+ project_id=$(echo "$project_data" | jq -r '.data.organization.projectV2.id // empty' 2>/dev/null || true)
+ fields_json=$(echo "$project_data" | jq -r '.data.organization.projectV2.fields.nodes // empty' 2>/dev/null || true)
+ else
+ fields_json=$(echo "$project_data" | jq -r '.data.user.projectV2.fields.nodes // empty' 2>/dev/null || true)
+ fi
+
+ if [[ -z "$project_id" || -z "$fields_json" ]]; then
+ echo "Error: Could not resolve project owner '${PROJECT_OWNER}' project #${PROJECT_NUMBER}."
+ echo "If PROJECT_OWNER is an organization, ensure PROJECT_OWNER is exactly the org login and token has org access."
+ exit 1
+ fi
+
+ echo "Project ID: $project_id"
+ echo "project_id=$project_id" >> $GITHUB_OUTPUT
+
+ # Find or create 'branch' field
+ branch_field_id=$(echo "$fields_json" | jq -r '.[] | select(.name == "branch") | .id')
+ if [[ -z "$branch_field_id" || "$branch_field_id" == "null" ]]; then
+ echo "Creating 'branch' field..."
+ branch_field_id=$(gh api graphql -f query='
+ mutation($projectId: ID!, $name: String!) {
+ createProjectV2Field(input: {projectId: $projectId, dataType: TEXT, name: $name}) {
+ projectV2Field {
+ ... on ProjectV2Field {
+ id
+ }
+ }
+ }
+ }' -f projectId="$project_id" -f name="branch" --jq '.data.createProjectV2Field.projectV2Field.id')
+ echo "Created 'branch' field: $branch_field_id"
+ else
+ echo "Found existing 'branch' field: $branch_field_id"
+ fi
+ echo "branch_field_id=$branch_field_id" >> $GITHUB_OUTPUT
+
+ # Find or create 'commit_hash' field
+ commit_hash_field_id=$(echo "$fields_json" | jq -r '.[] | select(.name == "commit_hash") | .id')
+ if [[ -z "$commit_hash_field_id" || "$commit_hash_field_id" == "null" ]]; then
+ echo "Creating 'commit_hash' field..."
+ commit_hash_field_id=$(gh api graphql -f query='
+ mutation($projectId: ID!, $name: String!) {
+ createProjectV2Field(input: {projectId: $projectId, dataType: TEXT, name: $name}) {
+ projectV2Field {
+ ... on ProjectV2Field {
+ id
+ }
+ }
+ }
+ }' -f projectId="$project_id" -f name="commit_hash" --jq '.data.createProjectV2Field.projectV2Field.id')
+ echo "Created 'commit_hash' field: $commit_hash_field_id"
+ else
+ echo "Found existing 'commit_hash' field: $commit_hash_field_id"
+ fi
+ echo "commit_hash_field_id=$commit_hash_field_id" >> $GITHUB_OUTPUT
+
+ - name: Create issues for commits
+ env:
+ PREV_POST_TAG: ${{ inputs.prev_post_tag }}
+ CURR_PRE_TAG: ${{ inputs.curr_pre_tag }}
+ TARGET_REPO: ${{ inputs.target_repo }}
+ PROJECT_NUMBER: ${{ inputs.project_number }}
+ PROJECT_OWNER: ${{ inputs.project_owner }}
+ REPO_NAME: ${{ github.repository }}
+ BRANCH: ${{ steps.parse.outputs.branch }}
+ PROJECT_ID: ${{ steps.project_fields.outputs.project_id }}
+ BRANCH_FIELD_ID: ${{ steps.project_fields.outputs.branch_field_id }}
+ COMMIT_HASH_FIELD_ID: ${{ steps.project_fields.outputs.commit_hash_field_id }}
+ run: |
+ echo "Creating issues for commits..."
+
+ commit_count=$(git rev-list --count --no-merges "${PREV_POST_TAG}..${CURR_PRE_TAG}" --not upstream/main)
+ if [[ "${commit_count}" -eq 0 ]]; then
+ echo "No ROCm-only commits in range ${PREV_POST_TAG}..${CURR_PRE_TAG}; nothing to create."
+ exit 0
+ fi
+
+ echo "Found ${commit_count} ROCm-only commits to process."
+
+ git log "${PREV_POST_TAG}..${CURR_PRE_TAG}" --format="%H" --no-merges --not upstream/main | while read hash; do
+ short_hash="${hash:0:5}"
+ subject=$(git log -1 --format="%s" "$hash")
+ author=$(git log -1 --format="%an" "$hash")
+ email=$(git log -1 --format="%ae" "$hash")
+
+ echo "Processing ${short_hash}: ${subject}"
+
+ # Try to get GitHub username via API first
+ gh_username=""
+ gh_username=$(gh api "repos/${REPO_NAME}/commits/${hash}" --jq '.author.login // empty' 2>/dev/null || true)
+
+ if [[ -z "${gh_username}" ]]; then
+ # Fallback: try to extract from noreply email
+ if [[ "$email" =~ ^[0-9]+\+([^@]+)@users\.noreply\.github\.com$ ]]; then
+ gh_username="${BASH_REMATCH[1]}"
+ echo " Extracted username from email: ${gh_username}"
+ fi
+ else
+ echo " Found GitHub username via API: ${gh_username}"
+ fi
+
+ # Dedupe by commit hash marker in issue body across all issue states.
+ existing_issue_url=$(gh issue list \
+ --repo "${TARGET_REPO}" \
+ --state all \
+ --search "\"${hash}\" in:body" \
+ --limit 20 \
+ --json url,body \
+ | jq -r --arg hash "$hash" '.[] | select((.body // "") | contains("**Commit:** " + $hash)) | .url' \
+ | head -n 1 || true)
+ if [[ -n "${existing_issue_url}" ]]; then
+ echo " Existing issue found for commit ${short_hash}: ${existing_issue_url}"
+ echo " Skipping duplicate issue creation."
+ continue
+ fi
+
+ body="**Commit:** ${hash}"$'\n'"**Author:** ${author} (${email})"$'\n'"**Branch:** ${BRANCH}"$'\n'"**Link:** [View commit](https://github.com/${REPO_NAME}/commit/${hash})"
+
+ issue_url=$(gh issue create \
+ --repo "${TARGET_REPO}" \
+ --title "${subject}" \
+ --body "${body}" 2>/dev/null || true)
+
+ if [[ -z "${issue_url}" ]]; then
+ echo " ERROR: Failed to create issue for ${short_hash}. Skipping."
+ continue
+ fi
+
+ echo " Created: ${issue_url}"
+
+ # Try to assign the issue
+ if [[ -n "${gh_username}" ]]; then
+ echo " Trying to assign to @${gh_username}..."
+ if gh issue edit "${issue_url}" --add-assignee "${gh_username}" 2>/dev/null; then
+ echo " Successfully assigned issue"
+ else
+ echo " Could not assign, adding comment instead"
+ gh issue comment "${issue_url}" --body "cc @${gh_username} - you authored this commit" || true
+ fi
+ fi
+
+ # Add to project and set field values
+ if [[ -n "${PROJECT_NUMBER}" && -n "${PROJECT_ID}" ]]; then
+ echo " Adding to project..."
+ item_id=$(gh project item-add "${PROJECT_NUMBER}" --owner "${PROJECT_OWNER}" --url "${issue_url}" --format json 2>/dev/null | jq -r '.id' || true)
+
+ if [[ -n "${item_id}" && "${item_id}" != "null" ]]; then
+ echo " Project item ID: ${item_id}"
+
+ # Set branch field
+ if [[ -n "${BRANCH_FIELD_ID}" ]]; then
+ echo " Setting branch field to: ${BRANCH}"
+ gh api graphql -f query='
+ mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
+ updateProjectV2ItemFieldValue(input: {
+ projectId: $projectId
+ itemId: $itemId
+ fieldId: $fieldId
+ value: {text: $value}
+ }) {
+ projectV2Item {
+ id
+ }
+ }
+ }' -f projectId="${PROJECT_ID}" -f itemId="${item_id}" -f fieldId="${BRANCH_FIELD_ID}" -f value="${BRANCH}" || echo " Warning: Failed to set branch field"
+ fi
+
+ # Set commit_hash field
+ if [[ -n "${COMMIT_HASH_FIELD_ID}" ]]; then
+ echo " Setting commit_hash field to: ${hash}"
+ gh api graphql -f query='
+ mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
+ updateProjectV2ItemFieldValue(input: {
+ projectId: $projectId
+ itemId: $itemId
+ fieldId: $fieldId
+ value: {text: $value}
+ }) {
+ projectV2Item {
+ id
+ }
+ }
+ }' -f projectId="${PROJECT_ID}" -f itemId="${item_id}" -f fieldId="${COMMIT_HASH_FIELD_ID}" -f value="${hash}" || echo " Warning: Failed to set commit_hash field"
+ fi
+ else
+ echo " Warning: Could not get project item ID"
+ fi
+ fi
+
+ sleep 1
+ done
+
+ echo "Done creating issues!"
diff --git a/.github/workflows/create_ifu_tag.yml b/.github/workflows/create_ifu_tag.yml
new file mode 100644
index 000000000000..7dc766cd06b0
--- /dev/null
+++ b/.github/workflows/create_ifu_tag.yml
@@ -0,0 +1,352 @@
+name: Create git tags for IFU PRs
+
+on:
+ # ORIGINAL: Triggered when an IFU PR is merged
+ pull_request:
+ types: [closed]
+
+ # Test harness - manually trigger to test without a real PR merge
+ workflow_dispatch:
+ inputs:
+ test_branch:
+ description: "Branch name to test (e.g., rocm7.1_internal_testing)"
+ required: true
+ type: string
+ test_curr_pre_tag:
+ description: "Pre tag to use as curr_pre_tag (required for full chain test)"
+ required: false
+ type: string
+ test_issue_prev_ref:
+ description: "Optional issue range start ref for cold-start full-chain test (tag or SHA)"
+ required: false
+ type: string
+ run_full_chain:
+ description: "Run full chain - actually call create_ifu_issues.yml (will create real issues!)"
+ required: false
+ default: false
+ type: boolean
+ pr_num:
+ description: "Merged IFU PR number — runs full pipeline (tags, PR body, create_issues) as if that PR just merged"
+ required: false
+ default: 0
+ type: number
+
+permissions:
+ contents: write # create/push tags
+ pull-requests: write # edit PR body
+ issues: write # needed for create_ifu_issues.yml when called
+
+jobs:
+ tag-ifu:
+ # Run for workflow_dispatch (test mode) OR for real PR merges
+ if: >
+ github.event_name == 'workflow_dispatch' ||
+ (github.event.pull_request.merged == true &&
+ contains(github.event.pull_request.title, '[AUTOGENERATED]') &&
+ contains(github.event.pull_request.title, 'IFU'))
+ runs-on: ubuntu-latest
+
+ # Export values so the create-issues job can use them
+ outputs:
+ prev_post_tag: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ curr_pre_tag: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num == 0 && inputs.test_curr_pre_tag) || steps.tagname.outputs.PRE_TAG }}
+ has_prev_tag: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ issue_prev_ref: ${{ steps.prev_ref.outputs.issue_prev_ref }}
+ can_create_issues: ${{ steps.prev_ref.outputs.can_create_issues }}
+
+ steps:
+ - name: Validate test inputs
+ if: github.event_name == 'workflow_dispatch' && inputs.run_full_chain == true
+ run: |
+ if [[ -z "${{ inputs.test_curr_pre_tag }}" ]]; then
+ echo "ERROR: test_curr_pre_tag is required when run_full_chain is enabled"
+ echo "Please provide an existing pre tag (e.g., rocm7.1_internal_testing_IFU_2025-10-29_pre)"
+ exit 1
+ fi
+ echo "Full chain test enabled with:"
+ echo " test_branch: ${{ inputs.test_branch }}"
+ echo " test_curr_pre_tag: ${{ inputs.test_curr_pre_tag }}"
+ if [[ -n "${{ inputs.test_issue_prev_ref }}" ]]; then
+ echo " test_issue_prev_ref: ${{ inputs.test_issue_prev_ref }}"
+ fi
+
+ # When dispatch + pr_num: fetch PR via API so we have base.ref, head.sha, merge_commit_sha, title (no event.pull_request in dispatch).
+ - name: Get PR details
+ id: get_pr
+ if: github.event_name == 'workflow_dispatch' && inputs.pr_num != 0
+ env:
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+ run: |
+ set -euo pipefail
+ PR_JSON=$(gh api "repos/${{ github.repository }}/pulls/${{ inputs.pr_num }}")
+ MERGE_SHA=$(echo "$PR_JSON" | jq -r .merge_commit_sha)
+ if [[ "$MERGE_SHA" == "null" || -z "$MERGE_SHA" ]]; then
+ echo "ERROR: PR #${{ inputs.pr_num }} is not merged yet. Use a merged IFU PR number."
+ exit 1
+ fi
+ echo "base_ref=$(echo "$PR_JSON" | jq -r .base.ref)" >> "$GITHUB_OUTPUT"
+ echo "head_sha=$(echo "$PR_JSON" | jq -r .head.sha)" >> "$GITHUB_OUTPUT"
+ echo "merge_sha=$MERGE_SHA" >> "$GITHUB_OUTPUT"
+ echo "title=$(echo "$PR_JSON" | jq -r .title)" >> "$GITHUB_OUTPUT"
+ echo "pr_num=${{ inputs.pr_num }}" >> "$GITHUB_OUTPUT"
+ echo "Fetched PR #${{ inputs.pr_num }}: base=$(echo "$PR_JSON" | jq -r .base.ref), merge_sha=$MERGE_SHA"
+
+ - name: Checkout base repo (full history)
+ uses: actions/checkout@v4
+ with:
+ # Worflow_dispatch
+ # pr_num != 0 -> use pr details from json which we got in get_pr step
+ # pr_num == 0 -> use current branch
+ # PR merge -> use base.ref
+ ref: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0 && steps.get_pr.outputs.base_ref) || (github.event_name == 'workflow_dispatch' && github.ref) || github.event.pull_request.base.ref }}
+ fetch-depth: 0
+ token: ${{ secrets.IFU_GITHUB_TOKEN }}
+
+ # Fetch all tags so we can find the previous post tag
+ - name: Fetch all tags
+ run: git fetch origin --tags --force
+
+ - name: Configure Git user
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Derive key SHAs (rocm base, upstream main, merge)
+ id: shas
+ if: (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0) || (github.event_name != 'workflow_dispatch')
+ env:
+ PR_NUM: ${{ steps.get_pr.outputs.pr_num || github.event.pull_request.number }}
+ BASE_REF: ${{ steps.get_pr.outputs.base_ref || github.event.pull_request.base.ref }}
+ HEAD_SHA: ${{ steps.get_pr.outputs.head_sha || github.event.pull_request.head.sha }}
+ MERGE_SHA: ${{ steps.get_pr.outputs.merge_sha || github.event.pull_request.merge_commit_sha }}
+ shell: bash
+ run: |
+ set -euo pipefail
+
+ # Upstream ref is usually the same as base branch. For rocm/pytorch's
+ # develop branch, compare against upstream/main.
+ UPSTREAM_REF="$BASE_REF"
+ if [ "$UPSTREAM_REF" == "develop" ]; then
+ UPSTREAM_REF="main"
+ fi
+
+ echo "PR_NUM=$PR_NUM"
+ echo "BASE_REF=$BASE_REF"
+ echo "UPSTREAM_REF=$UPSTREAM_REF"
+ echo "HEAD_SHA=$HEAD_SHA"
+ echo "MERGE_SHA=$MERGE_SHA"
+
+ # The ROCm base commit is the first parent of the merge commit that landed the PR
+ # (i.e., the base branch tip BEFORE this PR merged).
+ ROCM_BASE_SHA=$(git rev-parse "${MERGE_SHA}^1")
+
+ # Add upstream if missing.
+ if ! git remote get-url upstream >/dev/null 2>&1; then
+ git remote add upstream "https://github.com/pytorch/pytorch.git"
+ fi
+
+ # Some IFU base branches may not exist in upstream (e.g., fork-only/test branches).
+ # In that case, fall back to upstream/main.
+ if ! git ls-remote --exit-code --heads upstream "$UPSTREAM_REF" >/dev/null 2>&1; then
+ echo "Upstream branch '$UPSTREAM_REF' not found; falling back to upstream/main"
+ UPSTREAM_REF="main"
+ fi
+ git fetch upstream "$UPSTREAM_REF"
+
+ # Heuristic: the upstream commit integrated by the PR's head is the merge-base
+ # between the PR head commit and upstream/main as fetched now.
+ # This gives you the exact upstream commit (or the best common ancestor) that HEAD included.
+ UPSTREAM_MAIN_SHA=$(git merge-base "${HEAD_SHA}" "upstream/$UPSTREAM_REF")
+ echo "ROCM_BASE_SHA=$ROCM_BASE_SHA"
+ echo "UPSTREAM_MAIN_SHA=$UPSTREAM_MAIN_SHA"
+ echo "UPSTREAM_REF_USED=$UPSTREAM_REF"
+
+
+ echo "PR_NUM=$PR_NUM" >> "$GITHUB_OUTPUT"
+ echo "BASE_REF=$BASE_REF" >> "$GITHUB_OUTPUT"
+ echo "UPSTREAM_REF_USED=$UPSTREAM_REF" >> "$GITHUB_OUTPUT"
+ echo "HEAD_SHA=$HEAD_SHA" >> "$GITHUB_OUTPUT"
+ echo "MERGE_SHA=$MERGE_SHA" >> "$GITHUB_OUTPUT"
+ echo "ROCM_BASE_SHA=$ROCM_BASE_SHA" >> "$GITHUB_OUTPUT"
+ echo "UPSTREAM_MAIN_SHA=$UPSTREAM_MAIN_SHA" >> "$GITHUB_OUTPUT"
+
+ - name: Extract tag base from PR title
+ id: tagname
+ if: (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0) || (github.event_name != 'workflow_dispatch')
+ env:
+ TITLE: ${{ steps.get_pr.outputs.title || github.event.pull_request.title }}
+ run: |
+ # Remove everything up to and including "[AUTOGENERATED]"
+ # Remove trailing whitespace
+ BASE_TAG=$(echo "$TITLE" | sed -E 's/^\[AUTOGENERATED\][[:space:]]*//' | sed -E 's/[[:space:]]+$//')
+
+ echo "BASE_TAG=$BASE_TAG"
+ echo "PRE_TAG=${BASE_TAG}_pre"
+ echo "POST_TAG=${BASE_TAG}_post"
+
+ # Extract branch name from BASE_TAG (everything before _IFU_)
+ BRANCH="${BASE_TAG%_IFU_*}"
+ echo "BRANCH=$BRANCH"
+
+ echo "BASE_TAG=$BASE_TAG" >> $GITHUB_OUTPUT
+ echo "PRE_TAG=${BASE_TAG}_pre" >> $GITHUB_OUTPUT
+ echo "POST_TAG=${BASE_TAG}_post" >> $GITHUB_OUTPUT
+ echo "BRANCH=$BRANCH" >> $GITHUB_OUTPUT
+
+ # Find the most recent post tag for this branch
+ # This is needed to know the range of commits for issue creation
+ - name: Find previous post tag
+ id: prev_tag
+ env:
+ # Dispatch without pr_num: test_branch; dispatch+pr_num or PR merge: from tagname
+ BRANCH: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num == 0 && inputs.test_branch) || steps.tagname.outputs.BRANCH }}
+ run: |
+ echo "Finding previous post tag for branch: ${BRANCH}"
+
+ # List all post tags for this branch, sorted by version (date in tag name)
+ echo "All post tags for ${BRANCH}:"
+ git tag --list "${BRANCH}_IFU_*_post" --sort=-version:refname
+
+ # Get the most recent post tag
+ prev_post_tag=$(git tag --list "${BRANCH}_IFU_*_post" --sort=-version:refname | head -n 1)
+
+ if [[ -z "$prev_post_tag" ]]; then
+ echo "WARNING: No previous post tag found for branch ${BRANCH}"
+ echo "This might be the first IFU for this branch"
+ echo "prev_post_tag=" >> $GITHUB_OUTPUT
+ echo "has_prev_tag=false" >> $GITHUB_OUTPUT
+ else
+ echo "Found previous post tag: $prev_post_tag"
+ echo "prev_post_tag=$prev_post_tag" >> $GITHUB_OUTPUT
+ echo "has_prev_tag=true" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Validate full-chain test start ref
+ if: github.event_name == 'workflow_dispatch' && inputs.run_full_chain == true
+ env:
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ run: |
+ if [[ "${HAS_PREV_TAG}" != "true" && -z "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo "ERROR: No previous post tag found for this branch."
+ echo "For cold-start full-chain tests, provide test_issue_prev_ref (tag or SHA)."
+ exit 1
+ fi
+
+ # In test mode, print a summary of what was found
+ - name: Test mode summary
+ if: github.event_name == 'workflow_dispatch'
+ env:
+ BRANCH: ${{ inputs.test_branch }}
+ PREV_POST_TAG: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ TEST_CURR_PRE_TAG: ${{ inputs.test_curr_pre_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ RUN_FULL_CHAIN: ${{ inputs.run_full_chain }}
+ run: |
+ echo "=========================================="
+ echo "TEST MODE SUMMARY"
+ echo "=========================================="
+ echo "Branch: ${BRANCH}"
+ echo "Has previous post tag: ${HAS_PREV_TAG}"
+ echo "Previous post tag: ${PREV_POST_TAG:-'(none)'}"
+ echo ""
+ if [[ "${RUN_FULL_CHAIN}" == "true" ]]; then
+ echo " FULL CHAIN TEST ENABLED"
+ echo "Will call create_ifu_issues.yml with:"
+ echo " - prev_post_tag: ${PREV_POST_TAG}"
+ echo " - curr_pre_tag: ${TEST_CURR_PRE_TAG}"
+ if [[ -n "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo " - test_issue_prev_ref override: ${TEST_ISSUE_PREV_REF}"
+ fi
+ echo ""
+ echo " WARNING: This will create REAL issues!"
+ else
+ echo " Full chain test NOT enabled"
+ echo "To test issue creation, re-run with:"
+ echo " - run_full_chain: true"
+ echo " - test_curr_pre_tag: (an existing pre tag)"
+ fi
+ echo "=========================================="
+
+ # Determine the start reference for issue creation.
+ # Priority:
+ # 1) previous IFU post tag (normal path)
+ # 2) test_issue_prev_ref (cold-start fallback for workflow_dispatch test path)
+ # 3) UPSTREAM_MAIN_SHA (cold-start fallback on real merge path only)
+ - name: Resolve issue range start reference
+ id: prev_ref
+ env:
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ PREV_POST_TAG: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ EVENT_NAME: ${{ github.event_name }}
+ UPSTREAM_MAIN_SHA: ${{ steps.shas.outputs.UPSTREAM_MAIN_SHA }}
+ run: |
+ if [[ "${HAS_PREV_TAG}" == "true" && -n "${PREV_POST_TAG}" ]]; then
+ echo "Using previous IFU post tag for issue range start: ${PREV_POST_TAG}"
+ echo "issue_prev_ref=${PREV_POST_TAG}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ elif [[ "${EVENT_NAME}" == "workflow_dispatch" && -n "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo "Using test override for issue range start: ${TEST_ISSUE_PREV_REF}"
+ echo "issue_prev_ref=${TEST_ISSUE_PREV_REF}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ elif [[ "${EVENT_NAME}" != "workflow_dispatch" && -n "${UPSTREAM_MAIN_SHA:-}" ]]; then
+ echo "No previous IFU post tag found; using cold-start fallback UPSTREAM_MAIN_SHA: ${UPSTREAM_MAIN_SHA}"
+ echo "issue_prev_ref=${UPSTREAM_MAIN_SHA}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ else
+ echo "Could not determine issue range start reference."
+ echo "issue_prev_ref=" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=false" >> "$GITHUB_OUTPUT"
+ fi
+
+ - name: Create pre/post tags
+ if: (github.event_name == 'pull_request') || (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0)
+ shell: bash
+ run: |
+ set -euo pipefail
+ echo "Tagging:"
+ echo " ${{ steps.tagname.outputs.PRE_TAG }} @ ${{ steps.shas.outputs.ROCM_BASE_SHA }}"
+ echo " ${{ steps.tagname.outputs.POST_TAG }} @ ${{ steps.shas.outputs.MERGE_SHA }}"
+
+ git tag -a "${{ steps.tagname.outputs.PRE_TAG }}" -m "IFU pre (PR #${{ steps.shas.outputs.PR_NUM }})" "${{ steps.shas.outputs.ROCM_BASE_SHA }}"
+ git tag -a "${{ steps.tagname.outputs.POST_TAG }}" -m "IFU post (PR #${{ steps.shas.outputs.PR_NUM }})" "${{ steps.shas.outputs.MERGE_SHA }}"
+
+ #Force pushing is safe. If we land a new PR, we'd wanna retag a commit if we have to.
+ git push origin "refs/tags/${{ steps.tagname.outputs.PRE_TAG }}" -f
+ git push origin "refs/tags/${{ steps.tagname.outputs.POST_TAG }}" -f
+
+ - name: Append rocm_base & upstream_main to PR body
+ if: (github.event_name == 'pull_request') || (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0)
+ env:
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+ shell: bash
+ run: |
+ set -euo pipefail
+ # Read current body
+ PR="${{ steps.shas.outputs.PR_NUM }}"
+ CURR=$(gh api repos/${{ github.repository }}/pulls/$PR --jq .body)
+ APPEND=$'\n'"rocm_base: ${{ steps.shas.outputs.ROCM_BASE_SHA }}"$'\n'"upstream_main: ${{ steps.shas.outputs.UPSTREAM_MAIN_SHA }}"$'\n'
+ NEW_BODY="${CURR}${APPEND}"
+
+ # Write to a temp file and update PR body
+ printf '%s' "$NEW_BODY" > body.txt
+ gh api --method PATCH -H "Accept: application/vnd.github+json" \
+ repos/${{ github.repository }}/pulls/$PR -F body=@body.txt
+
+ # Calls create_ifu_issues.yml after tagging
+ # Runs for:
+ # - Real PR merges (when a start reference can be resolved)
+ # - Test mode with run_full_chain=true (when a start reference can be resolved)
+ create-issues:
+ needs: tag-ifu
+ if: >
+ needs.tag-ifu.outputs.can_create_issues == 'true' &&
+ (github.event_name != 'workflow_dispatch' || inputs.run_full_chain == true || inputs.pr_num != 0)
+ uses: ./.github/workflows/create_ifu_issues.yml
+ with:
+ prev_post_tag: ${{ needs.tag-ifu.outputs.issue_prev_ref }}
+ curr_pre_tag: ${{ needs.tag-ifu.outputs.curr_pre_tag }}
+ secrets:
+ IFU_GITHUB_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
diff --git a/.github/workflows/parity-auto.yml b/.github/workflows/parity-auto.yml
new file mode 100644
index 000000000000..06eb450f9bdb
--- /dev/null
+++ b/.github/workflows/parity-auto.yml
@@ -0,0 +1,412 @@
+name: Parity Auto Trigger
+run-name: "Parity auto-trigger · pytorch/pytorch main"
+
+# Every 10 min, dispatch parity.yml once per completed upstream trunk.yml push
+# whose parity inputs have all finished. Scope is trunk only: the mi350 ROCm
+# shards that ride along in trunk.yml vs that run's CUDA shards. Other arches
+# have their own periodic workflows - run parity.yml manually for those.
+#
+# Readiness is gated per check-run, not on workflow_run conclusion: one failed
+# shard flips the parent run to failure while siblings are still going, so we
+# wait for every ROCm + CUDA test check-run to complete before dispatching.
+
+on:
+ schedule:
+ - cron: '*/10 * * * *'
+ pull_request:
+ paths:
+ - '.github/workflows/parity-auto.yml'
+ - '.github/workflows/parity.yml'
+ - '.automation_scripts/pytorch-unit-test-scripts/parity_job_config.json'
+ workflow_dispatch:
+ inputs:
+ max_commits:
+ description: 'How many of the most recent completed upstream trunk.yml pushes on main to scan.'
+ required: false
+ default: '200'
+ type: string
+ max_dispatches:
+ description: 'Maximum number of ready upstream commits to dispatch in one scan.'
+ required: false
+ default: '50'
+ type: string
+ max_age_hours:
+ description: 'Skip commits older than this (avoid back-filling ancient SHAs).'
+ required: false
+ default: '72'
+ type: string
+ arch_jobname_regex_map:
+ description: 'Optional override: JSON map of arch -> PCRE regex for that arch''s ROCm test check-run names. Blank = use parity_job_config.json (rocm..checkrun_regex).'
+ required: false
+ default: ''
+ type: string
+ arch_workflow_regex_map:
+ description: 'Optional override: JSON map of arch -> PCRE regex for upstream workflow paths meaning the arch ran. Blank = derive from parity_job_config.json (union of each arch''s workflow values).'
+ required: false
+ default: ''
+ type: string
+ target_ref:
+ description: 'Ref of this repo to dispatch parity.yml against. Leave blank to use this workflow run''s ref.'
+ required: false
+ default: ''
+ type: string
+ dry_run:
+ description: 'Scan and log, but do not actually dispatch parity.yml.'
+ required: false
+ default: false
+ type: boolean
+
+permissions:
+ contents: read
+ actions: write
+
+concurrency:
+ group: parity-auto-trigger
+ cancel-in-progress: false
+
+jobs:
+ scan-and-dispatch:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Find ready arches per upstream commit and dispatch parity.yml
+ env:
+ GH_TOKEN: ${{ github.token }}
+ UPSTREAM: pytorch/pytorch
+ BRANCH: main
+ MAX_COMMITS: ${{ github.event_name == 'pull_request' && '20' || inputs.max_commits || '200' }}
+ MAX_DISPATCHES: ${{ github.event_name == 'pull_request' && '5' || inputs.max_dispatches || '50' }}
+ MAX_AGE_HOURS: ${{ inputs.max_age_hours || '72' }}
+ # Auto-parity is trunk-scoped: mi350 is the only ROCm arch that rides
+ # along in trunk.yml. Other arches have their own periodic workflows.
+ ARCHS_IN: mi350
+ # Optional manual overrides; blank means "derive from parity_job_config.json".
+ ARCH_JOBNAME_REGEX_OVERRIDE: ${{ inputs.arch_jobname_regex_map || '' }}
+ ARCH_WORKFLOW_REGEX_OVERRIDE: ${{ inputs.arch_workflow_regex_map || '' }}
+ CONFIG_PATH: .automation_scripts/pytorch-unit-test-scripts/parity_job_config.json
+ TARGET_REF_IN: ${{ inputs.target_ref || '' }}
+ DRY_RUN: ${{ github.event_name == 'pull_request' && 'true' || inputs.dry_run || 'false' }}
+ run: |
+ # GitHub runs this step as `bash -e`, which exits on the first non-zero
+ # status. Our paginated-API calls return non-zero in normal cases and
+ # were silently killing the scan loop, so disable -e.
+ # Keep -u (catch unset vars) and pipefail (catch errors mid-pipe).
+ set +e
+ set -uo pipefail
+
+ # Check-runs from rerun/flaky tooling are not parity test shards.
+ NON_TEST_CHECKRUNS='mem_leak_check|rerun_disabled_tests'
+
+ # Read parity_job_config.json over the API at GITHUB_SHA (this job never
+ # checks out the fork) and build the regex maps used for matching.
+ # Manual override inputs win when set. Sets ARCH_JOBNAME_REGEX_MAP,
+ # ARCH_WORKFLOW_REGEX_MAP, and CUDA_JOBNAME_REGEX.
+ load_matching_config() {
+ local config_json
+ config_json=$(gh api "repos/$GITHUB_REPOSITORY/contents/$CONFIG_PATH?ref=$GITHUB_SHA" --jq '.content' | base64 -d)
+ if [ -z "$config_json" ] || ! echo "$config_json" | jq -e . >/dev/null 2>&1; then
+ echo "::error::Could not read $CONFIG_PATH at $GITHUB_SHA"
+ exit 1
+ fi
+ ARCH_JOBNAME_REGEX_MAP=${ARCH_JOBNAME_REGEX_OVERRIDE:-$(echo "$config_json" | jq -c '.rocm | map_values(.checkrun_regex)')}
+ # workflow_regex is not stored: derive it per arch from the union of
+ # the workflow values across that arch's default/distributed/inductor
+ # source lists (primary + fallback entries).
+ ARCH_WORKFLOW_REGEX_MAP=${ARCH_WORKFLOW_REGEX_OVERRIDE:-$(echo "$config_json" | jq -c '
+ .rocm | map_values(
+ [ (.default[]?, .distributed[]?, .inductor[]?).workflow ] | unique
+ | "(^|/)(" + join("|") + ")[.]yml$"
+ )')}
+ CUDA_JOBNAME_REGEX=$(echo "$config_json" | jq -r '.cuda.checkrun_regex')
+ }
+
+ print_run_config() {
+ printf '%s\n' \
+ "Upstream: $UPSTREAM@$BRANCH" \
+ "Target ref: $TARGET_REF" \
+ "Scope archs: $ARCHS" \
+ "Max trunk runs: $MAX_COMMITS" \
+ "Max dispatches: $MAX_DISPATCHES" \
+ "Max age: ${MAX_AGE_HOURS}h" \
+ "Dry run: $DRY_RUN" \
+ "Arch->jobs: $ARCH_JOBNAME_REGEX_MAP" \
+ "Arch->workflows: $ARCH_WORKFLOW_REGEX_MAP" \
+ "CUDA jobs: $CUDA_JOBNAME_REGEX" \
+ ""
+ }
+
+ # Echo "" lines for the most recent completed
+ # trunk.yml pushes, deduped and newest-first, up to MAX_COMMITS.
+ # Candidate SHAs come from completed trunk pushes (not raw main
+ # commits): the report consumes trunk's jobs, so a completed trunk run
+ # is the earliest a SHA can be ready.
+ fetch_trunk_commits() {
+ local commits_json='[]' page=1 page_runs
+ while [ "$(echo "$commits_json" | jq 'length')" -lt "$MAX_COMMITS" ]; do
+ page_runs=$(gh api \
+ "repos/$UPSTREAM/actions/workflows/trunk.yml/runs?branch=$BRANCH&event=push&status=completed&per_page=100&page=$page" \
+ --jq '.workflow_runs | map({head_sha, created_at})' 2>/dev/null)
+ # On a gh api failure (rate limit / transient 5xx) page_runs is
+ # empty or non-JSON; stop paginating and use what we have so far.
+ if ! echo "$page_runs" | jq -e . >/dev/null 2>&1; then
+ echo "::warning::trunk.yml runs page $page returned no/invalid JSON - stopping pagination" >&2
+ break
+ fi
+ [ "$(echo "$page_runs" | jq 'length')" -eq 0 ] && break
+ commits_json=$(jq -s --arg max "$MAX_COMMITS" '
+ (.[0] + .[1]) as $runs
+ | reduce $runs[] as $run ({seen:{}, rows:[]};
+ if .seen[$run.head_sha] then .
+ else .seen[$run.head_sha] = true | .rows += [$run]
+ end
+ )
+ | .rows[:($max | tonumber)]
+ ' <(echo "$commits_json") <(echo "$page_runs"))
+ page=$((page + 1))
+ done
+ echo "$commits_json" | jq -r '.[] | "\(.head_sha) \(.created_at)"'
+ }
+
+ # Echo a JSON array of recent auto-parity workflow_dispatch runs in our
+ # repo, used to skip SHAs we already dispatched. Auto runs come from
+ # github-actions[bot] and carry an "autoparity-" run-name prefix; we
+ # match either signal so manual parity.yml runs don't suppress us.
+ fetch_existing_dispatches() {
+ gh api --paginate \
+ "repos/$GITHUB_REPOSITORY/actions/workflows/parity.yml/runs?event=workflow_dispatch&created=%3E%3D$(date -u -d "@$MAX_AGE_EPOCH" '+%Y-%m-%dT%H:%M:%SZ')&per_page=100" \
+ --jq '.workflow_runs[] | {display_title, actor: .actor.login}' |
+ jq -s '.'
+ }
+
+ # True if EXISTING already has an auto-parity run for this SHA.
+ sha_already_dispatched() {
+ local sha="$1"
+ echo "$EXISTING" | jq -e --arg sha "$sha" \
+ 'any(.[]; ((.display_title // "") | contains($sha)) and (((.display_title // "") | startswith("autoparity-")) or (.actor == "github-actions[bot]")))' >/dev/null
+ }
+
+ # Filter a check-run JSON array down to the parity test shards whose
+ # name matches $2 (PCRE), dropping non-test check-runs. Reads the array
+ # from stdin, echoes the filtered array.
+ select_test_check_runs() {
+ local regex="$1"
+ jq --arg rx "$regex" --arg skip "$NON_TEST_CHECKRUNS" \
+ '[.[] | select((.name | test($rx)) and (.name | test($skip) | not))]'
+ }
+
+ # Determine which in-scope archs had their upstream workflow run on $1.
+ # Sets RUN_ARCHS (space separated) and NOT_RUN_NOTES (per-arch reasons
+ # for the ones that did not, for logging). Sets globals rather than
+ # echoing so both outputs survive - command substitution would run this
+ # in a subshell and drop NOT_RUN_NOTES.
+ archs_that_ran() {
+ local sha="$1" all_runs arch wf_regex wf_total run_archs="" notes=""
+ all_runs=$(gh api --paginate \
+ "repos/$UPSTREAM/actions/runs?head_sha=$sha&per_page=100" \
+ --jq '.workflow_runs[] | {name,path,status,conclusion}' \
+ 2>/dev/null | jq -s '.' || echo '[]')
+ for arch in $ARCHS; do
+ wf_regex=$(echo "$ARCH_WORKFLOW_REGEX_MAP" | jq -r --arg a "$arch" '.[$a] // ""')
+ if [ -z "$wf_regex" ]; then
+ notes="$notes $arch:no-workflow-regex"
+ continue
+ fi
+ wf_total=$(echo "$all_runs" | jq --arg rx "$wf_regex" \
+ 'map(select((.path // "") | test($rx))) | length')
+ if [ "$wf_total" -eq 0 ]; then
+ notes="$notes $arch:no-workflow"
+ else
+ run_archs="$run_archs $arch"
+ fi
+ done
+ RUN_ARCHS=$(echo "$run_archs" | xargs)
+ NOT_RUN_NOTES=$(echo "$notes" | xargs)
+ }
+
+ # Determine which archs from $1 have at least one ROCm test check-run
+ # present in ROCM_CHECK_RUNS. Sets READY (space separated) and
+ # NOT_READY_NOTES (per-arch reasons for the ones still missing shards).
+ # Sets globals for the same subshell reason as archs_that_ran.
+ archs_with_shards() {
+ local run_archs="$1" arch regex total ready="" notes=""
+ for arch in $run_archs; do
+ regex=$(echo "$ARCH_JOBNAME_REGEX_MAP" | jq -r --arg a "$arch" '.[$a] // ""')
+ if [ -z "$regex" ]; then
+ notes="$notes $arch:no-regex"
+ continue
+ fi
+ total=$(echo "$ROCM_CHECK_RUNS" | jq --arg rx "$regex" \
+ 'map(select(.name | test($rx))) | length')
+ if [ "$total" -eq 0 ]; then
+ notes="$notes $arch:workflow-run-no-shards-yet"
+ else
+ ready="$ready $arch"
+ fi
+ done
+ READY=$(echo "$ready" | xargs)
+ NOT_READY_NOTES=$(echo "$notes" | xargs)
+ }
+
+ # Evaluate one trunk SHA and dispatch parity.yml when it is ready and
+ # not already processed. Returns 0 to keep scanning, 1 to stop the scan
+ # (commit too old, or per-scan dispatch cap reached).
+ process_commit() {
+ local sha="$1" date="$2"
+ local short commit_epoch all_check_runs cuda_check_runs
+ local total_cr pending_cr pending_sample arch_dispatch
+ short=$(echo "$sha" | cut -c1-8)
+
+ commit_epoch=$(date -u -d "$date" +%s 2>/dev/null || echo 0)
+ if [ "$commit_epoch" -ne 0 ] && [ "$commit_epoch" -lt "$MAX_AGE_EPOCH" ]; then
+ echo "[$short] $date too old (>${MAX_AGE_HOURS}h) - stopping scan"
+ return 1
+ fi
+
+ if sha_already_dispatched "$sha"; then
+ echo "[$short] parity report already exists for this SHA - skip"
+ return 0
+ fi
+
+ # Sets RUN_ARCHS and NOT_RUN_NOTES.
+ archs_that_ran "$sha"
+ if [ -z "$RUN_ARCHS" ]; then
+ echo "[$short] $date no in-scope ROCm workflows ran on upstream (${NOT_RUN_NOTES:-none}) - skip"
+ return 0
+ fi
+
+ # Per-shard check-run state for this SHA (workflow_run conclusion can
+ # flip to failure before sibling shards finish, so we look per shard).
+ all_check_runs=$(gh api --paginate \
+ "repos/$UPSTREAM/commits/$sha/check-runs?per_page=100" \
+ --jq '.check_runs[] | {name,status,conclusion}' \
+ 2>/dev/null | jq -s '.' || echo '[]')
+
+ ROCM_CHECK_RUNS='[]'
+ local arch regex arch_check_runs
+ for arch in $RUN_ARCHS; do
+ regex=$(echo "$ARCH_JOBNAME_REGEX_MAP" | jq -r --arg a "$arch" '.[$a] // ""')
+ [ -z "$regex" ] && continue
+ arch_check_runs=$(echo "$all_check_runs" | select_test_check_runs "$regex")
+ ROCM_CHECK_RUNS=$(jq -s 'add | unique_by(.name)' \
+ <(echo "$ROCM_CHECK_RUNS") <(echo "$arch_check_runs"))
+ done
+ cuda_check_runs=$(echo "$all_check_runs" | select_test_check_runs "$CUDA_JOBNAME_REGEX")
+
+ if [ "$(echo "$ROCM_CHECK_RUNS" | jq 'length')" -eq 0 ]; then
+ echo "[$short] $date ROCm workflows ran ($RUN_ARCHS) but no parity check-runs yet - skip"
+ return 0
+ fi
+ if [ "$(echo "$cuda_check_runs" | jq 'length')" -eq 0 ]; then
+ echo "[$short] $date no CUDA parity check-runs yet on upstream - skip"
+ return 0
+ fi
+
+ # Gate 1: every check-run the report consumes (ROCm shards for arches
+ # that ran + CUDA tests) must be status=completed - we author the
+ # SHA's report on dispatch, so dispatching early = partial data.
+ local gate_check_runs
+ gate_check_runs=$(jq -s 'add' <(echo "$ROCM_CHECK_RUNS") <(echo "$cuda_check_runs"))
+ total_cr=$(echo "$gate_check_runs" | jq 'length')
+ pending_cr=$(echo "$gate_check_runs" | jq 'map(select(.status != "completed")) | length')
+ if [ "$pending_cr" -ne 0 ]; then
+ pending_sample=$(echo "$gate_check_runs" | jq -r \
+ 'map(select(.status != "completed")) | .[0:3] | map(.name) | join(", ")')
+ echo "[$short] $date ${pending_cr}/${total_cr} parity check-runs still pending - skip (e.g. $pending_sample)"
+ return 0
+ fi
+
+ # Gate 2: every arch workflow that ran must actually have its test
+ # shards present before we author the SHA's one report.
+ # Sets READY and NOT_READY_NOTES.
+ archs_with_shards "$RUN_ARCHS"
+ if [ -n "$NOT_READY_NOTES" ]; then
+ echo "[$short] $date ROCm workflows ran ($RUN_ARCHS) but some test shards are missing - skip (${NOT_READY_NOTES})"
+ return 0
+ fi
+ if [ -z "$READY" ]; then
+ echo "[$short] $date ROCm workflows ran ($RUN_ARCHS) but no in-scope arches are ready"
+ return 0
+ fi
+
+ arch_dispatch=$(echo "$READY" | sed 's/ /, /g')
+ echo "[$short] READY archs: '$(echo "$READY" | tr ' ' ',')' (committed $date; not-run: ${NOT_RUN_NOTES:-none})"
+ echo "[$short] dispatching for: '$(echo "$READY" | tr ' ' ',')'"
+
+ if [ "$DRY_RUN" = "true" ]; then
+ echo "[$short] DRY_RUN=true - not dispatching"
+ else
+ gh workflow run parity.yml \
+ --repo "$GITHUB_REPOSITORY" \
+ --ref "$TARGET_REF" \
+ -f sha="$sha" \
+ -f arch="$arch_dispatch" \
+ -f auto_triggered=true
+ fi
+
+ DISPATCHED_COUNT=$((DISPATCHED_COUNT + 1))
+ DISPATCHED_SUMMARY="${DISPATCHED_SUMMARY}${short}:${arch_dispatch}"$'\n'
+ if [ "$DISPATCHED_COUNT" -ge "$MAX_DISPATCHES" ]; then
+ echo "Reached max dispatches for this scan ($MAX_DISPATCHES); stopping"
+ return 1
+ fi
+ return 0
+ }
+
+ write_step_summary() {
+ local result_line dispatched_list
+ if [ "$DISPATCHED_COUNT" -gt 0 ]; then
+ result_line=$([ "$DRY_RUN" = "true" ] \
+ && echo "would dispatch $DISPATCHED_COUNT parity run(s) (dry-run)" \
+ || echo "dispatched $DISPATCHED_COUNT parity run(s)")
+ dispatched_list=$(echo "$DISPATCHED_SUMMARY" | sed '/^$/d; s/^/- /')
+ else
+ result_line="no ready unprocessed SHAs found"
+ dispatched_list=""
+ fi
+ {
+ printf '%s\n' \
+ "### Parity auto-trigger" \
+ "" \
+ "- Upstream: \`$UPSTREAM@$BRANCH\`" \
+ "- Scope archs: \`$ARCHS\`" \
+ "- Max commits: $MAX_COMMITS" \
+ "- Max dispatches: $MAX_DISPATCHES" \
+ "- Max age: ${MAX_AGE_HOURS}h" \
+ "- Target ref: \`$TARGET_REF\`" \
+ "- Result: $result_line"
+ # if (not &&) so a 0-dispatch scan leaves a 0 exit status, not the
+ # 1 from the empty-string test - which would fail the whole step.
+ if [ -n "$dispatched_list" ]; then printf '\n%s\n' "$dispatched_list"; fi
+ } >> "$GITHUB_STEP_SUMMARY"
+ }
+
+ main() {
+ NOW_EPOCH=$(date -u +%s)
+ MAX_AGE_EPOCH=$((NOW_EPOCH - MAX_AGE_HOURS * 3600))
+ TARGET_REF="${TARGET_REF_IN:-$GITHUB_REF_NAME}"
+ ARCHS=$(echo "$ARCHS_IN" | tr ',' ' ' | xargs)
+
+ load_matching_config
+ print_run_config
+
+ local commits
+ commits=$(fetch_trunk_commits)
+ if [ -z "$commits" ]; then
+ echo "::warning::No completed trunk.yml push runs returned from $UPSTREAM@$BRANCH"
+ exit 0
+ fi
+
+ EXISTING=$(fetch_existing_dispatches)
+
+ DISPATCHED_COUNT=0
+ DISPATCHED_SUMMARY=""
+ local sha date
+ while IFS=' ' read -r sha date; do
+ [ -z "$sha" ] && continue
+ # process_commit returns non-zero to stop the scan (too old / cap).
+ process_commit "$sha" "$date" || break
+ done <<< "$commits"
+
+ write_step_summary
+ }
+
+ main
diff --git a/.github/workflows/parity.yml b/.github/workflows/parity.yml
new file mode 100644
index 000000000000..fc53323a25ee
--- /dev/null
+++ b/.github/workflows/parity.yml
@@ -0,0 +1,428 @@
+name: Parity Report
+
+# Generates the ROCm-vs-CUDA (or commit-vs-commit) test parity report for one
+# or more architectures. Three jobs run in sequence:
+# 1. setup-matrix - turn the "arch" input into a job matrix + filename prefix
+# 2. generate-parity - per arch: download_testlogs -> per-arch CSV artifact
+# 3. summarize - merge the per-arch CSVs into one report + step summary
+#
+# The run-name below is just a human-friendly title shown in the Actions list.
+# It reads as: [autoparity-][-]