Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: data padding bug in mongo_normalized #825

Merged
merged 5 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions databroker/mongo_normalized.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def populate_columns(keys, min_seq_num, max_seq_num):
# their size if we need to squeeze more performance out here. But maybe
# we can get away with never adding that complexity.
for key, est_row_bytesize in zip(nonscalars, estimated_nonscalar_row_bytesizes):
page_size = TARGET_PAGE_BYTESIZE // est_row_bytesize
page_size = max(1, TARGET_PAGE_BYTESIZE // est_row_bytesize)
boundaries = list(range(min_seq_num, 1 + max_seq_num, page_size))
if boundaries[-1] != max_seq_num:
boundaries.append(max_seq_num)
Expand Down Expand Up @@ -2209,10 +2209,11 @@ def default_validate_shape(key, data, expected_shape):
Check that data.shape == expected.shape.

* If number of dimensions differ, raise BadShapeMetadata
* If any dimension is larger than expected, raise BadShapeMetadata.
* If any dimension differs by more than MAX_SIZE_DIFF, raise BadShapeMetadata.
* If some dimensions are smaller than expected,, pad "right" edge of each
dimension that falls short with NaN.
"""
MAX_SIZE_DIFF = 2
if data.shape == expected_shape:
return data
if len(data.shape) != len(expected_shape):
Expand All @@ -2224,32 +2225,23 @@ def default_validate_shape(key, data, expected_shape):
)
# Pad at the "end" along any dimension that is too short.
padding = []
trimming = []
for actual, expected in zip(data.shape, expected_shape):
margin = expected - actual
# Limit how much padding or trimming we are willing to do.
SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE = 2
if abs(margin) > SOMEWHAT_ARBITRARY_LIMIT_OF_WHAT_IS_REASONABLE:
if abs(margin) > MAX_SIZE_DIFF:
raise BadShapeMetadata(
f"For data key {key} "
f"shape {data.shape} does not "
f"match expected shape {expected_shape}."
)
if margin > 0:
padding.append((0, margin))
trimming.append(slice(None, None))
elif margin < 0:
padding.append((0, 0))
trimming.append(slice(None))
else: # margin == 0
padding.append((0, 0))
trimming.append(slice(None, None))
# TODO Rethink this!
# We cannot do NaN because that does not work for integers
# and it is too late to change our mind about the data type.
padded = numpy.pad(data, padding, "edge")
padded_and_trimmed = padded[tuple(trimming)]
return padded_and_trimmed
padded = numpy.pad(data, padding, "edge")
return padded


def build_summary(run_start_doc, run_stop_doc, stream_names):
Expand Down
74 changes: 72 additions & 2 deletions databroker/tests/test_validate_shape.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from bluesky import RunEngine
from bluesky.plans import count
from ophyd.sim import img
from ophyd.sim import img, DirectImage
from tiled.client import Context, from_context
from tiled.server.app import build_app

from ..mongo_normalized import MongoAdapter
from ..mongo_normalized import MongoAdapter, BadShapeMetadata

import numpy as np
import pytest


def test_validate_shape(tmpdir):
Expand All @@ -29,3 +32,70 @@ def post_document(name, doc):
assert not shapes
client[uid]["primary"]["data"]["img"][:]
assert shapes


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11,)),
((10, 20), (10, 21)),
((10, 20), (10, 19)),
((10, 20, 30), (10, 21, 30)),
((10, 20, 30), (10, 20, 31)),
((20, 20, 20, 20), (20, 21, 20, 22)),
],
)
def test_padding(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
assert client[uid]["primary"]["data"]["img"][0].shape == expected_shape


@pytest.mark.parametrize(
"shape,expected_shape",
[
((10,), (11, 12)),
((10, 20), (10, 200)),
((20, 20, 20, 20), (20, 21, 20, 200)),
((10, 20), (5, 20)),
],
)
def test_default_validate_shape(tmpdir, shape, expected_shape):
adapter = MongoAdapter.from_mongomock()

direct_img = DirectImage(
func=lambda: np.array(np.ones(shape)), name="direct", labels={"detectors"}
)
direct_img.img.name = "img"

with Context.from_app(build_app(adapter), token_cache=tmpdir) as context:
client = from_context(context)

def post_document(name, doc):
if name == "descriptor":
doc["data_keys"]["img"]["shape"] = expected_shape

client.post_document(name, doc)

RE = RunEngine()
RE.subscribe(post_document)
(uid,) = RE(count([direct_img]))
with pytest.raises(BadShapeMetadata):
client[uid]["primary"]["data"]["img"][:]
Loading