diff --git a/examples/rgb.py b/examples/rgb.py index eff61a03..7191472f 100644 --- a/examples/rgb.py +++ b/examples/rgb.py @@ -1,15 +1,18 @@ +import math + import numpy import ndv -img = numpy.zeros((256, 256, 3), dtype=numpy.uint8) +img = numpy.zeros((256, 256, 4), dtype=numpy.float32) for x in range(256): for y in range(256): img[x, y, 0] = x img[x, y, 1] = y img[x, y, 2] = 255 - x + img[x, y, 3] = math.sqrt((x - 128) ** 2 + (y - 128) ** 2) -img = img.transpose(2, 0, 1) +# img = img.transpose(2, 0, 1) -ndv.imshow(img, channel_mode="rgb") +ndv.imshow(img) diff --git a/src/ndv/util.py b/src/ndv/util.py index e61aeb1a..e33f93f7 100644 --- a/src/ndv/util.py +++ b/src/ndv/util.py @@ -19,7 +19,7 @@ def imshow( data: Any | DataWrapper, cmap: Any | None = None, *, - channel_mode: Literal["mono", "composite", "rgb", "auto"] = "auto", + channel_mode: Literal["mono", "composite", "rgba", "auto"] = "auto", ) -> NDViewer: """Display an array or DataWrapper in a new NDViewer window. @@ -29,7 +29,7 @@ def imshow( The data to be displayed. If not a DataWrapper, it will be wrapped in one. cmap : Any | None, optional The colormap(s) to use for displaying the data. - channel_mode : Literal['mono', 'composite', 'rgb'], optional + channel_mode : Literal['mono', 'composite', 'rgba'], optional The initial mode for displaying the channels. By default "mono" will be used unless a cmap is provided, in which case "composite" will be used. @@ -39,12 +39,8 @@ def imshow( The viewer window. """ app, should_exec = _get_app() - if cmap is not None: - channel_mode = "composite" - if not isinstance(cmap, (list, tuple)): - cmap = [cmap] - elif channel_mode == "auto": - channel_mode = "mono" + if cmap is not None and not isinstance(cmap, (list, tuple)): + cmap = [cmap] viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode) viewer.show() viewer.raise_() diff --git a/src/ndv/viewer/_backends/_vispy.py b/src/ndv/viewer/_backends/_vispy.py index 45fc5fb6..31f875ce 100755 --- a/src/ndv/viewer/_backends/_vispy.py +++ b/src/ndv/viewer/_backends/_vispy.py @@ -498,7 +498,7 @@ def add_image( ) -> VispyImageHandle: """Add a new Image node to the scene.""" if data is not None and data.ndim == 3: - # VisPy expects (A)RGB data to be X, Y, C + # VisPy expects RGB(A) data to be X, Y, C for i, s in enumerate(data.shape): if s in [3, 4]: data = np.moveaxis(data, i, -1) diff --git a/src/ndv/viewer/_components.py b/src/ndv/viewer/_components.py index 59aa3367..93f9051e 100644 --- a/src/ndv/viewer/_components.py +++ b/src/ndv/viewer/_components.py @@ -36,7 +36,7 @@ def __init__(self, parent: QWidget | None = None): class ChannelMode(str, Enum): COMPOSITE = "composite" - RGB = "rgb" + RGBA = "rgba" MONO = "mono" def __str__(self) -> str: @@ -47,12 +47,12 @@ class ChannelModeCombo(QEnumComboBox): def __init__(self, parent: QWidget | None = None): super().__init__(parent, enum_class=ChannelMode) - def enable_rgb(self, enable: bool) -> None: + def enable_rgba(self, enable: bool) -> None: with signals_blocked(self): current = self.currentEnum() self.setEnumClass(ChannelMode) if not enable: - idx = list(ChannelMode.__members__.keys()).index("RGB") + idx = list(ChannelMode.__members__.keys()).index("RGBA") self.removeItem(idx) if current: self.setCurrentEnum(current) diff --git a/src/ndv/viewer/_viewer.py b/src/ndv/viewer/_viewer.py index 956e8375..b18f5ad4 100755 --- a/src/ndv/viewer/_viewer.py +++ b/src/ndv/viewer/_viewer.py @@ -120,6 +120,11 @@ def __init__( channel_mode: ChannelMode | str = ChannelMode.MONO, ): super().__init__(parent=parent) + channel_mode = ( + self._guess_channel_mode(colormaps, data) + if channel_mode == "auto" + else channel_mode + ) # ATTRIBUTES ---------------------------------------------------- @@ -296,8 +301,10 @@ def set_data( visualized_dims = visualized_dims[-self._ndims :] self.set_visualized_dims(visualized_dims) - is_rgb = (self._channel_axis is not None) and (sizes[self._channel_axis] == 3) - self._channel_mode_combo.enable_rgb(is_rgb) + is_rgba = (self._channel_axis is not None) and ( + sizes[self._channel_axis] in [3, 4] + ) + self._channel_mode_combo.enable_rgba(is_rgba) # update the range of all the sliders to match the sizes we set above with signals_blocked(self._dims_sliders): @@ -401,7 +408,8 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None: if self._channel_axis is not None: # set the visibility of the channel slider self._dims_sliders.set_dimension_visible( - self._channel_axis, mode not in [ChannelMode.COMPOSITE, ChannelMode.RGB] + self._channel_axis, + mode not in [ChannelMode.COMPOSITE, ChannelMode.RGBA], ) if self._img_handles: @@ -436,6 +444,21 @@ def set_current_index(self, index: Indices | None = None) -> None: # ------------------- PRIVATE METHODS ---------------------------- + def _guess_channel_mode( + self, + colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None, + data: DataWrapper | Any | None = None, + ) -> ChannelMode: + # Users who provider colormaps generally expect composite images + if colormaps is not None: + return ChannelMode.COMPOSITE + # Data shaped [Y, X, 3] or [Y, X, 4], are usually RGB images + if (shape := getattr(data, "shape", None)) is not None: + if len(shape) == 3 and shape[-1] in [3, 4]: + return ChannelMode.RGBA + # Default + return ChannelMode.MONO + def _toggle_3d(self) -> None: self.set_ndim(3 if self._ndims == 2 else 2) @@ -494,7 +517,7 @@ def _update_data_for_index(self, index: Indices) -> None: ] elif ( self._channel_axis is not None - and self._channel_mode == ChannelMode.RGB + and self._channel_mode == ChannelMode.RGBA and self._channel_axis in (sizes := self._data_wrapper.sizes()) ): indices = [{k: v for k, v in index.items() if k != self._channel_axis}] @@ -564,13 +587,13 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None: if self._channel_mode == ChannelMode.COMPOSITE else GRAYS ) - if datum.ndim == 2 or self._channel_mode == ChannelMode.RGB: + if datum.ndim == 2 or self._channel_mode == ChannelMode.RGBA: handles.append(self._canvas.add_image(datum, cmap=cm)) elif datum.ndim == 3: handles.append(self._canvas.add_volume(datum, cmap=cm)) if imkey not in self._lut_ctrls: cls = ( - RGBControl if self._channel_mode == ChannelMode.RGB else LutControl + RGBControl if self._channel_mode == ChannelMode.RGBA else LutControl ) ch_index = index.get(self._channel_axis, 0) self._lut_ctrls[imkey] = c = cls( @@ -608,10 +631,10 @@ def _reduce_data_for_display( if extra_dims := data.ndim - len(visualized_dims): shapes = sorted(enumerate(data.shape), key=lambda x: x[1]) # HACK: Preserve channels in RGB mode - if self._channel_mode == ChannelMode.RGB: + if self._channel_mode == ChannelMode.RGBA: # There should be one dimension of size 3 that we need to preserve for i, (_dim, pos) in enumerate(shapes): - if pos == 3: + if pos in [3, 4]: shapes.pop(i) extra_dims -= 1 if pos >= 3: @@ -624,7 +647,7 @@ def _reduce_data_for_display( data = data.astype(np.int32) else: data = data.astype(np.float32) - if self._channel_mode == ChannelMode.RGB: + if self._channel_mode == ChannelMode.RGBA: data = data.astype(np.uint8) return data