From be7330a0bb53814e9c1f86c49b49f532505a2707 Mon Sep 17 00:00:00 2001 From: RF-Tar-Railt Date: Sat, 2 Nov 2024 00:23:17 +0800 Subject: [PATCH] :bug: fix conflict arg if it is visited param token --- src/arclet/alconna/_internal/_analyser.py | 2 +- src/arclet/alconna/_internal/_argv.py | 8 ++++++-- src/arclet/alconna/_internal/_handlers.py | 22 ++++++++++------------ 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/src/arclet/alconna/_internal/_analyser.py b/src/arclet/alconna/_internal/_analyser.py index 977772d7..32600ee1 100644 --- a/src/arclet/alconna/_internal/_analyser.py +++ b/src/arclet/alconna/_internal/_analyser.py @@ -205,7 +205,7 @@ def process(self, argv: Argv[TDC]) -> Self: if levenshtein(name, al) >= argv.fuzzy_threshold: raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name)) raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name)) - + argv.visited_param_ids.update(sub.aliases) self.value_result = sub.action.value return self.analyse(argv) diff --git a/src/arclet/alconna/_internal/_argv.py b/src/arclet/alconna/_internal/_argv.py index 91f20cf4..913ce7f4 100644 --- a/src/arclet/alconna/_internal/_argv.py +++ b/src/arclet/alconna/_internal/_argv.py @@ -32,6 +32,8 @@ class Argv(Generic[TDC]): """检查传入命令""" param_ids: set[str] = field(default_factory=set) """节点名集合""" + visited_param_ids: set[str] = field(default_factory=set) + """已访问的节点名集合""" fuzzy_match: bool = field(init=False) """当前命令是否模糊匹配""" @@ -105,6 +107,7 @@ def reset(self): self.origin = "None" # type: ignore self._sep = None self.current_node = None + self.visited_param_ids = set() @staticmethod def generate_token(data: list) -> int: @@ -275,11 +278,12 @@ def release(self, separate: str | None = None, recover: bool = False, no_split: return _result def data_set(self): - return self.raw_data.copy(), self.current_index + return self.raw_data.copy(), self.current_index, self.visited_param_ids.copy() - def data_reset(self, data: list[str | Any], index: int): + def data_reset(self, data: list[str | Any], index: int, visited_param_ids: set[str]): self.raw_data = data self.current_index = index + self.visited_param_ids = visited_param_ids def enter(self, ctx: dict[str, Any] | None = None) -> Self: """进入上下文""" diff --git a/src/arclet/alconna/_internal/_handlers.py b/src/arclet/alconna/_internal/_handlers.py index 0902210a..96074b6a 100644 --- a/src/arclet/alconna/_internal/_handlers.py +++ b/src/arclet/alconna/_internal/_handlers.py @@ -82,7 +82,7 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict if _str and may_arg in argv.special: if argv.special[may_arg] not in argv.namespace.disable_builtin_options: raise SpecialOptionTriggered(argv.special[may_arg]) - if not may_arg or (_str and may_arg in argv.param_ids): + if not may_arg or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids): argv.rollback(may_arg) break if _str and may_arg in config.remainders: @@ -127,7 +127,7 @@ def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str, if _str and may_arg in argv.special: if argv.special[may_arg] not in argv.namespace.disable_builtin_options: raise SpecialOptionTriggered(argv.special[may_arg]) - if not may_arg or (_str and may_arg in argv.param_ids) or not _str: + if not may_arg or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids) or not _str: argv.rollback(may_arg) break if _str and may_arg in config.remainders: @@ -180,7 +180,7 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]): _key = key if _key not in args.argument.keyword_only: argv.rollback(may_arg) - if args.argument.vars_keyword or (_str and may_arg in argv.param_ids): + if args.argument.vars_keyword or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids): break for arg in args.argument.keyword_only.values(): if arg.value.base.validate(may_arg).flag == "valid": # type: ignore @@ -231,13 +231,10 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]: if _str and may_arg in argv.special: if argv.special[may_arg] not in argv.namespace.disable_builtin_options: raise SpecialOptionTriggered(argv.special[may_arg]) - if _str and may_arg in argv.param_ids and arg.optional: - if (de := arg.field.default) is not Empty: - result[arg.name] = de + if _str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids and arg.optional: argv.rollback(may_arg) - continue + may_arg = None if may_arg is None or (_str and not may_arg): - # argv.rollback(may_arg) if (de := arg.field.default) is not Empty: result[arg.name] = de elif not arg.optional: @@ -312,6 +309,7 @@ def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]: if levenshtein(name, al) >= argv.fuzzy_threshold: raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name)) raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name)) + argv.visited_param_ids.update(opt.aliases) name = opt.dest if opt.nargs: return name, OptionResult(None, analyse_args(argv, opt.args)) @@ -366,7 +364,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv): argv (Argv): 命令行参数 """ for param in analyser.compact_params: - _data, _index = argv.data_set() + _data, _index, _visited = argv.data_set() try: if param.__class__ is Option: oparam: Option = param # type: ignore @@ -402,7 +400,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv): except InvalidParam as e: if argv.current_node.__class__ is Arg: raise e - argv.data_reset(_data, _index) + argv.data_reset(_data, _index, _visited) else: return False @@ -459,7 +457,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: str | None = None): exc: Exception | None = None lparam: list[Option] = _param # type: ignore for opt in lparam: - _data, _index = argv.data_set() + _data, _index, _visited = argv.data_set() try: if opt.requires and analyser.sentences != opt.requires: raise InvalidParam( @@ -474,7 +472,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: str | None = None): break except Exception as e: exc = e - argv.data_reset(_data, _index) + argv.data_reset(_data, _index, _visited) if exc: raise exc # type: ignore # noqa analyser.sentences.clear()