Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance + bugfix of images and labels elements #127

Merged
merged 28 commits into from
Feb 24, 2023
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
86665b7
don't remove coords
giovp Feb 5, 2023
632c112
update iter multiscale
giovp Feb 17, 2023
0410cea
merge
giovp Feb 17, 2023
6e4a050
update spatial-image and multiscale-spatial_image version
giovp Feb 17, 2023
6565b12
remove check that is now in multiscale spatial image
giovp Feb 19, 2023
4833e63
add coordinates assignment
giovp Feb 19, 2023
2d8bbd2
add coordinates to parser
giovp Feb 19, 2023
22b18a9
Merge branch 'main' into models/images/add-coordinates
giovp Feb 19, 2023
c99c5dd
fix tests
giovp Feb 20, 2023
f562a3b
update tests for 3d
giovp Feb 20, 2023
aef337b
add tests for labels
giovp Feb 20, 2023
0dc0459
update shapely
giovp Feb 20, 2023
97f8d7e
add some comments
giovp Feb 20, 2023
90e0cc3
improve validation per #115
giovp Feb 20, 2023
6140832
remove multiscale_factors and add sclae_factors
giovp Feb 20, 2023
238fb8c
try fixing tests
giovp Feb 20, 2023
c67435a
updates
giovp Feb 20, 2023
3240659
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 20, 2023
55c0e7c
fix mypy
giovp Feb 20, 2023
4214abc
add channels to IO
giovp Feb 20, 2023
f65151e
update omero exlcuding it from labels
giovp Feb 20, 2023
448da01
use isel
giovp Feb 22, 2023
51d865d
read name from raster data
LucaMarconato Feb 23, 2023
697eae1
Merge branch 'main' into models/images/add-coordinates
LucaMarconato Feb 23, 2023
5245259
fix unpad_raster()
LucaMarconato Feb 23, 2023
0a752fe
all tests passing
LucaMarconato Feb 23, 2023
f20582a
fixed precommit
LucaMarconato Feb 23, 2023
7a9f140
remove reference to name
giovp Feb 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,12 @@ dependencies = [
"zarr",
# "ome_zarr",
"ome_zarr@git+https://github.com/LucaMarconato/ome-zarr-py@bug_fix_io",
"spatial_image",
"multiscale_spatial_image",
"spatial_image>=0.3.0",
"multiscale_spatial_image>=0.11.2",
"xarray-schema",
"pygeos",
"geopandas",
"shapely==2.0rc2",
"shapely>=2.0.1",
"rich",
"pyarrow",
"tqdm",
Expand Down
4 changes: 2 additions & 2 deletions spatialdata/_core/_transform_elements.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,10 @@ def _(data: MultiscaleSpatialImage, transformation: BaseTransformation) -> Multi
# assert np.allclose(almost_zero, np.zeros_like(almost_zero), rtol=2.)
try:
multiscale_factors.append(round(factors[0]))
except OverflowError as e:
except ValueError as e:
raise e
# mypy thinks that schema could be ShapesModel, PointsModel, ...
transformed_data = schema.parse(transformed_dask, dims=axes, multiscale_factors=multiscale_factors) # type: ignore[call-arg,arg-type]
transformed_data = schema.parse(transformed_dask, dims=axes, scale_factors=multiscale_factors) # type: ignore[call-arg,arg-type]
print(
"TODO: compose the transformation!!!! we need to put the previous one concatenated with the translation showen above. The translation operates before the other transformation"
)
Expand Down
113 changes: 99 additions & 14 deletions spatialdata/_core/core_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import copy
from functools import singledispatch
from typing import Any, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union

import numpy as np
from anndata import AnnData
Expand Down Expand Up @@ -255,7 +255,7 @@ def get_default_coordinate_system(dims: tuple[str, ...]) -> NgffCoordinateSystem
@singledispatch
def get_dims(e: SpatialElement) -> tuple[str, ...]:
"""
Get the dimensions of a spatial element
Get the dimensions of a spatial element.

Parameters
----------
Expand All @@ -264,8 +264,7 @@ def get_dims(e: SpatialElement) -> tuple[str, ...]:

Returns
-------
dims
Dimensions of the spatial element (e.g. ("z", "y", "x"))
Dimensions of the spatial element (e.g. ("z", "y", "x"))
"""
raise TypeError(f"Unsupported type: {type(e)}")

Expand All @@ -278,16 +277,10 @@ def _(e: SpatialImage) -> tuple[str, ...]:

@get_dims.register(MultiscaleSpatialImage)
def _(e: MultiscaleSpatialImage) -> tuple[str, ...]:
# luca: I prefer this first method
d = dict(e["scale0"])
assert len(d) == 1
dims0 = d.values().__iter__().__next__().dims
assert isinstance(dims0, tuple)
# still, let's do a runtime check against the other method
variables = list(e[list(e.keys())[0]].variables)
dims1 = e[list(e.keys())[0]][variables[0]].dims
assert dims0 == dims1
return dims0
if "scale0" in e:
return tuple(i for i in e["scale0"].dims.keys())
else:
return tuple(i for i in e.dims.keys())


@get_dims.register(GeoDataFrame)
Expand All @@ -309,3 +302,95 @@ def _(e: AnnData) -> tuple[str, ...]:
valid_dims = (X, Y, Z)
dims = [c for c in valid_dims if c in e.columns]
return tuple(dims)


@singledispatch
def compute_coordinates(
data: Union[SpatialImage, MultiscaleSpatialImage]
) -> Union[SpatialImage, MultiscaleSpatialImage]:
"""
Computes and assign coordinates to a (Multiscale)SpatialImage.

Parameters
----------
data
:class:`SpatialImage` or :class:`MultiscaleSpatialImage`.

Returns
-------
:class:`SpatialImage` or :class:`MultiscaleSpatialImage` with coordinates assigned.
"""
raise TypeError(f"Unsupported type: {type(data)}")


@compute_coordinates.register(SpatialImage)
def _(data: SpatialImage) -> SpatialImage:
coords: dict[str, ArrayLike] = {
d: np.arange(data.sizes[d], dtype=np.float_) for d in data.sizes.keys() if d in ["x", "y", "z"]
}
return data.assign_coords(coords)


@compute_coordinates.register(MultiscaleSpatialImage)
def _(data: MultiscaleSpatialImage) -> MultiscaleSpatialImage:
def _get_scale(transforms: dict[str, Any]) -> Optional[ArrayLike]:
for t in transforms["global"].transformations:
if hasattr(t, "scale"):
if TYPE_CHECKING:
assert isinstance(t.scale, np.ndarray)
return t.scale

def _compute_coords(max_: int, scale_f: Union[int, float]) -> ArrayLike:
return ( # type: ignore[no-any-return]
DataArray(np.linspace(0, max_, max_, endpoint=False, dtype=np.float_))
.coarsen(dim_0=scale_f, boundary="trim", side="right")
.mean()
.values
)

max_scale0 = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]}
img_name = list(data["scale0"].data_vars.keys())[0]
out = {}

for name, dt in data.items():
max_scale = {d: s for d, s in data["scale0"].sizes.items() if d in ["x", "y", "z"]}
if name == "scale0":
coords: dict[str, ArrayLike] = {d: np.arange(max_scale[d], dtype=np.float_) for d in max_scale.keys()}
out[name] = dt[img_name].assign_coords(coords)
else:
scalef = _get_scale(dt[img_name].attrs["transform"])
assert len(max_scale.keys()) == len(scalef), "Mismatch between coordinates and scales." # type: ignore[arg-type]
out[name] = dt[img_name].assign_coords(
{k: _compute_coords(max_scale0[k], round(s)) for k, s in zip(max_scale.keys(), scalef)} # type: ignore[arg-type]
)
return MultiscaleSpatialImage.from_dict(d=out)


@singledispatch
def get_channels(data: Any) -> list[Any]:
"""Get channels from data.

Parameters
----------
data
data to get channels from

Returns
-------
List of channels
"""
raise ValueError(f"Cannot get channels from {type(data)}")


@get_channels.register
def _(data: SpatialImage) -> list[Any]:
return data.coords["c"].values.tolist() # type: ignore[no-any-return]


@get_channels.register
def _(data: MultiscaleSpatialImage) -> list[Any]:
name = list({list(data[i].data_vars.keys())[0] for i in data.keys()})[0]
channels = {tuple(data[i][name].coords["c"].values) for i in data.keys()}
if len(channels) > 1:
raise ValueError("TODO")
return list(next(iter(channels)))
105 changes: 55 additions & 50 deletions spatialdata/_core/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
_get_transformations,
_set_transformations,
_validate_mapping_to_coordinate_system_type,
compute_coordinates,
get_dims,
)
from spatialdata._core.transformations import BaseTransformation, Identity
Expand Down Expand Up @@ -110,7 +111,7 @@ def parse(
data: Union[ArrayLike, DataArray, DaskArray],
dims: Optional[Sequence[str]] = None,
transformations: Optional[MappingToCoordinateSystem_t] = None,
multiscale_factors: Optional[ScaleFactors_t] = None,
scale_factors: Optional[ScaleFactors_t] = None,
method: Optional[Methods] = None,
chunks: Optional[Chunks_t] = None,
**kwargs: Any,
Expand All @@ -126,9 +127,9 @@ def parse(
Dimensions of the data.
transformations
Transformations to apply to the data.
multiscale_factors
scale_factors
Scale factors to apply for multiscale.
If not None, a :class:`multiscale_spatial_image.multiscale_spatial_image.MultiscaleSpatialImage` is returned.
If not None, a :class:`multiscale_spatial_image.MultiscaleSpatialImage` is returned.
method
Method to use for multiscale.
chunks
Expand All @@ -137,24 +138,26 @@ def parse(
Returns
-------
:class:`spatial_image.SpatialImage` or
:class:`multiscale_spatial_image.multiscale_spatial_image.MultiscaleSpatialImage`.
:class:`multiscale_spatial_image.MultiscaleSpatialImage`.
"""
# check if dims is specified and if it has correct values

# if dims is specified inside the data, get the value of dims from the data
if isinstance(data, DataArray) or isinstance(data, SpatialImage):
if not isinstance(data.data, DaskArray): # numpy -> dask
data.data = from_array(data.data)
if dims is not None:
if dims != data.dims:
if set(dims).symmetric_difference(data.dims):
raise ValueError(
f"`dims`: {dims} does not match `data.dims`: {data.dims}, please specify the dims only once."
)
else:
logger.info("`dims` is specified redundantly: found also inside `data`")
logger.info("`dims` is specified redundantly: found also inside `data`.")
else:
dims = data.dims # type: ignore[assignment]
dims = data.dims
# but if dims don't match the model's dims, throw error
if set(dims).symmetric_difference(cls.dims.dims):
raise ValueError(f"Wrong `dims`: {dims}. Expected {cls.dims.dims}.")
_reindex = lambda d: d
# if there are no dims in the data, use the model's dims or provided dims
elif isinstance(data, np.ndarray) or isinstance(data, DaskArray):
if not isinstance(data, DaskArray): # numpy -> dask
data = from_array(data)
Expand All @@ -178,51 +181,59 @@ def parse(
except ValueError:
raise ValueError(f"Cannot transpose arrays to match `dims`: {dims}. Try to reshape `data` or `dims`.")

# finally convert to spatial image
data = to_spatial_image(array_like=data, dims=cls.dims.dims, **kwargs)
assert isinstance(data, SpatialImage)
# TODO(giovp): drop coordinates for now until solution with IO.
data = data.drop(data.coords.keys())
# parse transformations
_parse_transformations(data, transformations)
if multiscale_factors is not None:
# check that the image pyramid doesn't contain axes that get collapsed and eventually truncates the list
# of downscaling factors to avoid this
adjusted_multiscale_factors: list[int] = []
assert isinstance(data, DataArray)
current_shape: ArrayLike = np.array(data.shape, dtype=float)
# multiscale_factors could be a dict, we don't support this case here (in the future this code and the
# more general case will be handled by multiscale-spatial-image)
assert isinstance(multiscale_factors, list)
for factor in multiscale_factors:
scale_vector = np.array([1.0 if ax == "c" else factor for ax in data.dims])
current_shape /= scale_vector
if current_shape.min() < 1:
logger.warning(
f"Detected a multiscale factor that would collapse an axis: truncating list of factors from {multiscale_factors} to {adjusted_multiscale_factors}"
)
break
adjusted_multiscale_factors.append(factor)
# convert to multiscale if needed
if scale_factors is not None:
parsed_transform = _get_transformations(data)
# delete transforms
del data.attrs["transform"]
data = to_multiscale(
data,
scale_factors=adjusted_multiscale_factors,
scale_factors=scale_factors,
method=method,
chunks=chunks,
)
_parse_transformations(data, parsed_transform)
assert isinstance(data, MultiscaleSpatialImage)
# recompute coordinates for (multiscale) spatial image
data = compute_coordinates(data)
return data

def validate(self, data: Union[SpatialImage, MultiscaleSpatialImage]) -> None:
if isinstance(data, SpatialImage):
super().validate(data)
elif isinstance(data, MultiscaleSpatialImage):
name = {list(data[i].data_vars.keys())[0] for i in data.keys()}
if len(name) > 1:
raise ValueError(f"Wrong name for datatree: {name}.")
name = list(name)[0]
for d in data:
super().validate(data[d][name])
@singledispatchmethod
def validate(self, data: Any) -> None:
"""
Validate data.

Parameters
----------
data
Data to validate.

Raises
------
ValueError
If data is not valid.
"""

raise ValueError(f"Unsupported data type: {type(data)}.")

@validate.register(SpatialImage)
def _(self, data: SpatialImage) -> None:
super().validate(data)

@validate.register(MultiscaleSpatialImage)
def _(self, data: MultiscaleSpatialImage) -> None:
for j, k in zip(data.keys(), [f"scale{i}" for i in np.arange(len(data.keys()))]):
if j != k:
raise ValueError(f"Wrong key for multiscale data, found: `{j}`, expected: `{k}`.")
name = {list(data[i].data_vars.keys())[0] for i in data.keys()}
if len(name) > 1:
raise ValueError(f"Wrong name for datatree: `{name}`.")
name = list(name)[0]
for d in data:
super().validate(data[d][name])


class Labels2DModel(RasterSchema):
Expand All @@ -234,8 +245,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
dims=self.dims,
array_type=self.array_type,
# suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115
# attrs=self.attrs,
*args,
**kwargs,
)
Expand All @@ -250,8 +259,6 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
dims=self.dims,
array_type=self.array_type,
# suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115
# attrs=self.attrs,
*args,
**kwargs,
)
Expand All @@ -266,8 +273,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
dims=self.dims,
array_type=self.array_type,
# suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115
# attrs=self.attrs,
attrs=self.attrs,
*args,
**kwargs,
)
Expand All @@ -282,8 +288,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(
dims=self.dims,
array_type=self.array_type,
# suppressing the check of .attrs['transform']; see https://github.com/scverse/spatialdata/issues/115
# attrs=self.attrs,
attrs=self.attrs,
*args,
**kwargs,
)
Expand Down
Loading