diff --git a/.automation_scripts/parse_xml_results.py b/.automation_scripts/parse_xml_results.py
new file mode 100644
index 0000000000000..7db2e1ce9233c
--- /dev/null
+++ b/.automation_scripts/parse_xml_results.py
@@ -0,0 +1,178 @@
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from typing import Any, Dict, Tuple
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ work_flow_name: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ print(f"Parsing {tag}s for test report: {report}")
+
+ job_id = get_job_id(report)
+ print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["work_flow_name"] = work_flow_name
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["work_flow_name"] ) ] = case
+ elif tag == 'testsuite':
+ case["work_flow_name"] = work_flow_name
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for ind in range(len(BACKENDS_LIST)):
+ if BACKENDS_LIST[ind] in report.parts:
+ case_name = case_name + "_" + BACKENDS_LIST[ind]
+ break
+ case["invoking_file"] = case_name
+
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["work_flow_name"] ) ] = case
+
+ return test_cases
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
\ No newline at end of file
diff --git a/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py b/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py
new file mode 100644
index 0000000000000..cf948495ec04e
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/auto_classify_skip_reasons.py
@@ -0,0 +1,1027 @@
+#!/usr/bin/env python3
+"""
+Auto-classify skip reasons for ROCm parity CSV tests.
+
+Takes a parity CSV (output of summarize_xml_testreports.py) and automatically
+assigns skip_reason categories to tests where ROCm=SKIPPED/MISSED and CUDA=PASSED
+based on patterns in:
+ - The skip message (message_rocm column)
+ - The test file name
+ - The test class name
+ - The test name
+
+Rules are ordered by specificity: combined match rules first, then message-based,
+then file+class combos, then file-only fallbacks. First matching rule wins.
+
+Usage:
+ python auto_classify_skip_reasons.py -i input.csv -o output.csv [--report]
+ python auto_classify_skip_reasons.py -i input.csv -o output.csv --tsv-out updated_skip_reasons.tsv
+ python auto_classify_skip_reasons.py -i input.csv --dry-run --report
+"""
+
+import argparse
+import ast
+import csv
+import re
+import sys
+from collections import Counter, defaultdict
+
+
+# ---------------------------------------------------------------------------
+# Rules are evaluated top-to-bottom; first match wins.
+# Each rule is a dict with:
+# reason: the skip_reason category string
+# msg: (optional) regex to match against the skip message
+# file: (optional) regex to match against test_file
+# cls: (optional) regex to match against test_class
+# name: (optional) regex to match against test_name
+# workflow: (optional) one of "default", "distributed", "inductor"
+#
+# All provided fields must match (AND logic). Omitted fields match anything.
+# msg="" matches empty messages; omitting msg matches anything.
+# ---------------------------------------------------------------------------
+
+RULES = [
+ # ==================================================================
+ # TIER 1: High-specificity combined rules (message + file/class)
+ # ==================================================================
+
+ # --- bfloat16_SDPA_ME: dropout mask in test_transformers with bfloat16 in TEST NAME ---
+ # Must be before generic SDPA_ME rule
+ {"reason": "bfloat16_SDPA_ME",
+ "msg": r"_fill_mem_eff_dropout_mask",
+ "file": r"^test_transformers$",
+ "name": r"(?i)bfloat16|bf16"},
+
+ # --- GEMMS: test_mm_bmm in test_matmul_cuda with accuracy regression ---
+ # Must be before generic hipblas rule
+ {"reason": "GEMMS",
+ "msg": r"accuracy regression in hipblas",
+ "file": r"^test_matmul_cuda$",
+ "name": r"test_mm_bmm"},
+
+ # --- hipblas hipblaslt: test_addmm/test_cublas/other in test_matmul_cuda ---
+ {"reason": "hipblas hipblaslt",
+ "msg": r"accuracy regression in hipblas",
+ "file": r"^test_matmul_cuda$"},
+ {"reason": "hipblas hipblaslt",
+ "msg": r"skipIfRocm.*doesn't currently work",
+ "file": r"^test_matmul_cuda$"},
+ {"reason": "hipblas hipblaslt",
+ "file": r"^test_matmul_cuda$",
+ "msg": r"Green contexts are not supported"},
+
+ # --- Expected to work: skipCUDAIfRocm in test_meta for ldl_solve ops ---
+ {"reason": "Expected to work",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$",
+ "name": r"(?i)ldl_solve"},
+
+ # --- Linalg: skipCUDAIfRocm in test_meta for other linalg ops ---
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$"},
+
+ # --- Linalg: skipCUDAIfRocm in test_ops/test_linalg/test_meta/test_ops_fwd_gradients/test_ops_gradients ---
+ # These are ops like linalg.svd, linalg.eigh, etc.
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_linalg$"},
+ {"reason": "Linalg",
+ "msg": r"_convert_weight_to_int4pack_cuda.*(supported only for|is supported only for) CDNA"},
+ {"reason": "Linalg",
+ "msg": r"bfloat16 NCHW train failed"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|matrix_rank|cross|norm|cond|householder|ormqr|geqrf|triangular|vecdot|multi_dot"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_fwd_gradients$"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_gradients$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|householder|ormqr|geqrf|triangular"},
+ {"reason": "Linalg",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_meta$",
+ "name": r"(?i)linalg|svd|eig[hs]?|cholesky|lstsq|solve|inv|det|qr|lu|pinv|householder|ormqr|geqrf|triangular"},
+ {"reason": "Linalg",
+ "file": r"^test_nn$",
+ "msg": r"skipIfRocm.*doesn't currently work"},
+
+ # --- hipSolver/Magma: skipCUDAIfRocm in test_ops for ldl_solve, scaled_dot_product, conv_transpose3d ---
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)ldl_solve|scaled_dot_product|conv_transpose3d"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops_jit$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_decomp$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_schema_check$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_testing$"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"Skipped for ROCm!"},
+ {"reason": "hipSolver/Magma",
+ "msg": r"test_cow_input does not work with efficient attention on ROCM"},
+
+ # --- Compiler issue: "Skipped!" in test_ops for specific compiler-related tests ---
+ {"reason": "Compiler issue",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$",
+ "name": r"(?i)special_hermite_polynomial_h|special_laguerre"},
+
+ # --- non-standard bool: "Skipped!" in test_ops for bool-related tests ---
+ {"reason": "non-standard bool",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$",
+ "name": r"(?i)bool"},
+
+ # --- pow: "Skipped!" in test_ops/test_decomp for pow tests ---
+ {"reason": "pow",
+ "msg": r"^Skipped!$",
+ "file": r"^test_ops$|^test_decomp$",
+ "name": r"(?i)^pow$|_pow_|float_power"},
+
+ # --- fft: "Skipped!" or "Skipped on ROCm" in test_ops for fft tests ---
+ {"reason": "fft",
+ "msg": r"^Skipped(!| on ROCm)$",
+ "file": r"^test_ops$",
+ "name": r"(?i)fft"},
+
+ # --- NHWC: "Skipped!" in test_modules for NHWC tests ---
+ {"reason": "NHWC",
+ "msg": r"^Skipped!$",
+ "file": r"^test_modules$"},
+
+ # (FakeTensor removed — "Requires CUDA" messages are explicit NVIDIA test per policy)
+
+ # --- hermite_polynomial_h: custom_mask_type in test_ops for hermite ---
+ {"reason": "hermite_polynomial_h",
+ "msg": r"Efficient attention on ROCM doesn't support custom_mask_type",
+ "file": r"^test_ops$",
+ "name": r"(?i)hermite"},
+
+ # --- fake_crossref: skipCUDAIfRocm in test_ops for crossref tests ---
+ {"reason": "fake_crossref",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work",
+ "file": r"^test_ops$",
+ "name": r"(?i)crossref|fake_crossref"},
+
+ # --- Jit: Tensor-likes not close in test_jit_fuser ---
+ {"reason": "Jit",
+ "msg": r"Tensor-likes are not close",
+ "file": r"test_jit_fuser"},
+
+ # --- Memory allocation: TestBlockStateAbsorption in test_cuda ---
+ {"reason": "Memory allocation",
+ "file": r"^test_cuda$",
+ "cls": r"^TestBlockStateAbsorption$"},
+
+ # --- cuda allocator: TestCudaAllocator in test_cuda ---
+ {"reason": "cuda allocator",
+ "file": r"^test_cuda$",
+ "cls": r"^TestCudaAllocator$"},
+
+ # --- hipGraph/cudaGraph: CudaGraph-related classes in test_cuda ---
+ {"reason": "hipGraph/cudaGraph",
+ "file": r"^test_cuda$",
+ "cls": r"CachingHostAllocatorCudaGraph|GreenContext"},
+
+ # --- Memory allocation: TestMemPool in test_cuda ---
+ {"reason": "Memory allocation",
+ "file": r"^test_cuda$",
+ "cls": r"^TestMemPool$"},
+
+ # --- Profiler: TestFXMemoryProfiler in test_cuda ---
+ {"reason": "Profiler",
+ "file": r"^test_cuda$",
+ "cls": r"FXMemoryProfiler"},
+
+ # --- compiled optimizer: ROCm numerical behavior in inductor.test_compiled_optimizers ---
+ {"reason": "compiled optimizer",
+ "msg": r"ROCm may have different numerical behavior",
+ "file": r"inductor\.test_compiled_optimizers"},
+
+ # --- functorch: FuncTorch classes in inductor.test_compiled_autograd ---
+ {"reason": "functorch",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"FuncTorch"},
+
+ # --- PT2.0 - Distributed: DTensor classes in inductor.test_compiled_autograd ---
+ {"reason": "PT2.0 - Distributed",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"DTensor"},
+
+ # --- hipdnn: cudnn Attention messages ---
+ {"reason": "hipdnn",
+ "msg": r"[Cc]u[Dd][Nn][Nn] Attention is not supported"},
+ {"reason": "hipdnn",
+ "msg": r"Efficient or cuDNN Attention was not built"},
+
+ # --- Will not be supported on ROCm: test_transformers with (no message) ---
+ {"reason": "Will not be supported on ROCm",
+ "file": r"^test_transformers$",
+ "cls": r"SDPA.*CUDA",
+ "msg": r"^$"},
+
+ # --- transformers: test_transformers / test_flop_counter with misc messages ---
+ {"reason": "transformers",
+ "file": r"^test_transformers$",
+ "msg": r"Does not support all SDPA backends"},
+ {"reason": "transformers",
+ "file": r"^test_flop_counter$"},
+
+ # --- bfloat16: test_sparse_csr with (no message) ---
+ {"reason": "bfloat16",
+ "file": r"^test_sparse_csr$",
+ "cls": r"[Bb]float16|bf16"},
+ {"reason": "bfloat16",
+ "file": r"^test_sparse$",
+ "cls": r"[Bb]float16|bf16"},
+ {"reason": "bfloat16",
+ "file": r"^test_matmul_cuda$",
+ "msg": r"ROCm doesn't support CUTLASS"},
+
+ # --- explicit NVIDIA test: test_sparse_semi_structured with cutlass in NAME ---
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_sparse_semi_structured$",
+ "name": r"(?i)cutlass"},
+
+ # --- cusparselt: everything else in test_sparse_semi_structured ---
+ {"reason": "cusparselt",
+ "file": r"^test_sparse_semi_structured$"},
+
+ # --- Quantization: distributed quantization tests ---
+ {"reason": "Quantization",
+ "msg": r"Test skipped for ROCm",
+ "file": r"distributed\.algorithms\.quantization"},
+
+ # --- Process Group: distributed spawn/c10d with "Test skipped for ROCm" ---
+ {"reason": "Process Group",
+ "msg": r"Test skipped for ROCm",
+ "file": r"distributed\.test_distributed_spawn.*nccl"},
+
+ # ==================================================================
+ # TIER 2: Message-based rules (strong signal from skip message)
+ # ==================================================================
+
+ # SDPA_ME
+ {"reason": "SDPA_ME",
+ "msg": r"_fill_mem_eff_dropout_mask"},
+ {"reason": "SDPA_ME",
+ "msg": r"Efficient attention on ROCM doesn't support custom_mask_type"},
+ {"reason": "SDPA_ME",
+ "msg": r"Efficient Attention on ROCM does not support head_dim"},
+
+ # SDPA_FA
+ {"reason": "SDPA_FA",
+ "msg": r"Large numerical errors on ROCM"},
+ {"reason": "SDPA_FA",
+ "msg": r"flash attention not supported"},
+
+ # Will not be supported on ROCm
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"head_dim != head_dim_v unsupported on ROCm"},
+
+ # Triton 3.7 bump
+ {"reason": "triton 3.7 bump",
+ "msg": r"skipIfRocm.*Fails with Triton 3\.7"},
+
+ # MIOpen
+ {"reason": "MIOpen Convolutions",
+ "msg": r"Marked as skipped for MIOpen"},
+
+ # Static CUDA launcher
+ {"reason": "static cuda launcher",
+ "msg": r"Static cuda launcher doesn't work with ROCM"},
+
+ # NUMBA
+ {"reason": "NUMBA",
+ "msg": r"No numba\.cuda"},
+
+ # int4
+ {"reason": "int4",
+ "msg": r"_int4_mm is supported only for CDNA"},
+
+ # FP8
+ {"reason": "FP8",
+ "msg": r"cuBLAS blockwise scaling"},
+
+ # variable length attention
+ {"reason": "variable length attention",
+ "msg": r"ROCm does not support seqused_k"},
+
+ # CUDA IPC
+ {"reason": "Pass with unskip or minor mod",
+ "msg": r"CUDA IPC not available"},
+
+ # Python version
+ {"reason": "Python version",
+ "msg": r"Not supported in Python 3\.1[0-9]+"},
+
+ # cpp_test / CUDA not found
+ {"reason": "cpp_test",
+ "msg": r"CUDA not found"},
+ {"reason": "cpp_test",
+ "msg": r"CUDA_HOME not set"},
+
+ # Foreach
+ {"reason": "Foreach",
+ "msg": r"failed starting on ROCm"},
+
+ # CUTLASS
+ {"reason": "cutlass",
+ "msg": r"ROCm doesn't support CUTLASS|CUTLASS backend is not supported on HIP|ROCm and Windows doesn't support CUTLASS"},
+
+ # Transformers dependency
+ {"reason": "transformers",
+ "msg": r"No transformers"},
+
+ # hipGraph / cudaGraph (but NOT in functorch files -- those stay functorch)
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"Green contexts are not supported"},
+ {"reason": "functorch",
+ "msg": r"CUDA 12\.4 or greater is required for CUDA Graphs",
+ "file": r"^functorch\."},
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"CUDA 12\.4 or greater is required for CUDA Graphs"},
+ {"reason": "hipGraph/cudaGraph",
+ "msg": r"ROCM >= 5\.3 required for graphs.*cuda-bindings"},
+
+ # TMA / Blackwell
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Need.*TMA support"},
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Need Blackwell"},
+
+ # CUDA SM requirements
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires CUDA SM >= [0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires CUDA with SM >= [0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Test is only supported on CUDA 1[0-9]"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Requires NCCL version greater than"},
+ {"reason": "explicit NVIDIA test",
+ "msg": r"Excluded from CUDA tests"},
+
+ # FP8 — MI300+ / H100+ only
+ {"reason": "FP8",
+ "msg": r"FP8 is only supported on H100\+|FP8 is not supported on this platform|FP8 requires H100\+"},
+ {"reason": "FP8",
+ "msg": r"requires gpu with fp8 support"},
+
+ # Symmetric memory
+ {"reason": "Symmetric memory",
+ "msg": r"SymmMem is not supported on this ROCm arch"},
+
+ # Python version / 3.12+
+ {"reason": "Python version",
+ "msg": r"Failing on python 3\.12\+|torch\.compile is not supported on python 3\.12\+|complex flaky in 3\.12"},
+
+ # Greater than 4 GPU (distributed)
+ {"reason": "Greater than 4 GPU",
+ "msg": r"Need at least 4 CUDA devices"},
+ {"reason": "Greater than 4 GPU",
+ "msg": r"Test requires.*world size of 4"},
+ {"reason": "Greater than 4 GPU",
+ "msg": r"requires [34] GPUs, found [12]"},
+
+ # tensor_parallel — architecture-specific skip
+ {"reason": "tensor_parallel",
+ "msg": r"test only runs on \('gfx942'"},
+
+ # Process Group: subprocess level skip
+ {"reason": "Process Group",
+ "msg": r"Test skipped at subprocess level"},
+
+ # Sharded Tensor: subprocess level skip in _shard
+ {"reason": "Sharded Tensor",
+ "msg": r"Test skipped at subprocess level",
+ "file": r"distributed\._shard"},
+
+ # Process Group: NCCL version / device assert
+ {"reason": "Process Group",
+ "msg": r"NCCL test requires 2\+ GPUs"},
+
+ # Misc: ROCm preserves subnormals
+ {"reason": "Misc",
+ "msg": r"ROCm preserves subnormals"},
+
+ # Misc: GCC codegen
+ {"reason": "Misc",
+ "msg": r"Fails under GCC 1[0-9] due to vector codegen"},
+
+ # Misc: Skipped on ROCm due to hang
+ {"reason": "Misc",
+ "msg": r"Skipped on ROCm due to hang"},
+
+ # Misc: Test skipped for ROCm (generic distributed)
+ {"reason": "Misc",
+ "msg": r"Test skipped for ROCm"},
+
+ # Misc: architecture-specific skips
+ {"reason": "Misc",
+ "msg": r"test skipped on \('gfx"},
+
+ # cuFFT-specific
+ {"reason": "Misc",
+ "msg": r"cuFFT-specific"},
+
+ # ROCTracer profiler
+ {"reason": "Memory allocation",
+ "msg": r"ROCTracer does not capture"},
+
+ # expandable_segments-related messages
+ {"reason": "expandable_segments",
+ "msg": r"expandable_segments mode is not supported on ROCm"},
+ {"reason": "expandable_segments",
+ "msg": r"CUDA >= 11\.0 required for external events in cuda graphs.*rocm"},
+
+ # not enabled by default on rocm
+ {"reason": "expandable_segments",
+ "msg": r"not enabled by default on rocm"},
+
+ # HIP runtime context
+ {"reason": "Misc",
+ "msg": r"HIP runtime doesn't create context"},
+
+ # ==================================================================
+ # TIER 3: File + class based rules (for empty/generic messages)
+ # ==================================================================
+
+ # --- test_cuda class-based disambiguation ---
+ {"reason": "Misc",
+ "file": r"^test_cuda$",
+ "cls": r"^TestCuda$"},
+ {"reason": "compiled optimizer",
+ "file": r"^test_cuda$",
+ "cls": r"TestCudaOptims"},
+ {"reason": "Misc",
+ "file": r"^test_cuda$",
+ "cls": r"TestCudaAutocast"},
+ {"reason": "cpp_test",
+ "file": r"^test_cuda$",
+ "cls": r"TestCompileKernel"},
+
+ # --- test_nn (MI200-specific skips, no message) ---
+ {"reason": "Misc",
+ "file": r"^test_nn$"},
+
+ # --- inductor.test_fp8 ---
+ {"reason": "FP8",
+ "file": r"^inductor\.test_fp8$"},
+
+ # --- test_scaled_matmul_cuda ---
+ {"reason": "FP8",
+ "file": r"^test_scaled_matmul_cuda$"},
+
+ # --- inductor.test_torchinductor_strided_blocks ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_strided_blocks$"},
+
+ # --- inductor.test_flex_decoding ---
+ {"reason": "flex_decoding",
+ "file": r"^inductor\.test_flex_decoding$"},
+
+ # --- inductor.test_loop_ordering ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_loop_ordering$"},
+
+ # --- torch_np / numpy tests ---
+ {"reason": "NumPy",
+ "file": r"^torch_np\."},
+
+ # --- test_binary_ufuncs ---
+ {"reason": "Misc",
+ "file": r"^test_binary_ufuncs$"},
+
+ # --- test_fx ---
+ {"reason": "FX",
+ "file": r"^test_fx$"},
+
+ # --- profiler.test_execution_trace ---
+ {"reason": "Profiler",
+ "file": r"^profiler\.test_execution_trace$"},
+
+ # --- test_cpp_api_parity ---
+ {"reason": "cpp_test",
+ "file": r"^test_cpp_api_parity$"},
+
+ # --- test_expanded_weights ---
+ {"reason": "Misc",
+ "file": r"^test_expanded_weights$"},
+
+ # --- test_linalg (arch-specific skips) ---
+ {"reason": "Linalg",
+ "file": r"^test_linalg$"},
+
+ # --- test_torch (arch-specific skips) ---
+ {"reason": "Misc",
+ "file": r"^test_torch$"},
+
+ # --- nn.test_convolution (arch-specific) ---
+ {"reason": "MIOpen Convolutions",
+ "file": r"^nn\.test_convolution$"},
+
+ # --- inductor.test_aot_inductor_arrayref ---
+ {"reason": "PT2.0 - AOTInductor",
+ "file": r"^inductor\.test_aot_inductor_arrayref$"},
+
+ # --- distributed.test_symmetric_memory ---
+ {"reason": "Symmetric memory",
+ "file": r"^distributed\.test_symmetric_memory$"},
+
+ # --- inductor.test_compiled_autograd HigherOrderOp (MI300 has more classes) ---
+ {"reason": "functorch",
+ "file": r"^inductor\.test_compiled_autograd$",
+ "cls": r"HigherOrderOp"},
+
+ # --- explicit NVIDIA test in various files ---
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_cuda_nvml_based_avail$"},
+ {"reason": "explicit NVIDIA test",
+ "file": r"^test_cpp_extensions_aot"},
+
+ # --- hipGraph/cudaGraph: only test_graph_* (NOT test_cuda_graph_*) in test_cuda_expandable_segments ---
+ {"reason": "hipGraph/cudaGraph",
+ "file": r"^test_cuda_expandable_segments$",
+ "name": r"^test_graph_"},
+
+ # --- expandable_segments (everything else in test_cuda_expandable_segments) ---
+ {"reason": "expandable_segments",
+ "file": r"^test_cuda_expandable_segments$"},
+
+ # --- Profiler ---
+ {"reason": "Profiler",
+ "file": r"^profiler\.test_profiler$"},
+
+ # --- serialization ---
+ {"reason": "serialization",
+ "file": r"^test_serialization$"},
+
+ # --- dataloader ---
+ {"reason": "dataloader",
+ "file": r"^test_dataloader$"},
+
+ # --- Multi-Processing ---
+ {"reason": "Multi-Processing",
+ "file": r"^test_multiprocessing_spawn$"},
+ {"reason": "Multi-Processing",
+ "file": r"^test_multiprocessing$"},
+
+ # --- hipSparse ---
+ {"reason": "hipSparse",
+ "file": r"^test_sparse_csr$"},
+ {"reason": "hipSparse",
+ "file": r"^test_sparse$",
+ "msg": r"^$"},
+
+ # --- nested tensor ---
+ {"reason": "nested tensor",
+ "file": r"^test_nestedtensor$"},
+
+ # --- asm_elementwise ---
+ {"reason": "asm_elementwise",
+ "file": r"higher_order_ops\.test_inline_asm_elementwise"},
+
+ # --- torchinductor_opinfo_properties ---
+ {"reason": "torchinductor_opinfo_properties",
+ "file": r"^inductor\.test_torchinductor_opinfo_properties$"},
+
+ # --- flex_attention ---
+ {"reason": "flex_attention",
+ "file": r"^inductor\.test_flex_attention$"},
+
+ # --- compiled optimizer ---
+ {"reason": "compiled optimizer",
+ "file": r"^inductor\.test_compiled_optimizers$"},
+
+ # --- inductor combo_kernels ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_combo_kernels$"},
+
+ # --- inductor compiled_autograd (remaining after FuncTorch/DTensor class rules) ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compiled_autograd$"},
+
+ # --- Foreach (inductor) ---
+ {"reason": "Foreach",
+ "file": r"^inductor\.test_foreach$"},
+
+ # --- inductor codecache / cudacodecache ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_codecache$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cudacodecache$"},
+
+ # --- inductor GPU cpp wrapper ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_gpu_cpp_wrapper$"},
+
+ # --- inductor torchinductor variants ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_dynamic_shapes$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_codegen_dynamic_shapes$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_torchinductor_opinfo$"},
+
+ # --- inductor compile subprocess ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compile_subprocess$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_compile_worker$"},
+
+ # --- inductor cpu/cuda repro ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cpu_repro$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_cuda_repro$"},
+
+ # --- inductor custom lowering / minifier ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_custom_lowering$"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_minifier"},
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^inductor\.test_mix_order"},
+
+ # --- inductor aot_inductor ---
+ {"reason": "PT2.0 - AOTInductor",
+ "file": r"^inductor\.test_aot_inductor"},
+
+ # --- functorch ---
+ {"reason": "functorch",
+ "file": r"^functorch\."},
+
+ # --- dynamo ---
+ {"reason": "PT2.0 - Dynamo",
+ "file": r"^dynamo\."},
+
+ # --- export ---
+ {"reason": "PT2.0 - Inductor",
+ "file": r"^export\."},
+
+ # --- tf32: test_nn with "Test is disabled" ---
+ {"reason": "tf32",
+ "file": r"^test_nn$",
+ "msg": r"Test is disabled"},
+
+ # --- MIOpen Convolutions ---
+ {"reason": "MIOpen Convolutions",
+ "file": r"^nn\.test_convolution$"},
+
+ # --- test_stateless ---
+ {"reason": "Misc",
+ "file": r"^test_stateless$"},
+
+ # --- test_cuda_primary_ctx ---
+ {"reason": "Misc",
+ "file": r"^test_cuda_primary_ctx$"},
+
+ # --- test_torchfuzz ---
+ {"reason": "Misc",
+ "file": r"^test_torchfuzz"},
+
+ # ==================================================================
+ # TIER 4: Distributed file-based rules
+ # ==================================================================
+
+ # Sharded Tensor
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._shard\."},
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._composable\.fsdp\.test_fully_shard_training$"},
+ {"reason": "Sharded Tensor",
+ "file": r"^distributed\._composable\.fsdp\.test_fully_shard_clip_grad"},
+
+ # tensor_parallel
+ {"reason": "tensor_parallel",
+ "file": r"^distributed\.tensor\.parallel\."},
+
+ # pipeline_parallel
+ {"reason": "pipeline_parallel",
+ "file": r"^distributed\.pipelining\."},
+
+ # FSDP
+ {"reason": "FSDP",
+ "file": r"^distributed\.fsdp\."},
+ {"reason": "FSDP",
+ "file": r"^distributed\._composable\.fsdp\."},
+
+ # 2D FSDP / composability
+ {"reason": "2D FSDP",
+ "file": r"^distributed\._composable\.test_composability"},
+
+ # DDP / replicate
+ {"reason": "DDP",
+ "file": r"^distributed\._composable\.test_replicate"},
+
+ # Process Group / c10d
+ {"reason": "Process Group",
+ "file": r"^distributed\.test_c10d_"},
+
+ # PT2.0 - Distributed (dynamo_distributed)
+ {"reason": "PT2.0 - Distributed",
+ "file": r"^distributed\.test_dynamo_distributed$"},
+
+ # Collectives (tensor ops, composability, nccl)
+ {"reason": "Collectives",
+ "file": r"^distributed\.tensor\.test_"},
+ {"reason": "Collectives",
+ "file": r"^distributed\.test_composability$"},
+ {"reason": "Collectives",
+ "file": r"^distributed\.test_nccl$"},
+
+ # Distributed tools
+ {"reason": "Misc",
+ "file": r"^distributed\._tools\."},
+
+ # Distributed elastic
+ {"reason": "elastic",
+ "file": r"^distributed\.elastic\."},
+
+ # Distributed quantization
+ {"reason": "Quantization",
+ "file": r"^distributed\.algorithms\.quantization"},
+
+ # Distributed rpc
+ {"reason": "Misc",
+ "file": r"^distributed\.rpc\."},
+
+ # Distributed spawn
+ {"reason": "Misc",
+ "file": r"^distributed\.test_distributed_spawn"},
+
+ # Distributed (generic catch-all)
+ {"reason": "Misc",
+ "file": r"^distributed\."},
+
+ # ==================================================================
+ # TIER 5: Generic message fallbacks
+ # ==================================================================
+
+ # "Test is disabled" messages
+ {"reason": "Misc",
+ "msg": r"Test is disabled because an issue exists disabling it"},
+
+ # Generic skipIfRocm / skipCUDAIfRocm
+ {"reason": "Misc",
+ "msg": r"skipIfRocm.*doesn't currently work on the ROCm stack"},
+ {"reason": "Misc",
+ "msg": r"skipCUDAIfRocm.*doesn't currently work on the ROCm stack"},
+
+ # "Skipped!" / "Skipped"
+ {"reason": "Misc",
+ "msg": r"^Skipped!?$"},
+
+ # "Skipped on ROCm"
+ {"reason": "Misc",
+ "msg": r"^Skipped on ROCm$"},
+
+ # Not supported on ROCm (generic)
+ {"reason": "Will not be supported on ROCm",
+ "msg": r"Not supported on ROCm"},
+
+ # ==================================================================
+ # TIER 6: Catch-all for remaining test_cuda (no message, generic class)
+ # ==================================================================
+ {"reason": "Misc",
+ "file": r"^test_cuda$"},
+]
+
+
+def extract_message(raw_msg: str) -> str:
+ """Extract a clean message string from the raw CSV message_rocm value."""
+ if not raw_msg or raw_msg.strip() == '':
+ return ''
+ try:
+ d = ast.literal_eval(raw_msg)
+ if isinstance(d, dict):
+ return d.get('message', str(d))
+ except (ValueError, SyntaxError):
+ pass
+ return raw_msg.strip()
+
+
+def classify_test(msg: str, test_file: str, test_class: str, test_name: str,
+ workflow: str = '') -> str | None:
+ """Return the skip_reason for a test, or None if no rule matches."""
+ for rule in RULES:
+ match = True
+ if 'msg' in rule:
+ if not re.search(rule['msg'], msg, re.IGNORECASE):
+ match = False
+ if 'file' in rule and match:
+ if not re.search(rule['file'], test_file):
+ match = False
+ if 'cls' in rule and match:
+ if not re.search(rule['cls'], test_class):
+ match = False
+ if 'name' in rule and match:
+ if not re.search(rule['name'], test_name):
+ match = False
+ if 'workflow' in rule and match:
+ if workflow and workflow != rule['workflow']:
+ match = False
+ if match:
+ return rule['reason']
+ return None
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Auto-classify skip reasons for ROCm parity CSVs')
+ parser.add_argument('-i', '--input', required=True,
+ help='Input parity CSV file')
+ parser.add_argument('-o', '--output',
+ help='Output CSV with auto-classified skip_reason column')
+ parser.add_argument('--tsv-out',
+ help='Also write a TSV file in skip_reasons format '
+ '(compatible with --skip_reasons in summarize_xml_testreports.py)')
+ parser.add_argument('--only-unclassified', action='store_true',
+ help='Only classify tests that have no skip_reason (default)')
+ parser.add_argument('--reclassify-all', action='store_true',
+ help='Re-classify all tests, overwriting existing skip_reason')
+ parser.add_argument('--report', action='store_true',
+ help='Print classification report to stderr')
+ parser.add_argument('--dry-run', action='store_true',
+ help='Print report but do not write output files')
+ return parser.parse_args()
+
+
+def detect_columns(fieldnames):
+ """Detect whether CSV uses status_rocm/status_cuda or status_set1/status_set2."""
+ if 'status_rocm' in fieldnames:
+ return 'status_rocm', 'status_cuda', 'message_rocm'
+ elif 'status_set1' in fieldnames:
+ return 'status_set1', 'status_set2', 'message_set1'
+ else:
+ raise ValueError(f"Cannot detect status columns. Available: {fieldnames}")
+
+
+def main():
+ args = parse_args()
+
+ rows = []
+ with open(args.input, newline='') as f:
+ reader = csv.DictReader(f)
+ fieldnames = list(reader.fieldnames)
+ for row in reader:
+ rows.append(row)
+
+ col_rocm, col_cuda, col_msg = detect_columns(fieldnames)
+
+ for col in ('skip_reason', 'assignee', 'comments'):
+ if col not in fieldnames:
+ fieldnames.append(col)
+
+ classified_count = 0
+ already_had_count = 0
+ unclassified_count = 0
+ overwritten_count = 0
+ auto_reasons = Counter()
+ unclassified_msgs = Counter()
+ unclassified_files = Counter()
+ unclassified_details = []
+
+ tsv_entries = []
+
+ for row in rows:
+ status_rocm = row.get(col_rocm, '')
+ status_cuda = row.get(col_cuda, '')
+ existing_reason = row.get('skip_reason', '').strip()
+
+ needs_reason = (
+ status_rocm in ('SKIPPED', 'MISSED')
+ and status_cuda == 'PASSED'
+ )
+
+ if not needs_reason:
+ continue
+
+ raw_msg = row.get(col_msg, '')
+ msg = extract_message(raw_msg)
+ test_file = row.get('test_file', '')
+ test_class = row.get('test_class', '')
+ test_name = row.get('test_name', '')
+ workflow = row.get('test_config', '')
+
+ if existing_reason and not args.reclassify_all:
+ already_had_count += 1
+ tsv_entries.append({
+ 'test_file': test_file,
+ 'test_name': test_name,
+ 'test_class': test_class,
+ 'skip_reason': existing_reason,
+ 'assignee': row.get('assignee', ' '),
+ 'comments': row.get('comments', ' '),
+ })
+ continue
+
+ reason = classify_test(msg, test_file, test_class, test_name, workflow)
+
+ if reason:
+ if existing_reason and existing_reason != reason:
+ overwritten_count += 1
+ row['skip_reason'] = reason
+ row.setdefault('assignee', '')
+ row.setdefault('comments', 'auto-classified')
+ classified_count += 1
+ auto_reasons[reason] += 1
+ tsv_entries.append({
+ 'test_file': test_file,
+ 'test_name': test_name,
+ 'test_class': test_class,
+ 'skip_reason': reason,
+ 'assignee': row.get('assignee', ' ') if not args.reclassify_all else ' ',
+ 'comments': 'auto-classified',
+ })
+ else:
+ unclassified_count += 1
+ display_msg = msg[:100] if msg else '(no message)'
+ unclassified_msgs[display_msg] += 1
+ unclassified_files[test_file] += 1
+ unclassified_details.append(
+ f" {test_file:55s} {test_class:45s} {test_name[:40]:42s} {display_msg[:50]}")
+
+ if args.report or args.dry_run:
+ total = already_had_count + classified_count + unclassified_count
+ print(f"\n{'='*60}", file=sys.stderr)
+ print(f"AUTO-CLASSIFICATION REPORT", file=sys.stderr)
+ print(f"{'='*60}", file=sys.stderr)
+ print(f"Already had skip_reason: {already_had_count}", file=sys.stderr)
+ print(f"Auto-classified: {classified_count}", file=sys.stderr)
+ if overwritten_count:
+ print(f" (overwritten existing: {overwritten_count})", file=sys.stderr)
+ print(f"Still unclassified: {unclassified_count}", file=sys.stderr)
+ if total:
+ pct = (already_had_count + classified_count) / total * 100
+ print(f"Coverage: {pct:.1f}%", file=sys.stderr)
+ print(f"Total target tests: {total}", file=sys.stderr)
+
+ if auto_reasons:
+ print(f"\nAuto-classified by category:", file=sys.stderr)
+ for reason, cnt in auto_reasons.most_common():
+ print(f" {cnt:5d} {reason}", file=sys.stderr)
+
+ if unclassified_msgs:
+ print(f"\nUnclassified — top messages:", file=sys.stderr)
+ for msg_key, cnt in unclassified_msgs.most_common(15):
+ print(f" {cnt:5d} {msg_key}", file=sys.stderr)
+
+ if unclassified_files:
+ print(f"\nUnclassified — top files:", file=sys.stderr)
+ for f, cnt in unclassified_files.most_common(15):
+ print(f" {cnt:5d} {f}", file=sys.stderr)
+
+ if unclassified_details and len(unclassified_details) <= 50:
+ print(f"\nUnclassified tests:", file=sys.stderr)
+ for d in unclassified_details:
+ print(d, file=sys.stderr)
+
+ if args.dry_run:
+ return
+
+ if not args.output:
+ print("No --output specified; use --dry-run for report-only mode.",
+ file=sys.stderr)
+ sys.exit(1)
+
+ with open(args.output, 'w', newline='') as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames, extrasaction='ignore')
+ writer.writeheader()
+ for row in rows:
+ writer.writerow(row)
+
+ if args.tsv_out and tsv_entries:
+ with open(args.tsv_out, 'w', newline='') as f:
+ writer = csv.DictWriter(
+ f,
+ fieldnames=['test_file', 'test_name', 'test_class',
+ 'skip_reason', 'assignee', 'comments'],
+ delimiter='\t',
+ )
+ writer.writeheader()
+ for entry in tsv_entries:
+ writer.writerow(entry)
+ print(f"\nWrote TSV with {len(tsv_entries)} entries to {args.tsv_out}",
+ file=sys.stderr)
+
+ print(f"Wrote {len(rows)} rows to {args.output}", file=sys.stderr)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py b/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py
new file mode 100755
index 0000000000000..0156624c35973
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/detect_log_failures.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+"""Scan CI log files (.txt) for test failures not captured in XML reports.
+
+Tests that timeout (exit code 124), crash (SIGIOT, SIGSEGV, Fatal Python error),
+or are killed (SIGKILL, OOM) never produce JUnit XML output. This script detects
+those failures from the raw log files and outputs a CSV/summary.
+
+Usage:
+ python detect_log_failures.py --logs-dir [--output ]
+"""
+
+import argparse
+import csv
+import os
+import re
+import sys
+from collections import defaultdict
+from pathlib import Path
+
+
+RE_RUNNING = re.compile(
+ r"Running (?P\S+) (?P\d+)/(?P\d+) \.\.\."
+)
+RE_SUCCESS = re.compile(
+ r"(?P\S+) (?P\d+)/(?P\d+) was successful"
+)
+RE_FAILED = re.compile(
+ r"(?P\S+) (?P\d+)/(?P\d+) failed!(?P.*)"
+)
+RE_EXIT_CODE = re.compile(r"Got exit code (?P\d+)")
+RE_TIMEOUT = re.compile(r"Command took >(\d+)min, returning 124")
+RE_FAILED_CONSISTENTLY = re.compile(
+ r"FAILED CONSISTENTLY: (?P\S+)"
+)
+RE_STEPCURRENT = re.compile(
+ r"stepcurrent:.*Running only (?:test/)?(?P\S+)"
+)
+RE_INDIVIDUAL_TEST = re.compile(
+ r"(?P\S+\.py::(?P\w+)::(?P\w+))"
+)
+RE_INDIV_PASSED = re.compile(
+ r"(?:test/)?(?P\S+\.py)::(?P\w+)::(?P\S+?)\s+PASSED"
+)
+RE_NEW_PROCESS_SUCCESS = re.compile(r"Test succeeded in new process")
+
+CRASH_PATTERNS = [
+ (re.compile(r"Segmentation fault", re.IGNORECASE), "SEGFAULT"),
+ (re.compile(r"SIGSEGV"), "SIGSEGV"),
+ (re.compile(r"SIGIOT"), "SIGIOT"),
+ (re.compile(r"SIGABRT"), "SIGABRT"),
+ (re.compile(r"SIGKILL"), "SIGKILL"),
+ (re.compile(r"Fatal Python error", re.IGNORECASE), "FATAL_PYTHON"),
+ (re.compile(r"core dumped", re.IGNORECASE), "CORE_DUMP"),
+ (re.compile(r"Aborted \(core dumped\)", re.IGNORECASE), "ABORTED"),
+ (re.compile(r"torch\.cuda\.OutOfMemoryError"), "CUDA_OOM"),
+ (re.compile(r"std::bad_alloc"), "BAD_ALLOC"),
+]
+
+LOG_FILE_MAP = {
+ "rocm": ("rocm", "default"),
+ "rocm_dist": ("rocm", "distributed"),
+ "rocm_inductor": ("rocm", "inductor"),
+ "cuda": ("cuda", "default"),
+ "cuda_dist": ("cuda", "distributed"),
+ "cuda_inductor": ("cuda", "inductor"),
+ "baseline": ("baseline", "default"),
+}
+
+
+def classify_log_file(filename):
+ """Return (platform, test_config, shard_num) from a log filename like rocm3.txt."""
+ stem = Path(filename).stem
+ for prefix, (platform, test_config) in sorted(LOG_FILE_MAP.items(), key=lambda x: -len(x[0])):
+ if stem.startswith(prefix):
+ remainder = stem[len(prefix):]
+ if remainder.isdigit():
+ return platform, test_config, int(remainder)
+ return None, None, None
+
+
+RE_TIMESTAMP = re.compile(r"^\d{4}-\d{2}-\d{2}T[\d:.]+Z\s*")
+
+
+def parse_log_file(filepath):
+ """Parse a single log file and return test file results, consistent failures,
+ and flaky tests.
+
+ A flaky test is one that failed in its normal-process run but PASSED when the
+ CI harness re-ran it alone in a new subprocess (indicated by a PASSED line
+ for the specific test::class::method, followed by 'Test succeeded in new
+ process, continuing with the rest of the tests').
+ """
+ results = {}
+ current_test = None
+ last_failed_test = None
+ consistent_failures = []
+ flaky_tests = []
+ last_passed_individual = None
+
+ with open(filepath, "r", errors="replace") as f:
+ for line in f:
+ # Lightweight tracking of individual pytest test lines.
+ # These are very frequent (~37% of lines) so we extract the
+ # test name directly without timestamp stripping.
+ if ".py::" in line:
+ m_ind = RE_INDIVIDUAL_TEST.search(line)
+ if m_ind:
+ active = current_test or last_failed_test
+ if active and active in results:
+ # Only update if the pytest path belongs to this shard's test file,
+ # otherwise rerun output from earlier shards contaminates later ones.
+ shard_file = results[active]["test_file"]
+ if shard_file + ".py" in m_ind.group("test_path"):
+ results[active]["last_test"] = f"{m_ind.group('cls')}::{m_ind.group('method')}"
+
+ if " ... [" not in line and "was successful" not in line \
+ and "failed!" not in line and "Got exit code" not in line \
+ and "returning 124" not in line and "FAILED CONSISTENTLY" not in line \
+ and "Retrying" not in line \
+ and "Segmentation fault" not in line and "SIGIOT" not in line \
+ and "SIGSEGV" not in line and "SIGABRT" not in line \
+ and "SIGKILL" not in line \
+ and "Fatal Python error" not in line and "core dumped" not in line \
+ and "Aborted (core dumped)" not in line \
+ and "OutOfMemoryError" not in line \
+ and "bad_alloc" not in line \
+ and "stepcurrent" not in line \
+ and "PASSED" not in line \
+ and "new process" not in line:
+ continue
+
+ stripped = RE_TIMESTAMP.sub("", line).rstrip()
+
+ m = RE_RUNNING.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ current_test = key
+ if key not in results:
+ results[key] = {
+ "test_file": m.group("test_file"),
+ "shard": int(m.group("shard")),
+ "total": int(m.group("total")),
+ "status": "RUNNING",
+ "reason": "",
+ "exit_codes": [],
+ "crashes": [],
+ "crash_tests": [],
+ "last_test": "",
+ }
+ continue
+
+ m = RE_SUCCESS.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ if key in results:
+ results[key]["status"] = "PASSED"
+ current_test = None
+ last_failed_test = None
+ continue
+
+ m = RE_FAILED.search(stripped)
+ if m:
+ key = f"{m.group('test_file')} {m.group('shard')}/{m.group('total')}"
+ reason = m.group("reason").strip()
+ if key in results:
+ results[key]["status"] = "FAILED"
+ if reason:
+ results[key]["reason"] = reason
+ last_failed_test = key
+ current_test = key
+ continue
+
+ active = current_test or last_failed_test
+
+ # Track stepcurrent rerun lines — identifies crash-causing test
+ m = RE_STEPCURRENT.search(stripped)
+ if m:
+ test_path = m.group("test_path")
+ parts = test_path.split("::")
+ if len(parts) >= 3:
+ crash_id = f"{parts[1]}::{parts[2]}"
+ elif len(parts) == 2:
+ crash_id = parts[1]
+ else:
+ crash_id = None
+ if crash_id and active and active in results:
+ shard_file = results[active]["test_file"]
+ if shard_file in test_path:
+ if crash_id not in results[active]["crash_tests"]:
+ results[active]["crash_tests"].append(crash_id)
+ continue
+
+ # Track individual pytest test lines for last-running-test context
+ m_ind = RE_INDIVIDUAL_TEST.search(stripped)
+ if m_ind and active and active in results:
+ cls = m_ind.group("cls")
+ method = m_ind.group("method")
+ results[active]["last_test"] = f"{cls}::{method}"
+
+ m = RE_EXIT_CODE.search(stripped)
+ if m:
+ code = int(m.group("code"))
+ if active and active in results:
+ results[active]["exit_codes"].append(code)
+
+ m = RE_TIMEOUT.search(stripped)
+ if m and active and active in results:
+ if "TIMEOUT" not in results[active]["crashes"]:
+ results[active]["crashes"].append("TIMEOUT")
+
+ m = RE_FAILED_CONSISTENTLY.search(stripped)
+ if m:
+ shard_str = ""
+ if active and active in results:
+ info = results[active]
+ shard_str = f"{info['shard']}/{info['total']}"
+ consistent_failures.append((m.group("test_path"), shard_str))
+
+ # Detect individual PASSED lines for flaky-rerun tracking.
+ m = RE_INDIV_PASSED.search(stripped)
+ if m:
+ last_passed_individual = {
+ "file": m.group("file"),
+ "cls": m.group("cls"),
+ "method": m.group("method"),
+ "active": active,
+ }
+
+ # When we see 'Test succeeded in new process' after a PASSED
+ # individual test, that test was originally failing in the main
+ # process (CI only falls back to rerun-in-new-process for tests
+ # that crashed or failed) but passed on retry -> flaky.
+ if RE_NEW_PROCESS_SUCCESS.search(stripped) and last_passed_individual:
+ lp = last_passed_individual
+ lp_active = lp.get("active")
+ test_shard = ""
+ if lp_active and lp_active in results:
+ info = results[lp_active]
+ test_shard = f"{info['shard']}/{info['total']}"
+ flaky_tests.append({
+ "file": lp["file"],
+ "cls": lp["cls"],
+ "method": lp["method"],
+ "test_shard": test_shard,
+ })
+ last_passed_individual = None
+
+ if active and active in results:
+ for pattern, label in CRASH_PATTERNS:
+ if pattern.search(stripped):
+ if label not in results[active]["crashes"]:
+ results[active]["crashes"].append(label)
+
+ return results, consistent_failures, flaky_tests
+
+
+def scan_logs(logs_dir):
+ """Scan all log files and return non-passing test file results plus a
+ test-level shard inventory.
+
+ Returns (all_failures, shard_inventory) where shard_inventory is a list
+ of dicts with one entry per (platform, test_config, job_shard, test_file)
+ combination seen in the logs, plus a sorted comma-separated list of the
+ test-level shards observed (e.g. "1/1" or "1/15,2/15,...,15/15"). This
+ lets downstream consumers look up the test-level shard for any XML-based
+ failure whose only shard info is the job-level shard."""
+ all_failures = []
+ all_flaky = []
+ shard_map = defaultdict(set)
+
+ # Pre-compute job-level shard totals per (platform, test_config) by
+ # counting how many log files belong to each group. Log files are
+ # 1-indexed (e.g. rocm1.txt..rocm6.txt for a 6-way sharded job), so
+ # the count == total shards for that CI job.
+ shard_totals = defaultdict(int)
+ for fname in os.listdir(logs_dir):
+ if not fname.endswith(".txt"):
+ continue
+ platform, test_config, shard_num = classify_log_file(fname)
+ if platform is None:
+ continue
+ shard_totals[(platform, test_config)] += 1
+
+ for fname in sorted(os.listdir(logs_dir)):
+ if not fname.endswith(".txt"):
+ continue
+
+ platform, test_config, shard_num = classify_log_file(fname)
+ if platform is None:
+ continue
+
+ job_total = shard_totals.get((platform, test_config), 0)
+ job_shard_str = f"{shard_num}/{job_total}" if job_total else str(shard_num)
+
+ filepath = os.path.join(logs_dir, fname)
+ results, consistent_failures, flaky_tests = parse_log_file(filepath)
+
+ for ft in flaky_tests:
+ file_part = ft["file"].replace("test/", "").replace(".py", "")
+ all_flaky.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": file_part,
+ "test_class": ft["cls"],
+ "test_name": ft["method"],
+ "job_shard": job_shard_str,
+ "test_shard": ft["test_shard"],
+ })
+
+ # Record every (test_file, test_shard) observed in this log file,
+ # including PASSED ones, so the inventory covers the full run.
+ for info in results.values():
+ shard_map[(platform, test_config, job_shard_str, info["test_file"])].add(
+ f"{info['shard']}/{info['total']}"
+ )
+
+ for key, info in results.items():
+ if info["status"] == "PASSED":
+ continue
+
+ categories = []
+ if 124 in info["exit_codes"] or "TIMEOUT" in info["crashes"]:
+ categories.append("TIMEOUT")
+ for c in info["crashes"]:
+ if c != "TIMEOUT":
+ categories.append(c)
+ if info["status"] == "FAILED" and not categories:
+ categories.append("FAILED")
+ if info["status"] == "RUNNING" and not categories:
+ categories.append("INCOMPLETE")
+
+ if not categories:
+ continue
+ # Skip tests stuck in RUNNING with no evidence of failure —
+ # these are typically from multi-shard logs where a different
+ # shard's "Running ..." line appeared but the result was elsewhere.
+ if info["status"] == "RUNNING" and categories == ["INCOMPLETE"]:
+ continue
+
+ reason = info["reason"]
+ # Populate reason with identified crash/timeout test name
+ crash_tests = info.get("crash_tests", [])
+ last_test = info.get("last_test", "")
+ identified_test = ""
+ if crash_tests:
+ identified_test = crash_tests[0]
+ elif last_test:
+ identified_test = last_test
+
+ if identified_test and "::" in identified_test:
+ if not reason:
+ reason = identified_test
+ elif "::" not in reason:
+ reason = f"{identified_test} | {reason}"
+
+ all_failures.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": info["test_file"],
+ "job_shard": job_shard_str,
+ "test_shard": f"{info['shard']}/{info['total']}",
+ "status": info["status"],
+ "category": "+".join(categories),
+ "reason": reason,
+ "exit_codes": ",".join(str(c) for c in info["exit_codes"]),
+ })
+
+ for test_path, shard_str in consistent_failures:
+ parts = test_path.split("::")
+ file_part = parts[0].replace("test/", "").replace(".py", "")
+ test_class = parts[1] if len(parts) > 1 else ""
+ test_name = parts[2] if len(parts) > 2 else ""
+
+ all_failures.append({
+ "log_file": fname,
+ "platform": platform,
+ "test_config": test_config,
+ "test_file": file_part,
+ "job_shard": job_shard_str,
+ "test_shard": shard_str,
+ "status": "FAILED_CONSISTENTLY",
+ "category": "CONSISTENT_FAILURE",
+ "reason": f"{test_class}::{test_name}" if test_class else "",
+ "exit_codes": "",
+ })
+
+ def _sort_shards(vals):
+ def key(v):
+ try:
+ a, b = v.split("/", 1)
+ return (int(b), int(a))
+ except (ValueError, AttributeError):
+ return (0, 0)
+ return sorted(vals, key=key)
+
+ shard_inventory = [
+ {
+ "platform": platform,
+ "test_config": test_config,
+ "job_shard": job_shard_str,
+ "test_file": test_file,
+ "test_shards": ",".join(_sort_shards(shards)),
+ }
+ for (platform, test_config, job_shard_str, test_file), shards in shard_map.items()
+ ]
+ shard_inventory.sort(key=lambda r: (r["platform"], r["test_config"],
+ r["job_shard"], r["test_file"]))
+
+ all_flaky.sort(key=lambda r: (r["platform"], r["test_config"],
+ r["job_shard"], r["test_file"],
+ r["test_class"], r["test_name"]))
+
+ return all_failures, shard_inventory, all_flaky
+
+
+def write_csv_report(failures, output_path):
+ fieldnames = [
+ "log_file", "platform", "test_config", "test_file",
+ "job_shard", "test_shard",
+ "status", "category", "reason", "exit_codes",
+ ]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(failures)
+ print(f"Log failure report: {output_path} ({len(failures)} entries)")
+
+
+def write_shards_report(inventory, output_path):
+ fieldnames = ["platform", "test_config", "job_shard", "test_file", "test_shards"]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(inventory)
+ print(f"Log shard inventory: {output_path} ({len(inventory)} entries)")
+
+
+def write_flaky_report(flaky, output_path):
+ fieldnames = [
+ "log_file", "platform", "test_config", "test_file",
+ "test_class", "test_name", "job_shard", "test_shard",
+ ]
+ with open(output_path, "w", newline="") as f:
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
+ writer.writeheader()
+ writer.writerows(flaky)
+ print(f"Flaky test report: {output_path} ({len(flaky)} entries)")
+
+
+def _derive_sibling_path(output_path, new_prefix):
+ """Given an output path like '.../log_failures_mi355.csv' and
+ new_prefix='log_shards', return '.../log_shards_mi355.csv'. Falls back to
+ appending '.{new_prefix}.csv' if the expected prefix isn't present."""
+ d, base = os.path.split(output_path)
+ if base.startswith("log_failures"):
+ return os.path.join(d, new_prefix + base[len("log_failures"):])
+ stem, ext = os.path.splitext(base)
+ return os.path.join(d, f"{stem}.{new_prefix}{ext or '.csv'}")
+
+
+def _derive_shards_path(output_path):
+ return _derive_sibling_path(output_path, "log_shards")
+
+
+def _derive_flaky_path(output_path):
+ return _derive_sibling_path(output_path, "flaky_tests")
+
+
+def print_summary(failures):
+ if not failures:
+ print("No log-based failures detected.")
+ return
+
+ by_category = defaultdict(list)
+ for f in failures:
+ by_category[f["category"]].append(f)
+
+ print(f"\n{'='*60}")
+ print("LOG FAILURE DETECTION SUMMARY")
+ print(f"{'='*60}")
+ print(f"Total failures detected: {len(failures)}")
+ print()
+
+ for cat, items in sorted(by_category.items()):
+ print(f" {cat}: {len(items)}")
+ for item in items:
+ print(f" - {item['test_file']} ({item['platform']}/{item['test_config']}) [{item['log_file']}]")
+ if item["reason"]:
+ print(f" Reason: {item['reason'][:120]}")
+ print()
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description="Detect test failures from CI log files not captured in XML reports"
+ )
+ parser.add_argument(
+ "--logs-dir", required=True,
+ help="Directory containing .txt log files"
+ )
+ parser.add_argument(
+ "--output", default="log_failures.csv",
+ help="Output CSV path (default: log_failures.csv)"
+ )
+ args = parser.parse_args()
+
+ failures, shard_inventory, flaky_tests = scan_logs(args.logs_dir)
+ print_summary(failures)
+ write_csv_report(failures, args.output)
+ write_shards_report(shard_inventory, _derive_shards_path(args.output))
+ write_flaky_report(flaky_tests, _derive_flaky_path(args.output))
+ return 0 if not failures else 1
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.automation_scripts/pytorch-unit-test-scripts/download_testlogs b/.automation_scripts/pytorch-unit-test-scripts/download_testlogs
new file mode 100755
index 0000000000000..ac4214f99fecd
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/download_testlogs
@@ -0,0 +1,1074 @@
+#!/usr/bin/env python3
+
+
+try:
+ import os
+ import json
+ import argparse
+ import requests
+ import re
+ import sys
+ from upload_stats_lib import unzip
+ from upload_test_stats import download_gha_artifacts, download_s3_artifacts
+except ImportError:
+ import subprocess
+ result = subprocess.run(["pip3", "install", "-U", "-r", "requirements.txt"], capture_output=True, text=True)
+ print(result.stdout)
+ print("Please rerun the download_testlogs script")
+ sys.exit(1)
+
+
+# Check if environment variables are set
+required_env_vars = ['GITHUB_TOKEN', 'AWS_ACCESS_KEY_ID', 'AWS_SECRET_ACCESS_KEY']
+
+missing_vars = [var for var in required_env_vars if not os.getenv(var)]
+if missing_vars:
+ print(f"ERROR: Please set these environment variables: {', '.join(missing_vars)}")
+ sys.exit(1)
+
+
+# global variables
+error_msgs = []
+# Workflow names mapped to TEST_CONFIG values in PyTorch CI
+# These are set dynamically based on --arch argument in main()
+ROCmWorkflowNames = {}
+CUDAWorkflowNames = {"default": "trunk",
+ # Same as default, so not used for now
+ # "distributed": "pull",
+ "inductor": "inductor"}
+
+authentication_headers = None
+
+def get_commit_hashes(pr_id, token):
+ owner = "pytorch"
+ repo = "pytorch"
+ commits_url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_id}/commits"
+ headers = {
+ "Authorization": f"token {token}",
+ "Accept": "application/vnd.github.v3+json"
+ }
+ page = 1
+ commits = []
+ while True:
+ response = requests.get(commits_url, headers=headers, params={'page': page})
+ if response.status_code == 200:
+ new_commits = response.json()
+ if not new_commits:
+ break
+ commits.extend(new_commits)
+ page += 1
+ else:
+ print(f"Failed to fetch commits: {response.status_code}")
+ break
+ return commits
+
+def get_latest_commit_sha(pr_id, token):
+ commits = get_commit_hashes(pr_id, token)
+ if commits:
+ return commits[-1]['sha']
+ else:
+ print("No commits found for the given pull request.")
+ sys.exit(1)
+
+def write_test_log_to_file(filename, test_key, jobs, sha):
+ js = [j for j in jobs if test_key in j['name']]
+ if len(js) > 0:
+ if len(js) > 1:
+ print(f"WARNING: Found multiple jobs with key: '{test_key}', selecting first one")
+ for j in js:
+ print(j['name'])
+ test_id = js[0]['id']
+ print(f"key: {test_key}, job Name: {js[0]['name']}, job ID: {test_id}, Downloading to {filename}")
+ else:
+ # Not being able to download logs is not a fatal error since we primarily depend on xml artifacts
+ # so log error and continue
+ error_msg = f"Error: TEST KEY: {test_key} DOES NOT EXIST IN JOBS.\nCheck url - https://hud.pytorch.org/hud/pytorch/pytorch/{sha}/1?per_page=50 - for job name"
+ print(error_msg)
+ error_msgs.append(error_msg)
+ return
+ response = requests.get( "https://ossci-raw-job-status.s3.amazonaws.com/log/" + str(test_id) )
+ with open(filename, "w", encoding="utf-8") as f:
+ f.write(response.text)
+
+def get_workflow_jobs(wf):
+ """Get all jobs for a workflow run."""
+ if wf is None:
+ raise Exception("wf is None!")
+ page_size = 100 #max allowed by Github API
+ response = requests.get( wf['jobs_url'], headers=authentication_headers, params={'per_page':page_size} )
+ response_json = response.json()
+ jobs = response_json["jobs"]
+
+ if response_json['total_count'] > page_size:
+ import math
+ for i in range(2, math.ceil(response_json['total_count']/page_size) + 1):
+ response = requests.get( wf['jobs_url'], headers=authentication_headers, params={'per_page':page_size, 'page':i} )
+ jobs += response.json()["jobs"]
+ return jobs
+
+def get_check_runs_for_commit(sha, prefix):
+ """Get check runs for a commit filtered by name prefix.
+
+ The workflow jobs API does not return jobs from reusable workflows
+ (workflow_call). The check-runs API returns all jobs regardless of
+ workflow nesting, so we use it as a fallback.
+ """
+ check_runs = []
+ page = 1
+ while True:
+ response = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/commits/{sha}/check-runs",
+ headers=authentication_headers,
+ params={'per_page': 100, 'page': page},
+ )
+ data = response.json()
+ runs = data.get('check_runs', [])
+ check_runs.extend([cr for cr in runs if prefix in cr.get('name', '')])
+ if len(runs) < 100:
+ break
+ page += 1
+ return check_runs
+
+def get_job_ids_by_prefix(wf, prefix):
+ """Get job IDs (as strings) for jobs whose name contains the given prefix."""
+ jobs = get_workflow_jobs(wf)
+ return [str(j['id']) for j in jobs if prefix in j['name']]
+
+def matches_job_prefix(job_name, prefix):
+ """Match the exact CUDA job family without also matching -debug/-sm86 jobs."""
+ return job_name.startswith(f"{prefix} / ")
+
+def get_cuda_test_jobs(jobs, cuda_job_prefix):
+ """Return the CUDA test kind and jobs for either main or PR CI layouts."""
+ for test_kind in ("test-osdc", "test"):
+ test_jobs = [
+ j for j in jobs
+ if matches_job_prefix(j['name'], cuda_job_prefix)
+ and f"/ {test_kind} (" in j['name']
+ ]
+ if test_jobs:
+ return test_kind, test_jobs
+ return "test-osdc", []
+
+def get_cuda_inductor_test_jobs(jobs):
+ """Return the CUDA inductor test kind and jobs for either main or PR CI layouts."""
+ for test_kind in ("test-osdc", "test"):
+ test_jobs = [
+ j for j in jobs
+ if f"unit-test / inductor-test / {test_kind} (inductor," in j['name']
+ ]
+ if test_jobs:
+ return test_kind, test_jobs
+ return "test-osdc", []
+
+def download_logs(wf, test_log_list, test_folder, jobs=None):
+ if wf is None:
+ raise Exception("wf is None!")
+
+ if jobs is None:
+ jobs = get_workflow_jobs(wf)
+
+ for test_log in test_log_list:
+ write_out_file = test_folder + "/" + test_log[0]
+ write_test_log_to_file(write_out_file, test_log[1], jobs, wf['head_sha'])
+
+def download_gha_artifacts_filtered(workflow_run_id, workflow_run_attempt, prefixes=[], allowed_substrings=None):
+ """Download GHA artifacts matching prefixes and optional substring filters.
+
+ GHA artifact names include run attempt info, e.g.:
+ test-reports-runattempt1-test-default-3-6-linux.rocm.gpu.gfx942.1_68425162477.zip
+ while S3 prefixes look like:
+ test-reports-test-default-3-6
+ We strip the runattemptN- portion before matching prefixes.
+
+ When a shard is re-run, only the latest attempt's artifact exists for that
+ shard, while other shards keep their original attempt. We collect all
+ matching artifacts and prefer the highest run attempt per shard key.
+ """
+ from pathlib import Path
+ from collections import defaultdict
+ artifact_paths = []
+ response = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/runs/{workflow_run_id}/artifacts?per_page=100",
+ headers=authentication_headers,
+ )
+ artifacts = response.json().get("artifacts", [])
+ while "next" in response.links:
+ response = requests.get(response.links["next"]["url"], headers=authentication_headers)
+ artifacts.extend(response.json().get("artifacts", []))
+
+ # Group matching artifacts by shard key, keeping highest run attempt
+ # shard key = normalized name without runattemptN- and without runner/jobid suffix
+ best_per_shard = {}
+ for artifact in artifacts:
+ name = artifact["name"]
+ if not name.startswith("test-reports-"):
+ continue
+ if "rerun_disabled" in name:
+ continue
+ normalized = re.sub(r'runattempt\d+-', '', name)
+ if not any(normalized.startswith(pfx) for pfx in prefixes):
+ continue
+ if allowed_substrings and not any(sub in name for sub in allowed_substrings):
+ continue
+ # Extract run attempt number
+ attempt_match = re.search(r'runattempt(\d+)', name)
+ attempt_num = int(attempt_match.group(1)) if attempt_match else 0
+ # Use the shard portion as key (e.g., test-reports-test-default-3-6)
+ shard_key = re.sub(r'-[a-z]+\..*$', '', normalized)
+ if shard_key not in best_per_shard or attempt_num > best_per_shard[shard_key][0]:
+ best_per_shard[shard_key] = (attempt_num, name, artifact["archive_download_url"])
+
+ for shard_key, (attempt_num, name, url) in best_per_shard.items():
+ print(f"Downloading GHA artifact: {name}")
+ dl_response = requests.get(url, headers=authentication_headers)
+ if dl_response.status_code != 200:
+ print(f" WARNING: Failed to download (HTTP {dl_response.status_code})")
+ continue
+ p = Path(name if name.endswith(".zip") else name + ".zip")
+ with open(p, "wb") as f:
+ f.write(dl_response.content)
+ artifact_paths.append(p)
+
+ return artifact_paths
+
+def _shorten_unzipped_dirs():
+ """Rename unzipped-* directories to short names for Windows MAX_PATH compatibility.
+
+ Converts names like:
+ unzipped-test-reports-runattempt1-test-default-1-6-linux.rocm.gpu.gfx942.1_68613413431.zip
+ unzipped-test-reports-runattempt1-test-osdc-default-1-5-mt-l-x86aavx2-29-113-l4_73385044118.zip
+ to:
+ test-default-1-6
+ test-default-1-5
+
+ Preserves the 'test-' prefix so that summarize_xml_testreports.py
+ can still detect workflow type via substring matching.
+ """
+ from pathlib import Path
+ for d in sorted(Path(".").glob("unzipped-*")):
+ if not d.is_dir():
+ continue
+ m = re.search(r'test-(?:osdc-)?(default|distributed|inductor)-(\d+)-(\d+)', d.name)
+ if m:
+ short_name = f"test-{m.group(1)}-{m.group(2)}-{m.group(3)}"
+ if not Path(short_name).exists():
+ d.rename(short_name)
+ print(f" Renamed {d.name} -> {short_name}")
+ else:
+ print(f" WARNING: {short_name} already exists, keeping {d.name}")
+
+def download_xml_files(workflow_run_id, workflow_run_attempts, prefixes=[], allowed_substrings=None):
+ # Get from S3 artifacts
+ artifact_paths = []
+ for prefix in prefixes:
+ print("Trying to download S3 artifacts for workflow_run_attempt {} with prefix {}".format(workflow_run_attempts, prefix))
+ artifact_paths += download_s3_artifacts(
+ prefix,
+ workflow_run_id,
+ workflow_run_attempts,
+ allowed_substrings=allowed_substrings,
+ )
+
+ # Filter out rerun_disabled_tests artifacts (same prefix, different job)
+ before = len(artifact_paths)
+ artifact_paths = [p for p in artifact_paths if "rerun_disabled" not in p.name]
+ if before != len(artifact_paths):
+ print(f" Filtered out {before - len(artifact_paths)} rerun_disabled artifacts")
+
+ # Fall back to GHA artifacts if S3 returned nothing
+ if len(artifact_paths) == 0:
+ print(f"No S3 artifacts found, trying GHA artifacts as fallback...")
+ artifact_paths = download_gha_artifacts_filtered(
+ workflow_run_id,
+ workflow_run_attempts,
+ prefixes=prefixes,
+ allowed_substrings=allowed_substrings,
+ )
+
+ if len(artifact_paths) == 0:
+ error_msg = f"WARNING: workflow run id: {workflow_run_id} - no artifacts found (S3 or GHA) for prefixes: {prefixes}"
+ print(error_msg)
+ error_msgs.append(error_msg)
+ return
+
+ for path in artifact_paths:
+ unzip(path)
+
+ _shorten_unzipped_dirs()
+
+ # Delete raw zip files now that contents are extracted
+ for path in artifact_paths:
+ try:
+ path.unlink()
+ print(f" Deleted {path}")
+ except Exception:
+ pass
+
+def download_artifacts(wf, prefixes=[], test_folder=".", allowed_substrings=None):
+ os.chdir(test_folder)
+ #download the xml files
+ download_xml_files(
+ wf['id'],
+ wf.get('run_attempt',1),
+ prefixes,
+ allowed_substrings=allowed_substrings,
+ )
+ os.chdir("..")
+# for older runs, add 'created':'<=YYYY-MM-DD'. see https://docs.github.com/en/search-github/getting-started-with-searching-on-github/understanding-the-search-syntax#query-for-dates
+def download_workflow_run(created=None, max_pages=10, workflow=None, sha=None, ignore_status=False, status='success', error_msg='Error downloading workflow runs'):
+ if not workflow:
+ raise Exception("Workflow must be specified")
+ for page in range(max_pages):
+ params = {'per_page': 30, 'page': page}
+ if not ignore_status:
+ if status:
+ params['status'] = status
+ if created:
+ params['created'] = created
+ if sha:
+ params['head_sha'] = sha
+ else:
+ params['branch'] = "main"
+ print(".")
+
+ # Uncomment below for additional debug info
+ # print(f"authentication_headers: {authentication_headers}")
+ # print(f"params: {params}")
+ # print("https://api.github.com/repos/pytorch/pytorch/actions/workflows/{}.yml/runs".format(workflow))
+ response = requests.get("https://api.github.com/repos/pytorch/pytorch/actions/workflows/{}.yml/runs".format(workflow), headers=authentication_headers, params=params)
+ #print(response.json())
+ workflow_runs = None
+ try:
+ workflow_runs = response.json()['workflow_runs']
+ except:
+ raise Exception(response.text)
+ if not workflow_runs:
+ continue
+ # Prefer completed runs over in-progress ones. When multiple
+ # runs exist for the same SHA, the most recent may still be
+ # running and have no artifacts yet.
+ completed = [wf for wf in workflow_runs if wf.get('status') == 'completed']
+ if completed:
+ return completed[0]
+ return workflow_runs[0]
+
+ # Should not reach here ideally
+ raise Exception(error_msg)
+
+def create_test_folder(wf):
+ if wf is None:
+ raise Exception("wf is None!")
+ #return
+ test_folder = re.sub('T.*Z', '', wf['created_at'].replace(":", "").replace("-", "")) + "_" + wf['head_sha']
+ if not os.path.exists(test_folder):
+ os.mkdir(test_folder)
+
+ cuda_xml_folder = test_folder + "/cuda_xml"
+ if not os.path.exists(cuda_xml_folder):
+ os.mkdir(cuda_xml_folder)
+
+ rocm_xml_folder = test_folder + "/rocm_xml"
+ if not os.path.exists(rocm_xml_folder):
+ os.mkdir(rocm_xml_folder)
+ return [test_folder, cuda_xml_folder, rocm_xml_folder]
+
+_first_folder = None
+
+def get_or_create_test_folder(wf):
+ """Reuse the first folder created so all artifacts land in one place.
+
+ Different upstream workflows for the same SHA can have different created_at
+ dates (e.g. spanning midnight), which would cause create_test_folder to
+ create separate directories. This wrapper ensures every call returns the
+ same folder that was established by the very first invocation.
+ """
+ global _first_folder
+ if _first_folder is not None:
+ test_folder = _first_folder
+ cuda_xml = test_folder + "/cuda_xml"
+ rocm_xml = test_folder + "/rocm_xml"
+ os.makedirs(cuda_xml, exist_ok=True)
+ os.makedirs(rocm_xml, exist_ok=True)
+ return [test_folder, cuda_xml, rocm_xml]
+ result = create_test_folder(wf)
+ _first_folder = result[0]
+ return result
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Download pytorch unit test logs')
+ parser.add_argument('--created', const=None, help='eg., \'<=YYYY-MM-DD\'. See https://docs.github.com/en/search-github/getting-started-with-searching-on-github/understanding-the-search-syntax#query-for-dates')
+ parser.add_argument('--max_pages', type=int, default=10, help='eg., 100')
+ parser.add_argument('--sha1', const=None, help='eg., 3dcd67a1b374faea01f4d2e17beb6bb1fff76d76')
+ parser.add_argument('--exclude_distributed', action='store_true')
+ parser.add_argument('--exclude_inductor', action='store_true')
+ parser.add_argument('--exclude_default', action='store_true')
+ parser.add_argument('--ignore_status', action='store_true')
+ parser.add_argument('--artifacts_only', action='store_true')
+ parser.add_argument('--no_rocm', action='store_true')
+ parser.add_argument('--no_cuda', action='store_true')
+ parser.add_argument('--pr_id', type=int, help='The pull request ID')
+ parser.add_argument('--arch', type=str, choices=['mi200', 'mi300', 'mi355', 'navi31', 'nightly'], default='mi355', help='ROCm GPU architecture (mi200, mi300, mi355, navi31, or nightly, default: mi355)')
+ parser.add_argument('--include_inductor_periodic', action='store_true', help='Also download inductor-periodic benchmark artifacts (into a separate directory, not included in parity CSV)')
+ parser.add_argument('--baseline_sha', type=str, help='Baseline commit SHA to compare against. Downloads the same ROCm workflows for this commit into baseline_xml/.')
+ return parser.parse_args()
+
+# Rate-limit issues
+# Authenticated users get 5000 requests/day
+# Check rate-limit without penalty: curl -H "Authorization: token $GITHUB_TOKEN" -I https://api.github.com/users/octocat
+
+def main():
+ global args
+ args = parse_args()
+ if args.max_pages < 1:
+ args.max_pages=1
+
+ # Set ROCm workflow names based on architecture
+ global ROCmWorkflowNames
+ arch = args.arch # 'mi200', 'mi300', 'mi355', 'navi31', or 'nightly'
+ if arch == 'nightly':
+ ROCmWorkflowNames = {
+ "default": "rocm-nightly",
+ "distributed": "rocm-nightly",
+ "inductor": "rocm-nightly",
+ }
+ elif arch == 'mi355':
+ ROCmWorkflowNames = {
+ "default": "trunk",
+ "distributed": "periodic-rocm-mi355",
+ "inductor": "inductor-rocm-mi355"
+ }
+ elif arch == 'mi200':
+ ROCmWorkflowNames = {
+ "default": "rocm-mi200",
+ "distributed": "periodic-rocm-mi200",
+ "inductor": "inductor-rocm-mi200"
+ }
+ else:
+ # MI300 and navi31 use dedicated ROCm workflows
+ ROCmWorkflowNames = {
+ "default": f"rocm-{arch}",
+ "distributed": f"periodic-rocm-{arch}",
+ "inductor": f"inductor-rocm-{arch}"
+ }
+ # Job key prefix for log downloads - architecture specific
+ # MI200 uses older jammy/py3.10 config, MI300 uses noble/py3.12
+ # Inductor jobs have a different naming format
+ rocm_job_prefixes = {
+ "nightly": {
+ "default": "linux-noble-rocm-nightly-py3.12-gfx942",
+ "distributed": "linux-noble-rocm-nightly-py3.12-gfx942",
+ "inductor": "linux-noble-rocm-nightly-py3.12-gfx942",
+ },
+ "mi200": {
+ "default": "linux-jammy-rocm-py3.10-mi200",
+ "distributed": "linux-jammy-rocm-py3.10-mi200",
+ "inductor": "linux-jammy-rocm-py3.10-mi200"
+ },
+ "mi300": {
+ "default": "linux-noble-rocm-py3.12-mi300",
+ "distributed": "linux-noble-rocm-py3.12-mi300",
+ "inductor": "linux-noble-rocm-py3.12-mi300"
+ },
+ "mi355": {
+ "default": "linux-jammy-rocm-py3.10-mi355",
+ "distributed": "linux-noble-rocm-py3.12-mi355",
+ "inductor": "linux-noble-rocm-py3.12-mi355"
+ },
+ "navi31": {
+ "default": "linux-jammy-rocm-py3.10-navi31",
+ "distributed": "linux-jammy-rocm-py3.10-navi31",
+ "inductor": "linux-jammy-rocm-py3.10-navi31"
+ }
+ }
+ # Architecture-specific shard counts
+ rocm_shard_counts = {
+ "nightly": {"default": 6, "distributed": 3, "inductor": 2},
+ "mi200": {"default": 6, "distributed": 3, "inductor": 2},
+ "mi300": {"default": 6, "distributed": 3, "inductor": 2},
+ "mi355": {"default": 10, "distributed": 4, "inductor": 2}, #trunk
+ "navi31": {"default": 2, "distributed": 3, "inductor": 2},
+ }
+ rocm_job_prefix = rocm_job_prefixes[arch]
+ rocm_shards = rocm_shard_counts[arch]
+ rocm_artifact_substrings = ["rocm.gpu"] if arch in ("mi355", "nightly") else None
+ # navi31 only has default tests (no distributed/inductor workflows)
+ if arch in ("navi31",):
+ if not args.exclude_distributed:
+ print(f"NOTE: {arch} has no distributed workflow, auto-excluding distributed")
+ args.exclude_distributed = True
+ if not args.exclude_inductor:
+ print(f"NOTE: {arch} has no inductor workflow, auto-excluding inductor")
+ args.exclude_inductor = True
+ if args.baseline_sha and not args.no_cuda:
+ print("NOTE: baseline_sha provided, auto-skipping CUDA (commit-vs-commit comparison)")
+ args.no_cuda = True
+
+ print(f"Using ROCm architecture: {arch}")
+ print(f"Using ROCm workflows: {ROCmWorkflowNames}")
+ if not args.no_cuda:
+ print(f"Using CUDA workflows: {CUDAWorkflowNames}")
+ print(f"Using ROCm job prefixes: {rocm_job_prefix}")
+ print(f"Using initial ROCm shard counts (may be updated based on actual workflow used): {rocm_shards}")
+
+ token = os.getenv('GITHUB_TOKEN', '...')
+ global authentication_headers
+ authentication_headers = {'Authorization': f'token {token}'}
+ if (args.pr_id and args.sha1) or (not args.pr_id and not args.sha1):
+ error_msg = "Error: Please provide either pr_id or sha!"
+ print(error_msg)
+ sys.exit(1)
+ if args.pr_id:
+ pr_id = args.pr_id
+ sha = get_latest_commit_sha(pr_id, token)
+ else:
+ sha = args.sha1
+ pr_id = None
+ status = "success"
+ print(f"sha: {sha}")
+
+ # When comparing two commits, prefix log filenames with short SHAs
+ if args.baseline_sha:
+ current_prefix = sha[:8] + "_"
+ baseline_prefix = args.baseline_sha[:8] + "_"
+ else:
+ current_prefix = ""
+ baseline_prefix = "baseline_"
+
+ if not args.exclude_distributed and not args.no_rocm:
+ periodic_sha = sha
+ print("==============================================")
+ print(f"Finding ROCm distributed tests in workflow '{ROCmWorkflowNames['distributed']}' by sha: {sha}")
+ print("==============================================")
+ # find distributed test in periodic workflow with success status
+ error_msg="Error: Periodic workflow not found in scanned workflow runs."
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ periodic_fallback_used = False
+ try:
+ periodic_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["distributed"], sha=periodic_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ periodic_wf = None
+ periodic_fallbacks = {
+ "mi355": ("trunk", "linux-jammy-rocm-py3.10-mi355"),
+ "mi200": ("trunk-rocm-sandbox", "linux-jammy-rocm-py3.10"),
+ }
+ if periodic_wf is None and arch in periodic_fallbacks:
+ fallback_wf, fallback_prefix = periodic_fallbacks[arch]
+ print(f"Distributed not found in {ROCmWorkflowNames['distributed']}, falling back to {fallback_wf}")
+ periodic_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=periodic_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ periodic_fallback_used = True
+ if periodic_wf is None:
+ raise Exception(error_msg)
+ dist_wf_name = ROCmWorkflowNames['distributed'] if not periodic_fallback_used else periodic_fallbacks[arch][0]
+ print(f"Using workflow '{dist_wf_name}' with id:{periodic_wf['id']} for ROCm distributed")
+
+ if periodic_fallback_used and arch in periodic_fallbacks:
+ dist_job_prefix = periodic_fallbacks[arch][1]
+ else:
+ dist_job_prefix = rocm_job_prefix['distributed']
+
+ folder_list = get_or_create_test_folder(periodic_wf)
+
+ # Download logs
+ # If the ROCm distributed logs aren't found you might want to check the HUD for the correct tags
+ # HUD link: https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=rocm
+ # Make sure "Hide unstable jobs" is unselected, in case ROCm jobs are marked as unstable
+
+ if arch == "mi355":
+ dist_shards = 3 if not periodic_fallback_used else rocm_shards["distributed"]
+ else:
+ dist_shards = rocm_shards["distributed"]
+ print(f"Using final ROCm shard count {dist_shards} for distributed")
+
+ if not args.artifacts_only:
+ test_log_list_rocm_distributed = [
+ [f"{current_prefix}rocm_dist{i}.txt", f"{dist_job_prefix} / test (distributed, {i}, {dist_shards}"]
+ for i in range(1, dist_shards + 1)
+ ]
+ download_logs(periodic_wf, test_log_list_rocm_distributed, folder_list[0])
+
+ # Download artifacts
+ test_artifacts_list_rocm_distributed = [
+ f"test-reports-test-distributed-{i}-{dist_shards}"
+ for i in range(1, dist_shards + 1)
+ ]
+ download_artifacts(
+ periodic_wf,
+ test_artifacts_list_rocm_distributed,
+ folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # Download ROCm default rocm_wf when ROCm is enabled
+ if not args.no_rocm and not args.exclude_default:
+ rocm_sha = sha
+ print("===========================================")
+ print(f"Finding ROCm default tests in workflow '{ROCmWorkflowNames['default']}' by sha: {rocm_sha}")
+ print("===========================================")
+ # find tests in rocm workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ error_msg="Error: rocm workflow not found in scanned workflow runs. Try increasing max_pages."
+ default_fallback_used = False
+ try:
+ rocm_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["default"], sha=rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ rocm_wf = None
+ default_fallbacks = {
+ "mi355": ("rocm-mi355", "linux-noble-rocm-py3.12-mi355"),
+ "mi200": ("trunk-rocm-sandbox", "linux-jammy-rocm-py3.10"),
+ }
+ if rocm_wf is None and arch in default_fallbacks:
+ fallback_wf, fallback_prefix = default_fallbacks[arch]
+ print(f"Default not found in {ROCmWorkflowNames['default']}, falling back to {fallback_wf}")
+ rocm_wf = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ default_fallback_used = True
+ rocm_job_prefix['default'] = fallback_prefix
+ if rocm_wf is None:
+ raise Exception(error_msg)
+ default_wf_name = ROCmWorkflowNames['default'] if not default_fallback_used else default_fallbacks[arch][0]
+ print(f"Using workflow '{default_wf_name}' with id:{rocm_wf['id']} for ROCm default{' (fallback)' if default_fallback_used else ''}")
+
+ folder_list = get_or_create_test_folder(rocm_wf)
+
+ # Download logs
+ # If logs aren't found you might want to check the HUD for the correct tags
+ # HUD link: https://hud.pytorch.org/hud/pytorch/pytorch/main/1?per_page=50&name_filter=rocm
+ if arch == "mi355":
+ default_shards = 6 if default_fallback_used else rocm_shards["default"]
+ else:
+ default_shards = rocm_shards["default"]
+ print(f"Using final ROCm shard count {default_shards} for default")
+
+ if not args.artifacts_only:
+ test_log_list_rocm_default = [
+ [f"{current_prefix}rocm{i}.txt", f"{rocm_job_prefix['default']} / test (default, {i}, {default_shards}"]
+ for i in range(1, default_shards + 1)
+ ]
+ download_logs(rocm_wf, test_log_list_rocm_default, folder_list[0])
+
+ # Download artifacts
+ test_artifacts_list_rocm_default = [
+ f"test-reports-test-default-{i}-{default_shards}"
+ for i in range(1, default_shards + 1)
+ ]
+ if not args.exclude_default:
+ download_artifacts(
+ rocm_wf,
+ test_artifacts_list_rocm_default,
+ test_folder=folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # add new inductor workflow downloading for ROCm
+ if not args.no_rocm and not args.exclude_inductor:
+ inductor_rocm_sha = sha
+ # find tests in inductor workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ print("===========================================")
+ print(f"Finding ROCm inductor tests in workflow '{ROCmWorkflowNames['inductor']}' by sha: {inductor_rocm_sha}")
+ print("===========================================")
+ error_msg="Error: inductor workflow not found in scanned workflow runs. Try increasing max_pages."
+ inductor_fallback_used = False
+ try:
+ inductor_wf_rocm = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=ROCmWorkflowNames["inductor"], sha=inductor_rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ except (IndexError, Exception):
+ inductor_wf_rocm = None
+ inductor_fallbacks = {
+ "mi200": ("trunk-rocm-sandbox", "linux-jammy-rocm-py3.10"),
+ }
+ if inductor_wf_rocm is None and arch in inductor_fallbacks:
+ fallback_wf, fallback_prefix = inductor_fallbacks[arch]
+ print(f"Inductor not found in {ROCmWorkflowNames['inductor']}, falling back to {fallback_wf}")
+ inductor_wf_rocm = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=fallback_wf, sha=inductor_rocm_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ inductor_fallback_used = True
+ if inductor_wf_rocm is None:
+ raise Exception(error_msg)
+ inductor_wf_name = ROCmWorkflowNames['inductor'] if not inductor_fallback_used else inductor_fallbacks[arch][0]
+ print(f"Using workflow '{inductor_wf_name}' with id:{inductor_wf_rocm['id']} for ROCm inductor")
+
+ folder_list = get_or_create_test_folder(inductor_wf_rocm)
+
+ inductor_shards = rocm_shards["inductor"]
+ print(f"Using final ROCm shard count {inductor_shards} for inductor")
+ if inductor_fallback_used and arch in inductor_fallbacks:
+ inductor_job_prefix = inductor_fallbacks[arch][1]
+ else:
+ inductor_job_prefix = rocm_job_prefix['inductor']
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_rocm_inductor = [
+ [f"{current_prefix}rocm_inductor{i}.txt", f"{inductor_job_prefix} / test (inductor, {i}, {inductor_shards}"]
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_logs(inductor_wf_rocm, test_log_list_rocm_inductor, folder_list[0])
+
+ #Download artifacts
+ test_artifacts_list_rocm_inductor = [
+ f"test-reports-test-inductor-{i}-{inductor_shards}"
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_artifacts(
+ inductor_wf_rocm,
+ test_artifacts_list_rocm_inductor,
+ test_folder=folder_list[2],
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+
+ if not args.no_cuda:
+ cuda_job_prefix = "linux-jammy-cuda13.0-py3.10-gcc11"
+ print("==========================================")
+ print(f"Finding CUDA tests in workflow '{CUDAWorkflowNames['default']}' by sha: {sha}")
+ print("==========================================")
+
+ # There can be multiple trunk runs for the same SHA. Find the one
+ # that actually contains CUDA test jobs by checking each run's jobs
+ # list, falling back to check-runs API to resolve the correct run.
+ trunk_wf = None
+ all_cuda_jobs = []
+ cuda_test_jobs = []
+
+ trunk_runs = []
+ params = {'per_page': 10}
+ if not args.ignore_status:
+ params['status'] = status
+ params['head_sha'] = sha
+ resp = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/workflows/{CUDAWorkflowNames['default']}.yml/runs",
+ headers=authentication_headers, params=params,
+ )
+ trunk_runs = resp.json().get('workflow_runs', [])
+
+ for run in trunk_runs:
+ jobs = get_workflow_jobs(run)
+ test_kind, test_jobs = get_cuda_test_jobs(jobs, cuda_job_prefix)
+ if test_jobs:
+ trunk_wf = run
+ all_cuda_jobs = jobs
+ cuda_test_jobs = test_jobs
+ cuda_test_job_kind = test_kind
+ print(f"Found CUDA test jobs in trunk run {run['id']}")
+ break
+
+ if not cuda_test_jobs and trunk_runs:
+ # CUDA test jobs may be in a different run than the one returned
+ # by the jobs API. Use check-runs API to find the actual run.
+ print("No CUDA test jobs in any trunk run's jobs API, trying check-runs API...")
+ check_runs = get_check_runs_for_commit(sha, cuda_job_prefix)
+ cuda_test_job_kind, cuda_test_jobs = get_cuda_test_jobs(check_runs, cuda_job_prefix)
+ if cuda_test_jobs:
+ # Extract the actual workflow run ID from the check-run details URL
+ import re as _re
+ run_match = _re.search(r'/runs/(\d+)/', cuda_test_jobs[0].get('details_url', ''))
+ if run_match:
+ actual_run_id = int(run_match.group(1))
+ # Find or fetch the correct workflow run
+ for run in trunk_runs:
+ if run['id'] == actual_run_id:
+ trunk_wf = run
+ break
+ if trunk_wf is None:
+ resp = requests.get(
+ f"https://api.github.com/repos/pytorch/pytorch/actions/runs/{actual_run_id}",
+ headers=authentication_headers,
+ )
+ trunk_wf = resp.json()
+ print(f"CUDA test jobs are in trunk run {trunk_wf['id']} (found via check-runs)")
+ all_cuda_jobs = list(cuda_test_jobs)
+
+ if trunk_wf is None:
+ trunk_wf = trunk_runs[0] if trunk_runs else None
+ if trunk_wf is None:
+ raise Exception("Error: No trunk workflow run found for CUDA tests")
+
+ print(f"Using workflow '{CUDAWorkflowNames['default']}' with id:{trunk_wf['id']} for CUDA default")
+
+ cuda_job_ids = [str(j['id']) for j in cuda_test_jobs]
+ cuda_artifact_substrings = [f"_{jid}" for jid in cuda_job_ids] if cuda_job_ids else ["nvidia.gpu"]
+ print(f"Using CUDA job prefix: {cuda_job_prefix}")
+ print(f"Using CUDA test job kind: {cuda_test_job_kind}")
+ print(f"Found {len(cuda_test_jobs)} CUDA test jobs matching prefix")
+
+ folder_list = get_or_create_test_folder(trunk_wf)
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_cuda_default = [
+ ["cuda1.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, 1, 5"],
+ ["cuda2.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, 2, 5"],
+ ["cuda3.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, 3, 5"],
+ ["cuda4.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, 4, 5"],
+ ["cuda5.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (default, 5, 5"],
+ ]
+ test_log_list_cuda = test_log_list_cuda_default
+ if not args.exclude_distributed:
+ test_log_list_cuda_distributed = [
+ ["cuda_dist1.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (distributed, 1, 3"],
+ ["cuda_dist2.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (distributed, 2, 3"],
+ ["cuda_dist3.txt", f"{cuda_job_prefix} / {cuda_test_job_kind} (distributed, 3, 3"],
+ ]
+ test_log_list_cuda += test_log_list_cuda_distributed
+
+ download_logs(trunk_wf, test_log_list_cuda, folder_list[0], jobs=all_cuda_jobs)
+
+ # Download artifacts
+ test_artifacts_list_cuda_default = [
+ f"test-reports-{cuda_test_job_kind}-default-1-5",
+ f"test-reports-{cuda_test_job_kind}-default-2-5",
+ f"test-reports-{cuda_test_job_kind}-default-3-5",
+ f"test-reports-{cuda_test_job_kind}-default-4-5",
+ f"test-reports-{cuda_test_job_kind}-default-5-5",
+ ]
+
+ test_artifacts_list_cuda = []
+ if not args.exclude_default:
+ test_artifacts_list_cuda += test_artifacts_list_cuda_default
+
+ if not args.exclude_distributed:
+ test_artifacts_list_cuda_distributed = [
+ f"test-reports-{cuda_test_job_kind}-distributed-1-3",
+ f"test-reports-{cuda_test_job_kind}-distributed-2-3",
+ f"test-reports-{cuda_test_job_kind}-distributed-3-3",
+ ]
+ test_artifacts_list_cuda += test_artifacts_list_cuda_distributed
+
+ if test_artifacts_list_cuda:
+ download_artifacts(
+ trunk_wf,
+ test_artifacts_list_cuda,
+ test_folder=folder_list[1],
+ allowed_substrings=cuda_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # add new inductor workflow downloading for CUDA
+ if not args.exclude_inductor:
+ inductor_sha = sha
+ print("==========================================")
+ print(f"Finding CUDA inductor tests in workflow '{CUDAWorkflowNames['inductor']}' by sha: {inductor_sha}")
+ print("==========================================")
+ # find tests in inductor workflow with given sha and success status
+ #https://docs.github.com/en/rest/actions/workflow-runs#list-workflow-runs-for-a-repository
+ error_msg="Error: inductor workflow not found in scanned workflow runs. Try increasing max_pages."
+ inductor_wf_cuda = download_workflow_run(created=args.created, max_pages=args.max_pages, workflow=CUDAWorkflowNames["inductor"], sha=inductor_sha, ignore_status=args.ignore_status, status=status, error_msg=error_msg)
+ print(f"Using workflow '{CUDAWorkflowNames['inductor']}' with id:{inductor_wf_cuda['id']} for CUDA inductor")
+
+ inductor_cuda_jobs = get_workflow_jobs(inductor_wf_cuda)
+ cuda_inductor_test_job_kind, cuda_inductor_test_jobs = get_cuda_inductor_test_jobs(inductor_cuda_jobs)
+ cuda_inductor_job_ids = [str(j['id']) for j in cuda_inductor_test_jobs]
+ cuda_inductor_artifact_substrings = (
+ [f"_{jid}" for jid in cuda_inductor_job_ids]
+ if cuda_inductor_job_ids
+ else None
+ )
+ print(f"Using CUDA inductor test job kind: {cuda_inductor_test_job_kind}")
+ print(f"Found {len(cuda_inductor_test_jobs)} CUDA inductor test jobs")
+
+ folder_list = get_or_create_test_folder(inductor_wf_cuda)
+
+ # Download logs
+ if not args.artifacts_only:
+ test_log_list_cuda_inductor = [
+ ["cuda_inductor1.txt", f"unit-test / inductor-test / {cuda_inductor_test_job_kind} (inductor, 1, 2"],
+ ["cuda_inductor2.txt", f"unit-test / inductor-test / {cuda_inductor_test_job_kind} (inductor, 2, 2"],
+ ]
+ download_logs(inductor_wf_cuda, test_log_list_cuda_inductor, folder_list[0], jobs=inductor_cuda_jobs)
+
+ test_artifacts_list_cuda_inductor = [
+ f"test-reports-{cuda_inductor_test_job_kind}-inductor-1-2",
+ f"test-reports-{cuda_inductor_test_job_kind}-inductor-2-2"
+ ]
+ download_artifacts(
+ inductor_wf_cuda,
+ test_artifacts_list_cuda_inductor,
+ test_folder=folder_list[1],
+ allowed_substrings=cuda_inductor_artifact_substrings,
+ )
+ os.chdir("..")
+
+ # Download baseline commit artifacts for commit-vs-commit comparison
+ if args.baseline_sha and not args.no_rocm:
+ baseline_sha = args.baseline_sha
+ print("==============================================")
+ print(f"Downloading BASELINE ROCm artifacts for sha: {baseline_sha}")
+ print("==============================================")
+
+ import glob
+ existing_folders = sorted(glob.glob("[0-9]*_[0-9a-f]*"), key=os.path.getmtime, reverse=True)
+ if existing_folders:
+ test_folder = existing_folders[0]
+ else:
+ raise Exception("No output folder found from primary downloads")
+
+ baseline_xml_dir = os.path.join(test_folder, "baseline_xml")
+ os.makedirs(baseline_xml_dir, exist_ok=True)
+
+ if not args.exclude_default:
+ try:
+ baseline_default_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["default"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline default workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline default workflow '{ROCmWorkflowNames['default']}' id: {baseline_default_wf['id']}")
+ default_shards = rocm_shards["default"]
+
+ if not args.artifacts_only:
+ baseline_default_logs = [
+ [f"{baseline_prefix}rocm{i}.txt", f"{rocm_job_prefix['default']} / test (default, {i}, {default_shards}"]
+ for i in range(1, default_shards + 1)
+ ]
+ download_logs(baseline_default_wf, baseline_default_logs, test_folder)
+
+ baseline_default_prefixes = [
+ f"test-reports-test-default-{i}-{default_shards}"
+ for i in range(1, default_shards + 1)
+ ]
+ download_artifacts(
+ baseline_default_wf,
+ baseline_default_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline default artifacts: {e}")
+
+ if not args.exclude_distributed and "distributed" in ROCmWorkflowNames:
+ try:
+ baseline_dist_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["distributed"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline distributed workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline distributed workflow '{ROCmWorkflowNames['distributed']}' id: {baseline_dist_wf['id']}")
+ dist_shards = rocm_shards["distributed"]
+
+ if not args.artifacts_only:
+ baseline_dist_logs = [
+ [f"{baseline_prefix}rocm_dist{i}.txt", f"{rocm_job_prefix['distributed']} / test (distributed, {i}, {dist_shards}"]
+ for i in range(1, dist_shards + 1)
+ ]
+ download_logs(baseline_dist_wf, baseline_dist_logs, test_folder)
+
+ baseline_dist_prefixes = [
+ f"test-reports-test-distributed-{i}-{dist_shards}"
+ for i in range(1, dist_shards + 1)
+ ]
+ download_artifacts(
+ baseline_dist_wf,
+ baseline_dist_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline distributed artifacts: {e}")
+
+ if not args.exclude_inductor and "inductor" in ROCmWorkflowNames:
+ try:
+ baseline_inductor_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow=ROCmWorkflowNames["inductor"], sha=baseline_sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=f"Baseline inductor workflow not found for {baseline_sha}",
+ )
+ print(f"Baseline inductor workflow '{ROCmWorkflowNames['inductor']}' id: {baseline_inductor_wf['id']}")
+ inductor_shards = rocm_shards["inductor"]
+
+ if not args.artifacts_only:
+ baseline_inductor_logs = [
+ [f"{baseline_prefix}rocm_inductor{i}.txt", f"{rocm_job_prefix['inductor']} / test (inductor, {i}, {inductor_shards}"]
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_logs(baseline_inductor_wf, baseline_inductor_logs, test_folder)
+
+ baseline_inductor_prefixes = [
+ f"test-reports-test-inductor-{i}-{inductor_shards}"
+ for i in range(1, inductor_shards + 1)
+ ]
+ download_artifacts(
+ baseline_inductor_wf,
+ baseline_inductor_prefixes,
+ test_folder=baseline_xml_dir,
+ allowed_substrings=rocm_artifact_substrings,
+ )
+ os.chdir("..")
+ except Exception as e:
+ print(f"WARNING: Could not download baseline inductor artifacts: {e}")
+
+ print(f"Baseline artifacts saved to: {baseline_xml_dir}")
+
+ # Download inductor-periodic benchmark artifacts (separate from parity CSV)
+ if args.include_inductor_periodic:
+ print("==============================================")
+ print(f"Finding inductor-periodic tests in workflow 'inductor-periodic' by sha: {sha}")
+ print("==============================================")
+ error_msg = "Error: inductor-periodic workflow not found for this SHA. It may not have run on this commit."
+ try:
+ inductor_periodic_wf = download_workflow_run(
+ created=args.created, max_pages=args.max_pages,
+ workflow="inductor-periodic", sha=sha,
+ ignore_status=args.ignore_status, status=status,
+ error_msg=error_msg,
+ )
+ except (IndexError, Exception) as e:
+ print(f"WARNING: {e}")
+ inductor_periodic_wf = None
+
+ if inductor_periodic_wf:
+ print(f"Using workflow 'inductor-periodic' with id:{inductor_periodic_wf['id']} for inductor-periodic")
+
+ folder_list = get_or_create_test_folder(inductor_periodic_wf)
+ test_folder = folder_list[0]
+
+ rocm_periodic_dir = os.path.join(test_folder, "inductor_periodic_rocm_dir")
+ cuda_periodic_dir = os.path.join(test_folder, "inductor_periodic_cuda_dir")
+ os.makedirs(rocm_periodic_dir, exist_ok=True)
+ os.makedirs(cuda_periodic_dir, exist_ok=True)
+
+ if not args.no_rocm:
+ print("Downloading inductor-periodic ROCm artifacts...")
+ download_artifacts(
+ inductor_periodic_wf,
+ ["test-reports-"],
+ test_folder=rocm_periodic_dir,
+ allowed_substrings=["rocm.gpu"],
+ )
+ os.chdir("..")
+
+ if not args.no_cuda:
+ print("Downloading inductor-periodic CUDA artifacts...")
+ cuda_periodic_job_ids = get_job_ids_by_prefix(inductor_periodic_wf, "linux.g5")
+ cuda_periodic_substrings = (
+ [f"_{jid}" for jid in cuda_periodic_job_ids]
+ if cuda_periodic_job_ids
+ else ["nvidia.gpu"]
+ )
+ download_artifacts(
+ inductor_periodic_wf,
+ ["test-reports-"],
+ test_folder=cuda_periodic_dir,
+ allowed_substrings=cuda_periodic_substrings,
+ )
+ os.chdir("..")
+
+ print(f"Inductor-periodic artifacts saved to:")
+ print(f" ROCm: {rocm_periodic_dir}")
+ print(f" CUDA: {cuda_periodic_dir}")
+ else:
+ print("Skipping inductor-periodic download (workflow run not found)")
+
+ return
+
+if __name__ == "__main__":
+ main()
+ if error_msgs:
+ for msg in error_msgs:
+ print(msg)
+ exit(1)
diff --git a/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py b/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py
new file mode 100644
index 0000000000000..406f4b49b78cc
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/generate_summary.py
@@ -0,0 +1,810 @@
+#!/usr/bin/env python3
+
+import argparse
+import csv
+import os
+import sys
+
+
+TEST_CONFIGS = ['default', 'distributed', 'inductor']
+TEST_CONFIG_DISPLAY = {
+ 'default': 'TEST DEFAULT',
+ 'distributed': 'TEST DISTRIBUTED',
+ 'inductor': 'TEST INDUCTOR',
+}
+MAX_DIAGNOSTIC_FIELD_CHARS = 20_000
+DIAGNOSTIC_FIELDS = {
+ 'comments',
+ 'message_cuda',
+ 'message_rocm',
+ 'message_set1',
+ 'message_set2',
+ 'reason',
+ 'skip_reason',
+}
+
+
+def _configure_csv_field_limit():
+ limit = sys.maxsize
+ while True:
+ try:
+ csv.field_size_limit(limit)
+ return
+ except OverflowError:
+ limit //= 10
+
+
+_configure_csv_field_limit()
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='Generate a parity summary from per-architecture test status CSVs'
+ )
+ parser.add_argument(
+ '--csv', nargs='+', required=True,
+ help='CSV file(s) to summarize (one per architecture, same order as --arch)'
+ )
+ parser.add_argument(
+ '--arch', nargs='+', required=True,
+ help='Architecture labels matching --csv order (e.g. mi200 mi300 mi355)'
+ )
+ parser.add_argument('--sha', type=str, default='', help='Commit SHA')
+ parser.add_argument('--pr_id', type=str, default='', help='Pull request ID')
+ parser.add_argument(
+ '--set1_name', type=str, default='set1',
+ help='Name used for set1 in CSV column headers (default: set1)'
+ )
+ parser.add_argument(
+ '--set2_name', type=str, default='set2',
+ help='Name used for set2 in CSV column headers (default: set2)'
+ )
+ parser.add_argument(
+ '--output', type=str, default='parity_summary',
+ help='Output path prefix (produces .csv and .md)'
+ )
+ parser.add_argument(
+ '--log-failures', nargs='*', default=[],
+ help='CSV file(s) from detect_log_failures.py to include in summary'
+ )
+ return parser.parse_args()
+
+
+def load_csv(filepath):
+ with open(filepath, newline='') as f:
+ return [_truncate_diagnostic_fields(row) for row in csv.DictReader(f)]
+
+
+def _truncate_diagnostic_fields(row):
+ for field in DIAGNOSTIC_FIELDS:
+ value = row.get(field, '')
+ if len(value) > MAX_DIAGNOSTIC_FIELD_CHARS:
+ omitted = len(value) - MAX_DIAGNOSTIC_FIELD_CHARS
+ row[field] = (
+ value[:MAX_DIAGNOSTIC_FIELD_CHARS]
+ + f'\n...[truncated {omitted:,} chars by generate_summary.py]'
+ )
+ return row
+
+
+def detect_columns(headers, set1_name, set2_name):
+ s1_status = f'status_{set1_name}'
+ s2_status = f'status_{set2_name}'
+ s1_time = f'running_time_{set1_name}'
+ s2_time = f'running_time_{set2_name}'
+ if s1_status not in headers:
+ s1_status = 'status_set1'
+ s2_status = 'status_set2'
+ s1_time = 'running_time_set1'
+ s2_time = 'running_time_set2'
+ return s1_status, s2_status, s1_time, s2_time
+
+
+def test_config_stats_keys(s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+ if not has_set2:
+ return [
+ f'PASSED ({s1_name})',
+ f'SKIPPED ({s1_name})',
+ f'FAILED ({s1_name})',
+ f'MISSED ({s1_name})',
+ f'TOTAL {s1}',
+ ]
+ return [
+ f'SKIPPED (on {s1_name}, but not on {s2_name})',
+ f'SKIPPED (on {s1_name})',
+ f'SKIPPED (on {s2_name})',
+ f'MISSED (MISSED on {s1_name}, NOT SKIPPED on {s2_name})',
+ f'{s1}ONLY (PASSED on {s1}, NOT PASSED on {s2})',
+ s2,
+ s1,
+ 'SKIPPED + MISSED',
+ f'{s2} - (SKIPPED + MISSED)',
+ f'DISAGREE [(SKIPPED+MISSED)/{s2}] %',
+ ]
+
+
+def compute_test_config_stats(rows, s1_col, s2_col, s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+
+ if not has_set2:
+ vals = {}
+ keys = test_config_stats_keys(s1_name, s2_name, has_set2=False)
+ vals[keys[0]] = sum(1 for r in rows if r[s1_col] == 'PASSED')
+ vals[keys[1]] = sum(1 for r in rows if r[s1_col] == 'SKIPPED')
+ vals[keys[2]] = sum(1 for r in rows if r[s1_col] == 'FAILED')
+ vals[keys[3]] = sum(1 for r in rows if r[s1_col] == 'MISSED')
+ vals[keys[4]] = sum(1 for r in rows if r[s1_col].strip())
+ return vals
+
+ s1_skip_not_s2 = sum(
+ 1 for r in rows
+ if r[s1_col] == 'SKIPPED' and r[s2_col] != 'SKIPPED'
+ )
+ s1_skip = sum(1 for r in rows if r[s1_col] == 'SKIPPED')
+ s2_skip = sum(1 for r in rows if r[s2_col] == 'SKIPPED')
+ s1_miss_not_s2_skip = sum(
+ 1 for r in rows
+ if r[s1_col] == 'MISSED' and r[s2_col] != 'SKIPPED'
+ )
+ only_s1 = sum(
+ 1 for r in rows
+ if r[s1_col] == 'PASSED' and r[s2_col] != 'PASSED'
+ )
+ total_s2 = sum(1 for r in rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+ total_s1 = sum(1 for r in rows if r[s1_col].strip() and r[s1_col].strip() != 'MISSED')
+
+ skip_miss = s1_skip_not_s2 + s1_miss_not_s2_skip
+ s2_minus = total_s2 - skip_miss
+ pct = (skip_miss / total_s2 * 100) if total_s2 else 0
+
+ vals = {}
+ keys = test_config_stats_keys(s1_name, s2_name)
+ vals[keys[0]] = s1_skip_not_s2
+ vals[keys[1]] = s1_skip
+ vals[keys[2]] = s2_skip
+ vals[keys[3]] = s1_miss_not_s2_skip
+ vals[keys[4]] = only_s1
+ vals[keys[5]] = total_s2
+ vals[keys[6]] = total_s1
+ vals[keys[7]] = skip_miss
+ vals[keys[8]] = s2_minus
+ vals[keys[9]] = f'{pct:.2f}%'
+ return vals
+
+
+def overall_stats_keys(s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+ if not has_set2:
+ keys = []
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ keys.append(f'{status}({s1_name})')
+ keys += [
+ f'TOTAL {s1}',
+ f'TOTAL {s1} RUNNING TIME',
+ ]
+ return keys
+ keys = [
+ 'Overall DISAGREE%',
+ 'Overall AGREE%',
+ ]
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ keys.append(f'{status}({s1_name})')
+ keys.append(f'{status}({s2_name})')
+ keys += [
+ f'TOTAL {s2}',
+ f'TOTAL {s1}',
+ f'TOTAL {s1} RUNNING TIME',
+ f'TOTAL {s2} RUNNING TIME',
+ ]
+ return keys
+
+
+def compute_overall_stats(rows, s1_col, s2_col, s1_time_col, s2_time_col, s1_name, s2_name, has_set2=True):
+ s1 = s1_name.upper()
+ s2 = s2_name.upper()
+
+ def safe_float(v):
+ try:
+ return float(v)
+ except (ValueError, TypeError):
+ return 0.0
+
+ if not has_set2:
+ vals = {}
+ keys = overall_stats_keys(s1_name, s2_name, has_set2=False)
+ idx = 0
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col] == status)
+ idx += 1
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col].strip())
+ idx += 1
+ vals[keys[idx]] = f'{sum(safe_float(r[s1_time_col]) for r in rows):.2f}'
+ return vals
+
+ total_disagree = 0
+ total_s2 = 0
+ for wf in TEST_CONFIGS:
+ wf_rows = [r for r in rows if r['test_config'] == wf]
+ s1_skip_not_s2 = sum(
+ 1 for r in wf_rows
+ if r[s1_col] == 'SKIPPED' and r[s2_col] != 'SKIPPED'
+ )
+ s1_miss_not_s2_skip = sum(
+ 1 for r in wf_rows
+ if r[s1_col] == 'MISSED' and r[s2_col] != 'SKIPPED'
+ )
+ total_disagree += s1_skip_not_s2 + s1_miss_not_s2_skip
+ total_s2 += sum(1 for r in wf_rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+
+ disagree_pct = (total_disagree / total_s2 * 100) if total_s2 else 0
+ agree_pct = 100 - disagree_pct
+
+ vals = {}
+ keys = overall_stats_keys(s1_name, s2_name)
+ vals[keys[0]] = f'{disagree_pct:.2f}%'
+ vals[keys[1]] = f'{agree_pct:.2f}%'
+
+ idx = 2
+ for status in ['PASSED', 'SKIPPED', 'FAILED', 'XFAILED']:
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col] == status)
+ vals[keys[idx + 1]] = sum(1 for r in rows if r[s2_col] == status)
+ idx += 2
+
+ vals[keys[idx]] = sum(1 for r in rows if r[s2_col].strip() and r[s2_col].strip() != 'MISSED')
+ idx += 1
+ vals[keys[idx]] = sum(1 for r in rows if r[s1_col].strip() and r[s1_col].strip() != 'MISSED')
+ idx += 1
+
+ vals[keys[idx]] = f'{sum(safe_float(r[s1_time_col]) for r in rows):.2f}'
+ idx += 1
+ vals[keys[idx]] = f'{sum(safe_float(r[s2_time_col]) for r in rows):.2f}'
+ return vals
+
+
+def collect_failed_tests(arch_data, archs, s1_name, s2_name):
+ """Return a list of failed test rows across all architectures.
+
+ Only collects tests where s1 (ROCm) is FAILED. Each entry records shards
+ for both s1 and s2 so the reviewer can look up the failure in either CI
+ job. 'also_failing_in' is populated later once log failures are known so
+ CUDA log-only failures can be included.
+ """
+ failed = []
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, _, _ = d['cols']
+ has_set2 = d.get('has_set2', True)
+ for r in d['rows']:
+ s1 = r[s1_col].strip()
+ s2 = r[s2_col].strip() if has_set2 else ''
+ if s1 == 'FAILED':
+ entry = {
+ 'arch': arch,
+ 'test_file': r.get('test_file', ''),
+ 'test_class': r.get('test_class', ''),
+ 'test_name': r.get('test_name', ''),
+ 'test_config': r.get('test_config', ''),
+ f'shard_{s1_name}': r.get(f'shard_{s1_name}', ''),
+ f'status_{s1_name}': s1,
+ }
+ if has_set2:
+ entry[f'shard_{s2_name}'] = r.get(f'shard_{s2_name}', '')
+ entry[f'status_{s2_name}'] = s2
+ failed.append(entry)
+
+ return failed
+
+
+def _add_cross_arch_info(failed_tests, log_failures, s2_name):
+ """Populate 'also_failing_in' for each entry.
+
+ Matches across other ROCm architectures (from XML-based failures) and also
+ includes s2 (CUDA) if a log failure is recorded for the same test tuple.
+ """
+ from collections import defaultdict
+ by_tuple = defaultdict(set)
+ for t in failed_tests:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ by_tuple[key].add(t['arch'])
+
+ cuda_log_tuples = set()
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s2_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ cuda_log_tuples.add((lf.get('test_file', ''), test_class, test_name))
+
+ for t in failed_tests:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ others = sorted(a for a in by_tuple[key] if a != t['arch'])
+ if key in cuda_log_tuples and s2_name not in others:
+ others.append(s2_name)
+ t['also_failing_in'] = ', '.join(others)
+
+
+def _add_log_failure_cross_arch(log_failures, failed_tests, s1_name, s2_name):
+ """Populate 'also_failing_in' for each log failure entry.
+
+ Cross-references: other archs that have the same test failing (either as
+ a log failure or as an XML-based failure), plus s2 (CUDA) if it appears
+ in log failures for the same test tuple.
+ """
+ from collections import defaultdict
+ by_tuple_archs = defaultdict(set)
+
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s1_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('test_file', ''), test_class, test_name)
+ by_tuple_archs[key].add(lf.get('arch', ''))
+ for t in failed_tests or []:
+ key = (t['test_file'], t['test_class'], t['test_name'])
+ by_tuple_archs[key].add(t['arch'])
+
+ cuda_log_tuples = set()
+ for lf in log_failures or []:
+ if lf.get('platform', '') == s2_name:
+ test_class, test_name = _parse_log_failure_names(lf)
+ cuda_log_tuples.add((lf.get('test_file', ''), test_class, test_name))
+
+ for lf in log_failures or []:
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('test_file', ''), test_class, test_name)
+ arch = lf.get('arch', '')
+ others = sorted(a for a in by_tuple_archs[key] if a and a != arch)
+ if key in cuda_log_tuples and s2_name not in others:
+ others.append(s2_name)
+ lf['also_failing_in'] = ', '.join(others)
+
+
+def load_log_failures(filepaths):
+ """Load log failure CSVs from detect_log_failures.py.
+
+ Extracts the architecture from the filename (e.g. log_failures_mi355.csv -> mi355).
+ """
+ entries = []
+ for fp in filepaths:
+ if not os.path.isfile(fp):
+ continue
+ basename = os.path.basename(fp)
+ arch = ''
+ if basename.startswith('log_failures_') and basename.endswith('.csv'):
+ arch = basename[len('log_failures_'):-len('.csv')]
+ with open(fp, newline='') as f:
+ for row in csv.DictReader(f):
+ row['arch'] = arch
+ entries.append(row)
+ return entries
+
+
+def load_flaky_tests_as_log_failures(filepaths):
+ """Load flaky_tests_.csv and return entries shaped like log-failure rows.
+
+ Each returned dict has the same schema as the entries produced by
+ load_log_failures, with category='FLAKY' and reason='::',
+ so they can be appended to the log_failures list and surfaced in the
+ LOG-BASED FAILURES table alongside crashes/timeouts/etc.
+ """
+ entries = []
+ for fp in filepaths or []:
+ if not fp:
+ continue
+ basename = os.path.basename(fp)
+ if not (basename.startswith('log_failures_') and basename.endswith('.csv')):
+ continue
+ arch = basename[len('log_failures_'):-len('.csv')]
+ flaky_path = os.path.join(
+ os.path.dirname(fp),
+ 'flaky_tests_' + basename[len('log_failures_'):],
+ )
+ if not os.path.isfile(flaky_path):
+ continue
+ with open(flaky_path, newline='') as f:
+ for row in csv.DictReader(f):
+ test_class = row.get('test_class', '')
+ test_name = row.get('test_name', '')
+ entries.append({
+ 'arch': arch,
+ 'log_file': row.get('log_file', ''),
+ 'platform': row.get('platform', ''),
+ 'test_config': row.get('test_config', ''),
+ 'test_file': row.get('test_file', ''),
+ 'job_shard': row.get('job_shard', ''),
+ 'test_shard': row.get('test_shard', ''),
+ 'status': 'FLAKY',
+ 'category': 'FLAKY',
+ 'reason': f'{test_class}::{test_name}' if test_class else test_name,
+ 'exit_codes': '',
+ })
+ return entries
+
+
+def load_log_shards(filepaths):
+ """Load log shard inventory CSVs written alongside log_failures files.
+
+ For each log_failures_.csv, looks for a sibling log_shards_.csv
+ and returns a lookup dict:
+ (arch, platform, test_config, job_shard, normalized_test_file) -> test_shards_str
+
+ The CSV is produced by detect_log_failures.py and records every
+ (test_file, test_shard) pair observed per job-level shard. If an XML-based
+ failure's key matches, we can back-fill the test-level shard value.
+ """
+ lookup = {}
+ for fp in filepaths:
+ if not fp:
+ continue
+ basename = os.path.basename(fp)
+ arch = ''
+ if basename.startswith('log_failures_') and basename.endswith('.csv'):
+ arch = basename[len('log_failures_'):-len('.csv')]
+ shards_path = os.path.join(
+ os.path.dirname(fp),
+ 'log_shards_' + basename[len('log_failures_'):],
+ )
+ else:
+ continue
+ if not os.path.isfile(shards_path):
+ continue
+ with open(shards_path, newline='') as f:
+ for row in csv.DictReader(f):
+ key = (arch, row.get('platform', ''), row.get('test_config', ''),
+ row.get('job_shard', ''),
+ _norm_test_file(row.get('test_file', '')))
+ lookup[key] = row.get('test_shards', '')
+ return lookup
+
+
+def _format_test_shards(shards_str):
+ """Collapse a test_shards inventory string into a compact display value.
+
+ - '' -> ''
+ - '1/1' -> '1/1'
+ - '3/14' -> '3/14'
+ - '1/14,6/14,12/14' -> '1,6,12/14' (multiple test-level shards observed)
+ - mixed totals fall back to the raw string."""
+ if not shards_str:
+ return ''
+ parts = [p for p in shards_str.split(',') if p]
+ if len(parts) == 1:
+ return parts[0]
+ totals = set()
+ nums = []
+ for p in parts:
+ if '/' not in p:
+ return shards_str
+ a, b = p.split('/', 1)
+ totals.add(b)
+ nums.append(a)
+ if len(totals) == 1:
+ return f"{','.join(nums)}/{totals.pop()}"
+ return shards_str
+
+
+def fmt_val(v):
+ if isinstance(v, int):
+ return f'{v:,}'
+ return str(v)
+
+
+def build_rows(args, archs, arch_data):
+ """Return a list of (label, val_per_arch...) tuples and section markers."""
+ out = []
+ any_has_set2 = any(d.get('has_set2', True) for d in arch_data.values())
+
+ if args.sha:
+ out.append(('__header__', f'Commit SHA: {args.sha}'))
+ if args.pr_id:
+ out.append(('__header__', f'PR ID: {args.pr_id}'))
+
+ wf_keys = test_config_stats_keys(args.set1_name, args.set2_name, has_set2=any_has_set2)
+ for wf in TEST_CONFIGS:
+ out.append(('__section__', TEST_CONFIG_DISPLAY[wf]))
+ arch_stats = {}
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, _, _ = d['cols']
+ has_set2 = d.get('has_set2', True)
+ wf_rows = [r for r in d['rows'] if r['test_config'] == wf]
+ arch_stats[arch] = compute_test_config_stats(
+ wf_rows, s1_col, s2_col, args.set1_name, args.set2_name,
+ has_set2=has_set2,
+ )
+ for key in wf_keys:
+ out.append((key, [arch_stats[a].get(key, 0) for a in archs]))
+
+ out.append(('__section__', 'OVERALL'))
+ ov_keys = overall_stats_keys(args.set1_name, args.set2_name, has_set2=any_has_set2)
+ arch_overall = {}
+ for arch in archs:
+ d = arch_data[arch]
+ s1_col, s2_col, s1_time, s2_time = d['cols']
+ has_set2 = d.get('has_set2', True)
+ arch_overall[arch] = compute_overall_stats(
+ d['rows'], s1_col, s2_col, s1_time, s2_time,
+ args.set1_name, args.set2_name, has_set2=has_set2,
+ )
+ for key in ov_keys:
+ out.append((key, [arch_overall[a].get(key, 0) for a in archs]))
+ return out
+
+
+def _norm_test_file(path):
+ """Normalize a test_file string so XML-sourced ('a.b.c') and log-sourced
+ ('a/b/c') forms compare equal. Also strips a trailing .py if present."""
+ if not path:
+ return ''
+ s = path.replace('/', '.')
+ if s.endswith('.py'):
+ s = s[:-3]
+ return s
+
+
+def _parse_log_failure_names(lf):
+ """Extract test_class and test_name from a log failure's reason field.
+
+ Handles formats like 'TestClass::test_method' and
+ 'TestClass::test_method | extra reason text'.
+ """
+ reason = lf.get('reason', '')
+ if '::' not in reason:
+ return '', ''
+ test_part = reason.split(' | ', 1)[0] if ' | ' in reason else reason
+ parts = test_part.split('::', 1)
+ return parts[0], parts[1]
+
+
+def write_csv(rows, archs, output_path, failed_tests=None, s1_name='set1', s2_name='set2', has_set2=True, log_failures=None, shard_lookup=None):
+ csv_rows = []
+ csv_rows.append([''] + list(archs))
+ for label, vals in rows:
+ if label == '__header__':
+ csv_rows.append([vals])
+ elif label == '__section__':
+ csv_rows.append([])
+ csv_rows.append([vals])
+ else:
+ csv_rows.append([label] + list(vals))
+ csv_rows.append([])
+
+ s1_failed = [t for t in (failed_tests or []) if t.get(f'status_{s1_name}') == 'FAILED']
+
+ shard_lookup = shard_lookup or {}
+
+ def _xml_test_shard(t, platform):
+ key = (t.get('arch', ''), platform, t.get('test_config', ''),
+ t.get(f'shard_{platform}', ''),
+ _norm_test_file(t.get('test_file', '')))
+ return _format_test_shards(shard_lookup.get(key, ''))
+
+ if s1_failed:
+ csv_rows.append(['FAILED TESTS'])
+ header = ['Arch', 'Test Config', 'Test File', 'Test Class',
+ 'Test Name',
+ f'Job-Level Shard ({s1_name})',
+ f'Test-Level Shard ({s1_name})']
+ if has_set2:
+ header.append(f'Job-Level Shard ({s2_name})')
+ header.append(f'Test-Level Shard ({s2_name})')
+ header.append(f'Status ({s1_name})')
+ if has_set2:
+ header.append(f'Status ({s2_name})')
+ header.append('Also Failing In')
+ csv_rows.append(header)
+ for t in s1_failed:
+ row = [t['arch'], t['test_config'], t['test_file'],
+ t['test_class'], t['test_name'],
+ t.get(f'shard_{s1_name}', ''),
+ _xml_test_shard(t, s1_name)]
+ if has_set2:
+ row.append(t.get(f'shard_{s2_name}', ''))
+ row.append(_xml_test_shard(t, s2_name))
+ row.append(t[f'status_{s1_name}'])
+ if has_set2:
+ row.append(t.get(f'status_{s2_name}', ''))
+ row.append(t.get('also_failing_in', ''))
+ csv_rows.append(row)
+ csv_rows.append([])
+
+ if log_failures:
+ xml_failed_keys = {
+ (t['arch'], _norm_test_file(t['test_file']), t['test_class'], t['test_name'])
+ for t in (failed_tests or [])
+ }
+ rocm_log_failures = []
+ for lf in log_failures:
+ if lf.get('platform', '') != s1_name:
+ continue
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('arch', ''), _norm_test_file(lf.get('test_file', '')),
+ test_class, test_name)
+ # Skip entries already present in the XML-based FAILED TESTS table
+ # to avoid double-counting the same failure, except for FLAKY
+ # entries which represent an independent signal (a rerun passed).
+ if key in xml_failed_keys and lf.get('category', '') != 'FLAKY':
+ continue
+ rocm_log_failures.append(lf)
+ if rocm_log_failures:
+ csv_rows.append(['LOG-BASED FAILURES (not in XML)'])
+ csv_rows.append(['Arch', 'Platform', 'Test Config', 'Test File', 'Test Class',
+ 'Test Name', 'Job-Level Shard', 'Test-Level Shard',
+ 'Category', 'Also Failing In', 'Log File'])
+ for lf in rocm_log_failures:
+ test_class, test_name = _parse_log_failure_names(lf)
+ csv_rows.append([
+ lf.get('arch', ''), lf.get('platform', ''), lf.get('test_config', ''),
+ lf.get('test_file', ''), test_class, test_name,
+ lf.get('job_shard', ''),
+ lf.get('test_shard', lf.get('shard', '')),
+ lf.get('category', ''),
+ lf.get('also_failing_in', ''),
+ lf.get('log_file', ''),
+ ])
+ csv_rows.append([])
+
+ with open(output_path, 'w', newline='') as f:
+ csv.writer(f).writerows(csv_rows)
+ print(f'CSV written to {output_path}')
+
+
+def write_markdown(rows, archs, output_path, failed_tests=None, s1_name='set1', s2_name='set2', has_set2=True, log_failures=None, shard_lookup=None):
+ lines = []
+ current_section = []
+
+ def flush_table():
+ if not current_section:
+ return
+ header = '| Metric | ' + ' | '.join(archs) + ' |'
+ sep = '| :--- | ' + ' | '.join(['---:'] * len(archs)) + ' |'
+ lines.append(header)
+ lines.append(sep)
+ for label, vals in current_section:
+ formatted = [fmt_val(v) for v in vals]
+ lines.append(f'| {label} | ' + ' | '.join(formatted) + ' |')
+ lines.append('')
+ current_section.clear()
+
+ for label, vals in rows:
+ if label == '__header__':
+ flush_table()
+ lines.append(f'**{vals}**')
+ lines.append('')
+ elif label == '__section__':
+ flush_table()
+ lines.append(f'### {vals}')
+ lines.append('')
+ else:
+ current_section.append((label, vals))
+
+ flush_table()
+
+ s1_failed = [t for t in (failed_tests or []) if t.get(f'status_{s1_name}') == 'FAILED']
+
+ shard_lookup = shard_lookup or {}
+
+ def _xml_test_shard(t, platform):
+ key = (t.get('arch', ''), platform, t.get('test_config', ''),
+ t.get(f'shard_{platform}', ''),
+ _norm_test_file(t.get('test_file', '')))
+ return _format_test_shards(shard_lookup.get(key, ''))
+
+ cols = ['Arch', 'Test Config', 'Test File', 'Test Class', 'Test Name',
+ f'Job-Level Shard ({s1_name})',
+ f'Test-Level Shard ({s1_name})']
+ if has_set2:
+ cols.append(f'Job-Level Shard ({s2_name})')
+ cols.append(f'Test-Level Shard ({s2_name})')
+ cols.append(f'Status ({s1_name})')
+ if has_set2:
+ cols.append(f'Status ({s2_name})')
+ cols.append('Also Failing In')
+
+ if s1_failed:
+ lines.append(f'### FAILED TESTS ({len(s1_failed)})')
+ lines.append('')
+ lines.append('| ' + ' | '.join(cols) + ' |')
+ lines.append('| ' + ' | '.join(['---'] * len(cols)) + ' |')
+ for t in s1_failed:
+ line = (f"| {t['arch']} | {t['test_config']} | {t['test_file']} "
+ f"| {t['test_class']} | {t['test_name']} "
+ f"| {t.get(f'shard_{s1_name}', '')} "
+ f"| {_xml_test_shard(t, s1_name)}")
+ if has_set2:
+ line += f" | {t.get(f'shard_{s2_name}', '')}"
+ line += f" | {_xml_test_shard(t, s2_name)}"
+ line += f" | {t[f'status_{s1_name}']}"
+ if has_set2:
+ line += f" | {t.get(f'status_{s2_name}', '')}"
+ line += f" | {t.get('also_failing_in', '')} |"
+ lines.append(line)
+ lines.append('')
+ else:
+ lines.append('### FAILED TESTS')
+ lines.append('')
+ lines.append('No failed tests found.')
+ lines.append('')
+
+ if log_failures:
+ xml_failed_keys = {
+ (t['arch'], _norm_test_file(t['test_file']), t['test_class'], t['test_name'])
+ for t in (failed_tests or [])
+ }
+ rocm_log_failures = []
+ for lf in log_failures:
+ if lf.get('platform', '') != s1_name:
+ continue
+ test_class, test_name = _parse_log_failure_names(lf)
+ key = (lf.get('arch', ''), _norm_test_file(lf.get('test_file', '')),
+ test_class, test_name)
+ if key in xml_failed_keys and lf.get('category', '') != 'FLAKY':
+ continue
+ rocm_log_failures.append(lf)
+ if rocm_log_failures:
+ lines.append(f'### LOG-BASED FAILURES (not in XML) ({len(rocm_log_failures)})')
+ lines.append('')
+ lines.append('These test failures were detected from CI log files but have no XML report')
+ lines.append('(typically due to timeouts, crashes, or process kills).')
+ lines.append('')
+ lines.append('| Arch | Platform | Test Config | Test File | Test Class | Test Name | Job-Level Shard | Test-Level Shard | Category | Also Failing In |')
+ lines.append('| --- | --- | --- | --- | --- | --- | --- | --- | --- | --- |')
+ for lf in rocm_log_failures:
+ test_class, test_name = _parse_log_failure_names(lf)
+ lines.append(
+ f"| {lf.get('arch', '')} | {lf.get('platform', '')} | {lf.get('test_config', '')} "
+ f"| {lf.get('test_file', '')} | {test_class} "
+ f"| {test_name} "
+ f"| {lf.get('job_shard', '')} "
+ f"| {lf.get('test_shard', lf.get('shard', ''))} "
+ f"| {lf.get('category', '')} "
+ f"| {lf.get('also_failing_in', '')} |"
+ )
+ lines.append('')
+
+ md = '\n'.join(lines)
+ with open(output_path, 'w') as f:
+ f.write(md)
+ print(f'Markdown written to {output_path}')
+ return md
+
+
+def main():
+ args = parse_args()
+
+ if len(args.csv) != len(args.arch):
+ print('Error: --csv and --arch must have the same number of values')
+ sys.exit(1)
+
+ archs = args.arch
+ arch_data = {}
+ for csv_path, arch in zip(args.csv, archs):
+ rows = load_csv(csv_path)
+ headers = set(rows[0].keys()) if rows else set()
+ cols = detect_columns(headers, args.set1_name, args.set2_name)
+ s2_col = cols[1]
+ has_set2 = any(r.get(s2_col, '').strip() for r in rows)
+ arch_data[arch] = {'rows': rows, 'cols': cols, 'has_set2': has_set2}
+
+ data_rows = build_rows(args, archs, arch_data)
+ failed = collect_failed_tests(arch_data, archs, args.set1_name, args.set2_name)
+ any_has_set2 = any(d.get('has_set2', True) for d in arch_data.values())
+ log_failures = load_log_failures(args.log_failures) if args.log_failures else []
+ if args.log_failures:
+ log_failures.extend(load_flaky_tests_as_log_failures(args.log_failures))
+ shard_lookup = load_log_shards(args.log_failures) if args.log_failures else {}
+
+ _add_cross_arch_info(failed, log_failures, args.set2_name)
+ _add_log_failure_cross_arch(log_failures, failed, args.set1_name, args.set2_name)
+
+ output_base = args.output
+ if output_base.endswith('.csv') or output_base.endswith('.md'):
+ output_base = output_base.rsplit('.', 1)[0]
+
+ write_csv(data_rows, archs, f'{output_base}.csv', failed, args.set1_name, args.set2_name, has_set2=any_has_set2, log_failures=log_failures, shard_lookup=shard_lookup)
+ write_markdown(data_rows, archs, f'{output_base}.md', failed, args.set1_name, args.set2_name, has_set2=any_has_set2, log_failures=log_failures, shard_lookup=shard_lookup)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/.automation_scripts/pytorch-unit-test-scripts/requirements.txt b/.automation_scripts/pytorch-unit-test-scripts/requirements.txt
new file mode 100644
index 0000000000000..9ee33b404d9cd
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/requirements.txt
@@ -0,0 +1,4 @@
+pandas
+rockset
+boto3
+requests
diff --git a/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py b/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py
new file mode 100755
index 0000000000000..72e587bbf54bd
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/summarize_xml_testreports.py
@@ -0,0 +1,733 @@
+#!/usr/bin/env python3
+
+import argparse
+import csv
+import os
+import re
+import pandas as pd
+from enum import Enum
+from itertools import chain
+from pathlib import Path
+from upload_test_stats import (
+ parse_xml_report,
+ get_pytest_parallel_times,
+ summarize_test_cases,
+)
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+# excluded test suites for comparison
+EXCLUDED_TEST_SUITES = [
+ "_nvfuser.test_dynamo",
+ "_nvfuser.test_python_frontend",
+ "_nvfuser.test_torchscript",
+ "test_jit_cuda_fuser",
+ "test_nvfuser_dynamo",
+ "test_nvfuser_frontend"
+]
+
+
+EXCLUDED_TEST_CLASSES = [
+ "nvfuser_tests",
+ "TensorPipeCudaDdpComparisonTest",
+ "TensorPipeCudaDistAutogradTest",
+ "TensorPipeCudaRemoteModuleTest",
+ "TensorPipeCudaRpcTest",
+ "TensorPipeTensorPipeAgentCudaRpcTest",
+ "TensorPipeTensorPipeCudaDistAutogradTest",
+ "test_cpp_rpc"
+]
+EXCLUDED_TESTS = [
+]
+
+
+# Test config names
+TestConfigName = Enum('TestConfigName', ['default', 'distributed', 'inductor'])
+
+def _status_priority(test_case):
+ """Return a numeric priority for deduplication of retried tests.
+ PASSED/XFAILED are preferred over FAILED/ERROR/SKIPPED since a
+ passing retry means the test is considered passing (flaky) in CI."""
+ status = get_test_status(test_case)
+ return {"PASSED": 4, "XFAILED": 3, "SKIPPED": 2, "FAILED": 1, "ERROR": 1, "MISSED": 0}.get(status, 0)
+
+def _extract_shard(dirname):
+ """Extract shard number from directory names like 'test-default-3-6'."""
+ m = re.match(r'test-\w+-(\d+)-(\d+)', dirname)
+ if m:
+ return f"{m.group(1)}/{m.group(2)}"
+ return ""
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, path="."):
+ test_config = ""
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ if "test-default" in new_dir:
+ test_config = TestConfigName.default.name
+ elif "test-distributed" in new_dir:
+ test_config = TestConfigName.distributed.name
+ elif "test-inductor" in new_dir:
+ test_config = TestConfigName.inductor.name
+ shard = _extract_shard(dir)
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ try:
+ new_cases = parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ test_config
+ )
+ except Exception as e:
+ print(f"WARNING: Skipping malformed XML {xml_report}: {e}")
+ continue
+ for key, case in new_cases.items():
+ case["shard"] = shard
+ existing = test_cases.get(key)
+ if existing is None or _status_priority(case) > _status_priority(existing):
+ test_cases[key] = case
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if not test_case:
+ return "MISSED"
+ elif "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_running_time(test_case):
+ status = get_test_status(test_case)
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def check_time_valid(time):
+ if time == "":
+ return False
+ return True
+
+def summarize_xml_files(args):
+ # TODO: Add arguments and parse accordingly
+ set1_path = args.set1 if args.set1 else "."
+ set2_path = args.set2
+ set1_name = args.set1_name
+ set2_name = args.set2_name
+
+ # statistics
+ SKIPPED_DEFAULT = 0
+ MISSED_DEFAULT = 0
+ CUDA_DEFAULT = 0
+ ROCM_DEFAULT = 0
+ ROCMONLY_DEFAULT = 0
+
+ SKIPPED_DISTRIBUTED = 0
+ MISSED_DISTRIBUTED = 0
+ CUDA_DISTRIBUTED = 0
+ ROCM_DISTRIBUTED = 0
+ ROCMONLY_DISTRIBUTED = 0
+
+ SKIPPED_INDUCTOR = 0
+ MISSED_INDUCTOR = 0
+ CUDA_INDUCTOR = 0
+ ROCM_INDUCTOR = 0
+ ROCMONLY_INDUCTOR = 0
+
+ TOTAL_CUDA_RUNNING_TIME = 0.0
+ TOTAL_ROCM_RUNNING_TIME = 0.0
+
+ # filter example: --filter SKIPPED-PASSED-MISSED-PASSED (tuples: set1 status1 - set2 status1, set1 status2 - set2 status2)
+ ut_status_filter = args.filter if args.filter else "."
+ list_of_status = ut_status_filter.split('-') if args.filter else []
+ # assertion: should be an even number length
+ assert len(list_of_status) % 2 == 0
+ list_status_set1 = []
+ list_status_set2 = []
+
+ index = 0
+ while index < len(list_of_status):
+ # special handling for status-NOT_status scenario
+ if "NOT" in list_of_status[index] or "NOT" in list_of_status[index+1]:
+ if "NOT" in list_of_status[index]:
+ items = list_of_status[index].split('_')
+ not_item = items[1]
+ for ind in range(len(UT_STATUS_LIST)):
+ if UT_STATUS_LIST[ind] != not_item:
+ list_status_set1.append(UT_STATUS_LIST[ind])
+ list_status_set2.append(list_of_status[index+1])
+ else:
+ items = list_of_status[index+1].split('_')
+ not_item = items[1]
+ for ind in range(len(UT_STATUS_LIST)):
+ if UT_STATUS_LIST[ind] != not_item:
+ list_status_set2.append(UT_STATUS_LIST[ind])
+ list_status_set1.append(list_of_status[index])
+ index += 2
+ else:
+ list_status_set1.append(list_of_status[index])
+ index += 1
+ list_status_set2.append(list_of_status[index])
+ index += 1
+
+ assert len(list_status_set1) == len(list_status_set2), \
+ "status_list not specified correctly, should be in pairs of two"
+ len_status_filter = len(list_status_set1)
+
+ # define column list
+ column_list = ['set1', 'set2', 'skip_reason', 'assignee', 'comments']
+
+ # function location pattern
+ pattern = "at 0x"
+
+ #parse the xml files
+ test_cases_set1_running_time = parse_xml_reports_as_dict(-1, -1, 'testsuite', set1_path)
+ # TODO: Does it matter what the workflow_run_attempt is set to below??
+ # test_cases is dict of dicts, with keys as tuple of test_file, test_class, test_name and test_config
+ test_cases_set1 = parse_xml_reports_as_dict(-1, -1, 'testcase', set1_path)
+ for (k,v) in list(test_cases_set1.items()):
+ if v['test_config'] == TestConfigName.default.name:
+ ROCM_DEFAULT += 1
+ elif v['test_config'] == TestConfigName.distributed.name:
+ ROCM_DISTRIBUTED += 1
+ elif v['test_config'] == TestConfigName.inductor.name:
+ ROCM_INDUCTOR += 1
+
+ # start with creating empty dicts for set2 for each test tuple
+ # for rocm/cuda comparison(with valid set2_path), sometimes parity sheet has inaccurate resutls due to different function string but with same test names,
+ # such as test_np_argmin_argmax_keepdims_size_(1, 2, 3, 4)_axis_-4_method_
+ test_cases_set1_new: Dict[Tuple[str], Dict[str, Any]] = {}
+ if set2_path:
+ for (k,v) in list(test_cases_set1.items()):
+ if pattern in k[2]:
+ values = list(k)
+ index = k[2].find(pattern)
+ values[2] = k[2][0 : index]
+ k_new = tuple(values)
+ test_cases_set1_new[k_new] = v
+ del test_cases_set1[k]
+ #combine two dict
+ test_cases_set1_combined = {**test_cases_set1, **test_cases_set1_new}
+ test_cases = { k:[v, {}] for (k,v) in test_cases_set1_combined.items() }
+ else:
+ test_cases = { k:[v, {}] for (k,v) in test_cases_set1.items() }
+
+ test_cases_set2_running_time = {}
+ if set2_path:
+ assert set2_path != set1_path, \
+ "set2 path not specified correctly, should be different from set1 path"
+ test_cases_set2_running_time = parse_xml_reports_as_dict(-1, -1, 'testsuite', set2_path)
+ test_cases_set2 = parse_xml_reports_as_dict(-1, -1, 'testcase', set2_path)
+ for (k,v) in list(test_cases_set2.items()):
+ if v['test_config'] == TestConfigName.default.name:
+ CUDA_DEFAULT += 1
+ elif v['test_config'] == TestConfigName.distributed.name:
+ CUDA_DISTRIBUTED += 1
+ elif v['test_config'] == TestConfigName.inductor.name:
+ CUDA_INDUCTOR += 1
+
+ # for rocm/cuda comparison, sometimes parity sheet has inaccurate resutls due to different function string but with same test names,
+ # such as test_np_argmin_argmax_keepdims_size_(1, 2, 3, 4)_axis_-4_method_
+ test_cases_set2_new: Dict[Tuple[str], Dict[str, Any]] = {}
+ for (k,v) in list(test_cases_set2.items()):
+ if pattern in k[2]:
+ values = list(k)
+ index = k[2].find(pattern)
+ values[2] = k[2][0 : index]
+ k_new = tuple(values)
+ test_cases_set2_new[k_new] = v
+ del test_cases_set2[k]
+ #combine two dict
+ test_cases_set2_combined = {**test_cases_set2, **test_cases_set2_new}
+
+ # repopulate set2 dicts for test_tuples from test_cases_set2,
+ # creating empty dicts for set1 if test_tuple doesn't exist in test_cases
+ for test_case in test_cases_set2_combined:
+ test_cases[test_case] = [test_cases_set1_combined[test_case] if test_case in test_cases_set1_combined else {}, test_cases_set2_combined[test_case]]
+
+ # expand with skip_reason, assignee and comments
+ for (k,v) in list(test_cases.items()):
+ # set1, set2, skip_reason, assignee and comments
+ while len(v) < len(column_list):
+ v.append('')
+
+ # get running time statistics before any exclusion and filter since they are only for comparison
+ # total running time: ROCm and CUDA
+ for (k,v) in list(test_cases_set1_running_time.items()):
+ TOTAL_ROCM_RUNNING_TIME += v["running_time_xml"]
+ for (k,v) in list(test_cases_set2_running_time.items()):
+ TOTAL_CUDA_RUNNING_TIME += v["running_time_xml"]
+
+ # test file level running time: ROCm and CUDA
+ test_file_level_ROCm: Dict[Tuple[str], float] = {}
+ test_file_level_CUDA: Dict[Tuple[str], float] = {}
+ for (k,v) in list(test_cases_set1_running_time.items()):
+ test_file_name = k[0]
+ test_config_name = k[2]
+ tar_tup_rocm = (test_file_name, test_config_name,)
+ if test_file_level_ROCm.get(tar_tup_rocm) == None:
+ test_file_level_ROCm[ ( test_file_name, test_config_name ) ] = v["running_time_xml"]
+ else:
+ test_file_level_ROCm[ ( test_file_name, test_config_name ) ] += v["running_time_xml"]
+ for (k,v) in list(test_cases_set2_running_time.items()):
+ test_file_name = k[0]
+ test_config_name = k[2]
+ tar_tup_cuda = (test_file_name, test_config_name)
+ if test_file_level_CUDA.get(tar_tup_cuda) == None:
+ test_file_level_CUDA[ ( test_file_name, test_config_name ) ] = v["running_time_xml"]
+ else:
+ test_file_level_CUDA[ ( test_file_name, test_config_name ) ] += v["running_time_xml"]
+
+ # test file level counts: ROCm tests run, passed, skipped, missed; CUDA tests run
+ test_file_counts_ROCm: Dict[Tuple[str], Dict[str, int]] = {}
+ test_file_counts_CUDA: Dict[Tuple[str], int] = {}
+ for (k,v) in list(test_cases_set1.items()):
+ test_file_name = k[0]
+ test_config_name = v['test_config']
+ tar_tup = (test_file_name, test_config_name)
+ if tar_tup not in test_file_counts_ROCm:
+ test_file_counts_ROCm[tar_tup] = {'tests_run': 0, 'passed': 0, 'skipped': 0, 'missed': 0}
+ test_file_counts_ROCm[tar_tup]['tests_run'] += 1
+ status = get_test_status(v)
+ if status == "PASSED":
+ test_file_counts_ROCm[tar_tup]['passed'] += 1
+ elif status == "SKIPPED":
+ test_file_counts_ROCm[tar_tup]['skipped'] += 1
+ elif status == "MISSED":
+ test_file_counts_ROCm[tar_tup]['missed'] += 1
+ for (k,v) in list(test_cases_set2.items()) if set2_path else []:
+ test_file_name = k[0]
+ test_config_name = v['test_config']
+ tar_tup = (test_file_name, test_config_name)
+ if tar_tup not in test_file_counts_CUDA:
+ test_file_counts_CUDA[tar_tup] = 0
+ test_file_counts_CUDA[tar_tup] += 1
+
+ # exclude certain tests for comparison
+ if set2_path:
+ for (k,v) in list(test_cases.items()):
+ if k[0] in EXCLUDED_TEST_SUITES:
+ test_cases.pop(k)
+ elif k[1] in EXCLUDED_TEST_CLASSES:
+ test_cases.pop(k)
+ elif (k[0], k[1], k[2]) in EXCLUDED_TESTS:
+ test_cases.pop(k)
+
+ # remove unmatched items if user specified ut status filters
+ if len_status_filter > 0:
+ case_matched = True
+ for (k,v) in list(test_cases.items()):
+ case_matched = False
+ status_set_1 = get_test_status(v[0])
+ status_set_2 = get_test_status(v[1]) if set2_path else ""
+ for index in range(len_status_filter):
+ if status_set_1 == list_status_set1[index] and status_set_2 == list_status_set2[index]:
+ case_matched = True
+ break
+
+ if not case_matched:
+ test_cases.pop(k)
+
+ # insert skip_reason, assignee and comments info for the cases that: rocm-missed+cuda-passed OR rocm-skipped+cuda-passed
+ # To do: assume set1 is ROCm currently. Should insert another arg for ROCm and CUDA order?
+ skip_reasons_stat_default = dict()
+ skip_reasons_stat_distributed = dict()
+ skip_reasons_stat_inductor = dict()
+ if args.skip_reasons:
+ # read skip reasons csv file
+ known_skips = pd.read_csv(args.skip_reasons, sep='\t')
+ known_skips = known_skips.to_dict(orient="records")
+
+ # Load previous week's CSV to check if tests existed and get skip reasons
+ prev_week_tests = set()
+ prev_week_skip_reasons = {} # Maps (test_file, test_class, test_name) -> (skip_reason, assignee, comments)
+ if args.prev_week_csv:
+ prev_week_df = pd.read_csv(args.prev_week_csv)
+ for _, row in prev_week_df.iterrows():
+ test_key = (row['test_file'], row['test_class'], row['test_name'])
+ prev_week_tests.add(test_key)
+ # Also extract skip_reason, assignee, comments if they exist
+ skip_reason = row.get('skip_reason', '') if 'skip_reason' in row and not pd.isna(row.get('skip_reason', '')) else ''
+ assignee = row.get('assignee', '') if 'assignee' in row and not pd.isna(row.get('assignee', '')) else ''
+ comments = row.get('comments', '') if 'comments' in row and not pd.isna(row.get('comments', '')) else ''
+ if skip_reason or assignee or comments:
+ prev_week_skip_reasons[test_key] = (skip_reason, assignee, comments)
+
+ for (k,v) in list(test_cases.items()):
+ status_set_1 = get_test_status(v[0])
+ status_set_2 = get_test_status(v[1]) if set2_path else ""
+ test_file_name = k[0]
+ test_info = v[0]
+ test_info_set2 = []
+ if status_set_1 == "SKIPPED" and status_set_2 != "SKIPPED":
+ if test_info['test_config'] == TestConfigName.default.name:
+ SKIPPED_DEFAULT += 1
+ elif test_info['test_config'] == TestConfigName.distributed.name:
+ SKIPPED_DISTRIBUTED += 1
+ elif test_info['test_config'] == TestConfigName.inductor.name:
+ SKIPPED_INDUCTOR += 1
+ elif set2_path:
+ test_info_set2 = v[1]
+ if status_set_1 == "MISSED" and status_set_2 != "MISSED":
+ if test_info_set2['test_config'] == TestConfigName.default.name:
+ MISSED_DEFAULT += 1
+ elif test_info_set2['test_config'] == TestConfigName.distributed.name:
+ MISSED_DISTRIBUTED += 1
+ elif test_info_set2['test_config'] == TestConfigName.inductor.name:
+ MISSED_INDUCTOR += 1
+
+
+ if args.skip_reasons:
+ if (status_set_1 == "SKIPPED" and status_set_2 != "SKIPPED") or status_set_1 == "MISSED":
+ for known_skip in known_skips:
+ if test_file_name == known_skip['test_file'] and k[1] == known_skip['test_class'] and k[2] == known_skip['test_name']:
+ v[2] = known_skip['skip_reason'] if known_skip.__contains__('skip_reason') and not pd.isna(known_skip['skip_reason']) else ' '
+ if (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.default.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.default.name):
+ if not skip_reasons_stat_default.__contains__(v[2]):
+ skip_reasons_stat_default[v[2]] = 1
+ else:
+ skip_reasons_stat_default[v[2]] += 1
+ elif (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.distributed.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.distributed.name):
+ if not skip_reasons_stat_distributed.__contains__(v[2]):
+ skip_reasons_stat_distributed[v[2]] = 1
+ else:
+ skip_reasons_stat_distributed[v[2]] += 1
+ elif (test_info.__contains__('test_config') and test_info['test_config'] == TestConfigName.inductor.name) or (test_info_set2.__contains__('test_config') and test_info_set2['test_config'] == TestConfigName.inductor.name):
+ if not skip_reasons_stat_inductor.__contains__(v[2]):
+ skip_reasons_stat_inductor[v[2]] = 1
+ else:
+ skip_reasons_stat_inductor[v[2]] += 1
+ v[3] = known_skip['assignee'] if known_skip.__contains__('assignee') and not pd.isna(known_skip['assignee']) else ' '
+ v[4] = known_skip['comments'] if known_skip.__contains__('comments') and not pd.isna(known_skip['comments']) else ' '
+ break
+
+ if status_set_1 == "PASSED" and status_set_2 != "PASSED" and set2_path:
+ if test_info['test_config'] == TestConfigName.default.name:
+ ROCMONLY_DEFAULT += 1
+ elif test_info['test_config'] == TestConfigName.distributed.name:
+ ROCMONLY_DISTRIBUTED += 1
+ elif test_info['test_config'] == TestConfigName.inductor.name:
+ ROCMONLY_INDUCTOR += 1
+
+ skip_reasons_stat_default.pop(' ', None)
+ skip_reasons_stat_distributed.pop(' ', None)
+
+ test_cases_for_csv = {}
+ # k is test_tuple, v is list of rocm and cuda info for that test_tuple
+ skip_reason_file_specified = False
+ if args.skip_reasons:
+ skip_reason_file_specified = True
+ for (k,v) in test_cases.items():
+ item_values = {}
+ item_values["test_file"] = k[0]
+ item_values["test_class"] = k[1]
+ item_values["test_name"] = k[2]
+ item_values[f"status_{set1_name}"] = get_test_status(v[0])
+ item_values[f"status_{set2_name}"] = get_test_status(v[1]) if set2_path else ""
+ # get test config info
+ v_values = v[0]
+ v1_values = v[1] if set2_path else []
+ config_name = ""
+ item_values["test_config"] = ""
+ if item_values[f"status_{set1_name}"] != "MISSED":
+ config_name = v_values['test_config']
+ elif item_values[f"status_{set2_name}"] != "MISSED" and item_values[f"status_{set2_name}"] != "":
+ config_name = v1_values['test_config']
+ item_values["test_config"] = config_name
+ item_values[f"shard_{set1_name}"] = v_values.get('shard', '') if v_values else ''
+ item_values[f"shard_{set2_name}"] = v1_values.get('shard', '') if v1_values else ''
+ # get test related info
+ item_values[f"message_{set1_name}"] = get_test_message(v[0])
+ item_values[f"message_{set2_name}"] = get_test_message(v[1]) if set2_path else ""
+ # Get skip_reason, assignee, comments from --skip_reasons file if specified
+ if skip_reason_file_specified:
+ item_values["skip_reason"] = v[2]
+ item_values["assignee"] = v[3]
+ item_values["comments"] = v[4]
+ # Check if test existed in previous week's CSV and get skip reasons from there
+ if args.prev_week_csv:
+ test_key = (k[0], k[1], k[2]) # (test_file, test_class, test_name)
+ item_values["existed_last_week"] = "yes" if test_key in prev_week_tests else "no"
+ # If skip_reason not set by --skip_reasons, try to get from prev_week_csv
+ if not skip_reason_file_specified:
+ if test_key in prev_week_skip_reasons:
+ prev_skip_reason, prev_assignee, prev_comments = prev_week_skip_reasons[test_key]
+ item_values["skip_reason"] = prev_skip_reason
+ item_values["assignee"] = prev_assignee
+ item_values["comments"] = prev_comments
+ else:
+ item_values["skip_reason"] = ""
+ item_values["assignee"] = ""
+ item_values["comments"] = ""
+ if not skip_reason_file_specified and not args.prev_week_csv:
+ item_values["skip_reason"] = ""
+ item_values["assignee"] = ""
+ item_values["comments"] = ""
+ running_time1 = get_running_time(v[0])
+ item_values[f"running_time_{set1_name}"] = running_time1
+ running_time2 = get_running_time(v[1])
+ item_values[f"running_time_{set2_name}"] = running_time2
+ item_values["abs_time_diff"] = ""
+ item_values["relative_time_diff"] = ""
+ if check_time_valid(running_time1) and check_time_valid(running_time2):
+ item_values["abs_time_diff"] = running_time1 - running_time2
+ if get_running_time(v[1]) != 0.0:
+ item_values["relative_time_diff"] = 100 * (running_time1 - running_time2) / running_time2
+ test_cases_for_csv[k] = item_values
+
+ test_cases_for_csv = dict(sorted(test_cases_for_csv.items()))
+
+ #store test_cases in csv
+ tests_from_xml_filename = args.output_csv
+ keys_list = list(set(chain.from_iterable(sub.keys() for sub in test_cases_for_csv.values())))
+
+ def sorting_key(e):
+ if e == "invoking_file":
+ return 0
+ elif e == "test_file":
+ return 1
+ elif e == "test_class":
+ return 2
+ elif e == "test_name":
+ return 3
+ elif e == "test_config":
+ return 4
+ elif e == "skip_reason":
+ return 5
+ elif e == "assignee":
+ return 6
+ elif e == "comments":
+ return 7
+ elif e == f"status_{set1_name}":
+ return 8
+ elif e == f"message_{set1_name}":
+ return 9
+ elif e == f"running_time_{set1_name}":
+ return 10
+ elif e == f"status_{set2_name}":
+ return 11
+ elif e == f"message_{set2_name}":
+ return 12
+ elif e == f"running_time_{set2_name}":
+ return 13
+ elif e == "abs_time_diff":
+ return 14
+ elif e == "relative_time_diff":
+ return 15
+ elif e == "skipped":
+ return 16
+ elif e == "failure":
+ return 17
+ elif e == "error":
+ return 18
+ elif e == "system-out":
+ return 19
+ elif e == "existed_last_week":
+ return 20
+ elif e == f"shard_{set1_name}":
+ return 21
+ elif e == f"shard_{set2_name}":
+ return 22
+ elif e == "workflow_run_attempt" or e == "job_id":
+ return 1000
+ else:
+ return 100
+
+ keys_list.sort(key=sorting_key)
+
+ with open(tests_from_xml_filename, "w") as outfile:
+ writer = csv.DictWriter(outfile, fieldnames = keys_list)
+ writer.writeheader()
+ writer.writerows(test_cases_for_csv.values())
+ ## TODO - usage yet to be identified
+ #pytest_parallel_times = get_pytest_parallel_times()
+ ##extract test cases summary and save them in csv file
+ #test_cases_summary = summarize_test_cases(test_cases)
+ #testcases_summary_filename = "testcases_summary.csv"
+ #keys_list = list(set(chain.from_iterable(sub.keys() for sub in test_cases_summary)))
+ #with open(testcases_summary_filename, "w") as outfile:
+ # writer = csv.DictWriter(outfile, fieldnames = keys_list)
+ # writer.writeheader()
+ # writer.writerows(test_cases_summary)
+
+ # write test file running time to file
+ test_file_running_time_for_csv = {}
+ for key_rocm in test_file_level_ROCm.keys():
+ item_values = {}
+ item_values["test_file"] = key_rocm[0]
+ item_values["test_config"] = key_rocm[1]
+ item_values["rocm_running_time"] = test_file_level_ROCm[key_rocm]
+ item_values["cuda_running_time"] = 0.0
+ if key_rocm in test_file_level_CUDA.keys():
+ item_values["cuda_running_time"] = test_file_level_CUDA[key_rocm]
+ item_values["abs_time_diff"] = item_values["rocm_running_time"] - item_values["cuda_running_time"]
+ item_values["relative_time_diff"] = 0.0
+ if item_values["cuda_running_time"] != 0.0:
+ item_values["relative_time_diff"] = 100 * (item_values["rocm_running_time"] - item_values["cuda_running_time"]) / item_values["cuda_running_time"]
+ # Add test counts
+ item_values["rocm_tests_run"] = test_file_counts_ROCm.get(key_rocm, {}).get('tests_run', 0)
+ item_values["cuda_tests_run"] = test_file_counts_CUDA.get(key_rocm, 0)
+ item_values["rocm_passed"] = test_file_counts_ROCm.get(key_rocm, {}).get('passed', 0)
+ item_values["rocm_skipped"] = test_file_counts_ROCm.get(key_rocm, {}).get('skipped', 0)
+ item_values["rocm_missed"] = test_file_counts_ROCm.get(key_rocm, {}).get('missed', 0)
+ test_file_running_time_for_csv[key_rocm] = item_values
+
+ for key_cuda in test_file_level_CUDA.keys():
+ if not key_cuda in test_file_level_ROCm.keys():
+ item_values = {}
+ item_values["test_file"] = key_cuda[0]
+ item_values["test_config"] = key_cuda[1]
+ item_values["rocm_running_time"] = 0.0
+ item_values["cuda_running_time"] = test_file_level_CUDA[key_cuda]
+ item_values["abs_time_diff"] = item_values["rocm_running_time"] - item_values["cuda_running_time"]
+ item_values["relative_time_diff"] = 0.0
+ if item_values["cuda_running_time"] != 0.0:
+ item_values["relative_time_diff"] = 100 * (item_values["rocm_running_time"] - item_values["cuda_running_time"]) / item_values["cuda_running_time"]
+ # Add test counts
+ item_values["rocm_tests_run"] = test_file_counts_ROCm.get(key_cuda, {}).get('tests_run', 0)
+ item_values["cuda_tests_run"] = test_file_counts_CUDA.get(key_cuda, 0)
+ item_values["rocm_passed"] = test_file_counts_ROCm.get(key_cuda, {}).get('passed', 0)
+ item_values["rocm_skipped"] = test_file_counts_ROCm.get(key_cuda, {}).get('skipped', 0)
+ item_values["rocm_missed"] = test_file_counts_ROCm.get(key_cuda, {}).get('missed', 0)
+ test_file_running_time_for_csv[key_cuda] = item_values
+
+ test_file_running_time_for_csv = dict(sorted(test_file_running_time_for_csv.items()))
+ keys_list_running_time = list(set(chain.from_iterable(sub.keys() for sub in test_file_running_time_for_csv.values())))
+ def sorting_key_running_time(e):
+ if e == "test_file":
+ return 0
+ elif e == "test_config":
+ return 1
+ elif e == "rocm_running_time":
+ return 2
+ elif e == "cuda_running_time":
+ return 3
+ elif e == "abs_time_diff":
+ return 4
+ elif e == "relative_time_diff":
+ return 5
+ elif e == "rocm_tests_run":
+ return 6
+ elif e == "cuda_tests_run":
+ return 7
+ elif e == "rocm_passed":
+ return 8
+ elif e == "rocm_skipped":
+ return 9
+ elif e == "rocm_missed":
+ return 10
+ else:
+ return 100
+
+ keys_list_running_time.sort(key=sorting_key_running_time)
+ tests_from_xml_file_running_time = args.test_file_running_time_output_csv
+ with open(tests_from_xml_file_running_time, "w") as outfile:
+ writer = csv.DictWriter(outfile, fieldnames = keys_list_running_time)
+ writer.writeheader()
+ writer.writerows(test_file_running_time_for_csv.values())
+
+ # print summary
+ print( " " )
+ print( "_____________________________________" )
+ print( "Test-results" )
+ print( " " )
+ print( "=====Single GPU Number=====" )
+ print( "SKIPPED_DEFAULT, MISSED_DEFAULT, ROCMONLY_DEFAULT, CUDA_DEFAULT, ROCM_DEFAULT" )
+ print( str(SKIPPED_DEFAULT) + ", " + str(MISSED_DEFAULT) + ", " + str(ROCMONLY_DEFAULT) + ", " + str(CUDA_DEFAULT) + ", " + str(ROCM_DEFAULT) )
+ print( " " )
+ print( "=====Distributed GPU Number=====" )
+ print( "SKIPPED_DISTRIBUTED, MISSED_DISTRIBUTED, ROCMONLY_DISTRIBUTED, CUDA_DISTRIBUTED, ROCM_DISTRIBUTED" )
+ print( str(SKIPPED_DISTRIBUTED) + ", " + str(MISSED_DISTRIBUTED) + ", " + str(ROCMONLY_DISTRIBUTED) + ", " + str(CUDA_DISTRIBUTED) + ", " + str(ROCM_DISTRIBUTED) )
+ print( " " )
+ print( "=====Inductor GPU Number=====" )
+ print( "SKIPPED_INDUCTOR, MISSED_INDUCTOR, ROCMONLY_INDUCTOR, CUDA_INDUCTOR, ROCM_INDUCTOR" )
+ print( str(SKIPPED_INDUCTOR) + ", " + str(MISSED_INDUCTOR) + ", " + str(ROCMONLY_INDUCTOR) + ", " + str(CUDA_INDUCTOR) + ", " + str(ROCM_INDUCTOR) )
+ print( " " )
+ print( "SELECTED CAUSES SUMMARY" )
+ print( " " )
+ print( "=====================" )
+ print( "Single GPU test" )
+ sorted_skip_reasons_statistics_default = sorted(skip_reasons_stat_default.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_statistics_default:
+ print( skip_reason_entry, ": ", skip_reasons_stat_default[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Distributed test" )
+ sorted_skip_reasons_distributed_statistics = sorted(skip_reasons_stat_distributed.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_distributed_statistics:
+ print( skip_reason_entry, ": ", skip_reasons_stat_distributed[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Inductor test" )
+ sorted_skip_reasons_statistics_inductor = sorted(skip_reasons_stat_inductor.keys(), key = lambda x : x.lower())
+ for skip_reason_entry in sorted_skip_reasons_statistics_inductor:
+ print( skip_reason_entry, ": ", skip_reasons_stat_inductor[skip_reason_entry] )
+ print( " " )
+ print( "=====================" )
+ print( "Time statistics" )
+ print( "ROCM_RUNNING_TIME, CUDA_RUNNING_TIME" )
+ print( str(TOTAL_ROCM_RUNNING_TIME) + ", " + str(TOTAL_CUDA_RUNNING_TIME) )
+ #print( "ROCm test file level time statistics" )
+ #for (k,v) in list(test_file_level_ROCm.items()):
+ #print( k[0] + ", " + k[1] + ", " + k[2] + ", " + str(v) )
+ #print( "CUDA test file level time statistics" )
+ #for (k,v) in list(test_file_level_CUDA.items()):
+ #print( k[0] + ", " + k[1] + ", " + k[2] + ", " + str(v) )
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Parse xml test-reports')
+ parser.add_argument("--set1", required=False, type=str, help="absolute or relative path to first test-reports dir")
+ parser.add_argument("--set2", required=False, type=str, help="absolute or relative path to second test-reports dir")
+ parser.add_argument("--set1_name", required=False, type=str, default="set1", help="display name for set1 in CSV column headers (default: set1)")
+ parser.add_argument("--set2_name", required=False, type=str, default="set2", help="display name for set2 in CSV column headers (default: set2)")
+ parser.add_argument("--output_csv", required=False, type=str, help="output csv filename", default="tests_from_xml.csv")
+ parser.add_argument("--filter", required=False, type=str, help="ut status filter flag")
+ parser.add_argument("--skip_reasons", required=False, type=str, help='skip reasons file')
+ parser.add_argument("--test_file_running_time_output_csv", required=False, type=str, help="file running time output csv filename", default="file_running_time_output.csv")
+ parser.add_argument("--prev_week_csv", required=False, type=str, help="previous week's all tests status CSV file to check if tests existed")
+ return parser.parse_args()
+
+def main():
+ global args
+ args = parse_args()
+ summarize_xml_files(args)
+
+if __name__ == "__main__":
+ main()
+
diff --git a/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py b/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py
new file mode 100644
index 0000000000000..218e35768ef2c
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/upload_stats_lib.py
@@ -0,0 +1,187 @@
+import gzip
+import io
+import json
+import os
+import xml.etree.ElementTree as ET
+import zipfile
+from pathlib import Path
+from typing import Any, Dict, List, Optional
+
+import boto3 # type: ignore[import]
+import requests
+import rockset # type: ignore[import]
+
+PYTORCH_REPO = "https://api.github.com/repos/pytorch/pytorch"
+S3_RESOURCE = boto3.resource("s3")
+TARGET_WORKFLOW = "--rerun-disabled-tests"
+
+
+def _get_request_headers() -> Dict[str, str]:
+ return {
+ "Accept": "application/vnd.github.v3+json",
+ "Authorization": "token " + os.environ["GITHUB_TOKEN"],
+ }
+
+
+def _get_artifact_urls(prefix: str, workflow_run_id: int) -> Dict[Path, str]:
+ """Get all workflow artifacts with 'test-report' in the name."""
+ response = requests.get(
+ f"{PYTORCH_REPO}/actions/runs/{workflow_run_id}/artifacts?per_page=100",
+ )
+ artifacts = response.json()["artifacts"]
+ while "next" in response.links.keys():
+ response = requests.get(
+ response.links["next"]["url"], headers=_get_request_headers()
+ )
+ artifacts.extend(response.json()["artifacts"])
+
+ artifact_urls = {}
+ for artifact in artifacts:
+ if artifact["name"].startswith(prefix):
+ artifact_urls[Path(artifact["name"])] = artifact["archive_download_url"]
+ return artifact_urls
+
+
+def _download_artifact(
+ artifact_name: Path, artifact_url: str, workflow_run_attempt: int
+) -> Path:
+ # [Artifact run attempt]
+ # All artifacts on a workflow share a single namespace. However, we can
+ # re-run a workflow and produce a new set of artifacts. To avoid name
+ # collisions, we add `-runattempt1-` somewhere in the artifact name.
+ #
+ # This code parses out the run attempt number from the artifact name. If it
+ # doesn't match the one specified on the command line, skip it.
+ atoms = str(artifact_name).split("-")
+ for atom in atoms:
+ if atom.startswith("runattempt"):
+ found_run_attempt = int(atom[len("runattempt") :])
+ if workflow_run_attempt != found_run_attempt:
+ print(
+ f"Skipping {artifact_name} as it is an invalid run attempt. "
+ f"Expected {workflow_run_attempt}, found {found_run_attempt}."
+ )
+
+ print(f"Downloading {artifact_name}")
+
+ response = requests.get(artifact_url, headers=_get_request_headers())
+ with open(artifact_name, "wb") as f:
+ f.write(response.content)
+ return artifact_name
+
+
+def download_s3_artifacts(
+ prefix: str,
+ workflow_run_id: int,
+ workflow_run_attempt: int,
+ allowed_substrings: Optional[List[str]] = None,
+) -> List[Path]:
+ bucket = S3_RESOURCE.Bucket("gha-artifacts")
+ objs = bucket.objects.filter(
+ Prefix=f"pytorch/pytorch/{workflow_run_id}/{workflow_run_attempt}/artifact/{prefix}"
+ )
+
+ found_one = False
+ paths = []
+ for obj in objs:
+ p = Path(Path(obj.key).name)
+ if allowed_substrings and not any(sub in p.name for sub in allowed_substrings):
+ continue
+ found_one = True
+ print(f"Downloading {p}")
+ with open(p, "wb") as f:
+ f.write(obj.get()["Body"].read())
+ paths.append(p)
+
+ if not found_one:
+ print(
+ "::warning title=s3 artifacts not found::"
+ "Didn't find any test reports in s3, there might be a bug!"
+ )
+ return paths
+
+
+def download_gha_artifacts(
+ prefix: str, workflow_run_id: int, workflow_run_attempt: int
+) -> List[Path]:
+ artifact_urls = _get_artifact_urls(prefix, workflow_run_id)
+ paths = []
+ for name, url in artifact_urls.items():
+ paths.append(_download_artifact(Path(name), url, workflow_run_attempt))
+ return paths
+
+
+def upload_to_rockset(collection: str, docs: List[Any]) -> None:
+ print(f"Writing {len(docs)} documents to Rockset")
+ client = rockset.Client(
+ api_server="api.rs2.usw2.rockset.com", api_key=os.environ["ROCKSET_API_KEY"]
+ )
+ client.Collection.retrieve(collection).add_docs(docs)
+ print("Done!")
+
+
+def upload_to_s3(
+ workflow_run_id: int,
+ workflow_run_attempt: int,
+ collection: str,
+ docs: List[Dict[str, Any]],
+) -> None:
+ print(f"Writing {len(docs)} documents to S3")
+ body = io.StringIO()
+ for doc in docs:
+ json.dump(doc, body)
+ body.write("\n")
+
+ S3_RESOURCE.Object(
+ "ossci-raw-job-status",
+ f"{collection}/{workflow_run_id}/{workflow_run_attempt}",
+ ).put(
+ Body=gzip.compress(body.getvalue().encode()),
+ ContentEncoding="gzip",
+ ContentType="application/json",
+ )
+ print("Done!")
+
+
+def upload_file_to_s3(
+ file_name: str,
+ bucket: str,
+ key: str,
+) -> None:
+ """
+ Upload a local file to S3
+ """
+ print(f"Upload {file_name} to s3://{bucket}/{key}")
+ boto3.client("s3").upload_file(
+ file_name,
+ bucket,
+ key,
+ )
+
+
+def unzip(p: Path) -> None:
+ """Unzip the provided zipfile to a similarly-named directory.
+
+ Returns None if `p` is not a zipfile.
+
+ Looks like: /tmp/test-reports.zip -> /tmp/unzipped-test-reports/
+ """
+ assert p.is_file()
+ unzipped_dir = p.with_name("unzipped-" + p.stem)
+ print(f"Extracting {p} to {unzipped_dir}")
+
+ with zipfile.ZipFile(p, "r") as zip:
+ zip.extractall(unzipped_dir)
+
+
+def is_rerun_disabled_tests(root: ET.ElementTree) -> bool:
+ """
+ Check if the test report is coming from rerun_disabled_tests workflow
+ """
+ skipped = root.find(".//*skipped")
+ # Need to check against None here, if not skipped doesn't work as expected
+ if skipped is None:
+ return False
+
+ message = skipped.attrib.get("message", "")
+ return TARGET_WORKFLOW in message or "num_red" in message
diff --git a/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py b/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py
new file mode 100644
index 0000000000000..29384d3bd0b41
--- /dev/null
+++ b/.automation_scripts/pytorch-unit-test-scripts/upload_test_stats.py
@@ -0,0 +1,394 @@
+import argparse
+import os
+import sys
+import xml.etree.ElementTree as ET
+from pathlib import Path
+from tempfile import TemporaryDirectory
+from typing import Any, Dict, List, Tuple
+
+from upload_stats_lib import (
+ download_gha_artifacts,
+ download_s3_artifacts,
+ is_rerun_disabled_tests,
+ unzip,
+ upload_to_s3,
+)
+
+
+# Backends list
+BACKENDS_LIST = [
+ "dist-gloo",
+ "dist-nccl"
+]
+
+def get_job_id(report: Path) -> int:
+ # [Job id in artifacts]
+ # Retrieve the job id from the report path. In our GHA workflows, we append
+ # the job id to the end of the report name, so `report` looks like:
+ # unzipped-test-reports-foo_5596745227/test/test-reports/foo/TEST-foo.xml
+ # and we want to get `5596745227` out of it.
+ try:
+ return int(report.parts[0].rpartition("_")[2])
+ except ValueError:
+ return -1
+
+
+def parse_xml_report(
+ tag: str,
+ report: Path,
+ workflow_id: int,
+ workflow_run_attempt: int,
+ test_config: str
+) -> Dict[Tuple[str], Dict[str, Any]]:
+ """Convert a test report xml file into a JSON-serializable list of test cases."""
+ #print(f"Parsing {tag}s for test report: {report}")
+ print(".", end="", flush=True)
+
+ job_id = get_job_id(report)
+ #print(f"Found job id: {job_id}")
+
+ test_cases: Dict[Tuple[str], Dict[str, Any]] = {}
+
+ root = ET.parse(report)
+ # TODO: unlike unittest, pytest-flakefinder used by rerun disabled tests for test_ops
+ # includes skipped messages multiple times (50 times by default). This slows down
+ # this script too much (O(n)) because it tries to gather all the stats. This should
+ # be fixed later in the way we use pytest-flakefinder. A zipped test report from rerun
+ # disabled test is only few MB, but will balloon up to a much bigger XML file after
+ # extracting from a dozen to few hundred MB
+ if is_rerun_disabled_tests(root):
+ return test_cases
+
+ for test_case in root.iter(tag):
+ case = process_xml_element(test_case)
+ if tag == 'testcase':
+ case["workflow_id"] = workflow_id
+ case["workflow_run_attempt"] = workflow_run_attempt
+ case["job_id"] = job_id
+ case["test_config"] = test_config
+
+ # [invoking file]
+ # The name of the file that the test is located in is not necessarily
+ # the same as the name of the file that invoked the test.
+ # For example, `test_jit.py` calls into multiple other test files (e.g.
+ # jit/test_dce.py). For sharding/test selection purposes, we want to
+ # record the file that invoked the test.
+ #
+ # To do this, we leverage an implementation detail of how we write out
+ # tests (https://bit.ly/3ajEV1M), which is that reports are created
+ # under a folder with the same name as the invoking file.
+ case_name = report.parent.name
+ for part in report.parts:
+ for backend in BACKENDS_LIST:
+ if backend in part:
+ case_name = case_name + "_" + part
+ break
+ else:
+ continue
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["classname"], case["name"], case["test_config"] ) ] = case
+ elif tag == 'testsuite':
+ case["test_config"] = test_config
+ case["invoking_xml"] = report.name
+ case["running_time_xml"] = case["time"]
+ case_name = report.parent.name
+ for part in report.parts:
+ for backend in BACKENDS_LIST:
+ if backend in part:
+ case_name = case_name + "_" + part
+ break
+ else:
+ continue
+ break
+ case["invoking_file"] = case_name
+ test_cases[ ( case["invoking_file"], case["invoking_xml"], case["test_config"] ) ] = case
+
+ return test_cases
+
+
+def process_xml_element(element: ET.Element) -> Dict[str, Any]:
+ """Convert a test suite element into a JSON-serializable dict."""
+ ret: Dict[str, Any] = {}
+
+ # Convert attributes directly into dict elements.
+ # e.g.
+ #
+ # becomes:
+ # {"name": "test_foo", "classname": "test_bar"}
+ ret.update(element.attrib)
+
+ # The XML format encodes all values as strings. Convert to ints/floats if
+ # possible to make aggregation possible in Rockset.
+ for k, v in ret.items():
+ try:
+ ret[k] = int(v)
+ except ValueError:
+ pass
+ try:
+ ret[k] = float(v)
+ except ValueError:
+ pass
+
+ # Convert inner and outer text into special dict elements.
+ # e.g.
+ # my_inner_text my_tail
+ # becomes:
+ # {"text": "my_inner_text", "tail": " my_tail"}
+ if element.text and element.text.strip():
+ ret["text"] = element.text
+ if element.tail and element.tail.strip():
+ ret["tail"] = element.tail
+
+ # Convert child elements recursively, placing them at a key:
+ # e.g.
+ #
+ # hello
+ # world
+ # another
+ #
+ # becomes
+ # {
+ # "foo": [{"text": "hello"}, {"text": "world"}],
+ # "bar": {"text": "another"}
+ # }
+ for child in element:
+ if child.tag not in ret:
+ ret[child.tag] = process_xml_element(child)
+ else:
+ # If there are multiple tags with the same name, they should be
+ # coalesced into a list.
+ if not isinstance(ret[child.tag], list):
+ ret[child.tag] = [ret[child.tag]]
+ ret[child.tag].append(process_xml_element(child))
+ return ret
+
+
+def get_pytest_parallel_times() -> Dict[Any, Any]:
+ pytest_parallel_times: Dict[Any, Any] = {}
+ for report in Path(".").glob("**/python-pytest/**/*.xml"):
+ invoking_file = report.parent.name
+
+ root = ET.parse(report)
+ # TODO: Skip test reports from rerun disabled tests, same reason as mentioned
+ # above
+ if is_rerun_disabled_tests(root):
+ continue
+
+ assert len(list(root.iter("testsuite"))) == 1
+ for test_suite in root.iter("testsuite"):
+ pytest_parallel_times[
+ (invoking_file, get_job_id(report))
+ ] = test_suite.attrib["time"]
+ return pytest_parallel_times
+
+
+def get_tests(
+ workflow_run_id: int, workflow_run_attempt: int
+) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
+ with TemporaryDirectory() as temp_dir:
+ print("Using temporary directory:", temp_dir)
+ os.chdir(temp_dir)
+
+ # Download and extract all the reports (both GHA and S3)
+ s3_paths = download_s3_artifacts(
+ "test-report", workflow_run_id, workflow_run_attempt
+ )
+ for path in s3_paths:
+ unzip(path)
+
+ artifact_paths = download_gha_artifacts(
+ "test-report", workflow_run_id, workflow_run_attempt
+ )
+ for path in artifact_paths:
+ unzip(path)
+
+ # Parse the reports and transform them to JSON
+ test_cases = []
+ for xml_report in Path(".").glob("**/*.xml"):
+ test_cases.extend(
+ parse_xml_report(
+ "testcase",
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ )
+ )
+
+ pytest_parallel_times = get_pytest_parallel_times()
+
+ return test_cases, pytest_parallel_times
+
+
+def get_tests_for_circleci(
+ workflow_run_id: int, workflow_run_attempt: int
+) -> Tuple[List[Dict[str, Any]], Dict[Any, Any]]:
+ # Parse the reports and transform them to JSON
+ test_cases = []
+ for xml_report in Path(".").glob("**/test/test-reports/**/*.xml"):
+ test_cases.extend(
+ parse_xml_report(
+ "testcase", xml_report, workflow_run_id, workflow_run_attempt
+ )
+ )
+
+ pytest_parallel_times = get_pytest_parallel_times()
+
+ return test_cases, pytest_parallel_times
+
+
+def get_invoking_file_times(
+ test_case_summaries: List[Dict[str, Any]], pytest_parallel_times: Dict[Any, Any]
+) -> List[Dict[str, Any]]:
+ def get_key(summary: Dict[str, Any]) -> Any:
+ return (
+ summary["invoking_file"],
+ summary["job_id"],
+ )
+
+ def init_value(summary: Dict[str, Any]) -> Any:
+ return {
+ "job_id": summary["job_id"],
+ "workflow_id": summary["workflow_id"],
+ "workflow_run_attempt": summary["workflow_run_attempt"],
+ "invoking_file": summary["invoking_file"],
+ "time": 0.0,
+ }
+
+ ret = {}
+ for summary in test_case_summaries:
+ key = get_key(summary)
+ if key not in ret:
+ ret[key] = init_value(summary)
+ ret[key]["time"] += summary["time"]
+
+ for key, val in ret.items():
+ # when running in parallel in pytest, adding the test times will not give the correct
+ # time used to run the file, which will make the sharding incorrect, so if the test is
+ # run in parallel, we take the time reported by the testsuite
+ if key in pytest_parallel_times:
+ val["time"] = pytest_parallel_times[key]
+
+ return list(ret.values())
+
+
+def summarize_test_cases(test_cases: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """Group test cases by classname, file, and job_id. We perform the aggregation
+ manually instead of using the `test-suite` XML tag because xmlrunner does
+ not produce reliable output for it.
+ """
+
+ def get_key(test_case: Dict[str, Any]) -> Any:
+ return (
+ test_case.get("file"),
+ test_case.get("classname"),
+ test_case["job_id"],
+ test_case["workflow_id"],
+ test_case["workflow_run_attempt"],
+ # [see: invoking file]
+ test_case["invoking_file"],
+ )
+
+ def init_value(test_case: Dict[str, Any]) -> Dict[str, Any]:
+ return {
+ "file": test_case.get("file"),
+ "classname": test_case.get("classname"),
+ "job_id": test_case["job_id"],
+ "workflow_id": test_case["workflow_id"],
+ "workflow_run_attempt": test_case["workflow_run_attempt"],
+ # [see: invoking file]
+ "invoking_file": test_case["invoking_file"],
+ "tests": 0,
+ "failures": 0,
+ "errors": 0,
+ "skipped": 0,
+ "successes": 0,
+ "time": 0.0,
+ }
+
+ ret = {}
+ for test_case in test_cases:
+ key = get_key(test_case)
+ if key not in ret:
+ ret[key] = init_value(test_case)
+
+ ret[key]["tests"] += 1
+
+ if "failure" in test_case:
+ ret[key]["failures"] += 1
+ elif "error" in test_case:
+ ret[key]["errors"] += 1
+ elif "skipped" in test_case:
+ ret[key]["skipped"] += 1
+ else:
+ ret[key]["successes"] += 1
+
+ ret[key]["time"] += test_case["time"]
+ return list(ret.values())
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Upload test stats to Rockset")
+ parser.add_argument(
+ "--workflow-run-id",
+ required=True,
+ help="id of the workflow to get artifacts from",
+ )
+ parser.add_argument(
+ "--workflow-run-attempt",
+ type=int,
+ required=True,
+ help="which retry of the workflow this is",
+ )
+ parser.add_argument(
+ "--head-branch",
+ required=True,
+ help="Head branch of the workflow",
+ )
+ parser.add_argument(
+ "--circleci",
+ action="store_true",
+ help="If this is being run through circleci",
+ )
+ args = parser.parse_args()
+
+ print(f"Workflow id is: {args.workflow_run_id}")
+
+ if args.circleci:
+ test_cases, pytest_parallel_times = get_tests_for_circleci(
+ args.workflow_run_id, args.workflow_run_attempt
+ )
+ else:
+ test_cases, pytest_parallel_times = get_tests(
+ args.workflow_run_id, args.workflow_run_attempt
+ )
+
+ # Flush stdout so that any errors in rockset upload show up last in the logs.
+ sys.stdout.flush()
+
+ # For PRs, only upload a summary of test_runs. This helps lower the
+ # volume of writes we do to Rockset.
+ test_case_summary = summarize_test_cases(test_cases)
+ invoking_file_times = get_invoking_file_times(
+ test_case_summary, pytest_parallel_times
+ )
+
+ upload_to_s3(
+ args.workflow_run_id,
+ args.workflow_run_attempt,
+ "test_run_summary",
+ test_case_summary,
+ )
+
+ upload_to_s3(
+ args.workflow_run_id,
+ args.workflow_run_attempt,
+ "invoking_file_times",
+ invoking_file_times,
+ )
+
+ if args.head_branch == "master":
+ # For master jobs, upload everytihng.
+ upload_to_s3(
+ args.workflow_run_id, args.workflow_run_attempt, "test_run", test_cases
+ )
diff --git a/.automation_scripts/run_pytorch_unit_tests.py b/.automation_scripts/run_pytorch_unit_tests.py
new file mode 100644
index 0000000000000..514afd19624c3
--- /dev/null
+++ b/.automation_scripts/run_pytorch_unit_tests.py
@@ -0,0 +1,518 @@
+#!/usr/bin/env python3
+
+""" The Python PyTorch testing script.
+##
+# Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved.
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in
+# all copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+# THE SOFTWARE.
+"""
+
+import argparse
+import os
+import shutil
+import subprocess
+from subprocess import STDOUT, CalledProcessError
+
+from collections import namedtuple
+from datetime import datetime
+from pathlib import Path
+from parse_xml_results import (
+ parse_xml_report
+)
+from pprint import pprint
+from typing import Any, Dict, List
+
+# unit test status list
+UT_STATUS_LIST = [
+ "PASSED",
+ "MISSED",
+ "SKIPPED",
+ "FAILED",
+ "XFAILED",
+ "ERROR"
+]
+
+DEFAULT_CORE_TESTS = [
+ "test_nn",
+ "test_torch",
+ "test_cuda",
+ "test_ops",
+ "test_unary_ufuncs",
+ "test_autograd",
+ "inductor/test_torchinductor"
+]
+
+DISTRIBUTED_CORE_TESTS = [
+ "distributed/test_c10d_common",
+ "distributed/test_c10d_nccl",
+ "distributed/test_distributed_spawn"
+]
+
+CONSOLIDATED_LOG_FILE_NAME="pytorch_unit_tests.log"
+
+def parse_xml_reports_as_dict(workflow_run_id, workflow_run_attempt, tag, workflow_name, path="."):
+ test_cases = {}
+ items_list = os.listdir(path)
+ for dir in items_list:
+ new_dir = path + '/' + dir + '/'
+ if os.path.isdir(new_dir):
+ for xml_report in Path(new_dir).glob("**/*.xml"):
+ test_cases.update(
+ parse_xml_report(
+ tag,
+ xml_report,
+ workflow_run_id,
+ workflow_run_attempt,
+ workflow_name
+ )
+ )
+ return test_cases
+
+def get_test_status(test_case):
+ # In order of priority: S=skipped, F=failure, E=error, P=pass
+ if "skipped" in test_case and test_case["skipped"]:
+ type_message = test_case["skipped"]
+ if type_message.__contains__('type') and type_message['type'] == "pytest.xfail":
+ return "XFAILED"
+ else:
+ return "SKIPPED"
+ elif "failure" in test_case and test_case["failure"]:
+ return "FAILED"
+ elif "error" in test_case and test_case["error"]:
+ return "ERROR"
+ else:
+ return "PASSED"
+
+def get_test_message(test_case, status=None):
+ if status == "SKIPPED":
+ return test_case["skipped"] if "skipped" in test_case else ""
+ elif status == "FAILED":
+ return test_case["failure"] if "failure" in test_case else ""
+ elif status == "ERROR":
+ return test_case["error"] if "error" in test_case else ""
+ else:
+ if "skipped" in test_case:
+ return test_case["skipped"]
+ elif "failure" in test_case:
+ return test_case["failure"]
+ elif "error" in test_case:
+ return test_case["error"]
+ else:
+ return ""
+
+def get_test_file_running_time(test_suite):
+ if test_suite.__contains__('time'):
+ return test_suite["time"]
+ return 0
+
+def get_test_running_time(test_case):
+ if test_case.__contains__('time'):
+ return test_case["time"]
+ return ""
+
+def summarize_xml_files(path, workflow_name):
+ # statistics
+ TOTAL_TEST_NUM = 0
+ TOTAL_PASSED_NUM = 0
+ TOTAL_SKIPPED_NUM = 0
+ TOTAL_XFAIL_NUM = 0
+ TOTAL_FAILED_NUM = 0
+ TOTAL_ERROR_NUM = 0
+ TOTAL_EXECUTION_TIME = 0
+
+ #parse the xml files
+ test_cases = parse_xml_reports_as_dict(-1, -1, 'testcase', workflow_name, path)
+ test_suites = parse_xml_reports_as_dict(-1, -1, 'testsuite', workflow_name, path)
+ test_file_and_status = namedtuple("test_file_and_status", ["file_name", "status"])
+ # results dict
+ res = {}
+ res_item_list = [ "PASSED", "SKIPPED", "XFAILED", "FAILED", "ERROR" ]
+ test_file_items = set()
+ for (k,v) in list(test_suites.items()):
+ file_name = k[0]
+ if not file_name in test_file_items:
+ test_file_items.add(file_name)
+ # initialization
+ for item in res_item_list:
+ temp_item = test_file_and_status(file_name, item)
+ res[temp_item] = {}
+ temp_item_statistics = test_file_and_status(file_name, "STATISTICS")
+ res[temp_item_statistics] = {'TOTAL': 0, 'PASSED': 0, 'SKIPPED': 0, 'XFAILED': 0, 'FAILED': 0, 'ERROR': 0, 'EXECUTION_TIME': 0}
+ test_running_time = get_test_file_running_time(v)
+ res[temp_item_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+ else:
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ test_running_time = get_test_file_running_time(v)
+ res[test_tuple_key_statistics]["EXECUTION_TIME"] += test_running_time
+ TOTAL_EXECUTION_TIME += test_running_time
+
+ for (k,v) in list(test_cases.items()):
+ file_name = k[0]
+ class_name = k[1]
+ test_name = k[2]
+ combined_name = file_name + "::" + class_name + "::" + test_name
+ test_status = get_test_status(v)
+ test_running_time = get_test_running_time(v)
+ test_message = get_test_message(v, test_status)
+ test_info_value = ""
+ test_tuple_key_status = test_file_and_status(file_name, test_status)
+ test_tuple_key_statistics = test_file_and_status(file_name, "STATISTICS")
+ TOTAL_TEST_NUM += 1
+ res[test_tuple_key_statistics]["TOTAL"] += 1
+ if test_status == "PASSED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["PASSED"] += 1
+ TOTAL_PASSED_NUM += 1
+ elif test_status == "SKIPPED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["SKIPPED"] += 1
+ TOTAL_SKIPPED_NUM += 1
+ elif test_status == "XFAILED":
+ test_info_value = str(test_running_time)
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["XFAILED"] += 1
+ TOTAL_XFAIL_NUM += 1
+ elif test_status == "FAILED":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["FAILED"] += 1
+ TOTAL_FAILED_NUM += 1
+ elif test_status == "ERROR":
+ test_info_value = test_message
+ res[test_tuple_key_status][combined_name] = test_info_value
+ res[test_tuple_key_statistics]["ERROR"] += 1
+ TOTAL_ERROR_NUM += 1
+
+ # generate statistics_dict
+ statistics_dict = {}
+ statistics_dict["TOTAL"] = TOTAL_TEST_NUM
+ statistics_dict["PASSED"] = TOTAL_PASSED_NUM
+ statistics_dict["SKIPPED"] = TOTAL_SKIPPED_NUM
+ statistics_dict["XFAILED"] = TOTAL_XFAIL_NUM
+ statistics_dict["FAILED"] = TOTAL_FAILED_NUM
+ statistics_dict["ERROR"] = TOTAL_ERROR_NUM
+ statistics_dict["EXECUTION_TIME"] = TOTAL_EXECUTION_TIME
+ aggregate_item = workflow_name + "_aggregate"
+ total_item = test_file_and_status(aggregate_item, "STATISTICS")
+ res[total_item] = statistics_dict
+
+ return res
+
+def run_command_and_capture_output(cmd):
+ try:
+ print(f"Running command '{cmd}'")
+ with open(CONSOLIDATED_LOG_FILE_PATH, "a+") as output_file:
+ print(f"========================================", file=output_file, flush=True)
+ print(f"[RUN_PYTORCH_UNIT_TESTS] Running command '{cmd}'", file=output_file, flush=True) # send to consolidated file as well
+ print(f"========================================", file=output_file, flush=True)
+ p = subprocess.run(cmd, shell=True, stdout=output_file, stderr=STDOUT, text=True)
+ except CalledProcessError as e:
+ print(f"ERROR: Cmd {cmd} failed with return code: {e.returncode}!")
+
+def run_entire_tests(workflow_name, test_shell_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_entire_tests/"
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_entire_tests/"
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_entire_tests/"
+ # use test.sh for tests execution
+ run_command_and_capture_output(test_shell_path)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ entire_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+ return entire_results_dict
+
+def run_priority_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ default_priority_test_suites = " ".join(DEFAULT_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + default_priority_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_priority_tests/"
+ # use run_test.py for tests execution
+ distributed_priority_test_suites = " ".join(DISTRIBUTED_CORE_TESTS)
+ command = "python3 " + test_run_test_path + " --include " + distributed_priority_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ priority_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return priority_results_dict
+
+def run_selected_tests(workflow_name, test_run_test_path, overall_logs_path_current_run, test_reports_src, selected_list):
+ if os.path.exists(test_reports_src):
+ shutil.rmtree(test_reports_src)
+
+ os.mkdir(test_reports_src)
+ copied_logs_path = ""
+ if workflow_name == "default":
+ os.environ['TEST_CONFIG'] = 'default'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0'
+ copied_logs_path = overall_logs_path_current_run + "default_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ default_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + default_selected_test_suites + " --exclude-jit-executor --exclude-distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "distributed":
+ os.environ['TEST_CONFIG'] = 'distributed'
+ os.environ['HIP_VISIBLE_DEVICES'] = '0,1'
+ copied_logs_path = overall_logs_path_current_run + "distributed_xml_results_selected_tests/"
+ # use run_test.py for tests execution
+ distributed_selected_test_suites = " ".join(selected_list)
+ command = "python3 " + test_run_test_path + " --include " + distributed_selected_test_suites + " --distributed-tests --verbose"
+ run_command_and_capture_output(command)
+ del os.environ['HIP_VISIBLE_DEVICES']
+ elif workflow_name == "inductor":
+ os.environ['TEST_CONFIG'] = 'inductor'
+ copied_logs_path = overall_logs_path_current_run + "inductor_xml_results_selected_tests/"
+ inductor_selected_test_suites = ""
+ non_inductor_selected_test_suites = ""
+ for item in selected_list:
+ if "inductor/" in item:
+ inductor_selected_test_suites += item
+ inductor_selected_test_suites += " "
+ else:
+ non_inductor_selected_test_suites += item
+ non_inductor_selected_test_suites += " "
+ if inductor_selected_test_suites != "":
+ inductor_selected_test_suites = inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --include " + inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ if non_inductor_selected_test_suites != "":
+ non_inductor_selected_test_suites = non_inductor_selected_test_suites[:-1]
+ command = "python3 " + test_run_test_path + " --inductor --include " + non_inductor_selected_test_suites + " --verbose"
+ run_command_and_capture_output(command)
+ copied_logs_path_destination = shutil.copytree(test_reports_src, copied_logs_path)
+ selected_results_dict = summarize_xml_files(copied_logs_path_destination, workflow_name)
+
+ return selected_results_dict
+
+def run_test_and_summarize_results(
+ pytorch_root_dir: str,
+ priority_tests: bool,
+ test_config: List[str],
+ default_list: List[str],
+ distributed_list: List[str],
+ inductor_list: List[str],
+ skip_rerun: bool) -> Dict[str, Any]:
+
+ # copy current environment variables
+ _environ = dict(os.environ)
+
+ # modify path
+ test_shell_path = pytorch_root_dir + "/.ci/pytorch/test.sh"
+ test_run_test_path = pytorch_root_dir + "/test/run_test.py"
+ repo_test_log_folder_path = pytorch_root_dir + "/.automation_logs/"
+ test_reports_src = pytorch_root_dir + "/test/test-reports/"
+ run_test_python_file = pytorch_root_dir + "/test/run_test.py"
+
+ # change directory to pytorch root
+ os.chdir(pytorch_root_dir)
+
+ # all test results dict
+ res_all_tests_dict = {}
+
+ # patterns
+ search_text = "--reruns=2"
+ replace_text = "--reruns=0"
+
+ # create logs folder
+ if not os.path.exists(repo_test_log_folder_path):
+ os.mkdir(repo_test_log_folder_path)
+
+ # Set common environment variables for all scenarios
+ os.environ['CI'] = '1'
+ os.environ['PYTORCH_TEST_WITH_ROCM'] = '1'
+ os.environ['HSA_FORCE_FINE_GRAIN_PCIE'] = '1'
+ os.environ['PYTORCH_TESTING_DEVICE_ONLY_FOR'] = 'cuda'
+ os.environ['CONTINUE_THROUGH_ERROR'] = 'True'
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(search_text, replace_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ # Time stamp
+ current_datetime = datetime.now().strftime("%Y%m%d_%H-%M-%S")
+ print("Current date & time : ", current_datetime)
+ # performed as Job ID
+ str_current_datetime = str(current_datetime)
+ overall_logs_path_current_run = repo_test_log_folder_path + str_current_datetime + "/"
+ os.mkdir(overall_logs_path_current_run)
+
+ global CONSOLIDATED_LOG_FILE_PATH
+ CONSOLIDATED_LOG_FILE_PATH = overall_logs_path_current_run + CONSOLIDATED_LOG_FILE_NAME
+
+ # Check multi gpu availability if distributed tests are enabled
+ if ("distributed" in test_config) or len(distributed_list) != 0:
+ check_num_gpus_for_distributed()
+
+ # Install test requirements
+ command = "pip3 install -r requirements.txt && pip3 install -r .ci/docker/requirements-ci.txt"
+ run_command_and_capture_output(command)
+
+ # Run entire tests for each workflow
+ if not priority_tests and not default_list and not distributed_list and not inductor_list:
+ # run entire tests for default, distributed and inductor workflows → use test.sh
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ # distributed test process
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ # inductor test process
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_all = run_entire_tests("default", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_all
+ if "distributed" in workflow_list:
+ res_distributed_all = run_entire_tests("distributed", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_all
+ if "inductor" in workflow_list:
+ res_inductor_all = run_entire_tests("inductor", test_shell_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["inductor"] = res_inductor_all
+ # Run priority test for each workflow
+ elif priority_tests and not default_list and not distributed_list and not inductor_list:
+ if not test_config:
+ check_num_gpus_for_distributed()
+ # default test process
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ # distributed test process
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ # will not run inductor priority tests
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ else:
+ workflow_list = []
+ for item in test_config:
+ workflow_list.append(item)
+ if "default" in workflow_list:
+ res_default_priority = run_priority_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["default"] = res_default_priority
+ if "distributed" in workflow_list:
+ res_distributed_priority = run_priority_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src)
+ res_all_tests_dict["distributed"] = res_distributed_priority
+ if "inductor" in workflow_list:
+ print("Inductor priority tests cannot run since no core tests defined with inductor workflow.")
+ # Run specified tests for each workflow
+ elif (default_list or distributed_list or inductor_list) and not test_config and not priority_tests:
+ if default_list:
+ default_workflow_list = []
+ for item in default_list:
+ default_workflow_list.append(item)
+ res_default_selected = run_selected_tests("default", test_run_test_path, overall_logs_path_current_run, test_reports_src, default_workflow_list)
+ res_all_tests_dict["default"] = res_default_selected
+ if distributed_list:
+ distributed_workflow_list = []
+ for item in distributed_list:
+ distributed_workflow_list.append(item)
+ res_distributed_selected = run_selected_tests("distributed", test_run_test_path, overall_logs_path_current_run, test_reports_src, distributed_workflow_list)
+ res_all_tests_dict["distributed"] = res_distributed_selected
+ if inductor_list:
+ inductor_workflow_list = []
+ for item in inductor_list:
+ inductor_workflow_list.append(item)
+ res_inductor_selected = run_selected_tests("inductor", test_run_test_path, overall_logs_path_current_run, test_reports_src, inductor_workflow_list)
+ res_all_tests_dict["inductor"] = res_inductor_selected
+ else:
+ raise Exception("Invalid test configurations!")
+
+ # restore environment variables
+ os.environ.clear()
+ os.environ.update(_environ)
+
+ # restore files
+ if skip_rerun:
+ # modify run_test.py in-place
+ with open(run_test_python_file, 'r') as file:
+ data = file.read()
+ data = data.replace(replace_text, search_text)
+ with open(run_test_python_file, 'w') as file:
+ file.write(data)
+
+ return res_all_tests_dict
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Run PyTorch unit tests and generate xml results summary', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('--test_config', nargs='+', default=[], type=str, help="space-separated list of test workflows to be executed eg. 'default distributed'")
+ parser.add_argument('--priority_tests', action='store_true', help="run priority tests only")
+ parser.add_argument('--default_list', nargs='+', default=[], help="space-separated list of 'default' config test suites/files to be executed eg. 'test_weak test_dlpack'")
+ parser.add_argument('--distributed_list', nargs='+', default=[], help="space-separated list of 'distributed' config test suites/files to be executed eg. 'distributed/test_c10d_common distributed/test_c10d_nccl'")
+ parser.add_argument('--inductor_list', nargs='+', default=[], help="space-separated list of 'inductor' config test suites/files to be executed eg. 'inductor/test_torchinductor test_ops'")
+ parser.add_argument('--pytorch_root', default='.', type=str, help="PyTorch root directory")
+ parser.add_argument('--skip_rerun', action='store_true', help="skip rerun process")
+ parser.add_argument('--example_output', type=str, help="{'workflow_name': {\n"
+ " test_file_and_status(file_name='workflow_aggregate', status='STATISTICS'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='ERROR'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='FAILED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='PASSED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='SKIPPED'): {}, \n"
+ " test_file_and_status(file_name='test_file_name_1', status='STATISTICS'): {} \n"
+ "}}\n")
+ parser.add_argument('--example_usages', type=str, help="RUN ALL TESTS: python3 run_pytorch_unit_tests.py \n"
+ "RUN PRIORITY TESTS: python3 run_pytorch_unit_tests.py --test_config distributed --priority_test \n"
+ "RUN SELECTED TESTS: python3 run_pytorch_unit_tests.py --default_list test_weak test_dlpack --inductor_list inductor/test_torchinductor")
+ return parser.parse_args()
+
+def check_num_gpus_for_distributed():
+ p = subprocess.run("rocminfo | grep -cE 'Name:\s+gfx'", shell=True, capture_output=True, text=True)
+ num_gpus_visible = int(p.stdout)
+ assert num_gpus_visible > 1, "Number of visible GPUs should be >1 to run distributed unit tests"
+
+def main():
+ args = parse_args()
+ all_tests_results = run_test_and_summarize_results(args.pytorch_root, args.priority_tests, args.test_config, args.default_list, args.distributed_list, args.inductor_list, args.skip_rerun)
+ pprint(dict(all_tests_results))
+
+if __name__ == "__main__":
+ main()
diff --git a/.ci/docker/ci_commit_pins/triton.txt b/.ci/docker/ci_commit_pins/triton.txt
index 23407b4d540c4..3d17e9c0de64b 100644
--- a/.ci/docker/ci_commit_pins/triton.txt
+++ b/.ci/docker/ci_commit_pins/triton.txt
@@ -1 +1 @@
-9844da955a9db14ec69c9aac828ee9803085e288
+ba5c1517e6f5906761cf5783036efb587026208d
diff --git a/.ci/docker/common/install_cache.sh b/.ci/docker/common/install_cache.sh
index 040a31fc379d0..9bb80a4e80eca 100644
--- a/.ci/docker/common/install_cache.sh
+++ b/.ci/docker/common/install_cache.sh
@@ -38,7 +38,12 @@ sed -e 's|PATH="\(.*\)"|PATH="/opt/cache/bin:\1"|g' -i /etc/environment
export PATH="/opt/cache/bin:$PATH"
# Setup compiler cache
-install_ubuntu
+if [ -n "$ROCM_VERSION" ]; then
+ curl --retry 3 http://repo.radeon.com/misc/.sccache_amd/sccache -o /opt/cache/bin/sccache
+else
+ install_ubuntu
+fi
+
chmod a+x /opt/cache/bin/sccache
function write_sccache_stub() {
diff --git a/.ci/docker/common/install_triton.sh b/.ci/docker/common/install_triton.sh
index 1b68e3c247839..b2fdebdcc4747 100755
--- a/.ci/docker/common/install_triton.sh
+++ b/.ci/docker/common/install_triton.sh
@@ -21,7 +21,7 @@ elif [ -n "${TRITON_CPU}" ]; then
TRITON_REPO="https://github.com/triton-lang/triton-cpu"
TRITON_TEXT_FILE="triton-cpu"
else
- TRITON_REPO="https://github.com/triton-lang/triton"
+ TRITON_REPO="https://github.com/ROCm/triton"
TRITON_TEXT_FILE="triton"
fi
diff --git a/.ci/docker/requirements-ci.txt b/.ci/docker/requirements-ci.txt
index 14b8ff59fcfbe..cf79a13b4e444 100644
--- a/.ci/docker/requirements-ci.txt
+++ b/.ci/docker/requirements-ci.txt
@@ -120,7 +120,7 @@ ninja==1.11.1.4
numba==0.57.1 ; python_version == "3.10" and platform_machine != "s390x"
numba==0.60.0 ; python_version == "3.12" and platform_machine != "s390x"
#Description: Just-In-Time Compiler for Numerical Functions
-#Pinned versions: 0.55.2, 0.60.0
+#Pinned versions: 0.54.1, 0.49.0, <=0.49.1
#test that import: test_numba_integration.py
#Need release > 0.61.2 for s390x due to https://github.com/numba/numba/pull/10073
@@ -141,6 +141,7 @@ numpy==1.26.2; python_version == "3.11" or python_version == "3.12"
numpy==2.1.2; python_version >= "3.13" and python_version < "3.14"
numpy==2.3.4; python_version >= "3.14"
+
pandas==2.0.3; python_version < "3.12"
pandas==2.2.3; python_version >= "3.12" and python_version < "3.14"
pandas==2.3.3; python_version >= "3.14"
@@ -254,8 +255,7 @@ scikit-image==0.22.0
#Pinned versions: 0.20.3
#test that import:
-scipy==1.10.1 ; python_version <= "3.11"
-scipy==1.14.1 ; python_version > "3.11" and python_version < "3.14"
+scipy==1.14.1 ; python_version > "3.9" and python_version < "3.14"
scipy==1.16.2 ; python_version >= "3.14"
# Pin SciPy because of failing distribution tests (see #60347)
#Description: scientific python
@@ -316,8 +316,7 @@ z3-solver==4.15.1.0 ; platform_machine != "s390x"
#Pinned versions:
#test that import:
-tensorboard==2.13.0 ; python_version < "3.13"
-tensorboard==2.18.0 ; python_version >= "3.13"
+tensorboard==2.18.0
#Description: Also included in .ci/docker/requirements-docs.txt
#Pinned versions:
#test that import: test_tensorboard
diff --git a/.ci/pytorch/common_utils.sh b/.ci/pytorch/common_utils.sh
index b5c0a5e43dea7..88e587ab5ff7d 100644
--- a/.ci/pytorch/common_utils.sh
+++ b/.ci/pytorch/common_utils.sh
@@ -67,13 +67,13 @@ function pip_install_whl() {
# Loop through each path and install individually
for path in "${paths[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
else
# Loop through each argument and install individually
for path in "${args[@]}"; do
echo "Installing $path"
- python3 -mpip install --no-index --no-deps "$path"
+ python3 -mpip install "$path"
done
fi
}
diff --git a/.github/scripts/build_triton_wheel.py b/.github/scripts/build_triton_wheel.py
index e5ac9c5937dfa..0979c6f3f436e 100644
--- a/.github/scripts/build_triton_wheel.py
+++ b/.github/scripts/build_triton_wheel.py
@@ -3,6 +3,7 @@
from __future__ import annotations
import os
+import re
import shutil
import sys
from pathlib import Path
@@ -51,6 +52,31 @@ def patch_init_py(
with open(path, "w") as f:
f.write(orig)
+def get_rocm_version() -> str:
+ rocm_path = os.environ.get('ROCM_HOME') or os.environ.get('ROCM_PATH') or "/opt/rocm"
+ rocm_version = "0.0.0"
+ rocm_version_h = f"{rocm_path}/include/rocm-core/rocm_version.h"
+ if not os.path.isfile(rocm_version_h):
+ rocm_version_h = f"{rocm_path}/include/rocm_version.h"
+
+ # The file could be missing due to 1) ROCm version < 5.2, or 2) no ROCm install.
+ if os.path.isfile(rocm_version_h):
+ RE_MAJOR = re.compile(r"#define\s+ROCM_VERSION_MAJOR\s+(\d+)")
+ RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
+ RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
+ major, minor, patch = 0, 0, 0
+ for line in open(rocm_version_h):
+ match = RE_MAJOR.search(line)
+ if match:
+ major = int(match.group(1))
+ match = RE_MINOR.search(line)
+ if match:
+ minor = int(match.group(1))
+ match = RE_PATCH.search(line)
+ if match:
+ patch = int(match.group(1))
+ rocm_version = str(major)+"."+str(minor)+"."+str(patch)
+ return rocm_version
def build_triton(
*,
@@ -66,13 +92,22 @@ def build_triton(
max_jobs = os.cpu_count() or 1
env["MAX_JOBS"] = str(max_jobs)
+ version_suffix = ""
+ if not release:
+ # Nightly binaries include the triton commit hash, i.e. 2.1.0+e6216047b8
+ # while release build should only include the version, i.e. 2.1.0
+ rocm_version = get_rocm_version()
+ version_suffix = f"+rocm{rocm_version}.git{commit_hash[:8]}"
+ version += version_suffix
+
with TemporaryDirectory() as tmpdir:
triton_basedir = Path(tmpdir) / "triton"
triton_pythondir = triton_basedir / "python"
triton_repo = "https://github.com/openai/triton"
if device == "rocm":
- triton_pkg_name = "triton-rocm"
+ triton_pkg_name = "triton"
+ triton_repo = "https://github.com/ROCm/triton"
elif device == "xpu":
triton_pkg_name = "triton-xpu"
triton_repo = "https://github.com/intel/intel-xpu-backend-for-triton"
@@ -90,6 +125,7 @@ def build_triton(
# change built wheel name and version
env["TRITON_WHEEL_NAME"] = triton_pkg_name
+ env["TRITON_WHEEL_VERSION_SUFFIX"] = version_suffix
if with_clang_ldd:
env["TRITON_BUILD_WITH_CLANG_LLD"] = "1"
@@ -128,6 +164,13 @@ def build_triton(
cwd=triton_basedir,
)
+ # For gpt-oss models, triton requires this extra triton_kernels wheel
+ # triton_kernels came after pytorch release/2.8
+ triton_kernels_dir = Path(f"{triton_basedir}/python/triton_kernels")
+ check_call([sys.executable, "-m", "build", "--wheel"], cwd=triton_kernels_dir, env=env)
+ kernels_whl_path = next(iter((triton_kernels_dir / "dist").glob("*.whl")))
+ shutil.copy(kernels_whl_path, Path.cwd())
+
return Path.cwd() / whl_path.name
diff --git a/.github/scripts/install_pytorch_wheels.py b/.github/scripts/install_pytorch_wheels.py
new file mode 100644
index 0000000000000..cf8dc5eccc0c6
--- /dev/null
+++ b/.github/scripts/install_pytorch_wheels.py
@@ -0,0 +1,306 @@
+#!/usr/bin/env python3
+"""
+install_pytorch_wheels.py
+
+Installs PyTorch wheels from a pip index URL.
+
+Usage (from repo root):
+ python .github/scripts/install_pytorch_wheels.py --index-url --amdgpu-family [OPTIONS]
+
+Examples:
+ # Install latest versions
+ python .github/scripts/install_pytorch_wheels.py \
+ --index-url /whl \
+ --amdgpu-family gfx1250
+
+ # Install specific versions (matching ROCm builds)
+ python .github/scripts/install_pytorch_wheels.py \
+ --index-url /whl \
+ --amdgpu-family gfx1250 \
+ --torch-version "2.10.0+devrocm7.12.0.dev0.849eec43b..." \
+ --torchaudio-version "2.11.0a0+devrocm7.12.0.dev0.849eec43b..." \
+ --torchvision-version "0.25.0a0+devrocm7.12.0.dev0.849eec43b..."
+"""
+
+import argparse
+import re
+import subprocess
+import sys
+import urllib.parse
+import urllib.request
+
+
+# Package configuration: (name, always_install)
+PACKAGES = {
+ "torch": True,
+ "torchaudio": True,
+ "torchvision": True,
+ "triton": False,
+ "rocm[devel]": True,
+}
+PYTORCH_PKGS = ["torch", "torchaudio", "torchvision", "triton"]
+
+
+def print_banner(title: str) -> None:
+ """Print a formatted banner."""
+ print("=" * 50)
+ print(title)
+ print("=" * 50)
+
+
+def build_package_spec(name: str, version: str | None) -> str:
+ """Build a pip package spec (e.g., 'torch==2.10.0' or 'torch')."""
+ return f"{name}=={version}" if version else name
+
+def get_latest_package_version_for_rocm(
+ index_url: str, package_name: str, rocm_version: str, required: bool = True,
+ version_prefix: str | None = None,
+) -> str | None:
+ """Return latest package version containing rocm_version by parsing the index HTML.
+
+ If version_prefix is set (e.g. "2.9"), only versions whose base part starts
+ with that prefix are considered.
+ """
+
+ # Build the URL for this package's index page (e.g. .../gfx1250/torch/).
+ rocm_tag = f"rocm{rocm_version}"
+ url = f"{index_url.rstrip('/')}/{package_name}/"
+ # Fetch the package index page; on failure (e.g. 404, timeout) fail if always_install, else return None.
+ try:
+ with urllib.request.urlopen(url, timeout=30) as resp:
+ html = resp.read().decode("utf-8", errors="ignore")
+ except Exception as e:
+ print(f"Error: failed to fetch index for {package_name}: {e}", file=sys.stderr)
+ sys.exit(1)
+ # Parse wheel links: format is package-VERSION-...whl (e.g. torch-0.26.0a0+rocm7.12...-cp312-....whl).
+ # Version can contain dots and + (URL-encoded as %2B), so we capture everything up to .whl.
+ pattern = re.compile(
+ re.escape(package_name) + r"-(.+?)\.whl",
+ re.IGNORECASE,
+ )
+ all_suffixes = [m.group(1).strip() for m in pattern.finditer(html)]
+ # Keep only wheels whose version string contains the requested ROCm tag (e.g. rocm7.12.0a20260224).
+ # Version is the first segment before "-" in the suffix; decode %2B to + for comparison.
+ matching = []
+ for s in all_suffixes:
+ ver = s.split("-")[0]
+ if rocm_tag in ver:
+ matching.append(urllib.parse.unquote(ver))
+ # Filter by version prefix (e.g. "2.9" matches "2.9.0+...", "2.9.1+...").
+ if version_prefix and matching:
+ matching = [v for v in matching if v.split("+")[0].startswith(version_prefix)]
+ # No matching wheels: if required (always_install), fail; otherwise return None (package will be skipped).
+ if not matching:
+ if required:
+ msg = f"Error: no wheel found for {package_name} with ROCm {rocm_version}"
+ if version_prefix:
+ msg += f" and version prefix {version_prefix}"
+ print(msg, file=sys.stderr)
+ sys.exit(1)
+ return None
+ # Pick the latest version by comparing all numeric parts including the ROCm date.
+ def _key(v: str) -> tuple[int, ...]:
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+ return max(matching, key=_key)
+
+
+def run_pip_install(
+ index_url: str, packages: list[str], break_system_packages: bool = True
+) -> None:
+ """Run pip install with the given packages."""
+ cmd = [sys.executable, "-m", "pip", "install", "--index-url", index_url]
+
+ if break_system_packages:
+ cmd.append("--break-system-packages")
+
+ cmd.extend(packages)
+
+ print(f"Running: {' '.join(cmd)}")
+ result = subprocess.run(cmd, check=False)
+
+ if result.returncode != 0:
+ print(f"Error: pip install failed with return code {result.returncode}")
+ sys.exit(result.returncode)
+
+
+def check_package(name: str) -> tuple[bool, str | None]:
+ """Check if a package is installed and return (installed, version)."""
+ try:
+ module = __import__(name)
+ return True, getattr(module, "__version__", "unknown")
+ except ImportError:
+ return False, None
+
+
+def verify_installation() -> bool:
+ """Verify PyTorch installation and print version info."""
+ print_banner("Verifying Installation")
+
+ # Check torch separately for ROCm info
+ try:
+ import torch as _torch
+
+ version = getattr(_torch, "__version__", "unknown")
+ except ImportError as e:
+ print(f"Error: torch import failed ({e!r}). If wheels are installed, run rocm-sdk init first.")
+ return False
+
+ print(f"torch: {version}")
+
+ hip_version = _torch.version.hip
+ print(f"ROCm/HIP: {hip_version or 'not available'}")
+ print(f"Built with ROCm: {hip_version is not None}")
+
+ # Check other packages
+ for name in ["torchaudio", "torchvision", "triton", "rocm"]:
+ installed, version = check_package(name)
+ status = version if installed else "not installed"
+ print(f"{name}: {status}")
+
+ return True
+
+
+def list_installed_packages() -> None:
+ """List installed torch-related packages."""
+ print("\nInstalled PyTorch packages:")
+ result = subprocess.run(
+ [sys.executable, "-m", "pip", "list"],
+ capture_output=True,
+ text=True,
+ check=False,
+ )
+
+ if result.returncode == 0:
+ keywords = ["torch", "triton", "rocm"]
+ for line in result.stdout.splitlines():
+ if any(kw in line.lower() for kw in keywords):
+ print(f" {line}")
+
+
+def main() -> int:
+ """Main entry point."""
+ parser = argparse.ArgumentParser(
+ description="Install PyTorch wheels from a pip index URL",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog=__doc__,
+ )
+
+ parser.add_argument(
+ "--index-url", required=True, help="Base URL for PyTorch wheels index"
+ )
+ parser.add_argument(
+ "--amdgpu-family", required=True, help="AMD GPU family (e.g., gfx1250)"
+ )
+ parser.add_argument(
+ "--rocm-version",
+ help="Optional. ROCm version (e.g. 7.12.0a20260126). When set without --torch-version: discovers and installs latest torch/torchaudio/torchvision/triton built for this ROCm. ",
+ )
+ parser.add_argument(
+ "--torch-version", help="Specific torch version (default: latest)"
+ )
+ parser.add_argument(
+ "--torch-version-prefix",
+ help="Torch version prefix for discovery (e.g. '2.9' matches 2.9.x). "
+ "Only used in auto-discovery mode (--rocm-version without --torch-version).",
+ )
+ parser.add_argument(
+ "--torchaudio-version", help="Specific torchaudio version (default: latest)"
+ )
+ parser.add_argument(
+ "--torchvision-version", help="Specific torchvision version (default: latest)"
+ )
+ parser.add_argument(
+ "--triton-version",
+ help="Specific triton version (default: from torch dependency)",
+ )
+ parser.add_argument(
+ "--no-break-system-packages",
+ action="store_true",
+ help="Don't use --break-system-packages",
+ )
+ parser.add_argument(
+ "--skip-verify", action="store_true", help="Skip verification step"
+ )
+
+ args = parser.parse_args()
+
+ # Build the full index URL
+ index_url = f"{args.index_url.rstrip('/')}/{args.amdgpu_family}/"
+
+ rocm = args.rocm_version
+ rocm_only = bool(rocm and not args.torch_version)
+ torch_prefix = args.torch_version_prefix if rocm_only else None
+ break_sys = not args.no_break_system_packages
+
+ if rocm_only:
+ # Two-pass install:
+ # Pass 1: torch (pinned) + rocm[devel] (pinned)
+ # Pass 2: torchaudio, torchvision, triton (unpinned — pip resolves compatibility)
+ torch_version = get_latest_package_version_for_rocm(
+ index_url, "torch", rocm, required=True, version_prefix=torch_prefix,
+ )
+
+ print_banner("PyTorch Wheels Installation")
+ print(f"Index URL: {index_url}")
+ print(f"AMDGPU Family: {args.amdgpu_family}")
+ print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
+ print(f"torch: {torch_version}")
+ print(f"rocm[devel]: {rocm}")
+ print(f"torchaudio: (pip resolves)")
+ print(f"torchvision: (pip resolves)")
+ print(f"triton: (torch dependency)")
+ print("=" * 50)
+
+ # Pass 1: install torch + rocm[devel] with exact versions.
+ # torch's declared dependency on triton pulls in the correct build.
+ primary = [
+ build_package_spec("torch", torch_version),
+ build_package_spec("rocm[devel]", rocm),
+ ]
+ print_banner("Pass 1: torch + rocm[devel]")
+ print(f"Installing: {', '.join(primary)}")
+ run_pip_install(index_url, primary, break_sys)
+
+ # Pass 2: install torchaudio/torchvision without pinning — pip picks
+ # versions compatible with the torch that's already installed
+ companions = ["torchaudio", "torchvision"]
+ print_banner("Pass 2: torchaudio, torchvision (unpinned)")
+ print(f"Installing: {', '.join(companions)}")
+ run_pip_install(index_url, companions, break_sys)
+ else:
+ # Explicit versions mode — install everything in one shot
+ arg_attrs = ["torch_version", "torchaudio_version", "torchvision_version", "triton_version"]
+ versions = {p: getattr(args, a) for p, a in zip(PYTORCH_PKGS, arg_attrs)}
+ versions["rocm[devel]"] = rocm if rocm else None
+
+ print_banner("PyTorch Wheels Installation")
+ print(f"Index URL: {index_url}")
+ print(f"AMDGPU Family: {args.amdgpu_family}")
+ print(f"Python: {sys.version_info.major}.{sys.version_info.minor}")
+ for name, version in versions.items():
+ print(f"{name:14}: {version or 'latest'}")
+ print("=" * 50)
+
+ packages = []
+ for name, always_install in PACKAGES.items():
+ version = versions.get(name)
+ if always_install or version:
+ packages.append(build_package_spec(name, version))
+
+ print(f"Installing: {', '.join(packages)}")
+ run_pip_install(index_url, packages, break_sys)
+
+ # Verify
+ if not args.skip_verify and not verify_installation():
+ return 1
+
+ list_installed_packages()
+ print_banner("Installation complete")
+ return 0
+
+
+if __name__ == "__main__":
+ sys.exit(main())
diff --git a/.github/scripts/install_rocm_deps.sh b/.github/scripts/install_rocm_deps.sh
new file mode 100644
index 0000000000000..e4c0fd91a1066
--- /dev/null
+++ b/.github/scripts/install_rocm_deps.sh
@@ -0,0 +1,114 @@
+#!/bin/bash
+# install_rocm_deps.sh
+#
+# Installs runtime dependencies for ROCm on various Linux distributions.
+# Automatically detects the distribution and uses the appropriate package manager.
+#
+# Supported distributions:
+# - Ubuntu 22.04, 24.04 (apt)
+# - AlmaLinux 8 (dnf)
+# - Azure Linux 3 (tdnf)
+
+set -e
+
+# Detect distribution type from /etc/os-release
+detect_distro() {
+ if [ -f /etc/os-release ]; then
+ . /etc/os-release
+ echo "$ID"
+ else
+ echo "unknown"
+ fi
+}
+
+DISTRO=$(detect_distro)
+echo "Detected distribution: $DISTRO"
+
+case "$DISTRO" in
+ ubuntu)
+ echo "Installing dependencies using apt..."
+ apt-get update
+ apt-get install -y --no-install-recommends \
+ ca-certificates \
+ curl \
+ build-essential \
+ libelf1 \
+ libnuma1 \
+ libunwind8 \
+ libncurses6 \
+ perl \
+ file \
+ nano \
+ git \
+ python3 \
+ python3-dev \
+ python3-pip \
+ python3-venv \
+ kmod \
+ pkg-config \
+ liblzma-dev \
+ libdrm-dev
+ # libdw: libdw1t64 for Ubuntu 24.04+, libdw1 for older versions
+ apt-get install -y --no-install-recommends libdw1t64 2>/dev/null || \
+ apt-get install -y --no-install-recommends libdw1 || true
+ # libssl: libssl3 for Ubuntu 22.04+, libssl1.1 for older versions
+ apt-get install -y --no-install-recommends libssl3 2>/dev/null || \
+ apt-get install -y --no-install-recommends libssl1.1 || true
+ rm -rf /var/lib/apt/lists/*
+ ;;
+
+ almalinux)
+ echo "Installing dependencies using dnf..."
+ # Fix AlmaLinux repo to use direct baseurl instead of mirrorlist
+ if [ -f /etc/yum.repos.d/almalinux.repo ]; then
+ sed -i 's/^mirrorlist=/#mirrorlist=/g' /etc/yum.repos.d/almalinux.repo
+ sed -i 's/^# baseurl=/baseurl=/g' /etc/yum.repos.d/almalinux.repo
+ fi
+ dnf install -y --setopt=install_weak_deps=False \
+ ca-certificates \
+ curl \
+ libatomic \
+ elfutils-libelf \
+ elfutils-libs \
+ numactl-libs \
+ ncurses-libs \
+ openssl-libs \
+ perl \
+ file \
+ python3 \
+ python3-devel \
+ python3-pip \
+ kmod
+ dnf clean all
+ ;;
+
+ azurelinux)
+ echo "Installing dependencies using tdnf..."
+ tdnf install -y \
+ ca-certificates \
+ curl \
+ tar \
+ libatomic \
+ elfutils-libelf \
+ elfutils-libs \
+ numactl-libs \
+ libunwind \
+ ncurses-libs \
+ openssl-libs \
+ perl \
+ file \
+ python3 \
+ python3-devel \
+ python3-pip \
+ kmod
+ tdnf clean all
+ ;;
+
+ *)
+ echo "Error: Unsupported distribution: $DISTRO"
+ echo "Supported distributions: ubuntu, almalinux, azurelinux"
+ exit 1
+ ;;
+esac
+
+echo "Dependencies installed successfully for $DISTRO"
diff --git a/.github/workflows/build_portable_linux_pytorch_dockers.yml b/.github/workflows/build_portable_linux_pytorch_dockers.yml
new file mode 100644
index 0000000000000..d5c9a94c3b1ad
--- /dev/null
+++ b/.github/workflows/build_portable_linux_pytorch_dockers.yml
@@ -0,0 +1,427 @@
+name: Build Portable Linux PyTorch Dockers
+
+on:
+ schedule:
+ - cron: "0 6 * * *" # daily at 06:00 UTC
+ workflow_dispatch:
+ inputs:
+ pytorch_repo:
+ description: "GitHub repo to clone into the image (e.g. 'pytorch/pytorch' or 'ROCm/pytorch')"
+ type: string
+ default: "pytorch/pytorch"
+ pytorch_branch:
+ description: "Branch to clone. Default 'nightly' matches theRock wheel builds. For releases use ROCm/pytorch with 'release/2.11', 'release/2.10', etc."
+ type: string
+ default: "nightly"
+ python_version:
+ type: choice
+ options:
+ - "3.12"
+ - "3.10"
+ - "3.11"
+ - "3.13"
+ - "3.14"
+ default: "3.12"
+ amdgpu_family:
+ type: choice
+ options:
+ - gfx950-dcgpu
+ - gfx94X-dcgpu
+ - gfx90X-dcgpu
+ - gfx120X-all
+ - gfx110X-all
+ - gfx110X-dgpu
+ - gfx103X-dgpu
+ - gfx101X-dgpu
+ default: gfx94X-dcgpu
+ rocm_version:
+ description: "ROCm version (e.g. '7.13.0a20260413'). Leave empty to auto-discover from the latest available torch wheel."
+ type: string
+ index_url:
+ description: Base URL for PyTorch wheels index
+ type: string
+ default: "https://rocm.nightlies.amd.com/v2-staging"
+
+permissions:
+ contents: read
+
+run-name: >-
+ ${{ github.event_name == 'schedule' && 'Nightly Docker builds' ||
+ format('Build PyTorch Docker ({0}, {1}/{2}, ROCm {3})',
+ inputs.amdgpu_family || 'gfx950-dcgpu',
+ inputs.pytorch_repo || 'pytorch/pytorch',
+ inputs.pytorch_branch || 'nightly',
+ inputs.rocm_version || 'auto') }}
+
+env:
+ REGISTRY: docker.io
+ IMAGE_NAME: rocm/pytorch-private
+ DEFAULT_AMDGPU_FAMILY: gfx950-dcgpu
+ DEFAULT_PYTHON_VERSION: "3.12"
+ DEFAULT_INDEX_URL: "https://rocm.nightlies.amd.com/v2-staging"
+ DEFAULT_BASE_IMAGE: "ubuntu:24.04"
+
+jobs:
+ # ── Nightly matrix build (schedule only) ─────────────────────────────────
+ nightly-matrix:
+ if: github.event_name == 'schedule'
+ strategy:
+ fail-fast: false
+ matrix:
+ include:
+ - pytorch_repo: pytorch/pytorch
+ pytorch_branch: nightly
+ label: nightly
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.11
+ label: "2.11"
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.10
+ label: "2.10"
+ - pytorch_repo: ROCm/pytorch
+ pytorch_branch: release/2.9
+ label: "2.9"
+ name: "Nightly | torch ${{ matrix.label }} | MI355"
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout workflow files
+ uses: actions/checkout@v4
+
+ - name: Checkout PyTorch source
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ matrix.pytorch_repo }}
+ ref: ${{ matrix.pytorch_branch }}
+ path: pytorch-src
+ fetch-depth: 1
+
+ - name: Derive torch version prefix from branch
+ id: prefix
+ run: |
+ BRANCH="${{ matrix.pytorch_branch }}"
+ if [[ "$BRANCH" =~ ^release/([0-9]+\.[0-9]+) ]]; then
+ echo "value=${BASH_REMATCH[1]}" >> $GITHUB_OUTPUT
+ echo "Derived torch prefix: ${BASH_REMATCH[1]}"
+ else
+ echo "value=" >> $GITHUB_OUTPUT
+ echo "No prefix (nightly/main branch)"
+ fi
+
+ - name: Discover ROCm version from index
+ id: discover
+ run: |
+ python3 - "${{ env.DEFAULT_INDEX_URL }}" "${{ env.DEFAULT_AMDGPU_FAMILY }}" "${{ steps.prefix.outputs.value }}" <<'PYEOF'
+ import re, sys, urllib.request, urllib.parse
+
+ index_url, gpu_family = sys.argv[1], sys.argv[2]
+ prefix = sys.argv[3] if len(sys.argv) > 3 else ""
+
+ url = f"{index_url.rstrip('/')}/{gpu_family}/torch/"
+ print(f"Fetching torch index: {url}")
+ html = urllib.request.urlopen(url, timeout=60).read().decode()
+
+ pattern = re.compile(r"torch-(.+?)\.whl", re.IGNORECASE)
+ versions = []
+ for m in pattern.finditer(html):
+ ver = urllib.parse.unquote(m.group(1).split("-")[0])
+ if "+rocm" in ver:
+ versions.append(ver)
+
+ if prefix:
+ versions = [v for v in versions if v.split("+")[0].startswith(prefix)]
+
+ if not versions:
+ print(f"::error::No torch wheels found (prefix={prefix!r})")
+ sys.exit(1)
+
+ def key(v):
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+
+ latest = max(versions, key=key)
+ rocm_ver = re.search(r"\+rocm(.+)", latest).group(1)
+
+ print(f"Latest torch wheel: {latest}")
+ print(f"Discovered ROCm version: {rocm_ver}")
+
+ import os
+ with open(os.environ["GITHUB_OUTPUT"], "a") as f:
+ f.write(f"rocm_version={rocm_ver}\n")
+ f.write(f"torch_wheel_version={latest}\n")
+ PYEOF
+
+ - name: Resolve config
+ id: cfg
+ run: |
+ echo "amdgpu_family=${{ env.DEFAULT_AMDGPU_FAMILY }}" >> $GITHUB_OUTPUT
+ echo "python_version=${{ env.DEFAULT_PYTHON_VERSION }}" >> $GITHUB_OUTPUT
+ echo "rocm_version=${{ steps.discover.outputs.rocm_version }}" >> $GITHUB_OUTPUT
+ echo "index_url=${{ env.DEFAULT_INDEX_URL }}" >> $GITHUB_OUTPUT
+ echo "base_image=${{ env.DEFAULT_BASE_IMAGE }}" >> $GITHUB_OUTPUT
+ echo "torch_prefix=${{ steps.prefix.outputs.value }}" >> $GITHUB_OUTPUT
+ echo "pytorch_repo=${{ matrix.pytorch_repo }}" >> $GITHUB_OUTPUT
+ echo "pytorch_branch=${{ matrix.pytorch_branch }}" >> $GITHUB_OUTPUT
+
+ COMMIT="$(cd pytorch-src && git rev-parse --short=8 HEAD)"
+ echo "pytorch_commit=${COMMIT}" >> $GITHUB_OUTPUT
+
+ - name: Generate Docker image tag
+ id: docker-tag
+ run: |
+ BRANCH="${{ matrix.pytorch_branch }}"
+ BRANCH_SAFE="${BRANCH//\//-}"
+ COMMIT="${{ steps.cfg.outputs.pytorch_commit }}"
+ ROCM_VERSION="${{ steps.cfg.outputs.rocm_version }}"
+ PYTHON_VERSION="${{ steps.cfg.outputs.python_version }}"
+ GFX="${{ steps.cfg.outputs.amdgpu_family }}"
+ BASE_IMAGE="${{ steps.cfg.outputs.base_image }}"
+ OS=$(echo "${BASE_IMAGE}" | tr -d ':' | tr '/' '-')
+
+ IMAGE_TAG="pytorch-${BRANCH_SAFE}-${COMMIT}-rocm${ROCM_VERSION}-${OS}-py${PYTHON_VERSION}-${GFX}"
+ IMAGE_TAG="${IMAGE_TAG//+/-}"
+ echo "tag=${IMAGE_TAG}" >> $GITHUB_OUTPUT
+ echo "Generated image tag: ${IMAGE_TAG}"
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERUSERNAME }}
+ password: ${{ secrets.DOCKERTOKEN }}
+
+ - name: Prepare build context
+ run: |
+ cp dockerfiles/Dockerfile pytorch-src/
+ mkdir -p pytorch-src/.github/scripts
+ cp .github/scripts/install_rocm_deps.sh pytorch-src/.github/scripts/
+ cp .github/scripts/install_pytorch_wheels.py pytorch-src/.github/scripts/
+
+ - name: Build Docker image
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+
+ docker build \
+ --file pytorch-src/Dockerfile \
+ --tag "${IMAGE}" \
+ --label "pytorch.repo=${{ matrix.pytorch_repo }}" \
+ --label "pytorch.branch=${{ matrix.pytorch_branch }}" \
+ --label "pytorch.commit=${{ steps.cfg.outputs.pytorch_commit }}" \
+ --build-arg "BASE_IMAGE=${{ steps.cfg.outputs.base_image }}" \
+ --build-arg "ROCM_VERSION=${{ steps.cfg.outputs.rocm_version }}" \
+ --build-arg "AMDGPU_FAMILY=${{ steps.cfg.outputs.amdgpu_family }}" \
+ --build-arg "PYTHON_VERSION=${{ steps.cfg.outputs.python_version }}" \
+ --build-arg "INDEX_URL=${{ steps.cfg.outputs.index_url }}" \
+ --build-arg "TORCH_VERSION_PREFIX=${{ steps.prefix.outputs.value }}" \
+ pytorch-src
+
+ echo "Docker image built successfully: ${IMAGE}"
+
+ - name: Get ROCm packages info
+ id: rocm-packages
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ ROCM_PACKAGES=$(docker run --rm "${IMAGE}" pip freeze | grep -i rocm || echo "No ROCm packages found")
+ echo "rocm_packages<> $GITHUB_OUTPUT
+ echo "${ROCM_PACKAGES}" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+ echo "ROCm packages:"
+ echo "${ROCM_PACKAGES}"
+
+ - name: Push Docker image
+ run: |
+ docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}
+ echo "Docker image pushed successfully"
+
+ - name: Post-build summary
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ echo "## PyTorch Docker Build Summary — ${{ matrix.label }}" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "| Parameter | Value |" >> $GITHUB_STEP_SUMMARY
+ echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY
+ echo "| Image | \`${IMAGE}\` |" >> $GITHUB_STEP_SUMMARY
+ echo "| Torch Wheel | ${{ steps.discover.outputs.torch_wheel_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Repo | ${{ matrix.pytorch_repo }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Branch | ${{ matrix.pytorch_branch }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Commit | ${{ steps.cfg.outputs.pytorch_commit }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| AMDGPU Family | ${{ steps.cfg.outputs.amdgpu_family }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Python | ${{ steps.cfg.outputs.python_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| ROCm (discovered) | ${{ steps.cfg.outputs.rocm_version }} |" >> $GITHUB_STEP_SUMMARY
+
+ # ── Single image build (manual dispatch) ──────────────────────────────────
+ build-docker:
+ if: github.event_name == 'workflow_dispatch'
+ name: "Build | ${{ inputs.amdgpu_family }} | ${{ inputs.pytorch_repo || 'pytorch/pytorch' }}@${{ inputs.pytorch_branch || 'nightly' }}"
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout workflow files
+ uses: actions/checkout@v4
+
+ - name: Checkout PyTorch source
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ inputs.pytorch_repo || 'pytorch/pytorch' }}
+ ref: ${{ inputs.pytorch_branch || 'nightly' }}
+ path: pytorch-src
+ fetch-depth: 1
+
+ - name: Derive torch version prefix from branch
+ id: prefix
+ run: |
+ BRANCH="${{ inputs.pytorch_branch || 'nightly' }}"
+ if [[ "$BRANCH" =~ ^release/([0-9]+\.[0-9]+) ]]; then
+ echo "value=${BASH_REMATCH[1]}" >> $GITHUB_OUTPUT
+ echo "Derived torch prefix: ${BASH_REMATCH[1]}"
+ else
+ echo "value=" >> $GITHUB_OUTPUT
+ echo "No prefix (nightly/main branch)"
+ fi
+
+ - name: Discover ROCm version from index
+ id: discover
+ if: ${{ !inputs.rocm_version }}
+ run: |
+ python3 - "${{ inputs.index_url || env.DEFAULT_INDEX_URL }}" "${{ inputs.amdgpu_family || env.DEFAULT_AMDGPU_FAMILY }}" "${{ steps.prefix.outputs.value }}" <<'PYEOF'
+ import re, sys, urllib.request, urllib.parse
+
+ index_url, gpu_family = sys.argv[1], sys.argv[2]
+ prefix = sys.argv[3] if len(sys.argv) > 3 else ""
+
+ url = f"{index_url.rstrip('/')}/{gpu_family}/torch/"
+ print(f"Fetching torch index: {url}")
+ html = urllib.request.urlopen(url, timeout=60).read().decode()
+
+ pattern = re.compile(r"torch-(.+?)\.whl", re.IGNORECASE)
+ versions = []
+ for m in pattern.finditer(html):
+ ver = urllib.parse.unquote(m.group(1).split("-")[0])
+ if "+rocm" in ver:
+ versions.append(ver)
+
+ if prefix:
+ versions = [v for v in versions if v.split("+")[0].startswith(prefix)]
+
+ if not versions:
+ print(f"::error::No torch wheels found (prefix={prefix!r})")
+ sys.exit(1)
+
+ def key(v):
+ try:
+ return tuple(int(x) for x in re.split(r"[.\-a+]", v) if x.isdigit())
+ except (ValueError, AttributeError):
+ return (0,)
+
+ latest = max(versions, key=key)
+ rocm_ver = re.search(r"\+rocm(.+)", latest).group(1)
+
+ print(f"Latest torch wheel: {latest}")
+ print(f"Discovered ROCm version: {rocm_ver}")
+
+ import os
+ with open(os.environ["GITHUB_OUTPUT"], "a") as f:
+ f.write(f"rocm_version={rocm_ver}\n")
+ f.write(f"torch_wheel_version={latest}\n")
+ PYEOF
+
+ - name: Resolve inputs with defaults
+ id: cfg
+ run: |
+ echo "amdgpu_family=${{ inputs.amdgpu_family || env.DEFAULT_AMDGPU_FAMILY }}" >> $GITHUB_OUTPUT
+ echo "python_version=${{ inputs.python_version || env.DEFAULT_PYTHON_VERSION }}" >> $GITHUB_OUTPUT
+
+ # Use explicit rocm_version if provided, otherwise use discovered version
+ ROCM="${{ inputs.rocm_version || steps.discover.outputs.rocm_version }}"
+ echo "rocm_version=${ROCM}" >> $GITHUB_OUTPUT
+
+ echo "index_url=${{ inputs.index_url || env.DEFAULT_INDEX_URL }}" >> $GITHUB_OUTPUT
+ echo "base_image=${{ env.DEFAULT_BASE_IMAGE }}" >> $GITHUB_OUTPUT
+ echo "torch_prefix=${{ steps.prefix.outputs.value }}" >> $GITHUB_OUTPUT
+ echo "pytorch_repo=${{ inputs.pytorch_repo || 'pytorch/pytorch' }}" >> $GITHUB_OUTPUT
+ echo "pytorch_branch=${{ inputs.pytorch_branch || 'nightly' }}" >> $GITHUB_OUTPUT
+
+ COMMIT="$(cd pytorch-src && git rev-parse --short=8 HEAD)"
+ echo "pytorch_commit=${COMMIT}" >> $GITHUB_OUTPUT
+
+ - name: Generate Docker image tag
+ id: docker-tag
+ run: |
+ BRANCH="${{ steps.cfg.outputs.pytorch_branch }}"
+ BRANCH_SAFE="${BRANCH//\//-}"
+ COMMIT="${{ steps.cfg.outputs.pytorch_commit }}"
+ ROCM_VERSION="${{ steps.cfg.outputs.rocm_version }}"
+ PYTHON_VERSION="${{ steps.cfg.outputs.python_version }}"
+ GFX="${{ steps.cfg.outputs.amdgpu_family }}"
+ BASE_IMAGE="${{ steps.cfg.outputs.base_image }}"
+ OS=$(echo "${BASE_IMAGE}" | tr -d ':' | tr '/' '-')
+
+ IMAGE_TAG="pytorch-${BRANCH_SAFE}-${COMMIT}-rocm${ROCM_VERSION}-${OS}-py${PYTHON_VERSION}-${GFX}"
+ IMAGE_TAG="${IMAGE_TAG//+/-}"
+ echo "tag=${IMAGE_TAG}" >> $GITHUB_OUTPUT
+ echo "Generated image tag: ${IMAGE_TAG}"
+
+ - name: Log in to Docker Hub
+ uses: docker/login-action@v3
+ with:
+ username: ${{ secrets.DOCKERUSERNAME }}
+ password: ${{ secrets.DOCKERTOKEN }}
+
+ - name: Prepare build context
+ run: |
+ cp dockerfiles/Dockerfile pytorch-src/
+ mkdir -p pytorch-src/.github/scripts
+ cp .github/scripts/install_rocm_deps.sh pytorch-src/.github/scripts/
+ cp .github/scripts/install_pytorch_wheels.py pytorch-src/.github/scripts/
+
+ - name: Build Docker image
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+
+ docker build \
+ --file pytorch-src/Dockerfile \
+ --tag "${IMAGE}" \
+ --label "pytorch.repo=${{ steps.cfg.outputs.pytorch_repo }}" \
+ --label "pytorch.branch=${{ steps.cfg.outputs.pytorch_branch }}" \
+ --label "pytorch.commit=${{ steps.cfg.outputs.pytorch_commit }}" \
+ --build-arg "BASE_IMAGE=${{ steps.cfg.outputs.base_image }}" \
+ --build-arg "ROCM_VERSION=${{ steps.cfg.outputs.rocm_version }}" \
+ --build-arg "AMDGPU_FAMILY=${{ steps.cfg.outputs.amdgpu_family }}" \
+ --build-arg "PYTHON_VERSION=${{ steps.cfg.outputs.python_version }}" \
+ --build-arg "INDEX_URL=${{ steps.cfg.outputs.index_url }}" \
+ --build-arg "TORCH_VERSION_PREFIX=${{ steps.cfg.outputs.torch_prefix }}" \
+ pytorch-src
+
+ echo "Docker image built successfully: ${IMAGE}"
+
+ - name: Get ROCm packages info
+ id: rocm-packages
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ ROCM_PACKAGES=$(docker run --rm "${IMAGE}" pip freeze | grep -i rocm || echo "No ROCm packages found")
+ echo "rocm_packages<> $GITHUB_OUTPUT
+ echo "${ROCM_PACKAGES}" >> $GITHUB_OUTPUT
+ echo "EOF" >> $GITHUB_OUTPUT
+ echo "ROCm packages:"
+ echo "${ROCM_PACKAGES}"
+
+ - name: Push Docker image
+ run: |
+ docker push ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}
+ echo "Docker image pushed successfully"
+
+ - name: Post-build summary
+ run: |
+ IMAGE="${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}:${{ steps.docker-tag.outputs.tag }}"
+ echo "## PyTorch Docker Build Summary" >> $GITHUB_STEP_SUMMARY
+ echo "" >> $GITHUB_STEP_SUMMARY
+ echo "| Parameter | Value |" >> $GITHUB_STEP_SUMMARY
+ echo "|-----------|-------|" >> $GITHUB_STEP_SUMMARY
+ echo "| Image | \`${IMAGE}\` |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Repo | ${{ steps.cfg.outputs.pytorch_repo }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Branch | ${{ steps.cfg.outputs.pytorch_branch }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| PyTorch Commit | ${{ steps.cfg.outputs.pytorch_commit }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| AMDGPU Family | ${{ steps.cfg.outputs.amdgpu_family }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Python | ${{ steps.cfg.outputs.python_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| ROCm | ${{ steps.cfg.outputs.rocm_version }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Torch Version Prefix | ${{ steps.cfg.outputs.torch_prefix || 'latest' }} |" >> $GITHUB_STEP_SUMMARY
+ echo "| Index URL | ${{ steps.cfg.outputs.index_url }} |" >> $GITHUB_STEP_SUMMARY
diff --git a/.github/workflows/create_ifu_issues.yml b/.github/workflows/create_ifu_issues.yml
new file mode 100644
index 0000000000000..8e2e7da07ab43
--- /dev/null
+++ b/.github/workflows/create_ifu_issues.yml
@@ -0,0 +1,352 @@
+name: Create issues for ROCm commits
+
+on:
+ # Manual trigger for testing
+ workflow_dispatch:
+ inputs:
+ prev_post_tag:
+ description: "Issue range start ref (previous IFU post tag or cold-start SHA)"
+ required: true
+ type: string
+ curr_pre_tag:
+ description: "Current IFU pre tag"
+ required: true
+ type: string
+ target_repo:
+ description: "Target repo for issue creation"
+ required: false
+ default: "ROCm/pytorch"
+ type: string
+ project_number:
+ description: "GitHub Project number"
+ required: false
+ default: "114"
+ type: string
+ project_owner:
+ description: "Project owner"
+ required: false
+ default: "ROCm"
+ type: string
+
+ # Called by create_ifu_tag.yml after tagging
+ workflow_call:
+ inputs:
+ prev_post_tag:
+ description: "Issue range start ref (previous IFU post tag or cold-start SHA)"
+ required: true
+ type: string
+ curr_pre_tag:
+ description: "Current IFU pre tag"
+ required: true
+ type: string
+ target_repo:
+ description: "Target repo for issue creation"
+ required: false
+ default: "ROCm/pytorch"
+ type: string
+ project_number:
+ description: "GitHub Project number"
+ required: false
+ default: "114"
+ type: string
+ project_owner:
+ description: "Project owner"
+ required: false
+ default: "ROCm"
+ type: string
+ secrets:
+ IFU_GITHUB_TOKEN:
+ required: true
+
+permissions:
+ contents: read
+ issues: write
+
+jobs:
+ create-issues:
+ runs-on: ubuntu-latest
+ env:
+ # Use passed secret for workflow_call, direct secret for workflow_dispatch
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@v4
+ with:
+ fetch-depth: 0
+
+ - name: Fetch tags
+ run: git fetch origin --tags --force
+
+ - name: Extract branch from tag
+ id: parse
+ env:
+ CURR_PRE_TAG: ${{ inputs.curr_pre_tag }}
+ run: |
+ branch="${CURR_PRE_TAG%_IFU_*}"
+ echo "Branch: $branch"
+ echo "branch=$branch" >> $GITHUB_OUTPUT
+
+ - name: Fetch upstream
+ run: |
+ git remote add upstream https://github.com/pytorch/pytorch.git 2>/dev/null || true
+ git fetch upstream main --force
+
+ - name: List commits in range
+ run: |
+ echo "ROCm-only commits between ${{ inputs.prev_post_tag }} and ${{ inputs.curr_pre_tag }}:"
+ git log ${{ inputs.prev_post_tag }}..${{ inputs.curr_pre_tag }} --oneline --no-merges --not upstream/main
+
+ - name: Get or create project fields
+ id: project_fields
+ if: ${{ inputs.project_number != '' }}
+ env:
+ PROJECT_NUMBER: ${{ inputs.project_number }}
+ PROJECT_OWNER: ${{ inputs.project_owner }}
+ run: |
+ echo "Getting project information..."
+
+ # Try user-owned project first.
+ project_data=$(gh api graphql -f query='
+ query($owner: String!, $number: Int!) {
+ user(login: $owner) {
+ projectV2(number: $number) {
+ id
+ fields(first: 50) {
+ nodes {
+ ... on ProjectV2Field {
+ id
+ name
+ dataType
+ }
+ ... on ProjectV2SingleSelectField {
+ id
+ name
+ dataType
+ }
+ }
+ }
+ }
+ }
+ }' -f owner="${PROJECT_OWNER}" -F number="${PROJECT_NUMBER}" 2>/dev/null || true)
+
+ project_id=$(echo "$project_data" | jq -r '.data.user.projectV2.id // empty' 2>/dev/null || true)
+ echo "User project ID: ${project_id:-'(none)'}"
+
+ if [[ -z "$project_id" ]]; then
+ echo "User project not found (or owner is an org). Trying organization query..."
+ project_data=$(gh api graphql -f query='
+ query($owner: String!, $number: Int!) {
+ organization(login: $owner) {
+ projectV2(number: $number) {
+ id
+ fields(first: 50) {
+ nodes {
+ ... on ProjectV2Field {
+ id
+ name
+ dataType
+ }
+ ... on ProjectV2SingleSelectField {
+ id
+ name
+ dataType
+ }
+ }
+ }
+ }
+ }
+ }' -f owner="${PROJECT_OWNER}" -F number="${PROJECT_NUMBER}" 2>/dev/null || true)
+
+ project_id=$(echo "$project_data" | jq -r '.data.organization.projectV2.id // empty' 2>/dev/null || true)
+ fields_json=$(echo "$project_data" | jq -r '.data.organization.projectV2.fields.nodes // empty' 2>/dev/null || true)
+ else
+ fields_json=$(echo "$project_data" | jq -r '.data.user.projectV2.fields.nodes // empty' 2>/dev/null || true)
+ fi
+
+ if [[ -z "$project_id" || -z "$fields_json" ]]; then
+ echo "Error: Could not resolve project owner '${PROJECT_OWNER}' project #${PROJECT_NUMBER}."
+ echo "If PROJECT_OWNER is an organization, ensure PROJECT_OWNER is exactly the org login and token has org access."
+ exit 1
+ fi
+
+ echo "Project ID: $project_id"
+ echo "project_id=$project_id" >> $GITHUB_OUTPUT
+
+ # Find or create 'branch' field
+ branch_field_id=$(echo "$fields_json" | jq -r '.[] | select(.name == "branch") | .id')
+ if [[ -z "$branch_field_id" || "$branch_field_id" == "null" ]]; then
+ echo "Creating 'branch' field..."
+ branch_field_id=$(gh api graphql -f query='
+ mutation($projectId: ID!, $name: String!) {
+ createProjectV2Field(input: {projectId: $projectId, dataType: TEXT, name: $name}) {
+ projectV2Field {
+ ... on ProjectV2Field {
+ id
+ }
+ }
+ }
+ }' -f projectId="$project_id" -f name="branch" --jq '.data.createProjectV2Field.projectV2Field.id')
+ echo "Created 'branch' field: $branch_field_id"
+ else
+ echo "Found existing 'branch' field: $branch_field_id"
+ fi
+ echo "branch_field_id=$branch_field_id" >> $GITHUB_OUTPUT
+
+ # Find or create 'commit_hash' field
+ commit_hash_field_id=$(echo "$fields_json" | jq -r '.[] | select(.name == "commit_hash") | .id')
+ if [[ -z "$commit_hash_field_id" || "$commit_hash_field_id" == "null" ]]; then
+ echo "Creating 'commit_hash' field..."
+ commit_hash_field_id=$(gh api graphql -f query='
+ mutation($projectId: ID!, $name: String!) {
+ createProjectV2Field(input: {projectId: $projectId, dataType: TEXT, name: $name}) {
+ projectV2Field {
+ ... on ProjectV2Field {
+ id
+ }
+ }
+ }
+ }' -f projectId="$project_id" -f name="commit_hash" --jq '.data.createProjectV2Field.projectV2Field.id')
+ echo "Created 'commit_hash' field: $commit_hash_field_id"
+ else
+ echo "Found existing 'commit_hash' field: $commit_hash_field_id"
+ fi
+ echo "commit_hash_field_id=$commit_hash_field_id" >> $GITHUB_OUTPUT
+
+ - name: Create issues for commits
+ env:
+ PREV_POST_TAG: ${{ inputs.prev_post_tag }}
+ CURR_PRE_TAG: ${{ inputs.curr_pre_tag }}
+ TARGET_REPO: ${{ inputs.target_repo }}
+ PROJECT_NUMBER: ${{ inputs.project_number }}
+ PROJECT_OWNER: ${{ inputs.project_owner }}
+ REPO_NAME: ${{ github.repository }}
+ BRANCH: ${{ steps.parse.outputs.branch }}
+ PROJECT_ID: ${{ steps.project_fields.outputs.project_id }}
+ BRANCH_FIELD_ID: ${{ steps.project_fields.outputs.branch_field_id }}
+ COMMIT_HASH_FIELD_ID: ${{ steps.project_fields.outputs.commit_hash_field_id }}
+ run: |
+ echo "Creating issues for commits..."
+
+ commit_count=$(git rev-list --count --no-merges "${PREV_POST_TAG}..${CURR_PRE_TAG}" --not upstream/main)
+ if [[ "${commit_count}" -eq 0 ]]; then
+ echo "No ROCm-only commits in range ${PREV_POST_TAG}..${CURR_PRE_TAG}; nothing to create."
+ exit 0
+ fi
+
+ echo "Found ${commit_count} ROCm-only commits to process."
+
+ git log "${PREV_POST_TAG}..${CURR_PRE_TAG}" --format="%H" --no-merges --not upstream/main | while read hash; do
+ short_hash="${hash:0:5}"
+ subject=$(git log -1 --format="%s" "$hash")
+ author=$(git log -1 --format="%an" "$hash")
+ email=$(git log -1 --format="%ae" "$hash")
+
+ echo "Processing ${short_hash}: ${subject}"
+
+ # Try to get GitHub username via API first
+ gh_username=""
+ gh_username=$(gh api "repos/${REPO_NAME}/commits/${hash}" --jq '.author.login // empty' 2>/dev/null || true)
+
+ if [[ -z "${gh_username}" ]]; then
+ # Fallback: try to extract from noreply email
+ if [[ "$email" =~ ^[0-9]+\+([^@]+)@users\.noreply\.github\.com$ ]]; then
+ gh_username="${BASH_REMATCH[1]}"
+ echo " Extracted username from email: ${gh_username}"
+ fi
+ else
+ echo " Found GitHub username via API: ${gh_username}"
+ fi
+
+ # Dedupe by commit hash marker in issue body across all issue states.
+ existing_issue_url=$(gh issue list \
+ --repo "${TARGET_REPO}" \
+ --state all \
+ --search "\"${hash}\" in:body" \
+ --limit 20 \
+ --json url,body \
+ | jq -r --arg hash "$hash" '.[] | select((.body // "") | contains("**Commit:** " + $hash)) | .url' \
+ | head -n 1 || true)
+ if [[ -n "${existing_issue_url}" ]]; then
+ echo " Existing issue found for commit ${short_hash}: ${existing_issue_url}"
+ echo " Skipping duplicate issue creation."
+ continue
+ fi
+
+ body="**Commit:** ${hash}"$'\n'"**Author:** ${author} (${email})"$'\n'"**Branch:** ${BRANCH}"$'\n'"**Link:** [View commit](https://github.com/${REPO_NAME}/commit/${hash})"
+
+ issue_url=$(gh issue create \
+ --repo "${TARGET_REPO}" \
+ --title "${subject}" \
+ --body "${body}" 2>/dev/null || true)
+
+ if [[ -z "${issue_url}" ]]; then
+ echo " ERROR: Failed to create issue for ${short_hash}. Skipping."
+ continue
+ fi
+
+ echo " Created: ${issue_url}"
+
+ # Try to assign the issue
+ if [[ -n "${gh_username}" ]]; then
+ echo " Trying to assign to @${gh_username}..."
+ if gh issue edit "${issue_url}" --add-assignee "${gh_username}" 2>/dev/null; then
+ echo " Successfully assigned issue"
+ else
+ echo " Could not assign, adding comment instead"
+ gh issue comment "${issue_url}" --body "cc @${gh_username} - you authored this commit" || true
+ fi
+ fi
+
+ # Add to project and set field values
+ if [[ -n "${PROJECT_NUMBER}" && -n "${PROJECT_ID}" ]]; then
+ echo " Adding to project..."
+ item_id=$(gh project item-add "${PROJECT_NUMBER}" --owner "${PROJECT_OWNER}" --url "${issue_url}" --format json 2>/dev/null | jq -r '.id' || true)
+
+ if [[ -n "${item_id}" && "${item_id}" != "null" ]]; then
+ echo " Project item ID: ${item_id}"
+
+ # Set branch field
+ if [[ -n "${BRANCH_FIELD_ID}" ]]; then
+ echo " Setting branch field to: ${BRANCH}"
+ gh api graphql -f query='
+ mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
+ updateProjectV2ItemFieldValue(input: {
+ projectId: $projectId
+ itemId: $itemId
+ fieldId: $fieldId
+ value: {text: $value}
+ }) {
+ projectV2Item {
+ id
+ }
+ }
+ }' -f projectId="${PROJECT_ID}" -f itemId="${item_id}" -f fieldId="${BRANCH_FIELD_ID}" -f value="${BRANCH}" || echo " Warning: Failed to set branch field"
+ fi
+
+ # Set commit_hash field
+ if [[ -n "${COMMIT_HASH_FIELD_ID}" ]]; then
+ echo " Setting commit_hash field to: ${hash}"
+ gh api graphql -f query='
+ mutation($projectId: ID!, $itemId: ID!, $fieldId: ID!, $value: String!) {
+ updateProjectV2ItemFieldValue(input: {
+ projectId: $projectId
+ itemId: $itemId
+ fieldId: $fieldId
+ value: {text: $value}
+ }) {
+ projectV2Item {
+ id
+ }
+ }
+ }' -f projectId="${PROJECT_ID}" -f itemId="${item_id}" -f fieldId="${COMMIT_HASH_FIELD_ID}" -f value="${hash}" || echo " Warning: Failed to set commit_hash field"
+ fi
+ else
+ echo " Warning: Could not get project item ID"
+ fi
+ fi
+
+ sleep 1
+ done
+
+ echo "Done creating issues!"
diff --git a/.github/workflows/create_ifu_tag.yml b/.github/workflows/create_ifu_tag.yml
new file mode 100644
index 0000000000000..7dc766cd06b0a
--- /dev/null
+++ b/.github/workflows/create_ifu_tag.yml
@@ -0,0 +1,352 @@
+name: Create git tags for IFU PRs
+
+on:
+ # ORIGINAL: Triggered when an IFU PR is merged
+ pull_request:
+ types: [closed]
+
+ # Test harness - manually trigger to test without a real PR merge
+ workflow_dispatch:
+ inputs:
+ test_branch:
+ description: "Branch name to test (e.g., rocm7.1_internal_testing)"
+ required: true
+ type: string
+ test_curr_pre_tag:
+ description: "Pre tag to use as curr_pre_tag (required for full chain test)"
+ required: false
+ type: string
+ test_issue_prev_ref:
+ description: "Optional issue range start ref for cold-start full-chain test (tag or SHA)"
+ required: false
+ type: string
+ run_full_chain:
+ description: "Run full chain - actually call create_ifu_issues.yml (will create real issues!)"
+ required: false
+ default: false
+ type: boolean
+ pr_num:
+ description: "Merged IFU PR number — runs full pipeline (tags, PR body, create_issues) as if that PR just merged"
+ required: false
+ default: 0
+ type: number
+
+permissions:
+ contents: write # create/push tags
+ pull-requests: write # edit PR body
+ issues: write # needed for create_ifu_issues.yml when called
+
+jobs:
+ tag-ifu:
+ # Run for workflow_dispatch (test mode) OR for real PR merges
+ if: >
+ github.event_name == 'workflow_dispatch' ||
+ (github.event.pull_request.merged == true &&
+ contains(github.event.pull_request.title, '[AUTOGENERATED]') &&
+ contains(github.event.pull_request.title, 'IFU'))
+ runs-on: ubuntu-latest
+
+ # Export values so the create-issues job can use them
+ outputs:
+ prev_post_tag: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ curr_pre_tag: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num == 0 && inputs.test_curr_pre_tag) || steps.tagname.outputs.PRE_TAG }}
+ has_prev_tag: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ issue_prev_ref: ${{ steps.prev_ref.outputs.issue_prev_ref }}
+ can_create_issues: ${{ steps.prev_ref.outputs.can_create_issues }}
+
+ steps:
+ - name: Validate test inputs
+ if: github.event_name == 'workflow_dispatch' && inputs.run_full_chain == true
+ run: |
+ if [[ -z "${{ inputs.test_curr_pre_tag }}" ]]; then
+ echo "ERROR: test_curr_pre_tag is required when run_full_chain is enabled"
+ echo "Please provide an existing pre tag (e.g., rocm7.1_internal_testing_IFU_2025-10-29_pre)"
+ exit 1
+ fi
+ echo "Full chain test enabled with:"
+ echo " test_branch: ${{ inputs.test_branch }}"
+ echo " test_curr_pre_tag: ${{ inputs.test_curr_pre_tag }}"
+ if [[ -n "${{ inputs.test_issue_prev_ref }}" ]]; then
+ echo " test_issue_prev_ref: ${{ inputs.test_issue_prev_ref }}"
+ fi
+
+ # When dispatch + pr_num: fetch PR via API so we have base.ref, head.sha, merge_commit_sha, title (no event.pull_request in dispatch).
+ - name: Get PR details
+ id: get_pr
+ if: github.event_name == 'workflow_dispatch' && inputs.pr_num != 0
+ env:
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+ run: |
+ set -euo pipefail
+ PR_JSON=$(gh api "repos/${{ github.repository }}/pulls/${{ inputs.pr_num }}")
+ MERGE_SHA=$(echo "$PR_JSON" | jq -r .merge_commit_sha)
+ if [[ "$MERGE_SHA" == "null" || -z "$MERGE_SHA" ]]; then
+ echo "ERROR: PR #${{ inputs.pr_num }} is not merged yet. Use a merged IFU PR number."
+ exit 1
+ fi
+ echo "base_ref=$(echo "$PR_JSON" | jq -r .base.ref)" >> "$GITHUB_OUTPUT"
+ echo "head_sha=$(echo "$PR_JSON" | jq -r .head.sha)" >> "$GITHUB_OUTPUT"
+ echo "merge_sha=$MERGE_SHA" >> "$GITHUB_OUTPUT"
+ echo "title=$(echo "$PR_JSON" | jq -r .title)" >> "$GITHUB_OUTPUT"
+ echo "pr_num=${{ inputs.pr_num }}" >> "$GITHUB_OUTPUT"
+ echo "Fetched PR #${{ inputs.pr_num }}: base=$(echo "$PR_JSON" | jq -r .base.ref), merge_sha=$MERGE_SHA"
+
+ - name: Checkout base repo (full history)
+ uses: actions/checkout@v4
+ with:
+ # Worflow_dispatch
+ # pr_num != 0 -> use pr details from json which we got in get_pr step
+ # pr_num == 0 -> use current branch
+ # PR merge -> use base.ref
+ ref: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0 && steps.get_pr.outputs.base_ref) || (github.event_name == 'workflow_dispatch' && github.ref) || github.event.pull_request.base.ref }}
+ fetch-depth: 0
+ token: ${{ secrets.IFU_GITHUB_TOKEN }}
+
+ # Fetch all tags so we can find the previous post tag
+ - name: Fetch all tags
+ run: git fetch origin --tags --force
+
+ - name: Configure Git user
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Derive key SHAs (rocm base, upstream main, merge)
+ id: shas
+ if: (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0) || (github.event_name != 'workflow_dispatch')
+ env:
+ PR_NUM: ${{ steps.get_pr.outputs.pr_num || github.event.pull_request.number }}
+ BASE_REF: ${{ steps.get_pr.outputs.base_ref || github.event.pull_request.base.ref }}
+ HEAD_SHA: ${{ steps.get_pr.outputs.head_sha || github.event.pull_request.head.sha }}
+ MERGE_SHA: ${{ steps.get_pr.outputs.merge_sha || github.event.pull_request.merge_commit_sha }}
+ shell: bash
+ run: |
+ set -euo pipefail
+
+ # Upstream ref is usually the same as base branch. For rocm/pytorch's
+ # develop branch, compare against upstream/main.
+ UPSTREAM_REF="$BASE_REF"
+ if [ "$UPSTREAM_REF" == "develop" ]; then
+ UPSTREAM_REF="main"
+ fi
+
+ echo "PR_NUM=$PR_NUM"
+ echo "BASE_REF=$BASE_REF"
+ echo "UPSTREAM_REF=$UPSTREAM_REF"
+ echo "HEAD_SHA=$HEAD_SHA"
+ echo "MERGE_SHA=$MERGE_SHA"
+
+ # The ROCm base commit is the first parent of the merge commit that landed the PR
+ # (i.e., the base branch tip BEFORE this PR merged).
+ ROCM_BASE_SHA=$(git rev-parse "${MERGE_SHA}^1")
+
+ # Add upstream if missing.
+ if ! git remote get-url upstream >/dev/null 2>&1; then
+ git remote add upstream "https://github.com/pytorch/pytorch.git"
+ fi
+
+ # Some IFU base branches may not exist in upstream (e.g., fork-only/test branches).
+ # In that case, fall back to upstream/main.
+ if ! git ls-remote --exit-code --heads upstream "$UPSTREAM_REF" >/dev/null 2>&1; then
+ echo "Upstream branch '$UPSTREAM_REF' not found; falling back to upstream/main"
+ UPSTREAM_REF="main"
+ fi
+ git fetch upstream "$UPSTREAM_REF"
+
+ # Heuristic: the upstream commit integrated by the PR's head is the merge-base
+ # between the PR head commit and upstream/main as fetched now.
+ # This gives you the exact upstream commit (or the best common ancestor) that HEAD included.
+ UPSTREAM_MAIN_SHA=$(git merge-base "${HEAD_SHA}" "upstream/$UPSTREAM_REF")
+ echo "ROCM_BASE_SHA=$ROCM_BASE_SHA"
+ echo "UPSTREAM_MAIN_SHA=$UPSTREAM_MAIN_SHA"
+ echo "UPSTREAM_REF_USED=$UPSTREAM_REF"
+
+
+ echo "PR_NUM=$PR_NUM" >> "$GITHUB_OUTPUT"
+ echo "BASE_REF=$BASE_REF" >> "$GITHUB_OUTPUT"
+ echo "UPSTREAM_REF_USED=$UPSTREAM_REF" >> "$GITHUB_OUTPUT"
+ echo "HEAD_SHA=$HEAD_SHA" >> "$GITHUB_OUTPUT"
+ echo "MERGE_SHA=$MERGE_SHA" >> "$GITHUB_OUTPUT"
+ echo "ROCM_BASE_SHA=$ROCM_BASE_SHA" >> "$GITHUB_OUTPUT"
+ echo "UPSTREAM_MAIN_SHA=$UPSTREAM_MAIN_SHA" >> "$GITHUB_OUTPUT"
+
+ - name: Extract tag base from PR title
+ id: tagname
+ if: (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0) || (github.event_name != 'workflow_dispatch')
+ env:
+ TITLE: ${{ steps.get_pr.outputs.title || github.event.pull_request.title }}
+ run: |
+ # Remove everything up to and including "[AUTOGENERATED]"
+ # Remove trailing whitespace
+ BASE_TAG=$(echo "$TITLE" | sed -E 's/^\[AUTOGENERATED\][[:space:]]*//' | sed -E 's/[[:space:]]+$//')
+
+ echo "BASE_TAG=$BASE_TAG"
+ echo "PRE_TAG=${BASE_TAG}_pre"
+ echo "POST_TAG=${BASE_TAG}_post"
+
+ # Extract branch name from BASE_TAG (everything before _IFU_)
+ BRANCH="${BASE_TAG%_IFU_*}"
+ echo "BRANCH=$BRANCH"
+
+ echo "BASE_TAG=$BASE_TAG" >> $GITHUB_OUTPUT
+ echo "PRE_TAG=${BASE_TAG}_pre" >> $GITHUB_OUTPUT
+ echo "POST_TAG=${BASE_TAG}_post" >> $GITHUB_OUTPUT
+ echo "BRANCH=$BRANCH" >> $GITHUB_OUTPUT
+
+ # Find the most recent post tag for this branch
+ # This is needed to know the range of commits for issue creation
+ - name: Find previous post tag
+ id: prev_tag
+ env:
+ # Dispatch without pr_num: test_branch; dispatch+pr_num or PR merge: from tagname
+ BRANCH: ${{ (github.event_name == 'workflow_dispatch' && inputs.pr_num == 0 && inputs.test_branch) || steps.tagname.outputs.BRANCH }}
+ run: |
+ echo "Finding previous post tag for branch: ${BRANCH}"
+
+ # List all post tags for this branch, sorted by version (date in tag name)
+ echo "All post tags for ${BRANCH}:"
+ git tag --list "${BRANCH}_IFU_*_post" --sort=-version:refname
+
+ # Get the most recent post tag
+ prev_post_tag=$(git tag --list "${BRANCH}_IFU_*_post" --sort=-version:refname | head -n 1)
+
+ if [[ -z "$prev_post_tag" ]]; then
+ echo "WARNING: No previous post tag found for branch ${BRANCH}"
+ echo "This might be the first IFU for this branch"
+ echo "prev_post_tag=" >> $GITHUB_OUTPUT
+ echo "has_prev_tag=false" >> $GITHUB_OUTPUT
+ else
+ echo "Found previous post tag: $prev_post_tag"
+ echo "prev_post_tag=$prev_post_tag" >> $GITHUB_OUTPUT
+ echo "has_prev_tag=true" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Validate full-chain test start ref
+ if: github.event_name == 'workflow_dispatch' && inputs.run_full_chain == true
+ env:
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ run: |
+ if [[ "${HAS_PREV_TAG}" != "true" && -z "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo "ERROR: No previous post tag found for this branch."
+ echo "For cold-start full-chain tests, provide test_issue_prev_ref (tag or SHA)."
+ exit 1
+ fi
+
+ # In test mode, print a summary of what was found
+ - name: Test mode summary
+ if: github.event_name == 'workflow_dispatch'
+ env:
+ BRANCH: ${{ inputs.test_branch }}
+ PREV_POST_TAG: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ TEST_CURR_PRE_TAG: ${{ inputs.test_curr_pre_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ RUN_FULL_CHAIN: ${{ inputs.run_full_chain }}
+ run: |
+ echo "=========================================="
+ echo "TEST MODE SUMMARY"
+ echo "=========================================="
+ echo "Branch: ${BRANCH}"
+ echo "Has previous post tag: ${HAS_PREV_TAG}"
+ echo "Previous post tag: ${PREV_POST_TAG:-'(none)'}"
+ echo ""
+ if [[ "${RUN_FULL_CHAIN}" == "true" ]]; then
+ echo " FULL CHAIN TEST ENABLED"
+ echo "Will call create_ifu_issues.yml with:"
+ echo " - prev_post_tag: ${PREV_POST_TAG}"
+ echo " - curr_pre_tag: ${TEST_CURR_PRE_TAG}"
+ if [[ -n "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo " - test_issue_prev_ref override: ${TEST_ISSUE_PREV_REF}"
+ fi
+ echo ""
+ echo " WARNING: This will create REAL issues!"
+ else
+ echo " Full chain test NOT enabled"
+ echo "To test issue creation, re-run with:"
+ echo " - run_full_chain: true"
+ echo " - test_curr_pre_tag: (an existing pre tag)"
+ fi
+ echo "=========================================="
+
+ # Determine the start reference for issue creation.
+ # Priority:
+ # 1) previous IFU post tag (normal path)
+ # 2) test_issue_prev_ref (cold-start fallback for workflow_dispatch test path)
+ # 3) UPSTREAM_MAIN_SHA (cold-start fallback on real merge path only)
+ - name: Resolve issue range start reference
+ id: prev_ref
+ env:
+ HAS_PREV_TAG: ${{ steps.prev_tag.outputs.has_prev_tag }}
+ PREV_POST_TAG: ${{ steps.prev_tag.outputs.prev_post_tag }}
+ TEST_ISSUE_PREV_REF: ${{ inputs.test_issue_prev_ref }}
+ EVENT_NAME: ${{ github.event_name }}
+ UPSTREAM_MAIN_SHA: ${{ steps.shas.outputs.UPSTREAM_MAIN_SHA }}
+ run: |
+ if [[ "${HAS_PREV_TAG}" == "true" && -n "${PREV_POST_TAG}" ]]; then
+ echo "Using previous IFU post tag for issue range start: ${PREV_POST_TAG}"
+ echo "issue_prev_ref=${PREV_POST_TAG}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ elif [[ "${EVENT_NAME}" == "workflow_dispatch" && -n "${TEST_ISSUE_PREV_REF}" ]]; then
+ echo "Using test override for issue range start: ${TEST_ISSUE_PREV_REF}"
+ echo "issue_prev_ref=${TEST_ISSUE_PREV_REF}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ elif [[ "${EVENT_NAME}" != "workflow_dispatch" && -n "${UPSTREAM_MAIN_SHA:-}" ]]; then
+ echo "No previous IFU post tag found; using cold-start fallback UPSTREAM_MAIN_SHA: ${UPSTREAM_MAIN_SHA}"
+ echo "issue_prev_ref=${UPSTREAM_MAIN_SHA}" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=true" >> "$GITHUB_OUTPUT"
+ else
+ echo "Could not determine issue range start reference."
+ echo "issue_prev_ref=" >> "$GITHUB_OUTPUT"
+ echo "can_create_issues=false" >> "$GITHUB_OUTPUT"
+ fi
+
+ - name: Create pre/post tags
+ if: (github.event_name == 'pull_request') || (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0)
+ shell: bash
+ run: |
+ set -euo pipefail
+ echo "Tagging:"
+ echo " ${{ steps.tagname.outputs.PRE_TAG }} @ ${{ steps.shas.outputs.ROCM_BASE_SHA }}"
+ echo " ${{ steps.tagname.outputs.POST_TAG }} @ ${{ steps.shas.outputs.MERGE_SHA }}"
+
+ git tag -a "${{ steps.tagname.outputs.PRE_TAG }}" -m "IFU pre (PR #${{ steps.shas.outputs.PR_NUM }})" "${{ steps.shas.outputs.ROCM_BASE_SHA }}"
+ git tag -a "${{ steps.tagname.outputs.POST_TAG }}" -m "IFU post (PR #${{ steps.shas.outputs.PR_NUM }})" "${{ steps.shas.outputs.MERGE_SHA }}"
+
+ #Force pushing is safe. If we land a new PR, we'd wanna retag a commit if we have to.
+ git push origin "refs/tags/${{ steps.tagname.outputs.PRE_TAG }}" -f
+ git push origin "refs/tags/${{ steps.tagname.outputs.POST_TAG }}" -f
+
+ - name: Append rocm_base & upstream_main to PR body
+ if: (github.event_name == 'pull_request') || (github.event_name == 'workflow_dispatch' && inputs.pr_num != 0)
+ env:
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
+ shell: bash
+ run: |
+ set -euo pipefail
+ # Read current body
+ PR="${{ steps.shas.outputs.PR_NUM }}"
+ CURR=$(gh api repos/${{ github.repository }}/pulls/$PR --jq .body)
+ APPEND=$'\n'"rocm_base: ${{ steps.shas.outputs.ROCM_BASE_SHA }}"$'\n'"upstream_main: ${{ steps.shas.outputs.UPSTREAM_MAIN_SHA }}"$'\n'
+ NEW_BODY="${CURR}${APPEND}"
+
+ # Write to a temp file and update PR body
+ printf '%s' "$NEW_BODY" > body.txt
+ gh api --method PATCH -H "Accept: application/vnd.github+json" \
+ repos/${{ github.repository }}/pulls/$PR -F body=@body.txt
+
+ # Calls create_ifu_issues.yml after tagging
+ # Runs for:
+ # - Real PR merges (when a start reference can be resolved)
+ # - Test mode with run_full_chain=true (when a start reference can be resolved)
+ create-issues:
+ needs: tag-ifu
+ if: >
+ needs.tag-ifu.outputs.can_create_issues == 'true' &&
+ (github.event_name != 'workflow_dispatch' || inputs.run_full_chain == true || inputs.pr_num != 0)
+ uses: ./.github/workflows/create_ifu_issues.yml
+ with:
+ prev_post_tag: ${{ needs.tag-ifu.outputs.issue_prev_ref }}
+ curr_pre_tag: ${{ needs.tag-ifu.outputs.curr_pre_tag }}
+ secrets:
+ IFU_GITHUB_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }}
diff --git a/.github/workflows/parity.yml b/.github/workflows/parity.yml
new file mode 100644
index 0000000000000..5f88548712818
--- /dev/null
+++ b/.github/workflows/parity.yml
@@ -0,0 +1,394 @@
+name: Parity Report
+run-name: "${{ inputs.baseline_sha && format('{0} vs {1}', inputs.sha || 'latest', inputs.baseline_sha) || inputs.csv_name || inputs.pr_id && format('PR {0}', inputs.pr_id) || inputs.sha || 'latest' }} · ${{ inputs.arch || 'mi355, mi300, mi200' }}"
+
+on:
+ workflow_dispatch:
+ inputs:
+ # download_testlogs flags
+ sha:
+ description: 'Commit SHA to pull test results for. Example: 67f1ccf46a966e75f37facd497a03f7d1bd72982. Leave empty for latest green on main.'
+ required: false
+ type: string
+ baseline_sha:
+ description: 'Baseline commit SHA to compare against (same workflow/arch). Produces a commit-vs-commit report instead of ROCm-vs-CUDA.'
+ required: false
+ type: string
+ pr_id:
+ description: 'Pull request number (alternative to SHA, uses latest commit). Example: 176306'
+ required: false
+ type: string
+ arch:
+ description: 'ROCm architectures, comma or space separated. Options: mi355, mi300, mi200, nightly, navi31. Example: "nightly, mi355" or "mi300"'
+ required: false
+ default: 'mi355, mi300, mi200'
+ type: string
+ exclude_distributed:
+ description: 'Exclude distributed tests (auto-excluded for navi31)'
+ required: false
+ default: false
+ type: boolean
+ exclude_inductor:
+ description: 'Exclude inductor tests (auto-excluded for navi31)'
+ required: false
+ default: false
+ type: boolean
+ exclude_default:
+ description: 'Exclude default tests'
+ required: false
+ default: false
+ type: boolean
+ include_logs:
+ description: 'Download and include CI log files (.txt) in artifact zip'
+ required: false
+ default: true
+ type: boolean
+ skip_rocm:
+ description: 'Skip downloading ROCm test results (generate CUDA-only report)'
+ required: false
+ default: false
+ type: boolean
+ skip_cuda:
+ description: 'Skip downloading CUDA test results (generate ROCm-only report)'
+ required: false
+ default: false
+ type: boolean
+ # summarize_xml_testreports flags
+ set1_name:
+ description: 'Label for ROCm columns in output CSV. Examples: rocm, nightly, mi300. Default: rocm'
+ required: false
+ default: 'rocm'
+ type: string
+ set2_name:
+ description: 'Label for CUDA columns in output CSV. Examples: cuda, trunk. Default: cuda'
+ required: false
+ default: 'cuda'
+ type: string
+ csv_name:
+ description: 'Custom prefix for output filenames and artifacts. Default: YYYYMMDD_all_tests_status'
+ required: false
+ type: string
+ include_inductor_periodic:
+ description: 'Download inductor-periodic benchmark artifacts (separate from parity CSV)'
+ required: false
+ default: false
+ type: boolean
+ include_xml:
+ description: 'Include raw XML test reports in artifact zip (WARNING: drastically increases artifact size ~10x)'
+ required: false
+ default: false
+ type: boolean
+ auto_classify:
+ description: 'Auto-classify skip reasons for SKIPPED/MISSED tests in the output CSV'
+ required: false
+ default: false
+ type: boolean
+
+jobs:
+ setup-matrix:
+ runs-on: ubuntu-latest
+ outputs:
+ arch-matrix: ${{ steps.parse.outputs.matrix }}
+ prefix: ${{ steps.parse.outputs.prefix }}
+ steps:
+ - name: Parse arch input into matrix
+ id: parse
+ run: |
+ ARCHS="${{ inputs.arch }}"
+ ARCHS=$(echo "$ARCHS" | tr ',[:space:]' '\n' | sed '/^$/d' | tr '\n' ' ')
+ JSON=$(echo "$ARCHS" | tr ' ' '\n' | sed '/^$/d' | sed 's/^/"/;s/$/"/' | paste -sd',' | sed 's/^/[/;s/$/]/')
+ echo "matrix=$JSON" >> "$GITHUB_OUTPUT"
+ echo "Architectures: $JSON"
+
+ if [ -n "${{ inputs.csv_name }}" ]; then
+ PREFIX="${{ inputs.csv_name }}"
+ elif [ -n "${{ inputs.sha }}" ]; then
+ PREFIX="${{ inputs.sha }}"
+ elif [ -n "${{ inputs.pr_id }}" ]; then
+ PREFIX="${{ inputs.pr_id }}"
+ else
+ PREFIX="parity"
+ fi
+ PREFIX=$(echo "$PREFIX" | xargs)
+ echo "prefix=$PREFIX" >> "$GITHUB_OUTPUT"
+ echo "Artifact prefix: $PREFIX"
+
+ generate-parity:
+ needs: setup-matrix
+ runs-on: ubuntu-latest
+ strategy:
+ fail-fast: false
+ matrix:
+ arch: ${{ fromJson(needs.setup-matrix.outputs.arch-matrix) }}
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Install dependencies
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: pip install -r requirements.txt
+
+ - name: Download artifacts
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ env:
+ GITHUB_TOKEN: ${{ secrets.PARITY_GITHUB_TOKEN }}
+ AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
+ AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
+ run: |
+ ARGS="--arch ${{ matrix.arch }}"
+
+ if [ -n "${{ inputs.sha }}" ]; then
+ ARGS="$ARGS --sha1 ${{ inputs.sha }}"
+ fi
+ if [ -n "${{ inputs.pr_id }}" ]; then
+ ARGS="$ARGS --pr_id ${{ inputs.pr_id }}"
+ fi
+ if [ "${{ inputs.exclude_distributed }}" = "true" ]; then
+ ARGS="$ARGS --exclude_distributed"
+ fi
+ if [ "${{ inputs.exclude_inductor }}" = "true" ]; then
+ ARGS="$ARGS --exclude_inductor"
+ fi
+ if [ "${{ inputs.exclude_default }}" = "true" ]; then
+ ARGS="$ARGS --exclude_default"
+ fi
+ ARGS="$ARGS --ignore_status"
+ if [ "${{ inputs.include_logs }}" != "true" ]; then
+ ARGS="$ARGS --artifacts_only"
+ fi
+ if [ "${{ inputs.skip_rocm }}" = "true" ]; then
+ ARGS="$ARGS --no_rocm"
+ fi
+ if [ "${{ inputs.skip_cuda }}" = "true" ]; then
+ ARGS="$ARGS --no_cuda"
+ fi
+ if [ "${{ inputs.include_inductor_periodic }}" = "true" ]; then
+ ARGS="$ARGS --include_inductor_periodic"
+ fi
+ if [ -n "${{ inputs.baseline_sha }}" ]; then
+ ARGS="$ARGS --baseline_sha ${{ inputs.baseline_sha }}"
+ fi
+
+ echo "Running: python3 ./download_testlogs $ARGS"
+ python3 ./download_testlogs $ARGS 2>&1 | tee download_${{ matrix.arch }}.log
+
+ - name: Identify output folder
+ id: folder
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: |
+ FOLDER=$(ls -dt [0-9]*_[0-9a-f]*/ 2>/dev/null | head -1 | sed 's:/$::')
+ if [ -z "$FOLDER" ]; then
+ echo "ERROR: No output folder found"
+ exit 1
+ fi
+ echo "folder=$FOLDER" >> "$GITHUB_OUTPUT"
+ SHA=$(echo "$FOLDER" | grep -oP '[0-9a-f]{40}')
+ echo "sha=$SHA" >> "$GITHUB_OUTPUT"
+ DATE=$(TZ='America/Los_Angeles' date '+%Y%m%d')
+ echo "date=$DATE" >> "$GITHUB_OUTPUT"
+ mv download_${{ matrix.arch }}.log "$FOLDER/" 2>/dev/null || true
+ echo "Output folder: $FOLDER, SHA: $SHA, Date: $DATE"
+
+ - name: Generate CSV
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: |
+ FOLDER="${{ steps.folder.outputs.folder }}"
+ DATE="${{ steps.folder.outputs.date }}"
+ ARCH="${{ matrix.arch }}"
+
+ if [ -n "${{ inputs.csv_name }}" ]; then
+ CSV_NAME="${{ inputs.csv_name }}_${ARCH}"
+ else
+ CSV_NAME="${DATE}_all_tests_status_${ARCH}"
+ fi
+
+ ARGS="--set1 $FOLDER/rocm_xml"
+ if [ -n "${{ inputs.baseline_sha }}" ]; then
+ ARGS="$ARGS --set2 $FOLDER/baseline_xml"
+ CURRENT_SHORT=$(echo "${{ steps.folder.outputs.sha }}" | cut -c1-8)
+ BASELINE_SHORT=$(echo "${{ inputs.baseline_sha }}" | cut -c1-8)
+ ARGS="$ARGS --set1_name ${CURRENT_SHORT}"
+ ARGS="$ARGS --set2_name ${BASELINE_SHORT}"
+ else
+ if [ "${{ inputs.skip_cuda }}" != "true" ]; then
+ ARGS="$ARGS --set2 $FOLDER/cuda_xml"
+ fi
+ ARGS="$ARGS --set1_name ${{ inputs.set1_name }}"
+ ARGS="$ARGS --set2_name ${{ inputs.set2_name }}"
+ fi
+ ARGS="$ARGS --output_csv $FOLDER/${CSV_NAME}.csv"
+ SHORT_ARCH=$(echo "$ARCH" | sed 's/^mi//')
+ if [ -n "${{ inputs.csv_name }}" ]; then
+ RT_NAME="${{ inputs.csv_name }}_running_time_${SHORT_ARCH}"
+ else
+ RT_NAME="${DATE}_running_time_${SHORT_ARCH}"
+ fi
+ ARGS="$ARGS --test_file_running_time_output_csv $FOLDER/${RT_NAME}.csv"
+
+ echo "Running: python3 -u summarize_xml_testreports.py $ARGS"
+ python3 -u summarize_xml_testreports.py $ARGS 2>&1 | tee "$FOLDER/xml_processing_${DATE}.log"
+
+ - name: Auto-classify skip reasons
+ if: ${{ inputs.auto_classify }}
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: |
+ FOLDER="${{ steps.folder.outputs.folder }}"
+ CSV=$(find "$FOLDER" -maxdepth 1 -name "*.csv" ! -name "*_running_time*" | head -1)
+ if [ -n "$CSV" ]; then
+ echo "Auto-classifying skip reasons in $CSV"
+ python3 auto_classify_skip_reasons.py -i "$CSV" -o "$CSV" --report 2>&1
+ else
+ echo "No parity CSV found in $FOLDER, skipping auto-classify"
+ fi
+
+ - name: Detect log-based failures (timeouts, crashes)
+ if: ${{ inputs.include_logs }}
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: |
+ FOLDER="${{ steps.folder.outputs.folder }}"
+ if ls "$FOLDER"/*.txt 1>/dev/null 2>&1; then
+ python3 detect_log_failures.py --logs-dir "$FOLDER" --output "$FOLDER/log_failures_${{ matrix.arch }}.csv" 2>&1 || true
+ else
+ echo "No log files found in $FOLDER, skipping log failure detection"
+ fi
+
+ - name: Collect upload paths
+ id: upload-paths
+ run: |
+ FOLDER=".automation_scripts/pytorch-unit-test-scripts/${{ steps.folder.outputs.folder }}"
+ PATHS="${FOLDER}/*.csv
+ ${FOLDER}/*.log
+ ${FOLDER}/*.txt
+ ${FOLDER}/inductor_periodic_rocm_dir/
+ ${FOLDER}/inductor_periodic_cuda_dir/"
+ if [ "${{ inputs.include_xml }}" = "true" ]; then
+ PATHS="${PATHS}
+ ${FOLDER}/rocm_xml/
+ ${FOLDER}/cuda_xml/
+ ${FOLDER}/baseline_xml/"
+ fi
+ echo "paths<> "$GITHUB_OUTPUT"
+ echo "$PATHS" >> "$GITHUB_OUTPUT"
+ echo "EOF" >> "$GITHUB_OUTPUT"
+
+ - name: Upload artifacts
+ uses: actions/upload-artifact@v4
+ with:
+ name: ${{ needs.setup-matrix.outputs.prefix }}-results-${{ matrix.arch }}
+ retention-days: 1
+ path: ${{ steps.upload-paths.outputs.paths }}
+
+ summarize:
+ needs: [setup-matrix, generate-parity]
+ if: ${{ !cancelled() }}
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+
+ - name: Set up Python
+ uses: actions/setup-python@v5
+ with:
+ python-version: '3.10'
+
+ - name: Download all per-arch CSV artifacts
+ uses: actions/download-artifact@v4
+ with:
+ pattern: ${{ needs.setup-matrix.outputs.prefix }}-results-*
+ path: artifacts
+
+ - name: Build parity report
+ working-directory: .automation_scripts/pytorch-unit-test-scripts
+ run: |
+ ARCHS="${{ inputs.arch }}"
+ SHA="${{ inputs.sha }}"
+ PR_ID="${{ inputs.pr_id }}"
+ BASELINE_SHA="${{ inputs.baseline_sha }}"
+ if [ -n "$BASELINE_SHA" ]; then
+ SET1=$(echo "$SHA" | cut -c1-8)
+ SET2=$(echo "$BASELINE_SHA" | cut -c1-8)
+ else
+ SET1="${{ inputs.set1_name }}"
+ SET2="${{ inputs.set2_name }}"
+ fi
+
+ ARCHS=$(echo "$ARCHS" | tr ',[:space:]' ' ')
+ PREFIX=$(echo "${{ needs.setup-matrix.outputs.prefix }}" | xargs)
+ CSV_ARGS=()
+ ARCH_ARGS=()
+ for ARCH in $ARCHS; do
+ ARTIFACT_DIR="../../artifacts/${PREFIX}-results-${ARCH}"
+ CSV=$(find "$ARTIFACT_DIR"/ -maxdepth 2 -name "*.csv" ! -name "*_running_time*" ! -name "*_summary*" ! -name "log_failures_*" ! -name "log_shards_*" ! -name "flaky_tests_*" 2>/dev/null | head -1)
+ if [ -z "$CSV" ]; then
+ echo "WARNING: No CSV found for $ARCH, skipping"
+ continue
+ fi
+ echo "Found CSV for $ARCH: $CSV"
+ CSV_ARGS+=("$CSV")
+ ARCH_ARGS+=("$ARCH")
+ done
+
+ if [ ${#CSV_ARGS[@]} -eq 0 ]; then
+ echo "::warning::No CSVs found for any architecture — some or all generate-parity jobs may have failed"
+ echo "## ⚠ No CSVs produced" >> "$GITHUB_STEP_SUMMARY"
+ echo "No parity CSVs were found. Check the generate-parity job logs for errors." >> "$GITHUB_STEP_SUMMARY"
+ exit 0
+ fi
+
+ ARGS=(--csv "${CSV_ARGS[@]}" --arch "${ARCH_ARGS[@]}")
+ ARGS+=(--set1_name "$SET1" --set2_name "$SET2")
+
+ if [ -n "$SHA" ]; then
+ ARGS+=(--sha "$SHA")
+ else
+ DETECTED_SHA=$(basename "$(find ../../artifacts/ -name '*.csv' | head -1)" | grep -oP '[0-9a-f]{40}' || true)
+ if [ -n "$DETECTED_SHA" ]; then
+ ARGS+=(--sha "$DETECTED_SHA")
+ fi
+ fi
+ if [ -n "$PR_ID" ]; then
+ ARGS+=(--pr_id "$PR_ID")
+ fi
+
+ # Collect log failure CSVs if they exist
+ LOG_FAIL_ARGS=()
+ for ARCH in $ARCHS; do
+ LF="../../artifacts/${PREFIX}-results-${ARCH}"
+ LF_CSV=$(find "$LF"/ -maxdepth 2 -name "log_failures_*.csv" 2>/dev/null | head -1)
+ if [ -n "$LF_CSV" ]; then
+ LOG_FAIL_ARGS+=("$LF_CSV")
+ echo "Found log failures for $ARCH: $LF_CSV"
+ fi
+ done
+ if [ ${#LOG_FAIL_ARGS[@]} -gt 0 ]; then
+ ARGS+=(--log-failures "${LOG_FAIL_ARGS[@]}")
+ fi
+
+ OUTPUT="${PREFIX}_summary"
+ ARGS+=(--output "$OUTPUT")
+
+ echo "Running: python3 generate_summary.py ${ARGS[*]}"
+ python3 generate_summary.py "${ARGS[@]}"
+
+ cat "${OUTPUT}.md" >> "$GITHUB_STEP_SUMMARY"
+
+ - name: Add artifact links to summary
+ env:
+ GH_TOKEN: ${{ github.token }}
+ run: |
+ ARTIFACTS_JSON=$(gh api repos/${{ github.repository }}/actions/runs/${{ github.run_id }}/artifacts --paginate -q '.artifacts[] | {name, id}')
+ RUN_URL="${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}/artifacts"
+
+ {
+ echo ""
+ echo "### ARTIFACTS"
+ echo ""
+ echo "| Artifact | Link |"
+ echo "| --- | --- |"
+ echo "$ARTIFACTS_JSON" | jq -r '"| \(.name) | [Download]('"${RUN_URL}"'/\(.id)) |"'
+ echo ""
+ } >> "$GITHUB_STEP_SUMMARY"
diff --git a/.github/workflows/pytorch_ifu.yml b/.github/workflows/pytorch_ifu.yml
new file mode 100644
index 0000000000000..a06c567a61dcb
--- /dev/null
+++ b/.github/workflows/pytorch_ifu.yml
@@ -0,0 +1,145 @@
+name: PyTorch IFU (Sync with upstream)
+
+on:
+ workflow_dispatch:
+ inputs:
+ ifu_target_repo:
+ description: "Target repo for IFU"
+ required: false
+ default: "ROCm/pytorch"
+ type: string
+ ifu_target_branch:
+ description: "Target branch for IFU"
+ required: true
+ default: "rocm7.1_internal_testing"
+ type: string
+ ifu_source_repo:
+ description: "Source repo for IFU"
+ required: false
+ default: "pytorch/pytorch"
+ type: string
+ ifu_source_branch:
+ description: "Source branch for IFU"
+ required: false
+ default: "main"
+ type: string
+ # schedule:
+ # # Runs every 14 days at 09:00 AM UTC/ 04:00 AM CST
+ # - cron: "0 9 */14 * *"
+
+permissions:
+ contents: write # push branches/tags
+ pull-requests: write # create PRs
+
+concurrency:
+ group: ifu
+ # If two jobs are running simultaneously, we will queue them (not cancel the one running)
+ cancel-in-progress: false
+
+jobs:
+ ifu:
+ runs-on: ubuntu-latest
+ env:
+ UPSTREAM_REMOTE: upstream # IFU source remote name
+ UPSTREAM_REPO: ${{ inputs.ifu_source_repo }} # source repo for IFU
+ UPSTREAM_BRANCH: ${{ inputs.ifu_source_branch }} # source branch for IFU
+ DOWNSTREAM_REMOTE: origin # IFU target remote name
+ DOWNSTREAM_REPO: ${{ inputs.ifu_target_repo }} # target repo for IFU (fork); actions/checkout sets this to origin
+ DOWNSTREAM_BRANCH: ${{ inputs.ifu_target_branch }} # target branch for IFU
+ GH_TOKEN: ${{ secrets.IFU_GITHUB_TOKEN }} # used by gh; provided by Action
+ steps:
+ - name: Checkout repository (${{ env.DOWNSTREAM_REPO }}) (full history)
+ uses: actions/checkout@v4
+ with:
+ repository: ${{ env.DOWNSTREAM_REPO }}
+ path: ${{ env.DOWNSTREAM_REPO }}
+ ref: ${{ env.DOWNSTREAM_BRANCH }}
+ token: ${{ env.GH_TOKEN }}
+ fetch-depth: 0 # need full history for merges/tags
+ submodules: recursive
+
+ - name: Add upstream remote (${{ env.UPSTREAM_REPO }})
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ if ! git remote get-url ${UPSTREAM_REMOTE} >/dev/null 2>&1; then
+ git remote add ${UPSTREAM_REMOTE} https://github.com/${UPSTREAM_REPO}.git
+ fi
+ # Confirm remotes
+ git remote -v
+
+ - name: Configure Git user
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ git config user.name "github-actions[bot]"
+ git config user.email "github-actions[bot]@users.noreply.github.com"
+
+ - name: Fetch upstream and local branch
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ git fetch ${UPSTREAM_REMOTE} ${UPSTREAM_BRANCH}
+ git fetch ${DOWNSTREAM_REMOTE} ${DOWNSTREAM_BRANCH}
+
+ - name: Compute date tag and create working branch
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ id: tag
+ shell: bash
+ run: |
+ DATE="$(date +"%Y%m%d")"
+ TAG="${DOWNSTREAM_BRANCH}_IFU_${DATE}"
+ echo "TAG=${TAG}" >> $GITHUB_OUTPUT
+ # Start from rocm branch
+ git checkout -b "$TAG" "${DOWNSTREAM_REMOTE}/${DOWNSTREAM_BRANCH}"
+
+ - name: Save ROCm base commit
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ id: rocm_base
+ run: |
+ base_commit=`git rev-parse --short HEAD`
+ echo "ROCM_BASE_COMMIT=$base_commit" >> $GITHUB_OUTPUT
+
+ - name: Merge upstream into working branch (non-interactive)
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ id: merge
+ run: |
+ if git merge "${UPSTREAM_REMOTE}/${UPSTREAM_BRANCH}" --no-edit; then
+ echo "merge_status=clean" >> $GITHUB_OUTPUT
+ else
+ echo "Merge conflicts detected. Committing current resolution snapshot."
+ git submodule sync
+ git submodule update --init --recursive
+ git add -A
+ git status
+ git commit --no-edit
+ echo "merge_status=conflict" >> $GITHUB_OUTPUT
+ fi
+
+ - name: Push branch & tag to fork
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ git push ${DOWNSTREAM_REMOTE} "${{ steps.tag.outputs.TAG }}"
+
+ - name: Authenticate gh (non-interactive)
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ # The GitHub-hosted runner has gh preinstalled.
+ gh auth status || echo "$GH_TOKEN" | gh auth login --with-token
+ gh repo set-default "${{ env.DOWNSTREAM_REPO }}"
+
+ - name: Create Pull Request with gh
+ working-directory: ${{ env.DOWNSTREAM_REPO }}
+ run: |
+ BASE="${DOWNSTREAM_BRANCH}"
+ HEAD="${{ steps.tag.outputs.TAG }}"
+ TITLE="[AUTOGENERATED] $HEAD"
+ BODY="rocm_base: ${{ steps.rocm_base.outputs.ROCM_BASE_COMMIT }}"
+
+ # If a PR for this head already exists, skip creating a new one
+ if gh pr list --head "$HEAD" --base "$BASE" --state all --json number | grep -q '[0-9]'; then
+ echo "PR already exists for $HEAD -> $BASE. Skipping creation."
+ else
+ gh pr create --base "$BASE" --head "$HEAD" --title "$TITLE" --body "$BODY"
+ fi
+
+ - name: Summarize
+ run: |
+ echo "::notice title=IFU Completed::Branch ${{ steps.tag.outputs.TAG }} pushed. PR created (or already existed). Merge status: ${{ steps.merge.outputs.merge_status }}"
diff --git a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
index 65c89cb709790..2edb89891d889 100644
--- a/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
+++ b/aten/src/ATen/native/cuda/linalg/BatchLinearAlgebraLib.cpp
@@ -323,11 +323,11 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
// gesvd just knows how to handle m >= n, so in the other case we need to transpose A
const auto not_A_H = A.size(-2) >= A.size(-1);
Tensor Vcopy = V; // Shallow copy
-#ifdef USE_ROCM
+#ifdef ROCM_VERSION
// Similar to the case in svd_magma(), experiments have shown Vh tensor is
// not guaranteed to be column major on ROCM, we have to create a copy to
// deal with this
- if (!not_A_H) {
+ if (compute_uv && !not_A_H) {
Vcopy = at::empty_like(V.mT(),
V.options()
.device(V.device())
@@ -342,8 +342,8 @@ static void svd_cusolver_gesvd(const Tensor& A, const Tensor& U, const Tensor& S
infos,
full_matrices, compute_uv, calculate_all_batches, batches);
});
-#ifdef USE_ROCM
- if (!not_A_H) {
+#ifdef ROCM_VERSION
+ if (compute_uv && !not_A_H) {
V.copy_(Vcopy);
}
#endif
@@ -517,8 +517,8 @@ static void svd_cusolver_gesvdjBatched(const Tensor& A, const Tensor& U, const T
template
static void apply_svd_cusolver_gesvdaStridedBatched(const Tensor& A, const Tensor& U, const Tensor& S, const Tensor& V,
const Tensor& infos, bool full_matrices, bool compute_uv) {
-#ifndef CUDART_VERSION
- TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend.")
+#if defined(CUDART_VERSION) || defined(USE_ROCM) && ROCM_VERSION < 60100
+ TORCH_CHECK(false, "gesvda: Batched version is supported only with cuBLAS backend or ROCM >= 5.7.0.")
#else
using value_t = typename c10::scalar_value_type::type;
int m = cuda_int_cast(A.size(-2), "m");
@@ -656,7 +656,7 @@ void svd_cusolver(const Tensor& A,
static constexpr const char* check_svd_doc = "Check doc at https://pytorch.org/docs/stable/generated/torch.linalg.svd.html";
// The default heuristic is to use gesvdj driver
-#ifdef USE_ROCM
+#if defined(ROCM_VERSION) && ROCM_VERSION < 60100
const auto driver_v = std::string_view("gesvdj");
#else
const auto driver_v = driver.value_or("gesvdj");
diff --git a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
index 3f7a13424294b..ae58a21f9a437 100644
--- a/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
+++ b/aten/src/ATen/native/cuda/linalg/CUDASolver.cpp
@@ -469,8 +469,8 @@ void gesvdjBatched>(
}
-// ROCM does not implement gesdva yet
-#ifdef CUDART_VERSION
+// ROCM does not implement gesdva correctly before 6.1
+#if defined(CUDART_VERSION) || defined(ROCM_VERSION) && ROCM_VERSION >= 60100
template<>
void gesvdaStridedBatched_buffersize(
cusolverDnHandle_t handle, cusolverEigMode_t jobz, int rank, int m, int n, float *A, int lda, long long int strideA,
diff --git a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
index 745c9eb9af6ab..de17ba15d90f5 100644
--- a/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
+++ b/aten/src/ATen/native/sparse/cuda/SparseMatMul.cu
@@ -39,7 +39,27 @@
#include
+#if defined(__CUDACC__) && (defined(CUSPARSE_VERSION) || (defined(USE_ROCM) && ROCM_VERSION >= 60300))
+#define IS_CUSPARSE11_AVAILABLE() 1
+#else
+#define IS_CUSPARSE11_AVAILABLE() 0
+#endif
+
+#if defined(USE_ROCM) && (ROCM_VERSION >= 70000)
+#define HIPSPARSE_FP16_SUPPORT 1
+#else
+#define HIPSPARSE_FP16_SUPPORT 0
+#endif
+
+#if defined(USE_ROCM) && (ROCM_VERSION >= 70100)
+#define HIPSPARSE_FP16_BF16_SUPPORT 1
+#else
+#define HIPSPARSE_FP16_BF16_SUPPORT 0
+#endif
+
+#if IS_CUSPARSE11_AVAILABLE()
#include
+#endif
namespace at::native {
diff --git a/dockerfiles/Dockerfile b/dockerfiles/Dockerfile
new file mode 100644
index 0000000000000..361d0219eceef
--- /dev/null
+++ b/dockerfiles/Dockerfile
@@ -0,0 +1,159 @@
+# Dockerfile
+#
+# PyTorch + ROCm image built from TheRock portable wheels.
+# Does NOT include kernel drivers — the host must provide compatible
+# AMDGPU/ROCm kernel components and device access.
+#
+# Recommended docker run flags (mirrors TheRock CI container options):
+# docker run \
+# --shm-size=10g \
+# --cap-add=SYS_PTRACE \
+# --group-add video \
+# --device /dev/kfd \
+# --device /dev/dri \
+#
+#
+# Supported base images (examples)
+# - ubuntu:24.04
+# - almalinux:8
+# - mcr.microsoft.com/azurelinux/base/core:3.0
+#
+# Build arguments
+# - BASE_IMAGE : Base Docker image (default: ubuntu:24.04)
+# - ROCM_VERSION : Full ROCm version string. Supported formats:
+# - Nightly: 7.13.0a20260413
+# - Dev: 7.12.0.dev0+849eec43b2075459511b9a9ffe3bf1948490e9ee
+# - AMDGPU_FAMILY : AMD GPU family (e.g., gfx94X-dcgpu, gfx90X-dcgpu, gfx950-dcgpu)
+# - PYTHON_VERSION : Python version for PyTorch (default: 3.12)
+# - INDEX_URL : (Required) Base URL for PyTorch wheels index
+# - TORCH_VERSION : Optional specific PyTorch version. If not set, installs latest.
+# - TORCHAUDIO_VERSION : Optional specific torchaudio version. If not set, installs latest.
+# - TORCHVISION_VERSION: Optional specific torchvision version. If not set, installs latest.
+# - TRITON_VERSION : Optional specific triton version. If not set, uses torch's dependency.
+#
+# Note: The PyTorch source is included at /workspace/pytorch (from the repo root).
+#
+# Build example (run from repo root):
+#
+# docker build \
+# --build-arg BASE_IMAGE=ubuntu:24.04 \
+# --build-arg ROCM_VERSION=7.13.0a20260413 \
+# --build-arg AMDGPU_FAMILY=gfx94X-dcgpu \
+# --build-arg PYTHON_VERSION=3.12 \
+# --build-arg INDEX_URL=https://rocm.nightlies.amd.com/v2-staging \
+# -f dockerfiles/Dockerfile \
+# -t pytorch-rocm:ubuntu24.04-gfx94X-dcgpu-7.13.0a20260413 \
+# .
+#
+
+# Base image selection
+ARG BASE_IMAGE=ubuntu:24.04
+FROM ${BASE_IMAGE}
+
+# ROCm configuration arguments
+ARG ROCM_VERSION
+ARG AMDGPU_FAMILY
+ARG RELEASE_TYPE=nightly
+
+# PyTorch configuration arguments
+ARG PYTHON_VERSION=3.12
+ARG INDEX_URL
+ARG TORCH_VERSION
+ARG TORCH_VERSION_PREFIX
+ARG TORCHAUDIO_VERSION
+ARG TORCHVISION_VERSION
+ARG TRITON_VERSION
+
+# Copy installation scripts
+COPY .github/scripts/install_rocm_deps.sh /tmp/
+COPY .github/scripts/install_pytorch_wheels.py /tmp/
+
+# Copy PyTorch source from the repo root
+COPY . /workspace/pytorch
+
+# Install system dependencies
+RUN chmod +x /tmp/install_rocm_deps.sh && \
+ /tmp/install_rocm_deps.sh
+
+# Install the requested Python version if not already available.
+# Ubuntu 24.04 ships with 3.12; other versions come from deadsnakes PPA.
+RUN if ! command -v python${PYTHON_VERSION} >/dev/null 2>&1; then \
+ apt-get update && \
+ apt-get install -y --no-install-recommends software-properties-common && \
+ add-apt-repository -y ppa:deadsnakes/ppa && \
+ apt-get update && \
+ apt-get install -y --no-install-recommends \
+ python${PYTHON_VERSION} \
+ python${PYTHON_VERSION}-dev \
+ python${PYTHON_VERSION}-venv && \
+ rm -rf /var/lib/apt/lists/*; \
+ fi
+
+# Create Python virtual environment and upgrade pip/setuptools
+RUN python${PYTHON_VERSION} -m venv /opt/venv && \
+ /opt/venv/bin/python -m pip install --upgrade pip && \
+ /opt/venv/bin/python -m pip install --upgrade setuptools
+
+ENV PATH="/opt/venv/bin:${PATH}"
+
+# Install PyTorch wheels from the public nightlies index.
+RUN /opt/venv/bin/python /tmp/install_pytorch_wheels.py \
+ --no-break-system-packages \
+ --skip-verify \
+ --index-url "${INDEX_URL}" \
+ --amdgpu-family "${AMDGPU_FAMILY}" \
+ ${ROCM_VERSION:+--rocm-version "${ROCM_VERSION}"} \
+ ${TORCH_VERSION:+--torch-version "${TORCH_VERSION}"} \
+ ${TORCH_VERSION_PREFIX:+--torch-version-prefix "${TORCH_VERSION_PREFIX}"} \
+ ${TORCHAUDIO_VERSION:+--torchaudio-version "${TORCHAUDIO_VERSION}"} \
+ ${TORCHVISION_VERSION:+--torchvision-version "${TORCHVISION_VERSION}"} \
+ ${TRITON_VERSION:+--triton-version "${TRITON_VERSION}"}
+
+# Run rocm-sdk init to make rocm buildable
+RUN rocm-sdk init
+
+# ROCm environment variables (mirrors TheRock CI setup in
+# test_pytorch_wheels_full.yml "Initialize ROCm SDK and configure environment").
+# All paths derive from ROCM_HOME which is the rocm-sdk install location.
+ENV ROCM_HOME="/opt/venv/lib/python${PYTHON_VERSION}/site-packages/_rocm_sdk_devel" \
+ ROCM_PATH="/opt/venv/lib/python${PYTHON_VERSION}/site-packages/_rocm_sdk_devel" \
+ ROCM_SOURCE_DIR="/opt/venv/lib/python${PYTHON_VERSION}/site-packages/_rocm_sdk_devel" \
+ ROCM_BIN="/opt/venv/lib/python${PYTHON_VERSION}/site-packages/_rocm_sdk_devel/bin" \
+ ROCM_CMAKE="/opt/venv/lib/python${PYTHON_VERSION}/site-packages/_rocm_sdk_devel/lib/cmake" \
+ PYTORCH_ROCM_ARCH="${AMDGPU_FAMILY}" \
+ VIRTUAL_ENV=/opt/venv \
+ USE_MSLK=0
+
+ENV CMAKE_PREFIX_PATH="${ROCM_CMAKE}" \
+ HIP_DEVICE_LIB_PATH="${ROCM_HOME}/lib/llvm/amdgcn/bitcode" \
+ ROCM_DEVICE_LIB_PATH="${ROCM_HOME}/lib/llvm/amdgcn/bitcode" \
+ ROCM_SYSDEPS_INCLUDE="${ROCM_HOME}/lib/rocm_sysdeps/include" \
+ CPLUS_INCLUDE_PATH="${ROCM_HOME}/lib/rocm_sysdeps/include" \
+ C_INCLUDE_PATH="${ROCM_HOME}/lib/rocm_sysdeps/include" \
+ PKG_CONFIG_PATH="${ROCM_HOME}/lib/rocm_sysdeps/lib/pkgconfig" \
+ LD_LIBRARY_PATH="${ROCM_HOME}/lib/host-math/lib:${ROCM_HOME}/lib/rocm_sysdeps/lib" \
+ LIBRARY_PATH="${ROCM_HOME}/lib/host-math/lib:${ROCM_HOME}/lib/rocm_sysdeps/lib" \
+ CC="${ROCM_HOME}/lib/llvm/bin/clang" \
+ CXX="${ROCM_HOME}/lib/llvm/bin/clang++" \
+ PATH="${ROCM_BIN}:${PATH}"
+
+# Verify PyTorch imports and environment
+RUN python <<'PYEOF'
+import os, torch
+print('torch', torch.__version__)
+print('ROCm/HIP', torch.version.hip)
+print(f'ROCM_HOME={os.environ.get("ROCM_HOME", "NOT SET")}')
+print(f'CC={os.environ.get("CC", "NOT SET")}')
+print(f'CXX={os.environ.get("CXX", "NOT SET")}')
+for mod in ['torchaudio', 'torchvision', 'triton']:
+ try:
+ m = __import__(mod)
+ print(f'{mod} {m.__version__}')
+ except Exception as e:
+ print(f'{mod}: skipped ({e})')
+PYEOF
+
+# Clean up installation scripts
+RUN rm -f /tmp/install_rocm_deps.sh /tmp/install_pytorch_wheels.py
+
+WORKDIR /workspace/pytorch
diff --git a/related_commits b/related_commits
new file mode 100644
index 0000000000000..ee36e55601d0f
--- /dev/null
+++ b/related_commits
@@ -0,0 +1,10 @@
+ubuntu|pytorch|apex|master|2190fbaeb88384ed792373adbb83c182af117ca0|https://github.com/ROCm/apex
+centos|pytorch|apex|master|2190fbaeb88384ed792373adbb83c182af117ca0|https://github.com/ROCm/apex
+ubuntu|pytorch|torchvision|main|218d2ab791d437309f91e0486eb9fa7f00badc17|https://github.com/pytorch/vision
+centos|pytorch|torchvision|main|218d2ab791d437309f91e0486eb9fa7f00badc17|https://github.com/pytorch/vision
+ubuntu|pytorch|torchdata|main|92950795e0790eb74df995daf40b658e85fd2c9f|https://github.com/pytorch/data
+centos|pytorch|torchdata|main|92950795e0790eb74df995daf40b658e85fd2c9f|https://github.com/pytorch/data
+ubuntu|pytorch|torchaudio|main|3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2|https://github.com/pytorch/audio
+centos|pytorch|torchaudio|main|3b0e7a6f192ca2715e7e6cbe5db007aea7165fe2|https://github.com/pytorch/audio
+ubuntu|pytorch|ao|main|3577306c8b32517afe8eb6eb7e84335601180598|https://github.com/pytorch/ao
+centos|pytorch|ao|main|3577306c8b32517afe8eb6eb7e84335601180598|https://github.com/pytorch/ao
diff --git a/test/test_sparse.py b/test/test_sparse.py
index 87a9bccd22dfc..5e21192e19147 100644
--- a/test/test_sparse.py
+++ b/test/test_sparse.py
@@ -130,6 +130,12 @@ def wrapped(fn, inputs, *args, **kwargs):
) or (not IS_WINDOWS and not TEST_WITH_ROCM)
HIPSPARSE_SPMM_COMPLEX128_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("6.0")
+HIPSPARSE_FP16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.0")
+HIPSPARSE_BF16_SUPPORTED = torch.version.hip and version.parse(torch.version.hip.split("-")[0]) >= version.parse("7.1")
+
+SPARSE_COMPLEX128_SUPPORTED = CUSPARSE_SPMM_COMPLEX128_SUPPORTED or HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
+SPARSE_FLOAT16_SUPPORTED = (SM53OrLater and torch.version.cuda) or (HIPSPARSE_FP16_SUPPORTED)
+SPARSE_BFLOAT16_SUPPORTED = (SM80OrLater and torch.version.cuda) or (HIPSPARSE_BF16_SUPPORTED)
def all_sparse_layouts(test_name='layout', include_strided=False):
return parametrize(test_name, [
diff --git a/test/test_sparse_csr.py b/test/test_sparse_csr.py
index fa865499cc8ea..7fbc4d9f451b8 100644
--- a/test/test_sparse_csr.py
+++ b/test/test_sparse_csr.py
@@ -26,7 +26,8 @@
all_types_and_complex, floating_and_complex_types_and)
from torch.testing._internal.opinfo.definitions.linalg import sample_inputs_linalg_solve
from torch.testing._internal.opinfo.definitions.sparse import validate_sample_input_sparse
-from test_sparse import CUSPARSE_SPMM_COMPLEX128_SUPPORTED, HIPSPARSE_SPMM_COMPLEX128_SUPPORTED
+from test_sparse import HIPSPARSE_BF16_SUPPORTED, HIPSPARSE_FP16_SUPPORTED, \
+ SPARSE_FLOAT16_SUPPORTED, SPARSE_BFLOAT16_SUPPORTED, SPARSE_COMPLEX128_SUPPORTED
import operator
if TEST_SCIPY:
diff --git a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
index 16ccc5002f9ab..71695657c1ebb 100644
--- a/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
+++ b/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h
@@ -274,6 +274,11 @@ constexpr auto bfloat16_support_literal =
#define __align__(x) __attribute__((aligned(x)))
#endif
)" BF16_UINT32_DEF R"(
+#if defined(_HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_)
+typedef __hip_bfloat16 __nv_bfloat16;
+typedef struct __align__(2) { unsigned short x; } __nv_bfloat16_raw;
+#else
+
typedef struct __align__(2) {
unsigned short x;
}
@@ -333,6 +338,7 @@ __device__ float __bfloat162float(const __nv_bfloat16 a) {
return u.fp32;
}
#endif /* defined(__cplusplus) */
+#endif /* !defined(_HIP_INCLUDE_HIP_AMD_DETAIL_HIP_BF16_H_) */
)";
#else
constexpr auto bfloat16_support_literal =
diff --git a/torch/testing/_internal/common_cuda.py b/torch/testing/_internal/common_cuda.py
index 0dc3d0b4f10cb..8575c0a75f77d 100644
--- a/torch/testing/_internal/common_cuda.py
+++ b/torch/testing/_internal/common_cuda.py
@@ -228,6 +228,9 @@ def tf32_off():
@contextlib.contextmanager
def tf32_on(self, tf32_precision=1e-5):
+ if torch.version.hip:
+ hip_allow_tf32 = os.environ.get("HIPBLASLT_ALLOW_TF32", None)
+ os.environ["HIPBLASLT_ALLOW_TF32"] = "1"
old_allow_tf32_matmul = torch.backends.cuda.matmul.allow_tf32
old_precision = self.precision
try:
@@ -236,6 +239,11 @@ def tf32_on(self, tf32_precision=1e-5):
with torch.backends.cudnn.flags(enabled=None, benchmark=None, deterministic=None, allow_tf32=True):
yield
finally:
+ if torch.version.hip:
+ if hip_allow_tf32 is not None:
+ os.environ["HIPBLASLT_ALLOW_TF32"] = hip_allow_tf32
+ else:
+ del os.environ["HIPBLASLT_ALLOW_TF32"]
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
self.precision = old_precision