diff --git a/ome_zarr/writer.py b/ome_zarr/writer.py index d705cc7a..0e94a467 100644 --- a/ome_zarr/writer.py +++ b/ome_zarr/writer.py @@ -40,34 +40,36 @@ def _validate_axes_names( if isinstance(axes, str): axes = list(axes) - if axes is not None: - if len(axes) != ndim: - raise ValueError("axes length must match number of dimensions") - # from https://github.com/constantinpape/ome-ngff-implementations/ - val_axes = tuple(axes) - if ndim == 2: - if val_axes != ("y", "x"): - raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}") - elif ndim == 3: - if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]: - raise ValueError( - "3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')" - " or ('t', 'y', 'x'), not %s" % (val_axes,) - ) - elif ndim == 4: - if val_axes not in [ - ("t", "z", "y", "x"), - ("c", "z", "y", "x"), - ("t", "c", "y", "x"), - ]: - raise ValueError("4D data must have axes tzyx or czyx or tcyx") - else: - if val_axes != ("t", "c", "z", "y", "x"): - raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')") - + if len(axes) != ndim: + raise ValueError("axes length must match number of dimensions") + _validate_axes(axes) return axes +def _validate_axes(axes: List[str], fmt: Format = CurrentFormat()) -> None: + + val_axes = tuple(axes) + if len(val_axes) == 2: + if val_axes != ("y", "x"): + raise ValueError(f"2D data must have axes ('y', 'x') {val_axes}") + elif len(val_axes) == 3: + if val_axes not in [("z", "y", "x"), ("c", "y", "x"), ("t", "y", "x")]: + raise ValueError( + "3D data must have axes ('z', 'y', 'x') or ('c', 'y', 'x')" + " or ('t', 'y', 'x'), not %s" % (val_axes,) + ) + elif len(val_axes) == 4: + if val_axes not in [ + ("t", "z", "y", "x"), + ("c", "z", "y", "x"), + ("t", "c", "y", "x"), + ]: + raise ValueError("4D data must have axes tzyx or czyx or tcyx") + else: + if val_axes != ("t", "c", "z", "y", "x"): + raise ValueError("5D data must have axes ('t', 'c', 'z', 'y', 'x')") + + def write_multiscale( pyramid: List, group: zarr.Group, @@ -103,11 +105,45 @@ def write_multiscale( for path, dataset in enumerate(pyramid): # TODO: chunks here could be different per layer group.create_dataset(str(path), data=dataset, chunks=chunks) - paths.append({"path": str(path)}) + paths.append(str(path)) + write_multiscales_metadata(group, paths, fmt, axes) + + +def write_multiscales_metadata( + group: zarr.Group, + paths: List[str], + fmt: Format = CurrentFormat(), + axes: List[str] = None, +) -> None: + """ + Write the multiscales metadata in the group. - multiscales = [{"version": fmt.version, "datasets": paths}] + Parameters + ---------- + group: zarr.Group + the group within the zarr store to write the metadata in. + paths: list of str + The list of paths to the datasets for this multiscale image. + fmt: Format + The format of the ome_zarr data which should be used. + Defaults to the most current. + axes: list of str + the names of the axes. e.g. ["t", "c", "z", "y", "x"]. + Ignored for versions 0.1 and 0.2. Required for version 0.3 or greater. + """ + + multiscales = [ + { + "version": fmt.version, + "datasets": [{"path": str(p)} for p in paths], + } + ] if axes is not None: - multiscales[0]["axes"] = axes + if fmt.version in ("0.1", "0.2"): + LOGGER.info("axes ignored for version 0.1 or 0.2") + else: + _validate_axes(axes, fmt) + multiscales[0]["axes"] = axes group.attrs["multiscales"] = multiscales diff --git a/tests/test_writer.py b/tests/test_writer.py index cd992338..5110b3d4 100644 --- a/tests/test_writer.py +++ b/tests/test_writer.py @@ -8,7 +8,11 @@ from ome_zarr.io import parse_url from ome_zarr.reader import Multiscales, Reader from ome_zarr.scale import Scaler -from ome_zarr.writer import _validate_axes_names, write_image +from ome_zarr.writer import ( + _validate_axes_names, + write_image, + write_multiscales_metadata, +) class TestWriter: @@ -125,3 +129,74 @@ def test_dim_names(self): fmt=v03, axes="xyz", ) + + +class TestMultiscalesMetadata: + @pytest.fixture(autouse=True) + def initdir(self, tmpdir): + self.path = pathlib.Path(tmpdir.mkdir("data")) + self.store = parse_url(self.path, mode="w").store + self.root = zarr.group(store=self.store) + + def test_single_level(self): + write_multiscales_metadata(self.root, ["0"]) + assert "multiscales" in self.root.attrs + assert "version" in self.root.attrs["multiscales"][0] + assert self.root.attrs["multiscales"][0]["datasets"] == [{"path": "0"}] + + def test_multi_levels(self): + write_multiscales_metadata(self.root, ["0", "1", "2"]) + assert "multiscales" in self.root.attrs + assert "version" in self.root.attrs["multiscales"][0] + assert self.root.attrs["multiscales"][0]["datasets"] == [ + {"path": "0"}, + {"path": "1"}, + {"path": "2"}, + ] + + @pytest.mark.parametrize("fmt", (FormatV01(), FormatV02(), FormatV03())) + def test_version(self, fmt): + write_multiscales_metadata(self.root, ["0"], fmt=fmt) + assert "multiscales" in self.root.attrs + assert self.root.attrs["multiscales"][0]["version"] == fmt.version + assert self.root.attrs["multiscales"][0]["datasets"] == [{"path": "0"}] + + @pytest.mark.parametrize( + "axes", + ( + ["y", "x"], + ["c", "y", "x"], + ["z", "y", "x"], + ["t", "y", "x"], + ["t", "c", "y", "x"], + ["t", "z", "y", "x"], + ["c", "z", "y", "x"], + ["t", "c", "z", "y", "x"], + ), + ) + def test_axes(self, axes): + write_multiscales_metadata(self.root, ["0"], axes=axes) + assert "multiscales" in self.root.attrs + assert self.root.attrs["multiscales"][0]["axes"] == axes + + @pytest.mark.parametrize("fmt", (FormatV01(), FormatV02())) + def test_axes_ignored(self, fmt): + write_multiscales_metadata( + self.root, ["0"], fmt=fmt, axes=["t", "c", "z", "y", "x"] + ) + assert "multiscales" in self.root.attrs + assert "axes" not in self.root.attrs["multiscales"][0] + + @pytest.mark.parametrize( + "axes", + ( + [], + ["i", "j"], + ["x", "y"], + ["y", "x", "c"], + ["x", "y", "z", "c", "t"], + ), + ) + def test_invalid_0_3_axes(self, axes): + with pytest.raises(ValueError): + write_multiscales_metadata(self.root, ["0"], fmt=FormatV03(), axes=axes)