Skip to content

Commit

Permalink
alpha support
Browse files Browse the repository at this point in the history
  • Loading branch information
gselzer committed Oct 4, 2024
1 parent 7e07137 commit 339607c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 24 deletions.
9 changes: 6 additions & 3 deletions examples/rgb.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
import math

import numpy

import ndv

img = numpy.zeros((256, 256, 3), dtype=numpy.uint8)
img = numpy.zeros((256, 256, 4), dtype=numpy.float32)

for x in range(256):
for y in range(256):
img[x, y, 0] = x
img[x, y, 1] = y
img[x, y, 2] = 255 - x
img[x, y, 3] = math.sqrt((x - 128) ** 2 + (y - 128) ** 2)

img = img.transpose(2, 0, 1)
# img = img.transpose(2, 0, 1)

ndv.imshow(img, channel_mode="rgb")
ndv.imshow(img)
12 changes: 4 additions & 8 deletions src/ndv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def imshow(
data: Any | DataWrapper,
cmap: Any | None = None,
*,
channel_mode: Literal["mono", "composite", "rgb", "auto"] = "auto",
channel_mode: Literal["mono", "composite", "rgba", "auto"] = "auto",
) -> NDViewer:
"""Display an array or DataWrapper in a new NDViewer window.
Expand All @@ -29,7 +29,7 @@ def imshow(
The data to be displayed. If not a DataWrapper, it will be wrapped in one.
cmap : Any | None, optional
The colormap(s) to use for displaying the data.
channel_mode : Literal['mono', 'composite', 'rgb'], optional
channel_mode : Literal['mono', 'composite', 'rgba'], optional
The initial mode for displaying the channels. By default "mono" will be
used unless a cmap is provided, in which case "composite" will be used.
Expand All @@ -39,12 +39,8 @@ def imshow(
The viewer window.
"""
app, should_exec = _get_app()
if cmap is not None:
channel_mode = "composite"
if not isinstance(cmap, (list, tuple)):
cmap = [cmap]
elif channel_mode == "auto":
channel_mode = "mono"
if cmap is not None and not isinstance(cmap, (list, tuple)):
cmap = [cmap]
viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode)
viewer.show()
viewer.raise_()
Expand Down
2 changes: 1 addition & 1 deletion src/ndv/viewer/_backends/_vispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ def add_image(
) -> VispyImageHandle:
"""Add a new Image node to the scene."""
if data is not None and data.ndim == 3:
# VisPy expects (A)RGB data to be X, Y, C
# VisPy expects RGB(A) data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
Expand Down
6 changes: 3 additions & 3 deletions src/ndv/viewer/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, parent: QWidget | None = None):

class ChannelMode(str, Enum):
COMPOSITE = "composite"
RGB = "rgb"
RGBA = "rgba"
MONO = "mono"

def __str__(self) -> str:
Expand All @@ -47,12 +47,12 @@ class ChannelModeCombo(QEnumComboBox):
def __init__(self, parent: QWidget | None = None):
super().__init__(parent, enum_class=ChannelMode)

def enable_rgb(self, enable: bool) -> None:
def enable_rgba(self, enable: bool) -> None:
with signals_blocked(self):
current = self.currentEnum()
self.setEnumClass(ChannelMode)
if not enable:
idx = list(ChannelMode.__members__.keys()).index("RGB")
idx = list(ChannelMode.__members__.keys()).index("RGBA")
self.removeItem(idx)
if current:
self.setCurrentEnum(current)
Expand Down
41 changes: 32 additions & 9 deletions src/ndv/viewer/_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,11 @@ def __init__(
channel_mode: ChannelMode | str = ChannelMode.MONO,
):
super().__init__(parent=parent)
channel_mode = (
self._guess_channel_mode(colormaps, data)
if channel_mode == "auto"
else channel_mode
)

# ATTRIBUTES ----------------------------------------------------

Expand Down Expand Up @@ -296,8 +301,10 @@ def set_data(
visualized_dims = visualized_dims[-self._ndims :]
self.set_visualized_dims(visualized_dims)

is_rgb = (self._channel_axis is not None) and (sizes[self._channel_axis] == 3)
self._channel_mode_combo.enable_rgb(is_rgb)
is_rgba = (self._channel_axis is not None) and (
sizes[self._channel_axis] in [3, 4]
)
self._channel_mode_combo.enable_rgba(is_rgba)

# update the range of all the sliders to match the sizes we set above
with signals_blocked(self._dims_sliders):
Expand Down Expand Up @@ -401,7 +408,8 @@ def set_channel_mode(self, mode: ChannelMode | str | None = None) -> None:
if self._channel_axis is not None:
# set the visibility of the channel slider
self._dims_sliders.set_dimension_visible(
self._channel_axis, mode not in [ChannelMode.COMPOSITE, ChannelMode.RGB]
self._channel_axis,
mode not in [ChannelMode.COMPOSITE, ChannelMode.RGBA],
)

if self._img_handles:
Expand Down Expand Up @@ -436,6 +444,21 @@ def set_current_index(self, index: Indices | None = None) -> None:

# ------------------- PRIVATE METHODS ----------------------------

def _guess_channel_mode(
self,
colormaps: Iterable[cmap._colormap.ColorStopsLike] | None = None,
data: DataWrapper | Any | None = None,
) -> ChannelMode:
# Users who provider colormaps generally expect composite images
if colormaps is not None:
return ChannelMode.COMPOSITE
# Data shaped [Y, X, 3] or [Y, X, 4], are usually RGB images
if (shape := getattr(data, "shape", None)) is not None:
if len(shape) == 3 and shape[-1] in [3, 4]:
return ChannelMode.RGBA
# Default
return ChannelMode.MONO

def _toggle_3d(self) -> None:
self.set_ndim(3 if self._ndims == 2 else 2)

Expand Down Expand Up @@ -494,7 +517,7 @@ def _update_data_for_index(self, index: Indices) -> None:
]
elif (
self._channel_axis is not None
and self._channel_mode == ChannelMode.RGB
and self._channel_mode == ChannelMode.RGBA
and self._channel_axis in (sizes := self._data_wrapper.sizes())
):
indices = [{k: v for k, v in index.items() if k != self._channel_axis}]
Expand Down Expand Up @@ -564,13 +587,13 @@ def _update_canvas_data(self, data: np.ndarray, index: Indices) -> None:
if self._channel_mode == ChannelMode.COMPOSITE
else GRAYS
)
if datum.ndim == 2 or self._channel_mode == ChannelMode.RGB:
if datum.ndim == 2 or self._channel_mode == ChannelMode.RGBA:
handles.append(self._canvas.add_image(datum, cmap=cm))
elif datum.ndim == 3:
handles.append(self._canvas.add_volume(datum, cmap=cm))
if imkey not in self._lut_ctrls:
cls = (
RGBControl if self._channel_mode == ChannelMode.RGB else LutControl
RGBControl if self._channel_mode == ChannelMode.RGBA else LutControl
)
ch_index = index.get(self._channel_axis, 0)
self._lut_ctrls[imkey] = c = cls(
Expand Down Expand Up @@ -608,10 +631,10 @@ def _reduce_data_for_display(
if extra_dims := data.ndim - len(visualized_dims):
shapes = sorted(enumerate(data.shape), key=lambda x: x[1])
# HACK: Preserve channels in RGB mode
if self._channel_mode == ChannelMode.RGB:
if self._channel_mode == ChannelMode.RGBA:
# There should be one dimension of size 3 that we need to preserve
for i, (_dim, pos) in enumerate(shapes):
if pos == 3:
if pos in [3, 4]:
shapes.pop(i)
extra_dims -= 1
if pos >= 3:
Expand All @@ -624,7 +647,7 @@ def _reduce_data_for_display(
data = data.astype(np.int32)
else:
data = data.astype(np.float32)
if self._channel_mode == ChannelMode.RGB:
if self._channel_mode == ChannelMode.RGBA:
data = data.astype(np.uint8)
return data

Expand Down

0 comments on commit 339607c

Please sign in to comment.