Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

seqFISH: update seqfish reader #227

Merged
merged 14 commits into from
Dec 12, 2024
13 changes: 8 additions & 5 deletions src/spatialdata_io/_constants/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,14 +66,15 @@ class SeqfishKeys(ModeEnum):
# file extensions
CSV_FILE = ".csv"
TIFF_FILE = ".tiff"
OME_TIFF_FILE = ".ome.tiff"
GEOJSON_FILE = ".geojson"
# file identifiers
SECTION = "section"
TRANSCRIPT_COORDINATES = "TranscriptCoordinates"
ROI = "Roi"
TRANSCRIPT_COORDINATES = "TranscriptList"
DAPI = "DAPI"
COUNTS_FILE = "CxG"
CELL_MASK_FILE = "CellMask"
COUNTS_FILE = "CellxGene"
SEGMENTATION = "Segmentation"
CELL_COORDINATES = "CellCoordinates"
BOUNDARIES = "Boundaries"
# transcripts
TRANSCRIPTS_X = "x"
TRANSCRIPTS_Y = "y"
Expand All @@ -87,6 +88,8 @@ class SeqfishKeys(ModeEnum):
SPATIAL_KEY = "spatial"
REGION_KEY = "region"
INSTANCE_KEY_TABLE = "instance_id"
SCALEFEFACTOR_X = "PhysicalSizeX"
SCALEFEFACTOR_Y = "PhysicalSizeY"


@unique
Expand Down
237 changes: 149 additions & 88 deletions src/spatialdata_io/readers/seqfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import os
import re
import xml.etree.ElementTree as ET
from collections.abc import Mapping
from pathlib import Path
from types import MappingProxyType
Expand All @@ -10,6 +11,7 @@
import anndata as ad
import numpy as np
import pandas as pd
import tifffile
from dask_image.imread import imread
from spatialdata import SpatialData
from spatialdata.models import (
Expand All @@ -19,7 +21,7 @@
ShapesModel,
TableModel,
)
from spatialdata.transformations import Identity
from spatialdata.transformations.transformations import Identity, Scale

from spatialdata_io._constants._constants import SeqfishKeys as SK
from spatialdata_io._docs import inject_docs
Expand All @@ -33,19 +35,22 @@ def seqfish(
load_images: bool = True,
load_labels: bool = True,
load_points: bool = True,
sections: list[int] | None = None,
load_shapes: bool = True,
cells_as_circles: bool = False,
rois: list[int] | None = None,
imread_kwargs: Mapping[str, Any] = MappingProxyType({}),
raster_models_scale_factors: list[float] | None = None,
) -> SpatialData:
"""
Read *seqfish* formatted dataset.

This function reads the following files:

- ```{vx.COUNTS_FILE!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Counts and metadata file.
- ```{vx.CELL_COORDINATES!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Cell coordinates file.
- ```{vx.DAPI!r}{vx.SECTION!r}{vx.OME_TIFF_FILE!r}```: High resolution tiff image.
- ```{vx.CELL_MASK_FILE!r}{vx.SECTION!r}{vx.TIFF_FILE!r}```: Cell mask file.
- ```{vx.TRANSCRIPT_COORDINATES!r}{vx.SECTION!r}{vx.CSV_FILE!r}```: Transcript coordinates file.
- ```{vx.ROI!r}{vx.COUNTS_FILE!r}{vx.CSV_FILE!r}```: Counts and metadata file.
- ```{vx.ROI!r}{vx.CELL_COORDINATES!r}{vx.CSV_FILE!r}```: Cell coordinates file.
- ```{vx.ROI!r}{vx.DAPI!r}{vx.TIFF_FILE!r}```: High resolution tiff image.
- ```{vx.ROI!r}{vx.SEGMENTATION!r}{vx.TIFF_FILE!r}```: Cell mask file.
- ```{vx.ROI!r}{vx.TRANSCRIPT_COORDINATES!r}{vx.CSV_FILE!r}```: Transcript coordinates file.

.. seealso::

Expand All @@ -58,133 +63,189 @@ def seqfish(
load_images
Whether to load the images.
load_labels
Whether to load the labels.
Whether to load cell segmentation.
load_points
Whether to load the points.
sections
Which sections (specified as integers) to load. By default, all sections are loaded.
Whether to load the transcript locations.
load_shapes
Whether to load cells as shape.
cells_as_circles
Whether to read cells also as circles instead of labels.
rois
Which ROIs (specified as integers) to load. Only necessary if multiple ROIs present.
imread_kwargs
Keyword arguments to pass to :func:`dask_image.imread.imread`.

Returns
-------
:class:`spatialdata.SpatialData`

Examples
--------
This code shows how to change the annotation target of the table from the cell labels to the cell boundaries.
Please check that the string Roi1 is used in the naming of your dataset, otherwise adjust the code below.
>>> from spatialdata_io import seqfish
>>> sdata = seqfish("path/to/raw/data")
>>> sdata["table_Roi1"].obs["region"] = "Roi1_Boundaries"
>>> sdata.set_table_annotates_spatialelement(
... table_name="table_Roi1", region="Roi1_Boundaries", region_key="region", instance_key="instance_id"
... )
>>> sdata.write("path/to/data.zarr")
"""
path = Path(path)
count_file_pattern = re.compile(rf"(.*?)_{SK.CELL_COORDINATES}_{SK.SECTION}[0-9]+" + re.escape(SK.CSV_FILE))
count_files = [i for i in os.listdir(path) if count_file_pattern.match(i)]
count_file_pattern = re.compile(rf"(.*?){re.escape(SK.CELL_COORDINATES)}{re.escape(SK.CSV_FILE)}$")
count_files = [f for f in os.listdir(path) if count_file_pattern.match(f)]
if not count_files:
# no file matching tbe pattern found
raise ValueError(
f"No files matching the pattern {count_file_pattern} were found. Cannot infer the naming scheme."
)
matched = count_file_pattern.match(count_files[0])
if matched is None:
raise ValueError(f"File {count_files[0]} does not match the pattern {count_file_pattern}")
prefix = matched.group(1)

n = len(count_files)
all_sections = list(range(1, n + 1))
if sections is None:
sections = all_sections

roi_pattern = re.compile(f"^{SK.ROI}(\\d+)")
found_rois = {m.group(1) for i in os.listdir(path) if (m := roi_pattern.match(i))}
if rois is None:
rois_str = [f"{SK.ROI}{roi}" for roi in found_rois]
elif isinstance(rois, list):
for roi in rois:
if str(roi) not in found_rois:
raise ValueError(f"ROI{roi} not found.")
rois_str = [f"{SK.ROI}{roi}" for roi in rois]
else:
for section in sections:
if section not in all_sections:
raise ValueError(f"Section {section} not found in the data.")
sections_str = [f"{SK.SECTION}{x}" for x in sections]
raise ValueError("Invalid type for 'roi'. Must be list[int] or None.")

def get_cell_file(roi: str) -> str:
return f"{roi}_{SK.CELL_COORDINATES}{SK.CSV_FILE}"

def get_cell_file(section: str) -> str:
return f"{prefix}_{SK.CELL_COORDINATES}_{section}{SK.CSV_FILE}"
def get_count_file(roi: str) -> str:
return f"{roi}_{SK.COUNTS_FILE}{SK.CSV_FILE}"

def get_count_file(section: str) -> str:
return f"{prefix}_{SK.COUNTS_FILE}_{section}{SK.CSV_FILE}"
def get_dapi_file(roi: str) -> str:
return f"{roi}_{SK.DAPI}{SK.TIFF_FILE}"

def get_dapi_file(section: str) -> str:
return f"{prefix}_{SK.DAPI}_{section}{SK.OME_TIFF_FILE}"
def get_cell_segmentation_labels_file(roi: str) -> str:
return f"{roi}_{SK.SEGMENTATION}{SK.TIFF_FILE}"

def get_cell_mask_file(section: str) -> str:
return f"{prefix}_{SK.CELL_MASK_FILE}_{section}{SK.TIFF_FILE}"
def get_cell_segmentation_shapes_file(roi: str) -> str:
return f"{roi}_{SK.BOUNDARIES}{SK.GEOJSON_FILE}"

def get_transcript_file(section: str) -> str:
return f"{prefix}_{SK.TRANSCRIPT_COORDINATES}_{section}{SK.CSV_FILE}"
def get_transcript_file(roi: str) -> str:
return f"{roi}_{SK.TRANSCRIPT_COORDINATES}{SK.CSV_FILE}"

adatas: dict[str, ad.AnnData] = {}
for section in sections_str: # type: ignore[assignment]
assert isinstance(section, str)
cell_file = get_cell_file(section)
count_matrix = get_count_file(section)
adata = ad.read_csv(path / count_matrix, delimiter=",")
# parse table information
tables: dict[str, ad.AnnData] = {}
for roi_str in rois_str:
# parse cell gene expression data
count_matrix = get_count_file(roi_str)
df = pd.read_csv(path / count_matrix, delimiter=",")
instance_id = df.iloc[:, 0].astype(str)
expression = df.drop(columns=["Unnamed: 0"])
expression.set_index(instance_id, inplace=True)
adata = ad.AnnData(expression)

# parse cell spatial information
cell_file = get_cell_file(roi_str)
cell_info = pd.read_csv(path / cell_file, delimiter=",")
cell_info["label"] = cell_info["label"].astype("str")
# below, the obsm are assigned by position, not by index. Here we check that we can do it
assert cell_info["label"].to_numpy().tolist() == adata.obs.index.to_numpy().tolist()
cell_info.set_index("label", inplace=True)
adata.obs[SK.AREA] = cell_info[SK.AREA]
adata.obsm[SK.SPATIAL_KEY] = cell_info[[SK.CELL_X, SK.CELL_Y]].to_numpy()
adata.obs[SK.AREA] = np.reshape(cell_info[SK.AREA].to_numpy(), (-1, 1))
region = f"cells_{section}"

# map tables to cell labels (defined later)
region = os.path.splitext(get_cell_segmentation_labels_file(roi_str))[0]
adata.obs[SK.REGION_KEY] = region
adata.obs[SK.INSTANCE_KEY_TABLE] = adata.obs.index.astype(int)
adatas[section] = adata
adata.obs[SK.REGION_KEY] = adata.obs[SK.REGION_KEY].astype("category")
adata.obs[SK.INSTANCE_KEY_TABLE] = instance_id.to_numpy().astype(np.uint16)
adata.obs = adata.obs.reset_index(drop=True)
tables[f"table_{roi_str}"] = TableModel.parse(
adata,
region=region,
region_key=SK.REGION_KEY.value,
instance_key=SK.INSTANCE_KEY_TABLE.value,
)

scale_factors = [2, 2, 2, 2]
# parse scale factors to scale images and labels
scaled = {}
for roi_str in rois_str:
scaled[roi_str] = Scale(
np.array(_get_scale_factors(path / get_dapi_file(roi_str), SK.SCALEFEFACTOR_X, SK.SCALEFEFACTOR_Y)),
axes=("y", "x"),
)

if load_images:
images = {
f"image_{x}": Image2DModel.parse(
f"{os.path.splitext(get_dapi_file(x))[0]}": Image2DModel.parse(
imread(path / get_dapi_file(x), **imread_kwargs),
dims=("c", "y", "x"),
scale_factors=scale_factors,
transformations={x: Identity()},
scale_factors=raster_models_scale_factors,
transformations={"global": scaled[x]},
)
for x in sections_str
for x in rois_str
}
else:
images = {}

if load_labels:
labels = {
f"labels_{x}": Labels2DModel.parse(
imread(path / get_cell_mask_file(x), **imread_kwargs).squeeze(),
f"{os.path.splitext(get_cell_segmentation_labels_file(x))[0]}": Labels2DModel.parse(
imread(path / get_cell_segmentation_labels_file(x), **imread_kwargs).squeeze(),
dims=("y", "x"),
scale_factors=scale_factors,
transformations={x: Identity()},
scale_factors=raster_models_scale_factors,
transformations={"global": scaled[x]},
)
for x in sections_str
for x in rois_str
}
else:
labels = {}

points = {}
if load_points:
points = {
f"transcripts_{x}": PointsModel.parse(
pd.read_csv(path / get_transcript_file(x), delimiter=","),
for x in rois_str:

# prepare data
name = f"{os.path.splitext(get_transcript_file(x))[0]}"
p = pd.read_csv(path / get_transcript_file(x), delimiter=",")
instance_key_points = SK.INSTANCE_KEY_POINTS.value if SK.INSTANCE_KEY_POINTS.value in p.columns else None

# call parser
points[name] = PointsModel.parse(
p,
coordinates={"x": SK.TRANSCRIPTS_X, "y": SK.TRANSCRIPTS_Y},
feature_key=SK.FEATURE_KEY.value,
instance_key=SK.INSTANCE_KEY_POINTS.value,
transformations={x: Identity()},
instance_key=instance_key_points,
transformations={"global": Identity()},
)

shapes = {}
if cells_as_circles:
for x, adata in zip(rois_str, tables.values()):
shapes[f"{os.path.splitext(get_cell_file(x))[0]}"] = ShapesModel.parse(
adata.obsm[SK.SPATIAL_KEY],
geometry=0,
radius=np.sqrt(adata.obs[SK.AREA].to_numpy() / np.pi),
index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(),
transformations={"global": Identity()},
)
if load_shapes:
for x in rois_str:
# this assumes that the index matches the instance key of the table. A more robust approach could be
# implemented, as described here https://github.com/scverse/spatialdata-io/issues/249
shapes[f"{os.path.splitext(get_cell_segmentation_shapes_file(x))[0]}"] = ShapesModel.parse(
path / get_cell_segmentation_shapes_file(x),
transformations={"global": scaled[x]},
index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(),
)
for x in sections_str
}
else:
points = {}

adata = ad.concat(adatas.values())
adata.obs[SK.REGION_KEY] = adata.obs[SK.REGION_KEY].astype("category")
adata.obs = adata.obs.reset_index(drop=True)
table = TableModel.parse(
adata,
region=[f"cells_{x}" for x in sections_str],
region_key=SK.REGION_KEY.value,
instance_key=SK.INSTANCE_KEY_TABLE.value,
)

shapes = {
f"cells_{x}": ShapesModel.parse(
adata.obsm[SK.SPATIAL_KEY],
geometry=0,
radius=np.sqrt(adata.obs[SK.AREA].to_numpy() / np.pi),
index=adata.obs[SK.INSTANCE_KEY_TABLE].copy(),
transformations={x: Identity()},
)
for x, adata in adatas.items()
}

sdata = SpatialData(images=images, labels=labels, points=points, table=table, shapes=shapes)
sdata = SpatialData(images=images, labels=labels, points=points, tables=tables, shapes=shapes)

return sdata


def _get_scale_factors(DAPI_path: Path, scalefactor_x_key: str, scalefactor_y_key: str) -> list[float]:
with tifffile.TiffFile(DAPI_path) as tif:
ome_metadata = tif.ome_metadata
root = ET.fromstring(ome_metadata)
for element in root.iter():
if scalefactor_x_key in element.attrib.keys():
scalefactor_x = element.attrib[scalefactor_x_key]
scalefactor_y = element.attrib[scalefactor_y_key]
return [float(scalefactor_x), float(scalefactor_y)]
29 changes: 29 additions & 0 deletions tests/test_seqfish.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import math
from pathlib import Path

import pytest

from spatialdata_io.readers.seqfish import seqfish
from tests._utils import skip_if_below_python_version


# See https://github.com/scverse/spatialdata-io/blob/main/.github/workflows/prepare_test_data.yaml for instructions on
# how to download and place the data on disk
@skip_if_below_python_version()
@pytest.mark.parametrize(
"dataset,expected", [("seqfish-2-test-dataset/instrument 2 official", "{'y': (0, 108), 'x': (0, 108)}")]
)
@pytest.mark.parametrize("rois", [[1], None])
@pytest.mark.parametrize("cells_as_circles", [False, True])
def test_example_data(dataset: str, expected: str, rois: list[int] | None, cells_as_circles: bool) -> None:
f = Path("./data") / dataset
assert f.is_dir()
sdata = seqfish(f, cells_as_circles=cells_as_circles, rois=rois)
from spatialdata import get_extent

extent = get_extent(sdata, exact=False)
extent = {ax: (math.floor(extent[ax][0]), math.ceil(extent[ax][1])) for ax in extent}
if cells_as_circles:
# manual correction required to take into account for the circle radii
expected = "{'y': (-2, 109), 'x': (-2, 109)}"
assert str(extent) == expected
8 changes: 2 additions & 6 deletions tests/test_xenium.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ def test_roundtrip_with_data_limits() -> None:
assert np.array_equal(cell_id_str, f0(*f1(cell_id_str)))


# The datasets should be downloaded from
# https://www.10xgenomics.com/support/software/xenium-onboard-analysis/latest/resources/xenium-example-data#test-data
# and placed in the "data" directory; if you run the tests locally you may need to create a symlink in "tests/data"
# pointing to "data".
# The GitHub workflow "prepare_test_data.yaml" takes care of downloading the datasets and uploading an artifact for the
# tests to use
# See https://github.com/scverse/spatialdata-io/blob/main/.github/workflows/prepare_test_data.yaml for instructions on
# how to download and place the data on disk
@skip_if_below_python_version()
@pytest.mark.parametrize(
"dataset,expected",
Expand Down
Loading