Skip to content

Commit

Permalink
🐛 fix conflict arg if it is visited param token
Browse files Browse the repository at this point in the history
  • Loading branch information
RF-Tar-Railt committed Nov 1, 2024
1 parent 3ae7775 commit be7330a
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 1 addition & 1 deletion src/arclet/alconna/_internal/_analyser.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def process(self, argv: Argv[TDC]) -> Self:
if levenshtein(name, al) >= argv.fuzzy_threshold:
raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name))
raise InvalidParam(lang.require("subcommand", "name_error").format(source=sub.dest, target=name))

argv.visited_param_ids.update(sub.aliases)
self.value_result = sub.action.value
return self.analyse(argv)

Expand Down
8 changes: 6 additions & 2 deletions src/arclet/alconna/_internal/_argv.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ class Argv(Generic[TDC]):
"""检查传入命令"""
param_ids: set[str] = field(default_factory=set)
"""节点名集合"""
visited_param_ids: set[str] = field(default_factory=set)
"""已访问的节点名集合"""

fuzzy_match: bool = field(init=False)
"""当前命令是否模糊匹配"""
Expand Down Expand Up @@ -105,6 +107,7 @@ def reset(self):
self.origin = "None" # type: ignore
self._sep = None
self.current_node = None
self.visited_param_ids = set()

@staticmethod
def generate_token(data: list) -> int:
Expand Down Expand Up @@ -275,11 +278,12 @@ def release(self, separate: str | None = None, recover: bool = False, no_split:
return _result

def data_set(self):
return self.raw_data.copy(), self.current_index
return self.raw_data.copy(), self.current_index, self.visited_param_ids.copy()

def data_reset(self, data: list[str | Any], index: int):
def data_reset(self, data: list[str | Any], index: int, visited_param_ids: set[str]):
self.raw_data = data
self.current_index = index
self.visited_param_ids = visited_param_ids

def enter(self, ctx: dict[str, Any] | None = None) -> Self:
"""进入上下文"""
Expand Down
22 changes: 10 additions & 12 deletions src/arclet/alconna/_internal/_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def step_varpos(argv: Argv, args: Args, slot: tuple[MultiVar, Arg], result: dict
if _str and may_arg in argv.special:
if argv.special[may_arg] not in argv.namespace.disable_builtin_options:
raise SpecialOptionTriggered(argv.special[may_arg])
if not may_arg or (_str and may_arg in argv.param_ids):
if not may_arg or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids):
argv.rollback(may_arg)
break
if _str and may_arg in config.remainders:
Expand Down Expand Up @@ -127,7 +127,7 @@ def step_varkey(argv: Argv, slot: tuple[MultiKeyWordVar, Arg], result: dict[str,
if _str and may_arg in argv.special:
if argv.special[may_arg] not in argv.namespace.disable_builtin_options:
raise SpecialOptionTriggered(argv.special[may_arg])
if not may_arg or (_str and may_arg in argv.param_ids) or not _str:
if not may_arg or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids) or not _str:
argv.rollback(may_arg)
break
if _str and may_arg in config.remainders:
Expand Down Expand Up @@ -180,7 +180,7 @@ def step_keyword(argv: Argv, args: Args, result: dict[str, Any]):
_key = key
if _key not in args.argument.keyword_only:
argv.rollback(may_arg)
if args.argument.vars_keyword or (_str and may_arg in argv.param_ids):
if args.argument.vars_keyword or (_str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids):
break
for arg in args.argument.keyword_only.values():
if arg.value.base.validate(may_arg).flag == "valid": # type: ignore
Expand Down Expand Up @@ -231,13 +231,10 @@ def analyse_args(argv: Argv, args: Args) -> dict[str, Any]:
if _str and may_arg in argv.special:
if argv.special[may_arg] not in argv.namespace.disable_builtin_options:
raise SpecialOptionTriggered(argv.special[may_arg])
if _str and may_arg in argv.param_ids and arg.optional:
if (de := arg.field.default) is not Empty:
result[arg.name] = de
if _str and may_arg in argv.param_ids and may_arg not in argv.visited_param_ids and arg.optional:
argv.rollback(may_arg)
continue
may_arg = None
if may_arg is None or (_str and not may_arg):
# argv.rollback(may_arg)
if (de := arg.field.default) is not Empty:
result[arg.name] = de
elif not arg.optional:
Expand Down Expand Up @@ -312,6 +309,7 @@ def handle_option(argv: Argv, opt: Option) -> tuple[str, OptionResult]:
if levenshtein(name, al) >= argv.fuzzy_threshold:
raise FuzzyMatchSuccess(lang.require("fuzzy", "matched").format(source=al, target=name))
raise InvalidParam(lang.require("option", "name_error").format(source=opt.dest, target=name))
argv.visited_param_ids.update(opt.aliases)
name = opt.dest
if opt.nargs:
return name, OptionResult(None, analyse_args(argv, opt.args))
Expand Down Expand Up @@ -366,7 +364,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv):
argv (Argv): 命令行参数
"""
for param in analyser.compact_params:
_data, _index = argv.data_set()
_data, _index, _visited = argv.data_set()
try:
if param.__class__ is Option:
oparam: Option = param # type: ignore
Expand Down Expand Up @@ -402,7 +400,7 @@ def analyse_compact_params(analyser: SubAnalyser, argv: Argv):
except InvalidParam as e:
if argv.current_node.__class__ is Arg:
raise e
argv.data_reset(_data, _index)
argv.data_reset(_data, _index, _visited)
else:
return False

Expand Down Expand Up @@ -459,7 +457,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: str | None = None):
exc: Exception | None = None
lparam: list[Option] = _param # type: ignore
for opt in lparam:
_data, _index = argv.data_set()
_data, _index, _visited = argv.data_set()
try:
if opt.requires and analyser.sentences != opt.requires:
raise InvalidParam(
Expand All @@ -474,7 +472,7 @@ def analyse_param(analyser: SubAnalyser, argv: Argv, seps: str | None = None):
break
except Exception as e:
exc = e
argv.data_reset(_data, _index)
argv.data_reset(_data, _index, _visited)
if exc:
raise exc # type: ignore # noqa
analyser.sentences.clear()
Expand Down

0 comments on commit be7330a

Please sign in to comment.