Skip to content

Commit

Permalink
feat: RGB support
Browse files Browse the repository at this point in the history
  • Loading branch information
gselzer committed Sep 16, 2024
1 parent b79241c commit 7e6e69f
Show file tree
Hide file tree
Showing 8 changed files with 144 additions and 64 deletions.
15 changes: 15 additions & 0 deletions examples/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy

import ndv

img = numpy.zeros((256, 256, 3), dtype=numpy.uint8)

for x in range(256):
for y in range(256):
img[x, y, 0] = x
img[x, y, 1] = y
img[x, y, 2] = 255 - x

img = img.transpose(2, 0, 1)

ndv.imshow(img, channel_mode="rgb")
4 changes: 2 additions & 2 deletions src/ndv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def imshow(
data: Any | DataWrapper,
cmap: Any | None = None,
*,
channel_mode: Literal["mono", "composite", "auto"] = "auto",
channel_mode: Literal["mono", "composite", "rgb", "auto"] = "auto",
) -> NDViewer:
"""Display an array or DataWrapper in a new NDViewer window.
Expand All @@ -29,7 +29,7 @@ def imshow(
The data to be displayed. If not a DataWrapper, it will be wrapped in one.
cmap : Any | None, optional
The colormap(s) to use for displaying the data.
channel_mode : Literal['mono', 'composite'], optional
channel_mode : Literal['mono', 'composite', 'rgb'], optional
The initial mode for displaying the channels. By default "mono" will be
used unless a cmap is provided, in which case "composite" will be used.
Expand Down
17 changes: 16 additions & 1 deletion src/ndv/viewer/_backends/_pygfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None:
self._render = render
self._grid = cast("Texture", image.geometry.grid)
self._material = cast("ImageBasicMaterial", image.material)
self._cmap = cmap.Colormap("gray")

@property
def data(self) -> np.ndarray:
return self._grid.data # type: ignore [no-any-return]

@data.setter
def data(self, data: np.ndarray) -> None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
self._grid.data[:] = data
self._grid.update_range((0, 0, 0), self._grid.size)

Expand Down Expand Up @@ -85,7 +92,9 @@ def cmap(self) -> cmap.Colormap:
@cmap.setter
def cmap(self, cmap: cmap.Colormap) -> None:
self._cmap = cmap
self._material.map = cmap.to_pygfx()
# FIXME: RGB image special case
if self.data.ndim != 3:
self._material.map = cmap.to_pygfx()
self._render()

def start_move(self, pos: Sequence[float]) -> None:
Expand Down Expand Up @@ -470,6 +479,12 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> PyGFXImageHandle:
"""Add a new Image node to the scene."""
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
tex = pygfx.Texture(data, dim=2)
image = pygfx.Image(
pygfx.Geometry(grid=tex),
Expand Down
16 changes: 14 additions & 2 deletions src/ndv/viewer/_backends/_vispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None:
class VispyImageHandle:
def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None:
self._visual = visual
self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3
self._ndim = self.data.ndim

@property
def data(self) -> np.ndarray:
Expand All @@ -265,6 +265,12 @@ def data(self, data: np.ndarray) -> None:
stacklevel=2,
)
return
if data is not None and data.ndim == 3:
# VisPy expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
self._visual.set_data(data)

@property
Expand Down Expand Up @@ -490,11 +496,17 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> VispyImageHandle:
"""Add a new Image node to the scene."""
if data is not None and data.ndim == 3:
# VisPy expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
img = scene.visuals.Image(data, parent=self._view.scene)
img.set_gl_state("additive", depth_test=False)
img.interactive = True
if data is not None:
self._current_shape, prev_shape = data.shape, self._current_shape
self._current_shape, prev_shape = data.shape[:2], self._current_shape
if not prev_shape:
self.set_range()
handle = VispyImageHandle(img)
Expand Down
38 changes: 15 additions & 23 deletions src/ndv/viewer/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from qtpy.QtCore import QSize
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import QLabel, QPushButton, QWidget
from superqt import QIconifyIcon
from superqt import QEnumComboBox, QIconifyIcon
from superqt.utils import signals_blocked

SPIN_GIF = str(Path(__file__).parent / "spin.gif")

Expand Down Expand Up @@ -35,35 +36,26 @@ def __init__(self, parent: QWidget | None = None):

class ChannelMode(str, Enum):
COMPOSITE = "composite"
RGB = "rgb"
MONO = "mono"

def __str__(self) -> str:
return self.value


class ChannelModeButton(QPushButton):
class ChannelModeCombo(QEnumComboBox):
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
self.setCheckable(True)
self.toggled.connect(self.next_mode)

# set minimum width to the width of the larger string 'composite'
self.setMinimumWidth(92) # magic number :/

def next_mode(self) -> None:
if self.isChecked():
self.setMode(ChannelMode.MONO)
else:
self.setMode(ChannelMode.COMPOSITE)

def mode(self) -> ChannelMode:
return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE

def setMode(self, mode: ChannelMode) -> None:
# we show the name of the next mode, not the current one
other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO
self.setText(str(other))
self.setChecked(mode == ChannelMode.MONO)
super().__init__(parent, enum_class=ChannelMode)

def enable_rgb(self, enable: bool) -> None:
with signals_blocked(self):
current = self.currentEnum()
self.setEnumClass(ChannelMode)
if not enable:
idx = list(ChannelMode.__members__.keys()).index("RGB")
self.removeItem(idx)
if current:
self.setCurrentEnum(current)


class ROIButton(QPushButton):
Expand Down
12 changes: 4 additions & 8 deletions src/ndv/viewer/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,19 +139,15 @@ def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
# see if any of the dimensions are named "channel"
for dimkey, val in self.sizes().items():
sizes = self.sizes()
for dimkey, val in sizes.items():
if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES:
if val <= self.MAX_CHANNELS:
return dimkey

# for shaped arrays, use the smallest dimension as the channel axis
shape = getattr(self._data, "shape", None)
if isinstance(shape, Sequence):
with suppress(ValueError):
smallest_dim = min(shape)
if smallest_dim <= self.MAX_CHANNELS:
return shape.index(smallest_dim)
return None
min_key = min(sizes, key=sizes.get) # type: ignore
return min_key if sizes[min_key] <= self.MAX_CHANNELS else None

def save_as_zarr(self, save_loc: str | Path) -> None:
raise NotImplementedError("save_as_zarr not implemented for this data type.")
Expand Down
33 changes: 24 additions & 9 deletions src/ndv/viewer/_lut_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,27 @@ def update_autoscale(self) -> None:
clims[1] = max(clims[1], np.nanmax(handle.data))

mi, ma = tuple(int(x) for x in clims)
if mi != ma:
for handle in self._handles:
handle.clim = (mi, ma)

# set the slider values to the new clims
with signals_blocked(self._clims):
self._clims.setMinimum(min(mi, self._clims.minimum()))
self._clims.setMaximum(max(ma, self._clims.maximum()))
self._clims.setValue((mi, ma))
for handle in self._handles:
handle.clim = (mi, ma)

# set the slider values to the new clims
with signals_blocked(self._clims):
self._clims.setMinimum(min(mi, self._clims.minimum()))
self._clims.setMaximum(max(ma, self._clims.maximum()))
self._clims.setValue((mi, ma))


class RGBControl(LutControl):
def __init__(
self,
name: str = "",
handles: Iterable[PImageHandle] = (),
parent: QWidget | None = None,
cmaplist: Iterable[Any] = (),
) -> None:
LutControl.__init__(self, "RGB", handles, parent, cmaplist)
self._cmap.setVisible(False)

def _on_cmap_changed(self, cmap: cmap.Colormap) -> None:
# NB: No-op
pass
Loading

0 comments on commit 7e6e69f

Please sign in to comment.