From 1395fb5e8ebf03f56e24851cf2ddaad0fa4d1317 Mon Sep 17 00:00:00 2001 From: LucaMarconato <2664412+LucaMarconato@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:58:42 +0200 Subject: [PATCH] Fixes bug with extent of rotated data (#373) --- src/spatialdata/_core/data_extent.py | 186 +++++++++++++++---------- src/spatialdata/models/_utils.py | 7 +- tests/core/test_data_extent.py | 195 ++++++++++++++++++++++----- 3 files changed, 276 insertions(+), 112 deletions(-) diff --git a/src/spatialdata/_core/data_extent.py b/src/spatialdata/_core/data_extent.py index 491dad3f..eb77c771 100644 --- a/src/spatialdata/_core/data_extent.py +++ b/src/spatialdata/_core/data_extent.py @@ -2,7 +2,6 @@ from collections import defaultdict from functools import singledispatch -from typing import Union import numpy as np import pandas as pd @@ -15,7 +14,6 @@ from spatialdata._core.operations.transform import transform from spatialdata._core.spatialdata import SpatialData -from spatialdata._types import ArrayLike from spatialdata.models import get_axes_names from spatialdata.models._utils import SpatialElement from spatialdata.models.models import PointsModel @@ -86,32 +84,45 @@ def _get_extent_of_polygons_multipolygons( return extent +def _get_extent_of_points(e: DaskDataFrame) -> BoundingBoxDescription: + axes = get_axes_names(e) + min_coordinates = np.array([e[ax].min().compute() for ax in axes]) + max_coordinates = np.array([e[ax].max().compute() for ax in axes]) + extent = {} + for i, ax in enumerate(axes): + extent[ax] = (min_coordinates[i], max_coordinates[i]) + return extent + + def _get_extent_of_data_array(e: DataArray, coordinate_system: str) -> BoundingBoxDescription: # lightweight conversion to SpatialImage just to fix the type of the single-dispatch _check_element_has_coordinate_system(element=SpatialImage(e), coordinate_system=coordinate_system) # also here data_axes = get_axes_names(SpatialImage(e)) - min_coordinates = [] - max_coordinates = [] - axes = [] + extent: BoundingBoxDescription = {} for ax in ["z", "y", "x"]: if ax in data_axes: i = data_axes.index(ax) - axes.append(ax) - min_coordinates.append(0) - max_coordinates.append(e.shape[i]) + extent[ax] = (0, e.shape[i]) return _compute_extent_in_coordinate_system( # and here element=SpatialImage(e), coordinate_system=coordinate_system, - min_coordinates=np.array(min_coordinates), - max_coordinates=np.array(max_coordinates), - axes=tuple(axes), + extent=extent, ) @singledispatch -def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global") -> BoundingBoxDescription: +def get_extent( + e: SpatialData | SpatialElement, + coordinate_system: str = "global", + exact: bool = True, + has_images: bool = True, + has_labels: bool = True, + has_points: bool = True, + has_shapes: bool = True, + elements: list[str] | None = None, +) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object or a SpatialElement. @@ -128,6 +139,37 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global The maximum coordinate of the bounding box. axes The names of the dimensions of the bounding box + exact + If True, the extent is computed exactly. If False, an approximation faster to compute is given. The + approximation is guaranteed to contain all the data, see notes for details. + has_images + If True, images are included in the computation of the extent. + has_labels + If True, labels are included in the computation of the extent. + has_points + If True, points are included in the computation of the extent. + has_shapes + If True, shapes are included in the computation of the extent. + elements + If not None, only the elements with the given names are included in the computation of the extent. + + Notes + ----- + The extent of a SpatialData object is the extent of the union of the extents of all its elements. The extent of a + SpatialElement is the extent of the element in the coordinate system specified by the argument `coordinate_system`. + + If `exact` is False, first the extent of the SpatialElement before any transformation is computed. Then, the extent + is transformed to the target coordinate system. This is faster than computing the extent after the transformation, + since the transformation is applied to extent of the untransformed data, as opposed to transforming the data and + then computing the extent. + + The exact and approximate extent are the same if the transformation doesn't contain any rotation or shear, or in the + case in which the transformation is affine but all the corners of the extent of the untransformed data + (bounding box corners) are part of the dataset itself. Note that this is always the case for raster data. + + An extreme case is a dataset composed of the two points (0, 0) and (1, 1), rotated anticlockwise by 45 degrees. The + exact extent is the bounding box [minx, miny, maxx, maxy] = [0, 0, 0, 1.414], while the approximate extent is the + box [minx, miny, maxx, maxy] = [-0.707, 0, 0.707, 1.414]. """ raise ValueError("The object type is not supported.") @@ -136,11 +178,12 @@ def get_extent(e: SpatialData | SpatialElement, coordinate_system: str = "global def _( e: SpatialData, coordinate_system: str = "global", + exact: bool = True, has_images: bool = True, has_labels: bool = True, has_points: bool = True, has_shapes: bool = True, - elements: Union[list[str], None] = None, + elements: list[str] | None = None, ) -> BoundingBoxDescription: """ Get the extent (bounding box) of a SpatialData object: the extent of the union of the extents of all its elements. @@ -174,7 +217,10 @@ def _( assert isinstance(transformations, dict) coordinate_systems = list(transformations.keys()) if coordinate_system in coordinate_systems: - extent = get_extent(element_obj, coordinate_system=coordinate_system) + if isinstance(element_obj, (DaskDataFrame, GeoDataFrame)): + extent = get_extent(element_obj, coordinate_system=coordinate_system, exact=exact) + else: + extent = get_extent(element_obj, coordinate_system=coordinate_system) axes = list(extent.keys()) for ax in axes: new_min_coordinates_dict[ax] += [extent[ax][0]] @@ -183,8 +229,14 @@ def _( raise ValueError( f"The SpatialData object does not contain any element in the " f" coordinate system {coordinate_system!r}, " - f"please pass a different coordinate system wiht the argument 'coordinate_system'." + f"please pass a different coordinate system with the argument 'coordinate_system'." ) + if len(new_min_coordinates_dict) == 0: + raise ValueError( + f"The SpatialData object does not contain any element in the coordinate system {coordinate_system!r}, " + "please pass a different coordinate system with the argument 'coordinate_system'." + ) + axes = list(new_min_coordinates_dict.keys()) new_min_coordinates = np.array([min(new_min_coordinates_dict[ax]) for ax in axes]) new_max_coordinates = np.array([max(new_max_coordinates_dict[ax]) for ax in axes]) extent = {} @@ -193,8 +245,21 @@ def _( return extent +def _get_extent_of_shapes(e: GeoDataFrame) -> BoundingBoxDescription: + # remove potentially empty geometries + e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)] + assert len(e_temp) > 0, "Cannot compute extent of an empty collection of geometries." + + # separate points from (multi-)polygons + first_geometry = e_temp["geometry"].iloc[0] + if isinstance(first_geometry, Point): + return _get_extent_of_circles(e) + assert isinstance(first_geometry, (Polygon, MultiPolygon)) + return _get_extent_of_polygons_multipolygons(e) + + @get_extent.register -def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription: +def _(e: GeoDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: """ Compute the extent (bounding box) of a set of shapes. @@ -203,57 +268,33 @@ def _(e: GeoDataFrame, coordinate_system: str = "global") -> BoundingBoxDescript The bounding box description. """ _check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system) - # remove potentially empty geometries - e_temp = e[e["geometry"].apply(lambda geom: not geom.is_empty)] - - # separate points from (multi-)polygons - e_points = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, Point))] - e_polygons = e_temp[e_temp["geometry"].apply(lambda geom: isinstance(geom, (Polygon, MultiPolygon)))] - extent = None - if len(e_points) > 0: - assert "radius" in e_points.columns, "Shapes that are points must have a 'radius' column." - extent = _get_extent_of_circles(e_points) - if len(e_polygons) > 0: - extent_polygons = _get_extent_of_polygons_multipolygons(e_polygons) - if extent is None: - extent = extent_polygons - else: - # case when there are points AND (multi-)polygons in the GeoDataFrame - extent["y"] = (min(extent["y"][0], extent_polygons["y"][0]), max(extent["y"][1], extent_polygons["y"][1])) - extent["x"] = (min(extent["x"][0], extent_polygons["x"][0]), max(extent["x"][1], extent_polygons["x"][1])) - - if extent is None: - raise ValueError( - "Unable to compute extent of GeoDataFrame. It needs to contain at least one non-empty " - "Point or Polygon or Multipolygon." + if not exact: + extent = _get_extent_of_shapes(e) + return _compute_extent_in_coordinate_system( + element=e, + coordinate_system=coordinate_system, + extent=extent, ) - - min_coordinates = [extent["y"][0], extent["x"][0]] - max_coordinates = [extent["y"][1], extent["x"][1]] - axes = tuple(extent.keys()) - - return _compute_extent_in_coordinate_system( - element=e_temp, - coordinate_system=coordinate_system, - min_coordinates=np.array(min_coordinates), - max_coordinates=np.array(max_coordinates), - axes=axes, - ) + t = get_transformation(e, to_coordinate_system=coordinate_system) + assert isinstance(t, BaseTransformation) + transformed = transform(e, t) + return _get_extent_of_shapes(transformed) @get_extent.register -def _(e: DaskDataFrame, coordinate_system: str = "global") -> BoundingBoxDescription: +def _(e: DaskDataFrame, coordinate_system: str = "global", exact: bool = True) -> BoundingBoxDescription: _check_element_has_coordinate_system(element=e, coordinate_system=coordinate_system) - axes = get_axes_names(e) - min_coordinates = np.array([e[ax].min().compute() for ax in axes]) - max_coordinates = np.array([e[ax].max().compute() for ax in axes]) - return _compute_extent_in_coordinate_system( - element=e, - coordinate_system=coordinate_system, - min_coordinates=min_coordinates, - max_coordinates=max_coordinates, - axes=axes, - ) + if not exact: + extent = _get_extent_of_points(e) + return _compute_extent_in_coordinate_system( + element=e, + coordinate_system=coordinate_system, + extent=extent, + ) + t = get_transformation(e, to_coordinate_system=coordinate_system) + assert isinstance(t, BaseTransformation) + transformed = transform(e, t) + return _get_extent_of_points(transformed) @get_extent.register @@ -275,16 +316,12 @@ def _check_element_has_coordinate_system(element: SpatialElement, coordinate_sys if coordinate_system not in coordinate_systems: raise ValueError( f"The element does not contain any coordinate system named {coordinate_system!r}, " - f"please pass a different coordinate system wiht the argument 'coordinate_system'." + f"please pass a different coordinate system with the argument 'coordinate_system'." ) def _compute_extent_in_coordinate_system( - element: SpatialElement | DataArray, - coordinate_system: str, - min_coordinates: ArrayLike, - max_coordinates: ArrayLike, - axes: tuple[str, ...], + element: SpatialElement | DataArray, coordinate_system: str, extent: BoundingBoxDescription ) -> BoundingBoxDescription: """ Transform the extent from the intrinsic coordinates of the element to the given coordinate system. @@ -295,12 +332,8 @@ def _compute_extent_in_coordinate_system( The SpatialElement. coordinate_system The coordinate system to transform the extent to. - min_coordinates - Min coordinates of the extent in the intrinsic coordinates of the element, expects [y_min, x_min]. - max_coordinates - Max coordinates of the extent in the intrinsic coordinates of the element, expects [y_max, x_max]. - axes - The min and max coordinates refer to. + extent + The extent in the intrinsic coordinates of the element. Returns ------- @@ -310,6 +343,11 @@ def _compute_extent_in_coordinate_system( assert isinstance(transformation, BaseTransformation) from spatialdata._core.query._utils import get_bounding_box_corners + axes = get_axes_names(element) + if "c" in axes: + axes = tuple(ax for ax in axes if ax != "c") + min_coordinates = np.array([extent[ax][0] for ax in axes]) + max_coordinates = np.array([extent[ax][1] for ax in axes]) corners = get_bounding_box_corners( axes=axes, min_coordinate=min_coordinates, diff --git a/src/spatialdata/models/_utils.py b/src/spatialdata/models/_utils.py index 2b93204c..cf139d5e 100644 --- a/src/spatialdata/models/_utils.py +++ b/src/spatialdata/models/_utils.py @@ -5,7 +5,6 @@ import dask.dataframe as dd import geopandas -from anndata import AnnData from dask.dataframe import DataFrame as DaskDataFrame from geopandas import GeoDataFrame from multiscale_spatial_image import MultiscaleSpatialImage @@ -167,15 +166,15 @@ def _(e: MultiscaleSpatialImage) -> tuple[str, ...]: @get_axes_names.register(GeoDataFrame) def _(e: GeoDataFrame) -> tuple[str, ...]: - all_dims = (Z, Y, X) + all_dims = (X, Y, Z) n = e.geometry.iloc[0]._ndim - dims = all_dims[-n:] + dims = all_dims[:n] _validate_dims(dims) return dims @get_axes_names.register(DaskDataFrame) -def _(e: AnnData) -> tuple[str, ...]: +def _(e: DaskDataFrame) -> tuple[str, ...]: valid_dims = (X, Y, Z) dims = tuple([c for c in valid_dims if c in e.columns]) _validate_dims(dims) diff --git a/tests/core/test_data_extent.py b/tests/core/test_data_extent.py index 5082b461..d94a2b88 100644 --- a/tests/core/test_data_extent.py +++ b/tests/core/test_data_extent.py @@ -1,26 +1,35 @@ +import math + import numpy as np import pandas as pd import pytest from geopandas import GeoDataFrame -from shapely.geometry import Polygon +from shapely.geometry import MultiPolygon, Point, Polygon from spatialdata import SpatialData, get_extent, transform from spatialdata._utils import _deepcopy_geodataframe from spatialdata.datasets import blobs from spatialdata.models import PointsModel, ShapesModel -from spatialdata.transformations import Translation, remove_transformation, set_transformation +from spatialdata.transformations import Affine, Translation, remove_transformation, set_transformation # for faster tests; we will pay attention not to modify the original data sdata = blobs() -def check_test_results(extent, min_coordinates, max_coordinates, axes): - assert np.allclose([extent["x"][0], extent["y"][0]], min_coordinates) - assert np.allclose([extent["x"][1], extent["y"][1]], max_coordinates) +def check_test_results0(extent, min_coordinates, max_coordinates, axes): + for i, ax in enumerate(axes): + assert np.isclose(extent[ax][0], min_coordinates[i]) + assert np.isclose(extent[ax][1], max_coordinates[i]) extend_axes = list(extent.keys()) extend_axes.sort() assert tuple(extend_axes) == axes +def check_test_results1(extent0, extent1): + assert extent0.keys() == extent1.keys() + for ax in extent0: + assert np.allclose(extent0[ax], extent1[ax]) + + @pytest.mark.parametrize("shape_type", ["circles", "polygons", "multipolygons"]) def test_get_extent_shapes(shape_type): extent = get_extent(sdata[f"blobs_{shape_type}"]) @@ -35,7 +44,7 @@ def test_get_extent_shapes(shape_type): min_coordinates = np.array([291.06219195, 197.06539872]) max_coordinates = np.array([389.3319439, 375.89584037]) - check_test_results( + check_test_results0( extent, min_coordinates=min_coordinates, max_coordinates=max_coordinates, @@ -46,7 +55,7 @@ def test_get_extent_shapes(shape_type): def test_get_extent_points(): # 2d case extent = get_extent(sdata["blobs_points"]) - check_test_results( + check_test_results0( extent, min_coordinates=np.array([12.0, 13.0]), max_coordinates=np.array([500.0, 498.0]), @@ -58,7 +67,7 @@ def test_get_extent_points(): df = pd.DataFrame(data, columns=["zeta", "x", "y"]) points_3d = PointsModel.parse(df, coordinates={"x": "x", "y": "y", "z": "zeta"}) extent_3d = get_extent(points_3d) - check_test_results( + check_test_results0( extent_3d, min_coordinates=np.array([2, 3, 1]), max_coordinates=np.array([5, 6, 4]), @@ -72,7 +81,7 @@ def test_get_extent_raster(raster_type, multiscale): raster = sdata[f"blobs_multiscale_{raster_type}"] if multiscale else sdata[f"blobs_{raster_type}"] extent = get_extent(raster) - check_test_results( + check_test_results0( extent, min_coordinates=np.array([0, 0]), max_coordinates=np.array([512, 512]), @@ -83,7 +92,7 @@ def test_get_extent_raster(raster_type, multiscale): def test_get_extent_spatialdata(): sdata2 = SpatialData(shapes={"circles": sdata["blobs_circles"], "polygons": sdata["blobs_polygons"]}) extent = get_extent(sdata2) - check_test_results( + check_test_results0( extent, min_coordinates=np.array([98.92618679, 137.62348969]), max_coordinates=np.array([446.70264371, 461.85209239]), @@ -100,7 +109,127 @@ def test_get_extent_invalid_coordinate_system(): _ = get_extent(sdata, coordinate_system="invalid") +def _rotate_point(point: tuple[float, float], angle_degrees=45) -> tuple[float, float]: + angle_radians = math.radians(angle_degrees) + x, y = point + + x_prime = x * math.cos(angle_radians) - y * math.sin(angle_radians) + y_prime = x * math.sin(angle_radians) + y * math.cos(angle_radians) + + return (x_prime, y_prime) + + +@pytest.mark.parametrize("exact", [True, False]) +def test_rotate_vector_data(exact): + """ + To test for the ability to correctly compute the exact and approximate extent of vector datasets. + In particular tests for the solution to this issue: https://github.com/scverse/spatialdata/issues/353 + """ + circles = [] + for p in [[0.5, 0.1], [0.9, 0.5], [0.5, 0.9], [0.1, 0.5]]: + circles.append(Point(p)) + circles_gdf = GeoDataFrame(geometry=circles) + circles_gdf["radius"] = 0.1 + circles_gdf = ShapesModel.parse(circles_gdf) + + polygons = [] + polygons.append(Polygon([(0.5, 0.5), (0.5, 0), (0.6, 0.1), (0.5, 0.5)])) + polygons.append(Polygon([(0.5, 0.5), (1, 0.5), (0.9, 0.6), (0.5, 0.5)])) + polygons.append(Polygon([(0.5, 0.5), (0.5, 1), (0.4, 0.9), (0.5, 0.5)])) + polygons.append(Polygon([(0.5, 0.5), (0, 0.5), (0.1, 0.4), (0.5, 0.5)])) + polygons_gdf = GeoDataFrame(geometry=polygons) + polygons_gdf = ShapesModel.parse(polygons_gdf) + + multipolygons = [] + multipolygons.append(MultiPolygon([polygons[0], Polygon([(0.7, 0.1), (0.9, 0.1), (0.9, 0.3), (0.7, 0.1)])])) + multipolygons.append(MultiPolygon([polygons[1], Polygon([(0.9, 0.7), (0.9, 0.9), (0.7, 0.9), (0.9, 0.7)])])) + multipolygons.append(MultiPolygon([polygons[2], Polygon([(0.3, 0.9), (0.1, 0.9), (0.1, 0.7), (0.3, 0.9)])])) + multipolygons.append(MultiPolygon([polygons[3], Polygon([(0.1, 0.3), (0.1, 0.1), (0.3, 0.1), (0.1, 0.3)])])) + multipolygons_gdf = GeoDataFrame(geometry=multipolygons) + multipolygons_gdf = ShapesModel.parse(multipolygons_gdf) + + points_df = PointsModel.parse(np.array([[0.5, 0], [1, 0.5], [0.5, 1], [0, 0.5]])) + + sdata = SpatialData( + shapes={"circles": circles_gdf, "polygons": polygons_gdf, "multipolygons": multipolygons_gdf}, + points={"points": points_df}, + ) + + theta = math.pi / 4 + rotation = Affine( + [ + [math.cos(theta), -math.sin(theta), 0], + [math.sin(theta), math.cos(theta), 0], + [0, 0, 1], + ], + input_axes=("x", "y"), + output_axes=("x", "y"), + ) + for element_name in ["circles", "polygons", "multipolygons", "points"]: + set_transformation(element=sdata[element_name], transformation=rotation, to_coordinate_system="transformed") + + # manually computing the extent results and verifying it is correct + for e in [sdata, circles_gdf, polygons_gdf, multipolygons_gdf, points_df]: + extent = get_extent(e, coordinate_system="global") + check_test_results1(extent, {"x": (0.0, 1.0), "y": (0.0, 1.0)}) + + EXPECTED_NON_EXACT = {"x": (-math.sqrt(2) / 2, math.sqrt(2) / 2), "y": (0.0, math.sqrt(2))} + extent = get_extent(circles_gdf, coordinate_system="transformed", exact=exact) + if exact: + expected = { + "x": (_rotate_point((0.1, 0.5))[0] - 0.1, _rotate_point((0.5, 0.1))[0] + 0.1), + "y": (_rotate_point((0.5, 0.1))[1] - 0.1, _rotate_point((0.9, 0.5))[1] + 0.1), + } + else: + expected = EXPECTED_NON_EXACT + check_test_results1(extent, expected) + + extent = get_extent(polygons_gdf, coordinate_system="transformed", exact=exact) + if exact: + expected = { + "x": (_rotate_point((0, 0.5))[0], _rotate_point((0.5, 0))[0]), + "y": (_rotate_point((0.5, 0))[1], _rotate_point((1, 0.5))[1]), + } + else: + expected = EXPECTED_NON_EXACT + check_test_results1(extent, expected) + + extent = get_extent(multipolygons_gdf, coordinate_system="transformed", exact=exact) + if exact: + expected = { + "x": (_rotate_point((0.1, 0.9))[0], _rotate_point((0.9, 0.1))[0]), + "y": (_rotate_point((0.1, 0.1))[1], _rotate_point((0.9, 0.9))[1]), + } + else: + expected = EXPECTED_NON_EXACT + check_test_results1(extent, expected) + + extent = get_extent(points_df, coordinate_system="transformed", exact=exact) + if exact: + expected = { + "x": (_rotate_point((0, 0.5))[0], _rotate_point((0.5, 0))[0]), + "y": (_rotate_point((0.5, 0))[1], _rotate_point((1, 0.5))[1]), + } + else: + expected = EXPECTED_NON_EXACT + check_test_results1(extent, expected) + + extent = get_extent(sdata, coordinate_system="transformed", exact=exact) + if exact: + expected = { + "x": (_rotate_point((0.1, 0.9))[0], _rotate_point((0.9, 0.1))[0]), + "y": (_rotate_point((0.1, 0.1))[1], _rotate_point((0.9, 0.9))[1]), + } + else: + expected = EXPECTED_NON_EXACT + check_test_results1(extent, expected) + + def test_get_extent_affine_circles(): + """ + Verify that the extent of the transformed circles, computed with exact = False, gives the same result as + transforming the bounding box of the original circles + """ from tests.core.operations.test_transform import _get_affine affine = _get_affine(small_translation=True) @@ -111,21 +240,21 @@ def test_get_extent_affine_circles(): set_transformation(element=circles, transformation=affine, to_coordinate_system="transformed") extent = get_extent(circles) - transformed_extent = get_extent(circles, coordinate_system="transformed") - - assert extent[2] == transformed_extent[2] - assert not np.allclose(extent[0], transformed_extent[0]) - assert not np.allclose(extent[1], transformed_extent[1]) + transformed_extent = get_extent(circles, coordinate_system="transformed", exact=False) - min_coordinates, max_coordinates, axes = extent + axes = list(extent.keys()) + transformed_axes = list(extent.keys()) + assert axes == transformed_axes + for ax in axes: + assert not np.allclose(extent[ax], transformed_extent[ax]) # Create a list of points points = [ - (min_coordinates[0], min_coordinates[1]), # lower left corner - (min_coordinates[0], max_coordinates[1]), # upper left corner - (max_coordinates[0], max_coordinates[1]), # upper right corner - (max_coordinates[0], min_coordinates[1]), # lower right corner - (min_coordinates[0], min_coordinates[1]), # back to start to close the polygon + (extent["x"][0], extent["y"][0]), # lower left corner + (extent["x"][0], extent["y"][1]), # upper left corner + (extent["x"][1], extent["y"][1]), # upper right corner + (extent["x"][1], extent["y"][0]), # lower right corner + (extent["x"][0], extent["y"][0]), # back to start to close the polygon ] # Create a Polygon from the points @@ -134,12 +263,11 @@ def test_get_extent_affine_circles(): gdf = ShapesModel.parse(gdf) transformed_bounding_box = transform(gdf, affine) - min_coordinates0, max_coordinates0, axes0 = transformed_extent - min_coordinates1, max_coordinates1, axes1 = get_extent(transformed_bounding_box) + transformed_bounding_box_extent = get_extent(transformed_bounding_box) - assert np.allclose(min_coordinates0, min_coordinates1) - assert np.allclose(max_coordinates0, max_coordinates1) - assert axes0 == axes1 + assert transformed_axes == list(transformed_bounding_box_extent.keys()) + for ax in transformed_axes: + assert np.allclose(transformed_extent[ax], transformed_bounding_box_extent[ax]) def test_get_extent_affine_points3d(): @@ -159,16 +287,15 @@ def test_get_extent_affine_points3d(): transformed_extent_2d = get_extent(points_2d, coordinate_system="transformed") transformed_extent_3d = get_extent(points_3d, coordinate_system="transformed") - assert transformed_extent_2d[2] == ("x", "y") - assert transformed_extent_3d[2] == ("x", "y", "z") + assert list(transformed_extent_2d.keys()) == ["x", "y"] + assert list(transformed_extent_3d.keys()) == ["x", "y", "z"] # the x and y extent for the 2d and 3d points are identical - assert np.allclose(transformed_extent_2d[0], transformed_extent_3d[0][:2]) - assert np.allclose(transformed_extent_2d[1], transformed_extent_3d[1][:2]) + for ax in ["x", "y"]: + assert np.allclose(transformed_extent_2d[ax], transformed_extent_3d[ax]) # the z extent for the 3d points didn't get transformed, so it's the same as the original - assert np.allclose(transformed_extent_3d[0][2], extent_3d[0][2]) - assert np.allclose(transformed_extent_3d[1][2], extent_3d[1][2]) + assert np.allclose(transformed_extent_3d["z"], extent_3d["z"]) def test_get_extent_affine_sdata(): @@ -193,14 +320,14 @@ def test_get_extent_affine_sdata(): min_coordinates1 = np.array([149.92618679, 188.62348969]) + np.array([1000.0, 0.0]) max_coordinates1 = np.array([446.70264371, 461.85209239]) + np.array([1000.0, 0.0]) - check_test_results( + check_test_results0( extent0, min_coordinates=min_coordinates0, max_coordinates=max_coordinates0, axes=("x", "y"), ) - check_test_results( + check_test_results0( extent1, min_coordinates=min_coordinates1, max_coordinates=max_coordinates1,