From 6d3a2a483c055987ae1ceed3016976fefbe8253f Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 28 Dec 2024 15:36:13 -0500 Subject: [PATCH 1/5] add interpolate.py --- pytensor/tensor/interpolate.py | 197 +++++++++++++++++++++++++++++++ tests/tensor/test_interpolate.py | 107 +++++++++++++++++ 2 files changed, 304 insertions(+) create mode 100644 pytensor/tensor/interpolate.py create mode 100644 tests/tensor/test_interpolate.py diff --git a/pytensor/tensor/interpolate.py b/pytensor/tensor/interpolate.py new file mode 100644 index 0000000000..f9623fe7aa --- /dev/null +++ b/pytensor/tensor/interpolate.py @@ -0,0 +1,197 @@ +from collections.abc import Callable +from difflib import get_close_matches +from typing import Literal, get_args + +from pytensor.compile.builders import OpFromGraph +from pytensor.tensor import TensorLike +from pytensor.tensor.basic import as_tensor_variable, switch +from pytensor.tensor.blockwise import Blockwise +from pytensor.tensor.extra_ops import searchsorted +from pytensor.tensor.math import clip, eq, le +from pytensor.tensor.sort import argsort +from pytensor.tensor.type import scalar + + +InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"] +valid_methods = get_args(InterpolationMethod) + + +def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate): + if extrapolate: + return output + + n = x.shape[0] + + return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output)) + + +def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + + slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1]) + y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1]) + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + + left_distance = x_hat - x[clip_idx - 1] + right_distance = x[clip_idx] - x_hat + y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx]) + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx - 1, 0, x.shape[0] - 1) + y_hat = y[clip_idx] + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 0, x.shape[0] - 1) + y_hat = y[clip_idx] + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): + clip_idx = clip(idx, 1, x.shape[0] - 1) + y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2 + + return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) + + +def interpolate1d( + x: TensorLike, + y: TensorLike, + method: InterpolationMethod = "linear", + left_pad: TensorLike | None = None, + right_pad: TensorLike | None = None, + extrapolate: bool = True, +) -> Callable[[TensorLike], TensorLike]: + """ + Create a function to interpolate one-dimensional data. + + Parameters + ---------- + x : TensorLike + Input data used to create an interpolation function. Data will be sorted to be monotonically increasing. + y: TensorLike + Output data used to create an interpolation function. Must have the same shape as `x`. + method : InterpolationMethod, optional + Method for interpolation. The following methods are available: + - 'linear': Linear interpolation + - 'nearest': Nearest neighbor interpolation + - 'first': Stepwise interpolation using the closest value to the left of the query point + - 'last': Stepwise interpolation using the closest value to the right of the query point + - 'mean': Stepwise interpolation using the mean of the two closest values to the query point + left_pad: TensorLike, optional + Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this + case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`. + right_pad: TensorLike, optional + Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this + case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`. + extrapolate: bool + Whether to extend the request interpolation function beyond the range of the input-output pairs specified in + `x` and `y.` If False, constant values will be returned for such inputs. + + Returns + ------- + interpolation_func: OpFromGraph + A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns + the interpolated value `y_hat`. The input `x_hat` must be a 1d array. + + """ + x = as_tensor_variable(x) + y = as_tensor_variable(y) + + sort_idx = argsort(x) + x = x[sort_idx] + y = y[sort_idx] + + if left_pad is None: + left_pad = y[0] + else: + left_pad = as_tensor_variable(left_pad) + if right_pad is None: + right_pad = y[-1] + else: + right_pad = as_tensor_variable(right_pad) + + x_hat = scalar("x_hat", dtype=x.dtype) + idx = searchsorted(x, x_hat) + + if x.ndim != 1 or y.ndim != 1: + raise ValueError("Inputs must be 1d") + + if method == "linear": + y_hat = _linear_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "nearest": + y_hat = _nearest_neighbor_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "first": + y_hat = _stepwise_first_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "mean": + y_hat = _stepwise_mean_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "last": + y_hat = _stepwise_last_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + else: + raise NotImplementedError( + f"Unknown interpolation method: {method}. " + f"Did you mean {get_close_matches(method, valid_methods)}?" + ) + + return Blockwise( + OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()" + ) + + +def interp(x, xp, fp, left=None, right=None, period=None): + """ + One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that + matches ``np.interp`` + + Parameters + ---------- + x : TensorLike + The x-coordinates at which to evaluate the interpolated values. + + xp : TensorLike + The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise, + `xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``. + + fp : TensorLike + The y-coordinates of the data points, same length as `xp`. + + left : float, optional + Value to return for `x < xp[0]`. Default is `fp[0]`. + + right : float, optional + Value to return for `x > xp[-1]`. Default is `fp[-1]`. + + period : None + Not supported. Included to ensure the signature of this function matches ``numpy.interp``. + + Returns + ------- + y : Variable + The interpolated values, same shape as `x`. + """ + + f = interpolate1d( + xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False + ) + return f(x) diff --git a/tests/tensor/test_interpolate.py b/tests/tensor/test_interpolate.py new file mode 100644 index 0000000000..95ebae10e2 --- /dev/null +++ b/tests/tensor/test_interpolate.py @@ -0,0 +1,107 @@ +import numpy as np +import pytest +from numpy.testing import assert_allclose + +import pytensor +import pytensor.tensor as pt +from pytensor.tensor.interpolate import ( + InterpolationMethod, + interp, + interpolate1d, + valid_methods, +) + + +floatX = pytensor.config.floatX + + +def test_interp(): + xp = [1.0, 2.0, 3.0] + fp = [3.0, 2.0, 0.0] + + x = [0, 1, 1.5, 2.72, 3.14] + + out = interp(x, xp, fp).eval() + np_out = np.interp(x, xp, fp) + + assert_allclose(out, np_out) + + +def test_interp_padded(): + xp = [1.0, 2.0, 3.0] + fp = [3.0, 2.0, 0.0] + + assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0 + assert_allclose( + interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0] + ) + assert_allclose( + interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0] + ) + + +@pytest.mark.parametrize("method", valid_methods, ids=str) +@pytest.mark.parametrize( + "left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)] +) +def test_interpolate_scalar_no_extrapolate( + method: InterpolationMethod, left_pad, right_pad +): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = interpolate1d( + x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad + ) + x_hat_pt = pt.dscalar("x_hat") + f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") + + # Data points should be returned exactly, except when method == mean + if method not in ["mean", "first"]: + assert f(x[3]) == y[3] + elif method == "first": + assert f(x[3]) == y[2] + else: + # method == 'mean + assert f(x[3]) == (y[2] + y[3]) / 2 + + # When extrapolate=False, points beyond the data envelope should be constant + left_pad = y[0] if left_pad is None else left_pad + right_pad = y[-1] if right_pad is None else right_pad + + assert f(-10) == left_pad + assert f(100) == right_pad + + +@pytest.mark.parametrize("method", valid_methods, ids=str) +def test_interpolate_scalar_extrapolate(method: InterpolationMethod): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = interpolate1d(x, y, method) + x_hat_pt = pt.dscalar("x_hat") + f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") + + left_test_point = -5 + right_test_point = 100 + if method == "linear": + # Linear will compute a slope from the endpoints and continue it + left_slope = (left_test_point - x[0]) / (x[1] - x[0]) + right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2]) + assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0]) + assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2]) + + elif method == "mean": + left_expected = (y[0] + y[1]) / 2 + right_expected = (y[-1] + y[-2]) / 2 + assert f(left_test_point) == left_expected + assert f(right_test_point) == right_expected + + else: + assert f(left_test_point) == y[0] + assert f(right_test_point) == y[-1] + + # For interior points, "first" and "last" should disagree. First should take the left side of the interval, + # and last should take the right. + interior_point = x[3] + 0.1 + assert f(interior_point) == (y[4] if method == "last" else y[3]) From 0e03119cd334937a0a08feaf811d84d8b52e43bc Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 28 Dec 2024 16:50:13 -0500 Subject: [PATCH 2/5] Add jax dispatch for `searchsorted` --- pytensor/link/jax/dispatch/extra_ops.py | 11 +++++++++++ tests/link/jax/test_extra_ops.py | 8 ++++++++ tests/tensor/test_interpolate.py | 18 ++++++++++++++++++ 3 files changed, 37 insertions(+) diff --git a/pytensor/link/jax/dispatch/extra_ops.py b/pytensor/link/jax/dispatch/extra_ops.py index a9e36667ef..87e55f1007 100644 --- a/pytensor/link/jax/dispatch/extra_ops.py +++ b/pytensor/link/jax/dispatch/extra_ops.py @@ -10,6 +10,7 @@ FillDiagonalOffset, RavelMultiIndex, Repeat, + SearchsortedOp, Unique, UnravelIndex, ) @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs): # return filldiagonaloffset raise NotImplementedError("flatiter not implemented in JAX") + + +@jax_funcify.register(SearchsortedOp) +def jax_funcify_SearchsortedOp(op, **kwargs): + side = op.side + + def searchsorted(a, v, side=side, sorter=None): + return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter) + + return searchsorted diff --git a/tests/link/jax/test_extra_ops.py b/tests/link/jax/test_extra_ops.py index 1427413379..0c8fb92810 100644 --- a/tests/link/jax/test_extra_ops.py +++ b/tests/link/jax/test_extra_ops.py @@ -6,6 +6,7 @@ from pytensor.graph.fg import FunctionGraph from pytensor.graph.op import get_test_value from pytensor.tensor import extra_ops as pt_extra_ops +from pytensor.tensor.sort import argsort from pytensor.tensor.type import matrix, tensor from tests.link.jax.test_basic import compare_jax_and_py @@ -55,6 +56,13 @@ def test_extra_ops(): fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False ) + v = ptb.as_tensor_variable(6.0) + sorted_idx = argsort(a.ravel()) + + out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v) + fgraph = FunctionGraph([a], [out]) + compare_jax_and_py(fgraph, [a_test]) + @pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes") def test_bartlett_dynamic_shape(): diff --git a/tests/tensor/test_interpolate.py b/tests/tensor/test_interpolate.py index 95ebae10e2..b98e0ce371 100644 --- a/tests/tensor/test_interpolate.py +++ b/tests/tensor/test_interpolate.py @@ -8,6 +8,7 @@ InterpolationMethod, interp, interpolate1d, + polynomial_interpolate1d, valid_methods, ) @@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod): # and last should take the right. interior_point = x[3] + 0.1 assert f(interior_point) == (y[4] if method == "last" else y[3]) + + +def test_polynomial_interpolate1d(): + x = np.linspace(-2, 6, 10) + y = np.sin(x) + + f_op = polynomial_interpolate1d(x, y) + x_hat_pt = pt.dvector("x_hat") + degree = pt.iscalar("degree") + + f = pytensor.function( + [x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN" + ) + x_grid = np.linspace(-2, 6, 100) + y_hat = f(x_grid, 0) + + assert_allclose(y_hat, np.mean(y)) From b7a23d79c729058f2960cd2eb06c7f287891feae Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 28 Dec 2024 17:27:35 -0500 Subject: [PATCH 3/5] Refactor out `OpFromGraph` --- pytensor/tensor/interpolate.py | 79 +++++++++++++++++--------------- tests/tensor/test_interpolate.py | 18 -------- 2 files changed, 41 insertions(+), 56 deletions(-) diff --git a/pytensor/tensor/interpolate.py b/pytensor/tensor/interpolate.py index f9623fe7aa..b4d4173a69 100644 --- a/pytensor/tensor/interpolate.py +++ b/pytensor/tensor/interpolate.py @@ -2,14 +2,12 @@ from difflib import get_close_matches from typing import Literal, get_args -from pytensor.compile.builders import OpFromGraph from pytensor.tensor import TensorLike from pytensor.tensor.basic import as_tensor_variable, switch -from pytensor.tensor.blockwise import Blockwise from pytensor.tensor.extra_ops import searchsorted +from pytensor.tensor.functional import vectorize from pytensor.tensor.math import clip, eq, le from pytensor.tensor.sort import argsort -from pytensor.tensor.type import scalar InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"] @@ -122,41 +120,41 @@ def interpolate1d( else: right_pad = as_tensor_variable(right_pad) - x_hat = scalar("x_hat", dtype=x.dtype) - idx = searchsorted(x, x_hat) - - if x.ndim != 1 or y.ndim != 1: - raise ValueError("Inputs must be 1d") - - if method == "linear": - y_hat = _linear_interp1d( - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate - ) - elif method == "nearest": - y_hat = _nearest_neighbor_interp1d( - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate - ) - elif method == "first": - y_hat = _stepwise_first_interp1d( - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate - ) - elif method == "mean": - y_hat = _stepwise_mean_interp1d( - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate - ) - elif method == "last": - y_hat = _stepwise_last_interp1d( - x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate - ) - else: - raise NotImplementedError( - f"Unknown interpolation method: {method}. " - f"Did you mean {get_close_matches(method, valid_methods)}?" - ) - - return Blockwise( - OpFromGraph(inputs=[x_hat], outputs=[y_hat], inline=False), signature="()->()" - ) + def _scalar_interpolate1d(x_hat): + idx = searchsorted(x, x_hat) + + if x.ndim != 1 or y.ndim != 1: + raise ValueError("Inputs must be 1d") + + if method == "linear": + y_hat = _linear_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "nearest": + y_hat = _nearest_neighbor_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "first": + y_hat = _stepwise_first_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "mean": + y_hat = _stepwise_mean_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + elif method == "last": + y_hat = _stepwise_last_interp1d( + x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate + ) + else: + raise NotImplementedError( + f"Unknown interpolation method: {method}. " + f"Did you mean {get_close_matches(method, valid_methods)}?" + ) + + return y_hat + + return vectorize(_scalar_interpolate1d, signature="()->()") def interp(x, xp, fp, left=None, right=None, period=None): @@ -191,7 +189,12 @@ def interp(x, xp, fp, left=None, right=None, period=None): The interpolated values, same shape as `x`. """ + xp = as_tensor_variable(xp) + fp = as_tensor_variable(fp) + x = as_tensor_variable(x) + f = interpolate1d( xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False ) + return f(x) diff --git a/tests/tensor/test_interpolate.py b/tests/tensor/test_interpolate.py index b98e0ce371..95ebae10e2 100644 --- a/tests/tensor/test_interpolate.py +++ b/tests/tensor/test_interpolate.py @@ -8,7 +8,6 @@ InterpolationMethod, interp, interpolate1d, - polynomial_interpolate1d, valid_methods, ) @@ -106,20 +105,3 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod): # and last should take the right. interior_point = x[3] + 0.1 assert f(interior_point) == (y[4] if method == "last" else y[3]) - - -def test_polynomial_interpolate1d(): - x = np.linspace(-2, 6, 10) - y = np.sin(x) - - f_op = polynomial_interpolate1d(x, y) - x_hat_pt = pt.dvector("x_hat") - degree = pt.iscalar("degree") - - f = pytensor.function( - [x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN" - ) - x_grid = np.linspace(-2, 6, 100) - y_hat = f(x_grid, 0) - - assert_allclose(y_hat, np.mean(y)) From b484529298f52abc4034e1530d676dccea04ce1d Mon Sep 17 00:00:00 2001 From: jessegrabowski Date: Sat, 28 Dec 2024 18:19:26 -0500 Subject: [PATCH 4/5] Appease mypy --- pytensor/tensor/interpolate.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/pytensor/tensor/interpolate.py b/pytensor/tensor/interpolate.py index b4d4173a69..f598695784 100644 --- a/pytensor/tensor/interpolate.py +++ b/pytensor/tensor/interpolate.py @@ -2,7 +2,7 @@ from difflib import get_close_matches from typing import Literal, get_args -from pytensor.tensor import TensorLike +from pytensor import Variable from pytensor.tensor.basic import as_tensor_variable, switch from pytensor.tensor.extra_ops import searchsorted from pytensor.tensor.functional import vectorize @@ -64,13 +64,13 @@ def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=T def interpolate1d( - x: TensorLike, - y: TensorLike, + x: Variable, + y: Variable, method: InterpolationMethod = "linear", - left_pad: TensorLike | None = None, - right_pad: TensorLike | None = None, + left_pad: Variable | None = None, + right_pad: Variable | None = None, extrapolate: bool = True, -) -> Callable[[TensorLike], TensorLike]: +) -> Callable[[Variable], Variable]: """ Create a function to interpolate one-dimensional data. @@ -112,11 +112,11 @@ def interpolate1d( y = y[sort_idx] if left_pad is None: - left_pad = y[0] + left_pad = y[0] # type: ignore else: left_pad = as_tensor_variable(left_pad) if right_pad is None: - right_pad = y[-1] + right_pad = y[-1] # type: ignore else: right_pad = as_tensor_variable(right_pad) From d5daef1838758314868706608421f40d2f6d64ae Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Mon, 30 Dec 2024 21:06:23 +0800 Subject: [PATCH 5/5] Import user-facing functions in `tensor.__init__` --- pytensor/tensor/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytensor/tensor/__init__.py b/pytensor/tensor/__init__.py index 7385f02478..67b6ab071e 100644 --- a/pytensor/tensor/__init__.py +++ b/pytensor/tensor/__init__.py @@ -128,6 +128,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int: from pytensor.tensor.basic import * from pytensor.tensor.blas import batched_dot, batched_tensordot from pytensor.tensor.extra_ops import * +from pytensor.tensor.interpolate import interp, interpolate1d from pytensor.tensor.io import * from pytensor.tensor.math import * from pytensor.tensor.pad import pad