diff --git a/examples/rgb.py b/examples/rgb.py new file mode 100644 index 00000000..bfff0901 --- /dev/null +++ b/examples/rgb.py @@ -0,0 +1,16 @@ +import math + +import numpy + +import ndv + +img = numpy.zeros((256, 256, 4), dtype=numpy.uint8) + +for x in range(256): + for y in range(256): + img[x, y, 0] = x + img[x, y, 1] = y + img[x, y, 2] = 255 - x + img[x, y, 3] = int(math.sqrt((x - 128) ** 2 + (y - 128) ** 2)) + +n = ndv.imshow(img) diff --git a/src/ndv/util.py b/src/ndv/util.py index ac72d194..11b65e30 100644 --- a/src/ndv/util.py +++ b/src/ndv/util.py @@ -3,7 +3,7 @@ from __future__ import annotations import sys -from typing import TYPE_CHECKING, Any, Literal +from typing import TYPE_CHECKING, Any from qtpy.QtWidgets import QApplication @@ -12,14 +12,16 @@ if TYPE_CHECKING: from qtpy.QtCore import QCoreApplication + from .viewer._components import ChannelMode from .viewer._data_wrapper import DataWrapper + from .viewer._viewer import ChannelModeStr def imshow( data: Any | DataWrapper, cmap: Any | None = None, *, - channel_mode: Literal["mono", "composite", "auto"] = "auto", + channel_mode: ChannelModeStr | ChannelMode = "auto", ) -> NDViewer: """Display an array or DataWrapper in a new NDViewer window. @@ -29,7 +31,7 @@ def imshow( The data to be displayed. If not a DataWrapper, it will be wrapped in one. cmap : Any | None, optional The colormap(s) to use for displaying the data. - channel_mode : Literal['mono', 'composite'], optional + channel_mode : Literal['mono', 'composite', 'rgb', 'rgba', 'auto'], optional The initial mode for displaying the channels. By default "mono" will be used unless a cmap is provided, in which case "composite" will be used. @@ -39,12 +41,8 @@ def imshow( The viewer window. """ app, should_exec = _get_app() - if cmap is not None: - channel_mode = "composite" - if not isinstance(cmap, (list, tuple)): - cmap = [cmap] - elif channel_mode == "auto": - channel_mode = "mono" + if cmap is not None and not isinstance(cmap, (list, tuple)): + cmap = [cmap] viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode) viewer.show() viewer.raise_() diff --git a/src/ndv/viewer/_backends/_pygfx.py b/src/ndv/viewer/_backends/_pygfx.py index b6307b03..d374f3fb 100755 --- a/src/ndv/viewer/_backends/_pygfx.py +++ b/src/ndv/viewer/_backends/_pygfx.py @@ -32,12 +32,23 @@ def _is_inside(bounding_box: np.ndarray, pos: Sequence[float]) -> bool: ) +# TODO Combine with similar function in _vispy.py +def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None: + if data is not None and data.ndim == 3: + # PyGFX expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + return np.moveaxis(data, i, -1) + return data + + class PyGFXImageHandle: def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: self._image = image self._render = render self._grid = cast("Texture", image.geometry.grid) self._material = cast("ImageBasicMaterial", image.material) + self._cmap = cmap.Colormap("gray") @property def data(self) -> np.ndarray: @@ -45,7 +56,7 @@ def data(self) -> np.ndarray: @data.setter def data(self, data: np.ndarray) -> None: - self._grid.data[:] = data + self._grid.data[:] = _coerce_rgb(data) self._grid.update_range((0, 0, 0), self._grid.size) @property @@ -85,7 +96,9 @@ def cmap(self) -> cmap.Colormap: @cmap.setter def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap - self._material.map = cmap.to_pygfx() + # RGB image special case + if self.data.ndim != 3: + self._material.map = cmap.to_pygfx() self._render() def start_move(self, pos: Sequence[float]) -> None: @@ -409,7 +422,6 @@ class PyGFXViewerCanvas(PCanvas): """pygfx-based canvas wrapper.""" def __init__(self) -> None: - self._current_shape: tuple[int, ...] = () self._last_state: dict[Literal[2, 3], Any] = {} self._canvas = _QWgpuCanvas(size=(600, 600)) @@ -470,6 +482,7 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> PyGFXImageHandle: """Add a new Image node to the scene.""" + data = _coerce_rgb(data) tex = pygfx.Texture(data, dim=2) image = pygfx.Image( pygfx.Geometry(grid=tex), @@ -479,9 +492,7 @@ def add_image( self._scene.add(image) if data is not None: - self._current_shape, prev_shape = data.shape, self._current_shape - if not prev_shape: - self.set_range() + self.set_range() # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. @@ -504,9 +515,7 @@ def add_volume( if data is not None: vol.local_position = [-0.5 * i for i in data.shape[::-1]] - self._current_shape, prev_shape = data.shape, self._current_shape - if len(prev_shape) != 3: - self.set_range() + self.set_range() # FIXME: I suspect there are more performant ways to refresh the canvas # look into it. diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index e74a05e7..c06e27d6 100755 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -29,6 +29,16 @@ DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0) +# TODO Combine with similar function in _pygfx.py +def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None: + if data is not None and data.ndim == 3: + # PyGFX expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + return np.moveaxis(data, i, -1) + return data + + class Handle(scene.visuals.Markers): """A Marker that allows specific ROI alterations.""" @@ -248,7 +258,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: class VispyImageHandle: def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: self._visual = visual - self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 + self._ndim = self.data.ndim @property def data(self) -> np.ndarray: @@ -266,7 +276,7 @@ def data(self, data: np.ndarray) -> None: stacklevel=2, ) return - self._visual.set_data(data) + self._visual.set_data(_coerce_rgb(data)) @property def visible(self) -> bool: @@ -490,6 +500,7 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" + data = _coerce_rgb(data) img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True diff --git a/src/ndv/viewer/_components.py b/src/ndv/viewer/_components.py index 9dc7450b..92e2b622 100644 --- a/src/ndv/viewer/_components.py +++ b/src/ndv/viewer/_components.py @@ -2,11 +2,15 @@ from enum import Enum from pathlib import Path +from typing import TYPE_CHECKING -from qtpy.QtCore import QSize +from qtpy.QtCore import QSize, Qt from qtpy.QtGui import QMovie from qtpy.QtWidgets import QLabel, QPushButton, QWidget -from superqt import QIconifyIcon +from superqt import QEnumComboBox, QIconifyIcon + +if TYPE_CHECKING: + from qtpy.QtGui import QStandardItem, QStandardItemModel SPIN_GIF = str(Path(__file__).parent / "spin.gif") @@ -35,35 +39,41 @@ def __init__(self, parent: QWidget | None = None): class ChannelMode(str, Enum): COMPOSITE = "composite" + RGBA = "rgba" MONO = "mono" + @classmethod + def _missing_(cls, value: object) -> ChannelMode | None: + if value == "rgb": + return ChannelMode.RGBA + return None + def __str__(self) -> str: return self.value -class ChannelModeButton(QPushButton): - def __init__(self, parent: QWidget | None = None): - super().__init__(parent) - self.setCheckable(True) - self.toggled.connect(self.next_mode) - - # set minimum width to the width of the larger string 'composite' - self.setMinimumWidth(92) # magic number :/ - - def next_mode(self) -> None: - if self.isChecked(): - self.setMode(ChannelMode.MONO) - else: - self.setMode(ChannelMode.COMPOSITE) - - def mode(self) -> ChannelMode: - return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE - - def setMode(self, mode: ChannelMode) -> None: - # we show the name of the next mode, not the current one - other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO - self.setText(str(other)) - self.setChecked(mode == ChannelMode.MONO) +class ChannelModeCombo(QEnumComboBox): + """A ComboBox for ChannelMode, where the RGBA enum can be removed.""" + + def __init__(self, parent: QWidget | None = None, allow_rgba: bool = False): + super().__init__(parent, enum_class=ChannelMode) + # Find the RGBA item + idx = list(ChannelMode.__members__.keys()).index("RGBA") + model: QStandardItemModel = self.model() + self._rgba_item: QStandardItem = model.item(idx) + + self.allow_rgba(allow_rgba) + + def allow_rgba(self, enable: bool) -> None: + flags = self._rgba_item.flags() + self._rgba_item.setFlags( + flags | Qt.ItemFlag.ItemIsEnabled + if enable + else flags & ~Qt.ItemFlag.ItemIsEnabled + ) + if self.currentEnum() == ChannelMode.RGBA and not enable: + # Arbitrary fallback mode + self.setCurrentEnum(ChannelMode.COMPOSITE) class ROIButton(QPushButton): diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index 502389c2..5b10b7f1 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -134,19 +134,15 @@ def guess_channel_axis(self) -> Hashable | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" - for dimkey, val in self.sizes().items(): + sizes = self.sizes() + for dimkey, val in sizes.items(): if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: if val <= self.MAX_CHANNELS: return dimkey # for shaped arrays, use the smallest dimension as the channel axis - shape = getattr(self._data, "shape", None) - if isinstance(shape, Sequence): - with suppress(ValueError): - smallest_dim = min(shape) - if smallest_dim <= self.MAX_CHANNELS: - return shape.index(smallest_dim) - return None + min_key = min(sizes, key=sizes.get) # type: ignore + return min_key if sizes[min_key] <= self.MAX_CHANNELS else None def save_as_zarr(self, save_loc: str | Path) -> None: raise NotImplementedError("save_as_zarr not implemented for this data type.") diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index 75e823aa..ed925dae 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -135,3 +135,20 @@ def update_autoscale(self) -> None: self._clims.setMinimum(min(mi, self._clims.minimum())) self._clims.setMaximum(max(ma, self._clims.maximum())) self._clims.setValue((mi, ma)) + + +class RGBAControl(LutControl): + def __init__( + self, + name: str = "", + handles: Iterable[PImageHandle] = (), + parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), + auto_clim: bool = True, + ) -> None: + super().__init__("RGB", handles, parent, cmaplist, auto_clim) + self._cmap.setVisible(False) + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + # NB: No-op + pass diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 8b976559..daa46fcb 100755 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -1,5 +1,6 @@ from __future__ import annotations +import warnings from collections import defaultdict from itertools import cycle from typing import TYPE_CHECKING, Literal, cast @@ -14,7 +15,7 @@ from ndv.viewer._components import ( ChannelMode, - ChannelModeButton, + ChannelModeCombo, DimToggleButton, QSpinner, ROIButton, @@ -23,7 +24,7 @@ from ._backends import get_canvas_class from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders -from ._lut_control import LutControl +from ._lut_control import LutControl, RGBAControl if TYPE_CHECKING: from collections.abc import Hashable, Iterable, Sequence @@ -39,6 +40,7 @@ ImgKey: TypeAlias = Hashable # any mapping of dimensions to sizes SizesLike: TypeAlias = Sizes | Iterable[int | tuple[DimKey, int] | Sequence] + ChannelModeStr: TypeAlias = Literal["mono", "composite", "rgb", "rgba", "auto"] MID_GRAY = "#888888" GRAYS = cmap.Colormap("gray") @@ -117,7 +119,7 @@ def __init__( colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, parent: QWidget | None = None, channel_axis: DimKey | None = None, - channel_mode: ChannelMode | str = ChannelMode.MONO, + channel_mode: ChannelMode | ChannelModeStr = ChannelMode.MONO, ): super().__init__(parent=parent) @@ -136,6 +138,7 @@ def __init__( # the axis that represents the channels in the data self._channel_axis = channel_axis self._channel_mode: ChannelMode = None # type: ignore # set in set_channel_mode + self._allow_rgb: bool = False # colormaps that will be cycled through when displaying composite images if colormaps is not None: self._cmaps = [cmap.Colormap(c) for c in colormaps] @@ -156,8 +159,8 @@ def __init__( # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels - self._channel_mode_btn = ChannelModeButton(self) - self._channel_mode_btn.clicked.connect(self.set_channel_mode) + self._channel_mode_combo = ChannelModeCombo(self, self._allow_rgb) + self._channel_mode_combo.currentEnumChanged.connect(self.set_channel_mode) # button to reset the zoom of the canvas self._set_range_btn = QPushButton( QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self @@ -210,7 +213,7 @@ def __init__( btns.setContentsMargins(0, 0, 0, 0) btns.setSpacing(0) btns.addStretch() - btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._channel_mode_combo) btns.addWidget(self._ndims_btn) btns.addWidget(self._set_range_btn) btns.addWidget(self._add_roi_btn) @@ -235,9 +238,12 @@ def __init__( # SETUP ------------------------------------------------------ - self.set_channel_mode(channel_mode) self.set_data(data) + if channel_mode == "auto": + channel_mode = self._guess_channel_mode(colormaps, data) + self.set_channel_mode(channel_mode) + # ------------------- PUBLIC API ---------------------------- @property def dims_sliders(self) -> DimsSliders: @@ -280,23 +286,37 @@ def set_data( the initial index will be set to the middle of the data. """ # clear current data + self._data_wrapper = None + self._clear_images() + # NB This is particularly important to be done on every set_data call + # because when going from a N+M-dimensional image to an N-dimensional + # image, we need to remove the sliders for dims N+1, N+2, ... + self._dims_sliders.clear() + self._data_info_label.setText("") if data is None: - self._data_wrapper = None - self._clear_images() - self._dims_sliders.clear() - self._data_info_label.setText("") return # store the data - self._data_wrapper = DataWrapper.create(data) + self._data_wrapper: DataWrapper = DataWrapper.create(data) # type: ignore + + # update the dimensions we are visualizing + sizes = dict(self._data_wrapper.sizes().items()) # set channel axis self._channel_axis = self._data_wrapper.guess_channel_axis() - - # update the dimensions we are visualizing - sizes = self._data_wrapper.sizes() - visualized_dims = list(sizes)[-self._ndims :] - self.set_visualized_dims(visualized_dims) + self._allow_rgb = self._channel_axis is not None and sizes[ + self._channel_axis + ] in [3, 4] + self._channel_mode_combo.allow_rgba(self._allow_rgb) + + visualized_dims = list(sizes) + # Channel axis should never be a visualized dimension + # This axis could be at the end (common for RGB) + # or at the front (common for high-dim data) + if self._channel_axis is not None: + visualized_dims.remove(self._channel_axis) + # By convention, visualize the final dimensions + self.set_visualized_dims(visualized_dims[-self._ndims :]) # update the range of all the sliders to match the sizes we set above with signals_blocked(self._dims_sliders): @@ -388,21 +408,27 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: mode : ChannelMode | str | None The mode to set, must be one of 'composite' or 'mono'. """ - # bool may happen when called from the button clicked signal - if mode is None or isinstance(mode, bool): - mode = self._channel_mode_btn.mode() + if mode is None: + mode = self._channel_mode_combo.currentEnum() else: mode = ChannelMode(mode) - self._channel_mode_btn.setMode(mode) + if mode == ChannelMode.RGBA and not self._allow_rgb: + warnings.warn( + "Cannot set_channel_mode to RGBA without RGBA data!", + stacklevel=2, + ) + return + self._channel_mode_combo.setCurrentEnum(mode) if mode == self._channel_mode: return - self._channel_mode = mode + self._channel_mode = cast(ChannelMode, mode) self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle if self._channel_axis is not None: # set the visibility of the channel slider self._dims_sliders.set_dimension_visible( - self._channel_axis, mode != ChannelMode.COMPOSITE + self._channel_axis, + mode not in [ChannelMode.COMPOSITE, ChannelMode.RGBA], ) self.refresh() @@ -436,6 +462,21 @@ def set_current_index(self, index: Indices | None = None) -> None: # ------------------- PRIVATE METHODS ---------------------------- + def _guess_channel_mode( + self, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, + data: DataWrapper | Any | None = None, + ) -> ChannelMode: + # Users who provider colormaps generally expect composite images + if colormaps is not None: + return ChannelMode.COMPOSITE + # Data shaped [Y, X, 3] or [Y, X, 4], are usually RGB images + if (shape := getattr(data, "shape", None)) is not None: + if shape[-1] in [3, 4]: + return ChannelMode.RGBA + # Default + return ChannelMode.MONO + def _toggle_3d(self) -> None: self.set_ndim(3 if self._ndims == 2 else 2) @@ -457,9 +498,15 @@ def _update_slider_ranges(self) -> None: self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) # FIXME: this needs to be moved and made user-controlled - for dim in list(maxes.keys())[-self._ndims :]: + for dim in self._visualized_dims: self._dims_sliders.set_dimension_visible(dim, False) + if self._channel_axis is not None: + self._dims_sliders.set_dimension_visible( + self._channel_axis, + self._channel_mode not in [ChannelMode.RGBA, ChannelMode.COMPOSITE], + ) + def _on_set_range_clicked(self) -> None: # using method to swallow the parameter passed by _set_range_btn.clicked self._canvas.set_range() @@ -488,15 +535,22 @@ def _update_data_for_index(self, index: Indices) -> None: """ if self._data_wrapper is None: return + indices: list[Indices] if ( self._channel_axis is not None and self._channel_mode == ChannelMode.COMPOSITE and self._channel_axis in (sizes := self._data_wrapper.sizes()) ): - indices: list[Indices] = [ + indices = [ {**index, self._channel_axis: i} for i in range(sizes[self._channel_axis]) ] + elif ( + self._channel_axis is not None + and self._channel_mode == ChannelMode.RGBA + and self._channel_axis in (sizes := self._data_wrapper.sizes()) + ): + indices = [{k: v for k, v in index.items() if k != self._channel_axis}] else: indices = [index] @@ -562,13 +616,18 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) - if datum.ndim == 2: + if datum.ndim == 2 or self._channel_mode == ChannelMode.RGBA: handles.append(self._canvas.add_image(datum, cmap=cm)) elif datum.ndim == 3: handles.append(self._canvas.add_volume(datum, cmap=cm)) if imkey not in self._lut_ctrls: + cls = ( + RGBAControl + if self._channel_mode == ChannelMode.RGBA + else LutControl + ) ch_index = index.get(self._channel_axis, 0) - self._lut_ctrls[imkey] = c = LutControl( + self._lut_ctrls[imkey] = c = cls( f"Ch {ch_index}", handles, self, @@ -598,9 +657,12 @@ def _reduce_data_for_display( # - for better way to determine which dims need to be reduced (currently just # the smallest dims) data = data.squeeze() - visualized_dims = self._ndims - if extra_dims := data.ndim - visualized_dims: + if extra_dims := data.ndim - len(self._visualized_dims): shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + # Preserve channels in RGB mode + if self._channel_mode == ChannelMode.RGBA: + shapes = [s for s in shapes if s[0] != self._channel_axis] + extra_dims -= 1 smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) data = reductor(data, axis=smallest_dims)