diff --git a/pyproject.toml b/pyproject.toml index 213c9056..85487b0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,9 @@ dependencies = [ "anndata", "attrs>=22.1", "numpy>=1.21", + "pandas", "pyarrow", + "scipy", "typing-extensions", ] requires-python = "~=3.7" @@ -43,5 +45,5 @@ single_line_exclusions = ["typing", "typing_extensions"] [[tool.mypy.overrides]] # These dependencies do not currently have canonical type stubs. -module = ["anndata", "pyarrow"] +module = ["anndata", "pandas", "pyarrow", "scipy"] ignore_missing_imports = true diff --git a/python-spec/requirements-py3.10.txt b/python-spec/requirements-py3.10.txt index d366e6d9..5ae9fc28 100644 --- a/python-spec/requirements-py3.10.txt +++ b/python-spec/requirements-py3.10.txt @@ -1,4 +1,13 @@ +anndata==0.8.0 attrs==22.2.0 +h5py==3.7.0 +natsort==8.2.0 numpy==1.24.1 +packaging==23.0 +pandas==1.5.2 pyarrow==10.0.1 +python-dateutil==2.8.2 +pytz==2022.7 +scipy==1.10.0 +six==1.16.0 typing_extensions==4.4.0 diff --git a/python-spec/requirements-py3.7-lint.txt b/python-spec/requirements-py3.7-lint.txt index 1dd9b1c2..cd133536 100644 --- a/python-spec/requirements-py3.7-lint.txt +++ b/python-spec/requirements-py3.7-lint.txt @@ -1,8 +1,19 @@ +anndata==0.8.0 attrs==22.2.0 +h5py==3.7.0 +importlib-metadata==6.0.0 mypy==0.991 mypy-extensions==0.4.3 +natsort==8.2.0 numpy==1.21.6 +packaging==23.0 +pandas==1.3.5 pyarrow==10.0.1 +python-dateutil==2.8.2 +pytz==2022.7 +scipy==1.7.3 +six==1.16.0 tomli==2.0.1 typed-ast==1.5.4 typing_extensions==4.4.0 +zipp==3.11.0 diff --git a/python-spec/requirements-py3.7.txt b/python-spec/requirements-py3.7.txt index 8525a0a9..cc6a77b6 100644 --- a/python-spec/requirements-py3.7.txt +++ b/python-spec/requirements-py3.7.txt @@ -1,4 +1,15 @@ +anndata==0.8.0 attrs==22.2.0 +h5py==3.7.0 +importlib-metadata==6.0.0 +natsort==8.2.0 numpy==1.21.6 +packaging==23.0 +pandas==1.3.5 pyarrow==10.0.1 +python-dateutil==2.8.2 +pytz==2022.7 +scipy==1.7.3 +six==1.16.0 typing_extensions==4.4.0 +zipp==3.11.0 diff --git a/python-spec/requirements-py3.8.txt b/python-spec/requirements-py3.8.txt index d366e6d9..5ae9fc28 100644 --- a/python-spec/requirements-py3.8.txt +++ b/python-spec/requirements-py3.8.txt @@ -1,4 +1,13 @@ +anndata==0.8.0 attrs==22.2.0 +h5py==3.7.0 +natsort==8.2.0 numpy==1.24.1 +packaging==23.0 +pandas==1.5.2 pyarrow==10.0.1 +python-dateutil==2.8.2 +pytz==2022.7 +scipy==1.10.0 +six==1.16.0 typing_extensions==4.4.0 diff --git a/python-spec/requirements-py3.9.txt b/python-spec/requirements-py3.9.txt index d366e6d9..5ae9fc28 100644 --- a/python-spec/requirements-py3.9.txt +++ b/python-spec/requirements-py3.9.txt @@ -1,4 +1,13 @@ +anndata==0.8.0 attrs==22.2.0 +h5py==3.7.0 +natsort==8.2.0 numpy==1.24.1 +packaging==23.0 +pandas==1.5.2 pyarrow==10.0.1 +python-dateutil==2.8.2 +pytz==2022.7 +scipy==1.10.0 +six==1.16.0 typing_extensions==4.4.0 diff --git a/python-spec/src/somacore/__init__.py b/python-spec/src/somacore/__init__.py index 6565c3d4..beac3359 100644 --- a/python-spec/src/somacore/__init__.py +++ b/python-spec/src/somacore/__init__.py @@ -9,6 +9,7 @@ from somacore import ephemeral from somacore import options from somacore.query import axis +from somacore.query import query try: # This trips up mypy since it's a generated file: @@ -37,6 +38,7 @@ ResultOrder = options.ResultOrder AxisQuery = axis.AxisQuery +ExperimentAxisQuery = query.ExperimentAxisQuery __all__ = ( "SOMAObject", diff --git a/python-spec/src/somacore/composed.py b/python-spec/src/somacore/composed.py index fae9528d..0aa790a2 100644 --- a/python-spec/src/somacore/composed.py +++ b/python-spec/src/somacore/composed.py @@ -1,7 +1,5 @@ """Implementations of the composed SOMA data types.""" -from typing import Optional - from typing_extensions import Final from somacore import _wrap @@ -76,13 +74,15 @@ def axis_query( self, measurement_name: str, *, - obs_query: Optional[axis.AxisQuery] = None, - var_query: Optional[axis.AxisQuery] = None, - ) -> query.ExperimentAxisQuery: + obs_query: axis.AxisQuery = axis.AxisQuery(), + var_query: axis.AxisQuery = axis.AxisQuery(), + ) -> "query.ExperimentAxisQuery": """Creates an axis query over this experiment. See :class:`query.ExperimentAxisQuery` for details on usage. """ - raise NotImplementedError() + return query.ExperimentAxisQuery( + self, measurement_name, obs_query=obs_query, var_query=var_query + ) soma_type: Final = "SOMAExperiment" diff --git a/python-spec/src/somacore/query/query.py b/python-spec/src/somacore/query/query.py index f1faf577..d3a10ee9 100644 --- a/python-spec/src/somacore/query/query.py +++ b/python-spec/src/somacore/query/query.py @@ -1,12 +1,19 @@ -import abc -import contextlib -from typing import Any, Optional, Sequence +import enum +from concurrent import futures +from typing import Any, Dict, Optional, Sequence, Tuple, Union import anndata +import attrs +import numpy as np +import numpy.typing as npt +import pandas as pd import pyarrow as pa -from typing_extensions import TypedDict +from scipy import sparse +from typing_extensions import Literal, TypedDict, assert_never +from somacore import composed from somacore import data +from somacore.query import axis class AxisColumnNames(TypedDict, total=False): @@ -18,60 +25,119 @@ class AxisColumnNames(TypedDict, total=False): """var columns to use. All columns if ``None`` or not present.""" -class ExperimentAxisQuery(contextlib.AbstractContextManager, metaclass=abc.ABCMeta): - @abc.abstractmethod +class ExperimentAxisQuery: + """Axis-based query against a SOMA Experiment. [lifecycle: experimental] + + ExperimentAxisQuery allows easy selection and extraction of data from a + single soma.Measurement in a soma.Experiment, by obs/var (axis) coordinates + and/or value filter. + + The primary use for this class is slicing Experiment ``X`` layers by obs or + var value and/or coordinates. Slicing on SparseNDArray ``X`` matrices is + supported; DenseNDArray is not supported at this time. + + IMPORTANT: this class is not thread-safe. + + IMPORTANT: this query class assumes it can store the full result of both + axis dataframe queries in memory, and only provides incremental access to + the underlying X NDArray. API features such as `n_obs` and `n_vars` codify + this in the API. + + IMPORTANT: you must call `close()` on any instance of this class in order to + release underlying resources. The ExperimentAxisQuery is a context manager, + and it is recommended that you use the following pattern to make this easy + and safe:: + + with ExperimentAxisQuery(...) as query: + ... + + This base query implementation is designed to work against any SOMA + implementation that fulfills the basic APIs. A SOMA implementation may + include a custom query implementation optimized for its own use. + """ + + def __init__( + self, + experiment: "composed.Experiment", + measurement_name: str, + *, + obs_query: axis.AxisQuery = axis.AxisQuery(), + var_query: axis.AxisQuery = axis.AxisQuery(), + ): + if measurement_name not in experiment.ms: + raise ValueError("Measurement does not exist in the experiment") + + self.experiment = experiment + self.measurement_name = measurement_name + + self._matrix_axis_query = _MatrixAxisQuery(obs=obs_query, var=var_query) + self._joinids = _JoinIDCache(self) + self._indexer = _AxisIndexer(self) + self._threadpool_: Optional[futures.ThreadPoolExecutor] = None + def obs( self, *, column_names: Optional[Sequence[str]] = None ) -> data.ReadIter[pa.Table]: - """Returns ``obs`` as an Arrow table iterator.""" - raise NotImplementedError() + """Returns ``obs`` as an Arrow table iterator. [lifecycle: experimental]""" + obs_query = self._matrix_axis_query.obs + return self._obs_df.read( + ids=obs_query.coords, + value_filter=obs_query.value_filter, + column_names=column_names, + ) - @abc.abstractmethod def var( self, *, column_names: Optional[Sequence[str]] = None ) -> data.ReadIter[pa.Table]: - """Returns ``var`` as an Arrow table iterator.""" - raise NotImplementedError() + """Returns ``var`` as an Arrow table iterator. [lifecycle: experimental]""" + var_query = self._matrix_axis_query.var + return self._var_df.read( + ids=var_query.coords, + value_filter=var_query.value_filter, + column_names=column_names, + ) - @abc.abstractmethod def obs_joinids(self) -> pa.Array: - """Returns ``obs`` ``soma_joinids`` as an Arrow array.""" - raise NotImplementedError() + """Returns ``obs`` ``soma_joinids`` as an Arrow array. [lifecycle: experimental]""" + return self._joinids.obs - @abc.abstractmethod def var_joinids(self) -> pa.Array: - """Returns ``var`` ``soma_joinids`` as an Arrow array.""" - raise NotImplementedError() + """Returns ``var`` ``soma_joinids`` as an Arrow array. [lifecycle: experimental]""" + return self._joinids.var @property def n_obs(self) -> int: - """The number of ``obs`` axis query results.""" + """The number of ``obs`` axis query results. [lifecycle: experimental]""" return len(self.obs_joinids()) @property def n_vars(self) -> int: - """The number of ``var`` axis query results.""" + """The number of ``var`` axis query results. [lifecycle: experimental]""" return len(self.var_joinids()) - @abc.abstractmethod def X(self, layer_name: str) -> data.SparseRead: - """Returns an ``X`` layer as ``SparseRead`` data. + """Returns an ``X`` layer as ``SparseRead`` data. [lifecycle: experimental] :param layer_name: The X layer name to return. """ - raise NotImplementedError() + try: + x_layer = self._ms.X[layer_name] + except KeyError as ke: + raise KeyError(f"{layer_name} is not present in X") from ke + if not isinstance(x_layer, data.SparseNDArray): + raise TypeError("X layers may only be sparse arrays") + + self._joinids.preload(self._threadpool) + return x_layer.read((self._joinids.obs, self._joinids.var)) - @abc.abstractmethod def obsp(self, layer: str) -> data.SparseRead: - """Return an ``obsp`` layer as a SparseNDArrayRead""" - raise NotImplementedError() + """Return an ``obsp`` layer as a sparse read. [lifecycle: experimental]""" + return self._axisp_inner(_Axis.OBS, layer) - @abc.abstractmethod def varp(self, layer: str) -> data.SparseRead: - """Return an ``varp`` layer as a SparseNDArrayRead""" - raise NotImplementedError() + """Return an ``varp`` layer as a sparse read. [lifecycle: experimental]""" + return self._axisp_inner(_Axis.VAR, layer) - @abc.abstractmethod def to_anndata( self, X_name: str, @@ -81,6 +147,7 @@ def to_anndata( ) -> anndata.AnnData: """ Execute the query and return result as an ``AnnData`` in-memory object. + [lifecycle: experimental] :param X_name: The name of the X layer to read and return in the ``X`` slot. @@ -89,17 +156,34 @@ def to_anndata( :param X_layers: Additional X layers to read and return in the ``layers`` slot. """ - raise NotImplementedError() + query_result = self._read( + X_name, + column_names=column_names or AxisColumnNames(obs=None, var=None), + X_layers=X_layers, + ) + + # AnnData uses positional indexing + return self._indexer.rewrite(query_result).to_anndata() # Context management - @abc.abstractmethod def close(self) -> None: - """Releases resources associated with this query. + """Releases resources associated with this query. [lifecycle: experimental] This method must be idempotent. """ - raise NotImplementedError() + # Because this may be called during `__del__` when we might be getting + # disassembled, sometimes `_threadpool_` is simply missing. + # Only try to shut it down if it still exists. + pool = getattr(self, "_threadpool_", None) + if pool is None: + return + pool.shutdown() + self._threadpool_ = None + + # TODO: This should be "Self" once mypy supports that. + def __enter__(self) -> "ExperimentAxisQuery": + return self def __exit__(self, *_: Any) -> None: self.close() @@ -110,3 +194,377 @@ def __del__(self) -> None: sdel = getattr(super(), "__del__", lambda: None) sdel() self.close() + + # Internals + + def _read( + self, + X_name: str, + *, + column_names: AxisColumnNames, + X_layers: Sequence[str], + ) -> "_AxisQueryResult": + """Reads the entire query result into in-memory Arrow tables. + + This is a low-level routine intended to be used by loaders for other + in-core formats, such as AnnData, which can be created from the + resulting Tables. + + :param X_name: The name of the X layer to read and return + in the ``AnnData.X`` slot + :param column_names: Specify which column names in ``var`` and ``obs`` + dataframes to read and return. + :param X_layers: Addtional X layers read read and return in the + ``AnnData.layers`` slot + + """ + x_collection = self._ms.X + all_x_names = [X_name] + list(X_layers) + all_x_arrays: Dict[str, data.SparseNDArray] = {} + for _xname in all_x_names: + if not isinstance(_xname, str) or not _xname: + raise ValueError("X layer names must be specified as a string.") + if _xname not in x_collection: + raise ValueError("Unknown X layer name") + x_array = x_collection[_xname] + if not isinstance(x_array, data.SparseNDArray): + raise NotImplementedError("Dense array unsupported") + all_x_arrays[_xname] = x_array + + obs_table, var_table = self._read_both_axes(column_names) + + x_tables = { + # TODO: could also be done concurrently + _xname: all_x_arrays[_xname] + .read((self.obs_joinids(), self.var_joinids())) + .tables() + .concat() + for _xname in all_x_arrays + } + + x = x_tables.pop(X_name) + return _AxisQueryResult(obs=obs_table, var=var_table, X=x, X_layers=x_tables) + + def _read_both_axes( + self, + column_names: AxisColumnNames, + ) -> Tuple[pa.Table, pa.Table]: + """Reads both axes in their entirety, ensuring soma_joinid is retained.""" + obs_ft = self._threadpool.submit( + self._read_axis_dataframe, + _Axis.OBS, + column_names, + ) + var_ft = self._threadpool.submit( + self._read_axis_dataframe, + _Axis.VAR, + column_names, + ) + return obs_ft.result(), var_ft.result() + + def _read_axis_dataframe( + self, + axis: "_Axis", + axis_column_names: AxisColumnNames, + ) -> pa.Table: + """Reads the specified axis. Will cache join IDs if not present.""" + # mypy is not currently clever enough to figure out the type of the + # column names here, so we have to help it out. + column_names: Optional[Sequence[str]] = axis_column_names.get(axis.value) + if axis is _Axis.OBS: + axis_df = self._obs_df + axis_query = self._matrix_axis_query.obs + elif axis is _Axis.VAR: + axis_df = self._var_df + axis_query = self._matrix_axis_query.var + else: + assert_never(axis) # must be obs or var + + # If we can cache join IDs, prepare to add them to the cache. + joinids_cached = self._joinids._is_cached(axis) + query_columns = column_names + if ( + not joinids_cached + and column_names is not None + and "soma_joinid" not in column_names + ): + # If we want to fill the join ID cache, ensure that we query the + # soma_joinid column so that it is included in the results. + # We'll filter it out later. + query_columns = ["soma_joinid"] + list(column_names) + + # Do the actual query. + arrow_table = axis_df.read( + ids=axis_query.coords, + value_filter=axis_query.value_filter, + column_names=query_columns, + ).concat() + + # Update the cache if needed. We can do this because no matter what + # other columns are queried for, the contents of the `soma_joinid` + # column will be the same and can be safely stored. + if not joinids_cached: + setattr( + self._joinids, + axis.value, + arrow_table.column("soma_joinid").combine_chunks(), + ) + + # Ensure that we return the exact columns the caller was expecting, + # even if we added our own above. + if column_names is not None: + arrow_table = arrow_table.select(column_names) + return arrow_table + + def _axisp_inner( + self, + axis: "_Axis", + layer: str, + ) -> data.SparseRead: + key = axis.value + "p" + + if key not in self._ms: + raise ValueError(f"Measurement does not contain {key} data") + + axisp = self._ms.obsp if axis is _Axis.OBS else self._ms.varp + if not (layer and layer in axisp): + raise ValueError(f"Must specify '{key}' layer") + if not isinstance(axisp[layer], data.SparseNDArray): + raise TypeError(f"Unexpected SOMA type stored in '{key}' layer") + + joinids = getattr(self._joinids, axis.value) + return axisp[layer].read((joinids, joinids)) + + @property + def _obs_df(self) -> data.DataFrame: + return self.experiment.obs + + @property + def _ms(self) -> composed.Measurement: + return self.experiment.ms[self.measurement_name] + + @property + def _var_df(self) -> data.DataFrame: + return self._ms.var + + @property + def _threadpool(self) -> futures.ThreadPoolExecutor: + """Creates a thread pool just in time.""" + if self._threadpool_ is None: + # TODO: the user should be able to set their own threadpool, a la asyncio's + # loop.set_default_executor(). This is important for managing the level of + # concurrency, etc. + self._threadpool_ = futures.ThreadPoolExecutor() + return self._threadpool_ + + +# Private internal data structures + + +@attrs.define(frozen=True) +class _AxisQueryResult: + """Return type for the ExperimentAxisQuery.read() method""" + + obs: pa.Table + """Experiment.obs query slice, as an Arrow Table""" + var: pa.Table + """Experiment.ms[...].var query slice, as an Arrow Table""" + X: pa.Table + """Experiment.ms[...].X[...] query slice, as an Arrow Table""" + X_layers: Dict[str, pa.Table] = attrs.field(factory=dict) + """Any additional X layers requested, as Arrow Table(s)""" + + def to_anndata(self) -> anndata.AnnData: + """Convert to AnnData""" + obs = self.obs.to_pandas() + obs.index = obs.index.map(str) + + var = self.var.to_pandas() + var.index = var.index.map(str) + + shape = (len(obs), len(var)) + + x = self.X + if x is not None: + x = _arrow_to_scipy_csr(x, shape) + + layers = { + name: _arrow_to_scipy_csr(table, shape) + for name, table in self.X_layers.items() + } + return anndata.AnnData(X=x, obs=obs, var=var, layers=(layers or None)) + + +class _Axis(enum.Enum): + OBS = "obs" + VAR = "var" + + @property + def value(self) -> Literal["obs", "var"]: + return super().value + + +@attrs.define(frozen=True) +class _MatrixAxisQuery: + """Private: store per-axis user query definition""" + + obs: axis.AxisQuery + var: axis.AxisQuery + + +@attrs.define +class _JoinIDCache: + """Private: cache per-axis join ids in the query""" + + owner: ExperimentAxisQuery + + _cached_obs: Optional[pa.Array] = None + _cached_var: Optional[pa.Array] = None + + def _is_cached(self, axis: _Axis) -> bool: + field = "_cached_" + axis.value + return getattr(self, field) is not None + + def preload(self, pool: futures.ThreadPoolExecutor) -> None: + if self._cached_obs is not None and self._cached_var is not None: + return + obs_ft = pool.submit(lambda: self.obs) + var_ft = pool.submit(lambda: self.var) + # Wait for them and raise in case of error. + obs_ft.result() + var_ft.result() + + @property + def obs(self) -> pa.Array: + """Join IDs for the obs axis. Will load and cache if not already.""" + if not self._cached_obs: + self._cached_obs = _load_joinids( + self.owner._obs_df, self.owner._matrix_axis_query.obs + ) + return self._cached_obs + + @obs.setter + def obs(self, val: pa.Array) -> None: + self._cached_obs = val + + @property + def var(self) -> pa.Array: + """Join IDs for the var axis. Will load and cache if not already.""" + if not self._cached_var: + self._cached_var = _load_joinids( + self.owner._var_df, self.owner._matrix_axis_query.var + ) + return self._cached_var + + @var.setter + def var(self, val: pa.Array) -> None: + self._cached_var = val + + +def _load_joinids(df: data.DataFrame, axq: axis.AxisQuery) -> pa.Array: + tbl = df.read( + ids=axq.coords, + value_filter=axq.value_filter, + column_names=["soma_joinid"], + ).concat() + return tbl.column("soma_joinid").combine_chunks() + + +_Numpyable = Union[pa.Array, pa.ChunkedArray, npt.NDArray[np.int64]] +"""Things that can be converted to a NumPy array.""" + + +@attrs.define +class _AxisIndexer: + """Given a query, providing index-bulding services for obs/var axis.""" + + query: ExperimentAxisQuery + _cached_obs: Optional[pd.Index] = None + _cached_var: Optional[pd.Index] = None + + @property + def _obs_index(self) -> pd.Index: + if self._cached_obs is None: + self._cached_obs = pd.Index(data=self.query.obs_joinids().to_numpy()) + return self._cached_obs + + @property + def _var_index(self) -> pd.Index: + if self._cached_var is None: + self._cached_var = pd.Index(data=self.query.var_joinids().to_numpy()) + return self._cached_var + + def by_obs(self, coords: _Numpyable) -> npt.NDArray[np.intp]: + return self._obs_index.get_indexer(_to_numpy(coords)) + + def by_var(self, coords: _Numpyable) -> npt.NDArray[np.intp]: + return self._var_index.get_indexer(_to_numpy(coords)) + + def rewrite(self, qr: _AxisQueryResult) -> _AxisQueryResult: + """Rewrite the result to prepare for AnnData positional indexing.""" + return attrs.evolve( + qr, + X=self._rewrite_matrix(qr.X), + X_layers={ + name: self._rewrite_matrix(matrix) + for name, matrix in qr.X_layers.items() + }, + ) + + def _rewrite_matrix(self, x_table: pa.Table) -> pa.Table: + """ + Private convenience function to convert axis dataframe to X matrix joins + from ``soma_joinid``-based joins to positionally indexed joins + (like AnnData uses). + + Input is organized as: + obs[i] annotates X[ obs[i].soma_joinid, : ] + and + var[j] annotates X[ :, var[j].soma_joinid ] + + Output is organized as: + obs[i] annotates X[i, :] + and + var[j] annotates X[:, j] + + In addition, the ``soma_joinid`` column is dropped from axis dataframes. + """ + + return pa.Table.from_arrays( + ( + self.by_obs(x_table["soma_dim_0"]), + self.by_var(x_table["soma_dim_1"]), + # This consolidates chunks as a side effect. + x_table["soma_data"].to_numpy(), + ), + names=("_dim_0", "_dim_1", "soma_data"), + ) + + +def _to_numpy(it: _Numpyable) -> np.ndarray: + if isinstance(it, np.ndarray): + return it + return it.to_numpy() + + +def _arrow_to_scipy_csr( + arrow_table: pa.Table, shape: Tuple[int, int] +) -> sparse.csr_matrix: + """ + Private utility which converts a table repesentation of X to a CSR matrix. + + IMPORTANT: by convention, assumes that the data is positionally indexed (hence + the use of _dim_{n} rather than soma_dim{n}). + + See query.py::_rewrite_X_for_positional_indexing for more info. + """ + assert "_dim_0" in arrow_table.column_names, "X must be positionally indexed" + assert "_dim_1" in arrow_table.column_names, "X must be positionally indexed" + + return sparse.csr_matrix( + ( + arrow_table["soma_data"].to_numpy(), + (arrow_table["_dim_0"].to_numpy(), arrow_table["_dim_1"].to_numpy()), + ), + shape=shape, + )