diff --git a/ci/tools/validate-release-wheels b/ci/tools/validate-release-wheels index 5757ca17bc..52d24df473 100755 --- a/ci/tools/validate-release-wheels +++ b/ci/tools/validate-release-wheels @@ -9,11 +9,12 @@ from __future__ import annotations import argparse -import re import sys from collections import defaultdict from pathlib import Path +from check_release_notes import parse_version_from_tag + COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = { "cuda-core": {"cuda_core"}, "cuda-bindings": {"cuda_bindings"}, @@ -22,11 +23,13 @@ COMPONENT_TO_DISTRIBUTIONS: dict[str, set[str]] = { "all": {"cuda_core", "cuda_bindings", "cuda_pathfinder", "cuda_python"}, } -TAG_PATTERNS = ( - re.compile(r"^v(?P\d+\.\d+\.\d+)"), - re.compile(r"^cuda-core-v(?P\d+\.\d+\.\d+)"), - re.compile(r"^cuda-pathfinder-v(?P\d+\.\d+\.\d+)"), -) +COMPONENT_TO_TAG_COMPONENTS: dict[str, tuple[str, ...]] = { + "cuda-core": ("cuda-core",), + "cuda-bindings": ("cuda-bindings",), + "cuda-pathfinder": ("cuda-pathfinder",), + "cuda-python": ("cuda-python",), + "all": ("cuda-core", "cuda-bindings", "cuda-pathfinder", "cuda-python"), +} def parse_args() -> argparse.Namespace: @@ -42,15 +45,18 @@ def parse_args() -> argparse.Namespace: return parser.parse_args() -def version_from_tag(tag: str) -> str: - for pattern in TAG_PATTERNS: - match = pattern.match(tag) - if match: - return match.group("version") +def version_from_tag(tag: str, component: str) -> str: + versions = { + version + for tag_component in COMPONENT_TO_TAG_COMPONENTS[component] + if (version := parse_version_from_tag(tag, tag_component)) is not None + } + if len(versions) == 1: + return versions.pop() raise ValueError( "Unsupported git tag format " - f"{tag!r}; expected tags beginning with vX.Y.Z, cuda-core-vX.Y.Z, " - "or cuda-pathfinder-vX.Y.Z." + f"{tag!r} for component {component!r}; expected vX.Y.Z[.postN], " + "cuda-core-vX.Y.Z[.postN], or cuda-pathfinder-vX.Y.Z[.postN]." ) @@ -64,7 +70,12 @@ def parse_wheel_dist_and_version(path: Path) -> tuple[str, str]: def main() -> int: args = parse_args() - expected_version = version_from_tag(args.git_tag) + try: + expected_version = version_from_tag(args.git_tag, args.component) + except ValueError as exc: + print(f"Error: {exc}", file=sys.stderr) + return 1 + expected_distributions = COMPONENT_TO_DISTRIBUTIONS[args.component] wheel_dir = Path(args.wheel_dir)