From f07c12e04613e63bc6522929a078d3458ae88eb4 Mon Sep 17 00:00:00 2001 From: Hanjin Liu Date: Tue, 27 Feb 2024 23:15:19 +0900 Subject: [PATCH] log scale wip --- whitecanvas/backend/bokeh/_labels.py | 5 +++- whitecanvas/backend/matplotlib/_labels.py | 21 +++++++++++-- whitecanvas/backend/mock/canvas.py | 3 ++ whitecanvas/backend/plotly/_labels.py | 32 ++++++++++++++++---- whitecanvas/backend/pyqtgraph/_labels.py | 10 ++++++- whitecanvas/backend/vispy/_label.py | 11 ++++++- whitecanvas/canvas/_namespaces.py | 36 ++++++++++++++++++++++- whitecanvas/protocols/canvas_protocol.py | 5 +++- whitecanvas/types/__init__.py | 2 ++ whitecanvas/types/_enums.py | 5 ++++ 10 files changed, 118 insertions(+), 12 deletions(-) diff --git a/whitecanvas/backend/bokeh/_labels.py b/whitecanvas/backend/bokeh/_labels.py index 78f137f1..f0d4cf58 100644 --- a/whitecanvas/backend/bokeh/_labels.py +++ b/whitecanvas/backend/bokeh/_labels.py @@ -8,7 +8,7 @@ from cmap import Color from whitecanvas.backend.bokeh._base import to_bokeh_line_style -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from bokeh.models import Axis as BokehAxis @@ -94,6 +94,9 @@ def _plt_get_color(self): def _plt_set_color(self, color): self._plt_get_axis().axis_label_text_color = Color(color).hex + def _plt_set_scale(self, scale: AxisScale) -> None: + raise NotImplementedError("Bokeh does not support dynamically changing scale.") + class Label(_CanvasComponent): def __init__(self, canvas: Canvas): diff --git a/whitecanvas/backend/matplotlib/_labels.py b/whitecanvas/backend/matplotlib/_labels.py index 3b4f492e..f0df043d 100644 --- a/whitecanvas/backend/matplotlib/_labels.py +++ b/whitecanvas/backend/matplotlib/_labels.py @@ -4,9 +4,8 @@ from typing import TYPE_CHECKING from matplotlib import pyplot as plt -from matplotlib.ticker import AutoLocator, AutoMinorLocator -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from whitecanvas.backend.matplotlib.canvas import Canvas @@ -287,6 +286,15 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty linewidth=width, ) + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._canvas()._axes.set_xscale("linear") + elif scale is AxisScale.LOG: + self._canvas()._axes.set_xscale("log") + else: + raise ValueError(f"Invalid scale: {scale}") + self._canvas()._plt_draw() + class YAxis(AxisBase): def __init__(self, canvas: Canvas): @@ -326,3 +334,12 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty linestyle=style.value, linewidth=width, ) + + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._canvas()._axes.set_yscale("linear") + elif scale is AxisScale.LOG: + self._canvas()._axes.set_yscale("log") + else: + raise ValueError(f"Invalid scale: {scale}") + self._canvas()._plt_draw() diff --git a/whitecanvas/backend/mock/canvas.py b/whitecanvas/backend/mock/canvas.py index 6a43de38..88322042 100644 --- a/whitecanvas/backend/mock/canvas.py +++ b/whitecanvas/backend/mock/canvas.py @@ -216,6 +216,9 @@ def _plt_get_limits(self) -> tuple[float, float]: def _plt_set_limits(self, limits: tuple[float, float]): self._limits = limits + def _plt_set_scale(self, scale): + pass + class Ticks(_SupportsText): def __init__(self): diff --git a/whitecanvas/backend/plotly/_labels.py b/whitecanvas/backend/plotly/_labels.py index 1a7c7ce0..4077af1f 100644 --- a/whitecanvas/backend/plotly/_labels.py +++ b/whitecanvas/backend/plotly/_labels.py @@ -5,7 +5,7 @@ import numpy as np -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle from whitecanvas.utils.normalize import rgba_str_color if TYPE_CHECKING: @@ -96,13 +96,27 @@ def _plt_get_axis(self): return getattr(self._canvas()._subplot_layout(), self._axis) def _plt_get_limits(self) -> tuple[float, float]: - lim = self._plt_get_axis().range + axis = self._plt_get_axis() + lim = axis.range + typ = axis.type if lim is None: - lim = (0, 1) # TODO: how to get the limits? - return lim + # default value + if typ == "linear": + return (0, 1) + else: + return (0.1, 1) + else: + if typ == "linear": + return lim + else: + return 10 ** lim[0], 10 ** lim[1] def _plt_set_limits(self, limits: tuple[float, float]): - self._plt_get_axis().range = limits + axis = self._plt_get_axis() + if axis.type == "linear": + axis.range = limits + else: + axis.range = np.log10(limits) def _plt_get_color(self): # color of the axis itself @@ -123,6 +137,14 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty axis.gridcolor = rgba_str_color(color) axis.gridwidth = width + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._plt_get_axis().type = "linear" + elif scale is AxisScale.LOG: + self._plt_get_axis().type = "log" + else: + raise ValueError(f"Invalid scale: {scale}") + class Ticks(_CanvasComponent): def __init__(self, canvas: Canvas, axis: str): diff --git a/whitecanvas/backend/pyqtgraph/_labels.py b/whitecanvas/backend/pyqtgraph/_labels.py index e1de924e..59aca97b 100644 --- a/whitecanvas/backend/pyqtgraph/_labels.py +++ b/whitecanvas/backend/pyqtgraph/_labels.py @@ -9,7 +9,7 @@ from qtpy.QtGui import QFont, QPen from whitecanvas.backend.pyqtgraph._qt_utils import array_to_qcolor -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: import pyqtgraph as pg @@ -160,6 +160,14 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty axis.setGrid(grid) # tick disappears by unknown reason. + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self._plt_get_axis().setLogMode(False) + elif scale is AxisScale.LOG: + self._plt_get_axis().setLogMode(True) + else: + raise ValueError(f"Unknown scale: {scale}") + class Ticks(_CanvasComponent): def __init__(self, canvas: Canvas, axis: str): diff --git a/whitecanvas/backend/vispy/_label.py b/whitecanvas/backend/vispy/_label.py index 18185dc7..3a1540be 100644 --- a/whitecanvas/backend/vispy/_label.py +++ b/whitecanvas/backend/vispy/_label.py @@ -7,7 +7,7 @@ from vispy import scene from vispy.visuals.axis import AxisVisual, Ticker -from whitecanvas.types import LineStyle +from whitecanvas.types import AxisScale, LineStyle if TYPE_CHECKING: from vispy.visuals import LineVisual, TextVisual @@ -108,6 +108,15 @@ def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineSty # self._canvas()._gridlines.visible = False pass # TODO: implement this + def _plt_set_scale(self, scale: AxisScale): + if scale is AxisScale.LINEAR: + self.axis.scale_type = "linear" + elif scale is AxisScale.LOG: + # NOTE: not implemented on vispy side yet + self.axis.scale_type = "logarithmic" + else: + raise ValueError(f"Invalid scale: {scale}") + class Ticks: def __init__(self, axis: Axis): diff --git a/whitecanvas/canvas/_namespaces.py b/whitecanvas/canvas/_namespaces.py index 7d1403b7..ae98ca18 100644 --- a/whitecanvas/canvas/_namespaces.py +++ b/whitecanvas/canvas/_namespaces.py @@ -9,7 +9,7 @@ from whitecanvas import protocols from whitecanvas._exceptions import ReferenceDeletedError -from whitecanvas.types import ColorType, LineStyle +from whitecanvas.types import AxisScale, ColorType, LineStyle from whitecanvas.utils.normalize import arr_color if TYPE_CHECKING: @@ -269,6 +269,7 @@ def __init__(self, canvas: CanvasBase | None = None): super().__init__(canvas) self.events = AxisSignals() self._flipped = False + self._scale = AxisScale.LINEAR def _get_object(self) -> protocols.AxisProtocol: raise NotImplementedError @@ -318,17 +319,50 @@ def flipped(self, flipped: bool): def set_gridlines( self, + *, visible: bool = True, color: ColorType = "gray", width: float = 1.0, style: str | LineStyle = LineStyle.SOLID, ): + """ + Update the properties of grid lines for the axis. + + Parameters + ---------- + visible : bool, default True + Whether to show the grid lines. + color : color, default "gray" + The color of the grid lines. + width : float, default 1.0 + The width of the grid lines. + style : str or LineStyle, default "solid" + The style of the grid lines. + """ color = arr_color(color) style = LineStyle(style) if width < 0: raise ValueError("width must be non-negative.") self._get_object()._plt_set_grid_state(visible, color, width, style) + @property + def scale(self) -> AxisScale: + """Scale (linear or log) of the axis.""" + return self._scale + + @scale.setter + def scale(self, scale): + scale = AxisScale(scale) + if scale is AxisScale.LOG: + _min, _max = self.lim + if _min <= 0: + if _max <= 0: + self.lim = (0.01, 1) # default limits + else: + self.lim = (_max / 100, _max) + self._get_object()._plt_set_scale(scale) + self._scale = scale + class XAxisNamespace(AxisNamespace): label = XLabelNamespace() diff --git a/whitecanvas/protocols/canvas_protocol.py b/whitecanvas/protocols/canvas_protocol.py index 0efc5912..73eeb140 100644 --- a/whitecanvas/protocols/canvas_protocol.py +++ b/whitecanvas/protocols/canvas_protocol.py @@ -9,7 +9,7 @@ from numpy.typing import NDArray from whitecanvas.layers._legend import LegendItem - from whitecanvas.types import LegendLocation, LineStyle, MouseEvent + from whitecanvas.types import AxisScale, LegendLocation, LineStyle, MouseEvent @runtime_checkable @@ -181,6 +181,9 @@ def _plt_flip(self) -> None: def _plt_set_grid_state(self, visible: bool, color, width: float, style: LineStyle): """Set the grid line.""" + def _plt_set_scale(self, scale: AxisScale) -> None: + """Set scale of axis""" + @runtime_checkable class AxisGridProtocol(Protocol): diff --git a/whitecanvas/types/__init__.py b/whitecanvas/types/__init__.py index 26386525..bd65b05a 100644 --- a/whitecanvas/types/__init__.py +++ b/whitecanvas/types/__init__.py @@ -8,6 +8,7 @@ ) from whitecanvas.types._enums import ( Alignment, + AxisScale, Hatch, HistogramKind, HistogramShape, @@ -44,6 +45,7 @@ "MouseEvent", "MouseEventType", "Alignment", + "AxisScale", "XYData", "XYYData", "XYTextData", diff --git a/whitecanvas/types/_enums.py b/whitecanvas/types/_enums.py index 7f7ece81..be57b39c 100644 --- a/whitecanvas/types/_enums.py +++ b/whitecanvas/types/_enums.py @@ -239,3 +239,8 @@ class HistogramKind(_StrEnum): probability = "probability" frequency = "frequency" percent = "percent" + + +class AxisScale(_StrEnum): + LINEAR = "linear" + LOG = "log"