From 81d798d2205e35faecb6bd95b9a3968db6a4e5cb Mon Sep 17 00:00:00 2001 From: tandemdude <43570299+tandemdude@users.noreply.github.com> Date: Thu, 4 Apr 2024 01:43:41 +0100 Subject: [PATCH] Fix typing issues and add more execution steps --- lightbulb/client.py | 2 + lightbulb/commands/execution.py | 63 ++++++++++++------------- lightbulb/loaders.py | 8 +++- scripts/docs/api_reference_generator.py | 4 +- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/lightbulb/client.py b/lightbulb/client.py index 0392ff98..9a1b6af2 100644 --- a/lightbulb/client.py +++ b/lightbulb/client.py @@ -60,6 +60,8 @@ execution.ExecutionSteps.MAX_CONCURRENCY, execution.ExecutionSteps.CHECKS, execution.ExecutionSteps.COOLDOWNS, + execution.ExecutionSteps.INVOKE, + execution.ExecutionSteps.POST_INVOKE, ) diff --git a/lightbulb/commands/execution.py b/lightbulb/commands/execution.py index f82ee7b8..11e5e39b 100644 --- a/lightbulb/commands/execution.py +++ b/lightbulb/commands/execution.py @@ -35,7 +35,9 @@ __all__ = ["ExecutionStep", "ExecutionSteps", "ExecutionHook", "ExecutionPipeline", "hook", "invoke"] -ExecutionHookFuncT: t.TypeAlias = t.Callable[["ExecutionPipeline", "context_.Context"], types.MaybeAwaitable[None]] +ExecutionHookFunc: t.TypeAlias = t.Callable[ + t.Concatenate["ExecutionPipeline", "context_.Context", ...], types.MaybeAwaitable[None] +] @dataclasses.dataclass(frozen=True, slots=True, eq=True) @@ -71,6 +73,10 @@ class ExecutionSteps: """Step for execution of command check logic.""" COOLDOWNS = ExecutionStep("COOLDOWNS") """Step for execution of command cooldown logic.""" + INVOKE = ExecutionStep("INVOKE") + """Step for command invocation. No hooks should ever use this step.""" + POST_INVOKE = ExecutionStep("POST_INVOKE") + """Step for post-invocation logic.""" @dataclasses.dataclass(frozen=True, slots=True, eq=True) @@ -87,7 +93,7 @@ class ExecutionHook: step: ExecutionStep """The step that this hook should be run during.""" - func: ExecutionHookFuncT + func: ExecutionHookFunc """The function that this hook executes.""" async def __call__(self, pipeline: ExecutionPipeline, context: context_.Context) -> None: @@ -138,6 +144,14 @@ def _next_step(self) -> ExecutionStep | None: return self._remaining.pop(0) return None + def _fail(self, exc: Exception) -> None: + assert self._current_step is not None + assert self._current_hook is not None + + hook_exc = exceptions.HookFailedException(exc, self._current_hook) + + self._failure = hook_exc + async def _run(self) -> None: """ Run the pipeline. Does not reset the state if called multiple times. @@ -153,13 +167,20 @@ async def _run(self) -> None: """ self._current_step = self._next_step() while self._current_step is not None: + if self._current_step == ExecutionSteps.INVOKE: + try: + await getattr(self._context.command, self._context.command_data.invoke_method)(self._context) + continue + except Exception as e: + raise exceptions.InvocationFailedException(e, self._context) + step_hooks = list(self._hooks.get(self._current_step, [])) while step_hooks and not self.failed: self._current_hook = step_hooks.pop(0) try: await self._current_hook(self, self._context) except Exception as e: - self.fail(e) + self._fail(e) if self.failed: break @@ -170,34 +191,8 @@ async def _run(self) -> None: assert self._failure is not None raise self._failure - try: - await getattr(self._context.command, self._context.command_data.invoke_method)(self._context) - except Exception as e: - raise exceptions.InvocationFailedException(e, self._context) - - def fail(self, exc: str | Exception) -> None: - """ - Notify the pipeline of a failure in an execution hook. - - Args: - exc (:obj:`~typing.Union` [ :obj:`str`, :obj:`Exception` ]): Message or exception to include - with the failure. - - Returns: - :obj:`None` - """ - if not isinstance(exc, Exception): - exc = RuntimeError(exc) - - assert self._current_step is not None - assert self._current_hook is not None - - hook_exc = exceptions.HookFailedException(exc, self._current_hook) - - self._failure = hook_exc - -def hook(step: ExecutionStep) -> t.Callable[[ExecutionHookFuncT], ExecutionHook]: +def hook(step: ExecutionStep) -> t.Callable[[ExecutionHookFunc], ExecutionHook]: """ Second order decorator to convert a function into an execution hook for the given step. Also enables dependency injection on the decorated function. @@ -231,14 +226,18 @@ def only_on_mondays(pl: lightbulb.ExecutionPipeline, _: lightbulb.Context) -> No # Fail the pipeline execution pl.fail("This command can only be used on mondays") """ + if step == ExecutionSteps.INVOKE: + raise ValueError("hooks cannot be registered for the 'INVOKE' execution step") - def inner(func: ExecutionHookFuncT) -> ExecutionHook: + def inner(func: ExecutionHookFunc) -> ExecutionHook: return ExecutionHook(step, di.with_di(func)) # type: ignore[reportArgumentType] return inner -def invoke(func: t.Callable[..., t.Awaitable[t.Any]]) -> t.Callable[[context_.Context], t.Awaitable[t.Any]]: +def invoke( + func: t.Callable[t.Concatenate[context_.Context, ...], t.Awaitable[t.Any]], +) -> t.Callable[[context_.Context], t.Awaitable[t.Any]]: """ First order decorator to mark a method as the invocation method to be used for the command. Also enables dependency injection on the decorated method. The decorated method **must** have the first parameter (non-self) diff --git a/lightbulb/loaders.py b/lightbulb/loaders.py index a12df258..51f002c0 100644 --- a/lightbulb/loaders.py +++ b/lightbulb/loaders.py @@ -211,7 +211,9 @@ def _inner(command_: CommandOrGroupT) -> CommandOrGroupT: def listener( self, event_type: EventT - ) -> t.Callable[[t.Callable[..., t.Awaitable[None]]], t.Callable[[EventT], t.Awaitable[None]]]: + ) -> t.Callable[ + [t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]], t.Callable[[EventT], t.Awaitable[None]] + ]: """ Decorator to register a listener with this loader. Also enables dependency injection on the listener callback. @@ -233,7 +235,9 @@ async def message_create_listener(event: hikari.MessageCreateEvent) -> None: ... """ - def _inner(callback: t.Callable[..., t.Awaitable[None]]) -> t.Callable[[EventT], t.Awaitable[None]]: + def _inner( + callback: t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]], + ) -> t.Callable[[EventT], t.Awaitable[None]]: di_enabled = t.cast(t.Callable[[EventT], t.Awaitable[None]], di.with_di(callback)) self._loadables.append(_ListenerLoadable(di_enabled, event_type)) return di_enabled diff --git a/scripts/docs/api_reference_generator.py b/scripts/docs/api_reference_generator.py index cfe3b3c8..2d198985 100644 --- a/scripts/docs/api_reference_generator.py +++ b/scripts/docs/api_reference_generator.py @@ -126,7 +126,9 @@ def write(self) -> bool: ] if package_lines: - lines.extend(["**Subpackages:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(package_lines), ""]) + lines.extend( + ["**Subpackages:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(package_lines), ""] + ) if module_lines: lines.extend(["**Submodules:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(module_lines), ""])