-
Notifications
You must be signed in to change notification settings - Fork 115
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Pytensor-native interpolation functions (#1141)
* add interpolate.py * Add jax dispatch for `searchsorted` * Import user-facing functions in `tensor.__init__`
- Loading branch information
1 parent
83c6b44
commit 4e85676
Showing
5 changed files
with
327 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
from collections.abc import Callable | ||
from difflib import get_close_matches | ||
from typing import Literal, get_args | ||
|
||
from pytensor import Variable | ||
from pytensor.tensor.basic import as_tensor_variable, switch | ||
from pytensor.tensor.extra_ops import searchsorted | ||
from pytensor.tensor.functional import vectorize | ||
from pytensor.tensor.math import clip, eq, le | ||
from pytensor.tensor.sort import argsort | ||
|
||
|
||
InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"] | ||
valid_methods = get_args(InterpolationMethod) | ||
|
||
|
||
def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate): | ||
if extrapolate: | ||
return output | ||
|
||
n = x.shape[0] | ||
|
||
return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output)) | ||
|
||
|
||
def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): | ||
clip_idx = clip(idx, 1, x.shape[0] - 1) | ||
|
||
slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1]) | ||
y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1]) | ||
|
||
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) | ||
|
||
|
||
def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): | ||
clip_idx = clip(idx, 1, x.shape[0] - 1) | ||
|
||
left_distance = x_hat - x[clip_idx - 1] | ||
right_distance = x[clip_idx] - x_hat | ||
y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx]) | ||
|
||
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) | ||
|
||
|
||
def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): | ||
clip_idx = clip(idx - 1, 0, x.shape[0] - 1) | ||
y_hat = y[clip_idx] | ||
|
||
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) | ||
|
||
|
||
def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): | ||
clip_idx = clip(idx, 0, x.shape[0] - 1) | ||
y_hat = y[clip_idx] | ||
|
||
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) | ||
|
||
|
||
def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True): | ||
clip_idx = clip(idx, 1, x.shape[0] - 1) | ||
y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2 | ||
|
||
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate) | ||
|
||
|
||
def interpolate1d( | ||
x: Variable, | ||
y: Variable, | ||
method: InterpolationMethod = "linear", | ||
left_pad: Variable | None = None, | ||
right_pad: Variable | None = None, | ||
extrapolate: bool = True, | ||
) -> Callable[[Variable], Variable]: | ||
""" | ||
Create a function to interpolate one-dimensional data. | ||
Parameters | ||
---------- | ||
x : TensorLike | ||
Input data used to create an interpolation function. Data will be sorted to be monotonically increasing. | ||
y: TensorLike | ||
Output data used to create an interpolation function. Must have the same shape as `x`. | ||
method : InterpolationMethod, optional | ||
Method for interpolation. The following methods are available: | ||
- 'linear': Linear interpolation | ||
- 'nearest': Nearest neighbor interpolation | ||
- 'first': Stepwise interpolation using the closest value to the left of the query point | ||
- 'last': Stepwise interpolation using the closest value to the right of the query point | ||
- 'mean': Stepwise interpolation using the mean of the two closest values to the query point | ||
left_pad: TensorLike, optional | ||
Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this | ||
case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`. | ||
right_pad: TensorLike, optional | ||
Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this | ||
case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`. | ||
extrapolate: bool | ||
Whether to extend the request interpolation function beyond the range of the input-output pairs specified in | ||
`x` and `y.` If False, constant values will be returned for such inputs. | ||
Returns | ||
------- | ||
interpolation_func: OpFromGraph | ||
A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns | ||
the interpolated value `y_hat`. The input `x_hat` must be a 1d array. | ||
""" | ||
x = as_tensor_variable(x) | ||
y = as_tensor_variable(y) | ||
|
||
sort_idx = argsort(x) | ||
x = x[sort_idx] | ||
y = y[sort_idx] | ||
|
||
if left_pad is None: | ||
left_pad = y[0] # type: ignore | ||
else: | ||
left_pad = as_tensor_variable(left_pad) | ||
if right_pad is None: | ||
right_pad = y[-1] # type: ignore | ||
else: | ||
right_pad = as_tensor_variable(right_pad) | ||
|
||
def _scalar_interpolate1d(x_hat): | ||
idx = searchsorted(x, x_hat) | ||
|
||
if x.ndim != 1 or y.ndim != 1: | ||
raise ValueError("Inputs must be 1d") | ||
|
||
if method == "linear": | ||
y_hat = _linear_interp1d( | ||
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate | ||
) | ||
elif method == "nearest": | ||
y_hat = _nearest_neighbor_interp1d( | ||
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate | ||
) | ||
elif method == "first": | ||
y_hat = _stepwise_first_interp1d( | ||
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate | ||
) | ||
elif method == "mean": | ||
y_hat = _stepwise_mean_interp1d( | ||
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate | ||
) | ||
elif method == "last": | ||
y_hat = _stepwise_last_interp1d( | ||
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate | ||
) | ||
else: | ||
raise NotImplementedError( | ||
f"Unknown interpolation method: {method}. " | ||
f"Did you mean {get_close_matches(method, valid_methods)}?" | ||
) | ||
|
||
return y_hat | ||
|
||
return vectorize(_scalar_interpolate1d, signature="()->()") | ||
|
||
|
||
def interp(x, xp, fp, left=None, right=None, period=None): | ||
""" | ||
One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that | ||
matches ``np.interp`` | ||
Parameters | ||
---------- | ||
x : TensorLike | ||
The x-coordinates at which to evaluate the interpolated values. | ||
xp : TensorLike | ||
The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise, | ||
`xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``. | ||
fp : TensorLike | ||
The y-coordinates of the data points, same length as `xp`. | ||
left : float, optional | ||
Value to return for `x < xp[0]`. Default is `fp[0]`. | ||
right : float, optional | ||
Value to return for `x > xp[-1]`. Default is `fp[-1]`. | ||
period : None | ||
Not supported. Included to ensure the signature of this function matches ``numpy.interp``. | ||
Returns | ||
------- | ||
y : Variable | ||
The interpolated values, same shape as `x`. | ||
""" | ||
|
||
xp = as_tensor_variable(xp) | ||
fp = as_tensor_variable(fp) | ||
x = as_tensor_variable(x) | ||
|
||
f = interpolate1d( | ||
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False | ||
) | ||
|
||
return f(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
import numpy as np | ||
import pytest | ||
from numpy.testing import assert_allclose | ||
|
||
import pytensor | ||
import pytensor.tensor as pt | ||
from pytensor.tensor.interpolate import ( | ||
InterpolationMethod, | ||
interp, | ||
interpolate1d, | ||
valid_methods, | ||
) | ||
|
||
|
||
floatX = pytensor.config.floatX | ||
|
||
|
||
def test_interp(): | ||
xp = [1.0, 2.0, 3.0] | ||
fp = [3.0, 2.0, 0.0] | ||
|
||
x = [0, 1, 1.5, 2.72, 3.14] | ||
|
||
out = interp(x, xp, fp).eval() | ||
np_out = np.interp(x, xp, fp) | ||
|
||
assert_allclose(out, np_out) | ||
|
||
|
||
def test_interp_padded(): | ||
xp = [1.0, 2.0, 3.0] | ||
fp = [3.0, 2.0, 0.0] | ||
|
||
assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0 | ||
assert_allclose( | ||
interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0] | ||
) | ||
assert_allclose( | ||
interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0] | ||
) | ||
|
||
|
||
@pytest.mark.parametrize("method", valid_methods, ids=str) | ||
@pytest.mark.parametrize( | ||
"left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)] | ||
) | ||
def test_interpolate_scalar_no_extrapolate( | ||
method: InterpolationMethod, left_pad, right_pad | ||
): | ||
x = np.linspace(-2, 6, 10) | ||
y = np.sin(x) | ||
|
||
f_op = interpolate1d( | ||
x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad | ||
) | ||
x_hat_pt = pt.dscalar("x_hat") | ||
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") | ||
|
||
# Data points should be returned exactly, except when method == mean | ||
if method not in ["mean", "first"]: | ||
assert f(x[3]) == y[3] | ||
elif method == "first": | ||
assert f(x[3]) == y[2] | ||
else: | ||
# method == 'mean | ||
assert f(x[3]) == (y[2] + y[3]) / 2 | ||
|
||
# When extrapolate=False, points beyond the data envelope should be constant | ||
left_pad = y[0] if left_pad is None else left_pad | ||
right_pad = y[-1] if right_pad is None else right_pad | ||
|
||
assert f(-10) == left_pad | ||
assert f(100) == right_pad | ||
|
||
|
||
@pytest.mark.parametrize("method", valid_methods, ids=str) | ||
def test_interpolate_scalar_extrapolate(method: InterpolationMethod): | ||
x = np.linspace(-2, 6, 10) | ||
y = np.sin(x) | ||
|
||
f_op = interpolate1d(x, y, method) | ||
x_hat_pt = pt.dscalar("x_hat") | ||
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN") | ||
|
||
left_test_point = -5 | ||
right_test_point = 100 | ||
if method == "linear": | ||
# Linear will compute a slope from the endpoints and continue it | ||
left_slope = (left_test_point - x[0]) / (x[1] - x[0]) | ||
right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2]) | ||
assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0]) | ||
assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2]) | ||
|
||
elif method == "mean": | ||
left_expected = (y[0] + y[1]) / 2 | ||
right_expected = (y[-1] + y[-2]) / 2 | ||
assert f(left_test_point) == left_expected | ||
assert f(right_test_point) == right_expected | ||
|
||
else: | ||
assert f(left_test_point) == y[0] | ||
assert f(right_test_point) == y[-1] | ||
|
||
# For interior points, "first" and "last" should disagree. First should take the left side of the interval, | ||
# and last should take the right. | ||
interior_point = x[3] + 0.1 | ||
assert f(interior_point) == (y[4] if method == "last" else y[3]) |