Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: RGB support #41

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions examples/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import math

import numpy

import ndv

img = numpy.zeros((256, 256, 4), dtype=numpy.uint8)

for x in range(256):
for y in range(256):
img[x, y, 0] = x
img[x, y, 1] = y
img[x, y, 2] = 255 - x
img[x, y, 3] = int(math.sqrt((x - 128) ** 2 + (y - 128) ** 2))

n = ndv.imshow(img)
16 changes: 7 additions & 9 deletions src/ndv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING, Any, Literal
from typing import TYPE_CHECKING, Any

from qtpy.QtWidgets import QApplication

Expand All @@ -12,14 +12,16 @@
if TYPE_CHECKING:
from qtpy.QtCore import QCoreApplication

from .viewer._components import ChannelMode
from .viewer._data_wrapper import DataWrapper
from .viewer._viewer import ChannelModeStr


def imshow(
data: Any | DataWrapper,
cmap: Any | None = None,
*,
channel_mode: Literal["mono", "composite", "auto"] = "auto",
channel_mode: ChannelModeStr | ChannelMode = "auto",
) -> NDViewer:
"""Display an array or DataWrapper in a new NDViewer window.

Expand All @@ -29,7 +31,7 @@ def imshow(
The data to be displayed. If not a DataWrapper, it will be wrapped in one.
cmap : Any | None, optional
The colormap(s) to use for displaying the data.
channel_mode : Literal['mono', 'composite'], optional
channel_mode : Literal['mono', 'composite', 'rgb', 'rgba', 'auto'], optional
The initial mode for displaying the channels. By default "mono" will be
used unless a cmap is provided, in which case "composite" will be used.

Expand All @@ -39,12 +41,8 @@ def imshow(
The viewer window.
"""
app, should_exec = _get_app()
if cmap is not None:
channel_mode = "composite"
if not isinstance(cmap, (list, tuple)):
cmap = [cmap]
elif channel_mode == "auto":
channel_mode = "mono"
if cmap is not None and not isinstance(cmap, (list, tuple)):
cmap = [cmap]
viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode)
viewer.show()
viewer.raise_()
Expand Down
27 changes: 18 additions & 9 deletions src/ndv/viewer/_backends/_pygfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,31 @@
)


# TODO Combine with similar function in _vispy.py
def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
return np.moveaxis(data, i, -1)

Check warning on line 41 in src/ndv/viewer/_backends/_pygfx.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/viewer/_backends/_pygfx.py#L39-L41

Added lines #L39 - L41 were not covered by tests
return data


class PyGFXImageHandle:
def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None:
self._image = image
self._render = render
self._grid = cast("Texture", image.geometry.grid)
self._material = cast("ImageBasicMaterial", image.material)
self._cmap = cmap.Colormap("gray")

@property
def data(self) -> np.ndarray:
return self._grid.data # type: ignore [no-any-return]

@data.setter
def data(self, data: np.ndarray) -> None:
self._grid.data[:] = data
self._grid.data[:] = _coerce_rgb(data)

Check warning on line 59 in src/ndv/viewer/_backends/_pygfx.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/viewer/_backends/_pygfx.py#L59

Added line #L59 was not covered by tests
self._grid.update_range((0, 0, 0), self._grid.size)

@property
Expand Down Expand Up @@ -85,7 +96,9 @@
@cmap.setter
def cmap(self, cmap: cmap.Colormap) -> None:
self._cmap = cmap
self._material.map = cmap.to_pygfx()
# RGB image special case
if self.data.ndim != 3:
self._material.map = cmap.to_pygfx()
self._render()

def start_move(self, pos: Sequence[float]) -> None:
Expand Down Expand Up @@ -409,7 +422,6 @@
"""pygfx-based canvas wrapper."""

def __init__(self) -> None:
self._current_shape: tuple[int, ...] = ()
self._last_state: dict[Literal[2, 3], Any] = {}

self._canvas = _QWgpuCanvas(size=(600, 600))
Expand Down Expand Up @@ -470,6 +482,7 @@
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> PyGFXImageHandle:
"""Add a new Image node to the scene."""
data = _coerce_rgb(data)
tex = pygfx.Texture(data, dim=2)
image = pygfx.Image(
pygfx.Geometry(grid=tex),
Expand All @@ -479,9 +492,7 @@
self._scene.add(image)

if data is not None:
self._current_shape, prev_shape = data.shape, self._current_shape
if not prev_shape:
self.set_range()
self.set_range()

# FIXME: I suspect there are more performant ways to refresh the canvas
# look into it.
Expand All @@ -504,9 +515,7 @@

if data is not None:
vol.local_position = [-0.5 * i for i in data.shape[::-1]]
self._current_shape, prev_shape = data.shape, self._current_shape
if len(prev_shape) != 3:
self.set_range()
self.set_range()

Check warning on line 518 in src/ndv/viewer/_backends/_pygfx.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/viewer/_backends/_pygfx.py#L518

Added line #L518 was not covered by tests

# FIXME: I suspect there are more performant ways to refresh the canvas
# look into it.
Expand Down
15 changes: 13 additions & 2 deletions src/ndv/viewer/_backends/_vispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,16 @@
DEFAULT_QUATERNION = Quaternion(turn, turn, 0, 0)


# TODO Combine with similar function in _pygfx.py
def _coerce_rgb(data: np.ndarray | None) -> np.ndarray | None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
return np.moveaxis(data, i, -1)
return data


class Handle(scene.visuals.Markers):
"""A Marker that allows specific ROI alterations."""

Expand Down Expand Up @@ -248,7 +258,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None:
class VispyImageHandle:
def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None:
self._visual = visual
self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3
self._ndim = self.data.ndim

@property
def data(self) -> np.ndarray:
Expand All @@ -266,7 +276,7 @@ def data(self, data: np.ndarray) -> None:
stacklevel=2,
)
return
self._visual.set_data(data)
self._visual.set_data(_coerce_rgb(data))

@property
def visible(self) -> bool:
Expand Down Expand Up @@ -490,6 +500,7 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> VispyImageHandle:
"""Add a new Image node to the scene."""
data = _coerce_rgb(data)
img = scene.visuals.Image(data, parent=self._view.scene)
img.set_gl_state("additive", depth_test=False)
img.interactive = True
Expand Down
60 changes: 35 additions & 25 deletions src/ndv/viewer/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,15 @@

from enum import Enum
from pathlib import Path
from typing import TYPE_CHECKING

from qtpy.QtCore import QSize
from qtpy.QtCore import QSize, Qt
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import QLabel, QPushButton, QWidget
from superqt import QIconifyIcon
from superqt import QEnumComboBox, QIconifyIcon

if TYPE_CHECKING:
from qtpy.QtGui import QStandardItem, QStandardItemModel

SPIN_GIF = str(Path(__file__).parent / "spin.gif")

Expand Down Expand Up @@ -35,35 +39,41 @@

class ChannelMode(str, Enum):
COMPOSITE = "composite"
RGBA = "rgba"
MONO = "mono"

@classmethod
def _missing_(cls, value: object) -> ChannelMode | None:
if value == "rgb":
return ChannelMode.RGBA
return None

Check warning on line 49 in src/ndv/viewer/_components.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/viewer/_components.py#L47-L49

Added lines #L47 - L49 were not covered by tests

def __str__(self) -> str:
return self.value


class ChannelModeButton(QPushButton):
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
self.setCheckable(True)
self.toggled.connect(self.next_mode)

# set minimum width to the width of the larger string 'composite'
self.setMinimumWidth(92) # magic number :/

def next_mode(self) -> None:
if self.isChecked():
self.setMode(ChannelMode.MONO)
else:
self.setMode(ChannelMode.COMPOSITE)

def mode(self) -> ChannelMode:
return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE

def setMode(self, mode: ChannelMode) -> None:
# we show the name of the next mode, not the current one
other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO
self.setText(str(other))
self.setChecked(mode == ChannelMode.MONO)
class ChannelModeCombo(QEnumComboBox):
"""A ComboBox for ChannelMode, where the RGBA enum can be removed."""

def __init__(self, parent: QWidget | None = None, allow_rgba: bool = False):
super().__init__(parent, enum_class=ChannelMode)
# Find the RGBA item
idx = list(ChannelMode.__members__.keys()).index("RGBA")
model: QStandardItemModel = self.model()
self._rgba_item: QStandardItem = model.item(idx)

self.allow_rgba(allow_rgba)

def allow_rgba(self, enable: bool) -> None:
flags = self._rgba_item.flags()
self._rgba_item.setFlags(
flags | Qt.ItemFlag.ItemIsEnabled
if enable
else flags & ~Qt.ItemFlag.ItemIsEnabled
)
if self.currentEnum() == ChannelMode.RGBA and not enable:
# Arbitrary fallback mode
self.setCurrentEnum(ChannelMode.COMPOSITE)

Check warning on line 76 in src/ndv/viewer/_components.py

View check run for this annotation

Codecov / codecov/patch

src/ndv/viewer/_components.py#L76

Added line #L76 was not covered by tests


class ROIButton(QPushButton):
Expand Down
12 changes: 4 additions & 8 deletions src/ndv/viewer/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,15 @@ def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
# see if any of the dimensions are named "channel"
for dimkey, val in self.sizes().items():
sizes = self.sizes()
for dimkey, val in sizes.items():
if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES:
if val <= self.MAX_CHANNELS:
return dimkey

# for shaped arrays, use the smallest dimension as the channel axis
shape = getattr(self._data, "shape", None)
if isinstance(shape, Sequence):
with suppress(ValueError):
smallest_dim = min(shape)
if smallest_dim <= self.MAX_CHANNELS:
return shape.index(smallest_dim)
return None
min_key = min(sizes, key=sizes.get) # type: ignore
return min_key if sizes[min_key] <= self.MAX_CHANNELS else None

def save_as_zarr(self, save_loc: str | Path) -> None:
raise NotImplementedError("save_as_zarr not implemented for this data type.")
Expand Down
17 changes: 17 additions & 0 deletions src/ndv/viewer/_lut_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,20 @@ def update_autoscale(self) -> None:
self._clims.setMinimum(min(mi, self._clims.minimum()))
self._clims.setMaximum(max(ma, self._clims.maximum()))
self._clims.setValue((mi, ma))


class RGBAControl(LutControl):
def __init__(
self,
name: str = "",
handles: Iterable[PImageHandle] = (),
parent: QWidget | None = None,
cmaplist: Iterable[Any] = (),
auto_clim: bool = True,
) -> None:
super().__init__("RGB", handles, parent, cmaplist, auto_clim)
self._cmap.setVisible(False)

def _on_cmap_changed(self, cmap: cmap.Colormap) -> None:
# NB: No-op
pass
Loading
Loading