From 86665b7eee8f1254203070807c04b1fb4b105ead Mon Sep 17 00:00:00 2001 From: giovp Date: Sun, 5 Feb 2023 21:34:53 +0100 Subject: [PATCH 01/24] don't remove coords --- spatialdata/_core/core_utils.py | 16 +++------------- spatialdata/_core/models.py | 2 -- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index 5d5a342a..1f89f0ed 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -259,7 +259,7 @@ def get_default_coordinate_system(dims: tuple[str, ...]) -> NgffCoordinateSystem @singledispatch def get_dims(e: SpatialElement) -> tuple[str, ...]: """ - Get the dimensions of a spatial element + Get the dimensions of a spatial element. Parameters ---------- @@ -268,8 +268,7 @@ def get_dims(e: SpatialElement) -> tuple[str, ...]: Returns ------- - dims - Dimensions of the spatial element (e.g. ("z", "y", "x")) + Dimensions of the spatial element (e.g. ("z", "y", "x")) """ raise TypeError(f"Unsupported type: {type(e)}") @@ -282,16 +281,7 @@ def _(e: SpatialImage) -> tuple[str, ...]: @get_dims.register(MultiscaleSpatialImage) def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: - # luca: I prefer this first method - d = dict(e["scale0"]) - assert len(d) == 1 - dims0 = d.values().__iter__().__next__().dims - assert isinstance(dims0, tuple) - # still, let's do a runtime check against the other method - variables = list(e[list(e.keys())[0]].variables) - dims1 = e[list(e.keys())[0]][variables[0]].dims - assert dims0 == dims1 - return dims0 + return tuple(i for i in e["scale0"].dims.keys()) @get_dims.register(GeoDataFrame) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index cafe4724..f61716fe 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -180,8 +180,6 @@ def parse( data = to_spatial_image(array_like=data, dims=cls.dims.dims, **kwargs) assert isinstance(data, SpatialImage) - # TODO(giovp): drop coordinates for now until solution with IO. - data = data.drop(data.coords.keys()) _parse_transformations(data, transformations) if multiscale_factors is not None: # check that the image pyramid doesn't contain axes that get collapsed and eventually truncates the list From 632c112365b1e335f5b3f534db4506c85ab05485 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 17 Feb 2023 15:25:25 +0100 Subject: [PATCH 02/24] update iter multiscale --- spatialdata/_io/write.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/spatialdata/_io/write.py b/spatialdata/_io/write.py index ffd84373..f8147945 100644 --- a/spatialdata/_io/write.py +++ b/spatialdata/_io/write.py @@ -1,5 +1,5 @@ import os -from collections.abc import Mapping +from collections.abc import Mapping, Set from typing import Any, Literal, Optional, Union import pyarrow as pa @@ -189,6 +189,7 @@ def _get_group_for_writing_transformations() -> zarr.Group: assert transformations is not None assert len(transformations) > 0 chunks = _iter_multiscale(raster_data, "chunks") + # coords = _iter_multiscale(raster_data, "coords") parsed_axes = _get_valid_axes(axes=list(input_axes), fmt=fmt) storage_options = [{"chunks": chunk} for chunk in chunks] write_multi_scale_ngff( @@ -361,19 +362,12 @@ def write_table( def _iter_multiscale( data: MultiscaleSpatialImage, attr: str, - key: Optional[str] = None, ) -> list[Any]: # TODO: put this check also in the validator for raster multiscales - name = None for i in data.keys(): - variables = list(data[i].variables) - if len(variables) != 1: - raise ValueError("MultiscaleSpatialImage must have exactly one variable (the variable name is arbitrary)") - if name is not None: - if name != variables[0]: - raise ValueError("MultiscaleSpatialImage must have the same variable name across all levels") - name = variables[0] - if key is None: - return [getattr(data[i][name], attr) for i in data.keys()] - else: - return [getattr(data[i][name], attr).get(key) for i in data.keys()] + variables = set(data[i].variables.keys()) + names: Set[str] = variables.difference({"c", "z", "y", "x"}) + if len(names) != 1: + raise ValueError(f"Invalid variable name: `{names}`.") + name: str = next(iter(names)) + return [getattr(data[i][name], attr) for i in data.keys()] From 6e4a05084a0c2806a1118295e30f6614c57c5e79 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 17 Feb 2023 15:43:47 +0100 Subject: [PATCH 03/24] update spatial-image and multiscale-spatial_image version --- pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index ef39ae5a..dbebe3cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,8 @@ dependencies = [ "zarr", # "ome_zarr", "ome_zarr@git+https://github.com/LucaMarconato/ome-zarr-py@bug_fix_io", - "spatial_image", - "multiscale_spatial_image", + "spatial_image>=0.3.0", + "multiscale_spatial_image>=0.11.2", "xarray-schema", "pygeos", "geopandas", From 6565b12a55902dee912f413f8615dbe2d44a9823 Mon Sep 17 00:00:00 2001 From: giovp Date: Sun, 19 Feb 2023 17:01:47 +0100 Subject: [PATCH 04/24] remove check that is now in multiscale spatial image --- spatialdata/_core/models.py | 19 +------------------ 1 file changed, 1 insertion(+), 18 deletions(-) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index 22b5afe7..0dece645 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -183,28 +183,11 @@ def parse( assert isinstance(data, SpatialImage) _parse_transformations(data, transformations) if multiscale_factors is not None: - # check that the image pyramid doesn't contain axes that get collapsed and eventually truncates the list - # of downscaling factors to avoid this - adjusted_multiscale_factors: list[int] = [] - assert isinstance(data, DataArray) - current_shape: ArrayLike = np.array(data.shape, dtype=float) - # multiscale_factors could be a dict, we don't support this case here (in the future this code and the - # more general case will be handled by multiscale-spatial-image) - assert isinstance(multiscale_factors, list) - for factor in multiscale_factors: - scale_vector = np.array([1.0 if ax == "c" else factor for ax in data.dims]) - current_shape /= scale_vector - if current_shape.min() < 1: - logger.warning( - f"Detected a multiscale factor that would collapse an axis: truncating list of factors from {multiscale_factors} to {adjusted_multiscale_factors}" - ) - break - adjusted_multiscale_factors.append(factor) parsed_transform = _get_transformations(data) del data.attrs["transform"] data = to_multiscale( data, - scale_factors=adjusted_multiscale_factors, + scale_factors=multiscale_factors, method=method, chunks=chunks, ) From 4833e638596a4b4dc7677982984e5cab64a30dae Mon Sep 17 00:00:00 2001 From: giovp Date: Sun, 19 Feb 2023 18:12:34 +0100 Subject: [PATCH 05/24] add coordinates assignment --- spatialdata/_core/core_utils.py | 66 ++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 2 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index c4ba9044..1521fe2d 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -1,6 +1,6 @@ import copy from functools import singledispatch -from typing import Any, Optional, Union +from typing import TYPE_CHECKING, Any, Optional, Union import numpy as np from anndata import AnnData @@ -277,7 +277,10 @@ def _(e: SpatialImage) -> tuple[str, ...]: @get_dims.register(MultiscaleSpatialImage) def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: - return tuple(i for i in e["scale0"].dims.keys()) + if "scale0" in e: + return tuple(i for i in e["scale0"].dims.keys()) + else: + return tuple(i for i in e.dims.keys()) @get_dims.register(GeoDataFrame) @@ -299,3 +302,62 @@ def _(e: AnnData) -> tuple[str, ...]: valid_dims = (X, Y, Z) dims = [c for c in valid_dims if c in e.columns] return tuple(dims) + + +@singledispatch +def compute_coordinates(data: Union[SpatialImage, MultiscaleSpatialImage]) -> tuple[str, ...]: + """ + Computes and assign coordinates to a (Multiscale)SpatialImage. + + Parameters + ---------- + data + :class:`SpatialImage` or :class:`MultiscaleSpatialImage`. + + Returns + ------- + :class:`SpatialImage` or :class:`MultiscaleSpatialImage` with coordinates assigned. + """ + raise TypeError(f"Unsupported type: {type(data)}") + + +@compute_coordinates.register(SpatialImage) +def _(data: SpatialImage) -> SpatialImage: + coords: dict[str, ArrayLike] = { + d: np.arange(data.sizes[d], dtype=np.float_) for d in data.sizes.keys() if d in ["x", "y", "z"] + } + return data.assign_coords(coords) + + +@compute_coordinates.register(MultiscaleSpatialImage) +def _(data: MultiscaleSpatialImage) -> MultiscaleSpatialImage: + def _get_scale(transforms: dict[str, Any]) -> Optional[ArrayLike]: + for t in transforms["global"].transformations: + if hasattr(t, "scale"): + if TYPE_CHECKING: + assert isinstance(t.scale, np.ndarray) + return t.scale + + def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: + return ( # type: ignore[no-any-return] + DataArray(np.linspace(0, max_, max_, endpoint=False, dtype=np.float_)) + .coarsen(dim_0=scale_f, boundary="trim", side="right") + .mean() + .values + ) + + max_scale0 = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]} + out = {} + + for name, dt in data.items(): + max_scale = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]} + if name == "scale0": + coords: dict[str, ArrayLike] = {d: np.arange(max_scale[d], dtype=np.float_) for d in max_scale.keys()} + out[name] = dt["image"].assign_coords(coords) + else: + scalef = _get_scale(dt["image"].attrs["transform"]) + assert len(max_scale.keys()) == len(scalef), "Mismatch between coordinates and scales." # type: ignore[arg-type] + out[name] = dt["image"].assign_coords( + {k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type] + ) + return MultiscaleSpatialImage.from_dict(d=out) From 2d8bbd25d39b9f5b885d0287ef5d9903cc107160 Mon Sep 17 00:00:00 2001 From: giovp Date: Sun, 19 Feb 2023 18:25:50 +0100 Subject: [PATCH 06/24] add coordinates to parser --- spatialdata/_core/models.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index 0dece645..0f357d8c 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -47,6 +47,7 @@ _get_transformations, _set_transformations, _validate_mapping_to_coordinate_system_type, + compute_coordinates, get_dims, ) from spatialdata._core.transformations import BaseTransformation, Identity @@ -193,6 +194,7 @@ def parse( ) _parse_transformations(data, parsed_transform) assert isinstance(data, MultiscaleSpatialImage) + data = compute_coordinates(data) return data def validate(self, data: Union[SpatialImage, MultiscaleSpatialImage]) -> None: From c99c5dd57ea110d1c1450c5ff774458d5088b187 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 10:07:01 +0100 Subject: [PATCH 07/24] fix tests --- spatialdata/_core/core_utils.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index 1521fe2d..ecd68461 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -305,7 +305,9 @@ def _(e: AnnData) -> tuple[str, ...]: @singledispatch -def compute_coordinates(data: Union[SpatialImage, MultiscaleSpatialImage]) -> tuple[str, ...]: +def compute_coordinates( + data: Union[SpatialImage, MultiscaleSpatialImage] +) -> Union[SpatialImage, MultiscaleSpatialImage]: """ Computes and assign coordinates to a (Multiscale)SpatialImage. @@ -347,17 +349,18 @@ def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: ) max_scale0 = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]} + img_name = list(data["scale0"].data_vars.keys())[0] out = {} for name, dt in data.items(): max_scale = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]} if name == "scale0": coords: dict[str, ArrayLike] = {d: np.arange(max_scale[d], dtype=np.float_) for d in max_scale.keys()} - out[name] = dt["image"].assign_coords(coords) + out[name] = dt[img_name].assign_coords(coords) else: - scalef = _get_scale(dt["image"].attrs["transform"]) + scalef = _get_scale(dt[img_name].attrs["transform"]) assert len(max_scale.keys()) == len(scalef), "Mismatch between coordinates and scales." # type: ignore[arg-type] - out[name] = dt["image"].assign_coords( + out[name] = dt[img_name].assign_coords( {k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type] ) return MultiscaleSpatialImage.from_dict(d=out) From f562a3bf3d5a19b5b9d66976e9c8fecad147076d Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 10:41:59 +0100 Subject: [PATCH 08/24] update tests for 3d --- spatialdata/_core/models.py | 20 ++++++------- tests/conftest.py | 56 ++++++++++++++++++------------------- 2 files changed, 38 insertions(+), 38 deletions(-) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index a857f1b8..d52404b6 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -112,7 +112,7 @@ def parse( data: Union[ArrayLike, DataArray, DaskArray], dims: Optional[Sequence[str]] = None, transformations: Optional[MappingToCoordinateSystem_t] = None, - multiscale_factors: Optional[ScaleFactors_t] = None, + scale_factors: Optional[ScaleFactors_t] = None, method: Optional[Methods] = None, chunks: Optional[Chunks_t] = None, **kwargs: Any, @@ -130,7 +130,7 @@ def parse( Transformations to apply to the data. multiscale_factors Scale factors to apply for multiscale. - If not None, a :class:`multiscale_spatial_image.multiscale_spatial_image.MultiscaleSpatialImage` is returned. + If not None, a :class:`multiscale_spatial_image.MultiscaleSpatialImage` is returned. method Method to use for multiscale. chunks @@ -139,23 +139,23 @@ def parse( Returns ------- :class:`spatial_image.SpatialImage` or - :class:`multiscale_spatial_image.multiscale_spatial_image.MultiscaleSpatialImage`. + :class:`multiscale_spatial_image.MultiscaleSpatialImage`. """ - # check if dims is specified and if it has correct values - # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray) or isinstance(data, SpatialImage): if not isinstance(data.data, DaskArray): # numpy -> dask data.data = from_array(data.data) if dims is not None: - if dims != data.dims: + if set(dims).difference(data.dims): raise ValueError( f"`dims`: {dims} does not match `data.dims`: {data.dims}, please specify the dims only once." ) else: - logger.info("`dims` is specified redundantly: found also inside `data`") + logger.info("`dims` is specified redundantly: found also inside `data`.") else: - dims = data.dims # type: ignore[assignment] + dims = data.dims + if set(dims).difference(cls.dims.dims): + raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") _reindex = lambda d: d elif isinstance(data, np.ndarray) or isinstance(data, DaskArray): if not isinstance(data, DaskArray): # numpy -> dask @@ -183,12 +183,12 @@ def parse( data = to_spatial_image(array_like=data, dims=cls.dims.dims, **kwargs) assert isinstance(data, SpatialImage) _parse_transformations(data, transformations) - if multiscale_factors is not None: + if scale_factors is not None: parsed_transform = _get_transformations(data) del data.attrs["transform"] data = to_multiscale( data, - scale_factors=multiscale_factors, + scale_factors=scale_factors, method=method, chunks=chunks, ) diff --git a/tests/conftest.py b/tests/conftest.py index e152e474..cb909c4a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,10 +10,12 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Polygon from spatial_image import SpatialImage +from xarray import DataArray from spatialdata import SpatialData from spatialdata._core.models import ( Image2DModel, + Image3DModel, Labels2DModel, Labels3DModel, PointsModel, @@ -117,35 +119,33 @@ def sdata(request) -> SpatialData: def _get_images() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: out = {} dims_2d = ("c", "y", "x") - + dims_3d = ("z", "y", "x", "c") out["image2d"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), name="image2d", dims=dims_2d) out["image2d_multiscale"] = Image2DModel.parse( - RNG.normal(size=(3, 64, 64)), name="image2d_multiscale", multiscale_factors=[2, 2], dims=dims_2d + RNG.normal(size=(3, 64, 64)), name="image2d_multiscale", scale_factors=[2, 2], dims=dims_2d + ) + out["image2d_xarray"] = Image2DModel.parse( + DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), name="image2d_xarray", dims=None + ) + out["image2d_multiscale_xarray"] = Image2DModel.parse( + DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), + name="image2d_multiscale_xarray", + scale_factors=[2, 4], + dims=None, + ) + out["image3d_numpy"] = Image3DModel.parse(RNG.normal(size=(2, 64, 64, 3)), name="image3d_numpy", dims=dims_3d) + out["image3d_multiscale_numpy"] = Image3DModel.parse( + RNG.normal(size=(2, 64, 64, 3)), name="image3d_multiscale_numpy", scale_factors=[2], dims=dims_3d + ) + out["image3d_xarray"] = Image3DModel.parse( + DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), name="image3d_xarray", dims=None + ) + out["image3d_multiscale_xarray"] = Image3DModel.parse( + DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), + name="image3d_multiscale_xarray", + scale_factors=[2], + dims=None, ) - # TODO: (BUG) https://github.com/scverse/spatialdata/issues/59 - # out["image2d_xarray"] = Image2DModel.parse( - # DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), name="image2d_xarray", dims=None - # ) - # out["image2d_multiscale_xarray"] = Image2DModel.parse( - # DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), - # name="image2d_multiscale_xarray", - # multiscale_factors=[2, 4], - # dims=None, - # ) - # # TODO: not supported atm. - # out["image3d_numpy"] = Image3DModel.parse(RNG.normal(size=(2, 64, 64, 3)), name="image3d_numpy", dims=dims_3d) - # out["image3d_multiscale_numpy"] = Image3DModel.parse( - # RNG.normal(size=(2, 64, 64, 3)), name="image3d_multiscale_numpy", scale_factors=[2, 4], dims=dims_3d - # ) - # out["image3d_xarray"] = Image3DModel.parse( - # DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), name="image3d_xarray", dims=None - # ) - # out["image3d_multiscale_xarray"] = Image3DModel.parse( - # DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), - # name="image3d_multiscale_xarray", - # scale_factors=[2, 4], - # dims=None, - # ) return out @@ -156,7 +156,7 @@ def _get_labels() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: out["labels2d"] = Labels2DModel.parse(RNG.normal(size=(64, 64)), name="labels2d", dims=dims_2d) out["labels2d_multiscale"] = Labels2DModel.parse( - RNG.normal(size=(64, 64)), name="labels2d_multiscale", multiscale_factors=[2, 4], dims=dims_2d + RNG.normal(size=(64, 64)), name="labels2d_multiscale", scale_factors=[2, 4], dims=dims_2d ) # TODO: (BUG) https://github.com/scverse/spatialdata/issues/59 @@ -171,7 +171,7 @@ def _get_labels() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: # ) out["labels3d_numpy"] = Labels3DModel.parse(RNG.normal(size=(10, 64, 64)), name="labels3d_numpy", dims=dims_3d) out["labels3d_multiscale_numpy"] = Labels3DModel.parse( - RNG.normal(size=(10, 64, 64)), name="labels3d_multiscale_numpy", multiscale_factors=[2, 4], dims=dims_3d + RNG.normal(size=(10, 64, 64)), name="labels3d_multiscale_numpy", scale_factors=[2, 4], dims=dims_3d ) # TODO: (BUG) https://github.com/scverse/spatialdata/issues/59 # out["labels3d_xarray"] = Labels3DModel.parse( From aef337bfb60d90a36a115c8e7d83e9caa778a0a6 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 10:44:46 +0100 Subject: [PATCH 09/24] add tests for labels --- tests/conftest.py | 39 ++++++++++++++++++--------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index cb909c4a..29d37d90 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -158,31 +158,28 @@ def _get_labels() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: out["labels2d_multiscale"] = Labels2DModel.parse( RNG.normal(size=(64, 64)), name="labels2d_multiscale", scale_factors=[2, 4], dims=dims_2d ) - - # TODO: (BUG) https://github.com/scverse/spatialdata/issues/59 - # out["labels2d_xarray"] = Labels2DModel.parse( - # DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), name="labels2d_xarray", dims=None - # ) - # out["labels2d_multiscale_xarray"] = Labels2DModel.parse( - # DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), - # name="labels2d_multiscale_xarray", - # multiscale_factors=[2, 4], - # dims=None, - # ) + out["labels2d_xarray"] = Labels2DModel.parse( + DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), name="labels2d_xarray", dims=None + ) + out["labels2d_multiscale_xarray"] = Labels2DModel.parse( + DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), + name="labels2d_multiscale_xarray", + scale_factors=[2, 4], + dims=None, + ) out["labels3d_numpy"] = Labels3DModel.parse(RNG.normal(size=(10, 64, 64)), name="labels3d_numpy", dims=dims_3d) out["labels3d_multiscale_numpy"] = Labels3DModel.parse( RNG.normal(size=(10, 64, 64)), name="labels3d_multiscale_numpy", scale_factors=[2, 4], dims=dims_3d ) - # TODO: (BUG) https://github.com/scverse/spatialdata/issues/59 - # out["labels3d_xarray"] = Labels3DModel.parse( - # DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), name="labels3d_xarray", dims=None - # ) - # out["labels3d_multiscale_xarray"] = Labels3DModel.parse( - # DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), - # name="labels3d_multiscale_xarray", - # multiscale_factors=[2, 4], - # dims=None, - # ) + out["labels3d_xarray"] = Labels3DModel.parse( + DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), name="labels3d_xarray", dims=None + ) + out["labels3d_multiscale_xarray"] = Labels3DModel.parse( + DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), + name="labels3d_multiscale_xarray", + scale_factors=[2, 4], + dims=None, + ) return out From 0dc0459ee62caf9d31c4505fd1fd4abefebd96f7 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 10:48:43 +0100 Subject: [PATCH 10/24] update shapely --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index dbebe3cf..df46cfef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,7 +31,7 @@ dependencies = [ "xarray-schema", "pygeos", "geopandas", - "shapely==2.0rc2", + "shapely>=2.0.1", "rich", "pyarrow", "tqdm", From 97f8d7e04f850453a237e7ab60da7498b9900cee Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 10:57:27 +0100 Subject: [PATCH 11/24] add some comments --- spatialdata/_core/models.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index d52404b6..a9137537 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -146,7 +146,7 @@ def parse( if not isinstance(data.data, DaskArray): # numpy -> dask data.data = from_array(data.data) if dims is not None: - if set(dims).difference(data.dims): + if set(dims).symmetric_difference(data.dims): raise ValueError( f"`dims`: {dims} does not match `data.dims`: {data.dims}, please specify the dims only once." ) @@ -154,9 +154,11 @@ def parse( logger.info("`dims` is specified redundantly: found also inside `data`.") else: dims = data.dims - if set(dims).difference(cls.dims.dims): + # but if dims don't match the model's dims, throw error + if set(dims).symmetric_difference(cls.dims.dims): raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.") _reindex = lambda d: d + # if there are no dims in the data, use the model's dims or provided dims elif isinstance(data, np.ndarray) or isinstance(data, DaskArray): if not isinstance(data, DaskArray): # numpy -> dask data = from_array(data) @@ -180,11 +182,14 @@ def parse( except ValueError: raise ValueError(f"Cannot transpose arrays to match `dims`: {dims}. Try to reshape `data` or `dims`.") + # finally convert to spatial image data = to_spatial_image(array_like=data, dims=cls.dims.dims, **kwargs) - assert isinstance(data, SpatialImage) + # parse transformations _parse_transformations(data, transformations) + # convert to multiscale if needed if scale_factors is not None: parsed_transform = _get_transformations(data) + # delete transforms del data.attrs["transform"] data = to_multiscale( data, @@ -193,7 +198,7 @@ def parse( chunks=chunks, ) _parse_transformations(data, parsed_transform) - assert isinstance(data, MultiscaleSpatialImage) + # recompute coordinates for (multiscale) spatial image data = compute_coordinates(data) return data From 90e0cc359a85dbf5f0b6ae53fc74c5ac1e4d7db7 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 11:16:01 +0100 Subject: [PATCH 12/24] improve validation per #115 --- spatialdata/_core/models.py | 53 ++++++++++++++++++++++++------------- 1 file changed, 35 insertions(+), 18 deletions(-) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index a9137537..96d986f8 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -202,16 +202,39 @@ def parse( data = compute_coordinates(data) return data - def validate(self, data: Union[SpatialImage, MultiscaleSpatialImage]) -> None: - if isinstance(data, SpatialImage): - super().validate(data) - elif isinstance(data, MultiscaleSpatialImage): - name = {list(data[i].data_vars.keys())[0] for i in data.keys()} - if len(name) > 1: - raise ValueError(f"Wrong name for datatree: {name}.") - name = list(name)[0] - for d in data: - super().validate(data[d][name]) + @singledispatchmethod + def validate(self, data: Any) -> None: + """ + Validate data. + + Parameters + ---------- + data + Data to validate. + + Raises + ------ + ValueError + If data is not valid. + """ + + raise ValueError(f"Unsupported data type: {type(data)}.") + + @validate.register(SpatialImage) + def _(self, data: SpatialImage) -> None: + super().validate(data) + + @validate.register(MultiscaleSpatialImage) + def _(self, data: MultiscaleSpatialImage) -> None: + for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))]): + if j != k: + raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.") + name = {list(data[i].data_vars.keys())[0] for i in data.keys()} + if len(name) > 1: + raise ValueError(f"Wrong name for datatree: `{name}`.") + name = list(name)[0] + for d in data: + super().validate(data[d][name]) class Labels2DModel(RasterSchema): @@ -223,8 +246,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, - # suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115 - # attrs=self.attrs, *args, **kwargs, ) @@ -239,8 +260,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, - # suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115 - # attrs=self.attrs, *args, **kwargs, ) @@ -255,8 +274,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, - # suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115 - # attrs=self.attrs, + attrs=self.attrs, *args, **kwargs, ) @@ -271,8 +289,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, - # suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115 - # attrs=self.attrs, + attrs=self.attrs, *args, **kwargs, ) From 6140832373494fbebd54ee1acea4e778678ee970 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 11:22:59 +0100 Subject: [PATCH 13/24] remove multiscale_factors and add sclae_factors --- spatialdata/_core/_transform_elements.py | 2 +- spatialdata/_core/models.py | 2 +- spatialdata/utils.py | 2 +- tests/test_utils.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spatialdata/_core/_transform_elements.py b/spatialdata/_core/_transform_elements.py index 67d4bbfe..28c8e6ac 100644 --- a/spatialdata/_core/_transform_elements.py +++ b/spatialdata/_core/_transform_elements.py @@ -185,7 +185,7 @@ def _(data: MultiscaleSpatialImage, transformation: BaseTransformation) -> Multi except OverflowError as e: raise e # mypy thinks that schema could be ShapesModel, PointsModel, ... - transformed_data = schema.parse(transformed_dask, dims=axes, multiscale_factors=multiscale_factors) # type: ignore[call-arg,arg-type] + transformed_data = schema.parse(transformed_dask, dims=axes, scale_factors=multiscale_factors) # type: ignore[call-arg,arg-type] print( "TODO: compose the transformation!!!! we need to put the previous one concatenated with the translation showen above. The translation operates before the other transformation" ) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index 96d986f8..da0acf86 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -128,7 +128,7 @@ def parse( Dimensions of the data. transformations Transformations to apply to the data. - multiscale_factors + scale_factors Scale factors to apply for multiscale. If not None, a :class:`multiscale_spatial_image.MultiscaleSpatialImage` is returned. method diff --git a/spatialdata/utils.py b/spatialdata/utils.py index 28670a17..465950f1 100644 --- a/spatialdata/utils.py +++ b/spatialdata/utils.py @@ -133,7 +133,7 @@ def _unpad_axis(data: DataArray, axis: str) -> tuple[DataArray, float]: unpadded = unpad_raster(SpatialImage(xdata)) # TODO: here I am using some arbitrary scalingfactors, I think that we need an automatic initialization of multiscale. See discussion: https://github.com/scverse/spatialdata/issues/108 # mypy thinks that the schema could be a ShapeModel, ... but it's not - unpadded_multiscale = get_schema(raster).parse(unpadded, multiscale_factors=[2, 2]) # type: ignore[call-arg] + unpadded_multiscale = get_schema(raster).parse(unpadded, scale_factors=[2, 2]) # type: ignore[call-arg] return unpadded_multiscale else: raise TypeError(f"Unsupported type: {type(raster)}") diff --git a/tests/test_utils.py b/tests/test_utils.py index d48a7edd..19a0c767 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -46,7 +46,7 @@ def test_unpad_raster(images, labels) -> None: padded = schema.parse(padded, dims=data.dims) elif isinstance(raster, MultiscaleSpatialImage): # some arbitrary scaling factors - padded = schema.parse(padded, dims=data.dims, multiscale_factors=[2, 2]) + padded = schema.parse(padded, dims=data.dims, scale_factors=[2, 2]) else: raise ValueError(f"Unknown type: {type(raster)}") unpadded = unpad_raster(padded) From 238fb8ca606d55625722521d86288abd3178c1a3 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 11:47:28 +0100 Subject: [PATCH 14/24] try fixing tests --- spatialdata/_core/_transform_elements.py | 2 +- spatialdata/utils.py | 6 +++++- tests/test_utils.py | 2 +- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/spatialdata/_core/_transform_elements.py b/spatialdata/_core/_transform_elements.py index 28c8e6ac..6bfb4346 100644 --- a/spatialdata/_core/_transform_elements.py +++ b/spatialdata/_core/_transform_elements.py @@ -182,7 +182,7 @@ def _(data: MultiscaleSpatialImage, transformation: BaseTransformation) -> Multi # assert np.allclose(almost_zero, np.zeros_like(almost_zero), rtol=2.) try: multiscale_factors.append(round(factors[0])) - except OverflowError as e: + except ValueError as e: raise e # mypy thinks that schema could be ShapesModel, PointsModel, ... transformed_data = schema.parse(transformed_dask, dims=axes, scale_factors=multiscale_factors) # type: ignore[call-arg,arg-type] diff --git a/spatialdata/utils.py b/spatialdata/utils.py index 465950f1..27eb665c 100644 --- a/spatialdata/utils.py +++ b/spatialdata/utils.py @@ -133,7 +133,11 @@ def _unpad_axis(data: DataArray, axis: str) -> tuple[DataArray, float]: unpadded = unpad_raster(SpatialImage(xdata)) # TODO: here I am using some arbitrary scalingfactors, I think that we need an automatic initialization of multiscale. See discussion: https://github.com/scverse/spatialdata/issues/108 # mypy thinks that the schema could be a ShapeModel, ... but it's not - unpadded_multiscale = get_schema(raster).parse(unpadded, scale_factors=[2, 2]) # type: ignore[call-arg] + if "z" in axes: + scale_factors = [2] + else: + scale_factors = [2, 2] + unpadded_multiscale = get_schema(raster).parse(unpadded, scale_factors=scale_factors) # type: ignore[call-arg] return unpadded_multiscale else: raise TypeError(f"Unsupported type: {type(raster)}") diff --git a/tests/test_utils.py b/tests/test_utils.py index 19a0c767..d48a7edd 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -46,7 +46,7 @@ def test_unpad_raster(images, labels) -> None: padded = schema.parse(padded, dims=data.dims) elif isinstance(raster, MultiscaleSpatialImage): # some arbitrary scaling factors - padded = schema.parse(padded, dims=data.dims, scale_factors=[2, 2]) + padded = schema.parse(padded, dims=data.dims, multiscale_factors=[2, 2]) else: raise ValueError(f"Unknown type: {type(raster)}") unpadded = unpad_raster(padded) From c67435a1ff93dd6ec398a6add8fa73541638b0cc Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 13:02:43 +0100 Subject: [PATCH 15/24] updates --- spatialdata/_core/core_utils.py | 30 ++++++++++++++++++++++++++++++ spatialdata/_io/format.py | 24 ++++++++++++++++++++++++ spatialdata/_io/read.py | 1 + spatialdata/_io/write.py | 2 ++ 4 files changed, 57 insertions(+) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index ecd68461..2a54b2b7 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -364,3 +364,33 @@ def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: {k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type] ) return MultiscaleSpatialImage.from_dict(d=out) + + +@singledispatch +def get_channels(data: Any) -> list[Any]: + """Get channels from data. + + Parameters + ---------- + data + data to get channels from + + Returns + ------- + List of channels + """ + raise ValueError(f"Cannot get channels from {type(data)}") + + +@get_channels.register +def _(data: SpatialImage) -> list[Any]: + return data.coords["c"].values.tolist() + + +@get_channels.register +def _(data: MultiscaleSpatialImage) -> list[Any]: + name = list({list(data[i].data_vars.keys())[0] for i in data.keys()})[0] + channels = {tuple(data[i][name].coords["c"].values) for i in data.keys()} + if len(channels) > 1: + raise ValueError("TODO") + return [i for i in next(iter(channels))] diff --git a/spatialdata/_io/format.py b/spatialdata/_io/format.py index fcd50015..4fbc36bb 100644 --- a/spatialdata/_io/format.py +++ b/spatialdata/_io/format.py @@ -6,6 +6,9 @@ from shapely import GeometryType from spatialdata._core.models import PointsModel, PolygonsModel, ShapesModel +from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage +from spatialdata._core.core_utils import get_channels CoordinateTransform_t = list[dict[str, Any]] @@ -90,6 +93,27 @@ def validate_coordinate_transformations( assert np.all([j0 == j1 for j0, j1 in zip(json0, json1)]) + def channels_to_metadata( + self, data: Union[SpatialImage, MultiscaleSpatialImage], channels_metadata: Optional[dict[str, Any]] = None + ) -> dict[str, Union[int, str]]: + """Convert channels to metadata.""" + channels = get_channels(data) + metadata: dict[str, Any] = {"channels": []} + if channels_metadata is not None: + if set(channels_metadata.keys()).symmetric_difference(set(channels)): + for c in channels: + metadata["channels"].append({"labels": c} | omero_metadata[c]) + else: + raise ValueError("Channels metadata must contain all channels.") + else: + for c in channels: + metadata["channels"].append({"labels": c}) + return metadata + + def channels_from_metadata(self, omero_metadata: dict[str, Any]) -> list[Any]: + """Convert channels from metadata.""" + return [d["labels"] for d in omero_metadata["channels"]] + class PolygonsFormat(SpatialDataFormatV01): """Formatter for polygons.""" diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index e504b958..4e0e25ed 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -166,6 +166,7 @@ def _read_multiscale( encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) name = node.metadata["name"] + channels_metadata = fmt.channels_from_metadata(node.metadata["omero"]) if type(name) == list: assert len(name) == 1 name = name[0] diff --git a/spatialdata/_io/write.py b/spatialdata/_io/write.py index b125e597..fc27da09 100644 --- a/spatialdata/_io/write.py +++ b/spatialdata/_io/write.py @@ -118,6 +118,7 @@ def _write_raster( fmt: Format = SpatialDataFormatV01(), storage_options: Optional[Union[JSONDict, list[JSONDict]]] = None, label_metadata: Optional[JSONDict] = None, + channels_metadata: Optional[JSONDict] = None, **metadata: Union[str, JSONDict, list[JSONDict]], ) -> None: assert raster_type in ["image", "labels"] @@ -162,6 +163,7 @@ def _get_group_for_writing_transformations() -> zarr.Group: # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data + metadata["omero"] = fmt.channels_to_metadata(raster_data, channels_metadata) write_single_scale_ngff( group=_get_group_for_writing_data(), scaler=None, From 3240659999865b3c4f8c15eca851731f5600c1d8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 20 Feb 2023 12:03:26 +0000 Subject: [PATCH 16/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- spatialdata/_io/format.py | 6 +++--- spatialdata/_io/read.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/spatialdata/_io/format.py b/spatialdata/_io/format.py index 4fbc36bb..1fd8cb4e 100644 --- a/spatialdata/_io/format.py +++ b/spatialdata/_io/format.py @@ -1,14 +1,14 @@ from typing import Any, Optional, Union from anndata import AnnData +from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import CurrentFormat from pandas.api.types import is_categorical_dtype from shapely import GeometryType - -from spatialdata._core.models import PointsModel, PolygonsModel, ShapesModel -from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage + from spatialdata._core.core_utils import get_channels +from spatialdata._core.models import PointsModel, PolygonsModel, ShapesModel CoordinateTransform_t = list[dict[str, Any]] diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index 4e0e25ed..230df95c 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -166,7 +166,7 @@ def _read_multiscale( encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) name = node.metadata["name"] - channels_metadata = fmt.channels_from_metadata(node.metadata["omero"]) + fmt.channels_from_metadata(node.metadata["omero"]) if type(name) == list: assert len(name) == 1 name = name[0] From 4214abccb2d9d4ae1aa4501deec91ec3ee8446aa Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 18:14:25 +0100 Subject: [PATCH 17/24] add channels to IO --- spatialdata/_io/format.py | 4 ++-- spatialdata/_io/read.py | 24 ++++++++++++------------ spatialdata/_io/write.py | 4 +++- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/spatialdata/_io/format.py b/spatialdata/_io/format.py index 2df87c95..44d8f608 100644 --- a/spatialdata/_io/format.py +++ b/spatialdata/_io/format.py @@ -95,7 +95,7 @@ def validate_coordinate_transformations( def channels_to_metadata( self, data: Union[SpatialImage, MultiscaleSpatialImage], channels_metadata: Optional[dict[str, Any]] = None ) -> dict[str, Union[int, str]]: - """Convert channels to metadata.""" + """Convert channels to omero metadata.""" channels = get_channels(data) metadata: dict[str, Any] = {"channels": []} if channels_metadata is not None: @@ -110,7 +110,7 @@ def channels_to_metadata( return metadata def channels_from_metadata(self, omero_metadata: dict[str, Any]) -> list[Any]: - """Convert channels from metadata.""" + """Convert channels from omero metadata.""" return [d["labels"] for d in omero_metadata["channels"]] diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index a25fb677..b126d5bc 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -21,13 +21,13 @@ from spatialdata._core.core_utils import ( MappingToCoordinateSystem_t, _set_transformations, + compute_coordinates, ) from spatialdata._core.models import TableModel from spatialdata._core.ngff.ngff_transformations import NgffBaseTransformation from spatialdata._core.transformations import BaseTransformation from spatialdata._io._utils import ome_zarr_logger from spatialdata._io.format import PointsFormat, ShapesFormat, SpatialDataFormatV01 -from spatialdata._logging import logger def read_zarr(store: Union[str, Path, zarr.Group]) -> SpatialData: @@ -148,17 +148,15 @@ def _read_multiscale( datasets = node.load(Multiscales).datasets multiscales = node.load(Multiscales).zarr.root_attrs["multiscales"] assert len(multiscales) == 1 + # checking for multiscales[0]["coordinateTransformations"] would make fail + # something that doesn't have coordinateTransformations in top level + # which is true for the current version of the spec + # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - name = node.metadata["name"] - fmt.channels_from_metadata(node.metadata["omero"]) - if type(name) == list: - assert len(name) == 1 - name = name[0] - logger.warning( - "omero metadata is not fully supported yet, using a workaround. If you encounter bugs related " - "to omero metadata please follow the discussion at https://github.com/scverse/spatialdata/issues/60" - ) + node.metadata["name"] + omero = multiscales[0]["omero"] + channels = fmt.channels_from_metadata(omero) axes = [i["name"] for i in node.metadata["axes"]] if len(datasets) > 1: multiscale_image = {} @@ -170,11 +168,12 @@ def _read_multiscale( name="image", # name=name, dims=axes, + coords={"c": channels}, # attrs={"transform": t}, ) msi = MultiscaleSpatialImage.from_dict(multiscale_image) _set_transformations(msi, transformations) - return msi + return compute_coordinates(msi) else: data = node.load(Multiscales).array(resolution=datasets[0], version=fmt.version) si = SpatialImage( @@ -183,10 +182,11 @@ def _read_multiscale( name="image", # name=name, dims=axes, + coords={"c": channels}, # attrs={TRANSFORM_KEY: t}, ) _set_transformations(si, transformations) - return si + return compute_coordinates(si) def _read_shapes(store: Union[str, Path, MutableMapping, zarr.Group], fmt: SpatialDataFormatV01 = ShapesFormat()) -> GeoDataFrame: # type: ignore[type-arg] diff --git a/spatialdata/_io/write.py b/spatialdata/_io/write.py index 81e56604..22ee5086 100644 --- a/spatialdata/_io/write.py +++ b/spatialdata/_io/write.py @@ -143,6 +143,9 @@ def _get_group_for_writing_transformations() -> zarr.Group: else: return group["labels"][name] + # convert channel names to channel metadata + metadata["omero"] = fmt.channels_to_metadata(raster_data, channels_metadata) + if isinstance(raster_data, SpatialImage): data = raster_data.data transformations = _get_transformations(raster_data) @@ -158,7 +161,6 @@ def _get_group_for_writing_transformations() -> zarr.Group: # We need this because the argument of write_image_ngff is called image while the argument of # write_labels_ngff is called label. metadata[raster_type] = data - metadata["omero"] = fmt.channels_to_metadata(raster_data, channels_metadata) write_single_scale_ngff( group=_get_group_for_writing_data(), scaler=None, From f65151e9a237e6172d25525a2b17f5a3aa40faa7 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 20 Feb 2023 18:25:40 +0100 Subject: [PATCH 18/24] update omero exlcuding it from labels --- spatialdata/_io/read.py | 10 ++++++---- spatialdata/_io/write.py | 3 ++- tests/_io/test_readwrite.py | 10 ++-------- 3 files changed, 10 insertions(+), 13 deletions(-) diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index b126d5bc..aedee159 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -155,8 +155,10 @@ def _read_multiscale( encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) node.metadata["name"] - omero = multiscales[0]["omero"] - channels = fmt.channels_from_metadata(omero) + # if image, read channels metadata + if raster_type == "image": + omero = multiscales[0]["omero"] + channels = fmt.channels_from_metadata(omero) axes = [i["name"] for i in node.metadata["axes"]] if len(datasets) > 1: multiscale_image = {} @@ -168,7 +170,7 @@ def _read_multiscale( name="image", # name=name, dims=axes, - coords={"c": channels}, + coords={"c": channels} if raster_type == "image" else {}, # attrs={"transform": t}, ) msi = MultiscaleSpatialImage.from_dict(multiscale_image) @@ -182,7 +184,7 @@ def _read_multiscale( name="image", # name=name, dims=axes, - coords={"c": channels}, + coords={"c": channels} if raster_type == "image" else {}, # attrs={TRANSFORM_KEY: t}, ) _set_transformations(si, transformations) diff --git a/spatialdata/_io/write.py b/spatialdata/_io/write.py index 22ee5086..2429379a 100644 --- a/spatialdata/_io/write.py +++ b/spatialdata/_io/write.py @@ -144,7 +144,8 @@ def _get_group_for_writing_transformations() -> zarr.Group: return group["labels"][name] # convert channel names to channel metadata - metadata["omero"] = fmt.channels_to_metadata(raster_data, channels_metadata) + if raster_type == "image": + metadata["omero"] = fmt.channels_to_metadata(raster_data, channels_metadata) if isinstance(raster_data, SpatialImage): data = raster_data.data diff --git a/tests/_io/test_readwrite.py b/tests/_io/test_readwrite.py index 48cc3803..7e7a2460 100644 --- a/tests/_io/test_readwrite.py +++ b/tests/_io/test_readwrite.py @@ -26,10 +26,7 @@ def test_images(self, tmp_path: str, images: SpatialData) -> None: sdata = SpatialData.read(tmpdir) assert images.images.keys() == sdata.images.keys() for k in images.images.keys(): - if isinstance(sdata.images[k], SpatialImage): - assert images.images[k].equals(sdata.images[k]) - elif isinstance(images.images[k], MultiscaleSpatialImage): - assert images.images[k].equals(sdata.images[k]) + assert images.images[k].equals(sdata.images[k]) def test_labels(self, tmp_path: str, labels: SpatialData) -> None: """Test read/write.""" @@ -38,10 +35,7 @@ def test_labels(self, tmp_path: str, labels: SpatialData) -> None: sdata = SpatialData.read(tmpdir) assert labels.labels.keys() == sdata.labels.keys() for k in labels.labels.keys(): - if isinstance(sdata.labels[k], SpatialImage): - assert labels.labels[k].equals(sdata.labels[k]) - elif isinstance(sdata.labels[k], MultiscaleSpatialImage): - assert labels.labels[k].equals(sdata.labels[k]) + assert labels.labels[k].equals(sdata.labels[k]) def test_shapes(self, tmp_path: str, shapes: SpatialData) -> None: """Test read/write.""" From 448da01c3808852b77b60e931790bf6a1c4eb943 Mon Sep 17 00:00:00 2001 From: giovp Date: Wed, 22 Feb 2023 14:27:32 +0100 Subject: [PATCH 19/24] use isel --- spatialdata/_core/_spatial_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spatialdata/_core/_spatial_query.py b/spatialdata/_core/_spatial_query.py index 00f11f86..ad9621dd 100644 --- a/spatialdata/_core/_spatial_query.py +++ b/spatialdata/_core/_spatial_query.py @@ -131,7 +131,7 @@ def _bounding_box_query_image( # add the selection[axis_name] = slice(min_value, max_value) - query_result = image.sel(selection) + query_result = image.isel(selection) # update the transform # currently, this assumes the existing transforms input coordinate system From 51d865d50b1af248b2b1e5bf8d7166683066ab20 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 23 Feb 2023 15:50:44 +0100 Subject: [PATCH 20/24] read name from raster data --- spatialdata/_core/core_utils.py | 3 ++- spatialdata/_io/read.py | 10 +++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index 5f05ec6d..10265099 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -280,7 +280,8 @@ def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: if "scale0" in e: return tuple(i for i in e["scale0"].dims.keys()) else: - return tuple(i for i in e.dims.keys()) + raise ValueError("MultiscaleSpatialImage does not contain the scale0 key") + # return tuple(i for i in e.dims.keys()) @get_dims.register(GeoDataFrame) diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index aedee159..48537046 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -154,7 +154,7 @@ def _read_multiscale( # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - node.metadata["name"] + name = node.metadata["name"] # if image, read channels metadata if raster_type == "image": omero = multiscales[0]["omero"] @@ -166,9 +166,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=d, version=fmt.version) multiscale_image[f"scale{i}"] = DataArray( data, - # any name - name="image", - # name=name, + name=name, dims=axes, coords={"c": channels} if raster_type == "image" else {}, # attrs={"transform": t}, @@ -180,9 +178,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=datasets[0], version=fmt.version) si = SpatialImage( data, - # any name - name="image", - # name=name, + name=name, dims=axes, coords={"c": channels} if raster_type == "image" else {}, # attrs={TRANSFORM_KEY: t}, From 524525919789ed3e4dbd922e103eb374896cbe49 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 23 Feb 2023 18:40:04 +0100 Subject: [PATCH 21/24] fix unpad_raster() --- spatialdata/_core/core_utils.py | 21 ++++++-- spatialdata/_core/models.py | 9 ++-- spatialdata/utils.py | 87 +++++++++++++++++++++------------ tests/_core/test_models.py | 2 +- tests/test_utils.py | 12 +++-- 5 files changed, 89 insertions(+), 42 deletions(-) diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index 10265099..20c43b89 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -11,7 +11,7 @@ from xarray import DataArray from spatialdata._core.ngff.ngff_coordinate_system import NgffAxis, NgffCoordinateSystem -from spatialdata._core.transformations import BaseTransformation +from spatialdata._core.transformations import BaseTransformation, Sequence from spatialdata._types import ArrayLike SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] @@ -176,6 +176,8 @@ def _(e: MultiscaleSpatialImage, transformations: MappingToCoordinateSystem_t) - scale_factors = old_shape / new_shape filtered_scale_factors = [scale_factors[i] for i, ax in enumerate(dims) if ax != "c"] filtered_axes = [ax for ax in dims if ax != "c"] + if not np.isfinite(filtered_scale_factors).all(): + raise ValueError("Scale factors must be finite.") scale = Scale(scale=filtered_scale_factors, axes=tuple(filtered_axes)) assert transformations is not None new_transformations = {} @@ -335,11 +337,24 @@ def _(data: SpatialImage) -> SpatialImage: @compute_coordinates.register(MultiscaleSpatialImage) def _(data: MultiscaleSpatialImage) -> MultiscaleSpatialImage: def _get_scale(transforms: dict[str, Any]) -> Optional[ArrayLike]: - for t in transforms["global"].transformations: + all_scale_vectors = [] + for transformation in transforms.values(): + assert isinstance(transformation, Sequence) + # the first transformation is the scale + t = transformation.transformations[0] if hasattr(t, "scale"): if TYPE_CHECKING: assert isinstance(t.scale, np.ndarray) - return t.scale + all_scale_vectors.append(tuple(t.scale.tolist())) + else: + raise ValueError(f"Unsupported transformation: {t}") + # all the scales should be the same since they all refer to the mapping of the level of the multiscale to the + # base level, with respect to the intrinstic coordinate system + assert len(set(all_scale_vectors)) == 1 + scalef = np.array(all_scale_vectors[0]) + if not np.isfinite(scalef).all(): + raise ValueError(f"Invalid scale factor: {scalef}") + return scalef def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: return ( # type: ignore[no-any-return] diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index a64f53af..d4b4bdd9 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -174,9 +174,12 @@ def parse( # transpose if possible if dims != cls.dims.dims: try: - assert isinstance(data, DaskArray) or isinstance(data, DataArray) - # mypy complains that data has no .transpose but I have asserted right above that data is a DaskArray... - data = data.transpose(*[_reindex(d) for d in cls.dims.dims]) # type: ignore[attr-defined] + if isinstance(data, DataArray): + data = data.transpose(*list(cls.dims.dims)) + elif isinstance(data, DaskArray): + data = data.transpose(*[_reindex(d) for d in cls.dims.dims]) + else: + raise ValueError(f"Unsupported data type: {type(data)}.") logger.info(f"Transposing `data` of type: {type(data)} to {cls.dims.dims}.") except ValueError: raise ValueError(f"Cannot transpose arrays to match `dims`: {dims}. Try to reshape `data` or `dims`.") diff --git a/spatialdata/utils.py b/spatialdata/utils.py index 0af50d44..9b2ed124 100644 --- a/spatialdata/utils.py +++ b/spatialdata/utils.py @@ -97,9 +97,9 @@ def unpad_raster(raster: Union[SpatialImage, MultiscaleSpatialImage]) -> Union[S ------- The unpadded raster. """ - from spatialdata._core.models import get_schema + from spatialdata._core.core_utils import compute_coordinates, get_dims - def _unpad_axis(data: DataArray, axis: str) -> tuple[DataArray, float]: + def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]: others = list(data.dims) others.remove(axis) # mypy (luca's pycharm config) can't see the isclose method of dask array @@ -108,50 +108,73 @@ def _unpad_axis(data: DataArray, axis: str) -> tuple[DataArray, float]: x = s.compute() non_zero = np.where(x == 0)[0] if len(non_zero) == 0: - return data, 0 + min_coordinate, max_coordinate = data.coords[axis].min().item(), data.coords[axis].max().item() + if not min_coordinate != 0: + raise ValueError( + f"Expected minimum coordinate for axis {axis} to be 0, but got {min_coordinate}. Please report this bug." + ) + if max_coordinate != data.shape[data.dims.index(axis)] - 1: + raise ValueError( + f"Expected maximum coordinate for axis {axis} to be {data.shape[data.dims.index(axis)] - 1}, but got {max_coordinate}. Please report this bug." + ) + return 0, data.shape[data.dims.index(axis)] else: left_pad = non_zero[0] right_pad = non_zero[-1] + 1 - unpadded = data.isel({axis: slice(left_pad, right_pad)}) - return unpadded, left_pad - - from spatialdata._core.core_utils import get_dims + return left_pad, right_pad axes = get_dims(raster) - if isinstance(raster, SpatialImage): - unpadded = raster - translation_axes = [] - translation_values: list[float] = [] + translation_axes = [] + translation_values: list[float] = [] + unpadded = raster + + if isinstance(unpadded, SpatialImage): + for ax in axes: + if ax != "c": + left_pad, right_pad = _compute_paddings(data=unpadded, axis=ax) + unpadded = unpadded.isel({ax: slice(left_pad, right_pad)}) + translation_axes.append(ax) + translation_values.append(left_pad) + elif isinstance(unpadded, MultiscaleSpatialImage): for ax in axes: if ax != "c": - unpadded, left_pad = _unpad_axis(unpadded, axis=ax) + # let's just operate on the highest resolution. This is not an efficient implementation but we can always optimize later + d = dict(unpadded["scale0"]) + assert len(d) == 1 + xdata = d.values().__iter__().__next__() + + left_pad, right_pad = _compute_paddings(data=xdata, axis=ax) + EPS = 1e-6 + unpadded = unpadded.sel({ax: slice(left_pad, right_pad - EPS)}) translation_axes.append(ax) translation_values.append(left_pad) - translation = Translation(translation_values, axes=tuple(translation_axes)) - old_transformations = get_transformation(element=raster, get_all=True) - assert isinstance(old_transformations, dict) - for target_cs, old_transform in old_transformations.items(): - assert old_transform is not None - sequence = Sequence([translation, old_transform]) - set_transformation(element=unpadded, transformation=sequence, to_coordinate_system=target_cs) - return unpadded - elif isinstance(raster, MultiscaleSpatialImage): - # let's just operate on the highest resolution. This is not an efficient implementation but we can always optimize later - d = dict(raster["scale0"]) - assert len(d) == 1 - xdata = d.values().__iter__().__next__() - unpadded = unpad_raster(SpatialImage(xdata)) + d = {} + for k, v in unpadded.items(): + assert len(v.values()) == 1 + xdata = v.values().__iter__().__next__() + d[k] = xdata + unpadded = MultiscaleSpatialImage.from_dict(d) + # left_pad, right_pad = _compute_paddings(SpatialImage(xdata), axis=ax) # TODO: here I am using some arbitrary scalingfactors, I think that we need an automatic initialization of multiscale. See discussion: https://github.com/scverse/spatialdata/issues/108 # mypy thinks that the schema could be a ShapeModel, ... but it's not - if "z" in axes: - scale_factors = [2] - else: - scale_factors = [2, 2] - unpadded_multiscale = get_schema(raster).parse(unpadded, scale_factors=scale_factors) # type: ignore[call-arg] - return unpadded_multiscale + # if "z" in axes: + # scale_factors = [2] + # else: + # scale_factors = [2, 2] + # unpadded_multiscale = get_schema(unpadded).parse(unpadded, scale_factors=scale_factors) # type: ignore[call-arg] + # return compute_coordinates(unpadded_multiscale) else: raise TypeError(f"Unsupported type: {type(raster)}") + translation = Translation(translation_values, axes=tuple(translation_axes)) + old_transformations = get_transformation(element=raster, get_all=True) + assert isinstance(old_transformations, dict) + for target_cs, old_transform in old_transformations.items(): + assert old_transform is not None + sequence = Sequence([translation, old_transform]) + set_transformation(element=unpadded, transformation=sequence, to_coordinate_system=target_cs) + return compute_coordinates(unpadded) + def get_table_mapping_metadata(table: AnnData) -> dict[str, Union[Optional[Union[str, list[str]]], Optional[str]]]: """ diff --git a/tests/_core/test_models.py b/tests/_core/test_models.py index e208683e..2e4a73f3 100644 --- a/tests/_core/test_models.py +++ b/tests/_core/test_models.py @@ -52,7 +52,7 @@ RNG = default_rng() # should be set to False for pre-commit and CI; useful to set to True for are fixing/debugging tests -SHORT_TESTS = False +SHORT_TESTS = True class TestModels: diff --git a/tests/test_utils.py b/tests/test_utils.py index d48a7edd..1fc07ee1 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -46,17 +46,23 @@ def test_unpad_raster(images, labels) -> None: padded = schema.parse(padded, dims=data.dims) elif isinstance(raster, MultiscaleSpatialImage): # some arbitrary scaling factors - padded = schema.parse(padded, dims=data.dims, multiscale_factors=[2, 2]) + padded = schema.parse(padded, dims=data.dims, scale_factors=[2, 2]) else: raise ValueError(f"Unknown type: {type(raster)}") unpadded = unpad_raster(padded) if isinstance(raster, SpatialImage): - xarray.testing.assert_equal(raster, unpadded) + try: + xarray.testing.assert_equal(raster, unpadded) + except AssertionError as e: + raise e elif isinstance(raster, MultiscaleSpatialImage): d0 = dict(raster["scale0"]) assert len(d0) == 1 d1 = dict(unpadded["scale0"]) assert len(d1) == 1 - xarray.testing.assert_equal(d0.values().__iter__().__next__(), d1.values().__iter__().__next__()) + try: + xarray.testing.assert_equal(d0.values().__iter__().__next__(), d1.values().__iter__().__next__()) + except AssertionError as e: + raise e else: raise ValueError(f"Unknown type: {type(raster)}") From 0a752fe5e59bf59b03ec2b666d7b30850be89504 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 23 Feb 2023 21:42:56 +0100 Subject: [PATCH 22/24] all tests passing --- spatialdata/_core/_transform_elements.py | 73 ++++++++------ spatialdata/_core/core_utils.py | 119 +++++++++++++++-------- spatialdata/_io/read.py | 3 +- spatialdata/utils.py | 18 +--- tests/_core/test_models.py | 2 +- 5 files changed, 131 insertions(+), 84 deletions(-) diff --git a/spatialdata/_core/_transform_elements.py b/spatialdata/_core/_transform_elements.py index 5046d3eb..fb329a23 100644 --- a/spatialdata/_core/_transform_elements.py +++ b/spatialdata/_core/_transform_elements.py @@ -20,6 +20,8 @@ from spatialdata._core.core_utils import ( DEFAULT_COORDINATE_SYSTEM, SpatialElement, + _get_scale, + compute_coordinates, get_dims, ) from spatialdata._core.models import get_schema @@ -208,14 +210,21 @@ def _(data: SpatialImage, transformation: BaseTransformation, maintain_positioni get_transformation, set_transformation, ) - from spatialdata._core.models import Labels2DModel, Labels3DModel + from spatialdata._core.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + ) # labels need to be preserved after the resizing of the image if schema == Labels2DModel or schema == Labels3DModel: # TODO: this should work, test better kwargs = {"prefilter": False} - else: + elif schema == Image2DModel or schema == Image3DModel: kwargs = {} + else: + raise ValueError(f"Unsupported schema {schema}") axes = get_dims(data) transformed_dask, raster_translation = _transform_raster( @@ -232,6 +241,8 @@ def _(data: SpatialImage, transformation: BaseTransformation, maintain_positioni raster_translation=raster_translation, maintain_positioning=maintain_positioning, ) + transformed_data = compute_coordinates(transformed_data) + schema().validate(transformed_data) return transformed_data @@ -244,43 +255,43 @@ def _( get_transformation, set_transformation, ) - from spatialdata._core.models import Labels2DModel, Labels3DModel + from spatialdata._core.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + ) + from spatialdata._core.transformations import BaseTransformation, Sequence # labels need to be preserved after the resizing of the image if schema == Labels2DModel or schema == Labels3DModel: # TODO: this should work, test better kwargs = {"prefilter": False} - else: + elif schema == Image2DModel or schema == Image3DModel: kwargs = {} + else: + raise ValueError(f"MultiscaleSpatialImage with schema {schema} not supported") - axes = get_dims(data) - scale0 = dict(data["scale0"]) - assert len(scale0) == 1 - scale0_data = scale0.values().__iter__().__next__() - transformed_dask, raster_translation = _transform_raster( - data=scale0_data.data, axes=scale0_data.dims, transformation=transformation, **kwargs - ) + get_dims(data) + transformed_dict = {} + for k, v in data.items(): + assert len(v) == 1 + xdata = v.values().__iter__().__next__() + + composed: BaseTransformation + if k == "scale0": + composed = transformation + else: + scale = _get_scale(xdata.attrs["transform"]) + composed = Sequence([scale, transformation, scale.inverse()]) + + transformed_dask, raster_translation = _transform_raster( + data=xdata.data, axes=xdata.dims, transformation=composed, **kwargs + ) + transformed_dict[k] = SpatialImage(transformed_dask, dims=xdata.dims, name=xdata.name) - # this code is temporary and doens't work in all cases (in particular it breaks when the data is not similar - # to a square but has sides of very different lengths). I would remove it an implement (inside the parser) - # the logic described in https://github.com/scverse/spatialdata/issues/108) - shapes = [] - for level in range(len(data)): - dims = data[f"scale{level}"].dims.values() - shape = np.array([dict(dims._mapping)[k] for k in axes if k != "c"]) - shapes.append(shape) - multiscale_factors = [] - shape0 = shapes[0] - for shape in shapes[1:]: - factors = shape0 / shape - factors - min(factors) - # assert np.allclose(almost_zero, np.zeros_like(almost_zero), rtol=2.) - try: - multiscale_factors.append(round(factors[0])) - except ValueError as e: - raise e # mypy thinks that schema could be ShapesModel, PointsModel, ... - transformed_data = schema.parse(transformed_dask, dims=axes, scale_factors=multiscale_factors) # type: ignore[call-arg,arg-type] + transformed_data = MultiscaleSpatialImage.from_dict(transformed_dict) old_transformations = get_transformation(data, get_all=True) assert isinstance(old_transformations, dict) set_transformation(transformed_data, old_transformations, set_all=True) @@ -290,6 +301,8 @@ def _( raster_translation=raster_translation, maintain_positioning=maintain_positioning, ) + transformed_data = compute_coordinates(transformed_data) + schema().validate(transformed_data) return transformed_data diff --git a/spatialdata/_core/core_utils.py b/spatialdata/_core/core_utils.py index 20c43b89..165039fd 100644 --- a/spatialdata/_core/core_utils.py +++ b/spatialdata/_core/core_utils.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import copy from functools import singledispatch from typing import TYPE_CHECKING, Any, Optional, Union @@ -14,6 +16,9 @@ from spatialdata._core.transformations import BaseTransformation, Sequence from spatialdata._types import ArrayLike +if TYPE_CHECKING: + from spatialdata._core.transformations import Scale + SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] __all__ = [ @@ -178,11 +183,11 @@ def _(e: MultiscaleSpatialImage, transformations: MappingToCoordinateSystem_t) - filtered_axes = [ax for ax in dims if ax != "c"] if not np.isfinite(filtered_scale_factors).all(): raise ValueError("Scale factors must be finite.") - scale = Scale(scale=filtered_scale_factors, axes=tuple(filtered_axes)) + scale_transformation = Scale(scale=filtered_scale_factors, axes=tuple(filtered_axes)) assert transformations is not None new_transformations = {} for k, v in transformations.items(): - sequence: BaseTransformation = Sequence([scale, v]) + sequence: BaseTransformation = Sequence([scale_transformation, v]) new_transformations[k] = sequence _set_transformations_xarray(xdata, new_transformations) else: @@ -254,6 +259,14 @@ def get_default_coordinate_system(dims: tuple[str, ...]) -> NgffCoordinateSystem return NgffCoordinateSystem(name="".join(dims), axes=axes) +def _validate_dims(dims: tuple[str, ...]) -> None: + for c in dims: + if c not in (X, Y, Z, C): + raise ValueError(f"Invalid dimension: {c}") + if dims not in [(X,), (Y,), (Z,), (C,), (X, Y), (X, Y, Z), (Y, X), (Z, Y, X), (C, Y, X), (C, Z, Y, X)]: + raise ValueError(f"Invalid dimensions: {dims}") + + @singledispatch def get_dims(e: SpatialElement) -> tuple[str, ...]: """ @@ -274,13 +287,33 @@ def get_dims(e: SpatialElement) -> tuple[str, ...]: @get_dims.register(SpatialImage) def _(e: SpatialImage) -> tuple[str, ...]: dims = e.dims + # dims_sizes = tuple(list(e.sizes.keys())) + # # we check that the following values are the same otherwise we could incur in subtle bugs downstreams + # if dims != dims_sizes: + # raise ValueError(f"SpatialImage has inconsistent dimensions: {dims}, {dims_sizes}") + _validate_dims(dims) return dims # type: ignore @get_dims.register(MultiscaleSpatialImage) def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: if "scale0" in e: - return tuple(i for i in e["scale0"].dims.keys()) + # dims_coordinates = tuple(i for i in e["scale0"].dims.keys()) + + assert len(e["scale0"].values()) == 1 + xdata = e["scale0"].values().__iter__().__next__() + dims_data = xdata.dims + assert isinstance(dims_data, tuple) + + # dims_sizes = tuple(list(xdata.sizes.keys())) + + # # we check that all the following values are the same otherwise we could incur in subtle bugs downstreams + # if dims_coordinates != dims_data or dims_coordinates != dims_sizes: + # raise ValueError( + # f"MultiscaleSpatialImage has inconsistent dimensions: {dims_coordinates}, {dims_data}, {dims_sizes}" + # ) + _validate_dims(dims_data) + return dims_data else: raise ValueError("MultiscaleSpatialImage does not contain the scale0 key") # return tuple(i for i in e.dims.keys()) @@ -288,23 +321,19 @@ def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: @get_dims.register(GeoDataFrame) def _(e: GeoDataFrame) -> tuple[str, ...]: - dims = (X, Y, Z) + all_dims = (X, Y, Z) n = e.geometry.iloc[0]._ndim - return dims[:n] - - -@get_dims.register(AnnData) -def _(e: AnnData) -> tuple[str, ...]: - dims = (X, Y, Z) - n = e.obsm["spatial"].shape[1] - return dims[:n] + dims = all_dims[:n] + _validate_dims(dims) + return dims @get_dims.register(DaskDataFrame) def _(e: AnnData) -> tuple[str, ...]: valid_dims = (X, Y, Z) - dims = [c for c in valid_dims if c in e.columns] - return tuple(dims) + dims = tuple([c for c in valid_dims if c in e.columns]) + _validate_dims(dims) + return dims @singledispatch @@ -334,28 +363,37 @@ def _(data: SpatialImage) -> SpatialImage: return data.assign_coords(coords) +def _get_scale(transforms: dict[str, Any]) -> Scale: + from spatialdata._core.transformations import Scale + + all_scale_vectors = [] + all_scale_axes = [] + for transformation in transforms.values(): + assert isinstance(transformation, Sequence) + # the first transformation is the scale + t = transformation.transformations[0] + if hasattr(t, "scale"): + if TYPE_CHECKING: + assert isinstance(t.scale, np.ndarray) + all_scale_vectors.append(tuple(t.scale.tolist())) + assert isinstance(t, Scale) + all_scale_axes.append(tuple(t.axes)) + else: + raise ValueError(f"Unsupported transformation: {t}") + # all the scales should be the same since they all refer to the mapping of the level of the multiscale to the + # base level, with respect to the intrinstic coordinate system + assert len(set(all_scale_vectors)) == 1 + assert len(set(all_scale_axes)) == 1 + scalef = np.array(all_scale_vectors[0]) + if not np.isfinite(scalef).all(): + raise ValueError(f"Invalid scale factor: {scalef}") + scale_axes = all_scale_axes[0] + scale = Scale(scalef, axes=scale_axes) + return scale + + @compute_coordinates.register(MultiscaleSpatialImage) def _(data: MultiscaleSpatialImage) -> MultiscaleSpatialImage: - def _get_scale(transforms: dict[str, Any]) -> Optional[ArrayLike]: - all_scale_vectors = [] - for transformation in transforms.values(): - assert isinstance(transformation, Sequence) - # the first transformation is the scale - t = transformation.transformations[0] - if hasattr(t, "scale"): - if TYPE_CHECKING: - assert isinstance(t.scale, np.ndarray) - all_scale_vectors.append(tuple(t.scale.tolist())) - else: - raise ValueError(f"Unsupported transformation: {t}") - # all the scales should be the same since they all refer to the mapping of the level of the multiscale to the - # base level, with respect to the intrinstic coordinate system - assert len(set(all_scale_vectors)) == 1 - scalef = np.array(all_scale_vectors[0]) - if not np.isfinite(scalef).all(): - raise ValueError(f"Invalid scale factor: {scalef}") - return scalef - def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: return ( # type: ignore[no-any-return] DataArray(np.linspace(0, max_, max_, endpoint=False, dtype=np.float_)) @@ -374,12 +412,15 @@ def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike: coords: dict[str, ArrayLike] = {d: np.arange(max_scale[d], dtype=np.float_) for d in max_scale.keys()} out[name] = dt[img_name].assign_coords(coords) else: - scalef = _get_scale(dt[img_name].attrs["transform"]) + scale = _get_scale(dt[img_name].attrs["transform"]) + scalef = scale.scale assert len(max_scale.keys()) == len(scalef), "Mismatch between coordinates and scales." # type: ignore[arg-type] - out[name] = dt[img_name].assign_coords( - {k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type] - ) - return MultiscaleSpatialImage.from_dict(d=out) + new_coords = {k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type] + out[name] = dt[img_name].assign_coords(new_coords) + msi = MultiscaleSpatialImage.from_dict(d=out) + # this is to trigger the validation of the dims + _ = get_dims(msi) + return msi @singledispatch diff --git a/spatialdata/_io/read.py b/spatialdata/_io/read.py index 48537046..80831ed8 100644 --- a/spatialdata/_io/read.py +++ b/spatialdata/_io/read.py @@ -1,4 +1,5 @@ import logging +import os from collections.abc import MutableMapping from pathlib import Path from typing import Any, Literal, Optional, Union @@ -154,7 +155,7 @@ def _read_multiscale( # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - name = node.metadata["name"] + name = os.path.basename(node.metadata["name"]) # if image, read channels metadata if raster_type == "image": omero = multiscales[0]["omero"] diff --git a/spatialdata/utils.py b/spatialdata/utils.py index 9b2ed124..76164ac0 100644 --- a/spatialdata/utils.py +++ b/spatialdata/utils.py @@ -138,14 +138,14 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]: elif isinstance(unpadded, MultiscaleSpatialImage): for ax in axes: if ax != "c": - # let's just operate on the highest resolution. This is not an efficient implementation but we can always optimize later + # let's just operate on the highest resolution. This is not an efficient implementation but we can + # always optimize later d = dict(unpadded["scale0"]) assert len(d) == 1 xdata = d.values().__iter__().__next__() left_pad, right_pad = _compute_paddings(data=xdata, axis=ax) - EPS = 1e-6 - unpadded = unpadded.sel({ax: slice(left_pad, right_pad - EPS)}) + unpadded = unpadded.sel({ax: slice(left_pad, right_pad - 1e-6)}) translation_axes.append(ax) translation_values.append(left_pad) d = {} @@ -154,15 +154,6 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]: xdata = v.values().__iter__().__next__() d[k] = xdata unpadded = MultiscaleSpatialImage.from_dict(d) - # left_pad, right_pad = _compute_paddings(SpatialImage(xdata), axis=ax) - # TODO: here I am using some arbitrary scalingfactors, I think that we need an automatic initialization of multiscale. See discussion: https://github.com/scverse/spatialdata/issues/108 - # mypy thinks that the schema could be a ShapeModel, ... but it's not - # if "z" in axes: - # scale_factors = [2] - # else: - # scale_factors = [2, 2] - # unpadded_multiscale = get_schema(unpadded).parse(unpadded, scale_factors=scale_factors) # type: ignore[call-arg] - # return compute_coordinates(unpadded_multiscale) else: raise TypeError(f"Unsupported type: {type(raster)}") @@ -173,7 +164,8 @@ def _compute_paddings(data: DataArray, axis: str) -> tuple[int, int]: assert old_transform is not None sequence = Sequence([translation, old_transform]) set_transformation(element=unpadded, transformation=sequence, to_coordinate_system=target_cs) - return compute_coordinates(unpadded) + unpadded = compute_coordinates(unpadded) + return unpadded def get_table_mapping_metadata(table: AnnData) -> dict[str, Union[Optional[Union[str, list[str]]], Optional[str]]]: diff --git a/tests/_core/test_models.py b/tests/_core/test_models.py index 2e4a73f3..e208683e 100644 --- a/tests/_core/test_models.py +++ b/tests/_core/test_models.py @@ -52,7 +52,7 @@ RNG = default_rng() # should be set to False for pre-commit and CI; useful to set to True for are fixing/debugging tests -SHORT_TESTS = True +SHORT_TESTS = False class TestModels: From f20582a260507d9a53df3799b91898ccf8a4e769 Mon Sep 17 00:00:00 2001 From: Luca Marconato <2664412+LucaMarconato@users.noreply.github.com> Date: Thu, 23 Feb 2023 21:56:33 +0100 Subject: [PATCH 23/24] fixed precommit --- spatialdata/_compat.py | 1 + spatialdata/_core/models.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/spatialdata/_compat.py b/spatialdata/_compat.py index b39f58ac..ae04073e 100644 --- a/spatialdata/_compat.py +++ b/spatialdata/_compat.py @@ -22,4 +22,5 @@ def _check_geopandas_using_shapely() -> None: "If you intended to use PyGEOS, set the option to False." ), UserWarning, + stacklevel=2, ) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index d4b4bdd9..49f9912b 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -248,6 +248,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, + attrs=self.attrs, *args, **kwargs, ) @@ -262,6 +263,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__( dims=self.dims, array_type=self.array_type, + attrs=self.attrs, *args, **kwargs, ) From 7a9f140dda0e4cd2d440f35284286221c4a84b28 Mon Sep 17 00:00:00 2001 From: giovp Date: Fri, 24 Feb 2023 12:52:17 +0100 Subject: [PATCH 24/24] remove reference to name --- spatialdata/_core/models.py | 2 ++ tests/_core/test_models.py | 17 ++++++++++++---- tests/conftest.py | 40 +++++++++++-------------------------- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/spatialdata/_core/models.py b/spatialdata/_core/models.py index 49f9912b..8b25e11f 100644 --- a/spatialdata/_core/models.py +++ b/spatialdata/_core/models.py @@ -140,6 +140,8 @@ def parse( :class:`spatial_image.SpatialImage` or :class:`multiscale_spatial_image.MultiscaleSpatialImage`. """ + if "name" in kwargs: + raise ValueError("The `name` argument is not (yet) supported for raster data.") # if dims is specified inside the data, get the value of dims from the data if isinstance(data, DataArray) or isinstance(data, SpatialImage): if not isinstance(data.data, DaskArray): # numpy -> dask diff --git a/tests/_core/test_models.py b/tests/_core/test_models.py index e208683e..a205f15a 100644 --- a/tests/_core/test_models.py +++ b/tests/_core/test_models.py @@ -40,6 +40,7 @@ get_schema, ) from spatialdata._core.transformations import Scale +from spatialdata._types import ArrayLike from tests._core.conftest import MULTIPOLYGON_PATH, POINT_PATH, POLYGON_PATH from tests.conftest import ( _get_images, @@ -143,9 +144,12 @@ def _passes_validation_after_io(self, model: Any, element: Any, element_type: st model.validate(element_read) @pytest.mark.parametrize("converter", [lambda _: _, from_array, DataArray, to_spatial_image]) - @pytest.mark.parametrize("model", [Image2DModel, Labels2DModel, Labels3DModel]) # TODO: Image3DModel once fixed. + @pytest.mark.parametrize("model", [Image2DModel, Labels2DModel, Labels3DModel, Image3DModel]) @pytest.mark.parametrize("permute", [True, False]) - def test_raster_schema(self, converter: Callable[..., Any], model: RasterSchema, permute: bool) -> None: + @pytest.mark.parametrize("kwargs", [None, {"name": "test"}]) + def test_raster_schema( + self, converter: Callable[..., Any], model: RasterSchema, permute: bool, kwargs: Optional[dict[str, str]] + ) -> None: dims = np.array(model.dims.dims).tolist() if permute: RNG.shuffle(dims) @@ -156,9 +160,11 @@ def test_raster_schema(self, converter: Callable[..., Any], model: RasterSchema, elif converter is to_spatial_image: converter = partial(converter, dims=model.dims.dims) if n_dims == 2: - image: np.ndarray = np.random.rand(10, 10) + image: ArrayLike = np.random.rand(10, 10) elif n_dims == 3: - image = np.random.rand(3, 10, 10) + image: ArrayLike = np.random.rand(3, 10, 10) + elif n_dims == 4: + image: ArrayLike = np.random.rand(2, 3, 10, 10) image = converter(image) self._parse_transformation_from_multiple_places(model, image) spatial_image = model.parse(image) @@ -179,6 +185,9 @@ def test_raster_schema(self, converter: Callable[..., Any], model: RasterSchema, assert set(spatial_image.shape) == set(image.shape) assert set(spatial_image.data.shape) == set(image.shape) assert spatial_image.data.dtype == image.dtype + if kwargs is not None: + with pytest.raises(ValueError): + model.parse(image, **kwargs) @pytest.mark.parametrize("model", [ShapesModel]) @pytest.mark.parametrize("path", [POLYGON_PATH, MULTIPOLYGON_PATH, POINT_PATH]) diff --git a/tests/conftest.py b/tests/conftest.py index 8164a6b4..dcb552af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -124,29 +124,21 @@ def _get_images() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: out = {} dims_2d = ("c", "y", "x") dims_3d = ("z", "y", "x", "c") - out["image2d"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), name="image2d", dims=dims_2d) - out["image2d_multiscale"] = Image2DModel.parse( - RNG.normal(size=(3, 64, 64)), name="image2d_multiscale", scale_factors=[2, 2], dims=dims_2d - ) - out["image2d_xarray"] = Image2DModel.parse( - DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), name="image2d_xarray", dims=None - ) + out["image2d"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), dims=dims_2d) + out["image2d_multiscale"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), scale_factors=[2, 2], dims=dims_2d) + out["image2d_xarray"] = Image2DModel.parse(DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), dims=None) out["image2d_multiscale_xarray"] = Image2DModel.parse( DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), - name="image2d_multiscale_xarray", scale_factors=[2, 4], dims=None, ) - out["image3d_numpy"] = Image3DModel.parse(RNG.normal(size=(2, 64, 64, 3)), name="image3d_numpy", dims=dims_3d) + out["image3d_numpy"] = Image3DModel.parse(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d) out["image3d_multiscale_numpy"] = Image3DModel.parse( - RNG.normal(size=(2, 64, 64, 3)), name="image3d_multiscale_numpy", scale_factors=[2], dims=dims_3d - ) - out["image3d_xarray"] = Image3DModel.parse( - DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), name="image3d_xarray", dims=None + RNG.normal(size=(2, 64, 64, 3)), scale_factors=[2], dims=dims_3d ) + out["image3d_xarray"] = Image3DModel.parse(DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), dims=None) out["image3d_multiscale_xarray"] = Image3DModel.parse( DataArray(RNG.normal(size=(2, 64, 64, 3)), dims=dims_3d), - name="image3d_multiscale_xarray", scale_factors=[2], dims=None, ) @@ -158,29 +150,21 @@ def _get_labels() -> dict[str, Union[SpatialImage, MultiscaleSpatialImage]]: dims_2d = ("y", "x") dims_3d = ("z", "y", "x") - out["labels2d"] = Labels2DModel.parse(RNG.normal(size=(64, 64)), name="labels2d", dims=dims_2d) - out["labels2d_multiscale"] = Labels2DModel.parse( - RNG.normal(size=(64, 64)), name="labels2d_multiscale", scale_factors=[2, 4], dims=dims_2d - ) - out["labels2d_xarray"] = Labels2DModel.parse( - DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), name="labels2d_xarray", dims=None - ) + out["labels2d"] = Labels2DModel.parse(RNG.normal(size=(64, 64)), dims=dims_2d) + out["labels2d_multiscale"] = Labels2DModel.parse(RNG.normal(size=(64, 64)), scale_factors=[2, 4], dims=dims_2d) + out["labels2d_xarray"] = Labels2DModel.parse(DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), dims=None) out["labels2d_multiscale_xarray"] = Labels2DModel.parse( DataArray(RNG.normal(size=(64, 64)), dims=dims_2d), - name="labels2d_multiscale_xarray", scale_factors=[2, 4], dims=None, ) - out["labels3d_numpy"] = Labels3DModel.parse(RNG.normal(size=(10, 64, 64)), name="labels3d_numpy", dims=dims_3d) + out["labels3d_numpy"] = Labels3DModel.parse(RNG.normal(size=(10, 64, 64)), dims=dims_3d) out["labels3d_multiscale_numpy"] = Labels3DModel.parse( - RNG.normal(size=(10, 64, 64)), name="labels3d_multiscale_numpy", scale_factors=[2, 4], dims=dims_3d - ) - out["labels3d_xarray"] = Labels3DModel.parse( - DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), name="labels3d_xarray", dims=None + RNG.normal(size=(10, 64, 64)), scale_factors=[2, 4], dims=dims_3d ) + out["labels3d_xarray"] = Labels3DModel.parse(DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), dims=None) out["labels3d_multiscale_xarray"] = Labels3DModel.parse( DataArray(RNG.normal(size=(10, 64, 64)), dims=dims_3d), - name="labels3d_multiscale_xarray", scale_factors=[2, 4], dims=None, )