Skip to content

Commit

Permalink
add pspline penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
jobrachem committed Aug 26, 2024
1 parent 5aae977 commit 8829c8c
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
22 changes: 22 additions & 0 deletions liesel/splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,3 +189,25 @@ def basis_matrix(
design_matrix = vmap(lambda x: _build_basis_vector(x, knots, order))(x)

return design_matrix


def pspline_penalty(d: int, diff: int = 2):
"""
Builds an (n_param x n_param) P-spline penalty matrix.
Parameters
----------
d
Integer, dimension of the matrix. Corresponds to the number of parameters \
in a P-spline.
diff
Order of the differences used in constructing the penalty matrix. The default \
of ``diff=2`` corresponds to the common P-spline default of penalizing second \
differences.
Returns
-------
A 2d array, the penalty matrix.
"""
D = jnp.diff(jnp.identity(d), diff, axis=0)
return D.T @ D
16 changes: 15 additions & 1 deletion tests/test_splines.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest
from scipy.interpolate import BSpline

from liesel.splines import basis_matrix, equidistant_knots
from liesel.splines import basis_matrix, equidistant_knots, pspline_penalty


def test_knots_creation():
Expand Down Expand Up @@ -109,3 +109,17 @@ def test_vectorize(self):
B = basis_matrix_vec(x, knots, order, True)

assert B.shape == (2, n, n_params)


class TestPsplinePenalty:
def test_shape(self):
d = 10

P = pspline_penalty(d)
assert P.shape == (d, d)

def test_rank(self):
d = 10
for diff in [1, 2, 3]:
P = pspline_penalty(d, diff=diff)
assert jnp.linalg.matrix_rank(P) == d - diff

0 comments on commit 8829c8c

Please sign in to comment.