Skip to content

Commit

Permalink
feat: RGB support
Browse files Browse the repository at this point in the history
For provenance, VisPy allows RGB data to come from unsigned integers
and floating point data pre-nomalized within the range [0, 1].
  • Loading branch information
gselzer committed Oct 9, 2024
1 parent 10abe13 commit cf46ae2
Show file tree
Hide file tree
Showing 8 changed files with 201 additions and 82 deletions.
16 changes: 16 additions & 0 deletions examples/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import math

import numpy

import ndv

img = numpy.zeros((256, 256, 4), dtype=numpy.uint8)

for x in range(256):
for y in range(256):
img[x, y, 0] = x
img[x, y, 1] = y
img[x, y, 2] = 255 - x
img[x, y, 3] = int(math.sqrt((x - 128) ** 2 + (y - 128) ** 2))

n = ndv.imshow(img)
16 changes: 7 additions & 9 deletions src/ndv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any

from qtpy.QtWidgets import QApplication

Expand All @@ -12,14 +12,16 @@
if TYPE_CHECKING:
from qtpy.QtCore import QCoreApplication

from .viewer._components import ChannelMode
from .viewer._data_wrapper import DataWrapper
from .viewer._viewer import ChannelModeStr


def imshow(
data: Any | DataWrapper,
cmap: Any | None = None,
*,
channel_mode: Literal["mono", "composite", "auto"] = "auto",
channel_mode: ChannelModeStr | ChannelMode = "auto",
) -> NDViewer:
"""Display an array or DataWrapper in a new NDViewer window.
Expand All @@ -29,7 +31,7 @@ def imshow(
The data to be displayed. If not a DataWrapper, it will be wrapped in one.
cmap : Any | None, optional
The colormap(s) to use for displaying the data.
channel_mode : Literal['mono', 'composite'], optional
channel_mode : Literal['mono', 'composite', 'rgb', 'rgba', 'auto'], optional
The initial mode for displaying the channels. By default "mono" will be
used unless a cmap is provided, in which case "composite" will be used.
Expand All @@ -39,12 +41,8 @@ def imshow(
The viewer window.
"""
app, should_exec = _get_app()
if cmap is not None:
channel_mode = "composite"
if not isinstance(cmap, (list, tuple)):
cmap = [cmap]
elif channel_mode == "auto":
channel_mode = "mono"
if cmap is not None and not isinstance(cmap, (list, tuple)):
cmap = [cmap]
viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode)
viewer.show()
viewer.raise_()
Expand Down
27 changes: 18 additions & 9 deletions src/ndv/viewer/_backends/_pygfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,31 @@ def _is_inside(bounding_box: np.ndarray, pos: Sequence[float]) -> bool:
)


# TODO Combine with similar function in _vispy.py
def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
return np.moveaxis(data, i, -1)
return data


class PyGFXImageHandle:
def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None:
self._image = image
self._render = render
self._grid = cast("Texture", image.geometry.grid)
self._material = cast("ImageBasicMaterial", image.material)
self._cmap = cmap.Colormap("gray")

@property
def data(self) -> np.ndarray:
return self._grid.data # type: ignore [no-any-return]

@data.setter
def data(self, data: np.ndarray) -> None:
self._grid.data[:] = data
self._grid.data[:] = _coerce_rgb(data)
self._grid.update_range((0, 0, 0), self._grid.size)

@property
Expand Down Expand Up @@ -85,7 +96,9 @@ def cmap(self) -> cmap.Colormap:
@cmap.setter
def cmap(self, cmap: cmap.Colormap) -> None:
self._cmap = cmap
self._material.map = cmap.to_pygfx()
# RGB image special case
if self.data.ndim != 3:
self._material.map = cmap.to_pygfx()
self._render()

def start_move(self, pos: Sequence[float]) -> None:
Expand Down Expand Up @@ -409,7 +422,6 @@ class PyGFXViewerCanvas(PCanvas):
"""pygfx-based canvas wrapper."""

def __init__(self) -> None:
self._current_shape: tuple[int, ...] = ()
self._last_state: dict[Literal[2, 3], Any] = {}

self._canvas = _QWgpuCanvas(size=(600, 600))
Expand Down Expand Up @@ -470,6 +482,7 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> PyGFXImageHandle:
"""Add a new Image node to the scene."""
data = _coerce_rgb(data)
tex = pygfx.Texture(data, dim=2)
image = pygfx.Image(
pygfx.Geometry(grid=tex),
Expand All @@ -479,9 +492,7 @@ def add_image(
self._scene.add(image)

if data is not None:
self._current_shape, prev_shape = data.shape, self._current_shape
if not prev_shape:
self.set_range()
self.set_range()

# FIXME: I suspect there are more performant ways to refresh the canvas
# look into it.
Expand All @@ -504,9 +515,7 @@ def add_volume(

if data is not None:
vol.local_position = [-0.5 * i for i in data.shape[::-1]]
self._current_shape, prev_shape = data.shape, self._current_shape
if len(prev_shape) != 3:
self.set_range()
self.set_range()

# FIXME: I suspect there are more performant ways to refresh the canvas
# look into it.
Expand Down
15 changes: 13 additions & 2 deletions src/ndv/viewer/_backends/_vispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0)


# TODO Combine with similar function in _pygfx.py
def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
return np.moveaxis(data, i, -1)
return data


class Handle(scene.visuals.Markers):
"""A Marker that allows specific ROI alterations."""

Expand Down Expand Up @@ -248,7 +258,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None:
class VispyImageHandle:
def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None:
self._visual = visual
self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3
self._ndim = self.data.ndim

@property
def data(self) -> np.ndarray:
Expand All @@ -266,7 +276,7 @@ def data(self, data: np.ndarray) -> None:
stacklevel=2,
)
return
self._visual.set_data(data)
self._visual.set_data(_coerce_rgb(data))

@property
def visible(self) -> bool:
Expand Down Expand Up @@ -490,6 +500,7 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> VispyImageHandle:
"""Add a new Image node to the scene."""
data = _coerce_rgb(data)
img = scene.visuals.Image(data, parent=self._view.scene)
img.set_gl_state("additive", depth_test=False)
img.interactive = True
Expand Down
60 changes: 35 additions & 25 deletions src/ndv/viewer/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING

from qtpy.QtCore import QSize
from qtpy.QtCore import QSize, Qt
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import QLabel, QPushButton, QWidget
from superqt import QIconifyIcon
from superqt import QEnumComboBox, QIconifyIcon

if TYPE_CHECKING:
from qtpy.QtGui import QStandardItem, QStandardItemModel

SPIN_GIF = str(Path(__file__).parent / "spin.gif")

Expand Down Expand Up @@ -35,35 +39,41 @@ def __init__(self, parent: QWidget | None = None):

class ChannelMode(str, Enum):
COMPOSITE = "composite"
RGBA = "rgba"
MONO = "mono"

@classmethod
def _missing_(cls, value: object) -> ChannelMode | None:
if value == "rgb":
return ChannelMode.RGBA
return None

def __str__(self) -> str:
return self.value


class ChannelModeButton(QPushButton):
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
self.setCheckable(True)
self.toggled.connect(self.next_mode)

# set minimum width to the width of the larger string 'composite'
self.setMinimumWidth(92) # magic number :/

def next_mode(self) -> None:
if self.isChecked():
self.setMode(ChannelMode.MONO)
else:
self.setMode(ChannelMode.COMPOSITE)

def mode(self) -> ChannelMode:
return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE

def setMode(self, mode: ChannelMode) -> None:
# we show the name of the next mode, not the current one
other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO
self.setText(str(other))
self.setChecked(mode == ChannelMode.MONO)
class ChannelModeCombo(QEnumComboBox):
"""A ComboBox for ChannelMode, where the RGBA enum can be removed."""

def __init__(self, parent: QWidget | None = None, allow_rgba: bool = False):
super().__init__(parent, enum_class=ChannelMode)
# Find the RGBA item
idx = list(ChannelMode.__members__.keys()).index("RGBA")
model: QStandardItemModel = self.model()
self._rgba_item: QStandardItem = model.item(idx)

self.allow_rgba(allow_rgba)

def allow_rgba(self, enable: bool) -> None:
flags = self._rgba_item.flags()
self._rgba_item.setFlags(
flags | Qt.ItemFlag.ItemIsEnabled
if enable
else flags & ~Qt.ItemFlag.ItemIsEnabled
)
if self.currentEnum() == ChannelMode.RGBA and not enable:
# Arbitrary fallback mode
self.setCurrentEnum(ChannelMode.COMPOSITE)


class ROIButton(QPushButton):
Expand Down
12 changes: 4 additions & 8 deletions src/ndv/viewer/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,15 @@ def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
# see if any of the dimensions are named "channel"
for dimkey, val in self.sizes().items():
sizes = self.sizes()
for dimkey, val in sizes.items():
if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES:
if val <= self.MAX_CHANNELS:
return dimkey

# for shaped arrays, use the smallest dimension as the channel axis
shape = getattr(self._data, "shape", None)
if isinstance(shape, Sequence):
with suppress(ValueError):
smallest_dim = min(shape)
if smallest_dim <= self.MAX_CHANNELS:
return shape.index(smallest_dim)
return None
min_key = min(sizes, key=sizes.get) # type: ignore
return min_key if sizes[min_key] <= self.MAX_CHANNELS else None

def save_as_zarr(self, save_loc: str | Path) -> None:
raise NotImplementedError("save_as_zarr not implemented for this data type.")
Expand Down
17 changes: 17 additions & 0 deletions src/ndv/viewer/_lut_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,20 @@ def update_autoscale(self) -> None:
self._clims.setMinimum(min(mi, self._clims.minimum()))
self._clims.setMaximum(max(ma, self._clims.maximum()))
self._clims.setValue((mi, ma))


class RGBAControl(LutControl):
def __init__(
self,
name: str = "",
handles: Iterable[PImageHandle] = (),
parent: QWidget | None = None,
cmaplist: Iterable[Any] = (),
auto_clim: bool = True,
) -> None:
super().__init__("RGB", handles, parent, cmaplist, auto_clim)
self._cmap.setVisible(False)

def _on_cmap_changed(self, cmap: cmap.Colormap) -> None:
# NB: No-op
pass
Loading

0 comments on commit cf46ae2

Please sign in to comment.