Skip to content

Commit

Permalink
Refactor wheel_resolver to allow for non-PyPI downloads (#167)
Browse files Browse the repository at this point in the history
* Refactor wheel_resolver to download from non-PyPI URLs

* improve logging and error handling

* Improve logging

* Test new exit points of refactored main
  • Loading branch information
tm-jdelapuente authored Aug 20, 2024
1 parent 2a70346 commit 01767c8
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 72 deletions.
24 changes: 12 additions & 12 deletions tools/wheel_resolver/BUILD
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
subinclude("//build_defs:python")

python_binary(
name = "wheel_resolver",
main = "main.py",
visibility = ["PUBLIC"],
deps = [
name="wheel_resolver",
main="main.py",
visibility=["PUBLIC"],
deps=[
":wheel",
],
)

python_library(
name = "wheel",
srcs = [
name="wheel",
srcs=[
"__init__.py",
"wheel.py",
"output.py",
],
deps = [
deps=[
"//third_party/python:click",
"//third_party/python:click-log",
"//third_party/python:distlib",
Expand All @@ -26,14 +26,14 @@ python_library(
)

python_test(
name = "test",
timeout = 600,
srcs = [
name="test",
timeout=600,
srcs=[
"__init___test.py",
"wheel_test.py",
],
test_runner = "pytest",
deps = [
test_runner="pytest",
deps=[
":wheel",
"//third_party/python:pytest",
],
Expand Down
65 changes: 32 additions & 33 deletions tools/wheel_resolver/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import click
import typing
import requests
import logging
import click_log
import sys
Expand Down Expand Up @@ -71,7 +70,7 @@
)
@click_log.simple_verbosity_option(_LOGGER)
def main(
url: typing.List[str],
url: typing.Tuple[str],
package_name: str,
package_version: typing.Optional[str],
interpreter: typing.Tuple[str, ...],
Expand All @@ -85,39 +84,39 @@ def main(
PyPI for PACKAGE with VERSION.
"""
for u in url:
response = requests.head(u)
if response.status_code != requests.codes.ok:
_LOGGER.warning(
"%s-%s is not available, tried %r", package_name, package_version, u
)
else:
click.echo(u)
return
try:
output_name = output.get()
except output.OutputNotSetError:
_LOGGER.error("could not get $OUTS")
sys.exit(1)

if output.download(package_name, package_version, url, output_name):
return

# We're currently hardcoding PyPI but we should consider allowing other
# repositories
# TODO (tm-jdelapuente): allow downloads from other package repositories
locator = distlib.locators.SimpleScrapingLocator(url="https://pypi.org/simple")
locator.wheel_tags = list(itertools.product(interpreter, abi, platform))
u = wheel.url(
package_name=package_name,
package_version=package_version,
tags=[
str(x)
for i in interpreter
for x in tags.generic_tags(
interpreter=i,
abis=set(abi),
platforms=set(platform).union({"any"}),
)
],
locator=locator,
prereleases=prereleases,
)

if not output.try_download(u):
_LOGGER.error("Could not download from %r", u)
try:
u = wheel.url(
package_name=package_name,
package_version=package_version,
tags=[
str(x)
for i in interpreter
for x in tags.generic_tags(
interpreter=i,
abis=set(abi),
platforms=set(platform).union({"any"}),
)
],
locator=locator,
prereleases=prereleases,
)
pypi_url = (u,)
except Exception as error:
_LOGGER.error(error)
_LOGGER.error(f"could not find PyPI URL for {package_name}-{package_version}")
sys.exit(1)

click.echo(u)
if not output.download(package_name, package_version, pypi_url, output_name):
_LOGGER.error("could not download %s-%s", package_name, package_version)
sys.exit(1)
71 changes: 58 additions & 13 deletions tools/wheel_resolver/__init___test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import click.testing
import pytest
import unittest.mock
import requests

import tools.wheel_resolver as sut


Expand All @@ -12,15 +9,7 @@ def test_help(self) -> None:
result = runner.invoke(cli=sut.main, args=["--help"])
assert result.exit_code == 0

@unittest.mock.patch.object(sut.output, "try_download")
@unittest.mock.patch.object(sut.requests, "head")
def test_any_in_platforms(
self, _mock_requests_head: unittest.mock.MagicMock, _mock_try_download: unittest.mock.MagicMock
) -> None:
_mock_try_download.return_value = True
_mock_requests_head.return_value = requests.Response()
_mock_requests_head.return_value.status_code = requests.codes.ok

def test_any_in_platforms(self) -> None:
runner = click.testing.CliRunner()
with unittest.mock.patch.object(sut.wheel, "url") as mock_url:
result = runner.invoke(
Expand All @@ -38,4 +27,60 @@ def test_any_in_platforms(
# Due to tags being a required keyword argument
if "tags" in kwargs:
assert any([t for t in kwargs["tags"] if t.endswith("any")])
assert result.exit_code == 0

@unittest.mock.patch.object(sut, "_LOGGER")
@unittest.mock.patch.object(sut.output, "get")
def test_output_not_set_error(
self,
mock_output_get: unittest.mock.MagicMock,
mock_logger: unittest.mock.MagicMock,
) -> None:
mock_output_get.side_effect = sut.output.OutputNotSetError

runner = click.testing.CliRunner()
result = runner.invoke(cli=sut.main, args=["--package-name", "some-package"])

mock_logger.error.assert_called_once_with("could not get $OUTS")
assert result.exit_code == 1

@unittest.mock.patch.object(sut.output, "get")
@unittest.mock.patch.object(sut.output, "download")
@unittest.mock.patch.object(sut.wheel, "url")
@unittest.mock.patch.object(sut, "_LOGGER")
def test_wheel_url_error(
self,
mock_logger: unittest.mock.MagicMock,
mock_wheel_url: unittest.mock.MagicMock,
mock_output_download: unittest.mock.MagicMock,
mock_output_get: unittest.mock.MagicMock,
) -> None:

# Setup variables
package_name = "test-package"
package_version = "1.0.0"
exception_message = "Test Exception"

# Set up mocks
mock_output_get.return_value = "output_name"
mock_output_download.return_value = False
mock_wheel_url.side_effect = Exception(exception_message)

# Run the command
runner = click.testing.CliRunner()
result = runner.invoke(
cli=sut.main,
args=[
"--package-name",
f"{package_name}",
"--package-version",
f"{package_version}",
],
)

# Check that the error was logged correctly
mock_logger.error.assert_any_call(
f"could not find PyPI URL for {package_name}-{package_version}",
)

# Check that the program exited with an error code
assert result.exit_code == 1
36 changes: 23 additions & 13 deletions tools/wheel_resolver/output.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
import os
import sys
import urllib.request
import logging
import typing

_LOGGER = logging.getLogger(__name__)


class OutputNotSetError(RuntimeError):
pass

def try_download(url):
"""
Try to download url to $OUTS. Returns false if
it failed.
"""

def get() -> str:
output = os.environ.get("OUTS")
if output is None:
raise OutputNotSetError()

try:
urllib.request.urlretrieve(url, output)
except urllib.error.HTTPError:
return False

return True
return output


def download(
package_name: str,
package_version: typing.Optional[str],
url: typing.Tuple[str],
download_output: str,
) -> bool:
"""Download url to $OUTS."""
for u in url:
try:
urllib.request.urlretrieve(u, download_output)
except urllib.error.HTTPError as error:
_LOGGER.warning(
f"download {package_name}-{package_version} from {u}: {error}",
)
else:
_LOGGER.info(f"downloaded {package_name}-{package_version} from {u}")
return True
return False
1 change: 0 additions & 1 deletion tools/wheel_resolver/wheel.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import typing
import distlib.locators
import logging
import itertools

_LOGGER = logging.getLogger(__name__)

Expand Down

0 comments on commit 01767c8

Please sign in to comment.