diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 975d373a..28c5f8e5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,11 +33,25 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-latest] - python-version: ["3.9", "3.10", "3.11", "3.12"] + python-version: ["3.10", "3.11"] + + test-array-libs: + uses: pyapp-kit/workflows/.github/workflows/test-pyrepo.yml@v2 + with: + os: ${{ matrix.os }} + python-version: ${{ matrix.python-version }} + extras: "test,third_party_arrays" + coverage-upload: artifact + qt: pyqt6 + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python-version: ["3.9", "3.12"] upload_coverage: if: always() - needs: [test] + needs: [test, test-array-libs] uses: pyapp-kit/workflows/.github/workflows/upload-coverage.yml@v2 secrets: codecov_token: ${{ secrets.CODECOV_TOKEN }} diff --git a/README.md b/README.md index 48e36161..86180248 100644 --- a/README.md +++ b/README.md @@ -9,14 +9,12 @@ Simple, fast-loading, asynchronous, n-dimensional viewer for Qt, with minimal dependencies. ```python -from qtpy import QtWidgets -from ndv import NDViewer -from skimage import data # just for example data here - -qapp = QtWidgets.QApplication([]) -v = NDViewer(data.cells3d()) -v.show() -qapp.exec() +import ndv + +data = ndv.data.cells3d() +# or ndv.data.nd_sine_wave() +# or *any* arraylike object (see support below) +ndv.imshow(data) ``` ![Montage](https://github.com/pyapp-kit/ndv/assets/1609449/712861f7-ddcb-4ecd-9a4c-ba5f0cc1ee2c) @@ -27,12 +25,22 @@ qapp.exec() - sliders support integer as well as slice (range)-based slicing - colormaps provided by [cmap](https://github.com/tlambert03/cmap) - supports [vispy](https://github.com/vispy/vispy) and [pygfx](https://github.com/pygfx/pygfx) backends -- supports any numpy-like duck arrays, with special support for features in: - - `xarray.DataArray` +- supports any numpy-like duck arrays, including (but not limited to): + - `numpy.ndarray` + - `cupy.ndarray` - `dask.array.Array` - - `tensorstore.TensorStore` - - `zarr` - - `dask` + - `jax.Array` + - `pyopencl.array.Array` + - `sparse.COO` + - `tensorstore.TensorStore` (supports named dimensions) + - `torch.Tensor` (supports named dimensions) + - `xarray.DataArray` (supports named dimensions) + - `zarr` (supports named dimensions) +- You can add support for your own storage class by subclassing `ndv.DataWrapper` + and implementing a couple methods. (This doesn't require modifying ndv, + but contributions of new wrappers are welcome!) + +See examples for each of these array types in [examples](./examples/) ## Installation diff --git a/examples/custom_store.py b/examples/custom_store.py new file mode 100644 index 00000000..9d3fbff6 --- /dev/null +++ b/examples/custom_store.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +import numpy as np + +import ndv + +if TYPE_CHECKING: + from ndv import Indices, Sizes + + +class MyArrayThing: + def __init__(self, shape: tuple[int, ...]) -> None: + self.shape = shape + self._data = np.random.randint(0, 256, shape) + + def __getitem__(self, item: Any) -> np.ndarray: + return self._data[item] # type: ignore [no-any-return] + + +class MyWrapper(ndv.DataWrapper[MyArrayThing]): + @classmethod + def supports(cls, data: Any) -> bool: + if isinstance(data, MyArrayThing): + return True + return False + + def sizes(self) -> Sizes: + """Return a mapping of {dim: size} for the data""" + return {f"dim_{k}": v for k, v in enumerate(self.data.shape)} + + def isel(self, indexers: Indices) -> Any: + """Convert mapping of {dim: index} to conventional indexing""" + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self.data.shape))) + return self.data[idx] + + +data = MyArrayThing((10, 3, 512, 512)) +ndv.imshow(data) diff --git a/examples/dask_arr.py b/examples/dask_arr.py index a9c514e1..402f406c 100644 --- a/examples/dask_arr.py +++ b/examples/dask_arr.py @@ -6,6 +6,7 @@ from dask.array.core import map_blocks except ImportError: raise ImportError("Please `pip install dask[array]` to run this example.") +import ndv frame_size = (1024, 1024) @@ -21,12 +22,4 @@ def _dask_block(block_id: tuple[int, int, int, int, int]) -> np.ndarray | None: chunks += [(x,) for x in frame_size] dask_arr = map_blocks(_dask_block, chunks=chunks, dtype=np.uint8) -if __name__ == "__main__": - from qtpy import QtWidgets - - from ndv import NDViewer - - qapp = QtWidgets.QApplication([]) - v = NDViewer(dask_arr) - v.show() - qapp.exec() +v = ndv.imshow(dask_arr) diff --git a/examples/jax_arr.py b/examples/jax_arr.py index ed0e3208..dc803622 100644 --- a/examples/jax_arr.py +++ b/examples/jax_arr.py @@ -4,17 +4,7 @@ import jax.numpy as jnp except ImportError: raise ImportError("Please install jax to run this example") -from numpy_arr import generate_5d_sine_wave -from qtpy import QtWidgets +import ndv -from ndv import NDViewer - -# Example usage -array_shape = (10, 3, 5, 512, 512) # Specify the desired dimensions -sine_wave_5d = jnp.asarray(generate_5d_sine_wave(array_shape)) - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = NDViewer(sine_wave_5d, channel_axis=1) - v.show() - qapp.exec() +jax_arr = jnp.asarray(ndv.data.nd_sine_wave()) +v = ndv.imshow(jax_arr) diff --git a/examples/numpy_arr.py b/examples/numpy_arr.py index d9b0fb86..7a4ee67a 100644 --- a/examples/numpy_arr.py +++ b/examples/numpy_arr.py @@ -1,64 +1,11 @@ from __future__ import annotations -import numpy as np - - -def generate_5d_sine_wave( - shape: tuple[int, int, int, int, int], - amplitude: float = 240, - base_frequency: float = 5, -) -> np.ndarray: - """5D dataset.""" - # Unpack the dimensions - angle_dim, freq_dim, phase_dim, ny, nx = shape - - # Create an empty array to hold the data - output = np.zeros(shape) - - # Define spatial coordinates for the last two dimensions - half_per = base_frequency * np.pi - x = np.linspace(-half_per, half_per, nx) - y = np.linspace(-half_per, half_per, ny) - y, x = np.meshgrid(y, x) - - # Iterate through each parameter in the higher dimensions - for phase_idx in range(phase_dim): - for freq_idx in range(freq_dim): - for angle_idx in range(angle_dim): - # Calculate phase and frequency - phase = np.pi / phase_dim * phase_idx - frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step - - # Calculate angle - angle = np.pi / angle_dim * angle_idx - # Rotate x and y coordinates - xr = np.cos(angle) * x - np.sin(angle) * y - np.sin(angle) * x + np.cos(angle) * y - - # Compute the sine wave - sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase) - sine_wave += amplitude * 0.5 - - # Assign to the output array - output[angle_idx, freq_idx, phase_idx] = sine_wave - - return output - +import ndv try: - from skimage import data - - img = data.cells3d() -except Exception: - img = generate_5d_sine_wave((10, 3, 8, 512, 512)) - - -if __name__ == "__main__": - from qtpy import QtWidgets - - from ndv import NDViewer + img = ndv.data.cells3d() +except Exception as e: + print(e) + img = ndv.data.nd_sine_wave((10, 3, 8, 512, 512)) - qapp = QtWidgets.QApplication([]) - v = NDViewer(img) - v.show() - qapp.exec() +ndv.imshow(img) diff --git a/examples/pyopencl_arr.py b/examples/pyopencl_arr.py new file mode 100644 index 00000000..149e9f4c --- /dev/null +++ b/examples/pyopencl_arr.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +try: + import pyopencl as cl + import pyopencl.array as cl_array +except ImportError: + raise ImportError("Please install pyopencl to run this example") +import ndv + +# Set up OpenCL context and queue +context = cl.create_some_context(interactive=False) +queue = cl.CommandQueue(context) + + +gpu_data = cl_array.to_device(queue, ndv.data.nd_sine_wave()) + +ndv.imshow(gpu_data) diff --git a/examples/sparse_arr.py b/examples/sparse_arr.py new file mode 100644 index 00000000..9739800a --- /dev/null +++ b/examples/sparse_arr.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +try: + import sparse +except ImportError: + raise ImportError("Please install sparse to run this example") + +import numpy as np + +import ndv + +shape = (256, 4, 512, 512) +N = int(np.prod(shape) * 0.001) +coords = np.random.randint(low=0, high=shape, size=(N, len(shape))).T +data = np.random.randint(0, 256, N) + + +# Create the sparse array from the coordinates and data +sparse_array = sparse.COO(coords, data, shape=shape) + +ndv.imshow(sparse_array) diff --git a/examples/tensorstore_arr.py b/examples/tensorstore_arr.py index 9ac30a90..fdb9505d 100644 --- a/examples/tensorstore_arr.py +++ b/examples/tensorstore_arr.py @@ -1,23 +1,28 @@ from __future__ import annotations -import numpy as np -import tensorstore as ts -from qtpy import QtWidgets +try: + import tensorstore as ts +except ImportError: + raise ImportError("Please install tensorstore to run this example") -from ndv import NDViewer -shape = (10, 4, 3, 512, 512) +import ndv + +data = ndv.data.cells3d() + ts_array = ts.open( - {"driver": "zarr", "kvstore": {"driver": "memory"}}, + { + "driver": "zarr", + "kvstore": {"driver": "memory"}, + "transform": { + # tensorstore supports labeled dimensions + "input_labels": ["z", "c", "y", "x"], + }, + }, create=True, - shape=shape, - dtype=ts.uint8, + shape=data.shape, + dtype=data.dtype, ).result() -ts_array[:] = np.random.randint(0, 255, size=shape, dtype=np.uint8) -ts_array = ts_array[ts.d[:].label["t", "c", "z", "y", "x"]] +ts_array[:] = ndv.data.cells3d() -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = NDViewer(ts_array) - v.show() - qapp.exec() +ndv.imshow(ts_array) diff --git a/examples/torch_arr.py b/examples/torch_arr.py new file mode 100644 index 00000000..0f7c9174 --- /dev/null +++ b/examples/torch_arr.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +try: + import torch +except ImportError: + raise ImportError("Please install torch to run this example") + +import warnings + +import ndv + +warnings.filterwarnings("ignore", "Named tensors") # Named tensors are experimental + +# Example usage +try: + torch_data = torch.tensor(ndv.data.nd_sine_wave(), names=("t", "c", "z", "y", "x")) +except TypeError: + print("Named tensors are not supported in your version of PyTorch") + torch_data = torch.tensor(ndv.data.nd_sine_wave()) + +ndv.imshow(torch_data) diff --git a/examples/xarray_arr.py b/examples/xarray_arr.py index 05eaac09..b19e139f 100644 --- a/examples/xarray_arr.py +++ b/examples/xarray_arr.py @@ -1,14 +1,10 @@ from __future__ import annotations -import xarray as xr -from qtpy import QtWidgets - -from ndv import NDViewer +try: + import xarray as xr +except ImportError: + raise ImportError("Please install xarray to run this example") +import ndv da = xr.tutorial.open_dataset("air_temperature").air - -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = NDViewer(da, colormaps=["thermal"], channel_mode="composite") - v.show() - qapp.exec() +ndv.imshow(da, cmap="thermal") diff --git a/examples/zarr_arr.py b/examples/zarr_arr.py index ab31d9c8..e3a759bc 100644 --- a/examples/zarr_arr.py +++ b/examples/zarr_arr.py @@ -1,16 +1,15 @@ from __future__ import annotations -import zarr -import zarr.storage -from qtpy import QtWidgets +import ndv + +try: + import zarr + import zarr.storage +except ImportError: + raise ImportError("Please `pip install zarr aiohttp` to run this example") -from ndv import NDViewer URL = "https://s3.embl.de/i2k-2020/ngff-example-data/v0.4/tczyx.ome.zarr" zarr_arr = zarr.open(URL, mode="r") -if __name__ == "__main__": - qapp = QtWidgets.QApplication([]) - v = NDViewer(zarr_arr["s0"]) - v.show() - qapp.exec() +ndv.imshow(zarr_arr["s0"]) diff --git a/pyproject.toml b/pyproject.toml index 9ba6bf47..a2f179d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,15 +40,26 @@ pyqt = ["pyqt6"] vispy = ["vispy", "pyopengl"] pyside = ["pyside6"] pygfx = ["pygfx"] -test = ["pytest", "pytest-cov", "pytest-qt", "dask", "vispy", "pyopengl"] -dev = [ - "ipython", - "mypy", - "pdbpp", # https://github.com/pdbpp/pdbpp - "pre-commit", - "rich", # https://github.com/Textualize/rich - "ruff", +third_party_arrays = [ + "aiohttp", # for zarr example + "jax[cpu]", + "pooch", # for xarray example + "pyopencl[pocl]; sys_platform == 'linux'", + "sparse", + "tensorstore", + "torch", + "xarray", + "zarr", ] +test = [ + "ndv[vispy]", + "dask[array]", + "imageio[tifffile]", + "pytest-cov", + "pytest-qt", + "pytest", +] +dev = ["ipython", "mypy", "pdbpp", "pre-commit", "rich", "ruff"] [project.urls] homepage = "https://github.com/pyapp-kit/ndv" @@ -103,9 +114,11 @@ pretty = true # https://docs.pytest.org/ [tool.pytest.ini_options] +addopts = ["-v"] minversion = "7.0" testpaths = ["tests"] filterwarnings = ["error"] +markers = ["allow_leaks: mark test to allow widget leaks"] # https://coverage.readthedocs.io/ [tool.coverage.report] diff --git a/src/ndv/__init__.py b/src/ndv/__init__.py index 305e64e7..7faa6ea5 100644 --- a/src/ndv/__init__.py +++ b/src/ndv/__init__.py @@ -9,7 +9,21 @@ __author__ = "Talley Lambert" __email__ = "talley.lambert@example.com" -from .viewer._indexing import DataWrapper -from .viewer._stack_viewer import NDViewer +from typing import TYPE_CHECKING -__all__ = ["NDViewer", "DataWrapper"] +from . import data +from .util import imshow +from .viewer._data_wrapper import DataWrapper +from .viewer._viewer import NDViewer + +__all__ = ["NDViewer", "DataWrapper", "imshow", "data"] + + +if TYPE_CHECKING: + # these may be used externally, but are not guaranteed to be available at runtime + # they must be used inside a TYPE_CHECKING block + + from .viewer._dims_slider import DimKey as DimKey + from .viewer._dims_slider import Index as Index + from .viewer._dims_slider import Indices as Indices + from .viewer._dims_slider import Sizes as Sizes diff --git a/src/ndv/data.py b/src/ndv/data.py new file mode 100644 index 00000000..9f7a24e5 --- /dev/null +++ b/src/ndv/data.py @@ -0,0 +1,62 @@ +"""Sample data for testing and examples.""" + +import numpy as np + +__all__ = ["nd_sine_wave", "cells3d"] + + +def nd_sine_wave( + shape: tuple[int, int, int, int, int] = (10, 3, 5, 512, 512), + amplitude: float = 240, + base_frequency: float = 5, +) -> np.ndarray: + """5D dataset.""" + # Unpack the dimensions + if not len(shape) == 5: + raise ValueError("Shape must have 5 dimensions") + angle_dim, freq_dim, phase_dim, ny, nx = shape + + # Create an empty array to hold the data + output = np.zeros(shape) + + # Define spatial coordinates for the last two dimensions + half_per = base_frequency * np.pi + x = np.linspace(-half_per, half_per, nx) + y = np.linspace(-half_per, half_per, ny) + y, x = np.meshgrid(y, x) + + # Iterate through each parameter in the higher dimensions + for phase_idx in range(phase_dim): + for freq_idx in range(freq_dim): + for angle_idx in range(angle_dim): + # Calculate phase and frequency + phase = np.pi / phase_dim * phase_idx + frequency = 1 + (freq_idx * 0.1) # Increasing frequency with each step + + # Calculate angle + angle = np.pi / angle_dim * angle_idx + # Rotate x and y coordinates + xr = np.cos(angle) * x - np.sin(angle) * y + np.sin(angle) * x + np.cos(angle) * y + + # Compute the sine wave + sine_wave = (amplitude * 0.5) * np.sin(frequency * xr + phase) + sine_wave += amplitude * 0.5 + + # Assign to the output array + output[angle_idx, freq_idx, phase_idx] = sine_wave + + return output + + +def cells3d() -> np.ndarray: + """Load cells3d data from scikit-image.""" + try: + from imageio.v2 import volread + except ImportError as e: + raise ImportError( + "Please `pip install imageio[tifffile]` to load cells3d" + ) from e + + url = "https://gitlab.com/scikit-image/data/-/raw/2cdc5ce89b334d28f06a58c9f0ca21aa6992a5ba/cells3d.tif" + return volread(url) # type: ignore [no-any-return] diff --git a/src/ndv/util.py b/src/ndv/util.py new file mode 100644 index 00000000..ac72d194 --- /dev/null +++ b/src/ndv/util.py @@ -0,0 +1,64 @@ +"""Utility and convenience functions.""" + +from __future__ import annotations + +import sys +from typing import TYPE_CHECKING, Any, Literal + +from qtpy.QtWidgets import QApplication + +from .viewer._viewer import NDViewer + +if TYPE_CHECKING: + from qtpy.QtCore import QCoreApplication + + from .viewer._data_wrapper import DataWrapper + + +def imshow( + data: Any | DataWrapper, + cmap: Any | None = None, + *, + channel_mode: Literal["mono", "composite", "auto"] = "auto", +) -> NDViewer: + """Display an array or DataWrapper in a new NDViewer window. + + Parameters + ---------- + data : Any | DataWrapper + The data to be displayed. If not a DataWrapper, it will be wrapped in one. + cmap : Any | None, optional + The colormap(s) to use for displaying the data. + channel_mode : Literal['mono', 'composite'], optional + The initial mode for displaying the channels. By default "mono" will be + used unless a cmap is provided, in which case "composite" will be used. + + Returns + ------- + NDViewer + The viewer window. + """ + app, should_exec = _get_app() + if cmap is not None: + channel_mode = "composite" + if not isinstance(cmap, (list, tuple)): + cmap = [cmap] + elif channel_mode == "auto": + channel_mode = "mono" + viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode) + viewer.show() + viewer.raise_() + if should_exec: + app.exec() + return viewer + + +def _get_app() -> tuple[QCoreApplication, bool]: + is_ipython = False + if (app := QApplication.instance()) is None: + app = QApplication([]) + app.setApplicationName("ndv") + elif (ipy := sys.modules.get("IPython")) and (shell := ipy.get_ipython()): + is_ipython = str(shell.active_eventloop).startswith("qt") + + return app, not is_ipython diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py new file mode 100644 index 00000000..1d062d8c --- /dev/null +++ b/src/ndv/viewer/_data_wrapper.py @@ -0,0 +1,376 @@ +"""In this module, we provide built-in support for many array types.""" + +from __future__ import annotations + +import logging +import sys +from abc import abstractmethod +from collections.abc import Container, Hashable, Iterable, Iterator, Mapping, Sequence +from concurrent.futures import Future, ThreadPoolExecutor +from contextlib import suppress +from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar + +import numpy as np + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any, Protocol, TypeAlias, TypeGuard + + import dask.array as da + import numpy.typing as npt + import pyopencl.array as cl_array + import sparse + import tensorstore as ts + import torch + import xarray as xr + import zarr + from torch._tensor import Tensor + + from ._dims_slider import Index, Indices, Sizes + + _T_contra = TypeVar("_T_contra", contravariant=True) + + class SupportsIndexing(Protocol): + def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... + @property + def shape(self) -> tuple[int, ...]: ... + + class SupportsDunderLT(Protocol[_T_contra]): + def __lt__(self, other: _T_contra, /) -> bool: ... + + class SupportsDunderGT(Protocol[_T_contra]): + def __gt__(self, other: _T_contra, /) -> bool: ... + + SupportsRichComparison: TypeAlias = SupportsDunderLT[Any] | SupportsDunderGT[Any] + + +ArrayT = TypeVar("ArrayT") +_T = TypeVar("_T", bound=type) + +# Global executor for slice requests +_EXECUTOR = ThreadPoolExecutor(max_workers=2) + + +def _recurse_subclasses(cls: _T) -> Iterator[_T]: + for subclass in cls.__subclasses__(): + yield subclass + yield from _recurse_subclasses(subclass) + + +class DataWrapper(Generic[ArrayT]): + """Interface for wrapping different array-like data types. + + `DataWrapper.create` is a factory method that returns a DataWrapper instance + for the given data type. If your datastore type is not supported, you may implement + a new DataWrapper subclass to handle your data type. To do this, import and + subclass DataWrapper, and (minimally) implement the supports and isel methods. + Ensure that your class is imported before the DataWrapper.create method is called, + and it will be automatically detected and used to wrap your data. + """ + + # Order in which subclasses are checked for support. + # Lower numbers are checked first, and the first supporting subclass is used. + # Default is 50, and fallback to numpy-like duckarray is 100. + # Subclasses can override this to change the priority in which they are checked + PRIORITY: ClassVar[SupportsRichComparison] = 50 + # These names will be checked when looking for a channel axis + COMMON_CHANNEL_NAMES: ClassVar[Container[str]] = ("channel", "ch", "c") + # Maximum dimension size consider when guessing the channel axis + MAX_CHANNELS = 16 + + @classmethod + def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: + if isinstance(data, DataWrapper): + return data + # check subclasses for support + # This allows users to define their own DataWrapper subclasses which will + # be automatically detected (assuming they have been imported by this point) + for subclass in sorted(_recurse_subclasses(cls), key=lambda x: x.PRIORITY): + with suppress(Exception): + if subclass.supports(data): + logging.debug(f"Using {subclass.__name__} to wrap {type(data)}") + return subclass(data) + raise NotImplementedError(f"Don't know how to wrap type {type(data)}") + + def __init__(self, data: ArrayT) -> None: + self._data = data + + @property + def data(self) -> ArrayT: + return self._data + + @classmethod + @abstractmethod + def supports(cls, obj: Any) -> bool: + """Return True if this wrapper can handle the given object. + + Any exceptions raised by this method will be suppressed, so it is safe to + directly import necessary dependencies without a try/except block. + """ + raise NotImplementedError + + @abstractmethod + def isel(self, indexers: Indices) -> np.ndarray: + """Select a slice from a data store using (possibly) named indices. + + This follows the xarray-style indexing, where indexers is a mapping of + dimension names to indices or slices. Subclasses should implement this + method to return a numpy array. + """ + raise NotImplementedError + + def isel_async( + self, indexers: list[Indices] + ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: + """Asynchronous version of isel.""" + return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) + + def guess_channel_axis(self) -> Hashable | None: + """Return the (best guess) axis name for the channel dimension.""" + # for arrays with labeled dimensions, + # see if any of the dimensions are named "channel" + for dimkey, val in self.sizes().items(): + if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: + if val <= self.MAX_CHANNELS: + return dimkey + + # for shaped arrays, use the smallest dimension as the channel axis + shape = getattr(self._data, "shape", None) + if isinstance(shape, Sequence): + with suppress(ValueError): + smallest_dim = min(shape) + if smallest_dim <= self.MAX_CHANNELS: + return shape.index(smallest_dim) + return None + + def save_as_zarr(self, save_loc: str | Path) -> None: + raise NotImplementedError("save_as_zarr not implemented for this data type.") + + def sizes(self) -> Sizes: + """Return a mapping of {dimkey: size} for the data. + + The default implementation uses the shape attribute of the data, and + tries to find dimension names in the `dims`, `names`, or `labels` attributes. + (`dims` is used by xarray, `names` is used by torch, etc...). If no labels + are found, the dimensions are just named by their integer index. + """ + shape = getattr(self._data, "shape", None) + if not isinstance(shape, Sequence) or not all( + isinstance(x, int) for x in shape + ): + raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") + dims = range(len(shape)) + return {dim: int(size) for dim, size in zip(dims, shape)} + + def summary_info(self) -> str: + """Return info label with information about the data.""" + package = getattr(self._data, "__module__", "").split(".")[0] + info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" + + if sizes := self.sizes(): + # if all of the dimension keys are just integers, omit them from size_str + if all(isinstance(x, int) for x in sizes): + size_str = repr(tuple(sizes.values())) + # otherwise, include the keys in the size_str + else: + size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) + size_str = f"({size_str})" + info += f" {size_str}" + if dtype := getattr(self._data, "dtype", ""): + info += f", {dtype}" + if nbytes := getattr(self._data, "nbytes", 0) / 1e6: + info += f", {nbytes:.2f}MB" + return info + + +class XarrayWrapper(DataWrapper["xr.DataArray"]): + """Wrapper for xarray DataArray objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + return np.asarray(self._data.isel(indexers)) + + def sizes(self) -> Mapping[Hashable, int]: + return {k: int(v) for k, v in self._data.sizes.items()} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: + if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(save_loc) + + +class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): + """Wrapper for tensorstore.TensorStore objects.""" + + def __init__(self, data: Any) -> None: + super().__init__(data) + import tensorstore as ts + + self._ts = ts + + def sizes(self) -> Mapping[Hashable, int]: + return {dim.label: dim.size for dim in self._data.domain} + + def isel(self, indexers: Indices) -> np.ndarray: + result = ( + self._data[self._ts.d[tuple(indexers)][tuple(indexers.values())]] + .read() + .result() + ) + return np.asarray(result) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: + if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): + return True + return False + + +class ArrayLikeWrapper(DataWrapper, Generic[ArrayT]): + """Wrapper for numpy duck array-like objects.""" + + PRIORITY = 100 + + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return self._asarray(self._data[idx]) + + def _asarray(self, data: npt.ArrayLike) -> np.ndarray: + return np.asarray(data) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: + if ( + ( + isinstance(obj, np.ndarray) + or hasattr(obj, "__array_function__") + or hasattr(obj, "__array_namespace__") + or hasattr(obj, "__array__") + ) + and hasattr(obj, "__getitem__") + and hasattr(obj, "shape") + ): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + try: + import zarr + except ImportError: + raise ImportError("zarr is required to save this data type.") from None + + if isinstance(self._data, zarr.Array): + self._data.store = zarr.DirectoryStore(save_loc) + else: + zarr.save(str(save_loc), self._data) + + +class DaskWrapper(DataWrapper["da.Array"]): + """Wrapper for dask array objects.""" + + def isel(self, indexers: Indices) -> np.ndarray: + idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) + return np.asarray(self._data[idx].compute()) + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[da.Array]: + if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): + return True + return False + + def save_as_zarr(self, save_loc: str | Path) -> None: + self._data.to_zarr(url=str(save_loc)) + + +class CLArrayWrapper(ArrayLikeWrapper["cl_array.Array"]): + """Wrapper for pyopencl array objects.""" + + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[cl_array.Array]: + if (cl_array := sys.modules.get("pyopencl.array")) and isinstance( + obj, cl_array.Array + ): + return True + return False + + def _asarray(self, data: cl_array.Array) -> np.ndarray: + return np.asarray(data.get()) + + +class SparseArrayWrapper(ArrayLikeWrapper["sparse.Array"]): + PRIORITY = 50 + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[sparse.COO]: + if (sparse := sys.modules.get("sparse")) and isinstance(obj, sparse.COO): + return True + return False + + def _asarray(self, data: sparse.COO) -> np.ndarray: + return np.asarray(data.todense()) + + +class ZarrArrayWrapper(ArrayLikeWrapper["zarr.Array"]): + """Wrapper for zarr array objects.""" + + PRIORITY = 50 + + def __init__(self, data: Any) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if "_ARRAY_DIMENSIONS" in data.attrs: + self._name2index = { + name: i for i, name in enumerate(data.attrs["_ARRAY_DIMENSIONS"]) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[zarr.Array]: + if (zarr := sys.modules.get("zarr")) and isinstance(obj, zarr.Array): + return True + return False + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + return super().isel(real_indexers) + + +class TorchTensorWrapper(DataWrapper["torch.Tensor"]): + """Wrapper for torch tensor objects.""" + + def __init__(self, data: Tensor) -> None: + super().__init__(data) + self._name2index: dict[Hashable, int] + if names := getattr(data, "names", None): + # names may be something like (None, None, None)... + self._name2index = { + (i if name is None else name): i for i, name in enumerate(names) + } + else: + self._name2index = {i: i for i in range(data.ndim)} + + def sizes(self) -> Sizes: + return dict(zip(self._name2index, self.data.shape)) + + def isel(self, indexers: Indices) -> np.ndarray: + # convert possibly named indices to integer indices + real_indexers = {self._name2index.get(k, k): v for k, v in indexers.items()} + # convert to tuple of slices + idx = tuple(real_indexers.get(i, slice(None)) for i in range(self.data.ndim)) + return self.data[idx].numpy(force=True) # type: ignore [no-any-return] + + @classmethod + def supports(cls, obj: Any) -> TypeGuard[torch.Tensor]: + if (torch := sys.modules.get("torch")) and isinstance(obj, torch.Tensor): + return True + return False diff --git a/src/ndv/viewer/_dims_slider.py b/src/ndv/viewer/_dims_slider.py index 81fa6c1d..9161e79d 100644 --- a/src/ndv/viewer/_dims_slider.py +++ b/src/ndv/viewer/_dims_slider.py @@ -19,7 +19,7 @@ QVBoxLayout, QWidget, ) -from superqt import QElidingLabel, QLabeledRangeSlider +from superqt import QLabeledRangeSlider from superqt.iconify import QIconifyIcon from superqt.utils import signals_blocked @@ -27,11 +27,11 @@ from collections.abc import Hashable, Mapping from typing import TypeAlias - from PyQt6.QtGui import QResizeEvent + from qtpy.QtGui import QResizeEvent - # any hashable represent a single dimension in a AND array + # any hashable represent a single dimension in an ND array DimKey: TypeAlias = Hashable - # any object that can be used to index a single dimension in an AND array + # any object that can be used to index a single dimension in an ND array Index: TypeAlias = int | slice # a mapping from dimension keys to indices (eg. {"x": 0, "y": slice(5, 10)}) # this object is used frequently to query or set the currently displayed slice @@ -167,7 +167,10 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None self._play_btn.toggled.connect(self._toggle_animation) self._dim_key = dimension_key - self._dim_label = QElidingLabel(str(dimension_key).upper()) + self._dim_label = QLabel(str(dimension_key)) + self._dim_label.setSizePolicy( + QSizePolicy.Policy.Preferred, QSizePolicy.Policy.Preferred + ) self._dim_label.setToolTip("Double-click to toggle slice mode") # note, this lock button only prevents the slider from updating programmatically @@ -486,8 +489,6 @@ def set_dimension_visible(self, key: DimKey, visible: bool) -> None: self._invisible_dims.discard(key) if key in self._sliders: self._current_index[key] = self._sliders[key].value() - else: - self.add_dimension(key) else: self._invisible_dims.add(key) self._current_index.pop(key, None) diff --git a/src/ndv/viewer/_indexing.py b/src/ndv/viewer/_indexing.py deleted file mode 100644 index ffe4fa54..00000000 --- a/src/ndv/viewer/_indexing.py +++ /dev/null @@ -1,306 +0,0 @@ -"""In this module, we provide built-in support for many array types.""" - -from __future__ import annotations - -import sys -import warnings -from abc import abstractmethod -from collections.abc import Hashable, Iterable, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor -from contextlib import suppress -from typing import ( - TYPE_CHECKING, - Generic, - TypeVar, - cast, -) - -import numpy as np - -if TYPE_CHECKING: - from pathlib import Path - from typing import Any, Protocol, TypeGuard - - import dask.array as da - import numpy.typing as npt - import tensorstore as ts - import xarray as xr - from pymmcore_plus.mda.handlers import TensorStoreHandler - from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase - - from ._dims_slider import Index, Indices - - class SupportsIndexing(Protocol): - def __getitem__(self, key: Index | tuple[Index, ...]) -> npt.ArrayLike: ... - @property - def shape(self) -> tuple[int, ...]: ... - - -ArrayT = TypeVar("ArrayT") -MAX_CHANNELS = 16 -# Create a global executor -_EXECUTOR = ThreadPoolExecutor(max_workers=1) - - -class DataWrapper(Generic[ArrayT]): - """Interface for wrapping different array-like data types. - - If DataWrapper.create(your_obj) raises an exception, you can implement a new - DataWrapper subclass to handle your data type. - - It can be passed to NDViewer. - """ - - def __init__(self, data: ArrayT) -> None: - self._data = data - - @classmethod - def create(cls, data: ArrayT) -> DataWrapper[ArrayT]: - if isinstance(data, DataWrapper): - return data - if MMTensorStoreWrapper.supports(data): - return MMTensorStoreWrapper(data) - if MM5DWriter.supports(data): - return MM5DWriter(data) - if XarrayWrapper.supports(data): - return XarrayWrapper(data) - if DaskWrapper.supports(data): - return DaskWrapper(data) - if TensorstoreWrapper.supports(data): - return TensorstoreWrapper(data) - if ArrayLikeWrapper.supports(data): - return ArrayLikeWrapper(data) - raise NotImplementedError(f"Don't know how to wrap type {type(data)}") - - @abstractmethod - def isel(self, indexers: Indices) -> np.ndarray: - """Select a slice from a data store using (possibly) named indices. - - For xarray.DataArray, use the built-in isel method. - For any other duck-typed array, use numpy-style indexing, where indexers - is a mapping of axis to slice objects or indices. - """ - raise NotImplementedError - - def isel_async( - self, indexers: list[Indices] - ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - """Asynchronous version of isel.""" - return _EXECUTOR.submit(lambda: [(idx, self.isel(idx)) for idx in indexers]) - - @classmethod - @abstractmethod - def supports(cls, obj: Any) -> bool: - """Return True if this wrapper can handle the given object.""" - raise NotImplementedError - - def guess_channel_axis(self) -> Hashable | None: - """Return the (best guess) axis name for the channel dimension.""" - if isinstance(shp := getattr(self._data, "shape", None), Sequence): - # for numpy arrays, use the smallest dimension as the channel axis - if min(shp) <= MAX_CHANNELS: - return shp.index(min(shp)) - return None - - def save_as_zarr(self, save_loc: str | Path) -> None: - raise NotImplementedError("save_as_zarr not implemented for this data type.") - - def sizes(self) -> Mapping[Hashable, int]: - if (shape := getattr(self._data, "shape", None)) and isinstance(shape, tuple): - _sizes: dict[Hashable, int] = {} - for i, val in enumerate(shape): - if isinstance(val, int): - _sizes[i] = val - elif isinstance(val, Sequence) and len(val) == 2: - _sizes[val[0]] = int(val[1]) - else: - raise ValueError( - f"Invalid size: {val}. Must be an int or a 2-tuple." - ) - return _sizes - raise NotImplementedError(f"Cannot determine sizes for {type(self._data)}") - - def summary_info(self) -> str: - """Return info label with information about the data.""" - package = getattr(self._data, "__module__", "").split(".")[0] - info = f"{package}.{getattr(type(self._data), '__qualname__', '')}" - - if sizes := self.sizes(): - # if all of the dimension keys are just integers, omit them from size_str - if all(isinstance(x, int) for x in sizes): - size_str = repr(tuple(sizes.values())) - # otherwise, include the keys in the size_str - else: - size_str = ", ".join(f"{k}:{v}" for k, v in sizes.items()) - size_str = f"({size_str})" - info += f" {size_str}" - if dtype := getattr(self._data, "dtype", ""): - info += f", {dtype}" - if nbytes := getattr(self._data, "nbytes", 0) / 1e6: - info += f", {nbytes:.2f}MB" - return info - - -class MMTensorStoreWrapper(DataWrapper["TensorStoreHandler"]): - def sizes(self) -> Mapping[Hashable, int]: - with suppress(Exception): - return self._data.current_sequence.sizes # type: ignore [no-any-return] - return {} - - def guess_channel_axis(self) -> Hashable | None: - return "c" - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[TensorStoreHandler]: - with suppress(ImportError): - from pymmcore_plus.mda.handlers import TensorStoreHandler - - return isinstance(obj, TensorStoreHandler) - return False - - def isel(self, indexers: Indices) -> np.ndarray: - return self._data.isel(indexers) # type: ignore [no-any-return] - - def save_as_zarr(self, save_loc: str | Path) -> None: - if (store := self._data.store) is None: - return - import tensorstore as ts - - new_spec = store.spec().to_json() - new_spec["kvstore"] = {"driver": "file", "path": str(save_loc)} - new_ts = ts.open(new_spec, create=True).result() - new_ts[:] = store.read().result() - - -class MM5DWriter(DataWrapper["_5DWriterBase"]): - def guess_channel_axis(self) -> Hashable | None: - return "c" - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[_5DWriterBase]: - with suppress(ImportError): - try: - from pymmcore_plus.mda.handlers._5d_writer_base import _5DWriterBase - except ImportError: - from pymmcore_plus.mda.handlers import OMETiffWriter, OMEZarrWriter - - _5DWriterBase = (OMETiffWriter, OMEZarrWriter) - if isinstance(obj, _5DWriterBase): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - import zarr - from pymmcore_plus.mda.handlers import OMEZarrWriter - - if isinstance(self._data, OMEZarrWriter): - zarr.copy_store(self._data.group.store, zarr.DirectoryStore(save_loc)) - raise NotImplementedError(f"Cannot save {type(self._data)} data to Zarr.") - - def isel(self, indexers: Indices) -> np.ndarray: - p_index = indexers.get("p", 0) - if isinstance(p_index, slice): - warnings.warn("Cannot slice over position index", stacklevel=2) # TODO - p_index = p_index.start - p_index = cast(int, p_index) - - try: - sizes = [*list(self._data.position_sizes[p_index]), "y", "x"] - except IndexError as e: - raise IndexError( - f"Position index {p_index} out of range for " - f"{len(self._data.position_sizes)}" - ) from e - - data = self._data.position_arrays[self._data.get_position_key(p_index)] - full = slice(None, None) - index = tuple(indexers.get(k, full) for k in sizes) - return data[index] # type: ignore [no-any-return] - - -class XarrayWrapper(DataWrapper["xr.DataArray"]): - def isel(self, indexers: Indices) -> np.ndarray: - return np.asarray(self._data.isel(indexers)) - - def sizes(self) -> Mapping[Hashable, int]: - return {k: int(v) for k, v in self._data.sizes.items()} - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[xr.DataArray]: - if (xr := sys.modules.get("xarray")) and isinstance(obj, xr.DataArray): - return True - return False - - def guess_channel_axis(self) -> Hashable | None: - for d in self._data.dims: - if str(d).lower() in ("channel", "ch", "c"): - return cast("Hashable", d) - return None - - def save_as_zarr(self, save_loc: str | Path) -> None: - self._data.to_zarr(save_loc) - - -class DaskWrapper(DataWrapper["da.Array"]): - def isel(self, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) - return np.asarray(self._data[idx].compute()) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[da.Array]: - if (da := sys.modules.get("dask.array")) and isinstance(obj, da.Array): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - self._data.to_zarr(url=str(save_loc)) - - -class TensorstoreWrapper(DataWrapper["ts.TensorStore"]): - def __init__(self, data: Any) -> None: - super().__init__(data) - import tensorstore as ts - - self._ts = ts - - def sizes(self) -> Mapping[Hashable, int]: - return {dim.label: dim.size for dim in self._data.domain} - - def isel(self, indexers: Indices) -> np.ndarray: - result = ( - self._data[self._ts.d[tuple(indexers)][tuple(indexers.values())]] - .read() - .result() - ) - return np.asarray(result) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[ts.TensorStore]: - if (ts := sys.modules.get("tensorstore")) and isinstance(obj, ts.TensorStore): - return True - return False - - -class ArrayLikeWrapper(DataWrapper): - def isel(self, indexers: Indices) -> np.ndarray: - idx = tuple(indexers.get(k, slice(None)) for k in range(len(self._data.shape))) - return np.asarray(self._data[idx]) - - @classmethod - def supports(cls, obj: Any) -> TypeGuard[SupportsIndexing]: - if ( - isinstance(obj, np.ndarray) - or hasattr(obj, "__array_function__") - or hasattr(obj, "__array_namespace__") - or (hasattr(obj, "__getitem__") and hasattr(obj, "__array__")) - ): - return True - return False - - def save_as_zarr(self, save_loc: str | Path) -> None: - import zarr - - if isinstance(self._data, zarr.Array): - self._data.store = zarr.DirectoryStore(save_loc) - else: - zarr.save(str(save_loc), self._data) diff --git a/src/ndv/viewer/_save_button.py b/src/ndv/viewer/_save_button.py index 85520641..0ce45116 100644 --- a/src/ndv/viewer/_save_button.py +++ b/src/ndv/viewer/_save_button.py @@ -7,7 +7,7 @@ from superqt.iconify import QIconifyIcon if TYPE_CHECKING: - from ._indexing import DataWrapper + from ._data_wrapper import DataWrapper class SaveButton(QPushButton): diff --git a/src/ndv/viewer/_stack_viewer.py b/src/ndv/viewer/_viewer.py similarity index 89% rename from src/ndv/viewer/_stack_viewer.py rename to src/ndv/viewer/_viewer.py index a394a02d..08a108ac 100644 --- a/src/ndv/viewer/_stack_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -13,8 +13,8 @@ from superqt.utils import qthrottled, signals_blocked from ._backends import get_canvas +from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders -from ._indexing import DataWrapper from ._lut_control import LutControl if TYPE_CHECKING: @@ -61,7 +61,7 @@ def __init__(self, parent: QWidget | None = None): self.toggled.connect(self.next_mode) # set minimum width to the width of the larger string 'composite' - self.setMinimumWidth(92) # FIXME: magic number + self.setMinimumWidth(92) # magic number :/ def next_mode(self) -> None: if self.isChecked(): @@ -88,31 +88,6 @@ def __init__(self, parent: QWidget | None = None): self.setChecked(True) -# @dataclass -# class LutModel: -# name: str = "" -# autoscale: bool = True -# min: float = 0.0 -# max: float = 1.0 -# colormap: cmap.Colormap = GRAYS -# visible: bool = True - - -# @dataclass -# class ViewerModel: -# data: Any = None -# # dimensions of the data that will *not* be sliced. -# visualized_dims: Container[DimKey] = (-2, -1) -# # the axis that represents the channels in the data -# channel_axis: DimKey | None = None -# # the mode for displaying the channels -# # if MONO, only the current selection of channel_axis is displayed -# # if COMPOSITE, the full channel_axis is sliced, and luts determine display -# channel_mode: ChannelMode = ChannelMode.MONO -# # map of index in the channel_axis to LutModel -# luts: Mapping[int, LutModel] = {} - - class NDViewer(QWidget): """A viewer for ND arrays. @@ -140,8 +115,8 @@ class NDViewer(QWidget): - `_update_data_for_index` is an asynchronous method that retrieves the data for the given index from the datastore (using `_isel`) and queues the `_on_data_slice_ready` method to be called when the data is ready. The logic - for extracting data from the datastore is defined in `_indexing.py`, which handles - idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). + for extracting data from the datastore is defined in `_data_wrapper.py`, which + handles idiosyncrasies of different datastores (e.g. xarray, tensorstore, etc). - `_on_data_slice_ready` is called when the data is ready, and updates the image. Note that if the slice is multidimensional, the data will be reduced to 2D using max intensity projection (and double-clicking on any given dimension slider will @@ -154,8 +129,10 @@ class NDViewer(QWidget): Parameters ---------- data : Any - The data to display. This can be an ND array, an xarray DataArray, or any - object that supports numpy-style indexing. + The data to display. This can be any duck-like ND array, including numpy, dask, + xarray, jax, tensorstore, zarr, etc. You can add support for new datastores by + subclassing `DataWrapper` and implementing the required methods. See + `DataWrapper` for more information. parent : QWidget, optional The parent widget of this widget. channel_axis : Hashable, optional @@ -168,7 +145,7 @@ class NDViewer(QWidget): def __init__( self, - data: Any, + data: DataWrapper | Any, *, colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, parent: QWidget | None = None, @@ -179,8 +156,6 @@ def __init__( # ATTRIBUTES ---------------------------------------------------- - # dimensions of the data in the datastore - self._sizes: Sizes = {} # mapping of key to a list of objects that control image nodes in the canvas self._img_handles: defaultdict[ImgKey, list[PImageHandle]] = defaultdict(list) # mapping of same keys to the LutControl objects control image display props @@ -274,30 +249,29 @@ def __init__( self.set_data(data) # ------------------- PUBLIC API ---------------------------- + @property + def dims_sliders(self) -> DimsSliders: + """Return the DimsSliders widget.""" + return self._dims_sliders + + @property + def data_wrapper(self) -> DataWrapper: + """Return the DataWrapper object around the datastore.""" + return self._data_wrapper + @property def data(self) -> Any: """Return the data backing the view.""" - return self._data_wrapper._data + return self._data_wrapper.data @data.setter def data(self, data: Any) -> None: """Set the data backing the view.""" raise AttributeError("Cannot set data directly. Use `set_data` method.") - @property - def dims_sliders(self) -> DimsSliders: - """Return the DimsSliders widget.""" - return self._dims_sliders - - @property - def sizes(self) -> Sizes: - """Return sizes {dimkey: int} of the dimensions in the datastore.""" - return self._sizes - def set_data( self, - data: Any, - sizes: SizesLike | None = None, + data: DataWrapper | Any, channel_axis: int | None = None, visualized_dims: Iterable[DimKey] | None = None, ) -> None: @@ -305,9 +279,6 @@ def set_data( # store the data self._data_wrapper = DataWrapper.create(data) - # determine sizes of the data - self._sizes = self._data_wrapper.sizes() if sizes is None else _to_sizes(sizes) - # set channel axis if channel_axis is not None: self._channel_axis = channel_axis @@ -316,7 +287,8 @@ def set_data( # update the dimensions we are visualizing if visualized_dims is None: - visualized_dims = list(self._sizes)[-self._ndims :] + sizes = self._data_wrapper.sizes() + visualized_dims = list(sizes)[-self._ndims :] self.set_visualized_dims(visualized_dims) # update the range of all the sliders to match the sizes we set above @@ -347,8 +319,9 @@ def update_slider_ranges( This is mostly here as a public way to reset the """ if maxes is None: - maxes = self._sizes - maxes = _to_sizes(maxes) + maxes = self._data_wrapper.sizes() + else: + maxes = _to_sizes(maxes) self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) if mins is not None: self._dims_sliders.setMinima(_to_sizes(mins)) @@ -366,7 +339,7 @@ def set_ndim(self, ndim: Literal[2, 3]) -> None: self._canvas.set_ndim(ndim) # set the visibility of the last non-channel dimension - sizes = list(self._sizes) + sizes = list(self._data_wrapper.sizes()) if self._channel_axis is not None: sizes = [x for x in sizes if x != self._channel_axis] if len(sizes) >= 3: @@ -436,10 +409,11 @@ def _update_data_for_index(self, index: Indices) -> None: if ( self._channel_axis is not None and self._channel_mode == ChannelMode.COMPOSITE + and self._channel_axis in (sizes := self._data_wrapper.sizes()) ): indices: list[Indices] = [ {**index, self._channel_axis: i} - for i in range(self._sizes[self._channel_axis]) + for i in range(sizes[self._channel_axis]) ] else: indices = [index] @@ -452,7 +426,11 @@ def _update_data_for_index(self, index: Indices) -> None: {k: v for k, v in idx.items() if k not in self._visualized_dims} for idx in indices ] - self._last_future = f = self._isel(indices) + try: + self._last_future = f = self._data_wrapper.isel_async(indices) + except Exception as e: + raise type(e)(f"Failed to index data with {index}: {e}") from e + f.add_done_callback(self._on_data_slice_ready) def closeEvent(self, a0: QCloseEvent | None) -> None: @@ -461,15 +439,6 @@ def closeEvent(self, a0: QCloseEvent | None) -> None: self._last_future = None super().closeEvent(a0) - def _isel( - self, indices: list[Indices] - ) -> Future[Iterable[tuple[Indices, np.ndarray]]]: - """Select data from the datastore using the given index.""" - try: - return self._data_wrapper.isel_async(indices) - except Exception as e: - raise type(e)(f"Failed to index data with {indices}: {e}") from e - @ensure_main_thread # type: ignore def _on_data_slice_ready( self, future: Future[Iterable[tuple[Indices, np.ndarray]]] @@ -482,14 +451,11 @@ def _on_data_slice_ready( # because the future has a reference to this widget in its _done_callbacks # which will prevent the widget from being garbage collected if the future self._last_future = None + if future.cancelled(): return - data = future.result() - # FIXME: - # `self._channel_axis: i` is a bug; we assume channel indices start at 0 - # but the actual values used for indices are up to the user. - for idx, datum in data: + for idx, datum in future.result(): self._update_canvas_data(datum, idx) self._canvas.refresh() diff --git a/tests/conftest.py b/tests/conftest.py index cb8c3d4d..42a4a61c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,6 +18,11 @@ def _find_leaks(request: "FixtureRequest", qapp: "QApplication") -> Iterator[Non `functools.partial(self._method)` or `lambda: self._method` being used in that widget's code. """ + # check for the "allow_leaks" marker + if "allow_leaks" in request.node.keywords: + yield + return + nbefore = len(qapp.topLevelWidgets()) failures_before = request.session.testsfailed yield diff --git a/tests/test_examples.py b/tests/test_examples.py new file mode 100644 index 00000000..5250ddaa --- /dev/null +++ b/tests/test_examples.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import runpy +from pathlib import Path + +import pytest +from qtpy.QtWidgets import QApplication + +EXAMPLES = Path(__file__).parent.parent / "examples" +EXAMPLES_PY = list(EXAMPLES.glob("*.py")) + + +@pytest.fixture +def no_qapp_exec(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr(QApplication, "exec", lambda *_: None) + + +@pytest.mark.allow_leaks +@pytest.mark.usefixtures("no_qapp_exec") +@pytest.mark.parametrize("example", EXAMPLES_PY, ids=lambda x: x.name) +def test_example(qapp: QApplication, example: Path) -> None: + try: + runpy.run_path(str(example)) + except ImportError as e: + pytest.skip(str(e))