Skip to content

Commit

Permalink
typo in jax dispatch
Browse files Browse the repository at this point in the history
  • Loading branch information
jessegrabowski committed Dec 28, 2024
1 parent 4af40c4 commit 9313e6d
Showing 1 changed file with 18 additions and 0 deletions.
18 changes: 18 additions & 0 deletions tests/tensor/test_interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
InterpolationMethod,
interp,
interpolate1d,
polynomial_interpolate1d,
valid_methods,
)

Expand Down Expand Up @@ -105,3 +106,20 @@ def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
# and last should take the right.
interior_point = x[3] + 0.1
assert f(interior_point) == (y[4] if method == "last" else y[3])


def test_polynomial_interpolate1d():
x = np.linspace(-2, 6, 10)
y = np.sin(x)

f_op = polynomial_interpolate1d(x, y)
x_hat_pt = pt.dvector("x_hat")
degree = pt.iscalar("degree")

f = pytensor.function(
[x_hat_pt, degree], f_op(x_hat_pt, degree, True), mode="FAST_RUN"
)
x_grid = np.linspace(-2, 6, 100)
y_hat = f(x_grid, 0)

assert_allclose(y_hat, np.mean(y))

0 comments on commit 9313e6d

Please sign in to comment.