diff --git a/diracx-cli/src/diracx/cli/config.py b/diracx-cli/src/diracx/cli/config.py index 0035398bf..4c32caf15 100644 --- a/diracx-cli/src/diracx/cli/config.py +++ b/diracx-cli/src/diracx/cli/config.py @@ -17,7 +17,6 @@ @app.async_command() async def dump(): - breakpoint() async with DiracClient() as api: config = await api.config.serve_config() display(config) diff --git a/diracx-client/pyproject.toml b/diracx-client/pyproject.toml index 01f02d1ec..516857795 100644 --- a/diracx-client/pyproject.toml +++ b/diracx-client/pyproject.toml @@ -27,8 +27,8 @@ requires = ["setuptools>=61", "wheel", "setuptools_scm>=8"] build-backend = "setuptools.build_meta" [project.entry-points."diracx"] -client_class = "diracx.client.patches:DiracClient" -aio_client_class = "diracx.client.patches.aio:DiracClient" +client_class = "diracx.client._client:Dirac" +aio_client_class = "diracx.client.aio._client:Dirac" [tool.setuptools_scm] diff --git a/diracx-client/src/diracx/client/__init__.py b/diracx-client/src/diracx/client/__init__.py index 7cf5cf063..897989c30 100644 --- a/diracx-client/src/diracx/client/__init__.py +++ b/diracx-client/src/diracx/client/__init__.py @@ -4,30 +4,12 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -import sys -from importlib.abc import MetaPathFinder - - -class MyMetaPathFinder(MetaPathFinder): - def find_spec(self, fullname, path, target=None): - # print parameters for more information - print( - "find_spec(fullname={}, path={}, target={})".format(fullname, path, target) - ) - # return None to tell the python this finder can't find the module - return None - - -# insert a MyMetaPathFinder instance at the start of the meta_path list -# sys.meta_path.insert(0, MyMetaPathFinder()) - - from ._client import Dirac try: from ._patch import __all__ as _patch_all from ._patch import * # pylint: disable=unused-wildcard-import -except ImportError: +except ValueError: _patch_all = [] from ._patch import patch_sdk as _patch_sdk diff --git a/diracx-client/src/diracx/client/_patch.py b/diracx-client/src/diracx/client/_patch.py index f98b0d72c..1892c10d8 100644 --- a/diracx-client/src/diracx/client/_patch.py +++ b/diracx-client/src/diracx/client/_patch.py @@ -9,6 +9,7 @@ from __future__ import annotations from datetime import datetime, timezone +import importlib.util import json import jwt import requests @@ -25,7 +26,90 @@ from diracx.core.models import TokenResponse as CoreTokenResponse from diracx.core.preferences import DiracxPreferences, get_diracx_preferences + import sys +import importlib +from importlib.abc import MetaPathFinder, Loader + + +class MyLoader(Loader): + # MyLoader contructor to store txt filepath + def __init__(self, filepath): + self.filepath = filepath + + def create_module(self, spec): + # try to read txt file content + try: + with open(self.filepath) as inp_file: + self.data = inp_file.read() + except: + # raise ImportError if there was an error loading the file + raise ImportError + # when returning None, default module creation will be called + + def exec_module(self, module): + # here we update module and add our custom members + module.__dict__.update({"data": self.data}) + + +# print("CHRIS I DEFINE IT HERE") +class MyMetaPathFinder(MetaPathFinder): + def find_spec(self, fullname, path, target=None): + # print parameters for more information + # print( + # "find_spec(fullname={}, path={}, target={})".format(fullname, path, target) + # ) + + # if ( + # "gubbins.client." in fullname + # and "_patch" in fullname + # and fullname != "gubbins.client._patch" + # ): + if any( + [ + fullname.startswith(prefix) + for prefix in [ + "gubbins.client.operations._patch", + "gubbins.client.models._patch", + "gubbins.client.aio.operations._patch", + ] + ] + ): + # print(f"CHRIS COULD OVERWRITE {fullname=} {path=}") + + try: + overwritten = importlib.util.find_spec( + fullname.replace("gubbins", "diracx") + ) + # print(f"CHRIS MANAGED ! {fullname=} {overwritten=}") + except Exception as e: + # print(f"CHRIS EXCEPT: {e!r}") + overwritten = None + return overwritten + + # return None to tell the python this finder can't find the module + # importlib.util.spec_from_file_location + return None + + +# insert a MyMetaPathFinder instance at the start of the meta_path list +if not isinstance(sys.meta_path[0], MyMetaPathFinder): + sys.meta_path.insert(0, MyMetaPathFinder()) + print("CHRIS STARTING RELOAD") + for module_name, module in sys.modules.copy().items(): + if ( + "diracx.client" in module_name + and module_name + not in ( + "diracx.client", + "diracx.client._patch", + ) + and "_patch" in module_name + ): + print(f"CHRIS Reloading {module_name=}") + importlib.reload(module) + print("CHRIS FINISHED RELOADING") + __all__: List[str] = [ "DiracClient", @@ -41,7 +125,4 @@ def patch_sdk(): """ -from diracx.core.extensions import select_from_extension - -real_client = select_from_extension(group="diracx", name="client_class")[0] -DiracClient = real_client.load() +from .patches import DiracClient diff --git a/diracx-client/src/diracx/client/aio/__init__.py b/diracx-client/src/diracx/client/aio/__init__.py index cc37da18a..08c2206ca 100644 --- a/diracx-client/src/diracx/client/aio/__init__.py +++ b/diracx-client/src/diracx/client/aio/__init__.py @@ -6,10 +6,14 @@ from ._client import Dirac +print(f"CHRIS HERE") + try: from ._patch import __all__ as _patch_all + + print(f"CHRIS HERE {_patch_all=}") from ._patch import * # pylint: disable=unused-wildcard-import -except ImportError: +except ValueError as e: _patch_all = [] from ._patch import patch_sdk as _patch_sdk diff --git a/diracx-client/src/diracx/client/aio/_patch.py b/diracx-client/src/diracx/client/aio/_patch.py index a746ce6e2..31059e7d5 100644 --- a/diracx-client/src/diracx/client/aio/_patch.py +++ b/diracx-client/src/diracx/client/aio/_patch.py @@ -32,7 +32,4 @@ def patch_sdk(): """ -from diracx.core.extensions import select_from_extension - -real_client = select_from_extension(group="diracx", name="aio_client_class")[0] -DiracClient = real_client.load() +from ..patches.aio import DiracClient diff --git a/diracx-client/src/diracx/client/patches/__init__.py b/diracx-client/src/diracx/client/patches/__init__.py index 9bf166818..38eecb471 100644 --- a/diracx-client/src/diracx/client/patches/__init__.py +++ b/diracx-client/src/diracx/client/patches/__init__.py @@ -6,7 +6,7 @@ import requests from pathlib import Path -from typing import Any, Dict, List, Optional, cast +from typing import Any, Dict, List, Optional, cast, TypeAlias from urllib import parse from azure.core.credentials import AccessToken from azure.core.credentials import TokenCredential @@ -17,145 +17,18 @@ from diracx.core.models import TokenResponse as CoreTokenResponse from diracx.core.preferences import DiracxPreferences, get_diracx_preferences -from .._client import Dirac as DiracGenerated -from .utils import ( - get_openid_configuration, - get_token, - refresh_token, - is_refresh_token_valid, -) +from .utils import DiracClientMixin -__all__: List[str] = [ - "DiracClient", -] # Add all objects you want publicly available to users at this package level - - -class DiracTokenCredential(TokenCredential): - """Tailor get_token() for our context""" - - def __init__( - self, - location: Path, - token_endpoint: str, - client_id: str, - *, - verify: bool | str = True, - ) -> None: - self.location = location - self.verify = verify - self.token_endpoint = token_endpoint - self.client_id = client_id - - def get_token( - self, - *scopes: str, - claims: Optional[str] = None, - tenant_id: Optional[str] = None, - **kwargs: Any, - ) -> AccessToken: - """Refresh the access token using the refresh_token flow. - :param str scopes: The type of access needed. - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. - :keyword str tenant_id: Optional tenant to include in the token request. - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. - """ - return refresh_token( - self.location, - self.token_endpoint, - self.client_id, - kwargs["refresh_token"], - verify=self.verify, - ) - - -class DiracBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): - """Custom BearerTokenCredentialPolicy tailored for our use case. - - * It does not ensure the connection is done through https. - * It does not ensure that an access token is available. - """ +from diracx.core.extensions import select_from_extension - def __init__( - self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any - ) -> None: - super().__init__(credential, *scopes, **kwargs) +real_client = select_from_extension(group="diracx", name="client_class")[0].load() +DiracGenerated = real_client - def on_request( - self, request: PipelineRequest - ) -> None: # pylint:disable=invalid-overridden-method - """Authorization Bearer is optional here. - :param request: The pipeline request object to be modified. - :type request: ~azure.core.pipeline.PipelineRequest - :raises: :class:`~azure.core.exceptions.ServiceRequestError` - """ - self._token: AccessToken | None - self._credential: DiracTokenCredential - credentials: dict[str, Any] - try: - self._token = get_token(self._credential.location, self._token) - except RuntimeError: - # If we are here, it means the credentials path does not exist - # we suppose it is not needed to perform the request - return - - if not self._token: - credentials = json.loads(self._credential.location.read_text()) - refresh_token = credentials["refresh_token"] - if not is_refresh_token_valid(refresh_token): - # If we are here, it means the refresh token is not valid anymore - # we suppose it is not needed to perform the request - return - self._token = self._credential.get_token("", refresh_token=refresh_token) - - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" - - -class DiracClient(DiracGenerated): - """This class inherits from the generated Dirac client and adds support for tokens, - so that the caller does not need to configure it by itself. - """ - - def __init__( - self, - endpoint: str | None = None, - client_id: str | None = None, - diracx_preferences: DiracxPreferences | None = None, - verify: bool | str = True, - **kwargs: Any, - ) -> None: - diracx_preferences = diracx_preferences or get_diracx_preferences() - self._endpoint = str(endpoint or diracx_preferences.url) - if verify is True and diracx_preferences.ca_path: - verify = str(diracx_preferences.ca_path) - kwargs["connection_verify"] = verify - self._client_id = client_id or "myDIRACClientID" - - # Get .well-known configuration - openid_configuration = get_openid_configuration(self._endpoint, verify=verify) - - # Initialize Dirac with a Dirac-specific token credential policy - super().__init__( - endpoint=self._endpoint, - authentication_policy=DiracBearerTokenCredentialPolicy( - DiracTokenCredential( - location=diracx_preferences.credentials_path, - token_endpoint=openid_configuration["token_endpoint"], - client_id=self._client_id, - verify=verify, - ), - ), - **kwargs, - ) +__all__: List[str] = [ + "DiracClient", +] # Add all objects you want publicly available to users at this package level - @property - def client_id(self): - return self._client_id - def __aenter__(self) -> "DiracClient": - """Redefined to provide the patched Dirac client in the managed context""" - self._client.__enter__() - return self +class DiracClient(DiracClientMixin, DiracGenerated): ... # type: ignore diff --git a/diracx-client/src/diracx/client/patches/aio/__init__.py b/diracx-client/src/diracx/client/patches/aio/__init__.py index 0fd61defb..548cc6fc4 100644 --- a/diracx-client/src/diracx/client/patches/aio/__init__.py +++ b/diracx-client/src/diracx/client/patches/aio/__init__.py @@ -17,171 +17,17 @@ from diracx.core.preferences import get_diracx_preferences, DiracxPreferences -from ...aio._client import Dirac as DiracGenerated -from ..utils import ( - get_openid_configuration, - get_token, - refresh_token, - is_refresh_token_valid, -) +from .utils import DiracClientMixin __all__: List[str] = [ "DiracClient", ] # Add all objects you want publicly available to users at this package level -def patch_sdk(): - """Do not remove from this file. +from diracx.core.extensions import select_from_extension - `patch_sdk` is a last resort escape hatch that allows you to do customizations - you can't accomplish using the techniques described in - https://aka.ms/azsdk/python/dpcodegen/python/customize - """ +real_client = select_from_extension(group="diracx", name="aio_client_class")[0].load() +DiracGenerated = real_client -class DiracTokenCredential(AsyncTokenCredential): - """Tailor get_token() for our context""" - - def __init__( - self, - location: Path, - token_endpoint: str, - client_id: str, - *, - verify: bool | str = True, - ) -> None: - self.location = location - self.verify = verify - self.token_endpoint = token_endpoint - self.client_id = client_id - - async def get_token( - self, - *scopes: str, - claims: Optional[str] = None, - tenant_id: Optional[str] = None, - **kwargs: Any, - ) -> AccessToken: - """Refresh the access token using the refresh_token flow. - :param str scopes: The type of access needed. - :keyword str claims: Additional claims required in the token, such as those returned in a resource - provider's claims challenge following an authorization failure. - :keyword str tenant_id: Optional tenant to include in the token request. - :rtype: AccessToken - :return: An AccessToken instance containing the token string and its expiration time in Unix time. - """ - return refresh_token( - self.location, - self.token_endpoint, - self.client_id, - kwargs["refresh_token"], - verify=self.verify, - ) - - async def close(self) -> None: - """AsyncTokenCredential is a protocol: we need to 'implement' close()""" - pass - - async def __aenter__(self): - """AsyncTokenCredential is a protocol: we need to 'implement' __aenter__()""" - pass - - async def __aexit__( - self, - exc_type: type[BaseException] | None = ..., - exc_value: BaseException | None = ..., - traceback: TracebackType | None = ..., - ) -> None: - """AsyncTokenCredential is a protocol: we need to 'implement' __aexit__()""" - pass - - -class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): - """Custom AsyncBearerTokenCredentialPolicy tailored for our use case. - - * It does not ensure the connection is done through https. - * It does not ensure that an access token is available. - """ - - def __init__( - self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any - ) -> None: - super().__init__(credential, *scopes, **kwargs) - - async def on_request( - self, request: PipelineRequest - ) -> None: # pylint:disable=invalid-overridden-method - """Authorization Bearer is optional here. - :param request: The pipeline request object to be modified. - :type request: ~azure.core.pipeline.PipelineRequest - :raises: :class:`~azure.core.exceptions.ServiceRequestError` - """ - self._token: AccessToken | None - self._credential: DiracTokenCredential - credentials: dict[str, Any] - try: - self._token = get_token(self._credential.location, self._token) - except RuntimeError: - # If we are here, it means the credentials path does not exist - # we suppose it is not needed to perform the request - return - - if not self._token: - credentials = json.loads(self._credential.location.read_text()) - refresh_token = credentials["refresh_token"] - if not is_refresh_token_valid(refresh_token): - # If we are here, it means the refresh token is not valid anymore - # we suppose it is not needed to perform the request - return - self._token = await self._credential.get_token( - "", refresh_token=refresh_token - ) - - request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" - - -class DiracClient(DiracGenerated): - """This class inherits from the generated Dirac client and adds support for tokens, - so that the caller does not need to configure it by itself. - """ - - def __init__( - self, - endpoint: str | None = None, - client_id: str | None = None, - diracx_preferences: DiracxPreferences | None = None, - verify: bool | str = True, - **kwargs: Any, - ) -> None: - diracx_preferences = diracx_preferences or get_diracx_preferences() - if verify is True and diracx_preferences.ca_path: - verify = str(diracx_preferences.ca_path) - kwargs["connection_verify"] = verify - self._endpoint = str(endpoint or diracx_preferences.url) - self._client_id = client_id or "myDIRACClientID" - - # Get .well-known configuration - openid_configuration = get_openid_configuration(self._endpoint, verify=verify) - - # Initialize Dirac with a Dirac-specific token credential policy - super().__init__( - endpoint=self._endpoint, - authentication_policy=DiracBearerTokenCredentialPolicy( - DiracTokenCredential( - location=diracx_preferences.credentials_path, - token_endpoint=openid_configuration["token_endpoint"], - client_id=self._client_id, - verify=verify, - ), - ), - **kwargs, - ) - - @property - def client_id(self): - return self._client_id - - async def __aenter__(self) -> "DiracClient": - """Redefined to provide the patched Dirac client in the managed context""" - await self._client.__aenter__() - return self +class DiracClient(DiracClientMixin, DiracGenerated): ... # type: ignore diff --git a/diracx-client/src/diracx/client/patches/aio/utils.py b/diracx-client/src/diracx/client/patches/aio/utils.py new file mode 100644 index 000000000..d039031d5 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/aio/utils.py @@ -0,0 +1,184 @@ +# ------------------------------------ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +# ------------------------------------ +"""Customize generated code here. + +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize +""" +from __future__ import annotations + +import abc +import json +from types import TracebackType +from pathlib import Path +from typing import Any, List, Optional +from azure.core.credentials import AccessToken +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline import PipelineRequest +from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy + +from diracx.core.preferences import get_diracx_preferences, DiracxPreferences + +from ..utils import ( + get_openid_configuration, + get_token, + refresh_token, + is_refresh_token_valid, +) + +__all__: List[str] = [ + "DiracClient", +] # Add all objects you want publicly available to users at this package level + + +class DiracTokenCredential(AsyncTokenCredential): + """Tailor get_token() for our context""" + + def __init__( + self, + location: Path, + token_endpoint: str, + client_id: str, + *, + verify: bool | str = True, + ) -> None: + self.location = location + self.verify = verify + self.token_endpoint = token_endpoint + self.client_id = client_id + + async def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + **kwargs: Any, + ) -> AccessToken: + """Refresh the access token using the refresh_token flow. + :param str scopes: The type of access needed. + :keyword str claims: Additional claims required in the token, such as those returned in a resource + provider's claims challenge following an authorization failure. + :keyword str tenant_id: Optional tenant to include in the token request. + :rtype: AccessToken + :return: An AccessToken instance containing the token string and its expiration time in Unix time. + """ + return refresh_token( + self.location, + self.token_endpoint, + self.client_id, + kwargs["refresh_token"], + verify=self.verify, + ) + + async def close(self) -> None: + """AsyncTokenCredential is a protocol: we need to 'implement' close()""" + pass + + async def __aenter__(self): + """AsyncTokenCredential is a protocol: we need to 'implement' __aenter__()""" + pass + + async def __aexit__( + self, + exc_type: type[BaseException] | None = ..., + exc_value: BaseException | None = ..., + traceback: TracebackType | None = ..., + ) -> None: + """AsyncTokenCredential is a protocol: we need to 'implement' __aexit__()""" + pass + + +class DiracBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy): + """Custom AsyncBearerTokenCredentialPolicy tailored for our use case. + + * It does not ensure the connection is done through https. + * It does not ensure that an access token is available. + """ + + def __init__( + self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any + ) -> None: + super().__init__(credential, *scopes, **kwargs) + + async def on_request( + self, request: PipelineRequest + ) -> None: # pylint:disable=invalid-overridden-method + """Authorization Bearer is optional here. + :param request: The pipeline request object to be modified. + :type request: ~azure.core.pipeline.PipelineRequest + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + self._token: AccessToken | None + self._credential: DiracTokenCredential + credentials: dict[str, Any] + try: + self._token = get_token(self._credential.location, self._token) + except RuntimeError: + # If we are here, it means the credentials path does not exist + # we suppose it is not needed to perform the request + return + + if not self._token: + credentials = json.loads(self._credential.location.read_text()) + refresh_token = credentials["refresh_token"] + if not is_refresh_token_valid(refresh_token): + # If we are here, it means the refresh token is not valid anymore + # we suppose it is not needed to perform the request + return + self._token = await self._credential.get_token( + "", refresh_token=refresh_token + ) + + request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" + + +class DiracClientMixin(metaclass=abc.ABCMeta): + """This class inherits from the generated Dirac client and adds support for tokens, + so that the caller does not need to configure it by itself. + """ + + def __init__( + self, + endpoint: str | None = None, + client_id: str | None = None, + diracx_preferences: DiracxPreferences | None = None, + verify: bool | str = True, + **kwargs: Any, + ) -> None: + diracx_preferences = diracx_preferences or get_diracx_preferences() + if verify is True and diracx_preferences.ca_path: + verify = str(diracx_preferences.ca_path) + kwargs["connection_verify"] = verify + self._endpoint = str(endpoint or diracx_preferences.url) + self._client_id = client_id or "myDIRACClientID" + + # Get .well-known configuration + openid_configuration = get_openid_configuration(self._endpoint, verify=verify) + + # Initialize Dirac with a Dirac-specific token credential policy + # We need to ignore types here because mypy complains that we give + # too many arguments to "object" constructor as this is a mixin + + super().__init__( # type: ignore + endpoint=self._endpoint, + authentication_policy=DiracBearerTokenCredentialPolicy( + DiracTokenCredential( + location=diracx_preferences.credentials_path, + token_endpoint=openid_configuration["token_endpoint"], + client_id=self._client_id, + verify=verify, + ), + ), + **kwargs, + ) + + @property + def client_id(self): + return self._client_id + + async def __aenter__(self) -> "DiracClient": # type: ignore + """Redefined to provide the patched Dirac client in the managed context""" + # _client comes from the generated class + await self._client.__aenter__() # type: ignore + return self diff --git a/diracx-client/src/diracx/client/patches/utils.py b/diracx-client/src/diracx/client/patches/utils.py index 331ffba06..617be5364 100644 --- a/diracx-client/src/diracx/client/patches/utils.py +++ b/diracx-client/src/diracx/client/patches/utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from datetime import datetime, timezone import json import jwt @@ -110,3 +112,136 @@ def is_token_valid(token: AccessToken) -> bool: datetime.fromtimestamp(token.expires_on, tz=timezone.utc) - datetime.now(tz=timezone.utc) ).total_seconds() > 300 + + +class DiracTokenCredential(TokenCredential): + """Tailor get_token() for our context""" + + def __init__( + self, + location: Path, + token_endpoint: str, + client_id: str, + *, + verify: bool | str = True, + ) -> None: + self.location = location + self.verify = verify + self.token_endpoint = token_endpoint + self.client_id = client_id + + def get_token( + self, + *scopes: str, + claims: Optional[str] = None, + tenant_id: Optional[str] = None, + **kwargs: Any, + ) -> AccessToken: + """Refresh the access token using the refresh_token flow. + :param str scopes: The type of access needed. + :keyword str claims: Additional claims required in the token, such as those returned in a resource + provider's claims challenge following an authorization failure. + :keyword str tenant_id: Optional tenant to include in the token request. + :rtype: AccessToken + :return: An AccessToken instance containing the token string and its expiration time in Unix time. + """ + return refresh_token( + self.location, + self.token_endpoint, + self.client_id, + kwargs["refresh_token"], + verify=self.verify, + ) + + +class DiracBearerTokenCredentialPolicy(BearerTokenCredentialPolicy): + """Custom BearerTokenCredentialPolicy tailored for our use case. + + * It does not ensure the connection is done through https. + * It does not ensure that an access token is available. + """ + + def __init__( + self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any + ) -> None: + super().__init__(credential, *scopes, **kwargs) + + def on_request( + self, request: PipelineRequest + ) -> None: # pylint:disable=invalid-overridden-method + """Authorization Bearer is optional here. + :param request: The pipeline request object to be modified. + :type request: ~azure.core.pipeline.PipelineRequest + :raises: :class:`~azure.core.exceptions.ServiceRequestError` + """ + self._token: AccessToken | None + self._credential: DiracTokenCredential + credentials: dict[str, Any] + + try: + self._token = get_token(self._credential.location, self._token) + except RuntimeError: + # If we are here, it means the credentials path does not exist + # we suppose it is not needed to perform the request + return + + if not self._token: + credentials = json.loads(self._credential.location.read_text()) + refresh_token = credentials["refresh_token"] + if not is_refresh_token_valid(refresh_token): + # If we are here, it means the refresh token is not valid anymore + # we suppose it is not needed to perform the request + return + self._token = self._credential.get_token("", refresh_token=refresh_token) + + request.http_request.headers["Authorization"] = f"Bearer {self._token.token}" + + +class DiracClientMixin: + """This class inherits from the generated Dirac client and adds support for tokens, + so that the caller does not need to configure it by itself. + """ + + def __init__( + self, + endpoint: str | None = None, + client_id: str | None = None, + diracx_preferences: DiracxPreferences | None = None, + verify: bool | str = True, + **kwargs: Any, + ) -> None: + diracx_preferences = diracx_preferences or get_diracx_preferences() + self._endpoint = str(endpoint or diracx_preferences.url) + if verify is True and diracx_preferences.ca_path: + verify = str(diracx_preferences.ca_path) + kwargs["connection_verify"] = verify + self._client_id = client_id or "myDIRACClientID" + + # Get .well-known configuration + openid_configuration = get_openid_configuration(self._endpoint, verify=verify) + + # Initialize Dirac with a Dirac-specific token credential policy + # We need to ignore types here because mypy complains that we give + # too many arguments to "object" constructor as this is a mixin + + super().__init__( # type: ignore + endpoint=self._endpoint, + authentication_policy=DiracBearerTokenCredentialPolicy( + DiracTokenCredential( + location=diracx_preferences.credentials_path, + token_endpoint=openid_configuration["token_endpoint"], + client_id=self._client_id, + verify=verify, + ), + ), + **kwargs, + ) + + @property + def client_id(self): + return self._client_id + + def __aenter__(self) -> "DiracClient": # type: ignore + """Redefined to provide the patched Dirac client in the managed context""" + self._client.__enter__() # type: ignore + return self diff --git a/extensions/gubbins/gubbins-client/pyproject.toml b/extensions/gubbins/gubbins-client/pyproject.toml index 117186363..2ab6c4de5 100644 --- a/extensions/gubbins/gubbins-client/pyproject.toml +++ b/extensions/gubbins/gubbins-client/pyproject.toml @@ -28,8 +28,8 @@ build-backend = "setuptools.build_meta" [project.entry-points."diracx"] -client_class = "gubbins.client.patches:GubbinsClient" -aio_client_class = "gubbins.client.patches.aio:GubbinsClient" +client_class = "gubbins.client._client:Dirac" +aio_client_class = "gubbins.client.aio._client:Dirac" [tool.setuptools_scm] diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_patch.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_patch.py index abf561200..8ee1e4c66 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_patch.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_patch.py @@ -8,9 +8,9 @@ """ from typing import List -__all__: List[str] = ( - [] -) # Add all objects you want publicly available to users at this package level +__all__: List[str] = [ + "GubbinsClient" +] # Add all objects you want publicly available to users at this package level def patch_sdk(): @@ -20,3 +20,6 @@ def patch_sdk(): you can't accomplish using the techniques described in https://aka.ms/azsdk/python/dpcodegen/python/customize """ + + +from .patches import GubbinsClient diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/aio/_patch.py b/extensions/gubbins/gubbins-client/src/gubbins/client/aio/_patch.py index abf561200..5956b0c68 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/aio/_patch.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/aio/_patch.py @@ -8,9 +8,9 @@ """ from typing import List -__all__: List[str] = ( - [] -) # Add all objects you want publicly available to users at this package level +__all__: List[str] = [ + "GubbinsClient" +] # Add all objects you want publicly available to users at this package level def patch_sdk(): @@ -20,3 +20,6 @@ def patch_sdk(): you can't accomplish using the techniques described in https://aka.ms/azsdk/python/dpcodegen/python/customize """ + + +from ..patches.aio import GubbinsClient diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/patches/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/patches/__init__.py index 37f0470f1..a2a148a2b 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/patches/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/patches/__init__.py @@ -1,5 +1,5 @@ -from diracx.client.patches import DiracClient +from diracx.client.patches.utils import DiracClientMixin +from gubbins.client._client import Dirac as GubbinsGenerated -class GubbinsClient(DiracClient): - pass +class GubbinsClient(DiracClientMixin, GubbinsGenerated): ... diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/patches/aio/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/patches/aio/__init__.py index c0f194384..96a5669cc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/patches/aio/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/patches/aio/__init__.py @@ -1,5 +1,11 @@ -from diracx.client.patches.aio import DiracClient +# from diracx.client.patches.aio import DiracClient -class GubbinsClient(DiracClient): - pass +# class GubbinsClient(DiracClient): +# pass + +from diracx.client.patches.aio.utils import DiracClientMixin +from gubbins.client.aio._client import Dirac as GubbinsGenerated + + +class GubbinsClient(DiracClientMixin, GubbinsGenerated): ... diff --git a/extensions/gubbins/gubbins-client/tests/test_regenerate.py b/extensions/gubbins/gubbins-client/tests/test_regenerate.py index f0522190f..894d9346e 100644 --- a/extensions/gubbins/gubbins-client/tests/test_regenerate.py +++ b/extensions/gubbins/gubbins-client/tests/test_regenerate.py @@ -36,7 +36,6 @@ def test_regenerate_client(test_client, tmp_path): openapi_spec = tmp_path / "openapi.json" openapi_spec.write_text(r.text) - breakpoint() output_folder = Path(gubbins.client.__file__).parent.parent assert (output_folder / "client").is_dir() repo_root = output_folder.parent.parent.parent.parent.parent