diff --git a/.github/workflows/make_release.py b/.github/workflows/make_release.py index 88a04c1d..7eb5ddb8 100755 --- a/.github/workflows/make_release.py +++ b/.github/workflows/make_release.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import annotations + import argparse from pathlib import Path diff --git a/diracx-cli/pyproject.toml b/diracx-cli/pyproject.toml index 06d52667..8f14d7f0 100644 --- a/diracx-cli/pyproject.toml +++ b/diracx-cli/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ "gitpython", "pydantic>=2.10", "rich", - "typer", + "typer>=0.12.4", "pyyaml", ] dynamic = ["version"] diff --git a/diracx-cli/src/diracx/cli/__init__.py b/diracx-cli/src/diracx/cli/__init__.py index d86aa660..63eb968c 100644 --- a/diracx-cli/src/diracx/cli/__init__.py +++ b/diracx-cli/src/diracx/cli/__init__.py @@ -1,144 +1,13 @@ -import asyncio -import json -import os -from datetime import datetime, timedelta, timezone -from typing import Annotated, Optional +from __future__ import annotations -import typer - -from diracx.client.aio import DiracClient -from diracx.client.models import DeviceFlowErrorResponse from diracx.core.extensions import select_from_extension -from diracx.core.preferences import get_diracx_preferences -from diracx.core.utils import read_credentials, write_credentials - -from .utils import AsyncTyper - -app = AsyncTyper() - - -async def installation_metadata(): - async with DiracClient() as api: - return await api.well_known.installation_metadata() - - -def vo_callback(vo: str | None) -> str: - metadata = asyncio.run(installation_metadata()) - vos = list(metadata.virtual_organizations) - if not vo: - raise typer.BadParameter( - f"VO must be specified, available options are: {' '.join(vos)}" - ) - if vo not in vos: - raise typer.BadParameter( - f"Unknown VO {vo}, available options are: {' '.join(vos)}" - ) - return vo - - -@app.async_command() -async def login( - vo: Annotated[ - Optional[str], - typer.Argument(callback=vo_callback, help="Virtual Organization name"), - ] = None, - group: Optional[str] = typer.Option( - None, - help="Group name within the VO. If not provided, the default group for the VO will be used.", - ), - property: Optional[list[str]] = typer.Option( - None, - help=( - "List of properties to add to the default properties of the group. " - "If not provided, default properties of the group will be used." - ), - ), -): - """Login to the DIRAC system using the device flow. - - - If only VO is provided: Uses the default group and its properties for the VO. - - - If VO and group are provided: Uses the specified group and its properties for the VO. - - If VO and properties are provided: Uses the default group and combines its properties with the - provided properties. +from .auth import app - - If VO, group, and properties are provided: Uses the specified group and combines its properties with the - provided properties. - """ - scopes = [f"vo:{vo}"] - if group: - scopes.append(f"group:{group}") - if property: - scopes += [f"property:{p}" for p in property] - - print(f"Logging in with scopes: {scopes}") - async with DiracClient() as api: - data = await api.auth.initiate_device_flow( - client_id=api.client_id, - scope=" ".join(scopes), - ) - print("Now go to:", data.verification_uri_complete) - expires = datetime.now(tz=timezone.utc) + timedelta( - seconds=data.expires_in - 30 - ) - while expires > datetime.now(tz=timezone.utc): - print(".", end="", flush=True) - response = await api.auth.token(device_code=data.device_code, client_id=api.client_id) # type: ignore - if isinstance(response, DeviceFlowErrorResponse): - if response.error == "authorization_pending": - # TODO: Setting more than 5 seconds results in an error - # Related to keep-alive disconnects from uvicon (--timeout-keep-alive) - await asyncio.sleep(2) - continue - raise RuntimeError(f"Device flow failed with {response}") - break - else: - raise RuntimeError("Device authorization flow expired") - - # Save credentials - write_credentials(response) - credentials_path = get_diracx_preferences().credentials_path - print(f"Saved credentials to {credentials_path}") - print("\nLogin successful!") - - -@app.async_command() -async def whoami(): - async with DiracClient() as api: - user_info = await api.auth.userinfo() - # TODO: Add a RICH output format - print(json.dumps(user_info.as_dict(), indent=2)) - - -@app.async_command() -async def logout(): - async with DiracClient() as api: - credentials_path = get_diracx_preferences().credentials_path - if credentials_path.exists(): - credentials = read_credentials(credentials_path) - - # Revoke refresh token - try: - await api.auth.revoke_refresh_token(credentials.refresh_token) - except Exception as e: - print(f"Error revoking the refresh token {e!r}") - pass - - # Remove credentials - credentials_path.unlink(missing_ok=True) - print(f"Removed credentials from {credentials_path}") - print("\nLogout successful!") - - -@app.callback() -def callback(output_format: Optional[str] = None): - if output_format is not None: - os.environ["DIRACX_OUTPUT_FORMAT"] = output_format +__all__ = ("app",) # Load all the sub commands - cli_names = set( [entry_point.name for entry_point in select_from_extension(group="diracx.cli")] ) diff --git a/diracx-cli/src/diracx/cli/__main__.py b/diracx-cli/src/diracx/cli/__main__.py index 916bd38c..6c808d7f 100644 --- a/diracx-cli/src/diracx/cli/__main__.py +++ b/diracx-cli/src/diracx/cli/__main__.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from . import app if __name__ == "__main__": diff --git a/diracx-cli/src/diracx/cli/auth.py b/diracx-cli/src/diracx/cli/auth.py new file mode 100644 index 00000000..2f220c10 --- /dev/null +++ b/diracx-cli/src/diracx/cli/auth.py @@ -0,0 +1,142 @@ +# Can't using PEP-604 with typer: https://github.com/tiangolo/typer/issues/348 +# from __future__ import annotations +from __future__ import annotations + +__all__ = ("app",) + +import asyncio +import json +import os +from datetime import datetime, timedelta, timezone +from typing import Annotated, Optional + +import typer + +from diracx.client.aio import DiracClient +from diracx.client.models import DeviceFlowErrorResponse +from diracx.core.preferences import get_diracx_preferences +from diracx.core.utils import read_credentials, write_credentials + +from .utils import AsyncTyper + +app = AsyncTyper() + + +async def installation_metadata(): + async with DiracClient() as api: + return await api.well_known.installation_metadata() + + +def vo_callback(vo: str | None) -> str: + metadata = asyncio.run(installation_metadata()) + vos = list(metadata.virtual_organizations) + if not vo: + raise typer.BadParameter( + f"VO must be specified, available options are: {' '.join(vos)}" + ) + if vo not in vos: + raise typer.BadParameter( + f"Unknown VO {vo}, available options are: {' '.join(vos)}" + ) + return vo + + +@app.async_command() +async def login( + vo: Annotated[ + Optional[str], + typer.Argument(callback=vo_callback, help="Virtual Organization name"), + ] = None, + group: Optional[str] = typer.Option( + None, + help="Group name within the VO. If not provided, the default group for the VO will be used.", + ), + property: Optional[list[str]] = typer.Option( + None, + help=( + "List of properties to add to the default properties of the group. " + "If not provided, default properties of the group will be used." + ), + ), +): + """Login to the DIRAC system using the device flow. + + - If only VO is provided: Uses the default group and its properties for the VO. + + - If VO and group are provided: Uses the specified group and its properties for the VO. + + - If VO and properties are provided: Uses the default group and combines its properties with the + provided properties. + + - If VO, group, and properties are provided: Uses the specified group and combines its properties with the + provided properties. + """ + scopes = [f"vo:{vo}"] + if group: + scopes.append(f"group:{group}") + if property: + scopes += [f"property:{p}" for p in property] + + print(f"Logging in with scopes: {scopes}") + async with DiracClient() as api: + data = await api.auth.initiate_device_flow( + client_id=api.client_id, + scope=" ".join(scopes), + ) + print("Now go to:", data.verification_uri_complete) + expires = datetime.now(tz=timezone.utc) + timedelta( + seconds=data.expires_in - 30 + ) + while expires > datetime.now(tz=timezone.utc): + print(".", end="", flush=True) + response = await api.auth.token(device_code=data.device_code, client_id=api.client_id) # type: ignore + if isinstance(response, DeviceFlowErrorResponse): + if response.error == "authorization_pending": + # TODO: Setting more than 5 seconds results in an error + # Related to keep-alive disconnects from uvicon (--timeout-keep-alive) + await asyncio.sleep(2) + continue + raise RuntimeError(f"Device flow failed with {response}") + break + else: + raise RuntimeError("Device authorization flow expired") + + # Save credentials + write_credentials(response) + credentials_path = get_diracx_preferences().credentials_path + print(f"Saved credentials to {credentials_path}") + print("\nLogin successful!") + + +@app.async_command() +async def whoami(): + async with DiracClient() as api: + user_info = await api.auth.userinfo() + # TODO: Add a RICH output format + print(json.dumps(user_info.as_dict(), indent=2)) + + +@app.async_command() +async def logout(): + async with DiracClient() as api: + credentials_path = get_diracx_preferences().credentials_path + if credentials_path.exists(): + credentials = read_credentials(credentials_path) + + # Revoke refresh token + try: + await api.auth.revoke_refresh_token(credentials.refresh_token) + except Exception as e: + print(f"Error revoking the refresh token {e!r}") + pass + + # Remove credentials + credentials_path.unlink(missing_ok=True) + print(f"Removed credentials from {credentials_path}") + print("\nLogout successful!") + + +@app.callback() +def callback(output_format: Optional[str] = None): + if output_format is not None: + os.environ["DIRACX_OUTPUT_FORMAT"] = output_format diff --git a/diracx-cli/src/diracx/cli/config.py b/diracx-cli/src/diracx/cli/config.py index 4c32caf1..99ea4296 100644 --- a/diracx-cli/src/diracx/cli/config.py +++ b/diracx-cli/src/diracx/cli/config.py @@ -1,5 +1,6 @@ # Can't using PEP-604 with typer: https://github.com/tiangolo/typer/issues/348 # from __future__ import annotations +from __future__ import annotations __all__ = ("dump",) diff --git a/diracx-cli/src/diracx/cli/internal/__init__.py b/diracx-cli/src/diracx/cli/internal/__init__.py index 8fab8a68..d89a2e44 100644 --- a/diracx-cli/src/diracx/cli/internal/__init__.py +++ b/diracx-cli/src/diracx/cli/internal/__init__.py @@ -1,192 +1,8 @@ -from pathlib import Path -from typing import Annotated, Optional +from __future__ import annotations -import git -import typer -import yaml -from pydantic import TypeAdapter - -from diracx.core.config import ConfigSource, ConfigSourceUrl -from diracx.core.config.schema import ( - Config, - DIRACConfig, - GroupConfig, - IdpConfig, - OperationsConfig, - RegistryConfig, - UserConfig, -) - -from ..utils import AsyncTyper from . import legacy +from .config import app -app = AsyncTyper() -app.add_typer(legacy.app, name="legacy") - - -@app.command() -def generate_cs(config_repo: str): - """Generate a minimal DiracX configuration repository.""" - # TODO: The use of TypeAdapter should be moved in to typer itself - config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) - if config_repo.scheme != "git+file" or config_repo.path is None: - raise NotImplementedError("Only git+file:// URLs are supported") - repo_path = Path(config_repo.path) - if repo_path.exists() and list(repo_path.iterdir()): - typer.echo(f"ERROR: Directory {repo_path} already exists", err=True) - raise typer.Exit(1) - - config = Config( - Registry={}, - DIRAC=DIRACConfig(), - Operations={"Defaults": OperationsConfig()}, - ) - - git.Repo.init(repo_path, initial_branch="master") - update_config_and_commit( - repo_path=repo_path, config=config, message="Initial commit" - ) - typer.echo(f"Successfully created repo in {config_repo}", err=True) - - -@app.command() -def add_vo( - config_repo: str, - *, - vo: Annotated[str, typer.Option()], - default_group: Optional[str] = "user", - idp_url: Annotated[str, typer.Option()], - idp_client_id: Annotated[str, typer.Option()], -): - """Add a registry entry (vo) to an existing configuration repository.""" - # TODO: The use of TypeAdapter should be moved in to typer itself - config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) - if config_repo.scheme != "git+file" or config_repo.path is None: - raise NotImplementedError("Only git+file:// URLs are supported") - repo_path = Path(config_repo.path) - - # A VO should at least contain a default group - new_registry = RegistryConfig( - IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id), - DefaultGroup=default_group, - Users={}, - Groups={ - default_group: GroupConfig( - Properties={"NormalUser"}, Quota=None, Users=set() - ) - }, - ) - - config = ConfigSource.create_from_url(backend_url=repo_path).read_config() - - if vo in config.Registry: - typer.echo(f"ERROR: VO {vo} already exists", err=True) - raise typer.Exit(1) - - config.Registry[vo] = new_registry - - update_config_and_commit( - repo_path=repo_path, - config=config, - message=f"Added vo {vo} registry (default group {default_group} and idp {idp_url})", - ) - typer.echo(f"Successfully added vo to {config_repo}", err=True) - - -@app.command() -def add_group( - config_repo: str, - *, - vo: Annotated[str, typer.Option()], - group: Annotated[str, typer.Option()], - properties: list[str] = ["NormalUser"], -): - """Add a group to an existing vo in the configuration repository.""" - # TODO: The use of TypeAdapter should be moved in to typer itself - config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) - if config_repo.scheme != "git+file" or config_repo.path is None: - raise NotImplementedError("Only git+file:// URLs are supported") - repo_path = Path(config_repo.path) - - new_group = GroupConfig(Properties=set(properties), Quota=None, Users=set()) - - config = ConfigSource.create_from_url(backend_url=repo_path).read_config() +__all__ = ("app",) - if vo not in config.Registry: - typer.echo(f"ERROR: Virtual Organization {vo} does not exist", err=True) - raise typer.Exit(1) - - if group in config.Registry[vo].Groups.keys(): - typer.echo(f"ERROR: Group {group} already exists in {vo}", err=True) - raise typer.Exit(1) - - config.Registry[vo].Groups[group] = new_group - - update_config_and_commit( - repo_path=repo_path, config=config, message=f"Added group {group} in {vo}" - ) - typer.echo(f"Successfully added group to {config_repo}", err=True) - - -@app.command() -def add_user( - config_repo: str, - *, - vo: Annotated[str, typer.Option()], - groups: Annotated[Optional[list[str]], typer.Option("--group")] = None, - sub: Annotated[str, typer.Option()], - preferred_username: Annotated[str, typer.Option()], -): - """Add a user to an existing vo and group.""" - # TODO: The use of TypeAdapter should be moved in to typer itself - config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) - if config_repo.scheme != "git+file" or config_repo.path is None: - raise NotImplementedError("Only git+file:// URLs are supported") - - repo_path = Path(config_repo.path) - - new_user = UserConfig(PreferedUsername=preferred_username) - - config = ConfigSource.create_from_url(backend_url=repo_path).read_config() - - if vo not in config.Registry: - typer.echo(f"ERROR: Virtual Organization {vo} does not exist", err=True) - raise typer.Exit(1) - - if sub in config.Registry[vo].Users: - typer.echo(f"ERROR: User {sub} already exists", err=True) - raise typer.Exit(1) - - config.Registry[vo].Users[sub] = new_user - - if not groups: - groups = [config.Registry[vo].DefaultGroup] - - for group in set(groups): - if group not in config.Registry[vo].Groups: - typer.echo(f"ERROR: Group {group} does not exist in {vo}", err=True) - raise typer.Exit(1) - if sub in config.Registry[vo].Groups[group].Users: - typer.echo(f"ERROR: User {sub} already exists in group {group}", err=True) - raise typer.Exit(1) - - config.Registry[vo].Groups[group].Users.add(sub) - - update_config_and_commit( - repo_path=repo_path, - config=config, - message=f"Added user {sub} ({preferred_username}) to vo {vo} and groups {groups}", - ) - typer.echo(f"Successfully added user to {config_repo}", err=True) - - -def update_config_and_commit(repo_path: Path, config: Config, message: str): - """Update the yaml file in the repo and commit it.""" - repo = git.Repo(repo_path) - yaml_path = repo_path / "default.yml" - typer.echo(f"Writing back configuration to {yaml_path}", err=True) - yaml_path.write_text( - yaml.safe_dump(config.model_dump(exclude_unset=True, mode="json")) - ) - repo.index.add([yaml_path.relative_to(repo_path)]) - repo.index.commit(message) +app.add_typer(legacy.app, name="legacy") diff --git a/diracx-cli/src/diracx/cli/internal/config.py b/diracx-cli/src/diracx/cli/internal/config.py new file mode 100644 index 00000000..1c373072 --- /dev/null +++ b/diracx-cli/src/diracx/cli/internal/config.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Annotated, Optional + +import git +import typer +import yaml +from pydantic import TypeAdapter + +from diracx.core.config import ConfigSource, ConfigSourceUrl +from diracx.core.config.schema import ( + Config, + DIRACConfig, + GroupConfig, + IdpConfig, + OperationsConfig, + RegistryConfig, + UserConfig, +) + +from ..utils import AsyncTyper + +app = AsyncTyper() + + +@app.command() +def generate_cs(config_repo: str): + """Generate a minimal DiracX configuration repository.""" + # TODO: The use of TypeAdapter should be moved in to typer itself + config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) + if config_repo.scheme != "git+file" or config_repo.path is None: + raise NotImplementedError("Only git+file:// URLs are supported") + repo_path = Path(config_repo.path) + if repo_path.exists() and list(repo_path.iterdir()): + typer.echo(f"ERROR: Directory {repo_path} already exists", err=True) + raise typer.Exit(1) + + config = Config( + Registry={}, + DIRAC=DIRACConfig(), + Operations={"Defaults": OperationsConfig()}, + ) + + git.Repo.init(repo_path, initial_branch="master") + update_config_and_commit( + repo_path=repo_path, config=config, message="Initial commit" + ) + typer.echo(f"Successfully created repo in {config_repo}", err=True) + + +@app.command() +def add_vo( + config_repo: str, + *, + vo: Annotated[str, typer.Option()], + default_group: Optional[str] = "user", + idp_url: Annotated[str, typer.Option()], + idp_client_id: Annotated[str, typer.Option()], +): + """Add a registry entry (vo) to an existing configuration repository.""" + # TODO: The use of TypeAdapter should be moved in to typer itself + config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) + if config_repo.scheme != "git+file" or config_repo.path is None: + raise NotImplementedError("Only git+file:// URLs are supported") + repo_path = Path(config_repo.path) + + # A VO should at least contain a default group + new_registry = RegistryConfig( + IdP=IdpConfig(URL=idp_url, ClientID=idp_client_id), + DefaultGroup=default_group, + Users={}, + Groups={ + default_group: GroupConfig( + Properties={"NormalUser"}, Quota=None, Users=set() + ) + }, + ) + + config = ConfigSource.create_from_url(backend_url=repo_path).read_config() + + if vo in config.Registry: + typer.echo(f"ERROR: VO {vo} already exists", err=True) + raise typer.Exit(1) + + config.Registry[vo] = new_registry + + update_config_and_commit( + repo_path=repo_path, + config=config, + message=f"Added vo {vo} registry (default group {default_group} and idp {idp_url})", + ) + typer.echo(f"Successfully added vo to {config_repo}", err=True) + + +@app.command() +def add_group( + config_repo: str, + *, + vo: Annotated[str, typer.Option()], + group: Annotated[str, typer.Option()], + properties: list[str] = ["NormalUser"], +): + """Add a group to an existing vo in the configuration repository.""" + # TODO: The use of TypeAdapter should be moved in to typer itself + config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) + if config_repo.scheme != "git+file" or config_repo.path is None: + raise NotImplementedError("Only git+file:// URLs are supported") + repo_path = Path(config_repo.path) + + new_group = GroupConfig(Properties=set(properties), Quota=None, Users=set()) + + config = ConfigSource.create_from_url(backend_url=repo_path).read_config() + + if vo not in config.Registry: + typer.echo(f"ERROR: Virtual Organization {vo} does not exist", err=True) + raise typer.Exit(1) + + if group in config.Registry[vo].Groups.keys(): + typer.echo(f"ERROR: Group {group} already exists in {vo}", err=True) + raise typer.Exit(1) + + config.Registry[vo].Groups[group] = new_group + + update_config_and_commit( + repo_path=repo_path, config=config, message=f"Added group {group} in {vo}" + ) + typer.echo(f"Successfully added group to {config_repo}", err=True) + + +@app.command() +def add_user( + config_repo: str, + *, + vo: Annotated[str, typer.Option()], + groups: Annotated[Optional[list[str]], typer.Option("--group")] = None, + sub: Annotated[str, typer.Option()], + preferred_username: Annotated[str, typer.Option()], +): + """Add a user to an existing vo and group.""" + # TODO: The use of TypeAdapter should be moved in to typer itself + config_repo = TypeAdapter(ConfigSourceUrl).validate_python(config_repo) + if config_repo.scheme != "git+file" or config_repo.path is None: + raise NotImplementedError("Only git+file:// URLs are supported") + + repo_path = Path(config_repo.path) + + new_user = UserConfig(PreferedUsername=preferred_username) + + config = ConfigSource.create_from_url(backend_url=repo_path).read_config() + + if vo not in config.Registry: + typer.echo(f"ERROR: Virtual Organization {vo} does not exist", err=True) + raise typer.Exit(1) + + if sub in config.Registry[vo].Users: + typer.echo(f"ERROR: User {sub} already exists", err=True) + raise typer.Exit(1) + + config.Registry[vo].Users[sub] = new_user + + if not groups: + groups = [config.Registry[vo].DefaultGroup] + + for group in set(groups): + if group not in config.Registry[vo].Groups: + typer.echo(f"ERROR: Group {group} does not exist in {vo}", err=True) + raise typer.Exit(1) + if sub in config.Registry[vo].Groups[group].Users: + typer.echo(f"ERROR: User {sub} already exists in group {group}", err=True) + raise typer.Exit(1) + + config.Registry[vo].Groups[group].Users.add(sub) + + update_config_and_commit( + repo_path=repo_path, + config=config, + message=f"Added user {sub} ({preferred_username}) to vo {vo} and groups {groups}", + ) + typer.echo(f"Successfully added user to {config_repo}", err=True) + + +def update_config_and_commit(repo_path: Path, config: Config, message: str): + """Update the yaml file in the repo and commit it.""" + repo = git.Repo(repo_path) + yaml_path = repo_path / "default.yml" + typer.echo(f"Writing back configuration to {yaml_path}", err=True) + yaml_path.write_text( + yaml.safe_dump(config.model_dump(exclude_unset=True, mode="json")) + ) + repo.index.add([yaml_path.relative_to(repo_path)]) + repo.index.commit(message) diff --git a/diracx-cli/src/diracx/cli/internal/legacy.py b/diracx-cli/src/diracx/cli/internal/legacy.py index 509b89c6..03648e7c 100644 --- a/diracx-cli/src/diracx/cli/internal/legacy.py +++ b/diracx-cli/src/diracx/cli/internal/legacy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import json @@ -40,17 +42,6 @@ class ConversionConfig(BaseModel): VOs: dict[str, VOConfig] -# def parse_args(): -# parser = argparse.ArgumentParser("Convert the legacy DIRAC CS to the new format") -# parser.add_argument("old_file", type=Path) -# parser.add_argument("conversion_config", type=Path) -# parser.add_argument("repo", type=Path) -# args = parser.parse_args() - - -# main(args.old_file, args.conversion_config, args.repo / DEFAULT_CONFIG_FILE) - - @app.command() def cs_sync(old_file: Path, new_file: Path): """Load the old CS and convert it to the new YAML format.""" diff --git a/diracx-cli/src/diracx/cli/jobs.py b/diracx-cli/src/diracx/cli/jobs.py index 87fdc99b..fa4dfc8a 100644 --- a/diracx-cli/src/diracx/cli/jobs.py +++ b/diracx-cli/src/diracx/cli/jobs.py @@ -1,5 +1,6 @@ # Can't using PEP-604 with typer: https://github.com/tiangolo/typer/issues/348 # from __future__ import annotations +from __future__ import annotations __all__ = ("app",) diff --git a/diracx-cli/tests/legacy/cs_sync/test_cssync.py b/diracx-cli/tests/legacy/cs_sync/test_cssync.py index a4f67717..25febe53 100644 --- a/diracx-cli/tests/legacy/cs_sync/test_cssync.py +++ b/diracx-cli/tests/legacy/cs_sync/test_cssync.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import yaml diff --git a/diracx-cli/tests/legacy/test_legacy.py b/diracx-cli/tests/legacy/test_legacy.py index f8c8c195..1efacf6f 100644 --- a/diracx-cli/tests/legacy/test_legacy.py +++ b/diracx-cli/tests/legacy/test_legacy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from pathlib import Path import yaml diff --git a/diracx-cli/tests/test_login.py b/diracx-cli/tests/test_login.py index 4bfcf97b..a629fa20 100644 --- a/diracx-cli/tests/test_login.py +++ b/diracx-cli/tests/test_login.py @@ -19,7 +19,7 @@ async def test_logout(monkeypatch, capfd, cli_env, with_cli_login): assert expected_credentials_path.exists() # Run the logout command - await cli.logout() + await cli.auth.logout() captured = capfd.readouterr() assert "Removed credentials from" in captured.out assert "Logout successful!" in captured.out @@ -29,7 +29,7 @@ async def test_logout(monkeypatch, capfd, cli_env, with_cli_login): assert not expected_credentials_path.exists() # Rerun the logout command, it should not fail - await cli.logout() + await cli.auth.logout() captured = capfd.readouterr() assert "Removed credentials from" not in captured.out assert "Logout successful!" in captured.out diff --git a/diracx-client/src/diracx/client/aio.py b/diracx-client/src/diracx/client/aio.py index 25ed7778..f5fb6725 100644 --- a/diracx-client/src/diracx/client/aio.py +++ b/diracx-client/src/diracx/client/aio.py @@ -1 +1,3 @@ from .patches.aio import DiracClient + +__all__ = ("DiracClient",) diff --git a/diracx-client/src/diracx/client/models.py b/diracx-client/src/diracx/client/models.py index 15dd42a0..2f4e3577 100644 --- a/diracx-client/src/diracx/client/models.py +++ b/diracx-client/src/diracx/client/models.py @@ -3,3 +3,5 @@ # TODO: replace with postprocess from .generated.models import DeviceFlowErrorResponse + +__all__ = ("DeviceFlowErrorResponse",) diff --git a/diracx-client/tests/test_auth.py b/diracx-client/tests/test_auth.py index 75cb05af..05f0079d 100644 --- a/diracx-client/tests/test_auth.py +++ b/diracx-client/tests/test_auth.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import fcntl import json from datetime import datetime, timedelta, timezone diff --git a/diracx-client/tests/test_regenerate.py b/diracx-client/tests/test_regenerate.py index ab3686bc..e0e032ac 100644 --- a/diracx-client/tests/test_regenerate.py +++ b/diracx-client/tests/test_regenerate.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import subprocess from pathlib import Path diff --git a/diracx-core/src/diracx/core/config/__init__.py b/diracx-core/src/diracx/core/config/__init__.py index 0283767c..37ff457c 100644 --- a/diracx-core/src/diracx/core/config/__init__.py +++ b/diracx-core/src/diracx/core/config/__init__.py @@ -1,262 +1,21 @@ -"""This module implements the logic of the configuration server side. - -This is where all the backend abstraction and the caching logic takes place. -""" +"""Configuration module: Provides tools for managing backend configurations.""" from __future__ import annotations -__all__ = ("Config", "ConfigSource", "LocalGitConfigSource", "RemoteGitConfigSource") - -import asyncio -import logging -import os -from abc import ABCMeta, abstractmethod -from datetime import datetime, timezone -from pathlib import Path -from tempfile import TemporaryDirectory -from typing import Annotated - -import sh -import yaml -from cachetools import Cache, LRUCache, TTLCache, cachedmethod -from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints - -from ..exceptions import BadConfigurationVersionError -from ..extensions import select_from_extension from .schema import Config - -DEFAULT_CONFIG_FILE = "default.yml" -DEFAULT_GIT_BRANCH = "master" -DEFAULT_CS_CACHE_TTL = 5 -MAX_CS_CACHED_VERSIONS = 1 -DEFAULT_PULL_CACHE_TTL = 5 -MAX_PULL_CACHED_VERSIONS = 1 - -logger = logging.getLogger(__name__) - - -def is_running_in_async_context(): - try: - asyncio.get_running_loop() - return True - except RuntimeError: - return False - - -def _apply_default_scheme(value: str) -> str: - """Applies the default git+file:// scheme if not present.""" - if isinstance(value, str) and "://" not in value: - value = f"git+file://{value}" - return value - - -class AnyUrlWithoutHost(AnyUrl): - - _constraints = UrlConstraints(host_required=False) - - -ConfigSourceUrl = Annotated[AnyUrlWithoutHost, BeforeValidator(_apply_default_scheme)] - - -class ConfigSource(metaclass=ABCMeta): - """This class is the abstract base class that should be used everywhere - throughout the code. - It acts as a factory for concrete implementations - See the abstractmethods to implement a concrete class. - """ - - # Keep a mapping between the scheme and the class - __registry: dict[str, type[ConfigSource]] = {} - scheme: str - - @abstractmethod - def __init__(self, *, backend_url: ConfigSourceUrl) -> None: ... - - @abstractmethod - def latest_revision(self) -> tuple[str, datetime]: - """Must return: - * a unique hash as a string, representing the last version - * a datetime object corresponding to when the version dates. - """ - ... - - @abstractmethod - def read_raw(self, hexsha: str, modified: datetime) -> Config: - """Return the Config object that corresponds to the - specific hash - The `modified` parameter is just added as a attribute to the config. - - """ - ... - - def __init_subclass__(cls) -> None: - """Keep a record of .""" - if cls.scheme in cls.__registry: - raise TypeError(f"{cls.scheme=} is already define") - cls.__registry[cls.scheme] = cls - - @classmethod - def create(cls): - return cls.create_from_url(backend_url=os.environ["DIRACX_CONFIG_BACKEND_URL"]) - - @classmethod - def create_from_url( - cls, *, backend_url: ConfigSourceUrl | Path | str - ) -> ConfigSource: - """Factory method to produce a concrete instance depending on - the backend URL scheme. - - """ - url = TypeAdapter(ConfigSourceUrl).validate_python(str(backend_url)) - return cls.__registry[url.scheme](backend_url=url) - - def read_config(self) -> Config: - """:raises: - git.exc.BadName if version does not exist - """ - hexsha, modified = self.latest_revision() - return self.read_raw(hexsha, modified) - - @abstractmethod - def clear_caches(self): ... - - -class BaseGitConfigSource(ConfigSource): - """Base class for the git based config source - The caching is based on 2 caches: - * TTL to find the latest commit hashes - * LRU to keep in memory the last few versions. - """ - - repo_location: Path - - # Needed because of the ConfigSource.__init_subclass__ - scheme = "basegit" - - def __init__(self, *, backend_url: ConfigSourceUrl) -> None: - self._latest_revision_cache: Cache = TTLCache( - MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL - ) - self._read_raw_cache: Cache = LRUCache(MAX_CS_CACHED_VERSIONS) - - @cachedmethod(lambda self: self._latest_revision_cache) - def latest_revision(self) -> tuple[str, datetime]: - try: - rev = sh.git( - "rev-parse", - DEFAULT_GIT_BRANCH, - _cwd=self.repo_location, - _tty_out=False, - _async=is_running_in_async_context(), - ).strip() - commit_info = sh.git.show( - "-s", - "--format=%ct", - rev, - _cwd=self.repo_location, - _tty_out=False, - _async=is_running_in_async_context(), - ).strip() - modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) - except sh.ErrorReturnCode as e: - raise BadConfigurationVersionError( - f"Error parsing latest revision: {e}" - ) from e - logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) - return rev, modified - - @cachedmethod(lambda self: self._read_raw_cache) - def read_raw(self, hexsha: str, modified: datetime) -> Config: - """:param: hexsha commit hash""" - logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified) - try: - blob = sh.git.show( - f"{hexsha}:{DEFAULT_CONFIG_FILE}", - _cwd=self.repo_location, - _tty_out=False, - _async=False, - ) - raw_obj = yaml.safe_load(blob) - except sh.ErrorReturnCode as e: - raise BadConfigurationVersionError( - f"Error reading configuration: {e}" - ) from e - - config_class: Config = select_from_extension(group="diracx", name="config")[ - 0 - ].load() - config = config_class.model_validate(raw_obj) - config._hexsha = hexsha - config._modified = modified - return config - - def clear_caches(self): - self._latest_revision_cache.clear() - self._read_raw_cache.clear() - - -class LocalGitConfigSource(BaseGitConfigSource): - """The configuration is stored on a local git repository - When running on multiple servers, the filesystem must be shared. - """ - - scheme = "git+file" - - def __init__(self, *, backend_url: ConfigSourceUrl) -> None: - super().__init__(backend_url=backend_url) - if not backend_url.path: - raise ValueError("Empty path for LocalGitConfigSource") - - self.repo_location = Path(backend_url.path) - # Check if it's a valid git repository - try: - sh.git( - "rev-parse", - "--git-dir", - _cwd=self.repo_location, - _tty_out=False, - _async=False, - ) - except sh.ErrorReturnCode as e: - raise ValueError( - f"{self.repo_location} is not a valid git repository" - ) from e - - def __hash__(self): - return hash(self.repo_location) - - -class RemoteGitConfigSource(BaseGitConfigSource): - """Use a remote directory as a config source.""" - - scheme = "git+https" - - def __init__(self, *, backend_url: ConfigSourceUrl) -> None: - super().__init__(backend_url=backend_url) - if not backend_url: - raise ValueError("No remote url for RemoteGitConfigSource") - - # git does not understand `git+https`, so we remove the `git+` part - self.remote_url = str(backend_url).replace("git+", "") - self._temp_dir = TemporaryDirectory() - self.repo_location = Path(self._temp_dir.name) - sh.git.clone(self.remote_url, self.repo_location, _async=False) - self._pull_cache: Cache = TTLCache( - MAX_PULL_CACHED_VERSIONS, DEFAULT_PULL_CACHE_TTL - ) - - def clear_caches(self): - super().clear_caches() - self._pull_cache.clear() - - def __hash__(self): - return hash(self.repo_location) - - @cachedmethod(lambda self: self._pull_cache) - def _pull(self): - """Git pull from remote repo.""" - sh.git.pull(_cwd=self.repo_location, _async=False) - - def latest_revision(self) -> tuple[str, datetime]: - self._pull() - return super().latest_revision() +from .sources import ( + ConfigSource, + ConfigSourceUrl, + LocalGitConfigSource, + RemoteGitConfigSource, + is_running_in_async_context, +) + +__all__ = ( + "Config", + "ConfigSource", + "ConfigSourceUrl", + "LocalGitConfigSource", + "RemoteGitConfigSource", + "is_running_in_async_context", +) diff --git a/diracx-core/src/diracx/core/config/sources.py b/diracx-core/src/diracx/core/config/sources.py new file mode 100644 index 00000000..2dd6eaef --- /dev/null +++ b/diracx-core/src/diracx/core/config/sources.py @@ -0,0 +1,260 @@ +"""This module implements the logic of the configuration server side. + +This is where all the backend abstraction and the caching logic takes place. +""" + +from __future__ import annotations + +import asyncio +import logging +import os +from abc import ABCMeta, abstractmethod +from datetime import datetime, timezone +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import Annotated + +import sh +import yaml +from cachetools import Cache, LRUCache, TTLCache, cachedmethod +from pydantic import AnyUrl, BeforeValidator, TypeAdapter, UrlConstraints + +from ..exceptions import BadConfigurationVersionError +from ..extensions import select_from_extension +from .schema import Config + +DEFAULT_CONFIG_FILE = "default.yml" +DEFAULT_GIT_BRANCH = "master" +DEFAULT_CS_CACHE_TTL = 5 +MAX_CS_CACHED_VERSIONS = 1 +DEFAULT_PULL_CACHE_TTL = 5 +MAX_PULL_CACHED_VERSIONS = 1 + +logger = logging.getLogger(__name__) + + +def is_running_in_async_context(): + try: + asyncio.get_running_loop() + return True + except RuntimeError: + return False + + +def _apply_default_scheme(value: str) -> str: + """Applies the default git+file:// scheme if not present.""" + if isinstance(value, str) and "://" not in value: + value = f"git+file://{value}" + return value + + +class AnyUrlWithoutHost(AnyUrl): + + _constraints = UrlConstraints(host_required=False) + + +ConfigSourceUrl = Annotated[AnyUrlWithoutHost, BeforeValidator(_apply_default_scheme)] + + +class ConfigSource(metaclass=ABCMeta): + """This class is the abstract base class that should be used everywhere + throughout the code. + It acts as a factory for concrete implementations + See the abstractmethods to implement a concrete class. + """ + + # Keep a mapping between the scheme and the class + __registry: dict[str, type["ConfigSource"]] = {} + scheme: str + + @abstractmethod + def __init__(self, *, backend_url: ConfigSourceUrl) -> None: ... + + @abstractmethod + def latest_revision(self) -> tuple[str, datetime]: + """Must return: + * a unique hash as a string, representing the last version + * a datetime object corresponding to when the version dates. + """ + ... + + @abstractmethod + def read_raw(self, hexsha: str, modified: datetime) -> Config: + """Return the Config object that corresponds to the + specific hash + The `modified` parameter is just added as a attribute to the config. + + """ + ... + + def __init_subclass__(cls) -> None: + """Keep a record of .""" + if cls.scheme in cls.__registry: + raise TypeError(f"{cls.scheme=} is already define") + cls.__registry[cls.scheme] = cls + + @classmethod + def create(cls): + return cls.create_from_url(backend_url=os.environ["DIRACX_CONFIG_BACKEND_URL"]) + + @classmethod + def create_from_url( + cls, *, backend_url: ConfigSourceUrl | Path | str + ) -> "ConfigSource": + """Factory method to produce a concrete instance depending on + the backend URL scheme. + + """ + url = TypeAdapter(ConfigSourceUrl).validate_python(str(backend_url)) + return cls.__registry[url.scheme](backend_url=url) + + def read_config(self) -> Config: + """:raises: + git.exc.BadName if version does not exist + """ + hexsha, modified = self.latest_revision() + return self.read_raw(hexsha, modified) + + @abstractmethod + def clear_caches(self): ... + + +class BaseGitConfigSource(ConfigSource): + """Base class for the git based config source + The caching is based on 2 caches: + * TTL to find the latest commit hashes + * LRU to keep in memory the last few versions. + """ + + repo_location: Path + + # Needed because of the ConfigSource.__init_subclass__ + scheme = "basegit" + + def __init__(self, *, backend_url: ConfigSourceUrl) -> None: + self._latest_revision_cache: Cache = TTLCache( + MAX_CS_CACHED_VERSIONS, DEFAULT_CS_CACHE_TTL + ) + self._read_raw_cache: Cache = LRUCache(MAX_CS_CACHED_VERSIONS) + + @cachedmethod(lambda self: self._latest_revision_cache) + def latest_revision(self) -> tuple[str, datetime]: + try: + rev = sh.git( + "rev-parse", + DEFAULT_GIT_BRANCH, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), + ).strip() + commit_info = sh.git.show( + "-s", + "--format=%ct", + rev, + _cwd=self.repo_location, + _tty_out=False, + _async=is_running_in_async_context(), + ).strip() + modified = datetime.fromtimestamp(int(commit_info), tz=timezone.utc) + except sh.ErrorReturnCode as e: + raise BadConfigurationVersionError( + f"Error parsing latest revision: {e}" + ) from e + logger.debug("Latest revision for %s is %s with mtime %s", self, rev, modified) + return rev, modified + + @cachedmethod(lambda self: self._read_raw_cache) + def read_raw(self, hexsha: str, modified: datetime) -> Config: + """:param: hexsha commit hash""" + logger.debug("Reading %s for %s with mtime %s", self, hexsha, modified) + try: + blob = sh.git.show( + f"{hexsha}:{DEFAULT_CONFIG_FILE}", + _cwd=self.repo_location, + _tty_out=False, + _async=False, + ) + raw_obj = yaml.safe_load(blob) + except sh.ErrorReturnCode as e: + raise BadConfigurationVersionError( + f"Error reading configuration: {e}" + ) from e + + config_class: Config = select_from_extension(group="diracx", name="config")[ + 0 + ].load() + config = config_class.model_validate(raw_obj) + config._hexsha = hexsha + config._modified = modified + return config + + def clear_caches(self): + self._latest_revision_cache.clear() + self._read_raw_cache.clear() + + +class LocalGitConfigSource(BaseGitConfigSource): + """The configuration is stored on a local git repository + When running on multiple servers, the filesystem must be shared. + """ + + scheme = "git+file" + + def __init__(self, *, backend_url: ConfigSourceUrl) -> None: + super().__init__(backend_url=backend_url) + if not backend_url.path: + raise ValueError("Empty path for LocalGitConfigSource") + + self.repo_location = Path(backend_url.path) + # Check if it's a valid git repository + try: + sh.git( + "rev-parse", + "--git-dir", + _cwd=self.repo_location, + _tty_out=False, + _async=False, + ) + except sh.ErrorReturnCode as e: + raise ValueError( + f"{self.repo_location} is not a valid git repository" + ) from e + + def __hash__(self): + return hash(self.repo_location) + + +class RemoteGitConfigSource(BaseGitConfigSource): + """Use a remote directory as a config source.""" + + scheme = "git+https" + + def __init__(self, *, backend_url: ConfigSourceUrl) -> None: + super().__init__(backend_url=backend_url) + if not backend_url: + raise ValueError("No remote url for RemoteGitConfigSource") + + # git does not understand `git+https`, so we remove the `git+` part + self.remote_url = str(backend_url).replace("git+", "") + self._temp_dir = TemporaryDirectory() + self.repo_location = Path(self._temp_dir.name) + sh.git.clone(self.remote_url, self.repo_location, _async=False) + self._pull_cache: Cache = TTLCache( + MAX_PULL_CACHED_VERSIONS, DEFAULT_PULL_CACHE_TTL + ) + + def clear_caches(self): + super().clear_caches() + self._pull_cache.clear() + + def __hash__(self): + return hash(self.repo_location) + + @cachedmethod(lambda self: self._pull_cache) + def _pull(self): + """Git pull from remote repo.""" + sh.git.pull(_cwd=self.repo_location, _async=False) + + def latest_revision(self) -> tuple[str, datetime]: + self._pull() + return super().latest_revision() diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 79834b1c..79eb0fbd 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from http import HTTPStatus diff --git a/diracx-core/src/diracx/core/extensions.py b/diracx-core/src/diracx/core/extensions.py index 100ed936..28a02570 100644 --- a/diracx-core/src/diracx/core/extensions.py +++ b/diracx-core/src/diracx/core/extensions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + __all__ = ("select_from_extension",) import os diff --git a/diracx-core/tests/test_config_source.py b/diracx-core/tests/test_config_source.py index bb4a2e3b..a889690f 100644 --- a/diracx-core/tests/test_config_source.py +++ b/diracx-core/tests/test_config_source.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import datetime from urllib import request @@ -23,7 +25,7 @@ def github_is_down(): def test_remote_git_config_source(monkeypatch): monkeypatch.setattr( - "diracx.core.config.DEFAULT_CONFIG_FILE", + "diracx.core.config.sources.DEFAULT_CONFIG_FILE", "k3s/examples/cs.yaml", ) remote_conf = ConfigSource.create_from_url(backend_url=TEST_REPO) diff --git a/diracx-core/tests/test_extensions.py b/diracx-core/tests/test_extensions.py index 3bff1927..67c2d041 100644 --- a/diracx-core/tests/test_extensions.py +++ b/diracx-core/tests/test_extensions.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from diracx.core.extensions import extensions_by_priority diff --git a/diracx-db/src/diracx/db/exceptions.py b/diracx-db/src/diracx/db/exceptions.py index 0a163f92..36a3e278 100644 --- a/diracx-db/src/diracx/db/exceptions.py +++ b/diracx-db/src/diracx/db/exceptions.py @@ -1,2 +1,5 @@ +from __future__ import annotations + + class DBUnavailableError(Exception): pass diff --git a/diracx-db/src/diracx/db/sql/auth/schema.py b/diracx-db/src/diracx/db/sql/auth/schema.py index 8d7dddc7..d4397cc8 100644 --- a/diracx-db/src/diracx/db/sql/auth/schema.py +++ b/diracx-db/src/diracx/db/sql/auth/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import Enum, auto from sqlalchemy import ( diff --git a/diracx-db/src/diracx/db/sql/dummy/schema.py b/diracx-db/src/diracx/db/sql/dummy/schema.py index a0c11c09..5379de94 100644 --- a/diracx-db/src/diracx/db/sql/dummy/schema.py +++ b/diracx-db/src/diracx/db/sql/dummy/schema.py @@ -1,5 +1,7 @@ # The utils class define some boilerplate types that should be used # in place of the SQLAlchemy one. Have a look at them +from __future__ import annotations + from sqlalchemy import ForeignKey, Integer, String, Uuid from sqlalchemy.orm import declarative_base diff --git a/diracx-db/src/diracx/db/sql/job/schema.py b/diracx-db/src/diracx/db/sql/job/schema.py index eea1e3a1..a9f3a4bb 100644 --- a/diracx-db/src/diracx/db/sql/job/schema.py +++ b/diracx-db/src/diracx/db/sql/job/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy import ( DateTime, Enum, diff --git a/diracx-db/src/diracx/db/sql/job_logging/schema.py b/diracx-db/src/diracx/db/sql/job_logging/schema.py index 1c229bb7..b99503c0 100644 --- a/diracx-db/src/diracx/db/sql/job_logging/schema.py +++ b/diracx-db/src/diracx/db/sql/job_logging/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy import ( Integer, Numeric, diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py index 76cd5c89..bff7c460 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilot_agents/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy import ( DateTime, Double, diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index 28462778..9eb9dc22 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -5,7 +5,7 @@ import sqlalchemy from diracx.core.models import SandboxInfo, SandboxType, UserInfo -from diracx.db.sql.utils import BaseSQLDB, UTCNow +from diracx.db.sql.utils import BaseSQLDB, utcnow from .schema import Base as SandboxMetadataDBBase from .schema import SandBoxes, SBEntityMapping, SBOwners @@ -58,8 +58,8 @@ async def insert_sandbox( SEName=se_name, SEPFN=pfn, Bytes=size, - RegistrationTime=UTCNow(), - LastAccessTime=UTCNow(), + RegistrationTime=utcnow(), + LastAccessTime=utcnow(), ) try: result = await self.conn.execute(stmt) @@ -72,7 +72,7 @@ async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: stmt = ( sqlalchemy.update(SandBoxes) .where(SandBoxes.SEName == se_name, SandBoxes.SEPFN == pfn) - .values(LastAccessTime=UTCNow()) + .values(LastAccessTime=utcnow()) ) result = await self.conn.execute(stmt) assert result.rowcount == 1 diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py index 5864ea42..1c1133ff 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy import ( BigInteger, Boolean, diff --git a/diracx-db/src/diracx/db/sql/task_queue/schema.py b/diracx-db/src/diracx/db/sql/task_queue/schema.py index 6cfe2adc..0a3c0f03 100644 --- a/diracx-db/src/diracx/db/sql/task_queue/schema.py +++ b/diracx-db/src/diracx/db/sql/task_queue/schema.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from sqlalchemy import ( BigInteger, Boolean, diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index eafc4d3b..cd82d3c7 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -1,453 +1,24 @@ from __future__ import annotations -__all__ = ("utcnow", "Column", "NullColumn", "DateNowColumn", "BaseSQLDB") - -import contextlib -import logging -import os -import re -from abc import ABCMeta -from collections.abc import AsyncIterator -from contextvars import ContextVar -from datetime import datetime, timedelta, timezone -from functools import partial -from typing import TYPE_CHECKING, Self, cast - -import sqlalchemy.types as types -from pydantic import TypeAdapter -from sqlalchemy import Column as RawColumn -from sqlalchemy import DateTime, Enum, MetaData, func, select -from sqlalchemy.exc import OperationalError -from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine -from sqlalchemy.ext.compiler import compiles -from sqlalchemy.sql import expression - -from diracx.core.exceptions import InvalidQueryError -from diracx.core.extensions import select_from_extension -from diracx.core.models import SortDirection -from diracx.core.settings import SqlalchemyDsn -from diracx.db.exceptions import DBUnavailableError - -if TYPE_CHECKING: - from sqlalchemy.types import TypeEngine - -logger = logging.getLogger(__name__) - - -class UTCNow(expression.FunctionElement): - type: TypeEngine = DateTime() - inherit_cache: bool = True - - -@compiles(UTCNow, "postgresql") -def pg_utcnow(element, compiler, **kw) -> str: - return "TIMEZONE('utc', CURRENT_TIMESTAMP)" - - -@compiles(UTCNow, "mssql") -def ms_utcnow(element, compiler, **kw) -> str: - return "GETUTCDATE()" - - -@compiles(UTCNow, "mysql") -def mysql_utcnow(element, compiler, **kw) -> str: - return "(UTC_TIMESTAMP)" - - -@compiles(UTCNow, "sqlite") -def sqlite_utcnow(element, compiler, **kw) -> str: - return "DATETIME('now')" - - -class DateTrunc(expression.FunctionElement): - """Sqlalchemy function to truncate a date to a given resolution. - - Primarily used to be able to query for a specific resolution of a date e.g. - - select * from table where date_trunc('day', date_column) = '2021-01-01' - select * from table where date_trunc('year', date_column) = '2021' - select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00' - """ - - type = DateTime() - inherit_cache = True - - def __init__(self, *args, time_resolution, **kwargs) -> None: - super().__init__(*args, **kwargs) - self._time_resolution = time_resolution - - -@compiles(DateTrunc, "postgresql") -def pg_date_trunc(element, compiler, **kw): - res = { - "SECOND": "second", - "MINUTE": "minute", - "HOUR": "hour", - "DAY": "day", - "MONTH": "month", - "YEAR": "year", - }[element._time_resolution] - return f"date_trunc('{res}', {compiler.process(element.clauses)})" - - -@compiles(DateTrunc, "mysql") -def mysql_date_trunc(element, compiler, **kw): - pattern = { - "SECOND": "%Y-%m-%d %H:%i:%S", - "MINUTE": "%Y-%m-%d %H:%i", - "HOUR": "%Y-%m-%d %H", - "DAY": "%Y-%m-%d", - "MONTH": "%Y-%m", - "YEAR": "%Y", - }[element._time_resolution] - - (dt_col,) = list(element.clauses) - return compiler.process(func.date_format(dt_col, pattern)) - - -@compiles(DateTrunc, "sqlite") -def sqlite_date_trunc(element, compiler, **kw): - pattern = { - "SECOND": "%Y-%m-%d %H:%M:%S", - "MINUTE": "%Y-%m-%d %H:%M", - "HOUR": "%Y-%m-%d %H", - "DAY": "%Y-%m-%d", - "MONTH": "%Y-%m", - "YEAR": "%Y", - }[element._time_resolution] - (dt_col,) = list(element.clauses) - return compiler.process( - func.strftime( - pattern, - dt_col, - ) - ) - - -def substract_date(**kwargs: float) -> datetime: - return datetime.now(tz=timezone.utc) - timedelta(**kwargs) - - -Column: partial[RawColumn] = partial(RawColumn, nullable=False) -NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) -DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=UTCNow()) - - -def EnumColumn(name, enum_type, **kwargs): # noqa: N802 - return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) - - -class EnumBackedBool(types.TypeDecorator): - """Maps a ``EnumBackedBool()`` column to True/False in Python.""" - - impl = types.Enum - cache_ok: bool = True - - def __init__(self) -> None: - super().__init__("True", "False") - - def process_bind_param(self, value, dialect) -> str: - if value is True: - return "True" - elif value is False: - return "False" - else: - raise NotImplementedError(value, dialect) - - def process_result_value(self, value, dialect) -> bool: - if value == "True": - return True - elif value == "False": - return False - else: - raise NotImplementedError(f"Unknown {value=}") - - -class SQLDBError(Exception): - pass - - -class SQLDBUnavailableError(DBUnavailableError, SQLDBError): - """Used whenever we encounter a problem with the B connection.""" - - -class BaseSQLDB(metaclass=ABCMeta): - """This should be the base class of all the SQL DiracX DBs. - - The details covered here should be handled automatically by the service and - task machinery of DiracX and this documentation exists for informational - purposes. - - The available databases are discovered by calling `BaseSQLDB.available_urls`. - This method returns a mapping of database names to connection URLs. The - available databases are determined by the `diracx.dbs.sql` entrypoint in the - `pyproject.toml` file and the connection URLs are taken from the environment - variables of the form `DIRACX_DB_URL_`. - - If extensions to DiracX are being used, there can be multiple implementations - of the same database. To list the available implementations use - `BaseSQLDB.available_implementations(db_name)`. The first entry in this list - will be the preferred implementation and it can be initialized by calling - it's `__init__` function with a URL perviously obtained from - `BaseSQLDB.available_urls`. - - To control the lifetime of the SQLAlchemy engine used for connecting to the - database, which includes the connection pool, the `BaseSQLDB.engine_context` - asynchronous context manager should be entered. When inside this context - manager, the engine can be accessed with `BaseSQLDB.engine`. - - Upon entering, the DB class can then be used as an asynchronous context - manager to enter transactions. If an exception is raised the transaction is - rolled back automatically. If the inner context exits peacefully, the - transaction is committed automatically. When inside this context manager, - the DB connection can be accessed with `BaseSQLDB.conn`. - - For example: - - ```python - db_name = ... - url = BaseSQLDB.available_urls()[db_name] - MyDBClass = BaseSQLDB.available_implementations(db_name)[0] - - db = MyDBClass(url) - async with db.engine_context: - async with db: - # Do something in the first transaction - # Commit will be called automatically - - async with db: - # This transaction will be rolled back due to the exception - raise Exception(...) - ``` - """ - - # engine: AsyncEngine - # TODO: Make metadata an abstract property - metadata: MetaData - - def __init__(self, db_url: str) -> None: - # We use a ContextVar to make sure that self._conn - # is specific to each context, and avoid parallel - # route executions to overlap - self._conn: ContextVar[AsyncConnection | None] = ContextVar( - "_conn", default=None - ) - self._db_url = db_url - self._engine: AsyncEngine | None = None - - @classmethod - def available_implementations(cls, db_name: str) -> list[type[BaseSQLDB]]: - """Return the available implementations of the DB in reverse priority order.""" - db_classes: list[type[BaseSQLDB]] = [ - entry_point.load() - for entry_point in select_from_extension( - group="diracx.db.sql", name=db_name - ) - ] - if not db_classes: - raise NotImplementedError(f"Could not find any matches for {db_name=}") - return db_classes - - @classmethod - def available_urls(cls) -> dict[str, str]: - """Return a dict of available database urls. - - The list of available URLs is determined by environment variables - prefixed with ``DIRACX_DB_URL_{DB_NAME}``. - """ - db_urls: dict[str, str] = {} - for entry_point in select_from_extension(group="diracx.db.sql"): - db_name = entry_point.name - var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}" - if var_name in os.environ: - try: - db_url = os.environ[var_name] - if db_url == "sqlite+aiosqlite:///:memory:": - db_urls[db_name] = db_url - else: - db_urls[db_name] = str( - TypeAdapter(SqlalchemyDsn).validate_python(db_url) - ) - except Exception: - logger.error("Error loading URL for %s", db_name) - raise - return db_urls - - @classmethod - def transaction(cls) -> Self: - raise NotImplementedError("This should never be called") - - @property - def engine(self) -> AsyncEngine: - """The engine to use for database operations. - - It is normally not necessary to use the engine directly, unless you are - doing something special, like writing a test fixture that gives you a db. - - Requires that the engine_context has been entered. - """ - assert self._engine is not None, "engine_context must be entered" - return self._engine - - @contextlib.asynccontextmanager - async def engine_context(self) -> AsyncIterator[None]: - """Context manage to manage the engine lifecycle. - - This is called once at the application startup (see ``lifetime_functions``). - """ - assert self._engine is None, "engine_context cannot be nested" - - # Set the pool_recycle to 30mn - # That should prevent the problem of MySQL expiring connection - # after 60mn by default - engine = create_async_engine(self._db_url, pool_recycle=60 * 30) - self._engine = engine - try: - yield - finally: - self._engine = None - await engine.dispose() - - @property - def conn(self) -> AsyncConnection: - if self._conn.get() is None: - raise RuntimeError(f"{self.__class__} was used before entering") - return cast(AsyncConnection, self._conn.get()) - - async def __aenter__(self) -> Self: - """Create a connection. - - This is called by the Dependency mechanism (see ``db_transaction``), - It will create a new connection/transaction for each route call. - """ - assert self._conn.get() is None, "BaseSQLDB context cannot be nested" - try: - self._conn.set(await self.engine.connect().__aenter__()) - except Exception as e: - raise SQLDBUnavailableError( - f"Cannot connect to {self.__class__.__name__}" - ) from e - - return self - - async def __aexit__(self, exc_type, exc, tb): - """This is called when exiting a route. - - If there was no exception, the changes in the DB are committed. - Otherwise, they are rolled back. - """ - if exc_type is None: - await self._conn.get().commit() - await self._conn.get().__aexit__(exc_type, exc, tb) - self._conn.set(None) - - async def ping(self): - """Check whether the connection to the DB is still working. - - We could enable the ``pre_ping`` in the engine, but this would be ran at - every query. - """ - try: - await self.conn.scalar(select(1)) - except OperationalError as e: - raise SQLDBUnavailableError("Cannot ping the DB") from e - - -def find_time_resolution(value): - if isinstance(value, datetime): - return None, value - if match := re.fullmatch( - r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value - ): - if match.group(6): - precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6" - elif match.group(5): - precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5" - elif match.group(3): - precision, pattern = "HOUR", r"\1-\2-\3 \4" - elif match.group(2): - precision, pattern = "DAY", r"\1-\2-\3" - elif match.group(1): - precision, pattern = "MONTH", r"\1-\2" - else: - precision, pattern = "YEAR", r"\1" - return ( - precision, - re.sub( - r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$", - pattern, - value, - ), - ) - - raise InvalidQueryError(f"Cannot parse {value=}") - - -def apply_search_filters(column_mapping, stmt, search): - for query in search: - try: - column = column_mapping(query["parameter"]) - except KeyError as e: - raise InvalidQueryError(f"Unknown column {query['parameter']}") from e - - if isinstance(column.type, DateTime): - if "value" in query and isinstance(query["value"], str): - resolution, value = find_time_resolution(query["value"]) - if resolution: - column = DateTrunc(column, time_resolution=resolution) - query["value"] = value - - if query.get("values"): - resolutions, values = zip( - *map(find_time_resolution, query.get("values")) - ) - if len(set(resolutions)) != 1: - raise InvalidQueryError( - f"Cannot mix different time resolutions in {query=}" - ) - if resolution := resolutions[0]: - column = DateTrunc(column, time_resolution=resolution) - query["values"] = values - - if query["operator"] == "eq": - expr = column == query["value"] - elif query["operator"] == "neq": - expr = column != query["value"] - elif query["operator"] == "gt": - expr = column > query["value"] - elif query["operator"] == "lt": - expr = column < query["value"] - elif query["operator"] == "in": - expr = column.in_(query["values"]) - elif query["operator"] == "not in": - expr = column.notin_(query["values"]) - elif query["operator"] in "like": - expr = column.like(query["value"]) - elif query["operator"] in "ilike": - expr = column.ilike(query["value"]) - else: - raise InvalidQueryError(f"Unknown filter {query=}") - stmt = stmt.where(expr) - return stmt - - -def apply_sort_constraints(column_mapping, stmt, sorts): - sort_columns = [] - for sort in sorts or []: - try: - column = column_mapping(sort["parameter"]) - except KeyError as e: - raise InvalidQueryError( - f"Cannot sort by {sort['parameter']}: unknown column" - ) from e - sorted_column = None - if sort["direction"] == SortDirection.ASC: - sorted_column = column.asc() - elif sort["direction"] == SortDirection.DESC: - sorted_column = column.desc() - else: - raise InvalidQueryError(f"Unknown sort {sort['direction']=}") - sort_columns.append(sorted_column) - if sort_columns: - stmt = stmt.order_by(*sort_columns) - return stmt +from .base import ( + BaseSQLDB, + SQLDBUnavailableError, + apply_search_filters, + apply_sort_constraints, +) +from .functions import substract_date, utcnow +from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn + +__all__ = ( + "utcnow", + "Column", + "NullColumn", + "DateNowColumn", + "BaseSQLDB", + "EnumBackedBool", + "EnumColumn", + "apply_search_filters", + "apply_sort_constraints", + "substract_date", + "SQLDBUnavailableError", +) diff --git a/diracx-db/src/diracx/db/sql/utils/base.py b/diracx-db/src/diracx/db/sql/utils/base.py new file mode 100644 index 00000000..86f3be71 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/utils/base.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +import contextlib +import logging +import os +import re +from abc import ABCMeta +from collections.abc import AsyncIterator +from contextvars import ContextVar +from datetime import datetime +from typing import Self, cast + +from pydantic import TypeAdapter +from sqlalchemy import DateTime, MetaData, select +from sqlalchemy.exc import OperationalError +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, create_async_engine + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.extensions import select_from_extension +from diracx.core.models import SortDirection +from diracx.core.settings import SqlalchemyDsn +from diracx.db.exceptions import DBUnavailableError + +from .functions import date_trunc + +logger = logging.getLogger(__name__) + + +class SQLDBError(Exception): + pass + + +class SQLDBUnavailableError(DBUnavailableError, SQLDBError): + """Used whenever we encounter a problem with the B connection.""" + + +class BaseSQLDB(metaclass=ABCMeta): + """This should be the base class of all the SQL DiracX DBs. + + The details covered here should be handled automatically by the service and + task machinery of DiracX and this documentation exists for informational + purposes. + + The available databases are discovered by calling `BaseSQLDB.available_urls`. + This method returns a mapping of database names to connection URLs. The + available databases are determined by the `diracx.dbs.sql` entrypoint in the + `pyproject.toml` file and the connection URLs are taken from the environment + variables of the form `DIRACX_DB_URL_`. + + If extensions to DiracX are being used, there can be multiple implementations + of the same database. To list the available implementations use + `BaseSQLDB.available_implementations(db_name)`. The first entry in this list + will be the preferred implementation and it can be initialized by calling + it's `__init__` function with a URL perviously obtained from + `BaseSQLDB.available_urls`. + + To control the lifetime of the SQLAlchemy engine used for connecting to the + database, which includes the connection pool, the `BaseSQLDB.engine_context` + asynchronous context manager should be entered. When inside this context + manager, the engine can be accessed with `BaseSQLDB.engine`. + + Upon entering, the DB class can then be used as an asynchronous context + manager to enter transactions. If an exception is raised the transaction is + rolled back automatically. If the inner context exits peacefully, the + transaction is committed automatically. When inside this context manager, + the DB connection can be accessed with `BaseSQLDB.conn`. + + For example: + + ```python + db_name = ... + url = BaseSQLDB.available_urls()[db_name] + MyDBClass = BaseSQLDB.available_implementations(db_name)[0] + + db = MyDBClass(url) + async with db.engine_context: + async with db: + # Do something in the first transaction + # Commit will be called automatically + + async with db: + # This transaction will be rolled back due to the exception + raise Exception(...) + ``` + """ + + # engine: AsyncEngine + # TODO: Make metadata an abstract property + metadata: MetaData + + def __init__(self, db_url: str) -> None: + # We use a ContextVar to make sure that self._conn + # is specific to each context, and avoid parallel + # route executions to overlap + self._conn: ContextVar[AsyncConnection | None] = ContextVar( + "_conn", default=None + ) + self._db_url = db_url + self._engine: AsyncEngine | None = None + + @classmethod + def available_implementations(cls, db_name: str) -> list[type["BaseSQLDB"]]: + """Return the available implementations of the DB in reverse priority order.""" + db_classes: list[type[BaseSQLDB]] = [ + entry_point.load() + for entry_point in select_from_extension( + group="diracx.db.sql", name=db_name + ) + ] + if not db_classes: + raise NotImplementedError(f"Could not find any matches for {db_name=}") + return db_classes + + @classmethod + def available_urls(cls) -> dict[str, str]: + """Return a dict of available database urls. + + The list of available URLs is determined by environment variables + prefixed with ``DIRACX_DB_URL_{DB_NAME}``. + """ + db_urls: dict[str, str] = {} + for entry_point in select_from_extension(group="diracx.db.sql"): + db_name = entry_point.name + var_name = f"DIRACX_DB_URL_{entry_point.name.upper()}" + if var_name in os.environ: + try: + db_url = os.environ[var_name] + if db_url == "sqlite+aiosqlite:///:memory:": + db_urls[db_name] = db_url + else: + db_urls[db_name] = str( + TypeAdapter(SqlalchemyDsn).validate_python(db_url) + ) + except Exception: + logger.error("Error loading URL for %s", db_name) + raise + return db_urls + + @classmethod + def transaction(cls) -> Self: + raise NotImplementedError("This should never be called") + + @property + def engine(self) -> AsyncEngine: + """The engine to use for database operations. + + It is normally not necessary to use the engine directly, unless you are + doing something special, like writing a test fixture that gives you a db. + + Requires that the engine_context has been entered. + """ + assert self._engine is not None, "engine_context must be entered" + return self._engine + + @contextlib.asynccontextmanager + async def engine_context(self) -> AsyncIterator[None]: + """Context manage to manage the engine lifecycle. + + This is called once at the application startup (see ``lifetime_functions``). + """ + assert self._engine is None, "engine_context cannot be nested" + + # Set the pool_recycle to 30mn + # That should prevent the problem of MySQL expiring connection + # after 60mn by default + engine = create_async_engine(self._db_url, pool_recycle=60 * 30) + self._engine = engine + try: + yield + finally: + self._engine = None + await engine.dispose() + + @property + def conn(self) -> AsyncConnection: + if self._conn.get() is None: + raise RuntimeError(f"{self.__class__} was used before entering") + return cast(AsyncConnection, self._conn.get()) + + async def __aenter__(self) -> Self: + """Create a connection. + + This is called by the Dependency mechanism (see ``db_transaction``), + It will create a new connection/transaction for each route call. + """ + assert self._conn.get() is None, "BaseSQLDB context cannot be nested" + try: + self._conn.set(await self.engine.connect().__aenter__()) + except Exception as e: + raise SQLDBUnavailableError( + f"Cannot connect to {self.__class__.__name__}" + ) from e + + return self + + async def __aexit__(self, exc_type, exc, tb): + """This is called when exiting a route. + + If there was no exception, the changes in the DB are committed. + Otherwise, they are rolled back. + """ + if exc_type is None: + await self._conn.get().commit() + await self._conn.get().__aexit__(exc_type, exc, tb) + self._conn.set(None) + + async def ping(self): + """Check whether the connection to the DB is still working. + + We could enable the ``pre_ping`` in the engine, but this would be ran at + every query. + """ + try: + await self.conn.scalar(select(1)) + except OperationalError as e: + raise SQLDBUnavailableError("Cannot ping the DB") from e + + +def find_time_resolution(value): + if isinstance(value, datetime): + return None, value + if match := re.fullmatch( + r"\d{4}(-\d{2}(-\d{2}(([ T])\d{2}(:\d{2}(:\d{2}(\.\d{6}Z?)?)?)?)?)?)?", value + ): + if match.group(6): + precision, pattern = "SECOND", r"\1-\2-\3 \4:\5:\6" + elif match.group(5): + precision, pattern = "MINUTE", r"\1-\2-\3 \4:\5" + elif match.group(3): + precision, pattern = "HOUR", r"\1-\2-\3 \4" + elif match.group(2): + precision, pattern = "DAY", r"\1-\2-\3" + elif match.group(1): + precision, pattern = "MONTH", r"\1-\2" + else: + precision, pattern = "YEAR", r"\1" + return ( + precision, + re.sub( + r"^(\d{4})-?(\d{2})?-?(\d{2})?[ T]?(\d{2})?:?(\d{2})?:?(\d{2})?\.?(\d{6})?Z?$", + pattern, + value, + ), + ) + + raise InvalidQueryError(f"Cannot parse {value=}") + + +def apply_search_filters(column_mapping, stmt, search): + for query in search: + try: + column = column_mapping(query["parameter"]) + except KeyError as e: + raise InvalidQueryError(f"Unknown column {query['parameter']}") from e + + if isinstance(column.type, DateTime): + if "value" in query and isinstance(query["value"], str): + resolution, value = find_time_resolution(query["value"]) + if resolution: + column = date_trunc(column, time_resolution=resolution) + query["value"] = value + + if query.get("values"): + resolutions, values = zip( + *map(find_time_resolution, query.get("values")) + ) + if len(set(resolutions)) != 1: + raise InvalidQueryError( + f"Cannot mix different time resolutions in {query=}" + ) + if resolution := resolutions[0]: + column = date_trunc(column, time_resolution=resolution) + query["values"] = values + + if query["operator"] == "eq": + expr = column == query["value"] + elif query["operator"] == "neq": + expr = column != query["value"] + elif query["operator"] == "gt": + expr = column > query["value"] + elif query["operator"] == "lt": + expr = column < query["value"] + elif query["operator"] == "in": + expr = column.in_(query["values"]) + elif query["operator"] == "not in": + expr = column.notin_(query["values"]) + elif query["operator"] in "like": + expr = column.like(query["value"]) + elif query["operator"] in "ilike": + expr = column.ilike(query["value"]) + else: + raise InvalidQueryError(f"Unknown filter {query=}") + stmt = stmt.where(expr) + return stmt + + +def apply_sort_constraints(column_mapping, stmt, sorts): + sort_columns = [] + for sort in sorts or []: + try: + column = column_mapping(sort["parameter"]) + except KeyError as e: + raise InvalidQueryError( + f"Cannot sort by {sort['parameter']}: unknown column" + ) from e + sorted_column = None + if sort["direction"] == SortDirection.ASC: + sorted_column = column.asc() + elif sort["direction"] == SortDirection.DESC: + sorted_column = column.desc() + else: + raise InvalidQueryError(f"Unknown sort {sort['direction']=}") + sort_columns.append(sorted_column) + if sort_columns: + stmt = stmt.order_by(*sort_columns) + return stmt diff --git a/diracx-db/src/diracx/db/sql/utils/functions.py b/diracx-db/src/diracx/db/sql/utils/functions.py new file mode 100644 index 00000000..17e1d155 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/utils/functions.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from sqlalchemy import DateTime, func +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql import expression + +if TYPE_CHECKING: + from sqlalchemy.types import TypeEngine + + +class utcnow(expression.FunctionElement): # noqa: N801 + type: TypeEngine = DateTime() + inherit_cache: bool = True + + +@compiles(utcnow, "postgresql") +def pg_utcnow(element, compiler, **kw) -> str: + return "TIMEZONE('utc', CURRENT_TIMESTAMP)" + + +@compiles(utcnow, "mssql") +def ms_utcnow(element, compiler, **kw) -> str: + return "GETUTCDATE()" + + +@compiles(utcnow, "mysql") +def mysql_utcnow(element, compiler, **kw) -> str: + return "(UTC_TIMESTAMP)" + + +@compiles(utcnow, "sqlite") +def sqlite_utcnow(element, compiler, **kw) -> str: + return "DATETIME('now')" + + +class date_trunc(expression.FunctionElement): # noqa: N801 + """Sqlalchemy function to truncate a date to a given resolution. + + Primarily used to be able to query for a specific resolution of a date e.g. + + select * from table where date_trunc('day', date_column) = '2021-01-01' + select * from table where date_trunc('year', date_column) = '2021' + select * from table where date_trunc('minute', date_column) = '2021-01-01 12:00' + """ + + type = DateTime() + inherit_cache = True + + def __init__(self, *args, time_resolution, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._time_resolution = time_resolution + + +@compiles(date_trunc, "postgresql") +def pg_date_trunc(element, compiler, **kw): + res = { + "SECOND": "second", + "MINUTE": "minute", + "HOUR": "hour", + "DAY": "day", + "MONTH": "month", + "YEAR": "year", + }[element._time_resolution] + return f"date_trunc('{res}', {compiler.process(element.clauses)})" + + +@compiles(date_trunc, "mysql") +def mysql_date_trunc(element, compiler, **kw): + pattern = { + "SECOND": "%Y-%m-%d %H:%i:%S", + "MINUTE": "%Y-%m-%d %H:%i", + "HOUR": "%Y-%m-%d %H", + "DAY": "%Y-%m-%d", + "MONTH": "%Y-%m", + "YEAR": "%Y", + }[element._time_resolution] + + (dt_col,) = list(element.clauses) + return compiler.process(func.date_format(dt_col, pattern)) + + +@compiles(date_trunc, "sqlite") +def sqlite_date_trunc(element, compiler, **kw): + pattern = { + "SECOND": "%Y-%m-%d %H:%M:%S", + "MINUTE": "%Y-%m-%d %H:%M", + "HOUR": "%Y-%m-%d %H", + "DAY": "%Y-%m-%d", + "MONTH": "%Y-%m", + "YEAR": "%Y", + }[element._time_resolution] + (dt_col,) = list(element.clauses) + return compiler.process( + func.strftime( + pattern, + dt_col, + ) + ) + + +def substract_date(**kwargs: float) -> datetime: + return datetime.now(tz=timezone.utc) - timedelta(**kwargs) diff --git a/diracx-db/src/diracx/db/sql/utils/job.py b/diracx-db/src/diracx/db/sql/utils/job.py index 87763d45..3ffc587a 100644 --- a/diracx-db/src/diracx/db/sql/utils/job.py +++ b/diracx-db/src/diracx/db/sql/utils/job.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio from collections import defaultdict from copy import deepcopy diff --git a/diracx-db/src/diracx/db/sql/utils/types.py b/diracx-db/src/diracx/db/sql/utils/types.py new file mode 100644 index 00000000..58e56994 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/utils/types.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from functools import partial + +import sqlalchemy.types as types +from sqlalchemy import Column as RawColumn +from sqlalchemy import DateTime, Enum + +from .functions import utcnow + +Column: partial[RawColumn] = partial(RawColumn, nullable=False) +NullColumn: partial[RawColumn] = partial(RawColumn, nullable=True) +DateNowColumn = partial(Column, type_=DateTime(timezone=True), server_default=utcnow()) + + +def EnumColumn(name, enum_type, **kwargs): # noqa: N802 + return Column(name, Enum(enum_type, native_enum=False, length=16), **kwargs) + + +class EnumBackedBool(types.TypeDecorator): + """Maps a ``EnumBackedBool()`` column to True/False in Python.""" + + impl = types.Enum + cache_ok: bool = True + + def __init__(self) -> None: + super().__init__("True", "False") + + def process_bind_param(self, value, dialect) -> str: + if value is True: + return "True" + elif value is False: + return "False" + else: + raise NotImplementedError(value, dialect) + + def process_result_value(self, value, dialect) -> bool: + if value == "True": + return True + elif value == "False": + return False + else: + raise NotImplementedError(f"Unknown {value=}") diff --git a/diracx-db/tests/jobs/test_job_logging_db.py b/diracx-db/tests/jobs/test_job_logging_db.py index e720cec8..0e2f815f 100644 --- a/diracx-db/tests/jobs/test_job_logging_db.py +++ b/diracx-db/tests/jobs/test_job_logging_db.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import pytest diff --git a/diracx-routers/src/diracx/routers/__init__.py b/diracx-routers/src/diracx/routers/__init__.py index d17fbd8f..17785560 100644 --- a/diracx-routers/src/diracx/routers/__init__.py +++ b/diracx-routers/src/diracx/routers/__init__.py @@ -7,512 +7,6 @@ from __future__ import annotations -import inspect -import logging -import os -from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence -from functools import partial -from http import HTTPStatus -from importlib.metadata import EntryPoint, EntryPoints, entry_points -from logging import Formatter, StreamHandler -from typing import ( - Any, - TypeVar, - cast, -) +from .factory import DIRACX_MIN_CLIENT_VERSION, create_app, create_app_inner -import dotenv -from cachetools import TTLCache -from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status -from fastapi.dependencies.models import Dependant -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response -from fastapi.routing import APIRoute -from packaging.version import InvalidVersion, parse -from pydantic import TypeAdapter -from starlette.middleware.base import BaseHTTPMiddleware -from uvicorn.logging import AccessFormatter, DefaultFormatter - -from diracx.core.config import ConfigSource -from diracx.core.exceptions import DiracError, DiracHttpResponseError -from diracx.core.extensions import select_from_extension -from diracx.core.settings import ServiceSettingsBase -from diracx.core.utils import dotenv_files_from_environment -from diracx.db.exceptions import DBUnavailableError -from diracx.db.os.utils import BaseOSDB -from diracx.db.sql.utils import BaseSQLDB -from diracx.routers.access_policies import BaseAccessPolicy, check_permissions - -from .fastapi_classes import DiracFastAPI, DiracxRouter -from .otel import instrument_otel -from .utils.users import verify_dirac_access_token - -T = TypeVar("T") -T2 = TypeVar("T2", bound=BaseSQLDB | BaseOSDB) - - -logger = logging.getLogger(__name__) - - -DIRACX_MIN_CLIENT_VERSION = "0.0.1a1" - -###########################################3 - - -def configure_logger(): - """Configure the console logger. - - Access logs come from uvicorn, which configure its logger in a certain way - (https://github.com/tiangolo/fastapi/discussions/7457) - This method adds a timestamp to the uvicorn output, - and define a console handler for all the diracx loggers - We cannot configure just the root handler, as uvicorn - attaches handler to the `uvicorn` logger - """ - diracx_handler = StreamHandler() - diracx_handler.setFormatter(Formatter("%(asctime)s - %(levelname)s - %(message)s")) - logging.getLogger("diracx").addHandler(diracx_handler) - logging.getLogger("diracx").setLevel("INFO") - - # Recreate the formatters for the uvicorn loggers adding the timestamp - uvicorn_access_logger = logging.getLogger("uvicorn.access") - try: - previous_fmt = uvicorn_access_logger.handlers[0].formatter._fmt - new_format = f"%(asctime)s - {previous_fmt}" - uvicorn_access_logger.handlers[0].setFormatter(AccessFormatter(new_format)) - # There may not be any handler defined, like in the CI - except IndexError: - pass - - uvicorn_logger = logging.getLogger("uvicorn") - try: - previous_fmt = uvicorn_logger.handlers[0].formatter._fmt - new_format = f"%(asctime)s - {previous_fmt}" - uvicorn_logger.handlers[0].setFormatter(DefaultFormatter(new_format)) - # There may not be any handler defined, like in the CI - except IndexError: - pass - - -# Rules: -# All routes must have tags (needed for auto gen of client) -# Form headers must have a description (autogen) -# methods name should follow the generate_unique_id_function pattern -# All routes should have a policy mechanism - - -def create_app_inner( - *, - enabled_systems: set[str], - all_service_settings: Iterable[ServiceSettingsBase], - database_urls: dict[str, str], - os_database_conn_kwargs: dict[str, Any], - config_source: ConfigSource, - all_access_policies: dict[str, Sequence[BaseAccessPolicy]], -) -> DiracFastAPI: - """This method does the heavy lifting work of putting all the pieces together. - - When starting the application normaly, this method is called by create_app, - and the values of the parameters are taken from environment variables or - entrypoints. - - When running tests, the parameters are mocks or test settings. - - We rely on the dependency_override mechanism to implement - the actual behavior we are interested in for settings, DBs or policy. - This allows an extension to override any of these components - - - :param enabled_system: - this contains the name of all the routers we have to load - :param all_service_settings: - list of instance of each Settings type required - :param database_urls: - dict . When testing, sqlite urls are used - :param os_database_conn_kwargs: - containing all the parameters the OpenSearch client takes - :param config_source: - Source of the configuration to use - :param all_access_policies: - - - - """ - app = DiracFastAPI() - - # Find which settings classes are available and add them to dependency_overrides - # We use a single instance of each Setting classes for performance reasons, - # since it avoids recreating a pydantic model every time - # We add the Settings lifetime_function to the application lifetime_function, - # Please see ServiceSettingsBase for more details - - available_settings_classes: set[type[ServiceSettingsBase]] = set() - - for service_settings in all_service_settings: - cls = type(service_settings) - assert cls not in available_settings_classes - available_settings_classes.add(cls) - app.lifetime_functions.append(service_settings.lifetime_function) - # We always return the same setting instance for perf reasons - app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings) - - # Override the ConfigSource.create by the actual reading of the config - app.dependency_overrides[ConfigSource.create] = config_source.read_config - - all_access_policies_used = {} - - for access_policy_name, access_policy_classes in all_access_policies.items(): - - # The first AccessPolicy is the highest priority one - access_policy_used = access_policy_classes[0].policy - all_access_policies_used[access_policy_name] = access_policy_classes[0] - - # app.lifetime_functions.append(access_policy.lifetime_function) - # Add overrides for all the AccessPolicy classes, including those from extensions - # This means vanilla DiracX routers get an instance of the extension's AccessPolicy - for access_policy_class in access_policy_classes: - # Here we do not check that access_policy_class.check is - # not already in the dependency_overrides becaue the same - # policy could be used for multiple purpose - # (e.g. open access) - # assert access_policy_class.check not in app.dependency_overrides - app.dependency_overrides[access_policy_class.check] = partial( - check_permissions, access_policy_used, access_policy_name - ) - - app.dependency_overrides[BaseAccessPolicy.all_used_access_policies] = ( - lambda: all_access_policies_used - ) - - fail_startup = True - # Add the SQL DBs to the application - available_sql_db_classes: set[type[BaseSQLDB]] = set() - - for db_name, db_url in database_urls.items(): - - try: - sql_db_classes = BaseSQLDB.available_implementations(db_name) - - # The first DB is the highest priority one - sql_db = sql_db_classes[0](db_url=db_url) - - app.lifetime_functions.append(sql_db.engine_context) - # Add overrides for all the DB classes, including those from extensions - # This means vanilla DiracX routers get an instance of the extension's DB - for sql_db_class in sql_db_classes: - assert sql_db_class.transaction not in app.dependency_overrides - available_sql_db_classes.add(sql_db_class) - - app.dependency_overrides[sql_db_class.transaction] = partial( - db_transaction, sql_db - ) - - # At least one DB works, so we do not fail the startup - fail_startup = False - except Exception: - logger.exception("Failed to initialize DB %s", db_name) - - if fail_startup: - raise Exception("No SQL database could be initialized, aborting") - - # Add the OpenSearch DBs to the application - available_os_db_classes: set[type[BaseOSDB]] = set() - for db_name, connection_kwargs in os_database_conn_kwargs.items(): - os_db_classes = BaseOSDB.available_implementations(db_name) - # The first DB is the highest priority one - os_db = os_db_classes[0](connection_kwargs=connection_kwargs) - app.lifetime_functions.append(os_db.client_context) - # Add overrides for all the DB classes, including those from extensions - # This means vanilla DiracX routers get an instance of the extension's DB - for os_db_class in os_db_classes: - assert os_db_class.session not in app.dependency_overrides - available_os_db_classes.add(os_db_class) - app.dependency_overrides[os_db_class.session] = partial( - db_transaction, os_db - ) - - # Load the requested routers - routers: dict[str, APIRouter] = {} - # The enabled systems must be sorted to ensure the openapi.json is deterministic - # Without this AutoREST generates different client sources for each ordering - for system_name in sorted(enabled_systems): - assert system_name not in routers - for entry_point in select_from_extension( - group="diracx.services", name=system_name - ): - routers[system_name] = entry_point.load() - break - else: - raise NotImplementedError(f"Could not find {system_name=}") - - # Add routers ensuring that all the required settings are available - for system_name, router in routers.items(): - # Ensure required settings are available - for cls in find_dependents(router, ServiceSettingsBase): - if cls not in available_settings_classes: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {cls=}" - ) - - # Ensure required DBs are available - missing_sql_dbs = ( - set(find_dependents(router, BaseSQLDB)) - available_sql_db_classes - ) - - if missing_sql_dbs: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {missing_sql_dbs=}" - ) - missing_os_dbs = ( - set(find_dependents(router, BaseOSDB)) # type: ignore[type-abstract] - - available_os_db_classes - ) - if missing_os_dbs: - raise NotImplementedError( - f"Cannot enable {system_name=} as it requires {missing_os_dbs=}" - ) - - # Add the router to the application - dependencies = [] - if isinstance(router, DiracxRouter) and router.diracx_require_auth: - dependencies.append(Depends(verify_dirac_access_token)) - # Most routers are mounted under /api/ - path_root = getattr(router, "diracx_path_root", "/api") - app.include_router( - router, - prefix=f"{path_root}/{system_name}", - tags=[system_name], - dependencies=dependencies, - ) - - # Add exception handlers - # We need to cast because callables are contravariant and we define our exception handlers - # with a subclass of Exception (https://mypy.readthedocs.io/en/latest/generics.html#variance-of-generic-types) - handler_signature = Callable[[Request, Exception], Response | Awaitable[Response]] - app.add_exception_handler(DiracError, cast(handler_signature, dirac_error_handler)) - app.add_exception_handler( - DiracHttpResponseError, cast(handler_signature, http_response_handler) - ) - app.add_exception_handler( - DBUnavailableError, cast(handler_signature, route_unavailable_error_hander) - ) - - # TODO: remove the CORSMiddleware once we figure out how to launch - # diracx and diracx-web under the same origin - origins = [ - "http://localhost:8000", - ] - - app.add_middleware(ClientMinVersionCheckMiddleware) - - app.add_middleware( - CORSMiddleware, - allow_origins=origins, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], - ) - - configure_logger() - instrument_otel(app) - - return app - - -def create_app() -> DiracFastAPI: - """Load settings from the environment and create the application object. - - The configuration may be placed in .env files pointed to by - environment variables DIRACX_SERVICE_DOTENV. - They can be followed by "_X" where X is a number, and the order - is respected. - - We then loop over all the diracx.services definitions. - A specific route can be disabled with an environment variable - DIRACX_SERVICE__ENABLED=false - For each of the enabled route, we inspect which Setting classes - are needed. - - We attempt to load each setting classes to make sure that the - settings are correctly defined. - """ - for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"): - logger.debug("Loading dotenv file: %s", env_file) - if not dotenv.load_dotenv(env_file): - raise NotImplementedError(f"Could not load dotenv file {env_file}") - - # Load all available routers - enabled_systems = set() - settings_classes = set() - for entry_point in select_from_extension(group="diracx.services"): - env_var = f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED" - enabled = TypeAdapter(bool).validate_json(os.environ.get(env_var, "true")) - logger.debug("Found service %r: enabled=%s", entry_point, enabled) - if not enabled: - continue - router: APIRouter = entry_point.load() - enabled_systems.add(entry_point.name) - dependencies = set(find_dependents(router, ServiceSettingsBase)) - logger.debug("Found dependencies for %r: enabled=%s", entry_point, dependencies) - settings_classes |= dependencies - - # Load settings classes required by the routers - all_service_settings = [settings_class() for settings_class in settings_classes] - - # Find all the access policies - - available_access_policy_names = { - entry_point.name - for entry_point in select_from_extension(group="diracx.access_policies") - } - - all_access_policies = {} - - for access_policy_name in available_access_policy_names: - - access_policy_classes = BaseAccessPolicy.available_implementations( - access_policy_name - ) - all_access_policies[access_policy_name] = access_policy_classes - - return create_app_inner( - enabled_systems=enabled_systems, - all_service_settings=all_service_settings, - database_urls=BaseSQLDB.available_urls(), - os_database_conn_kwargs=BaseOSDB.available_urls(), - config_source=ConfigSource.create(), - all_access_policies=all_access_policies, - ) - - -def dirac_error_handler(request: Request, exc: DiracError) -> Response: - return JSONResponse( - status_code=exc.http_status_code, - content={"detail": exc.detail}, - headers=exc.http_headers, - ) - - -def http_response_handler(request: Request, exc: DiracHttpResponseError) -> Response: - return JSONResponse(status_code=exc.status_code, content=exc.data) - - -def route_unavailable_error_hander(request: Request, exc: DBUnavailableError): - return JSONResponse( - status_code=status.HTTP_503_SERVICE_UNAVAILABLE, - headers={"Retry-After": "10"}, - content={"detail": str(exc.args)}, - ) - - -def find_dependents( - obj: APIRouter | Iterable[Dependant], cls: type[T] -) -> Iterable[type[T]]: - if isinstance(obj, APIRouter): - # TODO: Support dependencies of the router itself - # yield from find_dependents(obj.dependencies, cls) - for route in obj.routes: - if isinstance(route, APIRoute): - yield from find_dependents(route.dependant.dependencies, cls) - return - - for dependency in obj: - bound_class = getattr(dependency.call, "__self__", None) - if inspect.isclass(bound_class) and issubclass(bound_class, cls): - yield bound_class - yield from find_dependents(dependency.dependencies, cls) - - -_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10) - - -async def is_db_unavailable(db: BaseSQLDB | BaseOSDB) -> str: - """Cache the result of pinging the DB - (exceptions are not cachable). - """ - if db not in _db_alive_cache: - try: - await db.ping() - _db_alive_cache[db] = "" - - except DBUnavailableError as e: - _db_alive_cache[db] = e.args[0] - - return _db_alive_cache[db] - - -async def db_transaction(db: T2) -> AsyncGenerator[T2]: - """Initiate a DB transaction.""" - # Entering the context already triggers a connection to the DB - # that may fail - async with db: - # Check whether the connection still works before executing the query - if reason := await is_db_unavailable(db): - raise DBUnavailableError(reason) - yield db - - -class ClientMinVersionCheckMiddleware(BaseHTTPMiddleware): - """Custom FastAPI middleware to verify that - the client has the required minimum version. - """ - - def __init__(self, app: FastAPI): - super().__init__(app) - self.min_client_version = get_min_client_version() - self.parsed_min_client_version = parse(self.min_client_version) - - async def dispatch(self, request: Request, call_next) -> Response: - client_version = request.headers.get("DiracX-Client-Version") - - try: - if client_version and self.is_version_too_old(client_version): - # When comes from Swagger or Web, there is no client version header. - # This is not managed here. - - raise HTTPException( - status_code=HTTPStatus.UPGRADE_REQUIRED, - detail=f"Client version ({client_version})" - f"not recent enough (>= {self.min_client_version})." - "Upgrade.", - ) - except HTTPException as exc: - # Return a JSONResponse because the HTTPException - # is not handled nicely in the middleware - logger.error("Error checking client version %s", client_version) - return JSONResponse( - status_code=exc.status_code, - content={"detail": exc.detail}, - ) - # If the version is not given - except Exception: # noqa: S110 - pass - - response = await call_next(request) - return response - - def is_version_too_old(self, client_version: str) -> bool | None: - """Verify that client version is ge than min.""" - try: - return parse(client_version) < self.parsed_min_client_version - except InvalidVersion as iv_exc: - raise HTTPException( - status_code=HTTPStatus.BAD_REQUEST, - detail=f"Invalid version string: '{client_version}'", - ) from iv_exc - - -def get_min_client_version(): - """Extracting min client version from entry_points and searching for extension.""" - matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version") - # Searching for an extension: - entry_points_dict: dict[str, EntryPoint] = { - ep.name: ep for ep in matched_entry_points - } - for ep_name, ep in entry_points_dict.items(): - if ep_name != "diracx": - return ep.load() - - # Taking diracx if no extension: - if "diracx" in entry_points_dict: - return entry_points_dict["diracx"].load() +__all__ = ("create_app", "create_app_inner", "DIRACX_MIN_CLIENT_VERSION") diff --git a/diracx-routers/src/diracx/routers/access_policies.py b/diracx-routers/src/diracx/routers/access_policies.py index 9276b05f..a2bf007b 100644 --- a/diracx-routers/src/diracx/routers/access_policies.py +++ b/diracx-routers/src/diracx/routers/access_policies.py @@ -2,7 +2,7 @@ We define a set of Policy classes (WMS, DFC, etc). They have a default implementation in diracx. -If an extension wants to change it, it can be overwriten in the entry point +If an extension wants to change it, it can be overwritten in the entry point diracx.access_policies Each route should either: @@ -30,11 +30,13 @@ from diracx.routers.dependencies import DevelopmentSettings from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token -# FastAPI bug: -# We normally would use `from __future__ import annotations` -# but a bug in FastAPI prevents us from doing so -# https://github.com/tiangolo/fastapi/pull/11355 -# Until it is merged, we can work around it by using strings. +if "annotations" in globals(): + raise NotImplementedError( + "FastAPI bug: We normally would use `from __future__ import annotations` " + "but a bug in FastAPI prevents us from doing so " + "https://github.com/tiangolo/fastapi/pull/11355 " + "Until it is merged, we can work around it by using strings." + ) class BaseAccessPolicy(metaclass=ABCMeta): diff --git a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py index f57aaa43..b52172ed 100644 --- a/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py +++ b/diracx-routers/src/diracx/routers/auth/authorize_code_flow.py @@ -32,6 +32,8 @@ * The client can then use the access token to access the DIRAC services. """ +from __future__ import annotations + from typing import Literal from fastapi import ( diff --git a/diracx-routers/src/diracx/routers/auth/device_flow.py b/diracx-routers/src/diracx/routers/auth/device_flow.py index 3fa47bf6..886f1cf2 100644 --- a/diracx-routers/src/diracx/routers/auth/device_flow.py +++ b/diracx-routers/src/diracx/routers/auth/device_flow.py @@ -52,6 +52,8 @@ * The client can then use the access token to access the DIRAC services. """ +from __future__ import annotations + from fastapi import ( HTTPException, Request, diff --git a/diracx-routers/src/diracx/routers/auth/management.py b/diracx-routers/src/diracx/routers/auth/management.py index 7bd7c1b9..04b52f6c 100644 --- a/diracx-routers/src/diracx/routers/auth/management.py +++ b/diracx-routers/src/diracx/routers/auth/management.py @@ -4,6 +4,8 @@ to get information about the user's identity. """ +from __future__ import annotations + from typing import Annotated, Any from fastapi import ( diff --git a/diracx-routers/src/diracx/routers/auth/token.py b/diracx-routers/src/diracx/routers/auth/token.py index d21416ea..d171d4ac 100644 --- a/diracx-routers/src/diracx/routers/auth/token.py +++ b/diracx-routers/src/diracx/routers/auth/token.py @@ -1,5 +1,7 @@ """Token endpoint implementation.""" +from __future__ import annotations + import base64 import hashlib import os diff --git a/diracx-routers/src/diracx/routers/auth/utils.py b/diracx-routers/src/diracx/routers/auth/utils.py index 7ca8b523..4b1ca649 100644 --- a/diracx-routers/src/diracx/routers/auth/utils.py +++ b/diracx-routers/src/diracx/routers/auth/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import json diff --git a/diracx-routers/src/diracx/routers/factory.py b/diracx-routers/src/diracx/routers/factory.py new file mode 100644 index 00000000..c48aa2c7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/factory.py @@ -0,0 +1,515 @@ +"""Logic for creating and configuring the FastAPI application.""" + +from __future__ import annotations + +import inspect +import logging +import os +from collections.abc import AsyncGenerator, Awaitable, Callable, Iterable, Sequence +from functools import partial +from http import HTTPStatus +from importlib.metadata import EntryPoint, EntryPoints, entry_points +from logging import Formatter, StreamHandler +from typing import ( + Any, + TypeVar, + cast, +) + +import dotenv +from cachetools import TTLCache +from fastapi import APIRouter, Depends, FastAPI, HTTPException, Request, status +from fastapi.dependencies.models import Dependant +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import JSONResponse, Response +from fastapi.routing import APIRoute +from packaging.version import InvalidVersion, parse +from pydantic import TypeAdapter +from starlette.middleware.base import BaseHTTPMiddleware +from uvicorn.logging import AccessFormatter, DefaultFormatter + +from diracx.core.config import ConfigSource +from diracx.core.exceptions import DiracError, DiracHttpResponseError +from diracx.core.extensions import select_from_extension +from diracx.core.settings import ServiceSettingsBase +from diracx.core.utils import dotenv_files_from_environment +from diracx.db.exceptions import DBUnavailableError +from diracx.db.os.utils import BaseOSDB +from diracx.db.sql.utils import BaseSQLDB +from diracx.routers.access_policies import BaseAccessPolicy, check_permissions + +from .fastapi_classes import DiracFastAPI, DiracxRouter +from .otel import instrument_otel +from .utils.users import verify_dirac_access_token + +T = TypeVar("T") +T2 = TypeVar("T2", bound=BaseSQLDB | BaseOSDB) + + +logger = logging.getLogger(__name__) + + +DIRACX_MIN_CLIENT_VERSION = "0.0.1a1" + +###########################################3 + + +def configure_logger(): + """Configure the console logger. + + Access logs come from uvicorn, which configure its logger in a certain way + (https://github.com/tiangolo/fastapi/discussions/7457) + This method adds a timestamp to the uvicorn output, + and define a console handler for all the diracx loggers + We cannot configure just the root handler, as uvicorn + attaches handler to the `uvicorn` logger + """ + diracx_handler = StreamHandler() + diracx_handler.setFormatter(Formatter("%(asctime)s - %(levelname)s - %(message)s")) + logging.getLogger("diracx").addHandler(diracx_handler) + logging.getLogger("diracx").setLevel("INFO") + + # Recreate the formatters for the uvicorn loggers adding the timestamp + uvicorn_access_logger = logging.getLogger("uvicorn.access") + try: + previous_fmt = uvicorn_access_logger.handlers[0].formatter._fmt + new_format = f"%(asctime)s - {previous_fmt}" + uvicorn_access_logger.handlers[0].setFormatter(AccessFormatter(new_format)) + # There may not be any handler defined, like in the CI + except IndexError: + pass + + uvicorn_logger = logging.getLogger("uvicorn") + try: + previous_fmt = uvicorn_logger.handlers[0].formatter._fmt + new_format = f"%(asctime)s - {previous_fmt}" + uvicorn_logger.handlers[0].setFormatter(DefaultFormatter(new_format)) + # There may not be any handler defined, like in the CI + except IndexError: + pass + + +# Rules: +# All routes must have tags (needed for auto gen of client) +# Form headers must have a description (autogen) +# methods name should follow the generate_unique_id_function pattern +# All routes should have a policy mechanism + + +def create_app_inner( + *, + enabled_systems: set[str], + all_service_settings: Iterable[ServiceSettingsBase], + database_urls: dict[str, str], + os_database_conn_kwargs: dict[str, Any], + config_source: ConfigSource, + all_access_policies: dict[str, Sequence[BaseAccessPolicy]], +) -> DiracFastAPI: + """This method does the heavy lifting work of putting all the pieces together. + + When starting the application normaly, this method is called by create_app, + and the values of the parameters are taken from environment variables or + entrypoints. + + When running tests, the parameters are mocks or test settings. + + We rely on the dependency_override mechanism to implement + the actual behavior we are interested in for settings, DBs or policy. + This allows an extension to override any of these components + + + :param enabled_system: + this contains the name of all the routers we have to load + :param all_service_settings: + list of instance of each Settings type required + :param database_urls: + dict . When testing, sqlite urls are used + :param os_database_conn_kwargs: + containing all the parameters the OpenSearch client takes + :param config_source: + Source of the configuration to use + :param all_access_policies: + + + + """ + app = DiracFastAPI() + + # Find which settings classes are available and add them to dependency_overrides + # We use a single instance of each Setting classes for performance reasons, + # since it avoids recreating a pydantic model every time + # We add the Settings lifetime_function to the application lifetime_function, + # Please see ServiceSettingsBase for more details + + available_settings_classes: set[type[ServiceSettingsBase]] = set() + + for service_settings in all_service_settings: + cls = type(service_settings) + assert cls not in available_settings_classes + available_settings_classes.add(cls) + app.lifetime_functions.append(service_settings.lifetime_function) + # We always return the same setting instance for perf reasons + app.dependency_overrides[cls.create] = partial(lambda x: x, service_settings) + + # Override the ConfigSource.create by the actual reading of the config + app.dependency_overrides[ConfigSource.create] = config_source.read_config + + all_access_policies_used = {} + + for access_policy_name, access_policy_classes in all_access_policies.items(): + + # The first AccessPolicy is the highest priority one + access_policy_used = access_policy_classes[0].policy + all_access_policies_used[access_policy_name] = access_policy_classes[0] + + # app.lifetime_functions.append(access_policy.lifetime_function) + # Add overrides for all the AccessPolicy classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's AccessPolicy + for access_policy_class in access_policy_classes: + # Here we do not check that access_policy_class.check is + # not already in the dependency_overrides becaue the same + # policy could be used for multiple purpose + # (e.g. open access) + # assert access_policy_class.check not in app.dependency_overrides + app.dependency_overrides[access_policy_class.check] = partial( + check_permissions, + policy=access_policy_used, + policy_name=access_policy_name, + ) + + app.dependency_overrides[BaseAccessPolicy.all_used_access_policies] = ( + lambda: all_access_policies_used + ) + + fail_startup = True + # Add the SQL DBs to the application + available_sql_db_classes: set[type[BaseSQLDB]] = set() + + for db_name, db_url in database_urls.items(): + + try: + sql_db_classes = BaseSQLDB.available_implementations(db_name) + + # The first DB is the highest priority one + sql_db = sql_db_classes[0](db_url=db_url) + + app.lifetime_functions.append(sql_db.engine_context) + # Add overrides for all the DB classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's DB + for sql_db_class in sql_db_classes: + assert sql_db_class.transaction not in app.dependency_overrides + available_sql_db_classes.add(sql_db_class) + + app.dependency_overrides[sql_db_class.transaction] = partial( + db_transaction, sql_db + ) + + # At least one DB works, so we do not fail the startup + fail_startup = False + except Exception: + logger.exception("Failed to initialize DB %s", db_name) + + if fail_startup: + raise Exception("No SQL database could be initialized, aborting") + + # Add the OpenSearch DBs to the application + available_os_db_classes: set[type[BaseOSDB]] = set() + for db_name, connection_kwargs in os_database_conn_kwargs.items(): + os_db_classes = BaseOSDB.available_implementations(db_name) + # The first DB is the highest priority one + os_db = os_db_classes[0](connection_kwargs=connection_kwargs) + app.lifetime_functions.append(os_db.client_context) + # Add overrides for all the DB classes, including those from extensions + # This means vanilla DiracX routers get an instance of the extension's DB + for os_db_class in os_db_classes: + assert os_db_class.session not in app.dependency_overrides + available_os_db_classes.add(os_db_class) + app.dependency_overrides[os_db_class.session] = partial( + db_transaction, os_db + ) + + # Load the requested routers + routers: dict[str, APIRouter] = {} + # The enabled systems must be sorted to ensure the openapi.json is deterministic + # Without this AutoREST generates different client sources for each ordering + for system_name in sorted(enabled_systems): + assert system_name not in routers + for entry_point in select_from_extension( + group="diracx.services", name=system_name + ): + routers[system_name] = entry_point.load() + break + else: + raise NotImplementedError(f"Could not find {system_name=}") + + # Add routers ensuring that all the required settings are available + for system_name, router in routers.items(): + # Ensure required settings are available + for cls in find_dependents(router, ServiceSettingsBase): + if cls not in available_settings_classes: + raise NotImplementedError( + f"Cannot enable {system_name=} as it requires {cls=}" + ) + + # Ensure required DBs are available + missing_sql_dbs = ( + set(find_dependents(router, BaseSQLDB)) - available_sql_db_classes + ) + + if missing_sql_dbs: + raise NotImplementedError( + f"Cannot enable {system_name=} as it requires {missing_sql_dbs=}" + ) + missing_os_dbs = ( + set(find_dependents(router, BaseOSDB)) # type: ignore[type-abstract] + - available_os_db_classes + ) + if missing_os_dbs: + raise NotImplementedError( + f"Cannot enable {system_name=} as it requires {missing_os_dbs=}" + ) + + # Add the router to the application + dependencies = [] + if isinstance(router, DiracxRouter) and router.diracx_require_auth: + dependencies.append(Depends(verify_dirac_access_token)) + # Most routers are mounted under /api/ + path_root = getattr(router, "diracx_path_root", "/api") + app.include_router( + router, + prefix=f"{path_root}/{system_name}", + tags=[system_name], + dependencies=dependencies, + ) + + # Add exception handlers + # We need to cast because callables are contravariant and we define our exception handlers + # with a subclass of Exception (https://mypy.readthedocs.io/en/latest/generics.html#variance-of-generic-types) + handler_signature = Callable[[Request, Exception], Response | Awaitable[Response]] + app.add_exception_handler(DiracError, cast(handler_signature, dirac_error_handler)) + app.add_exception_handler( + DiracHttpResponseError, cast(handler_signature, http_response_handler) + ) + app.add_exception_handler( + DBUnavailableError, cast(handler_signature, route_unavailable_error_hander) + ) + + # TODO: remove the CORSMiddleware once we figure out how to launch + # diracx and diracx-web under the same origin + origins = [ + "http://localhost:8000", + ] + + app.add_middleware(ClientMinVersionCheckMiddleware) + + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + configure_logger() + instrument_otel(app) + + return app + + +def create_app() -> DiracFastAPI: + """Load settings from the environment and create the application object. + + The configuration may be placed in .env files pointed to by + environment variables DIRACX_SERVICE_DOTENV. + They can be followed by "_X" where X is a number, and the order + is respected. + + We then loop over all the diracx.services definitions. + A specific route can be disabled with an environment variable + DIRACX_SERVICE__ENABLED=false + For each of the enabled route, we inspect which Setting classes + are needed. + + We attempt to load each setting classes to make sure that the + settings are correctly defined. + """ + for env_file in dotenv_files_from_environment("DIRACX_SERVICE_DOTENV"): + logger.debug("Loading dotenv file: %s", env_file) + if not dotenv.load_dotenv(env_file): + raise NotImplementedError(f"Could not load dotenv file {env_file}") + + # Load all available routers + enabled_systems = set() + settings_classes = set() + for entry_point in select_from_extension(group="diracx.services"): + env_var = f"DIRACX_SERVICE_{entry_point.name.upper()}_ENABLED" + enabled = TypeAdapter(bool).validate_json(os.environ.get(env_var, "true")) + logger.debug("Found service %r: enabled=%s", entry_point, enabled) + if not enabled: + continue + router: APIRouter = entry_point.load() + enabled_systems.add(entry_point.name) + dependencies = set(find_dependents(router, ServiceSettingsBase)) + logger.debug("Found dependencies for %r: enabled=%s", entry_point, dependencies) + settings_classes |= dependencies + + # Load settings classes required by the routers + all_service_settings = [settings_class() for settings_class in settings_classes] + + # Find all the access policies + + available_access_policy_names = { + entry_point.name + for entry_point in select_from_extension(group="diracx.access_policies") + } + + all_access_policies = {} + + for access_policy_name in available_access_policy_names: + + access_policy_classes = BaseAccessPolicy.available_implementations( + access_policy_name + ) + all_access_policies[access_policy_name] = access_policy_classes + + return create_app_inner( + enabled_systems=enabled_systems, + all_service_settings=all_service_settings, + database_urls=BaseSQLDB.available_urls(), + os_database_conn_kwargs=BaseOSDB.available_urls(), + config_source=ConfigSource.create(), + all_access_policies=all_access_policies, + ) + + +def dirac_error_handler(request: Request, exc: DiracError) -> Response: + return JSONResponse( + status_code=exc.http_status_code, + content={"detail": exc.detail}, + headers=exc.http_headers, + ) + + +def http_response_handler(request: Request, exc: DiracHttpResponseError) -> Response: + return JSONResponse(status_code=exc.status_code, content=exc.data) + + +def route_unavailable_error_hander(request: Request, exc: DBUnavailableError): + return JSONResponse( + status_code=status.HTTP_503_SERVICE_UNAVAILABLE, + headers={"Retry-After": "10"}, + content={"detail": str(exc.args)}, + ) + + +def find_dependents( + obj: APIRouter | Iterable[Dependant], cls: type[T] +) -> Iterable[type[T]]: + if isinstance(obj, APIRouter): + # TODO: Support dependencies of the router itself + # yield from find_dependents(obj.dependencies, cls) + for route in obj.routes: + if isinstance(route, APIRoute): + yield from find_dependents(route.dependant.dependencies, cls) + return + + for dependency in obj: + bound_class = getattr(dependency.call, "__self__", None) + if inspect.isclass(bound_class) and issubclass(bound_class, cls): + yield bound_class + yield from find_dependents(dependency.dependencies, cls) + + +_db_alive_cache: TTLCache = TTLCache(maxsize=1024, ttl=10) + + +async def is_db_unavailable(db: BaseSQLDB | BaseOSDB) -> str: + """Cache the result of pinging the DB + (exceptions are not cachable). + """ + if db not in _db_alive_cache: + try: + await db.ping() + _db_alive_cache[db] = "" + + except DBUnavailableError as e: + _db_alive_cache[db] = e.args[0] + + return _db_alive_cache[db] + + +async def db_transaction(db: T2) -> AsyncGenerator[T2]: + """Initiate a DB transaction.""" + # Entering the context already triggers a connection to the DB + # that may fail + async with db: + # Check whether the connection still works before executing the query + if reason := await is_db_unavailable(db): + raise DBUnavailableError(reason) + yield db + + +class ClientMinVersionCheckMiddleware(BaseHTTPMiddleware): + """Custom FastAPI middleware to verify that + the client has the required minimum version. + """ + + def __init__(self, app: FastAPI): + super().__init__(app) + self.min_client_version = get_min_client_version() + self.parsed_min_client_version = parse(self.min_client_version) + + async def dispatch(self, request: Request, call_next) -> Response: + client_version = request.headers.get("DiracX-Client-Version") + + try: + if client_version and self.is_version_too_old(client_version): + # When comes from Swagger or Web, there is no client version header. + # This is not managed here. + + raise HTTPException( + status_code=HTTPStatus.UPGRADE_REQUIRED, + detail=f"Client version ({client_version})" + f"not recent enough (>= {self.min_client_version})." + "Upgrade.", + ) + except HTTPException as exc: + # Return a JSONResponse because the HTTPException + # is not handled nicely in the middleware + logger.error("Error checking client version %s", client_version) + return JSONResponse( + status_code=exc.status_code, + content={"detail": exc.detail}, + ) + # If the version is not given + except Exception: # noqa: S110 + pass + + response = await call_next(request) + return response + + def is_version_too_old(self, client_version: str) -> bool | None: + """Verify that client version is ge than min.""" + try: + return parse(client_version) < self.parsed_min_client_version + except InvalidVersion as iv_exc: + raise HTTPException( + status_code=HTTPStatus.BAD_REQUEST, + detail=f"Invalid version string: '{client_version}'", + ) from iv_exc + + +def get_min_client_version(): + """Extracting min client version from entry_points and searching for extension.""" + matched_entry_points: EntryPoints = entry_points(group="diracx.min_client_version") + # Searching for an extension: + entry_points_dict: dict[str, EntryPoint] = { + ep.name: ep for ep in matched_entry_points + } + for ep_name, ep in entry_points_dict.items(): + if ep_name != "diracx": + return ep.load() + + # Taking diracx if no extension: + if "diracx" in entry_points_dict: + return entry_points_dict["diracx"].load() diff --git a/diracx-routers/src/diracx/routers/otel.py b/diracx-routers/src/diracx/routers/otel.py index cddf601e..2cb256d7 100644 --- a/diracx-routers/src/diracx/routers/otel.py +++ b/diracx-routers/src/diracx/routers/otel.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import logging import os diff --git a/diracx-routers/src/diracx/routers/utils/__init__.py b/diracx-routers/src/diracx/routers/utils/__init__.py index ac655cf1..e69de29b 100644 --- a/diracx-routers/src/diracx/routers/utils/__init__.py +++ b/diracx-routers/src/diracx/routers/utils/__init__.py @@ -1,8 +0,0 @@ -from asyncio import TaskGroup - - -class ForgivingTaskGroup(TaskGroup): - # Hacky way, check https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks - # Basically e're using this because we want to wait for all tasks to finish, even if one of them raises an exception - def _abort(self): - return None diff --git a/diracx-routers/src/diracx/routers/utils/users.py b/diracx-routers/src/diracx/routers/utils/users.py index 6346509d..bfd5f9b5 100644 --- a/diracx-routers/src/diracx/routers/utils/users.py +++ b/diracx-routers/src/diracx/routers/utils/users.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import re from typing import Annotated, Any from uuid import UUID diff --git a/diracx-routers/tests/auth/test_legacy_exchange.py b/diracx-routers/tests/auth/test_legacy_exchange.py index 551e3133..eaf5599b 100644 --- a/diracx-routers/tests/auth/test_legacy_exchange.py +++ b/diracx-routers/tests/auth/test_legacy_exchange.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import json diff --git a/diracx-routers/tests/auth/test_standard.py b/diracx-routers/tests/auth/test_standard.py index 8b2003fc..d883e7bf 100644 --- a/diracx-routers/tests/auth/test_standard.py +++ b/diracx-routers/tests/auth/test_standard.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import base64 import hashlib import secrets diff --git a/diracx-routers/tests/jobs/test_wms_access_policy.py b/diracx-routers/tests/jobs/test_wms_access_policy.py index 0746317c..6df6e675 100644 --- a/diracx-routers/tests/jobs/test_wms_access_policy.py +++ b/diracx-routers/tests/jobs/test_wms_access_policy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from uuid import uuid4 import pytest diff --git a/diracx-routers/tests/test_config_manager.py b/diracx-routers/tests/test_config_manager.py index dbd69c7e..6175b22c 100644 --- a/diracx-routers/tests/test_config_manager.py +++ b/diracx-routers/tests/test_config_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest from fastapi import status diff --git a/diracx-routers/tests/test_generic.py b/diracx-routers/tests/test_generic.py index a83ea439..659f0b5d 100644 --- a/diracx-routers/tests/test_generic.py +++ b/diracx-routers/tests/test_generic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from http import HTTPStatus import pytest diff --git a/diracx-routers/tests/test_job_manager.py b/diracx-routers/tests/test_job_manager.py index 4a81d7e9..62ed1442 100644 --- a/diracx-routers/tests/test_job_manager.py +++ b/diracx-routers/tests/test_job_manager.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone from http import HTTPStatus diff --git a/diracx-routers/tests/test_policy.py b/diracx-routers/tests/test_policy.py index 2ce91515..d9f80e75 100644 --- a/diracx-routers/tests/test_policy.py +++ b/diracx-routers/tests/test_policy.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import inspect from collections import defaultdict from typing import TYPE_CHECKING diff --git a/diracx-testing/src/diracx/testing/__init__.py b/diracx-testing/src/diracx/testing/__init__.py index 59ebca1d..b4ba7aa5 100644 --- a/diracx-testing/src/diracx/testing/__init__.py +++ b/diracx-testing/src/diracx/testing/__init__.py @@ -1,756 +1,47 @@ from __future__ import annotations -# TODO: this needs a lot of documentation, in particular what will matter for users -# are the enabled_dependencies markers -import asyncio -import contextlib -import os -import re -import ssl -import subprocess -import tomllib -from collections import defaultdict -from datetime import datetime, timedelta, timezone -from functools import partial -from html.parser import HTMLParser -from importlib.metadata import PackageNotFoundError, distribution, entry_points -from pathlib import Path -from typing import TYPE_CHECKING -from urllib.parse import parse_qs, urljoin, urlparse -from uuid import uuid4 - -import pytest -import requests - -if TYPE_CHECKING: - from diracx.core.settings import DevelopmentSettings - from diracx.routers.jobs.sandboxes import SandboxStoreSettings - from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings - - -# to get a string like this run: -# openssl rand -hex 32 -ALGORITHM = "HS256" -ISSUER = "http://lhcbdirac.cern.ch/" -AUDIENCE = "dirac" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 - - -def pytest_addoption(parser): - parser.addoption( - "--regenerate-client", - action="store_true", - default=False, - help="Regenerate the AutoREST client", - ) - parser.addoption( - "--demo-dir", - type=Path, - default=None, - help="Path to a diracx-charts directory with the demo running", - ) - - -def pytest_collection_modifyitems(config, items): - """Disable the test_regenerate_client if not explicitly asked for.""" - if config.getoption("--regenerate-client"): - # --regenerate-client given in cli: allow client re-generation - return - skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run") - for item in items: - if item.name == "test_regenerate_client": - item.add_marker(skip_regen) - - -@pytest.fixture(scope="session") -def private_key_pem() -> str: - from cryptography.hazmat.primitives import serialization - from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey - - private_key = Ed25519PrivateKey.generate() - return private_key.private_bytes( - encoding=serialization.Encoding.PEM, - format=serialization.PrivateFormat.PKCS8, - encryption_algorithm=serialization.NoEncryption(), - ).decode() - - -@pytest.fixture(scope="session") -def fernet_key() -> str: - from cryptography.fernet import Fernet - - return Fernet.generate_key().decode() - - -@pytest.fixture(scope="session") -def test_dev_settings() -> DevelopmentSettings: - from diracx.core.settings import DevelopmentSettings - - yield DevelopmentSettings() - - -@pytest.fixture(scope="session") -def test_auth_settings(private_key_pem, fernet_key) -> AuthSettings: - from diracx.routers.utils.users import AuthSettings - - yield AuthSettings( - token_algorithm="EdDSA", - token_key=private_key_pem, - state_key=fernet_key, - allowed_redirects=[ - "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", - ], - ) - - -@pytest.fixture(scope="session") -def aio_moto(worker_id): - """Start the moto server in a separate thread and return the base URL. - - The mocking provided by moto doesn't play nicely with aiobotocore so we use - the server directly. See https://github.com/aio-libs/aiobotocore/issues/755 - """ - from moto.server import ThreadedMotoServer - - port = 27132 - if worker_id != "master": - port += int(worker_id.replace("gw", "")) + 1 - server = ThreadedMotoServer(port=port) - server.start() - yield { - "endpoint_url": f"http://localhost:{port}", - "aws_access_key_id": "testing", - "aws_secret_access_key": "testing", - } - server.stop() - - -@pytest.fixture(scope="session") -def test_sandbox_settings(aio_moto) -> SandboxStoreSettings: - from diracx.routers.jobs.sandboxes import SandboxStoreSettings - - yield SandboxStoreSettings( - bucket_name="sandboxes", - s3_client_kwargs=aio_moto, - auto_create_bucket=True, - ) - - -class UnavailableDependency: - def __init__(self, key): - self.key = key - - def __call__(self): - raise NotImplementedError( - f"{self.key} has not been made available to this test!" - ) - - -class ClientFactory: - - def __init__( - self, - tmp_path_factory, - with_config_repo, - test_auth_settings, - test_sandbox_settings, - test_dev_settings, - ): - from diracx.core.config import ConfigSource - from diracx.core.extensions import select_from_extension - from diracx.core.settings import ServiceSettingsBase - from diracx.db.os.utils import BaseOSDB - from diracx.db.sql.utils import BaseSQLDB - from diracx.routers import create_app_inner - from diracx.routers.access_policies import BaseAccessPolicy - - from .mock_osdb import fake_available_osdb_implementations - - class AlwaysAllowAccessPolicy(BaseAccessPolicy): - """Dummy access policy.""" - - async def policy( - policy_name: str, # noqa: N805 - user_info: AuthorizedUserInfo, - /, - **kwargs, - ): - pass - - def enrich_tokens( - access_payload: dict, refresh_payload: dict # noqa: N805 - ): - - return {"PolicySpecific": "OpenAccessForTest"}, {} - - enabled_systems = { - e.name for e in select_from_extension(group="diracx.services") - } - database_urls = { - e.name: "sqlite+aiosqlite:///:memory:" - for e in select_from_extension(group="diracx.db.sql") - } - # TODO: Monkeypatch this in a less stupid way - # TODO: Only use this if opensearch isn't available - os_database_conn_kwargs = { - e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} - for e in select_from_extension(group="diracx.db.os") - } - BaseOSDB.available_implementations = partial( - fake_available_osdb_implementations, - real_available_implementations=BaseOSDB.available_implementations, - ) - - self._cache_dir = tmp_path_factory.mktemp("empty-dbs") - - self.test_auth_settings = test_auth_settings - self.test_dev_settings = test_dev_settings - - all_access_policies = { - e.name: [AlwaysAllowAccessPolicy] - + BaseAccessPolicy.available_implementations(e.name) - for e in select_from_extension(group="diracx.access_policies") - } - - self.app = create_app_inner( - enabled_systems=enabled_systems, - all_service_settings=[ - test_auth_settings, - test_sandbox_settings, - test_dev_settings, - ], - database_urls=database_urls, - os_database_conn_kwargs=os_database_conn_kwargs, - config_source=ConfigSource.create_from_url( - backend_url=f"git+file://{with_config_repo}" - ), - all_access_policies=all_access_policies, - ) - - self.all_dependency_overrides = self.app.dependency_overrides.copy() - self.app.dependency_overrides = {} - for obj in self.all_dependency_overrides: - assert issubclass( - obj.__self__, - ( - ServiceSettingsBase, - BaseSQLDB, - BaseOSDB, - ConfigSource, - BaseAccessPolicy, - ), - ), obj - - self.all_lifetime_functions = self.app.lifetime_functions[:] - self.app.lifetime_functions = [] - for obj in self.all_lifetime_functions: - assert isinstance( - obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource) - ), obj - - @contextlib.contextmanager - def configure(self, enabled_dependencies): - - assert ( - self.app.dependency_overrides == {} and self.app.lifetime_functions == [] - ), "configure cannot be nested" - for k, v in self.all_dependency_overrides.items(): - - class_name = k.__self__.__name__ - - if class_name in enabled_dependencies: - self.app.dependency_overrides[k] = v - else: - self.app.dependency_overrides[k] = UnavailableDependency(class_name) - - for obj in self.all_lifetime_functions: - # TODO: We should use the name of the entry point instead of the class name - if obj.__self__.__class__.__name__ in enabled_dependencies: - self.app.lifetime_functions.append(obj) - - # Add create_db_schemas to the end of the lifetime_functions so that the - # other lifetime_functions (i.e. those which run db.engine_context) have - # already been ran - self.app.lifetime_functions.append(self.create_db_schemas) - - try: - yield - finally: - self.app.dependency_overrides = {} - self.app.lifetime_functions = [] - - @contextlib.asynccontextmanager - async def create_db_schemas(self): - """Create DB schema's based on the DBs available in app.dependency_overrides.""" - import aiosqlite - import sqlalchemy - from sqlalchemy.util.concurrency import greenlet_spawn - - from diracx.db.sql.utils import BaseSQLDB - - for k, v in self.app.dependency_overrides.items(): - # Ignore dependency overrides which aren't BaseSQLDB.transaction - if ( - isinstance(v, UnavailableDependency) - or k.__func__ != BaseSQLDB.transaction.__func__ - ): - continue - # The first argument of the overridden BaseSQLDB.transaction is the DB object - db = v.args[0] - assert isinstance(db, BaseSQLDB), (k, db) - - # set PRAGMA foreign_keys=ON if sqlite - if db.engine.url.drivername.startswith("sqlite"): - - def set_sqlite_pragma(dbapi_connection, connection_record): - cursor = dbapi_connection.cursor() - cursor.execute("PRAGMA foreign_keys=ON") - cursor.close() - - sqlalchemy.event.listen( - db.engine.sync_engine, "connect", set_sqlite_pragma - ) - - # We maintain a cache of the populated DBs in empty_db_dir so that - # we don't have to recreate them for every test. This speeds up the - # tests by a considerable amount. - ref_db = self._cache_dir / f"{k.__self__.__name__}.db" - if ref_db.exists(): - async with aiosqlite.connect(ref_db) as ref_conn: - conn = await db.engine.raw_connection() - await ref_conn.backup(conn.driver_connection) - await greenlet_spawn(conn.close) - else: - async with db.engine.begin() as conn: - await conn.run_sync(db.metadata.create_all) - - async with aiosqlite.connect(ref_db) as ref_conn: - conn = await db.engine.raw_connection() - await conn.driver_connection.backup(ref_conn) - await greenlet_spawn(conn.close) - - yield - - @contextlib.contextmanager - def unauthenticated(self): - from fastapi.testclient import TestClient - - with TestClient(self.app) as client: - yield client - - @contextlib.contextmanager - def normal_user(self): - from diracx.core.properties import NORMAL_USER - from diracx.routers.auth.token import create_token - - with self.unauthenticated() as client: - payload = { - "sub": "testingVO:yellow-sub", - "exp": datetime.now(tz=timezone.utc) - + timedelta(self.test_auth_settings.access_token_expire_minutes), - "iss": ISSUER, - "dirac_properties": [NORMAL_USER], - "jti": str(uuid4()), - "preferred_username": "preferred_username", - "dirac_group": "test_group", - "vo": "lhcb", - } - token = create_token(payload, self.test_auth_settings) - - client.headers["Authorization"] = f"Bearer {token}" - client.dirac_token_payload = payload - yield client - - @contextlib.contextmanager - def admin_user(self): - from diracx.core.properties import JOB_ADMINISTRATOR - from diracx.routers.auth.token import create_token - - with self.unauthenticated() as client: - payload = { - "sub": "testingVO:yellow-sub", - "iss": ISSUER, - "dirac_properties": [JOB_ADMINISTRATOR], - "jti": str(uuid4()), - "preferred_username": "preferred_username", - "dirac_group": "test_group", - "vo": "lhcb", - } - token = create_token(payload, self.test_auth_settings) - client.headers["Authorization"] = f"Bearer {token}" - client.dirac_token_payload = payload - yield client - - -@pytest.fixture(scope="session") -def session_client_factory( +from .entrypoints import verify_entry_points +from .utils import ( + ClientFactory, + aio_moto, + cli_env, + client_factory, + demo_dir, + demo_kubectl_env, + demo_urls, + do_device_flow_with_dex, + fernet_key, + private_key_pem, + pytest_addoption, + pytest_collection_modifyitems, + session_client_factory, test_auth_settings, + test_dev_settings, + test_login, test_sandbox_settings, + with_cli_login, with_config_repo, - tmp_path_factory, - test_dev_settings, -): - """TODO. - ---- - - """ - yield ClientFactory( - tmp_path_factory, - with_config_repo, - test_auth_settings, - test_sandbox_settings, - test_dev_settings, - ) - - -@pytest.fixture -def client_factory(session_client_factory, request): - marker = request.node.get_closest_marker("enabled_dependencies") - if marker is None: - raise RuntimeError("This test requires the enabled_dependencies marker") - (enabled_dependencies,) = marker.args - with session_client_factory.configure(enabled_dependencies=enabled_dependencies): - yield session_client_factory - - -@pytest.fixture(scope="session") -def with_config_repo(tmp_path_factory): - from git import Repo - - from diracx.core.config import Config - - tmp_path = tmp_path_factory.mktemp("cs-repo") - - repo = Repo.init(tmp_path, initial_branch="master") - cs_file = tmp_path / "default.yml" - example_cs = Config.model_validate( - { - "DIRAC": {}, - "Registry": { - "lhcb": { - "DefaultGroup": "lhcb_user", - "DefaultProxyLifeTime": 432000, - "DefaultStorageQuota": 2000, - "IdP": { - "URL": "https://idp-server.invalid", - "ClientID": "test-idp", - }, - "Users": { - "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": { - "PreferedUsername": "chaen", - "Email": None, - }, - "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": { - "PreferedUsername": "albdr", - "Email": None, - }, - }, - "Groups": { - "lhcb_user": { - "Properties": ["NormalUser", "PrivateLimitedDelegation"], - "Users": [ - "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041", - "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152", - ], - }, - "lhcb_prmgr": { - "Properties": ["NormalUser", "ProductionManagement"], - "Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"], - }, - "lhcb_tokenmgr": { - "Properties": ["NormalUser", "ProxyManagement"], - "Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"], - }, - }, - } - }, - "Operations": {"Defaults": {}}, - "Systems": { - "WorkloadManagement": { - "Production": { - "Databases": { - "JobDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - "MaxRescheduling": 3, - }, - "JobLoggingDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - "PilotAgentsDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - "SandboxMetadataDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - "TaskQueueDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - "ElasticJobParametersDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - "VirtualMachineDB": { - "DBName": "xyz", - "Host": "xyz", - "Port": 9999, - }, - }, - }, - }, - }, - } - ) - cs_file.write_text(example_cs.model_dump_json()) - repo.index.add([cs_file]) # add it to the index - repo.index.commit("Added a new file") - yield tmp_path - - -@pytest.fixture(scope="session") -def demo_dir(request) -> Path: - demo_dir = request.config.getoption("--demo-dir") - if demo_dir is None: - pytest.skip("Requires a running instance of the DiracX demo") - demo_dir = (demo_dir / ".demo").resolve() - yield demo_dir - - -@pytest.fixture(scope="session") -def demo_urls(demo_dir): - import yaml - - helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text()) - yield helm_values["developer"]["urls"] - - -@pytest.fixture(scope="session") -def demo_kubectl_env(demo_dir): - """Get the dictionary of environment variables for kubectl to control the demo.""" - kube_conf = demo_dir / "kube.conf" - if not kube_conf.exists(): - raise RuntimeError(f"Could not find {kube_conf}, is the demo running?") - - env = { - **os.environ, - "KUBECONFIG": str(kube_conf), - "PATH": f"{demo_dir}:{os.environ['PATH']}", - } - - # Check that we can run kubectl - pods_result = subprocess.check_output( - ["kubectl", "get", "pods"], env=env, text=True - ) - assert "diracx" in pods_result - - yield env - - -@pytest.fixture -def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir): - """Set up the environment for the CLI.""" - import httpx - - from diracx.core.preferences import get_diracx_preferences - - diracx_url = demo_urls["diracx"] - ca_path = demo_dir / "demo-ca.pem" - if not ca_path.exists(): - raise RuntimeError(f"Could not find {ca_path}, is the demo running?") - - # Ensure the demo is working - - r = httpx.get( - f"{diracx_url}/api/openapi.json", - verify=ssl.create_default_context(cafile=ca_path), - ) - r.raise_for_status() - assert r.json()["info"]["title"] == "Dirac" - - env = { - "DIRACX_URL": diracx_url, - "DIRACX_CA_PATH": str(ca_path), - "HOME": str(tmp_path), - } - for key, value in env.items(): - monkeypatch.setenv(key, value) - yield env - - # The DiracX preferences are cached however when testing this cache is invalid - get_diracx_preferences.cache_clear() - - -@pytest.fixture -async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path): - try: - credentials = await test_login(monkeypatch, capfd, cli_env) - except Exception as e: - pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}") - - credentials_path = tmp_path / "credentials.json" - credentials_path.write_text(credentials) - monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path)) - yield - - -async def test_login(monkeypatch, capfd, cli_env): - from diracx import cli - - poll_attempts = 0 - - def fake_sleep(*args, **kwargs): - nonlocal poll_attempts - - # Keep track of the number of times this is called - poll_attempts += 1 - - # After polling 5 times, do the actual login - if poll_attempts == 5: - # The login URL should have been printed to stdout - captured = capfd.readouterr() - match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out) - assert match, captured - - do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"]) - - # Ensure we don't poll forever - assert poll_attempts <= 100 - - # Reduce the sleep duration to zero to speed up the test - return unpatched_sleep(0) - - # We monkeypatch asyncio.sleep to provide a hook to run the actions that - # would normally be done by a user. This includes capturing the login URL - # and doing the actual device flow with dex. - unpatched_sleep = asyncio.sleep - - expected_credentials_path = Path( - cli_env["HOME"], ".cache", "diracx", "credentials.json" - ) - # Ensure the credentials file does not exist before logging in - assert not expected_credentials_path.exists() - - # Run the login command - with monkeypatch.context() as m: - m.setattr("asyncio.sleep", fake_sleep) - await cli.login(vo="diracAdmin", group=None, property=None) - captured = capfd.readouterr() - assert "Login successful!" in captured.out - assert captured.err == "" - - # Ensure the credentials file exists after logging in - assert expected_credentials_path.exists() - - # Return the credentials so this test can also be used by the - # "with_cli_login" fixture - return expected_credentials_path.read_text() - - -def do_device_flow_with_dex(url: str, ca_path: str) -> None: - """Do the device flow with dex.""" - - class DexLoginFormParser(HTMLParser): - def handle_starttag(self, tag, attrs): - nonlocal action_url - if "form" in str(tag): - assert action_url is None - action_url = urljoin(login_page_url, dict(attrs)["action"]) - - # Get the login page - r = requests.get(url, verify=ca_path) - r.raise_for_status() - login_page_url = r.url # This is not the same as URL as we redirect to dex - login_page_body = r.text - - # Search the page for the login form so we know where to post the credentials - action_url = None - DexLoginFormParser().feed(login_page_body) - assert action_url is not None, login_page_body - - # Do the actual login - r = requests.post( - action_url, - data={"login": "admin@example.com", "password": "password"}, - verify=ca_path, - ) - r.raise_for_status() - approval_url = r.url # This is not the same as URL as we redirect to dex - # Do the actual approval - r = requests.post( - approval_url, - {"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]}, - verify=ca_path, - ) - - # This should have redirected to the DiracX page that shows the login is complete - assert "Please close the window" in r.text - - -def get_installed_entry_points(): - """Retrieve the installed entry points from the environment.""" - entry_pts = entry_points() - diracx_eps = defaultdict(dict) - for group in entry_pts.groups: - if "diracx" in group: - for ep in entry_pts.select(group=group): - diracx_eps[group][ep.name] = ep.value - return dict(diracx_eps) - - -def get_entry_points_from_toml(toml_file): - """Parse entry points from pyproject.toml.""" - with open(toml_file, "rb") as f: - pyproject = tomllib.load(f) - package_name = pyproject["project"]["name"] - return package_name, pyproject.get("project", {}).get("entry-points", {}) - - -def get_current_entry_points(repo_base) -> bool: - """Create current entry points dict for comparison.""" - current_eps = {} - for toml_file in repo_base.glob("diracx-*/pyproject.toml"): - package_name, entry_pts = get_entry_points_from_toml(f"{toml_file}") - # Ignore packages that are not installed - try: - distribution(package_name) - except PackageNotFoundError: - continue - # Merge the entry points - for key, value in entry_pts.items(): - current_eps[key] = current_eps.get(key, {}) | value - return current_eps - - -@pytest.fixture(scope="session", autouse=True) -def verify_entry_points(request, pytestconfig): - try: - ini_toml_name = tomllib.loads(pytestconfig.inipath.read_text())["project"][ - "name" - ] - except tomllib.TOMLDecodeError: - return - if ini_toml_name == "diracx": - repo_base = pytestconfig.inipath.parent - elif ini_toml_name.startswith("diracx-"): - repo_base = pytestconfig.inipath.parent.parent - else: - return - - installed_eps = get_installed_entry_points() - current_eps = get_current_entry_points(repo_base) - - if installed_eps != current_eps: - pytest.fail( - "Project and installed entry-points are not consistent. " - "You should run `pip install -r requirements-dev.txt`", - ) +) + +__all__ = ( + "verify_entry_points", + "ClientFactory", + "do_device_flow_with_dex", + "test_login", + "pytest_addoption", + "pytest_collection_modifyitems", + "private_key_pem", + "fernet_key", + "test_dev_settings", + "test_auth_settings", + "aio_moto", + "test_sandbox_settings", + "session_client_factory", + "client_factory", + "with_config_repo", + "demo_dir", + "demo_urls", + "demo_kubectl_env", + "cli_env", + "with_cli_login", +) diff --git a/diracx-testing/src/diracx/testing/entrypoints.py b/diracx-testing/src/diracx/testing/entrypoints.py new file mode 100644 index 00000000..17ad2077 --- /dev/null +++ b/diracx-testing/src/diracx/testing/entrypoints.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +import tomllib +from collections import defaultdict +from importlib.metadata import PackageNotFoundError, distribution, entry_points + +import pytest + + +def get_installed_entry_points(): + """Retrieve the installed entry points from the environment.""" + entry_pts = entry_points() + diracx_eps = defaultdict(dict) + for group in entry_pts.groups: + if "diracx" in group: + for ep in entry_pts.select(group=group): + diracx_eps[group][ep.name] = ep.value + return dict(diracx_eps) + + +def get_entry_points_from_toml(toml_file): + """Parse entry points from pyproject.toml.""" + with open(toml_file, "rb") as f: + pyproject = tomllib.load(f) + package_name = pyproject["project"]["name"] + return package_name, pyproject.get("project", {}).get("entry-points", {}) + + +def get_current_entry_points(repo_base) -> bool: + """Create current entry points dict for comparison.""" + current_eps = {} + for toml_file in repo_base.glob("diracx-*/pyproject.toml"): + package_name, entry_pts = get_entry_points_from_toml(f"{toml_file}") + # Ignore packages that are not installed + try: + distribution(package_name) + except PackageNotFoundError: + continue + # Merge the entry points + for key, value in entry_pts.items(): + current_eps[key] = current_eps.get(key, {}) | value + return current_eps + + +@pytest.fixture(scope="session", autouse=True) +def verify_entry_points(request, pytestconfig): + try: + ini_toml_name = tomllib.loads(pytestconfig.inipath.read_text())["project"][ + "name" + ] + except tomllib.TOMLDecodeError: + return + if ini_toml_name == "diracx": + repo_base = pytestconfig.inipath.parent + elif ini_toml_name.startswith("diracx-"): + repo_base = pytestconfig.inipath.parent.parent + else: + return + + installed_eps = get_installed_entry_points() + current_eps = get_current_entry_points(repo_base) + + if installed_eps != current_eps: + pytest.fail( + "Project and installed entry-points are not consistent. " + "You should run `pip install -r requirements-dev.txt`", + ) diff --git a/diracx-testing/src/diracx/testing/utils.py b/diracx-testing/src/diracx/testing/utils.py new file mode 100644 index 00000000..c895f4d4 --- /dev/null +++ b/diracx-testing/src/diracx/testing/utils.py @@ -0,0 +1,694 @@ +"""Utilities for testing DiracX.""" + +from __future__ import annotations + +# TODO: this needs a lot of documentation, in particular what will matter for users +# are the enabled_dependencies markers +import asyncio +import contextlib +import os +import re +import ssl +import subprocess +from datetime import datetime, timedelta, timezone +from functools import partial +from html.parser import HTMLParser +from pathlib import Path +from typing import TYPE_CHECKING, Generator +from urllib.parse import parse_qs, urljoin, urlparse +from uuid import uuid4 + +import pytest +import requests + +if TYPE_CHECKING: + from diracx.core.settings import DevelopmentSettings + from diracx.routers.jobs.sandboxes import SandboxStoreSettings + from diracx.routers.utils.users import AuthorizedUserInfo, AuthSettings + + +# to get a string like this run: +# openssl rand -hex 32 +ALGORITHM = "HS256" +ISSUER = "http://lhcbdirac.cern.ch/" +AUDIENCE = "dirac" +ACCESS_TOKEN_EXPIRE_MINUTES = 30 + + +def pytest_addoption(parser): + parser.addoption( + "--regenerate-client", + action="store_true", + default=False, + help="Regenerate the AutoREST client", + ) + parser.addoption( + "--demo-dir", + type=Path, + default=None, + help="Path to a diracx-charts directory with the demo running", + ) + + +def pytest_collection_modifyitems(config, items): + """Disable the test_regenerate_client if not explicitly asked for.""" + if config.getoption("--regenerate-client"): + # --regenerate-client given in cli: allow client re-generation + return + skip_regen = pytest.mark.skip(reason="need --regenerate-client option to run") + for item in items: + if item.name == "test_regenerate_client": + item.add_marker(skip_regen) + + +@pytest.fixture(scope="session") +def private_key_pem() -> str: + from cryptography.hazmat.primitives import serialization + from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey + + private_key = Ed25519PrivateKey.generate() + return private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ).decode() + + +@pytest.fixture(scope="session") +def fernet_key() -> str: + from cryptography.fernet import Fernet + + return Fernet.generate_key().decode() + + +@pytest.fixture(scope="session") +def test_dev_settings() -> Generator[DevelopmentSettings, None, None]: + from diracx.core.settings import DevelopmentSettings + + yield DevelopmentSettings() + + +@pytest.fixture(scope="session") +def test_auth_settings( + private_key_pem, fernet_key +) -> Generator[AuthSettings, None, None]: + from diracx.routers.utils.users import AuthSettings + + yield AuthSettings( + token_algorithm="EdDSA", + token_key=private_key_pem, + state_key=fernet_key, + allowed_redirects=[ + "http://diracx.test.invalid:8000/api/docs/oauth2-redirect", + ], + ) + + +@pytest.fixture(scope="session") +def aio_moto(worker_id): + """Start the moto server in a separate thread and return the base URL. + + The mocking provided by moto doesn't play nicely with aiobotocore so we use + the server directly. See https://github.com/aio-libs/aiobotocore/issues/755 + """ + from moto.server import ThreadedMotoServer + + port = 27132 + if worker_id != "master": + port += int(worker_id.replace("gw", "")) + 1 + server = ThreadedMotoServer(port=port) + server.start() + yield { + "endpoint_url": f"http://localhost:{port}", + "aws_access_key_id": "testing", + "aws_secret_access_key": "testing", + } + server.stop() + + +@pytest.fixture(scope="session") +def test_sandbox_settings(aio_moto) -> SandboxStoreSettings: + from diracx.routers.jobs.sandboxes import SandboxStoreSettings + + yield SandboxStoreSettings( + bucket_name="sandboxes", + s3_client_kwargs=aio_moto, + auto_create_bucket=True, + ) + + +class UnavailableDependency: + def __init__(self, key): + self.key = key + + def __call__(self): + raise NotImplementedError( + f"{self.key} has not been made available to this test!" + ) + + +class ClientFactory: + + def __init__( + self, + tmp_path_factory, + with_config_repo, + test_auth_settings, + test_sandbox_settings, + test_dev_settings, + ): + from diracx.core.config import ConfigSource + from diracx.core.extensions import select_from_extension + from diracx.core.settings import ServiceSettingsBase + from diracx.db.os.utils import BaseOSDB + from diracx.db.sql.utils import BaseSQLDB + from diracx.routers import create_app_inner + from diracx.routers.access_policies import BaseAccessPolicy + + from .mock_osdb import fake_available_osdb_implementations + + class AlwaysAllowAccessPolicy(BaseAccessPolicy): + """Dummy access policy.""" + + @staticmethod + async def policy( + policy_name: str, user_info: AuthorizedUserInfo, /, **kwargs + ): + pass + + @staticmethod + def enrich_tokens(access_payload: dict, refresh_payload: dict): + + return {"PolicySpecific": "OpenAccessForTest"}, {} + + enabled_systems = { + e.name for e in select_from_extension(group="diracx.services") + } + database_urls = { + e.name: "sqlite+aiosqlite:///:memory:" + for e in select_from_extension(group="diracx.db.sql") + } + # TODO: Monkeypatch this in a less stupid way + # TODO: Only use this if opensearch isn't available + os_database_conn_kwargs = { + e.name: {"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} + for e in select_from_extension(group="diracx.db.os") + } + BaseOSDB.available_implementations = partial( + fake_available_osdb_implementations, + real_available_implementations=BaseOSDB.available_implementations, + ) + + self._cache_dir = tmp_path_factory.mktemp("empty-dbs") + + self.test_auth_settings = test_auth_settings + self.test_dev_settings = test_dev_settings + + all_access_policies = { + e.name: [AlwaysAllowAccessPolicy] + + BaseAccessPolicy.available_implementations(e.name) + for e in select_from_extension(group="diracx.access_policies") + } + + self.app = create_app_inner( + enabled_systems=enabled_systems, + all_service_settings=[ + test_auth_settings, + test_sandbox_settings, + test_dev_settings, + ], + database_urls=database_urls, + os_database_conn_kwargs=os_database_conn_kwargs, + config_source=ConfigSource.create_from_url( + backend_url=f"git+file://{with_config_repo}" + ), + all_access_policies=all_access_policies, + ) + + self.all_dependency_overrides = self.app.dependency_overrides.copy() + self.app.dependency_overrides = {} + for obj in self.all_dependency_overrides: + assert issubclass( + obj.__self__, + ( + ServiceSettingsBase, + BaseSQLDB, + BaseOSDB, + ConfigSource, + BaseAccessPolicy, + ), + ), obj + + self.all_lifetime_functions = self.app.lifetime_functions[:] + self.app.lifetime_functions = [] + for obj in self.all_lifetime_functions: + assert isinstance( + obj.__self__, (ServiceSettingsBase, BaseSQLDB, BaseOSDB, ConfigSource) + ), obj + + @contextlib.contextmanager + def configure(self, enabled_dependencies): + + assert ( + self.app.dependency_overrides == {} and self.app.lifetime_functions == [] + ), "configure cannot be nested" + for k, v in self.all_dependency_overrides.items(): + + class_name = k.__self__.__name__ + + if class_name in enabled_dependencies: + self.app.dependency_overrides[k] = v + else: + self.app.dependency_overrides[k] = UnavailableDependency(class_name) + + for obj in self.all_lifetime_functions: + # TODO: We should use the name of the entry point instead of the class name + if obj.__self__.__class__.__name__ in enabled_dependencies: + self.app.lifetime_functions.append(obj) + + # Add create_db_schemas to the end of the lifetime_functions so that the + # other lifetime_functions (i.e. those which run db.engine_context) have + # already been ran + self.app.lifetime_functions.append(self.create_db_schemas) + + try: + yield + finally: + self.app.dependency_overrides = {} + self.app.lifetime_functions = [] + + @contextlib.asynccontextmanager + async def create_db_schemas(self): + """Create DB schema's based on the DBs available in app.dependency_overrides.""" + import aiosqlite + import sqlalchemy + from sqlalchemy.util.concurrency import greenlet_spawn + + from diracx.db.sql.utils import BaseSQLDB + + for k, v in self.app.dependency_overrides.items(): + # Ignore dependency overrides which aren't BaseSQLDB.transaction + if ( + isinstance(v, UnavailableDependency) + or k.__func__ != BaseSQLDB.transaction.__func__ + ): + continue + # The first argument of the overridden BaseSQLDB.transaction is the DB object + db = v.args[0] + assert isinstance(db, BaseSQLDB), (k, db) + + # set PRAGMA foreign_keys=ON if sqlite + if db.engine.url.drivername.startswith("sqlite"): + + def set_sqlite_pragma(dbapi_connection, connection_record): + cursor = dbapi_connection.cursor() + cursor.execute("PRAGMA foreign_keys=ON") + cursor.close() + + sqlalchemy.event.listen( + db.engine.sync_engine, "connect", set_sqlite_pragma + ) + + # We maintain a cache of the populated DBs in empty_db_dir so that + # we don't have to recreate them for every test. This speeds up the + # tests by a considerable amount. + ref_db = self._cache_dir / f"{k.__self__.__name__}.db" + if ref_db.exists(): + async with aiosqlite.connect(ref_db) as ref_conn: + conn = await db.engine.raw_connection() + await ref_conn.backup(conn.driver_connection) + await greenlet_spawn(conn.close) + else: + async with db.engine.begin() as conn: + await conn.run_sync(db.metadata.create_all) + + async with aiosqlite.connect(ref_db) as ref_conn: + conn = await db.engine.raw_connection() + await conn.driver_connection.backup(ref_conn) + await greenlet_spawn(conn.close) + + yield + + @contextlib.contextmanager + def unauthenticated(self): + from fastapi.testclient import TestClient + + with TestClient(self.app) as client: + yield client + + @contextlib.contextmanager + def normal_user(self): + from diracx.core.properties import NORMAL_USER + from diracx.routers.auth.token import create_token + + with self.unauthenticated() as client: + payload = { + "sub": "testingVO:yellow-sub", + "exp": datetime.now(tz=timezone.utc) + + timedelta(self.test_auth_settings.access_token_expire_minutes), + "iss": ISSUER, + "dirac_properties": [NORMAL_USER], + "jti": str(uuid4()), + "preferred_username": "preferred_username", + "dirac_group": "test_group", + "vo": "lhcb", + } + token = create_token(payload, self.test_auth_settings) + + client.headers["Authorization"] = f"Bearer {token}" + client.dirac_token_payload = payload + yield client + + @contextlib.contextmanager + def admin_user(self): + from diracx.core.properties import JOB_ADMINISTRATOR + from diracx.routers.auth.token import create_token + + with self.unauthenticated() as client: + payload = { + "sub": "testingVO:yellow-sub", + "iss": ISSUER, + "dirac_properties": [JOB_ADMINISTRATOR], + "jti": str(uuid4()), + "preferred_username": "preferred_username", + "dirac_group": "test_group", + "vo": "lhcb", + } + token = create_token(payload, self.test_auth_settings) + client.headers["Authorization"] = f"Bearer {token}" + client.dirac_token_payload = payload + yield client + + +@pytest.fixture(scope="session") +def session_client_factory( + test_auth_settings, + test_sandbox_settings, + with_config_repo, + tmp_path_factory, + test_dev_settings, +): + """TODO. + ---- + + """ + yield ClientFactory( + tmp_path_factory, + with_config_repo, + test_auth_settings, + test_sandbox_settings, + test_dev_settings, + ) + + +@pytest.fixture +def client_factory(session_client_factory, request): + marker = request.node.get_closest_marker("enabled_dependencies") + if marker is None: + raise RuntimeError("This test requires the enabled_dependencies marker") + (enabled_dependencies,) = marker.args + with session_client_factory.configure(enabled_dependencies=enabled_dependencies): + yield session_client_factory + + +@pytest.fixture(scope="session") +def with_config_repo(tmp_path_factory): + from git import Repo + + from diracx.core.config import Config + + tmp_path = tmp_path_factory.mktemp("cs-repo") + + repo = Repo.init(tmp_path, initial_branch="master") + cs_file = tmp_path / "default.yml" + example_cs = Config.model_validate( + { + "DIRAC": {}, + "Registry": { + "lhcb": { + "DefaultGroup": "lhcb_user", + "DefaultProxyLifeTime": 432000, + "DefaultStorageQuota": 2000, + "IdP": { + "URL": "https://idp-server.invalid", + "ClientID": "test-idp", + }, + "Users": { + "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041": { + "PreferedUsername": "chaen", + "Email": None, + }, + "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152": { + "PreferedUsername": "albdr", + "Email": None, + }, + }, + "Groups": { + "lhcb_user": { + "Properties": ["NormalUser", "PrivateLimitedDelegation"], + "Users": [ + "b824d4dc-1f9d-4ee8-8df5-c0ae55d46041", + "c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152", + ], + }, + "lhcb_prmgr": { + "Properties": ["NormalUser", "ProductionManagement"], + "Users": ["b824d4dc-1f9d-4ee8-8df5-c0ae55d46041"], + }, + "lhcb_tokenmgr": { + "Properties": ["NormalUser", "ProxyManagement"], + "Users": ["c935e5ed-2g0e-5ff9-9eg6-d1bf66e57152"], + }, + }, + } + }, + "Operations": {"Defaults": {}}, + "Systems": { + "WorkloadManagement": { + "Production": { + "Databases": { + "JobDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + "MaxRescheduling": 3, + }, + "JobLoggingDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + "PilotAgentsDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + "SandboxMetadataDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + "TaskQueueDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + "ElasticJobParametersDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + "VirtualMachineDB": { + "DBName": "xyz", + "Host": "xyz", + "Port": 9999, + }, + }, + }, + }, + }, + } + ) + cs_file.write_text(example_cs.model_dump_json()) + repo.index.add([cs_file]) # add it to the index + repo.index.commit("Added a new file") + yield tmp_path + + +@pytest.fixture(scope="session") +def demo_dir(request) -> Path: + demo_dir = request.config.getoption("--demo-dir") + if demo_dir is None: + pytest.skip("Requires a running instance of the DiracX demo") + demo_dir = (demo_dir / ".demo").resolve() + yield demo_dir + + +@pytest.fixture(scope="session") +def demo_urls(demo_dir): + import yaml + + helm_values = yaml.safe_load((demo_dir / "values.yaml").read_text()) + yield helm_values["developer"]["urls"] + + +@pytest.fixture(scope="session") +def demo_kubectl_env(demo_dir): + """Get the dictionary of environment variables for kubectl to control the demo.""" + kube_conf = demo_dir / "kube.conf" + if not kube_conf.exists(): + raise RuntimeError(f"Could not find {kube_conf}, is the demo running?") + + env = { + **os.environ, + "KUBECONFIG": str(kube_conf), + "PATH": f"{demo_dir}:{os.environ['PATH']}", + } + + # Check that we can run kubectl + pods_result = subprocess.check_output( + ["kubectl", "get", "pods"], env=env, text=True + ) + assert "diracx" in pods_result + + yield env + + +@pytest.fixture +def cli_env(monkeypatch, tmp_path, demo_urls, demo_dir): + """Set up the environment for the CLI.""" + import httpx + + from diracx.core.preferences import get_diracx_preferences + + diracx_url = demo_urls["diracx"] + ca_path = demo_dir / "demo-ca.pem" + if not ca_path.exists(): + raise RuntimeError(f"Could not find {ca_path}, is the demo running?") + + # Ensure the demo is working + + r = httpx.get( + f"{diracx_url}/api/openapi.json", + verify=ssl.create_default_context(cafile=ca_path), + ) + r.raise_for_status() + assert r.json()["info"]["title"] == "Dirac" + + env = { + "DIRACX_URL": diracx_url, + "DIRACX_CA_PATH": str(ca_path), + "HOME": str(tmp_path), + } + for key, value in env.items(): + monkeypatch.setenv(key, value) + yield env + + # The DiracX preferences are cached however when testing this cache is invalid + get_diracx_preferences.cache_clear() + + +@pytest.fixture +async def with_cli_login(monkeypatch, capfd, cli_env, tmp_path): + try: + credentials = await test_login(monkeypatch, capfd, cli_env) + except Exception as e: + pytest.skip(f"Login failed, fix test_login to re-enable this test: {e!r}") + + credentials_path = tmp_path / "credentials.json" + credentials_path.write_text(credentials) + monkeypatch.setenv("DIRACX_CREDENTIALS_PATH", str(credentials_path)) + yield + + +async def test_login(monkeypatch, capfd, cli_env): + from diracx import cli + + poll_attempts = 0 + + def fake_sleep(*args, **kwargs): + nonlocal poll_attempts + + # Keep track of the number of times this is called + poll_attempts += 1 + + # After polling 5 times, do the actual login + if poll_attempts == 5: + # The login URL should have been printed to stdout + captured = capfd.readouterr() + match = re.search(rf"{cli_env['DIRACX_URL']}[^\n]+", captured.out) + assert match, captured + + do_device_flow_with_dex(match.group(), cli_env["DIRACX_CA_PATH"]) + + # Ensure we don't poll forever + assert poll_attempts <= 100 + + # Reduce the sleep duration to zero to speed up the test + return unpatched_sleep(0) + + # We monkeypatch asyncio.sleep to provide a hook to run the actions that + # would normally be done by a user. This includes capturing the login URL + # and doing the actual device flow with dex. + unpatched_sleep = asyncio.sleep + + expected_credentials_path = Path( + cli_env["HOME"], ".cache", "diracx", "credentials.json" + ) + # Ensure the credentials file does not exist before logging in + assert not expected_credentials_path.exists() + + # Run the login command + with monkeypatch.context() as m: + m.setattr("asyncio.sleep", fake_sleep) + await cli.auth.login(vo="diracAdmin", group=None, property=None) + captured = capfd.readouterr() + assert "Login successful!" in captured.out + assert captured.err == "" + + # Ensure the credentials file exists after logging in + assert expected_credentials_path.exists() + + # Return the credentials so this test can also be used by the + # "with_cli_login" fixture + return expected_credentials_path.read_text() + + +def do_device_flow_with_dex(url: str, ca_path: str) -> None: + """Do the device flow with dex.""" + + class DexLoginFormParser(HTMLParser): + def handle_starttag(self, tag, attrs): + nonlocal action_url + if "form" in str(tag): + assert action_url is None + action_url = urljoin(login_page_url, dict(attrs)["action"]) + + # Get the login page + r = requests.get(url, verify=ca_path) + r.raise_for_status() + login_page_url = r.url # This is not the same as URL as we redirect to dex + login_page_body = r.text + + # Search the page for the login form so we know where to post the credentials + action_url = None + DexLoginFormParser().feed(login_page_body) + assert action_url is not None, login_page_body + + # Do the actual login + r = requests.post( + action_url, + data={"login": "admin@example.com", "password": "password"}, + verify=ca_path, + ) + r.raise_for_status() + approval_url = r.url # This is not the same as URL as we redirect to dex + # Do the actual approval + r = requests.post( + approval_url, + {"approval": "approve", "req": parse_qs(urlparse(r.url).query)["req"][0]}, + verify=ca_path, + ) + + # This should have redirected to the DiracX page that shows the login is complete + assert "Please close the window" in r.text diff --git a/extensions/gubbins/gubbins-core/tests/test_config.py b/extensions/gubbins/gubbins-core/tests/test_config.py index d4b2e453..6a226979 100644 --- a/extensions/gubbins/gubbins-core/tests/test_config.py +++ b/extensions/gubbins/gubbins-core/tests/test_config.py @@ -26,7 +26,7 @@ def github_is_down(): def test_remote_git_config_source(monkeypatch): monkeypatch.setattr( - "diracx.core.config.DEFAULT_CONFIG_FILE", + "diracx.core.config.sources.DEFAULT_CONFIG_FILE", "k3s/examples/cs.yaml", ) remote_conf = ConfigSource.create_from_url(backend_url=TEST_REPO) diff --git a/pyproject.toml b/pyproject.toml index 77998d3c..2429d06d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,11 +69,17 @@ ignore = [ "D404", ] +[tool.ruff.lint.isort] +required-imports = ["from __future__ import annotations"] + [tool.ruff.lint.per-file-ignores] # Ignore Bandit security checks in the test directories "diracx-testing/*" = ["S"] "diracx-*/tests/*" = ["S"] +[tool.ruff.lint.extend-per-file-ignores] +"diracx-routers/src/diracx/routers/access_policies.py" = ["I002"] + [tool.ruff.lint.flake8-bugbear] # Allow default arguments like, e.g., `data: List[str] = fastapi.Query(None)`. extend-immutable-calls = [ diff --git a/tests/make_token_local.py b/tests/make_token_local.py index bcbc4a07..3ddcd204 100755 --- a/tests/make_token_local.py +++ b/tests/make_token_local.py @@ -1,4 +1,6 @@ #!/usr/bin/env python +from __future__ import annotations + import argparse import uuid from datetime import datetime, timedelta, timezone diff --git a/tests/test_generic.py b/tests/test_generic.py index 56c4c17f..0885ef02 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import pytest