diff --git a/tools/wheel_resolver/BUILD b/tools/wheel_resolver/BUILD index 7e425e87..baff4ff7 100644 --- a/tools/wheel_resolver/BUILD +++ b/tools/wheel_resolver/BUILD @@ -1,22 +1,22 @@ subinclude("//build_defs:python") python_binary( - name = "wheel_resolver", - main = "main.py", - visibility = ["PUBLIC"], - deps = [ + name="wheel_resolver", + main="main.py", + visibility=["PUBLIC"], + deps=[ ":wheel", ], ) python_library( - name = "wheel", - srcs = [ + name="wheel", + srcs=[ "__init__.py", "wheel.py", "output.py", ], - deps = [ + deps=[ "//third_party/python:click", "//third_party/python:click-log", "//third_party/python:distlib", @@ -26,14 +26,14 @@ python_library( ) python_test( - name = "test", - timeout = 600, - srcs = [ + name="test", + timeout=600, + srcs=[ "__init___test.py", "wheel_test.py", ], - test_runner = "pytest", - deps = [ + test_runner="pytest", + deps=[ ":wheel", "//third_party/python:pytest", ], diff --git a/tools/wheel_resolver/__init__.py b/tools/wheel_resolver/__init__.py index 6052551a..910d637b 100644 --- a/tools/wheel_resolver/__init__.py +++ b/tools/wheel_resolver/__init__.py @@ -1,6 +1,5 @@ import click import typing -import requests import logging import click_log import sys @@ -71,7 +70,7 @@ ) @click_log.simple_verbosity_option(_LOGGER) def main( - url: typing.List[str], + url: typing.Tuple[str], package_name: str, package_version: typing.Optional[str], interpreter: typing.Tuple[str, ...], @@ -85,39 +84,39 @@ def main( PyPI for PACKAGE with VERSION. """ - for u in url: - response = requests.head(u) - if response.status_code != requests.codes.ok: - _LOGGER.warning( - "%s-%s is not available, tried %r", package_name, package_version, u - ) - else: - click.echo(u) - return + try: + output_name = output.get() + except output.OutputNotSetError: + _LOGGER.error("could not get $OUTS") + sys.exit(1) + + if output.download(package_name, package_version, url, output_name): + return - # We're currently hardcoding PyPI but we should consider allowing other - # repositories - # TODO (tm-jdelapuente): allow downloads from other package repositories locator = distlib.locators.SimpleScrapingLocator(url="https://pypi.org/simple") locator.wheel_tags = list(itertools.product(interpreter, abi, platform)) - u = wheel.url( - package_name=package_name, - package_version=package_version, - tags=[ - str(x) - for i in interpreter - for x in tags.generic_tags( - interpreter=i, - abis=set(abi), - platforms=set(platform).union({"any"}), - ) - ], - locator=locator, - prereleases=prereleases, - ) - - if not output.try_download(u): - _LOGGER.error("Could not download from %r", u) + try: + u = wheel.url( + package_name=package_name, + package_version=package_version, + tags=[ + str(x) + for i in interpreter + for x in tags.generic_tags( + interpreter=i, + abis=set(abi), + platforms=set(platform).union({"any"}), + ) + ], + locator=locator, + prereleases=prereleases, + ) + pypi_url = (u,) + except Exception as error: + _LOGGER.error(error) + _LOGGER.error(f"could not find PyPI URL for {package_name}-{package_version}") sys.exit(1) - click.echo(u) + if not output.download(package_name, package_version, pypi_url, output_name): + _LOGGER.error("could not download %s-%s", package_name, package_version) + sys.exit(1) diff --git a/tools/wheel_resolver/__init___test.py b/tools/wheel_resolver/__init___test.py index f00b3b0b..28ec7c76 100644 --- a/tools/wheel_resolver/__init___test.py +++ b/tools/wheel_resolver/__init___test.py @@ -1,8 +1,5 @@ import click.testing -import pytest import unittest.mock -import requests - import tools.wheel_resolver as sut @@ -12,15 +9,7 @@ def test_help(self) -> None: result = runner.invoke(cli=sut.main, args=["--help"]) assert result.exit_code == 0 - @unittest.mock.patch.object(sut.output, "try_download") - @unittest.mock.patch.object(sut.requests, "head") - def test_any_in_platforms( - self, _mock_requests_head: unittest.mock.MagicMock, _mock_try_download: unittest.mock.MagicMock - ) -> None: - _mock_try_download.return_value = True - _mock_requests_head.return_value = requests.Response() - _mock_requests_head.return_value.status_code = requests.codes.ok - + def test_any_in_platforms(self) -> None: runner = click.testing.CliRunner() with unittest.mock.patch.object(sut.wheel, "url") as mock_url: result = runner.invoke( @@ -38,4 +27,60 @@ def test_any_in_platforms( # Due to tags being a required keyword argument if "tags" in kwargs: assert any([t for t in kwargs["tags"] if t.endswith("any")]) - assert result.exit_code == 0 + + @unittest.mock.patch.object(sut, "_LOGGER") + @unittest.mock.patch.object(sut.output, "get") + def test_output_not_set_error( + self, + mock_output_get: unittest.mock.MagicMock, + mock_logger: unittest.mock.MagicMock, + ) -> None: + mock_output_get.side_effect = sut.output.OutputNotSetError + + runner = click.testing.CliRunner() + result = runner.invoke(cli=sut.main, args=["--package-name", "some-package"]) + + mock_logger.error.assert_called_once_with("could not get $OUTS") + assert result.exit_code == 1 + + @unittest.mock.patch.object(sut.output, "get") + @unittest.mock.patch.object(sut.output, "download") + @unittest.mock.patch.object(sut.wheel, "url") + @unittest.mock.patch.object(sut, "_LOGGER") + def test_wheel_url_error( + self, + mock_logger: unittest.mock.MagicMock, + mock_wheel_url: unittest.mock.MagicMock, + mock_output_download: unittest.mock.MagicMock, + mock_output_get: unittest.mock.MagicMock, + ) -> None: + + # Setup variables + package_name = "test-package" + package_version = "1.0.0" + exception_message = "Test Exception" + + # Set up mocks + mock_output_get.return_value = "output_name" + mock_output_download.return_value = False + mock_wheel_url.side_effect = Exception(exception_message) + + # Run the command + runner = click.testing.CliRunner() + result = runner.invoke( + cli=sut.main, + args=[ + "--package-name", + f"{package_name}", + "--package-version", + f"{package_version}", + ], + ) + + # Check that the error was logged correctly + mock_logger.error.assert_any_call( + f"could not find PyPI URL for {package_name}-{package_version}", + ) + + # Check that the program exited with an error code + assert result.exit_code == 1 diff --git a/tools/wheel_resolver/output.py b/tools/wheel_resolver/output.py index 1628c823..46f0b6ed 100644 --- a/tools/wheel_resolver/output.py +++ b/tools/wheel_resolver/output.py @@ -1,27 +1,37 @@ import os -import sys import urllib.request import logging +import typing _LOGGER = logging.getLogger(__name__) + class OutputNotSetError(RuntimeError): pass -def try_download(url): - """ - Try to download url to $OUTS. Returns false if - it failed. - """ + +def get() -> str: output = os.environ.get("OUTS") if output is None: raise OutputNotSetError() - - try: - urllib.request.urlretrieve(url, output) - except urllib.error.HTTPError: - return False - - return True + return output +def download( + package_name: str, + package_version: typing.Optional[str], + url: typing.Tuple[str], + download_output: str, +) -> bool: + """Download url to $OUTS.""" + for u in url: + try: + urllib.request.urlretrieve(u, download_output) + except urllib.error.HTTPError as error: + _LOGGER.warning( + f"download {package_name}-{package_version} from {u}: {error}", + ) + else: + _LOGGER.info(f"downloaded {package_name}-{package_version} from {u}") + return True + return False diff --git a/tools/wheel_resolver/wheel.py b/tools/wheel_resolver/wheel.py index 93134a11..0a38c168 100644 --- a/tools/wheel_resolver/wheel.py +++ b/tools/wheel_resolver/wheel.py @@ -1,7 +1,6 @@ import typing import distlib.locators import logging -import itertools _LOGGER = logging.getLogger(__name__)