Skip to content

Commit

Permalink
[python] Re-enable tiledbsoma.ExperimentAxisQuery (#3476) (#3479)
Browse files Browse the repository at this point in the history
* [python] Re-enable `tiledbsoma.ExperimentAxisQuery`

* Update unit-test case to use the user-facing API

* code-review feedback

* `somacore.Axis` and `tiledbsoma._query.Axis` are not the same thing

* clarify a confusing internal name

Co-authored-by: John Kerl <[email protected]>
  • Loading branch information
github-actions[bot] and johnkerl authored Dec 19, 2024
1 parent 6145ed2 commit 241dccf
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 34 deletions.
11 changes: 6 additions & 5 deletions apis/python/src/tiledbsoma/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,17 @@
except OSError:
# Otherwise try loading by name only.
ctypes.CDLL(libtiledbsoma_name)

from somacore import (
AffineTransform,
Axis,
AxisColumnNames,
AxisQuery,
CoordinateSpace,
AffineTransform,
ScaleTransform,
IdentityTransform,
ScaleTransform,
UniformScaleTransform,
AxisColumnNames,
AxisQuery,
)
from ._query import (
ExperimentAxisQuery,
)
from somacore.options import ResultOrder
Expand Down
38 changes: 19 additions & 19 deletions apis/python/src/tiledbsoma/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def obs(self) -> _T_co: ...
def var(self) -> _T_co: ...


class Axis(enum.Enum):
class AxisName(enum.Enum):
OBS = "obs"
VAR = "var"

Expand Down Expand Up @@ -376,29 +376,29 @@ def obsp(self, layer: str) -> SparseRead:
Lifecycle: maturing
"""
joinids = self._joinids.obs
return self._axisp_get_array(Axis.OBS, layer).read((joinids, joinids))
return self._axisp_get_array(AxisName.OBS, layer).read((joinids, joinids))

def varp(self, layer: str) -> SparseRead:
"""Returns a ``varp`` layer as a sparse read.
Lifecycle: maturing
"""
joinids = self._joinids.var
return self._axisp_get_array(Axis.VAR, layer).read((joinids, joinids))
return self._axisp_get_array(AxisName.VAR, layer).read((joinids, joinids))

def obsm(self, layer: str) -> SparseRead:
"""Returns an ``obsm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(Axis.OBS, layer).read(
return self._axism_get_array(AxisName.OBS, layer).read(
(self._joinids.obs, slice(None))
)

def varm(self, layer: str) -> SparseRead:
"""Returns a ``varm`` layer as a sparse read.
Lifecycle: maturing
"""
return self._axism_get_array(Axis.VAR, layer).read(
return self._axism_get_array(AxisName.VAR, layer).read(
(self._joinids.var, slice(None))
)

Expand All @@ -421,7 +421,7 @@ def obs_scene_ids(self) -> pa.Array:
)

full_table = obs_scene.read(
coords=((Axis.OBS.getattr_from(self._joinids), slice(None))),
coords=((AxisName.OBS.getattr_from(self._joinids), slice(None))),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand All @@ -448,7 +448,7 @@ def var_scene_ids(self) -> pa.Array:
)

full_table = var_scene.read(
coords=((Axis.VAR.getattr_from(self._joinids), slice(None))),
coords=((AxisName.VAR.getattr_from(self._joinids), slice(None))),
result_order=ResultOrder.COLUMN_MAJOR,
value_filter="data != 0",
).concat()
Expand Down Expand Up @@ -625,7 +625,7 @@ def _read(

obs_table, var_table = tp.map(
self._read_axis_dataframe,
(Axis.OBS, Axis.VAR),
(AxisName.OBS, AxisName.VAR),
(column_names, column_names),
)
obs_joinids = self.obs_joinids()
Expand All @@ -645,19 +645,19 @@ def _read(
x_future = x_matrices.pop(X_name)

obsm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.OBS, key)
key: tp.submit(self._axism_inner_ndarray, AxisName.OBS, key)
for key in obsm_layers
}
varm_future = {
key: tp.submit(self._axism_inner_ndarray, Axis.VAR, key)
key: tp.submit(self._axism_inner_ndarray, AxisName.VAR, key)
for key in varm_layers
}
obsp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.OBS, key)
key: tp.submit(self._axisp_inner_sparray, AxisName.OBS, key)
for key in obsp_layers
}
varp_future = {
key: tp.submit(self._axisp_inner_sparray, Axis.VAR, key)
key: tp.submit(self._axisp_inner_sparray, AxisName.VAR, key)
for key in varp_layers
}

Expand All @@ -680,7 +680,7 @@ def _read(

def _read_axis_dataframe(
self,
axis: Axis,
axis: AxisName,
axis_column_names: AxisColumnNames,
) -> pa.Table:
"""Reads the specified axis. Will cache join IDs if not present."""
Expand Down Expand Up @@ -730,7 +730,7 @@ def _read_axis_dataframe(

def _axisp_get_array(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> SparseNDArray:
p_name = f"{axis.value}p"
Expand All @@ -754,7 +754,7 @@ def _axisp_get_array(

def _axism_get_array(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> SparseNDArray:
m_name = f"{axis.value}m"
Expand All @@ -776,7 +776,7 @@ def _axism_get_array(
return axism_layer

def _convert_to_ndarray(
self, axis: Axis, table: pa.Table, n_row: int, n_col: int
self, axis: AxisName, table: pa.Table, n_row: int, n_col: int
) -> npt.NDArray[np.float32]:
indexer = cast(
Callable[[Numpyable], npt.NDArray[np.intp]],
Expand All @@ -789,7 +789,7 @@ def _convert_to_ndarray(

def _axisp_inner_sparray(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> sp.csr_matrix:
joinids = axis.getattr_from(self._joinids)
Expand All @@ -803,7 +803,7 @@ def _axisp_inner_sparray(

def _axism_inner_ndarray(
self,
axis: Axis,
axis: AxisName,
layer: str,
) -> npt.NDArray[np.float32]:
joinids = axis.getattr_from(self._joinids)
Expand Down Expand Up @@ -856,7 +856,7 @@ class JoinIDCache:
_cached_obs: pa.IntegerArray | None = None
_cached_var: pa.IntegerArray | None = None

def _is_cached(self, axis: Axis) -> bool:
def _is_cached(self, axis: AxisName) -> bool:
field = "_cached_" + axis.value
return getattr(self, field) is not None

Expand Down
25 changes: 15 additions & 10 deletions apis/python/tests/test_experiment_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,15 @@
from somacore import AxisQuery, options

import tiledbsoma as soma
from tiledbsoma import SOMATileDBContext, _factory, pytiledbsoma
from tiledbsoma import (
ExperimentAxisQuery,
SOMATileDBContext,
_factory,
pytiledbsoma,
)
from tiledbsoma._collection import CollectionBase
from tiledbsoma._experiment import Experiment
from tiledbsoma._query import Axis, ExperimentAxisQuery
from tiledbsoma._query import AxisName
from tiledbsoma.experiment_query import X_as_series

from tests._util import raises_no_typeguard
Expand Down Expand Up @@ -944,12 +949,12 @@ class IHaveObsVarStuff:

def test_axis_helpers() -> None:
thing = IHaveObsVarStuff(obs=1, var=2, the_obs_suf="observe", the_var_suf="vary")
assert 1 == Axis.OBS.getattr_from(thing)
assert 2 == Axis.VAR.getattr_from(thing)
assert "observe" == Axis.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == Axis.VAR.getattr_from(thing, pre="the_", suf="_suf")
assert 1 == AxisName.OBS.getattr_from(thing)
assert 2 == AxisName.VAR.getattr_from(thing)
assert "observe" == AxisName.OBS.getattr_from(thing, pre="the_", suf="_suf")
assert "vary" == AxisName.VAR.getattr_from(thing, pre="the_", suf="_suf")
ovdict = {"obs": "erve", "var": "y", "i_obscure": "hide", "i_varcure": "???"}
assert "erve" == Axis.OBS.getitem_from(ovdict)
assert "y" == Axis.VAR.getitem_from(ovdict)
assert "hide" == Axis.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == Axis.VAR.getitem_from(ovdict, pre="i_", suf="cure")
assert "erve" == AxisName.OBS.getitem_from(ovdict)
assert "y" == AxisName.VAR.getitem_from(ovdict)
assert "hide" == AxisName.OBS.getitem_from(ovdict, pre="i_", suf="cure")
assert "???" == AxisName.VAR.getitem_from(ovdict, pre="i_", suf="cure")

0 comments on commit 241dccf

Please sign in to comment.