From 439c530fcd5d4504bd9dcf98a8fec62d0d00ad0a Mon Sep 17 00:00:00 2001 From: John Kerl Date: Mon, 6 Jan 2025 18:53:05 -0700 Subject: [PATCH] [python] Check for list/tuple arguments in registrars (#3518) --- .../_registration/ambient_label_mappings.py | 18 +++++- apis/python/src/tiledbsoma/io/ingest.py | 4 +- .../tests/test_registration_mappings.py | 62 +++++++++++++++++++ 3 files changed, 80 insertions(+), 4 deletions(-) diff --git a/apis/python/src/tiledbsoma/io/_registration/ambient_label_mappings.py b/apis/python/src/tiledbsoma/io/_registration/ambient_label_mappings.py index a08f3ea784..386d66792f 100644 --- a/apis/python/src/tiledbsoma/io/_registration/ambient_label_mappings.py +++ b/apis/python/src/tiledbsoma/io/_registration/ambient_label_mappings.py @@ -391,7 +391,7 @@ def _acquire_experiment_mappings( def from_anndata_appends_on_experiment( cls, experiment_uri: str | None, - adatas: Sequence[ad.AnnData], + adatas: Sequence[ad.AnnData] | ad.AnnData, *, measurement_name: str, obs_field_name: str, @@ -404,6 +404,13 @@ def from_anndata_appends_on_experiment( is ``None`` then you will be computing registrations only for the input ``AnnData`` objects. If ``experiment_uri`` is not ``None`` then it is an error if the experiment is not accessible.""" + # typeguard doesn't help at runtime. Check this crucial user-facing API. + if isinstance(adatas, ad.AnnData): + adatas = [adatas] + elif not isinstance(adatas, (list, tuple)): + raise ValueError( + f"adatas must be list or tuple of AnnData, or a single AnnData; got {type(adatas)}" + ) registration_data = cls._acquire_experiment_mappings( experiment_uri, @@ -455,7 +462,7 @@ def from_h5ad_append_on_experiment( def from_h5ad_appends_on_experiment( cls, experiment_uri: str | None, - h5ad_file_names: Sequence[str], + h5ad_file_names: Sequence[str] | str, *, measurement_name: str, obs_field_name: str, @@ -465,6 +472,13 @@ def from_h5ad_appends_on_experiment( ) -> Self: """Extends registration data from the baseline, already-written SOMA experiment to include multiple H5AD input files.""" + # typeguard doesn't help at runtime. Check this crucial user-facing API. + if isinstance(h5ad_file_names, str): + h5ad_file_names = [h5ad_file_names] + elif not isinstance(h5ad_file_names, (list, tuple)): + raise ValueError( + f"h5ad_file_names must be list or tuple of string, or a single string; got {type(h5ad_file_names)}" + ) registration_data = cls._acquire_experiment_mappings( experiment_uri, diff --git a/apis/python/src/tiledbsoma/io/ingest.py b/apis/python/src/tiledbsoma/io/ingest.py index 183a8d5236..f22c7f46b5 100644 --- a/apis/python/src/tiledbsoma/io/ingest.py +++ b/apis/python/src/tiledbsoma/io/ingest.py @@ -188,7 +188,7 @@ def __init__( # entrypoints for append-mode soma_joinid registration. def register_h5ads( experiment_uri: str | None, - h5ad_file_names: Sequence[str], + h5ad_file_names: Sequence[str] | str, *, measurement_name: str, obs_field_name: str, @@ -212,7 +212,7 @@ def register_h5ads( def register_anndatas( experiment_uri: str | None, - adatas: Sequence[ad.AnnData], + adatas: Sequence[ad.AnnData] | ad.AnnData, *, measurement_name: str, obs_field_name: str, diff --git a/apis/python/tests/test_registration_mappings.py b/apis/python/tests/test_registration_mappings.py index c3c58c0025..55fb7c8d87 100644 --- a/apis/python/tests/test_registration_mappings.py +++ b/apis/python/tests/test_registration_mappings.py @@ -1342,3 +1342,65 @@ def test_multimodal_names(tmp_path, conftest_pbmc3k_adata): assert exp.obs.count == len(adata_protein.obs) assert exp.ms["RNA"].var.count == len(adata_rna.var) assert exp.ms["protein"].var.count == len(adata_protein.var) + + +def test_registration_lists_and_tuples(tmp_path): + obs_field_name = "cell_id" + var_field_name = "gene_id" + + exp_uri = create_soma_canned(1, obs_field_name, var_field_name) + adata = create_anndata_canned(2, obs_field_name, var_field_name) + h5ad_file_name = create_h5ad_canned(2, obs_field_name, var_field_name) + + rd1 = tiledbsoma.io.register_anndatas( + experiment_uri=exp_uri, + adatas=[adata], + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + + rd2 = tiledbsoma.io.register_anndatas( + experiment_uri=exp_uri, + adatas=(adata,), + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + + rd3 = tiledbsoma.io.register_anndatas( + experiment_uri=exp_uri, + adatas=adata, + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + assert rd1 == rd2 + assert rd2 == rd3 + + rd4 = tiledbsoma.io.register_h5ads( + experiment_uri=exp_uri, + h5ad_file_names=[h5ad_file_name], + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + + rd5 = tiledbsoma.io.register_h5ads( + experiment_uri=exp_uri, + h5ad_file_names=(h5ad_file_name,), + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + + rd6 = tiledbsoma.io.register_h5ads( + experiment_uri=exp_uri, + h5ad_file_names=h5ad_file_name, + measurement_name="measname", + obs_field_name=obs_field_name, + var_field_name=var_field_name, + ) + + assert rd4 == rd5 + assert rd5 == rd6