Skip to content

Commit

Permalink
[ancestry] Issue #421 #422: Add support for assembly on ancestry modu…
Browse files Browse the repository at this point in the history
…le (#427)

* Dynamically load model based on the assembly passed in the job request
* Cache up to 2 models to improve startup time

Stacked on
159bdcf

The commit for this PR:
1e64548

Also addresses #422
  • Loading branch information
akotlar authored Mar 12, 2024
2 parents e503af6 + 70e03fb commit fdc2832
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 53 deletions.
126 changes: 83 additions & 43 deletions python/python/bystro/ancestry/listener.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
"""Provide a worker for the ancestry model."""

import argparse
from collections.abc import Callable
import logging
from pathlib import Path

import boto3 # type: ignore
from botocore.exceptions import ClientError # type: ignore
import msgspec
import pandas as pd
import pyarrow.dataset as ds # type: ignore
Expand All @@ -17,6 +17,8 @@
from bystro.beanstalkd.messages import BaseMessage, CompletedJobMessage, SubmittedJobMessage
from bystro.beanstalkd.worker import ProgressPublisher, QueueConf, get_progress_reporter, listen

from bystro.utils.timer import Timer

logging.basicConfig(
filename="ancestry_listener.log",
level=logging.DEBUG,
Expand All @@ -30,19 +32,66 @@
PCA_FILE = "pca.csv"
RFC_FILE = "rfc.skop"

models_cache: dict[str, AncestryModel] = {}


def _get_model_from_s3(assembly: str) -> AncestryModel:
if assembly in models_cache:
logger.info("Model for assembly %s found in cache.", assembly)
return models_cache[assembly]

def _get_model_from_s3() -> AncestryModel:
s3_client = boto3.client("s3")

s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=PCA_FILE, Filename=PCA_FILE)
s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=RFC_FILE, Filename=RFC_FILE)
pca_local_key = f"{assembly}_pca.csv"
rfc_local_key = f"{assembly}_rfc.skop"

pca_file_key = f"{assembly}/{pca_local_key}"
rfc_file_key = f"{assembly}/{rfc_local_key}"

logger.info("Downloading PCA file %s", pca_file_key)

with Timer() as timer:
try:
s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=pca_file_key, Filename=pca_local_key)
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
raise ValueError(
f"{assembly} ancestry PCA file not found. This assembly is not supported."
)
raise # Re-raise the exception if it's not a "NoSuchKey" error

try:
s3_client.download_file(Bucket=ANCESTRY_BUCKET, Key=rfc_file_key, Filename=rfc_local_key)
except ClientError as e:
if e.response["Error"]["Code"] == "NoSuchKey":
raise ValueError(
f"{assembly} ancestry model not found. This assembly is not supported."
)
raise

logger.debug("Downloaded PCA file and RFC file in %f seconds", timer.elapsed_time)

with Timer() as timer:
logger.info("Loading PCA file %s", pca_local_key)
pca_loadings_df = pd.read_csv(pca_local_key, index_col=0)

logger.info("Loading RFC file %s", rfc_local_key)
rfc = skops_load(rfc_local_key)

logger.debug("Loaded PCA and RFC files in %f seconds", timer.elapsed_time)

logger.info("Loading PCA file %s", PCA_FILE)
pca_loadings_df = pd.read_csv(PCA_FILE, index_col=0)
logger.info("Loading RFC file %s", RFC_FILE)
rfc = skops_load(RFC_FILE)
logger.info("Loaded ancestry models from S3")
return AncestryModel(pca_loadings_df, rfc)

model = AncestryModel(pca_loadings_df, rfc)

# Update the cache with the new model
if len(models_cache) >= 2:
# Remove the oldest loaded model to maintain cache size
oldest_assembly = next(iter(models_cache))
del models_cache[oldest_assembly]
models_cache[assembly] = model

return model


class AncestryJobData(BaseMessage, frozen=True, rename="camel"):
Expand All @@ -57,10 +106,13 @@ class AncestryJobData(BaseMessage, frozen=True, rename="camel"):
The path to the dosage matrix file.
out_dir: str
The directory to write the results to.
assembly: str
The genome assembly used for the dosage matrix.
"""

dosage_matrix_path: str
out_dir: str
assembly: str


class AncestryJobCompleteMessage(CompletedJobMessage, frozen=True, kw_only=True, rename="camel"):
Expand All @@ -76,25 +128,20 @@ def _load_queue_conf(queue_conf_path: str) -> QueueConf:
return QueueConf(addresses=beanstalk_conf["addresses"], tubes=beanstalk_conf["tubes"])


def handler_fn_factory(
ancestry_model: AncestryModel,
) -> Callable[[ProgressPublisher, AncestryJobData], AncestryResults]:
"""Return partialed handler_fn with ancestry_model loaded."""

def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> AncestryResults:
"""Do ancestry job, wrapping infer_ancestry for beanstalk."""
# Separating _handler_fn from infer_ancestry in order to separate ML from infra concerns,
# and especially to keep infer_ancestry eager.
def handler_fn(publisher: ProgressPublisher, job_data: AncestryJobData) -> AncestryResults:
"""Do ancestry job, wrapping infer_ancestry for beanstalk."""
# Separating _handler_fn from infer_ancestry in order to separate ML from infra concerns,
# and especially to keep infer_ancestry eager.

# not doing anything with this reporter at the moment, we're
# simply threading it through for later.
_reporter = get_progress_reporter(publisher)
# not doing anything with this reporter at the moment, we're
# simply threading it through for later.
_reporter = get_progress_reporter(publisher)

dataset = ds.dataset(job_data.dosage_matrix_path, format="arrow")
dataset = ds.dataset(job_data.dosage_matrix_path, format="arrow")

return infer_ancestry(ancestry_model, dataset)
ancestry_model = _get_model_from_s3(job_data.assembly)

return handler_fn
return infer_ancestry(ancestry_model, dataset)


def submit_msg_fn(ancestry_job_data: AncestryJobData) -> SubmittedJobMessage:
Expand All @@ -121,22 +168,6 @@ def completed_msg_fn(
)


def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None:
"""Run ancestry listener."""
handler_fn_with_models = handler_fn_factory(ancestry_model)
logger.info(
"Ancestry worker is listening on addresses: %s, tube: %s...", queue_conf.addresses, ANCESTRY_TUBE
)
listen(
AncestryJobData,
handler_fn_with_models,
submit_msg_fn,
completed_msg_fn,
queue_conf,
ANCESTRY_TUBE,
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Process some config files.")
parser.add_argument(
Expand All @@ -146,8 +177,17 @@ def main(ancestry_model: AncestryModel, queue_conf: QueueConf) -> None:
required=True,
)
args = parser.parse_args()

ancestry_model = _get_model_from_s3()
queue_conf = _load_queue_conf(args.queue_conf)

main(ancestry_model, queue_conf)
logger.info(
"Ancestry worker is listening on addresses: %s, tube: %s...", queue_conf.addresses, ANCESTRY_TUBE
)

listen(
job_data_type=AncestryJobData,
handler_fn=handler_fn,
submit_msg_fn=submit_msg_fn,
completed_msg_fn=completed_msg_fn,
queue_conf=queue_conf,
tube=ANCESTRY_TUBE,
)
2 changes: 1 addition & 1 deletion python/python/bystro/ancestry/tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_infer_ancestry():

@pytest.mark.integration()
def test_infer_ancestry_from_model():
ancestry_model = _get_model_from_s3()
ancestry_model = _get_model_from_s3("hg38")

# Generate an arrow table that contains genotype dosages for 1000 samples
variants = list(ancestry_model.pca_loadings_df.index)
Expand Down
24 changes: 15 additions & 9 deletions python/python/bystro/ancestry/tests/test_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

from bystro.ancestry.listener import (
AncestryJobData,
handler_fn_factory,
handler_fn,
submit_msg_fn,
completed_msg_fn,
SubmittedJobMessage,
AncestryJobCompleteMessage,
AncestryResults,
AncestryResults
)
from bystro.ancestry.tests.test_inference import (
ANCESTRY_MODEL,
Expand All @@ -20,21 +20,20 @@
from bystro.beanstalkd.worker import ProgressPublisher


handler_fn = handler_fn_factory(ANCESTRY_MODEL)


def test_submit_fn():
ancestry_job_data = AncestryJobData(
submission_id="my_submission_id2",
dosage_matrix_path="some_dosage.feather",
out_dir="/path/to/some/dir",
assembly="hg38",
)
submitted_job_message = submit_msg_fn(ancestry_job_data)

assert isinstance(submitted_job_message, SubmittedJobMessage)


def test_handler_fn_happy_path(tmpdir):
def test_handler_fn_happy_path(mocker, tmpdir):
mocker.patch("bystro.ancestry.listener._get_model_from_s3", return_value=ANCESTRY_MODEL)
dosage_path = "some_dosage.feather"
f1 = tmpdir.join(dosage_path)

Expand All @@ -45,7 +44,7 @@ def test_handler_fn_happy_path(tmpdir):
host="127.0.0.1", port=1234, queue="my_queue", message=progress_message
)
ancestry_job_data = AncestryJobData(
submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir)
submission_id="my_submission_id2", dosage_matrix_path=f1, out_dir=str(tmpdir), assembly="hg38"
)
ancestry_response = handler_fn(publisher, ancestry_job_data)

Expand All @@ -62,7 +61,10 @@ def test_handler_fn_happy_path(tmpdir):

def test_completion_fn(tmpdir):
ancestry_job_data = AncestryJobData(
submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir=str(tmpdir)
submission_id="my_submission_id2",
dosage_matrix_path="some_dosage.feather",
out_dir=str(tmpdir),
assembly="hg38",
)

ancestry_results, _ = _infer_ancestry()
Expand Down Expand Up @@ -93,14 +95,18 @@ def test_completion_message():

def test_job_data_from_beanstalkd():
ancestry_job_data = AncestryJobData(
submission_id="my_submission_id2", dosage_matrix_path="some_dosage.feather", out_dir="/foo"
submission_id="my_submission_id2",
dosage_matrix_path="some_dosage.feather",
out_dir="/foo",
assembly="hg38",
)

serialized_values = json.encode(ancestry_job_data)
expected_value = {
"submissionId": "my_submission_id2",
"dosageMatrixPath": "some_dosage.feather",
"outDir": "/foo",
"assembly": "hg38",
}
serialized_expected_value = json.encode(expected_value)

Expand Down

0 comments on commit fdc2832

Please sign in to comment.