Skip to content

Commit

Permalink
✨ NEW: Add orm.Entity.fields interface for QueryBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisjsewell committed Feb 17, 2022
1 parent da5f9b0 commit 02bed54
Show file tree
Hide file tree
Showing 72 changed files with 1,293 additions and 50 deletions.
5 changes: 5 additions & 0 deletions aiida/orm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from .comments import *
from .computers import *
from .entities import *
from .fields import *
from .groups import *
from .logs import *
from .nodes import *
Expand Down Expand Up @@ -78,6 +79,10 @@
'OrmEntityLoader',
'ProcessNode',
'ProjectionData',
'QbAttrField',
'QbField',
'QbFieldFilters',
'QbFields',
'QueryBuilder',
'RemoteData',
'RemoteStashData',
Expand Down
13 changes: 11 additions & 2 deletions aiida/orm/authinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from aiida.plugins import TransportFactory

from . import entities, users
from .fields import QbField

if TYPE_CHECKING:
from aiida.orm import Computer, User
Expand Down Expand Up @@ -43,11 +44,19 @@ def delete(self, pk: int) -> None:
class AuthInfo(entities.Entity['BackendAuthInfo']):
"""ORM class that models the authorization information that allows a `User` to connect to a `Computer`."""

__qb_fields__ = (
QbField('enabled', dtype=bool, doc='Whether the instance is enabled'),
QbField('auth_params', dtype=Dict[str, Any], doc='Dictionary of authentication parameters'),
QbField('metadata', dtype=Dict[str, Any], doc='Dictionary of metadata'),
QbField('computer_pk', 'dbcomputer_id', dtype=int, doc='The PK of the computer'),
QbField('user_pk', 'aiidauser_id', dtype=int, doc='The PK of the user'),
)

Collection = AuthInfoCollection

@classproperty
def objects(cls: Type['AuthInfo']) -> AuthInfoCollection: # type: ignore[misc] # pylint: disable=no-self-argument
return AuthInfoCollection.get_cached(cls, get_manager().get_profile_storage())
def objects(cls: Type['AuthInfo']) -> AuthInfoCollection: # type: ignore # pylint: disable=no-self-argument
return AuthInfoCollection.get_cached(cls, get_manager().get_profile_storage()) # type: ignore

PROPERTY_WORKDIR = 'workdir'

Expand Down
14 changes: 12 additions & 2 deletions aiida/orm/comments.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from aiida.manage import get_manager

from . import entities, users
from .fields import QbField

if TYPE_CHECKING:
from aiida.orm import Node, User
Expand Down Expand Up @@ -66,11 +67,20 @@ def delete_many(self, filters: dict) -> List[int]:
class Comment(entities.Entity['BackendComment']):
"""Base class to map a DbComment that represents a comment attached to a certain Node."""

__qb_fields__ = (
QbField('uuid', dtype=str, doc='The UUID of the comment'),
QbField('ctime', dtype=datetime, doc='Creation time of the comment'),
QbField('mtime', dtype=datetime, doc='Modified time of the comment'),
QbField('content', dtype=str, doc='Content of the comment'),
QbField('user_pk', 'user_id', dtype=int, doc='User PK that created the comment'),
QbField('node_pk', 'dbnode_id', dtype=int, doc='Node PK that the comment is attached to'),
)

Collection = CommentCollection

@classproperty
def objects(cls: Type['Comment']) -> CommentCollection: # type: ignore[misc] # pylint: disable=no-self-argument
return CommentCollection.get_cached(cls, get_manager().get_profile_storage())
def objects(cls: Type['Comment']) -> CommentCollection: # type: ignore # pylint: disable=no-self-argument
return CommentCollection.get_cached(cls, get_manager().get_profile_storage()) # type: ignore

def __init__(self, node: 'Node', user: 'User', content: Optional[str] = None, backend: Optional['Backend'] = None):
"""Create a Comment for a given node and user
Expand Down
15 changes: 13 additions & 2 deletions aiida/orm/computers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from aiida.plugins import SchedulerFactory, TransportFactory

from . import entities, users
from .fields import QbField

if TYPE_CHECKING:
from aiida.orm import AuthInfo, User
Expand Down Expand Up @@ -75,11 +76,21 @@ class Computer(entities.Entity['BackendComputer']):
PROPERTY_WORKDIR = 'workdir'
PROPERTY_SHEBANG = 'shebang'

__qb_fields__ = (
QbField('uuid', dtype=str, doc='The UUID of the computer'),
QbField('label', dtype=str, doc='Label for the computer'),
QbField('description', dtype=str, doc='Description of the computer'),
QbField('hostname', dtype=str, doc='Hostname of the computer'),
QbField('transport_type', dtype=str, doc='Transport type of the computer'),
QbField('scheduler_type', dtype=str, doc='Scheduler type of the computer'),
QbField('metadata', dtype=Dict[str, Any], doc='Metadata of the computer'),
)

Collection = ComputerCollection

@classproperty
def objects(cls: Type['Computer']) -> ComputerCollection: # type: ignore[misc] # pylint: disable=no-self-argument
return ComputerCollection.get_cached(cls, get_manager().get_profile_storage())
def objects(cls: Type['Computer']) -> ComputerCollection: # type: ignore # pylint: disable=no-self-argument
return ComputerCollection.get_cached(cls, get_manager().get_profile_storage()) # type: ignore

def __init__( # pylint: disable=too-many-arguments
self,
Expand Down
10 changes: 8 additions & 2 deletions aiida/orm/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,16 @@
import copy
from enum import Enum
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol, Type, TypeVar, cast
from typing import TYPE_CHECKING, Any, Dict, Generic, List, Optional, Protocol, Sequence, Type, TypeVar, cast

from plumpy.base.utils import call_with_super_check, super_check

from aiida.common import exceptions
from aiida.common.lang import classproperty, type_check
from aiida.manage import get_manager

from .fields import EntityFieldMeta, QbField, QbFields

if TYPE_CHECKING:
from aiida.orm.implementation import Backend, BackendEntity
from aiida.orm.querybuilder import FilterType, OrderByType, QueryBuilder
Expand Down Expand Up @@ -161,9 +163,13 @@ def count(self, filters: Optional['FilterType'] = None) -> int:
return self.query(filters=filters).count()


class Entity(abc.ABC, Generic[BackendEntityType]):
class Entity(Generic[BackendEntityType], metaclass=EntityFieldMeta):
"""An AiiDA entity"""

fields: QbFields = QbFields()

__qb_fields__: Sequence[QbField] = (QbField('pk', 'id', dtype=int, doc='The primary key of the entity'),)

@classproperty
@abc.abstractmethod
def objects(cls: EntityType) -> Collection[EntityType]: # pylint: disable=no-self-argument,disable=no-self-use
Expand Down
Loading

0 comments on commit 02bed54

Please sign in to comment.