Skip to content

Commit

Permalink
fix handling for when Annotated is inside another type; namely list[A…
Browse files Browse the repository at this point in the history
…nnotated[...]].
  • Loading branch information
BrianPugh committed Jan 7, 2025
1 parent 2b6c5ea commit a898634
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 5 deletions.
3 changes: 2 additions & 1 deletion cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def _convert(
out = convert_tuple(type_, token, converter=converter)
else:
out = convert_tuple(type_, *token, converter=converter)
elif origin_type in ITERABLE_TYPES: # NOT including tuple
elif origin_type in ITERABLE_TYPES:
# NOT including tuple; handled in ``origin_type is tuple`` body above.
count, _ = token_count(inner_types[0])
if not isinstance(token, Sequence):
raise ValueError
Expand Down
33 changes: 29 additions & 4 deletions cyclopts/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import cyclopts.utils
from cyclopts._convert import (
ITERABLE_TYPES,
convert,
token_count,
)
Expand Down Expand Up @@ -112,6 +113,16 @@ def _identity_converter(type_, token):
return token


def _get_parameters(hint: Any) -> tuple[Any, list[Parameter]]:
"""At root level, checks for cyclopts.Parameter annotations."""
if is_annotated(hint):
inner = get_args(hint)
hint = inner[0]
return hint, [x for x in inner[1:] if isinstance(x, Parameter)]
else:
return hint, []


class ArgumentCollection(list["Argument"]):
"""A list-like container for :class:`Argument`."""

Expand Down Expand Up @@ -206,10 +217,23 @@ def _from_type(
cyclopts_parameters_no_group = []

hint = field_info.hint
if is_annotated(hint):
annotations = hint.__metadata__ # pyright: ignore
hint = get_args(hint)[0]
cyclopts_parameters_no_group.extend(x for x in annotations if isinstance(x, Parameter))
hint, hint_parameters = _get_parameters(hint)
cyclopts_parameters_no_group.extend(hint_parameters)

# Handle annotations where ``Annotated`` is not at the root level; e.g. ``list[Annotated[...]]``.
# Multiple inner Parameter Annotations only make sense if providing specific converter/validators.
origin = get_origin(hint)
if origin is tuple:
# handled in _convert.py
pass
elif origin in ITERABLE_TYPES:
inner_hints = get_args(hint)
if len(inner_hints) > 1:
raise NotImplementedError(f"Did not expect multiple inner type arguments: {inner_hints}.")
elif len(inner_hints) == 1:
inner_hint = inner_hints[0]
_, hint_parameters = _get_parameters(inner_hint)
cyclopts_parameters_no_group.extend(hint_parameters)

if not keys: # root hint annotation
if field_info.kind is field_info.VAR_KEYWORD:
Expand Down Expand Up @@ -245,6 +269,7 @@ def _from_type(
# if not immediate_parameter.parse:
# return out

# resolve/derive the parameter name
if keys:
cparam = Parameter.combine(
upstream_parameter,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ convention = "numpy"
"D106",
"D107",
"D205",
"D400",
"D404",
"S102", # use of "exec"
"S106", # possible hardcoded password.
Expand Down
22 changes: 22 additions & 0 deletions tests/types/test_types_number.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import pytest

from cyclopts.exceptions import ValidationError
Expand All @@ -18,3 +20,23 @@ def default(color: tuple[UInt8, UInt8, UInt8] = (0x00, 0x00, 0x00)):
with pytest.raises(ValidationError) as e:
app.parse_args("--color 100 200 300", exit_on_error=False)
assert str(e.value) == 'Invalid value "300" for "--color". Must be <= 255.'


def test_nested_list_annotated_validator(app, assert_parse_args):
@app.default
def default(color: Optional[list[tuple[UInt8, UInt8, UInt8]]] = None):
pass

assert_parse_args(
default,
"0x12 0x34 0x56 0x78 0x90 0xAB",
[(0x12, 0x34, 0x56), (0x78, 0x90, 0xAB)],
)

with pytest.raises(ValidationError) as e:
app.parse_args("100 200 300", exit_on_error=False)
assert str(e.value) == 'Invalid value "300" for "COLOR". Must be <= 255.'

with pytest.raises(ValidationError) as e:
app.parse_args("--color 100 200 300", exit_on_error=False)
assert str(e.value) == 'Invalid value "300" for "--color". Must be <= 255.'
22 changes: 22 additions & 0 deletions tests/types/test_types_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ def test_types_existing_file(convert, tmp_file):
assert tmp_file == convert(ct.ExistingFile, tmp_file)


def test_types_existing_file_app(app):
"""https://github.com/BrianPugh/cyclopts/issues/287"""

@app.default
def main(f: ct.ExistingFile):
pass

with pytest.raises(ValidationError):
app(["this-file-does-not-exist"], exit_on_error=False)


def test_types_existing_file_app_list(app):
"""https://github.com/BrianPugh/cyclopts/issues/287"""

@app.default
def main(f: list[ct.ExistingFile]):
pass

with pytest.raises(ValidationError):
app(["this-file-does-not-exist"], exit_on_error=False)


def test_types_existing_file_validation_error(convert, tmp_path):
with pytest.raises(ValidationError):
convert(ct.ExistingFile, tmp_path)
Expand Down

0 comments on commit a898634

Please sign in to comment.