Skip to content

Commit

Permalink
Merge pull request #788 from jmaruland/define-validate-shape-of-descr…
Browse files Browse the repository at this point in the history
…iptor

Define validate shape of descriptor
  • Loading branch information
jmaruland authored Dec 14, 2023
2 parents e05403a + aedd05d commit dc50a3f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 3 deletions.
26 changes: 23 additions & 3 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,6 +514,7 @@ def __init__(
event_collection,
root_map,
sub_dict,
validate_shape,
):
self._run = run
self._stream_name = stream_name
Expand All @@ -522,6 +523,7 @@ def __init__(
self._event_collection = event_collection
self._sub_dict = sub_dict
self.root_map = root_map
self.validate_shape = validate_shape

# metadata should look like
# {
Expand Down Expand Up @@ -851,7 +853,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
if expected_shape and (not is_external):
validated_column = list(
map(
lambda item: _validate_shape(
lambda item: self.validate_shape(
key, numpy.asarray(item), expected_shape
),
result[key],
Expand Down Expand Up @@ -936,7 +938,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
last_datum_id=None,
)
filled_data = filled_mock_event["data"][key]
validated_filled_data = _validate_shape(
validated_filled_data = self.validate_shape(
key, filled_data, expected_shape
)
filled_column.append(validated_filled_data)
Expand Down Expand Up @@ -1047,6 +1049,7 @@ def from_uri(
access_policy=None,
cache_ttl_complete=60, # seconds
cache_ttl_partial=2, # seconds
validate_shape=None
):
"""
Create a MongoAdapter from MongoDB with the "normalized" (original) layout.
Expand Down Expand Up @@ -1094,6 +1097,9 @@ def from_uri(
cache_ttl_complete : float
Time (in seconds) to cache a *complete* BlueskyRun before checking
the database for updates. Default 60.
validate_shape: func
function that will be used to validate that the shape of the data matches
the shape in the descriptor document
"""
metadatastore_db = _get_database(uri)
if asset_registry_uri is None:
Expand Down Expand Up @@ -1122,6 +1128,7 @@ def from_uri(
cache_of_partial_bluesky_runs=cache_of_partial_bluesky_runs,
metadata=metadata,
access_policy=access_policy,
validate_shape=validate_shape,
)

@classmethod
Expand All @@ -1135,6 +1142,7 @@ def from_mongomock(
access_policy=None,
cache_ttl_complete=60, # seconds
cache_ttl_partial=2, # seconds
validate_shape=None
):
"""
Create a transient MongoAdapter from backed by "mongomock".
Expand Down Expand Up @@ -1178,6 +1186,9 @@ def from_mongomock(
cache_ttl_complete : float
Time (in seconds) to cache a *complete* BlueskyRun before checking
the database for updates. Default 60.
validate_shape: func
function that will be used to validate that the shape of the data matches
the shape in the descriptor document
"""
import mongomock

Expand Down Expand Up @@ -1205,6 +1216,7 @@ def from_mongomock(
cache_of_partial_bluesky_runs=cache_of_partial_bluesky_runs,
metadata=metadata,
access_policy=access_policy,
validate_shape=validate_shape,
)

def __init__(
Expand All @@ -1220,6 +1232,7 @@ def __init__(
queries=None,
sorting=None,
access_policy=None,
validate_shape=None,
):
"This is not user-facing. Use MongoAdapter.from_uri."
self._run_start_collection = metadatastore_db.get_collection("run_start")
Expand Down Expand Up @@ -1249,6 +1262,11 @@ def __init__(
self._sorting = sorting
self.access_policy = access_policy
self._serializer = None
if validate_shape is None:
validate_shape = default_validate_shape
elif isinstance(validate_shape, str):
validate_shape = import_object(validate_shape)
self.validate_shape = validate_shape
super().__init__()

@property
Expand Down Expand Up @@ -1441,6 +1459,7 @@ def _build_event_stream(self, *, run_start_uid, stream_name, is_complete):
event_collection=self._event_collection,
root_map=self.root_map,
sub_dict="data",
validate_shape=self.validate_shape,
),
"timestamps": lambda: DatasetFromDocuments(
run=run,
Expand All @@ -1450,6 +1469,7 @@ def _build_event_stream(self, *, run_start_uid, stream_name, is_complete):
event_collection=self._event_collection,
root_map=self.root_map,
sub_dict="timestamps",
validate_shape=self.validate_shape,
),
"config": lambda: Config(
OneShotCachedMap(
Expand Down Expand Up @@ -2095,7 +2115,7 @@ class BadShapeMetadata(Exception):
pass


def _validate_shape(key, data, expected_shape):
def default_validate_shape(key, data, expected_shape):
"""
Check that data.shape == expected.shape.
Expand Down
31 changes: 31 additions & 0 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from bluesky import RunEngine
from bluesky.plans import count
from ophyd.sim import img
from tiled.client import Context, from_context
from tiled.server.app import build_app

from ..mongo_normalized import MongoAdapter


def test_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes

0 comments on commit dc50a3f

Please sign in to comment.