Skip to content

Commit

Permalink
⬆️ upgrade NEPattern to v1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 26, 2024
1 parent a1fe56f commit 87f1bbe
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 80 deletions.
2 changes: 1 addition & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __repr__(self):
alc = Alconna(
["."],
"test",
Args["bar", ANY]
Args.bar(ANY)
)

analyser = command_manager.require(alc)
Expand Down
2 changes: 1 addition & 1 deletion commander/__main__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from arclet.alconna import Alconna, Args, command_manager
from arclet.alconna.v1 import Alconna, Args, command_manager
from commander import Commands

command = Commands()
Expand Down
16 changes: 8 additions & 8 deletions exam7.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Literal, Optional, overload
from typing_extensions import Self

from arclet.alconna import Args
from arclet.alconna.args import ARGS_PARAM, Args, _Args, handle_args
from arclet.alconna.action import Action, store, store_true


Expand All @@ -20,7 +20,7 @@ class Scope:
@dataclass(eq=True, unsafe_hash=True)
class Node:
name: str
args: Args = field(default_factory=Args)
args: _Args = field(default_factory=lambda : _Args([]))
action: Action = field(default=store)
help_text: str = field(default="unknown")
dest: str = field(default="")
Expand All @@ -43,7 +43,7 @@ def path(self):
return f"{self.scope}.{self.name}" if self.scope != "$" else self.name

@overload
def assign(self, path: Literal[":args"], *, args: Args) -> Self:
def assign(self, path: Literal[":args"], *, args: ARGS_PARAM) -> Self:
...

@overload
Expand All @@ -59,7 +59,7 @@ def assign(self, path: Literal[":dest"], *, dest: str) -> Self:
...

@overload
def assign(self, path: str, spec: Literal[":args"], *, args: Args) -> Self:
def assign(self, path: str, spec: Literal[":args"], *, args: ARGS_PARAM) -> Self:
...

@overload
Expand All @@ -79,7 +79,7 @@ def assign(
self,
path: str,
*,
args: Optional[Args] = None,
args: Optional[ARGS_PARAM] = None,
action: Optional[Action] = None,
help_text: Optional[str] = None,
dest: Optional[str] = None
Expand All @@ -90,7 +90,7 @@ def assign(
self,
path: str,
spec: Optional[str] = None,
args: Optional[Args] = None,
args: Optional[ARGS_PARAM] = None,
action: Optional[Action] = None,
help_text: Optional[str] = None,
dest: Optional[str] = None
Expand All @@ -112,7 +112,7 @@ def assign(
if part not in NodeMap[prev.path].substance:
raise ValueError(f"Unknown node {part}")
prev = NodeMap[prev.path].substance[part]
new = Node(parts[-1], args, action, help_text, dest, prev.path)
new = Node(parts[-1], handle_args(args), action, help_text, dest, prev.path)
NodeMap[prev.path].substance[new.name] = new
return self

Expand All @@ -131,7 +131,7 @@ def select(self, path: str) -> "Node":

node = Node("root")
node.assign("foo")
node.assign("foo", ":args", args=Args["foo", int]["bar", str])
node.assign("foo", ":args", args=Args.foo(int).bar(str))
foo = node.select("foo")
foo.assign(":action", action=store_true)
bar = foo.assign("bar", help_text="bar").select("bar")
Expand Down
24 changes: 12 additions & 12 deletions pdm.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
]
dependencies = [
"typing-extensions>=4.5.0",
"nepattern<1.0.0,>=0.7.6",
"nepattern<1.1.0,>=1.0.0",
"tarina<0.7.0,>=0.6.1",
"elaina-segment>=0.4.0",
"elaina-flywheel>=0.6.0",
Expand Down
4 changes: 2 additions & 2 deletions src/arclet/alconna/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, Generic, Literal, TypeVar, ClassVar, ForwardRef, Final, TYPE_CHECKING, get_origin, get_args
from typing_extensions import dataclass_transform, ParamSpec, Concatenate, TypeAlias

from nepattern import NONE, BasePattern, RawStr, UnionPattern, parser
from nepattern import NONE, Pattern, RawStr, UnionPattern, parser
from tarina import Empty, lang

from ._dcls import safe_dcls_kw, safe_field_kw
Expand Down Expand Up @@ -105,7 +105,7 @@ class Arg(Generic[_T]):

name: str = dc.field(compare=True, hash=True)
"""参数单元的名称"""
type_: BasePattern[_T, Any, Any] = dc.field(compare=False, hash=True)
type_: Pattern[_T] = dc.field(compare=False, hash=True)
"""参数单元的类型"""
field: Field[_T] = dc.field(compare=False, hash=False)
"""参数单元的字段"""
Expand Down
22 changes: 12 additions & 10 deletions src/arclet/alconna/ingedia/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from typing import TYPE_CHECKING, Any, Iterable, Literal

from nepattern import ANY, STRING, AnyString, BasePattern
from nepattern import ANY, STRING, AnyString, Pattern
from tarina import Empty, lang, safe_eval, split_once

from ..action import Action
Expand Down Expand Up @@ -53,7 +53,7 @@ def _context(argv: Argv, target: Arg[Any], _arg: str):
)


def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any, Any], result: dict[str, Any], arg: Any, _str: bool):
def _validate(argv: Argv, target: Arg[Any], value: Pattern[Any], result: dict[str, Any], arg: Any, _str: bool):
_arg = arg
if _str and argv.context_style:
_arg = _context(argv, target, _arg)
Expand All @@ -64,13 +64,15 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any, Any], r
result[target.name] = str(_arg)
return
default_val = target.field.default
res = value.validate(_arg, default_val)
if res.flag != "valid":
res = value.execute(_arg)
if res._value is Empty:
argv.rollback(arg)
if res.flag == "error":
if default_val is not Empty:
result[target.name] = default_val
return
if target.field.optional:
return
raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), target)
raise InvalidParam(target.field.get_unmatch_tips(arg, res.error().args[0]), target) # type: ignore
result[target.name] = res._value # noqa


Expand All @@ -96,7 +98,7 @@ def step_varpos(argv: Argv, args: _Args, slot: tuple[int | Literal["+", "*", "st
if _str and args.vars_keyword and args.vars_keyword[0][1].field.kw_sep in may_arg:
argv.rollback(may_arg)
break
if (res := value.validate(may_arg)).flag != "valid":
if not (res := value.execute(may_arg)).success:
argv.rollback(may_arg)
break
_result.append(res._value) # noqa
Expand Down Expand Up @@ -140,7 +142,7 @@ def step_varkey(argv: Argv, slot: tuple[int | Literal["+", "*", "str"], Arg], re
key = _kwarg[1]
if not (_m_arg := _kwarg[2]):
_m_arg, _ = argv.next(arg.field.seps)
if (res := value.validate(_m_arg)).flag != "valid":
if not (res := value.execute(_m_arg)).success:
argv.rollback(may_arg)
break
_result[key] = res._value # noqa
Expand Down Expand Up @@ -185,7 +187,7 @@ def step_keyword(argv: Argv, args: _Args, result: dict[str, Any]):
):
break
for arg in args.keyword_only.values():
if arg.type_.validate(may_arg).flag == "valid":
if arg.type_.execute(may_arg).success:
raise InvalidParam(lang.require("args", "key_missing").format(target=may_arg, key=arg.name), arg)
for name in args.keyword_only:
if levenshtein(_key, name) >= argv.fuzzy_threshold:
Expand Down Expand Up @@ -253,7 +255,7 @@ def analyse_args(argv: Argv, args: _Args) -> dict[str, Any]:
else:
data = [
d for d in argv.release(no_split=True)
if (res := value.validate(d)).flag == "valid" or (not value.ignore and _raise(arg, d, res))
if (res := value.execute(d)).success or (not value.ignore and _raise(arg, d, res))
]
result[arg.name] = argv.converter(data)
argv.current_index = argv.ndata
Expand Down
12 changes: 6 additions & 6 deletions src/arclet/alconna/sistana/fragment.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .some import Some

if TYPE_CHECKING:
from nepattern import BasePattern
from nepattern import Pattern


@dataclass(**safe_dcls_kw(slots=True))
Expand Down Expand Up @@ -62,14 +62,14 @@ def _transform(v: Segment):

return self

def apply_nepattern(self, pat: BasePattern | None = None, capture_mode: bool = False):
def apply_nepattern(self, pat: Pattern | None = None, capture_mode: bool = False):
if pat is None:
if self.type is None:
return self

from nepattern import BasePattern
from nepattern import parser

pat = BasePattern.to(self.type.value)
pat = parser(self.type.value)
assert pat is not None

def _validate(v: Segment):
Expand All @@ -78,7 +78,7 @@ def _validate(v: Segment):
v = str(v)
else:
v = v.ref[0]
return pat.validate(v).success
return pat.execute(v).success

self.validator = _validate
if self.cast:
Expand All @@ -90,7 +90,7 @@ def _transform(v: Segment):
else:
v = v.ref[0]

return pat.validate(v).value()
return pat.execute(v).value()

self.transformer = _transform
return self
12 changes: 6 additions & 6 deletions src/arclet/alconna/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
)
from typing_extensions import TypeAlias

from nepattern import BasePattern, MatchFailed, MatchMode
from nepattern import Pattern, MatchFailed
from tarina import generic_isinstance, lang


Expand Down Expand Up @@ -66,15 +66,15 @@ def __len__(self) -> int: ...
TDC = TypeVar("TDC", bound=DataCollection[Any])
T = TypeVar("T")
T1 = TypeVar("T1")
TAValue: TypeAlias = Union[BasePattern[T, Any, Any], Type[T], T, Callable[..., T], Dict[Any, T], List[T]]
TAValue: TypeAlias = Union[Pattern[T], Type[T], T, Callable[..., T], Dict[Any, T], List[T]]


@final
class _AllParamPattern(BasePattern[T, T, Literal[MatchMode.KEEP]], Generic[T]):
class _AllParamPattern(Pattern[T]):
def __init__(self, types: tuple[type[T1], ...] = (), ignore: bool = True):
self.types = types
self.ignore = ignore
super().__init__(mode=MatchMode.KEEP, origin=Any, alias="*")
super().__init__(alias="*")

def match(self, input_: Any) -> Any: # pragma: no cover
if not self.types:
Expand All @@ -96,14 +96,14 @@ def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1
def __call__(self, *types: type[T1], ignore: bool = True) -> _AllParamPattern[T1]:
return _AllParamPattern(types, ignore)

def __calc_eq__(self, other): # pragma: no cover
def __eq__(self, other): # pragma: no cover
return other.__class__ is _AllParamPattern


AllParam: _AllParamPattern[Any] = _AllParamPattern()


class KWBool(BasePattern):
class KWBool(Pattern):
"""对布尔参数的包装"""


Expand Down
3 changes: 2 additions & 1 deletion src/arclet/alconna/v1/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from enum import Enum
from typing import Any, Final, Iterable

from nepattern import Pattern
from tarina import Empty
from typing_extensions import Self, deprecated

Expand Down Expand Up @@ -50,7 +51,7 @@ def build(self):
arg.field.kw_sep = value.sep
arg.type_ = value.base
elif isinstance(value, UnpackVar):
arg.type_ = value.of(value.origin)
arg.type_ = Pattern(value.origin)
return super().build()

def __truediv__(self, other) -> Self:
Expand Down
4 changes: 2 additions & 2 deletions src/arclet/alconna/v1/stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Generic, TypeVar
from typing_extensions import Self, deprecated

from nepattern import ANY, BasePattern
from nepattern import ANY, Pattern

from arclet.alconna.args import _Args
from arclet.alconna.base import Option, Subcommand, OptionResult, SubcommandResult
Expand Down Expand Up @@ -53,7 +53,7 @@ def __post_init__(self):
key = arg.name
if arg.type_ in (AllParam, ANY):
self.__annotations__[key] = Any
elif isinstance(arg.type_, BasePattern):
elif isinstance(arg.type_, Pattern):
self.__annotations__[key] = arg.type_.origin
else:
self.__annotations__[key] = arg.type_
Expand Down
Loading

0 comments on commit 87f1bbe

Please sign in to comment.