diff --git a/.github/workflows/test_and_deploy.yaml b/.github/workflows/test_and_deploy.yaml index 40882d26..d11ca125 100644 --- a/.github/workflows/test_and_deploy.yaml +++ b/.github/workflows/test_and_deploy.yaml @@ -6,7 +6,7 @@ on: tags: - "v*" # Push events to matching v*, i.e. v1.0, v20.15.10 pull_request: - branches: [main] + branches: "*" jobs: test: @@ -63,7 +63,7 @@ jobs: PLATFORM: ${{ matrix.os }} DISPLAY: :42 run: | - pytest -v --cov --color=yes --cov-report=xml + pytest --cov --color=yes --cov-report=xml - name: Upload coverage to Codecov uses: codecov/codecov-action@v3.1.1 with: diff --git a/.gitignore b/.gitignore index c0a79ef9..666248e4 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,6 @@ spatialdata-sandbox # version file _version.py + +# other +node_modules/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 89e11456..9af87581 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,11 +9,11 @@ ci: skip: [] repos: - repo: https://github.com/psf/black - rev: 23.10.1 + rev: 23.12.1 hooks: - id: black - repo: https://github.com/pre-commit/mirrors-prettier - rev: v3.0.3 + rev: v4.0.0-alpha.8 hooks: - id: prettier - repo: https://github.com/asottile/blacken-docs @@ -21,13 +21,13 @@ repos: hooks: - id: blacken-docs - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.6.1 + rev: v1.8.0 hooks: - id: mypy additional_dependencies: [numpy, types-requests] exclude: tests/|docs/ - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.1.3 + rev: v0.1.13 hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 690bf115..b59dfb7b 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -6,7 +6,7 @@ build: python: "3.10" sphinx: configuration: docs/conf.py - fail_on_warning: false + fail_on_warning: true python: install: - method: pip diff --git a/CHANGELOG.md b/CHANGELOG.md index f2fe3358..d913b652 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,58 @@ and this project adheres to [Semantic Versioning][]. [keep a changelog]: https://keepachangelog.com/en/1.0.0/ [semantic versioning]: https://semver.org/spec/v2.0.0.html -## [0.0.15] - tbd +## [0.1.0] - tbd ### Added +#### Major + +- Implemented support in SpatialData for storing multiple tables. These tables + can annotate a SpatialElement but not necessarily so. +- Increased in-memory vs on-disk control: changes performed in-memory (e.g. adding a new image) are not automatically performed on-disk. + +#### Minor + +- Added public helper function get_table_keys in spatialdata.models to retrieve annotation information of a given table. +- Added public helper function check_target_region_column_symmetry in spatialdata.models to check whether annotation + metadata in table.uns['spatialdata_attrs'] corresponds with respective columns in table.obs. +- Added function validate_table_in_spatialdata in SpatialData to validate the annotation target of a table being + present in the SpatialData object. +- Added function get_annotated_regions in SpatialData to get the regions annotated by a given table. +- Added function get_region_key_column in SpatialData to get the region_key column in table.obs. +- Added function get_instance_key_column in SpatialData to get the instance_key column in table.obs. +- Added function set_table_annotates_spatialelement in SpatialData to either set or change the annotation metadata of + a table in a given SpatialData object. +- Added tables property in SpatialData. +- Added tables setter in SpatialData. +- Added gen_spatial_elements generator in SpatialData to generate the SpatialElements in a given SpatialData object. +- Added gen_elements generator in SpatialData to generate elements of a SpatialData object including tables. + +### Changed + +#### Minor + +- Changed the string representation of SpatialData to reflect the changes in regard to multiple tables. + +## [0.0.x] - tbd + +### Minor + +- improved usability and robustness of sdata.write() when overwrite=True @aeisenbarth + +### Added + +- added SpatialData.subset() API +- added SpatialData.locate_element() API + +### Fixed + +- generalized queries to any combination of 2D/3D data and 2D/3D query region #409 + +#### Minor + +- refactored data loader for deep learning + ## [0.0.14] - 2023-10-11 ### Added diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst index d4668a41..e4665dfc 100644 --- a/docs/_templates/autosummary/class.rst +++ b/docs/_templates/autosummary/class.rst @@ -12,11 +12,8 @@ Attributes table ~~~~~~~~~~~~~~~~~~ .. autosummary:: - {% for item in attributes %} - ~{{ fullname }}.{{ item }} - {%- endfor %} {% endif %} {% endblock %} @@ -27,13 +24,10 @@ Methods table ~~~~~~~~~~~~~ .. autosummary:: - {% for item in methods %} - {%- if item != '__init__' %} ~{{ fullname }}.{{ item }} {%- endif -%} - {%- endfor %} {% endif %} {% endblock %} @@ -46,7 +40,6 @@ Attributes {% for item in attributes %} .. autoattribute:: {{ [objname, item] | join(".") }} - {%- endfor %} {% endif %} @@ -61,7 +54,6 @@ Methods {%- if item != '__init__' %} .. automethod:: {{ [objname, item] | join(".") }} - {%- endif -%} {%- endfor %} diff --git a/docs/api.md b/docs/api.md index 696c92e0..93509ffd 100644 --- a/docs/api.md +++ b/docs/api.md @@ -29,12 +29,12 @@ Operations on `SpatialData` objects. get_extent match_table_to_element concatenate - rasterize transform + rasterize aggregate ``` -### Utilities +### Operations Utilities ```{eval-rst} .. autosummary:: @@ -49,6 +49,7 @@ The elements (building-blocks) that consitute `SpatialData`. ```{eval-rst} .. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -61,9 +62,11 @@ The elements (building-blocks) that consitute `SpatialData`. TableModel ``` -### Utilities +### Models Utilities ```{eval-rst} +.. currentmodule:: spatialdata.models + .. autosummary:: :toctree: generated @@ -94,9 +97,11 @@ The transformations that can be defined between elements and coordinate systems Sequence ``` -### Utilities +### Transformations Utilities ```{eval-rst} +.. currentmodule:: spatialdata.transformations + .. autosummary:: :toctree: generated @@ -119,7 +124,7 @@ The transformations that can be defined between elements and coordinate systems ImageTilesDataset ``` -## Input/output +## Input/Output ```{eval-rst} .. currentmodule:: spatialdata @@ -129,4 +134,5 @@ The transformations that can be defined between elements and coordinate systems read_zarr save_transformations + get_dask_backing_files ``` diff --git a/pyproject.toml b/pyproject.toml index acd3e191..0904668a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,7 +50,6 @@ dev = [ docs = [ "sphinx>=4.5", "sphinx-book-theme>=1.0.0", - "sphinx_rtd_theme", "myst-nb", "sphinxcontrib-bibtex>=1.0.0", "sphinx-autodoc-typehints", diff --git a/src/spatialdata/__init__.py b/src/spatialdata/__init__.py index 0541c491..e09f42c0 100644 --- a/src/spatialdata/__init__.py +++ b/src/spatialdata/__init__.py @@ -28,6 +28,7 @@ "read_zarr", "unpad_raster", "save_transformations", + "get_dask_backing_files", ] from spatialdata import dataloader, models, transformations @@ -40,6 +41,6 @@ from spatialdata._core.query.relational_query import get_values, match_table_to_element from spatialdata._core.query.spatial_query import bounding_box_query, polygon_query from spatialdata._core.spatialdata import SpatialData -from spatialdata._io._utils import save_transformations +from spatialdata._io._utils import get_dask_backing_files, save_transformations from spatialdata._io.io_zarr import read_zarr from spatialdata._utils import unpad_raster diff --git a/src/spatialdata/_core/_elements.py b/src/spatialdata/_core/_elements.py new file mode 100644 index 00000000..023128ae --- /dev/null +++ b/src/spatialdata/_core/_elements.py @@ -0,0 +1,116 @@ +"""SpatialData elements.""" +from __future__ import annotations + +from collections import UserDict +from collections.abc import Iterable +from typing import Any +from warnings import warn + +from anndata import AnnData +from dask.dataframe.core import DataFrame as DaskDataFrame +from datatree import DataTree +from geopandas import GeoDataFrame + +from spatialdata._types import Raster_T +from spatialdata._utils import multiscale_spatial_image_from_data_tree +from spatialdata.models import ( + Image2DModel, + Image3DModel, + Labels2DModel, + Labels3DModel, + PointsModel, + ShapesModel, + TableModel, + get_axes_names, + get_model, +) + + +class Elements(UserDict[str, Any]): + def __init__(self, shared_keys: set[str | None]) -> None: + self._shared_keys = shared_keys + super().__init__() + + @staticmethod + def _check_key(key: str, element_keys: Iterable[str], shared_keys: set[str | None]) -> None: + if key in element_keys: + warn(f"Key `{key}` already exists. Overwriting it.", UserWarning, stacklevel=2) + else: + if key in shared_keys: + raise KeyError(f"Key `{key}` already exists.") + + def __setitem__(self, key: str, value: Any) -> None: + self._shared_keys.add(key) + super().__setitem__(key, value) + + def __delitem__(self, key: str) -> None: + self._shared_keys.remove(key) + super().__delitem__(key) + + +class Images(Elements): + def __setitem__(self, key: str, value: Raster_T) -> None: + self._check_key(key, self.keys(), self._shared_keys) + if isinstance(value, (DataTree)): + value = multiscale_spatial_image_from_data_tree(value) + schema = get_model(value) + if schema not in (Image2DModel, Image3DModel): + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ndim = len(get_axes_names(value)) + if ndim == 3: + Image2DModel().validate(value) + super().__setitem__(key, value) + elif ndim == 4: + Image3DModel().validate(value) + super().__setitem__(key, value) + else: + NotImplementedError("TODO: implement for ndim > 4.") + + +class Labels(Elements): + def __setitem__(self, key: str, value: Raster_T) -> None: + self._check_key(key, self.keys(), self._shared_keys) + if isinstance(value, (DataTree)): + value = multiscale_spatial_image_from_data_tree(value) + schema = get_model(value) + if schema not in (Labels2DModel, Labels3DModel): + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ndim = len(get_axes_names(value)) + if ndim == 2: + Labels2DModel().validate(value) + super().__setitem__(key, value) + elif ndim == 3: + Labels3DModel().validate(value) + super().__setitem__(key, value) + else: + NotImplementedError("TODO: implement for ndim > 3.") + + +class Shapes(Elements): + def __setitem__(self, key: str, value: GeoDataFrame) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != ShapesModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + ShapesModel().validate(value) + super().__setitem__(key, value) + + +class Points(Elements): + def __setitem__(self, key: str, value: DaskDataFrame) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != PointsModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + PointsModel().validate(value) + super().__setitem__(key, value) + + +class Tables(Elements): + def __setitem__(self, key: str, value: AnnData) -> None: + self._check_key(key, self.keys(), self._shared_keys) + schema = get_model(value) + if schema != TableModel: + raise TypeError(f"Unknown element type with schema: {schema!r}.") + TableModel().validate(value) + super().__setitem__(key, value) diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py new file mode 100644 index 00000000..1c22c802 --- /dev/null +++ b/src/spatialdata/_core/_utils.py @@ -0,0 +1,22 @@ +from spatialdata._core.spatialdata import SpatialData + + +def _find_common_table_keys(sdatas: list[SpatialData]) -> set[str]: + """ + Find table keys present in more than one SpatialData object. + + Parameters + ---------- + sdatas + A list of SpatialData objects. + + Returns + ------- + A set of common keys that are present in the tables of more than one SpatialData object. + """ + common_keys = set(sdatas[0].tables.keys()) + + for sdata in sdatas[1:]: + common_keys.intersection_update(sdata.tables.keys()) + + return common_keys diff --git a/src/spatialdata/_core/concatenate.py b/src/spatialdata/_core/concatenate.py index 77f82c53..8312d660 100644 --- a/src/spatialdata/_core/concatenate.py +++ b/src/spatialdata/_core/concatenate.py @@ -1,15 +1,16 @@ from __future__ import annotations +from collections import defaultdict from copy import copy # Should probably go up at the top from itertools import chain -from typing import TYPE_CHECKING, Any +from typing import Any +from warnings import warn import numpy as np from anndata import AnnData -if TYPE_CHECKING: - from spatialdata._core.spatialdata import SpatialData - +from spatialdata._core._utils import _find_common_table_keys +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import TableModel __all__ = [ @@ -25,6 +26,8 @@ def _concatenate_tables( ) -> AnnData: import anndata as ad + if not all(TableModel.ATTRS_KEY in table.uns for table in tables): + raise ValueError("Not all tables are annotating a spatial element") region_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] for table in tables] instance_keys = [table.uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] for table in tables] regions = [table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] for table in tables] @@ -73,6 +76,7 @@ def concatenate( sdatas: list[SpatialData], region_key: str | None = None, instance_key: str | None = None, + concatenate_tables: bool = False, **kwargs: Any, ) -> SpatialData: """ @@ -87,6 +91,8 @@ def concatenate( If all region_keys are the same, the `region_key` is used. instance_key The key to use for the instance column in the concatenated object. + concatenate_tables + Whether to merge the tables in case of having the same element name. kwargs See :func:`anndata.concat` for more details. @@ -94,8 +100,6 @@ def concatenate( ------- The concatenated :class:`spatialdata.SpatialData` object. """ - from spatialdata import SpatialData - merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}} if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]): raise KeyError("Images must have unique names across the SpatialData objects to concatenate") @@ -112,16 +116,43 @@ def concatenate( assert isinstance(sdatas, list), "sdatas must be a list" assert len(sdatas) > 0, "sdatas must be a non-empty list" - merged_table = _concatenate_tables( - [sdata.table for sdata in sdatas if sdata.table is not None], region_key, instance_key, **kwargs - ) + if not concatenate_tables: + key_counts: dict[str, int] = defaultdict(int) + for sdata in sdatas: + for k in sdata.tables: + key_counts[k] += 1 + + if any(value > 1 for value in key_counts.values()): + warn( + "Duplicate table names found. Tables will be added with integer suffix. Set concatenate_tables to True" + "if concatenation is wished for instead.", + UserWarning, + stacklevel=2, + ) + merged_tables = {} + count_dict: dict[str, int] = defaultdict(int) + + for sdata in sdatas: + for k, v in sdata.tables.items(): + new_key = f"{k}_{count_dict[k]}" if key_counts[k] > 1 else k + count_dict[k] += 1 + merged_tables[new_key] = v + else: + common_keys = _find_common_table_keys(sdatas) + merged_tables = {} + for sdata in sdatas: + for k, v in sdata.tables.items(): + if k in common_keys and merged_tables.get(k) is not None: + merged_tables[k] = _concatenate_tables([merged_tables[k], v], region_key, instance_key, **kwargs) + else: + merged_tables[k] = v return SpatialData( images=merged_images, labels=merged_labels, points=merged_points, shapes=merged_shapes, - table=merged_table, + tables=merged_tables, ) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 251a9e7b..3947fe5f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -115,9 +115,9 @@ def get_extent( has_labels: bool = True, has_points: bool = True, has_shapes: bool = True, - # python 3.9 tests fail if we don't use Union here, see - # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 - elements: Union[list[str], None] = None, # noqa: UP007 + elements: Union[ # noqa: UP007 # https://github.com/scverse/spatialdata/pull/318#issuecomment-1755714287 + list[str], None + ] = None, ) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object or a SpatialElement. @@ -129,43 +129,50 @@ def get_extent( Returns ------- + The bounding box description. + min_coordinate The minimum coordinate of the bounding box. max_coordinate The maximum coordinate of the bounding box. axes - The names of the dimensions of the bounding box + The names of the dimensions of the bounding box. exact - If True, the extent is computed exactly. If False, an approximation faster to compute is given. The - approximation is guaranteed to contain all the data, see notes for details. + Whether the extent is computed exactly or not. + + - If `True`, the extent is computed exactly. + - If `False`, an approximation faster to compute is given. + + The approximation is guaranteed to contain all the data, see notes for details. has_images - If True, images are included in the computation of the extent. + If `True`, images are included in the computation of the extent. has_labels - If True, labels are included in the computation of the extent. + If `True`, labels are included in the computation of the extent. has_points - If True, points are included in the computation of the extent. + If `True`, points are included in the computation of the extent. has_shapes - If True, shapes are included in the computation of the extent. + If `True`, shapes are included in the computation of the extent. elements - If not None, only the elements with the given names are included in the computation of the extent. + If not `None`, only the elements with the given names are included in the computation of the extent. Notes ----- - The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a - SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`. + The extent of a `SpatialData` object is the extent of the union of the extents of all its elements. + The extent of a `SpatialElement` is the extent of the element in the coordinate system + specified by the argument `coordinate_system`. - If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent - is transformed to the target coordinate system. This is faster than computing the extent after the transformation, - since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and - then computing the extent. + If `exact` is `False`, first the extent of the `SpatialElement` before any transformation is computed. + Then, the extent is transformed to the target coordinate system. This is faster than computing the extent + after the transformation, since the transformation is applied to extent of the untransformed data, + as opposed to transforming the data and then computing the extent. - The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the - case in which the transformation is affine but all the corners of the extent of the untransformed data + The exact and approximate extent are the same if the transformation does not contain any rotation or shear, or in + the case in which the transformation is affine but all the corners of the extent of the untransformed data (bounding box corners) are part of the dataset itself. Note that this is always the case for raster data. - An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The - exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the - box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]. + An extreme case is a dataset composed of the two points `(0, 0)` and `(1, 1)`, rotated anticlockwise by 45 degrees. + The exact extent is the bounding box `[minx, miny, maxx, maxy] = [0, 0, 0, 1.414]`, while the approximate extent is + the box `[minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]`. """ raise ValueError("The object type is not supported.") @@ -184,7 +191,9 @@ def _( elements: Union[list[str], None] = None, # noqa: UP007 ) -> BoundingBoxDescription: """ - Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. Parameters ---------- @@ -259,7 +268,14 @@ def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: @get_extent.register def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ - Compute the extent (bounding box) of a set of shapes. + Get the extent (bounding box) of a SpatialData object. + + The resulting extent is the union of the extents of all its elements. + + Parameters + ---------- + e + The SpatialData object. Returns ------- diff --git a/src/spatialdata/_core/operations/aggregate.py b/src/spatialdata/_core/operations/aggregate.py index 9881dc7b..29a78b3f 100644 --- a/src/spatialdata/_core/operations/aggregate.py +++ b/src/spatialdata/_core/operations/aggregate.py @@ -1,7 +1,7 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import Any import anndata as ad import dask as da @@ -20,6 +20,7 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.query._utils import circles_to_polygons from spatialdata._core.query.relational_query import get_values +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -32,9 +33,6 @@ ) from spatialdata.transformations import BaseTransformation, Identity, get_transformation -if TYPE_CHECKING: - from spatialdata import SpatialData - __all__ = ["aggregate"] @@ -236,7 +234,6 @@ def _create_sdata_from_table_and_shapes( instance_key: str, deepcopy: bool, ) -> SpatialData: - from spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe table.obs[instance_key] = table.obs_names.copy() @@ -250,7 +247,7 @@ def _create_sdata_from_table_and_shapes( if deepcopy: shapes = _deepcopy_geodataframe(shapes) - return SpatialData.from_elements_dict({shapes_name: shapes, "": table}) + return SpatialData.from_elements_dict({shapes_name: shapes, "table": table}) def _aggregate_image_by_labels( diff --git a/src/spatialdata/_core/operations/rasterize.py b/src/spatialdata/_core/operations/rasterize.py index d1a30c46..a850542f 100644 --- a/src/spatialdata/_core/operations/rasterize.py +++ b/src/spatialdata/_core/operations/rasterize.py @@ -207,8 +207,6 @@ def _( target_height: Optional[float] = None, target_depth: Optional[float] = None, ) -> SpatialData: - from spatialdata import SpatialData - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) @@ -232,7 +230,7 @@ def _( ) new_name = f"{name}_rasterized_{element_type}" new_images[new_name] = rasterized - return SpatialData(images=new_images, table=sdata.table) + return SpatialData(images=new_images, tables=sdata.tables) # get xdata diff --git a/src/spatialdata/_core/query/_utils.py b/src/spatialdata/_core/query/_utils.py index 25e8caa9..15fbe5c9 100644 --- a/src/spatialdata/_core/query/_utils.py +++ b/src/spatialdata/_core/query/_utils.py @@ -1,8 +1,13 @@ from __future__ import annotations +from typing import Any + import geopandas as gpd +from anndata import AnnData from xarray import DataArray +from spatialdata._core._elements import Tables +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata._utils import Number, _parse_list_into_array @@ -78,3 +83,39 @@ def get_bounding_box_corners( ], coords={"corner": range(8), "axis": list(axes)}, ) + + +def _get_filtered_or_unfiltered_tables( + filter_table: bool, elements: dict[str, Any], sdata: SpatialData +) -> dict[str, AnnData] | Tables: + """ + Get the tables in a SpatialData object. + + The tables of the SpatialData object can either be filtered to only include the tables that annotate an element in + elements or all tables are returned. + + Parameters + ---------- + filter_table + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. + elements + A dictionary containing the elements to use for filtering the tables. + sdata + The SpatialData object that contains the tables to filter. + + Returns + ------- + A dictionary containing the filtered or unfiltered tables based on the value of the 'filter_table' parameter. + + """ + if filter_table: + from spatialdata._core.query.relational_query import _filter_table_by_elements + + return { + name: filtered_table + for name, table in sdata.tables.items() + if (filtered_table := _filter_table_by_elements(table, elements)) and len(filtered_table) != 0 + } + + return sdata.tables diff --git a/src/spatialdata/_core/query/relational_query.py b/src/spatialdata/_core/query/relational_query.py index 14d6e88c..9beb4f16 100644 --- a/src/spatialdata/_core/query/relational_query.py +++ b/src/spatialdata/_core/query/relational_query.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import Any import dask.array as da import numpy as np @@ -11,6 +11,7 @@ from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _inplace_fix_subset_categorical_obs from spatialdata.models import ( Labels2DModel, @@ -22,11 +23,8 @@ get_model, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - -def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None: +def _filter_table_by_element_names(table: AnnData | None, element_names: str | list[str]) -> AnnData | None: """ Filter an AnnData table to keep only the rows that are in the coordinate system. @@ -34,19 +32,19 @@ def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: ---------- table The table to filter; if None, returns None - coordinate_system - The coordinate system to keep + element_names + The element_names to keep in the tables obs.region column Returns ------- The filtered table, or None if the input table was None """ - if table is None: + if table is None or not table.uns.get(TableModel.ATTRS_KEY): return None table_mapping_metadata = table.uns[TableModel.ATTRS_KEY] region_key = table_mapping_metadata[TableModel.REGION_KEY_KEY] table.obs = pd.DataFrame(table.obs) - table = table[table.obs[region_key].isin(coordinate_system)].copy() + table = table[table.obs[region_key].isin(element_names)].copy() table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = table.obs[region_key].unique().tolist() return table diff --git a/src/spatialdata/_core/query/spatial_query.py b/src/spatialdata/_core/query/spatial_query.py index 2cbd02f3..ef14c465 100644 --- a/src/spatialdata/_core/query/spatial_query.py +++ b/src/spatialdata/_core/query/spatial_query.py @@ -17,7 +17,7 @@ from tqdm import tqdm from xarray import DataArray -from spatialdata._core.query._utils import get_bounding_box_corners +from spatialdata._core.query._utils import _get_filtered_or_unfiltered_tables, get_bounding_box_corners from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike @@ -61,7 +61,9 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( target_coordinate_system The coordinate system the bounding box is defined in. - Returns ------- All the corners of the bounding box in the intrinsic coordinate system of the element. The shape + Returns + ------- + All the corners of the bounding box in the intrinsic coordinate system of the element. The shape is (2, 4) when axes has 2 spatial dimensions, and (2, 8) when axes has 3 spatial dimensions. The axes of the intrinsic coordinate system. @@ -73,6 +75,12 @@ def _get_bounding_box_corners_in_intrinsic_coordinates( # get the transformation from the element's intrinsic coordinate system # to the query coordinate space transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system) + m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation( + element, target_coordinate_system + ) + axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes( + axes, min_coordinate, max_coordinate, output_axes_without_c + ) # get the coordinates of the bounding box corners bounding_box_corners = get_bounding_box_corners( @@ -155,7 +163,7 @@ def _bounding_box_mask_points( min_coordinate: list[Number] | ArrayLike, max_coordinate: list[Number] | ArrayLike, ) -> da.Array: - """Compute a mask that is true for the points inside of an axis-aligned bounding box.. + """Compute a mask that is true for the points inside an axis-aligned bounding box. Parameters ---------- @@ -164,23 +172,26 @@ def _bounding_box_mask_points( axes The axes that min_coordinate and max_coordinate refer to. min_coordinate - The upper left hand corner of the bounding box (i.e., minimum coordinates - along all dimensions). + The upper left hand corner of the bounding box (i.e., minimum coordinates along all dimensions). max_coordinate - The lower right hand corner of the bounding box (i.e., the maximum coordinates - along all dimensions + The lower right hand corner of the bounding box (i.e., the maximum coordinates along all dimensions). Returns ------- - The mask for the points inside of the bounding box. + The mask for the points inside the bounding box. """ + element_axes = get_axes_names(points) min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) in_bounding_box_masks = [] for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue min_value = min_coordinate[axis_index] in_bounding_box_masks.append(points[axis_name].gt(min_value).to_dask_array(lengths=True)) for axis_index, axis_name in enumerate(axes): + if axis_name not in element_axes: + continue max_value = max_coordinate[axis_index] in_bounding_box_masks.append(points[axis_name].lt(max_value).to_dask_array(lengths=True)) in_bounding_box_masks = da.stack(in_bounding_box_masks, axis=-1) @@ -248,9 +259,6 @@ def _( target_coordinate_system: str, filter_table: bool = True, ) -> SpatialData: - from spatialdata import SpatialData - from spatialdata._core.query.relational_query import _filter_table_by_elements - min_coordinate = _parse_list_into_array(min_coordinate) max_coordinate = _parse_list_into_array(max_coordinate) new_elements = {} @@ -266,8 +274,80 @@ def _( ) new_elements[element_type] = queried_elements - table = _filter_table_by_elements(sdata.table, new_elements) if filter_table else sdata.table - return SpatialData(**new_elements, table=table) + tables = _get_filtered_or_unfiltered_tables(filter_table, new_elements, sdata) + + return SpatialData(**new_elements, tables=tables) + + +def _get_axes_of_tranformation( + element: SpatialElement, target_coordinate_system: str +) -> tuple[ArrayLike, tuple[str, ...], tuple[str, ...]]: + """ + Get the transformation matrix and the transformation's axes (ignoring `c`). + + The transformation is the one from the element's intrinsic coordinate system to the query coordinate space. + Note that the axes which specify the query shape are not necessarily the same as the axes that are output of the + transformation + + Parameters + ---------- + element + SpatialData element to be transformed. + target_coordinate_system + The target coordinate system for the transformation. + + Returns + ------- + m_without_c + The transformation from the element's intrinsic coordinate system to the query coordinate space, without the + "c" axis. + input_axes_without_c + The axes of the element's intrinsic coordinate system, without the "c" axis. + output_axes_without_c + The axes of the query coordinate system, without the "c" axis. + + """ + from spatialdata.transformations import get_transformation + + transform_to_query_space = get_transformation(element, to_coordinate_system=target_coordinate_system) + assert isinstance(transform_to_query_space, BaseTransformation) + m = _get_affine_for_element(element, transform_to_query_space) + input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"]) + output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"]) + m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c) + return m_without_c, input_axes_without_c, output_axes_without_c + + +def _adjust_bounding_box_to_real_axes( + axes: tuple[str, ...], + min_coordinate: ArrayLike, + max_coordinate: ArrayLike, + output_axes_without_c: tuple[str, ...], +) -> tuple[tuple[str, ...], ArrayLike, ArrayLike]: + """ + Adjust the bounding box to the real axes of the transformation. + + The bounding box is defined by the user and it's axes may not coincide with the axes of the transformation. + """ + if set(axes) != set(output_axes_without_c): + axes_only_in_bb = set(axes) - set(output_axes_without_c) + axes_only_in_output = set(output_axes_without_c) - set(axes) + + # let's remove from the bounding box whose axes that are not in the output axes (e.g. querying 2D points with a + # 3D bounding box) + indices_to_remove_from_bb = [axes.index(ax) for ax in axes_only_in_bb] + axes = tuple([ax for ax in axes if ax not in axes_only_in_bb]) + min_coordinate = np.delete(min_coordinate, indices_to_remove_from_bb) + max_coordinate = np.delete(max_coordinate, indices_to_remove_from_bb) + + # if there are axes in the output axes that are not in the bounding box, we need to add them to the bounding box + # with a range that includes everything (e.g. querying 3D points with a 2D bounding box) + for ax in axes_only_in_output: + axes = axes + (ax,) + M = np.finfo(np.float32).max - 1 + min_coordinate = np.append(min_coordinate, -M) + max_coordinate = np.append(max_coordinate, M) + return axes, min_coordinate, max_coordinate @bounding_box_query.register(SpatialImage) @@ -283,7 +363,6 @@ def _( Notes ----- - _____ See https://github.com/scverse/spatialdata/pull/151 for a detailed overview of the logic of this code, and for the cases the comments refer to. """ @@ -300,15 +379,10 @@ def _( max_coordinate=max_coordinate, ) - # get the transformation from the element's intrinsic coordinate system to the query coordinate space - transform_to_query_space = get_transformation(image, to_coordinate_system=target_coordinate_system) - assert isinstance(transform_to_query_space, BaseTransformation) - m = _get_affine_for_element(image, transform_to_query_space) - input_axes_without_c = tuple([ax for ax in m.input_axes if ax != "c"]) - output_axes_without_c = tuple([ax for ax in m.output_axes if ax != "c"]) - m_without_c = m.to_affine_matrix(input_axes=input_axes_without_c, output_axes=output_axes_without_c) + m_without_c, input_axes_without_c, output_axes_without_c = _get_axes_of_tranformation( + image, target_coordinate_system + ) m_without_c_linear = m_without_c[:-1, :-1] - transform_dimension = np.linalg.matrix_rank(m_without_c_linear) transform_coordinate_length = len(output_axes_without_c) data_dim = len(input_axes_without_c) @@ -336,24 +410,13 @@ def _( error_message = ( f"This case is not supported (data with dimension" f"{data_dim} but transformation with rank {transform_dimension}." - f"Please open a GitHub issue if you want to discuss a case." + f"Please open a GitHub issue if you want to discuss a use case." ) raise ValueError(error_message) - if set(axes) != set(output_axes_without_c): - if set(axes).issubset(output_axes_without_c): - logger.warning( - f"The element has axes {output_axes_without_c}, but the query has axes {axes}. Excluding the element " - f"from the query result. In the future we can add support for this case. If you are interested, " - f"please open a GitHub issue." - ) - return None - error_messeage = ( - f"Invalid case. The bounding box axes are {axes}," - f"the spatial axes in {target_coordinate_system} are" - f"{output_axes_without_c}" - ) - raise ValueError(error_messeage) + axes, min_coordinate, max_coordinate = _adjust_bounding_box_to_real_axes( + axes, min_coordinate, max_coordinate, output_axes_without_c + ) spatial_transform = Affine(m_without_c, input_axes=input_axes_without_c, output_axes=output_axes_without_c) spatial_transform_bb_axes = Affine( @@ -370,7 +433,7 @@ def _( ) else: assert case == 2 - # TODO: we need to intersect the plane in the extrinsic coordiante system with the 3D bounding box. The + # TODO: we need to intersect the plane in the extrinsic coordinate system with the 3D bounding box. The # vertices of this polygons needs to be transformed to the intrinsic coordinate system raise NotImplementedError( "Case 2 (the transformation is embedding 2D data in the 3D space, is not " @@ -570,7 +633,6 @@ def _polygon_query( labels: bool, ) -> SpatialData: from spatialdata._core.query._utils import circles_to_polygons - from spatialdata._core.query.relational_query import _filter_table_by_elements from spatialdata.models import ( PointsModel, ShapesModel, @@ -640,11 +702,10 @@ def _polygon_query( "issue and we will prioritize the implementation." ) - if filter_table and sdata.table is not None: - table = _filter_table_by_elements(sdata.table, {"shapes": new_shapes, "points": new_points}) - else: - table = sdata.table - return SpatialData(shapes=new_shapes, points=new_points, images=new_images, table=table) + elements = {"shapes": new_shapes, "points": new_points} + tables = _get_filtered_or_unfiltered_tables(filter_table, elements, sdata) + + return SpatialData(shapes=new_shapes, points=new_points, images=new_images, tables=tables) # this function is currently excluded from the API documentation. TODO: add it after the refactoring @@ -669,6 +730,9 @@ def polygon_query( The polygon (or list of polygons) to query by target_coordinate_system The coordinate system of the polygon + filter_table + Specifies whether to filter the tables to only include tables that annotate elements in the retrieved + SpatialData object of the query. shapes Whether to filter shapes points @@ -685,8 +749,6 @@ def polygon_query( making this function more general and ergonomic. """ - from spatialdata._core.query.relational_query import _filter_table_by_elements - # adjust coordinate transformation (this implementation can be made faster) sdata = sdata.transform_to_coordinate_system(target_coordinate_system) @@ -749,6 +811,7 @@ def polygon_query( vv = vv[~vv.index.duplicated(keep="first")] geodataframes[k] = vv - table = _filter_table_by_elements(sdata.table, {"shapes": geodataframes}) if filter_table else sdata.table + elements = {"shapes": geodataframes} + tables = _get_filtered_or_unfiltered_tables(filter_table, elements, sdata) - return SpatialData(shapes=geodataframes, table=table) + return SpatialData(shapes=geodataframes, tables=tables) diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index d0506c48..3cdf91d2 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -2,11 +2,13 @@ import hashlib import os +import warnings from collections.abc import Generator +from itertools import chain from pathlib import Path -from types import MappingProxyType -from typing import TYPE_CHECKING, Any, Union +from typing import TYPE_CHECKING, Any, Literal +import pandas as pd import zarr from anndata import AnnData from dask.dataframe import read_parquet @@ -18,17 +20,10 @@ from ome_zarr.types import JSONDict from spatial_image import SpatialImage -from spatialdata._io import ( - write_image, - write_labels, - write_points, - write_shapes, - write_table, -) -from spatialdata._io._utils import get_backing_files +from spatialdata._core._elements import Images, Labels, Points, Shapes, Tables from spatialdata._logging import logger -from spatialdata._types import ArrayLike -from spatialdata._utils import _natural_keys +from spatialdata._types import ArrayLike, Raster_T +from spatialdata._utils import _error_message_add_element, deprecation_alias from spatialdata.models import ( Image2DModel, Image3DModel, @@ -36,11 +31,12 @@ Labels3DModel, PointsModel, ShapesModel, - SpatialElement, TableModel, - get_axes_names, + check_target_region_column_symmetry, get_model, + get_table_keys, ) +from spatialdata.models._utils import SpatialElement, get_axes_names if TYPE_CHECKING: from spatialdata._core.query.spatial_query import BaseSpatialRequest @@ -54,9 +50,6 @@ Point_s = PointsModel() Table_s = TableModel() -# create a shorthand for raster image types -Raster_T = Union[SpatialImage, MultiscaleSpatialImage] - class SpatialData: """ @@ -116,53 +109,91 @@ class SpatialData: """ - _images: dict[str, Raster_T] = MappingProxyType({}) # type: ignore[assignment] - _labels: dict[str, Raster_T] = MappingProxyType({}) # type: ignore[assignment] - _points: dict[str, DaskDataFrame] = MappingProxyType({}) # type: ignore[assignment] - _shapes: dict[str, GeoDataFrame] = MappingProxyType({}) # type: ignore[assignment] - _table: AnnData | None = None - path: str | None = None - + @deprecation_alias(table="tables") def __init__( self, - images: dict[str, Raster_T] = MappingProxyType({}), # type: ignore[assignment] - labels: dict[str, Raster_T] = MappingProxyType({}), # type: ignore[assignment] - points: dict[str, DaskDataFrame] = MappingProxyType({}), # type: ignore[assignment] - shapes: dict[str, GeoDataFrame] = MappingProxyType({}), # type: ignore[assignment] - table: AnnData | None = None, + images: dict[str, Raster_T] | None = None, + labels: dict[str, Raster_T] | None = None, + points: dict[str, DaskDataFrame] | None = None, + shapes: dict[str, GeoDataFrame] | None = None, + tables: dict[str, AnnData] | Tables | None = None, ) -> None: - self.path = None + self._path: Path | None = None + + self._shared_keys: set[str | None] = set() + self._images: Images = Images(shared_keys=self._shared_keys) + self._labels: Labels = Labels(shared_keys=self._shared_keys) + self._points: Points = Points(shared_keys=self._shared_keys) + self._shapes: Shapes = Shapes(shared_keys=self._shared_keys) + self._tables: Tables = Tables(shared_keys=self._shared_keys) + + # Workaround to allow for backward compatibility + if isinstance(tables, AnnData): + tables = {"table": tables} self._validate_unique_element_names( - list(images.keys()) + list(labels.keys()) + list(points.keys()) + list(shapes.keys()) + list(chain.from_iterable([e.keys() for e in [images, labels, points, shapes] if e is not None])) ) if images is not None: - self._images: dict[str, SpatialImage | MultiscaleSpatialImage] = {} for k, v in images.items(): - self._add_image_in_memory(name=k, image=v) + self.images[k] = v if labels is not None: - self._labels: dict[str, SpatialImage | MultiscaleSpatialImage] = {} for k, v in labels.items(): - self._add_labels_in_memory(name=k, labels=v) + self.labels[k] = v if shapes is not None: - self._shapes: dict[str, GeoDataFrame] = {} for k, v in shapes.items(): - self._add_shapes_in_memory(name=k, shapes=v) + self.shapes[k] = v if points is not None: - self._points: dict[str, DaskDataFrame] = {} for k, v in points.items(): - self._add_points_in_memory(name=k, points=v) + self.points[k] = v - if table is not None: - Table_s.validate(table) - self._table = table + if tables is not None: + for k, v in tables.items(): + self.validate_table_in_spatialdata(v) + self.tables[k] = v self._query = QueryManager(self) + def validate_table_in_spatialdata(self, data: AnnData) -> None: + """ + Validate the presence of the annotation target of a SpatialData table in the SpatialData object. + + This method validates a table in the SpatialData object to ensure that if annotation metadata is present, the + annotation target (SpatialElement) is present in the SpatialData object. Otherwise, a warning is raised. + + Parameters + ---------- + data + The table potentially annotating a SpatialElement + + Raises + ------ + UserWarning + If the table is annotating elements not present in the SpatialData object. + """ + TableModel().validate(data) + element_names = [ + element_name for element_type, element_name, _ in self._gen_elements() if element_type != "tables" + ] + if TableModel.ATTRS_KEY in data.uns: + attrs = data.uns[TableModel.ATTRS_KEY] + regions = ( + attrs[TableModel.REGION_KEY] + if isinstance(attrs[TableModel.REGION_KEY], list) + else [attrs[TableModel.REGION_KEY]] + ) + # TODO: check throwing error + if not all(element_name in element_names for element_name in regions): + warnings.warn( + "The table is annotating an/some element(s) not present in the SpatialData object", + UserWarning, + stacklevel=2, + ) + @staticmethod def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> SpatialData: """ @@ -183,7 +214,7 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp "labels": {}, "points": {}, "shapes": {}, - "table": None, + "tables": {}, } for k, e in elements_dict.items(): schema = get_model(e) @@ -200,13 +231,200 @@ def from_elements_dict(elements_dict: dict[str, SpatialElement | AnnData]) -> Sp assert isinstance(d["shapes"], dict) d["shapes"][k] = e elif schema == TableModel: - if d["table"] is not None: - raise ValueError("Only one table can be present in the dataset.") - d["table"] = e + assert isinstance(d["tables"], dict) + d["tables"][k] = e else: raise ValueError(f"Unknown schema {schema}") return SpatialData(**d) # type: ignore[arg-type] + @staticmethod + def get_annotated_regions(table: AnnData) -> str | list[str]: + """ + Get the regions annotated by a table. + + Parameters + ---------- + table + The AnnData table for which to retrieve annotated regions. + + Returns + ------- + The annotated regions. + """ + regions, _, _ = get_table_keys(table) + return regions + + @staticmethod + def get_region_key_column(table: AnnData) -> pd.Series: + """Get the column of table.obs containing per row the region annotated by that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The region key column. + + Raises + ------ + KeyError + If the region key column is not found in table.obs. + """ + _, region_key, _ = get_table_keys(table) + if table.obs.get(region_key): + return table.obs[region_key] + raise KeyError(f"{region_key} is set as region key column. However the column is not found in table.obs.") + + @staticmethod + def get_instance_key_column(table: AnnData) -> pd.Series: + """ + Return the instance key column in table.obs containing for each row the instance id of that row. + + Parameters + ---------- + table + The AnnData table. + + Returns + ------- + The instance key column. + + Raises + ------ + KeyError + If the instance key column is not found in table.obs. + + """ + _, _, instance_key = get_table_keys(table) + if table.obs.get(instance_key): + return table.obs[instance_key] + raise KeyError(f"{instance_key} is set as instance key column. However the column is not found in table.obs.") + + @staticmethod + def _set_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: str, + instance_key: str, + ) -> None: + """ + Set the SpatialElement annotation target of an AnnData table. + + This method sets the target annotation element of a table based on the specified parameters. It creates the + `attrs` dictionary for `table.uns` and only after validation that the regions are present in the region_key + column of table.obs updates the annotation metadata of the table. + + Parameters + ---------- + table + The AnnData object containing the data table. + region + The name of the target element for the table annotation. + region_key + The key for the region annotation column in `table.obs`. + instance_key + The key for the instance annotation column in `table.obs`. + + Raises + ------ + ValueError + If `region_key` is not present in the `table.obs` columns. + ValueError + If `instance_key` is not present in the `table.obs` columns. + """ + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + attrs = { + TableModel.REGION_KEY: region, + TableModel.REGION_KEY_KEY: region_key, + TableModel.INSTANCE_KEY: instance_key, + } + check_target_region_column_symmetry(table, region_key, region) + table.uns[TableModel.ATTRS_KEY] = attrs + + @staticmethod + def _change_table_annotation_target( + table: AnnData, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """Change the annotation target of a table currently having annotation metadata already. + + Parameters + ---------- + table + The table already annotating a SpatialElement. + region + The name of the target SpatialElement for which the table annotation will be changed. + region_key + The name of the region key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + instance_key + The name of the instance key column in the table. If not provided, it will be extracted from the table's uns + attribute. If present here but also given as argument, the value in the table's uns attribute will be + overwritten. + + Raises + ------ + ValueError + If no region_key is provided, and it is not present in both table.uns['spatialdata_attrs'] and table.obs. + ValueError + If provided region_key is not present in table.obs. + """ + attrs = table.uns[TableModel.ATTRS_KEY] + table_region_key = region_key if region_key else attrs.get(TableModel.REGION_KEY_KEY) + + TableModel()._validate_set_region_key(table, region_key) + TableModel()._validate_set_instance_key(table, instance_key) + check_target_region_column_symmetry(table, table_region_key, region) + attrs[TableModel.REGION_KEY] = region + + def set_table_annotates_spatialelement( + self, + table_name: str, + region: str | pd.Series, + region_key: None | str = None, + instance_key: None | str = None, + ) -> None: + """ + Set the SpatialElement annotation target of a given AnnData table. + + Parameters + ---------- + table_name + The name of the table to set the annotation target for. + region + The name of the target element for the annotation. This can either be a string or a pandas Series object. + region_key + The region key for the annotation. If not specified, defaults to None which means the currently set region + key is reused. + instance_key + The instance key for the annotation. If not specified, defaults to None which means the currently set + instance key is reused. + + Raises + ------ + ValueError + If the annotation SpatialElement target is not present in the SpatialData object. + TypeError + If no current annotation metadata is found and both region_key and instance_key are not specified. + """ + table = self.tables[table_name] + element_names = {element[1] for element in self._gen_elements()} + if region not in element_names: + raise ValueError(f"Annotation target '{region}' not present as SpatialElement in " f"SpatialData object.") + + if table.uns.get(TableModel.ATTRS_KEY): + self._change_table_annotation_target(table, region, region_key, instance_key) + elif isinstance(region_key, str) and isinstance(instance_key, str): + self._set_table_annotation_target(table, region, region_key, instance_key) + else: + raise TypeError("No current annotation metadata found. Please specify both region_key and instance_key.") + @property def query(self) -> QueryManager: return self._query @@ -231,8 +449,8 @@ def aggregate( Notes ----- - This function calls :func:`spatialdata.aggregate` with the convenience that values and by can be string - without having to specify the values_sdata and by_sdata, which in that case will be replaced by `self`. + This function calls :func:`spatialdata.aggregate` with the convenience that `values` and `by` can be string + without having to specify the `values_sdata` and `by_sdata`, which in that case will be replaced by `self`. Please see :func:`spatialdata.aggregate` for the complete docstring. @@ -263,114 +481,18 @@ def aggregate( def _validate_unique_element_names(element_names: list[str]) -> None: if len(element_names) != len(set(element_names)): duplicates = {x for x in element_names if element_names.count(x) > 1} - raise ValueError( + raise KeyError( f"Element names must be unique. The following element names are used multiple times: {duplicates}" ) - def _add_image_in_memory( - self, name: str, image: SpatialImage | MultiscaleSpatialImage, overwrite: bool = False - ) -> None: - """Add an image element to the SpatialData object. - - Parameters - ---------- - name - name of the image - image - the image element to be added - overwrite - whether to overwrite the image if the name already exists. - """ - self._validate_unique_element_names( - list(self.labels.keys()) + list(self.points.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._images and not overwrite: - raise KeyError(f"Image {name} already exists in the dataset.") - ndim = len(get_axes_names(image)) - if ndim == 3: - Image2D_s.validate(image) - self._images[name] = image - elif ndim == 4: - Image3D_s.validate(image) - self._images[name] = image - else: - raise ValueError("Only czyx and cyx images supported") - - def _add_labels_in_memory( - self, name: str, labels: SpatialImage | MultiscaleSpatialImage, overwrite: bool = False - ) -> None: - """ - Add a labels element to the SpatialData object. - - Parameters - ---------- - name - name of the labels - labels - the labels element to be added - overwrite - whether to overwrite the labels if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.points.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._labels and not overwrite: - raise KeyError(f"Labels {name} already exists in the dataset.") - ndim = len(get_axes_names(labels)) - if ndim == 2: - Label2D_s.validate(labels) - self._labels[name] = labels - elif ndim == 3: - Label3D_s.validate(labels) - self._labels[name] = labels - else: - raise ValueError(f"Only yx and zyx labels supported, got {ndim} dimensions") - - def _add_shapes_in_memory(self, name: str, shapes: GeoDataFrame, overwrite: bool = False) -> None: - """ - Add a shapes element to the SpatialData object. - - Parameters - ---------- - name - name of the shapes - shapes - the shapes element to be added - overwrite - whether to overwrite the shapes if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.points.keys()) + list(self.labels.keys()) + [name] - ) - if name in self._shapes and not overwrite: - raise KeyError(f"Shapes {name} already exists in the dataset.") - Shape_s.validate(shapes) - self._shapes[name] = shapes - - def _add_points_in_memory(self, name: str, points: DaskDataFrame, overwrite: bool = False) -> None: - """ - Add a points element to the SpatialData object. - - Parameters - ---------- - name - name of the points element - points - the points to be added - overwrite - whether to overwrite the points if the name already exists. - """ - self._validate_unique_element_names( - list(self.images.keys()) + list(self.labels.keys()) + list(self.shapes.keys()) + [name] - ) - if name in self._points and not overwrite: - raise KeyError(f"Points {name} already exists in the dataset.") - Point_s.validate(points) - self._points[name] = points - def is_backed(self) -> bool: """Check if the data is backed by a Zarr storage or if it is in-memory.""" - return self.path is not None + return self._path is not None + + @property + def path(self) -> Path | None: + """Path to the Zarr storage.""" + return self._path # TODO: from a commennt from Giovanni: consolite somewhere in # a future PR (luca: also _init_add_element could be cleaned) @@ -387,7 +509,7 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group: Returns ------- - either the existing Zarr sub-group or a new one + either the existing Zarr sub-group or a new one. """ store = parse_url(self.path, mode="r+").store root = zarr.group(store=store) @@ -396,13 +518,6 @@ def _get_group_for_element(self, name: str, element_type: str) -> zarr.Group: return element_type_group.require_group(name) def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> zarr.Group: - if self.path is None: - # in the future we can relax this, but this ensures that we don't have objects that are partially backed - # and partially in memory - raise RuntimeError( - "The data is not backed by a Zarr storage. In order to add new elements after " - "initializing a SpatialData object you need to call SpatialData.write() first" - ) store = parse_url(self.path, mode="r+").store root = zarr.group(store=store) assert element_type in ["images", "labels", "points", "shapes"] @@ -430,24 +545,19 @@ def _init_add_element(self, name: str, element_type: str, overwrite: bool) -> za return elem_group return root - def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: + def locate_element(self, element: SpatialElement) -> list[str] | None: """ - Find the SpatialElement within the SpatialData object. + Locate a SpatialElement within the SpatialData object and, if found, returns its Zarr path relative to the root. Parameters ---------- element The queried SpatialElement - Returns ------- - name and type of the element - - Raises - ------ - ValueError - the element is not found or found multiple times in the SpatialData object + A list of Zarr paths of the element relative to the root (multiple copies of the same element are allowed), or + None if the element is not found. """ found: list[SpatialElement] = [] found_element_type: list[str] = [] @@ -459,39 +569,8 @@ def _locate_spatial_element(self, element: SpatialElement) -> tuple[str, str]: found_element_type.append(element_type) found_element_name.append(element_name) if len(found) == 0: - raise ValueError("Element not found in the SpatialData object.") - if len(found) > 1: - raise ValueError( - f"Element found multiple times in the SpatialData object." - f"Found {len(found)} elements with names: {found_element_name}," - f" and types: {found_element_type}" - ) - assert len(found_element_name) == 1 - assert len(found_element_type) == 1 - return found_element_name[0], found_element_type[0] - - def contains_element(self, element: SpatialElement, raise_exception: bool = False) -> bool: - """ - Check if the SpatialElement is contained in the SpatialData object. - - Parameters - ---------- - element - The SpatialElement to check - raise_exception - If True, raise an exception if the element is not found. If False, return False if the element is not found. - - Returns - ------- - True if the element is found; False otherwise (if raise_exception is False). - """ - try: - self._locate_spatial_element(element) - return True - except ValueError as e: - if raise_exception: - raise e - return False + return None + return [f"{found_element_type[i]}/{found_element_name[i]}" for i in range(len(found))] def _write_transformations_to_disk(self, element: SpatialElement) -> None: """ @@ -506,27 +585,37 @@ def _write_transformations_to_disk(self, element: SpatialElement) -> None: transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) - found_element_name, found_element_type = self._locate_spatial_element(element) - + located = self.locate_element(element) + if located is None: + raise ValueError( + "Cannot save the transformation to the element as it has not been found in the SpatialData object" + ) if self.path is not None: - group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) - axes = get_axes_names(element) - if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_raster, - ) + for path in located: + found_element_type, found_element_name = path.split("/") + group = self._get_group_for_element(name=found_element_name, element_type=found_element_type) + axes = get_axes_names(element) + if isinstance(element, (SpatialImage, MultiscaleSpatialImage)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_raster, + ) - overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) - elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): - from spatialdata._io._utils import ( - overwrite_coordinate_transformations_non_raster, - ) + overwrite_coordinate_transformations_raster(group=group, axes=axes, transformations=transformations) + elif isinstance(element, (DaskDataFrame, GeoDataFrame, AnnData)): + from spatialdata._io._utils import ( + overwrite_coordinate_transformations_non_raster, + ) - overwrite_coordinate_transformations_non_raster(group=group, axes=axes, transformations=transformations) - else: - raise ValueError("Unknown element type") + overwrite_coordinate_transformations_non_raster( + group=group, axes=axes, transformations=transformations + ) + else: + raise ValueError("Unknown element type") - def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter_table: bool = True) -> SpatialData: + @deprecation_alias(filter_table="filter_tables") + def filter_by_coordinate_system( + self, coordinate_system: str | list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: """ Filter the SpatialData by one (or a list of) coordinate system. @@ -537,37 +626,104 @@ def filter_by_coordinate_system(self, coordinate_system: str | list[str], filter ---------- coordinate_system The coordinate system(s) to filter by. - filter_table - If True (default), the table will be filtered to only contain regions + filter_tables + If True (default), the tables will be filtered to only contain regions of an element belonging to the specified coordinate system(s). + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. Returns ------- The filtered SpatialData. """ - from spatialdata._core.query.relational_query import _filter_table_by_coordinate_system + # TODO: decide whether to add parameter to filter only specific table. + from spatialdata.transformations.operations import get_transformation elements: dict[str, dict[str, SpatialElement]] = {} - element_paths_in_coordinate_system = [] + element_names_in_coordinate_system = [] if isinstance(coordinate_system, str): coordinate_system = [coordinate_system] for element_type, element_name, element in self._gen_elements(): - transformations = get_transformation(element, get_all=True) - assert isinstance(transformations, dict) - for cs in coordinate_system: - if cs in transformations: - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = element - element_paths_in_coordinate_system.append(element_name) - - if filter_table: - table = _filter_table_by_coordinate_system(self.table, element_paths_in_coordinate_system) + if element_type != "tables": + transformations = get_transformation(element, get_all=True) + assert isinstance(transformations, dict) + for cs in coordinate_system: + if cs in transformations: + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = element + element_names_in_coordinate_system.append(element_name) + tables = self._filter_tables( + set(), filter_tables, "cs", include_orphan_tables, element_names=element_names_in_coordinate_system + ) + + return SpatialData(**elements, tables=tables) + + # TODO: move to relational query with refactor + def _filter_tables( + self, + names_tables_to_keep: set[str], + filter_tables: bool = True, + by: Literal["cs", "elements"] | None = None, + include_orphan_tables: bool = False, + element_names: str | list[str] | None = None, + elements_dict: dict[str, dict[str, Any]] | None = None, + ) -> Tables | dict[str, AnnData]: + """ + Filter tables by coordinate system or elements or return tables. + + Parameters + ---------- + names_tables_to_keep + The names of the tables to keep even when filter_tables is True. + filter_tables + If True (default), the tables will be filtered to only contain regions + of an element belonging to the specified coordinate system(s) or including only rows annotating specified + elements. + by + Filter mode. Valid values are "cs" or "elements". Default is None. + include_orphan_tables + Flag indicating whether to include orphan tables. Default is False. + element_names + Element names of elements present in specific coordinate system. + elements_dict + Dictionary of elements for filtering the tables. Default is None. + + Returns + ------- + The filtered tables if filter_tables was True, otherwise tables of the SpatialData object. + + """ + if filter_tables: + tables: dict[str, AnnData] | Tables = {} + for table_name, table in self._tables.items(): + if include_orphan_tables and not table.uns.get(TableModel.ATTRS_KEY): + tables[table_name] = table + continue + if table_name in names_tables_to_keep: + tables[table_name] = table + continue + # each mode here requires paths or elements, using assert here to avoid mypy errors. + if by == "cs": + from spatialdata._core.query.relational_query import _filter_table_by_element_names + + assert element_names is not None + table = _filter_table_by_element_names(table, element_names) + if len(table) != 0: + tables[table_name] = table + elif by == "elements": + from spatialdata._core.query.relational_query import _filter_table_by_elements + + assert elements_dict is not None + table = _filter_table_by_elements(table, elements_dict=elements_dict) + if len(table) != 0: + tables[table_name] = table else: - table = self.table + tables = self.tables - return SpatialData(**elements, table=table) + return tables def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: """ @@ -599,7 +755,7 @@ def rename_coordinate_systems(self, rename_dict: dict[str, str]) -> None: new_names.append(new_cs) # rename the coordinate systems - for element in self._gen_elements_values(): + for element in self._gen_spatial_element_values(): # get the transformations transformations = get_transformation(element, get_all=True) assert isinstance(transformations, dict) @@ -673,300 +829,15 @@ def transform_to_coordinate_system( ------- The transformed SpatialData. """ - sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_table=False) + sdata = self.filter_by_coordinate_system(target_coordinate_system, filter_tables=False) elements: dict[str, dict[str, SpatialElement]] = {} for element_type, element_name, element in sdata._gen_elements(): - transformed = sdata.transform_element_to_coordinate_system(element, target_coordinate_system) - if element_type not in elements: - elements[element_type] = {} - elements[element_type][element_name] = transformed - return SpatialData(**elements, table=sdata.table) - - def add_image( - self, - name: str, - image: SpatialImage | MultiscaleSpatialImage, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """ - Add an image to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - image - The image to add, the object needs to pass validation - (see :class:`~spatialdata.Image2DModel` and :class:`~spatialdata.Image3DModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(image) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "images", name)) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. See more here: https://github.com/scverse/spatialdata/pull/138" - ) - self._add_image_in_memory(name=name, image=image, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_image( - # image=self.images[name], - # group=root, - # name=name, - # storage_options=storage_options, - # ) - # src_element_path = Path(store.path) / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "images" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_multiscale - # - # # reload the image from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # image = _read_multiscale(str(tgt_element_path), raster_type="image") - # self._add_image_in_memory(name=name, image=image, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="images", overwrite=overwrite) - write_image( - image=self.images[name], - group=elem_group, - name=name, - storage_options=storage_options, - ) - from spatialdata._io.io_raster import _read_multiscale - - # reload the image from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - assert elem_group.path == "images" - path = Path(elem_group.store.path) / "images" / name - image = _read_multiscale(path, raster_type="image") - self._add_image_in_memory(name=name, image=image, overwrite=True) - else: - self._add_image_in_memory(name=name, image=image, overwrite=overwrite) - - def add_labels( - self, - name: str, - labels: SpatialImage | MultiscaleSpatialImage, - storage_options: JSONDict | list[JSONDict] | None = None, - overwrite: bool = False, - ) -> None: - """ - Add labels to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - labels - The labels (masks) to add, the object needs to pass validation - (see :class:`~spatialdata.Labels2DModel` and :class:`~spatialdata.Labels3DModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(labels) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "labels", name)) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. We are considering changing this behavior to allow the overwriting of " - "elements used for backing. If you would like to support this use case please leave a comment on " - "https://github.com/scverse/spatialdata/pull/138" - ) - self._add_labels_in_memory(name=name, labels=labels, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_labels( - # labels=self.labels[name], - # group=root, - # name=name, - # storage_options=storage_options, - # ) - # src_element_path = Path(store.path) / "labels" / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "labels" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_multiscale - # - # # reload the labels from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # labels = _read_multiscale(str(tgt_element_path), raster_type="labels") - # self._add_labels_in_memory(name=name, labels=labels, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="labels", overwrite=overwrite) - write_labels( - labels=self.labels[name], - group=elem_group, - name=name, - storage_options=storage_options, - ) - # reload the labels from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - from spatialdata._io.io_raster import _read_multiscale - - # just a check to make sure that things go as expected - assert elem_group.path == "" - path = Path(elem_group.store.path) / "labels" / name - labels = _read_multiscale(path, raster_type="labels") - self._add_labels_in_memory(name=name, labels=labels, overwrite=True) - else: - self._add_labels_in_memory(name=name, labels=labels, overwrite=overwrite) - - def add_points( - self, - name: str, - points: DaskDataFrame, - overwrite: bool = False, - ) -> None: - """ - Add points to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - points - The points to add, the object needs to pass validation (see :class:`spatialdata.PointsModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - if self.is_backed(): - files = get_backing_files(points) - assert self.path is not None - target_path = os.path.realpath(os.path.join(self.path, "points", name, "points.parquet")) - if target_path in files: - raise ValueError( - "Cannot add the image to the SpatialData object because it would overwrite an element that it is" - "using for backing. We are considering changing this behavior to allow the overwriting of " - "elements used for backing. If you would like to support this use case please leave a comment on " - "https://github.com/scverse/spatialdata/pull/138" - ) - self._add_points_in_memory(name=name, points=points, overwrite=overwrite) - # old code to support overwriting the backing file - # with tempfile.TemporaryDirectory() as tmpdir: - # store = parse_url(Path(tmpdir) / "data.zarr", mode="w").store - # root = zarr.group(store=store) - # write_points( - # points=self.points[name], - # group=root, - # name=name, - # ) - # src_element_path = Path(store.path) / name - # assert isinstance(self.path, str) - # tgt_element_path = Path(self.path) / "points" / name - # if os.path.isdir(tgt_element_path) and overwrite: - # element_store = parse_url(tgt_element_path, mode="w").store - # _ = zarr.group(store=element_store, overwrite=True) - # element_store.close() - # pathlib.Path(tgt_element_path).mkdir(parents=True, exist_ok=True) - # for file in os.listdir(str(src_element_path)): - # src_file = src_element_path / file - # tgt_file = tgt_element_path / file - # os.rename(src_file, tgt_file) - # from spatialdata._io.read import _read_points - # - # # reload the points from the Zarr storage so that now the element is lazy loaded, and most importantly, - # # from the correct storage - # points = _read_points(str(tgt_element_path)) - # self._add_points_in_memory(name=name, points=points, overwrite=True) - elem_group = self._init_add_element(name=name, element_type="points", overwrite=overwrite) - write_points( - points=self.points[name], - group=elem_group, - name=name, - ) - # reload the points from the Zarr storage so that now the element is lazy loaded, and most importantly, - # from the correct storage - from spatialdata._io.io_points import _read_points - - assert elem_group.path == "points" - - path = Path(elem_group.store.path) / "points" / name - points = _read_points(path) - self._add_points_in_memory(name=name, points=points, overwrite=True) - else: - self._add_points_in_memory(name=name, points=points, overwrite=overwrite) - - def add_shapes( - self, - name: str, - shapes: GeoDataFrame, - overwrite: bool = False, - ) -> None: - """ - Add shapes to the SpatialData object. - - Parameters - ---------- - name - Key to the element inside the SpatialData object. - shapes - The shapes to add, the object needs to pass validation (see :class:`~spatialdata.ShapesModel`). - storage_options - Storage options for the Zarr storage. - See https://zarr.readthedocs.io/en/stable/api/storage.html for more details. - overwrite - If True, overwrite the element if it already exists. - - Notes - ----- - If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage. - """ - self._add_shapes_in_memory(name=name, shapes=shapes, overwrite=overwrite) - if self.is_backed(): - elem_group = self._init_add_element(name=name, element_type="shapes", overwrite=overwrite) - write_shapes( - shapes=self.shapes[name], - group=elem_group, - name=name, - ) - # no reloading of the file storage since the AnnData is not lazy loaded + if element_type != "tables": + transformed = sdata.transform_element_to_coordinate_system(element, target_coordinate_system) + if element_type not in elements: + elements[element_type] = {} + elements[element_type][element_name] = transformed + return SpatialData(**elements, tables=sdata.tables) def write( self, @@ -975,17 +846,21 @@ def write( overwrite: bool = False, consolidate_metadata: bool = True, ) -> None: + from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table + from spatialdata._io._utils import get_dask_backing_files + """Write the SpatialData object to Zarr.""" if isinstance(file_path, str): file_path = Path(file_path) assert isinstance(file_path, Path) - if self.is_backed() and self.path != file_path: + if self.is_backed() and str(self.path) != str(file_path): logger.info(f"The Zarr file used for backing will now change from {self.path} to {file_path}") # old code to support overwriting the backing file # target_path = None # tmp_zarr_file = None + if os.path.exists(file_path): if parse_url(file_path, mode="r") is None: raise ValueError( @@ -993,14 +868,22 @@ def write( "a Zarr store. Overwriting non-Zarr stores is not supported to prevent accidental " "data loss." ) - if not overwrite and self.path != str(file_path): + if not overwrite: raise ValueError("The Zarr store already exists. Use `overwrite=True` to overwrite the store.") - raise ValueError( - "The file path specified is the same as the one used for backing. " - "Overwriting the backing file is not supported to prevent accidental data loss." - "We are discussing how to support this use case in the future, if you would like us to " - "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" - ) + if self.is_backed() and str(self.path) == str(file_path): + raise ValueError( + "The file path specified is the same as the one used for backing. " + "Overwriting the backing file is not supported to prevent accidental data loss." + "We are discussing how to support this use case in the future, if you would like us to " + "support it please leave a comment on https://github.com/scverse/spatialdata/pull/138" + ) + if any(Path(fp).resolve().is_relative_to(file_path.resolve()) for fp in get_dask_backing_files(self)): + raise ValueError( + "The file path specified is a parent directory of one or more files used for backing for one or " + "more elements in the SpatialData object. You can either load every element of the SpatialData " + "object in memory, or save the current spatialdata object to a different path." + ) + # old code to support overwriting the backing file # else: # target_path = tempfile.TemporaryDirectory() @@ -1023,14 +906,13 @@ def write( # self.path = str(file_path) # else: # self.path = str(tmp_zarr_file) - self.path = str(file_path) + self._path = Path(file_path) try: if len(self.images): root.create_group(name="images") # add_image_in_memory will delete and replace the same key in self.images, # so we need to make a copy of the keys. Same for the other elements keys = self.images.keys() - from spatialdata._io.io_raster import _read_multiscale for name in keys: elem_group = self._init_add_element(name=name, element_type="images", overwrite=overwrite) @@ -1041,17 +923,16 @@ def write( storage_options=storage_options, ) + # TODO(giovp): fix or remove # reload the image from the Zarr storage so that now the element is lazy loaded, # and most importantly, from the correct storage - element_path = Path(self.path) / "images" / name - image = _read_multiscale(element_path, raster_type="image") - self._add_image_in_memory(name=name, image=image, overwrite=True) + # element_path = Path(self.path) / "images" / name + # _read_multiscale(element_path, raster_type="image") if len(self.labels): root.create_group(name="labels") # keys = list(self.labels.keys()) keys = self.labels.keys() - from spatialdata._io.io_raster import _read_multiscale for name in keys: elem_group = self._init_add_element(name=name, element_type="labels", overwrite=overwrite) @@ -1062,17 +943,16 @@ def write( storage_options=storage_options, ) + # TODO(giovp): fix or remove # reload the labels from the Zarr storage so that now the element is lazy loaded, # and most importantly, from the correct storage - element_path = Path(self.path) / "labels" / name - labels = _read_multiscale(element_path, raster_type="labels") - self._add_labels_in_memory(name=name, labels=labels, overwrite=True) + # element_path = Path(self.path) / "labels" / name + # _read_multiscale(element_path, raster_type="labels") if len(self.points): root.create_group(name="points") # keys = list(self.points.keys()) keys = self.points.keys() - from spatialdata._io.io_points import _read_points for name in keys: elem_group = self._init_add_element(name=name, element_type="points", overwrite=overwrite) @@ -1081,12 +961,12 @@ def write( group=elem_group, name=name, ) - element_path = Path(self.path) / "points" / name + # TODO(giovp): fix or remove + # element_path = Path(self.path) / "points" / name - # reload the points from the Zarr storage so that the element is lazy loaded, - # and most importantly, from the correct storage - points = _read_points(element_path) - self._add_points_in_memory(name=name, points=points, overwrite=True) + # # reload the points from the Zarr storage so that the element is lazy loaded, + # # and most importantly, from the correct storage + # _read_points(element_path) if len(self.shapes): root.create_group(name="shapes") @@ -1099,14 +979,14 @@ def write( group=elem_group, name=name, ) - # no reloading of the file storage since the AnnData is not lazy loaded - if self.table is not None: - elem_group = root.create_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if len(self.tables): + elem_group = root.create_group(name="tables") + for key in self.tables: + write_table(table=self.tables[key], group=elem_group, name=key) except Exception as e: # noqa: B902 - self.path = None + self._path = None raise e if consolidate_metadata: @@ -1147,56 +1027,76 @@ def write( # else: # raise ValueError(f"Unknown element type {element_type}") # self.__getattribute__(element_type)[name] = element - assert isinstance(self.path, str) + assert isinstance(self.path, Path) @property - def table(self) -> AnnData: + def tables(self) -> Tables: """ - Return the table. + Return tables dictionary. Returns ------- - The table. + dict[str, AnnData] + Either the empty dictionary or a dictionary with as values the strings representing the table names and + as values the AnnData tables themselves. """ - return self._table + return self._tables - @table.setter - def table(self, table: AnnData) -> None: - """ - Set the table of a SpatialData object in a object that doesn't contain a table. + @tables.setter + def tables(self, shapes: dict[str, GeoDataFrame]) -> None: + """Set tables.""" + self._shared_keys = self._shared_keys - set(self._tables.keys()) + self._tables = Tables(shared_keys=self._shared_keys) + for k, v in shapes.items(): + self._tables[k] = v - Parameters - ---------- - table - The table to set. + @property + def table(self) -> None | AnnData: + """ + Return table with name table from tables if it exists. - Notes - ----- - If a table is already present, it needs to be removed first. - The table needs to pass validation (see :class:`~spatialdata.TableModel`). - If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage. + Returns + ------- + The table. """ + warnings.warn( + "Table accessor will be deprecated with SpatialData version 0.1, use sdata.tables instead.", + DeprecationWarning, + stacklevel=2, + ) + # Isinstance will still return table if anndata has 0 rows. + if isinstance(self.tables.get("table"), AnnData): + return self.tables["table"] + return None + + @table.setter + def table(self, table: AnnData) -> None: + warnings.warn( + "Table setter will be deprecated with SpatialData version 0.1, use tables instead.", + DeprecationWarning, + stacklevel=2, + ) TableModel().validate(table) - if self.table is not None: - raise ValueError("The table already exists. Use del sdata.table to remove it first.") - self._table = table - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - elem_group = root.require_group(name="table") - write_table(table=self.table, group=elem_group, name="table") + if self.tables.get("table") is not None: + raise ValueError("The table already exists. Use del sdata.tables['table'] to remove it first.") + self.tables["table"] = table @table.deleter def table(self) -> None: """Delete the table.""" - self._table = None - if self.is_backed(): - store = parse_url(self.path, mode="r+").store - root = zarr.group(store=store) - del root["table/table"] + warnings.warn( + "del sdata.table will be deprecated with SpatialData version 0.1, use del sdata.tables['table'] instead.", + DeprecationWarning, + stacklevel=2, + ) + if self.tables.get("table"): + del self.tables["table"] + else: + # More informative than the error in the zarr library. + raise KeyError("table with name 'table' not present in the SpatialData object.") @staticmethod - def read(file_path: str, selection: tuple[str] | None = None) -> SpatialData: + def read(file_path: Path | str, selection: tuple[str] | None = None) -> SpatialData: """ Read a SpatialData object from a Zarr storage (on-disk or remote). @@ -1215,32 +1115,98 @@ def read(file_path: str, selection: tuple[str] | None = None) -> SpatialData: return read_zarr(file_path, selection=selection) + def add_image( + self, + name: str, + image: SpatialImage | MultiscaleSpatialImage, + storage_options: JSONDict | list[JSONDict] | None = None, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_labels( + self, + name: str, + labels: SpatialImage | MultiscaleSpatialImage, + storage_options: JSONDict | list[JSONDict] | None = None, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_points( + self, + name: str, + points: DaskDataFrame, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + + def add_shapes( + self, + name: str, + shapes: GeoDataFrame, + overwrite: bool = False, + ) -> None: + _error_message_add_element() + @property - def images(self) -> dict[str, SpatialImage | MultiscaleSpatialImage]: + def images(self) -> Images: """Return images as a Dict of name to image data.""" return self._images + @images.setter + def images(self, images: dict[str, Raster_T]) -> None: + """Set images.""" + self._shared_keys = self._shared_keys - set(self._images.keys()) + self._images = Images(shared_keys=self._shared_keys) + for k, v in images.items(): + self._images[k] = v + @property - def labels(self) -> dict[str, SpatialImage | MultiscaleSpatialImage]: + def labels(self) -> Labels: """Return labels as a Dict of name to label data.""" return self._labels + @labels.setter + def labels(self, labels: dict[str, Raster_T]) -> None: + """Set labels.""" + self._shared_keys = self._shared_keys - set(self._labels.keys()) + self._labels = Labels(shared_keys=self._shared_keys) + for k, v in labels.items(): + self._labels[k] = v + @property - def points(self) -> dict[str, DaskDataFrame]: + def points(self) -> Points: """Return points as a Dict of name to point data.""" return self._points + @points.setter + def points(self, points: dict[str, DaskDataFrame]) -> None: + """Set points.""" + self._shared_keys = self._shared_keys - set(self._points.keys()) + self._points = Points(shared_keys=self._shared_keys) + for k, v in points.items(): + self._points[k] = v + @property - def shapes(self) -> dict[str, GeoDataFrame]: + def shapes(self) -> Shapes: """Return shapes as a Dict of name to shape data.""" return self._shapes + @shapes.setter + def shapes(self, shapes: dict[str, GeoDataFrame]) -> None: + """Set shapes.""" + self._shared_keys = self._shared_keys - set(self._shapes.keys()) + self._shapes = Shapes(shared_keys=self._shared_keys) + for k, v in shapes.items(): + self._shapes[k] = v + @property def coordinate_systems(self) -> list[str]: from spatialdata.transformations.operations import get_transformation all_cs = set() - gen = self._gen_elements_values() + gen = self._gen_spatial_element_values() for obj in gen: transformations = get_transformation(obj, get_all=True) assert isinstance(transformations, dict) @@ -1256,7 +1222,7 @@ def _non_empty_elements(self) -> list[str]: non_empty_elements The names of the elements that are not empty. """ - all_elements = ["images", "labels", "points", "shapes", "table"] + all_elements = ["images", "labels", "points", "shapes", "tables"] return [ element for element in all_elements @@ -1276,6 +1242,7 @@ def _gen_repr( ------- The string representation of the SpatialData object. """ + from spatialdata._utils import _natural_keys def rreplace(s: str, old: str, new: str, occurrence: int) -> str: li = s.rsplit(old, occurrence) @@ -1293,71 +1260,64 @@ def h(s: str) -> str: attribute = getattr(self, attr) descr += f"\n{h('level0')}{attr.capitalize()}" - if isinstance(attribute, AnnData): + + unsorted_elements = attribute.items() + sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) + for k, v in sorted_elements: descr += f"{h('empty_line')}" - descr_class = attribute.__class__.__name__ - descr += f"{h('level1.0')}{attribute!r}: {descr_class} {attribute.shape}" - descr = rreplace(descr, h("level1.0"), " └── ", 1) - else: - unsorted_elements = attribute.items() - sorted_elements = sorted(unsorted_elements, key=lambda x: _natural_keys(x[0])) - for k, v in sorted_elements: - descr += f"{h('empty_line')}" - descr_class = v.__class__.__name__ - if attr == "shapes": - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" - elif attr == "points": - length: int | None = None - if len(v.dask.layers) == 1: - name, layer = v.dask.layers.items().__iter__().__next__() - if "read-parquet" in name: - t = layer.creation_info["args"] - assert isinstance(t, tuple) - assert len(t) == 1 - parquet_file = t[0] - table = read_parquet(parquet_file) - length = len(table) - else: - # length = len(v) - length = None + descr_class = v.__class__.__name__ + if attr == "shapes": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"shape: {v.shape} (2D shapes)" + elif attr == "points": + length: int | None = None + if len(v.dask.layers) == 1: + name, layer = v.dask.layers.items().__iter__().__next__() + if "read-parquet" in name: + t = layer.creation_info["args"] + assert isinstance(t, tuple) + assert len(t) == 1 + parquet_file = t[0] + table = read_parquet(parquet_file) + length = len(table) else: + # length = len(v) length = None + else: + length = None - n = len(get_axes_names(v)) - dim_string = f"({n}D points)" + n = len(get_axes_names(v)) + dim_string = f"({n}D points)" - assert len(v.shape) == 2 - if length is not None: - shape_str = f"({length}, {v.shape[1]})" - else: - shape_str = ( - "(" - + ", ".join( - [str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape] - ) - + ")" - ) - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + assert len(v.shape) == 2 + if length is not None: + shape_str = f"({length}, {v.shape[1]})" else: - if isinstance(v, SpatialImage): - descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" - elif isinstance(v, MultiscaleSpatialImage): - shapes = [] - dims: str | None = None - for pyramid_level in v: - dataset_names = list(v[pyramid_level].keys()) - assert len(dataset_names) == 1 - dataset_name = dataset_names[0] - vv = v[pyramid_level][dataset_name] - shape = vv.shape - if dims is None: - dims = "".join(vv.dims) - shapes.append(shape) - descr += ( - f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" - ) - else: - raise TypeError(f"Unknown type {type(v)}") + shape_str = ( + "(" + + ", ".join([str(dim) if not isinstance(dim, Delayed) else "" for dim in v.shape]) + + ")" + ) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} " f"with shape: {shape_str} {dim_string}" + elif attr == "tables": + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class} {v.shape}" + else: + if isinstance(v, SpatialImage): + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{''.join(v.dims)}] {v.shape}" + elif isinstance(v, MultiscaleSpatialImage): + shapes = [] + dims: str | None = None + for pyramid_level in v: + dataset_names = list(v[pyramid_level].keys()) + assert len(dataset_names) == 1 + dataset_name = dataset_names[0] + vv = v[pyramid_level][dataset_name] + shape = vv.shape + if dims is None: + dims = "".join(vv.dims) + shapes.append(shape) + descr += f"{h(attr + 'level1.1')}{k!r}: {descr_class}[{dims}] " f"{', '.join(map(str, shapes))}" + else: + raise TypeError(f"Unknown type {type(v)}") if last_attr is True: descr = descr.replace(h("empty_line"), "\n ") else: @@ -1366,7 +1326,7 @@ def h(s: str) -> str: descr = rreplace(descr, h("level0"), "└── ", 1) descr = descr.replace(h("level0"), "├── ") - for attr in ["images", "labels", "points", "table", "shapes"]: + for attr in ["images", "labels", "points", "tables", "shapes"]: descr = rreplace(descr, h(attr + "level1.1"), " └── ", 1) descr = descr.replace(h(attr + "level1.1"), " ├── ") @@ -1380,13 +1340,14 @@ def h(s: str) -> str: gen = self._gen_elements() elements_in_cs: dict[str, list[str]] = {} for k, name, obj in gen: - transformations = get_transformation(obj, get_all=True) - assert isinstance(transformations, dict) - target_css = transformations.keys() - if cs in target_css: - if k not in elements_in_cs: - elements_in_cs[k] = [] - elements_in_cs[k].append(name) + if not isinstance(obj, AnnData): + transformations = get_transformation(obj, get_all=True) + assert isinstance(transformations, dict) + target_css = transformations.keys() + if cs in target_css: + if k not in elements_in_cs: + elements_in_cs[k] = [] + elements_in_cs[k].append(name) for element_names in elements_in_cs.values(): element_names.sort(key=_natural_keys) if len(elements_in_cs) > 0: @@ -1402,26 +1363,97 @@ def h(s: str) -> str: descr += "\n" return descr - def _gen_elements_values(self) -> Generator[SpatialElement, None, None]: + def _gen_spatial_element_values(self) -> Generator[SpatialElement, None, None]: + """ + Generate spatial element objects contained in the SpatialData instance. + + Returns + ------- + Generator[SpatialElement, None, None] + A generator that yields spatial element objects contained in the SpatialData instance. + + """ for element_type in ["images", "labels", "points", "shapes"]: d = getattr(SpatialData, element_type).fget(self) yield from d.values() - def _gen_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: - for element_type in ["images", "labels", "points", "shapes"]: + def _gen_elements( + self, include_table: bool = False + ) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements contained in the SpatialData instance. + + Parameters + ---------- + include_table + Whether to also generate table elements. + + Returns + ------- + A generator object that returns a tuple containing the type of the element, its name, and the element + itself. + """ + element_types = ["images", "labels", "points", "shapes"] + if include_table: + element_types.append("tables") + for element_type in element_types: d = getattr(SpatialData, element_type).fget(self) for k, v in d.items(): yield element_type, k, v - def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement]: - for element_type, element_name_, element in self._gen_elements(): + def gen_spatial_elements(self) -> Generator[tuple[str, str, SpatialElement], None, None]: + """ + Generate spatial elements within the SpatialData object. + + This method generates spatial elements (images, labels, points and shapes). + + Returns + ------- + A generator that yields tuples containing the name, description, and SpatialElement objects themselves. + """ + return self._gen_elements() + + def gen_elements(self) -> Generator[tuple[str, str, SpatialElement | AnnData], None, None]: + """ + Generate elements within the SpatialData object. + + This method generates elements in the SpatialData object (images, labels, points, shapes and tables) + + Returns + ------- + A generator that yields tuples containing the name, description, and element objects themselves. + """ + return self._gen_elements(include_table=True) + + def _find_element(self, element_name: str) -> tuple[str, str, SpatialElement | AnnData]: + """ + Retrieve element from the SpatialData instance matching element_name. + + Parameters + ---------- + element_name + The name of the element to find. + + Returns + ------- + A tuple containing the element type, element name, and the retrieved element itself. + + Raises + ------ + KeyError + If the element with the given name cannot be found. + """ + for element_type, element_name_, element in self.gen_elements(): if element_name_ == element_name: return element_type, element_name_, element else: raise KeyError(f"Could not find element with name {element_name!r}") @classmethod - def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData | None = None) -> SpatialData: + @deprecation_alias(table="tables") + def init_from_elements( + cls, elements: dict[str, SpatialElement], tables: AnnData | dict[str, AnnData] | None = None + ) -> SpatialData: """ Create a SpatialData object from a dict of named elements and an optional table. @@ -1429,8 +1461,8 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData ---------- elements A dict of named elements. - table - An optional table. + tables + An optional table or dictionary of tables Returns ------- @@ -1449,7 +1481,46 @@ def init_from_elements(cls, elements: dict[str, SpatialElement], table: AnnData assert model == ShapesModel element_type = "shapes" elements_dict.setdefault(element_type, {})[name] = element - return cls(**elements_dict, table=table) + return cls(**elements_dict, tables=tables) + + def subset( + self, element_names: list[str], filter_tables: bool = True, include_orphan_tables: bool = False + ) -> SpatialData: + """ + Subset the SpatialData object. + + Parameters + ---------- + element_names + The names of the element_names to subset. If the element_name is the name of a table, this table would be + completely included in the subset even if filter_table is True. + filter_table + If True (default), the table is filtered to only contain rows that are annotating regions + contained within the element_names. + include_orphan_tables + If True (not default), include tables that do not annotate SpatialElement(s). Only has an effect if + filter_tables is also set to True. + + Returns + ------- + The subsetted SpatialData object. + """ + elements_dict: dict[str, SpatialElement] = {} + names_tables_to_keep: set[str] = set() + for element_type, element_name, element in self._gen_elements(include_table=True): + if element_name in element_names: + if element_type != "tables": + elements_dict.setdefault(element_type, {})[element_name] = element + else: + names_tables_to_keep.add(element_name) + tables = self._filter_tables( + names_tables_to_keep, + filter_tables, + "elements", + include_orphan_tables, + elements_dict=elements_dict, + ) + return SpatialData(**elements_dict, tables=tables) def __getitem__(self, item: str) -> SpatialElement: """ @@ -1480,17 +1551,17 @@ def __setitem__(self, key: str, value: SpatialElement | AnnData) -> None: """ schema = get_model(value) if schema in (Image2DModel, Image3DModel): - self.add_image(key, value) + self.images[key] = value elif schema in (Labels2DModel, Labels3DModel): - self.add_labels(key, value) + self.labels[key] = value elif schema == PointsModel: - self.add_points(key, value) + self.points[key] = value elif schema == ShapesModel: - self.add_shapes(key, value) + self.shapes[key] = value elif schema == TableModel: - raise TypeError("Use the table property to set the table (e.g. sdata.table = value)") + self.tables[key] = value else: - raise TypeError(f"Unknown element type with schema{schema!r}") + raise TypeError(f"Unknown element type with schema: {schema!r}.") class QueryManager: diff --git a/src/spatialdata/_io/__init__.py b/src/spatialdata/_io/__init__.py index fd72da5c..d9fc3cd6 100644 --- a/src/spatialdata/_io/__init__.py +++ b/src/spatialdata/_io/__init__.py @@ -1,3 +1,4 @@ +from spatialdata._io._utils import get_dask_backing_files from spatialdata._io.format import SpatialDataFormatV01 from spatialdata._io.io_points import write_points from spatialdata._io.io_raster import write_image, write_labels @@ -11,4 +12,5 @@ "write_shapes", "write_table", "SpatialDataFormatV01", + "get_dask_backing_files", ] diff --git a/src/spatialdata/_io/_utils.py b/src/spatialdata/_io/_utils.py index bfa12721..f5caa59d 100644 --- a/src/spatialdata/_io/_utils.py +++ b/src/spatialdata/_io/_utils.py @@ -8,17 +8,24 @@ from collections.abc import Generator, Mapping from contextlib import contextmanager from functools import singledispatch -from typing import TYPE_CHECKING, Any +from typing import Any +import numpy as np import zarr +from anndata import AnnData +from anndata import read_zarr as read_anndata_zarr +from anndata.experimental import read_elem +from dask.array.core import Array as DaskArray from dask.dataframe.core import DataFrame as DaskDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from ome_zarr.format import Format from ome_zarr.writer import _get_valid_axes from spatial_image import SpatialImage -from xarray import DataArray +from spatialdata._core.spatialdata import SpatialData +from spatialdata._logging import logger from spatialdata._utils import iterate_pyramid_levels +from spatialdata.models import TableModel from spatialdata.models._utils import ( MappingToCoordinateSystem_t, ValidAxis_t, @@ -30,9 +37,6 @@ _get_current_output_axes, ) -if TYPE_CHECKING: - from spatialdata import SpatialData - # suppress logger debug from ome_zarr with context manager @contextmanager @@ -175,8 +179,8 @@ def _are_directories_identical( if _root_dir2 is None: _root_dir2 = dir2 if exclude_regexp is not None and ( - re.match(rf"{_root_dir1}/" + exclude_regexp, str(dir1)) - or re.match(rf"{_root_dir2}/" + exclude_regexp, str(dir2)) + re.match(rf"{re.escape(str(_root_dir1))}/" + exclude_regexp, str(dir1)) + or re.match(rf"{re.escape(str(_root_dir2))}/" + exclude_regexp, str(dir2)) ): return True @@ -196,8 +200,6 @@ def _are_directories_identical( def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: - from spatialdata import SpatialData - if not isinstance(a, SpatialData) or not isinstance(b, SpatialData): return False # TODO: if the sdata object is backed on disk, don't create a new zarr file @@ -207,37 +209,59 @@ def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool: return _are_directories_identical(os.path.join(tmpdir, "a.zarr"), os.path.join(tmpdir, "b.zarr")) -def _get_backing_files_raster(raster: DataArray) -> list[str]: - files = [] - for k, v in raster.data.dask.layers.items(): - if k.startswith("original-from-zarr-"): - mapping = v.mapping[k] - path = mapping.store.path - files.append(os.path.realpath(path)) - return files +@singledispatch +def get_dask_backing_files(element: SpatialData | SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + """ + Get the backing files that appear in the Dask computational graph of an element/any element of a SpatialData object. + Parameters + ---------- + element + The element to get the backing files from. -@singledispatch -def get_backing_files(element: SpatialImage | MultiscaleSpatialImage | DaskDataFrame) -> list[str]: + Returns + ------- + List of backing files. + + Notes + ----- + It is possible for lazy objects to be constructed from multiple files. + """ raise TypeError(f"Unsupported type: {type(element)}") -@get_backing_files.register(SpatialImage) +@get_dask_backing_files.register(SpatialData) +def _(element: SpatialData) -> list[str]: + files: set[str] = set() + for e in element._gen_spatial_element_values(): + if isinstance(e, (SpatialImage, MultiscaleSpatialImage, DaskDataFrame)): + files = files.union(get_dask_backing_files(e)) + return list(files) + + +@get_dask_backing_files.register(SpatialImage) def _(element: SpatialImage) -> list[str]: - return _get_backing_files_raster(element) + return _get_backing_files(element.data) -@get_backing_files.register(MultiscaleSpatialImage) +@get_dask_backing_files.register(MultiscaleSpatialImage) def _(element: MultiscaleSpatialImage) -> list[str]: xdata0 = next(iter(iterate_pyramid_levels(element))) - return _get_backing_files_raster(xdata0) + return _get_backing_files(xdata0.data) -@get_backing_files.register(DaskDataFrame) +@get_dask_backing_files.register(DaskDataFrame) def _(element: DaskDataFrame) -> list[str]: + return _get_backing_files(element) + + +def _get_backing_files(element: DaskArray | DaskDataFrame) -> list[str]: files = [] - layers = element.dask.layers - for k, v in layers.items(): + for k, v in element.dask.layers.items(): + if k.startswith("original-from-zarr-"): + mapping = v.mapping[k] + path = mapping.store.path + files.append(os.path.realpath(path)) if k.startswith("read-parquet-"): t = v.creation_info["args"] assert isinstance(t, tuple) @@ -286,6 +310,57 @@ def save_transformations(sdata: SpatialData) -> None: """ from spatialdata.transformations import get_transformation, set_transformation - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): transformations = get_transformation(element, get_all=True) set_transformation(element, transformations, set_all=True, write_to_sdata=sdata) + + +def read_table_and_validate( + zarr_store_path: str, group: zarr.Group, subgroup: zarr.Group, tables: dict[str, AnnData] +) -> dict[str, AnnData]: + """ + Read in tables in the tables Zarr.group of a SpatialData Zarr store. + + Parameters + ---------- + zarr_store_path + The path to the Zarr store. + group + The parent group containing the subgroup. + subgroup + The subgroup containing the tables. + tables + A dictionary of tables. + + Returns + ------- + The modified dictionary with the tables. + """ + count = 0 + for table_name in subgroup: + f_elem = subgroup[table_name] + f_elem_store = os.path.join(zarr_store_path, f_elem.path) + if isinstance(group.store, zarr.storage.ConsolidatedMetadataStore): + tables[table_name] = read_elem(f_elem) + # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) + # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 + # table = read_anndata_zarr(f_elem) + else: + tables[table_name] = read_anndata_zarr(f_elem_store) + if TableModel.ATTRS_KEY in tables[table_name].uns: + # fill out eventual missing attributes that has been omitted because their value was None + attrs = tables[table_name].uns[TableModel.ATTRS_KEY] + if "region" not in attrs: + attrs["region"] = None + if "region_key" not in attrs: + attrs["region_key"] = None + if "instance_key" not in attrs: + attrs["instance_key"] = None + # fix type for region + if "region" in attrs and isinstance(attrs["region"], np.ndarray): + attrs["region"] = attrs["region"].tolist() + + count += 1 + + logger.debug(f"Found {count} elements in {subgroup}") + return tables diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 7fafb676..57a6069c 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -1,4 +1,3 @@ -import os from pathlib import Path from typing import Any, Literal, Optional, Union @@ -67,7 +66,8 @@ def _read_multiscale( # and for instance in the xenium example encoded_ngff_transformations = multiscales[0]["coordinateTransformations"] transformations = _get_transformations_from_ngff_dict(encoded_ngff_transformations) - name = os.path.basename(node.metadata["name"]) + # TODO: what to do with name? For now remove? + # name = os.path.basename(node.metadata["name"]) # if image, read channels metadata channels: Optional[list[Any]] = None if raster_type == "image" and channels_metadata is not None: @@ -79,7 +79,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=d, version=fmt.version) multiscale_image[f"scale{i}"] = DataArray( data, - name=name, + name="image", dims=axes, coords={"c": channels} if channels is not None else {}, ) @@ -89,7 +89,7 @@ def _read_multiscale( data = node.load(Multiscales).array(resolution=datasets[0], version=fmt.version) si = SpatialImage( data, - name=name, + name="image", dims=axes, coords={"c": channels} if channels is not None else {}, ) diff --git a/src/spatialdata/_io/io_table.py b/src/spatialdata/_io/io_table.py index 72ae5f4c..ead604af 100644 --- a/src/spatialdata/_io/io_table.py +++ b/src/spatialdata/_io/io_table.py @@ -4,6 +4,7 @@ from ome_zarr.format import Format from spatialdata._io.format import CurrentTablesFormat +from spatialdata.models import TableModel def write_table( @@ -13,10 +14,13 @@ def write_table( group_type: str = "ngff:regions_table", fmt: Format = CurrentTablesFormat(), ) -> None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"].get("region_key", None) - instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) - fmt.validate_table(table, region_key, instance_key) + if TableModel.ATTRS_KEY in table.uns: + region = table.uns["spatialdata_attrs"]["region"] + region_key = table.uns["spatialdata_attrs"].get("region_key", None) + instance_key = table.uns["spatialdata_attrs"].get("instance_key", None) + fmt.validate_table(table, region_key, instance_key) + else: + region, region_key, instance_key = (None, None, None) write_adata(group, name, table) # creates group[name] tables_group = group[name] tables_group.attrs["spatialdata-encoding-type"] = group_type diff --git a/src/spatialdata/_io/io_zarr.py b/src/spatialdata/_io/io_zarr.py index 7b4f286c..f5e378a2 100644 --- a/src/spatialdata/_io/io_zarr.py +++ b/src/spatialdata/_io/io_zarr.py @@ -1,21 +1,18 @@ import logging import os +import warnings from pathlib import Path from typing import Optional, Union -import numpy as np import zarr from anndata import AnnData -from anndata import read_zarr as read_anndata_zarr -from anndata.experimental import read_elem -from spatialdata import SpatialData -from spatialdata._io._utils import ome_zarr_logger +from spatialdata._core.spatialdata import SpatialData +from spatialdata._io._utils import ome_zarr_logger, read_table_and_validate from spatialdata._io.io_points import _read_points from spatialdata._io.io_raster import _read_multiscale from spatialdata._io.io_shapes import _read_shapes from spatialdata._logging import logger -from spatialdata.models import TableModel def _open_zarr_store(store: Union[str, Path, zarr.Group]) -> tuple[zarr.Group, str]: @@ -61,10 +58,11 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str images = {} labels = {} points = {} - table: Optional[AnnData] = None + tables: dict[str, AnnData] = {} shapes = {} - selector = {"images", "labels", "points", "shapes", "table"} if not selection else set(selection or []) + # TODO: remove table once deprecated. + selector = {"images", "labels", "points", "shapes", "tables", "table"} if not selection else set(selection or []) logger.debug(f"Reading selection {selector}") # read multiscale images @@ -123,36 +121,21 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str shapes[subgroup_name] = _read_shapes(f_elem_store) count += 1 logger.debug(f"Found {count} elements in {group}") + if "tables" in selector and "tables" in f: + group = f["tables"] + tables = read_table_and_validate(f_store_path, f, group, tables) if "table" in selector and "table" in f: - group = f["table"] - count = 0 - for subgroup_name in group: - if Path(subgroup_name).name.startswith("."): - # skip hidden files like .zgroup or .zmetadata - continue - f_elem = group[subgroup_name] - f_elem_store = os.path.join(f_store_path, f_elem.path) - if isinstance(f.store, zarr.storage.ConsolidatedMetadataStore): - table = read_elem(f_elem) - # we can replace read_elem with read_anndata_zarr after this PR gets into a release (>= 0.6.5) - # https://github.com/scverse/anndata/pull/1057#pullrequestreview-1530623183 - # table = read_anndata_zarr(f_elem) - else: - table = read_anndata_zarr(f_elem_store) - if TableModel.ATTRS_KEY in table.uns: - # fill out eventual missing attributes that has been omitted because their value was None - attrs = table.uns[TableModel.ATTRS_KEY] - if "region" not in attrs: - attrs["region"] = None - if "region_key" not in attrs: - attrs["region_key"] = None - if "instance_key" not in attrs: - attrs["instance_key"] = None - # fix type for region - if "region" in attrs and isinstance(attrs["region"], np.ndarray): - attrs["region"] = attrs["region"].tolist() - count += 1 + warnings.warn( + f"Table group found in zarr store at location {f_store_path}. Please update the zarr store" + f"to use tables instead.", + DeprecationWarning, + stacklevel=2, + ) + subgroup_name = "table" + group = f[subgroup_name] + tables = read_table_and_validate(f_store_path, f, group, tables) + logger.debug(f"Found {count} elements in {group}") sdata = SpatialData( @@ -160,7 +143,7 @@ def read_zarr(store: Union[str, Path, zarr.Group], selection: Optional[tuple[str labels=labels, points=points, shapes=shapes, - table=table, + tables=tables, ) - sdata.path = str(store) + sdata._path = Path(store) return sdata diff --git a/src/spatialdata/_types.py b/src/spatialdata/_types.py index 6ae68e1a..30b235c1 100644 --- a/src/spatialdata/_types.py +++ b/src/spatialdata/_types.py @@ -1,8 +1,12 @@ from __future__ import annotations +from typing import Union + import numpy as np +from multiscale_spatial_image import MultiscaleSpatialImage +from spatial_image import SpatialImage -__all__ = ["ArrayLike", "DTypeLike"] +__all__ = ["ArrayLike", "DTypeLike", "Raster_T"] try: from numpy.typing import DTypeLike, NDArray @@ -11,3 +15,5 @@ except (ImportError, TypeError): ArrayLike = np.ndarray # type: ignore[misc] DTypeLike = np.dtype # type: ignore[misc] + +Raster_T = Union[SpatialImage, MultiscaleSpatialImage] diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 59eaec6c..205308e8 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -1,9 +1,11 @@ from __future__ import annotations +import functools import re +import warnings from collections.abc import Generator from copy import deepcopy -from typing import TYPE_CHECKING, Union +from typing import Any, Callable, TypeVar, Union import numpy as np import pandas as pd @@ -25,9 +27,7 @@ # I was using "from numbers import Number" but this led to mypy errors, so I switched to the following: Number = Union[int, float] - -if TYPE_CHECKING: - pass +RT = TypeVar("RT") def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike: @@ -231,3 +231,83 @@ def _deepcopy_geodataframe(gdf: GeoDataFrame) -> GeoDataFrame: new_attrs = deepcopy(gdf.attrs) new_gdf.attrs = new_attrs return new_gdf + + +# TODO: change to paramspec as soon as we drop support for python 3.9, see https://stackoverflow.com/a/68290080 +def deprecation_alias(**aliases: str) -> Callable[[Callable[..., RT]], Callable[..., RT]]: + """ + Decorate a function to warn user of use of arguments set for deprecation. + + Parameters + ---------- + aliases + Deprecation argument aliases to be mapped to the new arguments. + + Returns + ------- + A decorator that can be used to mark an argument for deprecation and substituting it with the new argument. + + Raises + ------ + TypeError + If the provided aliases are not of string type. + + Example + ------- + Assuming we have an argument 'table' set for deprecation and we want to warn the user and substitute with 'tables': + + ```python + @deprecation_alias(table="tables") + def my_function(tables: AnnData | dict[str, AnnData]): + pass + ``` + """ + + def deprecation_decorator(f: Callable[..., RT]) -> Callable[..., RT]: + @functools.wraps(f) + def wrapper(*args: Any, **kwargs: Any) -> RT: + class_name = f.__qualname__ + rename_kwargs(f.__name__, kwargs, aliases, class_name) + return f(*args, **kwargs) + + return wrapper + + return deprecation_decorator + + +def rename_kwargs(func_name: str, kwargs: dict[str, Any], aliases: dict[str, str], class_name: None | str) -> None: + """Rename function arguments set for deprecation and gives warning in case of usage of these arguments.""" + for alias, new in aliases.items(): + if alias in kwargs: + class_name = class_name + "." if class_name else "" + if new in kwargs: + raise TypeError( + f"{class_name}{func_name} received both {alias} and {new} as arguments!" + f" {alias} is being deprecated in SpatialData version 0.1, only use {new} instead." + ) + warnings.warn( + message=( + f"`{alias}` is being deprecated as an argument to `{class_name}{func_name}` in SpatialData " + f"version 0.1, switch to `{new}` instead." + ), + category=DeprecationWarning, + stacklevel=3, + ) + kwargs[new] = kwargs.pop(alias) + + +def _error_message_add_element() -> None: + raise RuntimeError( + "The functions add_image(), add_labels(), add_points() and add_shapes() have been removed in favor of " + "dict-like access to the elements. Please use the following syntax to add an element:\n" + "\n" + '\tsdata.images["image_name"] = image\n' + '\tsdata.labels["labels_name"] = labels\n' + "\t...\n" + "\n" + "The new syntax does not automatically updates the disk storage, so you need to call sdata.write() when " + "the in-memory object is ready to be saved.\n" + "To save only a new specific element to an existing Zarr storage please use the functions write_image(), " + "write_labels(), write_points(), write_shapes() and write_table(). We are going to make these calls more " + "ergonomic in a follow up PR." + ) diff --git a/src/spatialdata/dataloader/__init__.py b/src/spatialdata/dataloader/__init__.py index f9262f85..819ab58e 100644 --- a/src/spatialdata/dataloader/__init__.py +++ b/src/spatialdata/dataloader/__init__.py @@ -1,6 +1,4 @@ -import contextlib - -with contextlib.suppress(ImportError): +try: from spatialdata.dataloader.datasets import ImageTilesDataset - -__all__ = ["ImageTilesDataset"] +except ImportError: + ImageTilesDataset = None # type: ignore[assignment, misc] diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index e31750f2..66fc5b4c 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -1,15 +1,20 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable +from collections.abc import Mapping +from functools import partial +from itertools import chain +from types import MappingProxyType +from typing import Any, Callable import numpy as np +import pandas as pd +from anndata import AnnData from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage -from shapely import MultiPolygon, Point, Polygon -from spatial_image import SpatialImage +from scipy.sparse import issparse from torch.utils.data import Dataset -from spatialdata._core.operations.rasterize import rasterize +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _affine_matrix_multiplication from spatialdata.models import ( Image2DModel, @@ -17,175 +22,389 @@ Labels2DModel, Labels3DModel, ShapesModel, + TableModel, get_axes_names, get_model, ) -from spatialdata.transformations.operations import get_transformation +from spatialdata.transformations import get_transformation from spatialdata.transformations.transformations import BaseTransformation -if TYPE_CHECKING: - from spatialdata import SpatialData +__all__ = ["ImageTilesDataset"] class ImageTilesDataset(Dataset): + """ + Dataloader for SpatialData. + + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData` object. + + By default, the dataset returns spatialdata object, but when `return_image` and `return_annot` + are set, the dataset returns a tuple containing: + + - the tile image, centered in the target coordinate system of the region. + - a vector or scalar value from the table. + + Parameters + ---------- + sdata + The SpatialData object. + regions_to_images + A mapping between region and images. The regions are used to compute the tile centers, while the images are + used to get the pixel values. + regions_to_coordinate_systems + A mapping between regions and coordinate systems. The coordinate systems are used to transform both + regions coordinates for tiles as well as images. + tile_scale + The scale of the tiles. This is used only if the `regions` are `shapes`. + It is a scaling factor applied to either the radius (spots) or length (polygons) of the `shapes` + according to the geometry type of the `shapes` element: + + - if `shapes` are circles (spots), the radius is scaled by `tile_scale`. + - if `shapes` are polygons, the length of the polygon is scaled by `tile_scale`. + + If `tile_dim_in_units` is passed, `tile_scale` is ignored. + tile_dim_in_units + The dimension of the requested tile in the units of the target coordinate system. + This specifies the extent of the tile. This is not related the size in pixel of each returned tile. + rasterize + If True, the images are rasterized using :func:`spatialdata.rasterize`. + If False, they are queried using :func:`spatialdata.bounding_box_query`. + return_annotations + If not None, a value from the table is returned together with the image tile. + Only columns in :attr:`anndata.AnnData.obs` and :attr:`anndata.AnnData.X` + can be returned. If None, it will return a `SpatialData` object with only the tuple + containing the image and the table value. + transform + A callable that takes as input the tuple (image, table_value) and returns a new tuple (when + `return_annotations` is not None); a callable that takes as input the `SpatialData` object and + returns a tuple when `return_annotations` is `None`. + This parameter can be used to apply data transformations (for instance a normalization operation) to the + image and the table value. + rasterize_kwargs + Keyword arguments passed to :func:`spatialdata.rasterize` if `rasterize` is True. + This argument can be used for instance to choose the pixel dimension of the image tile. + + Returns + ------- + :class:`torch.utils.data.Dataset` for loading tiles from a :class:`spatialdata.SpatialData`. + """ + + INSTANCE_KEY = "instance_id" + CS_KEY = "cs" + REGION_KEY = "region" + IMAGE_KEY = "image" + def __init__( self, sdata: SpatialData, regions_to_images: dict[str, str], - tile_dim_in_units: float, - tile_dim_in_pixels: int, - target_coordinate_system: str = "global", - # unused at the moment, see - transform: Callable[[SpatialData], Any] | None = None, + regions_to_coordinate_systems: dict[str, str], + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + rasterize: bool = False, + return_annotations: str | list[str] | None = None, + transform: Callable[[Any], Any] | None = None, + rasterize_kwargs: Mapping[str, Any] = MappingProxyType({}), ): - """ - Torch Dataset that returns image tiles around regions from a SpatialData object. - - Parameters - ---------- - sdata - The SpatialData object containing the regions and images from which to extract the tiles from. - regions_to_images - A dictionary mapping the regions element key we want to extract the tiles around to the images element key - we want to get the image data from. - tile_dim_in_units - The dimension of the requested tile in the units of the target coordinate system. This specifies the extent - of the image each tile is querying. This is not related he size in pixel of each returned tile. - tile_dim_in_pixels - The dimension of the requested tile in pixels. This specifies the size of the output tiles that we will get, - independently of which extent of the image the tile is covering. - target_coordinate_system - The coordinate system in which the tile_dim_in_units is specified. - """ - # TODO: we can extend this code to support: - # - automatic dermination of the tile_dim_in_pixels to match the image resolution (prevent down/upscaling) - # - use the bounding box query instead of the raster function if the user wants - self.sdata = sdata - self.regions_to_images = regions_to_images - self.tile_dim_in_units = tile_dim_in_units - self.tile_dim_in_pixels = tile_dim_in_pixels + from spatialdata import bounding_box_query + from spatialdata._core.operations.rasterize import rasterize as rasterize_fn + + self._validate(sdata, regions_to_images, regions_to_coordinate_systems) + self._preprocess(tile_scale, tile_dim_in_units) + + self._crop_image: Callable[..., Any] = ( + partial( + rasterize_fn, + **dict(rasterize_kwargs), + ) + if rasterize + else bounding_box_query # type: ignore[assignment] + ) + self._return = self._get_return(return_annotations) self.transform = transform - self.target_coordinate_system = target_coordinate_system - - self.n_spots_dict = self._compute_n_spots_dict() - self.n_spots = sum(self.n_spots_dict.values()) - - def _validate_regions_to_images(self) -> None: - for region_key, image_key in self.regions_to_images.items(): - regions_element = self.sdata[region_key] - images_element = self.sdata[image_key] - # we could allow also for points - if get_model(regions_element) not in [ShapesModel, Labels2DModel, Labels3DModel]: - raise ValueError("regions_element must be a shapes element or a labels element") - if get_model(images_element) not in [Image2DModel, Image3DModel]: - raise ValueError("images_element must be an image element") - - def _compute_n_spots_dict(self) -> dict[str, int]: - n_spots_dict = {} - for region_key in self.regions_to_images: - element = self.sdata[region_key] - # we could allow also points - if isinstance(element, GeoDataFrame): - n_spots_dict[region_key] = len(element) - elif isinstance(element, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be a geodataframe or a spatial image") - return n_spots_dict - - def _get_region_info_for_index(self, index: int) -> tuple[str, int]: - # TODO: this implmenetation can be improved - i = 0 - for region_key, n_spots in self.n_spots_dict.items(): - if index < i + n_spots: - return region_key, index - i - i += n_spots - raise ValueError(f"index {index} is out of range") - def __len__(self) -> int: - return self.n_spots + def _validate( + self, + sdata: SpatialData, + regions_to_images: dict[str, str], + regions_to_coordinate_systems: dict[str, str], + ) -> None: + """Validate input parameters.""" + self._region_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY] + self._instance_key = sdata.tables["table"].uns[TableModel.ATTRS_KEY][TableModel.INSTANCE_KEY] + available_regions = sdata.tables["table"].obs[self._region_key].cat.categories + cs_region_image = [] # list of tuples (coordinate_system, region, image) - def __getitem__(self, idx: int) -> Any | SpatialData: - from spatialdata import SpatialData - - if idx >= self.n_spots: - raise IndexError() - regions_name, region_index = self._get_region_info_for_index(idx) - regions = self.sdata[regions_name] - # TODO: here we just need to compute the centroids, - # we probably want to move this functionality to a different file - if isinstance(regions, GeoDataFrame): - dims = get_axes_names(regions) - region = regions.iloc[region_index] - shape = regions.geometry.iloc[0] - if isinstance(shape, Polygon): - xy = region.geometry.centroid.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - elif isinstance(shape, MultiPolygon): - raise NotImplementedError("MultiPolygon not supported yet") - elif isinstance(shape, Point): - xy = region.geometry.coords.xy - centroid = np.array([[xy[0][0], xy[1][0]]]) - else: - raise RuntimeError(f"Unsupported type: {type(shape)}") - - t = get_transformation(regions, self.target_coordinate_system) + # check unique matching between regions and images and coordinate systems + assert len(set(regions_to_images.values())) == len( + regions_to_images.keys() + ), "One region cannot be paired to multiple images." + assert len(set(regions_to_coordinate_systems.values())) == len( + regions_to_coordinate_systems.keys() + ), "One region cannot be paired to multiple coordinate systems." + + for region_key, image_key in regions_to_images.items(): + # get elements + region_elem = sdata[region_key] + image_elem = sdata[image_key] + + # check that the elements are supported + if get_model(region_elem) in [Labels2DModel, Labels3DModel]: + raise NotImplementedError("labels elements are not implemented yet.") + if get_model(region_elem) not in [ShapesModel]: + raise ValueError("`regions_element` must be a shapes element.") + if get_model(image_elem) not in [Image2DModel, Image3DModel]: + raise ValueError("`images_element` must be an image element.") + if isinstance(image_elem, MultiscaleSpatialImage): + raise NotImplementedError("Multiscale images are not implemented yet.") + + if region_key not in available_regions: + raise ValueError(f"region {region_key} not found in the spatialdata object.") + + # check that the coordinate systems are valid for the elements + try: + cs = regions_to_coordinate_systems[region_key] + region_trans = get_transformation(region_elem, cs) + image_trans = get_transformation(image_elem, cs) + if isinstance(region_trans, BaseTransformation) and isinstance(image_trans, BaseTransformation): + cs_region_image.append((cs, region_key, image_key)) + except KeyError as e: + raise KeyError(f"region {region_key} not found in `regions_to_coordinate_systems`") from e + + self.regions = list(regions_to_coordinate_systems.keys()) # all regions for the dataloader + self.sdata = sdata + self.dataset_table = self.sdata.tables["table"][ + self.sdata.tables["table"].obs[self._region_key].isin(self.regions) + ] # filtered table for the data loader + self._cs_region_image = tuple(cs_region_image) # tuple of tuples (coordinate_system, region_key, image_key) + + def _preprocess( + self, + tile_scale: float = 1.0, + tile_dim_in_units: float | None = None, + ) -> None: + """Preprocess the dataset.""" + index_df = [] + tile_coords_df = [] + dims_l = [] + shapes_l = [] + table = self.sdata.tables["table"] + for cs, region, image in self._cs_region_image: + # get dims and transformations for the region element + dims = get_axes_names(self.sdata[region]) + dims_l.append(dims) + t = get_transformation(self.sdata[region], cs) assert isinstance(t, BaseTransformation) - aff = t.to_affine_matrix(input_axes=dims, output_axes=dims) - transformed_centroid = np.squeeze(_affine_matrix_multiplication(aff, centroid), 0) - elif isinstance(regions, (SpatialImage, MultiscaleSpatialImage)): - raise NotImplementedError("labels not supported yet") - else: - raise ValueError("element must be shapes or labels") - min_coordinate = np.array(transformed_centroid) - self.tile_dim_in_units / 2 - max_coordinate = np.array(transformed_centroid) + self.tile_dim_in_units / 2 - - raster = self.sdata[self.regions_to_images[regions_name]] - tile = rasterize( - raster, - axes=dims, - min_coordinate=min_coordinate, - max_coordinate=max_coordinate, - target_coordinate_system=self.target_coordinate_system, - target_width=self.tile_dim_in_pixels, + + # get instances from region + inst = table.obs[table.obs[self._region_key] == region][self._instance_key].values + + # subset the regions by instances + subset_region = self.sdata[region].iloc[inst] + # get coordinates of centroids and extent for tiles + tile_coords = _get_tile_coords(subset_region, t, dims, tile_scale, tile_dim_in_units) + tile_coords_df.append(tile_coords) + + # get shapes + shapes_l.append(self.sdata[region]) + + # get index dictionary, with `instance_id`, `cs`, `region`, and `image` + df = pd.DataFrame({self.INSTANCE_KEY: inst}) + df[self.CS_KEY] = cs + df[self.REGION_KEY] = region + df[self.IMAGE_KEY] = image + index_df.append(df) + + # concatenate and assign to self + self.dataset_index = pd.concat(index_df).reset_index(drop=True) + self.tiles_coords = pd.concat(tile_coords_df).reset_index(drop=True) + # get table filtered by regions + self.filtered_table = table.obs[table.obs[self._region_key].isin(self.regions)] + + assert len(self.tiles_coords) == len(self.dataset_index) + dims_ = set(chain(*dims_l)) + assert np.all([i in self.tiles_coords for i in dims_]) + self.dims = list(dims_) + + def _get_return( + self, + return_annot: str | list[str] | None, + ) -> Callable[[int, Any], tuple[Any, Any] | SpatialData]: + """Get function to return values from the table of the dataset.""" + if return_annot is not None: + # table is always returned as array shape (1, len(return_annot)) + # where return_table can be a single column or a list of columns + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + # return tuple of (tile, table) + if np.all([i in self.dataset_table.obs for i in return_annot]): + return lambda x, tile: (tile, self.dataset_table.obs[return_annot].iloc[x].values.reshape(1, -1)) + if np.all([i in self.dataset_table.var_names for i in return_annot]): + if issparse(self.dataset_table.X): + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X.A) + return lambda x, tile: (tile, self.dataset_table[x, return_annot].X) + raise ValueError( + f"`return_annot` must be a column name in the table or a variable name in the table. " + f"Got {return_annot}." + ) + # return spatialdata consisting of the image tile and the associated table + return lambda x, tile: SpatialData( + images={self.dataset_index.iloc[x][self.IMAGE_KEY]: tile}, + table=self.dataset_table[x], ) - tile_regions = regions.iloc[region_index : region_index + 1] - # TODO: as explained in the TODO in the __init__(), we want to let the - # user also use the bounding box query instaed of the rasterization - # the return function of this function would change, so we need to - # decide if instead having an extra Tile dataset class - # from spatialdata._core._spatial_query import BoundingBoxRequest - # request = BoundingBoxRequest( - # target_coordinate_system=self.target_coordinate_system, - # axes=dims, - # min_coordinate=min_coordinate, - # max_coordinate=max_coordinate, - # ) - # sdata_item = self.sdata.query.bounding_box(**request.to_dict()) - table = self.sdata.table - filter_table = False - if table is not None: - region = table.uns["spatialdata_attrs"]["region"] - region_key = table.uns["spatialdata_attrs"]["region_key"] - instance_key = table.uns["spatialdata_attrs"]["instance_key"] - if isinstance(region, str): - if regions_name == region: - filter_table = True - elif isinstance(region, list): - if regions_name in region: - filter_table = True - else: - raise ValueError("region must be a string or a list of strings") - # TODO: maybe slow, we should check if there is a better way to do this - if filter_table: - instance = self.sdata[regions_name].iloc[region_index].name - row = table[(table.obs[region_key] == regions_name) & (table.obs[instance_key] == instance)].copy() - tile_table = row - else: - tile_table = None - tile_sdata = SpatialData( - images={self.regions_to_images[regions_name]: tile}, shapes={regions_name: tile_regions}, table=tile_table + + def __len__(self) -> int: + return len(self.dataset_index) + + def __getitem__(self, idx: int) -> Any | SpatialData: + """Get item from the dataset.""" + # get the row from the index + row = self.dataset_index.iloc[idx] + # get the tile coordinates + t_coords = self.tiles_coords.iloc[idx] + + image = self.sdata[row["image"]] + tile = self._crop_image( + image, + axes=self.dims, + min_coordinate=t_coords[[f"min{i}" for i in self.dims]].values, + max_coordinate=t_coords[[f"max{i}" for i in self.dims]].values, + target_coordinate_system=row["cs"], ) + if self.transform is not None: - return self.transform(tile_sdata) - return tile_sdata + out = self._return(idx, tile) + return self.transform(out) + return self._return(idx, tile) + + @property + def regions(self) -> list[str]: + """List of regions in the dataset.""" + return self._regions + + @regions.setter + def regions(self, regions: list[str]) -> None: # D102 + self._regions = regions + + @property + def sdata(self) -> SpatialData: + """The original SpatialData object.""" + return self._sdata + + @sdata.setter + def sdata(self, sdata: SpatialData) -> None: # D102 + self._sdata = sdata + + @property + def coordinate_systems(self) -> list[str]: + """List of coordinate systems in the dataset.""" + return self._coordinate_systems + + @coordinate_systems.setter + def coordinate_systems(self, coordinate_systems: list[str]) -> None: # D102 + self._coordinate_systems = coordinate_systems + + @property + def tiles_coords(self) -> pd.DataFrame: + """DataFrame with the index of tiles. + + It contains axis coordinates of the centroids, and extent of the tiles. + For example, for a 2D image, it contains the following columns: + + - `x`: the x coordinate of the centroid. + - `y`: the y coordinate of the centroid. + - `extent`: the extent of the tile. + - `minx`: the minimum x coordinate of the tile. + - `miny`: the minimum y coordinate of the tile. + - `maxx`: the maximum x coordinate of the tile. + - `maxy`: the maximum y coordinate of the tile. + """ + return self._tiles_coords + + @tiles_coords.setter + def tiles_coords(self, tiles: pd.DataFrame) -> None: + self._tiles_coords = tiles + + @property + def dataset_index(self) -> pd.DataFrame: + """DataFrame with the metadata of the tiles. + + It contains the following columns: + + - `instance`: the name of the instance in the region. + - `cs`: the coordinate system of the region-image pair. + - `region`: the name of the region. + - `image`: the name of the image. + """ + return self._dataset_index + + @dataset_index.setter + def dataset_index(self, dataset_index: pd.DataFrame) -> None: + self._dataset_index = dataset_index + + @property + def dataset_table(self) -> AnnData: + """AnnData table filtered by the `region` and `cs` present in the dataset.""" + return self._dataset_table + + @dataset_table.setter + def dataset_table(self, dataset_table: AnnData) -> None: + self._dataset_table = dataset_table + + @property + def dims(self) -> list[str]: + """Dimensions of the dataset.""" + return self._dims + + @dims.setter + def dims(self, dims: list[str]) -> None: + self._dims = dims + + +def _get_tile_coords( + elem: GeoDataFrame, + transformation: BaseTransformation, + dims: tuple[str, ...], + tile_scale: float | None = None, + tile_dim_in_units: float | None = None, +) -> pd.DataFrame: + """Get the (transformed) centroid of the region and the extent.""" + # get centroids and transform them + centroids = elem.centroid.get_coordinates().values + aff = transformation.to_affine_matrix(input_axes=dims, output_axes=dims) + centroids = _affine_matrix_multiplication(aff, centroids) + + # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` + if tile_dim_in_units is None: + if elem.iloc[0, 0].geom_type == "Point": + extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale + elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: + extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale + else: + raise ValueError("Only point and polygon shapes are supported.") + if tile_dim_in_units is not None: + if isinstance(tile_dim_in_units, (float, int)): + extent = np.repeat(tile_dim_in_units, len(centroids)) + else: + raise TypeError( + f"`tile_dim_in_units` must be a `float`, `int`, `list`, `tuple` or `np.ndarray`, " + f"not {type(tile_dim_in_units)}." + ) + if len(extent) != len(centroids): + raise ValueError( + f"the number of elements in the region ({len(extent)}) does not match" + f" the number of instances ({len(centroids)})." + ) + + # transform extent + aff = transformation.to_affine_matrix(input_axes=tuple(dims[0]), output_axes=tuple(dims[0])) + extent = _affine_matrix_multiplication(aff, np.array(extent)[:, np.newaxis]) + + # get min and max coordinates + min_coordinates = np.array(centroids) - extent / 2 + max_coordinates = np.array(centroids) + extent / 2 + + # return a dataframe with columns e.g. ["x", "y", "extent", "minx", "miny", "maxx", "maxy"] + return pd.DataFrame( + np.hstack([centroids, extent, min_coordinates, max_coordinates]), + columns=list(dims) + ["extent"] + ["min" + dim for dim in dims] + ["max" + dim for dim in dims], + ) diff --git a/src/spatialdata/datasets.py b/src/spatialdata/datasets.py index 0d811ad3..3b207e7b 100644 --- a/src/spatialdata/datasets.py +++ b/src/spatialdata/datasets.py @@ -13,8 +13,8 @@ from skimage.segmentation import slic from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.operations.aggregate import aggregate +from spatialdata._core.spatialdata import SpatialData from spatialdata._logging import logger from spatialdata._types import ArrayLike from spatialdata.models import ( @@ -153,7 +153,7 @@ def blobs( circles = self._circles_blobs(self.transformations, self.length, self.n_shapes) polygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes) multipolygons = self._polygons_blobs(self.transformations, self.length, self.n_shapes, multipolygons=True) - adata = aggregate(values=image, by=labels).table + adata = aggregate(values=image, by=labels).tables["table"] adata.obs["region"] = pd.Categorical(["blobs_labels"] * len(adata)) adata.obs["instance_id"] = adata.obs_names.astype(int) del adata.uns[TableModel.ATTRS_KEY] @@ -164,7 +164,7 @@ def blobs( labels={"blobs_labels": labels, "blobs_multiscale_labels": multiscale_labels}, points={"blobs_points": points}, shapes={"blobs_circles": circles, "blobs_polygons": polygons, "blobs_multipolygons": multipolygons}, - table=table, + tables=table, ) def _image_blobs( diff --git a/src/spatialdata/models/__init__.py b/src/spatialdata/models/__init__.py index 9a6cf64b..df370e4a 100644 --- a/src/spatialdata/models/__init__.py +++ b/src/spatialdata/models/__init__.py @@ -21,7 +21,9 @@ PointsModel, ShapesModel, TableModel, + check_target_region_column_symmetry, get_model, + get_table_keys, ) __all__ = [ @@ -44,4 +46,6 @@ "get_axes_names", "points_geopandas_to_dask_dataframe", "points_dask_dataframe_to_geopandas", + "check_target_region_column_symmetry", + "get_table_keys", ] diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index cf139d5e..d78dc8b9 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -16,7 +16,6 @@ SpatialElement = Union[SpatialImage, MultiscaleSpatialImage, GeoDataFrame, DaskDataFrame] TRANSFORM_KEY = "transform" DEFAULT_COORDINATE_SYSTEM = "global" -# ValidAxis_t = Literal["c", "x", "y", "z"] ValidAxis_t = str MappingToCoordinateSystem_t = dict[str, BaseTransformation] C = "c" diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f7155a3b..f36d91e9 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -18,7 +18,7 @@ from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods -from pandas.api.types import is_categorical_dtype +from pandas import CategoricalDtype from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -465,12 +465,13 @@ def validate(cls, data: DaskDataFrame) -> None: """ for ax in [X, Y, Z]: if ax in data.columns: - assert data[ax].dtype in [np.float32, np.float64, np.int64] + # TODO: check why this can return int32 on windows. + assert data[ax].dtype in [np.int32, np.float32, np.float64, np.int64] if cls.TRANSFORM_KEY not in data.attrs: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY] - if not is_categorical_dtype(data[feature_key]): + if not isinstance(data[feature_key], CategoricalDtype): logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") @singledispatchmethod @@ -624,7 +625,7 @@ def _add_metadata_and_validate( # Here we are explicitly importing the categories # but it is a convenient way to ensure that the categories are known. # It also just changes the state of the series, so it is not a big deal. - if is_categorical_dtype(data[c]) and not data[c].cat.known: + if isinstance(data[c], CategoricalDtype) and not data[c].cat.known: try: data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) except ValueError: @@ -642,24 +643,121 @@ class TableModel: REGION_KEY_KEY = "region_key" INSTANCE_KEY = "instance_key" - def validate( - self, - data: AnnData, - ) -> AnnData: + def _validate_set_region_key(self, data: AnnData, region_key: str | None = None) -> None: """ - Validate the data. + Validate the region key in table.uns or set a new region key as the region key column. Parameters ---------- data - The data to validate. + The AnnData table. + region_key + The region key to be validated and set in table.uns. + + + Raises + ------ + ValueError + If no region_key is found in table.uns and no region_key is provided as an argument. + ValueError + If the specified region_key in table.uns is not present as a column in table.obs. + ValueError + If the specified region key column is not present in table.obs. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = attrs = {} + table_region_key = attrs.get(self.REGION_KEY_KEY) + if not region_key: + if not table_region_key: + raise ValueError( + "No region_key in table.uns and no region_key provided as argument. Please specify 'region_key'." + ) + if data.obs.get(attrs[TableModel.REGION_KEY_KEY]) is None: + raise ValueError( + f"Specified region_key in table.uns '{table_region_key}' is not " + f"present as column in table.obs. Please specify region_key." + ) + else: + if region_key not in data.obs: + raise ValueError(f"'{region_key}' column not present in table.obs") + attrs[self.REGION_KEY_KEY] = region_key + + def _validate_set_instance_key(self, data: AnnData, instance_key: str | None = None) -> None: + """ + Validate the instance_key in table.uns or set a new instance_key as the instance_key column. + + If no instance_key is provided as argument, the presence of instance_key in table.uns is checked and validated. + If instance_key is provided, presence in table.obs will be validated and if present it will be set as the new + instance_key in table.uns. + + Parameters + ---------- + data + The AnnData table. + + instance_key + The instance_key to be validated and set in table.uns. + + Raises + ------ + ValueError + If no instance_key is provided as argument and no instance_key is found in the `uns` attribute of table. + ValueError + If no instance_key is provided and the instance_key in table.uns does not match any column in table.obs. + ValueError + If provided instance_key is not present as table.obs column. + """ + attrs = data.uns.get(self.ATTRS_KEY) + if attrs is None: + data.uns[self.ATTRS_KEY] = {} + + if not instance_key: + if not attrs.get(TableModel.INSTANCE_KEY): + raise ValueError( + "No instance_key in table.uns and no instance_key provided as argument. Please " + "specify instance_key." + ) + if data.obs.get(attrs[self.INSTANCE_KEY]) is None: + raise ValueError( + f"Specified instance_key in table.uns '{attrs.get(self.INSTANCE_KEY)}' is not present" + f" as column in table.obs. Please specify instance_key." + ) + if instance_key: + if instance_key in data.obs: + attrs[self.INSTANCE_KEY] = instance_key + else: + raise ValueError(f"Instance key column '{instance_key}' not found in table.obs.") + + def _validate_table_annotation_metadata(self, data: AnnData) -> None: + """ + Validate annotation metadata. + + Parameters + ---------- + data + The AnnData object containing the table annotation data. + + Raises + ------ + ValueError + If any of the required metadata keys are not found in the `adata.uns` dictionary or the `adata.obs` + dataframe. + + - If "region" is not found in `adata.uns['ATTRS_KEY']`. + - If "region_key" is not found in `adata.uns['ATTRS_KEY']`. + - If "instance_key" is not found in `adata.uns['ATTRS_KEY']`. + - If `attr[self.REGION_KEY_KEY]` is not found in `adata.obs`, with attr = adata.uns['ATTRS_KEY'] + - If `attr[self.INSTANCE_KEY]` is not found in `adata.obs`. + - If the regions in `adata.uns['ATTRS_KEY']['self.REGION_KEY']` and the unique values of + `attr[self.REGION_KEY_KEY]` do not match. + + Notes + ----- + This does not check whether the annotation target of the table is present in a given SpatialData object. Rather + it is an internal validation of the annotation metadata of the table. - Returns - ------- - The validated data. """ - if self.ATTRS_KEY not in data.uns: - raise ValueError(f"`{self.ATTRS_KEY}` not found in `adata.uns`.") attr = data.uns[self.ATTRS_KEY] if "region" not in attr: @@ -678,6 +776,27 @@ def validate( if len(set(expected_regions).symmetric_difference(set(found_regions))) > 0: raise ValueError(f"Regions in the AnnData object and `{attr[self.REGION_KEY_KEY]}` do not match.") + def validate( + self, + data: AnnData, + ) -> AnnData: + """ + Validate the data. + + Parameters + ---------- + data + The data to validate. + + Returns + ------- + The validated data. + """ + if self.ATTRS_KEY not in data.uns: + return data + + self._validate_table_annotation_metadata(data) + return data @classmethod @@ -704,15 +823,17 @@ def parse( Returns ------- - :class:`anndata.AnnData`. + The parsed data. """ # either all live in adata.uns or all be passed in as argument n_args = sum([region is not None, region_key is not None, instance_key is not None]) + if n_args == 0: + return adata if n_args > 0: if cls.ATTRS_KEY in adata.uns: raise ValueError( - f"Either pass `{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and `{cls.INSTANCE_KEY}`" - f"as arguments or have them in `adata.uns[{cls.ATTRS_KEY!r}]`." + f"`{cls.REGION_KEY}`, `{cls.REGION_KEY_KEY}` and / or `{cls.INSTANCE_KEY}` is/has been passed as" + f"as argument(s). However, `adata.uns[{cls.ATTRS_KEY!r}]` has already been set." ) elif cls.ATTRS_KEY in adata.uns: attr = adata.uns[cls.ATTRS_KEY] @@ -729,7 +850,7 @@ def parse( region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if not is_categorical_dtype(adata.obs[region_key]): + if not isinstance(adata.obs[region_key], CategoricalDtype): warnings.warn( f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2 ) @@ -792,3 +913,73 @@ def _validate_and_return( if isinstance(e, AnnData): return _validate_and_return(TableModel, e) raise TypeError(f"Unsupported type {type(e)}") + + +def get_table_keys(table: AnnData) -> tuple[str | list[str], str, str]: + """ + Get the table keys giving information about what spatial element is annotated. + + The first element returned gives information regarding which spatial elements are annotated by the table, the second + element gives information which column in table.obs contains the information which spatial element is annotated + by each row in the table and the instance key indicates the column in obs giving information of the id of each row. + + Parameters + ---------- + table: + AnnData table for which to retrieve the spatialdata_attrs keys. + + Returns + ------- + The keys in table.uns['spatialdata_attrs'] + """ + if table.uns.get(TableModel.ATTRS_KEY): + attrs = table.uns[TableModel.ATTRS_KEY] + return attrs[TableModel.REGION_KEY], attrs[TableModel.REGION_KEY_KEY], attrs[TableModel.INSTANCE_KEY] + + raise ValueError( + "No spatialdata_attrs key found in table.uns, therefore, no table keys found. Please parse the table." + ) + + +def check_target_region_column_symmetry(table: AnnData, region_key: str, target: str | pd.Series) -> None: + """ + Check region and region_key column symmetry. + + This checks whether the specified targets are also present in the region key column in obs and raises an error + if this is not the case. + + Parameters + ---------- + table + Table annotating specific SpatialElements + region_key + The column in obs containing for each row which SpatialElement is annotated by that row. + target + Name of target(s) SpatialElement(s) + + Raises + ------ + ValueError + If there is a mismatch between specified target regions and regions in the region key column of table.obs. + + Example + ------- + Assuming we have a table with region column in obs given by `region_key` called 'region' for which we want to check + whether it contains the specified annotation targets in the `target` variable as `pd.Series['region1', 'region2']`: + + ```python + check_target_region_column_symmetry(table, region_key=region_key, target=target) + ``` + + This returns None if both specified targets are present in the region_key obs column. In this case the annotation + targets can be safely set. If not then a ValueError is raised stating the elements that are not shared between + the region_key column in obs and the specified targets. + """ + found_regions = set(table.obs[region_key].unique().tolist()) + target_element_set = [target] if isinstance(target, str) else target + symmetric_difference = found_regions.symmetric_difference(target_element_set) + if symmetric_difference: + raise ValueError( + f"Mismatch(es) found between regions in region column in obs and target element: " + f"{', '.join(diff for diff in symmetric_difference)}" + ) diff --git a/src/spatialdata/transformations/operations.py b/src/spatialdata/transformations/operations.py index 165551c4..2e4a9e1f 100644 --- a/src/spatialdata/transformations/operations.py +++ b/src/spatialdata/transformations/operations.py @@ -15,7 +15,7 @@ ) if TYPE_CHECKING: - from spatialdata import SpatialData + from spatialdata._core.spatialdata import SpatialData from spatialdata.models import SpatialElement from spatialdata.transformations import Affine, BaseTransformation @@ -68,8 +68,8 @@ def set_transformation( assert to_coordinate_system is None _set_transformations(element, transformation) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either set a transformation to an element " @@ -164,8 +164,8 @@ def remove_transformation( assert to_coordinate_system is None _set_transformations(element, {}) else: - if not write_to_sdata.contains_element(element, raise_exception=True): - raise RuntimeError("contains_element() failed without raising an exception.") + if write_to_sdata.locate_element(element) is None: + raise RuntimeError("The element is not found in the SpatialData object.") if not write_to_sdata.is_backed(): raise ValueError( "The SpatialData object is not backed. You can either remove a transformation from an " @@ -178,7 +178,7 @@ def remove_transformation( def _build_transformations_graph(sdata: SpatialData) -> nx.Graph: g = nx.DiGraph() - gen = sdata._gen_elements_values() + gen = sdata._gen_spatial_element_values() for cs in sdata.coordinate_systems: g.add_node(cs) for e in gen: diff --git a/tests/conftest.py b/tests/conftest.py index 490cd929..3fcfe005 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -21,7 +21,7 @@ from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, @@ -66,12 +66,12 @@ def points() -> SpatialData: @pytest.fixture() def table_single_annotation() -> SpatialData: - return SpatialData(table=_get_table(region="sample1")) + return SpatialData(tables=_get_table(region="labels2d")) @pytest.fixture() def table_multiple_annotations() -> SpatialData: - return SpatialData(table=_get_table(region=["sample1", "sample2"])) + return SpatialData(table=_get_table(region=["labels2d", "poly"])) @pytest.fixture() @@ -93,7 +93,7 @@ def full_sdata() -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table(region="sample1"), + tables=_get_table(region="labels2d"), ) @@ -128,7 +128,7 @@ def sdata(request) -> SpatialData: labels=_get_labels(), shapes=_get_shapes(), points=_get_points(), - table=_get_table("sample1"), + tables=_get_table("labels2d"), ) if request.param == "empty": return SpatialData() @@ -141,7 +141,10 @@ def _get_images() -> dict[str, SpatialImage | MultiscaleSpatialImage]: dims_3d = ("z", "y", "x", "c") out["image2d"] = Image2DModel.parse(RNG.normal(size=(3, 64, 64)), dims=dims_2d, c_coords=["r", "g", "b"]) out["image2d_multiscale"] = Image2DModel.parse( - RNG.normal(size=(3, 64, 64)), scale_factors=[2, 2], dims=dims_2d, c_coords=["r", "g", "b"] + RNG.normal(size=(3, 64, 64)), + scale_factors=[2, 2], + dims=dims_2d, + c_coords=["r", "g", "b"], ) out["image2d_xarray"] = Image2DModel.parse(DataArray(RNG.normal(size=(3, 64, 64)), dims=dims_2d), dims=None) out["image2d_multiscale_xarray"] = Image2DModel.parse( @@ -277,11 +280,13 @@ def _get_points() -> dict[str, DaskDataFrame]: def _get_table( - region: str | list[str] = "sample1", - region_key: str = "region", - instance_key: str = "instance_id", + region: None | str | list[str] = "sample1", + region_key: None | str = "region", + instance_key: None | str = "instance_id", ) -> AnnData: adata = AnnData(RNG.normal(size=(100, 10)), obs=pd.DataFrame(RNG.normal(size=(100, 3)), columns=["a", "b", "c"])) + if not all(var for var in (region, region_key, instance_key)): + return TableModel.parse(adata=adata) adata.obs[instance_key] = np.arange(adata.n_obs) if isinstance(region, str): adata.obs[region_key] = region diff --git a/tests/core/operations/test_aggregations.py b/tests/core/operations/test_aggregations.py index 51ffa98e..36609464 100644 --- a/tests/core/operations/test_aggregations.py +++ b/tests/core/operations/test_aggregations.py @@ -9,8 +9,9 @@ from anndata.tests.helpers import assert_equal from geopandas import GeoDataFrame from numpy.random import default_rng -from spatialdata import SpatialData, aggregate +from spatialdata import aggregate from spatialdata._core.query._utils import circles_to_polygons +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import _deepcopy_geodataframe from spatialdata.models import Image2DModel, Labels2DModel, PointsModel, TableModel from spatialdata.transformations import Affine, Identity, set_transformation diff --git a/tests/core/operations/test_spatialdata_operations.py b/tests/core/operations/test_spatialdata_operations.py index c551461c..3b2b3e6a 100644 --- a/tests/core/operations/test_spatialdata_operations.py +++ b/tests/core/operations/test_spatialdata_operations.py @@ -1,5 +1,4 @@ -import tempfile -from pathlib import Path +from __future__ import annotations import numpy as np import pytest @@ -9,8 +8,8 @@ from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.concatenate import _concatenate_tables, concatenate +from spatialdata._core.spatialdata import SpatialData from spatialdata.datasets import blobs from spatialdata.models import ( Image2DModel, @@ -25,34 +24,88 @@ from tests.conftest import _get_table -def test_element_names_unique(): +def test_element_names_unique() -> None: shapes = ShapesModel.parse(np.array([[0, 0]]), geometry=0, radius=1) points = PointsModel.parse(np.array([[0, 0]])) labels = Labels2DModel.parse(np.array([[0, 0], [0, 0]]), dims=["y", "x"]) image = Image2DModel.parse(np.array([[[0, 0], [0, 0]]]), dims=["c", "y", "x"]) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, points={"image": points}) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, shapes={"image": shapes}) - with pytest.raises(ValueError): + with pytest.raises(KeyError): SpatialData(images={"image": image}, labels={"image": labels}) sdata = SpatialData( images={"image": image}, points={"points": points}, shapes={"shapes": shapes}, labels={"labels": labels} ) - with pytest.raises(ValueError): - sdata.add_image(name="points", image=image) - with pytest.raises(ValueError): - sdata.add_points(name="image", points=points) - with pytest.raises(ValueError): - sdata.add_shapes(name="image", shapes=shapes) - with pytest.raises(ValueError): - sdata.add_labels(name="image", labels=labels) + # add elements with the same name + # of element of same type + with pytest.warns(UserWarning): + sdata.images["image"] = image + with pytest.warns(UserWarning): + sdata.points["points"] = points + with pytest.warns(UserWarning): + sdata.shapes["shapes"] = shapes + with pytest.warns(UserWarning): + sdata.labels["labels"] = labels + + # add elements with the same name + # of element of different type + with pytest.raises(KeyError): + sdata.images["points"] = image + with pytest.raises(KeyError): + sdata.images["shapes"] = image + with pytest.raises(KeyError): + sdata.labels["points"] = labels + with pytest.raises(KeyError): + sdata.points["shapes"] = points + with pytest.raises(KeyError): + sdata.shapes["labels"] = shapes + + assert sdata["image"].shape == image.shape + assert sdata["labels"].shape == labels.shape + assert len(sdata["points"]) == len(points) + assert sdata["shapes"].shape == shapes.shape + + # add elements with the same name, test only couples of elements + with pytest.raises(KeyError): + sdata["labels"] = image + with pytest.warns(UserWarning): + sdata["points"] = points + # this should not raise warnings because it's a different (new) name + sdata["image2"] = image -def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData): + # test replacing complete attribute + sdata = SpatialData( + images={"image": image}, points={"points": points}, shapes={"shapes": shapes}, labels={"labels": labels} + ) + # test for images + sdata.images = {"image2": image} + assert set(sdata.images.keys()) == {"image2"} + assert "image2" in sdata._shared_keys + assert "image" not in sdata._shared_keys + # test for labels + sdata.labels = {"labels2": labels} + assert set(sdata.labels.keys()) == {"labels2"} + assert "labels2" in sdata._shared_keys + assert "labels" not in sdata._shared_keys + # test for points + sdata.points = {"points2": points} + assert set(sdata.points.keys()) == {"points2"} + assert "points2" in sdata._shared_keys + assert "points" not in sdata._shared_keys + # test for points + sdata.shapes = {"shapes2": shapes} + assert set(sdata.shapes.keys()) == {"shapes2"} + assert "shapes2" in sdata._shared_keys + assert "shapes" not in sdata._shared_keys + + +def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: for element_type, element_name, element in sdata0._gen_elements(): elements = sdata1.__getattribute__(element_type) assert element_name in elements @@ -72,11 +125,11 @@ def _assert_elements_left_to_right_seem_identical(sdata0: SpatialData, sdata1: S raise TypeError(f"Unsupported type {type(element)}") -def _assert_tables_seem_identical(table0: AnnData, table1: AnnData): - assert table0.shape == table1.shape +def _assert_tables_seem_identical(table0: AnnData | None, table1: AnnData | None) -> None: + assert table0 is None and table1 is None or table0.shape == table1.shape -def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData): +def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: SpatialData) -> None: # this is not a full comparison, but it's fine anyway assert len(list(sdata0._gen_elements())) == len(list(sdata1._gen_elements())) assert set(sdata0.coordinate_systems) == set(sdata1.coordinate_systems) @@ -85,7 +138,7 @@ def _assert_spatialdata_objects_seem_identical(sdata0: SpatialData, sdata1: Spat _assert_tables_seem_identical(sdata0.table, sdata1.table) -def test_filter_by_coordinate_system(full_sdata): +def test_filter_by_coordinate_system(full_sdata: SpatialData) -> None: sdata = full_sdata.filter_by_coordinate_system(coordinate_system="global", filter_table=False) _assert_spatialdata_objects_seem_identical(sdata, full_sdata) @@ -95,16 +148,16 @@ def test_filter_by_coordinate_system(full_sdata): set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") sdata_my_space = full_sdata.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) - assert len(list(sdata_my_space._gen_elements())) == 2 + assert len(list(sdata_my_space.gen_elements())) == 3 _assert_tables_seem_identical(sdata_my_space.table, full_sdata.table) sdata_my_space1 = full_sdata.filter_by_coordinate_system( coordinate_system=["my_space0", "my_space1", "my_space2"], filter_table=False ) - assert len(list(sdata_my_space1._gen_elements())) == 3 + assert len(list(sdata_my_space1.gen_elements())) == 4 -def test_filter_by_coordinate_system_also_table(full_sdata): +def test_filter_by_coordinate_system_also_table(full_sdata: SpatialData) -> None: from spatialdata.models import TableModel rng = np.random.default_rng(seed=0) @@ -128,7 +181,7 @@ def test_filter_by_coordinate_system_also_table(full_sdata): assert len(filtered_sdata2.table) == len(full_sdata.table) -def test_rename_coordinate_systems(full_sdata): +def test_rename_coordinate_systems(full_sdata: SpatialData) -> None: # all the elements point to global, add new coordinate systems set_transformation( element=full_sdata.shapes["circles"], transformation=Identity(), to_coordinate_system="my_space0" @@ -181,7 +234,7 @@ def test_rename_coordinate_systems(full_sdata): assert elements_in_global_before == elements_in_global_after -def test_concatenate_tables(): +def test_concatenate_tables() -> None: """ The concatenation uses AnnData.concatenate(), here we test the concatenation result on region, region_key, instance_key @@ -226,7 +279,7 @@ def test_concatenate_tables(): ) -def test_concatenate_sdatas(full_sdata): +def test_concatenate_sdatas(full_sdata: SpatialData) -> None: with pytest.raises(KeyError): concatenate([full_sdata, SpatialData(images={"image2d": full_sdata.images["image2d"]})]) with pytest.raises(KeyError): @@ -241,7 +294,7 @@ def test_concatenate_sdatas(full_sdata): set_transformation(full_sdata.shapes["circles"], Identity(), "my_space0") set_transformation(full_sdata.shapes["poly"], Identity(), "my_space1") filtered = full_sdata.filter_by_coordinate_system(coordinate_system=["my_space0", "my_space1"], filter_table=False) - assert len(list(filtered._gen_elements())) == 2 + assert len(list(filtered.gen_elements())) == 3 filtered0 = filtered.filter_by_coordinate_system(coordinate_system="my_space0", filter_table=False) filtered1 = filtered.filter_by_coordinate_system(coordinate_system="my_space1", filter_table=False) # this is needed cause we can't handle regions with same name. @@ -252,23 +305,22 @@ def test_concatenate_sdatas(full_sdata): filtered1.table = table_new filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY] = new_region filtered1.table.obs[filtered1.table.uns[TableModel.ATTRS_KEY][TableModel.REGION_KEY_KEY]] = new_region - concatenated = concatenate([filtered0, filtered1]) - assert len(list(concatenated._gen_elements())) == 2 + concatenated = concatenate([filtered0, filtered1], concatenate_tables=True) + assert len(list(concatenated.gen_elements())) == 3 -def test_locate_spatial_element(full_sdata): - assert full_sdata._locate_spatial_element(full_sdata.images["image2d"]) == ("image2d", "images") +def test_locate_spatial_element(full_sdata: SpatialData) -> None: + assert full_sdata.locate_element(full_sdata.images["image2d"])[0] == "images/image2d" im = full_sdata.images["image2d"] del full_sdata.images["image2d"] - with pytest.raises(ValueError, match="Element not found in the SpatialData object."): - full_sdata._locate_spatial_element(im) + assert full_sdata.locate_element(im) is None full_sdata.images["image2d"] = im full_sdata.images["image2d_again"] = im - with pytest.raises(ValueError): - full_sdata._locate_spatial_element(im) + paths = full_sdata.locate_element(im) + assert len(paths) == 2 -def test_get_item(points): +def test_get_item(points: SpatialData) -> None: assert id(points["points_0"]) == id(points.points["points_0"]) # removed this test after this change: https://github.com/scverse/spatialdata/pull/145#discussion_r1133122720 @@ -282,20 +334,14 @@ def test_get_item(points): _ = points["not_present"] -def test_set_item(full_sdata): +def test_set_item(full_sdata: SpatialData) -> None: for name in ["image2d", "labels2d", "points_0", "circles", "poly"]: full_sdata[name + "_again"] = full_sdata[name] - with pytest.raises(KeyError): + with pytest.warns(UserWarning): full_sdata[name] = full_sdata[name] - with tempfile.TemporaryDirectory() as tmpdir: - full_sdata.write(Path(tmpdir) / "test.zarr") - for name in ["image2d", "labels2d", "points_0"]: - # trying to overwrite the file used for backing (only for images, labels and points) - with pytest.raises(ValueError): - full_sdata[name] = full_sdata[name] -def test_no_shared_transformations(): +def test_no_shared_transformations() -> None: """Test transformation dictionary copy for transformations not to be shared.""" sdata = blobs() element_name = "blobs_image" @@ -303,15 +349,41 @@ def test_no_shared_transformations(): set_transformation(sdata.images[element_name], Identity(), to_coordinate_system=test_space) gen = sdata._gen_elements() - for _, name, obj in gen: - if name != element_name: - assert test_space not in get_transformation(obj, get_all=True) - else: - assert test_space in get_transformation(obj, get_all=True) + for element_type, name, obj in gen: + if element_type != "tables": + if name != element_name: + assert test_space not in get_transformation(obj, get_all=True) + else: + assert test_space in get_transformation(obj, get_all=True) -def test_init_from_elements(full_sdata): +def test_init_from_elements(full_sdata: SpatialData) -> None: all_elements = {name: el for _, name, el in full_sdata._gen_elements()} sdata = SpatialData.init_from_elements(all_elements, table=full_sdata.table) for element_type in ["images", "labels", "points", "shapes"]: assert set(getattr(sdata, element_type).keys()) == set(getattr(full_sdata, element_type).keys()) + + +def test_subset(full_sdata: SpatialData) -> None: + element_names = ["image2d", "points_0", "circles", "poly"] + subset0 = full_sdata.subset(element_names) + unique_names = set() + for _, k, _ in subset0.gen_spatial_elements(): + unique_names.add(k) + assert "image3d_xarray" in full_sdata.images + assert unique_names == set(element_names) + assert subset0.table is None + + adata = AnnData( + shape=(10, 0), + obs={"region": ["circles"] * 5 + ["poly"] * 5, "instance_id": [0, 1, 2, 3, 4, "a", "b", "c", "d", "e"]}, + ) + del full_sdata.table + sdata_table = TableModel.parse(adata, region=["circles", "poly"], region_key="region", instance_key="instance_id") + full_sdata.table = sdata_table + full_sdata.tables["second_table"] = sdata_table + subset1 = full_sdata.subset(["poly", "second_table"]) + assert subset1.table is not None + assert len(subset1.table) == 5 + assert subset1.table.obs["region"].unique().tolist() == ["poly"] + assert len(subset1["second_table"]) == 10 diff --git a/tests/core/operations/test_transform.py b/tests/core/operations/test_transform.py index f28b345f..5c11083d 100644 --- a/tests/core/operations/test_transform.py +++ b/tests/core/operations/test_transform.py @@ -8,7 +8,8 @@ from geopandas.testing import geom_almost_equals from multiscale_spatial_image import MultiscaleSpatialImage from spatial_image import SpatialImage -from spatialdata import SpatialData, transform +from spatialdata import transform +from spatialdata._core.spatialdata import SpatialData from spatialdata._utils import unpad_raster from spatialdata.models import Image2DModel, PointsModel, ShapesModel, get_axes_names from spatialdata.transformations.operations import ( @@ -462,7 +463,7 @@ def test_transform_elements_and_entire_spatial_data_object(sdata: SpatialData): # TODO: we are just applying the transformation, # we are not checking it is correct. We could improve this test scale = Scale([2], axes=("x",)) - for element in sdata._gen_elements_values(): + for element in sdata._gen_spatial_element_values(): set_transformation(element, scale, "my_space") sdata.transform_element_to_coordinate_system(element, "my_space") sdata.transform_to_coordinate_system("my_space") diff --git a/tests/core/query/test_spatial_query.py b/tests/core/query/test_spatial_query.py index 6db7e904..04ed6b11 100644 --- a/tests/core/query/test_spatial_query.py +++ b/tests/core/query/test_spatial_query.py @@ -2,17 +2,18 @@ import numpy as np import pytest +import xarray from anndata import AnnData from multiscale_spatial_image import MultiscaleSpatialImage from shapely import Polygon from spatial_image import SpatialImage -from spatialdata import SpatialData from spatialdata._core.query.spatial_query import ( BaseSpatialRequest, BoundingBoxRequest, bounding_box_query, polygon_query, ) +from spatialdata._core.spatialdata import SpatialData from spatialdata.models import ( Image2DModel, Image3DModel, @@ -101,23 +102,54 @@ def test_bounding_box_request_wrong_coordinate_order(): ) -def test_bounding_box_points(): +@pytest.mark.parametrize("is_3d", [True, False]) +@pytest.mark.parametrize("is_bb_3d", [True, False]) +def test_bounding_box_points(is_3d: bool, is_bb_3d: bool): """test the points bounding box_query""" - points_element = _make_points(np.array([[10, 10], [20, 20], [20, 30]])) - original_x = np.array(points_element["x"]) - original_y = np.array(points_element["y"]) + data_x = np.array([10, 20, 20, 20]) + data_y = np.array([10, 20, 30, 30]) + data_z = np.array([100, 200, 200, 300]) + + data = np.stack((data_x, data_y), axis=1) + if is_3d: + data = np.hstack((data, data_z.reshape(-1, 1))) + points_element = _make_points(data) + + original_x = points_element["x"] + original_y = points_element["y"] + if is_3d: + original_z = points_element["z"] + + if is_bb_3d: + _min_coordinate = np.array([18, 25, 250]) + _max_coordinate = np.array([22, 35, 350]) + _axes = ("x", "y", "z") + else: + _min_coordinate = np.array([18, 25]) + _max_coordinate = np.array([22, 35]) + _axes = ("x", "y") points_result = bounding_box_query( points_element, - axes=("x", "y"), - min_coordinate=np.array([18, 25]), - max_coordinate=np.array([22, 35]), + axes=_axes, + min_coordinate=_min_coordinate, + max_coordinate=_max_coordinate, target_coordinate_system="global", ) # Check that the correct point was selected - np.testing.assert_allclose(points_result["x"].compute(), [20]) - np.testing.assert_allclose(points_result["y"].compute(), [30]) + if is_3d: + if is_bb_3d: + np.testing.assert_allclose(points_result["x"].compute(), [20]) + np.testing.assert_allclose(points_result["y"].compute(), [30]) + np.testing.assert_allclose(points_result["z"].compute(), [300]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) + np.testing.assert_allclose(points_result["z"].compute(), [200, 300]) + else: + np.testing.assert_allclose(points_result["x"].compute(), [20, 20]) + np.testing.assert_allclose(points_result["y"].compute(), [30, 30]) # result should be valid points element PointsModel.validate(points_result) @@ -125,6 +157,8 @@ def test_bounding_box_points(): # original element should be unchanged np.testing.assert_allclose(points_element["x"].compute(), original_x) np.testing.assert_allclose(points_element["y"].compute(), original_y) + if is_3d: + np.testing.assert_allclose(points_element["z"].compute(), original_z) def test_bounding_box_points_no_points(): @@ -142,57 +176,74 @@ def test_bounding_box_points_no_points(): assert request is None +# @pytest.mark.parametrize("n_channels", [1, 2, 3]) @pytest.mark.parametrize("n_channels", [1, 2, 3]) -def test_bounding_box_image_2d(n_channels): - """Apply a bounding box to a 2D image""" - image = np.zeros((n_channels, 10, 10)) - # y: [5, 9], x: [0, 4] has value 1 - image[:, 5::, 0:5] = 1 - image_element = Image2DModel.parse(image) - image_element_multiscale = Image2DModel.parse(image, scale_factors=[2, 2]) +@pytest.mark.parametrize("is_labels", [True, False]) +@pytest.mark.parametrize("is_3d", [True, False]) +@pytest.mark.parametrize("is_bb_3d", [True, False]) +def test_bounding_box_raster(n_channels: int, is_labels: bool, is_3d: bool, is_bb_3d: bool): + """Apply a bounding box to a raster element.""" + if is_labels and n_channels > 1: + # labels cannot have multiple channels, let's ignore this combination of parameters + return + + shape = (10, 10) + if is_3d: + shape = (10,) + shape + shape = (n_channels,) + shape if not is_labels else (1,) + shape + + image = np.zeros(shape) + axes = ["y", "x"] + if is_3d: + image[:, 2:7, 5::, 0:5] = 1 + axes = ["z"] + axes + else: + image[:, 5::, 0:5] = 1 + + if is_labels: + image = np.squeeze(image, axis=0) + else: + axes = ["c"] + axes + + ximage = xarray.DataArray(image, dims=axes) + model = ( + Labels3DModel + if is_labels and is_3d + else Labels2DModel + if is_labels + else Image3DModel + if is_3d + else Image2DModel + ) - for image in [image_element, image_element_multiscale]: - # bounding box: y: [5, 10[, x: [0, 5[ - image_result = bounding_box_query( - image, - axes=("y", "x"), - min_coordinate=np.array([5, 0]), - max_coordinate=np.array([10, 5]), - target_coordinate_system="global", - ) - expected_image = np.ones((n_channels, 5, 5)) # c dimension is preserved - if isinstance(image, SpatialImage): - assert isinstance(image, SpatialImage) - np.testing.assert_allclose(image_result, expected_image) - elif isinstance(image, MultiscaleSpatialImage): - assert isinstance(image_result, MultiscaleSpatialImage) - v = image_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") + image_element = model.parse(image) + image_element_multiscale = model.parse(image, scale_factors=[2, 2]) + images = [image_element, image_element_multiscale] -@pytest.mark.parametrize("n_channels", [1, 2, 3]) -def test_bounding_box_image_3d(n_channels): - """Apply a bounding box to a 3D image""" - image = np.zeros((n_channels, 10, 10, 10)) - # z: [5, 9], y: [0, 4], x: [2, 6] has value 1 - image[:, 5::, 0:5, 2:7] = 1 - image_element = Image3DModel.parse(image) - image_element_multiscale = Image3DModel.parse(image, scale_factors=[2, 2]) + for image in images: + if is_bb_3d: + _min_coordinate = np.array([2, 5, 0]) + _max_coordinate = np.array([7, 10, 5]) + _axes = ("z", "y", "x") + else: + _min_coordinate = np.array([5, 0]) + _max_coordinate = np.array([10, 5]) + _axes = ("y", "x") - for image in [image_element, image_element_multiscale]: - # bounding box: z: [5, 10[, y: [0, 5[, x: [2, 7[ image_result = bounding_box_query( image, - axes=("z", "y", "x"), - min_coordinate=np.array([5, 0, 2]), - max_coordinate=np.array([10, 5, 7]), + axes=_axes, + min_coordinate=_min_coordinate, + max_coordinate=_max_coordinate, target_coordinate_system="global", ) - expected_image = np.ones((n_channels, 5, 5, 5)) # c dimension is preserved + + slices = {"y": slice(5, 10), "x": slice(0, 5)} + if is_bb_3d and is_3d: + slices["z"] = slice(2, 7) + expected_image = ximage.sel(**slices) + if isinstance(image, SpatialImage): assert isinstance(image, SpatialImage) np.testing.assert_allclose(image_result, expected_image) @@ -206,69 +257,6 @@ def test_bounding_box_image_3d(n_channels): raise ValueError("Unexpected type") -def test_bounding_box_labels_2d(): - """Apply a bounding box to a 2D label image""" - # in this test let's try some affine transformations, we could do that also for the other tests - image = np.zeros((10, 10)) - # y: [5, 9], x: [0, 4] has value 1 - image[5::, 0:5] = 1 - labels_element = Labels2DModel.parse(image) - labels_element_multiscale = Labels2DModel.parse(image, scale_factors=[2, 2]) - - for labels in [labels_element, labels_element_multiscale]: - # bounding box: y: [5, 10[, x: [0, 5[ - labels_result = bounding_box_query( - labels, - axes=("y", "x"), - min_coordinate=np.array([5, 0]), - max_coordinate=np.array([10, 5]), - target_coordinate_system="global", - ) - expected_image = np.ones((5, 5)) - if isinstance(labels, SpatialImage): - assert isinstance(labels, SpatialImage) - np.testing.assert_allclose(labels_result, expected_image) - elif isinstance(labels, MultiscaleSpatialImage): - assert isinstance(labels_result, MultiscaleSpatialImage) - v = labels_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") - - -def test_bounding_box_labels_3d(): - """Apply a bounding box to a 3D label image""" - image = np.zeros((10, 10, 10), dtype=int) - # z: [5, 9], y: [0, 4], x: [2, 6] has value 1 - image[5::, 0:5, 2:7] = 1 - labels_element = Labels3DModel.parse(image) - labels_element_multiscale = Labels3DModel.parse(image, scale_factors=[2, 2]) - - for labels in [labels_element, labels_element_multiscale]: - # bounding box: z: [5, 10[, y: [0, 5[, x: [2, 7[ - labels_result = bounding_box_query( - labels, - axes=("z", "y", "x"), - min_coordinate=np.array([5, 0, 2]), - max_coordinate=np.array([10, 5, 7]), - target_coordinate_system="global", - ) - expected_image = np.ones((5, 5, 5)) - if isinstance(labels, SpatialImage): - assert isinstance(labels, SpatialImage) - np.testing.assert_allclose(labels_result, expected_image) - elif isinstance(labels, MultiscaleSpatialImage): - assert isinstance(labels_result, MultiscaleSpatialImage) - v = labels_result["scale0"].values() - assert len(v) == 1 - xdata = v.__iter__().__next__() - np.testing.assert_allclose(xdata, expected_image) - else: - raise ValueError("Unexpected type") - - # TODO: more tests can be added for spatial queries after the cases 2, 3, 4 are implemented # (see https://github.com/scverse/spatialdata/pull/151, also for details on more tests) @@ -323,7 +311,7 @@ def test_bounding_box_spatial_data(full_sdata): _assert_spatialdata_objects_seem_identical(result, result2) - for element in result._gen_elements_values(): + for element in result._gen_spatial_element_values(): d = get_transformation(element, get_all=True) new_d = {k.replace("global", "cropped"): v for k, v in d.items()} set_transformation(element, new_d, set_all=True) @@ -338,7 +326,7 @@ def test_bounding_box_filter_table(): table.obs["region"] = ["circles0", "circles0", "circles1"] table.obs["instance"] = [0, 1, 0] table = TableModel.parse(table, region=["circles0", "circles1"], region_key="region", instance_key="instance") - sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, table=table) + sdata = SpatialData(shapes={"circles0": circles0, "circles1": circles1}, tables=table) queried0 = sdata.query.bounding_box( axes=("y", "x"), min_coordinate=np.array([15, 15]), @@ -364,7 +352,7 @@ def test_polygon_query_points(sdata_query_aggregation): queried = polygon_query(sdata, polygons=polygon, target_coordinate_system="global", shapes=False, points=True) points = queried["points"].compute() assert len(points) == 6 - assert len(queried.table) == 0 + assert queried.table is None # TODO: the case of querying points with multiple polygons is not currently implemented @@ -373,7 +361,7 @@ def test_polygon_query_shapes(sdata_query_aggregation): sdata = sdata_query_aggregation values_sdata = SpatialData( shapes={"values_polygons": sdata["values_polygons"], "values_circles": sdata["values_circles"]}, - table=sdata.table, + tables=sdata.table, ) polygon = sdata["by_polygons"].geometry.iloc[0] circle = sdata["by_circles"].geometry.iloc[0] @@ -427,7 +415,7 @@ def test_polygon_query_spatial_data(sdata_query_aggregation): "values_circles": sdata["values_circles"], }, points={"points": sdata["points"]}, - table=sdata.table, + tables=sdata.table, ) polygon = sdata["by_polygons"].geometry.iloc[0] queried = polygon_query(values_sdata, polygons=polygon, target_coordinate_system="global", shapes=True, points=True) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 59126996..dac01e80 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -4,59 +4,112 @@ import pandas as pd import pytest from anndata import AnnData -from spatialdata.dataloader.datasets import ImageTilesDataset +from spatialdata._core.spatialdata import SpatialData +from spatialdata.dataloader import ImageTilesDataset from spatialdata.models import TableModel -@pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) -@pytest.mark.parametrize( - "regions_element", - ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], -) -def test_tiles_dataset(sdata_blobs, image_element, regions_element): - if regions_element in ["blobs_labels", "blobs_multipolygons", "blobs_multiscale_labels"]: - cm = pytest.raises(NotImplementedError) - else: - cm = contextlib.nullcontext() - with cm: +class TestImageTilesDataset: + @pytest.mark.parametrize("image_element", ["blobs_image", "blobs_multiscale_image"]) + @pytest.mark.parametrize( + "regions_element", + ["blobs_labels", "blobs_multiscale_labels", "blobs_circles", "blobs_polygons", "blobs_multipolygons"], + ) + def test_validation(self, sdata_blobs, image_element, regions_element): + if regions_element in ["blobs_labels", "blobs_multiscale_labels"] or image_element == "blobs_multiscale_image": + cm = pytest.raises(NotImplementedError) + elif regions_element in ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]: + cm = pytest.raises(ValueError) + else: + cm = contextlib.nullcontext() + with cm: + _ = ImageTilesDataset( + sdata=sdata_blobs, + regions_to_images={regions_element: image_element}, + regions_to_coordinate_systems={regions_element: "global"}, + ) + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("raster", [True, False]) + def test_default(self, sdata_blobs, regions_element, raster): + raster_kwargs = {"target_unit_to_pixels": 2} if raster else {} + + sdata = self._annotate_shapes(sdata_blobs, regions_element) ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={regions_element: image_element}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", + sdata=sdata, + rasterize=raster, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + rasterize_kwargs=raster_kwargs, ) - tile = ds[0].images.values().__iter__().__next__() - assert tile.shape == (3, 32, 32) + sdata_tile = ds[0] + tile = sdata_tile.images.values().__iter__().__next__() -def test_tiles_table(sdata_blobs): - new_table = AnnData( - X=np.random.default_rng().random((3, 10)), - obs=pd.DataFrame({"region": "blobs_circles", "instance_id": np.array([0, 1, 2])}), - ) - new_table = TableModel.parse(new_table, region="blobs_circles", region_key="region", instance_key="instance_id") - del sdata_blobs.table - sdata_blobs.table = new_table - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 3 - assert len(ds[0].table) == 1 - assert np.all(ds[0].table.X == new_table[0].X) - - -def test_tiles_multiple_elements(sdata_blobs): - ds = ImageTilesDataset( - sdata=sdata_blobs, - regions_to_images={"blobs_circles": "blobs_image", "blobs_polygons": "blobs_multiscale_image"}, - tile_dim_in_units=10, - tile_dim_in_pixels=32, - target_coordinate_system="global", - ) - assert len(ds) == 6 - _ = ds[0] + if regions_element == "blobs_circles": + if raster: + assert tile.shape == (3, 50, 50) + else: + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + if raster: + assert tile.shape == (3, 164, 164) + else: + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + if raster: + assert tile.shape == (3, 329, 329) + else: + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + + # extent has units in pixel so should be the same as tile shape + if raster: + assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] + else: + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) + assert list(sdata_tile.images.keys())[0] == "blobs_image" + + @pytest.mark.parametrize("regions_element", ["blobs_circles", "blobs_polygons", "blobs_multipolygons"]) + @pytest.mark.parametrize("return_annot", ["region", ["region", "instance_id"]]) + def test_return_annot(self, sdata_blobs, regions_element, return_annot): + sdata = self._annotate_shapes(sdata_blobs, regions_element) + ds = ImageTilesDataset( + sdata=sdata, + regions_to_images={regions_element: "blobs_image"}, + regions_to_coordinate_systems={regions_element: "global"}, + return_annotations=return_annot, + ) + + tile, annot = ds[0] + if regions_element == "blobs_circles": + assert tile.shape == (3, 25, 25) + elif regions_element == "blobs_polygons": + assert tile.shape == (3, 82, 82) + elif regions_element == "blobs_multipolygons": + assert tile.shape == (3, 165, 164) + else: + raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] + return_annot = [return_annot] if isinstance(return_annot, str) else return_annot + assert annot.shape[1] == len(return_annot) + + # TODO: consider adding this logic to blobs, to generate blobs with arbitrary table annotation + def _annotate_shapes(self, sdata: SpatialData, shape: str) -> SpatialData: + new_table = AnnData( + X=np.random.default_rng().random((len(sdata[shape]), 10)), + obs=pd.DataFrame({"region": shape, "instance_id": sdata[shape].index.values}), + ) + new_table = TableModel.parse(new_table, region=shape, region_key="region", instance_key="instance_id") + del sdata.table + sdata.table = new_table + return sdata diff --git a/tests/dataloader/test_transforms.py b/tests/dataloader/test_transforms.py deleted file mode 100644 index e69de29b..00000000 diff --git a/tests/io/test_multi_table.py b/tests/io/test_multi_table.py index f7d4671c..56755b79 100644 --- a/tests/io/test_multi_table.py +++ b/tests/io/test_multi_table.py @@ -1,94 +1,172 @@ from pathlib import Path -import anndata as ad -import numpy as np +import pytest from anndata import AnnData -from spatialdata import SpatialData +from anndata.tests.helpers import assert_equal +from spatialdata import SpatialData, concatenate +from spatialdata.models import TableModel -from tests.conftest import _get_new_table, _get_shapes +from tests.conftest import _get_shapes, _get_table # notes on paths: https://github.com/orgs/scverse/projects/17/views/1?pane=issue&itemId=44066734 -# notes for the people (to prettify) https://hackmd.io/wd7K4Eg1SlykKVN-nOP44w - -# shapes test_shapes = _get_shapes() -instance_id = np.array([str(i) for i in range(5)]) -table = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) -adata0 = _get_new_table() -adata1 = _get_new_table() - # shuffle the indices of the dataframe -np.random.default_rng().shuffle(test_shapes["poly"].index) +# np.random.default_rng().shuffle(test_shapes["poly"].index) -# tables is a dict -SpatialData.tables - -# def get_table_keys(sdata: SpatialData) -> tuple[list[str], str, str]: -# d = sdata.table.uns[sd.models.TableModel.ATTRS_KEY] -# return d['region'], d['region_key'], d['instance_key'] -# -# @staticmethod -# def SpatialData.get_key_column(table: AnnData, key_column: str) -> ...: -# region, region_key, instance_key = sd.models.get_table_keys() -# if key_clumns == 'region_key': -# return table.obs[region_key] -# else: .... -# -# @staticmethod -# def SpatialData.get_region_key_column(table: AnnData | str): -# return get_key_column(...) -# @staticmethod -# def SpatialData.get_instance_key_column(table: AnnData | str): -# return get_key_column(...) +class TestMultiTable: + def test_set_get_tables_from_spatialdata(self, full_sdata: SpatialData, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + adata0 = _get_table(region="polygon") + adata1 = _get_table(region="multipolygon") + full_sdata["adata0"] = adata0 + full_sdata["adata1"] = adata1 + + adata2 = adata0.copy() + del adata2.obs["region"] + # fails because either none either all three 'region', 'region_key', 'instance_key' are required + with pytest.raises(ValueError): + full_sdata["not_added_table"] = adata2 + + assert len(full_sdata.tables) == 3 + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + full_sdata.write(tmpdir) + + full_sdata = SpatialData.read(tmpdir) + assert_equal(adata0, full_sdata["adata0"]) + assert_equal(adata1, full_sdata["adata1"]) + assert "adata0" in full_sdata.tables and "adata1" in full_sdata.tables + + @pytest.mark.parametrize( + "region_key, instance_key, error_msg", + [ + ( + None, + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ( + "region", + None, + "Specified instance_key in table.uns 'instance_id' is not present as column in table.obs. " + "Please specify instance_key.", + ), + ("region", "instance_id", "Instance key column 'instance_id' not found in table.obs."), + (None, "instance_id", "Instance key column 'instance_id' not found in table.obs."), + ], + ) + def test_change_annotation_target(self, full_sdata, region_key, instance_key, error_msg): + n_obs = full_sdata["table"].n_obs + ## + with pytest.raises( + ValueError, match=r"Mismatch\(es\) found between regions in region column in obs and target element: " + ): + # ValueError: Mismatch(es) found between regions in region column in obs and target element: labels2d, poly + full_sdata.set_table_annotates_spatialelement("table", "poly") + ## + + del full_sdata["table"].obs["region"] + with pytest.raises( + ValueError, + match="Specified region_key in table.uns 'region' is not present as column in table.obs. " + "Please specify region_key.", + ): + full_sdata.set_table_annotates_spatialelement("table", "poly") + + del full_sdata["table"].obs["instance_id"] + full_sdata["table"].obs["region"] = ["poly"] * n_obs + with pytest.raises(ValueError, match=error_msg): + full_sdata.set_table_annotates_spatialelement( + "table", "poly", region_key=region_key, instance_key=instance_key + ) + + full_sdata["table"].obs["instance_id"] = range(n_obs) + full_sdata.set_table_annotates_spatialelement( + "table", "poly", instance_key="instance_id", region_key=region_key + ) -# we need also the two set_...() functions + with pytest.raises(ValueError, match="'not_existing' column not present in table.obs"): + full_sdata.set_table_annotates_spatialelement("table", "circles", region_key="not_existing") + + def test_set_table_nonexisting_target(self, full_sdata): + with pytest.raises( + ValueError, + match="Annotation target 'non_existing' not present as SpatialElement in " "SpatialData object.", + ): + full_sdata.set_table_annotates_spatialelement("table", "non_existing") + + def test_set_table_annotates_spatialelement(self, full_sdata): + del full_sdata["table"].uns[TableModel.ATTRS_KEY] + with pytest.raises( + TypeError, match="No current annotation metadata found. " "Please specify both region_key and instance_key." + ): + full_sdata.set_table_annotates_spatialelement("table", "labels2d", region_key="non_existent") + with pytest.raises(ValueError, match="Instance key column 'non_existent' not found in table.obs."): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="non_existent" + ) + with pytest.raises(ValueError, match="column not present"): + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="non_existing", instance_key="instance_id" + ) + full_sdata.set_table_annotates_spatialelement( + "table", "labels2d", region_key="region", instance_key="instance_id" + ) + def test_old_accessor_deprecation(self, full_sdata, tmp_path): + # To test self._backed + tmpdir = Path(tmp_path) / "tmp.zarr" + full_sdata.write(tmpdir) + adata0 = _get_table(region="polygon") -def get_annotation_target_of_table(table: AnnData) -> pd.Series: - return SpatialData.get_region_key_column(table) + with pytest.warns(DeprecationWarning): + _ = full_sdata.table + with pytest.raises(ValueError): + full_sdata.table = adata0 + with pytest.warns(DeprecationWarning): + del full_sdata.table + with pytest.raises(KeyError): + del full_sdata.table + with pytest.warns(DeprecationWarning): + full_sdata.table = adata0 # this gets placed in sdata['table'] + assert_equal(adata0, full_sdata.table) -def set_annotation_target_of_table(table: AnnData, spatial_element: str | pd.Series) -> None: - SpatialData.set_instance_key_column(table, spatial_element) + del full_sdata.table + full_sdata.tables["my_new_table0"] = adata0 + assert full_sdata.table is None -class TestMultiTable: - def test_set_get_tables_from_spatialdata(self, sdata): # sdata is form conftest - sdata["my_new_table0"] = adata0 - sdata["my_new_table1"] = adata1 - - def test_old_accessor_deprecation(self, sdata): - # assume no table is present - # this prints a deprecation warning - sdata.table = adata0 # this gets placed in sdata['table'] - # this prints a deprecation warning - _ = sdata.table # this returns sdata['table'] - # this prints a deprecation waring - del sdata.table - - sdata["my_new_table0"] = adata0 - # will fail, because there is no sdata['table'], even if another table is present - _ = sdata.table - - def test_single_table(self, tmp_path: str): - # shared table + @pytest.mark.parametrize("region", ["test_shapes", "non_existing"]) + def test_single_table(self, tmp_path: str, region: str): tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=region) + + # Create shapes dictionary + shapes_dict = { + "test_shapes": test_shapes["poly"], + } + + if region == "non_existing": + with pytest.warns( + UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" + ): + SpatialData( + shapes=shapes_dict, + tables={"shape_annotate": table}, + ) test_sdata = SpatialData( - shapes={ - "test_shapes": test_shapes["poly"], - }, + shapes=shapes_dict, tables={"shape_annotate": table}, ) + test_sdata.write(tmpdir) sdata = SpatialData.read(tmpdir) - assert sdata.get("segmentation") - assert isinstance(sdata["segmentation"], AnnData) - from anndata.tests.helpers import assert_equal - - assert assert_equal(test_sdata["segmentation"], sdata["segmentation"]) + assert isinstance(sdata["shape_annotate"], AnnData) + assert_equal(test_sdata["shape_annotate"], sdata["shape_annotate"]) # note (to keep in the code): these tests here should silmulate the interactions from teh users; if the syntax # here we are matching the table to the shapes and viceversa (= subset + reordeing) @@ -107,28 +185,41 @@ def test_single_table(self, tmp_path: str): # assert ... def test_paired_elements_tables(self, tmp_path: str): - pass - - def test_elements_transfer_annotation(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region="poly") + table2 = _get_table(region="multipoly") + table3 = _get_table(region="non_existing") + with pytest.warns( + UserWarning, match=r"The table is annotating an/some element\(s\) not present in the SpatialData object" + ): + SpatialData( + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table3}, + ) test_sdata = SpatialData( - shapes={"test_shapes": test_shapes["poly"], "test_multipoly": test_shapes["multipoly"]}, - tables={"segmentation": table}, + shapes={"poly": test_shapes["poly"], "multipoly": test_shapes["multipoly"]}, + table={"poly_annotate": table, "multipoly_annotate": table2}, ) - set_annotation_target_of_table(test_sdata["segmentation"], "test_multipoly") - assert get_annotation_target_of_table(test_sdata["segmentation"]) == "test_multipoly" + test_sdata.write(tmpdir) + test_sdata = SpatialData.read(tmpdir) + assert len(test_sdata.tables) == 2 def test_single_table_multiple_elements(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=["poly", "multipoly"]) + subset = table[table.obs.region == "multipoly"] + with pytest.raises(ValueError, match="Regions in"): + TableModel().validate(subset) test_sdata = SpatialData( shapes={ - "test_shapes": test_shapes["poly"], - "test_multipoly": test_shapes["multi_poly"], + "poly": test_shapes["poly"], + "multipoly": test_shapes["multipoly"], }, - tables={"segmentation": table}, + table=table, ) test_sdata.write(tmpdir) - # sdata = SpatialData.read(tmpdir) + SpatialData.read(tmpdir) # # use case example 1 # # sorting the shapes visium0 to match the order of the table @@ -141,41 +232,58 @@ def test_single_table_multiple_elements(self, tmp_path: str): # sub_table.obs[sdata["visium0"]] # assert ... - def test_concatenate_tables(self): - table_two = _get_new_table(spatial_element="test_multipoly", instance_id=np.array([str(i) for i in range(2)])) - concatenated_table = ad.concat([table, table_two]) - test_sdata = SpatialData( - shapes={ - "test_shapes": test_shapes["poly"], - "test_multipoly": test_shapes["multi_poly"], - }, - tables={"segmentation": concatenated_table}, - ) - # use case tests as above (we test only visium0) - - def test_multiple_table_without_element(self): - table = _get_new_table() - table_two = _get_new_table() + def test_multiple_table_without_element(self, tmp_path: str): + tmpdir = Path(tmp_path) / "tmp.zarr" + table = _get_table(region=None, region_key=None, instance_key=None) + table_two = _get_table(region=None, region_key=None, instance_key=None) - test_sdata = SpatialData( + sdata = SpatialData( tables={"table": table, "table_two": table_two}, ) + sdata.write(tmpdir) + SpatialData.read(tmpdir) def test_multiple_tables_same_element(self, tmp_path: str): tmpdir = Path(tmp_path) / "tmp.zarr" - table_two = _get_new_table(spatial_element="test_shapes", instance_id=instance_id) + table = _get_table(region="test_shapes") + table2 = _get_table(region="test_shapes") test_sdata = SpatialData( shapes={ "test_shapes": test_shapes["poly"], }, - tables={"segmentation": table, "segmentation_two": table_two}, + tables={"table": table, "table2": table2}, ) test_sdata.write(tmpdir) + SpatialData.read(tmpdir) -# -# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, either helper function, ...) +def test_concatenate_sdata_multitables(): + sdatas = [ + SpatialData( + shapes={f"poly_{i + 1}": test_shapes["poly"], f"multipoly_{i + 1}": test_shapes["multipoly"]}, + tables={"table": _get_table(region=f"poly_{i + 1}"), "table2": _get_table(region=f"multipoly_{i + 1}")}, + ) + for i in range(3) + ] + + with pytest.warns( + UserWarning, + match="Duplicate table names found.", + ): + concatenate(sdatas) + + merged_sdata = concatenate(sdatas, concatenate_tables=True) + assert merged_sdata.tables["table"].n_obs == 300 + assert merged_sdata.tables["table2"].n_obs == 300 + assert all(merged_sdata.tables["table"].obs.region.unique() == ["poly_1", "poly_2", "poly_3"]) + assert all(merged_sdata.tables["table2"].obs.region.unique() == ["multipoly_1", "multipoly_2", "multipoly_3"]) + + +# The following use cases needs to be put in the tutorial notebook, let's keep the comment here until we have the +# notebook ready. +# # these use cases could be the preferred one for the users; we need to choose one/two preferred ones (either this, +# either helper function, ...) # # use cases # # use case example 1 # # sorting the shapes to match the order of the table @@ -187,6 +295,7 @@ def test_multiple_tables_same_element(self, tmp_path: str): # sdata.table.obs[sdata["visium0"]] # assert ... # +# We can postpone the implemntation of this test when the functions "match_table_to_element" etc. are ready. # def test_partial_match(): # # the function spatialdata._core.query.relational_query.match_table_to_element(no s) needs to be modified (will be # # simpler), we need also a function match_element_to_table. Maybe we can have just one function doing both the diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 046baf3b..e629182d 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -12,7 +12,7 @@ from numpy.random import default_rng from shapely.geometry import Point from spatial_image import SpatialImage -from spatialdata import SpatialData +from spatialdata import SpatialData, read_zarr from spatialdata._io._utils import _are_directories_identical from spatialdata.models import TableModel from spatialdata.transformations.operations import ( @@ -110,7 +110,7 @@ def test_incremental_io( tmpdir = Path(tmp_path) / "tmp.zarr" sdata = full_sdata - sdata.add_image(name="sdata_not_saved_yet", image=_get_images().values().__iter__().__next__()) + sdata.images["sdata_not_saved_yet"] = _get_images().values().__iter__().__next__() sdata.write(tmpdir) for k, v in _get_images().items(): @@ -122,10 +122,10 @@ def test_incremental_io( assert len(names) == 1 name = names[0] v[scale] = v[scale].rename_vars({name: f"incremental_{k}"}) - sdata.add_image(name=f"incremental_{k}", image=v) - with pytest.raises(KeyError): - sdata.add_image(name=f"incremental_{k}", image=v) - sdata.add_image(name=f"incremental_{k}", image=v, overwrite=True) + sdata.images[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.images[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v for k, v in _get_labels().items(): if isinstance(v, SpatialImage): @@ -136,26 +136,26 @@ def test_incremental_io( assert len(names) == 1 name = names[0] v[scale] = v[scale].rename_vars({name: f"incremental_{k}"}) - sdata.add_labels(name=f"incremental_{k}", labels=v) - with pytest.raises(KeyError): - sdata.add_labels(name=f"incremental_{k}", labels=v) - sdata.add_labels(name=f"incremental_{k}", labels=v, overwrite=True) + sdata.labels[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.labels[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v for k, v in _get_shapes().items(): - sdata.add_shapes(name=f"incremental_{k}", shapes=v) - with pytest.raises(KeyError): - sdata.add_shapes(name=f"incremental_{k}", shapes=v) - sdata.add_shapes(name=f"incremental_{k}", shapes=v, overwrite=True) + sdata.shapes[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.shapes[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v break for k, v in _get_points().items(): - sdata.add_points(name=f"incremental_{k}", points=v) - with pytest.raises(KeyError): - sdata.add_points(name=f"incremental_{k}", points=v) - sdata.add_points(name=f"incremental_{k}", points=v, overwrite=True) + sdata.points[f"incremental_{k}"] = v + with pytest.warns(UserWarning): + sdata.points[f"incremental_{k}"] = v + sdata[f"incremental_{k}"] = v break - def test_incremental_io_table(self, table_single_annotation): + def test_incremental_io_table(self, table_single_annotation: SpatialData) -> None: s = table_single_annotation t = s.table[:10, :].copy() with pytest.raises(ValueError): @@ -182,8 +182,8 @@ def test_io_and_lazy_loading_points(self, points): f = os.path.join(td, "data.zarr") dask0 = points.points[elem_name] points.write(f) - dask1 = points.points[elem_name] assert all("read-parquet" not in key for key in dask0.dask.layers) + dask1 = read_zarr(f).points[elem_name] assert any("read-parquet" in key for key in dask1.dask.layers) def test_io_and_lazy_loading_raster(self, images, labels): @@ -198,6 +198,7 @@ def test_io_and_lazy_loading_raster(self, images, labels): sdata.write(f) dask1 = d[elem_name].data assert all("from-zarr" not in key for key in dask0.dask.layers) + dask1 = read_zarr(f)[elem_name].data assert any("from-zarr" in key for key in dask1.dask.layers) def test_replace_transformation_on_disk_raster(self, images, labels): @@ -238,12 +239,34 @@ def test_replace_transformation_on_disk_non_raster(self, shapes, points): t1 = get_transformation(SpatialData.read(f).__getattribute__(k)[elem_name]) assert type(t1) == Scale + def test_overwrite_files_without_backed_data(self, full_sdata): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + old_data = SpatialData() + old_data.write(f) + # Since not backed, no risk of overwriting backing data. + # Should not raise "The file path specified is the same as the one used for backing." + full_sdata.write(f, overwrite=True) + + def test_not_overwrite_files_without_backed_data_but_with_dask_backed_data(self, full_sdata, points): + with tempfile.TemporaryDirectory() as tmpdir: + f = os.path.join(tmpdir, "data.zarr") + points.write(f) + points2 = SpatialData.read(f) + p = points2["points_0"] + full_sdata["points_0"] = p + with pytest.raises( + ValueError, + match="The file path specified is a parent directory of one or more files used for backing for one or ", + ): + full_sdata.write(f, overwrite=True) + def test_overwrite_files_with_backed_data(self, full_sdata): # addressing https://github.com/scverse/spatialdata/issues/137 with tempfile.TemporaryDirectory() as tmpdir: f = os.path.join(tmpdir, "data.zarr") full_sdata.write(f) - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="The file path specified is the same as the one used for backing."): full_sdata.write(f, overwrite=True) # support for overwriting backed sdata has been temporarily removed @@ -275,45 +298,6 @@ def test_overwrite_onto_non_zarr_file(self, full_sdata): with pytest.raises(ValueError): full_sdata.write(f1) - def test_incremental_io_with_backed_elements(self, full_sdata): - # addressing https://github.com/scverse/spatialdata/issues/137 - # we test also the non-backed case so that if we switch to the - # backed version in the future we already have the tests - - with tempfile.TemporaryDirectory() as tmpdir: - f = os.path.join(tmpdir, "data.zarr") - full_sdata.write(f) - - e = full_sdata.images.values().__iter__().__next__() - full_sdata.add_image("new_images", e, overwrite=True) - # support for overwriting backed images has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_image("new_images", full_sdata.images["new_images"], overwrite=True) - - e = full_sdata.labels.values().__iter__().__next__() - full_sdata.add_labels("new_labels", e, overwrite=True) - # support for overwriting backed labels has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_labels("new_labels", full_sdata.labels["new_labels"], overwrite=True) - - e = full_sdata.points.values().__iter__().__next__() - full_sdata.add_points("new_points", e, overwrite=True) - # support for overwriting backed points has been temporarily removed - with pytest.raises(ValueError): - full_sdata.add_points("new_points", full_sdata.points["new_points"], overwrite=True) - - e = full_sdata.shapes.values().__iter__().__next__() - full_sdata.add_shapes("new_shapes", e, overwrite=True) - full_sdata.add_shapes("new_shapes", full_sdata.shapes["new_shapes"], overwrite=True) - - # commenting out as it is failing - # f2 = os.path.join(tmpdir, "data2.zarr") - # sdata2 = SpatialData(table=full_sdata.table.copy()) - # sdata2.write(f2) - # del full_sdata.table - # full_sdata.table = sdata2.table - # full_sdata.write(f2, overwrite=True) - def test_io_table(shapes): adata = AnnData(X=RNG.normal(size=(5, 10))) diff --git a/tests/io/test_utils.py b/tests/io/test_utils.py index 3e2cb04f..d8f86c44 100644 --- a/tests/io/test_utils.py +++ b/tests/io/test_utils.py @@ -5,12 +5,13 @@ import numpy as np import pytest from spatialdata import read_zarr, save_transformations -from spatialdata._io._utils import get_backing_files +from spatialdata._io._utils import get_dask_backing_files from spatialdata._utils import multiscale_spatial_image_from_data_tree from spatialdata.transformations import Scale, get_transformation, set_transformation def test_backing_files_points(points): + """Test the ability to identify the backing files of a dask dataframe from examining its computational graph""" with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "points0.zarr") f1 = os.path.join(tmp_dir, "points1.zarr") @@ -21,7 +22,7 @@ def test_backing_files_points(points): p0 = points0.points["points_0"] p1 = points1.points["points_0"] p2 = dd.concat([p0, p1], axis=0) - files = get_backing_files(p2) + files = get_dask_backing_files(p2) expected_zarr_locations = [ os.path.realpath(os.path.join(f, "points/points_0/points.parquet")) for f in [f0, f1] ] @@ -29,6 +30,10 @@ def test_backing_files_points(points): def test_backing_files_images(images): + """ + Test the ability to identify the backing files of single scale and multiscale images from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "images0.zarr") f1 = os.path.join(tmp_dir, "images1.zarr") @@ -41,7 +46,7 @@ def test_backing_files_images(images): im0 = images0.images["image2d"] im1 = images1.images["image2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -49,13 +54,17 @@ def test_backing_files_images(images): im3 = images0.images["image2d_multiscale"] im4 = images1.images["image2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "images/image2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) # TODO: this function here below is very similar to the above, unify the test with the above or delete this todo def test_backing_files_labels(labels): + """ + Test the ability to identify the backing files of single scale and multiscale labels from examining their + computational graph + """ with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr") f1 = os.path.join(tmp_dir, "labels1.zarr") @@ -68,7 +77,7 @@ def test_backing_files_labels(labels): im0 = labels0.labels["labels2d"] im1 = labels1.labels["labels2d"] im2 = im0 + im1 - files = get_backing_files(im2) + files = get_dask_backing_files(im2) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) @@ -76,11 +85,37 @@ def test_backing_files_labels(labels): im3 = labels0.labels["labels2d_multiscale"] im4 = labels1.labels["labels2d_multiscale"] im5 = multiscale_spatial_image_from_data_tree(im3 + im4) - files = get_backing_files(im5) + files = get_dask_backing_files(im5) expected_zarr_locations = [os.path.realpath(os.path.join(f, "labels/labels2d_multiscale")) for f in [f0, f1]] assert set(files) == set(expected_zarr_locations) +def test_backing_files_combining_points_and_images(points, images): + """ + Test the ability to identify the backing files of an object that depends both on dask dataframes and dask arrays + from examining its computational graph + """ + with tempfile.TemporaryDirectory() as tmp_dir: + f0 = os.path.join(tmp_dir, "points0.zarr") + f1 = os.path.join(tmp_dir, "images1.zarr") + points.write(f0) + images.write(f1) + points0 = read_zarr(f0) + images1 = read_zarr(f1) + + p0 = points0.points["points_0"] + im1 = images1.images["image2d"] + v = p0["x"].loc[0].values + v.compute_chunk_sizes() + im2 = v + im1 + files = get_dask_backing_files(im2) + expected_zarr_locations = [ + os.path.realpath(os.path.join(f0, "points/points_0/points.parquet")), + os.path.realpath(os.path.join(f1, "images/image2d")), + ] + assert set(files) == set(expected_zarr_locations) + + def test_save_transformations(labels): with tempfile.TemporaryDirectory() as tmp_dir: f0 = os.path.join(tmp_dir, "labels0.zarr") diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 34035c7d..8e2cd333 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1,7 +1,6 @@ from __future__ import annotations import os -import pathlib import tempfile from copy import deepcopy from functools import partial @@ -22,7 +21,7 @@ from pandas.api.types import is_categorical_dtype from shapely.io import to_ragged_array from spatial_image import SpatialImage, to_spatial_image -from spatialdata import SpatialData +from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike from spatialdata.models import ( Image2DModel, @@ -119,7 +118,7 @@ def _parse_transformation_from_multiple_places(self, model: Any, element: Any, * str, np.ndarray, dask.array.core.Array, - pathlib.PosixPath, + Path, pd.DataFrame, ) ):