diff --git a/lib/tool_shed/util/repository_util.py b/lib/tool_shed/util/repository_util.py index 2f242a5d66f0..1cc1449bc40b 100644 --- a/lib/tool_shed/util/repository_util.py +++ b/lib/tool_shed/util/repository_util.py @@ -9,9 +9,12 @@ ) from markupsafe import escape -from sqlalchemy import false +from sqlalchemy import ( + delete, + false, + select, +) from sqlalchemy.orm import joinedload -from sqlalchemy.sql import select import tool_shed.dependencies.repository from galaxy import ( @@ -232,7 +235,7 @@ def create_repository( if category_ids: # Create category associations for category_id in category_ids: - category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id)) + category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) rca = app.model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) flush_needed = True @@ -335,39 +338,20 @@ def get_repo_info_dict(trans: "ProvidesRepositoriesContext", repository_id, chan def get_repositories_by_category( app: "ToolShedApp", category_id, installable=False, sort_order="asc", sort_key="name", page=None, per_page=25 ): - sa_session = app.model.session - query = ( - sa_session.query(app.model.Repository) - .join( - app.model.RepositoryCategoryAssociation, - app.model.Repository.id == app.model.RepositoryCategoryAssociation.repository_id, - ) - .join(app.model.User, app.model.User.id == app.model.Repository.user_id) - .filter(app.model.RepositoryCategoryAssociation.category_id == category_id) - ) - if installable: - subquery = select(app.model.RepositoryMetadata.table.c.repository_id) - query = query.filter(app.model.Repository.id.in_(subquery)) - if sort_key == "owner": - query = ( - query.order_by(app.model.User.username) - if sort_order == "asc" - else query.order_by(app.model.User.username.desc()) - ) - else: - query = ( - query.order_by(app.model.Repository.name) - if sort_order == "asc" - else query.order_by(app.model.Repository.name.desc()) - ) - if page is not None: - page = int(page) - query = query.limit(per_page) - if page > 1: - query = query.offset((page - 1) * per_page) - resultset = query.all() repositories = [] - for repository in resultset: + for repository in get_repositories( + app.model.session, + app.model.Repository, + app.model.RepositoryCategoryAssociation, + app.model.User, + app.model.RepositoryMetadata, + category_id, + installable, + sort_order, + sort_key, + page, + per_page, + ): default_value_mapper = { "id": app.security.encode_id, "user_id": app.security.encode_id, @@ -396,7 +380,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): repository_owner = repository.user if kwd.get("manage_role_associations_button", False): in_users_list = util.listify(kwd.get("in_users", [])) - in_users = [sa_session.query(app.model.User).get(x) for x in in_users_list] + in_users = [sa_session.get(app.model.User, x) for x in in_users_list] # Make sure the repository owner is always associated with the repostory's admin role. owner_associated = False for user in in_users: @@ -408,7 +392,7 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): message += "The repository owner must always be associated with the repository's administrator role. " status = "error" in_groups_list = util.listify(kwd.get("in_groups", [])) - in_groups = [sa_session.query(app.model.Group).get(x) for x in in_groups_list] + in_groups = [sa_session.get(app.model.Group, x) for x in in_groups_list] in_repositories = [repository] app.security_agent.set_entity_role_associations( roles=[role], users=in_users, groups=in_groups, repositories=in_repositories @@ -424,20 +408,12 @@ def handle_role_associations(app: "ToolShedApp", role, repository, **kwd): out_users = [] in_groups = [] out_groups = [] - for user in ( - sa_session.query(app.model.User) - .filter(app.model.User.table.c.deleted == false()) - .order_by(app.model.User.table.c.email) - ): + for user in get_current_users(sa_session, app.model.User): if user in [x.user for x in role.users]: in_users.append((user.id, user.email)) else: out_users.append((user.id, user.email)) - for group in ( - sa_session.query(app.model.Group) - .filter(app.model.Group.table.c.deleted == false()) - .order_by(app.model.Group.table.c.name) - ): + for group in get_current_groups(sa_session, app.model.Group): if group in [x.group for x in role.groups]: in_groups.append((group.id, group.name)) else: @@ -467,7 +443,7 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op message = None flush_needed = False sa_session = app.model.session - repository = sa_session.query(app.model.Repository).get(app.security.decode_id(id)) + repository = sa_session.get(app.model.Repository, app.security.decode_id(id)) if repository is None: return None, "Unknown repository ID" @@ -483,17 +459,14 @@ def update_repository(trans: "ProvidesUserContext", id: str, **kwds) -> Tuple[Op flush_needed = True if "category_ids" in kwds and isinstance(kwds["category_ids"], list): - # Get existing category associations - category_associations = sa_session.query(app.model.RepositoryCategoryAssociation).filter( - app.model.RepositoryCategoryAssociation.table.c.repository_id == app.security.decode_id(id) + # Remove existing category associations + delete_repository_category_associations( + sa_session, app.model.RepositoryCategoryAssociation, app.security.decode_id(id) ) - # Remove all of them - for rca in category_associations: - sa_session.delete(rca) # Then (re)create category associations for category_id in kwds["category_ids"]: - category = sa_session.query(app.model.Category).get(app.security.decode_id(category_id)) + category = sa_session.get(app.model.Category, app.security.decode_id(category_id)) if category: rca = app.model.RepositoryCategoryAssociation(repository, category) sa_session.add(rca) @@ -562,6 +535,66 @@ def validate_repository_name(app: "ToolShedApp", name, user): return "" +def get_repositories( + session, + repository_model, + repository_category_assoc_model, + user_model, + repository_metadata_model, + category_id, + installable, + sort_order, + sort_key, + page, + per_page, +): + Repository = repository_model + RepositoryCategoryAssociation = repository_category_assoc_model + User = user_model + RepositoryMetadata = repository_metadata_model + + stmt = ( + select(Repository) + .join( + RepositoryCategoryAssociation, + Repository.id == RepositoryCategoryAssociation.repository_id, + ) + .join(User, User.id == Repository.user_id) + .where(RepositoryCategoryAssociation.category_id == category_id) + ) + if installable: + stmt1 = select(RepositoryMetadata.repository_id) + stmt = stmt.where(Repository.id.in_(stmt1)) + if sort_key == "owner": + stmt = stmt.order_by(User.username) + else: + stmt = stmt.order_by(Repository.name) + if sort_order == "desc": + stmt = stmt.desc() + if page is not None: + page = int(page) + stmt = stmt.limit(per_page) + if page > 1: + stmt = stmt.offset((page - 1) * per_page) + + return session.scalars(stmt) + + +def get_current_users(session, user_model): + stmt = select(user_model).where(user_model.deleted == false()).order_by(user_model.email) + return session.scalars(stmt) + + +def get_current_groups(session, group_model): + stmt = select(group_model).where(group_model.deleted == false()).order_by(group_model.name) + return session.scalars(stmt) + + +def delete_repository_category_associations(session, repository_category_assoc_model, repository_id): + stmt = delete(repository_category_assoc_model).where(repository_category_assoc_model.repository_id == repository_id) + return session.execute(stmt) + + __all__ = ( "change_repository_name_in_hgrc_file", "create_or_update_tool_shed_repository",