diff --git a/examples/rgb.py b/examples/rgb.py new file mode 100644 index 00000000..eff61a03 --- /dev/null +++ b/examples/rgb.py @@ -0,0 +1,15 @@ +import numpy + +import ndv + +img = numpy.zeros((256, 256, 3), dtype=numpy.uint8) + +for x in range(256): + for y in range(256): + img[x, y, 0] = x + img[x, y, 1] = y + img[x, y, 2] = 255 - x + +img = img.transpose(2, 0, 1) + +ndv.imshow(img, channel_mode="rgb") diff --git a/src/ndv/util.py b/src/ndv/util.py index ac72d194..e61aeb1a 100644 --- a/src/ndv/util.py +++ b/src/ndv/util.py @@ -19,7 +19,7 @@ def imshow( data: Any | DataWrapper, cmap: Any | None = None, *, - channel_mode: Literal["mono", "composite", "auto"] = "auto", + channel_mode: Literal["mono", "composite", "rgb", "auto"] = "auto", ) -> NDViewer: """Display an array or DataWrapper in a new NDViewer window. @@ -29,7 +29,7 @@ def imshow( The data to be displayed. If not a DataWrapper, it will be wrapped in one. cmap : Any | None, optional The colormap(s) to use for displaying the data. - channel_mode : Literal['mono', 'composite'], optional + channel_mode : Literal['mono', 'composite', 'rgb'], optional The initial mode for displaying the channels. By default "mono" will be used unless a cmap is provided, in which case "composite" will be used. diff --git a/src/ndv/viewer/_backends/_pygfx.py b/src/ndv/viewer/_backends/_pygfx.py index 29886434..1260ef42 100755 --- a/src/ndv/viewer/_backends/_pygfx.py +++ b/src/ndv/viewer/_backends/_pygfx.py @@ -38,6 +38,7 @@ def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None: self._render = render self._grid = cast("Texture", image.geometry.grid) self._material = cast("ImageBasicMaterial", image.material) + self._cmap = cmap.Colormap("gray") @property def data(self) -> np.ndarray: @@ -45,6 +46,12 @@ def data(self) -> np.ndarray: @data.setter def data(self, data: np.ndarray) -> None: + if data is not None and data.ndim == 3: + # PyGFX expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + data = np.moveaxis(data, i, -1) + break self._grid.data[:] = data self._grid.update_range((0, 0, 0), self._grid.size) @@ -85,7 +92,9 @@ def cmap(self) -> cmap.Colormap: @cmap.setter def cmap(self, cmap: cmap.Colormap) -> None: self._cmap = cmap - self._material.map = cmap.to_pygfx() + # FIXME: RGB image special case + if self.data.ndim != 3: + self._material.map = cmap.to_pygfx() self._render() def start_move(self, pos: Sequence[float]) -> None: @@ -470,6 +479,12 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> PyGFXImageHandle: """Add a new Image node to the scene.""" + if data is not None and data.ndim == 3: + # PyGFX expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + data = np.moveaxis(data, i, -1) + break tex = pygfx.Texture(data, dim=2) image = pygfx.Image( pygfx.Geometry(grid=tex), diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 7b46634b..5ace2ebd 100755 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -247,7 +247,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None: class VispyImageHandle: def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None: self._visual = visual - self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3 + self._ndim = self.data.ndim @property def data(self) -> np.ndarray: @@ -265,6 +265,12 @@ def data(self, data: np.ndarray) -> None: stacklevel=2, ) return + if data is not None and data.ndim == 3: + # VisPy expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + data = np.moveaxis(data, i, -1) + break self._visual.set_data(data) @property @@ -490,11 +496,17 @@ def add_image( self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None ) -> VispyImageHandle: """Add a new Image node to the scene.""" + if data is not None and data.ndim == 3: + # VisPy expects (A)RGB data to be X, Y, C + for i, s in enumerate(data.shape): + if s in [3, 4]: + data = np.moveaxis(data, i, -1) + break img = scene.visuals.Image(data, parent=self._view.scene) img.set_gl_state("additive", depth_test=False) img.interactive = True if data is not None: - self._current_shape, prev_shape = data.shape, self._current_shape + self._current_shape, prev_shape = data.shape[:2], self._current_shape if not prev_shape: self.set_range() handle = VispyImageHandle(img) diff --git a/src/ndv/viewer/_components.py b/src/ndv/viewer/_components.py index 9dc7450b..59aa3367 100644 --- a/src/ndv/viewer/_components.py +++ b/src/ndv/viewer/_components.py @@ -6,7 +6,8 @@ from qtpy.QtCore import QSize from qtpy.QtGui import QMovie from qtpy.QtWidgets import QLabel, QPushButton, QWidget -from superqt import QIconifyIcon +from superqt import QEnumComboBox, QIconifyIcon +from superqt.utils import signals_blocked SPIN_GIF = str(Path(__file__).parent / "spin.gif") @@ -35,35 +36,26 @@ def __init__(self, parent: QWidget | None = None): class ChannelMode(str, Enum): COMPOSITE = "composite" + RGB = "rgb" MONO = "mono" def __str__(self) -> str: return self.value -class ChannelModeButton(QPushButton): +class ChannelModeCombo(QEnumComboBox): def __init__(self, parent: QWidget | None = None): - super().__init__(parent) - self.setCheckable(True) - self.toggled.connect(self.next_mode) - - # set minimum width to the width of the larger string 'composite' - self.setMinimumWidth(92) # magic number :/ - - def next_mode(self) -> None: - if self.isChecked(): - self.setMode(ChannelMode.MONO) - else: - self.setMode(ChannelMode.COMPOSITE) - - def mode(self) -> ChannelMode: - return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE - - def setMode(self, mode: ChannelMode) -> None: - # we show the name of the next mode, not the current one - other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO - self.setText(str(other)) - self.setChecked(mode == ChannelMode.MONO) + super().__init__(parent, enum_class=ChannelMode) + + def enable_rgb(self, enable: bool) -> None: + with signals_blocked(self): + current = self.currentEnum() + self.setEnumClass(ChannelMode) + if not enable: + idx = list(ChannelMode.__members__.keys()).index("RGB") + self.removeItem(idx) + if current: + self.setCurrentEnum(current) class ROIButton(QPushButton): diff --git a/src/ndv/viewer/_data_wrapper.py b/src/ndv/viewer/_data_wrapper.py index dd4983e4..db3f1784 100644 --- a/src/ndv/viewer/_data_wrapper.py +++ b/src/ndv/viewer/_data_wrapper.py @@ -139,19 +139,15 @@ def guess_channel_axis(self) -> Hashable | None: """Return the (best guess) axis name for the channel dimension.""" # for arrays with labeled dimensions, # see if any of the dimensions are named "channel" - for dimkey, val in self.sizes().items(): + sizes = self.sizes() + for dimkey, val in sizes.items(): if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES: if val <= self.MAX_CHANNELS: return dimkey # for shaped arrays, use the smallest dimension as the channel axis - shape = getattr(self._data, "shape", None) - if isinstance(shape, Sequence): - with suppress(ValueError): - smallest_dim = min(shape) - if smallest_dim <= self.MAX_CHANNELS: - return shape.index(smallest_dim) - return None + min_key = min(sizes, key=sizes.get) # type: ignore + return min_key if sizes[min_key] <= self.MAX_CHANNELS else None def save_as_zarr(self, save_loc: str | Path) -> None: raise NotImplementedError("save_as_zarr not implemented for this data type.") diff --git a/src/ndv/viewer/_lut_control.py b/src/ndv/viewer/_lut_control.py index e270e00e..90bd11af 100644 --- a/src/ndv/viewer/_lut_control.py +++ b/src/ndv/viewer/_lut_control.py @@ -135,3 +135,20 @@ def update_autoscale(self) -> None: self._clims.setMinimum(min(mi, self._clims.minimum())) self._clims.setMaximum(max(ma, self._clims.maximum())) self._clims.setValue((mi, ma)) + + +class RGBControl(LutControl): + def __init__( + self, + name: str = "", + handles: Iterable[PImageHandle] = (), + parent: QWidget | None = None, + cmaplist: Iterable[Any] = (), + auto_clim: bool = True, + ) -> None: + LutControl.__init__(self, "RGB", handles, parent, cmaplist, auto_clim) + self._cmap.setVisible(False) + + def _on_cmap_changed(self, cmap: cmap.Colormap) -> None: + # NB: No-op + pass diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 5f0092dc..a004d6b9 100755 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -14,7 +14,7 @@ from ndv.viewer._components import ( ChannelMode, - ChannelModeButton, + ChannelModeCombo, DimToggleButton, QSpinner, ROIButton, @@ -23,7 +23,7 @@ from ._backends import get_canvas_class from ._data_wrapper import DataWrapper from ._dims_slider import DimsSliders -from ._lut_control import LutControl +from ._lut_control import LutControl, RGBControl if TYPE_CHECKING: from concurrent.futures import Future @@ -154,8 +154,8 @@ def __init__( # WIDGETS ---------------------------------------------------- # the button that controls the display mode of the channels - self._channel_mode_btn = ChannelModeButton(self) - self._channel_mode_btn.clicked.connect(self.set_channel_mode) + self._channel_mode_combo = ChannelModeCombo(self) + self._channel_mode_combo.currentEnumChanged.connect(self.set_channel_mode) # button to reset the zoom of the canvas self._set_range_btn = QPushButton( QIconifyIcon("fluent:full-screen-maximize-24-filled"), "", self @@ -208,7 +208,7 @@ def __init__( btns.setContentsMargins(0, 0, 0, 0) btns.setSpacing(0) btns.addStretch() - btns.addWidget(self._channel_mode_btn) + btns.addWidget(self._channel_mode_combo) btns.addWidget(self._ndims_btn) btns.addWidget(self._set_range_btn) btns.addWidget(self._add_roi_btn) @@ -276,6 +276,10 @@ def set_data( or slices that define the slice of the data to display. If not provided, the initial index will be set to the middle of the data. """ + if hasattr(self, "_data_wrapper"): + for old_dim in list(self._data_wrapper.sizes()): + self.dims_sliders.remove_dimension(old_dim) + # store the data self._data_wrapper = DataWrapper.create(data) @@ -283,19 +287,25 @@ def set_data( self._channel_axis = self._data_wrapper.guess_channel_axis() # update the dimensions we are visualizing - sizes = self._data_wrapper.sizes() - visualized_dims = list(sizes)[-self._ndims :] + sizes = dict(self._data_wrapper.sizes().items()) + + visualized_dims = list(sizes) + if self._channel_axis is not None: + visualized_dims.remove(self._channel_axis) + visualized_dims = visualized_dims[-self._ndims :] self.set_visualized_dims(visualized_dims) + is_rgb = (self._channel_axis is not None) and (sizes[self._channel_axis] == 3) + self._channel_mode_combo.enable_rgb(is_rgb) + # update the range of all the sliders to match the sizes we set above with signals_blocked(self._dims_sliders): self._update_slider_ranges() # redraw if initial_index is None: - idx = self._dims_sliders.value() or { - k: int(v // 2) for k, v in sizes.items() - } + idx = {k: int(v // 2) for k, v in sizes.items()} + idx.update(self._dims_sliders.value() or {}) else: if not isinstance(initial_index, dict): # pragma: no cover raise TypeError("initial_index must be a dict") @@ -377,19 +387,20 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: """ # bool may happen when called from the button clicked signal if mode is None or isinstance(mode, bool): - mode = self._channel_mode_btn.mode() + mode = self._channel_mode_combo.currentEnum() else: mode = ChannelMode(mode) - self._channel_mode_btn.setMode(mode) + with signals_blocked(self._channel_mode_combo): + self._channel_mode_combo.setCurrentEnum(mode) if mode == self._channel_mode: return - self._channel_mode = mode + self._channel_mode = cast(ChannelMode, mode) self._cmap_cycle = cycle(self._cmaps) # reset the colormap cycle if self._channel_axis is not None: # set the visibility of the channel slider self._dims_sliders.set_dimension_visible( - self._channel_axis, mode != ChannelMode.COMPOSITE + self._channel_axis, mode not in [ChannelMode.COMPOSITE, ChannelMode.RGB] ) if self._img_handles: @@ -441,8 +452,7 @@ def _update_slider_ranges(self) -> None: maxes = self._data_wrapper.sizes() self._dims_sliders.setMaxima({k: v - 1 for k, v in maxes.items()}) - # FIXME: this needs to be moved and made user-controlled - for dim in list(maxes.keys())[-self._ndims :]: + for dim in self._visualized_dims: self._dims_sliders.set_dimension_visible(dim, False) def _on_set_range_clicked(self) -> None: @@ -471,15 +481,22 @@ def _update_data_for_index(self, index: Indices) -> None: makes a request for the new data slice and queues _on_data_future_done to be called when the data is ready. """ + indices: list[Indices] if ( self._channel_axis is not None and self._channel_mode == ChannelMode.COMPOSITE and self._channel_axis in (sizes := self._data_wrapper.sizes()) ): - indices: list[Indices] = [ + indices = [ {**index, self._channel_axis: i} for i in range(sizes[self._channel_axis]) ] + elif ( + self._channel_axis is not None + and self._channel_mode == ChannelMode.RGB + and self._channel_axis in (sizes := self._data_wrapper.sizes()) + ): + indices = [{k: v for k, v in index.items() if k != self._channel_axis}] else: indices = [index] @@ -539,19 +556,23 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: handle.data = datum if ctrl := self._lut_ctrls.get(imkey, None): ctrl.update_autoscale() + else: cm = ( next(self._cmap_cycle) if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) - if datum.ndim == 2: + if datum.ndim == 2 or self._channel_mode == ChannelMode.RGB: handles.append(self._canvas.add_image(datum, cmap=cm)) elif datum.ndim == 3: handles.append(self._canvas.add_volume(datum, cmap=cm)) if imkey not in self._lut_ctrls: + cls = ( + RGBControl if self._channel_mode == ChannelMode.RGB else LutControl + ) ch_index = index.get(self._channel_axis, 0) - self._lut_ctrls[imkey] = c = LutControl( + self._lut_ctrls[imkey] = c = cls( f"Ch {ch_index}", handles, self, @@ -574,16 +595,26 @@ def _reduce_data_for_display( the max allowed for display. The default behavior is to reduce the smallest dimensions, using np.max. This can be improved in the future. - This also coerces 64-bit data to 32-bit data. + This also coerces 64-bit data to 32-bit data, and RGB data to unsigned + 8-bit data """ # TODO # - allow dimensions to control how they are reduced (as opposed to just max) # - for better way to determine which dims need to be reduced (currently just # the smallest dims) data = data.squeeze() - visualized_dims = self._ndims - if extra_dims := data.ndim - visualized_dims: + visualized_dims = self._visualized_dims.copy() + if extra_dims := data.ndim - len(visualized_dims): shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) + # HACK: Preserve channels in RGB mode + if self._channel_mode == ChannelMode.RGB: + # There should be one dimension of size 3 that we need to preserve + for i, (_dim, pos) in enumerate(shapes): + if pos == 3: + shapes.pop(i) + extra_dims -= 1 + if pos >= 3: + break smallest_dims = tuple(i for i, _ in shapes[:extra_dims]) data = reductor(data, axis=smallest_dims) @@ -592,6 +623,8 @@ def _reduce_data_for_display( data = data.astype(np.int32) else: data = data.astype(np.float32) + if self._channel_mode == ChannelMode.RGB: + data = data.astype(np.uint8) return data def _clear_images(self) -> None: