Skip to content

Commit

Permalink
seems good now
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Sep 17, 2024
1 parent 366e8ed commit 171d908
Show file tree
Hide file tree
Showing 16 changed files with 453 additions and 341 deletions.
1 change: 0 additions & 1 deletion diracx-cli/src/diracx/cli/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

@app.async_command()
async def dump():
breakpoint()
async with DiracClient() as api:
config = await api.config.serve_config()
display(config)
Expand Down
4 changes: 2 additions & 2 deletions diracx-client/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ requires = ["setuptools>=61", "wheel", "setuptools_scm>=8"]
build-backend = "setuptools.build_meta"

[project.entry-points."diracx"]
client_class = "diracx.client.patches:DiracClient"
aio_client_class = "diracx.client.patches.aio:DiracClient"
client_class = "diracx.client._client:Dirac"
aio_client_class = "diracx.client.aio._client:Dirac"


[tool.setuptools_scm]
Expand Down
20 changes: 1 addition & 19 deletions diracx-client/src/diracx/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,12 @@
# Changes may cause incorrect behavior and will be lost if the code is regenerated.
# --------------------------------------------------------------------------

import sys
from importlib.abc import MetaPathFinder


class MyMetaPathFinder(MetaPathFinder):
def find_spec(self, fullname, path, target=None):
# print parameters for more information
print(
"find_spec(fullname={}, path={}, target={})".format(fullname, path, target)
)
# return None to tell the python this finder can't find the module
return None


# insert a MyMetaPathFinder instance at the start of the meta_path list
# sys.meta_path.insert(0, MyMetaPathFinder())


from ._client import Dirac

try:
from ._patch import __all__ as _patch_all
from ._patch import * # pylint: disable=unused-wildcard-import
except ImportError:
except ValueError:
_patch_all = []
from ._patch import patch_sdk as _patch_sdk

Expand Down
89 changes: 85 additions & 4 deletions diracx-client/src/diracx/client/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

from datetime import datetime, timezone
import importlib.util
import json
import jwt
import requests
Expand All @@ -25,7 +26,90 @@
from diracx.core.models import TokenResponse as CoreTokenResponse
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences


import sys
import importlib
from importlib.abc import MetaPathFinder, Loader


class MyLoader(Loader):
# MyLoader contructor to store txt filepath
def __init__(self, filepath):
self.filepath = filepath

def create_module(self, spec):
# try to read txt file content
try:
with open(self.filepath) as inp_file:
self.data = inp_file.read()
except:
# raise ImportError if there was an error loading the file
raise ImportError
# when returning None, default module creation will be called

def exec_module(self, module):
# here we update module and add our custom members
module.__dict__.update({"data": self.data})


# print("CHRIS I DEFINE IT HERE")
class MyMetaPathFinder(MetaPathFinder):
def find_spec(self, fullname, path, target=None):
# print parameters for more information
# print(
# "find_spec(fullname={}, path={}, target={})".format(fullname, path, target)
# )

# if (
# "gubbins.client." in fullname
# and "_patch" in fullname
# and fullname != "gubbins.client._patch"
# ):
if any(
[
fullname.startswith(prefix)
for prefix in [
"gubbins.client.operations._patch",
"gubbins.client.models._patch",
"gubbins.client.aio.operations._patch",
]
]
):
# print(f"CHRIS COULD OVERWRITE {fullname=} {path=}")

try:
overwritten = importlib.util.find_spec(
fullname.replace("gubbins", "diracx")
)
# print(f"CHRIS MANAGED ! {fullname=} {overwritten=}")
except Exception as e:
# print(f"CHRIS EXCEPT: {e!r}")
overwritten = None
return overwritten

# return None to tell the python this finder can't find the module
# importlib.util.spec_from_file_location
return None


# insert a MyMetaPathFinder instance at the start of the meta_path list
if not isinstance(sys.meta_path[0], MyMetaPathFinder):
sys.meta_path.insert(0, MyMetaPathFinder())
print("CHRIS STARTING RELOAD")
for module_name, module in sys.modules.copy().items():
if (
"diracx.client" in module_name
and module_name
not in (
"diracx.client",
"diracx.client._patch",
)
and "_patch" in module_name
):
print(f"CHRIS Reloading {module_name=}")
importlib.reload(module)
print("CHRIS FINISHED RELOADING")


__all__: List[str] = [
"DiracClient",
Expand All @@ -41,7 +125,4 @@ def patch_sdk():
"""


from diracx.core.extensions import select_from_extension

real_client = select_from_extension(group="diracx", name="client_class")[0]
DiracClient = real_client.load()
from .patches import DiracClient
6 changes: 5 additions & 1 deletion diracx-client/src/diracx/client/aio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,14 @@

from ._client import Dirac

print(f"CHRIS HERE")

try:
from ._patch import __all__ as _patch_all

print(f"CHRIS HERE {_patch_all=}")
from ._patch import * # pylint: disable=unused-wildcard-import
except ImportError:
except ValueError as e:
_patch_all = []
from ._patch import patch_sdk as _patch_sdk

Expand Down
5 changes: 1 addition & 4 deletions diracx-client/src/diracx/client/aio/_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,4 @@ def patch_sdk():
"""


from diracx.core.extensions import select_from_extension

real_client = select_from_extension(group="diracx", name="aio_client_class")[0]
DiracClient = real_client.load()
from ..patches.aio import DiracClient
145 changes: 9 additions & 136 deletions diracx-client/src/diracx/client/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import requests

from pathlib import Path
from typing import Any, Dict, List, Optional, cast
from typing import Any, Dict, List, Optional, cast, TypeAlias
from urllib import parse
from azure.core.credentials import AccessToken
from azure.core.credentials import TokenCredential
Expand All @@ -17,145 +17,18 @@
from diracx.core.models import TokenResponse as CoreTokenResponse
from diracx.core.preferences import DiracxPreferences, get_diracx_preferences

from .._client import Dirac as DiracGenerated
from .utils import (
get_openid_configuration,
get_token,
refresh_token,
is_refresh_token_valid,
)
from .utils import DiracClientMixin


__all__: List[str] = [
"DiracClient",
] # Add all objects you want publicly available to users at this package level


class DiracTokenCredential(TokenCredential):
"""Tailor get_token() for our context"""

def __init__(
self,
location: Path,
token_endpoint: str,
client_id: str,
*,
verify: bool | str = True,
) -> None:
self.location = location
self.verify = verify
self.token_endpoint = token_endpoint
self.client_id = client_id

def get_token(
self,
*scopes: str,
claims: Optional[str] = None,
tenant_id: Optional[str] = None,
**kwargs: Any,
) -> AccessToken:
"""Refresh the access token using the refresh_token flow.
:param str scopes: The type of access needed.
:keyword str claims: Additional claims required in the token, such as those returned in a resource
provider's claims challenge following an authorization failure.
:keyword str tenant_id: Optional tenant to include in the token request.
:rtype: AccessToken
:return: An AccessToken instance containing the token string and its expiration time in Unix time.
"""
return refresh_token(
self.location,
self.token_endpoint,
self.client_id,
kwargs["refresh_token"],
verify=self.verify,
)


class DiracBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
"""Custom BearerTokenCredentialPolicy tailored for our use case.
* It does not ensure the connection is done through https.
* It does not ensure that an access token is available.
"""
from diracx.core.extensions import select_from_extension

def __init__(
self, credential: DiracTokenCredential, *scopes: str, **kwargs: Any
) -> None:
super().__init__(credential, *scopes, **kwargs)
real_client = select_from_extension(group="diracx", name="client_class")[0].load()
DiracGenerated = real_client

def on_request(
self, request: PipelineRequest
) -> None: # pylint:disable=invalid-overridden-method
"""Authorization Bearer is optional here.
:param request: The pipeline request object to be modified.
:type request: ~azure.core.pipeline.PipelineRequest
:raises: :class:`~azure.core.exceptions.ServiceRequestError`
"""
self._token: AccessToken | None
self._credential: DiracTokenCredential
credentials: dict[str, Any]

try:
self._token = get_token(self._credential.location, self._token)
except RuntimeError:
# If we are here, it means the credentials path does not exist
# we suppose it is not needed to perform the request
return

if not self._token:
credentials = json.loads(self._credential.location.read_text())
refresh_token = credentials["refresh_token"]
if not is_refresh_token_valid(refresh_token):
# If we are here, it means the refresh token is not valid anymore
# we suppose it is not needed to perform the request
return
self._token = self._credential.get_token("", refresh_token=refresh_token)

request.http_request.headers["Authorization"] = f"Bearer {self._token.token}"


class DiracClient(DiracGenerated):
"""This class inherits from the generated Dirac client and adds support for tokens,
so that the caller does not need to configure it by itself.
"""

def __init__(
self,
endpoint: str | None = None,
client_id: str | None = None,
diracx_preferences: DiracxPreferences | None = None,
verify: bool | str = True,
**kwargs: Any,
) -> None:
diracx_preferences = diracx_preferences or get_diracx_preferences()
self._endpoint = str(endpoint or diracx_preferences.url)
if verify is True and diracx_preferences.ca_path:
verify = str(diracx_preferences.ca_path)
kwargs["connection_verify"] = verify
self._client_id = client_id or "myDIRACClientID"

# Get .well-known configuration
openid_configuration = get_openid_configuration(self._endpoint, verify=verify)

# Initialize Dirac with a Dirac-specific token credential policy
super().__init__(
endpoint=self._endpoint,
authentication_policy=DiracBearerTokenCredentialPolicy(
DiracTokenCredential(
location=diracx_preferences.credentials_path,
token_endpoint=openid_configuration["token_endpoint"],
client_id=self._client_id,
verify=verify,
),
),
**kwargs,
)
__all__: List[str] = [
"DiracClient",
] # Add all objects you want publicly available to users at this package level

@property
def client_id(self):
return self._client_id

def __aenter__(self) -> "DiracClient":
"""Redefined to provide the patched Dirac client in the managed context"""
self._client.__enter__()
return self
class DiracClient(DiracClientMixin, DiracGenerated): ... # type: ignore
Loading

0 comments on commit 171d908

Please sign in to comment.