From 89f5cc4c416dad5b25ea55af10989a52961b1b4e Mon Sep 17 00:00:00 2001 From: antonio Date: Wed, 13 Nov 2024 22:34:18 -0500 Subject: [PATCH] update flags to support strict/port options --- src/luxos/cli/flags.py | 26 ++++++++++++++++++-------- tests/conftest.py | 15 +++++++++++++++ tests/test_cli_flags.py | 17 +++++++---------- 3 files changed, 40 insertions(+), 18 deletions(-) diff --git a/src/luxos/cli/flags.py b/src/luxos/cli/flags.py index 79179a1..b9c9459 100644 --- a/src/luxos/cli/flags.py +++ b/src/luxos/cli/flags.py @@ -45,26 +45,36 @@ class type_ipaddress(ArgumentTypeBase): options = parser.parse_args() ... - assert options.x == ("host", 9999) + assert options.x == ("1.2.3.4", 9999) shell:: - file.py -x host:9999 + file.py -x 1.2.3.4:9999 """ + def __init__(self, port=None, strict=True): + super().__init__() + self.strict = strict + self.port = port + def validate(self, txt) -> None | tuple[str, None | int]: from luxos import ips if txt is None: return None try: - result = ips.parse_expr(txt) or ("", "", None) - if result[1]: - raise argparse.ArgumentTypeError("cannot use a range as expression") - return (result[0], result[2]) - except ips.AddressParsingError as exc: - raise argparse.ArgumentTypeError(f"failed to parse {txt=}: {exc.args[0]}") + if txt.count(":") not in {0, 1}: + raise ValueError("too many ':' (none or one)") + if self.strict: + ip, port = ips.splitip(txt) or ("", self.port) + else: + ip, _, port = txt.partition(":") + return ip, int(port) if port else self.port + except (RuntimeError, ValueError) as exc: + raise argparse.ArgumentTypeError( + f"failed to convert to a strict ip address: {exc.args[0]}" + ) def type_range(txt: str) -> Sequence[tuple[str, int | None]]: diff --git a/tests/conftest.py b/tests/conftest.py index 2f56010..581107f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -157,6 +157,21 @@ def port(miner_host_port): return miner_host_port[1] +@pytest.fixture(scope="function") +def cli_generator(): + from luxos.cli.shared import LuxosParserBase + + class MyParser(LuxosParserBase): + def __init__(self): + super().__init__([], exit_on_error=False) + + def add_argument(self, *args, **kwargs): + super().add_argument(*args, **kwargs) + return self + + return lambda: MyParser() + + def pytest_addoption(parser): parser.addoption( "--manual", diff --git a/tests/test_cli_flags.py b/tests/test_cli_flags.py index bce4c7d..09402fc 100644 --- a/tests/test_cli_flags.py +++ b/tests/test_cli_flags.py @@ -42,15 +42,12 @@ def test_type_range(resolver): def test_type_hhmm(): - assert flags.type_hhmm("12:34").default == datetime.time(12, 34) + assert flags.type_hhmm().validate("12:34") == datetime.time(12, 34) - pytest.raises(RuntimeError, flags.type_hhmm, "12") - pytest.raises(argparse.ArgumentTypeError, flags.type_hhmm(), "12") + with pytest.raises(argparse.ArgumentTypeError) as e: + flags.type_hhmm().validate("12") + assert e.value.args[-1] == "failed conversion into HH:MM for '12'" - -def test_type_ipaddress(): - assert flags.type_ipaddress("hello").default == ("hello", None) - assert flags.type_ipaddress("hello:123").default == ("hello", 123) - - pytest.raises(RuntimeError, flags.type_ipaddress, "12:dwedwe") - pytest.raises(argparse.ArgumentTypeError, flags.type_ipaddress(), "12:dwedwe") + with pytest.raises(argparse.ArgumentTypeError) as e: + flags.type_hhmm().validate("hello") + assert e.value.args[-1] == "failed conversion into HH:MM for 'hello'"