Skip to content

Commit

Permalink
Fix typing issues and add more execution steps
Browse files Browse the repository at this point in the history
  • Loading branch information
tandemdude committed Apr 4, 2024
1 parent 8f02135 commit 81d798d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 35 deletions.
2 changes: 2 additions & 0 deletions lightbulb/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
execution.ExecutionSteps.MAX_CONCURRENCY,
execution.ExecutionSteps.CHECKS,
execution.ExecutionSteps.COOLDOWNS,
execution.ExecutionSteps.INVOKE,
execution.ExecutionSteps.POST_INVOKE,
)


Expand Down
63 changes: 31 additions & 32 deletions lightbulb/commands/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@

__all__ = ["ExecutionStep", "ExecutionSteps", "ExecutionHook", "ExecutionPipeline", "hook", "invoke"]

ExecutionHookFuncT: t.TypeAlias = t.Callable[["ExecutionPipeline", "context_.Context"], types.MaybeAwaitable[None]]
ExecutionHookFunc: t.TypeAlias = t.Callable[
t.Concatenate["ExecutionPipeline", "context_.Context", ...], types.MaybeAwaitable[None]
]


@dataclasses.dataclass(frozen=True, slots=True, eq=True)
Expand Down Expand Up @@ -71,6 +73,10 @@ class ExecutionSteps:
"""Step for execution of command check logic."""
COOLDOWNS = ExecutionStep("COOLDOWNS")
"""Step for execution of command cooldown logic."""
INVOKE = ExecutionStep("INVOKE")
"""Step for command invocation. No hooks should ever use this step."""
POST_INVOKE = ExecutionStep("POST_INVOKE")
"""Step for post-invocation logic."""


@dataclasses.dataclass(frozen=True, slots=True, eq=True)
Expand All @@ -87,7 +93,7 @@ class ExecutionHook:

step: ExecutionStep
"""The step that this hook should be run during."""
func: ExecutionHookFuncT
func: ExecutionHookFunc
"""The function that this hook executes."""

async def __call__(self, pipeline: ExecutionPipeline, context: context_.Context) -> None:
Expand Down Expand Up @@ -138,6 +144,14 @@ def _next_step(self) -> ExecutionStep | None:
return self._remaining.pop(0)
return None

def _fail(self, exc: Exception) -> None:
assert self._current_step is not None
assert self._current_hook is not None

hook_exc = exceptions.HookFailedException(exc, self._current_hook)

self._failure = hook_exc

async def _run(self) -> None:
"""
Run the pipeline. Does not reset the state if called multiple times.
Expand All @@ -153,13 +167,20 @@ async def _run(self) -> None:
"""
self._current_step = self._next_step()
while self._current_step is not None:
if self._current_step == ExecutionSteps.INVOKE:
try:
await getattr(self._context.command, self._context.command_data.invoke_method)(self._context)
continue
except Exception as e:
raise exceptions.InvocationFailedException(e, self._context)

step_hooks = list(self._hooks.get(self._current_step, []))
while step_hooks and not self.failed:
self._current_hook = step_hooks.pop(0)
try:
await self._current_hook(self, self._context)
except Exception as e:
self.fail(e)
self._fail(e)

if self.failed:
break
Expand All @@ -170,34 +191,8 @@ async def _run(self) -> None:
assert self._failure is not None
raise self._failure

try:
await getattr(self._context.command, self._context.command_data.invoke_method)(self._context)
except Exception as e:
raise exceptions.InvocationFailedException(e, self._context)

def fail(self, exc: str | Exception) -> None:
"""
Notify the pipeline of a failure in an execution hook.
Args:
exc (:obj:`~typing.Union` [ :obj:`str`, :obj:`Exception` ]): Message or exception to include
with the failure.
Returns:
:obj:`None`
"""
if not isinstance(exc, Exception):
exc = RuntimeError(exc)

assert self._current_step is not None
assert self._current_hook is not None

hook_exc = exceptions.HookFailedException(exc, self._current_hook)

self._failure = hook_exc


def hook(step: ExecutionStep) -> t.Callable[[ExecutionHookFuncT], ExecutionHook]:
def hook(step: ExecutionStep) -> t.Callable[[ExecutionHookFunc], ExecutionHook]:
"""
Second order decorator to convert a function into an execution hook for the given
step. Also enables dependency injection on the decorated function.
Expand Down Expand Up @@ -231,14 +226,18 @@ def only_on_mondays(pl: lightbulb.ExecutionPipeline, _: lightbulb.Context) -> No
# Fail the pipeline execution
pl.fail("This command can only be used on mondays")
"""
if step == ExecutionSteps.INVOKE:
raise ValueError("hooks cannot be registered for the 'INVOKE' execution step")

def inner(func: ExecutionHookFuncT) -> ExecutionHook:
def inner(func: ExecutionHookFunc) -> ExecutionHook:
return ExecutionHook(step, di.with_di(func)) # type: ignore[reportArgumentType]

return inner


def invoke(func: t.Callable[..., t.Awaitable[t.Any]]) -> t.Callable[[context_.Context], t.Awaitable[t.Any]]:
def invoke(
func: t.Callable[t.Concatenate[context_.Context, ...], t.Awaitable[t.Any]],
) -> t.Callable[[context_.Context], t.Awaitable[t.Any]]:
"""
First order decorator to mark a method as the invocation method to be used for the command. Also enables
dependency injection on the decorated method. The decorated method **must** have the first parameter (non-self)
Expand Down
8 changes: 6 additions & 2 deletions lightbulb/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,9 @@ def _inner(command_: CommandOrGroupT) -> CommandOrGroupT:

def listener(
self, event_type: EventT
) -> t.Callable[[t.Callable[..., t.Awaitable[None]]], t.Callable[[EventT], t.Awaitable[None]]]:
) -> t.Callable[
[t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]]], t.Callable[[EventT], t.Awaitable[None]]
]:
"""
Decorator to register a listener with this loader. Also enables dependency injection on the listener
callback.
Expand All @@ -233,7 +235,9 @@ async def message_create_listener(event: hikari.MessageCreateEvent) -> None:
...
"""

def _inner(callback: t.Callable[..., t.Awaitable[None]]) -> t.Callable[[EventT], t.Awaitable[None]]:
def _inner(
callback: t.Callable[t.Concatenate[EventT, ...], t.Awaitable[None]],
) -> t.Callable[[EventT], t.Awaitable[None]]:
di_enabled = t.cast(t.Callable[[EventT], t.Awaitable[None]], di.with_di(callback))
self._loadables.append(_ListenerLoadable(di_enabled, event_type))
return di_enabled
Expand Down
4 changes: 3 additions & 1 deletion scripts/docs/api_reference_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,9 @@ def write(self) -> bool:
]

if package_lines:
lines.extend(["**Subpackages:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(package_lines), ""])
lines.extend(
["**Subpackages:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(package_lines), ""]
)
if module_lines:
lines.extend(["**Submodules:**", "", ".. toctree::", " :maxdepth: 1", "", *sorted(module_lines), ""])

Expand Down

0 comments on commit 81d798d

Please sign in to comment.