Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytensor-native interpolation functions #1141

Merged
merged 5 commits into from
Dec 30, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions pytensor/link/jax/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
SearchsortedOp,
Unique,
UnravelIndex,
)
Expand Down Expand Up @@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
# return filldiagonaloffset

raise NotImplementedError("flatiter not implemented in JAX")


@jax_funcify.register(SearchsortedOp)
def jax_funcify_SearchsortedOp(op, **kwargs):
side = op.side

def searchsorted(a, v, side=side, sorter=None):
return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter)

return searchsorted
200 changes: 200 additions & 0 deletions pytensor/tensor/interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from collections.abc import Callable
from difflib import get_close_matches
from typing import Literal, get_args

from pytensor.tensor import TensorLike
from pytensor.tensor.basic import as_tensor_variable, switch
from pytensor.tensor.extra_ops import searchsorted
from pytensor.tensor.functional import vectorize
from pytensor.tensor.math import clip, eq, le
from pytensor.tensor.sort import argsort


InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
valid_methods = get_args(InterpolationMethod)


def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate):
if extrapolate:
return output

n = x.shape[0]

return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output))


def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)

slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1])
y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1])

return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)


def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)

left_distance = x_hat - x[clip_idx - 1]
right_distance = x[clip_idx] - x_hat
y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx])

return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)


def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx - 1, 0, x.shape[0] - 1)
y_hat = y[clip_idx]

return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)


def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 0, x.shape[0] - 1)
y_hat = y[clip_idx]

return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)


def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
clip_idx = clip(idx, 1, x.shape[0] - 1)
y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2

return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)


def interpolate1d(
x: TensorLike,
y: TensorLike,
method: InterpolationMethod = "linear",
left_pad: TensorLike | None = None,
right_pad: TensorLike | None = None,
extrapolate: bool = True,
) -> Callable[[TensorLike], TensorLike]:
"""
Create a function to interpolate one-dimensional data.

Parameters
----------
x : TensorLike
Input data used to create an interpolation function. Data will be sorted to be monotonically increasing.
y: TensorLike
Output data used to create an interpolation function. Must have the same shape as `x`.
method : InterpolationMethod, optional
Method for interpolation. The following methods are available:
- 'linear': Linear interpolation
- 'nearest': Nearest neighbor interpolation
- 'first': Stepwise interpolation using the closest value to the left of the query point
- 'last': Stepwise interpolation using the closest value to the right of the query point
- 'mean': Stepwise interpolation using the mean of the two closest values to the query point
left_pad: TensorLike, optional
Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this
case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`.
right_pad: TensorLike, optional
Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this
case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`.
extrapolate: bool
Whether to extend the request interpolation function beyond the range of the input-output pairs specified in
`x` and `y.` If False, constant values will be returned for such inputs.

Returns
-------
interpolation_func: OpFromGraph
A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns
the interpolated value `y_hat`. The input `x_hat` must be a 1d array.

"""
x = as_tensor_variable(x)
y = as_tensor_variable(y)

sort_idx = argsort(x)
x = x[sort_idx]
y = y[sort_idx]

if left_pad is None:
left_pad = y[0]
else:
left_pad = as_tensor_variable(left_pad)
if right_pad is None:
right_pad = y[-1]
else:
right_pad = as_tensor_variable(right_pad)

def _scalar_interpolate1d(x_hat):
idx = searchsorted(x, x_hat)

if x.ndim != 1 or y.ndim != 1:
raise ValueError("Inputs must be 1d")

Check warning on line 127 in pytensor/tensor/interpolate.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/interpolate.py#L127

Added line #L127 was not covered by tests

if method == "linear":
y_hat = _linear_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "nearest":
y_hat = _nearest_neighbor_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "first":
y_hat = _stepwise_first_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "mean":
y_hat = _stepwise_mean_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
elif method == "last":
y_hat = _stepwise_last_interp1d(
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
)
else:
raise NotImplementedError(

Check warning on line 150 in pytensor/tensor/interpolate.py

View check run for this annotation

Codecov / codecov/patch

pytensor/tensor/interpolate.py#L150

Added line #L150 was not covered by tests
f"Unknown interpolation method: {method}. "
f"Did you mean {get_close_matches(method, valid_methods)}?"
)

return y_hat

return vectorize(_scalar_interpolate1d, signature="()->()")


def interp(x, xp, fp, left=None, right=None, period=None):
"""
One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that
matches ``np.interp``

Parameters
----------
x : TensorLike
The x-coordinates at which to evaluate the interpolated values.

xp : TensorLike
The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise,
`xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``.

fp : TensorLike
The y-coordinates of the data points, same length as `xp`.

left : float, optional
Value to return for `x < xp[0]`. Default is `fp[0]`.

right : float, optional
Value to return for `x > xp[-1]`. Default is `fp[-1]`.

period : None
Not supported. Included to ensure the signature of this function matches ``numpy.interp``.

Returns
-------
y : Variable
The interpolated values, same shape as `x`.
"""

xp = as_tensor_variable(xp)
fp = as_tensor_variable(fp)
x = as_tensor_variable(x)

f = interpolate1d(
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
)

return f(x)
8 changes: 8 additions & 0 deletions tests/link/jax/test_extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor import extra_ops as pt_extra_ops
from pytensor.tensor.sort import argsort
from pytensor.tensor.type import matrix, tensor
from tests.link.jax.test_basic import compare_jax_and_py

Expand Down Expand Up @@ -55,6 +56,13 @@ def test_extra_ops():
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)

v = ptb.as_tensor_variable(6.0)
sorted_idx = argsort(a.ravel())

out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [a_test])


@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
def test_bartlett_dynamic_shape():
Expand Down
107 changes: 107 additions & 0 deletions tests/tensor/test_interpolate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import numpy as np
import pytest
from numpy.testing import assert_allclose

import pytensor
import pytensor.tensor as pt
from pytensor.tensor.interpolate import (
InterpolationMethod,
interp,
interpolate1d,
valid_methods,
)


floatX = pytensor.config.floatX


def test_interp():
xp = [1.0, 2.0, 3.0]
fp = [3.0, 2.0, 0.0]

x = [0, 1, 1.5, 2.72, 3.14]

out = interp(x, xp, fp).eval()
np_out = np.interp(x, xp, fp)

assert_allclose(out, np_out)


def test_interp_padded():
xp = [1.0, 2.0, 3.0]
fp = [3.0, 2.0, 0.0]

assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0
assert_allclose(
interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0]
)
assert_allclose(
interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0]
)


@pytest.mark.parametrize("method", valid_methods, ids=str)
@pytest.mark.parametrize(
"left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)]
)
def test_interpolate_scalar_no_extrapolate(
method: InterpolationMethod, left_pad, right_pad
):
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = interpolate1d(
x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad
)
x_hat_pt = pt.dscalar("x_hat")
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")

# Data points should be returned exactly, except when method == mean
if method not in ["mean", "first"]:
assert f(x[3]) == y[3]
elif method == "first":
assert f(x[3]) == y[2]
else:
# method == 'mean
assert f(x[3]) == (y[2] + y[3]) / 2

# When extrapolate=False, points beyond the data envelope should be constant
left_pad = y[0] if left_pad is None else left_pad
right_pad = y[-1] if right_pad is None else right_pad

assert f(-10) == left_pad
assert f(100) == right_pad


@pytest.mark.parametrize("method", valid_methods, ids=str)
def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = interpolate1d(x, y, method)
x_hat_pt = pt.dscalar("x_hat")
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")

left_test_point = -5
right_test_point = 100
if method == "linear":
# Linear will compute a slope from the endpoints and continue it
left_slope = (left_test_point - x[0]) / (x[1] - x[0])
right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2])
assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0])
assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2])

elif method == "mean":
left_expected = (y[0] + y[1]) / 2
right_expected = (y[-1] + y[-2]) / 2
assert f(left_test_point) == left_expected
assert f(right_test_point) == right_expected

else:
assert f(left_test_point) == y[0]
assert f(right_test_point) == y[-1]

# For interior points, "first" and "last" should disagree. First should take the left side of the interval,
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])
Loading