From 7f0ac8846865605a3a347e5ecdbcf9c2605a02ed Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Fri, 10 Jan 2025 16:34:36 -0500 Subject: [PATCH] [python] Add region columns and metadata to SpatialData (#3537) (#3556) * Inline SpatialData scene generation to ExperimentAxisQuery This will make iterating over the code changes faster/easier. * (WIP) Add regions to spatialdata table * Add regions to table * (WIP) Fix tests to not have overlapping obs_id * Update spatialdata test so soma_joinid don't overlap * Test regions are correctly set when exporting to SpatialData * Use SOMA_JOINID constant and fix typo Co-authored-by: Julia Dark <24235303+jp-dark@users.noreply.github.com> --- apis/python/src/tiledbsoma/_query.py | 179 ++++++++++++++++-- .../src/tiledbsoma/io/spatial/outgest.py | 10 +- .../tests/test_experiment_query_spatial.py | 85 +++++++-- 3 files changed, 239 insertions(+), 35 deletions(-) diff --git a/apis/python/src/tiledbsoma/_query.py b/apis/python/src/tiledbsoma/_query.py index 37ac9aacfb..e91958ada9 100644 --- a/apis/python/src/tiledbsoma/_query.py +++ b/apis/python/src/tiledbsoma/_query.py @@ -57,7 +57,7 @@ if TYPE_CHECKING: from ._experiment import Experiment -from ._constants import SPATIAL_DISCLAIMER +from ._constants import SOMA_JOINID, SPATIAL_DISCLAIMER from ._fastercsx import CompressedMatrix from ._measurement import Measurement from ._sparse_nd_array import SparseNDArray @@ -572,23 +572,141 @@ def to_spatialdata( # type: ignore[no-untyped-def] Defaults to ``obs``. """ - from spatialdata import SpatialData - - from .io.spatial.outgest import _add_scene_to_spatialdata + import spatialdata as sd + + from ._multiscale_image import MultiscaleImage + from ._point_cloud_dataframe import PointCloudDataFrame + from .io.spatial.outgest import ( + _convert_axis_names, + _get_transform_from_collection, + to_spatialdata_image, + to_spatialdata_multiscale_image, + to_spatialdata_points, + to_spatialdata_shapes, + ) warnings.warn(SPATIAL_DISCLAIMER) # Get a list of scenes to add to SpatialData object. if scene_presence_mode == "obs": - scene_ids = self.obs_scene_ids() + scene_names = tuple(str(scene_name) for scene_name in self.obs_scene_ids()) elif scene_presence_mode == "var": - scene_ids = self.var_scene_ids() + scene_names = tuple(str(scene_name) for scene_name in self.var_scene_ids()) else: raise ValueError( f"Invalid scene presence mode '{scene_presence_mode}'. Valid options " f"are 'obs' and 'var'." ) + # Create empty SpatialData instance and dict to store region/instance keys. + sdata = sd.SpatialData() + region_joinids: Dict[str, Any] = {} + + # Add data from linked scenes. + for scene_name in scene_names: + scene = self.experiment.spatial[scene_name] + + # Cannot have spatial data if no coordinate space. + if scene.coordinate_space is None: + continue + + # Get the map from Scene dimension names to SpatialData dimension names. + input_axis_names = scene.coordinate_space.axis_names + _, scene_dim_map = _convert_axis_names(input_axis_names, input_axis_names) + + # Export obsl data to SpatialData. + if "obsl" in scene: + for key, df in scene.obsl.items(): + output_key = f"{scene_name}_{key}" + transform = _get_transform_from_collection(key, scene.obsl.metadata) + if isinstance(df, PointCloudDataFrame): + if "soma_geometry" in df.metadata: + sdata.shapes[output_key] = to_spatialdata_shapes( + df, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + region_joinids[output_key] = sdata.shapes[output_key][ + SOMA_JOINID + ] + + else: + sdata.points[output_key] = to_spatialdata_points( + df, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + region_joinids[output_key] = sdata.points[output_key][ + SOMA_JOINID + ] + + else: + warnings.warn( + f"Skipping obsl[{key}] in Scene {scene_name}; unexpected " + f"datatype {type(df).__name__}." + ) + + # Export varl data to SpatialData. + if "varl" in scene and self.measurement_name in scene.varl: + subcoll = scene.varl[self.measurement_name] + for key, df in subcoll.items(): + output_key = f"{scene_name}_{self.measurement_name}_{key}" + transform = _get_transform_from_collection(key, subcoll.metadata) + if isinstance(df, PointCloudDataFrame): + if "soma_geometry" in df.metadata: + sdata.shapes[output_key] = to_spatialdata_shapes( + df, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + else: + sdata.points[output_key] = to_spatialdata_points( + df, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + else: + warnings.warn( + f"Skipping varl[{self.measurement_name}][{key}] in Scene " + f"{scene_name}; unexpected datatype {type(df).__name__}." + ) + + # Export img data to SpatialData. + if "img" in scene: + for key, image in scene.img.items(): + output_key = f"{scene_name}_{key}" + transform = _get_transform_from_collection(key, scene.img.metadata) + if not isinstance(image, MultiscaleImage): + warnings.warn( # type: ignore[unreachable] + f"Skipping img[{image}] in Scene {scene_name}; unexpected " + f"datatype {type(image).__name__}." + ) + if image.level_count == 1: + sdata.images[output_key] = to_spatialdata_image( + image, + 0, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + else: + sdata.images[output_key] = to_spatialdata_multiscale_image( + image, + key=output_key, + scene_id=scene_name, + scene_dim_map=scene_dim_map, + transform=transform, + ) + # Get the anndata table. ad = self.to_anndata( X_name, @@ -600,18 +718,45 @@ def to_spatialdata( # type: ignore[no-untyped-def] varp_layers=varp_layers, drop_levels=drop_levels, ) - sdata = SpatialData(tables={self.measurement_name: ad}) - - for scene_id in scene_ids: - scene = self.experiment.spatial[str(scene_id)] - _add_scene_to_spatialdata( - sdata, - scene_id=str(scene_id), - scene=scene, - obs_id_name="soma_joinid", - var_id_name="soma_joinid", - measurement_names=(self.measurement_name,), + + # Add joinids to region dataframe. Verify no overwrites. + regions: list[str] | None = None + region_key: str | None = None + instance_key: str | None = None + if region_joinids: + region_df = pd.concat( + [ + pd.DataFrame.from_dict( + { + SOMA_JOINID: joinid_series, + "region_key": key, + "instance_key": joinid_series.index, + } + ) + for key, joinid_series in region_joinids.items() + ] ) + if not region_df.empty: + try: + ad.obs = pd.merge( + ad.obs, + region_df, + how="left", + on=SOMA_JOINID, + validate="many_to_one", + ) + except pd.errors.MergeError as err: + raise NotImplementedError( + "Unable to export to SpatialData; exported assets have " + "overlapping observations." + ) from err + regions = list(region_joinids.keys()) + region_key = "region_key" + instance_key = "instance_key" + + sdata.tables[self.measurement_name] = sd.models.TableModel.parse( + ad, regions, region_key, instance_key + ) return sdata diff --git a/apis/python/src/tiledbsoma/io/spatial/outgest.py b/apis/python/src/tiledbsoma/io/spatial/outgest.py index 032018700e..86b42c8422 100644 --- a/apis/python/src/tiledbsoma/io/spatial/outgest.py +++ b/apis/python/src/tiledbsoma/io/spatial/outgest.py @@ -112,7 +112,7 @@ def to_spatialdata_points( scene_id: str, scene_dim_map: Dict[str, str], transform: somacore.CoordinateTransform | None, - soma_joinid_name: str, + soma_joinid_name: str = SOMA_JOINID, ) -> dd.DataFrame: """Export a :class:`PointCloudDataFrame` to a :class:`spatialdata.ShapesModel. @@ -145,7 +145,8 @@ def to_spatialdata_points( # Read the pandas dataframe, rename SOMA_JOINID, add metadata, and return. df: pd.DataFrame = points.read().concat().to_pandas() - df.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True) + if soma_joinid_name != SOMA_JOINID: + df.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True) return sd.models.PointsModel.parse(df, transformations=transforms) @@ -156,7 +157,7 @@ def to_spatialdata_shapes( scene_id: str, scene_dim_map: Dict[str, str], transform: somacore.CoordinateTransform | None, - soma_joinid_name: str, + soma_joinid_name: str = SOMA_JOINID, ) -> gpd.GeoDataFrame: """Export a :class:`PointCloudDataFrame` to a :class:`spatialdata.ShapesModel. @@ -206,7 +207,8 @@ def to_spatialdata_shapes( } data = points.read().concat().to_pandas() - data.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True) + if soma_joinid_name != SOMA_JOINID: + data.rename(columns={SOMA_JOINID: soma_joinid_name}, inplace=True) data.insert(len(data.columns), "radius", radius) ndim = len(orig_axis_names) if ndim == 2: diff --git a/apis/python/tests/test_experiment_query_spatial.py b/apis/python/tests/test_experiment_query_spatial.py index 7689458c0d..92fe3eee43 100644 --- a/apis/python/tests/test_experiment_query_spatial.py +++ b/apis/python/tests/test_experiment_query_spatial.py @@ -150,9 +150,16 @@ def soma_spatial_experiment(tmp_path_factory) -> soma.Experiment: np.linspace(-1.0, 1.0, num=4), np.linspace(-1.0, 1.0, num=4) ) point_df = { - "x": x.flatten(), - "y": y.flatten(), - "soma_joinid": np.arange(index * 16, index * 16 + 16, dtype=np.int64), + "x": x.flatten()[:8], + "y": y.flatten()[:8], + "soma_joinid": np.arange(index * 16, index * 16 + 8, dtype=np.int64), + } + circle_df = { + "x": x.flatten()[8:], + "y": y.flatten()[8:], + "soma_joinid": np.arange( + index * 16 + 8, (index + 1) * 16, dtype=np.int64 + ), } add_scene( spatial, @@ -163,9 +170,9 @@ def soma_spatial_experiment(tmp_path_factory) -> soma.Experiment: (("varl", "other"), "points3"): point_df, }, circles={ - ("obsl", "shape1"): point_df, - (("varl", "RNA"), "shape2"): point_df, - (("varl", "other"), "shape3"): point_df, + ("obsl", "shapes1"): circle_df, + (("varl", "RNA"), "shapes2"): circle_df, + (("varl", "other"), "shapes3"): circle_df, }, images={"tissue": ((3, 16, 8),)}, ) @@ -217,17 +224,21 @@ def check_for_scene_data(sdata, has_scenes: List[bool]): x, y = np.meshgrid(np.linspace(-1.0, 1.0, num=4), np.linspace(-1.0, 1.0, num=4)) expected_points = pd.DataFrame.from_dict( { - "x": x.flatten(), - "y": y.flatten(), - "soma_joinid": np.arange(index * 16, (index + 1) * 16, dtype=np.int64), + "x": x.flatten()[:8], + "y": y.flatten()[:8], + "soma_joinid": np.arange( + index * 16, (index + 1) * 16 - 8, dtype=np.int64 + ), } ) expected_shapes = pd.DataFrame.from_dict( { - "soma_joinid": np.arange(index * 16, (index + 1) * 16, dtype=np.int64), - "radius": 16 * [2.0], + "soma_joinid": np.arange( + index * 16 + 8, (index + 1) * 16, dtype=np.int64 + ), + "radius": 8 * [2.0], "geometry": shapely.points( - list(zip(x.flatten(), y.flatten())) + list(zip(x.flatten()[8:], y.flatten()[8:])) ).tolist(), } ) @@ -237,9 +248,9 @@ def check_for_scene_data(sdata, has_scenes: List[bool]): assert all(points1 == expected_points) points2 = sdata.points[f"{scene_id}_RNA_points2"] assert all(points2 == expected_points) - shapes1 = sdata.shapes[f"{scene_id}_shape1"] + shapes1 = sdata.shapes[f"{scene_id}_shapes1"] assert all(shapes1 == expected_shapes) - shapes2 = sdata.shapes[f"{scene_id}_RNA_shape2"] + shapes2 = sdata.shapes[f"{scene_id}_RNA_shapes2"] assert all(shapes2 == expected_shapes) image = sdata.images[f"{scene_id}_tissue"] assert image.shape == (3, 16, 8) @@ -265,6 +276,11 @@ def test_spatial_experiment_query_none(soma_spatial_experiment): ad = sdata["RNA"] assert ad.n_obs == 0 and ad.n_vars == 0 + # Check no region columns/metadata. + assert "region_key" not in ad.obs + assert "instance_key" not in ad.obs + assert "spatialdata_attrs" not in ad.uns + def test_spatial_experiment_query_all(soma_spatial_experiment): with soma_spatial_experiment.axis_query("RNA") as query: @@ -288,6 +304,47 @@ def test_spatial_experiment_query_all(soma_spatial_experiment): check_for_scene_data(sdata, 4 * [True]) + # Check table. + ad = sdata.tables["RNA"] + + sd_attrs = ad.uns["spatialdata_attrs"] + assert sd_attrs["region"] == [ + "scene0_points1", + "scene0_shapes1", + "scene1_points1", + "scene1_shapes1", + "scene2_points1", + "scene2_shapes1", + "scene3_points1", + "scene3_shapes1", + ] + assert sd_attrs["region_key"] == "region_key" + assert sd_attrs["instance_key"] == "instance_key" + + region_df = ad.obs[["soma_joinid", "region_key", "instance_key"]] + for scene_index in range(4): + # Filter on points key and get the joinids and instance keys for the points. + points_key = f"scene{scene_index}_points1" + points_region_df = region_df[region_df["region_key"] == points_key] + instance_keys = points_region_df["instance_key"] + region_joinids = points_region_df["soma_joinid"].tolist() + + # Check the joinids at the points instance key match expected joinids. + points = sdata.points[points_key].compute() + points_joinids = points.iloc[instance_keys]["soma_joinid"].tolist() + assert region_joinids == points_joinids + + # Filter on shape key and get the joinids and instance keys for the shapes. + shapes_key = f"scene{scene_index}_shapes1" + shapes_region_df = region_df[region_df["region_key"] == shapes_key] + instance_keys = shapes_region_df["instance_key"] + region_joinids = shapes_region_df["soma_joinid"].tolist() + + # Check the joinids at the shapes instance key match expected joinids. + shapes = sdata.shapes[shapes_key] + shapes_joinids = shapes.iloc[instance_keys]["soma_joinid"].to_list() + assert region_joinids == shapes_joinids + @pytest.mark.parametrize( "obs_slice,has_scene",