Skip to content

Commit

Permalink
[python] Add region columns and metadata to SpatialData (#3537)
Browse files Browse the repository at this point in the history
* Inline SpatialData scene generation to ExperimentAxisQuery

This will make iterating over the code changes faster/easier.

* (WIP) Add regions to spatialdata table

* Add regions to table

* (WIP) Fix tests to not have overlapping obs_id

* Update spatialdata test so soma_joinid don't overlap

* Test regions are correctly set when exporting to SpatialData

* Use SOMA_JOINID constant and fix typo
  • Loading branch information
jp-dark authored Jan 10, 2025
1 parent 1c8b1bd commit dc118ee
Show file tree
Hide file tree
Showing 3 changed files with 239 additions and 35 deletions.
179 changes: 162 additions & 17 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@

if TYPE_CHECKING:
from ._experiment import Experiment
from ._constants import SPATIAL_DISCLAIMER
from ._constants import SOMA_JOINID, SPATIAL_DISCLAIMER
from ._fastercsx import CompressedMatrix
from ._measurement import Measurement
from ._sparse_nd_array import SparseNDArray
Expand Down Expand Up @@ -572,23 +572,141 @@ def to_spatialdata( # type: ignore[no-untyped-def]
Defaults to ``obs``.
"""

from spatialdata import SpatialData

from .io.spatial.outgest import _add_scene_to_spatialdata
import spatialdata as sd

from ._multiscale_image import MultiscaleImage
from ._point_cloud_dataframe import PointCloudDataFrame
from .io.spatial.outgest import (
_convert_axis_names,
_get_transform_from_collection,
to_spatialdata_image,
to_spatialdata_multiscale_image,
to_spatialdata_points,
to_spatialdata_shapes,
)

warnings.warn(SPATIAL_DISCLAIMER)

# Get a list of scenes to add to SpatialData object.
if scene_presence_mode == "obs":
scene_ids = self.obs_scene_ids()
scene_names = tuple(str(scene_name) for scene_name in self.obs_scene_ids())
elif scene_presence_mode == "var":
scene_ids = self.var_scene_ids()
scene_names = tuple(str(scene_name) for scene_name in self.var_scene_ids())
else:
raise ValueError(
f"Invalid scene presence mode '{scene_presence_mode}'. Valid options "
f"are 'obs' and 'var'."
)

# Create empty SpatialData instance and dict to store region/instance keys.
sdata = sd.SpatialData()
region_joinids: Dict[str, Any] = {}

# Add data from linked scenes.
for scene_name in scene_names:
scene = self.experiment.spatial[scene_name]

# Cannot have spatial data if no coordinate space.
if scene.coordinate_space is None:
continue

# Get the map from Scene dimension names to SpatialData dimension names.
input_axis_names = scene.coordinate_space.axis_names
_, scene_dim_map = _convert_axis_names(input_axis_names, input_axis_names)

# Export obsl data to SpatialData.
if "obsl" in scene:
for key, df in scene.obsl.items():
output_key = f"{scene_name}_{key}"
transform = _get_transform_from_collection(key, scene.obsl.metadata)
if isinstance(df, PointCloudDataFrame):
if "soma_geometry" in df.metadata:
sdata.shapes[output_key] = to_spatialdata_shapes(
df,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)
region_joinids[output_key] = sdata.shapes[output_key][
SOMA_JOINID
]

else:
sdata.points[output_key] = to_spatialdata_points(
df,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)
region_joinids[output_key] = sdata.points[output_key][
SOMA_JOINID
]

else:
warnings.warn(
f"Skipping obsl[{key}] in Scene {scene_name}; unexpected "
f"datatype {type(df).__name__}."
)

# Export varl data to SpatialData.
if "varl" in scene and self.measurement_name in scene.varl:
subcoll = scene.varl[self.measurement_name]
for key, df in subcoll.items():
output_key = f"{scene_name}_{self.measurement_name}_{key}"
transform = _get_transform_from_collection(key, subcoll.metadata)
if isinstance(df, PointCloudDataFrame):
if "soma_geometry" in df.metadata:
sdata.shapes[output_key] = to_spatialdata_shapes(
df,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)
else:
sdata.points[output_key] = to_spatialdata_points(
df,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)
else:
warnings.warn(
f"Skipping varl[{self.measurement_name}][{key}] in Scene "
f"{scene_name}; unexpected datatype {type(df).__name__}."
)

# Export img data to SpatialData.
if "img" in scene:
for key, image in scene.img.items():
output_key = f"{scene_name}_{key}"
transform = _get_transform_from_collection(key, scene.img.metadata)
if not isinstance(image, MultiscaleImage):
warnings.warn( # type: ignore[unreachable]
f"Skipping img[{image}] in Scene {scene_name}; unexpected "
f"datatype {type(image).__name__}."
)
if image.level_count == 1:
sdata.images[output_key] = to_spatialdata_image(
image,
0,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)
else:
sdata.images[output_key] = to_spatialdata_multiscale_image(
image,
key=output_key,
scene_id=scene_name,
scene_dim_map=scene_dim_map,
transform=transform,
)

# Get the anndata table.
ad = self.to_anndata(
X_name,
Expand All @@ -600,18 +718,45 @@ def to_spatialdata( # type: ignore[no-untyped-def]
varp_layers=varp_layers,
drop_levels=drop_levels,
)
sdata = SpatialData(tables={self.measurement_name: ad})

for scene_id in scene_ids:
scene = self.experiment.spatial[str(scene_id)]
_add_scene_to_spatialdata(
sdata,
scene_id=str(scene_id),
scene=scene,
obs_id_name="soma_joinid",
var_id_name="soma_joinid",
measurement_names=(self.measurement_name,),

# Add joinids to region dataframe. Verify no overwrites.
regions: list[str] | None = None
region_key: str | None = None
instance_key: str | None = None
if region_joinids:
region_df = pd.concat(
[
pd.DataFrame.from_dict(
{
SOMA_JOINID: joinid_series,
"region_key": key,
"instance_key": joinid_series.index,
}
)
for key, joinid_series in region_joinids.items()
]
)
if not region_df.empty:
try:
ad.obs = pd.merge(
ad.obs,
region_df,
how="left",
on=SOMA_JOINID,
validate="many_to_one",
)
except pd.errors.MergeError as err:
raise NotImplementedError(
"Unable to export to SpatialData; exported assets have "
"overlapping observations."
) from err
regions = list(region_joinids.keys())
region_key = "region_key"
instance_key = "instance_key"

sdata.tables[self.measurement_name] = sd.models.TableModel.parse(
ad, regions, region_key, instance_key
)

return sdata

Expand Down
10 changes: 6 additions & 4 deletions apis/python/src/tiledbsoma/io/spatial/outgest.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def to_spatialdata_points(
scene_id: str,
scene_dim_map: Dict[str, str],
transform: somacore.CoordinateTransform | None,
soma_joinid_name: str,
soma_joinid_name: str = SOMA_JOINID,
) -> dd.DataFrame:
"""Export a :class:`PointCloudDataFrame` to a :class:`spatialdata.ShapesModel.
Expand Down Expand Up @@ -145,7 +145,8 @@ def to_spatialdata_points(

# Read the pandas dataframe, rename SOMA_JOINID, add metadata, and return.
df: pd.DataFrame = points.read().concat().to_pandas()
df.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True)
if soma_joinid_name != SOMA_JOINID:
df.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True)
return sd.models.PointsModel.parse(df, transformations=transforms)


Expand All @@ -156,7 +157,7 @@ def to_spatialdata_shapes(
scene_id: str,
scene_dim_map: Dict[str, str],
transform: somacore.CoordinateTransform | None,
soma_joinid_name: str,
soma_joinid_name: str = SOMA_JOINID,
) -> gpd.GeoDataFrame:
"""Export a :class:`PointCloudDataFrame` to a :class:`spatialdata.ShapesModel.
Expand Down Expand Up @@ -206,7 +207,8 @@ def to_spatialdata_shapes(
}

data = points.read().concat().to_pandas()
data.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True)
if soma_joinid_name != SOMA_JOINID:
data.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True)
data.insert(len(data.columns), "radius", radius)
ndim = len(orig_axis_names)
if ndim == 2:
Expand Down
85 changes: 71 additions & 14 deletions apis/python/tests/test_experiment_query_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,16 @@ def soma_spatial_experiment(tmp_path_factory) -> soma.Experiment:
np.linspace(-1.0, 1.0, num=4), np.linspace(-1.0, 1.0, num=4)
)
point_df = {
"x": x.flatten(),
"y": y.flatten(),
"soma_joinid": np.arange(index * 16, index * 16 + 16, dtype=np.int64),
"x": x.flatten()[:8],
"y": y.flatten()[:8],
"soma_joinid": np.arange(index * 16, index * 16 + 8, dtype=np.int64),
}
circle_df = {
"x": x.flatten()[8:],
"y": y.flatten()[8:],
"soma_joinid": np.arange(
index * 16 + 8, (index + 1) * 16, dtype=np.int64
),
}
add_scene(
spatial,
Expand All @@ -163,9 +170,9 @@ def soma_spatial_experiment(tmp_path_factory) -> soma.Experiment:
(("varl", "other"), "points3"): point_df,
},
circles={
("obsl", "shape1"): point_df,
(("varl", "RNA"), "shape2"): point_df,
(("varl", "other"), "shape3"): point_df,
("obsl", "shapes1"): circle_df,
(("varl", "RNA"), "shapes2"): circle_df,
(("varl", "other"), "shapes3"): circle_df,
},
images={"tissue": ((3, 16, 8),)},
)
Expand Down Expand Up @@ -217,17 +224,21 @@ def check_for_scene_data(sdata, has_scenes: List[bool]):
x, y = np.meshgrid(np.linspace(-1.0, 1.0, num=4), np.linspace(-1.0, 1.0, num=4))
expected_points = pd.DataFrame.from_dict(
{
"x": x.flatten(),
"y": y.flatten(),
"soma_joinid": np.arange(index * 16, (index + 1) * 16, dtype=np.int64),
"x": x.flatten()[:8],
"y": y.flatten()[:8],
"soma_joinid": np.arange(
index * 16, (index + 1) * 16 - 8, dtype=np.int64
),
}
)
expected_shapes = pd.DataFrame.from_dict(
{
"soma_joinid": np.arange(index * 16, (index + 1) * 16, dtype=np.int64),
"radius": 16 * [2.0],
"soma_joinid": np.arange(
index * 16 + 8, (index + 1) * 16, dtype=np.int64
),
"radius": 8 * [2.0],
"geometry": shapely.points(
list(zip(x.flatten(), y.flatten()))
list(zip(x.flatten()[8:], y.flatten()[8:]))
).tolist(),
}
)
Expand All @@ -237,9 +248,9 @@ def check_for_scene_data(sdata, has_scenes: List[bool]):
assert all(points1 == expected_points)
points2 = sdata.points[f"{scene_id}_RNA_points2"]
assert all(points2 == expected_points)
shapes1 = sdata.shapes[f"{scene_id}_shape1"]
shapes1 = sdata.shapes[f"{scene_id}_shapes1"]
assert all(shapes1 == expected_shapes)
shapes2 = sdata.shapes[f"{scene_id}_RNA_shape2"]
shapes2 = sdata.shapes[f"{scene_id}_RNA_shapes2"]
assert all(shapes2 == expected_shapes)
image = sdata.images[f"{scene_id}_tissue"]
assert image.shape == (3, 16, 8)
Expand All @@ -265,6 +276,11 @@ def test_spatial_experiment_query_none(soma_spatial_experiment):
ad = sdata["RNA"]
assert ad.n_obs == 0 and ad.n_vars == 0

# Check no region columns/metadata.
assert "region_key" not in ad.obs
assert "instance_key" not in ad.obs
assert "spatialdata_attrs" not in ad.uns


def test_spatial_experiment_query_all(soma_spatial_experiment):
with soma_spatial_experiment.axis_query("RNA") as query:
Expand All @@ -288,6 +304,47 @@ def test_spatial_experiment_query_all(soma_spatial_experiment):

check_for_scene_data(sdata, 4 * [True])

# Check table.
ad = sdata.tables["RNA"]

sd_attrs = ad.uns["spatialdata_attrs"]
assert sd_attrs["region"] == [
"scene0_points1",
"scene0_shapes1",
"scene1_points1",
"scene1_shapes1",
"scene2_points1",
"scene2_shapes1",
"scene3_points1",
"scene3_shapes1",
]
assert sd_attrs["region_key"] == "region_key"
assert sd_attrs["instance_key"] == "instance_key"

region_df = ad.obs[["soma_joinid", "region_key", "instance_key"]]
for scene_index in range(4):
# Filter on points key and get the joinids and instance keys for the points.
points_key = f"scene{scene_index}_points1"
points_region_df = region_df[region_df["region_key"] == points_key]
instance_keys = points_region_df["instance_key"]
region_joinids = points_region_df["soma_joinid"].tolist()

# Check the joinids at the points instance key match expected joinids.
points = sdata.points[points_key].compute()
points_joinids = points.iloc[instance_keys]["soma_joinid"].tolist()
assert region_joinids == points_joinids

# Filter on shape key and get the joinids and instance keys for the shapes.
shapes_key = f"scene{scene_index}_shapes1"
shapes_region_df = region_df[region_df["region_key"] == shapes_key]
instance_keys = shapes_region_df["instance_key"]
region_joinids = shapes_region_df["soma_joinid"].tolist()

# Check the joinids at the shapes instance key match expected joinids.
shapes = sdata.shapes[shapes_key]
shapes_joinids = shapes.iloc[instance_keys]["soma_joinid"].to_list()
assert region_joinids == shapes_joinids


@pytest.mark.parametrize(
"obs_slice,has_scene",
Expand Down

0 comments on commit dc118ee

Please sign in to comment.