Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Chunking and Padding #829

Merged
merged 14 commits into from
Oct 29, 2024
10 changes: 8 additions & 2 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,8 @@ def structure_from_descriptor(descriptor, sub_dict, max_seq_num, unicode_columns
numpy_dtype = dtype.to_numpy_dtype()
if "chunks" in field_metadata:
# If the Event Descriptor tells us a preferred chunking, use that.
suggested_chunks = tuple(tuple(chunks) for chunks in field_metadata["chunks"])
suggested_chunks = [tuple(chunk) if isinstance(chunk, list)
else chunk for chunk in field_metadata['chunks']]
elif (0 in shape) or (numpy_dtype.itemsize == 0):
# special case to avoid warning from dask
suggested_chunks = shape
Expand Down Expand Up @@ -2211,7 +2212,7 @@ def default_validate_shape(key, data, expected_shape):
* If number of dimensions differ, raise BadShapeMetadata
* If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata.
* If some dimensions are smaller than expected,, pad "right" edge of each
dimension that falls short with NaN.
dimension that falls short with zeros..
"""
MAX_SIZE_DIFF = 2
if data.shape == expected_shape:
Expand Down Expand Up @@ -2241,6 +2242,11 @@ def default_validate_shape(key, data, expected_shape):
else: # margin == 0
padding.append((0, 0))
padded = numpy.pad(data, padding, "edge")

logger.warning(f"The data.shape: {data.shape} did not match the expected_shape: "
danielballan marked this conversation as resolved.
Show resolved Hide resolved
f"{expected_shape} for key: '{key}'. This data has been zero-padded "
"to match the expected_shape!")

return padded


Expand Down
100 changes: 72 additions & 28 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,42 +10,61 @@
import pytest


def test_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data
@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11,)), # Short by 1, 1d.
((10, 20), (10, 21)), # Short by 1, 2d.
((10, 20, 30), (10, 21, 30)), # Short by 1, 3d.
((10, 20, 30), (10, 20, 31)), # Short by 1, 3d.
((10, 20), (10, 19)), # Too-big by 1, 2d.
((20, 20, 20, 20), (20, 21, 20, 22)), # 4d example.
],
)
def test_padding(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)
direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape


@pytest.mark.parametrize(
"shape,expected_shape",
"chunks,shape,expected_chunks",
[
((10,), (11,)),
((10, 20), (10, 21)),
((10, 20), (10, 19)),
((10, 20, 30), (10, 21, 30)),
((10, 20, 30), (10, 20, 31)),
((20, 20, 20, 20), (20, 21, 20, 22)),
([1, 2], (10,), ((1,), (2, 2, 2, 2, 2))), # 1D image
([1, 3], (10,), ((1,), (3, 3, 3, 1))), # not evenly divisible.
([1, 2, 2], (10, 10), ((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2))), # 2D
([1, 2, -1], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # -1 for max size.
([1, 2, "auto"], (10, 10), ((1,), (2, 2, 2, 2, 2), (10,))), # auto
(
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
(10, 10),
((1,), (2, 2, 2, 2, 2), (2, 2, 2, 2, 2)),
), # normalized chunks
(
[1, 5, "auto", -1, 5],
(10, 10, 10, 10),
((1,), (5, 5), (10,), (10,), (5, 5))
), # mixture of things.
],
)
def test_padding(tmpdir, shape, expected_shape):
def test_custom_chunking(tmpdir, chunks, shape, expected_chunks):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -54,30 +73,31 @@ def test_padding(tmpdir, shape, expected_shape):
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)
client = from_context(context, "dask")

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape
doc["data_keys"]["img"]["chunks"] = chunks

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape
assert client[uid]["primary"]["data"]["img"].chunks == expected_chunks
# assert client[uid]["primary"]["data"]["img"][0].shape == shape


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11, 12)),
((10, 20), (10, 200)),
((20, 20, 20, 20), (20, 21, 20, 200)),
((10, 20), (5, 20)),
((10,), (11, 12)), # Different number of dimensions.
((10, 20), (10, 200)), # Dimension sizes differ by more than 2.
((20, 20, 20, 20), (20, 21, 20, 200)), # Dimension sizes differ by more than 2.
((10, 20), (5, 20)), # Data is bigger than expected.
],
)
def test_default_validate_shape(tmpdir, shape, expected_shape):
def test_validate_shape_exceptions(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
Expand All @@ -99,3 +119,27 @@ def post_document(name, doc):
(uid,) = RE(count([direct_img]))
with pytest.raises(BadShapeMetadata):
client[uid]["primary"]["data"]["img"][:]


def test_custom_validate_shape(tmpdir):
# custom_validate_shape will mutate this to show it has been called
shapes = []

def custom_validate_shape(key, data, expected_shape):
shapes.append(expected_shape)
return data

adapter = MongoAdapter.from_mongomock(validate_shape=custom_validate_shape)

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([img]))
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes
5 changes: 5 additions & 0 deletions requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
area-detector-handlers
bluesky
boltons
codecov
coverage
doct
flake8
glue-core <1.18
glueviz
matplotlib
mongomock
mongoquery
msgpack >=1.0.0
numpy >=1.16.0 # for astropy (via glueviz)
ophyd
pims
pyqt5 !=5.14.1
pytest >=4.4,!=5.4.0
pytest-rerunfailures
Expand All @@ -18,5 +22,6 @@ suitcase-jsonl >=0.1.0b2
suitcase-mongo >=0.5.0
suitcase-msgpack >=0.2.2
tiled[all] >=0.1.0b9
tzlocal
ujson
vcrpy
Loading