Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partial reading #765

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 60 additions & 1 deletion src/spatialdata/_io/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
import logging
import os.path
import re
import sys
import tempfile
import traceback
import warnings
from collections.abc import Generator, Mapping, Sequence
from contextlib import contextmanager
from enum import Enum
from functools import singledispatch
from pathlib import Path
from typing import Any
from typing import Any, Literal

import zarr
from anndata import AnnData
Expand Down Expand Up @@ -384,3 +387,59 @@ def save_transformations(sdata: SpatialData) -> None:
stacklevel=2,
)
sdata.write_transformations()


class BadFileHandleMethod(Enum):
ERROR = "error"
WARN = "warn"


@contextmanager
def handle_read_errors(
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN],
location: str,
exc_types: tuple[type[Exception], ...],
) -> Generator[None, None, None]:
"""
Handle read errors according to parameter `on_bad_files`.

Parameters
----------
on_bad_files
Specifies what to do upon encountering an exception.
Allowed values are :

- 'error', let the exception be raised.
- 'warn', convert the exception into a warning if it is one of the expected exception types.
location
String identifying the function call where the exception happened
exc_types
A tuple of expected exception classes that should be converted into warnings.

Raises
------
If `on_bad_files="error"`, all encountered exceptions are raised.
If `on_bad_files="warn"`, any encountered exceptions not matching the `exc_types` are raised.
"""
on_bad_files = BadFileHandleMethod(on_bad_files) # str to enum
if on_bad_files == BadFileHandleMethod.WARN:
try:
yield
except exc_types as e:
# Extract the original filename and line number from the exception and
# create a warning from it.
exc_traceback = sys.exc_info()[-1]
last_frame, lineno = list(traceback.walk_tb(exc_traceback))[-1]
filename = last_frame.f_code.co_filename
# Include the location (element path) in the warning message.
message = f"{location}: {e.__class__.__name__}: {e.args[0]}"
warnings.warn_explicit(
message=message,
category=UserWarning,
filename=filename,
lineno=lineno,
)
# continue
else: # on_bad_files == BadFileHandleMethod.ERROR
# Let it raise exceptions
yield
65 changes: 40 additions & 25 deletions src/spatialdata/_io/io_table.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,29 @@
from __future__ import annotations

import os
from json import JSONDecodeError
from typing import Literal

import numpy as np
import zarr
from anndata import AnnData
from anndata import read_zarr as read_anndata_zarr
from anndata._io.specs import write_elem as write_adata
from ome_zarr.format import Format
from zarr.errors import ArrayNotFoundError

from spatialdata._io._utils import BadFileHandleMethod, handle_read_errors
from spatialdata._io.format import CurrentTablesFormat, TablesFormats, _parse_version
from spatialdata._logging import logger
from spatialdata.models import TableModel


def _read_table(
zarr_store_path: str, group: zarr.Group, subgroup: zarr.Group, tables: dict[str, AnnData]
zarr_store_path: str,
group: zarr.Group,
subgroup: zarr.Group,
tables: dict[str, AnnData],
on_bad_files: Literal[BadFileHandleMethod.ERROR, BadFileHandleMethod.WARN] = BadFileHandleMethod.ERROR,
) -> dict[str, AnnData]:
"""
Read in tables in the tables Zarr.group of a SpatialData Zarr store.
Expand All @@ -30,6 +38,8 @@ def _read_table(
The subgroup containing the tables.
tables
A dictionary of tables.
on_bad_files
Specifies what to do upon encountering a bad file, e.g. corrupted, invalid or missing files.

Returns
-------
Expand All @@ -40,33 +50,38 @@ def _read_table(
f_elem = subgroup[table_name]
f_elem_store = os.path.join(zarr_store_path, f_elem.path)

tables[table_name] = read_anndata_zarr(f_elem_store)
with handle_read_errors(
on_bad_files=on_bad_files,
location=f"{subgroup.path}/{table_name}",
exc_types=(JSONDecodeError, KeyError, ValueError, ArrayNotFoundError),
):
tables[table_name] = read_anndata_zarr(f_elem_store)

f = zarr.open(f_elem_store, mode="r")
version = _parse_version(f, expect_attrs_key=False)
assert version is not None
# since have just one table format, we currently read it but do not use it; if we ever change the format
# we can rename the two _ to format and implement the per-format read logic (as we do for shapes)
_ = TablesFormats[version]
f.store.close()
f = zarr.open(f_elem_store, mode="r")
version = _parse_version(f, expect_attrs_key=False)
assert version is not None
# since have just one table format, we currently read it but do not use it; if we ever change the format
# we can rename the two _ to format and implement the per-format read logic (as we do for shapes)
_ = TablesFormats[version]
f.store.close()

# # replace with format from above
# version = "0.1"
# format = TablesFormats[version]
if TableModel.ATTRS_KEY in tables[table_name].uns:
# fill out eventual missing attributes that has been omitted because their value was None
attrs = tables[table_name].uns[TableModel.ATTRS_KEY]
if "region" not in attrs:
attrs["region"] = None
if "region_key" not in attrs:
attrs["region_key"] = None
if "instance_key" not in attrs:
attrs["instance_key"] = None
# fix type for region
if "region" in attrs and isinstance(attrs["region"], np.ndarray):
attrs["region"] = attrs["region"].tolist()
# # replace with format from above
# version = "0.1"
# format = TablesFormats[version]
if TableModel.ATTRS_KEY in tables[table_name].uns:
# fill out eventual missing attributes that has been omitted because their value was None
attrs = tables[table_name].uns[TableModel.ATTRS_KEY]
if "region" not in attrs:
attrs["region"] = None
if "region_key" not in attrs:
attrs["region_key"] = None
if "instance_key" not in attrs:
attrs["instance_key"] = None
# fix type for region
if "region" in attrs and isinstance(attrs["region"], np.ndarray):
attrs["region"] = attrs["region"].tolist()

count += 1
count += 1

logger.debug(f"Found {count} elements in {subgroup}")
return tables
Expand Down
Loading
Loading