Skip to content

Commit

Permalink
Fix SA2.0 usage in tool_shed.util.repository_util
Browse files Browse the repository at this point in the history
  • Loading branch information
jdavcs committed Oct 12, 2023
1 parent 43500cf commit bf919c0
Showing 1 changed file with 88 additions and 55 deletions.
143 changes: 88 additions & 55 deletions lib/tool_shed/util/repository_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
)

from markupsafe import escape
from sqlalchemy import false
from sqlalchemy import (
delete,
false,
select,
)
from sqlalchemy.orm import joinedload
from sqlalchemy.sql import select

import tool_shed.dependencies.repository
from galaxy import (
Expand Down Expand Up @@ -232,7 +235,7 @@ def create_repository(
if category_ids:
# Create category associations
for category_id in category_ids:
category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id))
category = sa_session.get(app.model.Category, app.security.decode_id(category_id))
rca = app.model.RepositoryCategoryAssociation(repository, category)
sa_session.add(rca)
flush_needed = True
Expand Down Expand Up @@ -335,39 +338,20 @@ def get_repo_info_dict(trans: "ProvidesRepositoriesContext", repository_id, chan
def get_repositories_by_category(
app: "ToolShedApp", category_id, installable=False, sort_order="asc", sort_key="name", page=None, per_page=25
):
sa_session = app.model.session
query = (
sa_session.query(app.model.Repository)
.join(
app.model.RepositoryCategoryAssociation,
app.model.Repository.id == app.model.RepositoryCategoryAssociation.repository_id,
)
.join(app.model.User, app.model.User.id == app.model.Repository.user_id)
.filter(app.model.RepositoryCategoryAssociation.category_id == category_id)
)
if installable:
subquery = select(app.model.RepositoryMetadata.table.c.repository_id)
query = query.filter(app.model.Repository.id.in_(subquery))
if sort_key == "owner":
query = (
query.order_by(app.model.User.username)
if sort_order == "asc"
else query.order_by(app.model.User.username.desc())
)
else:
query = (
query.order_by(app.model.Repository.name)
if sort_order == "asc"
else query.order_by(app.model.Repository.name.desc())
)
if page is not None:
page = int(page)
query = query.limit(per_page)
if page > 1:
query = query.offset((page - 1) * per_page)
resultset = query.all()
repositories = []
for repository in resultset:
for repository in get_repositories(
app.model.session,
app.model.Repository,
app.model.RepositoryCategoryAssociation,
app.model.User,
app.model.RepositoryMetadata,
category_id,
installable,
sort_order,
sort_key,
page,
per_page,
):
default_value_mapper = {
"id": app.security.encode_id,
"user_id": app.security.encode_id,
Expand Down Expand Up @@ -396,7 +380,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd):
repository_owner = repository.user
if kwd.get("manage_role_associations_button", False):
in_users_list = util.listify(kwd.get("in_users", []))
in_users = [sa_session.query(app.model.User).get(x) for x in in_users_list]
in_users = [sa_session.get(app.model.User, x) for x in in_users_list]
# Make sure the repository owner is always associated with the repostory's admin role.
owner_associated = False
for user in in_users:
Expand All @@ -408,7 +392,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd):
message += "The repository owner must always be associated with the repository's administrator role. "
status = "error"
in_groups_list = util.listify(kwd.get("in_groups", []))
in_groups = [sa_session.query(app.model.Group).get(x) for x in in_groups_list]
in_groups = [sa_session.get(app.model.Group, x) for x in in_groups_list]
in_repositories = [repository]
app.security_agent.set_entity_role_associations(
roles=[role], users=in_users, groups=in_groups, repositories=in_repositories
Expand All @@ -424,20 +408,12 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd):
out_users = []
in_groups = []
out_groups = []
for user in (
sa_session.query(app.model.User)
.filter(app.model.User.table.c.deleted == false())
.order_by(app.model.User.table.c.email)
):
for user in get_current_users(sa_session, app.model.User):
if user in [x.user for x in role.users]:
in_users.append((user.id, user.email))
else:
out_users.append((user.id, user.email))
for group in (
sa_session.query(app.model.Group)
.filter(app.model.Group.table.c.deleted == false())
.order_by(app.model.Group.table.c.name)
):
for group in get_current_groups(sa_session, app.model.Group):
if group in [x.group for x in role.groups]:
in_groups.append((group.id, group.name))
else:
Expand Down Expand Up @@ -467,7 +443,7 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op
message = None
flush_needed = False
sa_session = app.model.session
repository = sa_session.query(app.model.Repository).get(app.security.decode_id(id))
repository = sa_session.get(app.model.Repository, app.security.decode_id(id))
if repository is None:
return None, "Unknown repository ID"

Expand All @@ -483,17 +459,14 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op
flush_needed = True

if "category_ids" in kwds and isinstance(kwds["category_ids"], list):
# Get existing category associations
category_associations = sa_session.query(app.model.RepositoryCategoryAssociation).filter(
app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id)
# Remove existing category associations
delete_repository_category_associations(
sa_session, app.model.RepositoryCategoryAssociation, app.security.decode_id(id)
)
# Remove all of them
for rca in category_associations:
sa_session.delete(rca)

# Then (re)create category associations
for category_id in kwds["category_ids"]:
category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id))
category = sa_session.get(app.model.Category, app.security.decode_id(category_id))
if category:
rca = app.model.RepositoryCategoryAssociation(repository, category)
sa_session.add(rca)
Expand Down Expand Up @@ -562,6 +535,66 @@ def validate_repository_name(app: "ToolShedApp", name, user):
return ""


def get_repositories(
session,
repository_model,
repository_category_assoc_model,
user_model,
repository_metadata_model,
category_id,
installable,
sort_order,
sort_key,
page,
per_page,
):
Repository = repository_model
RepositoryCategoryAssociation = repository_category_assoc_model
User = user_model
RepositoryMetadata = repository_metadata_model

stmt = (
select(Repository)
.join(
RepositoryCategoryAssociation,
Repository.id == RepositoryCategoryAssociation.repository_id,
)
.join(User, User.id == Repository.user_id)
.where(RepositoryCategoryAssociation.category_id == category_id)
)
if installable:
stmt1 = select(RepositoryMetadata.repository_id)
stmt = stmt.where(Repository.id.in_(stmt1))
if sort_key == "owner":
stmt = stmt.order_by(User.username)
else:
stmt = stmt.order_by(Repository.name)
if sort_order == "desc":
stmt = stmt.desc()
if page is not None:
page = int(page)
stmt = stmt.limit(per_page)
if page > 1:
stmt = stmt.offset((page - 1) * per_page)

return session.scalars(stmt)


def get_current_users(session, user_model):
stmt = select(user_model).where(user_model.deleted == false()).order_by(user_model.email)
return session.scalars(stmt)


def get_current_groups(session, group_model):
stmt = select(group_model).where(group_model.deleted == false()).order_by(group_model.name)
return session.scalars(stmt)


def delete_repository_category_associations(session, repository_category_assoc_model, repository_id):
stmt = delete(repository_category_assoc_model).where(repository_category_assoc_model.repository_id == repository_id)
return session.execute(stmt)


__all__ = (
"change_repository_name_in_hgrc_file",
"create_or_update_tool_shed_repository",
Expand Down

0 comments on commit bf919c0

Please sign in to comment.