diff --git a/whitecanvas/backend/bokeh/_labels.py b/whitecanvas/backend/bokeh/_labels.py index 1ae6ecbb..1d5fe62b 100644 --- a/whitecanvas/backend/bokeh/_labels.py +++ b/whitecanvas/backend/bokeh/_labels.py @@ -8,7 +8,7 @@ from cmap import Color from whitecanvas.backend.bokeh._base import to_bokeh_line_style -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from bokeh.models import Axis as BokehAxis @@ -94,6 +94,9 @@ def _plt_get_color(self): def _plt_set_color(self, color): self._plt_get_axis().axis_label_text_color = Color(color).hex + def _plt_set_scale(self, scale: AxisScale) -> None: + raise NotImplementedError("Bokeh does not support dynamically changing scale.") + class Label(_CanvasComponent): def __init__(self, canvas: Canvas): diff --git a/whitecanvas/backend/matplotlib/_labels.py b/whitecanvas/backend/matplotlib/_labels.py index 3b4f492e..f0df043d 100644 --- a/whitecanvas/backend/matplotlib/_labels.py +++ b/whitecanvas/backend/matplotlib/_labels.py @@ -4,9 +4,8 @@ from typing import TYPE_CHECKING from matplotlib import pyplot as plt -from matplotlib.ticker import AutoLocator, AutoMinorLocator -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from whitecanvas.backend.matplotlib.canvas import Canvas @@ -287,6 +286,15 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty linewidth=width, ) + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._canvas()._axes.set_xscale("linear") + elif scale is AxisScale.LOG: + self._canvas()._axes.set_xscale("log") + else: + raise ValueError(f"Invalid scale: {scale}") + self._canvas()._plt_draw() + class YAxis(AxisBase): def __init__(self, canvas: Canvas): @@ -326,3 +334,12 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty linestyle=style.value, linewidth=width, ) + + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._canvas()._axes.set_yscale("linear") + elif scale is AxisScale.LOG: + self._canvas()._axes.set_yscale("log") + else: + raise ValueError(f"Invalid scale: {scale}") + self._canvas()._plt_draw() diff --git a/whitecanvas/backend/mock/canvas.py b/whitecanvas/backend/mock/canvas.py index 64a98586..7a6f0e8a 100644 --- a/whitecanvas/backend/mock/canvas.py +++ b/whitecanvas/backend/mock/canvas.py @@ -253,6 +253,9 @@ def _plt_set_limits(self, limits: tuple[float, float]): def _plt_set_grid_state(self, *args, **kwargs): pass + def _plt_set_scale(self, scale): + pass + class Ticks(_SupportsText): def __init__(self): diff --git a/whitecanvas/backend/plotly/_labels.py b/whitecanvas/backend/plotly/_labels.py index e620f2d8..e6ccb318 100644 --- a/whitecanvas/backend/plotly/_labels.py +++ b/whitecanvas/backend/plotly/_labels.py @@ -5,7 +5,7 @@ import numpy as np -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle from whitecanvas.utils.normalize import rgba_str_color if TYPE_CHECKING: @@ -96,13 +96,27 @@ def _plt_get_axis(self): return getattr(self._canvas()._subplot_layout(), self._axis) def _plt_get_limits(self) -> tuple[float, float]: - lim = self._plt_get_axis().range + axis = self._plt_get_axis() + lim = axis.range + typ = axis.type if lim is None: - lim = (0, 1) # TODO: how to get the limits? - return lim + # default value + if typ == "linear": + return (0, 1) + else: + return (0.1, 1) + else: + if typ == "linear": + return lim + else: + return 10 ** lim[0], 10 ** lim[1] def _plt_set_limits(self, limits: tuple[float, float]): - self._plt_get_axis().range = limits + axis = self._plt_get_axis() + if axis.type == "linear": + axis.range = limits + else: + axis.range = np.log10(limits) def _plt_get_color(self): # color of the axis itself @@ -123,6 +137,14 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty axis.gridcolor = rgba_str_color(color) axis.gridwidth = width + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._plt_get_axis().type = "linear" + elif scale is AxisScale.LOG: + self._plt_get_axis().type = "log" + else: + raise ValueError(f"Invalid scale: {scale}") + class Ticks(_CanvasComponent): def __init__(self, canvas: Canvas, axis: str): diff --git a/whitecanvas/backend/pyqtgraph/_labels.py b/whitecanvas/backend/pyqtgraph/_labels.py index a6530229..4b509611 100644 --- a/whitecanvas/backend/pyqtgraph/_labels.py +++ b/whitecanvas/backend/pyqtgraph/_labels.py @@ -10,7 +10,7 @@ from whitecanvas.backend.pyqtgraph._base import PyQtAxis from whitecanvas.backend.pyqtgraph._qt_utils import array_to_qcolor -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: import pyqtgraph as pg @@ -163,6 +163,14 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty axis.setGrid(grid) # tick disappears by unknown reason. + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._plt_get_axis().setLogMode(False) + elif scale is AxisScale.LOG: + self._plt_get_axis().setLogMode(True) + else: + raise ValueError(f"Unknown scale: {scale}") + class Ticks(_CanvasComponent): def __init__(self, canvas: Canvas, axis: str): diff --git a/whitecanvas/backend/vispy/_label.py b/whitecanvas/backend/vispy/_label.py index 41134355..550c8d02 100644 --- a/whitecanvas/backend/vispy/_label.py +++ b/whitecanvas/backend/vispy/_label.py @@ -7,7 +7,7 @@ from vispy import scene from vispy.visuals.axis import AxisVisual, Ticker -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from vispy.visuals import LineVisual, TextVisual @@ -109,6 +109,15 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty else: grid_lines.set_x_grid_lines(visible, color, width, style) + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self.axis.scale_type = "linear" + elif scale is AxisScale.LOG: + # NOTE: not implemented on vispy side yet + self.axis.scale_type = "logarithmic" + else: + raise ValueError(f"Invalid scale: {scale}") + class Ticks: def __init__(self, axis: Axis): diff --git a/whitecanvas/canvas/_namespaces.py b/whitecanvas/canvas/_namespaces.py index 2453d2e3..cc34e9fe 100644 --- a/whitecanvas/canvas/_namespaces.py +++ b/whitecanvas/canvas/_namespaces.py @@ -18,6 +18,7 @@ from whitecanvas._exceptions import ReferenceDeletedError from whitecanvas._signal import MouseMoveSignal, MouseSignal from whitecanvas.types import ( + AxisScale, ColorType, LineStyle, Modifier, @@ -286,6 +287,7 @@ def __init__(self, canvas: CanvasBase | None = None): super().__init__(canvas) self.events = AxisSignals() self._flipped = False + self._scale = AxisScale.LINEAR def _get_object(self) -> protocols.AxisProtocol: raise NotImplementedError @@ -335,17 +337,50 @@ def flipped(self, flipped: bool): def set_gridlines( self, + *, visible: bool = True, color: ColorType = "gray", width: float = 1.0, style: str | LineStyle = LineStyle.SOLID, ): + """ + Update the properties of grid lines for the axis. + + Parameters + ---------- + visible : bool, default True + Whether to show the grid lines. + color : color, default "gray" + The color of the grid lines. + width : float, default 1.0 + The width of the grid lines. + style : str or LineStyle, default "solid" + The style of the grid lines. + """ color = arr_color(color) style = LineStyle(style) if width < 0: raise ValueError("width must be non-negative.") self._get_object()._plt_set_grid_state(visible, color, width, style) + @property + def scale(self) -> AxisScale: + """Scale (linear or log) of the axis.""" + return self._scale + + @scale.setter + def scale(self, scale): + scale = AxisScale(scale) + if scale is AxisScale.LOG: + _min, _max = self.lim + if _min <= 0: + if _max <= 0: + self.lim = (0.01, 1) # default limits + else: + self.lim = (_max / 100, _max) + self._get_object()._plt_set_scale(scale) + self._scale = scale + class XAxisNamespace(AxisNamespace): label = XLabelNamespace() diff --git a/whitecanvas/protocols/canvas_protocol.py b/whitecanvas/protocols/canvas_protocol.py index b5c4b401..eb9a8aae 100644 --- a/whitecanvas/protocols/canvas_protocol.py +++ b/whitecanvas/protocols/canvas_protocol.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from whitecanvas.layers._legend import LegendItem - from whitecanvas.types import LineStyle, Location, MouseEvent + from whitecanvas.types import AxisScale, LineStyle, Location, MouseEvent @runtime_checkable @@ -187,6 +187,9 @@ def _plt_flip(self) -> None: def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineStyle): """Set the grid line.""" + def _plt_set_scale(self, scale: AxisScale) -> None: + """Set scale of axis""" + @runtime_checkable class AxisGridProtocol(Protocol): diff --git a/whitecanvas/types/__init__.py b/whitecanvas/types/__init__.py index ea0822a7..154c0911 100644 --- a/whitecanvas/types/__init__.py +++ b/whitecanvas/types/__init__.py @@ -8,6 +8,7 @@ ) from whitecanvas.types._enums import ( Alignment, + AxisScale, Hatch, HistogramKind, HistogramShape, @@ -47,6 +48,7 @@ "Point", "MouseEventType", "Alignment", + "AxisScale", "XYData", "XYYData", "XYTextData", diff --git a/whitecanvas/types/_enums.py b/whitecanvas/types/_enums.py index 578a5b85..357afe63 100644 --- a/whitecanvas/types/_enums.py +++ b/whitecanvas/types/_enums.py @@ -248,3 +248,8 @@ class HistogramKind(_StrEnum): probability = "probability" frequency = "frequency" percent = "percent" + + +class AxisScale(_StrEnum): + LINEAR = "linear" + LOG = "log"