From c9e7d99e063c9da247529aa70a0a68001e3e436e Mon Sep 17 00:00:00 2001 From: Giovanni Palla <25887487+giovp@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:14:18 +0100 Subject: [PATCH] refactor data loader (#299) Co-authored-by: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Luca Marconato --- .readthedocs.yaml | 2 +- docs/_templates/autosummary/class.rst | 8 - docs/api.md | 14 +- pyproject.toml | 1 - src/spatialdata/_core/concatenate.py | 8 +- src/spatialdata/_core/data_extent.py | 64 ++- src/spatialdata/_core/operations/aggregate.py | 7 +- src/spatialdata/_core/operations/rasterize.py | 2 - .../_core/query/relational_query.py | 6 +- src/spatialdata/_core/query/spatial_query.py | 1 - src/spatialdata/_core/spatialdata.py | 30 +- src/spatialdata/_io/_utils.py | 8 +- src/spatialdata/_io/io_zarr.py | 2 +- src/spatialdata/_utils.py | 5 +- src/spatialdata/dataloader/__init__.py | 8 +- src/spatialdata/dataloader/datasets.py | 531 +++++++++++++----- src/spatialdata/datasets.py | 2 +- src/spatialdata/models/_utils.py | 1 - src/spatialdata/models/models.py | 8 +- src/spatialdata/transformations/operations.py | 2 +- tests/conftest.py | 2 +- tests/core/operations/test_aggregations.py | 3 +- .../operations/test_spatialdata_operations.py | 2 +- tests/core/operations/test_transform.py | 3 +- tests/core/query/test_spatial_query.py | 2 +- tests/dataloader/test_datasets.py | 151 +++-- tests/dataloader/test_transforms.py | 0 tests/io/test_readwrite.py | 2 +- tests/models/test_models.py | 2 +- 29 files changed, 573 insertions(+), 304 deletions(-) delete mode 100644 tests/dataloader/test_transforms.py diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 690bf115..b59dfb7b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.10" sphinx: configuration: docs/conf.py - fail_on_warning: false + fail_on_warning: true python: install: - method: pip diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index d4668a41..e4665dfc 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -12,11 +12,8 @@ Attributes table ~~~~~~~~~~~~~~~~~~ .. autosummary:: - {% for item in attributes %} - ~{{ fullname }}.{{ item }} - {%- endfor %} {% endif %} {% endblock %} @@ -27,13 +24,10 @@ Methods table ~~~~~~~~~~~~~ .. autosummary:: - {% for item in methods %} - {%- if item != '__init__' %} ~{{ fullname }}.{{ item }} {%- endif -%} - {%- endfor %} {% endif %} {% endblock %} @@ -46,7 +40,6 @@ Attributes {% for item in attributes %} .. autoattribute:: {{ [objname, item] | join(".") }} - {%- endfor %} {% endif %} @@ -61,7 +54,6 @@ Methods {%- if item != '__init__' %} .. automethod:: {{ [objname, item] | join(".") }} - {%- endif -%} {%- endfor %} diff --git a/docs/api.md b/docs/api.md index 696c92e0..9034b0d9 100644 --- a/docs/api.md +++ b/docs/api.md @@ -30,11 +30,10 @@ Operations on `SpatialData` objects. match_table_to_element concatenate rasterize - transform aggregate ``` -### Utilities +### Operations Utilities ```{eval-rst} .. autosummary:: @@ -49,6 +48,7 @@ The elements (building-blocks) that consitute `SpatialData`. ```{eval-rst} .. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -61,9 +61,11 @@ The elements (building-blocks) that consitute `SpatialData`. TableModel ``` -### Utilities +### Models Utilities ```{eval-rst} +.. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -94,9 +96,11 @@ The transformations that can be defined between elements and coordinate systems Sequence ``` -### Utilities +### Transformations Utilities ```{eval-rst} +.. currentmodule:: spatialdata.transformations + .. autosummary:: :toctree: generated @@ -119,7 +123,7 @@ The transformations that can be defined between elements and coordinate systems ImageTilesDataset ``` -## Input/output +## Input/Output ```{eval-rst} .. currentmodule:: spatialdata diff --git a/pyproject.toml b/pyproject.toml index acd3e191..0904668a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dev = [ docs = [ "sphinx>=4.5", "sphinx-book-theme>=1.0.0", - "sphinx_rtd_theme", "myst-nb", "sphinxcontrib-bibtex>=1.0.0", "sphinx-autodoc-typehints", diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 77f82c53..eadc1d6b 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -2,14 +2,12 @@ from copy import copy # Should probably go up at the top from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import Any import numpy as np from anndata import AnnData -if TYPE_CHECKING: - from spatialdata._core.spatialdata import SpatialData - +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel __all__ = [ @@ -94,8 +92,6 @@ def concatenate( ------- The concatenated :class:`spatialdata.SpatialData` object. """ - from spatialdata import SpatialData - merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}} if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]): raise KeyError("Images must have unique names across the SpatialData objects to concatenate") diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 251a9e7b..3947fe5f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -115,9 +115,9 @@ def get_extent( has_labels: bool = True, has_points: bool = True, has_shapes: bool = True, - # python 3.9 tests fail if we don't use Union here, see - # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 - elements: Union[list[str], None] = None, # noqa: UP007 + elements: Union[ # noqa: UP007 # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 + list[str], None + ] = None, ) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object or a SpatialElement. @@ -129,43 +129,50 @@ def get_extent( Returns ------- + The bounding box description. + min_coordinate The minimum coordinate of the bounding box. max_coordinate The maximum coordinate of the bounding box. axes - The names of the dimensions of the bounding box + The names of the dimensions of the bounding box. exact - If True, the extent is computed exactly. If False, an approximation faster to compute is given. The - approximation is guaranteed to contain all the data, see notes for details. + Whether the extent is computed exactly or not. + + - If `True`, the extent is computed exactly. + - If `False`, an approximation faster to compute is given. + + The approximation is guaranteed to contain all the data, see notes for details. has_images - If True, images are included in the computation of the extent. + If `True`, images are included in the computation of the extent. has_labels - If True, labels are included in the computation of the extent. + If `True`, labels are included in the computation of the extent. has_points - If True, points are included in the computation of the extent. + If `True`, points are included in the computation of the extent. has_shapes - If True, shapes are included in the computation of the extent. + If `True`, shapes are included in the computation of the extent. elements - If not None, only the elements with the given names are included in the computation of the extent. + If not `None`, only the elements with the given names are included in the computation of the extent. Notes ----- - The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a - SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`. + The extent of a `SpatialData` object is the extent of the union of the extents of all its elements. + The extent of a `SpatialElement` is the extent of the element in the coordinate system + specified by the argument `coordinate_system`. - If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent - is transformed to the target coordinate system. This is faster than computing the extent after the transformation, - since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and - then computing the extent. + If `exact` is `False`, first the extent of the `SpatialElement` before any transformation is computed. + Then, the extent is transformed to the target coordinate system. This is faster than computing the extent + after the transformation, since the transformation is applied to extent of the untransformed data, + as opposed to transforming the data and then computing the extent. - The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the - case in which the transformation is affine but all the corners of the extent of the untransformed data + The exact and approximate extent are the same if the transformation does not contain any rotation or shear, or in + the case in which the transformation is affine but all the corners of the extent of the untransformed data (bounding box corners) are part of the dataset itself. Note that this is always the case for raster data. - An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The - exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the - box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]. + An extreme case is a dataset composed of the two points `(0, 0)` and `(1, 1)`, rotated anticlockwise by 45 degrees. + The exact extent is the bounding box `[minx, miny, maxx, maxy] = [0, 0, 0, 1.414]`, while the approximate extent is + the box `[minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]`. """ raise ValueError("The object type is not supported.") @@ -184,7 +191,9 @@ def _( elements: Union[list[str], None] = None, # noqa: UP007 ) -> BoundingBoxDescription: """ - Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. Parameters ---------- @@ -259,7 +268,14 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: @get_extent.register def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ - Compute the extent (bounding box) of a set of shapes. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. + + Parameters + ---------- + e + The SpatialData object. Returns ------- diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 9881dc7b..e51c848c 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import Any import anndata as ad import dask as da @@ -20,6 +20,7 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.query.relational_query import get_values +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -32,9 +33,6 @@ ) from spatialdata.transformations import BaseTransformation, Identity, get_transformation -if TYPE_CHECKING: - from spatialdata import SpatialData - __all__ = ["aggregate"] @@ -236,7 +234,6 @@ def _create_sdata_from_table_and_shapes( instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe table.obs[instance_key] = table.obs_names.copy() diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index d1a30c46..e3ec4774 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -207,8 +207,6 @@ def _( target_height: Optional[float] = None, target_depth: Optional[float] = None, ) -> SpatialData: - from spatialdata import SpatialData - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 14d6e88c..d75897e7 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any import dask.array as da import numpy as np @@ -11,6 +11,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _inplace_fix_subset_categorical_obs from spatialdata.models import ( Labels2DModel, @@ -22,9 +23,6 @@ get_model, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None: """ diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 2cbd02f3..e7be09bc 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -248,7 +248,6 @@ def _( target_coordinate_system: str, filter_table: bool = True, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._core.query.relational_query import _filter_table_by_elements min_coordinate = _parse_list_into_array(min_coordinate) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index 57e06339..36f1692b 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -18,27 +18,17 @@ from ome_zarr.types import JSONDict from spatial_image import SpatialImage -from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, -) -from spatialdata._io._utils import get_backing_files from spatialdata._logging import logger from spatialdata._types import ArrayLike -from spatialdata._utils import _natural_keys -from spatialdata.models import ( +from spatialdata.models._utils import SpatialElement, get_axes_names +from spatialdata.models.models import ( Image2DModel, Image3DModel, Labels2DModel, Labels3DModel, PointsModel, ShapesModel, - SpatialElement, TableModel, - get_axes_names, get_model, ) @@ -709,6 +699,9 @@ def add_image( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_raster import write_image + if self.is_backed(): files = get_backing_files(image) assert self.path is not None @@ -792,6 +785,9 @@ def add_labels( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_raster import write_labels + if self.is_backed(): files = get_backing_files(labels) assert self.path is not None @@ -876,6 +872,9 @@ def add_points( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io._utils import get_backing_files + from spatialdata._io.io_points import write_points + if self.is_backed(): files = get_backing_files(points) assert self.path is not None @@ -958,6 +957,8 @@ def add_shapes( ----- If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. """ + from spatialdata._io.io_shapes import write_shapes + self._add_shapes_in_memory(name=name, shapes=shapes, overwrite=overwrite) if self.is_backed(): elem_group = self._init_add_element(name=name, element_type="shapes", overwrite=overwrite) @@ -975,6 +976,8 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, ) -> None: + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table + """Write the SpatialData object to Zarr.""" if isinstance(file_path, str): file_path = Path(file_path) @@ -1176,6 +1179,8 @@ def table(self, table: AnnData) -> None: The table needs to pass validation (see :class:`~spatialdata.TableModel`). If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage. """ + from spatialdata._io.io_table import write_table + TableModel().validate(table) if self.table is not None: raise ValueError("The table already exists. Use del sdata.table to remove it first.") @@ -1276,6 +1281,7 @@ def _gen_repr( ------- The string representation of the SpatialData object. """ + from spatialdata._utils import _natural_keys def rreplace(s: str, old: str, new: str, occurrence: int) -> str: li = s.rsplit(old, occurrence) diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bfa12721..c2d44114 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -8,7 +8,7 @@ from collections.abc import Generator, Mapping from contextlib import contextmanager from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import Any import zarr from dask.dataframe.core import DataFrame as DaskDataFrame @@ -18,6 +18,7 @@ from spatial_image import SpatialImage from xarray import DataArray +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import iterate_pyramid_levels from spatialdata.models._utils import ( MappingToCoordinateSystem_t, @@ -30,9 +31,6 @@ _get_current_output_axes, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - # suppress logger debug from ome_zarr with context manager @contextmanager @@ -196,8 +194,6 @@ def _are_directories_identical( def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: - from spatialdata import SpatialData - if not isinstance(a, SpatialData) or not isinstance(b, SpatialData): return False # TODO: if the sdata object is backed on disk, don't create a new zarr file diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 7b4f286c..b1926822 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -9,7 +9,7 @@ from anndata import read_zarr as read_anndata_zarr from anndata.experimental import read_elem -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import ome_zarr_logger from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 59eaec6c..42e02264 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -3,7 +3,7 @@ import re from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Union +from typing import Union import numpy as np import pandas as pd @@ -26,9 +26,6 @@ # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] -if TYPE_CHECKING: - pass - def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: if isinstance(array, list): diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index f9262f85..819ab58e 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1,6 +1,4 @@ -import contextlib - -with contextlib.suppress(ImportError): +try: from spatialdata.dataloader.datasets import ImageTilesDataset - -__all__ = ["ImageTilesDataset"] +except ImportError: + ImageTilesDataset = None # type: ignore[assignment, misc] diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index e31750f2..388db612 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,15 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Mapping +from functools import partial +from itertools import chain +from types import MappingProxyType +from typing import Any, Callable import numpy as np +import pandas as pd +from anndata import AnnData from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage -from shapely import MultiPolygon, Point, Polygon -from spatial_image import SpatialImage +from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata._core.operations.rasterize import rasterize +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -17,175 +22,389 @@ Labels2DModel, Labels3DModel, ShapesModel, + TableModel, get_axes_names, get_model, ) -from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations import get_transformation from spatialdata.transformations.transformations import BaseTransformation -if TYPE_CHECKING: - from spatialdata import SpatialData +__all__ = ["ImageTilesDataset"] class ImageTilesDataset(Dataset): + """ + Dataloader for SpatialData. + + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. + + By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` + are set, the dataset returns a tuple containing: + + - the tile image, centered in the target coordinate system of the region. + - a vector or scalar value from the table. + + Parameters + ---------- + sdata + The SpatialData object. + regions_to_images + A mapping between region and images. The regions are used to compute the tile centers, while the images are + used to get the pixel values. + regions_to_coordinate_systems + A mapping between regions and coordinate systems. The coordinate systems are used to transform both + regions coordinates for tiles as well as images. + tile_scale + The scale of the tiles. This is used only if the `regions` are `shapes`. + It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` + according to the geometry type of the `shapes` element: + + - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. + - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. + + If `tile_dim_in_units` is passed, `tile_scale` is ignored. + tile_dim_in_units + The dimension of the requested tile in the units of the target coordinate system. + This specifies the extent of the tile. This is not related the size in pixel of each returned tile. + rasterize + If True, the images are rasterized using :func:`spatialdata.rasterize`. + If False, they are queried using :func:`spatialdata.bounding_box_query`. + return_annotations + If not None, a value from the table is returned together with the image tile. + Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` + can be returned. If None, it will return a `SpatialData` object with only the tuple + containing the image and the table value. + transform + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the + image and the table value. + rasterize_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. + This argument can be used for instance to choose the pixel dimension of the image tile. + + Returns + ------- + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. + """ + + INSTANCE_KEY = "instance_id" + CS_KEY = "cs" + REGION_KEY = "region" + IMAGE_KEY = "image" + def __init__( self, sdata: SpatialData, regions_to_images: dict[str, str], - tile_dim_in_units: float, - tile_dim_in_pixels: int, - target_coordinate_system: str = "global", - # unused at the moment, see - transform: Callable[[SpatialData], Any] | None = None, + regions_to_coordinate_systems: dict[str, str], + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + rasterize: bool = False, + return_annotations: str | list[str] | None = None, + transform: Callable[[Any], Any] | None = None, + rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - """ - Torch Dataset that returns image tiles around regions from a SpatialData object. - - Parameters - ---------- - sdata - The SpatialData object containing the regions and images from which to extract the tiles from. - regions_to_images - A dictionary mapping the regions element key we want to extract the tiles around to the images element key - we want to get the image data from. - tile_dim_in_units - The dimension of the requested tile in the units of the target coordinate system. This specifies the extent - of the image each tile is querying. This is not related he size in pixel of each returned tile. - tile_dim_in_pixels - The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, - independently of which extent of the image the tile is covering. - target_coordinate_system - The coordinate system in which the tile_dim_in_units is specified. - """ - # TODO: we can extend this code to support: - # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) - # - use the bounding box query instead of the raster function if the user wants - self.sdata = sdata - self.regions_to_images = regions_to_images - self.tile_dim_in_units = tile_dim_in_units - self.tile_dim_in_pixels = tile_dim_in_pixels + from spatialdata import bounding_box_query + from spatialdata._core.operations.rasterize import rasterize as rasterize_fn + + self._validate(sdata, regions_to_images, regions_to_coordinate_systems) + self._preprocess(tile_scale, tile_dim_in_units) + + self._crop_image: Callable[..., Any] = ( + partial( + rasterize_fn, + **dict(rasterize_kwargs), + ) + if rasterize + else bounding_box_query # type: ignore[assignment] + ) + self._return = self._get_return(return_annotations) self.transform = transform - self.target_coordinate_system = target_coordinate_system - - self.n_spots_dict = self._compute_n_spots_dict() - self.n_spots = sum(self.n_spots_dict.values()) - - def _validate_regions_to_images(self) -> None: - for region_key, image_key in self.regions_to_images.items(): - regions_element = self.sdata[region_key] - images_element = self.sdata[image_key] - # we could allow also for points - if get_model(regions_element) not in [ShapesModel, Labels2DModel, Labels3DModel]: - raise ValueError("regions_element must be a shapes element or a labels element") - if get_model(images_element) not in [Image2DModel, Image3DModel]: - raise ValueError("images_element must be an image element") - - def _compute_n_spots_dict(self) -> dict[str, int]: - n_spots_dict = {} - for region_key in self.regions_to_images: - element = self.sdata[region_key] - # we could allow also points - if isinstance(element, GeoDataFrame): - n_spots_dict[region_key] = len(element) - elif isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be a geodataframe or a spatial image") - return n_spots_dict - - def _get_region_info_for_index(self, index: int) -> tuple[str, int]: - # TODO: this implmenetation can be improved - i = 0 - for region_key, n_spots in self.n_spots_dict.items(): - if index < i + n_spots: - return region_key, index - i - i += n_spots - raise ValueError(f"index {index} is out of range") - def __len__(self) -> int: - return self.n_spots + def _validate( + self, + sdata: SpatialData, + regions_to_images: dict[str, str], + regions_to_coordinate_systems: dict[str, str], + ) -> None: + """Validate input parameters.""" + self._region_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.table.obs[self._region_key].cat.categories + cs_region_image = [] # list of tuples (coordinate_system, region, image) - def __getitem__(self, idx: int) -> Any | SpatialData: - from spatialdata import SpatialData - - if idx >= self.n_spots: - raise IndexError() - regions_name, region_index = self._get_region_info_for_index(idx) - regions = self.sdata[regions_name] - # TODO: here we just need to compute the centroids, - # we probably want to move this functionality to a different file - if isinstance(regions, GeoDataFrame): - dims = get_axes_names(regions) - region = regions.iloc[region_index] - shape = regions.geometry.iloc[0] - if isinstance(shape, Polygon): - xy = region.geometry.centroid.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - elif isinstance(shape, MultiPolygon): - raise NotImplementedError("MultiPolygon not supported yet") - elif isinstance(shape, Point): - xy = region.geometry.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - else: - raise RuntimeError(f"Unsupported type: {type(shape)}") - - t = get_transformation(regions, self.target_coordinate_system) + # check unique matching between regions and images and coordinate systems + assert len(set(regions_to_images.values())) == len( + regions_to_images.keys() + ), "One region cannot be paired to multiple images." + assert len(set(regions_to_coordinate_systems.values())) == len( + regions_to_coordinate_systems.keys() + ), "One region cannot be paired to multiple coordinate systems." + + for region_key, image_key in regions_to_images.items(): + # get elements + region_elem = sdata[region_key] + image_elem = sdata[image_key] + + # check that the elements are supported + if get_model(region_elem) in [Labels2DModel, Labels3DModel]: + raise NotImplementedError("labels elements are not implemented yet.") + if get_model(region_elem) not in [ShapesModel]: + raise ValueError("`regions_element` must be a shapes element.") + if get_model(image_elem) not in [Image2DModel, Image3DModel]: + raise ValueError("`images_element` must be an image element.") + if isinstance(image_elem, MultiscaleSpatialImage): + raise NotImplementedError("Multiscale images are not implemented yet.") + + if region_key not in available_regions: + raise ValueError(f"region {region_key} not found in the spatialdata object.") + + # check that the coordinate systems are valid for the elements + try: + cs = regions_to_coordinate_systems[region_key] + region_trans = get_transformation(region_elem, cs) + image_trans = get_transformation(image_elem, cs) + if isinstance(region_trans, BaseTransformation) and isinstance(image_trans, BaseTransformation): + cs_region_image.append((cs, region_key, image_key)) + except KeyError as e: + raise KeyError(f"region {region_key} not found in `regions_to_coordinate_systems`") from e + + self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader + self.sdata = sdata + self.dataset_table = self.sdata.table[ + self.sdata.table.obs[self._region_key].isin(self.regions) + ] # filtered table for the data loader + self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) + + def _preprocess( + self, + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + ) -> None: + """Preprocess the dataset.""" + index_df = [] + tile_coords_df = [] + dims_l = [] + shapes_l = [] + + for cs, region, image in self._cs_region_image: + # get dims and transformations for the region element + dims = get_axes_names(self.sdata[region]) + dims_l.append(dims) + t = get_transformation(self.sdata[region], cs) assert isinstance(t, BaseTransformation) - aff = t.to_affine_matrix(input_axes=dims, output_axes=dims) - transformed_centroid = np.squeeze(_affine_matrix_multiplication(aff, centroid), 0) - elif isinstance(regions, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be shapes or labels") - min_coordinate = np.array(transformed_centroid) - self.tile_dim_in_units / 2 - max_coordinate = np.array(transformed_centroid) + self.tile_dim_in_units / 2 - - raster = self.sdata[self.regions_to_images[regions_name]] - tile = rasterize( - raster, - axes=dims, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - target_coordinate_system=self.target_coordinate_system, - target_width=self.tile_dim_in_pixels, + + # get instances from region + inst = self.sdata.table.obs[self.sdata.table.obs[self._region_key] == region][self._instance_key].values + + # subset the regions by instances + subset_region = self.sdata[region].iloc[inst] + # get coordinates of centroids and extent for tiles + tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) + tile_coords_df.append(tile_coords) + + # get shapes + shapes_l.append(self.sdata[region]) + + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + df = pd.DataFrame({self.INSTANCE_KEY: inst}) + df[self.CS_KEY] = cs + df[self.REGION_KEY] = region + df[self.IMAGE_KEY] = image + index_df.append(df) + + # concatenate and assign to self + self.dataset_index = pd.concat(index_df).reset_index(drop=True) + self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) + # get table filtered by regions + self.filtered_table = self.sdata.table.obs[self.sdata.table.obs[self._region_key].isin(self.regions)] + + assert len(self.tiles_coords) == len(self.dataset_index) + dims_ = set(chain(*dims_l)) + assert np.all([i in self.tiles_coords for i in dims_]) + self.dims = list(dims_) + + def _get_return( + self, + return_annot: str | list[str] | None, + ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: + """Get function to return values from the table of the dataset.""" + if return_annot is not None: + # table is always returned as array shape (1, len(return_annot)) + # where return_table can be a single column or a list of columns + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + # return tuple of (tile, table) + if np.all([i in self.dataset_table.obs for i in return_annot]): + return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) + if np.all([i in self.dataset_table.var_names for i in return_annot]): + if issparse(self.dataset_table.X): + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) + raise ValueError( + f"`return_annot` must be a column name in the table or a variable name in the table. " + f"Got {return_annot}." + ) + # return spatialdata consisting of the image tile and the associated table + return lambda x, tile: SpatialData( + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + table=self.dataset_table[x], ) - tile_regions = regions.iloc[region_index : region_index + 1] - # TODO: as explained in the TODO in the __init__(), we want to let the - # user also use the bounding box query instaed of the rasterization - # the return function of this function would change, so we need to - # decide if instead having an extra Tile dataset class - # from spatialdata._core._spatial_query import BoundingBoxRequest - # request = BoundingBoxRequest( - # target_coordinate_system=self.target_coordinate_system, - # axes=dims, - # min_coordinate=min_coordinate, - # max_coordinate=max_coordinate, - # ) - # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) - table = self.sdata.table - filter_table = False - if table is not None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"]["region_key"] - instance_key = table.uns["spatialdata_attrs"]["instance_key"] - if isinstance(region, str): - if regions_name == region: - filter_table = True - elif isinstance(region, list): - if regions_name in region: - filter_table = True - else: - raise ValueError("region must be a string or a list of strings") - # TODO: maybe slow, we should check if there is a better way to do this - if filter_table: - instance = self.sdata[regions_name].iloc[region_index].name - row = table[(table.obs[region_key] == regions_name) & (table.obs[instance_key] == instance)].copy() - tile_table = row - else: - tile_table = None - tile_sdata = SpatialData( - images={self.regions_to_images[regions_name]: tile}, shapes={regions_name: tile_regions}, table=tile_table + + def __len__(self) -> int: + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Any | SpatialData: + """Get item from the dataset.""" + # get the row from the index + row = self.dataset_index.iloc[idx] + # get the tile coordinates + t_coords = self.tiles_coords.iloc[idx] + + image = self.sdata[row["image"]] + tile = self._crop_image( + image, + axes=self.dims, + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + target_coordinate_system=row["cs"], ) + if self.transform is not None: - return self.transform(tile_sdata) - return tile_sdata + out = self._return(idx, tile) + return self.transform(out) + return self._return(idx, tile) + + @property + def regions(self) -> list[str]: + """List of regions in the dataset.""" + return self._regions + + @regions.setter + def regions(self, regions: list[str]) -> None: # D102 + self._regions = regions + + @property + def sdata(self) -> SpatialData: + """The original SpatialData object.""" + return self._sdata + + @sdata.setter + def sdata(self, sdata: SpatialData) -> None: # D102 + self._sdata = sdata + + @property + def coordinate_systems(self) -> list[str]: + """List of coordinate systems in the dataset.""" + return self._coordinate_systems + + @coordinate_systems.setter + def coordinate_systems(self, coordinate_systems: list[str]) -> None: # D102 + self._coordinate_systems = coordinate_systems + + @property + def tiles_coords(self) -> pd.DataFrame: + """DataFrame with the index of tiles. + + It contains axis coordinates of the centroids, and extent of the tiles. + For example, for a 2D image, it contains the following columns: + + - `x`: the x coordinate of the centroid. + - `y`: the y coordinate of the centroid. + - `extent`: the extent of the tile. + - `minx`: the minimum x coordinate of the tile. + - `miny`: the minimum y coordinate of the tile. + - `maxx`: the maximum x coordinate of the tile. + - `maxy`: the maximum y coordinate of the tile. + """ + return self._tiles_coords + + @tiles_coords.setter + def tiles_coords(self, tiles: pd.DataFrame) -> None: + self._tiles_coords = tiles + + @property + def dataset_index(self) -> pd.DataFrame: + """DataFrame with the metadata of the tiles. + + It contains the following columns: + + - `instance`: the name of the instance in the region. + - `cs`: the coordinate system of the region-image pair. + - `region`: the name of the region. + - `image`: the name of the image. + """ + return self._dataset_index + + @dataset_index.setter + def dataset_index(self, dataset_index: pd.DataFrame) -> None: + self._dataset_index = dataset_index + + @property + def dataset_table(self) -> AnnData: + """AnnData table filtered by the `region` and `cs` present in the dataset.""" + return self._dataset_table + + @dataset_table.setter + def dataset_table(self, dataset_table: AnnData) -> None: + self._dataset_table = dataset_table + + @property + def dims(self) -> list[str]: + """Dimensions of the dataset.""" + return self._dims + + @dims.setter + def dims(self, dims: list[str]) -> None: + self._dims = dims + + +def _get_tile_coords( + elem: GeoDataFrame, + transformation: BaseTransformation, + dims: tuple[str, ...], + tile_scale: float | None = None, + tile_dim_in_units: float | None = None, +) -> pd.DataFrame: + """Get the (transformed) centroid of the region and the extent.""" + # get centroids and transform them + centroids = elem.centroid.get_coordinates().values + aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) + centroids = _affine_matrix_multiplication(aff, centroids) + + # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` + if tile_dim_in_units is None: + if elem.iloc[0, 0].geom_type == "Point": + extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale + elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: + extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + else: + raise ValueError("Only point and polygon shapes are supported.") + if tile_dim_in_units is not None: + if isinstance(tile_dim_in_units, (float, int)): + extent = np.repeat(tile_dim_in_units, len(centroids)) + else: + raise TypeError( + f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " + f"not {type(tile_dim_in_units)}." + ) + if len(extent) != len(centroids): + raise ValueError( + f"the number of elements in the region ({len(extent)}) does not match" + f" the number of instances ({len(centroids)})." + ) + + # transform extent + aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) + extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) + + # get min and max coordinates + min_coordinates = np.array(centroids) - extent / 2 + max_coordinates = np.array(centroids) + extent / 2 + + # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] + return pd.DataFrame( + np.hstack([centroids, extent, min_coordinates, max_coordinates]), + columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], + ) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 0d811ad3..3b9f6726 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -13,8 +13,8 @@ from skimage.segmentation import slic from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.operations.aggregate import aggregate +from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import ( diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index cf139d5e..d78dc8b9 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -16,7 +16,6 @@ SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] TRANSFORM_KEY = "transform" DEFAULT_COORDINATE_SYSTEM = "global" -# ValidAxis_t = Literal["c", "x", "y", "z"] ValidAxis_t = str MappingToCoordinateSystem_t = dict[str, BaseTransformation] C = "c" diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f7155a3b..e27a08c3 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -18,7 +18,7 @@ from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods -from pandas.api.types import is_categorical_dtype +from pandas import CategoricalDtype from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -470,7 +470,7 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY] - if not is_categorical_dtype(data[feature_key]): + if not isinstance(data[feature_key], CategoricalDtype): logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") @singledispatchmethod @@ -624,7 +624,7 @@ def _add_metadata_and_validate( # Here we are explicitly importing the categories # but it is a convenient way to ensure that the categories are known. # It also just changes the state of the series, so it is not a big deal. - if is_categorical_dtype(data[c]) and not data[c].cat.known: + if isinstance(data[c], CategoricalDtype) and not data[c].cat.known: try: data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) except ValueError: @@ -729,7 +729,7 @@ def parse( region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if not is_categorical_dtype(adata.obs[region_key]): + if not isinstance(adata.obs[region_key], CategoricalDtype): warnings.warn( f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2 ) diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 165551c4..61792777 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -15,7 +15,7 @@ ) if TYPE_CHECKING: - from spatialdata import SpatialData + from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement from spatialdata.transformations import Affine, BaseTransformation diff --git a/tests/conftest.py b/tests/conftest.py index 045f88bf..50d8f8d1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,7 +19,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 51ffa98e..36609464 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -9,8 +9,9 @@ from anndata.tests.helpers import assert_equal from geopandas import GeoDataFrame from numpy.random import default_rng -from spatialdata import SpatialData, aggregate +from spatialdata import aggregate from spatialdata._core.query._utils import circles_to_polygons +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index c551461c..10ace62b 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -9,8 +9,8 @@ from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index f28b345f..601159ef 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -8,7 +8,8 @@ from geopandas.testing import geom_almost_equals from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData, transform +from spatialdata import transform +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 6db7e904..58d27823 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -6,13 +6,13 @@ from multiscale_spatial_image import MultiscaleSpatialImage from shapely import Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( BaseSpatialRequest, BoundingBoxRequest, bounding_box_query, polygon_query, ) +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 59126996..dac01e80 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -4,59 +4,112 @@ import pandas as pd import pytest from anndata import AnnData -from spatialdata.dataloader.datasets import ImageTilesDataset +from spatialdata._core.spatialdata import SpatialData +from spatialdata.dataloader import ImageTilesDataset from spatialdata.models import TableModel -@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) -@pytest.mark.parametrize( - "regions_element", - ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], -) -def test_tiles_dataset(sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: - cm = pytest.raises(NotImplementedError) - else: - cm = contextlib.nullcontext() - with cm: +class TestImageTilesDataset: + @pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) + @pytest.mark.parametrize( + "regions_element", + ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], + ) + def test_validation(self, sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multiscale_labels"] or image_element == "blobs_multiscale_image": + cm = pytest.raises(NotImplementedError) + elif regions_element in ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]: + cm = pytest.raises(ValueError) + else: + cm = contextlib.nullcontext() + with cm: + _ = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={regions_element: image_element}, + regions_to_coordinate_systems={regions_element: "global"}, + ) + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("raster", [True, False]) + def test_default(self, sdata_blobs, regions_element, raster): + raster_kwargs = {"target_unit_to_pixels": 2} if raster else {} + + sdata = self._annotate_shapes(sdata_blobs, regions_element) ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", + sdata=sdata, + rasterize=raster, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + rasterize_kwargs=raster_kwargs, ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) + sdata_tile = ds[0] + tile = sdata_tile.images.values().__iter__().__next__() -def test_tiles_table(sdata_blobs): - new_table = AnnData( - X=np.random.default_rng().random((3, 10)), - obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), - ) - new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") - del sdata_blobs.table - sdata_blobs.table = new_table - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 3 - assert len(ds[0].table) == 1 - assert np.all(ds[0].table.X == new_table[0].X) - - -def test_tiles_multiple_elements(sdata_blobs): - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 6 - _ = ds[0] + if regions_element == "blobs_circles": + if raster: + assert tile.shape == (3, 50, 50) + else: + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + if raster: + assert tile.shape == (3, 164, 164) + else: + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + if raster: + assert tile.shape == (3, 329, 329) + else: + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + + # extent has units in pixel so should be the same as tile shape + if raster: + assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] + else: + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) + assert list(sdata_tile.images.keys())[0] == "blobs_image" + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("return_annot", ["region", ["region", "instance_id"]]) + def test_return_annot(self, sdata_blobs, regions_element, return_annot): + sdata = self._annotate_shapes(sdata_blobs, regions_element) + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + return_annotations=return_annot, + ) + + tile, annot = ds[0] + if regions_element == "blobs_circles": + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + assert annot.shape[1] == len(return_annot) + + # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation + def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: + new_table = AnnData( + X=np.random.default_rng().random((len(sdata[shape]), 10)), + obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), + ) + new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") + del sdata.table + sdata.table = new_table + return sdata diff --git a/tests/dataloader/test_transforms.py b/tests/dataloader/test_transforms.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 046baf3b..ddeefb37 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -12,7 +12,7 @@ from numpy.random import default_rng from shapely.geometry import Point from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._io._utils import _are_directories_identical from spatialdata.models import TableModel from spatialdata.transformations.operations import ( diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 34035c7d..f1dfe2df 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -22,7 +22,7 @@ from pandas.api.types import is_categorical_dtype from shapely.io import to_ragged_array from spatial_image import SpatialImage, to_spatial_image -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel,