Skip to content

Commit

Permalink
Merge pull request #829 from gwbischof/chunking
Browse files Browse the repository at this point in the history
Chunking and Padding
  • Loading branch information
danielballan authored Oct 29, 2024
2 parents 2550bd8 + 81c122b commit c3bd898
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 56 deletions.
16 changes: 13 additions & 3 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
from datetime import datetime, timedelta
import functools
import inspect
import itertools
import logging
import os
Expand Down Expand Up @@ -160,7 +161,8 @@ def structure_from_descriptor(descriptor, sub_dict, max_seq_num, unicode_columns
numpy_dtype = dtype.to_numpy_dtype()
if "chunks" in field_metadata:
# If the Event Descriptor tells us a preferred chunking, use that.
suggested_chunks = tuple(tuple(chunks) for chunks in field_metadata["chunks"])
suggested_chunks = [tuple(chunk) if isinstance(chunk, list)
else chunk for chunk in field_metadata['chunks']]
elif (0 in shape) or (numpy_dtype.itemsize == 0):
# special case to avoid warning from dask
suggested_chunks = shape
Expand Down Expand Up @@ -931,6 +933,9 @@ def populate_columns(keys, min_seq_num, max_seq_num):
map(
lambda item: self.validate_shape(
key, numpy.asarray(item), expected_shape
) if 'uid' in inspect.signature(self.validate_shape).parameters
else self.validate_shape(
key, numpy.asarray(item), expected_shape, uid=self._run.metadata()['start']['uid']
),
result[key],
)
Expand Down Expand Up @@ -2204,14 +2209,14 @@ class BadShapeMetadata(Exception):
pass


def default_validate_shape(key, data, expected_shape):
def default_validate_shape(key, data, expected_shape, uid=None):
"""
Check that data.shape == expected.shape.
* If number of dimensions differ, raise BadShapeMetadata
* If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata.
* If some dimensions are smaller than expected,, pad "right" edge of each
dimension that falls short with NaN.
dimension that falls short with zeros.
"""
MAX_SIZE_DIFF = 2
if data.shape == expected_shape:
Expand Down Expand Up @@ -2241,6 +2246,11 @@ def default_validate_shape(key, data, expected_shape):
else: # margin == 0
padding.append((0, 0))
padded = numpy.pad(data, padding, "edge")

logger.warning(f"The data.shape: {data.shape} did not match the expected_shape: "
f"{expected_shape} for key: '{key}'. This data has been zero-padded "
"to match the expected_shape! RunStart UID: {uid}")

return padded


Expand Down
21 changes: 0 additions & 21 deletions databroker/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import contextlib
import getpass
import os
import pytest
import sys
Expand Down Expand Up @@ -114,22 +112,3 @@ def delete_dm():
@pytest.fixture(params=['scalar', 'image', 'external_image'])
def detector(request, hw):
return getattr(hw, SIM_DETECTORS[request.param])


@pytest.fixture
def enter_password(monkeypatch):
"""
Return a context manager that overrides getpass, used like:
>>> with enter_password(...):
... # Run code that calls getpass.getpass().
"""

@contextlib.contextmanager
def f(password):
original = getpass.getpass
monkeypatch.setattr("getpass.getpass", lambda: password)
yield
monkeypatch.setattr("getpass.getpass", original)

return f
8 changes: 4 additions & 4 deletions databroker/tests/test_access_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from bluesky.plans import count
from tiled.client import Context, from_context
from tiled.server.app import build_app_from_config

from tiled._tests.utils import enter_username_password
from ..mongo_normalized import MongoAdapter, SimpleAccessPolicy


Expand All @@ -22,7 +22,7 @@ def __init__(self, *args, **kwargs):
InstrumentedMongoAdapter.from_mongomock(access_policy=access_policy)


def test_access_policy_example(tmpdir, enter_password):
def test_access_policy_example(tmpdir):

config = {
"authentication": {
Expand Down Expand Up @@ -52,8 +52,8 @@ def test_access_policy_example(tmpdir, enter_password):
],
}
with Context.from_app(build_app_from_config(config), token_cache=tmpdir) as context:
with enter_password("secret"):
client = from_context(context, username="alice", prompt_for_reauthentication=True)
with enter_username_password("alice", "secret"):
client = from_context(context, prompt_for_reauthentication=True)

def post_document(name, doc):
client.post_document(name, doc)
Expand Down
99 changes: 71 additions & 28 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,61 @@
import pytest


def test_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data
@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11,)), # Short by 1, 1d.
((10, 20), (10, 21)), # Short by 1, 2d.
((10, 20, 30), (10, 21, 30)), # Short by 1, 3d.
((10, 20, 30), (10, 20, 31)), # Short by 1, 3d.
((10, 20), (10, 19)), # Too-big by 1, 2d.
((20, 20, 20, 20), (20, 21, 20, 22)), # 4d example.
],
)
def test_padding(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)
direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape


@pytest.mark.parametrize(
"shape,expected_shape",
"chunks,shape,expected_chunks",
[
((10,), (11,)),
((10, 20), (10, 21)),
((10, 20), (10, 19)),
((10, 20, 30), (10, 21, 30)),
((10, 20, 30), (10, 20, 31)),
((20, 20, 20, 20), (20, 21, 20, 22)),
([1, 2], (10,), ((1,), (2, 2, 2, 2, 2))), # 1D image
([1, 3], (10,), ((1,), (3, 3, 3, 1))), # not evenly divisible.
([1, 2, 2], (10, 10), ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2))), # 2D
([1, 2, -1], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # -1 for max size.
([1, 2, "auto"], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # auto
(
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
(10, 10),
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
), # normalized chunks
(
[1, 5, "auto", -1, 5],
(10, 10, 10, 10),
((1,), (5, 5), (10,), (10,), (5, 5))
), # mixture of things.
],
)
def test_padding(tmpdir, shape, expected_shape):
def test_custom_chunking(tmpdir, chunks, shape, expected_chunks):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -54,30 +73,30 @@ def test_padding(tmpdir, shape, expected_shape):
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)
client = from_context(context, "dask")

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape
doc["data_keys"]["img"]["chunks"] = chunks

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape
assert client[uid]["primary"]["data"]["img"].chunks == expected_chunks


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11, 12)),
((10, 20), (10, 200)),
((20, 20, 20, 20), (20, 21, 20, 200)),
((10, 20), (5, 20)),
((10,), (11, 12)), # Different number of dimensions.
((10, 20), (10, 200)), # Dimension sizes differ by more than 2.
((20, 20, 20, 20), (20, 21, 20, 200)), # Dimension sizes differ by more than 2.
((10, 20), (5, 20)), # Data is bigger than expected.
],
)
def test_default_validate_shape(tmpdir, shape, expected_shape):
def test_validate_shape_exceptions(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -99,3 +118,27 @@ def post_document(name, doc):
(uid,) = RE(count([direct_img]))
with pytest.raises(BadShapeMetadata):
client[uid]["primary"]["data"]["img"][:]


def test_custom_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes

0 comments on commit c3bd898

Please sign in to comment.