Skip to content

Commit

Permalink
💥 Arparma changes
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Oct 29, 2024
1 parent ef73698 commit bc5c635
Show file tree
Hide file tree
Showing 22 changed files with 274 additions and 313 deletions.
2 changes: 1 addition & 1 deletion devtool.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from arclet.alconna.ingedia._argv import Argv
from arclet.alconna.base import Option, Subcommand, Header, Config
from arclet.alconna.config import Namespace
from arclet.alconna.typing import DataCollection
from arclet.alconna.utils import DataCollection


class AnalyseError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion src/arclet/alconna/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,6 @@
from .formatter import TextFormatter as TextFormatter
from .shortcut import ShortcutArgs as ShortcutArgs
from .manager import command_manager as command_manager
from .typing import AllParam as AllParam
from .utils import AllParam as AllParam

__version__ = "1.8.31"
2 changes: 1 addition & 1 deletion src/arclet/alconna/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from ._dcls import safe_dcls_kw, safe_field_kw
from .exceptions import InvalidArgs
from .typing import TAValue, parent_frame_namespace, merge_cls_and_parent_ns
from .utils import TAValue, parent_frame_namespace, merge_cls_and_parent_ns

_T = TypeVar("_T")

Expand Down
205 changes: 97 additions & 108 deletions src/arclet/alconna/arparma.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from warnings import warn
import inspect
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
Expand All @@ -8,56 +9,18 @@
from typing import Any, Callable, ClassVar, Generic, TypeVar, cast, overload, Literal
from typing_extensions import Self

from tarina import Empty, generic_isinstance, lang, safe_eval
from tarina import Empty, generic_isinstance, safe_eval

from .exceptions import BehaveCancelled, OutBoundsBehave
from .base import HeadResult, OptionResult, SubcommandResult
from .typing import TDC
from .utils import TDC
from .args import ArgsMeta, ArgsBase

T = TypeVar("T")
T1 = TypeVar("T1")
D = TypeVar("D")


def _handle_opt(_pf: str, _parts: list[str], _opts: dict[str, OptionResult]):
"""处理 `options.xxx.yyy.zzz` 形式的参数"""
if _pf == "options":
_pf = _parts.pop(0)
if not _parts: # options.foo or foo
return _opts, _pf
elif not (__src := _opts.get(_pf)): # options.foo.bar or foo.bar
return _opts, _pf
if (_end := _parts.pop(0)) == "value":
return __src, _end
if _end == "args":
return (__src.args, _parts.pop(0)) if _parts else (__src, _end)
return __src.args, _end


def _handle_sub(_pf: str, _parts: list[str], _subs: dict[str, SubcommandResult]):
"""处理 `subcommands.xxx.yyy.zzz` 形式的参数"""
if _pf == "subcommands":
_pf = _parts.pop(0)
if not _parts:
return _subs, _pf
elif not (__src := _subs.get(_pf)):
return _subs, _pf
if (_end := _parts.pop(0)) == "value":
return __src, _end
if _end == "args":
return (__src.args, _parts.pop(0)) if _parts else (__src, _end)
if _end == "options" and (_end in __src.options or not _parts):
raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=f"{_pf}.{_end}"))
if _end == "options" or _end in __src.options:
return _handle_opt(_end, _parts, __src.options)
if _end == "subcommands" and (_end in __src.subcommands or not _parts):
raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=f"{_pf}.{_end}"))
if _end == "subcommands" or _end in __src.subcommands:
return _handle_sub(_end, _parts, __src.subcommands)
return __src.args, _end


class _Query(Generic[T]):
source: Arparma

Expand Down Expand Up @@ -104,7 +67,7 @@ def __call__(self, path: str, default: D | None = None, *, force_return: bool =
raise KeyError(path)
return default
return MappingProxyType(source) # type: ignore
if endpoint:
if isinstance(endpoint, str) and endpoint:
try:
return getattr(source, endpoint)
except AttributeError:
Expand All @@ -123,17 +86,13 @@ class Arparma(Generic[TDC]):
header_match (HeadResult): 命令头匹配结果
error_info (type[BaseException] | BaseException | str): 错误信息
error_data (list[str | Any]): 错误数据
main_args (dict[str, Any]): 主参数匹配结果
other_args (dict[str, Any]): 其他参数匹配结果
options (dict[str, OptionResult]): 选项匹配结果
subcommands (dict[str, SubcommandResult]): 子命令匹配结果
value_result (dict[tuple[str, ...], Any]): 值匹配结果
args_result (dict[tuple[str, ...], dict[str, Any]]): 参数匹配结果
context (dict[str, Any]): 上下文
output (str | None): 输出信息
"""

header_match: HeadResult
options: dict[str, OptionResult]
subcommands: dict[str, SubcommandResult]
output: str | None

def __init__(
Expand All @@ -144,9 +103,8 @@ def __init__(
header_match: HeadResult | None = None,
error_info: type[Exception] | Exception | None = None,
error_data: list[str | Any] | None = None,
main_args: dict[str, Any] | None = None,
options: dict[str, OptionResult] | None = None,
subcommands: dict[str, SubcommandResult] | None = None,
value_result: dict[tuple[str, ...], Any] | None = None,
args_result: dict[tuple[str, ...], dict[str, Any]] | None = None,
ctx: dict[str, Any] | None = None,
):
"""初始化 `Arparma`
Expand All @@ -157,9 +115,8 @@ def __init__(
header_match (HeadResult | None, optional): 命令头匹配结果
error_info (type[Exception] | Exception | None, optional): 错误信息
error_data (list[str | Any] | None, optional): 错误数据
main_args (dict[str, Any] | None, optional): 主参数匹配结果
options (dict[str, OptionResult] | None, optional): 选项匹配结果
subcommands (dict[str, SubcommandResult] | None, optional): 子命令匹配结果
value_result (dict[tuple[str, ...], Any] | None, optional): 值匹配结果
args_result (dict[tuple[str, ...], dict[str, Any]] | None, optional): 参数匹配结果
ctx (dict[str, Any] | None, optional): 上下文
"""
self._id = _id
Expand All @@ -168,10 +125,8 @@ def __init__(
self.header_match = header_match or HeadResult()
self.error_info = error_info
self.error_data = error_data or []
self.main_args = main_args or {}
self.other_args = {}
self.options = options or {}
self.subcommands = subcommands or {}
self.value_result = value_result or {}
self.args_result = args_result or {}
self.context = ctx or {}
self.output = None

Expand All @@ -181,10 +136,8 @@ def __init__(
def _clr(self):
self.context.clear()
self.error_data.clear()
self.main_args.clear()
self.other_args.clear()
self.options.clear()
self.subcommands.clear()
self.value_result.clear()
self.args_result.clear()
ks = list(self.__dict__.keys())
for k in ks:
delattr(self, k)
Expand All @@ -197,17 +150,22 @@ def head_matched(self):
@property
def non_component(self) -> bool:
"""返回是否没有解析到任何组件"""
return not self.subcommands and not self.options
return not self.value_result

@property
def main_args(self) -> dict[str, Any]:
"""返回 Alconna 中主要 Args 解析到的值"""
return self.args_result.get((), {})

@property
def components(self) -> dict[str, OptionResult | SubcommandResult]:
"""返回解析到的组件"""
return {**self.options, **self.subcommands}
def other_args(self) -> dict[str, Any]:
"""返回 Alconna 中其他 Args 解析到的值"""
return {k: v for path, args in self.args_result.items() for k, v in args.items() if path}

@property
def all_matched_args(self) -> dict[str, Any]:
"""返回 Alconna 中所有 Args 解析到的值"""
return {**self.main_args, **self.other_args}
return {k: v for args in self.args_result.values() for k, v in args.items()}

@property
def token(self) -> int:
Expand All @@ -222,22 +180,41 @@ def source(self):

return command_manager._resolve(self._id)

def _unpack_opts(self, _data):
for _v in _data.values():
self.other_args = {**self.other_args, **_v.args}

def _unpack_subs(self, _data):
for _v in _data.values():
self.other_args = {**self.other_args, **_v.args}
if _v.options:
self._unpack_opts(_v.options)
if _v.subcommands:
self._unpack_subs(_v.subcommands)

def unpack(self) -> None:
"""处理 `Arparma` 中的数据"""
self._unpack_opts(self.options)
self._unpack_subs(self.subcommands)
@property
def options(self):
result = {}
for path, v in self.value_result.items():
if path == ():
continue
prefixes, key = path[:-1], path[-1]
if not prefixes:
if key in result:
result[key].value = v
else:
result[key] = SubcommandResult(v)
else:
sub = result.setdefault(prefixes[0], SubcommandResult())
for part in prefixes[1:]:
sub = sub.subcommands.setdefault(part, SubcommandResult())
sub.subcommands[key] = SubcommandResult(v)
for path, v in self.args_result.items():
if path == ():
continue
prefixes, key = path[:-1], path[-1]
if not prefixes:
if key in result:
result[key].args = v
else:
result[key] = SubcommandResult(..., v)
else:
sub = result.setdefault(prefixes[0], SubcommandResult())
for part in prefixes[1:]:
sub = sub.subcommands.setdefault(part, SubcommandResult())
if key in sub.subcommands:
sub.subcommands[key].args = v
else:
sub.subcommands[key] = SubcommandResult(..., v)
return result

@staticmethod
def behave_cancel(*msg: str):
Expand Down Expand Up @@ -289,7 +266,6 @@ def call(self, target: Callable[..., T]) -> T:
"args": self.main_args,
"all_args": self.all_matched_args,
"options": self.options,
"subcommands": self.subcommands,
}

sig = inspect.signature(target)
Expand Down Expand Up @@ -321,33 +297,42 @@ def fail(self, exc: type[Exception] | Exception) -> Self:
"""生成一个失败的 `Arparma`"""
return Arparma(self._id, self.origin, False, self.header_match, error_info=exc) # type: ignore

def __require__(self, parts: list[str]) -> tuple[dict[str, Any] | OptionResult | SubcommandResult | None, str]:
def __require__(self, parts: list[str]) -> tuple[Any, tuple[str, ...] | str | None]:
"""如果能够返回, 除开基本信息, 一定返回该path所在的dict"""
if len(parts) == 1:
part = parts[0]
if part in {"options", "subcommands", "main_args", "other_args", "context"}:
bak = parts.copy()
if len(bak) > 1:
if parts[-1] == "value":
parts.pop()
if tuple(parts) in self.value_result:
return self.value_result, tuple(parts)
elif parts[-1] == "args":
parts.pop()
if tuple(parts) in self.args_result:
return self.args_result, tuple(parts)
if tuple(parts) in self.value_result:
return OptionResult(self.value_result[tuple(parts)], self.args_result.get(tuple(parts), {})), None
# return self.value_result, tuple(parts)
may_arg = parts.pop()
if tuple(parts) in self.args_result:
return self.args_result[tuple(parts)], may_arg
else:
part = bak[0]
if part in {"main_args", "other_args", "context"}:
return getattr(self, part, {}), ""
for src in (self.main_args, self.other_args, self.options, self.subcommands, self.context):
for src in (self.main_args, self.other_args, self.context):
if part in src:
return src, part
return (self.all_matched_args, "") if part == "args" else (None, part)
prefix = parts.pop(0) # parts[0]
if prefix in {"options", "subcommands"} and prefix in self.components:
raise RuntimeError(lang.require("arparma", "ambiguous_name").format(target=prefix))
if prefix == "options" or prefix in self.options:
return _handle_opt(prefix, parts, self.options)
if prefix == "subcommands" or prefix in self.subcommands:
return _handle_sub(prefix, parts, self.subcommands)
prefix = prefix.replace("$main", "main_args").replace("$other", "other_args")
if prefix in {"main_args", "other_args"}:
return getattr(self, prefix, {}), parts.pop(0)
path = ".".join([prefix] + parts)
if (part,) in self.value_result:
return OptionResult(self.value_result[(part,)], self.args_result.get((part,), {})), None
if part == "args":
return self.all_matched_args, ""
path = ".".join(bak)
if path in self.context:
return self.context, path
try:
return safe_eval(path, self.context), "" # type: ignore
except Exception:
return None, prefix
return None, path

def query_with(self, arg_type: type[T], *args):
return self.query[arg_type](*args)
Expand Down Expand Up @@ -395,6 +380,11 @@ def __getitem__(self, item: str | type[T] | tuple[type[T], int]) -> T | Any | No
return next(i for i in self.all_matched_args.values() if generic_isinstance(i, item))

def __getattr__(self, item: str):
warn(
f"`Arparma.{item}` is deprecated, use `Arparma.query({item!r})` or `Arparma[{item!r}]` instead",
category=DeprecationWarning,
stacklevel=2
)
return self.all_matched_args.get(item, self.query(item.replace("_", ".")))

def __repr__(self):
Expand All @@ -405,8 +395,7 @@ def __repr__(self):
attrs = {
"matched": self.matched,
"header_match": self.header_match,
"options": self.options,
"subcommands": self.subcommands,
"value_result": {k: v for k, v in sorted(self.value_result.items(), key=lambda x: x[0])},
"main_args": self.main_args,
"other_args": self.other_args,
}
Expand All @@ -423,7 +412,6 @@ class ArparmaBehavior(metaclass=ABCMeta):

requires: list[ArparmaBehavior] = field(init=False, hash=False, repr=False)


@abstractmethod
def operate(self, interface: Arparma):
"""对解析结果进行操作"""
Expand All @@ -438,20 +426,21 @@ def update(self, interface: Arparma, path: str, value: Any):
value (Any): 要更新的值
"""

def _update(tkn, src, pth, ep, val):
def _update(tkn, src, ep, val):
if isinstance(src, dict):
src[ep] = val
else:
setattr(src, ep, val)

source, end = interface.__require__(path.split("."))
parts = path.split(".")
source, end = interface.__require__(parts)
if source is None:
return
if end:
_update(interface.token, source, path, end, value)
_update(interface.token, source, end, value)
elif isinstance(value, dict):
for k, v in value.items():
_update(interface.token, source, f"{path}.{k}", k, v)
_update(interface.token, source, (*parts, k), v)


@lru_cache(4096)
Expand Down
Loading

0 comments on commit bc5c635

Please sign in to comment.