Skip to content

Commit

Permalink
Refactoring CredentialsService and Add CredentialsManager for seperat…
Browse files Browse the repository at this point in the history
…ing user credential management
  • Loading branch information
arash77 committed Jan 15, 2025
1 parent 8fbd065 commit 68fbd59
Show file tree
Hide file tree
Showing 3 changed files with 260 additions and 261 deletions.
191 changes: 191 additions & 0 deletions lib/galaxy/managers/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)

from sqlalchemy import select
from sqlalchemy.orm import aliased

from galaxy.exceptions import (
AuthenticationRequired,
ItemOwnershipException,
RequestParameterInvalidException,
)
from galaxy.managers.context import ProvidesUserContext
from galaxy.model import (
CredentialsGroup,
Secret,
UserCredentials,
Variable,
)
from galaxy.model.base import transaction
from galaxy.model.scoped_session import galaxy_scoped_session
from galaxy.schema.credentials import (
CreateSourceCredentialsPayload,
SOURCE_TYPE,
)
from galaxy.schema.fields import DecodedDatabaseIdField
from galaxy.schema.schema import FlexibleUserIdType
from galaxy.security.vault import UserVaultWrapper
from galaxy.structured_app import StructuredApp


class CredentialsManager:
"""Manager object shared by controllers for interacting with credentials."""

def __init__(self, app: StructuredApp) -> None:
self._app = app

def get_user_credentials(
self,
trans: ProvidesUserContext,
user_id: FlexibleUserIdType,
source_type: Optional[SOURCE_TYPE] = None,
source_id: Optional[str] = None,
group_name: Optional[str] = None,
user_credentials_id: Optional[DecodedDatabaseIdField] = None,
group_id: Optional[DecodedDatabaseIdField] = None,
) -> List[Tuple[UserCredentials, CredentialsGroup]]:
if trans.anonymous:
raise AuthenticationRequired("You need to be logged in to access your credentials.")
if user_id == "current":
user_id = trans.user.id
elif trans.user.id != user_id:
raise ItemOwnershipException("You can only access your own credentials.")
user_cred_alias, group_alias = aliased(UserCredentials), aliased(CredentialsGroup)
stmt = (
select(user_cred_alias, group_alias)
.join(group_alias, group_alias.user_credentials_id == user_cred_alias.id)
.where(user_cred_alias.user_id == user_id)
)
if source_type:
stmt = stmt.where(user_cred_alias.source_type == source_type)
if source_id:
if not source_type:
raise RequestParameterInvalidException("Source type is required when source ID is provided.")
stmt = stmt.where(user_cred_alias.source_id == source_id)
if group_name:
stmt = stmt.where(group_alias.name == group_name)
if user_credentials_id:
stmt = stmt.where(user_cred_alias.id == user_credentials_id)
if group_id:
stmt = stmt.where(group_alias.id == group_id)

result = trans.sa_session.execute(stmt).all()
return [(row[0], row[1]) for row in result]

def fetch_credentials(
self,
session: galaxy_scoped_session,
group_id: DecodedDatabaseIdField,
) -> Tuple[List[Variable], List[Secret]]:
variables = list(
session.execute(select(Variable).where(Variable.user_credential_group_id == group_id)).scalars().all()
)

secrets = list(
session.execute(select(Secret).where(Secret.user_credential_group_id == group_id)).scalars().all()
)

return variables, secrets

def create_or_update_credentials(
self,
trans: ProvidesUserContext,
payload: CreateSourceCredentialsPayload,
db_user_credentials: List[Tuple[UserCredentials, CredentialsGroup]],
credentials_dict: Dict[int, Dict[str, Any]],
) -> None:
session = trans.sa_session
existing_groups = {
cred["reference"]: {group["name"]: group["id"] for group in cred["groups"].values()}
for cred in credentials_dict.values()
}
for service_payload in payload.credentials:
reference = service_payload.reference
current_group_name = service_payload.current_group
current_group_id = existing_groups.get(reference, {}).get(current_group_name)

user_credentials = next((uc[0] for uc in db_user_credentials if uc[0].reference == reference), None)
if not user_credentials:
user_credentials = UserCredentials(
user_id=trans.user.id,
reference=reference,
source_type=payload.source_type,
source_id=payload.source_id,
)
session.add(user_credentials)
session.flush()
user_credentials_id = user_credentials.id

for group in service_payload.groups:
group_name = group.name
credentials_group = next(
(uc[1] for uc in db_user_credentials if uc[1].name == group_name and uc[0].reference == reference),
None,
)
if not credentials_group:
credentials_group = CredentialsGroup(name=group_name, user_credentials_id=user_credentials_id)
session.add(credentials_group)
session.flush()
user_credential_group_id = credentials_group.id
if current_group_name == group_name:
current_group_id = user_credential_group_id
variables, secrets = self.fetch_credentials(trans.sa_session, user_credential_group_id)
user_vault = UserVaultWrapper(self._app.vault, trans.user)
for variable_payload in group.variables:
variable_name, variable_value = variable_payload.name, variable_payload.value
if variable_value is None:
continue
variable = next(
(var for var in variables if var.name == variable_name),
None,
)
if variable:
variable.value = variable_value
else:
variable = Variable(
user_credential_group_id=user_credential_group_id,
name=variable_name,
value=variable_value,
)
session.add(variable)
for secret_payload in group.secrets:
secret_name, secret_value = secret_payload.name, secret_payload.value
if secret_value is None:
continue
secret = next(
(sec for sec in secrets if sec.name == secret_name),
None,
)
if secret:
secret.already_set = True if secret_value else False
else:
secret = Secret(
user_credential_group_id=user_credential_group_id,
name=secret_name,
already_set=True if secret_value else False,
)
session.add(secret)
vault_ref = f"{payload.source_type}|{payload.source_id}|{reference}|{group_name}|{secret_name}"
user_vault.write_secret(vault_ref, secret_value)
if not current_group_id:
raise RequestParameterInvalidException("No current group selected.")
user_credentials.current_group_id = current_group_id
session.add(user_credentials)

with transaction(session):
session.commit()

def delete_rows(
self,
session: galaxy_scoped_session,
rows_to_delete: List,
) -> None:
for row in rows_to_delete:
session.delete(row)
with transaction(session):
session.commit()
2 changes: 1 addition & 1 deletion lib/galaxy/webapps/galaxy/api/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def delete_service_credentials(
user_credentials_id: DecodedDatabaseIdField,
trans: ProvidesUserContext = DependsOnTrans,
):
self.service.delete_service_credentials(trans, user_id, user_credentials_id)
self.service.delete_credentials(trans, user_id, user_credentials_id)
return Response(status_code=status.HTTP_204_NO_CONTENT)

@router.delete(
Expand Down
Loading

0 comments on commit 68fbd59

Please sign in to comment.