Skip to content

Commit

Permalink
refactor: minimize public API surface (#13)
Browse files Browse the repository at this point in the history
* refactor: minimize public API

* misc
  • Loading branch information
tlambert03 authored Jun 9, 2024
1 parent 2e2b53f commit b83af28
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 71 deletions.
2 changes: 1 addition & 1 deletion src/ndv/viewer/_dims_slider.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def __init__(self, dimension_key: DimKey, parent: QWidget | None = None) -> None
layout.addWidget(self._pos_label)
layout.addWidget(self._out_of_label)
layout.addWidget(self._lock_btn)
self.setMinimumHeight(22)
self.setMinimumHeight(26)

def resizeEvent(self, a0: QResizeEvent | None) -> None:
if isinstance(par := self.parent(), DimsSliders):
Expand Down
148 changes: 78 additions & 70 deletions src/ndv/viewer/_viewer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from collections.abc import Iterable, Mapping, Sequence
from itertools import cycle
from typing import TYPE_CHECKING, Literal, cast

Expand All @@ -24,7 +23,7 @@
from ._lut_control import LutControl

if TYPE_CHECKING:
from collections.abc import Hashable
from collections.abc import Hashable, Iterable, Sequence
from concurrent.futures import Future
from typing import Any, Callable, TypeAlias

Expand Down Expand Up @@ -132,7 +131,6 @@ def __init__(
self._channel_axis = channel_axis
self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode
# colormaps that will be cycled through when displaying composite images
# TODO: allow user to set this
if colormaps is not None:
self._cmaps = [cmap.Colormap(c) for c in colormaps]
else:
Expand All @@ -157,7 +155,7 @@ def __init__(

# button to change number of displayed dimensions
self._ndims_btn = DimToggleButton(self)
self._ndims_btn.clicked.connect(self.toggle_3d)
self._ndims_btn.clicked.connect(self._toggle_3d)

# place to display dataset summary
self._data_info_label = QElidingLabel("", parent=self)
Expand Down Expand Up @@ -242,32 +240,46 @@ def data(self, data: Any) -> None:
raise AttributeError("Cannot set data directly. Use `set_data` method.")

def set_data(
self,
data: DataWrapper | Any,
channel_axis: int | None = None,
visualized_dims: Iterable[DimKey] | None = None,
self, data: DataWrapper | Any, *, initial_index: Indices | None = None
) -> None:
"""Set the datastore, and, optionally, the sizes of the data."""
"""Set the datastore, and, optionally, the sizes of the data.
Properties
----------
data : DataWrapper | Any
The data to display. This can be any duck-like ND array, including numpy,
dask, xarray, jax, tensorstore, zarr, etc. You can add support for new
datastores by subclassing `DataWrapper` and implementing the required
methods. If a `DataWrapper` instance is passed, it is used directly.
See `DataWrapper` for more information.
initial_index : Indices | None
The initial index to display. This is a mapping of dimensions to integers
or slices that define the slice of the data to display. If not provided,
the initial index will be set to the middle of the data.
"""
# store the data
self._data_wrapper = DataWrapper.create(data)

# set channel axis
if channel_axis is not None:
self._channel_axis = channel_axis
elif self._channel_axis is None:
self._channel_axis = self._data_wrapper.guess_channel_axis()
self._channel_axis = self._data_wrapper.guess_channel_axis()

# update the dimensions we are visualizing
if visualized_dims is None:
sizes = self._data_wrapper.sizes()
visualized_dims = list(sizes)[-self._ndims :]
sizes = self._data_wrapper.sizes()
visualized_dims = list(sizes)[-self._ndims :]
self.set_visualized_dims(visualized_dims)

# update the range of all the sliders to match the sizes we set above
with signals_blocked(self._dims_sliders):
self.update_slider_ranges()
self._update_slider_ranges()

# redraw
self.setIndex({})
if initial_index is None:
idx = {k: int(v // 2) for k, v in sizes.items()}
else:
if not isinstance(initial_index, dict): # pragma: no cover
raise TypeError("initial_index must be a dict")
idx = initial_index
self.set_current_index(idx)
# update the data info label
self._data_info_label.setText(self._data_wrapper.summary_info())

Expand All @@ -282,31 +294,11 @@ def set_visualized_dims(self, dims: Iterable[DimKey]) -> None:
for d in self._visualized_dims:
self._dims_sliders.set_dimension_visible(d, False)

def update_slider_ranges(
self, mins: SizesLike | None = None, maxes: SizesLike | None = None
) -> None:
"""Set the maximum values of the sliders.
If `sizes` is not provided, sizes will be inferred from the datastore.
This is mostly here as a public way to reset the
"""
if maxes is None:
maxes = self._data_wrapper.sizes()
else:
maxes = _to_sizes(maxes)
self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()})
if mins is not None:
self._dims_sliders.setMinima(_to_sizes(mins))

# FIXME: this needs to be moved and made user-controlled
for dim in list(maxes.keys())[-self._ndims :]:
self._dims_sliders.set_dimension_visible(dim, False)

def toggle_3d(self) -> None:
self.set_ndim(3 if self._ndims == 2 else 2)

def set_ndim(self, ndim: Literal[2, 3]) -> None:
"""Set the number of dimensions to display."""
if ndim not in (2, 3):
raise ValueError("ndim must be 2 or 3")

self._ndims = ndim
self._canvas.set_ndim(ndim)

Expand All @@ -330,13 +322,19 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None:
self._channel_axis as the channel axis. In "grayscale" mode, each channel is
displayed separately. (If mode is None, the current value of the
channel_mode_picker button is used)
Parameters
----------
mode : ChannelMode | str | None
The mode to set, must be one of 'composite' or 'mono'.
"""
# bool may happen when called from the button clicked signal
if mode is None or isinstance(mode, bool):
mode = self._channel_mode_btn.mode()
else:
mode = ChannelMode(mode)
self._channel_mode_btn.setMode(mode)
if mode == getattr(self, "_channel_mode", None):
if mode == self._channel_mode:
return

self._channel_mode = mode
Expand All @@ -351,12 +349,45 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None:
self._clear_images()
self._update_data_for_index(self._dims_sliders.value())

def setIndex(self, index: Indices) -> None:
"""Set the index of the displayed image."""
self._dims_sliders.setValue(index)
def set_current_index(self, index: Indices | None = None) -> None:
"""Set the index of the displayed image.
`index` is a mapping of dimensions to integers or slices that define the slice
of the data to display. For example, a numpy slice of `[0, 1, 5:10]` would be
represented as `{0: 0, 1: 1, 2: slice(5, 10)}`, but dimensions can also be
named, e.g. `{'t': 0, 'c': 1, 'z': slice(5, 10)}` if the data has named
dimensions.
Note, calling `.set_current_index()` with no arguments will force the widget
to redraw the current slice.
"""
self._dims_sliders.setValue(index or {})

# camelCase aliases

dimsSliders = dims_sliders
setChannelMode = set_channel_mode
setData = set_data
setCurrentIndex = set_current_index
setVisualizedDims = set_visualized_dims

# ------------------- PRIVATE METHODS ----------------------------

def _toggle_3d(self) -> None:
self.set_ndim(3 if self._ndims == 2 else 2)

def _update_slider_ranges(self) -> None:
"""Set the maximum values of the sliders.
If `sizes` is not provided, sizes will be inferred from the datastore.
"""
maxes = self._data_wrapper.sizes()
self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()})

# FIXME: this needs to be moved and made user-controlled
for dim in list(maxes.keys())[-self._ndims :]:
self._dims_sliders.set_dimension_visible(dim, False)

def _on_set_range_clicked(self) -> None:
# using method to swallow the parameter passed by _set_range_btn.clicked
self._canvas.set_range()
Expand Down Expand Up @@ -457,19 +488,15 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None:
elif datum.ndim == 3:
handles.append(self._canvas.add_volume(datum, cmap=cm))
if imkey not in self._lut_ctrls:
channel_name = self._get_channel_name(index)
ch_index = index.get(self._channel_axis, 0)
self._lut_ctrls[imkey] = c = LutControl(
channel_name,
f"Ch {ch_index}",
handles,
self,
cmaplist=self._cmaps + DEFAULT_COLORMAPS,
)
self._lut_drop.addWidget(c)

def _get_channel_name(self, index: Indices) -> str:
c = index.get(self._channel_axis, 0)
return f"Ch {c}" # TODO: get name from user

def _reduce_data_for_display(
self, data: np.ndarray, reductor: Callable[..., np.ndarray] = np.max
) -> np.ndarray:
Expand Down Expand Up @@ -511,22 +538,3 @@ def _clear_images(self) -> None:
cast("QVBoxLayout", self.layout()).removeWidget(c)
c.deleteLater()
self._lut_ctrls.clear()


def _to_sizes(sizes: SizesLike | None) -> Sizes:
"""Coerce `sizes` to a {dimKey -> int} mapping."""
if sizes is None:
return {}
if isinstance(sizes, Mapping):
return {k: int(v) for k, v in sizes.items()}
if not isinstance(sizes, Iterable):
raise TypeError(f"SizeLike must be an iterable or mapping, not: {type(sizes)}")
_sizes: dict[Hashable, int] = {}
for i, val in enumerate(sizes):
if isinstance(val, int):
_sizes[i] = val
elif isinstance(val, Sequence) and len(val) == 2:
_sizes[val[0]] = int(val[1])
else:
raise ValueError(f"Invalid size: {val}. Must be an int or a 2-tuple.")
return _sizes

0 comments on commit b83af28

Please sign in to comment.