diff --git a/databroker/mongo_normalized.py b/databroker/mongo_normalized.py index 3ad914510..15d258a8f 100644 --- a/databroker/mongo_normalized.py +++ b/databroker/mongo_normalized.py @@ -4,6 +4,7 @@ import copy from datetime import datetime, timedelta import functools +import inspect import itertools import logging import os @@ -160,7 +161,8 @@ def structure_from_descriptor(descriptor, sub_dict, max_seq_num, unicode_columns numpy_dtype = dtype.to_numpy_dtype() if "chunks" in field_metadata: # If the Event Descriptor tells us a preferred chunking, use that. - suggested_chunks = tuple(tuple(chunks) for chunks in field_metadata["chunks"]) + suggested_chunks = [tuple(chunk) if isinstance(chunk, list) + else chunk for chunk in field_metadata['chunks']] elif (0 in shape) or (numpy_dtype.itemsize == 0): # special case to avoid warning from dask suggested_chunks = shape @@ -931,6 +933,9 @@ def populate_columns(keys, min_seq_num, max_seq_num): map( lambda item: self.validate_shape( key, numpy.asarray(item), expected_shape + ) if 'uid' in inspect.signature(self.validate_shape).parameters + else self.validate_shape( + key, numpy.asarray(item), expected_shape, uid=self._run.metadata()['start']['uid'] ), result[key], ) @@ -2204,14 +2209,14 @@ class BadShapeMetadata(Exception): pass -def default_validate_shape(key, data, expected_shape): +def default_validate_shape(key, data, expected_shape, uid=None): """ Check that data.shape == expected.shape. * If number of dimensions differ, raise BadShapeMetadata * If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata. * If some dimensions are smaller than expected,, pad "right" edge of each - dimension that falls short with NaN. + dimension that falls short with zeros. """ MAX_SIZE_DIFF = 2 if data.shape == expected_shape: @@ -2241,6 +2246,11 @@ def default_validate_shape(key, data, expected_shape): else: # margin == 0 padding.append((0, 0)) padded = numpy.pad(data, padding, "edge") + + logger.warning(f"The data.shape: {data.shape} did not match the expected_shape: " + f"{expected_shape} for key: '{key}'. This data has been zero-padded " + "to match the expected_shape! RunStart UID: {uid}") + return padded diff --git a/databroker/tests/conftest.py b/databroker/tests/conftest.py index 7d8468bd1..d7864a569 100644 --- a/databroker/tests/conftest.py +++ b/databroker/tests/conftest.py @@ -1,5 +1,3 @@ -import contextlib -import getpass import os import pytest import sys @@ -114,22 +112,3 @@ def delete_dm(): @pytest.fixture(params=['scalar', 'image', 'external_image']) def detector(request, hw): return getattr(hw, SIM_DETECTORS[request.param]) - - -@pytest.fixture -def enter_password(monkeypatch): - """ - Return a context manager that overrides getpass, used like: - - >>> with enter_password(...): - ... # Run code that calls getpass.getpass(). - """ - - @contextlib.contextmanager - def f(password): - original = getpass.getpass - monkeypatch.setattr("getpass.getpass", lambda: password) - yield - monkeypatch.setattr("getpass.getpass", original) - - return f diff --git a/databroker/tests/test_access_policy.py b/databroker/tests/test_access_policy.py index 6a18a5985..3c2dd97b2 100644 --- a/databroker/tests/test_access_policy.py +++ b/databroker/tests/test_access_policy.py @@ -2,7 +2,7 @@ from bluesky.plans import count from tiled.client import Context, from_context from tiled.server.app import build_app_from_config - +from tiled._tests.utils import enter_username_password from ..mongo_normalized import MongoAdapter, SimpleAccessPolicy @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs): InstrumentedMongoAdapter.from_mongomock(access_policy=access_policy) -def test_access_policy_example(tmpdir, enter_password): +def test_access_policy_example(tmpdir): config = { "authentication": { @@ -52,8 +52,8 @@ def test_access_policy_example(tmpdir, enter_password): ], } with Context.from_app(build_app_from_config(config), token_cache=tmpdir) as context: - with enter_password("secret"): - client = from_context(context, username="alice", prompt_for_reauthentication=True) + with enter_username_password("alice", "secret"): + client = from_context(context, prompt_for_reauthentication=True) def post_document(name, doc): client.post_document(name, doc) diff --git a/databroker/tests/test_validate_shape.py b/databroker/tests/test_validate_shape.py index df563e384..0f2ca4c89 100644 --- a/databroker/tests/test_validate_shape.py +++ b/databroker/tests/test_validate_shape.py @@ -10,42 +10,61 @@ import pytest -def test_validate_shape(tmpdir): - # custom_validate_shape will mutate this to show it has been called - shapes = [] - - def custom_validate_shape(key, data, expected_shape): - shapes.append(expected_shape) - return data +@pytest.mark.parametrize( + "shape,expected_shape", + [ + ((10,), (11,)), # Short by 1, 1d. + ((10, 20), (10, 21)), # Short by 1, 2d. + ((10, 20, 30), (10, 21, 30)), # Short by 1, 3d. + ((10, 20, 30), (10, 20, 31)), # Short by 1, 3d. + ((10, 20), (10, 19)), # Too-big by 1, 2d. + ((20, 20, 20, 20), (20, 21, 20, 22)), # 4d example. + ], +) +def test_padding(tmpdir, shape, expected_shape): + adapter = MongoAdapter.from_mongomock() - adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape) + direct_img = DirectImage( + func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"} + ) + direct_img.img.name = "img" with Context.from_app(build_app(adapter), token_cache=tmpdir) as context: client = from_context(context) def post_document(name, doc): + if name == "descriptor": + doc["data_keys"]["img"]["shape"] = expected_shape + client.post_document(name, doc) RE = RunEngine() RE.subscribe(post_document) - (uid,) = RE(count([img])) - assert not shapes - client[uid]["primary"]["data"]["img"][:] - assert shapes + (uid,) = RE(count([direct_img])) + assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape @pytest.mark.parametrize( - "shape,expected_shape", + "chunks,shape,expected_chunks", [ - ((10,), (11,)), - ((10, 20), (10, 21)), - ((10, 20), (10, 19)), - ((10, 20, 30), (10, 21, 30)), - ((10, 20, 30), (10, 20, 31)), - ((20, 20, 20, 20), (20, 21, 20, 22)), + ([1, 2], (10,), ((1,), (2, 2, 2, 2, 2))), # 1D image + ([1, 3], (10,), ((1,), (3, 3, 3, 1))), # not evenly divisible. + ([1, 2, 2], (10, 10), ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2))), # 2D + ([1, 2, -1], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # -1 for max size. + ([1, 2, "auto"], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # auto + ( + ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)), + (10, 10), + ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)), + ), # normalized chunks + ( + [1, 5, "auto", -1, 5], + (10, 10, 10, 10), + ((1,), (5, 5), (10,), (10,), (5, 5)) + ), # mixture of things. ], ) -def test_padding(tmpdir, shape, expected_shape): +def test_custom_chunking(tmpdir, chunks, shape, expected_chunks): adapter = MongoAdapter.from_mongomock() direct_img = DirectImage( @@ -54,30 +73,30 @@ def test_padding(tmpdir, shape, expected_shape): direct_img.img.name = "img" with Context.from_app(build_app(adapter), token_cache=tmpdir) as context: - client = from_context(context) + client = from_context(context, "dask") def post_document(name, doc): if name == "descriptor": - doc["data_keys"]["img"]["shape"] = expected_shape + doc["data_keys"]["img"]["chunks"] = chunks client.post_document(name, doc) RE = RunEngine() RE.subscribe(post_document) (uid,) = RE(count([direct_img])) - assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape + assert client[uid]["primary"]["data"]["img"].chunks == expected_chunks @pytest.mark.parametrize( "shape,expected_shape", [ - ((10,), (11, 12)), - ((10, 20), (10, 200)), - ((20, 20, 20, 20), (20, 21, 20, 200)), - ((10, 20), (5, 20)), + ((10,), (11, 12)), # Different number of dimensions. + ((10, 20), (10, 200)), # Dimension sizes differ by more than 2. + ((20, 20, 20, 20), (20, 21, 20, 200)), # Dimension sizes differ by more than 2. + ((10, 20), (5, 20)), # Data is bigger than expected. ], ) -def test_default_validate_shape(tmpdir, shape, expected_shape): +def test_validate_shape_exceptions(tmpdir, shape, expected_shape): adapter = MongoAdapter.from_mongomock() direct_img = DirectImage( @@ -99,3 +118,27 @@ def post_document(name, doc): (uid,) = RE(count([direct_img])) with pytest.raises(BadShapeMetadata): client[uid]["primary"]["data"]["img"][:] + + +def test_custom_validate_shape(tmpdir): + # custom_validate_shape will mutate this to show it has been called + shapes = [] + + def custom_validate_shape(key, data, expected_shape): + shapes.append(expected_shape) + return data + + adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape) + + with Context.from_app(build_app(adapter), token_cache=tmpdir) as context: + client = from_context(context) + + def post_document(name, doc): + client.post_document(name, doc) + + RE = RunEngine() + RE.subscribe(post_document) + (uid,) = RE(count([img])) + assert not shapes + client[uid]["primary"]["data"]["img"][:] + assert shapes