Skip to content

Commit

Permalink
Merge pull request #148 from sbesson/hcs_axes_fix
Browse files Browse the repository at this point in the history
Fix remaining assumptions on 5D dimensions
  • Loading branch information
sbesson authored Jan 19, 2022
2 parents 804dccf + 962031f commit 0a3e927
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 24 deletions.
44 changes: 28 additions & 16 deletions ome_zarr/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.specs.append(PlateLabels(self))
elif Plate.matches(zarr):
self.specs.append(Plate(self))
self.add(zarr, plate_labels=True)
# self.add(zarr, plate_labels=True)
if Well.matches(zarr):
self.specs.append(Well(self))

Expand Down Expand Up @@ -412,6 +412,8 @@ def __init__(self, node: Node) -> None:
# Use first Field for rendering settings, shape etc.
image_zarr = self.zarr.create(image_paths[0])
image_node = Node(image_zarr, node)
x_index = len(image_node.metadata["axes"]) - 1
y_index = len(image_node.metadata["axes"]) - 2
level = 0 # load full resolution image
self.numpy_type = image_node.data[level].dtype
self.img_metadata = image_node.metadata
Expand Down Expand Up @@ -448,8 +450,8 @@ def get_lazy_well() -> da.Array:
dtype=self.numpy_type,
)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=4))
return da.concatenate(lazy_rows, axis=3)
lazy_rows.append(da.concatenate(lazy_row, axis=x_index))
return da.concatenate(lazy_rows, axis=y_index)

node.data = [get_lazy_well()]
node.metadata = image_node.metadata
Expand All @@ -470,7 +472,7 @@ def get_pyramid_lazy(self, node: Node) -> None:
stitched full-resolution images.
"""
self.plate_data = self.lookup("plate", {})
LOGGER.info("plate_data", self.plate_data)
LOGGER.info("plate_data: %s", self.plate_data)
self.rows = self.plate_data.get("rows")
self.columns = self.plate_data.get("columns")
self.first_field = "0"
Expand All @@ -491,10 +493,11 @@ def get_pyramid_lazy(self, node: Node) -> None:
raise Exception("could not find first well")
self.numpy_type = well_spec.numpy_type

LOGGER.debug("img_pyramid_shapes", well_spec.img_pyramid_shapes)
LOGGER.debug(f"img_pyramid_shapes: {well_spec.img_pyramid_shapes}")

size_y = well_spec.img_shape[3]
size_x = well_spec.img_shape[4]
self.axes = well_spec.img_metadata["axes"]
size_y = well_spec.img_shape[len(self.axes) - 2]
size_x = well_spec.img_shape[len(self.axes) - 1]

# FIXME - if only returning a single stiched plate (not a pyramid)
# need to decide optimal size. E.g. longest side < 1500
Expand All @@ -511,7 +514,7 @@ def get_pyramid_lazy(self, node: Node) -> None:
if longest_side <= TARGET_SIZE:
break

LOGGER.debug("target_level", target_level)
LOGGER.debug(f"target_level: {target_level}")

pyramid = []

Expand Down Expand Up @@ -541,16 +544,19 @@ def get_tile_path(self, level: int, row: int, col: int) -> str:
)

def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array:
LOGGER.debug(f"get_stitched_grid() level: {level}, tile_shape: {tile_shape}")

def get_tile(tile_name: str) -> np.ndarray:
"""tile_name is 'level,z,c,t,row,col'"""
row, col = (int(n) for n in tile_name.split(","))
path = self.get_tile_path(level, row, col)
LOGGER.debug(f"LOADING tile... {path}")
LOGGER.debug(f"LOADING tile... {path} with shape: {tile_shape}")

try:
data = self.zarr.load(path)
except ValueError:
except ValueError as e:
LOGGER.error(f"Failed to load {path}")
LOGGER.debug(f"{e}")
data = np.zeros(tile_shape, dtype=self.numpy_type)
return data

Expand All @@ -566,23 +572,29 @@ def get_tile(tile_name: str) -> np.ndarray:
lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type
)
lazy_row.append(lazy_tile)
lazy_rows.append(da.concatenate(lazy_row, axis=4))
return da.concatenate(lazy_rows, axis=3)
lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1))
return da.concatenate(lazy_rows, axis=len(self.axes) - 2)


class PlateLabels(Plate):
def get_tile_path(self, level: int, row: int, col: int) -> str:
def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover
"""251.zarr/A/1/0/labels/0/3/"""
path = (
f"{self.row_names[row]}/{self.col_names[col]}/"
f"{self.first_field}/labels/0/{level}"
)
return path

def get_pyramid_lazy(self, node: Node) -> None:
def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover
super().get_pyramid_lazy(node)
# pyramid data may be multi-channel, but we only have 1 labels channel
node.data[0] = node.data[0][:, 0:1, :, :, :]
# TODO: when PlateLabels are re-enabled, update the logic to handle
# 0.4 axes (list of dictionaries)
if "c" in self.axes:
c_index = self.axes.index("c")
idx = [slice(None)] * len(self.axes)
idx[c_index] = slice(0, 1)
node.data[0] = node.data[0][tuple(idx)]
# remove image metadata
node.metadata = {}

Expand All @@ -602,7 +614,7 @@ def get_pyramid_lazy(self, node: Node) -> None:
del properties[label_val]["label-value"]
node.metadata["properties"] = properties

def get_numpy_type(self, image_node: Node) -> np.dtype:
def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover
# FIXME - don't assume Well A1 is valid
path = self.get_tile_path(0, 0, 0)
label_zarr = self.zarr.load(path)
Expand Down
1 change: 0 additions & 1 deletion tests/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ def test_multiwells_plate(self, fmt):
for wp in empty_wells:
assert parse_url(str(self.path / wp)) is None

@pytest.mark.xfail(reason="https://github.com/ome/ome-zarr-py/issues/145")
@pytest.mark.parametrize(
"axes, dims",
(
Expand Down
16 changes: 9 additions & 7 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ome_zarr.data import create_zarr
from ome_zarr.io import parse_url
from ome_zarr.reader import Node, Plate, PlateLabels, Reader
from ome_zarr.reader import Node, Plate, Reader
from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata


Expand Down Expand Up @@ -50,11 +50,12 @@ def test_minimal_plate(self):

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
assert len(nodes) == 2
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
assert len(nodes[1].specs) == 1
assert isinstance(nodes[1].specs[0], PlateLabels)
# assert len(nodes[1].specs) == 1
# assert isinstance(nodes[1].specs[0], PlateLabels)

def test_multiwells_plate(self):
row_names = ["A", "B", "C"]
Expand All @@ -72,8 +73,9 @@ def test_multiwells_plate(self):

reader = Reader(parse_url(str(self.path)))
nodes = list(reader())
assert len(nodes) == 2
# currently reading plate labels disabled. Only 1 node
assert len(nodes) == 1
assert len(nodes[0].specs) == 1
assert isinstance(nodes[0].specs[0], Plate)
assert len(nodes[1].specs) == 1
assert isinstance(nodes[1].specs[0], PlateLabels)
# assert len(nodes[1].specs) == 1
# assert isinstance(nodes[1].specs[0], PlateLabels)

0 comments on commit 0a3e927

Please sign in to comment.