Skip to content

Commit

Permalink
Refactors credentials management methods to only include db connectio…
Browse files Browse the repository at this point in the history
…n part

Simplifies and modularizes methods for adding user credentials, groups, variables, and secrets
Improves code clarity and maintainability by breaking down complex methods into smaller, focused functions
  • Loading branch information
arash77 committed Jan 16, 2025
1 parent f483a43 commit 7f6bd6f
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 84 deletions.
165 changes: 85 additions & 80 deletions lib/galaxy/managers/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,14 @@
)
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.credentials import (
CreateSourceCredentialsPayload,
SOURCE_TYPE,
)
from galaxy.schema.credentials import SOURCE_TYPE
from galaxy.schema.fields import DecodedDatabaseIdField
from galaxy.schema.schema import FlexibleUserIdType
from galaxy.security.vault import UserVaultWrapper
from galaxy.structured_app import StructuredApp


class CredentialsManager:
"""Manager object shared by controllers for interacting with credentials."""

def __init__(self, app: StructuredApp) -> None:
self._app = app

def get_user_credentials(
self,
trans: ProvidesUserContext,
Expand Down Expand Up @@ -87,80 +79,93 @@ def fetch_credentials(

return variables, secrets

def create_or_update_credentials(
def add_user_credentials(
self,
trans: ProvidesUserContext,
payload: CreateSourceCredentialsPayload,
session: galaxy_scoped_session,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
user_id: DecodedDatabaseIdField,
reference: str,
source_type: SOURCE_TYPE,
source_id: str,
) -> DecodedDatabaseIdField:
user_credentials = next((uc[0] for uc in db_user_credentials if uc[0].reference == reference), None)
if not user_credentials:
user_credentials = UserCredentials(
user_id=user_id,
reference=reference,
source_type=source_type,
source_id=source_id,
)
session.add(user_credentials)
session.flush()
return user_credentials.id

def add_group(
self,
session: galaxy_scoped_session,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
user_credentials_id: DecodedDatabaseIdField,
group_name: str,
reference: str,
) -> DecodedDatabaseIdField:
credentials_group = next(
(uc[1] for uc in db_user_credentials if uc[1].name == group_name and uc[0].reference == reference),
None,
)
if not credentials_group:
credentials_group = CredentialsGroup(name=group_name, user_credentials_id=user_credentials_id)
session.add(credentials_group)
session.flush()
return credentials_group.id

def add_variable(
self,
session: galaxy_scoped_session,
variables: List[Variable],
user_credential_group_id: DecodedDatabaseIdField,
variable_name: str,
variable_value: str,
) -> None:
variable = next(
(var for var in variables if var.name == variable_name),
None,
)
if variable:
variable.value = variable_value
else:
variable = Variable(
user_credential_group_id=user_credential_group_id,
name=variable_name,
value=variable_value,
)
session.add(variable)

def add_secret(
self,
session: galaxy_scoped_session,
secrets: List[Secret],
user_credential_group_id: DecodedDatabaseIdField,
secret_name: str,
secret_value: str,
) -> None:
secret = next(
(sec for sec in secrets if sec.name == secret_name),
None,
)
if secret:
secret.already_set = True if secret_value else False
else:
secret = Secret(
user_credential_group_id=user_credential_group_id,
name=secret_name,
already_set=True if secret_value else False,
)
session.add(secret)

def commit_session(
self,
session: galaxy_scoped_session,
) -> None:
session = trans.sa_session
for service_payload in payload.credentials:
reference = service_payload.reference
current_group_name = service_payload.current_group
if not current_group_name:
current_group_name = "default"
user_credentials = next((uc[0] for uc in db_user_credentials if uc[0].reference == reference), None)
if not user_credentials:
user_credentials = UserCredentials(
user_id=trans.user.id,
reference=reference,
source_type=payload.source_type,
source_id=payload.source_id,
)
session.add(user_credentials)
session.flush()
user_credentials_id = user_credentials.id

for group in service_payload.groups:
group_name = group.name
credentials_group = next(
(uc[1] for uc in db_user_credentials if uc[1].name == group_name and uc[0].reference == reference),
None,
)
if not credentials_group:
credentials_group = CredentialsGroup(name=group_name, user_credentials_id=user_credentials_id)
session.add(credentials_group)
session.flush()
user_credential_group_id = credentials_group.id
variables, secrets = self.fetch_credentials(trans.sa_session, user_credential_group_id)
user_vault = UserVaultWrapper(self._app.vault, trans.user)
for variable_payload in group.variables:
variable_name, variable_value = variable_payload.name, variable_payload.value
if variable_value is None:
continue
variable = next(
(var for var in variables if var.name == variable_name),
None,
)
if variable:
variable.value = variable_value
else:
variable = Variable(
user_credential_group_id=user_credential_group_id,
name=variable_name,
value=variable_value,
)
session.add(variable)
for secret_payload in group.secrets:
secret_name, secret_value = secret_payload.name, secret_payload.value
if secret_value is None:
continue
secret = next(
(sec for sec in secrets if sec.name == secret_name),
None,
)
if secret:
secret.already_set = True if secret_value else False
else:
secret = Secret(
user_credential_group_id=user_credential_group_id,
name=secret_name,
already_set=True if secret_value else False,
)
session.add(secret)
vault_ref = f"{payload.source_type}|{payload.source_id}|{reference}|{group_name}|{secret_name}"
user_vault.write_secret(vault_ref, secret_value)
self.update_current_group(trans, user_credentials_id, current_group_name)
with transaction(session):
session.commit()

Expand Down
54 changes: 50 additions & 4 deletions lib/galaxy/webapps/galaxy/services/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,19 @@
)
from galaxy.schema.fields import DecodedDatabaseIdField
from galaxy.schema.schema import FlexibleUserIdType
from galaxy.security.vault import UserVaultWrapper
from galaxy.structured_app import StructuredApp


class CredentialsService:
"""Service object shared by controllers for interacting with credentials."""

def __init__(
self,
app: StructuredApp,
credentials_manager: CredentialsManager,
) -> None:
self._app = app
self._credentials_manager = credentials_manager

def list_user_credentials(
Expand All @@ -55,10 +59,8 @@ def provide_credential(
payload: CreateSourceCredentialsPayload,
) -> UserCredentialsListResponse:
"""Allows users to provide credentials for a group of secrets and variables."""
source_type, source_id = payload.source_type, payload.source_id
db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id)
self._credentials_manager.create_or_update_credentials(trans, payload, db_user_credentials)
return self._list_user_credentials(trans, user_id, source_type, source_id)
self._create_or_update_credentials(trans, user_id, payload)
return self._list_user_credentials(trans, user_id, payload.source_type, payload.source_id)

def delete_credentials(
self,
Expand Down Expand Up @@ -140,3 +142,47 @@ def _map_user_credentials(
)

return user_credentials_dict

def _create_or_update_credentials(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
payload: CreateSourceCredentialsPayload,
) -> None:
session = trans.sa_session
source_type, source_id = payload.source_type, payload.source_id
db_user_credentials = self._credentials_manager.get_user_credentials(trans, user_id, source_type, source_id)
for service_payload in payload.credentials:
reference = service_payload.reference
current_group_name = service_payload.current_group
if not current_group_name:
current_group_name = "default"
user_credentials_id = self._credentials_manager.add_user_credentials(
session, db_user_credentials, trans.user.id, reference, source_type, source_id
)
for group in service_payload.groups:
group_name = group.name
user_credential_group_id = self._credentials_manager.add_group(
session, db_user_credentials, user_credentials_id, group_name, reference
)
variables, secrets = self._credentials_manager.fetch_credentials(session, user_credential_group_id)
user_vault = UserVaultWrapper(self._app.vault, trans.user)
for variable_payload in group.variables:
variable_name, variable_value = variable_payload.name, variable_payload.value
if variable_value is None:
continue
self._credentials_manager.add_variable(
session, variables, user_credential_group_id, variable_name, variable_value
)
for secret_payload in group.secrets:
secret_name, secret_value = secret_payload.name, secret_payload.value
if secret_value is None:
continue
vault_ref = f"{source_type}|{source_id}|{reference}|{group_name}|{secret_name}"
user_vault.write_secret(vault_ref, secret_value)
self._credentials_manager.add_secret(
session, secrets, user_credential_group_id, secret_name, secret_value
)

self._credentials_manager.update_current_group(trans, user_credentials_id, current_group_name)
self._credentials_manager.commit_session(session)

0 comments on commit 7f6bd6f

Please sign in to comment.