Skip to content

Commit

Permalink
feat: add get_all_records DB method
Browse files Browse the repository at this point in the history
  • Loading branch information
jsstevenson committed Oct 26, 2023
1 parent fb9b7fd commit 9598b5a
Show file tree
Hide file tree
Showing 6 changed files with 199 additions and 43 deletions.
23 changes: 21 additions & 2 deletions disease/database/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
from enum import Enum
from os import environ
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, Generator, List, Optional, Set, Union

import click

from disease.schemas import RefType, SourceMeta, SourceName
from disease.schemas import RecordType, RefType, SourceMeta, SourceName


class DatabaseException(Exception): # noqa: N818
Expand Down Expand Up @@ -142,6 +142,25 @@ def get_all_concept_ids(self, source: Optional[SourceName] = None) -> Set[str]:
:return: Set of concept IDs as strings.
"""

@abc.abstractmethod
def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]:
"""Retrieve all source or normalized records. Either return all source records,
or all records that qualify as "normalized" (i.e., merged groups + source
records that are otherwise ungrouped).
For example,
.. code-block::pycon
>>> from disease.database import create_db
>>> from disease.schemas import RecordType
>>> db = create_db()
>>> for record in db.get_all_records(RecordType.MERGER):
>>> pass # do something
:param record_type: type of result to return
:return: Generator that lazily provides records as they are retrieved
"""

@abc.abstractmethod
def add_source_metadata(self, src_name: SourceName, data: SourceMeta) -> None:
"""Add new source metadata entry.
Expand Down
61 changes: 52 additions & 9 deletions disease/database/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import sys
from os import environ
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, Generator, List, Optional, Set, Union

import boto3
import click
Expand All @@ -23,7 +23,7 @@
DatabaseWriteException,
confirm_aws_db_use,
)
from disease.schemas import RefType, SourceMeta, SourceName
from disease.schemas import RecordType, RefType, SourceMeta, SourceName

_logger = logging.getLogger()
_logger.setLevel(logging.DEBUG)
Expand Down Expand Up @@ -201,7 +201,7 @@ def check_tables_populated(self) -> bool:

records = self.diseases.query(
IndexName="item_type_index",
KeyConditionExpression=Key("item_type").eq("identity"),
KeyConditionExpression=Key("item_type").eq(RecordType.IDENTITY.value),
Limit=1,
)
if len(records.get("Items", [])) < 1:
Expand All @@ -210,7 +210,7 @@ def check_tables_populated(self) -> bool:

normalized_records = self.diseases.query(
IndexName="item_type_index",
KeyConditionExpression=Key("item_type").eq("merger"),
KeyConditionExpression=Key("item_type").eq(RecordType.MERGER.value),
Limit=1,
)
if len(normalized_records.get("Items", [])) < 1:
Expand Down Expand Up @@ -254,9 +254,9 @@ def get_record_by_id(
"""
try:
if merge:
pk = f"{concept_id.lower()}##merger"
pk = f"{concept_id.lower()}##{RecordType.MERGER.value}"
else:
pk = f"{concept_id.lower()}##identity"
pk = f"{concept_id.lower()}##{RecordType.IDENTITY.value}"
if case_sensitive:
response = self.diseases.get_item(
Key={"label_and_type": pk, "concept_id": concept_id}
Expand Down Expand Up @@ -329,6 +329,47 @@ def get_all_concept_ids(self, source: Optional[SourceName] = None) -> Set[str]:
break
return set(concept_ids)

def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]:
"""Retrieve all source or normalized records. Either return all source records,
or all records that qualify as "normalized" (i.e., merged groups + source
records that are otherwise ungrouped).
For example,
.. code-block::pycon
>>> from disease.database import create_db
>>> from disease.schemas import RecordType
>>> db = create_db()
>>> for record in db.get_all_records(RecordType.MERGER):
>>> pass # do something
:param record_type: type of result to return
:return: Generator that lazily provides records as they are retrieved
"""
last_evaluated_key = None
while True:
if last_evaluated_key:
response = self.diseases.scan(
ExclusiveStartKey=last_evaluated_key,
)
else:
response = self.diseases.scan()
records = response.get("Items", [])
for record in records:
incoming_record_type = record.get("item_type")
if record_type == RecordType.IDENTITY:
if incoming_record_type == record_type:
yield record
else:
if (
incoming_record_type == RecordType.IDENTITY
and not record.get("merge_ref")
) or incoming_record_type == RecordType.MERGER:
yield record
last_evaluated_key = response.get("LastEvaluatedKey")
if not last_evaluated_key:
break

def add_source_metadata(self, src_name: SourceName, metadata: SourceMeta) -> None:
"""Add new source metadata entry.
Expand Down Expand Up @@ -383,9 +424,9 @@ def add_merged_record(self, record: Dict) -> None:
concept_id = record["concept_id"]
id_prefix = concept_id.split(":")[0].lower()
record["src_name"] = PREFIX_LOOKUP[id_prefix]
label_and_type = f"{concept_id.lower()}##merger"
label_and_type = f"{concept_id.lower()}##{RecordType.MERGER.value}"
record["label_and_type"] = label_and_type
record["item_type"] = "merger"
record["item_type"] = RecordType.MERGER.value
try:
self.batch.put_item(Item=record)
except ClientError as e:
Expand Down Expand Up @@ -467,7 +508,9 @@ def delete_normalized_concepts(self) -> None:
try:
response = self.diseases.query(
IndexName="item_type_index",
KeyConditionExpression=Key("item_type").eq("merger"),
KeyConditionExpression=Key("item_type").eq(
RecordType.MERGER.value
),
)
except ClientError as e:
raise DatabaseReadException(e)
Expand Down
119 changes: 95 additions & 24 deletions disease/database/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import tempfile
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Union
from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union

import psycopg
import requests
Expand All @@ -18,7 +18,7 @@
)

from disease.database import AbstractDatabase, DatabaseException, DatabaseWriteException
from disease.schemas import RefType, SourceMeta, SourceName
from disease.schemas import RecordType, RefType, SourceMeta, SourceName

logger = logging.getLogger()

Expand Down Expand Up @@ -289,6 +289,25 @@ def get_source_metadata(self, src_name: Union[str, SourceName]) -> Optional[Dict

_record_query = b"SELECT * FROM record_lookup_view WHERE lower(concept_id) = %s;"

def _format_source_record(self, source_row: Tuple) -> Dict:
"""Restructure row from disease_concepts table as source record result object.
:param source_row: result tuple from psycopg
:return: reformatted dictionary keying gene properties to row values
"""
disease_record = {
"concept_id": source_row[0],
"label": source_row[1],
"aliases": source_row[2],
"associated_with": source_row[3],
"xrefs": source_row[4],
"src_name": source_row[5],
"merge_ref": source_row[6],
"pediatric_disease": source_row[7],
"item_type": RecordType.IDENTITY.value,
}
return {k: v for k, v in disease_record.items() if v}

def _get_record(self, concept_id: str, case_sensitive: bool) -> Optional[Dict]:
"""Retrieve non-merged record. The query is pretty different, so this method
is broken out for PostgreSQL.
Expand All @@ -306,21 +325,27 @@ def _get_record(self, concept_id: str, case_sensitive: bool) -> Optional[Dict]:
if not result:
return None

disease_record = {
"concept_id": result[0],
"label": result[1],
"aliases": result[2],
"associated_with": result[3],
"xrefs": result[4],
"src_name": result[5],
"merge_ref": result[6],
"pediatric_disease": result[7],
"item_type": "identity",
}
return {k: v for k, v in disease_record.items() if v}
return self._format_source_record(result)

_merged_record_query = b"SELECT * FROM disease_merged WHERE lower(concept_id) = %s;"

def _format_merged_record(self, merged_row: Tuple) -> Dict:
"""Restructure row from disease_merged table as normalized result object.
:param merged_row: result tuple from psycopg
:return: reformatted dictionary keying normalized gene properties to row values
"""
merged_record = {
"concept_id": merged_row[0],
"label": merged_row[1],
"aliases": merged_row[2],
"associated_with": merged_row[3],
"xrefs": merged_row[4],
"pediatric_disease": merged_row[5],
"item_type": RecordType.MERGER.value,
}
return {k: v for k, v in merged_record.items() if v}

def _get_merged_record(
self, concept_id: str, case_sensitive: bool
) -> Optional[Dict]:
Expand All @@ -338,16 +363,7 @@ def _get_merged_record(
if not result:
return None

merged_record = {
"concept_id": result[0],
"label": result[1],
"aliases": result[2],
"associated_with": result[3],
"xrefs": result[4],
"pediatric_disease": result[5],
"item_type": "merger",
}
return {k: v for k, v in merged_record.items() if v}
return self._format_merged_record(result)

def get_record_by_id(
self, concept_id: str, case_sensitive: bool = True, merge: bool = False
Expand Down Expand Up @@ -417,6 +433,61 @@ def get_all_concept_ids(self, source: Optional[SourceName] = None) -> Set[str]:
ids_tuple = cur.fetchall()
return {i[0] for i in ids_tuple}

_get_all_normalized_records_query = b"SELECT * FROM disease_merged;"
_get_all_unmerged_source_records_query = (
b"SELECT * FROM record_lookup_view WHERE merge_ref IS NULL;" # noqa: E501
)
_get_all_source_records_query = b"SELECT * FROM record_lookup_view;"

def get_all_records(self, record_type: RecordType) -> Generator[Dict, None, None]:
"""Retrieve all source or normalized records. Either return all source records,
or all records that qualify as "normalized" (i.e., merged groups + source
records that are otherwise ungrouped).
For example,
.. code-block::pycon
>>> from disease.database import create_db
>>> from disease.schemas import RecordType
>>> db = create_db()
>>> for record in db.get_all_records(RecordType.MERGER):
>>> pass # do something
Unlike DynamoDB, merged records are stored in a separate table from source
records. As a result, when fetching all normalized records, merged records are
return first, and iteration continues with all source records that don't
belong to a normalized concept group.
:param record_type: type of result to return
:return: Generator that lazily provides records as they are retrieved
"""
batch_size = 500

if record_type == RecordType.MERGER:
with self.conn.cursor() as cur:
results = cur.execute(self._get_all_normalized_records_query)
fetched = results.fetchmany(batch_size)
while fetched:
for row in fetched:
yield self._format_merged_record(row)
fetched = results.fetchmany(batch_size)
with self.conn.cursor() as cur:
results = cur.execute(self._get_all_unmerged_source_records_query)
fetched = results.fetchmany(batch_size)
while fetched:
for result in fetched:
yield self._format_source_record(result)
fetched = results.fetchmany(batch_size)
else:
with self.conn.cursor() as cur:
results = cur.execute(self._get_all_source_records_query)
fetched = results.fetchmany(batch_size)
while fetched:
for result in fetched:
yield self._format_source_record(result)
fetched = results.fetchmany(batch_size)

_add_source_metadata_query = b"""
INSERT INTO disease_sources(
name, data_license, data_license_url, version, data_url, rdp_url,
Expand Down
5 changes: 0 additions & 5 deletions disease/etl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,6 @@ def get_latest_version(self) -> str:
"""
return bioversions.get_version(self._src_name)

@abstractmethod
def _download_data(self) -> None:
"""Download source data."""
raise NotImplementedError

def _zip_handler(self, dl_path: Path, outfile_path: Path) -> None:
"""Provide simple callback function to extract the largest file within a given
zipfile and save it within the appropriate data directory.
Expand Down
7 changes: 7 additions & 0 deletions disease/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ class DataLicenseAttributes(BaseModel):
attribution: StrictBool


class RecordType(str, Enum):
"""Record item types."""

IDENTITY = "identity"
MERGER = "merger"


class RefType(str, Enum):
"""Reference item types."""

Expand Down
27 changes: 24 additions & 3 deletions tests/unit/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@

import pytest

from disease.schemas import RecordType

IS_DDB = not os.environ.get("DISEASE_NORM_DB_URL", "").lower().startswith("postgres")
IS_TEST_ENV = os.environ.get("DISEASE_TEST", "").lower() == "true"


def test_tables_created(database):
"""Check that required tables are created."""
Expand All @@ -22,9 +27,6 @@ def test_tables_created(database):
assert database.disease_metadata_table in existing_tables


IS_DDB = not os.environ.get("DISEASE_NORM_DB_URL", "").lower().startswith("postgres")


@pytest.mark.skipif(not IS_DDB, reason="only applies to DynamoDB in test env")
def test_item_type(database):
"""Check that objects are tagged with item_type attribute."""
Expand Down Expand Up @@ -59,3 +61,22 @@ def test_item_type(database):
item = database.diseases.query(KeyConditionExpression=filter_exp)["Items"][0]
assert "item_type" in item
assert item["item_type"] == "merger"


@pytest.mark.skipif(not IS_TEST_ENV, reason="not in test environment")
def database(db_fixture):
"""Perform basic test of get_all_records method.
It's probably overkill (and unmaintainable) to do exact checks against every
record, but fairly easy to check against expected counts and ensure that nothing
is getting sent twice.
"""
source_records = list(db_fixture.get_all_records(RecordType.IDENTITY))
assert len(source_records) == 1463
source_ids = {r["concept_id"] for r in source_records}
assert len(source_ids) == 1463

normalized_records = list(db_fixture.get_all_records(RecordType.MERGER))
assert len(normalized_records) == 1391
normalized_ids = {r["concept_id"] for r in normalized_records}
assert len(normalized_ids) == 1391

0 comments on commit 9598b5a

Please sign in to comment.