Skip to content

Commit

Permalink
Fixes bug with extent of rotated data (#373)
Browse files Browse the repository at this point in the history
  • Loading branch information
LucaMarconato authored Oct 10, 2023
1 parent b229099 commit 1395fb5
Show file tree
Hide file tree
Showing 3 changed files with 276 additions and 112 deletions.
186 changes: 112 additions & 74 deletions src/spatialdata/_core/data_extent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from collections import defaultdict
from functools import singledispatch
from typing import Union

import numpy as np
import pandas as pd
Expand All @@ -15,7 +14,6 @@

from spatialdata._core.operations.transform import transform
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import get_axes_names
from spatialdata.models._utils import SpatialElement
from spatialdata.models.models import PointsModel
Expand Down Expand Up @@ -86,32 +84,45 @@ def _get_extent_of_polygons_multipolygons(
return extent


def _get_extent_of_points(e: DaskDataFrame) -> BoundingBoxDescription:
axes = get_axes_names(e)
min_coordinates = np.array([e[ax].min().compute() for ax in axes])
max_coordinates = np.array([e[ax].max().compute() for ax in axes])
extent = {}
for i, ax in enumerate(axes):
extent[ax] = (min_coordinates[i], max_coordinates[i])
return extent


def _get_extent_of_data_array(e: DataArray, coordinate_system: str) -> BoundingBoxDescription:
# lightweight conversion to SpatialImage just to fix the type of the single-dispatch
_check_element_has_coordinate_system(element=SpatialImage(e), coordinate_system=coordinate_system)
# also here
data_axes = get_axes_names(SpatialImage(e))
min_coordinates = []
max_coordinates = []
axes = []
extent: BoundingBoxDescription = {}
for ax in ["z", "y", "x"]:
if ax in data_axes:
i = data_axes.index(ax)
axes.append(ax)
min_coordinates.append(0)
max_coordinates.append(e.shape[i])
extent[ax] = (0, e.shape[i])
return _compute_extent_in_coordinate_system(
# and here
element=SpatialImage(e),
coordinate_system=coordinate_system,
min_coordinates=np.array(min_coordinates),
max_coordinates=np.array(max_coordinates),
axes=tuple(axes),
extent=extent,
)


@singledispatch
def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global") -> BoundingBoxDescription:
def get_extent(
e: SpatialData | SpatialElement,
coordinate_system: str = "global",
exact: bool = True,
has_images: bool = True,
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: list[str] | None = None,
) -> BoundingBoxDescription:
"""
Get the extent (bounding box) of a SpatialData object or a SpatialElement.
Expand All @@ -128,6 +139,37 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global
The maximum coordinate of the bounding box.
axes
The names of the dimensions of the bounding box
exact
If True, the extent is computed exactly. If False, an approximation faster to compute is given. The
approximation is guaranteed to contain all the data, see notes for details.
has_images
If True, images are included in the computation of the extent.
has_labels
If True, labels are included in the computation of the extent.
has_points
If True, points are included in the computation of the extent.
has_shapes
If True, shapes are included in the computation of the extent.
elements
If not None, only the elements with the given names are included in the computation of the extent.
Notes
-----
The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a
SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`.
If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent
is transformed to the target coordinate system. This is faster than computing the extent after the transformation,
since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and
then computing the extent.
The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the
case in which the transformation is affine but all the corners of the extent of the untransformed data
(bounding box corners) are part of the dataset itself. Note that this is always the case for raster data.
An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The
exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the
box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414].
"""
raise ValueError("The object type is not supported.")

Expand All @@ -136,11 +178,12 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global
def _(
e: SpatialData,
coordinate_system: str = "global",
exact: bool = True,
has_images: bool = True,
has_labels: bool = True,
has_points: bool = True,
has_shapes: bool = True,
elements: Union[list[str], None] = None,
elements: list[str] | None = None,
) -> BoundingBoxDescription:
"""
Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements.
Expand Down Expand Up @@ -174,7 +217,10 @@ def _(
assert isinstance(transformations, dict)
coordinate_systems = list(transformations.keys())
if coordinate_system in coordinate_systems:
extent = get_extent(element_obj, coordinate_system=coordinate_system)
if isinstance(element_obj, (DaskDataFrame, GeoDataFrame)):
extent = get_extent(element_obj, coordinate_system=coordinate_system, exact=exact)
else:
extent = get_extent(element_obj, coordinate_system=coordinate_system)
axes = list(extent.keys())
for ax in axes:
new_min_coordinates_dict[ax] += [extent[ax][0]]
Expand All @@ -183,8 +229,14 @@ def _(
raise ValueError(
f"The SpatialData object does not contain any element in the "
f" coordinate system {coordinate_system!r}, "
f"please pass a different coordinate system wiht the argument 'coordinate_system'."
f"please pass a different coordinate system with the argument 'coordinate_system'."
)
if len(new_min_coordinates_dict) == 0:
raise ValueError(
f"The SpatialData object does not contain any element in the coordinate system {coordinate_system!r}, "
"please pass a different coordinate system with the argument 'coordinate_system'."
)
axes = list(new_min_coordinates_dict.keys())
new_min_coordinates = np.array([min(new_min_coordinates_dict[ax]) for ax in axes])
new_max_coordinates = np.array([max(new_max_coordinates_dict[ax]) for ax in axes])
extent = {}
Expand All @@ -193,8 +245,21 @@ def _(
return extent


def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription:
# remove potentially empty geometries
e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)]
assert len(e_temp) > 0, "Cannot compute extent of an empty collection of geometries."

# separate points from (multi-)polygons
first_geometry = e_temp["geometry"].iloc[0]
if isinstance(first_geometry, Point):
return _get_extent_of_circles(e)
assert isinstance(first_geometry, (Polygon, MultiPolygon))
return _get_extent_of_polygons_multipolygons(e)


@get_extent.register
def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription:
def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription:
"""
Compute the extent (bounding box) of a set of shapes.
Expand All @@ -203,57 +268,33 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescript
The bounding box description.
"""
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
# remove potentially empty geometries
e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)]

# separate points from (multi-)polygons
e_points = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, Point))]
e_polygons = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, (Polygon, MultiPolygon)))]
extent = None
if len(e_points) > 0:
assert "radius" in e_points.columns, "Shapes that are points must have a 'radius' column."
extent = _get_extent_of_circles(e_points)
if len(e_polygons) > 0:
extent_polygons = _get_extent_of_polygons_multipolygons(e_polygons)
if extent is None:
extent = extent_polygons
else:
# case when there are points AND (multi-)polygons in the GeoDataFrame
extent["y"] = (min(extent["y"][0], extent_polygons["y"][0]), max(extent["y"][1], extent_polygons["y"][1]))
extent["x"] = (min(extent["x"][0], extent_polygons["x"][0]), max(extent["x"][1], extent_polygons["x"][1]))

if extent is None:
raise ValueError(
"Unable to compute extent of GeoDataFrame. It needs to contain at least one non-empty "
"Point or Polygon or Multipolygon."
if not exact:
extent = _get_extent_of_shapes(e)
return _compute_extent_in_coordinate_system(
element=e,
coordinate_system=coordinate_system,
extent=extent,
)

min_coordinates = [extent["y"][0], extent["x"][0]]
max_coordinates = [extent["y"][1], extent["x"][1]]
axes = tuple(extent.keys())

return _compute_extent_in_coordinate_system(
element=e_temp,
coordinate_system=coordinate_system,
min_coordinates=np.array(min_coordinates),
max_coordinates=np.array(max_coordinates),
axes=axes,
)
t = get_transformation(e, to_coordinate_system=coordinate_system)
assert isinstance(t, BaseTransformation)
transformed = transform(e, t)
return _get_extent_of_shapes(transformed)


@get_extent.register
def _(e: DaskDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription:
def _(e: DaskDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription:
_check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system)
axes = get_axes_names(e)
min_coordinates = np.array([e[ax].min().compute() for ax in axes])
max_coordinates = np.array([e[ax].max().compute() for ax in axes])
return _compute_extent_in_coordinate_system(
element=e,
coordinate_system=coordinate_system,
min_coordinates=min_coordinates,
max_coordinates=max_coordinates,
axes=axes,
)
if not exact:
extent = _get_extent_of_points(e)
return _compute_extent_in_coordinate_system(
element=e,
coordinate_system=coordinate_system,
extent=extent,
)
t = get_transformation(e, to_coordinate_system=coordinate_system)
assert isinstance(t, BaseTransformation)
transformed = transform(e, t)
return _get_extent_of_points(transformed)


@get_extent.register
Expand All @@ -275,16 +316,12 @@ def _check_element_has_coordinate_system(element: SpatialElement, coordinate_sys
if coordinate_system not in coordinate_systems:
raise ValueError(
f"The element does not contain any coordinate system named {coordinate_system!r}, "
f"please pass a different coordinate system wiht the argument 'coordinate_system'."
f"please pass a different coordinate system with the argument 'coordinate_system'."
)


def _compute_extent_in_coordinate_system(
element: SpatialElement | DataArray,
coordinate_system: str,
min_coordinates: ArrayLike,
max_coordinates: ArrayLike,
axes: tuple[str, ...],
element: SpatialElement | DataArray, coordinate_system: str, extent: BoundingBoxDescription
) -> BoundingBoxDescription:
"""
Transform the extent from the intrinsic coordinates of the element to the given coordinate system.
Expand All @@ -295,12 +332,8 @@ def _compute_extent_in_coordinate_system(
The SpatialElement.
coordinate_system
The coordinate system to transform the extent to.
min_coordinates
Min coordinates of the extent in the intrinsic coordinates of the element, expects [y_min, x_min].
max_coordinates
Max coordinates of the extent in the intrinsic coordinates of the element, expects [y_max, x_max].
axes
The min and max coordinates refer to.
extent
The extent in the intrinsic coordinates of the element.
Returns
-------
Expand All @@ -310,6 +343,11 @@ def _compute_extent_in_coordinate_system(
assert isinstance(transformation, BaseTransformation)
from spatialdata._core.query._utils import get_bounding_box_corners

axes = get_axes_names(element)
if "c" in axes:
axes = tuple(ax for ax in axes if ax != "c")
min_coordinates = np.array([extent[ax][0] for ax in axes])
max_coordinates = np.array([extent[ax][1] for ax in axes])
corners = get_bounding_box_corners(
axes=axes,
min_coordinate=min_coordinates,
Expand Down
7 changes: 3 additions & 4 deletions src/spatialdata/models/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import dask.dataframe as dd
import geopandas
from anndata import AnnData
from dask.dataframe import DataFrame as DaskDataFrame
from geopandas import GeoDataFrame
from multiscale_spatial_image import MultiscaleSpatialImage
Expand Down Expand Up @@ -167,15 +166,15 @@ def _(e: MultiscaleSpatialImage) -> tuple[str, ...]:

@get_axes_names.register(GeoDataFrame)
def _(e: GeoDataFrame) -> tuple[str, ...]:
all_dims = (Z, Y, X)
all_dims = (X, Y, Z)
n = e.geometry.iloc[0]._ndim
dims = all_dims[-n:]
dims = all_dims[:n]
_validate_dims(dims)
return dims


@get_axes_names.register(DaskDataFrame)
def _(e: AnnData) -> tuple[str, ...]:
def _(e: DaskDataFrame) -> tuple[str, ...]:
valid_dims = (X, Y, Z)
dims = tuple([c for c in valid_dims if c in e.columns])
_validate_dims(dims)
Expand Down
Loading

0 comments on commit 1395fb5

Please sign in to comment.