Skip to content

Commit

Permalink
Merge pull request #251 from BrianPugh/bugfix/list-of-bool
Browse files Browse the repository at this point in the history
Allow for list[bool] and similar (list of flags).
  • Loading branch information
BrianPugh authored Nov 19, 2024
2 parents 657b72b + a16ba52 commit 1d2a550
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 15 deletions.
17 changes: 14 additions & 3 deletions cyclopts/_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,22 @@


_implicit_iterable_type_mapping: dict[type, type] = {
Iterable: list[str],
Sequence: list[str],
frozenset: frozenset[str],
list: list[str],
set: set[str],
tuple: tuple[str, ...],
dict: dict[str, str],
}

ITERABLE_TYPES = {list, set, frozenset, Sequence, Iterable, tuple}
ITERABLE_TYPES = {
Iterable,
Sequence,
frozenset,
list,
set,
tuple,
}

NestedCliArgs = dict[str, Union[Sequence[str], "NestedCliArgs"]]

Expand Down Expand Up @@ -161,7 +170,9 @@ def _convert(
origin_type = get_origin(type_)
inner_types = [resolve(x) for x in get_args(type_)]

if type_ in _implicit_iterable_type_mapping:
if type_ is dict:
out = convert(dict[str, str], token)
elif type_ in _implicit_iterable_type_mapping:
out = convert(_implicit_iterable_type_mapping[type_], token)
elif origin_type in (collections.abc.Iterable, collections.abc.Sequence):
assert len(inner_types) == 1
Expand Down
16 changes: 11 additions & 5 deletions cyclopts/argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
get_field_infos,
)
from cyclopts.group import Group
from cyclopts.parameter import Parameter
from cyclopts.parameter import ITERATIVE_BOOL_IMPLICIT_VALUE, Parameter
from cyclopts.token import Token
from cyclopts.utils import UNSET, ParameterDict, grouper, is_builtin

Expand Down Expand Up @@ -827,7 +827,7 @@ def _match_name(
name = transform(name)
if _startswith(term, name):
trailing = term[len(name) :]
implicit_value = True if self.hint is bool else None
implicit_value = True if self.hint is bool or self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE else None
if trailing:
if trailing[0] == delimiter:
trailing = trailing[1:]
Expand All @@ -843,7 +843,10 @@ def _match_name(
name = transform(name)
if term.startswith(name):
trailing = term[len(name) :]
implicit_value = (get_origin(self.hint) or self.hint)()
if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE:
implicit_value = False
else:
implicit_value = (get_origin(self.hint) or self.hint)()
if trailing:
if trailing[0] == delimiter:
trailing = trailing[1:]
Expand Down Expand Up @@ -917,8 +920,11 @@ def safe_converter(hint, tokens):
keyword = {}
for token in self.tokens:
if token.implicit_value is not UNSET:
assert len(self.tokens) == 1
return token.implicit_value
if self.hint in ITERATIVE_BOOL_IMPLICIT_VALUE:
return get_origin(self.hint)(x.implicit_value for x in self.tokens)
else:
assert len(self.tokens) == 1
return token.implicit_value

if token.keys:
lookup = keyword
Expand Down
33 changes: 26 additions & 7 deletions cyclopts/parameter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections.abc import Iterable
from functools import partial
from typing import Any, Callable, Optional, Union, cast, get_args, get_origin
from typing import Any, Callable, List, Optional, Sequence, Tuple, Union, cast, get_args, get_origin

import attrs
from attrs import field, frozen
Expand All @@ -18,7 +18,19 @@
to_tuple_converter,
)

_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES])
ITERATIVE_BOOL_IMPLICIT_VALUE = frozenset(
{
Iterable[bool],
Sequence[bool],
List[bool],
list[bool],
Tuple[bool, ...],
tuple[bool, ...],
}
)


_NEGATIVE_FLAG_TYPES = frozenset([bool, *ITERABLE_TYPES, *ITERATIVE_BOOL_IMPLICIT_VALUE])


def _not_hyphen_validator(instance, attribute, values):
Expand Down Expand Up @@ -85,7 +97,7 @@ def main(foo: Annotated[int, Parameter(name="bar")]):
converter=lambda x: cast(tuple[Callable, ...], to_tuple_converter(x)),
)

# This can ONLY ever be a Tuple[str, ...]
# This can ONLY ever be ``None`` or ``Tuple[str, ...]``
negative: Union[None, str, Iterable[str]] = field(default=None, converter=optional_to_tuple_converter)

# This can ONLY ever be a Tuple[Union[Group, str], ...]
Expand Down Expand Up @@ -162,10 +174,14 @@ def get_negatives(self, type_) -> tuple[str, ...]:
if is_union(type_):
type_ = next(x for x in get_args(type_) if x is not None)

type_ = get_origin(type_) or type_
origin = get_origin(type_)

if (self.negative is not None and not self.negative) or type_ not in _NEGATIVE_FLAG_TYPES:
return ()
if type_ not in _NEGATIVE_FLAG_TYPES:
if origin:
if origin not in _NEGATIVE_FLAG_TYPES:
return ()
else:
return ()

out, user_negatives = [], []
if self.negative:
Expand All @@ -182,7 +198,10 @@ def get_negatives(self, type_) -> tuple[str, ...]:
name = name[2:]
name_components = name.split(".")

negative_prefixes = self.negative_bool if type_ is bool else self.negative_iterable
if type_ is bool or type_ in ITERATIVE_BOOL_IMPLICIT_VALUE:
negative_prefixes = self.negative_bool
else:
negative_prefixes = self.negative_iterable
name_prefix = ".".join(name_components[:-1])
if name_prefix:
name_prefix += "."
Expand Down
25 changes: 25 additions & 0 deletions tests/test_bind_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,31 @@ def foo(a: Optional[List[int]] = None):
assert_parse_args(foo, "foo")


@pytest.mark.parametrize(
"cmd_expected",
[
("", None),
("--verbose", [True]),
("--verbose --verbose", [True, True]),
("--verbose --verbose --no-verbose", [True, True, False]),
("--verbose --verbose=False", [True, False]),
("--verbose --no-verbose=False", [True, True]),
("--verbose --verbose=True", [True, True]),
],
)
def test_keyword_list_of_bool(app, assert_parse_args, cmd_expected):
cmd, expected = cmd_expected

@app.default
def foo(*, verbose: Optional[list[bool]] = None):
pass

if expected is None:
assert_parse_args(foo, cmd)
else:
assert_parse_args(foo, cmd, verbose=expected)


@pytest.mark.parametrize(
"cmd",
[
Expand Down

0 comments on commit 1d2a550

Please sign in to comment.