diff --git a/ome_zarr/reader.py b/ome_zarr/reader.py index 79886178..0c6d0e7b 100644 --- a/ome_zarr/reader.py +++ b/ome_zarr/reader.py @@ -57,7 +57,7 @@ def __init__( self.specs.append(PlateLabels(self)) elif Plate.matches(zarr): self.specs.append(Plate(self)) - self.add(zarr, plate_labels=True) + # self.add(zarr, plate_labels=True) if Well.matches(zarr): self.specs.append(Well(self)) @@ -412,6 +412,8 @@ def __init__(self, node: Node) -> None: # Use first Field for rendering settings, shape etc. image_zarr = self.zarr.create(image_paths[0]) image_node = Node(image_zarr, node) + x_index = len(image_node.metadata["axes"]) - 1 + y_index = len(image_node.metadata["axes"]) - 2 level = 0 # load full resolution image self.numpy_type = image_node.data[level].dtype self.img_metadata = image_node.metadata @@ -448,8 +450,8 @@ def get_lazy_well() -> da.Array: dtype=self.numpy_type, ) lazy_row.append(lazy_tile) - lazy_rows.append(da.concatenate(lazy_row, axis=4)) - return da.concatenate(lazy_rows, axis=3) + lazy_rows.append(da.concatenate(lazy_row, axis=x_index)) + return da.concatenate(lazy_rows, axis=y_index) node.data = [get_lazy_well()] node.metadata = image_node.metadata @@ -470,7 +472,7 @@ def get_pyramid_lazy(self, node: Node) -> None: stitched full-resolution images. """ self.plate_data = self.lookup("plate", {}) - LOGGER.info("plate_data", self.plate_data) + LOGGER.info("plate_data: %s", self.plate_data) self.rows = self.plate_data.get("rows") self.columns = self.plate_data.get("columns") self.first_field = "0" @@ -491,10 +493,11 @@ def get_pyramid_lazy(self, node: Node) -> None: raise Exception("could not find first well") self.numpy_type = well_spec.numpy_type - LOGGER.debug("img_pyramid_shapes", well_spec.img_pyramid_shapes) + LOGGER.debug(f"img_pyramid_shapes: {well_spec.img_pyramid_shapes}") - size_y = well_spec.img_shape[3] - size_x = well_spec.img_shape[4] + self.axes = well_spec.img_metadata["axes"] + size_y = well_spec.img_shape[len(self.axes) - 2] + size_x = well_spec.img_shape[len(self.axes) - 1] # FIXME - if only returning a single stiched plate (not a pyramid) # need to decide optimal size. E.g. longest side < 1500 @@ -511,7 +514,7 @@ def get_pyramid_lazy(self, node: Node) -> None: if longest_side <= TARGET_SIZE: break - LOGGER.debug("target_level", target_level) + LOGGER.debug(f"target_level: {target_level}") pyramid = [] @@ -541,16 +544,19 @@ def get_tile_path(self, level: int, row: int, col: int) -> str: ) def get_stitched_grid(self, level: int, tile_shape: tuple) -> da.core.Array: + LOGGER.debug(f"get_stitched_grid() level: {level}, tile_shape: {tile_shape}") + def get_tile(tile_name: str) -> np.ndarray: """tile_name is 'level,z,c,t,row,col'""" row, col = (int(n) for n in tile_name.split(",")) path = self.get_tile_path(level, row, col) - LOGGER.debug(f"LOADING tile... {path}") + LOGGER.debug(f"LOADING tile... {path} with shape: {tile_shape}") try: data = self.zarr.load(path) - except ValueError: + except ValueError as e: LOGGER.error(f"Failed to load {path}") + LOGGER.debug(f"{e}") data = np.zeros(tile_shape, dtype=self.numpy_type) return data @@ -566,12 +572,12 @@ def get_tile(tile_name: str) -> np.ndarray: lazy_reader(tile_name), shape=tile_shape, dtype=self.numpy_type ) lazy_row.append(lazy_tile) - lazy_rows.append(da.concatenate(lazy_row, axis=4)) - return da.concatenate(lazy_rows, axis=3) + lazy_rows.append(da.concatenate(lazy_row, axis=len(self.axes) - 1)) + return da.concatenate(lazy_rows, axis=len(self.axes) - 2) class PlateLabels(Plate): - def get_tile_path(self, level: int, row: int, col: int) -> str: + def get_tile_path(self, level: int, row: int, col: int) -> str: # pragma: no cover """251.zarr/A/1/0/labels/0/3/""" path = ( f"{self.row_names[row]}/{self.col_names[col]}/" @@ -579,10 +585,16 @@ def get_tile_path(self, level: int, row: int, col: int) -> str: ) return path - def get_pyramid_lazy(self, node: Node) -> None: + def get_pyramid_lazy(self, node: Node) -> None: # pragma: no cover super().get_pyramid_lazy(node) # pyramid data may be multi-channel, but we only have 1 labels channel - node.data[0] = node.data[0][:, 0:1, :, :, :] + # TODO: when PlateLabels are re-enabled, update the logic to handle + # 0.4 axes (list of dictionaries) + if "c" in self.axes: + c_index = self.axes.index("c") + idx = [slice(None)] * len(self.axes) + idx[c_index] = slice(0, 1) + node.data[0] = node.data[0][tuple(idx)] # remove image metadata node.metadata = {} @@ -602,7 +614,7 @@ def get_pyramid_lazy(self, node: Node) -> None: del properties[label_val]["label-value"] node.metadata["properties"] = properties - def get_numpy_type(self, image_node: Node) -> np.dtype: + def get_numpy_type(self, image_node: Node) -> np.dtype: # pragma: no cover # FIXME - don't assume Well A1 is valid path = self.get_tile_path(0, 0, 0) label_zarr = self.zarr.load(path) diff --git a/tests/test_node.py b/tests/test_node.py index aede5b99..a538c7c7 100644 --- a/tests/test_node.py +++ b/tests/test_node.py @@ -109,7 +109,6 @@ def test_multiwells_plate(self, fmt): for wp in empty_wells: assert parse_url(str(self.path / wp)) is None - @pytest.mark.xfail(reason="https://github.com/ome/ome-zarr-py/issues/145") @pytest.mark.parametrize( "axes, dims", ( diff --git a/tests/test_reader.py b/tests/test_reader.py index f82eb59e..f5d4a3fd 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -4,7 +4,7 @@ from ome_zarr.data import create_zarr from ome_zarr.io import parse_url -from ome_zarr.reader import Node, Plate, PlateLabels, Reader +from ome_zarr.reader import Node, Plate, Reader from ome_zarr.writer import write_image, write_plate_metadata, write_well_metadata @@ -50,11 +50,12 @@ def test_minimal_plate(self): reader = Reader(parse_url(str(self.path))) nodes = list(reader()) - assert len(nodes) == 2 + # currently reading plate labels disabled. Only 1 node + assert len(nodes) == 1 assert len(nodes[0].specs) == 1 assert isinstance(nodes[0].specs[0], Plate) - assert len(nodes[1].specs) == 1 - assert isinstance(nodes[1].specs[0], PlateLabels) + # assert len(nodes[1].specs) == 1 + # assert isinstance(nodes[1].specs[0], PlateLabels) def test_multiwells_plate(self): row_names = ["A", "B", "C"] @@ -72,8 +73,9 @@ def test_multiwells_plate(self): reader = Reader(parse_url(str(self.path))) nodes = list(reader()) - assert len(nodes) == 2 + # currently reading plate labels disabled. Only 1 node + assert len(nodes) == 1 assert len(nodes[0].specs) == 1 assert isinstance(nodes[0].specs[0], Plate) - assert len(nodes[1].specs) == 1 - assert isinstance(nodes[1].specs[0], PlateLabels) + # assert len(nodes[1].specs) == 1 + # assert isinstance(nodes[1].specs[0], PlateLabels)