diff --git a/.github/scripts/benchmarks/gather_metadata.py b/.github/scripts/benchmarks/gather_metadata.py index 50011fb3f7..e38c8b5bdf 100755 --- a/.github/scripts/benchmarks/gather_metadata.py +++ b/.github/scripts/benchmarks/gather_metadata.py @@ -5,8 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import os import json +import os import time from typing import Any diff --git a/.github/scripts/get_tutorials_stats.py b/.github/scripts/get_tutorials_stats.py index f1f4c8ec18..a65c9fe002 100644 --- a/.github/scripts/get_tutorials_stats.py +++ b/.github/scripts/get_tutorials_stats.py @@ -10,6 +10,7 @@ import boto3 # type: ignore[import] + METADATA_PATH = "ossci_tutorials_stats/metadata.csv" FILENAMES_PATH = "ossci_tutorials_stats/filenames.csv" diff --git a/.github/scripts/update_commit_hashes.py b/.github/scripts/update_commit_hashes.py index a5407f42b2..59047aa2d9 100644 --- a/.github/scripts/update_commit_hashes.py +++ b/.github/scripts/update_commit_hashes.py @@ -6,6 +6,7 @@ import requests + UPDATEBOT_TOKEN = os.environ["UPDATEBOT_TOKEN"] PYTORCHBOT_TOKEN = os.environ["PYTORCHBOT_TOKEN"] diff --git a/.github/scripts/upload_benchmark_results.py b/.github/scripts/upload_benchmark_results.py index 338b20e5f2..f67c4962fd 100755 --- a/.github/scripts/upload_benchmark_results.py +++ b/.github/scripts/upload_benchmark_results.py @@ -14,13 +14,13 @@ from argparse import Action, ArgumentParser, Namespace from decimal import Decimal from json.decoder import JSONDecodeError - from logging import info from typing import Any, Callable, Dict, List, Optional from warnings import warn import boto3 + logging.basicConfig(level=logging.INFO) diff --git a/.github/scripts/validate_scale_config.py b/.github/scripts/validate_scale_config.py index 00c270e4f4..cff250e51c 100644 --- a/.github/scripts/validate_scale_config.py +++ b/.github/scripts/validate_scale_config.py @@ -9,16 +9,14 @@ import copy import json import os - import urllib.request from pathlib import Path - from typing import Any, cast, Dict, List, NamedTuple import jsonschema - import yaml + MAX_AVAILABLE_MINIMUM = 50 # Paths relative to their respective repositories diff --git a/tools/analytics/cubinsizes.py b/tools/analytics/cubinsizes.py index 33875057a7..8a07a627b2 100755 --- a/tools/analytics/cubinsizes.py +++ b/tools/analytics/cubinsizes.py @@ -12,27 +12,30 @@ try: from elftools.elf.elffile import ELFFile except ModuleNotFoundError: - print(f'elftools module not found, trying to install it from pip') + print(f"elftools module not found, trying to install it from pip") from pip._internal import main as pip_main + try: pip_main(["install", "pyelftools", "--user"]) except SystemExit: - print(f'PIP installation failed, please install it manually by invoking "{sys.executable} -mpip install pyelftools --user"') + print( + f'PIP installation failed, please install it manually by invoking "{sys.executable} -mpip install pyelftools --user"' + ) sys.exit(-1) from elftools.elf.elffile import ELFFile # From https://stackoverflow.com/questions/1094841/reusable-library-to-get-human-readable-version-of-file-size -def sizeof_fmt(num, suffix='B'): - for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: +def sizeof_fmt(num, suffix="B"): + for unit in ["", "Ki", "Mi", "Gi", "Ti", "Pi", "Ei", "Zi"]: if abs(num) < 1024.0: return "%3.1f%s%s" % (num, unit, suffix) num /= 1024.0 - return "%.1f%s%s" % (num, 'Yi', suffix) + return "%.1f%s%s" % (num, "Yi", suffix) -def compute_cubin_sizes(file_name, section_name='.nv_fatbin', debug=False): - with open(file_name, 'rb') as f: +def compute_cubin_sizes(file_name, section_name=".nv_fatbin", debug=False): + with open(file_name, "rb") as f: elf_file = ELFFile(f) nv_fatbin = elf_file.get_section_by_name(section_name) if nv_fatbin is None: @@ -41,20 +44,32 @@ def compute_cubin_sizes(file_name, section_name='.nv_fatbin', debug=False): idx, offs = 0, 0 elf_sizes = {} while offs < len(data): - (magic, version, header_size, fatbin_size) = struct.unpack('IHHL', data[offs: offs + 16]) - if magic != 0xba55ed50 or version != 1: - raise RuntimeError(f"Unexpected fatbin magic {hex(magic)} or version {version}") + (magic, version, header_size, fatbin_size) = struct.unpack( + "IHHL", data[offs : offs + 16] + ) + if magic != 0xBA55ED50 or version != 1: + raise RuntimeError( + f"Unexpected fatbin magic {hex(magic)} or version {version}" + ) if debug: - print(f"Found fatbin at {offs} header_size={header_size} fatbin_size={fatbin_size}") + print( + f"Found fatbin at {offs} header_size={header_size} fatbin_size={fatbin_size}" + ) offs += header_size fatbin_end = offs + fatbin_size while offs < fatbin_end: - (kind, version, hdr_size, elf_size, empty, code_ver, sm_ver) = struct.unpack('HHILLIH', data[offs: offs + 30]) + (kind, version, hdr_size, elf_size, empty, code_ver, sm_ver) = ( + struct.unpack("HHILLIH", data[offs : offs + 30]) + ) if version != 0x0101 or kind not in [1, 2]: - raise RuntimeError(f"Unexpected cubin version {hex(version)} or kind {kind}") + raise RuntimeError( + f"Unexpected cubin version {hex(version)} or kind {kind}" + ) sm_ver = f'{"ptx" if kind == 1 else "sm"}_{sm_ver}' if debug: - print(f" {idx}: elf_size={elf_size} code_ver={hex(code_ver)} sm={sm_ver}") + print( + f" {idx}: elf_size={elf_size} code_ver={hex(code_ver)} sm={sm_ver}" + ) if sm_ver not in elf_sizes: elf_sizes[sm_ver] = 0 elf_sizes[sm_ver] += elf_size @@ -71,7 +86,7 @@ def __init__(self, ar_name: str) -> None: def __enter__(self) -> str: self._pwd = os.getcwd() rc = self._tmpdir.__enter__() - subprocess.check_call(['ar', 'x', self.ar_name]) + subprocess.check_call(["ar", "x", self.ar_name]) return rc def __exit__(self, ex, value, tb) -> None: @@ -86,13 +101,16 @@ def dict_add(rc: Dict[str, int], b: Dict[str, int]) -> Dict[str, int]: def main(): - if sys.platform != 'linux': - print('This script only works with Linux ELF files') + if sys.platform != "linux": + print("This script only works with Linux ELF files") return if len(sys.argv) < 2: - print(f"{sys.argv[0]} invoked without any arguments trying to infer location of libtorch_cuda") + print( + f"{sys.argv[0]} invoked without any arguments trying to infer location of libtorch_cuda" + ) import torch - fname = os.path.join(os.path.dirname(torch.__file__), 'lib', 'libtorch_cuda.so') + + fname = os.path.join(os.path.dirname(torch.__file__), "lib", "libtorch_cuda.so") else: fname = sys.argv[1] @@ -100,26 +118,27 @@ def main(): print(f"Can't find {fname}") sys.exit(-1) - section_names = ['.nv_fatbin', '__nv_relfatbin'] + section_names = [".nv_fatbin", "__nv_relfatbin"] results = {name: {} for name in section_names} print(f"Analyzing {fname}") - if os.path.splitext(fname)[1] == '.a': + if os.path.splitext(fname)[1] == ".a": with ArFileCtx(fname): for fname in os.listdir("."): - if not fname.endswith(".o"): continue + if not fname.endswith(".o"): + continue for section_name in section_names: elf_sizes = compute_cubin_sizes(fname, section_name) dict_add(results[section_name], elf_sizes) else: - for section_name in ['.nv_fatbin', '__nv_relfatbin']: + for section_name in [".nv_fatbin", "__nv_relfatbin"]: dict_add(results[section_name], compute_cubin_sizes(fname, section_name)) for section_name in section_names: elf_sizes = results[section_name] print(f"{section_name} size {sizeof_fmt(sum(elf_sizes.values()))}") - for (sm_ver, total_size) in elf_sizes.items(): + for sm_ver, total_size in elf_sizes.items(): print(f" {sm_ver}: {sizeof_fmt(total_size)}") -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/tools/analytics/download_count_wheels.py b/tools/analytics/download_count_wheels.py index 277edaa164..39a2b83daa 100644 --- a/tools/analytics/download_count_wheels.py +++ b/tools/analytics/download_count_wheels.py @@ -1,16 +1,18 @@ -from collections import defaultdict -from datetime import datetime, timedelta, timezone import gzip import os import re import urllib +from collections import defaultdict +from datetime import datetime, timedelta, timezone -from tqdm import tqdm import boto3 +from tqdm import tqdm + + +S3 = boto3.resource("s3") +CLIENT = boto3.client("s3") +BUCKET = S3.Bucket("pytorch") -S3 = boto3.resource('s3') -CLIENT = boto3.client('s3') -BUCKET = S3.Bucket('pytorch') class CacheEntry: _size = None @@ -38,20 +40,15 @@ def target_arch(self) -> str: @property def package_name(self) -> str: - filename_contents = os.path.basename(self.download_uri).split('-') + filename_contents = os.path.basename(self.download_uri).split("-") return filename_contents[0] @property def package_version(self) -> str: if "dev" in self.download_uri: - results = re.search( - r"[0-9]+\.[0-9]+\.[0-9]+\.dev[0-9]+", - self.download_uri - ) + results = re.search(r"[0-9]+\.[0-9]+\.[0-9]+\.dev[0-9]+", self.download_uri) else: - results = re.search( - r"[0-9]+\.[0-9]+\.[0-9]+", self.download_uri - ) + results = re.search(r"[0-9]+\.[0-9]+\.[0-9]+", self.download_uri) if not results: raise Exception("Wtf there's no version o.O") return results[0] @@ -59,45 +56,40 @@ def package_version(self) -> str: @property def size(self) -> int: if self._size is None: - for key in BUCKET.objects.filter( - Prefix=self.download_uri.lstrip("/") - ): + for key in BUCKET.objects.filter(Prefix=self.download_uri.lstrip("/")): self._size = key.size if self._size is None: - raise Exception( - f"No object found for prefix {self.download_uri}" - ) + raise Exception(f"No object found for prefix {self.download_uri}") return self._size @property def downloads(self): return self.bytes_sent // self.size + def parse_logs(log_directory: str) -> dict: bytes_cache = {} - for (dirpath, _, filenames) in os.walk(log_directory): + for dirpath, _, filenames in os.walk(log_directory): for filename in tqdm(filenames): - with gzip.open(os.path.join(dirpath, filename), 'r') as gf: + with gzip.open(os.path.join(dirpath, filename), "r") as gf: string = gf.read().decode("utf-8") entries = [] entries += string.splitlines()[2:] for entry in entries: - columns = entry.split('\t') + columns = entry.split("\t") bytes_sent = int(columns[3]) - download_uri = urllib.parse.unquote( - urllib.parse.unquote(columns[7]) - ) + download_uri = urllib.parse.unquote(urllib.parse.unquote(columns[7])) status = columns[8] - if not all([ - status.startswith("2"), - download_uri.endswith((".whl", ".zip")) - ]): + if not all( + [status.startswith("2"), download_uri.endswith((".whl", ".zip"))] + ): continue if not bytes_cache.get(download_uri): bytes_cache[download_uri] = CacheEntry(download_uri) bytes_cache[download_uri].bytes_sent += bytes_sent return bytes_cache + def output_results(bytes_cache: dict) -> None: os_results = defaultdict(int) arch_results = defaultdict(int) @@ -106,25 +98,19 @@ def output_results(bytes_cache: dict) -> None: try: os_results[val.os_type] += val.downloads arch_results[val.target_arch] += val.downloads - package_results[val.package_name][val.package_version] += ( - val.downloads - ) + package_results[val.package_name][val.package_version] += val.downloads except Exception: pass print("=-=-= Results =-=-=") print("=-=-= OS =-=-=") total_os_num = sum(os_results.values()) for os_type, num in os_results.items(): - print( - f"\t* {os_type}: {num} ({(num/total_os_num)*100:.2f}%)" - ) + print(f"\t* {os_type}: {num} ({(num/total_os_num)*100:.2f}%)") print("=-=-= ARCH =-=-=") total_arch_num = sum(arch_results.values()) for arch_type, num in arch_results.items(): - print( - f"\t* {arch_type}: {num} ({(num/total_arch_num) * 100:.2f}%)" - ) + print(f"\t* {arch_type}: {num} ({(num/total_arch_num) * 100:.2f}%)") print("=-=-= By Package =-=-=") for package_name, upper_val in package_results.items(): @@ -135,11 +121,14 @@ def output_results(bytes_cache: dict) -> None: f"\t* {package_version}: {num} ({(num/total_package_num) * 100:.2f}%)" ) + def download_logs(log_directory: str, since: float): dt_now = datetime.now(timezone.utc) dt_end = datetime(dt_now.year, dt_now.month, dt_now.day, tzinfo=timezone.utc) - dt_start = dt_end - timedelta(days=1, hours=1) # Add 1 hour padding to account for potentially missed logs due to timing - for key in tqdm(BUCKET.objects.filter(Prefix='cflogs')): + dt_start = dt_end - timedelta( + days=1, hours=1 + ) # Add 1 hour padding to account for potentially missed logs due to timing + for key in tqdm(BUCKET.objects.filter(Prefix="cflogs")): remote_fname = key.key local_fname = os.path.join(log_directory, remote_fname) # Only download things from yesterday @@ -156,8 +145,8 @@ def download_logs(log_directory: str, since: float): if __name__ == "__main__": print("Downloading logs") - download_logs('cache', 1) + download_logs("cache", 1) print("Parsing logs") - cache = parse_logs('cache/cflogs/') + cache = parse_logs("cache/cflogs/") print("Calculating results") output_results(cache) diff --git a/tools/analytics/duplicates_analyze.py b/tools/analytics/duplicates_analyze.py index 8fdc3af22b..64dcaaf49a 100755 --- a/tools/analytics/duplicates_analyze.py +++ b/tools/analytics/duplicates_analyze.py @@ -1,41 +1,55 @@ #!/usr/bin/env python3 -from typing import Dict, List -from subprocess import check_output import os import sys +from subprocess import check_output +from typing import Dict, List def get_defined_symbols(fname: str, verbose: bool = False) -> Dict[str, int]: if verbose: - print(f"Processing {fname}...", end='', flush=True) - if sys.platform == 'darwin': - lines = check_output(['nm', '--defined-only', '-n', fname]).decode('ascii').split("\n")[:-1] + print(f"Processing {fname}...", end="", flush=True) + if sys.platform == "darwin": + lines = ( + check_output(["nm", "--defined-only", "-n", fname]) + .decode("ascii") + .split("\n")[:-1] + ) rc = {} for idx, line in enumerate(lines): - addr, stype, name = line.split(' ') - size = 4 if idx + 1 == len(lines) else (int(lines[idx + 1].split(' ')[0], 16) - int(addr, 16)) + addr, stype, name = line.split(" ") + size = ( + 4 + if idx + 1 == len(lines) + else (int(lines[idx + 1].split(" ")[0], 16) - int(addr, 16)) + ) rc[name] = size else: - lines = check_output(['nm', '--print-size', '--defined-only', fname]).decode('ascii').split('\n') - rc = {e[3]: int(e[1], 16) for e in [line.split() for line in lines] if len(e) == 4} + lines = ( + check_output(["nm", "--print-size", "--defined-only", fname]) + .decode("ascii") + .split("\n") + ) + rc = { + e[3]: int(e[1], 16) for e in [line.split() for line in lines] if len(e) == 4 + } if verbose: print("done") return rc def get_deps(fname: str) -> List[str]: - if sys.platform == 'darwin': + if sys.platform == "darwin": rc = [] - lines = check_output(['otool', '-l', fname]).decode('ascii').split("\n")[1:-1] + lines = check_output(["otool", "-l", fname]).decode("ascii").split("\n")[1:-1] for idx, line in enumerate(lines): - if line.strip() != 'cmd LC_LOAD_DYLIB': + if line.strip() != "cmd LC_LOAD_DYLIB": continue path = lines[idx + 2].strip() - assert path.startswith('name') - rc.append(os.path.basename(path.split(' ')[1])) + assert path.startswith("name") + rc.append(os.path.basename(path.split(" ")[1])) return rc - lines = check_output(['readelf', '--dynamic', fname]).decode('ascii').split('\n') - return [line.split('[')[1][:-1] for line in lines if '(NEEDED)' in line] + lines = check_output(["readelf", "--dynamic", fname]).decode("ascii").split("\n") + return [line.split("[")[1][:-1] for line in lines if "(NEEDED)" in line] def humansize(size): @@ -85,14 +99,18 @@ def print_symbols_overlap(libname1: str, libname2: str) -> None: sym_overlap = set(sym1.keys()).intersection(set(sym2.keys())) overlap_size = sum(sym1[s] for s in sym_overlap) if overlap_size == 0: - print(f"{libname1} symbols size {humansize(sym1_size)} does not overlap with {libname2}") + print( + f"{libname1} symbols size {humansize(sym1_size)} does not overlap with {libname2}" + ) return - print(f"{libname1} symbols size {humansize(sym1_size)} overlap {humansize(overlap_size)} ({100.0 * overlap_size/sym1_size :.2f}%)") + print( + f"{libname1} symbols size {humansize(sym1_size)} overlap {humansize(overlap_size)} ({100.0 * overlap_size/sym1_size :.2f}%)" + ) for sym in sym_overlap: print(sym) -if __name__ == '__main__': +if __name__ == "__main__": if len(sys.argv) == 3: print_symbols_overlap(sys.argv[1], sys.argv[2]) else: diff --git a/tools/analytics/github_analyze.py b/tools/analytics/github_analyze.py index b6a37aaf3a..fe1ec14582 100755 --- a/tools/analytics/github_analyze.py +++ b/tools/analytics/github_analyze.py @@ -1,12 +1,13 @@ #!/usr/bin/env python3 -from datetime import datetime, timedelta -from typing import Any, Dict, List, Iterable, Optional, Union -from urllib.request import urlopen, Request -from urllib.error import HTTPError -import json import enum +import json import os +from datetime import datetime, timedelta +from typing import Any, Dict, Iterable, List, Optional, Union +from urllib.error import HTTPError +from urllib.request import Request, urlopen + class IssueState(enum.Enum): OPEN = "open" @@ -26,14 +27,16 @@ class GitCommit: commit_date: Optional[datetime] pr_url: str - def __init__(self, - commit_hash: str, - author: str, - author_date: datetime, - title: str, - body: str, - pr_url : str, - commit_date: Optional[datetime] = None) -> None: + def __init__( + self, + commit_hash: str, + author: str, + author_date: datetime, + title: str, + body: str, + pr_url: str, + commit_date: Optional[datetime] = None, + ) -> None: self.commit_hash = commit_hash self.author = author self.author_date = author_date @@ -53,14 +56,18 @@ def is_issue_mentioned(self, issue_url: str) -> bool: issue_hash = f"#{issue_url.split('issues/')[1]}" if "fixes" in self.title.lower() and issue_hash in self.title: return True - return any("fixes" in line.lower() and issue_hash in line for line in self.body.split("\n")) + return any( + "fixes" in line.lower() and issue_hash in line + for line in self.body.split("\n") + ) def get_revert_revision(commit: GitCommit) -> Optional[str]: import re + body_rc = re.search("Original Phabricator Diff: (D\\d+)", commit.body) - if commit.title.startswith("Back out \"") and body_rc is not None: + if commit.title.startswith('Back out "') and body_rc is not None: return body_rc.group(1) rc = re.match("Revert (D\\d+):", commit.title) @@ -71,6 +78,7 @@ def get_revert_revision(commit: GitCommit) -> Optional[str]: def get_diff_revision(commit: GitCommit) -> Optional[str]: import re + rc = re.search("\\s*Differential Revision: (D\\d+)", commit.body) if rc is None: return None @@ -79,18 +87,25 @@ def get_diff_revision(commit: GitCommit) -> Optional[str]: def get_ghf_revert_revision(commit: GitCommit) -> Optional[str]: import re + rc = re.search("\\s*This reverts commit ([0-9a-f]+).", commit.body) - if all([ - commit.title.startswith("Revert"), - commit.author == "PyTorch MergeBot ", - rc is not None - ]): + if all( + [ + commit.title.startswith("Revert"), + commit.author + == "PyTorch MergeBot ", + rc is not None, + ] + ): return rc.group(1) return None def is_revert(commit: GitCommit) -> bool: - return get_revert_revision(commit) is not None or get_ghf_revert_revision(commit) is not None + return ( + get_revert_revision(commit) is not None + or get_ghf_revert_revision(commit) is not None + ) def parse_medium_format(lines: Union[str, List[str]]) -> GitCommit: @@ -115,12 +130,13 @@ def parse_medium_format(lines: Union[str, List[str]]) -> GitCommit: assert lines[1].startswith("Author: ") assert lines[2].startswith("Date: ") assert len(lines[3]) == 0 - return GitCommit(commit_hash=lines[0].split()[1].strip(), - author=lines[1].split(":", 1)[1].strip(), - author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), - title=lines[4].strip(), - body="\n".join(lines[5:]), - ) + return GitCommit( + commit_hash=lines[0].split()[1].strip(), + author=lines[1].split(":", 1)[1].strip(), + author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), + title=lines[4].strip(), + body="\n".join(lines[5:]), + ) def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: @@ -156,37 +172,43 @@ def parse_fuller_format(lines: Union[str, List[str]]) -> GitCommit: prUrl = line.split("Pull Request resolved:")[1].strip() break - return GitCommit(commit_hash=lines[0].split()[1].strip(), - author=lines[1].split(":", 1)[1].strip(), - author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), - commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())), - title=lines[6].strip(), - body="\n".join(lines[7:]), - pr_url=prUrl, - ) + return GitCommit( + commit_hash=lines[0].split()[1].strip(), + author=lines[1].split(":", 1)[1].strip(), + author_date=datetime.fromtimestamp(int(lines[2].split(":", 1)[1].strip())), + commit_date=datetime.fromtimestamp(int(lines[4].split(":", 1)[1].strip())), + title=lines[6].strip(), + body="\n".join(lines[7:]), + pr_url=prUrl, + ) -def _check_output(items: List[str], encoding='utf-8') -> str: +def _check_output(items: List[str], encoding="utf-8") -> str: from subprocess import check_output + return check_output(items).decode(encoding) def get_git_remotes(path: str) -> Dict[str, str]: keys = _check_output(["git", "-C", path, "remote"]).strip().split("\n") - return {key: _check_output(["git", "-C", path, "remote", "get-url", key]).strip() for key in keys} + return { + key: _check_output(["git", "-C", path, "remote", "get-url", key]).strip() + for key in keys + } class GitRepo: - def __init__(self, path, remote='upstream'): + def __init__(self, path, remote="upstream"): self.repo_dir = path self.remote = remote def _run_git_cmd(self, *args) -> str: - return _check_output(['git', '-C', self.repo_dir] + list(args)) + return _check_output(["git", "-C", self.repo_dir] + list(args)) def _run_git_log(self, revision_range) -> List[GitCommit]: - log = self._run_git_cmd('log', '--format=fuller', - '--date=unix', revision_range, '--', '.').split("\n") + log = self._run_git_cmd( + "log", "--format=fuller", "--date=unix", revision_range, "--", "." + ).split("\n") rc: List[GitCommit] = [] cur_msg: List[str] = [] for line in log: @@ -203,7 +225,14 @@ def get_commit_list(self, from_ref, to_ref) -> List[GitCommit]: return self._run_git_log(f"{self.remote}/{from_ref}..{self.remote}/{to_ref}") def get_ghstack_orig_branches(self) -> List[str]: - return [x.strip() for x in self._run_git_cmd("branch", "--remotes", "--list", self.remote + "/gh/*/orig").strip().split("\n")] + return [ + x.strip() + for x in self._run_git_cmd( + "branch", "--remotes", "--list", self.remote + "/gh/*/orig" + ) + .strip() + .split("\n") + ] def show_ref(self, ref) -> str: return self._run_git_cmd("show-ref", ref).split(" ")[0] @@ -212,7 +241,9 @@ def merge_base(self, ref1, ref2) -> str: return self._run_git_cmd("merge-base", ref1, ref2).strip() def rev_list(self, ref): - return self._run_git_cmd("rev-list", f"{self.remote}/main..{ref}").strip().split() + return ( + self._run_git_cmd("rev-list", f"{self.remote}/main..{ref}").strip().split() + ) def build_commit_dict(commits: List[GitCommit]) -> Dict[str, GitCommit]: @@ -223,22 +254,31 @@ def build_commit_dict(commits: List[GitCommit]) -> Dict[str, GitCommit]: return rc -def fetch_json(url: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: - headers = {'Accept': 'application/vnd.github.v3+json'} +def fetch_json( + url: str, params: Optional[Dict[str, Any]] = None +) -> List[Dict[str, Any]]: + headers = {"Accept": "application/vnd.github.v3+json"} token = os.environ.get("GITHUB_TOKEN") - if token is not None and url.startswith('https://api.github.com/'): - headers['Authorization'] = f'token {token}' + if token is not None and url.startswith("https://api.github.com/"): + headers["Authorization"] = f"token {token}" if params is not None and len(params) > 0: - url += '?' + '&'.join(f"{name}={val}" for name, val in params.items()) + url += "?" + "&".join(f"{name}={val}" for name, val in params.items()) try: with urlopen(Request(url, headers=headers)) as data: return json.load(data) except HTTPError as err: - if err.code == 403 and all(key in err.headers for key in ['X-RateLimit-Limit', 'X-RateLimit-Used']): - print(f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}") + if err.code == 403 and all( + key in err.headers for key in ["X-RateLimit-Limit", "X-RateLimit-Used"] + ): + print( + f"Rate limit exceeded: {err.headers['X-RateLimit-Used']}/{err.headers['X-RateLimit-Limit']}" + ) raise -def fetch_multipage_json(url: str, params: Optional[Dict[str, Any]] = None) -> List[Dict[str, Any]]: + +def fetch_multipage_json( + url: str, params: Optional[Dict[str, Any]] = None +) -> List[Dict[str, Any]]: if params is None: params = {} assert "page" not in params @@ -251,17 +291,22 @@ def fetch_multipage_json(url: str, params: Optional[Dict[str, Any]] = None) -> L return rc -def gh_get_milestones(org='pytorch', project='pytorch', state: IssueState = IssueState.OPEN) -> List[Dict[str, Any]]: - url = f'https://api.github.com/repos/{org}/{project}/milestones' +def gh_get_milestones( + org="pytorch", project="pytorch", state: IssueState = IssueState.OPEN +) -> List[Dict[str, Any]]: + url = f"https://api.github.com/repos/{org}/{project}/milestones" return fetch_multipage_json(url, {"state": state}) -def gh_get_milestone_issues(org: str, project: str, milestone_idx: int, state: IssueState = IssueState.OPEN): - url = f'https://api.github.com/repos/{org}/{project}/issues' + +def gh_get_milestone_issues( + org: str, project: str, milestone_idx: int, state: IssueState = IssueState.OPEN +): + url = f"https://api.github.com/repos/{org}/{project}/issues" return fetch_multipage_json(url, {"milestone": milestone_idx, "state": state}) def gh_get_ref_statuses(org: str, project: str, ref: str) -> Dict[str, Any]: - url = f'https://api.github.com/repos/{org}/{project}/commits/{ref}/status' + url = f"https://api.github.com/repos/{org}/{project}/commits/{ref}/status" params = {"page": 1, "per_page": 100} nrc = rc = fetch_json(url, params) while "statuses" in nrc and len(nrc["statuses"]) == 100: @@ -271,11 +316,13 @@ def gh_get_ref_statuses(org: str, project: str, ref: str) -> Dict[str, Any]: rc["statuses"] += nrc["statuses"] return rc -def get_issue_comments(org: str, project: str, issue_number : int): - url = f'https://api.github.com/repos/{org}/{project}/issues/{issue_number}/comments' + +def get_issue_comments(org: str, project: str, issue_number: int): + url = f"https://api.github.com/repos/{org}/{project}/issues/{issue_number}/comments" return fetch_multipage_json(url) + def extract_statuses_map(json: Dict[str, Any]): return {s["context"]: s["state"] for s in json["statuses"]} @@ -286,7 +333,9 @@ class PeriodStats: authors: int date: datetime - def __init__(self, date: datetime, commits: int, reverts: int, authors: int) -> None: + def __init__( + self, date: datetime, commits: int, reverts: int, authors: int + ) -> None: self.date = date self.commits = commits self.reverts = reverts @@ -296,11 +345,19 @@ def __init__(self, date: datetime, commits: int, reverts: int, authors: int) -> def get_monthly_stats(commits: List[GitCommit]) -> Iterable[PeriodStats]: y, m, total, reverts, authors = None, None, 0, 0, set() for commit in commits: - commit_date = commit.commit_date if commit.commit_date is not None else commit.author_date + commit_date = ( + commit.commit_date if commit.commit_date is not None else commit.author_date + ) if y != commit_date.year or m != commit_date.month: if y is not None: yield PeriodStats(datetime(y, m, 1), total, reverts, len(authors)) - y, m, total, reverts, authors = commit_date.year, commit_date.month, 0, 0, set() + y, m, total, reverts, authors = ( + commit_date.year, + commit_date.month, + 0, + 0, + set(), + ) if is_revert(commit): reverts += 1 total += 1 @@ -317,8 +374,10 @@ def print_monthly_stats(commits: List[GitCommit]) -> None: if idx + 1 < len(stats): commits_growth = 100.0 * (stat.commits / stats[idx + 1].commits - 1) else: - commits_growth = float('nan') - print(f"{y}-{m:02d}: commits {total} ({commits_growth:+.1f}%) reverts {reverts} ({reverts_ratio:.1f}%) authors {authors}") + commits_growth = float("nan") + print( + f"{y}-{m:02d}: commits {total} ({commits_growth:+.1f}%) reverts {reverts} ({reverts_ratio:.1f}%) authors {authors}" + ) def print_reverts(commits: List[GitCommit]) -> None: @@ -341,9 +400,13 @@ def analyze_reverts(commits: List[GitCommit]): if orig_commit is None: print(f"Failed to find original commit for {commit.title}") continue - print(f"{commit.commit_hash} is a revert of {orig_commit.commit_hash}: {orig_commit.title}") + print( + f"{commit.commit_hash} is a revert of {orig_commit.commit_hash}: {orig_commit.title}" + ) revert_statuses = gh_get_ref_statuses("pytorch", "pytorch", commit.commit_hash) - orig_statuses = gh_get_ref_statuses("pytorch", "pytorch", orig_commit.commit_hash) + orig_statuses = gh_get_ref_statuses( + "pytorch", "pytorch", orig_commit.commit_hash + ) orig_sm = extract_statuses_map(orig_statuses) revert_sm = extract_statuses_map(revert_statuses) for k in revert_sm.keys(): @@ -367,39 +430,65 @@ def print_contributor_stats(commits, delta: Optional[timedelta] = None) -> None: authors[author] = 0 authors[author] += 1 - print(f"{len(authors)} contributors made {sum(authors.values())} commits in last {delta.days} days") - for count, author in sorted(((commit, author) for author, commit in authors.items()), reverse=True): + print( + f"{len(authors)} contributors made {sum(authors.values())} commits in last {delta.days} days" + ) + for count, author in sorted( + ((commit, author) for author, commit in authors.items()), reverse=True + ): print(f"{author}: {count}") -def commits_missing_in_branch(repo: GitRepo, branch: str, orig_branch: str, milestone_idx: int) -> None: +def commits_missing_in_branch( + repo: GitRepo, branch: str, orig_branch: str, milestone_idx: int +) -> None: def get_commits_dict(x, y): return build_commit_dict(repo.get_commit_list(x, y)) - main_commits = get_commits_dict(orig_branch, 'main') + + main_commits = get_commits_dict(orig_branch, "main") release_commits = get_commits_dict(orig_branch, branch) print(f"len(main_commits)={len(main_commits)}") print(f"len(release_commits)={len(release_commits)}") print("URL;Title;Status") - for issue in gh_get_milestone_issues('pytorch', 'pytorch', milestone_idx, IssueState.ALL): + for issue in gh_get_milestone_issues( + "pytorch", "pytorch", milestone_idx, IssueState.ALL + ): issue_url, state = issue["html_url"], issue["state"] # Skip closed states if they were landed before merge date if state == "closed": - mentioned_after_cut = any(commit.is_issue_mentioned(issue_url) for commit in main_commits.values()) + mentioned_after_cut = any( + commit.is_issue_mentioned(issue_url) for commit in main_commits.values() + ) # If issue is not mentioned after cut, that it must be present in release branch if not mentioned_after_cut: continue - mentioned_in_release = any(commit.is_issue_mentioned(issue_url) for commit in release_commits.values()) + mentioned_in_release = any( + commit.is_issue_mentioned(issue_url) + for commit in release_commits.values() + ) # if Issue is mentioned is release branch, than it was picked already if mentioned_in_release: continue print(f'{issue_url};{issue["title"]};{state}') -def commits_missing_in_release(repo: GitRepo, branch: str, orig_branch: str, minor_release: str, milestone_idx: int, cut_off_date : datetime, issue_num : int) -> None: + +def commits_missing_in_release( + repo: GitRepo, + branch: str, + orig_branch: str, + minor_release: str, + milestone_idx: int, + cut_off_date: datetime, + issue_num: int, +) -> None: def get_commits_dict(x, y): return build_commit_dict(repo.get_commit_list(x, y)) - main_commits = get_commits_dict(minor_release, 'main') + + main_commits = get_commits_dict(minor_release, "main") prev_release_commits = get_commits_dict(orig_branch, branch) - current_issue_comments = get_issue_comments('pytorch', 'pytorch',issue_num) # issue comments for the release tracker as cherry picks + current_issue_comments = get_issue_comments( + "pytorch", "pytorch", issue_num + ) # issue comments for the release tracker as cherry picks print(f"len(main_commits)={len(main_commits)}") print(f"len(prev_release_commits)={len(prev_release_commits)}") print(f"len(current_issue_comments)={len(current_issue_comments)}") @@ -408,58 +497,70 @@ def get_commits_dict(x, y): # Iterate over the previous release branch to find potentially missing cherry picks in the current issue. for commit in prev_release_commits.values(): - not_cherry_picked_in_current_issue = any(commit.pr_url not in issue_comment['body'] for issue_comment in current_issue_comments) + not_cherry_picked_in_current_issue = any( + commit.pr_url not in issue_comment["body"] + for issue_comment in current_issue_comments + ) for main_commit in main_commits.values(): - if main_commit.pr_url == commit.pr_url : + if main_commit.pr_url == commit.pr_url: mentioned_after_cut_off_date = cut_off_date < main_commit.commit_date if not_cherry_picked_in_current_issue and mentioned_after_cut_off_date: # Commits that are release only, which exist in previous release branch and not in main. - print(f'{commit.pr_url};{commit.title};{commit.commit_date}') + print(f"{commit.pr_url};{commit.title};{commit.commit_date}") break + def analyze_stacks(repo: GitRepo) -> None: from tqdm.contrib.concurrent import thread_map + branches = repo.get_ghstack_orig_branches() stacks_by_author: Dict[str, List[int]] = {} - for branch,rv_commits in thread_map(lambda x: (x, repo.rev_list(x)), branches, max_workers=10): + for branch, rv_commits in thread_map( + lambda x: (x, repo.rev_list(x)), branches, max_workers=10 + ): author = branch.split("/")[2] if author not in stacks_by_author: - stacks_by_author[author]=[] + stacks_by_author[author] = [] stacks_by_author[author].append(len(rv_commits)) - for author, slen in sorted(stacks_by_author.items(), key=lambda x:len(x[1]), reverse=True): + for author, slen in sorted( + stacks_by_author.items(), key=lambda x: len(x[1]), reverse=True + ): if len(slen) == 1: print(f"{author} has 1 stack of depth {slen[0]}") continue - print(f"{author} has {len(slen)} stacks max depth is {max(slen)} avg depth is {sum(slen)/len(slen):.2f} mean is {slen[len(slen)//2]}") + print( + f"{author} has {len(slen)} stacks max depth is {max(slen)} avg depth is {sum(slen)/len(slen):.2f} mean is {slen[len(slen)//2]}" + ) def parse_arguments(): from argparse import ArgumentParser + parser = ArgumentParser(description="Print GitHub repo stats") - parser.add_argument("--repo-path", - type=str, - help="Path to PyTorch git checkout", - default=os.path.expanduser("~/git/pytorch/pytorch")) + parser.add_argument( + "--repo-path", + type=str, + help="Path to PyTorch git checkout", + default=os.path.expanduser("~/git/pytorch/pytorch"), + ) parser.add_argument("--milestone-id", type=str) parser.add_argument("--branch", type=str) parser.add_argument("--minor-release", type=str) - parser.add_argument("--remote", - type=str, - help="Remote to base off of", - default="") + parser.add_argument("--remote", type=str, help="Remote to base off of", default="") parser.add_argument("--analyze-reverts", action="store_true") parser.add_argument("--print-reverts", action="store_true") parser.add_argument("--contributor-stats", action="store_true") parser.add_argument("--missing-in-branch", action="store_true") parser.add_argument("--missing-in-release", action="store_true") parser.add_argument("--analyze-stacks", action="store_true") - parser.add_argument('--date', type=lambda d: datetime.strptime(d, '%Y-%m-%d')) + parser.add_argument("--date", type=lambda d: datetime.strptime(d, "%Y-%m-%d")) parser.add_argument("--issue-num", type=int) return parser.parse_args() def main(): import time + args = parse_arguments() remote = args.remote if not remote: @@ -467,7 +568,7 @@ def main(): # Pick best remote remote = next(iter(remotes.keys())) for key in remotes: - if remotes[key].endswith('github.com/pytorch/pytorch'): + if remotes[key].endswith("github.com/pytorch/pytorch"): remote = key repo = GitRepo(args.repo_path, remote) @@ -483,31 +584,31 @@ def main(): milestone_idx = -1 milestones = gh_get_milestones() for milestone in milestones: - if milestone.get('title', '') == args.milestone_id: - milestone_idx = int(milestone.get('number', '-2')) + if milestone.get("title", "") == args.milestone_id: + milestone_idx = int(milestone.get("number", "-2")) if milestone_idx < 0: - print(f'Could not find milestone {args.milestone_id}') + print(f"Could not find milestone {args.milestone_id}") return if args.missing_in_branch: - commits_missing_in_branch(repo, - args.branch, - f'orig/{args.branch}', - milestone_idx) + commits_missing_in_branch( + repo, args.branch, f"orig/{args.branch}", milestone_idx + ) return if args.missing_in_release: - commits_missing_in_release(repo, - args.branch, - f'orig/{args.branch}', - args.minor_release, - milestone_idx, - args.date, - args.issue_num - ) + commits_missing_in_release( + repo, + args.branch, + f"orig/{args.branch}", + args.minor_release, + milestone_idx, + args.date, + args.issue_num, + ) return - print(f"Parsing git history with remote {remote}...", end='', flush=True) + print(f"Parsing git history with remote {remote}...", end="", flush=True) start_time = time.time() x = repo._run_git_log(f"{remote}/main") print(f"done in {time.time()-start_time:.1f} sec") @@ -516,7 +617,7 @@ def main(): elif args.contributor_stats: print_contributor_stats(x) elif args.print_reverts: - print_reverts(x[:2**9]) + print_reverts(x[: 2**9]) else: print_monthly_stats(x) diff --git a/tools/analytics/s3_test_stats_analyze.py b/tools/analytics/s3_test_stats_analyze.py index 78ea3a7fd8..5e99c993b7 100644 --- a/tools/analytics/s3_test_stats_analyze.py +++ b/tools/analytics/s3_test_stats_analyze.py @@ -1,20 +1,20 @@ import argparse -import boto3 import bz2 import json import os import re -import requests +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional +import boto3 import pandas as pd - -from datetime import datetime, timedelta +import requests from tqdm import tqdm -from typing import Any, Dict, Optional, List -S3 = boto3.resource('s3') -CLIENT = boto3.client('s3') -BUCKET = S3.Bucket('ossci-metrics') + +S3 = boto3.resource("s3") +CLIENT = boto3.client("s3") +BUCKET = S3.Bucket("ossci-metrics") GITHUB_API_BASE = "https://api.github.com/" GITHUB_COMMITS_API = "repos/pytorch/pytorch/commits" @@ -22,42 +22,50 @@ CACHE_PICKLE = "cache/test_time/dataframe.pickle" + def _get_latests_git_commit_sha_list(lookback: int): - sha_since = (datetime.utcnow() - timedelta(hours = lookback)).strftime(STRF_FORMAT) + sha_since = (datetime.utcnow() - timedelta(hours=lookback)).strftime(STRF_FORMAT) resp = requests.get(GITHUB_API_BASE + GITHUB_COMMITS_API + f"?since={sha_since}") if resp.status_code == 200: - return [e.get('sha') for e in resp.json()] + return [e.get("sha") for e in resp.json()] else: return [] + def _json_to_df(data: Dict[str, Any], granularity: str) -> pd.DataFrame: reformed_data = list() - for fname, fdata in data['files'].items(): - if granularity == 'file': - reformed_data.append({ - "job": data['job'], - "sha": data['sha'], - 'file': fname, - 'file_total_sec': fdata['total_seconds'], - }) + for fname, fdata in data["files"].items(): + if granularity == "file": + reformed_data.append( + { + "job": data["job"], + "sha": data["sha"], + "file": fname, + "file_total_sec": fdata["total_seconds"], + } + ) else: - for sname, sdata in fdata['suites'].items(): - if granularity == 'suite': - reformed_data.append({ - "job": data['job'], - "sha": data['sha'], - 'suite': sname, - 'suite_total_sec': sdata['total_seconds'], - }) + for sname, sdata in fdata["suites"].items(): + if granularity == "suite": + reformed_data.append( + { + "job": data["job"], + "sha": data["sha"], + "suite": sname, + "suite_total_sec": sdata["total_seconds"], + } + ) else: - for cname, cdata in sdata['cases'].items(): - reformed_data.append({ - "job": data['job'], - "sha": data['sha'], - 'case': cname, - 'case_status': cdata['status'], - 'case_sec': cdata['seconds'], - }) + for cname, cdata in sdata["cases"].items(): + reformed_data.append( + { + "job": data["job"], + "sha": data["sha"], + "case": cname, + "case_status": cdata["status"], + "case_sec": cdata["seconds"], + } + ) df = pd.json_normalize(reformed_data) return df @@ -65,7 +73,7 @@ def _json_to_df(data: Dict[str, Any], granularity: str) -> pd.DataFrame: def download_stats(folder: str, lookback: int): commit_sha_list = _get_latests_git_commit_sha_list(lookback) for commit_sha in commit_sha_list: - for key in tqdm(BUCKET.objects.filter(Prefix=f'test_time/{commit_sha}')): + for key in tqdm(BUCKET.objects.filter(Prefix=f"test_time/{commit_sha}")): remote_fname = key.key local_fname = os.path.join(folder, remote_fname) # TODO: Do this in parallel @@ -79,20 +87,22 @@ def download_stats(folder: str, lookback: int): CLIENT.download_file("ossci-metrics", remote_fname, local_fname) -def parse_and_export_stats(folder: str, granularity: str, commit_sha_lists: Optional[List[str]] = None): +def parse_and_export_stats( + folder: str, granularity: str, commit_sha_lists: Optional[List[str]] = None +): dataframe = None - for (dirpath, _, filenames) in os.walk(folder): + for dirpath, _, filenames in os.walk(folder): for filename in tqdm(filenames): splits = dirpath.split("/") job_name = splits[-1] sha = splits[-2] if not commit_sha_lists or sha in commit_sha_lists: - with bz2.open(os.path.join(dirpath, filename), 'r') as zf: + with bz2.open(os.path.join(dirpath, filename), "r") as zf: string = zf.read().decode("utf-8") data = json.loads(string) # create a deep json with sha and job info - data['sha'] = sha - data['job'] = job_name + data["sha"] = sha + data["job"] = job_name df = _json_to_df(data, granularity) dataframe = df if dataframe is None else dataframe.append(df) return dataframe @@ -105,26 +115,26 @@ def main(): formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( - '--lookback', + "--lookback", type=int, - help='lookback in # of hours', + help="lookback in # of hours", default=24, ) parser.add_argument( - '--output', - help='output filename', - default='cache/df.pickle', + "--output", + help="output filename", + default="cache/df.pickle", ) parser.add_argument( - '--cache_folder', - help='cache folder', - default='cache', + "--cache_folder", + help="cache folder", + default="cache", ) parser.add_argument( - '--granularity', - choices=['file', 'suite', 'case'], - help='granularity of stats summary', - default='file', + "--granularity", + choices=["file", "suite", "case"], + help="granularity of stats summary", + default="file", ) args = parser.parse_args() @@ -137,10 +147,9 @@ def main(): download_stats(cache_folder, lookback) print("Parsing test stats and write to pd dataframe") if not os.path.exists(output): - dataframe = parse_and_export_stats(f'{cache_folder}/test_time/', granularity) + dataframe = parse_and_export_stats(f"{cache_folder}/test_time/", granularity) dataframe.to_pickle(output) - if __name__ == "__main__": main() diff --git a/tools/analytics/validate_binaries.py b/tools/analytics/validate_binaries.py index 65965c59ad..5686c4e5ab 100644 --- a/tools/analytics/validate_binaries.py +++ b/tools/analytics/validate_binaries.py @@ -1,13 +1,13 @@ +import json +from datetime import datetime + from conda.cli.python_api import Commands, run_command from tabulate import tabulate -from datetime import datetime -import json + PLATFORMS = ["osx-64", "linux-64", "win-64"] PYTHON_VERSIONS = ["3.10", "3.9", "3.8", "3.7"] -CUDA_CUDNN_VERSION = [ - ("11.7", "8.5.0"), ("cpu", None) -] +CUDA_CUDNN_VERSION = [("11.7", "8.5.0"), ("cpu", None)] CHANNEL = "pytorch-test" VERSION = "1.13.*" @@ -47,7 +47,10 @@ def main() -> None: # Actual builds available in Conda stdout, stderr, return_code = run_command( - Commands.SEARCH, f"{CHANNEL}::*[name=pytorch version={VERSION} subdir={platform}]", "--json") + Commands.SEARCH, + f"{CHANNEL}::*[name=pytorch version={VERSION} subdir={platform}]", + "--json", + ) if return_code != 0: raise Exception(stderr) @@ -58,18 +61,22 @@ def main() -> None: actual_builds = set() for version in available_versions["pytorch"]: actual_builds.add(version["build"]) - output_data.append(( - version["fn"], - datetime.fromtimestamp(version["timestamp"] / 1000), - size_format(version["size"]) - )) + output_data.append( + ( + version["fn"], + datetime.fromtimestamp(version["timestamp"] / 1000), + size_format(version["size"]), + ) + ) assert len(expected_builds) > 0, "expected builds set should not be empty." - assert expected_builds == actual_builds, ( - f"Missing following builds in conda: {expected_builds.difference(actual_builds)} for platform {platform}" - ) + assert ( + expected_builds == actual_builds + ), f"Missing following builds in conda: {expected_builds.difference(actual_builds)} for platform {platform}" - print(f"\nSuccessfully verified following binaries are available in Conda for {platform}...") + print( + f"\nSuccessfully verified following binaries are available in Conda for {platform}..." + ) print(tabulate(output_data, headers=headers, tablefmt="grid")) diff --git a/tools/analytics/validate_pypi_staging.py b/tools/analytics/validate_pypi_staging.py index 1be8b0f852..8e8801ff8c 100644 --- a/tools/analytics/validate_pypi_staging.py +++ b/tools/analytics/validate_pypi_staging.py @@ -9,26 +9,21 @@ import boto3 import botocore + PLATFORMS = [ "manylinux1_x86_64", "manylinux2014_aarch64", "win_amd64", "macosx_11_0_arm64", ] -PYTHON_VERSIONS = [ - "cp38", - "cp39", - "cp310", - "cp311", - "cp312" - ] +PYTHON_VERSIONS = ["cp38", "cp39", "cp310", "cp311", "cp312"] S3_PYPI_STAGING = "pytorch-backup" PACKAGE_RELEASES = { "torch": "2.3.1", "torchvision": "0.18.1", "torchaudio": "2.3.1", "torchtext": "0.18.0", - "executorch": "0.2.1" + "executorch": "0.2.1", } PATTERN_V = "Version:" diff --git a/tools/binary_size_validation/binary_size_validation.py b/tools/binary_size_validation/binary_size_validation.py index b0725bc215..0034fc3389 100644 --- a/tools/binary_size_validation/binary_size_validation.py +++ b/tools/binary_size_validation/binary_size_validation.py @@ -9,6 +9,7 @@ import requests from bs4 import BeautifulSoup + Wheel = namedtuple("Wheel", ["name", "url"]) diff --git a/tools/binary_size_validation/test_binary_size_validation.py b/tools/binary_size_validation/test_binary_size_validation.py index dee89efca9..1bea38bbc9 100644 --- a/tools/binary_size_validation/test_binary_size_validation.py +++ b/tools/binary_size_validation/test_binary_size_validation.py @@ -1,5 +1,6 @@ from binary_size_validation import parse_index + # ignore long lines in this file # flake8: noqa: E501 test_html = """ diff --git a/tools/clang-tidy-checks/check_s3.py b/tools/clang-tidy-checks/check_s3.py index 52c1384d3a..9899d357f8 100644 --- a/tools/clang-tidy-checks/check_s3.py +++ b/tools/clang-tidy-checks/check_s3.py @@ -8,10 +8,11 @@ s3 path and hash to the lintrunner s3 init config: https://github.com/pytorch/pytorch/blob/915625307eeda338fef00c984e223c5774c00a2b/tools/linter/adapters/s3_init_config.json#L1 """ -import hashlib -from urllib.request import Request, urlopen + import argparse +import hashlib from urllib.error import HTTPError +from urllib.request import Request, urlopen def download_s3_file(s3_key): diff --git a/tools/linter/adapters/exec_linter.py b/tools/linter/adapters/exec_linter.py index f00dc60afb..d3cfcf8d16 100644 --- a/tools/linter/adapters/exec_linter.py +++ b/tools/linter/adapters/exec_linter.py @@ -1,15 +1,16 @@ """ EXEC: Ensure that source files are not executable. """ + import argparse import json import logging import os import sys - from enum import Enum from typing import NamedTuple, Optional + LINTER_CODE = "EXEC" diff --git a/tools/linter/adapters/newlines_linter.py b/tools/linter/adapters/newlines_linter.py index a2cb1c5ccd..c3dadc7715 100644 --- a/tools/linter/adapters/newlines_linter.py +++ b/tools/linter/adapters/newlines_linter.py @@ -1,14 +1,15 @@ """ NEWLINE: Checks files to make sure there are no trailing newlines. """ + import argparse import json import logging import sys - from enum import Enum from typing import List, NamedTuple, Optional + NEWLINE = 10 # ASCII "\n" CARRIAGE_RETURN = 13 # ASCII "\r" LINTER_CODE = "NEWLINE" diff --git a/tools/linter/adapters/pip_init.py b/tools/linter/adapters/pip_init.py index f177a920d0..7b7de3bbd9 100644 --- a/tools/linter/adapters/pip_init.py +++ b/tools/linter/adapters/pip_init.py @@ -1,13 +1,13 @@ """ Initializer script that installs stuff to pip. """ + import argparse import logging import os import subprocess import sys import time - from typing import List diff --git a/tools/linter/adapters/s3_init.py b/tools/linter/adapters/s3_init.py index c3d6e8e03c..b954e62423 100644 --- a/tools/linter/adapters/s3_init.py +++ b/tools/linter/adapters/s3_init.py @@ -11,6 +11,7 @@ import urllib.request from pathlib import Path + # String representing the host platform (e.g. Linux, Darwin). HOST_PLATFORM = platform.system() HOST_PLATFORM_ARCH = platform.system() + "-" + platform.processor() diff --git a/tools/linter/convert_to_sarif.py b/tools/linter/convert_to_sarif.py index 8b540e69e5..6281ae7444 100644 --- a/tools/linter/convert_to_sarif.py +++ b/tools/linter/convert_to_sarif.py @@ -19,7 +19,7 @@ def severity_to_github_level(severity: str) -> str: def parse_single_lintrunner_result( - lintrunner_result: dict[str, Any] + lintrunner_result: dict[str, Any], ) -> tuple[dict[str, Any], dict[str, Any]]: r"""Parse a single lintrunner result. diff --git a/tools/pkg-helpers/pytorch_pkg_helpers/__main__.py b/tools/pkg-helpers/pytorch_pkg_helpers/__main__.py index ffc7bd59e7..96b5ebac89 100644 --- a/tools/pkg-helpers/pytorch_pkg_helpers/__main__.py +++ b/tools/pkg-helpers/pytorch_pkg_helpers/__main__.py @@ -42,10 +42,11 @@ def parse_args() -> argparse.Namespace: "--platform", help="Platform to generate for", type=str, - default= ( - sys.platform if os.getenv("PLATFORM", sys.platform) == "" + default=( + sys.platform + if os.getenv("PLATFORM", sys.platform) == "" else os.getenv("PLATFORM", sys.platform) - ), + ), ) parser.add_argument( "--gpu-arch-version", diff --git a/tools/pkg-helpers/pytorch_pkg_helpers/conda.py b/tools/pkg-helpers/pytorch_pkg_helpers/conda.py index fadd99ac5f..1c0939911f 100644 --- a/tools/pkg-helpers/pytorch_pkg_helpers/conda.py +++ b/tools/pkg-helpers/pytorch_pkg_helpers/conda.py @@ -1,5 +1,4 @@ import re - from typing import List from .utils import transform_cuversion diff --git a/tools/pkg-helpers/pytorch_pkg_helpers/cuda.py b/tools/pkg-helpers/pytorch_pkg_helpers/cuda.py index 32a3ff0147..3799e87930 100644 --- a/tools/pkg-helpers/pytorch_pkg_helpers/cuda.py +++ b/tools/pkg-helpers/pytorch_pkg_helpers/cuda.py @@ -1,9 +1,9 @@ import sys - from typing import List from .utils import transform_cuversion + WINDOWS_PATH_PREFIX = "C:\\Program Files\\NVIDIA GPU Computing Toolkit\\CUDA\\v" diff --git a/tools/pkg-helpers/pytorch_pkg_helpers/version.py b/tools/pkg-helpers/pytorch_pkg_helpers/version.py index e619d2c2dd..3196a6993e 100644 --- a/tools/pkg-helpers/pytorch_pkg_helpers/version.py +++ b/tools/pkg-helpers/pytorch_pkg_helpers/version.py @@ -1,10 +1,10 @@ import re import subprocess - from datetime import datetime from pathlib import Path from typing import List + LEADING_V_PATTERN = re.compile("^v") TRAILING_RC_PATTERN = re.compile("-rc[0-9]*$") LEGACY_BASE_VERSION_SUFFIX_PATTERN = re.compile("a0$") @@ -104,7 +104,11 @@ def get_version_variables( ) -> List[str]: version = PytorchVersion( gpu_arch_version=gpu_arch_version, - no_build_suffix=(platform == "darwin" or platform == "linux-aarch64" or package_type == "conda"), + no_build_suffix=( + platform == "darwin" + or platform == "linux-aarch64" + or package_type == "conda" + ), base_build_version=base_build_version, ) output_version = version.get_nightly_version() diff --git a/tools/pkg-helpers/tests/test_conda.py b/tools/pkg-helpers/tests/test_conda.py index 73b04e1796..d4d62f5436 100644 --- a/tools/pkg-helpers/tests/test_conda.py +++ b/tools/pkg-helpers/tests/test_conda.py @@ -1,7 +1,6 @@ import json import pytest - from pytorch_pkg_helpers.conda import ( get_conda_cuda_variables, get_conda_version_variables, diff --git a/tools/pkg-helpers/tests/test_cuda.py b/tools/pkg-helpers/tests/test_cuda.py index 875b5a8874..288b102b0a 100644 --- a/tools/pkg-helpers/tests/test_cuda.py +++ b/tools/pkg-helpers/tests/test_cuda.py @@ -1,5 +1,4 @@ import pytest - from pytorch_pkg_helpers.cuda import get_cuda_arch_list, get_cuda_variables from pytorch_pkg_helpers.utils import transform_cuversion diff --git a/tools/pkg-helpers/tests/test_macos.py b/tools/pkg-helpers/tests/test_macos.py index a1a5a98cf9..7299e9afce 100644 --- a/tools/pkg-helpers/tests/test_macos.py +++ b/tools/pkg-helpers/tests/test_macos.py @@ -1,5 +1,4 @@ import pytest - from pytorch_pkg_helpers.macos import get_macos_variables diff --git a/tools/pkg-helpers/tests/test_version.py b/tools/pkg-helpers/tests/test_version.py index ab6694ee4b..f6fad78201 100644 --- a/tools/pkg-helpers/tests/test_version.py +++ b/tools/pkg-helpers/tests/test_version.py @@ -1,9 +1,9 @@ from datetime import datetime import pytest - from pytorch_pkg_helpers.version import get_version_variables + DATE_STR = datetime.today().strftime("%Y%m%d") diff --git a/tools/pkg-helpers/tests/test_wheel.py b/tools/pkg-helpers/tests/test_wheel.py index 53f564bce2..ed4be51fd5 100644 --- a/tools/pkg-helpers/tests/test_wheel.py +++ b/tools/pkg-helpers/tests/test_wheel.py @@ -1,5 +1,4 @@ import pytest - from pytorch_pkg_helpers.wheel import ( get_pytorch_pip_install_command, get_pytorch_s3_bucket_path, diff --git a/tools/rockset_migration/compare_keys.py b/tools/rockset_migration/compare_keys.py index b166a47f95..a83c47b689 100644 --- a/tools/rockset_migration/compare_keys.py +++ b/tools/rockset_migration/compare_keys.py @@ -2,18 +2,19 @@ Helper script to compare dynamo keys present between Rockset and Clickhouse, and upload missing keys to Clickhouse if any are missing """ + +import os from argparse import ArgumentParser from functools import lru_cache from typing import Any, List -import os -import rockset +import rockset from dynamo2ch import ( ADAPTERS, + get_clickhouse_client, get_dynamo_client, unmarshal, upload_to_clickhouse, - get_clickhouse_client, ) diff --git a/tools/rockset_migration/create_clickhouse_schema.py b/tools/rockset_migration/create_clickhouse_schema.py index 78063972b6..29a3c2e01e 100644 --- a/tools/rockset_migration/create_clickhouse_schema.py +++ b/tools/rockset_migration/create_clickhouse_schema.py @@ -3,8 +3,10 @@ some manual work to verify and fill in some types if the script cannot infer them. """ + import re from typing import Dict, List + from rockset_queries import get_query_lambdas from torchci.rockset_utils import query_rockset diff --git a/tools/rockset_migration/dynamo2ch.py b/tools/rockset_migration/dynamo2ch.py index 2478cf19d5..0ae6b5e260 100755 --- a/tools/rockset_migration/dynamo2ch.py +++ b/tools/rockset_migration/dynamo2ch.py @@ -5,20 +5,20 @@ """ import datetime -from functools import lru_cache import json -from multiprocessing import Pool import os import time from argparse import ArgumentParser +from functools import lru_cache +from multiprocessing import Pool from typing import Any, Dict, Optional, Union - import boto3 import clickhouse_connect import line_profiler from prefetch_generator import BackgroundGenerator + S3_RESOURCE = boto3.resource("s3") CLICKHOUSE_ENDPOINT = os.environ.get("CLICKHOUSE_ENDPOINT", "localhost") CLICKHOUSE_USERNAME = os.environ.get("CLICKHOUSE_USERNAME", "username") diff --git a/tools/rockset_migration/rockset_2_dynamodb.py b/tools/rockset_migration/rockset_2_dynamodb.py index 48faf25931..33e146726c 100755 --- a/tools/rockset_migration/rockset_2_dynamodb.py +++ b/tools/rockset_migration/rockset_2_dynamodb.py @@ -9,7 +9,6 @@ from typing import Any, Dict, List import boto3 - from dateutil import parser # type: ignore[import] from rockset import RocksetClient # type: ignore[import] from rockset.models import QueryParameter, QueryRequestSql # type: ignore[import] diff --git a/tools/rockset_migration/rockset_queries.py b/tools/rockset_migration/rockset_queries.py index 812a34943a..8324eda23c 100755 --- a/tools/rockset_migration/rockset_queries.py +++ b/tools/rockset_migration/rockset_queries.py @@ -19,11 +19,11 @@ import json import os from pathlib import Path - from typing import Any, Dict, List, NamedTuple import requests + ROCKSET_API_KEY = os.environ.get("ROCKSET_API_KEY") # In[ ]: @@ -249,17 +249,14 @@ def backup_lambdas(queries: Dict[str, LambdaQuery], dir: Path) -> None: # In[ ]: if __name__ == "__main__": - queries = get_query_lambdas() # In[ ]: - backup_lambdas(queries, Path("lambdas_backup")) # In[ ]: - prob_unneeded = { **not_run(queries), **not_recently_run(queries, 60), @@ -267,7 +264,6 @@ def backup_lambdas(queries: Dict[str, LambdaQuery], dir: Path) -> None: # In[ ]: - # This code will be used to delete unused lambads, 10 at a time # # Deletes lambadas that have never been run @@ -283,7 +279,6 @@ def backup_lambdas(queries: Dict[str, LambdaQuery], dir: Path) -> None: # In[ ]: - important_queries = not_in(queries, prob_unneeded) len(have_human_descriptions(important_queries)) @@ -295,7 +290,6 @@ def printq(queries: Dict[str, LambdaQuery], fields: List[str]) -> None: query.printfields(fields) print() - def print_query_descriptions(queries: Dict[str, LambdaQuery]) -> None: for query in queries.values(): print(f"{query.workspace}.{query.name}", end="") @@ -306,15 +300,11 @@ def print_query_descriptions(queries: Dict[str, LambdaQuery]) -> None: else: print() - occasionally_run = not_in(important_queries, queries_run_recently(queries, 7)) len(occasionally_run) - - # In[ ]: - collections = get_collections() len(collections) @@ -330,7 +320,6 @@ def unused_collections( used_collections.update(query.collections) return {k: v for k, v in collections.items() if k not in used_collections} - def used_collections( collections: Dict[str, Collections], queries: Dict[str, LambdaQuery] ) -> Dict[str, Collections]: @@ -339,7 +328,6 @@ def used_collections( used_collections.update(query.collections) return {k: v for k, v in collections.items() if k in used_collections} - print("Used collections:") for collection in used_collections(collections, important_queries).values(): print(f"{collection.workspace}.{collection.name}") diff --git a/tools/rockset_migration/s32ch.py b/tools/rockset_migration/s32ch.py index 1daeabe7cf..07bdc870f7 100755 --- a/tools/rockset_migration/s32ch.py +++ b/tools/rockset_migration/s32ch.py @@ -5,24 +5,24 @@ """ import datetime -from functools import lru_cache import importlib import json -from multiprocessing import Pool import os -from pathlib import Path import sys import time +import urllib from argparse import ArgumentParser +from functools import lru_cache +from multiprocessing import Pool +from pathlib import Path from typing import Any, Optional -import urllib - import boto3 import clickhouse_connect import line_profiler from prefetch_generator import BackgroundGenerator + REPO_ROOT = Path(__file__).resolve().parents[2] sys.path.append(str(REPO_ROOT)) lambda_function = importlib.import_module( diff --git a/tools/scripts/analyze_ci_workflows.py b/tools/scripts/analyze_ci_workflows.py index c2a0136116..b9eaac2072 100755 --- a/tools/scripts/analyze_ci_workflows.py +++ b/tools/scripts/analyze_ci_workflows.py @@ -2,7 +2,6 @@ import argparse import collections - import re import typing diff --git a/tools/scripts/backfill_events.py b/tools/scripts/backfill_events.py index ec69339f8d..c57503d2e5 100755 --- a/tools/scripts/backfill_events.py +++ b/tools/scripts/backfill_events.py @@ -4,12 +4,13 @@ import json import os from typing import Any -from warnings import warn from urllib.request import urlopen +from warnings import warn import boto3 from octokit import Octokit + S3 = boto3.resource("s3") BUCKET_NAME = "ossci-raw-job-status" BUCKET = S3.Bucket(BUCKET_NAME) @@ -22,7 +23,9 @@ def json_dumps(body: Any) -> str: return json.dumps(body, sort_keys=True, indent=4, separators=(",", ": ")) -def upload_log(client: Octokit, owner: str, repo: str, job_id: int, conclusion: str) -> None: +def upload_log( + client: Octokit, owner: str, repo: str, job_id: int, conclusion: str +) -> None: # This logic is copied from github-status-test lambda function log = client.actions.download_job_logs_for_workflow_run( owner=owner, repo=repo, job_id=job_id @@ -48,8 +51,8 @@ def upload_log(client: Octokit, owner: str, repo: str, job_id: int, conclusion: ) except Exception as error: warn( - f"Failed to upload {log} for job {job_id} from repo {owner}/{repo}: " + - f"{error}, skipping..." + f"Failed to upload {log} for job {job_id} from repo {owner}/{repo}: " + + f"{error}, skipping..." ) @@ -104,8 +107,8 @@ def process_workflow_run( response = client.actions.list_jobs_for_workflow_run(**params).json if not response: warn( - f"Fetching workflow_job for run {run_id} from repo {owner}/{repo} " + - f"with {params} returns no response, skipping..." + f"Fetching workflow_job for run {run_id} from repo {owner}/{repo} " + + f"with {params} returns no response, skipping..." ) return @@ -137,7 +140,9 @@ def process_workflow_run( params["page"] += 1 -def backfill(owner: str, repo: str, event: str, branch: str = "", limit: int = 0) -> None: +def backfill( + owner: str, repo: str, event: str, branch: str = "", limit: int = 0 +) -> None: token = os.environ.get("GITHUB_TOKEN", "") client = Octokit(auth="token", token=token) count = 0 diff --git a/tools/scripts/fetch_latest_green_commit.py b/tools/scripts/fetch_latest_green_commit.py index 43bf9445f4..465f93dac2 100644 --- a/tools/scripts/fetch_latest_green_commit.py +++ b/tools/scripts/fetch_latest_green_commit.py @@ -1,13 +1,15 @@ import json -from pathlib import Path import re import sys +from pathlib import Path from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple + REPO_ROOT = Path(__file__).resolve().parents[2] sys.path.insert(0, str(REPO_ROOT / "tools")) -from torchci.clickhouse import query_clickhouse_saved from scripts.gitutils import _check_output +from torchci.clickhouse import query_clickhouse_saved + sys.path.pop(0) diff --git a/tools/scripts/generate_binary_build_matrix.py b/tools/scripts/generate_binary_build_matrix.py index 7001f28ae9..b34050d89e 100755 --- a/tools/scripts/generate_binary_build_matrix.py +++ b/tools/scripts/generate_binary_build_matrix.py @@ -14,13 +14,12 @@ * Latest XPU """ - import argparse import json import os import sys +from typing import Any, Callable, Dict, List, Optional, Tuple -from typing import Dict, List, Optional, Tuple, Any, Callable PYTHON_ARCHES_DICT = { "nightly": ["3.9", "3.10", "3.11", "3.12", "3.13"], @@ -472,7 +471,6 @@ def generate_wheels_matrix( ret: List[Dict[str, Any]] = [] for python_version in python_versions: for arch_version in arches: - gpu_arch_type = arch_type(arch_version) gpu_arch_version = ( "" if arch_version in [CPU, CPU_AARCH64, XPU] else arch_version diff --git a/tools/scripts/generate_docker_release_matrix.py b/tools/scripts/generate_docker_release_matrix.py index 21e801d3fe..ec52695a58 100644 --- a/tools/scripts/generate_docker_release_matrix.py +++ b/tools/scripts/generate_docker_release_matrix.py @@ -11,28 +11,36 @@ """ +import argparse import json import os import sys -import argparse -from typing import Dict, List from datetime import datetime +from typing import Dict, List import generate_binary_build_matrix -DOCKER_IMAGE_TYPES = ["runtime", "devel"] +DOCKER_IMAGE_TYPES = ["runtime", "devel"] -def generate_docker_matrix(channel: str, generate_dockerhub_images: str) -> Dict[str, List[Dict[str, str]]]: +def generate_docker_matrix( + channel: str, generate_dockerhub_images: str +) -> Dict[str, List[Dict[str, str]]]: ret: List[Dict[str, str]] = [] prefix = "ghcr.io/pytorch/pytorch" docker_image_version = "" if channel == "release": - prefix_for_release = prefix.replace("ghcr.io/", "") if generate_dockerhub_images == "true" else prefix + prefix_for_release = ( + prefix.replace("ghcr.io/", "") + if generate_dockerhub_images == "true" + else prefix + ) docker_image_version = f"{prefix_for_release}:{generate_binary_build_matrix.CURRENT_STABLE_VERSION}" elif channel == "test": - docker_image_version = f"{prefix}-test:{generate_binary_build_matrix.CURRENT_CANDIDATE_VERSION}" + docker_image_version = ( + f"{prefix}-test:{generate_binary_build_matrix.CURRENT_CANDIDATE_VERSION}" + ) else: docker_image_version = f"{prefix}-nightly:{generate_binary_build_matrix.CURRENT_NIGHTLY_VERSION}.dev{datetime.today().strftime('%Y%m%d')}" @@ -83,8 +91,11 @@ def main() -> None: ) options = parser.parse_args() - build_matrix = generate_docker_matrix(options.channel, options.generate_dockerhub_images) + build_matrix = generate_docker_matrix( + options.channel, options.generate_dockerhub_images + ) print(json.dumps(build_matrix)) + if __name__ == "__main__": main() diff --git a/tools/scripts/generate_release_matrix.py b/tools/scripts/generate_release_matrix.py index b809dbbc2e..e38c0c180a 100644 --- a/tools/scripts/generate_release_matrix.py +++ b/tools/scripts/generate_release_matrix.py @@ -6,28 +6,100 @@ """ - import argparse import json -import sys import os +import sys from typing import Dict + mod = sys.modules[__name__] RELEASE_DICT = { - "2.1.0": { 'torch': '2.1.0', 'torchvision': '0.16.0', 'torchaudio': '2.1.0', 'torchtext': '0.16.0', 'torchdata': '0.7.0'}, - "2.1.1": { 'torch': '2.1.1', 'torchvision': '0.16.1', 'torchaudio': '2.1.1', 'torchtext': '0.16.1', 'torchdata': '0.7.1'}, - "2.1.2": { 'torch': '2.1.2', 'torchvision': '0.16.2', 'torchaudio': '2.1.2', 'torchtext': '0.16.2', 'torchdata': '0.7.1'}, - "2.2.0": { 'torch': '2.2.0', 'torchvision': '0.17.0', 'torchaudio': '2.2.0', 'torchtext': '0.17.0', 'torchdata': '0.7.1'}, - "2.2.1": { 'torch': '2.2.1', 'torchvision': '0.17.1', 'torchaudio': '2.2.1', 'torchtext': '0.17.1', 'torchdata': '0.7.1'}, - "2.2.2": { 'torch': '2.2.2', 'torchvision': '0.17.2', 'torchaudio': '2.2.2', 'torchtext': '0.17.2', 'torchdata': '0.7.1'}, - "2.3.0": { 'torch': '2.3.0', 'torchvision': '0.18.0', 'torchaudio': '2.3.0', 'torchtext': '0.18.0', 'torchdata': '0.7.1'}, - "2.3.1": { 'torch': '2.3.1', 'torchvision': '0.18.1', 'torchaudio': '2.3.1', 'torchtext': '0.18.1', 'torchdata': '0.7.1'}, - "2.4.0": { 'torch': '2.4.0', 'torchvision': '0.19.0', 'torchaudio': '2.4.0', 'torchtext': '0.18.1', 'torchdata': '0.7.1'}, - "2.4.1": { 'torch': '2.4.1', 'torchvision': '0.19.1', 'torchaudio': '2.4.1', 'torchtext': '0.18.1', 'torchdata': '0.7.1'}, - "2.5.0": { 'torch': '2.5.0', 'torchvision': '0.20.0', 'torchaudio': '2.5.0', 'torchtext': '0.18.1', 'torchdata': '0.7.1'}, - "2.5.1": { 'torch': '2.5.1', 'torchvision': '0.20.1', 'torchaudio': '2.5.1', 'torchtext': '0.18.1', 'torchdata': '0.7.1'}, + "2.1.0": { + "torch": "2.1.0", + "torchvision": "0.16.0", + "torchaudio": "2.1.0", + "torchtext": "0.16.0", + "torchdata": "0.7.0", + }, + "2.1.1": { + "torch": "2.1.1", + "torchvision": "0.16.1", + "torchaudio": "2.1.1", + "torchtext": "0.16.1", + "torchdata": "0.7.1", + }, + "2.1.2": { + "torch": "2.1.2", + "torchvision": "0.16.2", + "torchaudio": "2.1.2", + "torchtext": "0.16.2", + "torchdata": "0.7.1", + }, + "2.2.0": { + "torch": "2.2.0", + "torchvision": "0.17.0", + "torchaudio": "2.2.0", + "torchtext": "0.17.0", + "torchdata": "0.7.1", + }, + "2.2.1": { + "torch": "2.2.1", + "torchvision": "0.17.1", + "torchaudio": "2.2.1", + "torchtext": "0.17.1", + "torchdata": "0.7.1", + }, + "2.2.2": { + "torch": "2.2.2", + "torchvision": "0.17.2", + "torchaudio": "2.2.2", + "torchtext": "0.17.2", + "torchdata": "0.7.1", + }, + "2.3.0": { + "torch": "2.3.0", + "torchvision": "0.18.0", + "torchaudio": "2.3.0", + "torchtext": "0.18.0", + "torchdata": "0.7.1", + }, + "2.3.1": { + "torch": "2.3.1", + "torchvision": "0.18.1", + "torchaudio": "2.3.1", + "torchtext": "0.18.1", + "torchdata": "0.7.1", + }, + "2.4.0": { + "torch": "2.4.0", + "torchvision": "0.19.0", + "torchaudio": "2.4.0", + "torchtext": "0.18.1", + "torchdata": "0.7.1", + }, + "2.4.1": { + "torch": "2.4.1", + "torchvision": "0.19.1", + "torchaudio": "2.4.1", + "torchtext": "0.18.1", + "torchdata": "0.7.1", + }, + "2.5.0": { + "torch": "2.5.0", + "torchvision": "0.20.0", + "torchaudio": "2.5.0", + "torchtext": "0.18.1", + "torchdata": "0.7.1", + }, + "2.5.1": { + "torch": "2.5.1", + "torchvision": "0.20.1", + "torchaudio": "2.5.1", + "torchtext": "0.18.1", + "torchdata": "0.7.1", + }, } diff --git a/tools/scripts/gitutils.py b/tools/scripts/gitutils.py index 88230eb689..d45915ca0b 100644 --- a/tools/scripts/gitutils.py +++ b/tools/scripts/gitutils.py @@ -17,6 +17,7 @@ Union, ) + T = TypeVar("T") RE_GITHUB_URL_MATCH = re.compile("^https://.*@?github.com/(.+)/(.+)$") diff --git a/tools/scripts/list_prs_from_partners_by_label.py b/tools/scripts/list_prs_from_partners_by_label.py index 9e9cab11cc..89c6d6723a 100644 --- a/tools/scripts/list_prs_from_partners_by_label.py +++ b/tools/scripts/list_prs_from_partners_by_label.py @@ -17,6 +17,7 @@ import requests + token = os.environ.get("GITHUB_TOKEN") local_cache = os.environ.get("CACHE") diff --git a/tools/self-hosted-runner-utils/replace_runners_prefix_submit_pr.py b/tools/self-hosted-runner-utils/replace_runners_prefix_submit_pr.py index 21ac2a38ff..8c92608053 100644 --- a/tools/self-hosted-runner-utils/replace_runners_prefix_submit_pr.py +++ b/tools/self-hosted-runner-utils/replace_runners_prefix_submit_pr.py @@ -9,12 +9,13 @@ """ import argparse -from datetime import datetime import fnmatch import os +import subprocess import sys +from datetime import datetime + import yaml -import subprocess REPOS = [ @@ -35,9 +36,13 @@ def get_opts() -> argparse.Namespace: - parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) parser.add_argument("--scale-config", type=str, required=True) - parser.add_argument("--temp-folder", type=str, default="/Users/jschmidt/.the_tmp_repl_runners") + parser.add_argument( + "--temp-folder", type=str, default="/Users/jschmidt/.the_tmp_repl_runners" + ) parser.add_argument("--prefix", type=str, default="amz2023.") return parser.parse_args() @@ -45,11 +50,21 @@ def get_opts() -> argparse.Namespace: def get_runners_names(prefix: str, scale_config: str) -> list[str]: with open(scale_config, "r") as f: config = yaml.safe_load(f) - runners_w_prefix = set(runner_name.replace(prefix, "") for runner_name in config["runner_types"].keys() if runner_name.startswith(prefix)) - runners_wo_prefix = [runner_name for runner_name in config["runner_types"].keys() if not runner_name.startswith(prefix) and runner_name in runners_w_prefix] + runners_w_prefix = set( + runner_name.replace(prefix, "") + for runner_name in config["runner_types"].keys() + if runner_name.startswith(prefix) + ) + runners_wo_prefix = [ + runner_name + for runner_name in config["runner_types"].keys() + if not runner_name.startswith(prefix) and runner_name in runners_w_prefix + ] runners_wo_prefix.sort() - runners = [runners_wo_prefix[0], ] + runners = [ + runners_wo_prefix[0], + ] for idx in range(1, len(runners_wo_prefix)): if runners_wo_prefix[idx].startswith(runners[-1]): continue @@ -74,44 +89,106 @@ def find_replace(directory, runners, prefix, filePattern): continue -def commit_push_open_pr(repo_name: str, temp_folder: str, branch_name: str, comment: str) -> None: - subprocess.run(["git", "add", "-A"], cwd=f"{temp_folder}/{repo_name}", ) - subprocess.run(["git", "commit", "-m", comment], cwd=f"{temp_folder}/{repo_name}", ) - subprocess.run(["git", "push", "origin", branch_name], cwd=f"{temp_folder}/{repo_name}", ) +def commit_push_open_pr( + repo_name: str, temp_folder: str, branch_name: str, comment: str +) -> None: + subprocess.run( + ["git", "add", "-A"], + cwd=f"{temp_folder}/{repo_name}", + ) + subprocess.run( + ["git", "commit", "-m", comment], + cwd=f"{temp_folder}/{repo_name}", + ) + subprocess.run( + ["git", "push", "origin", branch_name], + cwd=f"{temp_folder}/{repo_name}", + ) subprocess.run( [ - "gh", "pr", "create", - "--repo", f"pytorch/{repo_name}", - "--base", "main", - "--head", branch_name, - "--title", comment, - "--body", f"testing new runners", + "gh", + "pr", + "create", + "--repo", + f"pytorch/{repo_name}", + "--base", + "main", + "--head", + branch_name, + "--title", + comment, + "--body", + f"testing new runners", ], - cwd=f"{temp_folder}/{repo_name}" + cwd=f"{temp_folder}/{repo_name}", ) def open_branch(repo, repo_name: str, temp_folder: str, branch_name: str) -> None: - subprocess.run(["git", "clone", repo, f"{temp_folder}/{repo_name}", ]) - subprocess.run(["git", "branch", branch_name, ], cwd=f"{temp_folder}/{repo_name}") - subprocess.run(["git", "checkout", branch_name, ], cwd=f"{temp_folder}/{repo_name}") + subprocess.run( + [ + "git", + "clone", + repo, + f"{temp_folder}/{repo_name}", + ] + ) + subprocess.run( + [ + "git", + "branch", + branch_name, + ], + cwd=f"{temp_folder}/{repo_name}", + ) + subprocess.run( + [ + "git", + "checkout", + branch_name, + ], + cwd=f"{temp_folder}/{repo_name}", + ) def main() -> None: opts = get_opts() runners = get_runners_names(opts.prefix, opts.scale_config) branch_name = f"replace_runners_prefix_{datetime.today().strftime('%Y%m%d%H%M%S')}" - subprocess.run(["rm", "-rf", opts.temp_folder, ]) - subprocess.run(["mkdir", "-p", opts.temp_folder, ]) + subprocess.run( + [ + "rm", + "-rf", + opts.temp_folder, + ] + ) + subprocess.run( + [ + "mkdir", + "-p", + opts.temp_folder, + ] + ) try: for repo in REPOS: - repo_name = repo.split('/')[-1] + repo_name = repo.split("/")[-1] open_branch(repo, repo_name, opts.temp_folder, branch_name) find_replace(f"{opts.temp_folder}/{repo_name}", runners, opts.prefix, "*") - commit_push_open_pr(repo_name, opts.temp_folder, branch_name, f"Replace runners prefix {opts.prefix}") + commit_push_open_pr( + repo_name, + opts.temp_folder, + branch_name, + f"Replace runners prefix {opts.prefix}", + ) finally: pass - subprocess.run(["rm", "-rf", opts.temp_folder, ]) + subprocess.run( + [ + "rm", + "-rf", + opts.temp_folder, + ] + ) if __name__ == "__main__": diff --git a/tools/stronghold/src/api/__init__.py b/tools/stronghold/src/api/__init__.py index 9e885647dc..5ed07375f1 100644 --- a/tools/stronghold/src/api/__init__.py +++ b/tools/stronghold/src/api/__init__.py @@ -3,7 +3,6 @@ from __future__ import annotations import dataclasses - from collections.abc import Sequence from typing import Optional diff --git a/tools/stronghold/src/api/ast.py b/tools/stronghold/src/api/ast.py index fb45db2eb3..51a6a57d51 100644 --- a/tools/stronghold/src/api/ast.py +++ b/tools/stronghold/src/api/ast.py @@ -105,5 +105,5 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: def visit_FunctionDef(self, node: ast.FunctionDef) -> None: # Records this function. - name = '.'.join(list(self._context) + [node.name]) + name = ".".join(list(self._context) + [node.name]) self._out[name] = node diff --git a/tools/stronghold/src/api/checker.py b/tools/stronghold/src/api/checker.py index 7024e9ae37..97863a4059 100644 --- a/tools/stronghold/src/api/checker.py +++ b/tools/stronghold/src/api/checker.py @@ -12,28 +12,28 @@ def run() -> None: parser = argparse.ArgumentParser(prog=sys.argv[0], description=__doc__) - parser.add_argument('--base-commit', type=str, required=True) - parser.add_argument('--head-commit', type=str, required=True) + parser.add_argument("--base-commit", type=str, required=True) + parser.add_argument("--head-commit", type=str, required=True) parser.add_argument( - '--suppressed', + "--suppressed", default=False, required=False, - action='store_true', - help='Failures are suppressed' - '(alternative to #suppress-api-compatibility-check commit message tag).', + action="store_true", + help="Failures are suppressed" + "(alternative to #suppress-api-compatibility-check commit message tag).", ) args = parser.parse_args(sys.argv[1:]) - repo = api.git.Repository(pathlib.Path('.')) + repo = api.git.Repository(pathlib.Path(".")) # By default, our GitHub jobs only fetch to a depth of one. This # means that the base commit will not be known to our local # clone. We must fetch it in order to compare head and base. # # The fetch is a smidge noisy, hide it by default. - print('::group::fetch github.event.pull_request.base.sha') - repo.run(['fetch', 'origin', args.base_commit], check=True) - print('::endgroup::') + print("::group::fetch github.event.pull_request.base.sha") + repo.run(["fetch", "origin", args.base_commit], check=True) + print("::endgroup::") violations = api.compatibility.check_range( repo, head=args.head_commit, base=args.base_commit @@ -43,18 +43,18 @@ def run() -> None: pinfo = repo.run( [ - 'show', + "show", # Don't show the file contents. - '--no-patch', + "--no-patch", # Show the title and the full commit message. - '--pretty=format:%B', + "--pretty=format:%B", ], check=True, stdout=subprocess.PIPE, ) - suppression_tags = ['#suppress-api-compatibility-check', '#suppress-bc-linter'] + suppression_tags = ["#suppress-api-compatibility-check", "#suppress-bc-linter"] suppressed = args.suppressed or any(tag in pinfo.stdout for tag in suppression_tags) - level = 'notice' if suppressed else 'warning' + level = "notice" if suppressed else "warning" for file, file_violations in violations.items(): for violation in file_violations: diff --git a/tools/stronghold/src/api/compatibility.py b/tools/stronghold/src/api/compatibility.py index 55d7933e84..5dcfc488a2 100644 --- a/tools/stronghold/src/api/compatibility.py +++ b/tools/stronghold/src/api/compatibility.py @@ -5,7 +5,6 @@ import difflib import pathlib import tempfile - from collections.abc import Iterable, Mapping, Sequence import api @@ -18,28 +17,28 @@ def check_range( repo: api.git.Repository, *, head: str, base: str ) -> Mapping[pathlib.Path, Sequence[api.violations.Violation]]: result = {} - for file in repo.get_files_in_range(f'{base}..{head}'): + for file in repo.get_files_in_range(f"{base}..{head}"): # Someday, we'll want to customize the filters we use to # ignore files. - if file.suffix != '.py': + if file.suffix != ".py": # Only consider Python files. continue - if any(dir.name.startswith('_') for dir in file.parents): + if any(dir.name.startswith("_") for dir in file.parents): # Ignore any internal packages. continue - if any(dir.name.startswith('.') for dir in file.parents): + if any(dir.name.startswith(".") for dir in file.parents): # Ignore any internal packages and ci modules continue - if file.name.startswith('_'): + if file.name.startswith("_"): # Ignore internal modules. continue - if any(dir.name == 'test' for dir in file.parents): + if any(dir.name == "test" for dir in file.parents): # Ignore tests (not part of PyTorch package). continue - if any(dir.name == 'benchmarks' for dir in file.parents): + if any(dir.name == "benchmarks" for dir in file.parents): # Ignore benchmarks (not part of PyTorch package). continue - if file.name.startswith('test_') or file.stem.endswith('_test'): + if file.name.startswith("test_") or file.stem.endswith("_test"): # Ignore test files. continue @@ -47,8 +46,8 @@ def check_range( # # Note that if the file doesn't exist, it is equivalent to it # being empty. - after = repo.get_contents(file, commit_id=head) or '' - before = repo.get_contents(file, commit_id=base) or '' + after = repo.get_contents(file, commit_id=head) or "" + before = repo.get_contents(file, commit_id=base) or "" with tempfile.NamedTemporaryFile() as before_file: before_path = pathlib.Path(before_file.name) @@ -74,7 +73,7 @@ def check( violations: list[api.violations.Violation] = [] for name, before_def in before_api.items(): - if any(token.startswith('_') for token in name.split('.')): + if any(token.startswith("_") for token in name.split(".")): continue after_def = after_api.get(name) @@ -182,9 +181,9 @@ def _check_by_position( matcher = difflib.SequenceMatcher(a=before_param_names, b=after_param_names) for tag, i1, i2, j1, j2 in matcher.get_opcodes(): - if tag == 'equal': + if tag == "equal": continue - if tag == 'replace': + if tag == "replace": yield api.violations.ParameterRenamed( func=func, parameter=before_param_names[i1], @@ -192,7 +191,7 @@ def _check_by_position( line=after.line, ) continue - if tag == 'insert': + if tag == "insert": after_param = after_params[j1] if after_param.required: yield api.violations.ParameterNowRequired( @@ -201,7 +200,7 @@ def _check_by_position( line=after_param.line, ) continue - if tag == 'delete': + if tag == "delete": yield api.violations.ParameterRemoved( func=func, parameter=before_params[i1].name, diff --git a/tools/stronghold/src/api/git.py b/tools/stronghold/src/api/git.py index 0e93ed2ff6..452021f0e2 100644 --- a/tools/stronghold/src/api/git.py +++ b/tools/stronghold/src/api/git.py @@ -4,7 +4,6 @@ import pathlib import subprocess - from collections.abc import Iterable from typing import Any, Optional, Union @@ -24,21 +23,21 @@ def dir(self, /) -> pathlib.Path: def get_files_in_range(self, /, range: str) -> Iterable[pathlib.Path]: """Gets files modified in a range of commits.""" pinfo = self.run( - ['diff-tree', '--name-only', '-r', range], + ["diff-tree", "--name-only", "-r", range], check=True, stdout=subprocess.PIPE, ) return [pathlib.Path(p) for p in pinfo.stdout.splitlines()] def get_contents( - self, path: pathlib.Path, *, commit_id: str = 'HEAD' + self, path: pathlib.Path, *, commit_id: str = "HEAD" ) -> Optional[str]: """Gets the contents of a file at a specified commit. Defaults to the most recent commit. """ proc = self.run( - ['show', f'{commit_id}:{path}'], + ["show", f"{commit_id}:{path}"], stdout=subprocess.PIPE, ) if proc.returncode == 128: @@ -52,7 +51,7 @@ def run( self, args: list[Union[pathlib.Path, str]], /, **kwargs: Any ) -> subprocess.CompletedProcess[str]: """Runs a git command in the repository.""" - args.insert(0, 'git') + args.insert(0, "git") return subprocess.run( args, cwd=self._dir, diff --git a/tools/stronghold/src/api/github.py b/tools/stronghold/src/api/github.py index 92756239b8..bb01f2da64 100644 --- a/tools/stronghold/src/api/github.py +++ b/tools/stronghold/src/api/github.py @@ -10,6 +10,6 @@ def render_violation( level: str, file: pathlib.Path, violation: api.violations.Violation ) -> str: return ( - f'::{level} file={file},line={violation.line}::' - f'Function {violation.func}: {violation.message}' + f"::{level} file={file},line={violation.line}::" + f"Function {violation.func}: {violation.message}" ) diff --git a/tools/stronghold/src/api/types.py b/tools/stronghold/src/api/types.py index 0c703301f3..3d50e508d9 100644 --- a/tools/stronghold/src/api/types.py +++ b/tools/stronghold/src/api/types.py @@ -35,11 +35,11 @@ class Generic: Represents a generic type, like `List[int]` or `Dict[str, int]`. """ - base: Union[TypeName, 'Attribute'] + base: Union[TypeName, "Attribute"] arguments: List[TypeHint] def __str__(self) -> str: - arguments_str = ', '.join(str(arg) for arg in self.arguments) + arguments_str = ", ".join(str(arg) for arg in self.arguments) return f"{str(self.base)}[{arguments_str}]" @@ -52,7 +52,7 @@ class Tuple: arguments: List[TypeHint] def __str__(self) -> str: - return ', '.join(str(arg) for arg in self.arguments) + return ", ".join(str(arg) for arg in self.arguments) @dataclass @@ -61,7 +61,7 @@ class Attribute: Represents an attribute, like `foo.bar` or `foo.bar.baz`. """ - value: Union[TypeName, 'Attribute'] + value: Union[TypeName, "Attribute"] attr: str def __str__(self) -> str: diff --git a/tools/stronghold/src/api/violations.py b/tools/stronghold/src/api/violations.py index 1b209326bc..53342d7d84 100644 --- a/tools/stronghold/src/api/violations.py +++ b/tools/stronghold/src/api/violations.py @@ -25,7 +25,7 @@ class FunctionDeleted(Violation): a shim in the interim? """ - message: str = 'function deleted' + message: str = "function deleted" # ==================================== @@ -34,14 +34,14 @@ class FunctionDeleted(Violation): class VarArgsDeleted(Violation): """Represents when *varargs has been deleted""" - message: str = '*varargs was removed' + message: str = "*varargs was removed" @dataclass class KwArgsDeleted(Violation): """Represents when **kwargs has been deleted""" - message: str = '**kwargs was removed' + message: str = "**kwargs was removed" # ==================================== @@ -49,14 +49,14 @@ class KwArgsDeleted(Violation): @dataclass class ParameterViolation(Violation): # name of the parameter that was invovled in the violation - parameter: str = '' + parameter: str = "" @dataclass class ParameterRemoved(ParameterViolation): """Represents when a public function has a parameter that's been removed""" - message: str = '' + message: str = "" def __post_init__(self) -> None: self.message = f"{self.parameter} was removed" @@ -66,20 +66,20 @@ def __post_init__(self) -> None: class ParameterBecameRequired(ParameterViolation): """Represents when a public function has a parameter that became required""" - message: str = '' + message: str = "" def __post_init__(self) -> None: - self.message = f'{self.parameter} became now required' + self.message = f"{self.parameter} became now required" @dataclass class ParameterNowRequired(ParameterViolation): """Represents when a public function has a parameter is now required""" - message: str = '' + message: str = "" def __post_init__(self) -> None: - self.message = f'{self.parameter} was added and is now required' + self.message = f"{self.parameter} was added and is now required" @dataclass @@ -99,12 +99,12 @@ class ParameterRenamed(ParameterViolation): """Represents when a parameter has been renamed to a different parameter""" # Parameter after it was renamed - parameter_after: str = '' + parameter_after: str = "" - message: str = '' + message: str = "" def __post_init__(self) -> None: - self.message = f'{self.parameter} was renamed to {self.parameter_after}' + self.message = f"{self.parameter} was renamed to {self.parameter_after}" @dataclass @@ -112,14 +112,14 @@ class ParameterTypeChanged(ParameterViolation): """Represents when a parameter type has changed in a non-compatible way""" # Type before it was changed - type_before: str = '' + type_before: str = "" # Type after it was changed - type_after: str = '' + type_after: str = "" - message: str = '' + message: str = "" def __post_init__(self) -> None: self.message = ( - f'{self.parameter} changed from {self.type_before} to {self.type_after}' + f"{self.parameter} changed from {self.type_before} to {self.type_after}" ) diff --git a/tools/stronghold/tests/api/conftest.py b/tools/stronghold/tests/api/conftest.py index bbdc322f8e..3bc9368cd5 100644 --- a/tools/stronghold/tests/api/conftest.py +++ b/tools/stronghold/tests/api/conftest.py @@ -1,7 +1,6 @@ import pathlib import api.git - import pytest @@ -9,8 +8,8 @@ def git_repo(tmp_path: pathlib.Path) -> api.git.Repository: """pytest fixture providing an empty initialized git repository.""" repo = api.git.Repository(tmp_path) - repo.run(['init'], check=True) + repo.run(["init"], check=True) # Set the user for this repository only. - repo.run(['config', 'user.email', 'user@mcuserface.test'], check=True) - repo.run(['config', 'user.name', 'User McUserface'], check=True) + repo.run(["config", "user.email", "user@mcuserface.test"], check=True) + repo.run(["config", "user.name", "User McUserface"], check=True) return repo diff --git a/tools/stronghold/tests/api/test_ast.py b/tools/stronghold/tests/api/test_ast.py index 9cba72a299..9410ba877e 100644 --- a/tools/stronghold/tests/api/test_ast.py +++ b/tools/stronghold/tests/api/test_ast.py @@ -5,7 +5,6 @@ import api import api.ast import api.types - from testing import source @@ -15,7 +14,7 @@ def func() -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[], variadic_args=False, variadic_kwargs=False, line=1 ) } @@ -27,15 +26,15 @@ def func(x: int, /) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=True, keyword=False, required=True, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -51,15 +50,15 @@ def func(x: int = 0, /) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=True, keyword=False, required=False, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -75,15 +74,15 @@ def func(x: int) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=True, keyword=True, required=True, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -99,15 +98,15 @@ def func(x: int = 0) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=True, keyword=True, required=False, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -123,15 +122,15 @@ def func(*, x: int) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=False, keyword=True, required=True, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -147,15 +146,15 @@ def func(*, x: int = 0) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[ api.Parameter( - name='x', + name="x", positional=False, keyword=True, required=False, line=1, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ) ], variadic_args=False, @@ -171,7 +170,7 @@ def func(*args: int) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[], variadic_args=True, variadic_kwargs=False, line=1 ) } @@ -183,7 +182,7 @@ def func(**kwargs: int) -> None: funcs = api.ast.extract(source.make_file(tmp_path, func)) assert funcs == { - 'func': api.Parameters( + "func": api.Parameters( parameters=[], variadic_args=False, variadic_kwargs=True, line=1 ) } @@ -196,10 +195,10 @@ def func(self, /) -> None: funcs = api.ast.extract(source.make_file(tmp_path, Class)) assert funcs == { - 'Class.func': api.Parameters( + "Class.func": api.Parameters( parameters=[ api.Parameter( - name='self', + name="self", positional=True, keyword=False, required=True, @@ -222,38 +221,38 @@ def func( funcs = api.ast.extract(source.make_file(tmp_path, Class)) assert funcs == { - 'Class.func': api.Parameters( + "Class.func": api.Parameters( parameters=[ api.Parameter( - name='self', + name="self", positional=True, keyword=False, required=True, line=3, ), api.Parameter( - name='a', + name="a", positional=True, keyword=False, required=True, line=3, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ), api.Parameter( - name='b', + name="b", positional=True, keyword=True, required=False, line=3, - type_annotation=api.types.TypeName('float'), + type_annotation=api.types.TypeName("float"), ), api.Parameter( - name='c', + name="c", positional=False, keyword=True, required=True, line=3, - type_annotation=api.types.TypeName('int'), + type_annotation=api.types.TypeName("int"), ), ], variadic_args=True, diff --git a/tools/stronghold/tests/api/test_ast_param_compatibility.py b/tools/stronghold/tests/api/test_ast_param_compatibility.py index 8529e61e7a..53ce695221 100644 --- a/tools/stronghold/tests/api/test_ast_param_compatibility.py +++ b/tools/stronghold/tests/api/test_ast_param_compatibility.py @@ -5,12 +5,12 @@ def test_none() -> None: - assert _check_type_compatibility(TypeName('int'), None) is True + assert _check_type_compatibility(TypeName("int"), None) is True assert ( _check_type_compatibility( None, - TypeName('int'), + TypeName("int"), ) is True ) @@ -27,16 +27,16 @@ def test_none() -> None: def test_simple_types() -> None: assert ( _check_type_compatibility( - TypeName('int'), - TypeName('int'), + TypeName("int"), + TypeName("int"), ) is True ) assert ( _check_type_compatibility( - TypeName('int'), - TypeName('str'), + TypeName("int"), + TypeName("str"), ) is False ) @@ -44,12 +44,12 @@ def test_simple_types() -> None: assert ( _check_type_compatibility( Attribute( - value=TypeName('types'), - attr='Test', + value=TypeName("types"), + attr="Test", ), Attribute( - value=TypeName('types'), - attr='Test', + value=TypeName("types"), + attr="Test", ), ) is True @@ -58,12 +58,12 @@ def test_simple_types() -> None: assert ( _check_type_compatibility( Attribute( - value=TypeName('types'), - attr='Test', + value=TypeName("types"), + attr="Test", ), Attribute( - value=TypeName('types'), - attr='Test2', + value=TypeName("types"), + attr="Test2", ), ) is False @@ -73,16 +73,16 @@ def test_simple_types() -> None: def test_unknown_types() -> None: assert ( _check_type_compatibility( - TypeName('int'), - Unknown('?'), + TypeName("int"), + Unknown("?"), ) is True ) assert ( _check_type_compatibility( - Unknown('?'), - TypeName('int'), + Unknown("?"), + TypeName("int"), ) is True ) @@ -91,32 +91,32 @@ def test_unknown_types() -> None: def test_constant_types() -> None: assert ( _check_type_compatibility( - Constant('None'), - Constant('None'), + Constant("None"), + Constant("None"), ) is True ) assert ( _check_type_compatibility( - Constant('None'), - Constant('True'), + Constant("None"), + Constant("True"), ) is False ) assert ( _check_type_compatibility( - Constant('None'), - Constant('False'), + Constant("None"), + Constant("False"), ) is False ) assert ( _check_type_compatibility( - Constant('True'), - TypeName('bool'), + Constant("True"), + TypeName("bool"), ) is True ) @@ -124,8 +124,8 @@ def test_constant_types() -> None: # note: asymmetry assert ( _check_type_compatibility( - TypeName('bool'), - Constant('True'), + TypeName("bool"), + Constant("True"), ) is False ) @@ -134,8 +134,8 @@ def test_constant_types() -> None: # thus it is compatible with any type assert ( _check_type_compatibility( - Constant('True'), - TypeName('int'), + Constant("True"), + TypeName("int"), ) is True ) @@ -145,12 +145,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), ) is True @@ -159,12 +159,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("str")], ), ) is False @@ -173,12 +173,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), Generic( - base=TypeName('Tuple'), - arguments=[TypeName('int')], + base=TypeName("Tuple"), + arguments=[TypeName("int")], ), ) is False @@ -187,12 +187,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int'), TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("int"), TypeName("str")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('str'), TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("str"), TypeName("int")], ), ) is False @@ -201,12 +201,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int'), TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("int"), TypeName("str")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('int'), TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("int"), TypeName("str")], ), ) is True @@ -215,12 +215,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int'), TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("int"), TypeName("str")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), ) is False @@ -229,12 +229,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), Generic( - base=TypeName('List'), - arguments=[Unknown('?')], + base=TypeName("List"), + arguments=[Unknown("?")], ), ) is True @@ -243,12 +243,12 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), - arguments=[Unknown('?')], + base=TypeName("List"), + arguments=[Unknown("?")], ), Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ), ) is True @@ -258,20 +258,20 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), + base=TypeName("List"), arguments=[ Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ) ], ), Generic( - base=TypeName('List'), + base=TypeName("List"), arguments=[ Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ) ], ), @@ -282,20 +282,20 @@ def test_generic_types() -> None: assert ( _check_type_compatibility( Generic( - base=TypeName('List'), + base=TypeName("List"), arguments=[ Generic( - base=TypeName('List'), - arguments=[TypeName('int')], + base=TypeName("List"), + arguments=[TypeName("int")], ) ], ), Generic( - base=TypeName('List'), + base=TypeName("List"), arguments=[ Generic( - base=TypeName('List'), - arguments=[TypeName('str')], + base=TypeName("List"), + arguments=[TypeName("str")], ) ], ), diff --git a/tools/stronghold/tests/api/test_ast_param_types.py b/tools/stronghold/tests/api/test_ast_param_types.py index 645f0b8ec9..063709de98 100644 --- a/tools/stronghold/tests/api/test_ast_param_types.py +++ b/tools/stronghold/tests/api/test_ast_param_types.py @@ -6,7 +6,6 @@ import api import api.ast import api.types - from testing import source @@ -36,9 +35,9 @@ def func(a: int, b: float, c: List, /) -> None: # type: ignore params = extract_parameter_types(source.make_file(tmp_path, func)) assert params == [ - api.types.TypeName('int'), - api.types.TypeName('float'), - api.types.TypeName('List'), + api.types.TypeName("int"), + api.types.TypeName("float"), + api.types.TypeName("List"), ] @@ -48,9 +47,9 @@ def func(a: None, b: True, c: False, /) -> None: # type: ignore params = extract_parameter_types(source.make_file(tmp_path, func)) assert params == [ - api.types.Constant('None'), - api.types.Constant('True'), - api.types.Constant('False'), + api.types.Constant("None"), + api.types.Constant("True"), + api.types.Constant("False"), ] @@ -63,23 +62,23 @@ def func( params = extract_parameter_types(source.make_file(tmp_path, func)) assert params == [ api.types.Generic( - base=api.types.TypeName('List'), - arguments=[api.types.TypeName('int')], + base=api.types.TypeName("List"), + arguments=[api.types.TypeName("int")], ), api.types.Generic( - base=api.types.TypeName('Dict'), - arguments=[api.types.TypeName('str'), api.types.TypeName('int')], + base=api.types.TypeName("Dict"), + arguments=[api.types.TypeName("str"), api.types.TypeName("int")], ), api.types.Generic( - base=api.types.TypeName('Tuple'), - arguments=[api.types.TypeName('int'), api.types.TypeName('str')], + base=api.types.TypeName("Tuple"), + arguments=[api.types.TypeName("int"), api.types.TypeName("str")], ), api.types.Generic( - base=api.types.TypeName('List'), + base=api.types.TypeName("List"), arguments=[ api.types.Generic( - base=api.types.TypeName('Dict'), - arguments=[api.types.TypeName('str'), api.types.TypeName('int')], + base=api.types.TypeName("Dict"), + arguments=[api.types.TypeName("str"), api.types.TypeName("int")], ) ], ), @@ -94,17 +93,17 @@ def func(a: api.types.TypeName, b: api.types.Attribute, /) -> None: assert params == [ api.types.Attribute( value=api.types.Attribute( - value=api.types.TypeName('api'), - attr='types', + value=api.types.TypeName("api"), + attr="types", ), - attr='TypeName', + attr="TypeName", ), api.types.Attribute( value=api.types.Attribute( - value=api.types.TypeName('api'), - attr='types', + value=api.types.TypeName("api"), + attr="types", ), - attr='Attribute', + attr="Attribute", ), ] diff --git a/tools/stronghold/tests/api/test_compatibility.py b/tools/stronghold/tests/api/test_compatibility.py index 6c65784cc9..dc72bec361 100644 --- a/tools/stronghold/tests/api/test_compatibility.py +++ b/tools/stronghold/tests/api/test_compatibility.py @@ -4,9 +4,7 @@ import api.compatibility import api.violations - import pytest - from testing import git, source @@ -19,7 +17,7 @@ def func() -> None: after = source.make_file(tmp_path, lambda: None) assert api.compatibility.check(before, after) == [ - api.violations.FunctionDeleted(func='func', line=1) + api.violations.FunctionDeleted(func="func", line=1) ] @@ -39,7 +37,7 @@ def rose_by_any_other_name( after = source.make_file(tmp_path, rose_by_any_other_name) assert api.compatibility.check(before, after) == [ - api.violations.FunctionDeleted(func='rose', line=1) + api.violations.FunctionDeleted(func="rose", line=1) ] @@ -53,7 +51,7 @@ def func(self, /) -> None: after = source.make_file(tmp_path, lambda: None) assert api.compatibility.check(before, after) == [ - api.violations.FunctionDeleted(func='Class.func', line=1) + api.violations.FunctionDeleted(func="Class.func", line=1) ] @@ -69,7 +67,7 @@ def func() -> None: # type: ignore[no-redef] after = source.make_file(tmp_path, func) assert api.compatibility.check(before, after) == [ - api.violations.VarArgsDeleted(func='func', line=1) + api.violations.VarArgsDeleted(func="func", line=1) ] @@ -85,7 +83,7 @@ def func() -> None: # type: ignore[no-redef] after = source.make_file(tmp_path, func) assert api.compatibility.check(before, after) == [ - api.violations.KwArgsDeleted(func='func', line=1) + api.violations.KwArgsDeleted(func="func", line=1) ] @@ -433,8 +431,8 @@ def func(a: str, b: int, /) -> None: # type: ignore[no-redef] func=func.__name__, parameter="a", line=1, - type_before='int', - type_after='str', + type_before="int", + type_after="str", ) ] @@ -455,8 +453,8 @@ def func(*, b: int, a: str) -> None: # type: ignore[no-redef] func=func.__name__, parameter="a", line=1, - type_before='int', - type_after='str', + type_before="int", + type_after="str", ) ] @@ -476,14 +474,14 @@ def func(*, b: List[int], a: List[int]) -> None: # type: ignore[no-redef] @pytest.mark.parametrize( - 'path', + "path", [ - 'python.cpp', - '_internal/module.py', - '_module.py', - 'test/module.py', - 'test_module.py', - 'module_test.py', + "python.cpp", + "_internal/module.py", + "_module.py", + "test/module.py", + "test_module.py", + "module_test.py", ], ) def test_check_range_skips(path: str, git_repo: api.git.Repository) -> None: @@ -491,34 +489,34 @@ def test_check_range_skips(path: str, git_repo: api.git.Repository) -> None: git_repo, pathlib.Path(path), textwrap.dedent( - ''' + """ def will_be_deleted(): pass - ''' + """ ), ) - git.commit_file(git_repo, pathlib.Path(path), '') - violations = api.compatibility.check_range(git_repo, head='HEAD', base='HEAD~') + git.commit_file(git_repo, pathlib.Path(path), "") + violations = api.compatibility.check_range(git_repo, head="HEAD", base="HEAD~") assert violations == {} def test_check_range(git_repo: api.git.Repository) -> None: git.commit_file( git_repo, - pathlib.Path('module.py'), + pathlib.Path("module.py"), textwrap.dedent( - ''' + """ def will_be_deleted(): pass - ''' + """ ), ) - git.commit_file(git_repo, pathlib.Path('module.py'), '') + git.commit_file(git_repo, pathlib.Path("module.py"), "") - violations = api.compatibility.check_range(git_repo, head='HEAD', base='HEAD~') + violations = api.compatibility.check_range(git_repo, head="HEAD", base="HEAD~") assert violations == { - pathlib.Path('module.py'): [ - api.violations.FunctionDeleted(func='will_be_deleted', line=1) + pathlib.Path("module.py"): [ + api.violations.FunctionDeleted(func="will_be_deleted", line=1) ], } diff --git a/tools/stronghold/tests/api/test_git.py b/tools/stronghold/tests/api/test_git.py index dbbe7061ce..f7e474d241 100644 --- a/tools/stronghold/tests/api/test_git.py +++ b/tools/stronghold/tests/api/test_git.py @@ -1,45 +1,44 @@ import pathlib import api.git - from testing import git def test_get_files_in_range(git_repo: api.git.Repository) -> None: - file = pathlib.Path('meh.txt') + file = pathlib.Path("meh.txt") # Check-in the file initially. - git.commit_file(git_repo, file, 'contents') + git.commit_file(git_repo, file, "contents") # The diff-tree command only works if there is a second commit. - git.commit_file(git_repo, file, 'contents\n') + git.commit_file(git_repo, file, "contents\n") - assert git_repo.get_files_in_range('HEAD~..HEAD') == [file] + assert git_repo.get_files_in_range("HEAD~..HEAD") == [file] def test_get_contents(git_repo: api.git.Repository) -> None: - file = pathlib.Path('meh.txt') + file = pathlib.Path("meh.txt") # Check-in the file initially. - git.commit_file(git_repo, file, 'contents\n') + git.commit_file(git_repo, file, "contents\n") - assert git_repo.get_contents(file) == 'contents\n' + assert git_repo.get_contents(file) == "contents\n" def test_get_contents_missing_file(git_repo: api.git.Repository) -> None: # Check-in the file initially. - git.commit_file(git_repo, pathlib.Path('meh.txt'), 'contents\n') + git.commit_file(git_repo, pathlib.Path("meh.txt"), "contents\n") - assert git_repo.get_contents(pathlib.Path('non_existent_file.txt')) is None + assert git_repo.get_contents(pathlib.Path("non_existent_file.txt")) is None def test_custom_commit_id(git_repo: api.git.Repository) -> None: - file = pathlib.Path('meh.txt') + file = pathlib.Path("meh.txt") # Check-in the file initially. - git.commit_file(git_repo, file, 'contents') + git.commit_file(git_repo, file, "contents") # The diff-tree command only works if there is a second commit. - git.commit_file(git_repo, file, 'contents\n') + git.commit_file(git_repo, file, "contents\n") # Add third commit to have multiple valid commit ids. - git.commit_file(git_repo, file, 'new contents\n') + git.commit_file(git_repo, file, "new contents\n") - assert git_repo.get_contents(file, commit_id='HEAD~') == 'contents\n' + assert git_repo.get_contents(file, commit_id="HEAD~") == "contents\n" diff --git a/tools/stronghold/tests/api/test_github.py b/tools/stronghold/tests/api/test_github.py index 6380e59a63..ee96f10996 100644 --- a/tools/stronghold/tests/api/test_github.py +++ b/tools/stronghold/tests/api/test_github.py @@ -3,20 +3,19 @@ import api.compatibility import api.github import api.violations - import pytest -@pytest.mark.parametrize('level', ['notice', 'warning']) +@pytest.mark.parametrize("level", ["notice", "warning"]) def test_render_violation(level: str) -> None: assert ( api.github.render_violation( level, - pathlib.Path('test.py'), + pathlib.Path("test.py"), api.violations.KwArgsDeleted( - func='foo', + func="foo", line=3, ), ) - == f'::{level} file=test.py,line=3::Function foo: **kwargs was removed' + == f"::{level} file=test.py,line=3::Function foo: **kwargs was removed" ) diff --git a/tools/stronghold/tests/lib/testing/git.py b/tools/stronghold/tests/lib/testing/git.py index aa598b8176..71593a4cb7 100644 --- a/tools/stronghold/tests/lib/testing/git.py +++ b/tools/stronghold/tests/lib/testing/git.py @@ -13,9 +13,9 @@ def commit_file( message: Optional[str] = None, ) -> None: """Creates a commit with the file and contents in the repository.""" - message = message or f'setting {file} to:\n{contents}' + message = message or f"setting {file} to:\n{contents}" file = git_repo.dir / file file.parent.mkdir(parents=True, exist_ok=True) file.write_text(contents) - git_repo.run(['add', '--intent-to-add', os.fspath(file)], check=True) - git_repo.run(['commit', f'--message={message}', os.fspath(file)], check=True) + git_repo.run(["add", "--intent-to-add", os.fspath(file)], check=True) + git_repo.run(["commit", f"--message={message}", os.fspath(file)], check=True) diff --git a/tools/tests/test_fetch_latest_green_commit.py b/tools/tests/test_fetch_latest_green_commit.py index 153238dfe4..e4f11de938 100644 --- a/tools/tests/test_fetch_latest_green_commit.py +++ b/tools/tests/test_fetch_latest_green_commit.py @@ -3,6 +3,7 @@ from tools.scripts.fetch_latest_green_commit import is_green, WorkflowCheck + workflow_names = [ "pull", "trunk", diff --git a/tools/tests/test_generate_binary_build_matrix.py b/tools/tests/test_generate_binary_build_matrix.py index d1cae0adbc..5ab30698ba 100644 --- a/tools/tests/test_generate_binary_build_matrix.py +++ b/tools/tests/test_generate_binary_build_matrix.py @@ -1,12 +1,12 @@ +import argparse import json import os - -import argparse import sys from unittest import main, TestCase from tools.scripts.generate_binary_build_matrix import generate_build_matrix + ASSETS_DIR = "tools/tests/assets" diff --git a/tools/torchci/check_alerts.py b/tools/torchci/check_alerts.py index 4ca20831ac..717d26b557 100755 --- a/tools/torchci/check_alerts.py +++ b/tools/torchci/check_alerts.py @@ -13,6 +13,7 @@ import requests from setuptools import distutils # type: ignore[import] + ALL_SKIPPED_THRESHOLD = 100 SIMILARITY_THRESHOLD = 0.75 FAILURE_CHAIN_THRESHOLD = 2 @@ -273,9 +274,9 @@ def generate_failed_job_issue( ) -> Any: failed_jobs.sort(key=lambda status: status.job_name) issue = {} - issue[ - "title" - ] = f"[Pytorch] There are {len(failed_jobs)} Recurrently Failing Jobs on {repo} {branch}" + issue["title"] = ( + f"[Pytorch] There are {len(failed_jobs)} Recurrently Failing Jobs on {repo} {branch}" + ) body = "Within the last 50 commits, there are the following failures on the main branch of pytorch: \n" for job in failed_jobs: failing_sha = job.failure_chain[-1]["sha"] @@ -326,9 +327,9 @@ def gen_update_comment(original_issue: Dict[str, Any], jobs: List[JobStatus]) -> def generate_no_flaky_tests_issue() -> Any: issue = {} - issue[ - "title" - ] = f"[Pytorch][Warning] No flaky test issues have been detected in the past {FLAKY_TESTS_SEARCH_PERIOD_DAYS} days!" + issue["title"] = ( + f"[Pytorch][Warning] No flaky test issues have been detected in the past {FLAKY_TESTS_SEARCH_PERIOD_DAYS} days!" + ) issue["body"] = ( f"No issues have been filed in the past {FLAKY_TESTS_SEARCH_PERIOD_DAYS} days for " f"the repository {REPO_OWNER}/{TEST_INFRA_REPO_NAME}.\n" diff --git a/tools/torchci/download_logs.py b/tools/torchci/download_logs.py index a270bfcb19..e8ad8a9433 100644 --- a/tools/torchci/download_logs.py +++ b/tools/torchci/download_logs.py @@ -8,6 +8,7 @@ import requests from torchci.clickhouse import query_clickhouse + REPO_ROOT = Path(__file__).resolve().parent.parent.parent diff --git a/tools/torchci/queue_alert.py b/tools/torchci/queue_alert.py index 33d6e7252d..51c1249075 100644 --- a/tools/torchci/queue_alert.py +++ b/tools/torchci/queue_alert.py @@ -5,10 +5,10 @@ from typing import Any, Dict, List, NamedTuple import requests - from setuptools import distutils # type: ignore[import] from torchci.check_alerts import clear_alerts, create_issue, fetch_alerts, update_issue + REPO_ROOT = Path(__file__).resolve().parent.parent.parent QUEUE_ALERT_LABEL = "queue-alert" diff --git a/tools/torchci/reverts.py b/tools/torchci/reverts.py index b0c68948b6..fcdfc30fc2 100644 --- a/tools/torchci/reverts.py +++ b/tools/torchci/reverts.py @@ -10,6 +10,7 @@ from torchci.clickhouse import query_clickhouse from torchci.github_analyze import GitCommit, GitRepo # type: ignore[import] + # Should match the contents produced by trymerge on revert RE_REVERT_COMMIT_BODY = r"Reverted .* on behalf of .* due to .* \(\[comment\]\((.*)\)\)" diff --git a/tools/torchci/td/historical_class_failure_correlation.py b/tools/torchci/td/historical_class_failure_correlation.py index aa0288a948..ebfc52d8a5 100644 --- a/tools/torchci/td/historical_class_failure_correlation.py +++ b/tools/torchci/td/historical_class_failure_correlation.py @@ -1,7 +1,6 @@ import json from torchci.clickhouse import query_clickhouse - from torchci.td.utils import ( calculate_generic_test_ratings, evaluate, @@ -9,6 +8,7 @@ get_merge_bases_dict, ) + FAILED_TESTS_QUERY = """ SELECT distinct REPLACE(t.invoking_file, '.', '/') as invoking_file, diff --git a/tools/torchci/td/historical_file_failure_correlation.py b/tools/torchci/td/historical_file_failure_correlation.py index e1c418efcd..9618a9628a 100644 --- a/tools/torchci/td/historical_file_failure_correlation.py +++ b/tools/torchci/td/historical_file_failure_correlation.py @@ -2,13 +2,13 @@ from collections import defaultdict from torchci.clickhouse import query_clickhouse - from torchci.td.utils import ( calculate_generic_test_ratings, evaluate, get_merge_bases_dict, ) + FAILED_TESTS_QUERY = """ select w.head_sha, diff --git a/tools/torchci/td/td_heuristic_historical_edited_files.py b/tools/torchci/td/td_heuristic_historical_edited_files.py index 4a656a4d7e..c6abc7f643 100644 --- a/tools/torchci/td/td_heuristic_historical_edited_files.py +++ b/tools/torchci/td/td_heuristic_historical_edited_files.py @@ -12,6 +12,7 @@ list_past_year_shas, ) + CHANGED_FILES_QUERY = """ select sha, diff --git a/tools/torchci/td/td_heuristic_profiling.py b/tools/torchci/td/td_heuristic_profiling.py index b9688511e9..bd5938b305 100644 --- a/tools/torchci/td/td_heuristic_profiling.py +++ b/tools/torchci/td/td_heuristic_profiling.py @@ -1,7 +1,6 @@ import json import requests - from torchci.td.utils import evaluate, get_filtered_failed_tests, get_merge_bases_dict diff --git a/tools/torchci/td/utils.py b/tools/torchci/td/utils.py index 83cfe6240c..149f0abde5 100644 --- a/tools/torchci/td/utils.py +++ b/tools/torchci/td/utils.py @@ -5,7 +5,6 @@ from typing import Any, Dict, List import requests - from torchci.clickhouse import query_clickhouse from torchci.utils import cache_json, run_command diff --git a/tools/torchci/tests/test_check_alerts.py b/tools/torchci/tests/test_check_alerts.py index 97aa0ceb5f..62a69a700a 100644 --- a/tools/torchci/tests/test_check_alerts.py +++ b/tools/torchci/tests/test_check_alerts.py @@ -14,6 +14,7 @@ PYTORCH_ALERT_LABEL, ) + JOB_NAME = "periodic / linux-xenial-cuda10.2-py3-gcc7-slow-gradcheck / test (default, 2, 2, linux.4xlarge.nvidia.gpu)" DISABLED_JOB_NAMES = [ "linux-focal-rocm5.3-py3.8-slow / test (slow, 1, 1, linux.rocm.gpu, rerun_disabled_tests)", @@ -164,9 +165,9 @@ def test_update_comment_empty(self): self.assertFalse(update_comment) jobs = [JobStatus("job1", [{}]), JobStatus("job2", [{}])] - original_issue[ - "body" - ] = "- [job1](a) failed consecutively starting with commit []()" + original_issue["body"] = ( + "- [job1](a) failed consecutively starting with commit []()" + ) update_comment = gen_update_comment(original_issue, jobs) self.assertTrue("started failing" in update_comment) self.assertTrue("job2" in update_comment) diff --git a/tools/torchci/update_test_times.py b/tools/torchci/update_test_times.py index 9ec28a6a15..4e4e052c9a 100644 --- a/tools/torchci/update_test_times.py +++ b/tools/torchci/update_test_times.py @@ -4,6 +4,7 @@ import requests from torchci.clickhouse import query_clickhouse_saved + TEST_TIMES_URL = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-times.json" TEST_CLASS_TIMES_URL = "https://raw.githubusercontent.com/pytorch/test-infra/generated-stats/stats/test-class-times.json" diff --git a/tools/torchci/utils.py b/tools/torchci/utils.py index f7c17f28d7..eaaf842427 100644 --- a/tools/torchci/utils.py +++ b/tools/torchci/utils.py @@ -6,6 +6,7 @@ from hashlib import sha256 from typing import List, Union + FILE_CACHE_LIFESPAN_SECONDS = 60 * 60 * 24 # 1 day REPO_ROOT = pathlib.Path(__file__).parent.parent.parent CACHE_FOLDER = REPO_ROOT / "_logs" / ".torchci_python_utils_cache"