Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mild CLI cleanup #16560

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
238 changes: 73 additions & 165 deletions src/prefect/cli/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,61 +2,51 @@
Custom Prefect CLI types
"""

import asyncio
import functools
import sys
from typing import Any, Callable, List, Optional
from typing import (
Any,
Callable,
Coroutine,
Optional,
Protocol,
Type,
TypeVar,
)

import typer
from rich.console import Console
from rich.theme import Theme
from typer.core import TyperCommand
from typing_extensions import Concatenate, ParamSpec

from prefect._internal.compatibility.deprecated import generate_deprecation_message
from prefect.cli._utilities import with_cli_exception_handling
from prefect.settings import PREFECT_CLI_COLORS, Setting
from prefect.utilities.asyncutils import is_async_fn


def SettingsOption(setting: Setting, *args: Any, **kwargs: Any) -> Any:
"""Custom `typer.Option` factory to load the default value from settings"""

return typer.Option(
# The default is dynamically retrieved
setting.value,
*args,
# Typer shows "(dynamic)" by default. We'd like to actually show the value
# that would be used if the parameter is not specified and a reference if the
# source is from the environment or profile, but typer does not support this
# yet. See https://github.com/tiangolo/typer/issues/354
show_default=f"from {setting.name}",
**kwargs,
)


def SettingsArgument(setting: Setting, *args: Any, **kwargs: Any) -> Any:
"""Custom `typer.Argument` factory to load the default value from settings"""

# See comments in `SettingsOption`
return typer.Argument(
setting.value,
*args,
show_default=f"from {setting.name}",
**kwargs,
)


def with_deprecated_message(
warning: str,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
def decorator(fn: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
print("WARNING:", warning, file=sys.stderr, flush=True)
return fn(*args, **kwargs)
from prefect.utilities.asyncutils import sync

return wrapper
P = ParamSpec("P")

R = TypeVar("R")
T = TypeVar("T")


def with_settings(
func: Callable[Concatenate[Any, P], T],
) -> Callable[Concatenate[Setting, P], T]:
@functools.wraps(func)
def wrapper(setting: Setting, *args: P.args, **kwargs: P.kwargs) -> T:
kwargs.update({"show_default": f"from {setting.name}"})
return func(setting.value, *args, **kwargs)

return wrapper


SettingsOption = with_settings(typer.Option)

return decorator

class WrappedCallable(Protocol[P, T]):
__wrapped__: Callable[P, T]

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
...


class PrefectTyper(typer.Typer):
Expand All @@ -66,134 +56,52 @@ class PrefectTyper(typer.Typer):

console: Console

def __init__(
self,
*args: Any,
deprecated: bool = False,
deprecated_start_date: Optional[str] = None,
deprecated_help: str = "",
deprecated_name: str = "",
**kwargs: Any,
):
super().__init__(*args, **kwargs)

self.deprecated = deprecated
if self.deprecated:
if not deprecated_name:
raise ValueError("Provide the name of the deprecated command group.")
self.deprecated_message: str = generate_deprecation_message(
name=f"The {deprecated_name!r} command group",
start_date=deprecated_start_date,
help=deprecated_help,
)

self.console = Console(
highlight=False,
theme=Theme({"prompt.choices": "bold blue"}),
color_system="auto" if PREFECT_CLI_COLORS else None,
)

def add_typer(
self,
typer_instance: "PrefectTyper",
*args: Any,
no_args_is_help: bool = True,
aliases: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
"""
This will cause help to be default command for all sub apps unless specifically stated otherwise, opposite of before.
"""
if aliases:
for alias in aliases:
super().add_typer(
typer_instance,
*args,
name=alias,
no_args_is_help=no_args_is_help,
hidden=True,
**kwargs,
)

return super().add_typer(
typer_instance, *args, no_args_is_help=no_args_is_help, **kwargs
)

def command(
def acommand(
self,
name: Optional[str] = None,
*args: Any,
aliases: Optional[List[str]] = None,
*,
cls: Optional[Type[TyperCommand]] = None,
context_settings: Optional[dict[str, Any]] = None,
help: Optional[str] = None,
epilog: Optional[str] = None,
short_help: Optional[str] = None,
options_metavar: str = "[OPTIONS]",
add_help_option: bool = True,
no_args_is_help: bool = False,
hidden: bool = False,
deprecated: bool = False,
deprecated_start_date: Optional[str] = None,
deprecated_help: str = "",
deprecated_name: str = "",
**kwargs: Any,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
) -> Callable[
[Callable[P, Coroutine[Any, Any, T]]],
Callable[P, T],
]:
"""
Create a new command. If aliases are provided, the same command function
will be registered with multiple names.
Decorator for registering a command on this Typer app that MUST be async.

If the decorated function is async, it will be wrapped via your custom
sync(...) method so Typer sees a synchronous function.

Provide `deprecated=True` to mark the command as deprecated. If `deprecated=True`,
`deprecated_name` and `deprecated_start_date` must be provided.
If the decorated function is NOT async, a TypeError is raised.
"""

def wrapper(original_fn: Callable[..., Any]) -> Callable[..., Any]:
# click doesn't support async functions, so we wrap them in
# asyncio.run(). This has the advantage of keeping the function in
# the main thread, which means signal handling works for e.g. the
# server and workers. However, it means that async CLI commands can
# not directly call other async CLI commands (because asyncio.run()
# can not be called nested). In that (rare) circumstance, refactor
# the CLI command so its business logic can be invoked separately
# from its entrypoint.
if is_async_fn(original_fn):
async_fn = original_fn

@functools.wraps(original_fn)
def sync_fn(*args: Any, **kwargs: Any) -> Any:
return asyncio.run(async_fn(*args, **kwargs))

setattr(sync_fn, "aio", async_fn)
wrapped_fn = sync_fn
else:
wrapped_fn = original_fn

wrapped_fn = with_cli_exception_handling(wrapped_fn)
if deprecated:
if not deprecated_name or not deprecated_start_date:
raise ValueError(
"Provide the name of the deprecated command and a deprecation start date."
)
command_deprecated_message = generate_deprecation_message(
name=f"The {deprecated_name!r} command",
start_date=deprecated_start_date,
help=deprecated_help,
)
wrapped_fn = with_deprecated_message(command_deprecated_message)(
wrapped_fn
)
elif self.deprecated:
wrapped_fn = with_deprecated_message(self.deprecated_message)(
wrapped_fn
)

# register fn with its original name
def wrapper(fn: Callable[P, Coroutine[Any, Any, T]]) -> Callable[P, T]:
@functools.wraps(fn)
def sync_fn(*args: P.args, **kwargs: P.kwargs) -> T:
return sync(fn, *args, **kwargs)

command_decorator = super(PrefectTyper, self).command(
name=name, *args, **kwargs
name=name,
cls=cls,
context_settings=context_settings,
help=help,
epilog=epilog,
short_help=short_help,
options_metavar=options_metavar,
add_help_option=add_help_option,
no_args_is_help=no_args_is_help,
hidden=hidden,
deprecated=deprecated,
)
original_command = command_decorator(wrapped_fn)

# register fn for each alias, e.g. @marvin_app.command(aliases=["r"])
if aliases:
for alias in aliases:
super(PrefectTyper, self).command(
name=alias,
*args,
**{k: v for k, v in kwargs.items() if k != "aliases"},
)(wrapped_fn)

return original_command
return command_decorator(sync_fn)

return wrapper

Expand Down
6 changes: 3 additions & 3 deletions src/prefect/cli/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
app.add_typer(artifact_app)


@artifact_app.command("ls")
@artifact_app.acommand("ls")
async def list_artifacts(
limit: int = typer.Option(
100,
Expand Down Expand Up @@ -83,7 +83,7 @@ async def list_artifacts(
app.console.print(table)


@artifact_app.command("inspect")
@artifact_app.acommand("inspect")
async def inspect(
key: str,
limit: int = typer.Option(
Expand Down Expand Up @@ -142,7 +142,7 @@ async def inspect(
app.console.print(Pretty(artifacts))


@artifact_app.command("delete")
@artifact_app.acommand("delete")
async def delete(
key: Optional[str] = typer.Argument(
None, help="The key of the artifact to delete."
Expand Down
29 changes: 15 additions & 14 deletions src/prefect/cli/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@
if TYPE_CHECKING:
from prefect.client.schemas.objects import BlockDocument, BlockType

blocks_app: PrefectTyper = PrefectTyper(name="block", help="Manage blocks.")
blocktypes_app: PrefectTyper = PrefectTyper(
name="type", help="Inspect and delete block types."
)
app.add_typer(blocks_app, aliases=["blocks"])
blocks_app.add_typer(blocktypes_app, aliases=["types"])

blocks_app = PrefectTyper(name="block", help="Manage blocks.")
blocktypes_app = PrefectTyper(name="type", help="Inspect and delete block types.")
app.add_typer(blocks_app, no_args_is_help=True)
app.add_typer(blocks_app, name="blocks", hidden=True, no_args_is_help=True)
blocks_app.add_typer(blocktypes_app, no_args_is_help=True)
blocks_app.add_typer(blocktypes_app, name="types", hidden=True, no_args_is_help=True)


def display_block(block_document: "BlockDocument") -> Table:
Expand Down Expand Up @@ -155,7 +156,7 @@ def _build_registered_blocks_table(registered_blocks: list[type[Block]]) -> Tabl
return table


@blocks_app.command()
@blocks_app.acommand()
async def register(
module_name: Optional[str] = typer.Option(
None,
Expand Down Expand Up @@ -250,7 +251,7 @@ async def register(
app.console.print(msg, soft_wrap=True)


@blocks_app.command("ls")
@blocks_app.acommand("ls")
async def block_ls():
"""
View all configured blocks.
Expand Down Expand Up @@ -279,7 +280,7 @@ async def block_ls():
app.console.print(table)


@blocks_app.command("delete")
@blocks_app.acommand("delete")
async def block_delete(
slug: Optional[str] = typer.Argument(
None, help="A block slug. Formatted as '<BLOCK_TYPE_SLUG>/<BLOCK_NAME>'"
Expand Down Expand Up @@ -324,7 +325,7 @@ async def block_delete(
exit_with_error("Must provide a block slug or id")


@blocks_app.command("create")
@blocks_app.acommand("create")
async def block_create(
block_type_slug: str = typer.Argument(
...,
Expand Down Expand Up @@ -357,7 +358,7 @@ async def block_create(
)


@blocks_app.command("inspect")
@blocks_app.acommand("inspect")
async def block_inspect(
slug: Optional[str] = typer.Argument(
None, help="A Block slug: <BLOCK_TYPE_SLUG>/<BLOCK_NAME>"
Expand Down Expand Up @@ -394,7 +395,7 @@ async def block_inspect(
app.console.print(display_block(block_document))


@blocktypes_app.command("ls")
@blocktypes_app.acommand("ls")
async def list_types():
"""
List all block types.
Expand Down Expand Up @@ -427,7 +428,7 @@ async def list_types():
app.console.print(table)


@blocktypes_app.command("inspect")
@blocktypes_app.acommand("inspect")
async def blocktype_inspect(
slug: str = typer.Argument(..., help="A block type slug"),
):
Expand Down Expand Up @@ -461,7 +462,7 @@ async def blocktype_inspect(
)


@blocktypes_app.command("delete")
@blocktypes_app.acommand("delete")
async def blocktype_delete(
slug: str = typer.Argument(..., help="A Block type slug"),
):
Expand Down
Loading
Loading