From a08e413be461767f6d89572bada1da5780e1e30d Mon Sep 17 00:00:00 2001 From: Luca Marconato Date: Tue, 10 Oct 2023 16:27:32 +0200 Subject: [PATCH] applied code review, added test for xy swap bug --- src/spatialdata/_core/data_extent.py | 25 +++++++++---------------- tests/core/test_data_extent.py | 12 +++++++++++- 2 files changed, 20 insertions(+), 17 deletions(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index eb77c771..5fcab46f 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -1,3 +1,4 @@ +# Functions to compute the bounding box describing the extent of a SpatialElement or region. from __future__ import annotations from collections import defaultdict @@ -32,6 +33,7 @@ def _get_extent_of_circles(circles: GeoDataFrame) -> BoundingBoxDescription: Parameters ---------- circles + The circles represented as a GeoDataFrame with a radius column. Returns ------- @@ -54,10 +56,7 @@ def _get_extent_of_circles(circles: GeoDataFrame) -> BoundingBoxDescription: bounds["maxx"] += circles["radius"] bounds["maxy"] += circles["radius"] - extent = {} - for ax in axes: - extent[ax] = (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) - return extent + return {ax: (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) for ax in axes} def _get_extent_of_polygons_multipolygons( @@ -69,6 +68,7 @@ def _get_extent_of_polygons_multipolygons( Parameters ---------- shapes + The shapes represented as a GeoDataFrame. Returns ------- @@ -77,21 +77,14 @@ def _get_extent_of_polygons_multipolygons( assert isinstance(shapes.geometry.iloc[0], (Polygon, MultiPolygon)) axes = get_axes_names(shapes) bounds = shapes["geometry"].bounds - # NOTE: this implies the order x, y (which is probably correct?) - extent = {} - for ax in axes: - extent[ax] = (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) - return extent + return {ax: (bounds[f"min{ax}"].min(), bounds[f"max{ax}"].max()) for ax in axes} def _get_extent_of_points(e: DaskDataFrame) -> BoundingBoxDescription: axes = get_axes_names(e) - min_coordinates = np.array([e[ax].min().compute() for ax in axes]) - max_coordinates = np.array([e[ax].max().compute() for ax in axes]) - extent = {} - for i, ax in enumerate(axes): - extent[ax] = (min_coordinates[i], max_coordinates[i]) - return extent + mins = dict(e[list(axes)].min().compute()) + maxs = dict(e[list(axes)].max().compute()) + return {ax: (mins[ax], maxs[ax]) for ax in axes} def _get_extent_of_data_array(e: DataArray, coordinate_system: str) -> BoundingBoxDescription: @@ -129,7 +122,7 @@ def get_extent( Parameters ---------- e - The SpatialData object or SpatialElement to computed the extent of. + The SpatialData object or SpatialElement to compute the extent of. Returns ------- diff --git a/tests/core/test_data_extent.py b/tests/core/test_data_extent.py index d94a2b88..204138c1 100644 --- a/tests/core/test_data_extent.py +++ b/tests/core/test_data_extent.py @@ -4,15 +4,17 @@ import pandas as pd import pytest from geopandas import GeoDataFrame +from numpy.random import default_rng from shapely.geometry import MultiPolygon, Point, Polygon from spatialdata import SpatialData, get_extent, transform from spatialdata._utils import _deepcopy_geodataframe from spatialdata.datasets import blobs -from spatialdata.models import PointsModel, ShapesModel +from spatialdata.models import Image2DModel, PointsModel, ShapesModel from spatialdata.transformations import Affine, Translation, remove_transformation, set_transformation # for faster tests; we will pay attention not to modify the original data sdata = blobs() +RNG = default_rng(seed=0) def check_test_results0(extent, min_coordinates, max_coordinates, axes): @@ -333,3 +335,11 @@ def test_get_extent_affine_sdata(): max_coordinates=max_coordinates1, axes=("x", "y"), ) + + +def test_bug_get_extent_swap_xy_for_images(): + # https://github.com/scverse/spatialdata/issues/335#issue-1842914360 + x = RNG.random((1, 10, 20)) + im = Image2DModel.parse(x, dims=("c", "x", "y")) + extent = get_extent(im) + check_test_results1(extent, {"x": (0.0, 10.0), "y": (0.0, 20.0)})