Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

suggestion #3

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/rgb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import math

import numpy

import ndv

img = numpy.zeros((256, 256, 4), dtype=numpy.float32)

for x in range(256):
for y in range(256):
img[x, y, 0] = x
img[x, y, 1] = y
img[x, y, 2] = 255 - x
img[x, y, 3] = math.sqrt((x - 128) ** 2 + (y - 128) ** 2)

# ndv assumes that a trailing dimension of size 3 or 4 is an RGB(A) image
ndv.imshow(img)
10 changes: 2 additions & 8 deletions src/ndv/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def imshow(
data: Any | DataWrapper,
cmap: Any | None = None,
*,
channel_mode: Literal["mono", "composite", "auto"] = "auto",
channel_mode: Literal["mono", "composite", "rgb", "auto"] = "auto",
) -> NDViewer:
"""Display an array or DataWrapper in a new NDViewer window.

Expand All @@ -29,7 +29,7 @@ def imshow(
The data to be displayed. If not a DataWrapper, it will be wrapped in one.
cmap : Any | None, optional
The colormap(s) to use for displaying the data.
channel_mode : Literal['mono', 'composite'], optional
channel_mode : Literal['mono', 'composite', 'rgb'], optional
The initial mode for displaying the channels. By default "mono" will be
used unless a cmap is provided, in which case "composite" will be used.

Expand All @@ -39,12 +39,6 @@ def imshow(
The viewer window.
"""
app, should_exec = _get_app()
if cmap is not None:
channel_mode = "composite"
if not isinstance(cmap, (list, tuple)):
cmap = [cmap]
elif channel_mode == "auto":
channel_mode = "mono"
viewer = NDViewer(data, colormaps=cmap, channel_mode=channel_mode)
viewer.show()
viewer.raise_()
Expand Down
17 changes: 16 additions & 1 deletion src/ndv/viewer/_backends/_pygfx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,20 @@ def __init__(self, image: pygfx.Image | pygfx.Volume, render: Callable) -> None:
self._render = render
self._grid = cast("Texture", image.geometry.grid)
self._material = cast("ImageBasicMaterial", image.material)
self._cmap = cmap.Colormap("gray")

@property
def data(self) -> np.ndarray:
return self._grid.data # type: ignore [no-any-return]

@data.setter
def data(self, data: np.ndarray) -> None:
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
self._grid.data[:] = data
self._grid.update_range((0, 0, 0), self._grid.size)

Expand Down Expand Up @@ -85,7 +92,9 @@ def cmap(self) -> cmap.Colormap:
@cmap.setter
def cmap(self, cmap: cmap.Colormap) -> None:
self._cmap = cmap
self._material.map = cmap.to_pygfx()
# FIXME: RGB image special case
if self.data.ndim != 3:
self._material.map = cmap.to_pygfx()
self._render()

def start_move(self, pos: Sequence[float]) -> None:
Expand Down Expand Up @@ -470,6 +479,12 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> PyGFXImageHandle:
"""Add a new Image node to the scene."""
if data is not None and data.ndim == 3:
# PyGFX expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
tex = pygfx.Texture(data, dim=2)
image = pygfx.Image(
pygfx.Geometry(grid=tex),
Expand Down
16 changes: 14 additions & 2 deletions src/ndv/viewer/_backends/_vispy.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def cursor_at(self, pos: Sequence[float]) -> Qt.CursorShape | None:
class VispyImageHandle:
def __init__(self, visual: scene.visuals.Image | scene.visuals.Volume) -> None:
self._visual = visual
self._ndim = 2 if isinstance(visual, scene.visuals.Image) else 3
self._ndim = self.data.ndim

@property
def data(self) -> np.ndarray:
Expand All @@ -266,6 +266,12 @@ def data(self, data: np.ndarray) -> None:
stacklevel=2,
)
return
if data is not None and data.ndim == 3:
# VisPy expects (A)RGB data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
self._visual.set_data(data)

@property
Expand Down Expand Up @@ -491,11 +497,17 @@ def add_image(
self, data: np.ndarray | None = None, cmap: cmap.Colormap | None = None
) -> VispyImageHandle:
"""Add a new Image node to the scene."""
if data is not None and data.ndim == 3:
# VisPy expects RGB(A) data to be X, Y, C
for i, s in enumerate(data.shape):
if s in [3, 4]:
data = np.moveaxis(data, i, -1)
break
img = scene.visuals.Image(data, parent=self._view.scene)
img.set_gl_state("additive", depth_test=False)
img.interactive = True
if data is not None:
self._current_shape, prev_shape = data.shape, self._current_shape
self._current_shape, prev_shape = data.shape[:2], self._current_shape
if not prev_shape:
self.set_range()
handle = VispyImageHandle(img)
Expand Down
38 changes: 15 additions & 23 deletions src/ndv/viewer/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from qtpy.QtCore import QSize
from qtpy.QtGui import QMovie
from qtpy.QtWidgets import QLabel, QPushButton, QWidget
from superqt import QIconifyIcon
from superqt import QEnumComboBox, QIconifyIcon
from superqt.utils import signals_blocked

SPIN_GIF = str(Path(__file__).parent / "spin.gif")

Expand Down Expand Up @@ -35,35 +36,26 @@ def __init__(self, parent: QWidget | None = None):

class ChannelMode(str, Enum):
COMPOSITE = "composite"
RGB = "rgb"
MONO = "mono"

def __str__(self) -> str:
return self.value


class ChannelModeButton(QPushButton):
class ChannelModeCombo(QEnumComboBox):
def __init__(self, parent: QWidget | None = None):
super().__init__(parent)
self.setCheckable(True)
self.toggled.connect(self.next_mode)

# set minimum width to the width of the larger string 'composite'
self.setMinimumWidth(92) # magic number :/

def next_mode(self) -> None:
if self.isChecked():
self.setMode(ChannelMode.MONO)
else:
self.setMode(ChannelMode.COMPOSITE)

def mode(self) -> ChannelMode:
return ChannelMode.MONO if self.isChecked() else ChannelMode.COMPOSITE

def setMode(self, mode: ChannelMode) -> None:
# we show the name of the next mode, not the current one
other = ChannelMode.COMPOSITE if mode is ChannelMode.MONO else ChannelMode.MONO
self.setText(str(other))
self.setChecked(mode == ChannelMode.MONO)
super().__init__(parent, enum_class=ChannelMode)

def enable_rgb(self, enable: bool) -> None:
with signals_blocked(self):
current = self.currentEnum()
self.setEnumClass(ChannelMode)
if not enable:
idx = list(ChannelMode.__members__.keys()).index("RGB")
self.removeItem(idx)
if current:
self.setCurrentEnum(current)


class ROIButton(QPushButton):
Expand Down
12 changes: 4 additions & 8 deletions src/ndv/viewer/_data_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,15 @@ def guess_channel_axis(self) -> Hashable | None:
"""Return the (best guess) axis name for the channel dimension."""
# for arrays with labeled dimensions,
# see if any of the dimensions are named "channel"
for dimkey, val in self.sizes().items():
sizes = self.sizes()
for dimkey, val in sizes.items():
if str(dimkey).lower() in self.COMMON_CHANNEL_NAMES:
if val <= self.MAX_CHANNELS:
return dimkey

# for shaped arrays, use the smallest dimension as the channel axis
shape = getattr(self._data, "shape", None)
if isinstance(shape, Sequence):
with suppress(ValueError):
smallest_dim = min(shape)
if smallest_dim <= self.MAX_CHANNELS:
return shape.index(smallest_dim)
return None
min_key = min(sizes, key=sizes.get) # type: ignore
return min_key if sizes[min_key] <= self.MAX_CHANNELS else None

def save_as_zarr(self, save_loc: str | Path) -> None:
raise NotImplementedError("save_as_zarr not implemented for this data type.")
Expand Down
17 changes: 17 additions & 0 deletions src/ndv/viewer/_lut_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,20 @@ def update_autoscale(self) -> None:
self._clims.setMinimum(min(mi, self._clims.minimum()))
self._clims.setMaximum(max(ma, self._clims.maximum()))
self._clims.setValue((mi, ma))


class RGBControl(LutControl):
def __init__(
self,
name: str = "",
handles: Iterable[PImageHandle] = (),
parent: QWidget | None = None,
cmaplist: Iterable[Any] = (),
auto_clim: bool = True,
) -> None:
super().__init__("RGB", handles, parent, cmaplist, auto_clim)
self._cmap.setVisible(False)

def _on_cmap_changed(self, cmap: cmap.Colormap) -> None:
# NB: No-op
pass
Loading