Skip to content

Commit

Permalink
💥 change AllParam to Field.wildcard
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Nov 18, 2024
1 parent 96d1912 commit 1fce043
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 105 deletions.
1 change: 0 additions & 1 deletion src/arclet/alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,5 @@
from .formatter import TextFormatter as TextFormatter
from .shortcut import ShortcutArgs as ShortcutArgs
from .manager import command_manager as command_manager
from .utils import AllParam as AllParam

__version__ = "1.8.31"
9 changes: 7 additions & 2 deletions src/arclet/alconna/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ class Field(Generic[_T]):
kw_only: bool = dc.field(default=False, compare=False, hash=False)
multiple: bool | int | Literal["+", "*", "str"] = dc.field(default=False, compare=False, hash=False)
kw_sep: str = dc.field(default="=", compare=False, hash=False)
wildcard: bool = dc.field(default=False, compare=False, hash=False)

@property
def display(self):
Expand Down Expand Up @@ -100,8 +101,9 @@ def arg_field(
kw_sep: str = "=",
optional: bool = False,
hidden: bool = False,
wildcard: bool = False,
) -> "Any":
return Field(default, default_factory, alias, completion, unmatch_tips, missing_tips, notice, seps, optional, hidden, kw_only, multiple, kw_sep)
return Field(default, default_factory, alias, completion, unmatch_tips, missing_tips, notice, seps, optional, hidden, kw_only, multiple, kw_sep, wildcard)


@dc.dataclass(**safe_dcls_kw(init=False, eq=True, unsafe_hash=True, slots=True))
Expand Down Expand Up @@ -153,7 +155,10 @@ def __init__(
setattr(self.field, k, v)

def __str__(self):
n, v = f"'{self.name_display}'", self.type_display
if self.field.wildcard:
v = n = f"...{self.name}"
else:
n, v = f"'{self.name_display}'", self.type_display
return (n if n == v else f"{n}: {v}") + (f" = '{self.field.display}'" if self.field.display is not Empty else "")

def __add__(self, other) -> "ArgsBuilder":
Expand Down
3 changes: 1 addition & 2 deletions src/arclet/alconna/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from .i18n import i18n
from .args import Arg, _Args
from .base import Option, Subcommand
from .utils import AllParam
from .shortcut import InnerShortcutArgs

if TYPE_CHECKING:
Expand Down Expand Up @@ -154,7 +153,7 @@ def param(self, parameter: Arg) -> str:
return f"[{name}]" if parameter.field.optional else name
if parameter.field.hidden:
return f"[{name}]" if parameter.field.optional else f"<{name}>"
if parameter.type_ is AllParam:
if parameter.field.wildcard:
return f"<...{name}>"
arg = f"[{name}" if parameter.field.optional else f"<{name}"
if parameter.type_ not in (ANY, AnyString):
Expand Down
98 changes: 51 additions & 47 deletions src/arclet/alconna/ingedia/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PauseTriggered,
ParamsUnmatched,
)
from ..utils import _AllParamPattern, levenshtein
from ..utils import levenshtein

if TYPE_CHECKING:
from ._analyser import Analyser
Expand Down Expand Up @@ -51,10 +51,6 @@ def _context(argv: Argv, target: Arg[Any], _arg: str):
) from e


def _raise(target: Arg, arg: Any, res: Any):
raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), arg)


def _handle_arg(argv: Argv, target: Arg[Any], arg: Any, _str: bool):
value = target.type_
_arg = arg
Expand All @@ -76,38 +72,6 @@ def _handle_arg(argv: Argv, target: Arg[Any], arg: Any, _str: bool):
return res._value # noqa


def step(argv: Argv, ana: Analyser, arg: Arg[Any], result: dict[str, Any]):
field = arg.field
may_arg, _str = argv.next(field.seps)
if _str and may_arg in ana._unvisited and ((slot := ana._unvisited[may_arg])[1] not in ana.value_result and not slot[0].soft_keyword):
argv.rollback(may_arg)
may_arg = None
if may_arg is None or (_str and not may_arg):
if (de := arg.field.get_default()) is not Empty:
result[arg.name] = de
elif not field.optional:
raise ArgumentMissing(field.get_missing_tips(i18n.require("args.missing").format(key=arg.name)), arg)
return True
if field.kw_only:
if not _str:
raise InvalidParam(i18n.require("args.key_missing").format(target=may_arg, key=arg.name), arg)
key, _m_arg = split_once(may_arg, field.kw_sep, argv.filter_crlf)
key: str = pat.fullmatch(key)["name"] # type: ignore
if key != arg.name:
if levenshtein(key, arg.name) >= argv.fuzzy_threshold:
raise FuzzyMatchSuccess(i18n.require("fuzzy.matched").format(source=arg.name, target=key))
raise InvalidParam(i18n.require("args.key_not_found").format(name=key), arg)
if _m_arg:
may_arg = _m_arg
else:
may_arg, _str = argv.next(field.seps)
ans = _handle_arg(argv, arg, may_arg, _str)
if ans is Empty:
return True
result[arg.name] = ans
return True


def step_multiple(argv: Argv, ana: Analyser, arg: Arg[Any], result: dict[str, Any]):
field = arg.field
may_arg, _str = argv.next(field.seps)
Expand Down Expand Up @@ -154,6 +118,21 @@ def step_multiple(argv: Argv, ana: Analyser, arg: Arg[Any], result: dict[str, An
return False


def _handle_arg_wild(target: Arg[Any], arg: Any):
value = target.type_
_str = isinstance(arg, str)
if value is ANY or (value is STRING and _str):
return arg
if value is AnyString:
return str(arg)
res = value.execute(arg)
if res._value is Empty:
if target.field.optional:
return Empty
raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), target) # type: ignore
return res._value # noqa


def analyse_args(analyser: Analyser, argv: Argv, args: _Args) -> dict[str, Any]:
"""
分析 `_Args` 部分
Expand All @@ -170,18 +149,43 @@ def analyse_args(analyser: Analyser, argv: Argv, args: _Args) -> dict[str, Any]:
index = 0
while index < args.count:
arg = args.data[index]
if arg.type_.alias == "*":
if TYPE_CHECKING:
assert isinstance(arg.type_, _AllParamPattern)
if not arg.type_.types:
result[arg.name] = argv.converter(argv.release(no_split=True))
else:
data = [d for d in argv.release(no_split=True) if (res := arg.type_.execute(d)).success or (not arg.type_.ignore and _raise(arg, d, res))]
result[arg.name] = argv.converter(data)
field = arg.field
if field.wildcard:
data = [_handle_arg_wild(arg, d) for d in argv.release(no_split=True)]
data = [d for d in data if d is not Empty]
result[arg.name] = argv.converter(data)
argv.current_index = argv.ndata
return result
if arg.field.multiple is False:
index += step(argv, analyser, arg, result)
if field.multiple is False:
may_arg, _str = argv.next(field.seps)
if _str and may_arg in analyser._unvisited and ((slot := analyser._unvisited[may_arg])[1] not in analyser.value_result and not slot[0].soft_keyword):
argv.rollback(may_arg)
may_arg = None
if may_arg is None or (_str and not may_arg):
if (de := arg.field.get_default()) is not Empty:
result[arg.name] = de
elif not field.optional:
raise ArgumentMissing(field.get_missing_tips(i18n.require("args.missing").format(key=arg.name)), arg)
index += 1
continue
if field.kw_only:
if not _str:
raise InvalidParam(i18n.require("args.key_missing").format(target=may_arg, key=arg.name), arg)
key, _m_arg = split_once(may_arg, field.kw_sep, argv.filter_crlf)
key: str = pat.fullmatch(key)["name"] # type: ignore
if key != arg.name:
if levenshtein(key, arg.name) >= argv.fuzzy_threshold:
raise FuzzyMatchSuccess(i18n.require("fuzzy.matched").format(source=arg.name, target=key))
raise InvalidParam(i18n.require("args.key_not_found").format(name=key), arg)
if _m_arg:
may_arg = _m_arg
else:
may_arg, _str = argv.next(field.seps)
ans = _handle_arg(argv, arg, may_arg, _str)
if ans is not Empty:
result[arg.name] = ans
index += 1
continue
elif step_multiple(argv, analyser, arg, result):
if arg.field.kw_only:
result[arg.name] = dict(result[arg.name])
Expand Down
34 changes: 0 additions & 34 deletions src/arclet/alconna/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,40 +68,6 @@ def __len__(self) -> int: ...
TAValue: TypeAlias = Union[Pattern[T], Type[T], T, Callable[..., T], Dict[Any, T], List[T]]


@final
class _AllParamPattern(Pattern[T]):
def __init__(self, types: tuple[type[T1], ...] = (), ignore: bool = True):
self.types = types
self.ignore = ignore
super().__init__(alias="*")

def match(self, input_: Any) -> Any: # pragma: no cover
if not self.types:
return input_
if generic_isinstance(input_, self.types): # type: ignore
return input_
raise MatchFailed(
lang.require("nepattern", "error.type").format(
type=input_.__class__.__name__, target=input_, expected=" | ".join(map(lambda t: t.__name__, self.types))
)
)

@overload
def __call__(self, *, ignore: bool = True) -> _AllParamPattern[Any]: ...

@overload
def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1]: ...

def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1]:
return _AllParamPattern(types, ignore)

def __eq__(self, other): # pragma: no cover
return other.__class__ is _AllParamPattern


AllParam: _AllParamPattern[Any] = _AllParamPattern()


class KWBool(Pattern):
"""对布尔参数的包装"""

Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
from arclet.alconna.formatter import TextFormatter as TextFormatter # noqa: F401
from arclet.alconna.manager import ShortcutArgs as ShortcutArgs # noqa: F401
from arclet.alconna.manager import command_manager as command_manager # noqa: F401
from arclet.alconna.utils import AllParam as AllParam # noqa: F401

from .args import ArgFlag as ArgFlag
from .args import Args as Args
Expand All @@ -49,6 +48,7 @@
from .stub import OptionStub as OptionStub
from .stub import SubcommandStub as SubcommandStub

from .typing import AllParam as AllParam
from .typing import KeyWordVar as KeyWordVar
from .typing import Kw as Kw
from .typing import MultiVar as MultiVar
Expand Down
8 changes: 5 additions & 3 deletions src/arclet/alconna/v1/args.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from __future__ import annotations

import warnings
from enum import Enum
from typing import Any, Final, Iterable

from nepattern import Pattern
from nepattern import Pattern, ANY, UnionPattern
from tarina import Empty
from typing_extensions import Self, deprecated

from arclet.alconna.args import ArgsBuilder, Arg
from arclet.alconna.utils import TAValue

from .typing import KeyWordVar, MultiVar, _StrMulti, UnpackVar
from .typing import KeyWordVar, MultiVar, _StrMulti, UnpackVar, _AllParamPattern


class ArgFlag(str, Enum):
Expand Down Expand Up @@ -52,6 +51,9 @@ def build(self):
arg.type_ = value.base
elif isinstance(value, UnpackVar):
arg.type_ = Pattern(value.origin)
elif isinstance(value, _AllParamPattern):
arg.type_ = UnionPattern.of(value.types) if value.types else ANY # type: ignore
arg.field.wildcard = True
return super().build()

def __truediv__(self, other) -> Self:
Expand Down
42 changes: 40 additions & 2 deletions src/arclet/alconna/v1/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,54 @@
from typing import (
Any,
Literal,
TypeVar,
TypeVar, overload, final,
)

from tarina import lang
from typing_extensions import deprecated

from nepattern import Pattern, parser
from nepattern import Pattern, parser, MatchFailed

from arclet.alconna.utils import TAValue
from arclet.alconna.utils import KWBool as KWBool # type: ignore[misc]

T = TypeVar("T")
T1 = TypeVar("T1")


@final
@deprecated("AllParam is deprecated, use `Field(wildcard=True)` instead", category=DeprecationWarning, stacklevel=1)
class _AllParamPattern(Pattern[T]):
def __init__(self, types: tuple[type[T1], ...] = (), ignore: bool = True):
self.types = types
self.ignore = ignore
super().__init__(alias="*")

def match(self, input_: Any) -> Any: # pragma: no cover
if not self.types:
return input_
if generic_isinstance(input_, self.types): # type: ignore
return input_
raise MatchFailed(
lang.require("nepattern", "error.type").format(
type=input_.__class__.__name__, target=input_, expected=" | ".join(map(lambda t: t.__name__, self.types))
)
)

@overload
def __call__(self, *, ignore: bool = True) -> _AllParamPattern[Any]: ...

@overload
def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1]: ...

def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1]:
return _AllParamPattern(types, ignore)

def __eq__(self, other): # pragma: no cover
return other.__class__ is _AllParamPattern


AllParam: _AllParamPattern[Any] = _AllParamPattern()


@deprecated("KeyWordVar is deprecated, use `Field(kw_only=True)` instead", category=DeprecationWarning, stacklevel=1)
Expand Down
Loading

0 comments on commit 1fce043

Please sign in to comment.