Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor data loader #299

Merged
merged 47 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
45ab2df
fix docs
giovp Jun 15, 2023
0a7eecb
update
giovp Jun 15, 2023
6cff536
get tile centroid and eextent outside of function
giovp Jun 16, 2023
638dd33
add return type
giovp Jun 17, 2023
1736377
move return table out of init
giovp Jun 17, 2023
1ee6abb
add comments
giovp Jun 17, 2023
288e28a
merge branch 'main' into feat/dataloader
giovp Jun 20, 2023
9eb6854
Merge branch 'main' into feat/dataloader
giovp Jun 20, 2023
9079490
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 20, 2023
38cfcff
update precommit
giovp Jun 20, 2023
5dbd876
Merge branch 'main' into feat/dataloader
giovp Jun 21, 2023
b9a5b9e
simplify return
giovp Jun 21, 2023
1d2f3cc
update tests and simplify
giovp Jun 21, 2023
dd03cdf
add tests
giovp Jun 22, 2023
f846caf
fix tests for pandas
giovp Jun 22, 2023
58b15c8
fix import and docs
giovp Jun 22, 2023
43065a9
update api
giovp Jun 22, 2023
556b87b
update import
giovp Jun 22, 2023
ffe73ec
fix imports
giovp Jun 23, 2023
05f5db7
update api
giovp Jun 23, 2023
bba86ef
fix api
giovp Jun 23, 2023
b37f40c
try fix docs
giovp Jun 23, 2023
a96b6fd
try fix docs
giovp Jun 23, 2023
ceb6f5b
add optional import of dataloader
giovp Jun 23, 2023
d290845
minor fixes
giovp Jun 23, 2023
7878379
fix test
giovp Jun 27, 2023
2f0b5e5
Merge branch 'main' into feat/dataloader
giovp Jun 29, 2023
9b84084
fixed typos
LucaMarconato Jun 29, 2023
13da142
Update src/spatialdata/dataloader/datasets.py
giovp Jul 13, 2023
13dc9be
update
giovp Jul 13, 2023
6ef34a5
update with more comments
giovp Jul 14, 2023
a65a855
fix precommit
giovp Jul 14, 2023
8ef16eb
modified docstring for transform in ImageTilesDataset
LucaMarconato Jul 14, 2023
d24bd7c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 14, 2023
61f347b
Merge branch 'main' into feat/dataloader
giovp Jul 19, 2023
f32d472
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 19, 2023
fc0528a
Merge branch 'main' into feat/dataloader
giovp Jul 26, 2023
854360d
Merge branch 'main' into feat/dataloader
giovp Nov 2, 2023
de79554
Merge branch 'main' into feat/dataloader
LucaMarconato Nov 2, 2023
77a590f
Merge branch 'main' into feat/dataloader
giovp Nov 27, 2023
95b5b49
tryf ixing docs
giovp Nov 27, 2023
16977e0
fix tests again
giovp Nov 27, 2023
4923ee6
fix tests
giovp Nov 27, 2023
63066ea
update
giovp Nov 27, 2023
000f83b
fix docs
giovp Nov 27, 2023
fa539f5
update
giovp Nov 27, 2023
63e761d
fix tests and docs and remove
giovp Nov 27, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ build:
python: "3.10"
sphinx:
configuration: docs/conf.py
fail_on_warning: false
fail_on_warning: true
python:
install:
- method: pip
Expand Down
6 changes: 5 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ Operations on `SpatialData` objects.
match_table_to_element
concatenate
rasterize
transform
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

restoring this, since it's imported from spatialdata._core.operations

aggregate
```

Expand All @@ -48,6 +47,7 @@ The elements (building-blocks) that consitute `SpatialData`.

```{eval-rst}
.. currentmodule:: spatialdata.models

.. autosummary::
:toctree: generated

Expand All @@ -63,6 +63,8 @@ The elements (building-blocks) that consitute `SpatialData`.
### Utilities

```{eval-rst}
.. currentmodule:: spatialdata.models

.. autosummary::
:toctree: generated

Expand Down Expand Up @@ -96,6 +98,8 @@ The transformations that can be defined between elements and coordinate systems
### Utilities

```{eval-rst}
.. currentmodule:: spatialdata.transformations

.. autosummary::
:toctree: generated

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ dev = [
docs = [
"sphinx>=4.5",
"sphinx-book-theme>=1.0.0",
"sphinx_rtd_theme",
"myst-nb",
"sphinxcontrib-bibtex>=1.0.0",
"sphinx-autodoc-typehints",
Expand Down
8 changes: 2 additions & 6 deletions src/spatialdata/_core/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,12 @@

from copy import copy # Should probably go up at the top
from itertools import chain
from typing import TYPE_CHECKING, Any
from typing import Any

import numpy as np
from anndata import AnnData

if TYPE_CHECKING:
from spatialdata._core.spatialdata import SpatialData

from spatialdata._core.spatialdata import SpatialData
from spatialdata.models import TableModel

__all__ = [
Expand Down Expand Up @@ -94,8 +92,6 @@ def concatenate(
-------
The concatenated :class:`spatialdata.SpatialData` object.
"""
from spatialdata import SpatialData

merged_images = {**{k: v for sdata in sdatas for k, v in sdata.images.items()}}
if len(merged_images) != np.sum([len(sdata.images) for sdata in sdatas]):
raise KeyError("Images must have unique names across the SpatialData objects to concatenate")
Expand Down
7 changes: 2 additions & 5 deletions src/spatialdata/_core/operations/aggregate.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any
from typing import Any

import anndata as ad
import dask as da
Expand All @@ -20,6 +20,7 @@
from spatialdata._core.operations.transform import transform
from spatialdata._core.query._utils import circles_to_polygons
from spatialdata._core.query.relational_query import get_values
from spatialdata._core.spatialdata import SpatialData
from spatialdata._types import ArrayLike
from spatialdata.models import (
Image2DModel,
Expand All @@ -32,9 +33,6 @@
)
from spatialdata.transformations import BaseTransformation, Identity, get_transformation

if TYPE_CHECKING:
from spatialdata import SpatialData

__all__ = ["aggregate"]


Expand Down Expand Up @@ -236,7 +234,6 @@ def _create_sdata_from_table_and_shapes(
instance_key: str,
deepcopy: bool,
) -> SpatialData:
from spatialdata import SpatialData
from spatialdata._utils import _deepcopy_geodataframe

table.obs[instance_key] = table.obs_names.copy()
Expand Down
2 changes: 0 additions & 2 deletions src/spatialdata/_core/operations/rasterize.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ def _(
target_height: Optional[float] = None,
target_depth: Optional[float] = None,
) -> SpatialData:
from spatialdata import SpatialData

min_coordinate = _parse_list_into_array(min_coordinate)
max_coordinate = _parse_list_into_array(max_coordinate)

Expand Down
6 changes: 2 additions & 4 deletions src/spatialdata/_core/query/relational_query.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import Any

import dask.array as da
import numpy as np
Expand All @@ -11,6 +11,7 @@
from multiscale_spatial_image import MultiscaleSpatialImage
from spatial_image import SpatialImage

from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import _inplace_fix_subset_categorical_obs
from spatialdata.models import (
Labels2DModel,
Expand All @@ -22,9 +23,6 @@
get_model,
)

if TYPE_CHECKING:
from spatialdata import SpatialData


def _filter_table_by_coordinate_system(table: AnnData | None, coordinate_system: str | list[str]) -> AnnData | None:
"""
Expand Down
1 change: 0 additions & 1 deletion src/spatialdata/_core/query/spatial_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,6 @@ def _(
target_coordinate_system: str,
filter_table: bool = True,
) -> SpatialData:
from spatialdata import SpatialData
from spatialdata._core.query.relational_query import _filter_table_by_elements

min_coordinate = _parse_list_into_array(min_coordinate)
Expand Down
30 changes: 18 additions & 12 deletions src/spatialdata/_core/spatialdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,27 +18,17 @@
from pyarrow.parquet import read_table
from spatial_image import SpatialImage

from spatialdata._io import (
write_image,
write_labels,
write_points,
write_shapes,
write_table,
)
from spatialdata._io._utils import get_backing_files
from spatialdata._logging import logger
from spatialdata._types import ArrayLike
from spatialdata._utils import _natural_keys
from spatialdata.models import (
from spatialdata.models._utils import SpatialElement, get_axes_names
from spatialdata.models.models import (
giovp marked this conversation as resolved.
Show resolved Hide resolved
Image2DModel,
Image3DModel,
Labels2DModel,
Labels3DModel,
PointsModel,
ShapesModel,
SpatialElement,
TableModel,
get_axes_names,
get_model,
)

Expand Down Expand Up @@ -653,6 +643,9 @@ def add_image(
-----
If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage.
"""
from spatialdata._io._utils import get_backing_files
from spatialdata._io.io_raster import write_image

if self.is_backed():
files = get_backing_files(image)
assert self.path is not None
Expand Down Expand Up @@ -736,6 +729,9 @@ def add_labels(
-----
If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage.
"""
from spatialdata._io._utils import get_backing_files
from spatialdata._io.io_raster import write_labels

if self.is_backed():
files = get_backing_files(labels)
assert self.path is not None
Expand Down Expand Up @@ -820,6 +816,9 @@ def add_points(
-----
If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage.
"""
from spatialdata._io._utils import get_backing_files
from spatialdata._io.io_points import write_points

if self.is_backed():
files = get_backing_files(points)
assert self.path is not None
Expand Down Expand Up @@ -902,6 +901,8 @@ def add_shapes(
-----
If the SpatialData object is backed by a Zarr storage, the image will be written to the Zarr storage.
"""
from spatialdata._io.io_shapes import write_shapes

self._add_shapes_in_memory(name=name, shapes=shapes, overwrite=overwrite)
if self.is_backed():
elem_group = self._init_add_element(name=name, element_type="shapes", overwrite=overwrite)
Expand All @@ -918,6 +919,8 @@ def write(
storage_options: JSONDict | list[JSONDict] | None = None,
overwrite: bool = False,
) -> None:
from spatialdata._io import write_image, write_labels, write_points, write_shapes, write_table

"""Write the SpatialData object to Zarr."""
if isinstance(file_path, str):
file_path = Path(file_path)
Expand Down Expand Up @@ -1113,6 +1116,8 @@ def table(self, table: AnnData) -> None:
The table needs to pass validation (see :class:`~spatialdata.TableModel`).
If the SpatialData object is backed by a Zarr storage, the table will be written to the Zarr storage.
"""
from spatialdata._io.io_table import write_table

TableModel().validate(table)
if self.table is not None:
raise ValueError("The table already exists. Use del sdata.table to remove it first.")
Expand Down Expand Up @@ -1199,6 +1204,7 @@ def _gen_repr(
-------
The string representation of the SpatialData object.
"""
from spatialdata._utils import _natural_keys

def rreplace(s: str, old: str, new: str, occurrence: int) -> str:
li = s.rsplit(old, occurrence)
Expand Down
8 changes: 2 additions & 6 deletions src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from collections.abc import Generator, Mapping
from contextlib import contextmanager
from functools import singledispatch
from typing import TYPE_CHECKING, Any
from typing import Any

import zarr
from dask.dataframe.core import DataFrame as DaskDataFrame
Expand All @@ -18,6 +18,7 @@
from spatial_image import SpatialImage
from xarray import DataArray

from spatialdata._core.spatialdata import SpatialData
from spatialdata._utils import iterate_pyramid_levels
from spatialdata.models._utils import (
MappingToCoordinateSystem_t,
Expand All @@ -30,9 +31,6 @@
_get_current_output_axes,
)

if TYPE_CHECKING:
from spatialdata import SpatialData


# suppress logger debug from ome_zarr with context manager
@contextmanager
Expand Down Expand Up @@ -196,8 +194,6 @@ def _are_directories_identical(


def _compare_sdata_on_disk(a: SpatialData, b: SpatialData) -> bool:
from spatialdata import SpatialData

if not isinstance(a, SpatialData) or not isinstance(b, SpatialData):
return False
# TODO: if the sdata object is backed on disk, don't create a new zarr file
Expand Down
2 changes: 1 addition & 1 deletion src/spatialdata/_io/io_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from anndata import AnnData
from anndata import read_zarr as read_anndata_zarr

from spatialdata import SpatialData
from spatialdata._core.spatialdata import SpatialData
from spatialdata._io._utils import ome_zarr_logger
from spatialdata._io.io_points import _read_points
from spatialdata._io.io_raster import _read_multiscale
Expand Down
5 changes: 1 addition & 4 deletions src/spatialdata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import re
from collections.abc import Generator
from copy import deepcopy
from typing import TYPE_CHECKING, Union
from typing import Union

import numpy as np
import pandas as pd
Expand All @@ -26,9 +26,6 @@
# I was using "from numbers import Number" but this led to mypy errors, so I switched to the following:
Number = Union[int, float]

if TYPE_CHECKING:
pass


def _parse_list_into_array(array: list[Number] | ArrayLike) -> ArrayLike:
if isinstance(array, list):
Expand Down
8 changes: 3 additions & 5 deletions src/spatialdata/dataloader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import contextlib

with contextlib.suppress(ImportError):
try:
from spatialdata.dataloader.datasets import ImageTilesDataset

__all__ = ["ImageTilesDataset"]
except ImportError:
ImageTilesDataset = None # type: ignore[assignment, misc]
Loading