From 63e761df077b998a60e44b565696a9b9b79a46d0 Mon Sep 17 00:00:00 2001 From: giovp Date: Mon, 27 Nov 2023 16:27:19 +0100 Subject: [PATCH] fix tests and docs and remove --- src/spatialdata/dataloader/datasets.py | 4 ++-- src/spatialdata/models/models.py | 8 ++++---- tests/dataloader/test_datasets.py | 15 +++++++++++---- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/spatialdata/dataloader/datasets.py b/src/spatialdata/dataloader/datasets.py index 52bf04e6..388db612 100644 --- a/src/spatialdata/dataloader/datasets.py +++ b/src/spatialdata/dataloader/datasets.py @@ -375,9 +375,9 @@ def _get_tile_coords( # get extent, first by checking shape defaults, then by using the `tile_dim_in_units` if tile_dim_in_units is None: - if elem.iloc[0][0].geom_type == "Point": + if elem.iloc[0, 0].geom_type == "Point": extent = elem[ShapesModel.RADIUS_KEY].values * tile_scale - elif elem.iloc[0][0].geom_type in ["Polygon", "MultiPolygon"]: + elif elem.iloc[0, 0].geom_type in ["Polygon", "MultiPolygon"]: extent = elem[ShapesModel.GEOMETRY_KEY].length * tile_scale else: raise ValueError("Only point and polygon shapes are supported.") diff --git a/src/spatialdata/models/models.py b/src/spatialdata/models/models.py index f7155a3b..e27a08c3 100644 --- a/src/spatialdata/models/models.py +++ b/src/spatialdata/models/models.py @@ -18,7 +18,7 @@ from multiscale_spatial_image import to_multiscale from multiscale_spatial_image.multiscale_spatial_image import MultiscaleSpatialImage from multiscale_spatial_image.to_multiscale.to_multiscale import Methods -from pandas.api.types import is_categorical_dtype +from pandas import CategoricalDtype from shapely._geometry import GeometryType from shapely.geometry import MultiPolygon, Point, Polygon from shapely.geometry.collection import GeometryCollection @@ -470,7 +470,7 @@ def validate(cls, data: DaskDataFrame) -> None: raise ValueError(f":attr:`dask.dataframe.core.DataFrame.attrs` does not contain `{cls.TRANSFORM_KEY}`.") if cls.ATTRS_KEY in data.attrs and "feature_key" in data.attrs[cls.ATTRS_KEY]: feature_key = data.attrs[cls.ATTRS_KEY][cls.FEATURE_KEY] - if not is_categorical_dtype(data[feature_key]): + if not isinstance(data[feature_key], CategoricalDtype): logger.info(f"Feature key `{feature_key}`could be of type `pd.Categorical`. Consider casting it.") @singledispatchmethod @@ -624,7 +624,7 @@ def _add_metadata_and_validate( # Here we are explicitly importing the categories # but it is a convenient way to ensure that the categories are known. # It also just changes the state of the series, so it is not a big deal. - if is_categorical_dtype(data[c]) and not data[c].cat.known: + if isinstance(data[c], CategoricalDtype) and not data[c].cat.known: try: data[c] = data[c].cat.set_categories(data[c].head(1).cat.categories) except ValueError: @@ -729,7 +729,7 @@ def parse( region_: list[str] = region if isinstance(region, list) else [region] if not adata.obs[region_key].isin(region_).all(): raise ValueError(f"`adata.obs[{region_key}]` values do not match with `{cls.REGION_KEY}` values.") - if not is_categorical_dtype(adata.obs[region_key]): + if not isinstance(adata.obs[region_key], CategoricalDtype): warnings.warn( f"Converting `{cls.REGION_KEY_KEY}: {region_key}` to categorical dtype.", UserWarning, stacklevel=2 ) diff --git a/tests/dataloader/test_datasets.py b/tests/dataloader/test_datasets.py index 3b99b104..dac01e80 100644 --- a/tests/dataloader/test_datasets.py +++ b/tests/dataloader/test_datasets.py @@ -60,14 +60,18 @@ def test_default(self, sdata_blobs, regions_element, raster): if raster: assert tile.shape == (3, 329, 329) else: - assert tile.shape == (3, 164, 164) + assert tile.shape == (3, 165, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") + # extent has units in pixel so should be the same as tile shape if raster: assert round(ds.tiles_coords.extent.unique()[0] * 2) == tile.shape[1] else: - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert int(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] assert np.all(sdata_tile.table.obs.columns == ds.sdata.table.obs.columns) assert list(sdata_tile.images.keys())[0] == "blobs_image" @@ -88,11 +92,14 @@ def test_return_annot(self, sdata_blobs, regions_element, return_annot): elif regions_element == "blobs_polygons": assert tile.shape == (3, 82, 82) elif regions_element == "blobs_multipolygons": - assert tile.shape == (3, 164, 164) + assert tile.shape == (3, 165, 164) else: raise ValueError(f"Unexpected regions_element: {regions_element}") # extent has units in pixel so should be the same as tile shape - assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + if regions_element != "blobs_multipolygons": + assert int(ds.tiles_coords.extent.unique()[0]) == tile.shape[1] + else: + assert round(ds.tiles_coords.extent.unique()[0]) + 1 == tile.shape[1] return_annot = [return_annot] if isinstance(return_annot, str) else return_annot assert annot.shape[1] == len(return_annot)