Skip to content

Commit

Permalink
Merge pull request #1020 from jmmshn/jmmshn/cli
Browse files Browse the repository at this point in the history
[Feature] Allow Different Azure Authentication Methods
  • Loading branch information
rkingsbury authored Dec 28, 2024
2 parents 9589881 + 5f8f988 commit 4ed05b8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 4 deletions.
34 changes: 30 additions & 4 deletions src/maggma/stores/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Advanced Stores for connecting to Microsoft Azure data.
"""

import importlib
import os
import threading
import warnings
Expand All @@ -11,7 +12,7 @@
from concurrent.futures.thread import ThreadPoolExecutor
from hashlib import sha1
from json import dumps
from typing import Optional, Union
from typing import Literal, Optional, Union

import msgpack # type: ignore
from monty.msgpack import default as monty_default
Expand All @@ -23,15 +24,34 @@
import azure
import azure.storage.blob as azure_blob
from azure.core.exceptions import ResourceExistsError
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient, ContainerClient


except (ImportError, ModuleNotFoundError):
azure_blob = None # type: ignore
ContainerClient = None


AZURE_KEY_SANITIZE = {"-": "_", ".": "_"}

CredentialType = Literal[
"DefaultAzureCredential",
"AzureCliCredential",
"ManagedIdentityCredential",
]


def _get_azure_credential(credential_class):
"""Import the azure.identity module and return the credential class.
If the credential_class is a class, return an instance of it.
If the credential_class is a string, import the module first
"""
if isinstance(credential_class, str):
module_name = "azure.identity"
credential_class = getattr(importlib.import_module(module_name), credential_class)
return credential_class()


class AzureBlobStore(Store):
"""
Expand All @@ -45,6 +65,7 @@ def __init__(
index: Store,
container_name: str,
azure_client_info: Optional[Union[str, dict]] = None,
credential_type: CredentialType = "DefaultAzureCredential",
compress: bool = False,
sub_dir: Optional[str] = None,
workers: int = 1,
Expand All @@ -69,6 +90,10 @@ def __init__(
BlobServiceClient.
Currently supported keywords:
- connection_string: a connection string for the Azure blob
credential_type: the type of credential to use to authenticate with Azure.
Default is "DefaultAzureCredential". For serializable stores, provide
a string representation of the credential class. Otherwises, you may
provide the class itself.
compress: compress files inserted into the store
sub_dir: (optional) subdirectory of the container to store the data.
When defined, a final "/" will be added if not already present.
Expand Down Expand Up @@ -104,6 +129,7 @@ def __init__(
key_sanitize_dict = AZURE_KEY_SANITIZE
self.key_sanitize_dict = key_sanitize_dict
self.create_container = create_container
self.credential_type = credential_type

# Force the key to be the same as the index
assert isinstance(
Expand Down Expand Up @@ -351,8 +377,8 @@ def _get_service_client(self):
if not hasattr(self._thread_local, "container"):
if isinstance(self.azure_client_info, str):
# assume it is the account_url and that the connection is passwordless
default_credential = DefaultAzureCredential()
return BlobServiceClient(self.azure_client_info, credential=default_credential)
credentials_ = _get_azure_credential(self.credential_type)
return BlobServiceClient(self.azure_client_info, credential=credentials_)

if isinstance(self.azure_client_info, dict):
connection_string = self.azure_client_info.get("connection_string")
Expand Down
31 changes: 31 additions & 0 deletions tests/stores/test_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,34 @@ def test_no_login():

with pytest.raises(RuntimeError, match=r".*Could not instantiate BlobServiceClient.*"):
store.connect()


def test_credential_type_valid():
credential_type = "DefaultAzureCredential"
index = MemoryStore("index")
store = AzureBlobStore(
index,
AZURITE_CONTAINER_NAME,
azure_client_info="client_url",
credential_type=credential_type,
)
assert store.credential_type == credential_type

# tricks the store into thinking you already
# provided the blob service client so it skips
# the connection checks. We are only testing that
# the credential import works properly
store.service = True
store.connect()

from azure.identity import DefaultAzureCredential

credential_type = DefaultAzureCredential
index = MemoryStore("index")
store = AzureBlobStore(
index,
AZURITE_CONTAINER_NAME,
azure_client_info="client_url",
credential_type=credential_type,
)
assert not isinstance(store.credential_type, str)

0 comments on commit 4ed05b8

Please sign in to comment.