From 72b43771513f4006025000a48fb24943c2951028 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:11:13 +0100 Subject: [PATCH 01/14] =?UTF-8?q?=F0=9F=94=A7=20MAINTAIN:=20Improve=20typi?= =?UTF-8?q?ng=20`Entity.objects`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/orm/authinfos.py | 2 +- aiida/orm/comments.py | 2 +- aiida/orm/computers.py | 2 +- aiida/orm/groups.py | 2 +- aiida/orm/logs.py | 2 +- aiida/orm/users.py | 2 +- 6 files changed, 6 insertions(+), 6 deletions(-) diff --git a/aiida/orm/authinfos.py b/aiida/orm/authinfos.py index 7a398b20bd..1466f2df54 100644 --- a/aiida/orm/authinfos.py +++ b/aiida/orm/authinfos.py @@ -46,7 +46,7 @@ class AuthInfo(entities.Entity['BackendAuthInfo']): Collection = AuthInfoCollection @classproperty - def objects(cls) -> AuthInfoCollection: # pylint: disable=no-self-argument + def objects(cls: Type['AuthInfo']) -> AuthInfoCollection: # type: ignore[misc] # pylint: disable=no-self-argument return AuthInfoCollection.get_cached(cls, get_manager().get_backend()) PROPERTY_WORKDIR = 'workdir' diff --git a/aiida/orm/comments.py b/aiida/orm/comments.py index a08820738f..de7b74698d 100644 --- a/aiida/orm/comments.py +++ b/aiida/orm/comments.py @@ -69,7 +69,7 @@ class Comment(entities.Entity['BackendComment']): Collection = CommentCollection @classproperty - def objects(cls) -> CommentCollection: # pylint: disable=no-self-argument + def objects(cls: Type['Comment']) -> CommentCollection: # type: ignore[misc] # pylint: disable=no-self-argument return CommentCollection.get_cached(cls, get_manager().get_backend()) def __init__(self, node: 'Node', user: 'User', content: Optional[str] = None, backend: Optional['Backend'] = None): diff --git a/aiida/orm/computers.py b/aiida/orm/computers.py index 5af8102104..63d45129d5 100644 --- a/aiida/orm/computers.py +++ b/aiida/orm/computers.py @@ -78,7 +78,7 @@ class Computer(entities.Entity['BackendComputer']): Collection = ComputerCollection @classproperty - def objects(cls) -> ComputerCollection: # pylint: disable=no-self-argument + def objects(cls: Type['Computer']) -> ComputerCollection: # type: ignore[misc] # pylint: disable=no-self-argument return ComputerCollection.get_cached(cls, get_manager().get_backend()) def __init__( # pylint: disable=too-many-arguments diff --git a/aiida/orm/groups.py b/aiida/orm/groups.py index 731d820553..e19888cbcc 100644 --- a/aiida/orm/groups.py +++ b/aiida/orm/groups.py @@ -118,7 +118,7 @@ class Group(entities.Entity['BackendGroup'], entities.EntityExtrasMixin, metacla Collection = GroupCollection @classproperty - def objects(cls) -> GroupCollection: # pylint: disable=no-self-argument + def objects(cls: Type['Group']) -> GroupCollection: # type: ignore[misc] # pylint: disable=no-self-argument return GroupCollection.get_cached(cls, get_manager().get_backend()) def __init__( diff --git a/aiida/orm/logs.py b/aiida/orm/logs.py index b15d63d32f..4975559ec4 100644 --- a/aiida/orm/logs.py +++ b/aiida/orm/logs.py @@ -131,7 +131,7 @@ class Log(entities.Entity['BackendLog']): Collection = LogCollection @classproperty - def objects(cls) -> LogCollection: # pylint: disable=no-self-argument + def objects(cls: Type['Log']) -> LogCollection: # type: ignore[misc] # pylint: disable=no-self-argument return LogCollection.get_cached(cls, get_manager().get_backend()) def __init__( diff --git a/aiida/orm/users.py b/aiida/orm/users.py index 07e48f9b3f..8c8e4f3474 100644 --- a/aiida/orm/users.py +++ b/aiida/orm/users.py @@ -78,7 +78,7 @@ class User(entities.Entity['BackendUser']): Collection = UserCollection @classproperty - def objects(cls) -> UserCollection: # pylint: disable=no-self-argument + def objects(cls: Type['User']) -> UserCollection: # type: ignore[misc] # pylint: disable=no-self-argument return UserCollection.get_cached(cls, get_manager().get_backend()) def __init__( From 9f80c905aad082fcb26e271a713ced01bfa66d21 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:15:35 +0100 Subject: [PATCH 02/14] =?UTF-8?q?=F0=9F=94=A7=20MAINTAIN:=20Remove=20dupli?= =?UTF-8?q?cate=20`EntityTypes`=20enum?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/orm/__init__.py | 1 + aiida/orm/entities.py | 2 +- aiida/orm/implementation/querybuilder.py | 29 ++++++++---------------- aiida/orm/querybuilder.py | 11 +++++---- 4 files changed, 18 insertions(+), 25 deletions(-) diff --git a/aiida/orm/__init__.py b/aiida/orm/__init__.py index 4fc360170f..5a2033f302 100644 --- a/aiida/orm/__init__.py +++ b/aiida/orm/__init__.py @@ -53,6 +53,7 @@ 'Entity', 'EntityAttributesMixin', 'EntityExtrasMixin', + 'EntityTypes', 'EnumData', 'Float', 'FolderData', diff --git a/aiida/orm/entities.py b/aiida/orm/entities.py index 5f07cd0b1d..527ad7d89d 100644 --- a/aiida/orm/entities.py +++ b/aiida/orm/entities.py @@ -29,7 +29,7 @@ from aiida.orm.implementation import Backend, BackendEntity from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder -__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin') +__all__ = ('Entity', 'Collection', 'EntityAttributesMixin', 'EntityExtrasMixin', 'EntityTypes') CollectionType = TypeVar('CollectionType', bound='Collection') EntityType = TypeVar('EntityType', bound='Entity') diff --git a/aiida/orm/implementation/querybuilder.py b/aiida/orm/implementation/querybuilder.py index d2fa23ed3f..9155bb1d07 100644 --- a/aiida/orm/implementation/querybuilder.py +++ b/aiida/orm/implementation/querybuilder.py @@ -9,11 +9,11 @@ ########################################################################### """Abstract `QueryBuilder` definition.""" import abc -from enum import Enum from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Set, Union from aiida.common.lang import type_check from aiida.common.log import AIIDA_LOGGER +from aiida.orm.entities import EntityTypes try: from typing import Literal, TypedDict # pylint: disable=ungrouped-imports @@ -28,29 +28,18 @@ QUERYBUILD_LOGGER = AIIDA_LOGGER.getChild('orm.querybuilder') - -class EntityTypes(Enum): - """The entity types and their allowed relationships.""" - AUTHINFO = 'authinfo' - COMMENT = 'comment' - COMPUTER = 'computer' - GROUP = 'group' - LOG = 'log' - NODE = 'node' - USER = 'user' - - EntityRelationships: Dict[str, Set[str]] = { - 'authinfo': {'with_computer', 'with_user'}, - 'comment': {'with_node', 'with_user'}, - 'computer': {'with_node'}, - 'group': {'with_node', 'with_user'}, - 'log': {'with_node'}, - 'node': { + EntityTypes.AUTHINFO.value: {'with_computer', 'with_user'}, + EntityTypes.COMMENT.value: {'with_node', 'with_user'}, + EntityTypes.COMPUTER.value: {'with_node'}, + EntityTypes.GROUP.value: {'with_node', 'with_user'}, + EntityTypes.LOG.value: {'with_node'}, + EntityTypes.NODE.value: { 'with_comment', 'with_log', 'with_incoming', 'with_outgoing', 'with_descendants', 'with_ancestors', 'with_computer', 'with_user', 'with_group' }, - 'user': {'with_authinfo', 'with_comment', 'with_group', 'with_node'} + EntityTypes.USER.value: {'with_authinfo', 'with_comment', 'with_group', 'with_node'}, + EntityTypes.LINK.value: set(), } diff --git a/aiida/orm/querybuilder.py b/aiida/orm/querybuilder.py index d3c04ebc56..94bdce9cde 100644 --- a/aiida/orm/querybuilder.py +++ b/aiida/orm/querybuilder.py @@ -39,11 +39,11 @@ import warnings from aiida.manage.manager import get_manager +from aiida.orm.entities import EntityTypes from aiida.orm.implementation.querybuilder import ( GROUP_ENTITY_TYPE_PREFIX, BackendQueryBuilder, EntityRelationships, - EntityTypes, PathItemType, QueryDictType, ) @@ -1254,12 +1254,15 @@ def _get_ormclass_from_str(type_string: str) -> Tuple[EntityTypes, Classifier]: if type_string_lower.startswith(GROUP_ENTITY_TYPE_PREFIX): classifiers = Classifier('group.core') ormclass = EntityTypes.GROUP - elif type_string_lower == 'computer': + elif type_string_lower == EntityTypes.COMPUTER.value: classifiers = Classifier('computer') ormclass = EntityTypes.COMPUTER - elif type_string_lower == 'user': + elif type_string_lower == EntityTypes.USER.value: classifiers = Classifier('user') ormclass = EntityTypes.USER + elif type_string_lower == EntityTypes.LINK.value: + classifiers = Classifier('link') + ormclass = EntityTypes.LINK else: # At this point, we assume it is a node. The only valid type string then is a string # that matches exactly the _plugin_type_string of a node class @@ -1415,7 +1418,7 @@ def get(self, tag_or_cls: Union[str, EntityClsType]) -> str: if isinstance(tag_or_cls, str): if tag_or_cls in self: return tag_or_cls - raise ValueError(f'Tag {tag_or_cls} is not among my known tags: {list(self)}') + raise ValueError(f'Tag {tag_or_cls!r} is not among my known tags: {list(self)}') if self._cls_to_tag_map.get(tag_or_cls, None): if len(self._cls_to_tag_map[tag_or_cls]) != 1: raise ValueError( From a617829e03e085656064f70301c9cc79d120f978 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:17:31 +0100 Subject: [PATCH 03/14] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Add=20`AuthInfo?= =?UTF-8?q?`=20to=20`verdi=20shell`?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/cmdline/utils/shell.py | 1 + 1 file changed, 1 insertion(+) diff --git a/aiida/cmdline/utils/shell.py b/aiida/cmdline/utils/shell.py index a8c55bdd42..afe85feb2f 100644 --- a/aiida/cmdline/utils/shell.py +++ b/aiida/cmdline/utils/shell.py @@ -31,6 +31,7 @@ ('aiida.orm', 'Group', 'Group'), ('aiida.orm', 'QueryBuilder', 'QueryBuilder'), ('aiida.orm', 'User', 'User'), + ('aiida.orm', 'AuthInfo', 'AuthInfo'), ('aiida.orm', 'load_code', 'load_code'), ('aiida.orm', 'load_computer', 'load_computer'), ('aiida.orm', 'load_group', 'load_group'), From dcb95bc3027332d596b8297c4f17309664e34c20 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:18:57 +0100 Subject: [PATCH 04/14] =?UTF-8?q?=F0=9F=A7=AA=20TESTS:=20yield=20`aiida=5F?= =?UTF-8?q?profile`=20from=20fixtures?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/manage/tests/pytest_fixtures.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/aiida/manage/tests/pytest_fixtures.py b/aiida/manage/tests/pytest_fixtures.py index 9b2be85c7d..0bbc93bcdb 100644 --- a/aiida/manage/tests/pytest_fixtures.py +++ b/aiida/manage/tests/pytest_fixtures.py @@ -60,7 +60,7 @@ def clear_database(clear_database_after_test): @pytest.fixture(scope='function') def clear_database_after_test(aiida_profile): """Clear the database after the test.""" - yield + yield aiida_profile aiida_profile.reset_db() @@ -68,7 +68,7 @@ def clear_database_after_test(aiida_profile): def clear_database_before_test(aiida_profile): """Clear the database before the test.""" aiida_profile.reset_db() - yield + yield aiida_profile @pytest.fixture(scope='class') From ff40dc92aaadab3153373e5498c2545a00bd3739 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:36:57 +0100 Subject: [PATCH 05/14] =?UTF-8?q?=F0=9F=90=9B=20FIX:=20`SqlaBackend.bulk?= =?UTF-8?q?=5Finsert`=20AuthInfo=20metadata?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Convert from `_metadata` -> `metadata` --- aiida/orm/implementation/sqlalchemy/backend.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiida/orm/implementation/sqlalchemy/backend.py b/aiida/orm/implementation/sqlalchemy/backend.py index 01a34c125b..045f156b21 100644 --- a/aiida/orm/implementation/sqlalchemy/backend.py +++ b/aiida/orm/implementation/sqlalchemy/backend.py @@ -139,7 +139,7 @@ def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults mapper, keys = self._get_mapper_from_entity(entity_type, False) if not rows: return [] - if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG): + if entity_type in (EntityTypes.COMPUTER, EntityTypes.LOG, EntityTypes.AUTHINFO): for row in rows: row['_metadata'] = row.pop('metadata') if allow_defaults: From 888df2ac1a408546dac8a4408fe2db24d73a4dc1 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:44:14 +0100 Subject: [PATCH 06/14] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Allow=20for=20n?= =?UTF-8?q?ode=20links=20to=20be=20queried?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Link queries return `LinkQuadruple` --- aiida/orm/convert.py | 4 ++++ aiida/orm/implementation/django/convert.py | 9 ++++++++ .../orm/implementation/sqlalchemy/convert.py | 11 +++++++++- .../sqlalchemy/querybuilder/joiner.py | 1 + .../sqlalchemy/querybuilder/main.py | 4 +++- tests/orm/test_querybuilder.py | 22 +++++++++++++++++++ 6 files changed, 49 insertions(+), 2 deletions(-) diff --git a/aiida/orm/convert.py b/aiida/orm/convert.py index f5bbe1b3a1..ea9dd36bd2 100644 --- a/aiida/orm/convert.py +++ b/aiida/orm/convert.py @@ -51,6 +51,10 @@ def _(backend_entity): Note that we do not register on `collections.abc.Sequence` because that will also match strings. """ + if hasattr(backend_entity, '_asdict'): + # it is a NamedTuple, so return as is + return backend_entity + converted = [] # Note that we cannot use a simple comprehension because raised `TypeError` should be caught here otherwise only diff --git a/aiida/orm/implementation/django/convert.py b/aiida/orm/implementation/django/convert.py index 0bfb836ee4..9e446b2532 100644 --- a/aiida/orm/implementation/django/convert.py +++ b/aiida/orm/implementation/django/convert.py @@ -223,3 +223,12 @@ def _(dbmodel, backend): metadata=dbmodel.metadata # pylint: disable=protected-access ) return logs.DjangoLog.from_dbmodel(djlog, backend) + + +@get_backend_entity.register(djmodels.DbLink.sa) +def _(dbmodel, backend): + """ + Convert a dblink to the backend entity + """ + from aiida.orm.utils.links import LinkQuadruple + return LinkQuadruple(dbmodel.input_id, dbmodel.output_id, dbmodel.type, dbmodel.label) diff --git a/aiida/orm/implementation/sqlalchemy/convert.py b/aiida/orm/implementation/sqlalchemy/convert.py index 5190cf3fa5..c4283474be 100644 --- a/aiida/orm/implementation/sqlalchemy/convert.py +++ b/aiida/orm/implementation/sqlalchemy/convert.py @@ -21,7 +21,7 @@ from aiida.backends.sqlalchemy.models.computer import DbComputer from aiida.backends.sqlalchemy.models.group import DbGroup from aiida.backends.sqlalchemy.models.log import DbLog -from aiida.backends.sqlalchemy.models.node import DbNode +from aiida.backends.sqlalchemy.models.node import DbLink, DbNode from aiida.backends.sqlalchemy.models.user import DbUser __all__ = ('get_backend_entity',) @@ -105,3 +105,12 @@ def _(dbmodel, backend): """ from . import logs return logs.SqlaLog.from_dbmodel(dbmodel, backend) + + +@get_backend_entity.register(DbLink) +def _(dbmodel, backend): + """ + Convert a dblink to the backend entity + """ + from aiida.orm.utils.links import LinkQuadruple + return LinkQuadruple(dbmodel.input_id, dbmodel.output_id, dbmodel.type, dbmodel.label) diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py index c5cb11883f..f6b0cc4b54 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/joiner.py @@ -116,6 +116,7 @@ def _entity_join_map(self) -> Dict[str, Dict[str, JoinFuncType]]: 'with_node': self._join_node_group, 'with_user': self._join_user_group, }, + 'link': {}, 'log': { 'with_node': self._join_node_log, }, diff --git a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py index 099e4bd603..a6dd7a4ede 100644 --- a/aiida/orm/implementation/sqlalchemy/querybuilder/main.py +++ b/aiida/orm/implementation/sqlalchemy/querybuilder/main.py @@ -33,7 +33,8 @@ from sqlalchemy.types import Boolean, DateTime, Float, Integer, String from aiida.common.exceptions import NotExistent -from aiida.orm.implementation.querybuilder import QUERYBUILD_LOGGER, BackendQueryBuilder, EntityTypes, QueryDictType +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation.querybuilder import QUERYBUILD_LOGGER, BackendQueryBuilder, QueryDictType from .joiner import SqlaJoiner @@ -271,6 +272,7 @@ def rebuild_aliases(self) -> None: EntityTypes.NODE.value: self.Node, EntityTypes.LOG.value: self.Log, EntityTypes.USER.value: self.User, + EntityTypes.LINK.value: self.Link, } self._tag_to_alias = {} for path in self._data['path']: diff --git a/tests/orm/test_querybuilder.py b/tests/orm/test_querybuilder.py index 6e70221e65..da8a3cec32 100644 --- a/tests/orm/test_querybuilder.py +++ b/tests/orm/test_querybuilder.py @@ -21,6 +21,7 @@ from aiida.common.links import LinkType from aiida.manage import configuration from aiida.orm.querybuilder import _get_ormclass +from aiida.orm.utils.links import LinkQuadruple @pytest.mark.usefixtures('clear_database_before_test') @@ -672,6 +673,27 @@ def test_flat(): assert len(result) == 20 assert result == list(chain.from_iterable(zip(pks, uuids))) + def test_query_links(self): + """Test querying for links""" + d1, d2, d3, d4 = [orm.Data().store() for _ in range(4)] + c1, c2 = [orm.CalculationNode() for _ in range(2)] + c1.add_incoming(d1, link_type=LinkType.INPUT_CALC, link_label='link_d1c1') + c1.store() + d2.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d2') + d4.add_incoming(c1, link_type=LinkType.CREATE, link_label='link_c1d4') + c2.add_incoming(d2, link_type=LinkType.INPUT_CALC, link_label='link_d2c2') + c2.store() + d3.add_incoming(c2, link_type=LinkType.CREATE, link_label='link_c2d3') + + builder = orm.QueryBuilder().append(entity_type='link') + assert builder.count() == 5 + + builder = orm.QueryBuilder().append(entity_type='link', filters={'type': LinkType.CREATE.value}) + assert builder.count() == 3 + + builder = orm.QueryBuilder().append(entity_type='link', filters={'label': 'link_d2c2'}) + assert builder.one()[0] == LinkQuadruple(d2.id, c2.id, LinkType.INPUT_CALC.value, 'link_d2c2') + @pytest.mark.usefixtures('clear_database_before_test') class TestMultipleProjections: From c5d9b487c4b265883b296336905738ce1c4bb957 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:34:31 +0100 Subject: [PATCH 07/14] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20Extract?= =?UTF-8?q?=20`get=5Fdatabase=5Fsummary`=20to=20common?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This will allow the code to be re-used with the new `verdi archive` commands --- aiida/cmdline/commands/cmd_storage.py | 40 ++------------------ aiida/cmdline/utils/common.py | 54 ++++++++++++++++++++++++++- 2 files changed, 56 insertions(+), 38 deletions(-) diff --git a/aiida/cmdline/commands/cmd_storage.py b/aiida/cmdline/commands/cmd_storage.py index 0430c06092..58b7c96163 100644 --- a/aiida/cmdline/commands/cmd_storage.py +++ b/aiida/cmdline/commands/cmd_storage.py @@ -84,42 +84,8 @@ def storage_integrity(): @click.option('--statistics', is_flag=True, help='Provides more in-detail statistically relevant data.') def storage_info(statistics): """Summarise the contents of the storage.""" - from aiida.orm import Comment, Computer, Group, Log, Node, QueryBuilder, User - data = {} - - # User - query_user = QueryBuilder().append(User, project=['email']) - data['Users'] = {'count': query_user.count()} - if statistics: - data['Users']['emails'] = query_user.distinct().all(flat=True) - - # Computer - query_comp = QueryBuilder().append(Computer, project=['label']) - data['Computers'] = {'count': query_comp.count()} - if statistics: - data['Computers']['labels'] = query_comp.distinct().all(flat=True) - - # Node - count = QueryBuilder().append(Node).count() - data['Nodes'] = {'count': count} - if statistics: - node_types = QueryBuilder().append(Node, project=['node_type']).distinct().all(flat=True) - data['Nodes']['node_types'] = node_types - process_types = QueryBuilder().append(Node, project=['process_type']).distinct().all(flat=True) - data['Nodes']['process_types'] = [p for p in process_types if p] - - # Group - query_group = QueryBuilder().append(Group, project=['type_string']) - data['Groups'] = {'count': query_group.count()} - if statistics: - data['Groups']['type_strings'] = query_group.distinct().all(flat=True) - - # Comment - count = QueryBuilder().append(Comment).count() - data['Comments'] = {'count': count} - - # Log - count = QueryBuilder().append(Log).count() - data['Logs'] = {'count': count} + from aiida.cmdline.utils.common import get_database_summary + from aiida.orm import QueryBuilder + data = get_database_summary(QueryBuilder, statistics) echo.echo_dictionary(data, sort_keys=False, fmt='yaml') diff --git a/aiida/cmdline/utils/common.py b/aiida/cmdline/utils/common.py index 9c57980ab9..cb869b78b0 100644 --- a/aiida/cmdline/utils/common.py +++ b/aiida/cmdline/utils/common.py @@ -11,7 +11,7 @@ import logging import os import sys -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Callable from tabulate import tabulate @@ -529,3 +529,55 @@ def check_worker_load(active_slots): echo.echo_report(f'Using {percent_load * 100:.0f}%% of the available daemon worker slots.') else: echo.echo_report('No active daemon workers.') + + +def get_database_summary(querybuilder: Callable, verbose: bool) -> dict: + """Generate a summary of the database.""" + from aiida.orm import Comment, Computer, Group, Log, Node, User + + data = {} + + # User + query_user = querybuilder().append(User, project=['email']) + data['Users'] = {'count': query_user.count()} + if verbose: + data['Users']['emails'] = sorted({email for email, in query_user.iterall() if email is not None}) + + # Computer + query_comp = querybuilder().append(Computer, project=['label']) + data['Computers'] = {'count': query_comp.count()} + if verbose: + data['Computers']['labels'] = sorted({comp for comp, in query_comp.iterall() if comp is not None}) + + # Node + count = querybuilder().append(Node).count() + data['Nodes'] = {'count': count} + if verbose: + node_types = sorted({ + typ for typ, in querybuilder().append(Node, project=['node_type']).iterall() if typ is not None + }) + data['Nodes']['node_types'] = node_types + process_types = sorted({ + typ for typ, in querybuilder().append(Node, project=['process_type']).iterall() if typ is not None + }) + data['Nodes']['process_types'] = [p for p in process_types if p] + + # Group + query_group = querybuilder().append(Group, project=['type_string']) + data['Groups'] = {'count': query_group.count()} + if verbose: + data['Groups']['type_strings'] = sorted({typ for typ, in query_group.iterall() if typ is not None}) + + # Comment + count = querybuilder().append(Comment).count() + data['Comments'] = {'count': count} + + # Log + count = querybuilder().append(Log).count() + data['Logs'] = {'count': count} + + # Links + count = querybuilder().append(entity_type='link').count() + data['Links'] = {'count': count} + + return data From 724bec5846c5462ad4fffe4a3878b47c878a75d7 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:47:17 +0100 Subject: [PATCH 08/14] =?UTF-8?q?=F0=9F=91=8C=20IMPROVE:=20Add=20`glob`=20?= =?UTF-8?q?method=20for=20node=20repositories?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- aiida/orm/nodes/repository.py | 8 ++++++++ aiida/repository/backend/abstract.py | 2 +- aiida/repository/common.py | 8 ++++++++ tests/orm/node/test_repository.py | 13 +++++++++++++ 4 files changed, 30 insertions(+), 1 deletion(-) diff --git a/aiida/orm/nodes/repository.py b/aiida/orm/nodes/repository.py index f8608299fb..67dc7bb3db 100644 --- a/aiida/orm/nodes/repository.py +++ b/aiida/orm/nodes/repository.py @@ -215,6 +215,14 @@ def walk(self, path: FilePath = None) -> Iterable[Tuple[pathlib.PurePosixPath, L """ yield from self._repository.walk(path) + def glob(self) -> Iterable[pathlib.PurePosixPath]: + """Yield a recursive list of all paths (files and directories).""" + for dirpath, dirnames, filenames in self.walk(): + for dirname in dirnames: + yield dirpath / dirname + for filename in filenames: + yield dirpath / filename + def copy_tree(self, target: Union[str, pathlib.Path], path: FilePath = None) -> None: """Copy the contents of the entire node repository to another location on the local file system. diff --git a/aiida/repository/backend/abstract.py b/aiida/repository/backend/abstract.py index 63ba30dd32..7903cc60c6 100644 --- a/aiida/repository/backend/abstract.py +++ b/aiida/repository/backend/abstract.py @@ -75,7 +75,7 @@ def put_object_from_filelike(self, handle: BinaryIO) -> str: :return: the generated fully qualified identifier for the object within the repository. :raises TypeError: if the handle is not a byte stream. """ - if not isinstance(handle, io.BytesIO) and not self.is_readable_byte_stream(handle): + if not isinstance(handle, io.BufferedIOBase) and not self.is_readable_byte_stream(handle): raise TypeError(f'handle does not seem to be a byte stream: {type(handle)}.') return self._put_object_from_filelike(handle) diff --git a/aiida/repository/common.py b/aiida/repository/common.py index 4b1b5eb2c6..4382a9637b 100644 --- a/aiida/repository/common.py +++ b/aiida/repository/common.py @@ -107,6 +107,14 @@ def file_type(self) -> FileType: """Return the file type of the file object.""" return self._file_type + def is_file(self) -> bool: + """Return whether this instance is a file object.""" + return self.file_type == FileType.FILE + + def is_dir(self) -> bool: + """Return whether this instance is a directory object.""" + return self.file_type == FileType.DIRECTORY + @property def key(self) -> typing.Union[str, None]: """Return the key of the file object.""" diff --git a/tests/orm/node/test_repository.py b/tests/orm/node/test_repository.py index 0af06ebc60..ba185ba025 100644 --- a/tests/orm/node/test_repository.py +++ b/tests/orm/node/test_repository.py @@ -175,6 +175,8 @@ def test_get_object(): file_object = node.get_object(None) assert isinstance(file_object, File) assert file_object.file_type == FileType.DIRECTORY + assert file_object.is_file() is False + assert file_object.is_dir() is True file_object = node.get_object('relative') assert isinstance(file_object, File) @@ -185,6 +187,8 @@ def test_get_object(): assert isinstance(file_object, File) assert file_object.file_type == FileType.FILE assert file_object.name == 'file_b' + assert file_object.is_file() is True + assert file_object.is_dir() is False @pytest.mark.usefixtures('clear_database_before_test') @@ -215,6 +219,15 @@ def test_walk(): ] +@pytest.mark.usefixtures('clear_database_before_test') +def test_glob(): + """Test the ``NodeRepositoryMixin.glob`` method.""" + node = Data() + node.put_object_from_filelike(io.BytesIO(b'content'), 'relative/path') + + assert {path.as_posix() for path in node.glob()} == {'relative', 'relative/path'} + + @pytest.mark.usefixtures('clear_database_before_test') def test_copy_tree(tmp_path): """Test the ``Repository.copy_tree`` method.""" From 924af94baa9d5cada33a891f115f52e60999c962 Mon Sep 17 00:00:00 2001 From: Chris Sewell Date: Tue, 30 Nov 2021 17:48:52 +0100 Subject: [PATCH 09/14] =?UTF-8?q?=E2=99=BB=EF=B8=8F=20REFACTOR:=20New=20ar?= =?UTF-8?q?chive=20format?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Implement the new archive format, as discussed in `aiidateam/AEP/005_exportformat`. To address shortcomings in cpu/memory performance for export/import, the archive format has been re-designed. In particular, 1. The `data.json` has been replaced with an sqlite database, using the saeme schema as the sqlabackend, meaning it is no longer required to be fully read into memory. 2. The archive utilises the repository redesign, with binary files stored by hashkeys (removing de-duplication) 3. The archive is only saved as zip (not tar), meaning internal files can be decompressed+streamed independantly, without the need to uncompress the entire archive file. 4. The archive is implemented as a full (read-only) backend, meaning it can be queried without the need to import to a profile. Additionally, the entire export/import code has been re-written to utilise these changes. These changes have reduced the export times by ~250%, export peak RAM by ~400%, import times by ~400%, and import peak RAM by ~500%. The changes also allow for future push/pull mechanisms. --- .gitignore | 2 + .pre-commit-config.yaml | 4 +- aiida/backends/sqlalchemy/models/__init__.py | 3 +- aiida/backends/sqlalchemy/models/group.py | 1 + aiida/cmdline/commands/cmd_archive.py | 264 ++-- aiida/tools/__init__.py | 41 - .../common => archive}/__init__.py | 24 +- aiida/tools/archive/abstract.py | 296 +++++ .../common/utils.py => archive/common.py} | 74 +- aiida/tools/archive/create.py | 678 ++++++++++ .../common => archive}/exceptions.py | 54 +- .../implementations}/__init__.py | 8 +- .../implementations/sqlite}/__init__.py | 5 +- .../archive/implementations/sqlite/backend.py | 405 ++++++ .../archive/implementations/sqlite/common.py | 158 +++ .../archive/implementations/sqlite/main.py | 116 ++ .../sqlite/migrations}/__init__.py | 1 + .../sqlite/migrations/legacy}/__init__.py | 23 +- .../sqlite/migrations/legacy}/v04_to_v05.py | 12 +- .../sqlite/migrations/legacy}/v05_to_v06.py | 15 +- .../sqlite/migrations/legacy}/v06_to_v07.py | 19 +- .../sqlite/migrations/legacy}/v07_to_v08.py | 13 +- .../sqlite/migrations/legacy}/v08_to_v09.py | 13 +- .../sqlite/migrations/legacy}/v09_to_v10.py | 12 +- .../sqlite/migrations/legacy/v10_to_v11.py} | 21 +- .../sqlite/migrations/legacy/v11_to_v12.py} | 21 +- .../sqlite/migrations/legacy_to_new.py | 275 ++++ .../implementations/sqlite/migrations/main.py | 187 +++ .../sqlite}/migrations/utils.py | 2 +- .../sqlite/migrations/v1_db_schema.py | 169 +++ .../archive/implementations/sqlite/reader.py | 116 ++ .../archive/implementations/sqlite/writer.py | 313 +++++ aiida/tools/archive/imports.py | 1122 +++++++++++++++++ aiida/tools/graph/graph_traversers.py | 12 +- aiida/tools/importexport/__init__.py | 71 -- aiida/tools/importexport/archive/__init__.py | 48 - aiida/tools/importexport/archive/common.py | 234 ---- .../archive/migrations/v01_to_v02.py | 76 -- .../archive/migrations/v02_to_v03.py | 140 -- .../archive/migrations/v03_to_v04.py | 510 -------- .../archive/migrations/v10_to_v11.py | 76 -- aiida/tools/importexport/archive/migrators.py | 280 ---- aiida/tools/importexport/archive/readers.py | 442 ------- aiida/tools/importexport/archive/writers.py | 506 -------- aiida/tools/importexport/common/config.py | 236 ---- aiida/tools/importexport/dbexport/main.py | 600 --------- aiida/tools/importexport/dbexport/utils.py | 289 ----- .../importexport/dbimport/backends/common.py | 146 --- .../importexport/dbimport/backends/django.py | 694 ---------- .../importexport/dbimport/backends/sqla.py | 683 ---------- aiida/tools/importexport/dbimport/main.py | 86 -- aiida/tools/importexport/dbimport/utils.py | 290 ----- docs/source/conf.py | 1 + docs/source/howto/share_data.rst | 109 +- .../howto/visualising_graphs/graph1.aiida | Bin 4472 -> 5455 bytes docs/source/internals/data_storage.rst | 2 + docs/source/nitpick-exceptions | 52 +- docs/source/reference/command_line.rst | 2 +- environment.yml | 2 +- pyproject.toml | 2 + requirements/requirements-py-3.7.txt | 2 +- requirements/requirements-py-3.8.txt | 2 +- requirements/requirements-py-3.9.txt | 2 +- setup.json | 2 +- .../{test_importexport.py => test_archive.py} | 8 +- ...chive_export.py => test_archive_create.py} | 126 +- tests/cmdline/commands/test_archive_import.py | 189 ++- tests/cmdline/commands/test_calcjob.py | 11 +- tests/orm/implementation/test_comments.py | 8 +- tests/orm/implementation/test_logs.py | 6 +- tests/orm/implementation/test_nodes.py | 8 +- tests/static/calcjob/arithmetic.add.aiida | Bin 9011 -> 7738 bytes tests/static/calcjob/arithmetic.add_old.aiida | Bin 7619 -> 6721 bytes tests/static/export/compare/django.aiida | Bin 2622 -> 2967 bytes tests/static/export/compare/sqlalchemy.aiida | Bin 2617 -> 2966 bytes .../export/migrate/export_v0.10_simple.aiida | Bin 56642 -> 48809 bytes .../export/migrate/export_v0.11_simple.aiida | Bin 83553 -> 47372 bytes .../export/migrate/export_v0.12_simple.aiida | Bin 83560 -> 47914 bytes .../export/migrate/export_v0.13_simple.aiida | Bin 83578 -> 0 bytes .../export/migrate/export_v0.4_no_Nodes.aiida | Bin 0 -> 1313 bytes .../export/migrate/export_v0.4_simple.tar.gz | Bin 0 -> 41232 bytes .../export/migrate/export_v1.0_simple.aiida | Bin 0 -> 47286 bytes tests/static/graphs/graph1.aiida | Bin 4452 -> 5444 bytes .../migration => archive}/__init__.py | 0 tests/tools/archive/conftest.py | 54 + tests/tools/archive/migration/__init__.py | 0 .../migration/conftest.py | 32 +- .../migration/test_legacy_funcs.py} | 25 +- .../tools/archive/migration/test_migration.py | 193 +++ .../archive/migration/test_prov_redesign.py | 273 ++++ .../migration/test_v04_to_v05.py | 2 +- .../migration/test_v05_to_v06.py | 4 +- .../migration/test_v06_to_v07.py | 11 +- .../migration/test_v07_to_v08.py | 5 +- .../migration/test_v08_to_v09.py | 5 +- .../{importexport => archive}/orm/__init__.py | 0 tests/tools/archive/orm/test_attributes.py | 39 + tests/tools/archive/orm/test_authinfo.py | 70 + tests/tools/archive/orm/test_calculations.py | 86 ++ tests/tools/archive/orm/test_codes.py | 107 ++ tests/tools/archive/orm/test_comments.py | 601 +++++++++ tests/tools/archive/orm/test_computers.py | 337 +++++ tests/tools/archive/orm/test_extras.py | 163 +++ tests/tools/archive/orm/test_groups.py | 213 ++++ tests/tools/archive/orm/test_links.py | 646 ++++++++++ tests/tools/archive/orm/test_logs.py | 382 ++++++ tests/tools/archive/orm/test_users.py | 165 +++ tests/tools/archive/test_abstract.py | 101 ++ tests/tools/archive/test_backend.py | 58 + tests/tools/archive/test_common.py | 51 + tests/tools/archive/test_complex.py | 198 +++ .../test_repository.py | 6 +- tests/tools/archive/test_simple.py | 146 +++ tests/tools/archive/test_specific_import.py | 182 +++ .../tools/{importexport => archive}/utils.py | 0 tests/tools/importexport/__init__.py | 38 - .../importexport/migration/test_migration.py | 200 --- .../importexport/migration/test_v02_to_v03.py | 142 --- .../importexport/migration/test_v03_to_v04.py | 372 ------ .../tools/importexport/orm/test_attributes.py | 59 - .../importexport/orm/test_calculations.py | 95 -- tests/tools/importexport/orm/test_codes.py | 121 -- tests/tools/importexport/orm/test_comments.py | 584 --------- .../tools/importexport/orm/test_computers.py | 360 ------ tests/tools/importexport/orm/test_extras.py | 171 --- tests/tools/importexport/orm/test_groups.py | 231 ---- tests/tools/importexport/orm/test_links.py | 734 ----------- tests/tools/importexport/orm/test_logs.py | 392 ------ tests/tools/importexport/orm/test_users.py | 175 --- tests/tools/importexport/test_complex.py | 207 --- .../tools/importexport/test_prov_redesign.py | 276 ---- tests/tools/importexport/test_reader.py | 82 -- tests/tools/importexport/test_simple.py | 175 --- .../importexport/test_specific_import.py | 186 --- tests/utils/archives.py | 6 +- tests/utils/configuration.py | 28 - utils/make_all.py | 2 + 137 files changed, 8536 insertions(+), 10675 deletions(-) rename aiida/tools/{importexport/common => archive}/__init__.py (68%) create mode 100644 aiida/tools/archive/abstract.py rename aiida/tools/{importexport/common/utils.py => archive/common.py} (54%) create mode 100644 aiida/tools/archive/create.py rename aiida/tools/{importexport/common => archive}/exceptions.py (63%) rename aiida/tools/{importexport/dbexport => archive/implementations}/__init__.py (86%) rename aiida/tools/{importexport/dbimport => archive/implementations/sqlite}/__init__.py (90%) create mode 100644 aiida/tools/archive/implementations/sqlite/backend.py create mode 100644 aiida/tools/archive/implementations/sqlite/common.py create mode 100644 aiida/tools/archive/implementations/sqlite/main.py rename aiida/tools/{importexport/dbimport/backends => archive/implementations/sqlite/migrations}/__init__.py (90%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/__init__.py (66%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v04_to_v05.py (85%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v05_to_v06.py (93%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v06_to_v07.py (88%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v07_to_v08.py (83%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v08_to_v09.py (84%) rename aiida/tools/{importexport/archive/migrations => archive/implementations/sqlite/migrations/legacy}/v09_to_v10.py (77%) rename aiida/tools/{importexport/archive/migrations/v11_to_v12.py => archive/implementations/sqlite/migrations/legacy/v10_to_v11.py} (66%) rename aiida/tools/{importexport/archive/migrations/v12_to_v13.py => archive/implementations/sqlite/migrations/legacy/v11_to_v12.py} (90%) create mode 100644 aiida/tools/archive/implementations/sqlite/migrations/legacy_to_new.py create mode 100644 aiida/tools/archive/implementations/sqlite/migrations/main.py rename aiida/tools/{importexport/archive => archive/implementations/sqlite}/migrations/utils.py (98%) create mode 100644 aiida/tools/archive/implementations/sqlite/migrations/v1_db_schema.py create mode 100644 aiida/tools/archive/implementations/sqlite/reader.py create mode 100644 aiida/tools/archive/implementations/sqlite/writer.py create mode 100644 aiida/tools/archive/imports.py delete mode 100644 aiida/tools/importexport/__init__.py delete mode 100644 aiida/tools/importexport/archive/__init__.py delete mode 100644 aiida/tools/importexport/archive/common.py delete mode 100644 aiida/tools/importexport/archive/migrations/v01_to_v02.py delete mode 100644 aiida/tools/importexport/archive/migrations/v02_to_v03.py delete mode 100644 aiida/tools/importexport/archive/migrations/v03_to_v04.py delete mode 100644 aiida/tools/importexport/archive/migrations/v10_to_v11.py delete mode 100644 aiida/tools/importexport/archive/migrators.py delete mode 100644 aiida/tools/importexport/archive/readers.py delete mode 100644 aiida/tools/importexport/archive/writers.py delete mode 100644 aiida/tools/importexport/common/config.py delete mode 100644 aiida/tools/importexport/dbexport/main.py delete mode 100644 aiida/tools/importexport/dbexport/utils.py delete mode 100644 aiida/tools/importexport/dbimport/backends/common.py delete mode 100644 aiida/tools/importexport/dbimport/backends/django.py delete mode 100644 aiida/tools/importexport/dbimport/backends/sqla.py delete mode 100644 aiida/tools/importexport/dbimport/main.py delete mode 100644 aiida/tools/importexport/dbimport/utils.py rename tests/benchmark/{test_importexport.py => test_archive.py} (95%) rename tests/cmdline/commands/{test_archive_export.py => test_archive_create.py} (58%) delete mode 100644 tests/static/export/migrate/export_v0.13_simple.aiida create mode 100644 tests/static/export/migrate/export_v0.4_no_Nodes.aiida create mode 100644 tests/static/export/migrate/export_v0.4_simple.tar.gz create mode 100644 tests/static/export/migrate/export_v1.0_simple.aiida rename tests/tools/{importexport/migration => archive}/__init__.py (100%) create mode 100644 tests/tools/archive/conftest.py create mode 100644 tests/tools/archive/migration/__init__.py rename tests/tools/{importexport => archive}/migration/conftest.py (66%) rename tests/tools/{importexport/migration/test_migration_funcs.py => archive/migration/test_legacy_funcs.py} (72%) create mode 100644 tests/tools/archive/migration/test_migration.py create mode 100644 tests/tools/archive/migration/test_prov_redesign.py rename tests/tools/{importexport => archive}/migration/test_v04_to_v05.py (95%) rename tests/tools/{importexport => archive}/migration/test_v05_to_v06.py (96%) rename tests/tools/{importexport => archive}/migration/test_v06_to_v07.py (93%) rename tests/tools/{importexport => archive}/migration/test_v07_to_v08.py (94%) rename tests/tools/{importexport => archive}/migration/test_v08_to_v09.py (94%) rename tests/tools/{importexport => archive}/orm/__init__.py (100%) create mode 100644 tests/tools/archive/orm/test_attributes.py create mode 100644 tests/tools/archive/orm/test_authinfo.py create mode 100644 tests/tools/archive/orm/test_calculations.py create mode 100644 tests/tools/archive/orm/test_codes.py create mode 100644 tests/tools/archive/orm/test_comments.py create mode 100644 tests/tools/archive/orm/test_computers.py create mode 100644 tests/tools/archive/orm/test_extras.py create mode 100644 tests/tools/archive/orm/test_groups.py create mode 100644 tests/tools/archive/orm/test_links.py create mode 100644 tests/tools/archive/orm/test_logs.py create mode 100644 tests/tools/archive/orm/test_users.py create mode 100644 tests/tools/archive/test_abstract.py create mode 100644 tests/tools/archive/test_backend.py create mode 100644 tests/tools/archive/test_common.py create mode 100644 tests/tools/archive/test_complex.py rename tests/tools/{importexport => archive}/test_repository.py (92%) create mode 100644 tests/tools/archive/test_simple.py create mode 100644 tests/tools/archive/test_specific_import.py rename tests/tools/{importexport => archive}/utils.py (100%) delete mode 100644 tests/tools/importexport/__init__.py delete mode 100644 tests/tools/importexport/migration/test_migration.py delete mode 100644 tests/tools/importexport/migration/test_v02_to_v03.py delete mode 100644 tests/tools/importexport/migration/test_v03_to_v04.py delete mode 100644 tests/tools/importexport/orm/test_attributes.py delete mode 100644 tests/tools/importexport/orm/test_calculations.py delete mode 100644 tests/tools/importexport/orm/test_codes.py delete mode 100644 tests/tools/importexport/orm/test_comments.py delete mode 100644 tests/tools/importexport/orm/test_computers.py delete mode 100644 tests/tools/importexport/orm/test_extras.py delete mode 100644 tests/tools/importexport/orm/test_groups.py delete mode 100644 tests/tools/importexport/orm/test_links.py delete mode 100644 tests/tools/importexport/orm/test_logs.py delete mode 100644 tests/tools/importexport/orm/test_users.py delete mode 100644 tests/tools/importexport/test_complex.py delete mode 100644 tests/tools/importexport/test_prov_redesign.py delete mode 100644 tests/tools/importexport/test_reader.py delete mode 100644 tests/tools/importexport/test_simple.py delete mode 100644 tests/tools/importexport/test_specific_import.py diff --git a/.gitignore b/.gitignore index 52e6f940fe..cc9eb2c818 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,5 @@ pip-wheel-metadata # Docs docs/build docs/source/reference/apidoc + +pplot_out/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0d58bc026a..e4d79f9b65 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -105,9 +105,7 @@ repos: aiida/repository/.*py| aiida/tools/graph/graph_traversers.py| aiida/tools/groups/paths.py| - aiida/tools/importexport/archive/.*py| - aiida/tools/importexport/dbexport/__init__.py| - aiida/tools/importexport/dbimport/backends/.*.py| + aiida/tools/archive/.*py| )$ - id: pylint diff --git a/aiida/backends/sqlalchemy/models/__init__.py b/aiida/backends/sqlalchemy/models/__init__.py index a3d39148fb..a61b2a1c66 100644 --- a/aiida/backends/sqlalchemy/models/__init__.py +++ b/aiida/backends/sqlalchemy/models/__init__.py @@ -9,6 +9,7 @@ ########################################################################### """Module to define the database models for the SqlAlchemy backend.""" import sqlalchemy as sa +from sqlalchemy.orm import mapper # SqlAlchemy does not set default values for table columns upon construction of a new instance, but will only do so # when storing the instance. Any attributes that do not have a value but have a defined default, will be populated with @@ -34,4 +35,4 @@ def instant_defaults_listener(target, _, __): setattr(target, key, column.default.arg) -sa.event.listen(sa.orm.mapper, 'init', instant_defaults_listener) +sa.event.listen(mapper, 'init', instant_defaults_listener) diff --git a/aiida/backends/sqlalchemy/models/group.py b/aiida/backends/sqlalchemy/models/group.py index 1fdb898987..8b3eda563a 100644 --- a/aiida/backends/sqlalchemy/models/group.py +++ b/aiida/backends/sqlalchemy/models/group.py @@ -33,6 +33,7 @@ class DbGroupNode(Base): """Class to store group to nodes relation using SQLA backend.""" + __tablename__ = table_groups_nodes.name __table__ = table_groups_nodes diff --git a/aiida/cmdline/commands/cmd_archive.py b/aiida/cmdline/commands/cmd_archive.py index 33d9519d18..36b651b32c 100644 --- a/aiida/cmdline/commands/cmd_archive.py +++ b/aiida/cmdline/commands/cmd_archive.py @@ -16,18 +16,20 @@ import urllib.request import click +from click_spinner import spinner import tabulate from aiida.cmdline.commands.cmd_verdi import verdi from aiida.cmdline.params import arguments, options from aiida.cmdline.params.types import GroupParamType, PathOrUrl from aiida.cmdline.utils import decorators, echo +from aiida.cmdline.utils.common import get_database_summary from aiida.common.links import GraphTraversalRules from aiida.common.log import AIIDA_LOGGER -EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none', 'ask'] +EXTRAS_MODE_EXISTING = ['keep_existing', 'update_existing', 'mirror', 'none'] EXTRAS_MODE_NEW = ['import', 'none'] -COMMENT_MODE = ['newest', 'overwrite'] +COMMENT_MODE = ['leave', 'newest', 'overwrite'] @verdi.group('archive') @@ -39,51 +41,73 @@ def verdi_archive(): @click.argument('archive', nargs=1, type=click.Path(exists=True, readable=True)) @click.option('-v', '--version', is_flag=True, help='Print the archive format version and exit.') @click.option('-m', '--meta-data', is_flag=True, help='Print the meta data contents and exit.') -def inspect(archive, version, meta_data): +@click.option('-d', '--database', is_flag=True, help='Include information on entities in the database.') +def inspect(archive, version, meta_data, database): """Inspect contents of an archive without importing it. - By default a summary of the archive contents will be printed. The various options can be used to change exactly what - information is displayed. + By default a summary of the archive contents will be printed. + The various options can be used to change exactly what information is displayed. """ - import dataclasses + from aiida.tools.archive.abstract import get_format + from aiida.tools.archive.exceptions import UnreadableArchiveError - from aiida.tools.importexport import CorruptArchive, detect_archive_type, get_reader + archive_format = get_format() + latest_version = archive_format.latest_version + try: + current_version = archive_format.read_version(archive) + except UnreadableArchiveError as exc: + echo.echo_critical(f'archive file of unknown format: {exc}') - reader_cls = get_reader(detect_archive_type(archive)) + if version: + echo.echo(current_version) + return - with reader_cls(archive) as reader: - try: - if version: - echo.echo(reader.export_version) - elif meta_data: - echo.echo_dictionary(dataclasses.asdict(reader.metadata)) - else: - statistics = { - 'Version aiida': reader.metadata.aiida_version, - 'Version format': reader.metadata.export_version, - 'Computers': reader.entity_count('Computer'), - 'Groups': reader.entity_count('Group'), - 'Links': reader.link_count, - 'Nodes': reader.entity_count('Node'), - 'Users': reader.entity_count('User'), - } - if reader.metadata.conversion_info: - statistics['Conversion info'] = '\n'.join(reader.metadata.conversion_info) - - echo.echo(tabulate.tabulate(statistics.items())) - except CorruptArchive as exception: - echo.echo_critical(f'corrupt archive: {exception}') + if current_version != latest_version: + echo.echo_critical( + f"Archive version is not the latest: '{current_version}' != '{latest_version}'. " + 'Use `verdi migrate` to upgrade to the latest version' + ) + + with archive_format.open(archive, 'r') as archive_reader: + metadata = archive_reader.get_metadata() + + if meta_data: + echo.echo_dictionary(metadata, sort_keys=False) + return + + statistics = { + name: metadata[key] for key, name in [ + ['export_version', 'Version archive'], + ['aiida_version', 'Version aiida'], + ['compression', 'Compression'], + ['ctime', 'Created'], + ['mtime', 'Modified'], + ] if key in metadata + } + if 'conversion_info' in metadata: + statistics['Conversion info'] = '\n'.join(metadata['conversion_info']) + + echo.echo(tabulate.tabulate(statistics.items())) + + if database: + echo.echo('') + echo.echo('Database statistics') + echo.echo('-------------------') + with spinner(): + with archive_format.open(archive, 'r') as archive_reader: + data = get_database_summary(archive_reader.querybuilder, True) + repo = archive_reader.get_backend().get_repository() + data['Repo Files'] = {'count': sum(1 for _ in repo.list_objects())} + echo.echo_dictionary(data, sort_keys=False, fmt='yaml') @verdi_archive.command('create') @arguments.OUTPUT_FILE(type=click.Path(exists=False)) +@options.ALL() @options.CODES() @options.COMPUTERS() @options.GROUPS() @options.NODES() -@options.ARCHIVE_FORMAT( - type=click.Choice(['zip', 'zip-uncompressed', 'zip-lowmemory', 'tar.gz', 'null']), -) @options.FORCE(help='Overwrite output file if it already exists.') @options.graph_traversal_rules(GraphTraversalRules.EXPORT.value) @click.option( @@ -98,23 +122,26 @@ def inspect(archive, version, meta_data): show_default=True, help='Include or exclude comments for node(s) in export. (Will also export extra users who commented).' ) -# will only be useful when moving to a new archive format, that does not store all data in memory -# @click.option( -# '-b', -# '--batch-size', -# default=1000, -# type=int, -# help='Batch database query results in sub-collections to reduce memory usage.' -# ) +@click.option( + '--include-authinfos/--exclude-authinfos', + default=False, + show_default=True, + help='Include or exclude authentication information for computer(s) in export.' +) +@click.option('--compress', default=6, show_default=True, type=int, help='Level of compression to use (0-9).') +@click.option( + '-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.' +) +@click.option('--test-run', is_flag=True, help='Determine entities to export, but do not create the archive.') @decorators.with_dbenv() def create( - output_file, codes, computers, groups, nodes, archive_format, force, input_calc_forward, input_work_forward, - create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs + output_file, all_entries, codes, computers, groups, nodes, force, input_calc_forward, input_work_forward, + create_backward, return_backward, call_calc_backward, call_work_backward, include_comments, include_logs, + include_authinfos, compress, batch_size, test_run ): - """ - Export subsets of the provenance graph to file for sharing. + """Write subsets of the provenance graph to a single file. - Besides Nodes of the provenance graph, you can export Groups, Codes, Computers, Comments and Logs. + Besides Nodes of the provenance graph, you can archive Groups, Codes, Computers, Comments and Logs. By default, the archive file will include not only the entities explicitly provided via the command line but also their provenance, according to the rules outlined in the documentation. @@ -122,22 +149,28 @@ def create( """ # pylint: disable=too-many-branches from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import ExportFileFormat, export - from aiida.tools.importexport.common.exceptions import ArchiveExportError + from aiida.tools.archive.abstract import get_format + from aiida.tools.archive.create import create_archive + from aiida.tools.archive.exceptions import ArchiveExportError + + archive_format = get_format() - entities = [] + if all_entries: + entities = None + else: + entities = [] - if codes: - entities.extend(codes) + if codes: + entities.extend(codes) - if computers: - entities.extend(computers) + if computers: + entities.extend(computers) - if groups: - entities.extend(groups) + if groups: + entities.extend(groups) - if nodes: - entities.extend(nodes) + if nodes: + entities.extend(nodes) kwargs = { 'input_calc_forward': input_calc_forward, @@ -146,32 +179,22 @@ def create( 'return_backward': return_backward, 'call_calc_backward': call_calc_backward, 'call_work_backward': call_work_backward, + 'include_authinfos': include_authinfos, 'include_comments': include_comments, 'include_logs': include_logs, 'overwrite': force, + 'compression': compress, + 'batch_size': batch_size, + 'test_run': test_run } - if archive_format == 'zip': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': True}}) - elif archive_format == 'zip-uncompressed': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'use_compression': False}}) - elif archive_format == 'zip-lowmemory': - export_format = ExportFileFormat.ZIP - kwargs.update({'writer_init': {'cache_zipinfo': True}}) - elif archive_format == 'tar.gz': - export_format = ExportFileFormat.TAR_GZIPPED - elif archive_format == 'null': - export_format = 'null' - if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member - set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level == logging.DEBUG)) + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) try: - export(entities, filename=output_file, file_format=export_format, **kwargs) + create_archive(entities, filename=output_file, archive_format=archive_format, **kwargs) except ArchiveExportError as exception: echo.echo_critical(f'failed to write the archive file. Exception: {exception}') else: @@ -181,11 +204,9 @@ def create( @verdi_archive.command('migrate') @arguments.INPUT_FILE() @arguments.OUTPUT_FILE(required=False) -@options.ARCHIVE_FORMAT() @options.FORCE(help='overwrite output file if it already exists') @click.option('-i', '--in-place', is_flag=True, help='Migrate the archive in place, overwriting the original file.') @click.option( - '-v', '--version', type=click.STRING, required=False, @@ -195,11 +216,10 @@ def create( # version inside the function when needed. help='Archive format version to migrate to (defaults to latest version).', ) -def migrate(input_file, output_file, force, in_place, archive_format, version): +def migrate(input_file, output_file, force, in_place, version): """Migrate an export archive to a more recent format version.""" from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter - from aiida.tools.importexport import EXPORT_VERSION, detect_archive_type - from aiida.tools.importexport.archive.migrators import get_migrator + from aiida.tools.archive.abstract import get_format if in_place: if output_file: @@ -212,18 +232,17 @@ def migrate(input_file, output_file, force, in_place, archive_format, version): ) if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member - set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level == logging.DEBUG)) + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) - if version is None: - version = EXPORT_VERSION + archive_format = get_format() - migrator_cls = get_migrator(detect_archive_type(input_file)) - migrator = migrator_cls(input_file) + if version is None: + version = archive_format.latest_version try: - migrator.migrate(version, output_file, force=force, out_compression=archive_format) + archive_format.migrate(input_file, output_file, version, force=force, compression=6) except Exception as error: # pylint: disable=broad-except if AIIDA_LOGGER.level <= logging.DEBUG: raise @@ -238,11 +257,10 @@ def migrate(input_file, output_file, force, in_place, archive_format, version): class ExtrasImportCode(Enum): """Exit codes for the verdi command line.""" # pylint: disable=invalid-name - keep_existing = 'kcl' - update_existing = 'kcu' - mirror = 'ncu' - none = 'knl' - ask = 'kca' + keep_existing = ('k', 'c', 'l') + update_existing = ('k', 'c', 'u') + mirror = ('n', 'c', 'u') + none = ('k', 'n', 'l') @verdi_archive.command('import') @@ -264,14 +282,13 @@ class ExtrasImportCode(Enum): '-e', '--extras-mode-existing', type=click.Choice(EXTRAS_MODE_EXISTING), - default='keep_existing', + default='none', help='Specify which extras from the export archive should be imported for nodes that are already contained in the ' 'database: ' - 'ask: import all extras and prompt what to do for existing extras. ' + 'none: do not import any extras.' 'keep_existing: import all extras and keep original value of existing extras. ' 'update_existing: import all extras and overwrite value of existing extras. ' 'mirror: import all extras and remove any existing extras that are not present in the archive. ' - 'none: do not import any extras.' ) @click.option( '-n', @@ -285,22 +302,33 @@ class ExtrasImportCode(Enum): @click.option( '--comment-mode', type=click.Choice(COMMENT_MODE), - default='newest', + default='leave', help='Specify the way to import Comments with identical UUIDs: ' - 'newest: Only the newest Comments (based on mtime) (default).' + 'leave: Leave the existing Comments in the database (default).' + 'newest: Use only the newest Comments (based on mtime).' 'overwrite: Replace existing Comments with those from the import file.' ) +@click.option( + '--include-authinfos/--exclude-authinfos', + default=False, + show_default=True, + help='Include or exclude authentication information for computer(s) in import.' +) @click.option( '--migration/--no-migration', default=True, show_default=True, help='Force migration of archive file archives, if needed.' ) -@options.NON_INTERACTIVE() +@click.option( + '-b', '--batch-size', default=1000, type=int, help='Stream database rows in batches, to reduce memory usage.' +) +@click.option('--test-run', is_flag=True, help='Determine entities to import, but do not actually import them.') @decorators.with_dbenv() @click.pass_context def import_archive( - ctx, archives, webpages, group, extras_mode_existing, extras_mode_new, comment_mode, migration, non_interactive + ctx, archives, webpages, extras_mode_existing, extras_mode_new, comment_mode, include_authinfos, migration, + batch_size, group, test_run ): """Import data from an AiiDA archive file. @@ -310,7 +338,7 @@ def import_archive( from aiida.common.progress_reporter import set_progress_bar_tqdm, set_progress_reporter if AIIDA_LOGGER.level <= logging.REPORT: # pylint: disable=no-member - set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level == logging.DEBUG)) + set_progress_bar_tqdm(leave=(AIIDA_LOGGER.level <= logging.INFO)) else: set_progress_reporter(None) @@ -322,14 +350,17 @@ def import_archive( # Shared import key-word arguments import_kwargs = { + 'import_new_extras': extras_mode_new == 'import', + 'merge_extras': ExtrasImportCode[extras_mode_existing].value, + 'merge_comments': comment_mode, + 'include_authinfos': include_authinfos, + 'batch_size': batch_size, 'group': group, - 'extras_mode_existing': ExtrasImportCode[extras_mode_existing].value, - 'extras_mode_new': extras_mode_new, - 'comment_mode': comment_mode, + 'test_run': test_run, } for archive, web_based in all_archives: - _import_archive(archive, web_based, import_kwargs, migration) + _import_archive_and_migrate(archive, web_based, import_kwargs, migration) def _echo_exception(msg: str, exception, warn_only: bool = False): @@ -340,12 +371,12 @@ def _echo_exception(msg: str, exception, warn_only: bool = False): :param warn_only: If True only print a warning, otherwise calls sys.exit with a non-zero exit status """ - from aiida.tools.importexport import IMPORT_LOGGER + from aiida.tools.archive.imports import IMPORT_LOGGER message = f'{msg}: {exception.__class__.__name__}: {str(exception)}' if warn_only: echo.echo_warning(message) else: - IMPORT_LOGGER.debug('%s', traceback.format_exc()) + IMPORT_LOGGER.info('%s', traceback.format_exc()) echo.echo_critical(message) @@ -355,7 +386,7 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: :returns: list of (archive path, whether it is web based) """ - from aiida.tools.importexport.common.utils import get_valid_import_links + from aiida.tools.archive.common import get_valid_import_links final_archives = [] @@ -383,7 +414,7 @@ def _gather_imports(archives, webpages) -> List[Tuple[str, bool]]: return final_archives -def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): +def _import_archive_and_migrate(archive: str, web_based: bool, import_kwargs: dict, try_migration: bool): """Perform the archive import. :param archive: the path or URL to the archive @@ -393,13 +424,11 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr """ from aiida.common.folders import SandboxFolder - from aiida.tools.importexport import ( - EXPORT_VERSION, - IncompatibleArchiveVersionError, - detect_archive_type, - import_data, - ) - from aiida.tools.importexport.archive.migrators import get_migrator + from aiida.tools.archive.abstract import get_format + from aiida.tools.archive.exceptions import IncompatibleArchiveVersionError + from aiida.tools.archive.imports import import_archive as _import_archive + + archive_format = get_format() with SandboxFolder() as temp_folder: @@ -418,22 +447,21 @@ def _import_archive(archive: str, web_based: bool, import_kwargs: dict, try_migr echo.echo_report(f'starting import: {archive}') try: - import_data(archive_path, **import_kwargs) + _import_archive(archive_path, archive_format=archive_format, **import_kwargs) except IncompatibleArchiveVersionError as exception: if try_migration: echo.echo_report(f'incompatible version detected for {archive}, trying migration') try: - migrator = get_migrator(detect_archive_type(archive_path))(archive_path) - archive_path = migrator.migrate( - EXPORT_VERSION, None, out_compression='none', work_dir=temp_folder.abspath - ) + new_path = temp_folder.get_abs_path('migrated_archive.aiida') + archive_format.migrate(archive_path, new_path, archive_format.latest_version, compression=0) + archive_path = new_path except Exception as exception: _echo_exception(f'an exception occurred while migrating the archive {archive}', exception) echo.echo_report('proceeding with import of migrated archive') try: - import_data(archive_path, **import_kwargs) + _import_archive(archive_path, archive_format=archive_format, **import_kwargs) except Exception as exception: _echo_exception( f'an exception occurred while trying to import the migrated archive {archive}', exception diff --git a/aiida/tools/__init__.py b/aiida/tools/__init__.py index cb4615adb7..9a055c8969 100644 --- a/aiida/tools/__init__.py +++ b/aiida/tools/__init__.py @@ -29,67 +29,26 @@ from .data import * from .graph import * from .groups import * -from .importexport import * from .visualization import * __all__ = ( - 'ARCHIVE_READER_LOGGER', - 'ArchiveExportError', - 'ArchiveImportError', - 'ArchiveMetadata', - 'ArchiveMigrationError', - 'ArchiveMigratorAbstract', - 'ArchiveMigratorJsonBase', - 'ArchiveMigratorJsonTar', - 'ArchiveMigratorJsonZip', - 'ArchiveReaderAbstract', - 'ArchiveWriterAbstract', - 'CacheFolder', 'CalculationTools', - 'CorruptArchive', 'DELETE_LOGGER', - 'DanglingLinkError', - 'EXPORT_LOGGER', - 'EXPORT_VERSION', - 'ExportFileFormat', - 'ExportImportException', - 'ExportValidationError', 'Graph', 'GroupNotFoundError', 'GroupNotUniqueError', 'GroupPath', - 'IMPORT_LOGGER', - 'ImportUniquenessError', - 'ImportValidationError', - 'IncompatibleArchiveVersionError', 'InvalidPath', - 'MIGRATE_LOGGER', - 'MigrationValidationError', 'NoGroupsInPathError', 'Orbital', - 'ProgressBarError', - 'ReaderJsonBase', - 'ReaderJsonFolder', - 'ReaderJsonTar', - 'ReaderJsonZip', 'RealhydrogenOrbital', - 'WriterJsonFolder', - 'WriterJsonTar', - 'WriterJsonZip', 'default_link_styles', 'default_node_styles', 'default_node_sublabels', 'delete_group_nodes', 'delete_nodes', - 'detect_archive_type', - 'export', 'get_explicit_kpoints_path', 'get_kpoints_path', - 'get_migrator', - 'get_reader', - 'get_writer', - 'import_data', - 'null_callback', 'pstate_node_styles', 'spglib_tuple_to_structure', 'structure_to_spglib_tuple', diff --git a/aiida/tools/importexport/common/__init__.py b/aiida/tools/archive/__init__.py similarity index 68% rename from aiida/tools/importexport/common/__init__.py rename to aiida/tools/archive/__init__.py index f2755eade4..4252c80745 100644 --- a/aiida/tools/importexport/common/__init__.py +++ b/aiida/tools/archive/__init__.py @@ -7,30 +7,44 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Common utility functions, classes, and exceptions""" +"""The AiiDA archive allows export/import, +of subsets of the provenance graph, to a single file +""" # AUTO-GENERATED # yapf: disable # pylint: disable=wildcard-import -from .config import * +from .abstract import * +from .common import * +from .create import * from .exceptions import * +from .implementations import * +from .imports import * __all__ = ( 'ArchiveExportError', + 'ArchiveFormatAbstract', + 'ArchiveFormatSqlZip', 'ArchiveImportError', 'ArchiveMigrationError', + 'ArchiveReaderAbstract', + 'ArchiveWriterAbstract', 'CorruptArchive', - 'DanglingLinkError', - 'EXPORT_VERSION', + 'EXPORT_LOGGER', 'ExportImportException', 'ExportValidationError', + 'IMPORT_LOGGER', + 'ImportTestRun', 'ImportUniquenessError', 'ImportValidationError', 'IncompatibleArchiveVersionError', + 'MIGRATE_LOGGER', 'MigrationValidationError', - 'ProgressBarError', + 'create_archive', + 'get_format', + 'import_archive', ) # yapf: enable diff --git a/aiida/tools/archive/abstract.py b/aiida/tools/archive/abstract.py new file mode 100644 index 0000000000..53281afa5e --- /dev/null +++ b/aiida/tools/archive/abstract.py @@ -0,0 +1,296 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Abstraction for an archive file format.""" +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, BinaryIO, Dict, List, Optional, Type, TypeVar, Union, overload + +try: + from typing import Literal # pylint: disable=ungrouped-imports +except ImportError: + # Python <3.8 backport + from typing_extensions import Literal # type: ignore + +if TYPE_CHECKING: + from aiida.orm import QueryBuilder + from aiida.orm.entities import Entity, EntityTypes + from aiida.orm.implementation import Backend + from aiida.tools.visualization.graph import Graph + +SelfType = TypeVar('SelfType') +EntityType = TypeVar('EntityType', bound='Entity') + +__all__ = ('ArchiveFormatAbstract', 'ArchiveReaderAbstract', 'ArchiveWriterAbstract', 'get_format') + + +class ArchiveWriterAbstract(ABC): + """Writer of an archive, that will be used as a context manager.""" + + def __init__( + self, + path: Union[str, Path], + fmt: 'ArchiveFormatAbstract', + *, + mode: Literal['x', 'w', 'a'] = 'x', + compression: int = 6, + **kwargs: Any + ): + """Initialise the writer. + + :param path: archive path + :param mode: mode to open the archive in: 'x' (exclusive), 'w' (write) or 'a' (append) + :param compression: default level of compression to use (integer from 0 to 9) + """ + self._path = Path(path) + if mode not in ('x', 'w', 'a'): + raise ValueError(f'mode not in x, w, a: {mode}') + self._mode = mode + if compression not in range(10): + raise ValueError(f'compression not in range 0-9: {compression}') + self._compression = compression + self._format = fmt + + @property + def path(self) -> Path: + """Return the path to the archive.""" + return self._path + + @property + def mode(self) -> Literal['x', 'w', 'a']: + """Return the mode of the archive.""" + return self._mode + + @property + def compression(self) -> int: + """Return the compression level.""" + return self._compression + + def __enter__(self: SelfType) -> SelfType: + """Start writing to the archive.""" + return self + + def __exit__(self, *args, **kwargs) -> None: + """Finalise the archive.""" + + @abstractmethod + def update_metadata(self, data: Dict[str, Any], overwrite: bool = False) -> None: + """Add key, values to the top-level metadata.""" + + @abstractmethod + def bulk_insert( + self, + entity_type: 'EntityTypes', + rows: List[Dict[str, Any]], + allow_defaults: bool = False, + ) -> None: + """Add multiple rows of entity data to the archive. + + :param entity_type: The type of the entity + :param data: A list of dictionaries, containing all fields of the backend model, + except the `id` field (a.k.a primary key), which will be generated dynamically + :param allow_defaults: If ``False``, assert that each row contains all fields, + otherwise, allow default values for missing fields. + + :raises: ``IntegrityError`` if the keys in a row are not a subset of the columns in the table + """ + + @abstractmethod + def put_object(self, stream: BinaryIO, *, buffer_size: Optional[int] = None, key: Optional[str] = None) -> str: + """Add an object to the archive. + + :param stream: byte stream to read the object from + :param buffer_size: Number of bytes to buffer when read/writing + :param key: key to use for the object (if None will be auto-generated) + :return: the key of the object + """ + + @abstractmethod + def delete_object(self, key: str) -> None: + """Delete the object from the archive. + + :param key: fully qualified identifier for the object within the repository. + :raise IOError: if the file could not be deleted. + """ + + +class ArchiveReaderAbstract(ABC): + """Reader of an archive, that will be used as a context manager.""" + + def __init__(self, path: Union[str, Path], **kwargs: Any): + """Initialise the reader. + + :param path: archive path + """ + self._path = Path(path) + + @property + def path(self): + """Return the path to the archive.""" + return self._path + + def __enter__(self: SelfType) -> SelfType: + """Start reading from the archive.""" + return self + + def __exit__(self, *args, **kwargs) -> None: + """Finalise the archive.""" + + @abstractmethod + def get_metadata(self) -> Dict[str, Any]: + """Return the top-level metadata. + + :raises: ``UnreadableArchiveError`` if the top-level metadata cannot be read from the archive + """ + + @abstractmethod + def get_backend(self) -> 'Backend': + """Return a 'read-only' backend for the archive.""" + + # below are convenience methods for some common use cases + + def querybuilder(self, **kwargs: Any) -> 'QueryBuilder': + """Return a ``QueryBuilder`` instance, initialised with the archive backend.""" + from aiida.orm import QueryBuilder + return QueryBuilder(backend=self.get_backend(), **kwargs) + + def get(self, entity_cls: Type[EntityType], **filters: Any) -> EntityType: + """Return the entity for the given filters. + + Example:: + + reader.get(orm.Node, pk=1) + + :param entity_cls: The type of the front-end entity + :param filters: the filters identifying the object to get + """ + if 'pk' in filters: + filters['id'] = filters.pop('pk') + return self.querybuilder().append(entity_cls, filters=filters).one()[0] + + def graph(self, **kwargs: Any) -> 'Graph': + """Return a provenance graph generator for the archive.""" + from aiida.tools.visualization.graph import Graph + return Graph(backend=self.get_backend(), **kwargs) + + +class ArchiveFormatAbstract(ABC): + """Abstract class for an archive format.""" + + @property + @abstractmethod + def versions(self) -> List[str]: + """Return ordered list of versions of the archive format, oldest -> latest.""" + + @property + def latest_version(self) -> str: + """Return the latest version of the archive format.""" + return self.versions[-1] + + @property + @abstractmethod + def key_format(self) -> str: + """Return the format of repository keys.""" + + @abstractmethod + def read_version(self, path: Union[str, Path]) -> str: + """Read the version of the archive from a file. + + This method should account for reading all versions of the archive format. + + :param path: archive path + + :raises: ``FileNotFoundError`` if the file does not exist + :raises: ``UnreadableArchiveError`` if a version cannot be read from the archive + """ + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['r'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveReaderAbstract: + ... + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['x', 'w'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterAbstract: + ... + + @overload + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['a'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterAbstract: + ... + + @abstractmethod + def open( + self, + path: Union[str, Path], + mode: Literal['r', 'x', 'w', 'a'] = 'r', + *, + compression: int = 6, + **kwargs: Any + ) -> Union[ArchiveReaderAbstract, ArchiveWriterAbstract]: + """Open an archive (latest version only). + + :param path: archive path + :param mode: open mode: 'r' (read), 'x' (exclusive write), 'w' (write) or 'a' (append) + :param compression: default level of compression to use for writing (integer from 0 to 9) + + Note, in write mode, the writer is responsible for writing the format version. + """ + + @abstractmethod + def migrate( + self, + inpath: Union[str, Path], + outpath: Union[str, Path], + version: str, + *, + force: bool = False, + compression: int = 6 + ) -> None: + """Migrate an archive to a specific version. + + :param inpath: input archive path + :param outpath: output archive path + :param version: version to migrate to + :param force: allow overwrite of existing output archive path + :param compression: default level of compression to use for writing (integer from 0 to 9) + """ + + +def get_format(name: str = 'sqlitezip') -> ArchiveFormatAbstract: + """Get the archive format instance. + + :param name: name of the archive format + :return: archive format instance + """ + # to-do entry point for archive formats? + assert name == 'sqlitezip' + from aiida.tools.archive.implementations.sqlite.main import ArchiveFormatSqlZip + return ArchiveFormatSqlZip() diff --git a/aiida/tools/importexport/common/utils.py b/aiida/tools/archive/common.py similarity index 54% rename from aiida/tools/importexport/common/utils.py rename to aiida/tools/archive/common.py index 927546a027..a6bdce8094 100644 --- a/aiida/tools/importexport/common/utils.py +++ b/aiida/tools/archive/common.py @@ -7,52 +7,54 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -""" Utility functions for import/export of AiiDA entities """ -# pylint: disable=too-many-branches,too-many-return-statements,too-many-nested-blocks,too-many-locals +"""Shared resources for the archive.""" from html.parser import HTMLParser +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type import urllib.parse import urllib.request -from aiida.tools.importexport.common.config import ( - COMMENT_ENTITY_NAME, - COMPUTER_ENTITY_NAME, - GROUP_ENTITY_NAME, - LOG_ENTITY_NAME, - NODE_ENTITY_NAME, - USER_ENTITY_NAME, -) +from aiida.common.log import AIIDA_LOGGER +from aiida.orm import AuthInfo, Comment, Computer, Entity, Group, Log, Node, User +from aiida.orm.entities import EntityTypes +__all__ = ('MIGRATE_LOGGER',) -def schema_to_entity_names(class_string): - """ - Mapping from classes path to entity names (used by the SQLA import/export) - This could have been written much simpler if it is only for SQLA but there - is an attempt the SQLA import/export code to be used for Django too. - """ - if class_string is None: - return None +MIGRATE_LOGGER = AIIDA_LOGGER.getChild('migrate') - if class_string in ('aiida.backends.djsite.db.models.DbNode', 'aiida.backends.sqlalchemy.models.node.DbNode'): - return NODE_ENTITY_NAME +# Mapping from entity names to AiiDA classes +entity_type_to_orm: Dict[EntityTypes, Type[Entity]] = { + EntityTypes.AUTHINFO: AuthInfo, + EntityTypes.GROUP: Group, + EntityTypes.COMPUTER: Computer, + EntityTypes.USER: User, + EntityTypes.LOG: Log, + EntityTypes.NODE: Node, + EntityTypes.COMMENT: Comment, +} - if class_string in ('aiida.backends.djsite.db.models.DbGroup', 'aiida.backends.sqlalchemy.models.group.DbGroup'): - return GROUP_ENTITY_NAME - if class_string in ( - 'aiida.backends.djsite.db.models.DbComputer', 'aiida.backends.sqlalchemy.models.computer.DbComputer' - ): - return COMPUTER_ENTITY_NAME +def batch_iter(iterable: Iterable[Any], + size: int, + transform: Optional[Callable[[Any], Any]] = None) -> Iterable[Tuple[int, List[Any]]]: + """Yield an iterable in batches of a set number of items. - if class_string in ('aiida.backends.djsite.db.models.DbUser', 'aiida.backends.sqlalchemy.models.user.DbUser'): - return USER_ENTITY_NAME + Note, the final yield may be less than this size. - if class_string in ('aiida.backends.djsite.db.models.DbLog', 'aiida.backends.sqlalchemy.models.log.DbLog'): - return LOG_ENTITY_NAME - - if class_string in ( - 'aiida.backends.djsite.db.models.DbComment', 'aiida.backends.sqlalchemy.models.comment.DbComment' - ): - return COMMENT_ENTITY_NAME + :param transform: a transform to apply to each item + :returns: (number of items, list of items) + """ + transform = transform or (lambda x: x) + current = [] + length = 0 + for item in iterable: + current.append(transform(item)) + length += 1 + if length >= size: + yield length, current + current = [] + length = 0 + if current: + yield length, current class HTMLGetLinksParser(HTMLParser): @@ -63,7 +65,7 @@ class HTMLGetLinksParser(HTMLParser): # pylint: disable=abstract-method - def __init__(self, filter_extension=None): # pylint: disable=super-on-old-class + def __init__(self, filter_extension=None): self.filter_extension = filter_extension self.links = [] super().__init__() diff --git a/aiida/tools/archive/create.py b/aiida/tools/archive/create.py new file mode 100644 index 0000000000..bd7fc49db1 --- /dev/null +++ b/aiida/tools/archive/create.py @@ -0,0 +1,678 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-locals,too-many-branches,too-many-statements +"""Create an AiiDA archive. + +The archive is a subset of the provenance graph, +stored in a single file. +""" +from datetime import datetime +from pathlib import Path +import shutil +import tempfile +from typing import Callable, Dict, Iterable, List, Optional, Sequence, Set, Tuple, Union + +from tabulate import tabulate + +from aiida import orm +from aiida.common.exceptions import LicensingException +from aiida.common.lang import type_check +from aiida.common.links import GraphTraversalRules +from aiida.common.log import AIIDA_LOGGER +from aiida.common.progress_reporter import get_progress_reporter +from aiida.manage.manager import get_manager +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import Backend +from aiida.orm.utils.links import LinkQuadruple +from aiida.tools.graph.graph_traversers import get_nodes_export, validate_traversal_rules + +from .abstract import ArchiveFormatAbstract, ArchiveWriterAbstract +from .common import batch_iter, entity_type_to_orm +from .exceptions import ArchiveExportError, ExportValidationError +from .implementations.sqlite import ArchiveFormatSqlZip + +__all__ = ('create_archive', 'EXPORT_LOGGER') + +EXPORT_LOGGER = AIIDA_LOGGER.getChild('export') +QbType = Callable[[], orm.QueryBuilder] + + +def create_archive( + entities: Optional[Iterable[Union[orm.Computer, orm.Node, orm.Group, orm.User]]], + filename: Union[None, str, Path] = None, + *, + archive_format: Optional[ArchiveFormatAbstract] = None, + overwrite: bool = False, + include_comments: bool = True, + include_logs: bool = True, + include_authinfos: bool = False, + allowed_licenses: Optional[Union[list, Callable]] = None, + forbidden_licenses: Optional[Union[list, Callable]] = None, + batch_size: int = 1000, + compression: int = 6, + test_run: bool = False, + backend: Optional[Backend] = None, + **traversal_rules: bool +) -> Path: + """Export AiiDA data to an archive file. + + The export follows the following logic: + + First gather all entity primary keys (per type) that needs to be exported. + This need to proceed in the "reverse" order of relationships: + + - groups: input groups + - group_to_nodes: from nodes in groups + - nodes & links: from graph_traversal(input nodes & group_to_nodes) + - computers: from input computers & computers of nodes + - authinfos: from authinfos of computers + - comments: from comments of nodes + - logs: from logs of nodes + - users: from users of nodes, groups, comments & authinfos + + Now stream the full entities (per type) to the archive writer, + in the order of relationships: + + - users + - computers + - authinfos + - groups + - nodes + - comments + - logs + - group_to_nodes + - links + + Finally stream the repository files, + for the exported nodes, to the archive writer. + + Note, the logging level and progress reporter should be set externally, for example:: + + from aiida.common.progress_reporter import set_progress_bar_tqdm + + EXPORT_LOGGER.setLevel('DEBUG') + set_progress_bar_tqdm(leave=True) + create_archive(...) + + :param entities: If ``None``, import all entities, + or a list of entity instances that can include Computers, Groups, and Nodes. + + :param filename: the filename (possibly including the absolute path) + of the file on which to export. + + :param overwrite: if True, overwrite the output file without asking, if it exists. + If False, raise an + :py:class:`~aiida.tools.archive.exceptions.ArchiveExportError` + if the output file already exists. + + :param allowed_licenses: List or function. + If a list, then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param forbidden_licenses: List or function. If a list, + then checks whether all licenses of Data nodes are in the list. If a function, + then calls function for licenses of Data nodes expecting True if license is allowed, False + otherwise. + + :param include_comments: In-/exclude export of comments for given node(s) in ``entities``. + Default: True, *include* comments in export (as well as relevant users). + + :param include_logs: In-/exclude export of logs for given node(s) in ``entities``. + Default: True, *include* logs in export. + + :param compression: level of compression to use (integer from 0 to 9) + + :param batch_size: batch database query results in sub-collections to reduce memory usage + + :param test_run: if True, do not write to file + + :param backend: the backend to export from. If not specified, the default backend is used. + + :param traversal_rules: graph traversal rules. See :const:`aiida.common.links.GraphTraversalRules` + what rule names are toggleable and what the defaults are. + + :raises `~aiida.tools.archive.exceptions.ArchiveExportError`: + if there are any internal errors when exporting. + :raises `~aiida.common.exceptions.LicensingException`: + if any node is licensed under forbidden license. + + """ + # check the backend + backend = backend or get_manager().get_backend() + type_check(backend, Backend) + # create a function to get a query builder instance for the backend + querybuilder = lambda: orm.QueryBuilder(backend=backend) + + # check/set archive file path + type_check(filename, (str, Path), allow_none=True) + if filename is None: + filename = Path.cwd() / 'export_data.aiida' + filename = Path(filename) + if not overwrite and filename.exists(): + raise ArchiveExportError(f"The output file '{filename}' already exists") + if filename.exists() and not filename.is_file(): + raise ArchiveExportError(f"The output file '{filename}' exists as a directory") + + if compression not in range(10): + raise ArchiveExportError('compression must be an integer between 0 and 9') + + # check file format + archive_format = archive_format or ArchiveFormatSqlZip() + type_check(archive_format, ArchiveFormatAbstract) + + # check traversal rules + validate_traversal_rules(GraphTraversalRules.EXPORT, **traversal_rules) + full_traversal_rules = { + name: traversal_rules.get(name, rule.default) for name, rule in GraphTraversalRules.EXPORT.value.items() + } + + initial_summary = get_init_summary( + archive_version=archive_format.latest_version, + outfile=filename, + collect_all=entities is None, + include_authinfos=include_authinfos, + include_comments=include_comments, + include_logs=include_logs, + traversal_rules=full_traversal_rules, + compression=compression + ) + EXPORT_LOGGER.report(initial_summary) + + # Store starting UUIDs, to write to metadata + starting_uuids: Dict[EntityTypes, Set[str]] = { + EntityTypes.USER: set(), + EntityTypes.COMPUTER: set(), + EntityTypes.GROUP: set(), + EntityTypes.NODE: set() + } + + # Store all entity IDs to be written to the archive + # Note, this is the order they will be written to the archive + entity_ids: Dict[EntityTypes, Set[int]] = { + ent: set() for ent in [ + EntityTypes.USER, + EntityTypes.COMPUTER, + EntityTypes.AUTHINFO, + EntityTypes.GROUP, + EntityTypes.NODE, + EntityTypes.COMMENT, + EntityTypes.LOG, + ] + } + + # extract ids/uuid from initial entities + type_check(entities, Iterable, allow_none=True) + if entities is None: + group_nodes, link_data = _collect_all_entities( + querybuilder, entity_ids, include_authinfos, include_comments, include_logs, batch_size + ) + else: + for entry in entities: + if isinstance(entry, orm.Group): + starting_uuids[EntityTypes.GROUP].add(entry.uuid) + entity_ids[EntityTypes.GROUP].add(entry.pk) + elif isinstance(entry, orm.Node): + starting_uuids[EntityTypes.NODE].add(entry.uuid) + entity_ids[EntityTypes.NODE].add(entry.pk) + elif isinstance(entry, orm.Computer): + starting_uuids[EntityTypes.COMPUTER].add(entry.uuid) + entity_ids[EntityTypes.COMPUTER].add(entry.pk) + elif isinstance(entry, orm.User): + starting_uuids[EntityTypes.USER].add(entry.email) + entity_ids[EntityTypes.USER].add(entry.pk) + else: + raise ArchiveExportError( + f'I was given {entry} ({type(entry)}),' + ' which is not a User, Node, Computer, or Group instance' + ) + group_nodes, link_data = _collect_required_entities( + querybuilder, entity_ids, traversal_rules, include_authinfos, include_comments, include_logs, backend, + batch_size + ) + + # now all the nodes have been retrieved, perform some checks + if entity_ids[EntityTypes.NODE]: + EXPORT_LOGGER.report('Validating Nodes') + _check_unsealed_nodes(querybuilder, entity_ids[EntityTypes.NODE], batch_size) + _check_node_licenses( + querybuilder, entity_ids[EntityTypes.NODE], allowed_licenses, forbidden_licenses, batch_size + ) + + # get a count of entities, to report + entity_counts = {etype.value: len(ids) for etype, ids in entity_ids.items()} + entity_counts[EntityTypes.LINK.value] = len(link_data) + entity_counts[EntityTypes.GROUP_NODE.value] = len(group_nodes) + count_summary = [[(name + 's'), num] for name, num in entity_counts.items() if num] + + if test_run: + EXPORT_LOGGER.report('Test Run: Stopping before archive creation') + keys = set( + orm.Node.objects(backend).iter_repo_keys( + filters={'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + }}, batch_size=batch_size + ) + ) + count_summary.append(['Repository Files', len(keys)]) + EXPORT_LOGGER.report(f'Archive would be created with:\n{tabulate(count_summary)}') + return filename + + EXPORT_LOGGER.report(f'Creating archive with:\n{tabulate(count_summary)}') + + # Create and open the archive for writing. + # We create in a temp dir then move to final place at end, + # so that the user cannot end up with a half written archive on errors + with tempfile.TemporaryDirectory() as tmpdir: + tmp_filename = Path(tmpdir) / 'export.zip' + with archive_format.open(tmp_filename, mode='x', compression=compression) as writer: + # add metadata + writer.update_metadata({ + 'ctime': datetime.now().isoformat(), + 'creation_parameters': { + 'entities_starting_set': + {etype.value: list(unique) for etype, unique in starting_uuids.items() if unique}, + 'include_authinfos': include_authinfos, + 'include_comments': include_comments, + 'include_logs': include_logs, + 'graph_traversal_rules': full_traversal_rules, + 'entity_counts': dict(count_summary), # type: ignore + } + }) + # stream entity data to the archive + with get_progress_reporter()(desc='Archiving database: ', total=sum(entity_counts.values())) as progress: + for etype, ids in entity_ids.items(): + transform = lambda row: row['entity'] + progress.set_description_str(f'Archiving database: {etype.value}s') + if ids: + for nrows, rows in batch_iter( + querybuilder().append( + entity_type_to_orm[etype], filters={ + 'id': { + 'in': ids + } + }, tag='entity', project=['**'] + ).iterdict(batch_size=batch_size), batch_size, transform + ): + writer.bulk_insert(etype, rows) + progress.update(nrows) + + # stream links + progress.set_description_str(f'Archiving database: {EntityTypes.LINK.value}s') + transform = lambda d: { + 'input_id': d.source_id, + 'output_id': d.target_id, + 'label': d.link_label, + 'type': d.link_type + } + for nrows, rows in batch_iter(link_data, batch_size, transform): + writer.bulk_insert(EntityTypes.LINK, rows, allow_defaults=True) + progress.update(nrows) + del link_data # release memory + + # stream group_nodes + progress.set_description_str(f'Archiving database: {EntityTypes.GROUP_NODE.value}s') + transform = lambda d: {'dbgroup_id': d[0], 'dbnode_id': d[1]} + for nrows, rows in batch_iter(group_nodes, batch_size, transform): + writer.bulk_insert(EntityTypes.GROUP_NODE, rows, allow_defaults=True) + progress.update(nrows) + del group_nodes # release memory + + # stream node repository files to the archive + if entity_ids[EntityTypes.NODE]: + _stream_repo_files(archive_format.key_format, writer, entity_ids[EntityTypes.NODE], backend, batch_size) + + EXPORT_LOGGER.report('Finalizing archive creation...') + + if filename.exists(): + filename.unlink() + shutil.move(tmp_filename, filename) # type: ignore + + EXPORT_LOGGER.report('Archive created successfully') + + return filename + + +def _collect_all_entities( + querybuilder: QbType, entity_ids: Dict[EntityTypes, Set[int]], include_authinfos: bool, include_comments: bool, + include_logs: bool, batch_size: int +) -> Tuple[List[list], Set[LinkQuadruple]]: + """Collect all entities. + + :returns: (group_id_to_node_id, link_data) and updates entity_ids + """ + progress_str = lambda name: f'Collecting entities: {name}' + with get_progress_reporter()(desc=progress_str(''), total=9) as progress: + + progress.set_description_str(progress_str('Nodes')) + entity_ids[EntityTypes.NODE].update( + querybuilder().append(orm.Node, project='id').all(batch_size=batch_size, flat=True) + ) + progress.update() + + progress.set_description_str(progress_str('Links')) + progress.update() + qbuilder = querybuilder().append(orm.Node, tag='incoming', project=[ + 'id' + ]).append(orm.Node, with_incoming='incoming', project=['id'], edge_project=['type', 'label']).distinct() + link_data = {LinkQuadruple(*row) for row in qbuilder.all(batch_size=batch_size)} + + progress.set_description_str(progress_str('Groups')) + progress.update() + entity_ids[EntityTypes.GROUP].update( + querybuilder().append(orm.Group, project='id').all(batch_size=batch_size, flat=True) + ) + progress.set_description_str(progress_str('Nodes-Groups')) + progress.update() + qbuilder = querybuilder().append(orm.Group, project='id', + tag='group').append(orm.Node, with_group='group', project='id').distinct() + group_nodes = qbuilder.all(batch_size=batch_size) + + progress.set_description_str(progress_str('Computers')) + progress.update() + entity_ids[EntityTypes.COMPUTER].update( + querybuilder().append(orm.Computer, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('AuthInfos')) + progress.update() + if include_authinfos: + entity_ids[EntityTypes.AUTHINFO].update( + querybuilder().append(orm.AuthInfo, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Logs')) + progress.update() + if include_logs: + entity_ids[EntityTypes.LOG].update( + querybuilder().append(orm.Log, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Comments')) + progress.update() + if include_comments: + entity_ids[EntityTypes.COMMENT].update( + querybuilder().append(orm.Comment, project='id').all(batch_size=batch_size, flat=True) + ) + + progress.set_description_str(progress_str('Users')) + progress.update() + entity_ids[EntityTypes.USER].update( + querybuilder().append(orm.User, project='id').all(batch_size=batch_size, flat=True) + ) + + return group_nodes, link_data + + +def _collect_required_entities( + querybuilder: QbType, entity_ids: Dict[EntityTypes, Set[int]], traversal_rules: Dict[str, bool], + include_authinfos: bool, include_comments: bool, include_logs: bool, backend: Backend, batch_size: int +) -> Tuple[List[list], Set[LinkQuadruple]]: + """Collect required entities, given a set of starting entities and provenance graph traversal rules. + + :returns: (group_id_to_node_id, link_data) and updates entity_ids + """ + progress_str = lambda name: f'Collecting entities: {name}' + with get_progress_reporter()(desc=progress_str(''), total=7) as progress: + + # get all nodes from groups + progress.set_description_str(progress_str('Nodes (groups)')) + group_nodes = [] + if entity_ids[EntityTypes.GROUP]: + qbuilder = querybuilder() + qbuilder.append( + orm.Group, filters={'id': { + 'in': list(entity_ids[EntityTypes.GROUP]) + }}, project='id', tag='group' + ) + qbuilder.append(orm.Node, with_group='group', project='id') + qbuilder.distinct() + group_nodes = qbuilder.all(batch_size=batch_size) + entity_ids[EntityTypes.NODE].update(nid for _, nid in group_nodes) + + # get full set of nodes & links, following traversal rules + progress.set_description_str(progress_str('Nodes (traversal)')) + progress.update() + traverse_output = get_nodes_export( + starting_pks=entity_ids[EntityTypes.NODE], get_links=True, backend=backend, **traversal_rules + ) + entity_ids[EntityTypes.NODE].update(traverse_output.pop('nodes')) + link_data = traverse_output.pop('links') or set() # possible memory hog? + + progress.set_description_str(progress_str('Computers')) + progress.update() + + # get full set of computers + if entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.COMPUTER].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Computer, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of authinfos + progress.set_description_str(progress_str('AuthInfos')) + progress.update() + if include_authinfos and entity_ids[EntityTypes.COMPUTER]: + entity_ids[EntityTypes.AUTHINFO].update( + pk for pk, in querybuilder().append( + orm.Computer, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.COMPUTER]) + } + }, tag='comp' + ).append(orm.AuthInfo, with_computer='comp', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of logs + progress.set_description_str(progress_str('Logs')) + progress.update() + if include_logs and entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.LOG].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Log, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of comments + progress.set_description_str(progress_str('Comments')) + progress.update() + if include_comments and entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.COMMENT].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.Comment, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + + # get full set of users + progress.set_description_str(progress_str('Users')) + progress.update() + if entity_ids[EntityTypes.NODE]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Node, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.NODE]) + } + }, tag='node' + ).append(orm.User, with_node='node', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.GROUP]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Group, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.GROUP]) + } + }, tag='group' + ).append(orm.User, with_group='group', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.COMMENT]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.Comment, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.COMMENT]) + } + }, tag='comment' + ).append(orm.User, with_comment='comment', project='id').distinct().iterall(batch_size=batch_size) + ) + if entity_ids[EntityTypes.AUTHINFO]: + entity_ids[EntityTypes.USER].update( + pk for pk, in querybuilder().append( + orm.AuthInfo, filters={ + 'id': { + 'in': list(entity_ids[EntityTypes.AUTHINFO]) + } + }, tag='auth' + ).append(orm.User, with_authinfo='auth', project='id').distinct().iterall(batch_size=batch_size) + ) + + progress.update() + + return group_nodes, link_data + + +def _stream_repo_files( + key_format: str, writer: ArchiveWriterAbstract, node_ids: Set[int], backend: Backend, batch_size: int +) -> None: + """Collect all repository object keys from the nodes, then stream the files to the archive.""" + keys = set(orm.Node.objects(backend).iter_repo_keys(filters={'id': {'in': list(node_ids)}}, batch_size=batch_size)) + + repository = backend.get_repository() + if not repository.key_format == key_format: + # Here we would have to go back and replace all the keys in the `Node.repository_metadata`s + raise NotImplementedError( + f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}' + ) + with get_progress_reporter()(desc='Archiving files: ', total=len(keys)) as progress: + for key, stream in repository.iter_object_streams(keys): + # to-do should we use assume the key here is correct, or always re-compute and check? + writer.put_object(stream, key=key) + progress.update() + + +def _check_unsealed_nodes(querybuilder: QbType, node_ids: Set[int], batch_size: int) -> None: + """Check no process nodes are unsealed, i.e. all processes have completed.""" + qbuilder = querybuilder().append( + orm.ProcessNode, + filters={ + 'id': { + 'in': list(node_ids) + }, + 'attributes.sealed': { + '!in': [True] # better operator? + } + }, + project='id' + ).distinct() + unsealed_node_pks = qbuilder.all(batch_size=batch_size, flat=True) + if unsealed_node_pks: + raise ExportValidationError( + 'All ProcessNodes must be sealed before they can be exported. ' + f"Node(s) with PK(s): {', '.join(str(pk) for pk in unsealed_node_pks)} is/are not sealed." + ) + + +def _check_node_licenses( + querybuilder: QbType, node_ids: Set[int], allowed_licenses: Union[None, Sequence[str], Callable], + forbidden_licenses: Union[None, Sequence[str], Callable], batch_size: int +) -> None: + """Check the nodes to be archived for disallowed licences.""" + if allowed_licenses is None and forbidden_licenses is None: + return None + + # set allowed function + if allowed_licenses is None: + check_allowed = lambda l: True + elif callable(allowed_licenses): + + def _check_allowed(name): + try: + return allowed_licenses(name) # type: ignore + except Exception as exc: + raise LicensingException('allowed_licenses function error') from exc + + check_allowed = _check_allowed + elif isinstance(allowed_licenses, Sequence): + check_allowed = lambda l: l in allowed_licenses # type: ignore + else: + raise TypeError('allowed_licenses not a list or function') + + # set forbidden function + if forbidden_licenses is None: + check_forbidden = lambda l: False + elif callable(forbidden_licenses): + + def _check_forbidden(name): + try: + return forbidden_licenses(name) # type: ignore + except Exception as exc: + raise LicensingException('forbidden_licenses function error') from exc + + check_forbidden = _check_forbidden + elif isinstance(forbidden_licenses, Sequence): + check_forbidden = lambda l: l in forbidden_licenses # type: ignore + else: + raise TypeError('forbidden_licenses not a list or function') + + # create query + qbuilder = querybuilder().append( + orm.Node, + project=['id', 'attributes.source.license'], + filters={'id': { + 'in': list(node_ids) + }}, + ) + + for node_id, name in qbuilder.iterall(batch_size=batch_size): + if name is None: + continue + if not check_allowed(name): + raise LicensingException( + f"Node {node_id} is licensed under '{name}' license, which is not in the list of allowed licenses" + ) + if check_forbidden(name): + raise LicensingException( + f"Node {node_id} is licensed under '{name}' license, which is in the list of forbidden licenses" + ) + + +def get_init_summary( + *, archive_version: str, outfile: Path, collect_all: bool, include_authinfos: bool, include_comments: bool, + include_logs: bool, traversal_rules: dict, compression: int +) -> str: + """Get summary for archive initialisation""" + parameters = [['Path', str(outfile)], ['Version', archive_version], ['Compression', compression]] + + result = f"\n{tabulate(parameters, headers=['Archive Parameters', ''])}" + + inclusions = [['Computers/Nodes/Groups/Users', 'All' if collect_all else 'Selected'], + ['Computer Authinfos', include_authinfos], ['Node Comments', include_comments], + ['Node Logs', include_logs]] + result += f"\n\n{tabulate(inclusions, headers=['Inclusion rules', ''])}" + + if not collect_all: + rules_table = [[f"Follow links {' '.join(name.split('_'))}s", value] for name, value in traversal_rules.items()] + result += f"\n\n{tabulate(rules_table, headers=['Traversal rules', ''])}" + + return result + '\n' diff --git a/aiida/tools/importexport/common/exceptions.py b/aiida/tools/archive/exceptions.py similarity index 63% rename from aiida/tools/importexport/common/exceptions.py rename to aiida/tools/archive/exceptions.py index 5db8fd1c0d..1ad358308f 100644 --- a/aiida/tools/importexport/common/exceptions.py +++ b/aiida/tools/archive/exceptions.py @@ -7,18 +7,26 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Module that defines the exceptions thrown by AiiDA's export/import module. +"""Module that defines the exceptions thrown by AiiDA's archive module. -Note: In order to not override the built-in `ImportError`, both `ImportError` and `ExportError` are prefixed with - `Archive`. +Note: In order to not override the built-in `ImportError`, + both `ImportError` and `ExportError` are prefixed with `Archive`. """ from aiida.common.exceptions import AiidaException __all__ = ( - 'ExportImportException', 'ArchiveExportError', 'ArchiveImportError', 'CorruptArchive', - 'IncompatibleArchiveVersionError', 'ExportValidationError', 'ImportUniquenessError', 'ImportValidationError', - 'ArchiveMigrationError', 'MigrationValidationError', 'DanglingLinkError', 'ProgressBarError' + 'ExportImportException', + 'ArchiveExportError', + 'ExportValidationError', + 'CorruptArchive', + 'ArchiveMigrationError', + 'MigrationValidationError', + 'ArchiveImportError', + 'IncompatibleArchiveVersionError', + 'ImportValidationError', + 'ImportUniquenessError', + 'ImportTestRun', ) @@ -30,20 +38,24 @@ class ArchiveExportError(ExportImportException): """Base class for all AiiDA export exceptions.""" -class ArchiveImportError(ExportImportException): - """Base class for all AiiDA import exceptions.""" +class ExportValidationError(ArchiveExportError): + """Raised when validation fails during export, e.g. for non-sealed ``ProcessNode`` s.""" + + +class UnreadableArchiveError(ArchiveExportError): + """Raised when the version cannot be extracted from the archive.""" class CorruptArchive(ExportImportException): """Raised when an operation is applied to a corrupt export archive, e.g. missing files or invalid formats.""" -class IncompatibleArchiveVersionError(ExportImportException): - """Raised when trying to import an export archive with an incompatible schema version.""" +class ArchiveImportError(ExportImportException): + """Base class for all AiiDA import exceptions.""" -class ExportValidationError(ArchiveExportError): - """Raised when validation fails during export, e.g. for non-sealed ``ProcessNode`` s.""" +class IncompatibleArchiveVersionError(ExportImportException): + """Raised when trying to import an export archive with an incompatible schema version.""" class ImportUniquenessError(ArchiveImportError): @@ -57,6 +69,10 @@ class ImportValidationError(ArchiveImportError): """Raised when validation fails during import, e.g. for parameter types and values.""" +class ImportTestRun(ArchiveImportError): + """Raised during an import, before the transaction is commited.""" + + class ArchiveMigrationError(ExportImportException): """Base class for all AiiDA export archive migration exceptions.""" @@ -65,9 +81,15 @@ class MigrationValidationError(ArchiveMigrationError): """Raised when validation fails during migration of export archives.""" -class DanglingLinkError(MigrationValidationError): - """Raised when an export archive is detected to contain dangling links when importing.""" +class ReadOnlyError(IOError): + """Raised when a write operation is called on a read-only archive.""" + + def __init__(self, msg='Archive is read-only'): # pylint: disable=useless-super-delegation + super().__init__(msg) + +class ArchiveClosedError(IOError): + """Raised when the archive is closed.""" -class ProgressBarError(ExportImportException): - """Something is wrong with setting up the tqdm progress bar""" + def __init__(self, msg='Archive is closed'): # pylint: disable=useless-super-delegation + super().__init__(msg) diff --git a/aiida/tools/importexport/dbexport/__init__.py b/aiida/tools/archive/implementations/__init__.py similarity index 86% rename from aiida/tools/importexport/dbexport/__init__.py rename to aiida/tools/archive/implementations/__init__.py index 22173490f7..6f85411389 100644 --- a/aiida/tools/importexport/dbexport/__init__.py +++ b/aiida/tools/archive/implementations/__init__.py @@ -7,19 +7,17 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Provides export functionalities.""" +"""Concrete implementations of an archive file format.""" # AUTO-GENERATED # yapf: disable # pylint: disable=wildcard-import -from .main import * +from .sqlite import * __all__ = ( - 'EXPORT_LOGGER', - 'ExportFileFormat', - 'export', + 'ArchiveFormatSqlZip', ) # yapf: enable diff --git a/aiida/tools/importexport/dbimport/__init__.py b/aiida/tools/archive/implementations/sqlite/__init__.py similarity index 90% rename from aiida/tools/importexport/dbimport/__init__.py rename to aiida/tools/archive/implementations/sqlite/__init__.py index ad987679f1..d26c0161a0 100644 --- a/aiida/tools/importexport/dbimport/__init__.py +++ b/aiida/tools/archive/implementations/sqlite/__init__.py @@ -7,7 +7,7 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Provides import functionalities.""" +"""SQLite implementations of an archive file format.""" # AUTO-GENERATED @@ -17,8 +17,7 @@ from .main import * __all__ = ( - 'IMPORT_LOGGER', - 'import_data', + 'ArchiveFormatSqlZip', ) # yapf: enable diff --git a/aiida/tools/archive/implementations/sqlite/backend.py b/aiida/tools/archive/implementations/sqlite/backend.py new file mode 100644 index 0000000000..52fdeb5f36 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/backend.py @@ -0,0 +1,405 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The table models are dynamically generated from the sqlalchemy backend models.""" +from contextlib import contextmanager +from datetime import datetime +from functools import singledispatch +from pathlib import Path +from typing import BinaryIO, Iterable, Iterator, List, Optional, Sequence, Tuple, Type, cast +import zipfile +from zipfile import ZipFile + +import pytz +from sqlalchemy import CHAR, Text, orm, types +from sqlalchemy.dialects.postgresql import JSONB, UUID +from sqlalchemy.dialects.sqlite import JSON +from sqlalchemy.sql.schema import Table + +# we need to import all models, to ensure they are loaded on the SQLA Metadata +from aiida.backends.sqlalchemy.models import authinfo, base, comment, computer, group, log, node, user +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation.backends import Backend as BackendAbstract +from aiida.orm.implementation.sqlalchemy import authinfos, comments, computers, entities, groups, logs, nodes, users +from aiida.orm.implementation.sqlalchemy.querybuilder import SqlaQueryBuilder +from aiida.repository.backend.abstract import AbstractRepositoryBackend +from aiida.tools.archive.exceptions import ArchiveClosedError, ReadOnlyError + +from .common import REPO_FOLDER + + +class SqliteModel: + """Represent a row in an sqlite database table""" + + def __repr__(self) -> str: + """Return a representation of the row columns""" + string = f'<{self.__class__.__name__}' + for col in self.__table__.columns: # type: ignore[attr-defined] # pylint: disable=no-member + # don't include columns with potentially large values + if isinstance(col.type, (JSON, Text)): + continue + string += f' {col.name}={getattr(self, col.name)}' + return string + '>' + + +class TZDateTime(types.TypeDecorator): # pylint: disable=abstract-method + """A timezone naive UTC ``DateTime`` implementation for SQLite. + + see: https://docs.sqlalchemy.org/en/14/core/custom_types.html#store-timezone-aware-timestamps-as-timezone-naive-utc + """ + impl = types.DateTime + cache_ok = True + + def process_bind_param(self, value: Optional[datetime], dialect): + """Process before writing to database.""" + if value is None: + return value + if value.tzinfo is None: + value = value.astimezone(pytz.utc) + value = value.astimezone(pytz.utc).replace(tzinfo=None) + return value + + def process_result_value(self, value: Optional[datetime], dialect): + """Process when returning from database.""" + if value is None: + return value + if value.tzinfo is None: + return value.replace(tzinfo=pytz.utc) + return value.astimezone(pytz.utc) + + +ArchiveDbBase = orm.declarative_base(cls=SqliteModel, name='SqliteModel') + + +def pg_to_sqlite(pg_table: Table): + """Convert a model intended for PostGreSQL to one compatible with SQLite""" + new = pg_table.to_metadata(ArchiveDbBase.metadata) + for column in new.columns: + if isinstance(column.type, UUID): + column.type = CHAR(32) + elif isinstance(column.type, types.DateTime): + column.type = TZDateTime() + elif isinstance(column.type, JSONB): + column.type = JSON() + return new + + +def create_orm_cls(klass: base.Base) -> ArchiveDbBase: + """Create an ORM class from an existing table in the declarative meta""" + tbl = ArchiveDbBase.metadata.tables[klass.__tablename__] + return type( # type: ignore[return-value] + klass.__name__, + (ArchiveDbBase,), + { + '__tablename__': tbl.name, + '__table__': tbl, + **{col.name if col.name != 'metadata' else '_metadata': col for col in tbl.columns}, + }, + ) + + +for table in base.Base.metadata.sorted_tables: + pg_to_sqlite(table) + +DbUser = create_orm_cls(user.DbUser) +DbComputer = create_orm_cls(computer.DbComputer) +DbAuthInfo = create_orm_cls(authinfo.DbAuthInfo) +DbGroup = create_orm_cls(group.DbGroup) +DbNode = create_orm_cls(node.DbNode) +DbGroupNodes = create_orm_cls(group.DbGroupNode) +DbComment = create_orm_cls(comment.DbComment) +DbLog = create_orm_cls(log.DbLog) +DbLink = create_orm_cls(node.DbLink) + +# to-do This was the minimum for creating a graph, but really all relationships should be copied +DbNode.dbcomputer = orm.relationship('DbComputer', backref='dbnodes') # type: ignore[attr-defined] +DbGroup.dbnodes = orm.relationship( # type: ignore[attr-defined] + 'DbNode', secondary='db_dbgroup_dbnodes', backref='dbgroups', lazy='dynamic' +) + + +class ZipfileBackendRepository(AbstractRepositoryBackend): + """A read-only backend for an open zip file.""" + + def __init__(self, file: ZipFile): + self._zipfile = file + + @property + def zipfile(self) -> ZipFile: + if self._zipfile.fp is None: + raise ArchiveClosedError() + return self._zipfile + + @property + def uuid(self) -> Optional[str]: + return None + + @property + def key_format(self) -> Optional[str]: + return 'sha256' + + def initialise(self, **kwargs) -> None: + pass + + @property + def is_initialised(self) -> bool: + return True + + def erase(self) -> None: + raise ReadOnlyError() + + def _put_object_from_filelike(self, handle: BinaryIO) -> str: + raise ReadOnlyError() + + def has_object(self, key: str) -> bool: + try: + self.zipfile.getinfo(f'{REPO_FOLDER}/{key}') + except KeyError: + return False + return True + + def has_objects(self, keys: List[str]) -> List[bool]: + return [self.has_object(key) for key in keys] + + def list_objects(self) -> Iterable[str]: + for name in self.zipfile.namelist(): + if name.startswith(REPO_FOLDER + '/') and name[len(REPO_FOLDER) + 1:]: + yield name[len(REPO_FOLDER) + 1:] + + @contextmanager + def open(self, key: str) -> Iterator[BinaryIO]: + try: + handle = self.zipfile.open(f'{REPO_FOLDER}/{key}') + yield cast(BinaryIO, handle) + except KeyError: + raise FileNotFoundError(f'object with key `{key}` does not exist.') + finally: + handle.close() + + def iter_object_streams(self, keys: List[str]) -> Iterator[Tuple[str, BinaryIO]]: + for key in keys: + with self.open(key) as handle: # pylint: disable=not-context-manager + yield key, handle + + def delete_objects(self, keys: List[str]) -> None: + raise ReadOnlyError() + + def get_object_hash(self, key: str) -> str: + return key + + +class ArchiveBackendQueryBuilder(SqlaQueryBuilder): + """Archive query builder""" + + @property + def Node(self): + return DbNode + + @property + def Link(self): + return DbLink + + @property + def Computer(self): + return DbComputer + + @property + def User(self): + return DbUser + + @property + def Group(self): + return DbGroup + + @property + def AuthInfo(self): + return DbAuthInfo + + @property + def Comment(self): + return DbComment + + @property + def Log(self): + return DbLog + + @property + def table_groups_nodes(self): + return DbGroupNodes.__table__ # type: ignore[attr-defined] # pylint: disable=no-member + + +class ArchiveReadOnlyBackend(BackendAbstract): + """A read-only backend for the archive.""" + + def __init__(self, path: Path, session: orm.Session): + super().__init__() + self._path = path + self._session: Optional[orm.Session] = session + # lazy open the archive zipfile + self._zipfile: Optional[zipfile.ZipFile] = None + self._closed = False + + def close(self): + """Close the backend""" + if self._session: + self._session.close() + if self._zipfile: + self._zipfile.close() + self._session = None + self._zipfile = None + self._closed = True + + def get_session(self) -> orm.Session: + if not self._session: + raise ArchiveClosedError() + return self._session + + def get_repository(self) -> ZipfileBackendRepository: + if self._closed: + raise ArchiveClosedError() + if self._zipfile is None: + self._zipfile = ZipFile(self._path, mode='r') # pylint: disable=consider-using-with + return ZipfileBackendRepository(self._zipfile) + + def query(self) -> ArchiveBackendQueryBuilder: + return ArchiveBackendQueryBuilder(self) + + def get_backend_entity(self, res): # pylint: disable=no-self-use + """Return the backend entity that corresponds to the given Model instance.""" + klass = get_backend_entity(res) + return klass(self, res) + + @property + def authinfos(self): + return create_backend_collection(authinfos.SqlaAuthInfoCollection, self, authinfos.SqlaAuthInfo, DbAuthInfo) + + @property + def comments(self): + return create_backend_collection(comments.SqlaCommentCollection, self, comments.SqlaComment, DbComment) + + @property + def computers(self): + return create_backend_collection(computers.SqlaComputerCollection, self, computers.SqlaComputer, DbComputer) + + @property + def groups(self): + return create_backend_collection(groups.SqlaGroupCollection, self, groups.SqlaGroup, DbGroup) + + @property + def logs(self): + return create_backend_collection(logs.SqlaLogCollection, self, logs.SqlaLog, DbLog) + + @property + def nodes(self): + return create_backend_collection(nodes.SqlaNodeCollection, self, nodes.SqlaNode, DbNode) + + @property + def users(self): + return create_backend_collection(users.SqlaUserCollection, self, users.SqlaUser, DbUser) + + def migrate(self): + raise ReadOnlyError() + + def transaction(self): + raise ReadOnlyError() + + @property + def in_transaction(self) -> bool: + return False + + def bulk_insert(self, entity_type: EntityTypes, rows: List[dict], allow_defaults: bool = False) -> List[int]: + raise ReadOnlyError() + + def bulk_update(self, entity_type: EntityTypes, rows: List[dict]) -> None: + raise ReadOnlyError() + + def delete_nodes_and_connections(self, pks_to_delete: Sequence[int]): + raise ReadOnlyError() + + +def create_backend_cls(base_class, model_cls): + """Create an archive backend class for the given model class.""" + + class ReadOnlyEntityBackend(base_class): # type: ignore + """Backend class for the read-only archive.""" + + MODEL_CLASS = model_cls + + def __init__(self, _backend, model): + """Initialise the backend entity.""" + self._backend = _backend + # In the SQLA base classes, the SQLA model instance is wrapped in a proxy class, + # to handle attributes get/set on stored/unstored instances, and saving instances to the database. + # However, since the wrapper is currently tied to the global session (see #5172) + # and this is a read-only archive, we don't need to do that. + self._dbmodel = model + + @property + def dbmodel(self): + return self._dbmodel + + @classmethod + def from_dbmodel(cls, model, _backend): + return cls(_backend, model) + + @property + def is_stored(self): + return True + + def store(self): # pylint: disable=no-self-use + return ReadOnlyError() + + return ReadOnlyEntityBackend + + +def create_backend_collection(cls, _backend, entity_cls, model): + collection = cls(_backend) + new_cls = create_backend_cls(entity_cls, model) + collection.ENTITY_CLASS = new_cls + return collection + + +@singledispatch +def get_backend_entity(dbmodel) -> Type[entities.SqlaModelEntity]: # pylint: disable=unused-argument + raise TypeError(f'Cannot get backend entity for {dbmodel}') + + +@get_backend_entity.register(DbAuthInfo) +def _(dbmodel): + return create_backend_cls(authinfos.SqlaAuthInfo, dbmodel.__class__) + + +@get_backend_entity.register(DbComment) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(comments.SqlaComment, dbmodel.__class__) + + +@get_backend_entity.register(DbComputer) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(computers.SqlaComputer, dbmodel.__class__) + + +@get_backend_entity.register(DbGroup) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(groups.SqlaGroup, dbmodel.__class__) + + +@get_backend_entity.register(DbLog) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(logs.SqlaLog, dbmodel.__class__) + + +@get_backend_entity.register(DbNode) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(nodes.SqlaNode, dbmodel.__class__) + + +@get_backend_entity.register(DbUser) # type: ignore[no-redef] +def _(dbmodel): + return create_backend_cls(users.SqlaUser, dbmodel.__class__) diff --git a/aiida/tools/archive/implementations/sqlite/common.py b/aiida/tools/archive/implementations/sqlite/common.py new file mode 100644 index 0000000000..a375cf7c26 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/common.py @@ -0,0 +1,158 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Common variables""" +import os +from pathlib import Path +import shutil +import tempfile +from typing import Callable, Sequence, Union + +from archive_path import TarPath, ZipPath +from sqlalchemy import event +from sqlalchemy.future.engine import Engine, create_engine + +from aiida.common import json +from aiida.common.progress_reporter import create_callback, get_progress_reporter + +META_FILENAME = 'metadata.json' +DB_FILENAME = 'db.sqlite3' +# folder to store repository files in +REPO_FOLDER = 'repo' + + +def sqlite_enforce_foreign_keys(dbapi_connection, _): + """Enforce foreign key constraints, when using sqlite backend (off by default)""" + cursor = dbapi_connection.cursor() + cursor.execute('PRAGMA foreign_keys=ON;') + cursor.close() + + +def create_sqla_engine(path: Union[str, Path], *, enforce_foreign_keys: bool = True, **kwargs) -> Engine: + """Create a new engine instance.""" + engine = create_engine( + f'sqlite:///{path}', + json_serializer=json.dumps, + json_deserializer=json.loads, + encoding='utf-8', + future=True, + **kwargs + ) + if enforce_foreign_keys: + event.listen(engine, 'connect', sqlite_enforce_foreign_keys) + return engine + + +def copy_zip_to_zip( + inpath: Path, + outpath: Path, + path_callback: Callable[[ZipPath, ZipPath], bool], + *, + compression: int = 6, + overwrite: bool = True, + title: str = 'Writing new zip file', + info_order: Sequence[str] = () +) -> None: + """Create a new zip file from an existing zip file. + + All files/folders are streamed directly to the new zip file, + with the ``path_callback`` allowing for per path modifications. + The new zip file is first created in a temporary directory, and then moved to the desired location. + + :param inpath: the path to the existing archive + :param outpath: the path to output the new archive + :param path_callback: a callback that is called for each path in the archive: ``(inpath, outpath) -> handled`` + If handled is ``True``, the path is assumed to already have been copied to the new zip file. + :param compression: the default compression level to use for the new zip file + :param overwrite: whether to overwrite the output file if it already exists + :param title: the title of the progress bar + :param info_order: ``ZipInfo`` for these file names will be written first to the zip central directory. + This allows for faster reading of these files, with ``archive_path.read_file_in_zip``. + """ + if (not overwrite) and outpath.exists() and outpath.is_file(): + raise FileExistsError(f'{outpath} already exists') + with tempfile.TemporaryDirectory() as tmpdirname: + temp_archive = Path(tmpdirname) / 'archive.zip' + with ZipPath(temp_archive, mode='w', compresslevel=compression, info_order=info_order) as new_path: + with ZipPath(inpath, mode='r') as path: + length = sum(1 for _ in path.glob('**/*', include_virtual=False)) + with get_progress_reporter()(desc=title, total=length) as progress: + for subpath in path.glob('**/*', include_virtual=False): + new_path_sub = new_path.joinpath(subpath.at) + if path_callback(subpath, new_path_sub): + pass + elif subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + new_path_sub.putfile(subpath) + progress.update() + if overwrite and outpath.exists() and outpath.is_file(): + outpath.unlink() + shutil.move(temp_archive, outpath) # type: ignore[arg-type] + + +def copy_tar_to_zip( + inpath: Path, + outpath: Path, + path_callback: Callable[[Path, ZipPath], bool], + *, + compression: int = 6, + overwrite: bool = True, + title: str = 'Writing new zip file', + info_order: Sequence[str] = () +) -> None: + """Create a new zip file from an existing tar file. + + The tar file is first extracted to a temporary directory, and then the new zip file is created, + with the ``path_callback`` allowing for per path modifications. + The new zip file is first created in a temporary directory, and then moved to the desired location. + + :param inpath: the path to the existing archive + :param outpath: the path to output the new archive + :param path_callback: a callback that is called for each path in the archive: ``(inpath, outpath) -> handled`` + If handled is ``True``, the path is assumed to already have been copied to the new zip file. + :param compression: the default compression level to use for the new zip file + :param overwrite: whether to overwrite the output file if it already exists + :param title: the title of the progress bar + :param info_order: ``ZipInfo`` for these file names will be written first to the zip central directory. + This allows for faster reading of these files, with ``archive_path.read_file_in_zip``. + """ + if (not overwrite) and outpath.exists() and outpath.is_file(): + raise FileExistsError(f'{outpath} already exists') + with tempfile.TemporaryDirectory() as tmpdirname: + # for tar files we extract first, since the file is compressed as a single object + temp_extracted = Path(tmpdirname) / 'extracted' + with get_progress_reporter()(total=1) as progress: + callback = create_callback(progress) + TarPath(inpath, mode='r:*').extract_tree( + temp_extracted, + allow_dev=False, + allow_symlink=False, + callback=callback, + cb_descript=f'{title} (extracting tar)' + ) + temp_archive = Path(tmpdirname) / 'archive.zip' + with ZipPath(temp_archive, mode='w', compresslevel=compression, info_order=info_order) as new_path: + length = sum(1 for _ in temp_extracted.glob('**/*')) + with get_progress_reporter()(desc=title, total=length) as progress: + for subpath in temp_extracted.glob('**/*'): + new_path_sub = new_path.joinpath(subpath.relative_to(temp_extracted).as_posix()) + if path_callback(subpath.relative_to(temp_extracted), new_path_sub): + pass + elif subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + # files extracted from the tar do not include a modified time, yet zip requires one + os.utime(subpath, (subpath.stat().st_ctime, subpath.stat().st_ctime)) + new_path_sub.putfile(subpath) + progress.update() + + if overwrite and outpath.exists() and outpath.is_file(): + outpath.unlink() + shutil.move(temp_archive, outpath) # type: ignore[arg-type] diff --git a/aiida/tools/archive/implementations/sqlite/main.py b/aiida/tools/archive/implementations/sqlite/main.py new file mode 100644 index 0000000000..46682362ae --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/main.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""The file format implementation""" +from pathlib import Path +from typing import Any, List, Union, overload + +from aiida.tools.archive.abstract import ArchiveFormatAbstract + +from .migrations.main import ALL_VERSIONS, migrate +from .reader import ArchiveReaderSqlZip, read_version +from .writer import ArchiveAppenderSqlZip, ArchiveWriterSqlZip + +try: + from typing import Literal # pylint: disable=ungrouped-imports +except ImportError: + # Python <3.8 backport + from typing_extensions import Literal # type: ignore + +__all__ = ('ArchiveFormatSqlZip',) + + +class ArchiveFormatSqlZip(ArchiveFormatAbstract): + """Archive format, which uses a zip file, containing an SQLite database. + + The content of the zip file is:: + + |- archive.zip + |- metadata.json + |- db.sqlite3 + |- repo/ + |- hashkey + + Repository files are named by their SHA256 content hash. + + """ + + @property + def versions(self) -> List[str]: + return ALL_VERSIONS + + def read_version(self, path: Union[str, Path]) -> str: + return read_version(path) + + @property + def key_format(self) -> str: + return 'sha256' + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['r'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveReaderSqlZip: + ... + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['x', 'w'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveWriterSqlZip: + ... + + @overload + def open( + self, + path: Union[str, Path], + mode: Literal['a'], + *, + compression: int = 6, + **kwargs: Any + ) -> ArchiveAppenderSqlZip: + ... + + def open( + self, + path: Union[str, Path], + mode: Literal['r', 'x', 'w', 'a'] = 'r', + *, + compression: int = 6, + **kwargs: Any + ) -> Union[ArchiveReaderSqlZip, ArchiveWriterSqlZip, ArchiveAppenderSqlZip]: + if mode == 'r': + return ArchiveReaderSqlZip(path, **kwargs) + if mode == 'a': + return ArchiveAppenderSqlZip(path, self, mode=mode, compression=compression, **kwargs) + return ArchiveWriterSqlZip(path, self, mode=mode, compression=compression, **kwargs) + + def migrate( + self, + inpath: Union[str, Path], + outpath: Union[str, Path], + version: str, + *, + force: bool = False, + compression: int = 6 + ) -> None: + """Migrate an archive to a specific version. + + :param path: archive path + """ + current_version = self.read_version(inpath) + return migrate(inpath, outpath, current_version, version, force=force, compression=compression) diff --git a/aiida/tools/importexport/dbimport/backends/__init__.py b/aiida/tools/archive/implementations/sqlite/migrations/__init__.py similarity index 90% rename from aiida/tools/importexport/dbimport/backends/__init__.py rename to aiida/tools/archive/implementations/sqlite/migrations/__init__.py index 2776a55f97..84dbe1264d 100644 --- a/aiida/tools/importexport/dbimport/backends/__init__.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/__init__.py @@ -7,3 +7,4 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### +"""Migration archive files from old export versions to newer ones.""" diff --git a/aiida/tools/importexport/archive/migrations/__init__.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/__init__.py similarity index 66% rename from aiida/tools/importexport/archive/migrations/__init__.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/__init__.py index d7628eb6a4..5190ad4d96 100644 --- a/aiida/tools/importexport/archive/migrations/__init__.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/__init__.py @@ -7,14 +7,13 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Migration archive files from old export versions to the newest, used by `verdi export migrate` command.""" -from typing import Callable, Dict, Tuple +"""Legacy migrations, +using the old ``data.json`` format for storing the database. -from aiida.tools.importexport.archive.common import CacheFolder +These migrations simply manipulate the metadata and data in-place. +""" +from typing import Callable, Dict, Tuple -from .v01_to_v02 import migrate_v1_to_v2 -from .v02_to_v03 import migrate_v2_to_v3 -from .v03_to_v04 import migrate_v3_to_v4 from .v04_to_v05 import migrate_v4_to_v5 from .v05_to_v06 import migrate_v5_to_v6 from .v06_to_v07 import migrate_v6_to_v7 @@ -23,14 +22,10 @@ from .v09_to_v10 import migrate_v9_to_v10 from .v10_to_v11 import migrate_v10_to_v11 from .v11_to_v12 import migrate_v11_to_v12 -from .v12_to_v13 import migrate_v12_to_v13 -# version from -> version to, function which acts on the cache folder -_vtype = Dict[str, Tuple[str, Callable[[CacheFolder], None]]] -MIGRATE_FUNCTIONS: _vtype = { - '0.1': ('0.2', migrate_v1_to_v2), - '0.2': ('0.3', migrate_v2_to_v3), - '0.3': ('0.4', migrate_v3_to_v4), +# version from -> version to, function which modifies metadata, data in-place +_vtype = Dict[str, Tuple[str, Callable[[dict, dict], None]]] +LEGACY_MIGRATE_FUNCTIONS: _vtype = { '0.4': ('0.5', migrate_v4_to_v5), '0.5': ('0.6', migrate_v5_to_v6), '0.6': ('0.7', migrate_v6_to_v7), @@ -39,5 +34,5 @@ '0.9': ('0.10', migrate_v9_to_v10), '0.10': ('0.11', migrate_v10_to_v11), '0.11': ('0.12', migrate_v11_to_v12), - '0.12': ('0.13', migrate_v12_to_v13), } +FINAL_LEGACY_VERSION = '0.12' diff --git a/aiida/tools/importexport/archive/migrations/v04_to_v05.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v04_to_v05.py similarity index 85% rename from aiida/tools/importexport/archive/migrations/v04_to_v05.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v04_to_v05.py index 42401d49c5..bd70cec9da 100644 --- a/aiida/tools/importexport/archive/migrations/v04_to_v05.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v04_to_v05.py @@ -24,9 +24,7 @@ Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import remove_fields, update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import remove_fields, update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migration_drop_node_columns_nodeversion_public(metadata, data): @@ -49,7 +47,7 @@ def migration_drop_computer_transport_params(metadata, data): remove_fields(metadata, data, [entity], [field]) -def migrate_v4_to_v5(folder: CacheFolder): +def migrate_v4_to_v5(metadata: dict, data: dict) -> None: """ Migration of archive files from v0.4 to v0.5 @@ -58,15 +56,9 @@ def migrate_v4_to_v5(folder: CacheFolder): old_version = '0.4' new_version = '0.5' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') # Apply migrations migration_drop_node_columns_nodeversion_public(metadata, data) migration_drop_computer_transport_params(metadata, data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v05_to_v06.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v05_to_v06.py similarity index 93% rename from aiida/tools/importexport/archive/migrations/v05_to_v06.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v05_to_v06.py index ef664e6c8c..6229ac9afb 100644 --- a/aiida/tools/importexport/archive/migrations/v05_to_v06.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v05_to_v06.py @@ -26,14 +26,12 @@ # pylint: disable=invalid-name from typing import Union -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migrate_deserialized_datetime(data, conversion): """Deserialize datetime strings from export archives, meaning to reattach the UTC timezone information.""" - from aiida.tools.importexport.common.exceptions import ArchiveMigrationError + from aiida.tools.archive.exceptions import ArchiveMigrationError ret_data: Union[str, dict, list] @@ -136,21 +134,14 @@ def migration_migrate_legacy_job_calculation_data(data): values['process_label'] = 'Legacy JobCalculation' -def migrate_v5_to_v6(folder: CacheFolder): +def migrate_v5_to_v6(metadata: dict, data: dict) -> None: """Migration of archive files from v0.5 to v0.6""" old_version = '0.5' new_version = '0.6' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_serialize_datetime_objects(data) migration_migrate_legacy_job_calculation_data(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v06_to_v07.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v06_to_v07.py similarity index 88% rename from aiida/tools/importexport/archive/migrations/v06_to_v07.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v06_to_v07.py index 0c231ca61b..56bdd93816 100644 --- a/aiida/tools/importexport/archive/migrations/v06_to_v07.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v06_to_v07.py @@ -24,9 +24,7 @@ Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def data_migration_legacy_process_attributes(data): @@ -48,14 +46,14 @@ def data_migration_legacy_process_attributes(data): `process_state` attribute. If they have it, it is checked whether the state is active or not, if not, the `sealed` attribute is created and set to `True`. - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if a Node, found to have attributes, + :raises `~aiida.tools.archive.exceptions.CorruptArchive`: if a Node, found to have attributes, cannot be found in the list of exported entities. - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: if the 'sealed' attribute does not exist and + :raises `~aiida.tools.archive.exceptions.CorruptArchive`: if the 'sealed' attribute does not exist and the ProcessNode is in an active state, i.e. `process_state` is one of ('created', 'running', 'waiting'). A log-file, listing all illegal ProcessNodes, will be produced in the current directory. """ from aiida.manage.database.integrity import write_database_integrity_violation - from aiida.tools.importexport.common.exceptions import CorruptArchive + from aiida.tools.archive.exceptions import CorruptArchive attrs_to_remove = ['_sealed', '_finished', '_failed', '_aborted', '_do_abort'] active_states = {'created', 'running', 'waiting'} @@ -113,21 +111,14 @@ def remove_attribute_link_metadata(metadata): metadata[dictionary].pop(entity, None) -def migrate_v6_to_v7(folder: CacheFolder): +def migrate_v6_to_v7(metadata: dict, data: dict) -> None: """Migration of archive files from v0.6 to v0.7""" old_version = '0.6' new_version = '0.7' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations data_migration_legacy_process_attributes(data) remove_attribute_link_metadata(metadata) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v07_to_v08.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v07_to_v08.py similarity index 83% rename from aiida/tools/importexport/archive/migrations/v07_to_v08.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v07_to_v08.py index d3b7ab5696..4fea760391 100644 --- a/aiida/tools/importexport/archive/migrations/v07_to_v08.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v07_to_v08.py @@ -24,9 +24,7 @@ Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migration_default_link_label(data: dict): @@ -39,20 +37,13 @@ def migration_default_link_label(data: dict): link['label'] = 'result' -def migrate_v7_to_v8(folder: CacheFolder): +def migrate_v7_to_v8(metadata: dict, data: dict) -> None: """Migration of archive files from v0.7 to v0.8.""" old_version = '0.7' new_version = '0.8' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_default_link_label(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v08_to_v09.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v08_to_v09.py similarity index 84% rename from aiida/tools/importexport/archive/migrations/v08_to_v09.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v08_to_v09.py index dc79d2781a..7da9bbe8ea 100644 --- a/aiida/tools/importexport/archive/migrations/v08_to_v09.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v08_to_v09.py @@ -24,9 +24,7 @@ Where id is a SQLA id and migration-name is the name of the particular migration. """ # pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module def migration_dbgroup_type_string(data): @@ -47,20 +45,13 @@ def migration_dbgroup_type_string(data): attributes['type_string'] = new -def migrate_v8_to_v9(folder: CacheFolder): +def migrate_v8_to_v9(metadata: dict, data: dict) -> None: """Migration of archive files from v0.8 to v0.9.""" old_version = '0.8' new_version = '0.9' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations migration_dbgroup_type_string(data) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v09_to_v10.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v09_to_v10.py similarity index 77% rename from aiida/tools/importexport/archive/migrations/v09_to_v10.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v09_to_v10.py index 08570a0ae7..a005837005 100644 --- a/aiida/tools/importexport/archive/migrations/v09_to_v10.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v09_to_v10.py @@ -8,24 +8,18 @@ # For further information please visit http://www.aiida.net # ########################################################################### """Migration from v0.9 to v0.10, used by `verdi export migrate` command.""" -# pylint: disable=invalid-name -from aiida.tools.importexport.archive.common import CacheFolder +# pylint: disable=invalid-name,unused-argument +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - -def migrate_v9_to_v10(folder: CacheFolder): +def migrate_v9_to_v10(metadata: dict, data: dict) -> None: """Migration of archive files from v0.9 to v0.10.""" old_version = '0.9' new_version = '0.10' - _, metadata = folder.load_json('metadata.json') - verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) metadata['all_fields_info']['Node']['attributes'] = {'convert_type': 'jsonb'} metadata['all_fields_info']['Node']['extras'] = {'convert_type': 'jsonb'} metadata['all_fields_info']['Group']['extras'] = {'convert_type': 'jsonb'} - - folder.write_json('metadata.json', metadata) diff --git a/aiida/tools/importexport/archive/migrations/v11_to_v12.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v10_to_v11.py similarity index 66% rename from aiida/tools/importexport/archive/migrations/v11_to_v12.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v10_to_v11.py index cd10fe266b..011a83d761 100644 --- a/aiida/tools/importexport/archive/migrations/v11_to_v12.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v10_to_v11.py @@ -7,27 +7,21 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Migration from v0.11 to v0.12, used by ``verdi archive migrate`` command. +"""Migration from v0.10 to v0.11, used by ``verdi archive migrate`` command. This migration applies the name change of the ``Computer`` attribute ``name`` to ``label``. """ -from aiida.tools.importexport.archive.common import CacheFolder +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - -def migrate_v11_to_v12(folder: CacheFolder): - """Migration of export files from v0.11 to v0.12.""" - old_version = '0.11' - new_version = '0.12' - - _, metadata = folder.load_json('metadata.json') +def migrate_v10_to_v11(metadata: dict, data: dict) -> None: + """Migration of export files from v0.10 to v0.11.""" + old_version = '0.10' + new_version = '0.11' verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Apply migrations for attributes in data.get('export_data', {}).get('Computer', {}).values(): attributes['label'] = attributes.pop('name') @@ -36,6 +30,3 @@ def migrate_v11_to_v12(folder: CacheFolder): metadata['all_fields_info']['Computer']['label'] = metadata['all_fields_info']['Computer'].pop('name') except KeyError: pass - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v12_to_v13.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v11_to_v12.py similarity index 90% rename from aiida/tools/importexport/archive/migrations/v12_to_v13.py rename to aiida/tools/archive/implementations/sqlite/migrations/legacy/v11_to_v12.py index a6763a91f1..8144787a18 100644 --- a/aiida/tools/importexport/archive/migrations/v12_to_v13.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy/v11_to_v12.py @@ -7,13 +7,11 @@ # For further information on the license, see the LICENSE.txt file # # For further information please visit http://www.aiida.net # ########################################################################### -"""Migration from v0.12 to v0.13, used by ``verdi archive migrate`` command. +"""Migration from v0.11 to v0.12, used by ``verdi archive migrate`` command. This migration is necessary after the `core.` prefix was added to entry points shipped with `aiida-core`. """ -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module +from ..utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module MAPPING_DATA = { 'data.array.ArrayData.': 'data.core.array.ArrayData.', @@ -67,19 +65,15 @@ } -def migrate_v12_to_v13(folder: CacheFolder): - """Migration of export files from v0.12 to v0.13.""" +def migrate_v11_to_v12(metadata: dict, data: dict) -> None: + """Migration of export files from v0.11 to v0.12.""" # pylint: disable=too-many-branches - old_version = '0.12' - new_version = '0.13' - - _, metadata = folder.load_json('metadata.json') + old_version = '0.11' + new_version = '0.12' verify_metadata_version(metadata, old_version) update_metadata(metadata, new_version) - _, data = folder.load_json('data.json') - # Migrate data entry point names for values in data.get('export_data', {}).get('Node', {}).values(): if 'node_type' in values and values['node_type'].startswith('data.'): @@ -126,6 +120,3 @@ def migrate_v12_to_v13(folder: CacheFolder): pass else: values['scheduler_type'] = new_scheduler_type - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/archive/implementations/sqlite/migrations/legacy_to_new.py b/aiida/tools/archive/implementations/sqlite/migrations/legacy_to_new.py new file mode 100644 index 0000000000..c2dc6db5ba --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/migrations/legacy_to_new.py @@ -0,0 +1,275 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""Migration from legacy JSON format.""" +from datetime import datetime +from hashlib import sha256 +import json +from pathlib import Path, PurePosixPath +import shutil +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +from archive_path import TarPath, ZipPath +from sqlalchemy import insert, select +from sqlalchemy.exc import IntegrityError + +from aiida.common.hashing import chunked_file_hash +from aiida.common.progress_reporter import get_progress_reporter +from aiida.repository.common import File, FileType +from aiida.tools.archive.common import MIGRATE_LOGGER, batch_iter +from aiida.tools.archive.exceptions import CorruptArchive, MigrationValidationError + +from . import v1_db_schema as db +from ..common import DB_FILENAME, META_FILENAME, REPO_FOLDER, create_sqla_engine +from .utils import update_metadata + +_NODE_ENTITY_NAME = 'Node' +_GROUP_ENTITY_NAME = 'Group' +_COMPUTER_ENTITY_NAME = 'Computer' +_USER_ENTITY_NAME = 'User' +_LOG_ENTITY_NAME = 'Log' +_COMMENT_ENTITY_NAME = 'Comment' + +file_fields_to_model_fields: Dict[str, Dict[str, str]] = { + _NODE_ENTITY_NAME: { + 'dbcomputer': 'dbcomputer_id', + 'user': 'user_id' + }, + _GROUP_ENTITY_NAME: { + 'user': 'user_id' + }, + _COMPUTER_ENTITY_NAME: {}, + _LOG_ENTITY_NAME: { + 'dbnode': 'dbnode_id' + }, + _COMMENT_ENTITY_NAME: { + 'dbnode': 'dbnode_id', + 'user': 'user_id' + } +} + +aiida_orm_to_backend = { + _USER_ENTITY_NAME: db.DbUser, + _GROUP_ENTITY_NAME: db.DbGroup, + _NODE_ENTITY_NAME: db.DbNode, + _COMMENT_ENTITY_NAME: db.DbComment, + _COMPUTER_ENTITY_NAME: db.DbComputer, + _LOG_ENTITY_NAME: db.DbLog, +} + + +def perform_v1_migration( # pylint: disable=too-many-locals + inpath: Path, working: Path, archive_name: str, is_tar: bool, metadata: dict, data: dict, compression: int +) -> str: + """Perform the repository and JSON to SQLite migration. + + 1. Iterate though the repository paths in the archive + 2. If a file, hash its contents and, if not already present, stream it to the new archive + 3. Store a mapping of the node UUIDs to a list of (path, hashkey or None if a directory) tuples + + :param inpath: the input path to the old archive + :param metadata: the metadata to migrate + :param data: the data to migrate + """ + # to-do streaming tar files is a lot slower than streaming zip files, + # it would be most performant to extract the entire archive to disk first and then stream it + + MIGRATE_LOGGER.report('Initialising new archive...') + node_repos: Dict[str, List[Tuple[str, Optional[str]]]] = {} + central_dir: Dict[str, Any] = {} + in_archive: Union[TarPath, ZipPath] = TarPath(inpath) if is_tar else ZipPath(inpath) + with ZipPath( + working / archive_name, + mode='w', + compresslevel=compression, + name_to_info=central_dir, + info_order=(META_FILENAME, DB_FILENAME) + ) as new_path: + with in_archive as path: + length = sum(1 for _ in path.glob('**/*')) + with get_progress_reporter()(desc='Converting repo', total=length) as progress: + for subpath in path.glob('**/*'): + progress.update() + parts = subpath.parts + # repository file are stored in the legacy archive as `nodes/uuid[0:2]/uuid[2:4]/uuid[4:]/path/...` + if len(parts) < 6 or parts[0] != 'nodes' or parts[4] not in ('raw_input', 'path'): + continue + uuid = ''.join(parts[1:4]) + posix_rel = PurePosixPath(*parts[5:]) + hashkey = None + if subpath.is_file(): + with subpath.open('rb') as handle: + hashkey = chunked_file_hash(handle, sha256) + if f'{REPO_FOLDER}/{hashkey}' not in central_dir: + with subpath.open('rb') as handle: + with (new_path / f'{REPO_FOLDER}/{hashkey}').open(mode='wb') as handle2: + shutil.copyfileobj(handle, handle2) + node_repos.setdefault(uuid, []).append((posix_rel.as_posix(), hashkey)) + MIGRATE_LOGGER.report(f'Unique files written: {len(central_dir)}') + + _json_to_sqlite(working / DB_FILENAME, data, node_repos) + + MIGRATE_LOGGER.report('Finalising archive') + with (working / DB_FILENAME).open('rb') as handle: + with (new_path / DB_FILENAME).open(mode='wb') as handle2: + shutil.copyfileobj(handle, handle2) + + # remove legacy keys from metadata and store + metadata.pop('unique_identifiers', None) + metadata.pop('all_fields_info', None) + # remove legacy key nesting + metadata['creation_parameters'] = metadata.pop('export_parameters', {}) + metadata['compression'] = compression + metadata['key_format'] = 'sha256' + metadata['mtime'] = datetime.now().isoformat() + update_metadata(metadata, '1.0') + (new_path / META_FILENAME).write_text(json.dumps(metadata)) + + return '1.0' + + +def _json_to_sqlite( + outpath: Path, data: dict, node_repos: Dict[str, List[Tuple[str, Optional[str]]]], batch_size: int = 100 +) -> None: + """Convert a JSON archive format to SQLite.""" + MIGRATE_LOGGER.report('Converting DB to SQLite') + + engine = create_sqla_engine(outpath) + db.ArchiveV1Base.metadata.create_all(engine) + + with engine.begin() as connection: + # proceed in order of relationships + for entity_type in ( + _USER_ENTITY_NAME, _COMPUTER_ENTITY_NAME, _GROUP_ENTITY_NAME, _NODE_ENTITY_NAME, _LOG_ENTITY_NAME, + _COMMENT_ENTITY_NAME + ): + if not data['export_data'].get(entity_type, {}): + continue + length = len(data['export_data'].get(entity_type, {})) + backend_cls = aiida_orm_to_backend[entity_type] + with get_progress_reporter()(desc=f'Adding {entity_type}s', total=length) as progress: + for nrows, rows in batch_iter(_iter_entity_fields(data, entity_type, node_repos), batch_size): + # to-do check for unused keys? + try: + connection.execute(insert(backend_cls.__table__), rows) # type: ignore + except IntegrityError as exc: + raise MigrationValidationError(f'Database integrity error: {exc}') from exc + progress.update(nrows) + + if not (data['groups_uuid'] or data['links_uuid']): + return None + + with engine.begin() as connection: + + # get mapping of node IDs to node UUIDs + node_uuid_map = {uuid: pk for uuid, pk in connection.execute(select(db.DbNode.uuid, db.DbNode.id))} # pylint: disable=unnecessary-comprehension + + # links + if data['links_uuid']: + + def _transform_link(link_row): + return { + 'input_id': node_uuid_map[link_row['input']], + 'output_id': node_uuid_map[link_row['output']], + 'label': link_row['label'], + 'type': link_row['type'] + } + + with get_progress_reporter()(desc='Adding Links', total=len(data['links_uuid'])) as progress: + for nrows, rows in batch_iter(data['links_uuid'], batch_size, transform=_transform_link): + connection.execute(insert(db.DbLink.__table__), rows) + progress.update(nrows) + + # groups to nodes + if data['groups_uuid']: + # get mapping of node IDs to node UUIDs + group_uuid_map = {uuid: pk for uuid, pk in connection.execute(select(db.DbGroup.uuid, db.DbGroup.id))} # pylint: disable=unnecessary-comprehension + length = sum(len(uuids) for uuids in data['groups_uuid'].values()) + with get_progress_reporter()(desc='Adding Group-Nodes', total=length) as progress: + for group_uuid, node_uuids in data['groups_uuid'].items(): + group_id = group_uuid_map[group_uuid] + connection.execute( + insert(db.DbGroupNodes.__table__), [{ + 'dbnode_id': node_uuid_map[uuid], + 'dbgroup_id': group_id + } for uuid in node_uuids] + ) + progress.update(len(node_uuids)) + + +def _convert_datetime(key, value): + if key in ('time', 'ctime', 'mtime'): + return datetime.strptime(value, '%Y-%m-%dT%H:%M:%S.%f') + return value + + +def _iter_entity_fields( + data, + name: str, + node_repos: Dict[str, List[Tuple[str, Optional[str]]]], +) -> Iterator[Dict[str, Any]]: + """Iterate through entity fields.""" + keys = file_fields_to_model_fields.get(name, {}) + if name == _NODE_ENTITY_NAME: + # here we merge in the attributes and extras before yielding + attributes = data.get('node_attributes', {}) + extras = data.get('node_extras', {}) + for pk, all_fields in data['export_data'].get(name, {}).items(): + if pk not in attributes: + raise CorruptArchive(f'Unable to find attributes info for Node with Pk={pk}') + if pk not in extras: + raise CorruptArchive(f'Unable to find extra info for Node with Pk={pk}') + uuid = all_fields['uuid'] + repository_metadata = _create_repo_metadata(node_repos[uuid]) if uuid in node_repos else {} + yield { + **{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()}, + **{ + 'id': pk, + 'attributes': attributes[pk], + 'extras': extras[pk], + 'repository_metadata': repository_metadata + } + } + else: + for pk, all_fields in data['export_data'].get(name, {}).items(): + yield {**{keys.get(key, key): _convert_datetime(key, val) for key, val in all_fields.items()}, **{'id': pk}} + + +def _create_repo_metadata(paths: List[Tuple[str, Optional[str]]]) -> Dict[str, Any]: + """Create the repository metadata. + + :param paths: list of (path, hashkey) tuples + :return: the repository metadata + """ + top_level = File() + for _path, hashkey in paths: + path = PurePosixPath(_path) + if hashkey is None: + _create_directory(top_level, path) + else: + directory = _create_directory(top_level, path.parent) + directory.objects[path.name] = File(path.name, FileType.FILE, hashkey) + return top_level.serialize() + + +def _create_directory(top_level: File, path: PurePosixPath) -> File: + """Create a new directory with the given path. + + :param path: the relative path of the directory. + :return: the created directory. + """ + directory = top_level + + for part in path.parts: + if part not in directory.objects: + directory.objects[part] = File(part) + + directory = directory.objects[part] + + return directory diff --git a/aiida/tools/archive/implementations/sqlite/migrations/main.py b/aiida/tools/archive/implementations/sqlite/migrations/main.py new file mode 100644 index 0000000000..c2328e3213 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/migrations/main.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""AiiDA archive migrator implementation.""" +from pathlib import Path +import shutil +import tarfile +import tempfile +from typing import Any, Dict, List, Optional, Union +import zipfile + +from archive_path import open_file_in_tar, open_file_in_zip + +from aiida.common import json +from aiida.common.progress_reporter import get_progress_reporter +from aiida.tools.archive.common import MIGRATE_LOGGER +from aiida.tools.archive.exceptions import ArchiveMigrationError, CorruptArchive + +from ..common import copy_tar_to_zip, copy_zip_to_zip +from .legacy import FINAL_LEGACY_VERSION, LEGACY_MIGRATE_FUNCTIONS +from .legacy_to_new import perform_v1_migration + +ALL_VERSIONS = ['0.4', '0.5', '0.6', '0.7', '0.8', '0.9', '0.10', '0.11', '0.12', '1.0'] + + +def migrate( # pylint: disable=too-many-branches,too-many-statements + inpath: Union[str, Path], + outpath: Union[str, Path], + current_version: str, + version: str, + *, + force: bool = False, + compression: int = 6 +) -> None: + """Migrate an archive to a specific version. + + :param path: archive path + """ + inpath = Path(inpath) + outpath = Path(outpath) + + if outpath.exists() and not force: + raise IOError('Output path already exists and force=False') + if outpath.exists() and not outpath.is_file(): + raise IOError('Existing output path is not a file') + + # check versions are valid + # versions 0.1, 0.2, 0.3 are no longer supported, + # since 0.3 -> 0.4 requires costly migrations of repo files (you would need to unpack all of them) + if current_version in ('0.1', '0.2', '0.3') or version in ('0.1', '0.2', '0.3'): + raise ArchiveMigrationError( + f"Migration from '{current_version}' -> '{version}' is not supported in aiida-core v2" + ) + if current_version not in ALL_VERSIONS: + raise ArchiveMigrationError(f"Unknown current version '{current_version}'") + if version not in ALL_VERSIONS: + raise ArchiveMigrationError(f"Unknown target version '{version}'") + + # if we are already at the desired version, then no migration is required + if current_version == version: + if inpath != outpath: + if outpath.exists() and force: + outpath.unlink() + shutil.copyfile(inpath, outpath) + return + + # the file should be either a tar (legacy only) or zip file + if tarfile.is_tarfile(str(inpath)): + is_tar = True + elif zipfile.is_zipfile(str(inpath)): + is_tar = False + else: + raise CorruptArchive(f'The input file is neither a tar nor a zip file: {inpath}') + + # read the metadata.json which should always be present + metadata = _read_json(inpath, 'metadata.json', is_tar) + # data.json will only be read from legacy archives + data: Optional[Dict[str, Any]] = None + + # if the archive is a "legacy" format, i.e. has a data.json file, migrate to latest one + if current_version in LEGACY_MIGRATE_FUNCTIONS: + MIGRATE_LOGGER.report('Legacy migrations required') + MIGRATE_LOGGER.report('Extracting data.json ...') + # read the data.json file + data = _read_json(inpath, 'data.json', is_tar) + to_version = FINAL_LEGACY_VERSION if version not in LEGACY_MIGRATE_FUNCTIONS else version + current_version = _perform_legacy_migrations(current_version, to_version, metadata, data) + + if current_version == version: + # create new legacy archive with updated metadata & data + def path_callback(inpath, outpath) -> bool: + if inpath.name == 'metadata.json': + outpath.write_text(json.dumps(metadata)) + return True + if inpath.name == 'data.json': + outpath.write_text(json.dumps(data)) + return True + return False + + func = copy_tar_to_zip if is_tar else copy_zip_to_zip + + func( + inpath, + outpath, + path_callback, + overwrite=force, + compression=compression, + title='Writing migrated legacy archive', + info_order=('metadata.json', 'data.json') + ) + return + + with tempfile.TemporaryDirectory() as tmpdirname: + + if current_version == FINAL_LEGACY_VERSION: + MIGRATE_LOGGER.report('aiida-core v1 -> v2 migration required') + if data is None: + MIGRATE_LOGGER.report('Extracting data.json ...') + data = _read_json(inpath, 'data.json', is_tar) + current_version = perform_v1_migration( + inpath, Path(tmpdirname), 'new.zip', is_tar, metadata, data, compression + ) + + if not current_version == version: + raise ArchiveMigrationError(f"Migration from '{current_version}' -> '{version}' failed") + + if outpath.exists() and force: + outpath.unlink() + shutil.move(Path(tmpdirname) / 'new.zip', outpath) # type: ignore + + +def _read_json(inpath: Path, filename: str, is_tar: bool) -> Dict[str, Any]: + """Read a JSON file from the archive.""" + if is_tar: + with open_file_in_tar(inpath, filename) as handle: + data = json.load(handle) + else: + with open_file_in_zip(inpath, filename) as handle: + data = json.load(handle) + return data + + +def _perform_legacy_migrations(current_version: str, to_version: str, metadata: dict, data: dict) -> str: + """Perform legacy migrations from the current version to the desired version. + + Legacy archives use the old ``data.json`` format for storing the database. + These migrations simply manipulate the metadata and data in-place. + + :param current_version: current version of the archive + :param to_version: version to migrate to + :param metadata: the metadata to migrate + :param data: the data to migrate + :return: the new version of the archive + """ + # compute the migration pathway + prev_version = current_version + pathway: List[str] = [] + while prev_version != to_version: + if prev_version not in LEGACY_MIGRATE_FUNCTIONS: + raise ArchiveMigrationError(f"No migration pathway available for '{current_version}' to '{to_version}'") + if prev_version in pathway: + raise ArchiveMigrationError( + f'cyclic migration pathway encountered: {" -> ".join(pathway + [prev_version])}' + ) + pathway.append(prev_version) + prev_version = LEGACY_MIGRATE_FUNCTIONS[prev_version][0] + + if not pathway: + MIGRATE_LOGGER.report('No migration required') + return to_version + + MIGRATE_LOGGER.report('Legacy migration pathway: %s', ' -> '.join(pathway + [to_version])) + + with get_progress_reporter()(total=len(pathway), desc='Performing migrations: ') as progress: + for from_version in pathway: + to_version = LEGACY_MIGRATE_FUNCTIONS[from_version][0] + progress.set_description_str(f'Performing migrations: {from_version} -> {to_version}', refresh=True) + LEGACY_MIGRATE_FUNCTIONS[from_version][1](metadata, data) + progress.update() + + return to_version diff --git a/aiida/tools/importexport/archive/migrations/utils.py b/aiida/tools/archive/implementations/sqlite/migrations/utils.py similarity index 98% rename from aiida/tools/importexport/archive/migrations/utils.py rename to aiida/tools/archive/implementations/sqlite/migrations/utils.py index ecdb7b076b..e769de1bd4 100644 --- a/aiida/tools/importexport/archive/migrations/utils.py +++ b/aiida/tools/archive/implementations/sqlite/migrations/utils.py @@ -9,7 +9,7 @@ ########################################################################### """Utility functions for migration of export-files.""" -from aiida.tools.importexport.common import exceptions +from aiida.tools.archive import exceptions def verify_metadata_version(metadata, version=None): diff --git a/aiida/tools/archive/implementations/sqlite/migrations/v1_db_schema.py b/aiida/tools/archive/implementations/sqlite/migrations/v1_db_schema.py new file mode 100644 index 0000000000..8fb14e3c28 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/migrations/v1_db_schema.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""This is the sqlite DB schema, coresponding to the 34a831f4286d main DB revision. + +For normal operation of the archive, +we auto-generate the schema from the models in ``aiida.backends.sqlalchemy.models``. +However, when migrating an archive from the old format, we require a fixed revision of the schema. + +The only difference between the PostGreSQL schema and SQLite one, +is the replacement of ``JSONB`` with ``JSON``, and ``UUID`` with ``CHAR(36)``. +""" +from sqlalchemy import ForeignKey, orm +from sqlalchemy.dialects.sqlite import JSON +from sqlalchemy.schema import Column, Index, UniqueConstraint +from sqlalchemy.types import CHAR, Boolean, DateTime, Integer, String, Text + +ArchiveV1Base = orm.declarative_base() + + +class DbAuthInfo(ArchiveV1Base): + """Class that keeps the authernification data.""" + + __tablename__ = 'db_dbauthinfo' + __table_args__ = (UniqueConstraint('aiidauser_id', 'dbcomputer_id'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + aiidauser_id = Column( + Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED') + ) + dbcomputer_id = Column( + Integer, ForeignKey('db_dbcomputer.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED') + ) + _metadata = Column('metadata', JSON) + auth_params = Column(JSON) + enabled = Column(Boolean, default=True) + + +class DbComment(ArchiveV1Base): + """Class to store comments.""" + + __tablename__ = 'db_dbcomment' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(36), unique=True) + dbnode_id = Column(Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) + ctime = Column(DateTime(timezone=True)) + mtime = Column(DateTime(timezone=True)) + user_id = Column(Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) + content = Column(Text, nullable=True) + + +class DbComputer(ArchiveV1Base): + """Class to store computers.""" + __tablename__ = 'db_dbcomputer' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(36), unique=True) + label = Column(String(255), unique=True, nullable=False) + hostname = Column(String(255)) + description = Column(Text, nullable=True) + scheduler_type = Column(String(255)) + transport_type = Column(String(255)) + _metadata = Column('metadata', JSON) + + +class DbGroupNodes(ArchiveV1Base): + """Class to store join table for group -> nodes.""" + + __tablename__ = 'db_dbgroup_dbnodes' + __table_args__ = (UniqueConstraint('dbgroup_id', 'dbnode_id', name='db_dbgroup_dbnodes_dbgroup_id_dbnode_id_key'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + dbnode_id = Column(Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED')) + dbgroup_id = Column(Integer, ForeignKey('db_dbgroup.id', deferrable=True, initially='DEFERRED')) + + +class DbGroup(ArchiveV1Base): + """Class to store groups.""" + + __tablename__ = 'db_dbgroup' + __table_args__ = (UniqueConstraint('label', 'type_string'),) + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(36), unique=True) + label = Column(String(255), index=True) + type_string = Column(String(255), default='', index=True) + time = Column(DateTime(timezone=True)) + description = Column(Text, nullable=True) + extras = Column(JSON, default=dict, nullable=False) + user_id = Column(Integer, ForeignKey('db_dbuser.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED')) + + Index('db_dbgroup_dbnodes_dbnode_id_idx', DbGroupNodes.dbnode_id) + Index('db_dbgroup_dbnodes_dbgroup_id_idx', DbGroupNodes.dbgroup_id) + + +class DbLog(ArchiveV1Base): + """Class to store logs.""" + + __tablename__ = 'db_dblog' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(36), unique=True) + time = Column(DateTime(timezone=True)) + loggername = Column(String(255), index=True) + levelname = Column(String(255), index=True) + dbnode_id = Column( + Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED', ondelete='CASCADE'), nullable=False + ) + message = Column(Text(), nullable=True) + _metadata = Column('metadata', JSON) + + +class DbNode(ArchiveV1Base): + """Class to store nodes.""" + + __tablename__ = 'db_dbnode' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + uuid = Column(CHAR(36), unique=True) + node_type = Column(String(255), index=True) + process_type = Column(String(255), index=True) + label = Column(String(255), index=True, nullable=True, default='') + description = Column(Text(), nullable=True, default='') + ctime = Column(DateTime(timezone=True)) + mtime = Column(DateTime(timezone=True)) + attributes = Column(JSON) + extras = Column(JSON) + repository_metadata = Column(JSON, nullable=False, default=dict, server_default='{}') + dbcomputer_id = Column( + Integer, + ForeignKey('db_dbcomputer.id', deferrable=True, initially='DEFERRED', ondelete='RESTRICT'), + nullable=True + ) + user_id = Column( + Integer, ForeignKey('db_dbuser.id', deferrable=True, initially='DEFERRED', ondelete='restrict'), nullable=False + ) + + +class DbLink(ArchiveV1Base): + """Class to store links between nodes.""" + + __tablename__ = 'db_dblink' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + input_id = Column(Integer, ForeignKey('db_dbnode.id', deferrable=True, initially='DEFERRED'), index=True) + output_id = Column( + Integer, ForeignKey('db_dbnode.id', ondelete='CASCADE', deferrable=True, initially='DEFERRED'), index=True + ) + label = Column(String(255), index=True, nullable=False) + type = Column(String(255), index=True) + + +class DbUser(ArchiveV1Base): + """Class to store users.""" + + __tablename__ = 'db_dbuser' + + id = Column(Integer, primary_key=True) # pylint: disable=invalid-name + email = Column(String(254), unique=True, index=True) + first_name = Column(String(254), nullable=True) + last_name = Column(String(254), nullable=True) + institution = Column(String(254), nullable=True) diff --git a/aiida/tools/archive/implementations/sqlite/reader.py b/aiida/tools/archive/implementations/sqlite/reader.py new file mode 100644 index 0000000000..169312316d --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/reader.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""AiiDA archive reader implementation.""" +import json +from pathlib import Path +import shutil +import tarfile +import tempfile +from typing import Any, Dict, Optional, Union +import zipfile + +from archive_path import extract_file_in_zip, read_file_in_tar, read_file_in_zip +from sqlalchemy import orm + +from aiida.tools.archive.abstract import ArchiveReaderAbstract +from aiida.tools.archive.exceptions import CorruptArchive, UnreadableArchiveError + +from . import backend as db +from .common import DB_FILENAME, META_FILENAME, create_sqla_engine + + +class ArchiveReaderSqlZip(ArchiveReaderAbstract): + """An archive reader for the SQLite format.""" + + def __init__(self, path: Union[str, Path], **kwargs: Any): + super().__init__(path, **kwargs) + self._in_context = False + # we lazily create the temp dir / session when needed, then clean up on exit + self._temp_dir: Optional[Path] = None + self._backend: Optional[db.ArchiveReadOnlyBackend] = None + + def __enter__(self) -> 'ArchiveReaderSqlZip': + self._in_context = True + return self + + def __exit__(self, *args, **kwargs) -> None: + """Finalise the archive.""" + super().__exit__(*args, **kwargs) + if self._backend: + self._backend.close() + self._backend = None + if self._temp_dir: + shutil.rmtree(self._temp_dir, ignore_errors=False) + self._temp_dir = None + self._in_context = False + + def get_metadata(self) -> Dict[str, Any]: + try: + return extract_metadata(self.path) + except Exception as exc: + raise CorruptArchive('metadata could not be read') from exc + + def get_backend(self) -> db.ArchiveReadOnlyBackend: + if not self._in_context: + raise AssertionError('Not in context') + if self._backend is not None: + return self._backend + if not self._temp_dir: + # create the work folder + self._temp_dir = Path(tempfile.mkdtemp()) + db_file = self._temp_dir / DB_FILENAME + if not db_file.exists(): + # extract the database to the work folder + with db_file.open('wb') as handle: + try: + extract_file_in_zip(self.path, DB_FILENAME, handle, search_limit=4) + except Exception as exc: + raise CorruptArchive(f'database could not be read: {exc}') from exc + engine = create_sqla_engine(db_file) + self._backend = db.ArchiveReadOnlyBackend(self.path, orm.Session(engine)) + return self._backend + + +def extract_metadata(path: Union[str, Path], search_limit: Optional[int] = 10) -> Dict[str, Any]: + """Extract the metadata dictionary from the archive""" + # we fail if not one of the first record in central directory (as expected) + # so we don't have to iter all repo files to fail + return json.loads(read_file_in_zip(path, META_FILENAME, 'utf8', search_limit=search_limit)) + + +def read_version(path: Union[str, Path]) -> str: + """Read the version of the archive from the file. + + Intended to work for all versions of the archive format. + + :param path: archive path + + :raises: ``FileNotFoundError`` if the file does not exist + :raises: ``UnreadableArchiveError`` if a version cannot be read from the archive + """ + path = Path(path) + if not path.is_file(): + raise FileNotFoundError('archive file not found') + # check the file is at least a zip or tar file + if zipfile.is_zipfile(path): + try: + metadata = extract_metadata(path, search_limit=None) + except Exception as exc: + raise UnreadableArchiveError(f'Could not read metadata for version: {exc}') from exc + elif tarfile.is_tarfile(path): + try: + metadata = json.loads(read_file_in_tar(path, META_FILENAME)) + except Exception as exc: + raise UnreadableArchiveError(f'Could not read metadata for version: {exc}') from exc + else: + raise UnreadableArchiveError('Not a zip or tar file') + if 'export_version' in metadata: + return metadata['export_version'] + raise UnreadableArchiveError("Metadata does not contain 'export_version' key") diff --git a/aiida/tools/archive/implementations/sqlite/writer.py b/aiida/tools/archive/implementations/sqlite/writer.py new file mode 100644 index 0000000000..0a31bcf471 --- /dev/null +++ b/aiida/tools/archive/implementations/sqlite/writer.py @@ -0,0 +1,313 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +"""AiiDA archive writer implementation.""" +from datetime import datetime +import functools +import hashlib +from io import BytesIO +import json +from pathlib import Path +import shutil +import tempfile +from typing import Any, BinaryIO, Dict, List, Optional, Set, Union +import zipfile + +from archive_path import NOTSET, ZipPath, extract_file_in_zip, read_file_in_zip +from sqlalchemy import insert, inspect +from sqlalchemy.exc import IntegrityError as SqlaIntegrityError +from sqlalchemy.future.engine import Connection + +from aiida import get_version +from aiida.common.exceptions import IntegrityError +from aiida.common.hashing import chunked_file_hash +from aiida.common.progress_reporter import get_progress_reporter +from aiida.orm.entities import EntityTypes +from aiida.tools.archive.abstract import ArchiveFormatAbstract, ArchiveWriterAbstract +from aiida.tools.archive.exceptions import CorruptArchive, IncompatibleArchiveVersionError + +from . import backend as db +from .common import DB_FILENAME, META_FILENAME, REPO_FOLDER, create_sqla_engine + +try: + from typing import Literal # pylint: disable=ungrouped-imports +except ImportError: + # Python <3.8 backport + from typing_extensions import Literal # type: ignore + + +@functools.lru_cache(maxsize=10) +def _get_model_from_entity(entity_type: EntityTypes): + """Return the Sqlalchemy model and column names corresponding to the given entity.""" + model = { + EntityTypes.USER: db.DbUser, + EntityTypes.AUTHINFO: db.DbAuthInfo, + EntityTypes.GROUP: db.DbGroup, + EntityTypes.NODE: db.DbNode, + EntityTypes.COMMENT: db.DbComment, + EntityTypes.COMPUTER: db.DbComputer, + EntityTypes.LOG: db.DbLog, + EntityTypes.LINK: db.DbLink, + EntityTypes.GROUP_NODE: db.DbGroupNodes + }[entity_type] + mapper = inspect(model).mapper + column_names = {col.name for col in mapper.c.values()} + return model, column_names + + +class ArchiveWriterSqlZip(ArchiveWriterAbstract): + """AiiDA archive writer implementation.""" + + meta_name = META_FILENAME + db_name = DB_FILENAME + + def __init__( + self, + path: Union[str, Path], + fmt: ArchiveFormatAbstract, + *, + mode: Literal['x', 'w', 'a'] = 'x', + compression: int = 6, + work_dir: Optional[Path] = None, + _debug: bool = False, + _enforce_foreign_keys: bool = True, + ): + super().__init__(path, fmt, mode=mode, compression=compression) + self._init_work_dir = work_dir + self._in_context = False + self._enforce_foreign_keys = _enforce_foreign_keys + self._debug = _debug + self._metadata: Dict[str, Any] = {} + self._central_dir: Dict[str, Any] = {} + self._deleted_paths: Set[str] = set() + self._zip_path: Optional[ZipPath] = None + self._work_dir: Optional[Path] = None + self._conn: Optional[Connection] = None + + def _assert_in_context(self): + if not self._in_context: + raise AssertionError('Not in context') + + def __enter__(self) -> 'ArchiveWriterSqlZip': + """Start writing to the archive""" + self._metadata = { + 'export_version': self._format.latest_version, + 'aiida_version': get_version(), + 'key_format': 'sha256', + 'compression': self._compression, + } + self._work_dir = Path(tempfile.mkdtemp()) if self._init_work_dir is None else Path(self._init_work_dir) + self._central_dir = {} + self._zip_path = ZipPath( + self._path, + mode=self._mode, + compression=zipfile.ZIP_DEFLATED if self._compression else zipfile.ZIP_STORED, + compresslevel=self._compression, + info_order=(self.meta_name, self.db_name), + name_to_info=self._central_dir, + ) + engine = create_sqla_engine( + self._work_dir / self.db_name, enforce_foreign_keys=self._enforce_foreign_keys, echo=self._debug + ) + db.ArchiveDbBase.metadata.create_all(engine) + self._conn = engine.connect() + self._in_context = True + return self + + def __exit__(self, *args, **kwargs): + """Finalise the archive""" + if self._conn: + self._conn.commit() + self._conn.close() + assert self._work_dir is not None + with (self._work_dir / self.db_name).open('rb') as handle: + self._stream_binary(self.db_name, handle) + self._stream_binary( + self.meta_name, + BytesIO(json.dumps(self._metadata).encode('utf8')), + compression=0, # the metadata is small, so no benefit for compression + ) + if self._zip_path: + self._zip_path.close() + self._central_dir = {} + if self._work_dir is not None and self._init_work_dir is None: + shutil.rmtree(self._work_dir, ignore_errors=True) + self._zip_path = self._work_dir = self._conn = None + self._in_context = False + + def update_metadata(self, data: Dict[str, Any], overwrite: bool = False) -> None: + if not overwrite and set(self._metadata).intersection(set(data)): + raise ValueError(f'Cannot overwrite existing keys: {set(self._metadata).intersection(set(data))}') + self._metadata.update(data) + + def bulk_insert( + self, + entity_type: EntityTypes, + rows: List[Dict[str, Any]], + allow_defaults: bool = False, + ) -> None: + if not rows: + return + self._assert_in_context() + assert self._conn is not None + model, col_keys = _get_model_from_entity(entity_type) + if allow_defaults: + for row in rows: + if not col_keys.issuperset(row): + raise IntegrityError( + f'Incorrect fields given for {entity_type}: {set(row)} not subset of {col_keys}' + ) + else: + for row in rows: + if set(row) != col_keys: + raise IntegrityError(f'Incorrect fields given for {entity_type}: {set(row)} != {col_keys}') + try: + self._conn.execute(insert(model.__table__), rows) + except SqlaIntegrityError as exc: + raise IntegrityError(f'Inserting {entity_type}: {exc}') from exc + + def _stream_binary( + self, + name: str, + handle: BinaryIO, + *, + buffer_size: Optional[int] = None, + compression: Optional[int] = None, + comment: Optional[bytes] = None, + ) -> None: + """Add a binary stream to the archive. + + :param buffer_size: Number of bytes to buffer + :param compression: Override global compression level + :param comment: A binary meta comment about the object + """ + self._assert_in_context() + assert self._zip_path is not None + kwargs: Dict[str, Any] = {'comment': NOTSET if comment is None else comment} + if compression is not None: + kwargs['compression'] = zipfile.ZIP_DEFLATED if compression else zipfile.ZIP_STORED + kwargs['level'] = compression + with self._zip_path.joinpath(name).open('wb', **kwargs) as zip_handle: + if buffer_size is None: + shutil.copyfileobj(handle, zip_handle) + else: + shutil.copyfileobj(handle, zip_handle, length=buffer_size) + + def put_object(self, stream: BinaryIO, *, buffer_size: Optional[int] = None, key: Optional[str] = None) -> str: + if key is None: + key = chunked_file_hash(stream, hashlib.sha256) + stream.seek(0) + if f'{REPO_FOLDER}/{key}' not in self._central_dir: + self._stream_binary(f'{REPO_FOLDER}/{key}', stream, buffer_size=buffer_size) + return key + + def delete_object(self, key: str) -> None: + raise IOError(f'Cannot delete objects in {self._mode!r} mode') + + +class ArchiveAppenderSqlZip(ArchiveWriterSqlZip): + """AiiDA archive appender implementation.""" + + def delete_object(self, key: str) -> None: + self._assert_in_context() + if f'{REPO_FOLDER}/{key}' in self._central_dir: + raise IOError(f'Cannot delete object {key!r} that has been added in the same append context') + self._deleted_paths.add(f'{REPO_FOLDER}/{key}') + + def __enter__(self) -> 'ArchiveAppenderSqlZip': + """Start appending to the archive""" + # the file should already exist + if not self._path.exists(): + raise FileNotFoundError(f'Archive {self._path} does not exist') + # the file should be an archive with the correct version + version = self._format.read_version(self._path) + if not version == self._format.latest_version: + raise IncompatibleArchiveVersionError( + f'Archive is version {version!r} but expected {self._format.latest_version!r}' + ) + # load the metadata + self._metadata = json.loads(read_file_in_zip(self._path, META_FILENAME, 'utf8', search_limit=4)) + # overwrite metadata + self._metadata['mtime'] = datetime.now().isoformat() + self._metadata['compression'] = self._compression + # create the work folder + self._work_dir = Path(tempfile.mkdtemp()) if self._init_work_dir is None else Path(self._init_work_dir) + # create a new zip file in the work folder + self._central_dir = {} + self._deleted_paths = set() + self._zip_path = ZipPath( + self._work_dir / 'archive.zip', + mode='w', + compression=zipfile.ZIP_DEFLATED if self._compression else zipfile.ZIP_STORED, + compresslevel=self._compression, + info_order=(self.meta_name, self.db_name), + name_to_info=self._central_dir, + ) + # extract the database to the work folder + db_file = self._work_dir / self.db_name + with db_file.open('wb') as handle: + try: + extract_file_in_zip(self.path, DB_FILENAME, handle, search_limit=4) + except Exception as exc: + raise CorruptArchive(f'database could not be read: {exc}') from exc + # open a connection to the database + engine = create_sqla_engine( + self._work_dir / self.db_name, enforce_foreign_keys=self._enforce_foreign_keys, echo=self._debug + ) + # to-do could check that the database has correct schema: + # https://docs.sqlalchemy.org/en/14/core/reflection.html#reflecting-all-tables-at-once + self._conn = engine.connect() + self._in_context = True + return self + + def __exit__(self, *args, **kwargs): + """Finalise the archive""" + if self._conn: + self._conn.commit() + self._conn.close() + assert self._work_dir is not None + # write the database and metadata to the new archive + with (self._work_dir / self.db_name).open('rb') as handle: + self._stream_binary(self.db_name, handle) + self._stream_binary( + self.meta_name, + BytesIO(json.dumps(self._metadata).encode('utf8')), + compression=0, + ) + # finalise the new archive + self._copy_old_zip_files() + if self._zip_path is not None: + self._zip_path.close() + self._central_dir = {} + self._deleted_paths = set() + # now move it to the original location + self._path.unlink() + shutil.move(self._work_dir / 'archive.zip', self._path) # type: ignore[arg-type] + if self._init_work_dir is None: + shutil.rmtree(self._work_dir, ignore_errors=True) + self._zip_path = self._work_dir = self._conn = None + self._in_context = False + + def _copy_old_zip_files(self): + """Copy the old archive content to the new one (omitting any amended or deleted files)""" + assert self._zip_path is not None + with ZipPath(self._path, mode='r') as old_archive: + length = sum(1 for _ in old_archive.glob('**/*', include_virtual=False)) + with get_progress_reporter()(desc='Writing amended archive', total=length) as progress: + for subpath in old_archive.glob('**/*', include_virtual=False): + if subpath.at in self._central_dir or subpath.at in self._deleted_paths: + continue + new_path_sub = self._zip_path.joinpath(subpath.at) + if subpath.is_dir(): + new_path_sub.mkdir(exist_ok=True) + else: + with subpath.open('rb') as handle: + with new_path_sub.open('wb') as new_handle: + shutil.copyfileobj(handle, new_handle) + progress.update() diff --git a/aiida/tools/archive/imports.py b/aiida/tools/archive/imports.py new file mode 100644 index 0000000000..5d09d01d81 --- /dev/null +++ b/aiida/tools/archive/imports.py @@ -0,0 +1,1122 @@ +# -*- coding: utf-8 -*- +########################################################################### +# Copyright (c), The AiiDA team. All rights reserved. # +# This file is part of the AiiDA code. # +# # +# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # +# For further information on the license, see the LICENSE.txt file # +# For further information please visit http://www.aiida.net # +########################################################################### +# pylint: disable=too-many-branches,too-many-lines,too-many-locals,too-many-statements +"""Import an archive.""" +from pathlib import Path +from typing import Callable, Dict, Optional, Set, Tuple, Union + +from tabulate import tabulate + +from aiida import orm +from aiida.common import timezone +from aiida.common.lang import type_check +from aiida.common.links import LinkType +from aiida.common.log import AIIDA_LOGGER +from aiida.common.progress_reporter import get_progress_reporter +from aiida.manage.manager import get_manager +from aiida.orm.entities import EntityTypes +from aiida.orm.implementation import Backend +from aiida.orm.querybuilder import QueryBuilder +from aiida.repository import Repository + +from .abstract import ArchiveFormatAbstract +from .common import batch_iter, entity_type_to_orm +from .exceptions import ImportTestRun, ImportUniquenessError, ImportValidationError, IncompatibleArchiveVersionError +from .implementations.sqlite import ArchiveFormatSqlZip + +try: + from typing import Literal +except ImportError: + from typing_extensions import Literal # type: ignore + +__all__ = ('IMPORT_LOGGER', 'import_archive') + +IMPORT_LOGGER = AIIDA_LOGGER.getChild('export') + +MergeExtrasType = Tuple[Literal['k', 'n'], Literal['c', 'n'], Literal['l', 'u', 'd']] +MergeExtraDescs = ({ + 'k': '(k)eep', + 'n': 'do (n)ot keep' +}, { + 'c': '(c)reate', + 'n': 'do (n)ot create' +}, { + 'l': '(l)eave existing', + 'u': '(u)pdate with new', + 'd': '(d)elete' +}) +MergeCommentsType = Literal['leave', 'newest', 'overwrite'] + +DUPLICATE_LABEL_MAX = 100 +DUPLICATE_LABEL_TEMPLATE = '{0} (Imported #{1})' + + +def import_archive( + path: Union[str, Path], + *, + archive_format: Optional[ArchiveFormatAbstract] = None, + batch_size: int = 1000, + import_new_extras: bool = True, + merge_extras: MergeExtrasType = ('k', 'n', 'l'), + merge_comments: MergeCommentsType = 'leave', + include_authinfos: bool = False, + group: Optional[orm.Group] = None, + test_run: bool = False, + backend: Optional[Backend] = None, +) -> Optional[int]: + """Import an archive into the AiiDA backend. + + :param path: the path to the archive + :param archive_format: The class for interacting with the archive + :param batch_size: Batch size for streaming database rows + :param import_new_extras: Keep extras on new nodes (except private aiida keys), else strip + :param merge_extras: Rules for merging extras into existing nodes. + The first letter acts on extras that are present in the original node and not present in the imported node. + Can be either: + 'k' (keep it) or + 'n' (do not keep it). + The second letter acts on the imported extras that are not present in the original node. + Can be either: + 'c' (create it) or + 'n' (do not create it). + The third letter defines what to do in case of a name collision. + Can be either: + 'l' (leave the old value), + 'u' (update with a new value), + 'd' (delete the extra) + :param group: Group wherein all imported Nodes will be placed. + If None, one will be auto-generated. + :param test_run: if True, do not write to file + :param backend: the backend to import to. If not specified, the default backend is used. + + :returns: Primary Key of the import Group + + :raises `~aiida.tools.archive.exceptions.IncompatibleArchiveVersionError`: if the provided archive's + version is not equal to the version of AiiDA at the moment of import. + :raises `~aiida.tools.archive.exceptions.ImportValidationError`: if parameters or the contents of + :raises `~aiida.tools.archive.exceptions.CorruptArchive`: if the provided archive cannot be read. + :raises `~aiida.tools.archive.exceptions.ImportUniquenessError`: if a new unique entity can not be created. + """ + archive_format = archive_format or ArchiveFormatSqlZip() + type_check(path, (str, Path)) + type_check(archive_format, ArchiveFormatAbstract) + type_check(batch_size, int) + type_check(import_new_extras, bool) + type_check(merge_extras, tuple) + if len(merge_extras) != 3: + raise ValueError('merge_extras not of length 3') + if not (merge_extras[0] in ['k', 'n'] and merge_extras[1] in ['c', 'n'] and merge_extras[2] in ['l', 'u', 'd']): + raise ValueError('merge_extras contains invalid values') + if merge_comments not in ('leave', 'newest', 'overwrite'): + raise ValueError(f"merge_comments not in {('leave', 'newest', 'overwrite')!r}") + type_check(group, orm.Group, allow_none=True) + type_check(test_run, bool) + backend = backend or get_manager().get_backend() + type_check(backend, Backend) + + if group and not group.is_stored: + group.store() + + # check the version is latest + # to-do we should have a way to check the version against aiida-core + # i.e. its not whether the version is the latest that matters, it is that it is compatible with the backend version + # its a bit weird at the moment because django/sqlalchemy have different versioning + if not archive_format.read_version(path) == archive_format.latest_version: + raise IncompatibleArchiveVersionError( + f'The archive version {archive_format.read_version(path)} ' + f'is not the latest version {archive_format.latest_version}' + ) + + IMPORT_LOGGER.report( + str( + tabulate([ + ['Archive', Path(path).name], + ['New Node Extras', 'keep' if import_new_extras else 'strip'], + ['Merge Node Extras (in database)', MergeExtraDescs[0][merge_extras[0]]], + ['Merge Node Extras (in archive)', MergeExtraDescs[1][merge_extras[1]]], + ['Merge Node Extras (in both)', MergeExtraDescs[2][merge_extras[2]]], + ['Merge Comments', merge_comments], + ['Computer Authinfos', 'include' if include_authinfos else 'exclude'], + ], + headers=['Parameters', '']) + ) + '\n' + ) + + if test_run: + IMPORT_LOGGER.report('Test run: nothing will be added to the profile') + + with archive_format.open(path, mode='r') as reader: + + backend_from = reader.get_backend() + + # To ensure we do not corrupt the backend database on a faulty import, + # Every addition/update is made in a single transaction, which is commited on exit + with backend.transaction(): + + user_ids_archive_backend = _import_users(backend_from, backend, batch_size) + computer_ids_archive_backend = _import_computers(backend_from, backend, batch_size) + if include_authinfos: + _import_authinfos( + backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend + ) + node_ids_archive_backend = _import_nodes( + backend_from, backend, batch_size, user_ids_archive_backend, computer_ids_archive_backend, + import_new_extras, merge_extras + ) + _import_logs(backend_from, backend, batch_size, node_ids_archive_backend) + _import_comments( + backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend, merge_comments + ) + _import_links(backend_from, backend, batch_size, node_ids_archive_backend) + group_labels = _import_groups( + backend_from, backend, batch_size, user_ids_archive_backend, node_ids_archive_backend + ) + import_group_id = _make_import_group(group, group_labels, node_ids_archive_backend, backend, batch_size) + new_repo_keys = _get_new_object_keys(archive_format.key_format, backend_from, backend, batch_size) + + if test_run: + # exit before we write anything to the database or repository + raise ImportTestRun('test run complete') + + # now the transaction has been successfully populated, but not committed, we add the repository files + # if the commit fails, this is not so much an issue, since the files can be removed on repo maintenance + _add_files_to_repo(backend_from, backend, new_repo_keys) + + IMPORT_LOGGER.report('Committing transaction to database...') + + return import_group_id + + +def _add_new_entities( + etype: EntityTypes, total: int, unique_field: str, backend_unique_id: dict, backend_from: Backend, + backend_to: Backend, batch_size: int, transform: Callable[[dict], dict] +) -> None: + """Add new entities to the output backend and update the mapping of unique field -> id.""" + IMPORT_LOGGER.report(f'Adding {total} new {etype.value}(s)') + iterator = QueryBuilder(backend=backend_from).append( + entity_type_to_orm[etype], + filters={ + unique_field: { + '!in': list(backend_unique_id) + } + } if backend_unique_id else {}, + project=['**'], + tag='entity' + ).iterdict(batch_size=batch_size) + with get_progress_reporter()(desc=f'Adding new {etype.value}(s)', total=total) as progress: + for nrows, rows in batch_iter(iterator, batch_size, transform): + new_ids = backend_to.bulk_insert(etype, rows) + backend_unique_id.update({row[unique_field]: pk for pk, row in zip(new_ids, rows)}) + progress.update(nrows) + + +def _import_users(backend_from: Backend, backend_to: Backend, batch_size: int) -> Dict[int, int]: + """Import users from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_email = dict( + qbuilder.append(orm.User, project=['id', 'email']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching emails from the backend + output_email_id = {} + if input_id_email: + output_email_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.User, filters={ + 'email': { + 'in': list(input_id_email.values()) + } + }, project=['email', 'id']).all(batch_size=batch_size) + ) + + new_users = len(input_id_email) - len(output_email_id) + existing_users = len(output_email_id) + + if existing_users: + IMPORT_LOGGER.report(f'Skipping {existing_users} existing User(s)') + if new_users: + # add new users and update output_email_id with their email -> id mapping + transform = lambda row: {k: v for k, v in row['entity'].items() if k != 'id'} + _add_new_entities( + EntityTypes.USER, new_users, 'email', output_email_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {i: output_email_id[email] for i, email in input_id_email.items()} + + +def _import_computers(backend_from: Backend, backend_to: Backend, batch_size: int) -> Dict[int, int]: + """Import computers from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict( + qbuilder.append(orm.Computer, project=['id', 'uuid']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Computer, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_computers = len(input_id_uuid) - len(backend_uuid_id) + existing_computers = len(backend_uuid_id) + + if existing_computers: + IMPORT_LOGGER.report(f'Skipping {existing_computers} existing Computer(s)') + if new_computers: + # add new computers and update backend_uuid_id with their uuid -> id mapping + + # Labels should be unique, so we create new labels on clashes + labels = { + label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Computer, project='label' + ).iterall(batch_size=batch_size) + } + relabelled = 0 + + def transform(row: dict) -> dict: + data = row['entity'] + pk = data.pop('id') + nonlocal labels + if data['label'] in labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(data['label'], i) + if new_label not in labels: + data['label'] = new_label + break + else: + raise ImportUniquenessError( + f'Archive Computer {pk} has existing label {data["label"]!r} and re-labelling failed' + ) + nonlocal relabelled + relabelled += 1 + labels.add(data['label']) + return data + + _add_new_entities( + EntityTypes.COMPUTER, new_computers, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, + transform + ) + + if relabelled: + IMPORT_LOGGER.report(f'Re-labelled {relabelled} new Computer(s)') + + # generate mapping of input backend id to output backend id + return {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _import_authinfos( + backend_from: Backend, backend_to: Backend, batch_size: int, user_ids_archive_backend: Dict[int, int], + computer_ids_archive_backend: Dict[int, int] +) -> None: + """Import logs from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_user_comp = ( + qbuilder.append( + orm.AuthInfo, + project=['id', 'aiidauser_id', 'dbcomputer_id'], + ).all(batch_size=batch_size) + ) + + # translate user_id / computer_id, from -> to + try: + to_user_id_comp_id = [(user_ids_archive_backend[_user_id], computer_ids_archive_backend[_comp_id]) + for _, _user_id, _comp_id in input_id_user_comp] + except KeyError as exception: + ImportValidationError(f'Archive AuthInfo has unknown User/Computer: {exception}') + + # retrieve existing user_id / computer_id + backend_id_user_comp = [] + if to_user_id_comp_id: + qbuilder = orm.QueryBuilder(backend=backend_to) + qbuilder.append( + orm.AuthInfo, + filters={ + 'aiidauser_id': { + 'in': [_user_id for _user_id, _ in to_user_id_comp_id] + }, + 'dbcomputer_id': { + 'in': [_comp_id for _, _comp_id in to_user_id_comp_id] + } + }, + project=['id', 'aiidauser_id', 'dbcomputer_id'] + ) + backend_id_user_comp = [(user_id, comp_id) + for _, user_id, comp_id in qbuilder.all(batch_size=batch_size) + if (user_id, comp_id) in to_user_id_comp_id] + + new_authinfos = len(input_id_user_comp) - len(backend_id_user_comp) + existing_authinfos = len(backend_id_user_comp) + + if existing_authinfos: + IMPORT_LOGGER.report(f'Skipping {existing_authinfos} existing AuthInfo(s)') + if not new_authinfos: + return + + # import new authinfos + IMPORT_LOGGER.report(f'Adding {new_authinfos} new {EntityTypes.AUTHINFO.value}(s)') + new_ids = [ + _id for _id, _user_id, _comp_id in input_id_user_comp + if (user_ids_archive_backend[_user_id], computer_ids_archive_backend[_comp_id]) not in backend_id_user_comp + ] + qbuilder = QueryBuilder(backend=backend_from + ).append(orm.AuthInfo, filters={'id': { + 'in': new_ids + }}, project=['**'], tag='entity') + iterator = qbuilder.iterdict() + + def transform(row: dict) -> dict: + data = row['entity'] + data.pop('id') + data['aiidauser_id'] = user_ids_archive_backend[data['aiidauser_id']] + data['dbcomputer_id'] = computer_ids_archive_backend[data['dbcomputer_id']] + return data + + with get_progress_reporter()( + desc=f'Adding new {EntityTypes.AUTHINFO.value}(s)', total=qbuilder.count() + ) as progress: + for nrows, rows in batch_iter(iterator, batch_size, transform): + backend_to.bulk_insert(EntityTypes.AUTHINFO, rows) + progress.update(nrows) + + +def _import_nodes( + backend_from: Backend, backend_to: Backend, batch_size: int, user_ids_archive_backend: Dict[int, int], + computer_ids_archive_backend: Dict[int, int], import_new_extras: bool, merge_extras: MergeExtrasType +) -> Dict[int, int]: + """Import users from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + IMPORT_LOGGER.report('Collecting Node(s) ...') + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict( + qbuilder.append(orm.Node, project=['id', 'uuid']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Node, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_nodes = len(input_id_uuid) - len(backend_uuid_id) + + if backend_uuid_id: + _merge_node_extras(backend_from, backend_to, batch_size, backend_uuid_id, merge_extras) + + if new_nodes: + # add new nodes and update backend_uuid_id with their uuid -> id mapping + transform = NodeTransform(user_ids_archive_backend, computer_ids_archive_backend, import_new_extras) + _add_new_entities( + EntityTypes.NODE, new_nodes, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +class NodeTransform: + """Callable to transform a Node DB row, between the source archive and target backend.""" + + def __init__( + self, user_ids_archive_backend: Dict[int, int], computer_ids_archive_backend: Dict[int, int], + import_new_extras: bool + ): + self.user_ids_archive_backend = user_ids_archive_backend + self.computer_ids_archive_backend = computer_ids_archive_backend + self.import_new_extras = import_new_extras + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Node {pk} has unknown User: {exc}') + if data['dbcomputer_id'] is not None: + try: + data['dbcomputer_id'] = self.computer_ids_archive_backend[data['dbcomputer_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Node {pk} has unknown Computer: {exc}') + if self.import_new_extras: + # Remove node hashing and other aiida "private" extras + data['extras'] = {k: v for k, v in data['extras'].items() if not k.startswith('_aiida_')} + if data.get('node_type', '').endswith('code.Code.'): + data['extras'].pop('hidden', None) + else: + data['extras'] = {} + if data.get('node_type', '').startswith('process.'): + # remove checkpoint from attributes of process nodes + data['attributes'].pop(orm.ProcessNode.CHECKPOINT_KEY, None) + return data + + +def _import_logs(backend_from: Backend, backend_to: Backend, batch_size: int, + node_ids_archive_backend: Dict[int, int]) -> Dict[int, int]: + """Import logs from one backend to another. + + :returns: mapping of input backend id to output backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict( + qbuilder.append(orm.Log, project=['id', 'uuid']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Log, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_logs = len(input_id_uuid) - len(backend_uuid_id) + existing_logs = len(backend_uuid_id) + + if existing_logs: + IMPORT_LOGGER.report(f'Skipping {existing_logs} existing Log(s)') + if new_logs: + # add new logs and update backend_uuid_id with their uuid -> id mapping + def transform(row: dict) -> dict: + data = row['entity'] + pk = data.pop('id') + try: + data['dbnode_id'] = node_ids_archive_backend[data['dbnode_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Log {pk} has unknown Node: {exc}') + return data + + _add_new_entities( + EntityTypes.LOG, new_logs, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + # generate mapping of input backend id to output backend id + return {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _merge_node_extras( + backend_from: Backend, backend_to: Backend, batch_size: int, backend_uuid_id: Dict[str, int], mode: MergeExtrasType +) -> None: + """Merge extras from the input backend with the ones in the output backend. + + :param backend_uuid_id: mapping of uuid to output backend id + :param mode: tuple of merge modes for extras + """ + num_existing = len(backend_uuid_id) + + if mode == ('k', 'n', 'l'): + # 'none': keep old extras, do not add imported ones + IMPORT_LOGGER.report(f'Skipping {num_existing} existing Node(s)') + return + + input_extras = QueryBuilder( + backend=backend_from + ).append(orm.Node, tag='node', filters={ + 'uuid': { + 'in': list(backend_uuid_id.keys()) + } + }, project=['uuid', 'extras']).order_by([{ + 'node': 'uuid' + }]) + + if mode == ('n', 'c', 'u'): + # 'mirror' operation: remove old extras, put only the new ones + IMPORT_LOGGER.report(f'Replacing {num_existing} existing Node extras') + transform = lambda row: {'id': backend_uuid_id[row[0]], 'extras': row[1]} + with get_progress_reporter()(desc='Replacing extras', total=input_extras.count()) as progress: + for nrows, rows in batch_iter(input_extras.iterall(batch_size=batch_size), batch_size, transform): + backend_to.bulk_update(EntityTypes.NODE, rows) + progress.update(nrows) + return + + # run (slower) generic merge operation + backend_extras = QueryBuilder( + backend=backend_to + ).append(orm.Node, tag='node', filters={ + 'uuid': { + 'in': list(backend_uuid_id.keys()) + } + }, project=['uuid', 'extras']).order_by([{ + 'node': 'uuid' + }]) + + IMPORT_LOGGER.report(f'Merging {num_existing} existing Node extras') + + if not input_extras.count() == backend_extras.count(): + raise ImportValidationError( + f'Number of Nodes in archive ({input_extras.count()}) and backend ({backend_extras.count()}) do not match' + ) + + def _transform(data: Tuple[Tuple[str, dict], Tuple[str, dict]]) -> dict: + """Transform the new and existing extras into a dict that can be passed to bulk_update.""" + new_uuid, new_extras = data[0] + old_uuid, old_extras = data[1] + if new_uuid != old_uuid: + raise ImportValidationError(f'UUID mismatch when merging node extras: {new_uuid} != {old_uuid}') + backend_id = backend_uuid_id[new_uuid] + old_keys = set(old_extras.keys()) + new_keys = set(new_extras.keys()) + collided_keys = old_keys.intersection(new_keys) + old_keys_only = old_keys.difference(collided_keys) + new_keys_only = new_keys.difference(collided_keys) + + final_extras = {} + + if mode == ('k', 'c', 'u'): + # 'update_existing' operation: if an extra already exists, + # overwrite its new value with a new one + final_extras = new_extras + for key in old_keys_only: + final_extras[key] = old_extras[key] + return {'id': backend_id, 'extras': final_extras} + + if mode == ('k', 'c', 'l'): + # 'keep_existing': if an extra already exists, keep its original value + final_extras = old_extras + for key in new_keys_only: + final_extras[key] = new_extras[key] + return {'id': backend_id, 'extras': final_extras} + + if mode[0] == 'k': + for key in old_keys_only: + final_extras[key] = old_extras[key] + elif mode[0] != 'n': + raise ImportValidationError( + f"Unknown first letter of the update extras mode: '{mode}'. Should be either 'k' or 'n'" + ) + if mode[1] == 'c': + for key in new_keys_only: + final_extras[key] = new_extras[key] + elif mode[1] != 'n': + raise ImportValidationError( + f"Unknown second letter of the update extras mode: '{mode}'. Should be either 'c' or 'n'" + ) + if mode[2] == 'u': + for key in collided_keys: + final_extras[key] = new_extras[key] + elif mode[2] == 'l': + for key in collided_keys: + final_extras[key] = old_extras[key] + elif mode[2] != 'd': + raise ImportValidationError( + f"Unknown third letter of the update extras mode: '{mode}'. Should be one of 'u'/'l'/'a'/'d'" + ) + return {'id': backend_id, 'extras': final_extras} + + with get_progress_reporter()(desc='Merging extras', total=input_extras.count()) as progress: + for nrows, rows in batch_iter( + zip(input_extras.iterall(batch_size=batch_size), backend_extras.iterall(batch_size=batch_size)), batch_size, + _transform + ): + backend_to.bulk_update(EntityTypes.NODE, rows) + progress.update(nrows) + + +class CommentTransform: + """Callable to transform a Comment DB row, between the source archive and target backend.""" + + def __init__( + self, + user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int], + ): + self.user_ids_archive_backend = user_ids_archive_backend + self.node_ids_archive_backend = node_ids_archive_backend + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Comment {pk} has unknown User: {exc}') + try: + data['dbnode_id'] = self.node_ids_archive_backend[data['dbnode_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Comment {pk} has unknown Node: {exc}') + return data + + +def _import_comments( + backend_from: Backend, + backend: Backend, + batch_size: int, + user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int], + merge_comments: MergeCommentsType, +) -> Dict[int, int]: + """Import comments from one backend to another. + + :returns: mapping of archive id to backend id + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict( + qbuilder.append(orm.Comment, project=['id', 'uuid']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend + ).append(orm.Comment, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + new_comments = len(input_id_uuid) - len(backend_uuid_id) + existing_comments = len(backend_uuid_id) + + archive_comments = QueryBuilder( + backend=backend_from + ).append(orm.Comment, filters={'uuid': { + 'in': list(backend_uuid_id.keys()) + }}, project=['uuid', 'mtime', 'content']) + + if existing_comments: + if merge_comments == 'leave': + IMPORT_LOGGER.report(f'Skipping {existing_comments} existing Comment(s)') + elif merge_comments == 'overwrite': + IMPORT_LOGGER.report(f'Overwriting {existing_comments} existing Comment(s)') + + def _transform(row): + data = {'id': backend_uuid_id[row[0]], 'mtime': row[1], 'content': row[2]} + return data + + with get_progress_reporter()(desc='Overwriting comments', total=archive_comments.count()) as progress: + for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + backend.bulk_update(EntityTypes.COMMENT, rows) + progress.update(nrows) + + elif merge_comments == 'newest': + IMPORT_LOGGER.report(f'Updating {existing_comments} existing Comment(s)') + + def _transform(row): + # to-do this is probably not the most efficient way to do this + uuid, new_mtime, new_comment = row + cmt = orm.Comment.objects.get(uuid=uuid) + if cmt.mtime < new_mtime: + cmt.set_mtime(new_mtime) + cmt.set_content(new_comment) + + with get_progress_reporter()(desc='Updating comments', total=archive_comments.count()) as progress: + for nrows, rows in batch_iter(archive_comments.iterall(batch_size=batch_size), batch_size, _transform): + progress.update(nrows) + + else: + raise ImportValidationError(f'Unknown merge_comments value: {merge_comments}.') + if new_comments: + # add new comments and update backend_uuid_id with their uuid -> id mapping + _add_new_entities( + EntityTypes.COMMENT, new_comments, 'uuid', backend_uuid_id, backend_from, backend, batch_size, + CommentTransform(user_ids_archive_backend, node_ids_archive_backend) + ) + + # generate mapping of input backend id to output backend id + return {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + + +def _import_links( + backend_from: Backend, backend_to: Backend, batch_size: int, node_ids_archive_backend: Dict[int, int] +) -> None: + """Import links from one backend to another.""" + + # initial variables + calculation_node_types = 'process.calculation.' + workflow_node_types = 'process.workflow.' + data_node_types = 'data.' + allowed_link_nodes = { + LinkType.CALL_CALC: (workflow_node_types, calculation_node_types), + LinkType.CALL_WORK: (workflow_node_types, workflow_node_types), + LinkType.CREATE: (calculation_node_types, data_node_types), + LinkType.INPUT_CALC: (data_node_types, calculation_node_types), + LinkType.INPUT_WORK: (data_node_types, workflow_node_types), + LinkType.RETURN: (workflow_node_types, data_node_types), + } + link_type_uniqueness = { + LinkType.CALL_CALC: ('out_id',), + LinkType.CALL_WORK: ('out_id',), + LinkType.CREATE: ( + 'in_id_label', + 'out_id', + ), + LinkType.INPUT_CALC: ('out_id_label',), + LinkType.INPUT_WORK: ('out_id_label',), + LinkType.RETURN: ('in_id_label',), + } + + # Batch by type, to reduce memory load + # to-do check no extra types in archive? + for link_type in LinkType: + + # get validation parameters + allowed_in_type, allowed_out_type = allowed_link_nodes[link_type] + link_uniqueness = link_type_uniqueness[link_type] + + # count links of this type in archive + archive_query = QueryBuilder(backend=backend_from + ).append(orm.Node, tag='incoming', project=['id', 'node_type']).append( + orm.Node, + with_incoming='incoming', + project=['id', 'node_type'], + edge_filters={'type': link_type.value}, + edge_project=['id', 'label'] + ) + total = archive_query.count() + + if not total: + continue # nothing to add + + # get existing links set, to check existing + IMPORT_LOGGER.report(f'Gathering existing {link_type.value!r} Link(s)') + existing_links = { + tuple(link) for link in orm.QueryBuilder(backend=backend_to). + append(entity_type='link', filters={ + 'type': link_type.value + }, project=['input_id', 'output_id', 'label']).iterall(batch_size=batch_size) + } + # create additional validators + # note, we only populate them when required, to reduce memory usage + existing_in_id_label = {(l[0], l[2]) for l in existing_links} if 'in_id_label' in link_uniqueness else set() + existing_out_id = {l[1] for l in existing_links} if 'out_id' in link_uniqueness else set() + existing_out_id_label = {(l[1], l[2]) for l in existing_links} if 'out_id_label' in link_uniqueness else set() + + # loop through archive links; validate and add new + new_count = existing_count = 0 + insert_rows = [] + with get_progress_reporter()(desc=f'Processing {link_type.value!r} Link(s)', total=total) as progress: + for in_id, in_type, out_id, out_type, link_id, link_label in archive_query.iterall(batch_size=batch_size): + + progress.update() + + # convert ids: archive -> profile + try: + in_id = node_ids_archive_backend[in_id] + except KeyError as exc: + raise ImportValidationError(f'Archive Link {link_id} has unknown input Node: {exc}') + try: + out_id = node_ids_archive_backend[out_id] + except KeyError as exc: + raise ImportValidationError(f'Archive Link {link_id} has unknown output Node: {exc}') + + # skip existing links + if (in_id, out_id, link_label) in existing_links: + existing_count += 1 + continue + + # validation + if in_id == out_id: + raise ImportValidationError(f'Cannot add a link to oneself: {in_id}') + if not in_type.startswith(allowed_in_type): + raise ImportValidationError( + f'Cannot add a {link_type.value!r} link from {in_type} (link {link_id})' + ) + if not out_type.startswith(allowed_out_type): + raise ImportValidationError(f'Cannot add a {link_type.value!r} link to {out_type} (link {link_id})') + if 'in_id_label' in link_uniqueness and (in_id, link_label) in existing_in_id_label: + raise ImportUniquenessError( + f'Node {in_id} already has an outgoing {link_type.value!r} link with label {link_label!r}' + ) + if 'out_id' in link_uniqueness and out_id in existing_out_id_label: + raise ImportUniquenessError(f'Node {out_id} already has an incoming {link_type.value!r} link') + if 'out_id_label' in link_uniqueness and (out_id, link_label) in existing_out_id_label: + raise ImportUniquenessError( + f'Node {out_id} already has an incoming {link_type.value!r} link with label {link_label!r}' + ) + + # update variables + new_count += 1 + insert_rows.append({ + 'input_id': in_id, + 'output_id': out_id, + 'type': link_type.value, + 'label': link_label, + }) + existing_links.add((in_id, out_id, link_label)) + existing_in_id_label.add((in_id, link_label)) + existing_out_id.add(out_id) + existing_out_id_label.add((out_id, link_label)) + + # flush new rows, once batch size is reached + if (new_count % batch_size) == 0: + backend_to.bulk_insert(EntityTypes.LINK, insert_rows) + insert_rows = [] + + # flush remaining new rows + if insert_rows: + backend_to.bulk_insert(EntityTypes.LINK, insert_rows) + + # report counts + if existing_count: + IMPORT_LOGGER.report(f'Skipped {existing_count} existing {link_type.value!r} Link(s)') + if new_count: + IMPORT_LOGGER.report(f'Added {new_count} new {link_type.value!r} Link(s)') + + +class GroupTransform: + """Callable to transform a Group DB row, between the source archive and target backend.""" + + def __init__(self, user_ids_archive_backend: Dict[int, int], labels: Set[str]): + self.user_ids_archive_backend = user_ids_archive_backend + self.labels = labels + self.relabelled = 0 + + def __call__(self, row: dict) -> dict: + """Perform the transform.""" + data = row['entity'] + pk = data.pop('id') + try: + data['user_id'] = self.user_ids_archive_backend[data['user_id']] + except KeyError as exc: + raise ImportValidationError(f'Archive Group {pk} has unknown User: {exc}') + # Labels should be unique, so we create new labels on clashes + if data['label'] in self.labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(data['label'], i) + if new_label not in self.labels: + data['label'] = new_label + break + else: + raise ImportUniquenessError( + f'Archive Group {pk} has existing label {data["label"]!r} and re-labelling failed' + ) + self.relabelled += 1 + self.labels.add(data['label']) + return data + + +def _import_groups( + backend_from: Backend, backend_to: Backend, batch_size: int, user_ids_archive_backend: Dict[int, int], + node_ids_archive_backend: Dict[int, int] +) -> Set[str]: + """Import groups from the input backend, and add group -> node records. + + :returns: Set of labels + """ + # get the records from the input backend + qbuilder = QueryBuilder(backend=backend_from) + input_id_uuid = dict( + qbuilder.append(orm.Group, project=['id', 'uuid']).all(batch_size=batch_size) # type: ignore[arg-type] + ) + + # get matching uuids from the backend + backend_uuid_id = {} + if input_id_uuid: + backend_uuid_id = dict( + orm.QueryBuilder( + backend=backend_to + ).append(orm.Group, filters={ + 'uuid': { + 'in': list(input_id_uuid.values()) + } + }, project=['uuid', 'id']).all(batch_size=batch_size) + ) + + # get all labels + labels = { + label for label, in orm.QueryBuilder(backend=backend_to).append(orm.Group, project='label' + ).iterall(batch_size=batch_size) + } + + new_groups = len(input_id_uuid) - len(backend_uuid_id) + new_uuids = set(input_id_uuid.values()).difference(backend_uuid_id.keys()) + existing_groups = len(backend_uuid_id) + + if existing_groups: + IMPORT_LOGGER.report(f'Skipping {existing_groups} existing Group(s)') + if new_groups: + # add new groups and update backend_uuid_id with their uuid -> id mapping + + transform = GroupTransform(user_ids_archive_backend, labels) + + _add_new_entities( + EntityTypes.GROUP, new_groups, 'uuid', backend_uuid_id, backend_from, backend_to, batch_size, transform + ) + + if transform.relabelled: + IMPORT_LOGGER.report(f'Re-labelled {transform.relabelled} new Group(s)') + + # generate mapping of input backend id to output backend id + group_id_archive_backend = {i: backend_uuid_id[uuid] for i, uuid in input_id_uuid.items()} + # Add nodes to new groups + iterator = QueryBuilder(backend=backend_from + ).append(orm.Group, project='id', filters={ + 'uuid': { + 'in': new_uuids + } + }, tag='group').append(orm.Node, project='id', with_group='group') + total = iterator.count() + if total: + IMPORT_LOGGER.report(f'Adding {total} Node(s) to new Group(s)') + + def group_node_transform(row): + group_id = group_id_archive_backend[row[0]] + try: + node_id = node_ids_archive_backend[row[1]] + except KeyError as exc: + raise ImportValidationError(f'Archive Group {group_id} has unknown Node: {exc}') + return {'dbgroup_id': group_id, 'dbnode_id': node_id} + + with get_progress_reporter()(desc=f'Adding new {EntityTypes.GROUP_NODE.value}(s)', total=total) as progress: + for nrows, rows in batch_iter( + iterator.iterall(batch_size=batch_size), batch_size, group_node_transform + ): + backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) + progress.update(nrows) + + return labels + + +def _make_import_group( + group: Optional[orm.Group], labels: Set[str], node_ids_archive_backend: Dict[int, int], backend_to: Backend, + batch_size: int +) -> Optional[int]: + """Make an import group containing all imported nodes. + + :param group: Use an existing group + :param labels: All existing group labels on the backend + :param node_ids_archive_backend: node pks to add to the group + + :returns: The id of the group + + """ + # Do not create an empty group + if not node_ids_archive_backend: + IMPORT_LOGGER.debug('No nodes to import, so no import group created') + return None + + # Get the Group id + if group is None: + # Get an unique name for the import group, based on the current (local) time + label = timezone.localtime(timezone.now()).strftime('%Y%m%d-%H%M%S') + if label in labels: + for i in range(DUPLICATE_LABEL_MAX): + new_label = DUPLICATE_LABEL_TEMPLATE.format(label, i) + if new_label not in labels: + label = new_label + break + else: + raise ImportUniquenessError(f'New import Group has existing label {label!r} and re-labelling failed') + dummy_orm = orm.ImportGroup(label) + row = { + 'label': label, + 'description': 'Group generated by archive import', + 'type_string': dummy_orm.type_string, + 'user_id': dummy_orm.user.pk, + } + group_id, = backend_to.bulk_insert(EntityTypes.GROUP, [row], allow_defaults=True) + IMPORT_LOGGER.report(f'Created new import Group: PK={group_id}, label={label}') + group_node_ids = set() + else: + group_id = group.pk + IMPORT_LOGGER.report(f'Using existing import Group: PK={group_id}, label={group.label}') + group_node_ids = { + pk for pk, in orm.QueryBuilder(backend=backend_to).append(orm.Group, filters={ + 'id': group_id + }, tag='group').append(orm.Node, with_group='group', project='id').iterall(batch_size=batch_size) + } + + # Add all the nodes to the Group + with get_progress_reporter()( + desc='Adding all Node(s) to the import Group', total=len(node_ids_archive_backend) + ) as progress: + iterator = ({ + 'dbgroup_id': group_id, + 'dbnode_id': node_id + } for node_id in node_ids_archive_backend.values() if node_id not in group_node_ids) + for nrows, rows in batch_iter(iterator, batch_size): + backend_to.bulk_insert(EntityTypes.GROUP_NODE, rows) + progress.update(nrows) + + return group_id + + +def _get_new_object_keys(key_format: str, backend_from: Backend, backend_to: Backend, batch_size: int) -> Set[str]: + """Return the object keys that need to be added to the backend.""" + archive_hashkeys: Set[str] = set() + query = QueryBuilder(backend=backend_from).append(orm.Node, project='repository_metadata') + with get_progress_reporter()(desc='Collecting archive Node file keys', total=query.count()) as progress: + for repository_metadata, in query.iterall(batch_size=batch_size): + archive_hashkeys.update(key for key in Repository.flatten(repository_metadata).values() if key is not None) + progress.update() + + IMPORT_LOGGER.report('Checking keys against repository ...') + + repository = backend_to.get_repository() + if not repository.key_format == key_format: + raise NotImplementedError( + f'Backend repository key format incompatible: {repository.key_format!r} != {key_format!r}' + ) + new_hashkeys = archive_hashkeys.difference(repository.list_objects()) + + existing_count = len(archive_hashkeys) - len(new_hashkeys) + if existing_count: + IMPORT_LOGGER.report(f'Skipping {existing_count} existing repository files') + if new_hashkeys: + IMPORT_LOGGER.report(f'Adding {len(new_hashkeys)} new repository files') + + return new_hashkeys + + +def _add_files_to_repo(backend_from: Backend, backend_to: Backend, new_keys: Set[str]) -> None: + """Add the new files to the repository.""" + if not new_keys: + return None + + repository_to = backend_to.get_repository() + repository_from = backend_from.get_repository() + with get_progress_reporter()(desc='Adding archive files to repository', total=len(new_keys)) as progress: + for key, handle in repository_from.iter_object_streams(new_keys): + backend_key = repository_to.put_object_from_filelike(handle) + if backend_key != key: + raise ImportValidationError( + f'Archive repository key is different to backend key: {key!r} != {backend_key!r}' + ) + progress.update() diff --git a/aiida/tools/graph/graph_traversers.py b/aiida/tools/graph/graph_traversers.py index 8f9f0c0f6d..991332a817 100644 --- a/aiida/tools/graph/graph_traversers.py +++ b/aiida/tools/graph/graph_traversers.py @@ -69,13 +69,13 @@ def get_nodes_delete( missing_callback=missing_callback, ) - function_output = { + function_output: TraverseGraphOutput = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } - return cast(TraverseGraphOutput, function_output) + return function_output def get_nodes_export( @@ -112,13 +112,13 @@ def get_nodes_export( links_backward=traverse_links['backward'] ) - function_output = { + function_output: TraverseGraphOutput = { 'nodes': traverse_output['nodes'], 'links': traverse_output['links'], 'rules': traverse_links['rules_applied'] } - return cast(TraverseGraphOutput, function_output) + return function_output def validate_traversal_rules( @@ -296,10 +296,10 @@ def traverse_graph( results = rulesequence.run(basket) - output = {} + output: TraverseGraphOutput = {} output['nodes'] = results.nodes.keyset output['links'] = None if get_links: output['links'] = results['nodes_nodes'].keyset - return cast(TraverseGraphOutput, output) + return output diff --git a/aiida/tools/importexport/__init__.py b/aiida/tools/importexport/__init__.py deleted file mode 100644 index 0d545768a0..0000000000 --- a/aiida/tools/importexport/__init__.py +++ /dev/null @@ -1,71 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Provides import/export functionalities. - -To see history/git blame prior to the move to aiida.tools.importexport, -explore tree: https://github.com/aiidateam/aiida-core/tree/eebef392c81e8b130834a92e1d7abf5e2e30b3ce -Functionality: /aiida/orm/importexport.py -Tests: /aiida/backends/tests/test_export_and_import.py -""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .archive import * -from .common import * -from .dbexport import * -from .dbimport import * - -__all__ = ( - 'ARCHIVE_READER_LOGGER', - 'ArchiveExportError', - 'ArchiveImportError', - 'ArchiveMetadata', - 'ArchiveMigrationError', - 'ArchiveMigratorAbstract', - 'ArchiveMigratorJsonBase', - 'ArchiveMigratorJsonTar', - 'ArchiveMigratorJsonZip', - 'ArchiveReaderAbstract', - 'ArchiveWriterAbstract', - 'CacheFolder', - 'CorruptArchive', - 'DanglingLinkError', - 'EXPORT_LOGGER', - 'EXPORT_VERSION', - 'ExportFileFormat', - 'ExportImportException', - 'ExportValidationError', - 'IMPORT_LOGGER', - 'ImportUniquenessError', - 'ImportValidationError', - 'IncompatibleArchiveVersionError', - 'MIGRATE_LOGGER', - 'MigrationValidationError', - 'ProgressBarError', - 'ReaderJsonBase', - 'ReaderJsonFolder', - 'ReaderJsonTar', - 'ReaderJsonZip', - 'WriterJsonFolder', - 'WriterJsonTar', - 'WriterJsonZip', - 'detect_archive_type', - 'export', - 'get_migrator', - 'get_reader', - 'get_writer', - 'import_data', - 'null_callback', -) - -# yapf: enable diff --git a/aiida/tools/importexport/archive/__init__.py b/aiida/tools/importexport/archive/__init__.py deleted file mode 100644 index b2dd149d7c..0000000000 --- a/aiida/tools/importexport/archive/__init__.py +++ /dev/null @@ -1,48 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -# type: ignore -"""Readers and writers for archive formats, that work independently of a connection to an AiiDA profile.""" - -# AUTO-GENERATED - -# yapf: disable -# pylint: disable=wildcard-import - -from .common import * -from .migrators import * -from .readers import * -from .writers import * - -__all__ = ( - 'ARCHIVE_READER_LOGGER', - 'ArchiveMetadata', - 'ArchiveMigratorAbstract', - 'ArchiveMigratorJsonBase', - 'ArchiveMigratorJsonTar', - 'ArchiveMigratorJsonZip', - 'ArchiveReaderAbstract', - 'ArchiveWriterAbstract', - 'CacheFolder', - 'MIGRATE_LOGGER', - 'ReaderJsonBase', - 'ReaderJsonFolder', - 'ReaderJsonTar', - 'ReaderJsonZip', - 'WriterJsonFolder', - 'WriterJsonTar', - 'WriterJsonZip', - 'detect_archive_type', - 'get_migrator', - 'get_reader', - 'get_writer', - 'null_callback', -) - -# yapf: enable diff --git a/aiida/tools/importexport/archive/common.py b/aiida/tools/importexport/archive/common.py deleted file mode 100644 index 583115fbf5..0000000000 --- a/aiida/tools/importexport/archive/common.py +++ /dev/null @@ -1,234 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Shared resources for the archive.""" -from collections import OrderedDict -import copy -import dataclasses -import os -from pathlib import Path -import tarfile -from types import TracebackType -from typing import Any, Dict, List, Optional, Tuple, Type, Union -import zipfile - -from aiida.common import json # handles byte dumps -from aiida.common.log import AIIDA_LOGGER - -__all__ = ('ArchiveMetadata', 'detect_archive_type', 'null_callback', 'CacheFolder') - -ARCHIVE_LOGGER = AIIDA_LOGGER.getChild('archive') - - -@dataclasses.dataclass -class ArchiveMetadata: - """Class for storing metadata about this archive. - - Required fields are necessary for importing the data back into AiiDA, - whereas optional fields capture information about the export/migration process(es) - """ - export_version: str - aiida_version: str - # Entity type -> database ID key - unique_identifiers: Dict[str, str] = dataclasses.field(repr=False) - # Entity type -> database key -> meta parameters - all_fields_info: Dict[str, Dict[str, Dict[str, str]]] = dataclasses.field(repr=False) - - # optional data - graph_traversal_rules: Optional[Dict[str, bool]] = dataclasses.field(default=None) - # Entity type -> UUID list - entities_starting_set: Optional[Dict[str, List[str]]] = dataclasses.field(default=None) - include_comments: Optional[bool] = dataclasses.field(default=None) - include_logs: Optional[bool] = dataclasses.field(default=None) - # list of migration event notifications - conversion_info: List[str] = dataclasses.field(default_factory=list, repr=False) - - -def null_callback(action: str, value: Any): # pylint: disable=unused-argument - """A null callback function.""" - - -def detect_archive_type(in_path: str) -> str: - """For back-compatibility, but should be replaced with direct comparison of classes. - - :param in_path: the path to the file - :returns: the archive type identifier (currently one of 'zip', 'tar.gz', 'folder') - - """ - from aiida.tools.importexport.common.config import ExportFileFormat - from aiida.tools.importexport.common.exceptions import ImportValidationError - - if os.path.isdir(in_path): - return 'folder' - if tarfile.is_tarfile(in_path): - return ExportFileFormat.TAR_GZIPPED - if zipfile.is_zipfile(in_path): - return ExportFileFormat.ZIP - raise ImportValidationError( - 'Unable to detect the input file format, it is neither a ' - 'folder, tar file, nor a (possibly compressed) zip file.' - ) - - -class CacheFolder: - """A class to encapsulate a folder path with cached read/writes. - - The class can be used as a context manager, and will flush the cache on exit:: - - with CacheFolder(path) as folder: - # these are stored in memory (no disk write) - folder.write_text('path/to/file.txt', 'content') - folder.write_json('path/to/data.json', {'a': 1}) - # these will be read from memory - text = folder.read_text('path/to/file.txt') - text = folder.load_json('path/to/data.json') - - # all files will now have been written to disk - - """ - - def __init__(self, path: Union[Path, str], *, encoding: str = 'utf8'): - """Initialise cached folder. - - :param path: folder path to cache - :param encoding: encoding of text to read/write - - """ - self._path = Path(path) - # dict mapping path -> (type, content) - self._cache = OrderedDict() # type: ignore - self._encoding = encoding - self._max_items = 100 # maximum limit of files to store in memory - - def _write_object(self, path: str, ctype: str, content: Any): - """Write an object from the cache to disk. - - :param path: relative path of file - :param ctype: the type of the content - :param content: the content to write - - """ - if ctype == 'text': - (self._path / path).write_text(content, encoding=self._encoding) - elif ctype == 'json': - with (self._path / path).open(mode='wb') as handle: - json.dump(content, handle) - else: - raise TypeError(f'Unknown content type: {ctype}') - - def flush(self): - """Flush the cache.""" - for path, (ctype, content) in self._cache.items(): - self._write_object(path, ctype, content) - - def _limit_cache(self): - """Ensure the cache does not exceed a set limit. - - Content is uncached on a First-In-First-Out basis. - - """ - while len(self._cache) > self._max_items: - path, (ctype, content) = self._cache.popitem(last=False) - self._write_object(path, ctype, content) - - def get_path(self, flush=True) -> Path: - """Return the path. - - :param flush: flush the cache before returning - - """ - if flush: - self.flush() - return self._path - - def write_text(self, path: str, content: str): - """write text to the cache. - - :param path: path relative to base folder - - """ - assert isinstance(content, str) - self._cache[path] = ('text', content) - self._limit_cache() - - def read_text(self, path) -> str: - """write text from the cache or base folder. - - :param path: path relative to base folder - - """ - if path not in self._cache: - return (self._path / path).read_text(self._encoding) - ctype, content = self._cache[path] - if ctype == 'text': - return content - if ctype == 'json': - return json.dumps(content) - - raise TypeError(f"content of type '{ctype}' could not be converted to text") - - def write_json(self, path: str, data: dict): - """Write dict to the folder, to be serialized as json. - - The dictionary is stored in memory, until the cache is flushed, - at which point the dictionary is serialized to json and written to disk. - - :param path: path relative to base folder - - """ - assert isinstance(data, dict) - # json.dumps(data) # make sure that the data can be converted to json (increases memory usage) - self._cache[path] = ('json', data) - self._limit_cache() - - def load_json(self, path: str, ensure_copy: bool = False) -> Tuple[bool, dict]: - """Load a json file from the cache folder. - - Important: if the dict is returned directly from the cache, any mutations will affect the cached dict. - - :param path: path relative to base folder - :param ensure_copy: ensure the dict is a copy of that from the cache - - :returns: (from cache, the content) - If from cache, mutations will directly affect the cache - - """ - if path not in self._cache: - return False, json.loads((self._path / path).read_text(self._encoding)) - - ctype, content = self._cache[path] - if ctype == 'text': - return False, json.loads(content) - if ctype == 'json': - if ensure_copy: - return False, copy.deepcopy(content) - return True, content - - raise TypeError(f"content of type '{ctype}' could not be converted to a dict") - - def remove_file(self, path): - """Remove a file from both the cache and base folder (if present). - - :param path: path relative to base folder - - """ - self._cache.pop(path, None) - if (self._path / path).exists(): - (self._path / path).unlink() - - def __enter__(self): - """Enter the contextmanager.""" - return self - - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ): - """Exit the contextmanager.""" - self.flush() - return False diff --git a/aiida/tools/importexport/archive/migrations/v01_to_v02.py b/aiida/tools/importexport/archive/migrations/v01_to_v02.py deleted file mode 100644 index ee59aea9ca..0000000000 --- a/aiida/tools/importexport/archive/migrations/v01_to_v02.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.1 to v0.2, used by `verdi export migrate` command.""" -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - - -def migrate_v1_to_v2(folder: CacheFolder): - """ - Migration of archive files from v0.1 to v0.2, which means generalizing the - field names with respect to the database backend - - :param metadata: the content of an export archive metadata.json file - :param data: the content of an export archive data.json file - """ - old_version = '0.1' - new_version = '0.2' - - old_start = 'aiida.djsite' - new_start = 'aiida.backends.djsite' - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - for field in ['export_data']: - for key in list(data[field]): - if key.startswith(old_start): - new_key = get_new_string(key, old_start, new_start) - data[field][new_key] = data[field][key] - del data[field][key] - - for field in ['unique_identifiers', 'all_fields_info']: - for key in list(metadata[field].keys()): - if key.startswith(old_start): - new_key = get_new_string(key, old_start, new_start) - metadata[field][new_key] = metadata[field][key] - del metadata[field][key] - - metadata['all_fields_info'] = replace_requires(metadata['all_fields_info'], old_start, new_start) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) - - -def get_new_string(old_string, old_start, new_start): - """Replace the old module prefix with the new.""" - if old_string.startswith(old_start): - return f'{new_start}{old_string[len(old_start):]}' - - return old_string - - -def replace_requires(data, old_start, new_start): - """Replace the requires keys with new module path.""" - if isinstance(data, dict): - new_data = {} - for key, value in data.items(): - if key == 'requires' and value.startswith(old_start): - new_data[key] = get_new_string(value, old_start, new_start) - else: - new_data[key] = replace_requires(value, old_start, new_start) - return new_data - - return data diff --git a/aiida/tools/importexport/archive/migrations/v02_to_v03.py b/aiida/tools/importexport/archive/migrations/v02_to_v03.py deleted file mode 100644 index 07e3b339bd..0000000000 --- a/aiida/tools/importexport/archive/migrations/v02_to_v03.py +++ /dev/null @@ -1,140 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.2 to v0.3, used by `verdi export migrate` command.""" -# pylint: disable=too-many-locals,too-many-statements,too-many-branches,unused-argument -import enum - -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.common.exceptions import DanglingLinkError - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - - -def migrate_v2_to_v3(folder: CacheFolder): - """ - Migration of archive files from v0.2 to v0.3, which means adding the link - types to the link entries and making the entity key names backend agnostic - by effectively removing the prefix 'aiida.backends.djsite.db.models' - - :param data: the content of an export archive data.json file - :param metadata: the content of an export archive metadata.json file - """ - - old_version = '0.2' - new_version = '0.3' - - class LinkType(enum.Enum): - """This was the state of the `aiida.common.links.LinkType` enum before aiida-core v1.0.0a5""" - - UNSPECIFIED = 'unspecified' - CREATE = 'createlink' - RETURN = 'returnlink' - INPUT = 'inputlink' - CALL = 'calllink' - - class NodeType(enum.Enum): - """A simple enum of relevant node types""" - - NONE = 'none' - CALC = 'calculation' - CODE = 'code' - DATA = 'data' - WORK = 'work' - - entity_map = { - 'aiida.backends.djsite.db.models.DbNode': 'Node', - 'aiida.backends.djsite.db.models.DbLink': 'Link', - 'aiida.backends.djsite.db.models.DbGroup': 'Group', - 'aiida.backends.djsite.db.models.DbComputer': 'Computer', - 'aiida.backends.djsite.db.models.DbUser': 'User', - 'aiida.backends.djsite.db.models.DbAttribute': 'Attribute' - } - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - # Create a mapping from node uuid to node type - mapping = {} - for nodes in data['export_data'].values(): - for node in nodes.values(): - - try: - node_uuid = node['uuid'] - node_type_string = node['type'] - except KeyError: - continue - - if node_type_string.startswith('calculation.job.'): - node_type = NodeType.CALC - elif node_type_string.startswith('calculation.inline.'): - node_type = NodeType.CALC - elif node_type_string.startswith('code.Code'): - node_type = NodeType.CODE - elif node_type_string.startswith('data.'): - node_type = NodeType.DATA - elif node_type_string.startswith('calculation.work.'): - node_type = NodeType.WORK - else: - node_type = NodeType.NONE - - mapping[node_uuid] = node_type - - # For each link, deduce the link type and insert it in place - for link in data['links_uuid']: - - try: - input_type = NodeType(mapping[link['input']]) - output_type = NodeType(mapping[link['output']]) - except KeyError: - raise DanglingLinkError(f"Unknown node UUID {link['input']} or {link['output']}") - - # The following table demonstrates the logic for inferring the link type - # (CODE, DATA) -> (WORK, CALC) : INPUT - # (CALC) -> (DATA) : CREATE - # (WORK) -> (DATA) : RETURN - # (WORK) -> (CALC, WORK) : CALL - if input_type in [NodeType.CODE, NodeType.DATA] and output_type in [NodeType.CALC, NodeType.WORK]: - link['type'] = LinkType.INPUT.value - elif input_type == NodeType.CALC and output_type == NodeType.DATA: - link['type'] = LinkType.CREATE.value - elif input_type == NodeType.WORK and output_type == NodeType.DATA: - link['type'] = LinkType.RETURN.value - elif input_type == NodeType.WORK and output_type in [NodeType.CALC, NodeType.WORK]: - link['type'] = LinkType.CALL.value - else: - link['type'] = LinkType.UNSPECIFIED.value - - # Now we migrate the entity key names i.e. removing the 'aiida.backends.djsite.db.models' prefix - for field in ['unique_identifiers', 'all_fields_info']: - for old_key, new_key in entity_map.items(): - if old_key in metadata[field]: - metadata[field][new_key] = metadata[field][old_key] - del metadata[field][old_key] - - # Replace the 'requires' keys in the nested dictionaries in 'all_fields_info' - for entity in metadata['all_fields_info'].values(): - for prop in entity.values(): - for key, value in prop.items(): - if key == 'requires' and value in entity_map: - prop[key] = entity_map[value] - - # Replace any present keys in the data.json - for field in ['export_data']: - for old_key, new_key in entity_map.items(): - if old_key in data[field]: - data[field][new_key] = data[field][old_key] - del data[field][old_key] - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v03_to_v04.py b/aiida/tools/importexport/archive/migrations/v03_to_v04.py deleted file mode 100644 index 63a6042ea7..0000000000 --- a/aiida/tools/importexport/archive/migrations/v03_to_v04.py +++ /dev/null @@ -1,510 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.3 to v0.4, used by `verdi export migrate` command. - -The migration steps are named similarly to the database migrations for Django and SQLAlchemy. -In the description of each migration, a revision number is given, which refers to the Django migrations. -The individual Django database migrations may be found at: - - `aiida.backends.djsite.db.migrations.00XX_.py` - -Where XX are the numbers in the migrations' documentation: REV. 1.0.XX -And migration-name is the name of the particular migration. -The individual SQLAlchemy database migrations may be found at: - - `aiida.backends.sqlalchemy.migrations.versions._.py` - -Where id is a SQLA id and migration-name is the name of the particular migration. -""" -# pylint: disable=invalid-name -import copy -import os - -import numpy as np - -from aiida.common import json -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.common.exceptions import ArchiveMigrationError - -from .utils import remove_fields, update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - - -def migration_base_data_plugin_type_string(data): - """Apply migration: 0009 - REV. 1.0.9 - `DbNode.type` content changes: - 'data.base.Bool.' -> 'data.bool.Bool.' - 'data.base.Float.' -> 'data.float.Float.' - 'data.base.Int.' -> 'data.int.Int.' - 'data.base.Str.' -> 'data.str.Str.' - 'data.base.List.' -> 'data.list.List.' - """ - for content in data['export_data'].get('Node', {}).values(): - if content.get('type', '').startswith('data.base.'): - type_str = content['type'].replace('data.base.', '') - type_str = f'data.{type_str.lower()}{type_str}' - content['type'] = type_str - - -def migration_process_type(metadata, data): - """Apply migrations: 0010 - REV. 1.0.10 - Add `DbNode.process_type` column - """ - # For data.json - for content in data['export_data'].get('Node', {}).values(): - if 'process_type' not in content: - content['process_type'] = '' - # For metadata.json - metadata['all_fields_info']['Node']['process_type'] = {} - - -def migration_code_sub_class_of_data(data): - """Apply migrations: 0016 - REV. 1.0.16 - The Code class used to be just a sub class of Node, but was changed to act like a Data node. - code.Code. -> data.code.Code. - """ - for content in data['export_data'].get('Node', {}).values(): - if content.get('type', '') == 'code.Code.': - content['type'] = 'data.code.Code.' - - -def migration_add_node_uuid_unique_constraint(data): - """Apply migration: 0014 - REV. 1.0.14, 0018 - REV. 1.0.18 - Check that no entries with the same uuid are present in the archive file - if yes - stop the import process - """ - for entry_type in ['Group', 'Computer', 'Node']: - if entry_type not in data['export_data']: # if a particular entry type is not present - skip - continue - all_uuids = [content['uuid'] for content in data['export_data'][entry_type].values()] - unique_uuids = set(all_uuids) - if len(all_uuids) != len(unique_uuids): - raise ArchiveMigrationError(f"""{entry_type}s with exactly the same UUID found, cannot proceed further.""") - - -def migration_migrate_builtin_calculations(data): - """Apply migrations: 0019 - REV. 1.0.19 - Remove 'simpleplugin' from ArithmeticAddCalculation and TemplatereplacerCalculation type - - ATTENTION: - - The 'process_type' column did not exist before migration 0010, consequently, it could not be present in any - export archive of the currently existing stable releases (0.12.*). Here, however, the migration acts - on the content of the 'process_type' column, which could only be introduced in alpha releases of AiiDA 1.0. - Assuming that 'add' and 'templateplacer' calculations are expected to have both 'type' and 'process_type' columns, - they will be added based solely on the 'type' column content (unlike the way it is done in the DB migration, - where the 'type_string' content was also checked). - """ - for key, content in data['export_data'].get('Node', {}).items(): - if content.get('type', '') == 'calculation.job.simpleplugins.arithmetic.add.ArithmeticAddCalculation.': - content['type'] = 'calculation.job.arithmetic.add.ArithmeticAddCalculation.' - content['process_type'] = 'aiida.calculations:arithmetic.add' - elif content.get('type', '') == 'calculation.job.simpleplugins.templatereplacer.TemplatereplacerCalculation.': - content['type'] = 'calculation.job.templatereplacer.TemplatereplacerCalculation.' - content['process_type'] = 'aiida.calculations:templatereplacer' - elif content.get('type', '') == 'data.code.Code.': - if data['node_attributes'][key]['input_plugin'] == 'simpleplugins.arithmetic.add': - data['node_attributes'][key]['input_plugin'] = 'arithmetic.add' - - elif data['node_attributes'][key]['input_plugin'] == 'simpleplugins.templatereplacer': - data['node_attributes'][key]['input_plugin'] = 'templatereplacer' - - -def migration_provenance_redesign(data): # pylint: disable=too-many-locals,too-many-branches,too-many-statements - """Apply migration: 0020 - REV. 1.0.20 - Provenance redesign - """ - from aiida.manage.database.integrity import write_database_integrity_violation - from aiida.manage.database.integrity.plugins import infer_calculation_entry_point - from aiida.plugins.entry_point import ENTRY_POINT_STRING_SEPARATOR - - fallback_cases = [] - calcjobs_to_migrate = {} - - for key, value in data['export_data'].get('Node', {}).items(): - if value.get('type', '').startswith('calculation.job.'): - calcjobs_to_migrate[key] = value - - if calcjobs_to_migrate: - # step1: rename the type column of process nodes - mapping_node_entry = infer_calculation_entry_point( - type_strings=[e['type'] for e in calcjobs_to_migrate.values()] - ) - for uuid, content in calcjobs_to_migrate.items(): - type_string = content['type'] - entry_point_string = mapping_node_entry[type_string] - - # If the entry point string does not contain the entry point string separator, - # the mapping function was not able to map the type string onto a known entry point string. - # As a fallback it uses the modified type string itself. - # All affected entries should be logged to file that the user can consult. - if ENTRY_POINT_STRING_SEPARATOR not in entry_point_string: - fallback_cases.append([uuid, type_string, entry_point_string]) - - content['process_type'] = entry_point_string - - if fallback_cases: - headers = ['UUID', 'type (old)', 'process_type (fallback)'] - warning_message = 'found calculation nodes with a type string ' \ - 'that could not be mapped onto a known entry point' - action_message = 'inferred `process_type` for all calculation nodes, ' \ - 'using fallback for unknown entry points' - write_database_integrity_violation(fallback_cases, headers, warning_message, action_message) - - # step2: detect and delete unexpected links - action_message = 'the link was deleted' - headers = ['UUID source', 'UUID target', 'link type', 'link label'] - - def delete_wrong_links(node_uuids, link_type, headers, warning_message, action_message): - """delete links that are matching link_type and are going from nodes listed in node_uuids""" - violations = [] - new_links_list = [] - for link in data['links_uuid']: - if link['input'] in node_uuids and link['type'] == link_type: - violations.append([link['input'], link['output'], link['type'], link['label']]) - else: - new_links_list.append(link) - data['links_uuid'] = new_links_list - if violations: - write_database_integrity_violation(violations, headers, warning_message, action_message) - - # calculations with outgoing CALL links - calculation_uuids = { - value['uuid'] for value in data['export_data'].get('Node', {}).values() if ( - value.get('type', '').startswith('calculation.job.') or - value.get('type', '').startswith('calculation.inline.') - ) - } - warning_message = 'detected calculation nodes with outgoing `call` links.' - delete_wrong_links(calculation_uuids, 'calllink', headers, warning_message, action_message) - - # calculations with outgoing RETURN links - warning_message = 'detected calculation nodes with outgoing `return` links.' - delete_wrong_links(calculation_uuids, 'returnlink', headers, warning_message, action_message) - - # outgoing CREATE links from FunctionCalculation and WorkCalculation nodes - warning_message = 'detected outgoing `create` links from FunctionCalculation and/or WorkCalculation nodes.' - work_uuids = { - value['uuid'] for value in data['export_data'].get('Node', {}).values() if ( - value.get('type', '').startswith('calculation.function') or - value.get('type', '').startswith('calculation.work') - ) - } - delete_wrong_links(work_uuids, 'createlink', headers, warning_message, action_message) - - for node_id, node in data['export_data'].get('Node', {}).items(): - # migrate very old `ProcessCalculation` to `WorkCalculation` - if node.get('type', '') == 'calculation.process.ProcessCalculation.': - node['type'] = 'calculation.work.WorkCalculation.' - - # WorkCalculations that have a `function_name` attribute are FunctionCalculations - if node.get('type', '') == 'calculation.work.WorkCalculation.': - if ( - 'function_name' in data['node_attributes'][node_id] and - data['node_attributes'][node_id]['function_name'] is not None - ): - # for some reason for the workchains the 'function_name' attribute is present but has None value - node['type'] = 'node.process.workflow.workfunction.WorkFunctionNode.' - else: - node['type'] = 'node.process.workflow.workchain.WorkChainNode.' - - # update type for JobCalculation nodes - if node.get('type', '').startswith('calculation.job.'): - node['type'] = 'node.process.calculation.calcjob.CalcJobNode.' - - # update type for InlineCalculation nodes - if node.get('type', '') == 'calculation.inline.InlineCalculation.': - node['type'] = 'node.process.calculation.calcfunction.CalcFunctionNode.' - - # update type for FunctionCalculation nodes - if node.get('type', '') == 'calculation.function.FunctionCalculation.': - node['type'] = 'node.process.workflow.workfunction.WorkFunctionNode.' - - uuid_node_type_mapping = { - node['uuid']: node['type'] for node in data['export_data'].get('Node', {}).values() if 'type' in node - } - for link in data['links_uuid']: - inp_uuid = link['output'] - # rename `createlink` to `create` - if link['type'] == 'createlink': - link['type'] = 'create' - # rename `returnlink` to `return` - elif link['type'] == 'returnlink': - link['type'] = 'return' - - elif link['type'] == 'inputlink': - # rename `inputlink` to `input_calc` if the target node is a calculation type node - if uuid_node_type_mapping[inp_uuid].startswith('node.process.calculation'): - link['type'] = 'input_calc' - # rename `inputlink` to `input_work` if the target node is a workflow type node - elif uuid_node_type_mapping[inp_uuid].startswith('node.process.workflow'): - link['type'] = 'input_work' - - elif link['type'] == 'calllink': - # rename `calllink` to `call_calc` if the target node is a calculation type node - if uuid_node_type_mapping[inp_uuid].startswith('node.process.calculation'): - link['type'] = 'call_calc' - # rename `calllink` to `call_work` if the target node is a workflow type node - elif uuid_node_type_mapping[inp_uuid].startswith('node.process.workflow'): - link['type'] = 'call_work' - - -def migration_dbgroup_name_to_label_type_to_type_string(metadata, data): - """Apply migrations: 0021 - REV. 1.0.21 - Rename dbgroup fields: - name -> label - type -> type_string - """ - # For data.json - for content in data['export_data'].get('Group', {}).values(): - if 'name' in content: - content['label'] = content.pop('name') - if 'type' in content: - content['type_string'] = content.pop('type') - # For metadata.json - metadata_group = metadata['all_fields_info']['Group'] - if 'name' in metadata_group: - metadata_group['label'] = metadata_group.pop('name') - if 'type' in metadata_group: - metadata_group['type_string'] = metadata_group.pop('type') - - -def migration_dbgroup_type_string_change_content(data): - """Apply migrations: 0022 - REV. 1.0.22 - Change type_string according to the following rule: - '' -> 'user' - 'data.upf.family' -> 'data.upf' - 'aiida.import' -> 'auto.import' - 'autogroup.run' -> 'auto.run' - """ - for content in data['export_data'].get('Group', {}).values(): - key_mapper = { - '': 'user', - 'data.upf.family': 'data.upf', - 'aiida.import': 'auto.import', - 'autogroup.run': 'auto.run' - } - if content['type_string'] in key_mapper: - content['type_string'] = key_mapper[content['type_string']] - - -def migration_calc_job_option_attribute_keys(data): - """Apply migrations: 0023 - REV. 1.0.23 - `custom_environment_variables` -> `environment_variables` - `jobresource_params` -> `resources` - `_process_label` -> `process_label` - `parser` -> `parser_name` - """ - - # Helper function - def _migration_calc_job_option_attribute_keys(attr_id, content): - """Apply migration 0023 - REV. 1.0.23 for both `node_attributes*` dicts in `data.json`""" - # For CalcJobNodes only - if data['export_data']['Node'][attr_id]['type'] == 'node.process.calculation.calcjob.CalcJobNode.': - key_mapper = { - 'custom_environment_variables': 'environment_variables', - 'jobresource_params': 'resources', - 'parser': 'parser_name' - } - # Need to loop over a clone because the `content` needs to be modified in place - for key in copy.deepcopy(content): - if key in key_mapper: - content[key_mapper[key]] = content.pop(key) - - # For all processes - if data['export_data']['Node'][attr_id]['type'].startswith('node.process.'): - if '_process_label' in content: - content['process_label'] = content.pop('_process_label') - - # Update node_attributes and node_attributes_conversion - attribute_dicts = ['node_attributes', 'node_attributes_conversion'] - for attribute_dict in attribute_dicts: - for attr_id, content in data[attribute_dict].items(): - if 'type' in data['export_data'].get('Node', {}).get(attr_id, {}): - _migration_calc_job_option_attribute_keys(attr_id, content) - - -def migration_move_data_within_node_module(data): - """Apply migrations: 0025 - REV. 1.0.25 - The type string for `Data` nodes changed from `data.*` to `node.data.*`. - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '').startswith('data.'): - value['type'] = value['type'].replace('data.', 'node.data.', 1) - - -def migration_trajectory_symbols_to_attribute(data: dict, folder: CacheFolder): - """Apply migrations: 0026 - REV. 1.0.26 and 0027 - REV. 1.0.27 - Create the symbols attribute from the repository array for all `TrajectoryData` nodes. - """ - path = folder.get_path(flush=False) - - for node_id, content in data['export_data'].get('Node', {}).items(): - if content.get('type', '') == 'node.data.array.trajectory.TrajectoryData.': - uuid = content['uuid'] - symbols_path = path.joinpath('nodes', uuid[0:2], uuid[2:4], uuid[4:], 'path', 'symbols.npy') - symbols = np.load(os.path.abspath(symbols_path)).tolist() - symbols_path.unlink() - # Update 'node_attributes' - data['node_attributes'][node_id].pop('array|symbols', None) - data['node_attributes'][node_id]['symbols'] = symbols - # Update 'node_attributes_conversion' - data['node_attributes_conversion'][node_id].pop('array|symbols', None) - data['node_attributes_conversion'][node_id]['symbols'] = [None] * len(symbols) - - -def migration_remove_node_prefix(data): - """Apply migrations: 0028 - REV. 1.0.28 - Change node type strings: - 'node.data.' -> 'data.' - 'node.process.' -> 'process.' - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '').startswith('node.data.'): - value['type'] = value['type'].replace('node.data.', 'data.', 1) - elif value.get('type', '').startswith('node.process.'): - value['type'] = value['type'].replace('node.process.', 'process.', 1) - - -def migration_rename_parameter_data_to_dict(data): - """Apply migration: 0029 - REV. 1.0.29 - Update ParameterData to Dict - """ - for value in data['export_data'].get('Node', {}).values(): - if value.get('type', '') == 'data.parameter.ParameterData.': - value['type'] = 'data.dict.Dict.' - - -def migration_dbnode_type_to_dbnode_node_type(metadata, data): - """Apply migration: 0030 - REV. 1.0.30 - Renaming DbNode.type to DbNode.node_type - """ - # For data.json - for content in data['export_data'].get('Node', {}).values(): - if 'type' in content: - content['node_type'] = content.pop('type') - # For metadata.json - if 'type' in metadata['all_fields_info']['Node']: - metadata['all_fields_info']['Node']['node_type'] = metadata['all_fields_info']['Node'].pop('type') - - -def migration_remove_dbcomputer_enabled(metadata, data): - """Apply migration: 0031 - REV. 1.0.31 - Remove DbComputer.enabled - """ - remove_fields(metadata, data, ['Computer'], ['enabled']) - - -def migration_replace_text_field_with_json_field(data): - """Apply migration 0033 - REV. 1.0.33 - Store dict-values as JSON serializable dicts instead of strings - NB! Specific for Django backend - """ - for content in data['export_data'].get('Computer', {}).values(): - for value in ['metadata', 'transport_params']: - if isinstance(content[value], str): - content[value] = json.loads(content[value]) - for content in data['export_data'].get('Log', {}).values(): - if isinstance(content['metadata'], str): - content['metadata'] = json.loads(content['metadata']) - - -def add_extras(data): - """Update data.json with the new Extras - Since Extras were not available previously and usually only include hashes, - the Node ids will be added, but included as empty dicts - """ - node_extras: dict = {} - node_extras_conversion: dict = {} - - for node_id in data['export_data'].get('Node', {}): - node_extras[node_id] = {} - node_extras_conversion[node_id] = {} - data.update({'node_extras': node_extras, 'node_extras_conversion': node_extras_conversion}) - - -def migrate_v3_to_v4(folder: CacheFolder): - """ - Migration of archive files from v0.3 to v0.4 - - Note concerning migration 0032 - REV. 1.0.32: - Remove legacy workflow tables: DbWorkflow, DbWorkflowData, DbWorkflowStep - These were (according to Antimo Marrazzo) never exported. - """ - old_version = '0.3' - new_version = '0.4' - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - # Apply migrations in correct sequential order - migration_base_data_plugin_type_string(data) - migration_process_type(metadata, data) - migration_code_sub_class_of_data(data) - migration_add_node_uuid_unique_constraint(data) - migration_migrate_builtin_calculations(data) - migration_provenance_redesign(data) - migration_dbgroup_name_to_label_type_to_type_string(metadata, data) - migration_dbgroup_type_string_change_content(data) - migration_calc_job_option_attribute_keys(data) - migration_move_data_within_node_module(data) - migration_trajectory_symbols_to_attribute(data, folder) - migration_remove_node_prefix(data) - migration_rename_parameter_data_to_dict(data) - migration_dbnode_type_to_dbnode_node_type(metadata, data) - migration_remove_dbcomputer_enabled(metadata, data) - migration_replace_text_field_with_json_field(data) - - # Add Node Extras - add_extras(data) - - # Update metadata.json with the new Log and Comment entities - new_entities = { - 'Log': { - 'uuid': {}, - 'time': { - 'convert_type': 'date' - }, - 'loggername': {}, - 'levelname': {}, - 'message': {}, - 'metadata': {}, - 'dbnode': { - 'related_name': 'dblogs', - 'requires': 'Node' - } - }, - 'Comment': { - 'uuid': {}, - 'ctime': { - 'convert_type': 'date' - }, - 'mtime': { - 'convert_type': 'date' - }, - 'content': {}, - 'dbnode': { - 'related_name': 'dbcomments', - 'requires': 'Node' - }, - 'user': { - 'related_name': 'dbcomments', - 'requires': 'User' - } - } - } - metadata['all_fields_info'].update(new_entities) - metadata['unique_identifiers'].update({'Log': 'uuid', 'Comment': 'uuid'}) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrations/v10_to_v11.py b/aiida/tools/importexport/archive/migrations/v10_to_v11.py deleted file mode 100644 index 6eef4f0ec6..0000000000 --- a/aiida/tools/importexport/archive/migrations/v10_to_v11.py +++ /dev/null @@ -1,76 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Migration from v0.10 to v0.11, used by `verdi export migrate` command. - -This migration deals with the file repository. In the old version, the -""" -import os -import shutil - -from aiida.tools.importexport.archive.common import CacheFolder - -from .utils import update_metadata, verify_metadata_version # pylint: disable=no-name-in-module - - -def migrate_repository(metadata, data, folder): - """Migrate the file repository to a disk object store container.""" - from disk_objectstore import Container - - from aiida.repository import File, Repository - from aiida.repository.backend import DiskObjectStoreRepositoryBackend - - container = Container(os.path.join(folder.get_path(), 'container')) - container.init_container() - backend = DiskObjectStoreRepositoryBackend(container=container) - repository = Repository(backend=backend) - - for values in data.get('export_data', {}).get('Node', {}).values(): - uuid = values['uuid'] - dirpath_calc = os.path.join(folder.get_path(), 'nodes', uuid[:2], uuid[2:4], uuid[4:], 'raw_input') - dirpath_data = os.path.join(folder.get_path(), 'nodes', uuid[:2], uuid[2:4], uuid[4:], 'path') - - if os.path.isdir(dirpath_calc): - dirpath = dirpath_calc - elif os.path.isdir(dirpath_data): - dirpath = dirpath_data - else: - raise AssertionError('node repository contains neither `raw_input` nor `path` subfolder.') - - if not os.listdir(dirpath): - continue - - repository.put_object_from_tree(dirpath) - values['repository_metadata'] = repository.serialize() - # Artificially reset the metadata - repository._directory = File() # pylint: disable=protected-access - - container.pack_all_loose(compress=False) - shutil.rmtree(os.path.join(folder.get_path(), 'nodes')) - - metadata['all_fields_info']['Node']['repository_metadata'] = {} - - -def migrate_v10_to_v11(folder: CacheFolder): - """Migration of export files from v0.10 to v0.11.""" - old_version = '0.10' - new_version = '0.11' - - _, metadata = folder.load_json('metadata.json') - - verify_metadata_version(metadata, old_version) - update_metadata(metadata, new_version) - - _, data = folder.load_json('data.json') - - # Apply migrations - migrate_repository(metadata, data, folder) - - folder.write_json('metadata.json', metadata) - folder.write_json('data.json', data) diff --git a/aiida/tools/importexport/archive/migrators.py b/aiida/tools/importexport/archive/migrators.py deleted file mode 100644 index 181c80ca73..0000000000 --- a/aiida/tools/importexport/archive/migrators.py +++ /dev/null @@ -1,280 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Archive migration classes, for migrating an archive to different versions.""" -from abc import ABC, abstractmethod -import json -import os -from pathlib import Path -import shutil -import tarfile -import tempfile -from typing import Any, Callable, List, Optional, Type, Union, cast -import zipfile - -from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip - -from aiida.common.log import AIIDA_LOGGER -from aiida.common.progress_reporter import create_callback, get_progress_reporter -from aiida.tools.importexport.archive.common import CacheFolder -from aiida.tools.importexport.archive.migrations import MIGRATE_FUNCTIONS -from aiida.tools.importexport.common.config import ExportFileFormat -from aiida.tools.importexport.common.exceptions import ArchiveMigrationError, CorruptArchive, DanglingLinkError - -__all__ = ( - 'ArchiveMigratorAbstract', 'ArchiveMigratorJsonBase', 'ArchiveMigratorJsonZip', 'ArchiveMigratorJsonTar', - 'MIGRATE_LOGGER', 'get_migrator' -) - -MIGRATE_LOGGER = AIIDA_LOGGER.getChild('migrate') - - -def get_migrator(file_format: str) -> Type['ArchiveMigratorAbstract']: - """Return the available archive migrator classes.""" - migrators = { - ExportFileFormat.ZIP: ArchiveMigratorJsonZip, - ExportFileFormat.TAR_GZIPPED: ArchiveMigratorJsonTar, - } - - if file_format not in migrators: - raise ValueError( - f'Can only migrate in the formats: {tuple(migrators.keys())}, please specify one for "file_format".' - ) - - return cast(Type[ArchiveMigratorAbstract], migrators[file_format]) - - -class ArchiveMigratorAbstract(ABC): - """An abstract base class to define an archive migrator.""" - - def __init__(self, filepath: str): - """Initialise the migrator - - :param filepath: the path to the archive file - :param version: the version of the archive file or, if None, the version will be auto-retrieved. - - """ - self._filepath = filepath - - @property - def filepath(self) -> str: - """Return the input file path.""" - return self._filepath - - @abstractmethod - def migrate( - self, - version: str, - filename: Optional[Union[str, Path]], - *, - force: bool = False, - work_dir: Optional[Path] = None, - **kwargs: Any - ) -> Optional[Path]: - """Migrate the archive to another version - - :param version: the version to migrate to - :param filename: the file path to migrate to. - If None, the migrated archive will not be copied from the work_dir. - :param force: overwrite output file if it already exists - :param work_dir: The directory in which to perform the migration. - If None, a temporary folder will be created and destroyed at the end of the process. - :param kwargs: key-word arguments specific to the concrete migrator implementation - - :returns: path to the migrated archive or None if no migration performed - (if filename is None, this will point to a path in the work_dir) - - :raises: :class:`~aiida.tools.importexport.common.exceptions.CorruptArchive`: - if the archive cannot be read - :raises: :class:`~aiida.tools.importexport.common.exceptions.ArchiveMigrationError`: - if the archive cannot migrated to the requested version - - """ - - -class ArchiveMigratorJsonBase(ArchiveMigratorAbstract): - """A migrator base for the JSON compressed formats.""" - - # pylint: disable=arguments-differ - def migrate( - self, - version: str, - filename: Optional[Union[str, Path]], - *, - force: bool = False, - work_dir: Optional[Path] = None, - out_compression: str = 'zip', - **kwargs - ) -> Optional[Path]: - # pylint: disable=too-many-branches - - if not isinstance(version, str): - raise TypeError('version must be a string') - - if filename and Path(filename).exists() and not force: - raise IOError(f'the output path already exists and force=False: {filename}') - - allowed_compressions = ['zip', 'zip-uncompressed', 'tar.gz', 'none'] - if out_compression not in allowed_compressions: - raise ValueError(f'Output compression must be in: {allowed_compressions}') - - MIGRATE_LOGGER.report('Reading archive version') - current_version = self._retrieve_version() - - # compute the migration pathway - prev_version = current_version - pathway: List[str] = [] - while prev_version != version: - if prev_version not in MIGRATE_FUNCTIONS: - raise ArchiveMigrationError(f"No migration pathway available for '{current_version}' to '{version}'") - if prev_version in pathway: - raise ArchiveMigrationError( - f'cyclic migration pathway encountered: {" -> ".join(pathway + [prev_version])}' - ) - pathway.append(prev_version) - prev_version = MIGRATE_FUNCTIONS[prev_version][0] - - if not pathway: - MIGRATE_LOGGER.report('No migration required') - return None - - MIGRATE_LOGGER.report('Migration pathway: %s', ' -> '.join(pathway + [version])) - - # perform migrations - if work_dir is not None: - migrated_path = self._perform_migration(Path(work_dir), pathway, out_compression, filename) - else: - with tempfile.TemporaryDirectory() as tmpdirname: - migrated_path = self._perform_migration(Path(tmpdirname), pathway, out_compression, filename) - MIGRATE_LOGGER.debug('Cleaning temporary folder') - - return migrated_path - - def _perform_migration( - self, work_dir: Path, pathway: List[str], out_compression: str, out_path: Optional[Union[str, Path]] - ) -> Path: - """Perform the migration(s) in the work directory, compress (if necessary), - then move to the out_path (if not None). - """ - MIGRATE_LOGGER.report('Extracting archive to work directory') - - extracted = Path(work_dir) / 'extracted' - extracted.mkdir(parents=True) - - with get_progress_reporter()(total=1) as progress: - callback = create_callback(progress) - self._extract_archive(extracted, callback) - - with CacheFolder(extracted) as folder: - with get_progress_reporter()(total=len(pathway), desc='Performing migrations: ') as progress: - for from_version in pathway: - to_version = MIGRATE_FUNCTIONS[from_version][0] - progress.set_description_str(f'Performing migrations: {from_version} -> {to_version}', refresh=True) - try: - MIGRATE_FUNCTIONS[from_version][1](folder) - except DanglingLinkError: - raise ArchiveMigrationError('Archive file is invalid because it contains dangling links') - progress.update() - MIGRATE_LOGGER.debug('Flushing cache') - - # re-compress archive - if out_compression != 'none': - MIGRATE_LOGGER.report(f"Re-compressing archive as '{out_compression}'") - migrated = work_dir / 'compressed' - else: - migrated = extracted - - if out_compression == 'zip': - self._compress_archive_zip(extracted, migrated, zipfile.ZIP_DEFLATED) - elif out_compression == 'zip-uncompressed': - self._compress_archive_zip(extracted, migrated, zipfile.ZIP_STORED) - elif out_compression == 'tar.gz': - self._compress_archive_tar(extracted, migrated) - - if out_path is not None: - # move to final location - MIGRATE_LOGGER.report('Moving archive to: %s', out_path) - self._move_file(migrated, Path(out_path)) - - return Path(out_path) if out_path else migrated - - @staticmethod - def _move_file(in_path: Path, out_path: Path): - """Move a file to a another path, deleting the target path first if it exists.""" - if out_path.exists(): - if os.path.samefile(str(in_path), str(out_path)): - return - if out_path.is_file(): - out_path.unlink() - else: - shutil.rmtree(out_path) - shutil.move(in_path, out_path) # type: ignore - - def _retrieve_version(self) -> str: - """Retrieve the version of the input archive.""" - raise NotImplementedError() - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - """Extract the archive to a filepath.""" - raise NotImplementedError() - - @staticmethod - def _compress_archive_zip(in_path: Path, out_path: Path, compression: int): - """Create a new zip compressed zip from a folder.""" - with get_progress_reporter()(total=1, desc='Compressing to zip') as progress: - _callback = create_callback(progress) - with ZipPath(out_path, mode='w', compression=compression, allow_zip64=True) as path: - path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to zip') - - @staticmethod - def _compress_archive_tar(in_path: Path, out_path: Path): - """Create a new zip compressed tar from a folder.""" - with get_progress_reporter()(total=1, desc='Compressing to tar') as progress: - _callback = create_callback(progress) - with TarPath(out_path, mode='w:gz', dereference=True) as path: - path.puttree(in_path, check_exists=False, callback=_callback, cb_descript='Compressing to tar') - - -class ArchiveMigratorJsonZip(ArchiveMigratorJsonBase): - """A migrator for a JSON zip compressed format.""" - - def _retrieve_version(self) -> str: - try: - metadata = json.loads(read_file_in_zip(self.filepath, 'metadata.json')) - except (IOError, FileNotFoundError) as error: - raise CorruptArchive(str(error)) - if 'export_version' not in metadata: - raise CorruptArchive("metadata.json doest not contain an 'export_version' key") - return metadata['export_version'] - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - try: - ZipPath(self.filepath, mode='r', allow_zip64=True).extract_tree(filepath, callback=callback) - except zipfile.BadZipfile as error: - raise CorruptArchive(f'The input file cannot be read: {error}') - - -class ArchiveMigratorJsonTar(ArchiveMigratorJsonBase): - """A migrator for a JSON tar compressed format.""" - - def _retrieve_version(self) -> str: - try: - metadata = json.loads(read_file_in_tar(self.filepath, 'metadata.json')) - except (IOError, FileNotFoundError) as error: - raise CorruptArchive(str(error)) - if 'export_version' not in metadata: - raise CorruptArchive("metadata.json doest not contain an 'export_version' key") - return metadata['export_version'] - - def _extract_archive(self, filepath: Path, callback: Callable[[str, Any], None]): - try: - TarPath(self.filepath, mode='r:*', pax_format=tarfile.PAX_FORMAT - ).extract_tree(filepath, allow_dev=False, allow_symlink=False, callback=callback) - except tarfile.ReadError as error: - raise CorruptArchive(f'The input file cannot be read: {error}') diff --git a/aiida/tools/importexport/archive/readers.py b/aiida/tools/importexport/archive/readers.py deleted file mode 100644 index 747ff077e1..0000000000 --- a/aiida/tools/importexport/archive/readers.py +++ /dev/null @@ -1,442 +0,0 @@ -# -*- coding: utf-8 -*- -########################################################################### -# Copyright (c), The AiiDA team. All rights reserved. # -# This file is part of the AiiDA code. # -# # -# The code is hosted on GitHub at https://github.com/aiidateam/aiida-core # -# For further information on the license, see the LICENSE.txt file # -# For further information please visit http://www.aiida.net # -########################################################################### -"""Archive reader classes.""" -from abc import ABC, abstractmethod -from distutils.version import StrictVersion -import json -from pathlib import Path -import tarfile -from types import TracebackType -from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Tuple, Type, cast -import zipfile - -from archive_path import TarPath, ZipPath, read_file_in_tar, read_file_in_zip -from disk_objectstore import Container - -from aiida.common.exceptions import InvalidOperation -from aiida.common.folders import SandboxFolder -from aiida.common.log import AIIDA_LOGGER -from aiida.tools.importexport.archive.common import ArchiveMetadata, null_callback -from aiida.tools.importexport.common.config import EXPORT_VERSION, GROUP_ENTITY_NAME, NODE_ENTITY_NAME, ExportFileFormat -from aiida.tools.importexport.common.exceptions import CorruptArchive, IncompatibleArchiveVersionError - -__all__ = ( - 'ArchiveReaderAbstract', - 'ARCHIVE_READER_LOGGER', - 'ReaderJsonBase', - 'ReaderJsonFolder', - 'ReaderJsonTar', - 'ReaderJsonZip', - 'get_reader', -) - -ARCHIVE_READER_LOGGER = AIIDA_LOGGER.getChild('archive.reader') - - -def get_reader(file_format: str) -> Type['ArchiveReaderAbstract']: - """Return the available writer classes.""" - readers = { - ExportFileFormat.ZIP: ReaderJsonZip, - ExportFileFormat.TAR_GZIPPED: ReaderJsonTar, - 'folder': ReaderJsonFolder, - } - - if file_format not in readers: - raise ValueError( - f'Can only read in the formats: {tuple(readers.keys())}, please specify one for "file_format".' - ) - - return cast(Type[ArchiveReaderAbstract], readers[file_format]) - - -class ArchiveReaderAbstract(ABC): - """An abstract interface for AiiDA archive readers. - - An ``ArchiveReader`` implementation is intended to be used with a context:: - - with ArchiveReader(filename) as reader: - reader.entity_count('Node') - - """ - - def __init__(self, filename: str, **kwargs: Any): - """An archive reader - - :param filename: the filename (possibly including the absolute path) - of the file to import. - - """ - # pylint: disable=unused-argument - self._filename = filename - self._in_context = False - - @property - def filename(self) -> str: - """Return the name of the file that is being read from.""" - return self._filename - - @property - @abstractmethod - def file_format_verbose(self) -> str: - """The file format name.""" - - @property - @abstractmethod - def compatible_export_version(self) -> str: - """Return the export version that this reader is compatible with.""" - - def __enter__(self) -> 'ArchiveReaderAbstract': - self._in_context = True - return self - - def __exit__( - self, exctype: Optional[Type[BaseException]], excinst: Optional[BaseException], exctb: Optional[TracebackType] - ): - self._in_context = False - - def assert_within_context(self): - """Assert that the method is called within a context. - - :raises: `~aiida.common.exceptions.InvalidOperation`: if not called within a context - """ - if not self._in_context: - raise InvalidOperation('the ArchiveReader method should be used within a context') - - @property - @abstractmethod - def export_version(self) -> str: - """Return the export version. - - :raises `~aiida.tools.importexport.common.exceptions.CorruptArchive`: If the version cannot be retrieved. - """ - # this should be able to be returned independent of any metadata validation - - def check_version(self): - """Check the version compatibility of the archive. - - :raises: `~aiida.tools.importexport.common.exceptions.IncompatibleArchiveVersionError`: - If the version is not compatible - - """ - file_version = StrictVersion(self.export_version) - expected_version = StrictVersion(self.compatible_export_version) - - try: - if file_version != expected_version: - msg = f'Archive file version is {file_version}, can read only version {expected_version}' - if file_version < expected_version: - msg += "\nUse 'verdi export migrate' to update this archive file." - else: - msg += '\nUpdate your AiiDA version in order to import this file.' - - raise IncompatibleArchiveVersionError(msg) - except AttributeError: - msg = ( - f'Archive file version is {self.export_version}, ' - f'can read only version {self.compatible_export_version}' - ) - raise IncompatibleArchiveVersionError(msg) - - @property - @abstractmethod - def metadata(self) -> ArchiveMetadata: - """Return the full (validated) archive metadata.""" - - @property - def entity_names(self) -> List[str]: - """Return list of all entity names.""" - return list(self.metadata.all_fields_info.keys()) - - @abstractmethod - def entity_count(self, name: str) -> int: - """Return the count of an entity or None if not contained in the archive.""" - - @property - @abstractmethod - def link_count(self) -> int: - """Return the count of links.""" - - @abstractmethod - def iter_entity_fields(self, - name: str, - fields: Optional[Tuple[str, ...]] = None) -> Iterator[Tuple[int, Dict[str, Any]]]: - """Iterate over entities and yield their pk and database fields.""" - - @abstractmethod - def iter_node_uuids(self) -> Iterator[str]: - """Iterate over node UUIDs.""" - - @abstractmethod - def iter_group_uuids(self) -> Iterator[Tuple[str, Set[str]]]: - """Iterate over group UUIDs and the a set of node UUIDs they contain.""" - - @abstractmethod - def iter_link_data(self) -> Iterator[dict]: - """Iterate over links: {'input': , 'output': , 'label':