Skip to content

Commit

Permalink
Refactor LinearOperator for method overriding (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
timweiland authored Nov 8, 2022
1 parent 3999850 commit ce4bbba
Show file tree
Hide file tree
Showing 10 changed files with 538 additions and 316 deletions.
32 changes: 20 additions & 12 deletions docs/source/tutorials/linops/linear_operators_quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@
{
"data": {
"text/plain": [
"<SumLinearOperator with shape=(5, 5) and dtype=float64>"
"<Matrix with shape=(5, 5) and dtype=float64>"
]
},
"execution_count": 9,
Expand All @@ -269,7 +269,10 @@
{
"data": {
"text/plain": [
"<ProductLinearOperator with shape=(5, 5) and dtype=float64>"
"ProductLinearOperator [\n",
"\t<Matrix with shape=(5, 5) and dtype=float64>, \n",
"\t<Matrix with shape=(5, 5) and dtype=float64>, \n",
"]"
]
},
"execution_count": 10,
Expand Down Expand Up @@ -311,10 +314,10 @@
{
"data": {
"text/plain": [
"array([-1.39282086, -2.09807924, -1.01469708, -0.74204673, -3.26963901,\n",
" -0.92439367, -0.65638407, 0.43823505, 0.66964627, -0.316306 ,\n",
" 5.7153326 , 0.43495681, 0.46390134, -2.66045433, 0.62615866,\n",
" 0.00715237, -0.83637837, -0.95389845, -0.41350942, -1.23499484])"
"array([ 1.49421769, -1.35451937, 1.05551543, -0.41823967, 0.42934955,\n",
" -0.82155968, -1.93141743, -4.31860989, -1.70475714, 4.36385187,\n",
" 2.36850628, -2.94034717, 0.39821307, -1.08656905, 0.36490375,\n",
" -0.86441656, -0.44778464, -0.44155178, 0.55687361, 0.17178464])"
]
},
"execution_count": 11,
Expand Down Expand Up @@ -361,7 +364,7 @@
{
"data": {
"text/plain": [
"<LinearOperator with shape=(5, 5) and dtype=float64>"
"<LambdaLinearOperator with shape=(5, 5) and dtype=float64>"
]
},
"execution_count": 12,
Expand All @@ -370,14 +373,14 @@
}
],
"source": [
"from probnum.linops import LinearOperator\n",
"from probnum.linops import LinearOperator, LambdaLinearOperator\n",
"\n",
"@LinearOperator.broadcast_matvec\n",
"def mv(v):\n",
" return np.roll(v, 1)\n",
"\n",
"n = 5\n",
"P_op = LinearOperator(shape=(n, n), dtype=np.float_, matmul=mv)\n",
"P_op = LambdaLinearOperator(shape=(n, n), dtype=np.float_, matmul=mv)\n",
"x = np.arange(0., n, 1)\n",
"\n",
"P_op"
Expand Down Expand Up @@ -509,7 +512,7 @@
"def mv(v):\n",
" return v[:n-1]\n",
"\n",
"Pr = LinearOperator(shape=(n-1, n), dtype=np.float_, matmul=mv)\n",
"Pr = LambdaLinearOperator(shape=(n-1, n), dtype=np.float_, matmul=mv)\n",
"\n",
"# Apply the operator to the 3D normal random variable\n",
"rv_projected = Pr @ rv"
Expand Down Expand Up @@ -602,7 +605,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": "Python 3.10.6 (conda)",
"language": "python",
"name": "python3"
},
Expand All @@ -616,7 +619,12 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.5"
"version": "3.10.6"
},
"vscode": {
"interpreter": {
"hash": "0457b12441837086dec1b475e0008c28e5fc37f4ffe0e5ee9f2b481cc28bc3c9"
}
}
},
"nbformat": 4,
Expand Down
18 changes: 9 additions & 9 deletions src/probnum/linalg/solvers/matrixbased.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,14 +109,14 @@ def _matmul(M):

Ainv0_mean = linops.Scaling(
alpha, shape=(self.n, self.n)
) + 2 / bx0 * linops.LinearOperator(
) + 2 / bx0 * linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(x0.dtype, alpha.dtype, b.dtype),
matmul=_matmul,
)
A0_mean = linops.Scaling(1 / alpha, shape=(self.n, self.n)) - 1 / (
alpha * np.squeeze((x0 - alpha * b).T @ x0)
) * linops.LinearOperator(
) * linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(x0.dtype, alpha.dtype, b.dtype),
matmul=_matmul,
Expand Down Expand Up @@ -632,7 +632,7 @@ def null_space_proj(x):

# Compute calibration term in the A view as a linear operator with scaling from
# degrees of freedom
calibration_term_A = linops.LinearOperator(
calibration_term_A = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=S.dtype,
matmul=linops.LinearOperator.broadcast_matvec(
Expand All @@ -642,7 +642,7 @@ def null_space_proj(x):

# Compute calibration term in the Ainv view as a linear operator with scaling
# from degrees of freedom
calibration_term_Ainv = linops.LinearOperator(
calibration_term_Ainv = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=S.dtype,
matmul=linops.LinearOperator.broadcast_matvec(
Expand All @@ -669,7 +669,7 @@ def _matmul(x):
# First term of calibration covariance class: AS(S'AS)^{-1}S'A
return (Y * sy**-1) @ (Y.T @ x.ravel())

_A_covfactor0 = linops.LinearOperator(
_A_covfactor0 = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Y, sy),
matmul=_matmul,
Expand All @@ -686,7 +686,7 @@ def _matmul(x):
)
return self.Ainv_mean0 @ (Y @ YAinv0Y_inv_YAinv0x)

_Ainv_covfactor0 = linops.LinearOperator(
_Ainv_covfactor0 = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Y, self.Ainv_mean0),
matmul=_matmul,
Expand Down Expand Up @@ -733,7 +733,7 @@ def _matmul(x):
def _matmul(x):
return 0.5 * (bWb * _Ainv_covfactor @ x + Wb @ (Wb.T @ x))

cov_op = linops.LinearOperator(
cov_op = linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(Wb.dtype, bWb.dtype),
matmul=_matmul,
Expand All @@ -755,7 +755,7 @@ def _mean_update(self, u, v):
def _matmul(x):
return u @ (v.T @ x) + v @ (u.T @ x)

return linops.LinearOperator(
return linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(u.dtype, v.dtype),
matmul=_matmul,
Expand All @@ -768,7 +768,7 @@ def _covariance_update(self, u, Ws):
def _matmul(x):
return Ws @ (u.T @ x)

return linops.LinearOperator(
return linops.LambdaLinearOperator(
shape=(self.n, self.n),
dtype=np.result_type(u.dtype, Ws.dtype),
matmul=_matmul,
Expand Down
11 changes: 10 additions & 1 deletion src/probnum/linops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
"""

from ._kronecker import IdentityKronecker, Kronecker, SymmetricKronecker, Symmetrize
from ._linear_operator import Embedding, Identity, LinearOperator, Matrix, Selection
from ._linear_operator import (
Embedding,
Identity,
LambdaLinearOperator,
LinearOperator,
Matrix,
Selection,
)
from ._scaling import Scaling, Zero
from ._utils import LinearOperatorLike, aslinop

Expand All @@ -22,6 +29,7 @@
"aslinop",
"Embedding",
"LinearOperator",
"LambdaLinearOperator",
"Matrix",
"Identity",
"IdentityKronecker",
Expand All @@ -35,6 +43,7 @@

# Set correct module paths. Corrects links and module paths in documentation.
LinearOperator.__module__ = "probnum.linops"
LambdaLinearOperator.__module__ = "probnum.linops"
Embedding.__module__ = "probnum.linops"
Matrix.__module__ = "probnum.linops"
Identity.__module__ = "probnum.linops"
Expand Down
8 changes: 4 additions & 4 deletions src/probnum/linops/_arithmetic_fallbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
from probnum.typing import NotImplementedType, ScalarLike
import probnum.utils

from ._linear_operator import BinaryOperandType, LinearOperator
from ._linear_operator import BinaryOperandType, LambdaLinearOperator, LinearOperator

########################################################################################
# Generic Linear Operator Arithmetic (Fallbacks)
########################################################################################


class ScaledLinearOperator(LinearOperator):
class ScaledLinearOperator(LambdaLinearOperator):
"""Linear operator scaled with a scalar."""

def __init__(self, linop: LinearOperator, scalar: ScalarLike):
Expand Down Expand Up @@ -81,7 +81,7 @@ def __repr__(self) -> str:
return f"-{self._linop}"


class SumLinearOperator(LinearOperator):
class SumLinearOperator(LambdaLinearOperator):
"""Sum of linear operators."""

def __init__(self, *summands: LinearOperator):
Expand Down Expand Up @@ -166,7 +166,7 @@ def _mul_fallback(
return res


class ProductLinearOperator(LinearOperator):
class ProductLinearOperator(LambdaLinearOperator):
"""(Operator) Product of linear operators."""

def __init__(self, *factors: LinearOperator):
Expand Down
64 changes: 7 additions & 57 deletions src/probnum/linops/_kronecker.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import _linear_operator, _utils


class Symmetrize(_linear_operator.LinearOperator):
class Symmetrize(_linear_operator.LambdaLinearOperator):
r"""Symmetrizes a vector in its matrix representation.
Given a vector :math:`x=\operatorname{vec}(X)`
Expand Down Expand Up @@ -65,7 +65,7 @@ def _matmul(self, x: np.ndarray) -> np.ndarray:
)


class Kronecker(_linear_operator.LinearOperator):
class Kronecker(_linear_operator.LambdaLinearOperator):
"""Kronecker product of two linear operators.
The Kronecker product [1]_ :math:`A \\otimes B` of two linear operators :math:`A`
Expand Down Expand Up @@ -136,7 +136,6 @@ def __init__(self, A: LinearOperatorLike, B: LinearOperatorLike):
self.A.shape[1] * self.B.shape[1],
),
matmul=lambda x: _kronecker_matmul(self.A, self.B, x),
rmatmul=lambda x: _kronecker_rmatmul(self.A, self.B, x),
todense=lambda: np.kron(
self.A.todense(cache=False), self.B.todense(cache=False)
),
Expand Down Expand Up @@ -260,29 +259,7 @@ def _kronecker_matmul(
return y


def _kronecker_rmatmul(
A: _linear_operator.LinearOperator,
B: _linear_operator.LinearOperator,
x: np.ndarray,
) -> np.ndarray:
# Reshape into stack of matrices
y = x

if not y.flags.c_contiguous:
y = y.copy(order="C")

y = y.reshape(y.shape[:-1] + (A.shape[0], B.shape[0]))

# ((A.T) @ X) @ (B.T).T
y = (A.T @ y) @ B

# Revert to stack of vectorized matrices
y = y.reshape(y.shape[:-2] + (-1,))

return y


class SymmetricKronecker(_linear_operator.LinearOperator):
class SymmetricKronecker(_linear_operator.LambdaLinearOperator):
"""Symmetric Kronecker product of two linear operators.
The symmetric Kronecker product [1]_ :math:`A \\otimes_{s} B` of two square linear
Expand Down Expand Up @@ -337,7 +314,6 @@ def __init__(

dtype = self.A.dtype
matmul = lambda x: _kronecker_matmul(self.A, self.A, x)
rmatmul = lambda x: _kronecker_rmatmul(self.A, self.A, x)
todense = self._todense_identical_factors
# (A (x)_s A)^T = A^T (x)_s A^T
transpose = lambda: SymmetricKronecker(A=self.A.T)
Expand All @@ -357,7 +333,6 @@ def __init__(

dtype = np.result_type(self.A.dtype, self.B.dtype, 0.5)
matmul = self._matmul_different_factors
rmatmul = self._rmatmul_different_factors
todense = self._todense_different_factors
# (A (x)_s B)^T = A^T (x)_s B^T
transpose = lambda: SymmetricKronecker(A=self.A.T, B=self.B.T)
Expand All @@ -371,7 +346,6 @@ def __init__(
dtype=dtype,
shape=2 * (self._n**2,),
matmul=matmul,
rmatmul=rmatmul,
todense=todense,
transpose=transpose,
inverse=inverse,
Expand Down Expand Up @@ -441,29 +415,6 @@ def _matmul_different_factors(self, x: np.ndarray) -> np.ndarray:

return y

def _rmatmul_different_factors(self, x: np.ndarray) -> np.ndarray:
# Reshape into stack of matrices
y = x

if not y.flags.c_contiguous:
y = y.copy(order="C")

y = y.reshape(y.shape[:-1] + (self._n, self._n))

# (A.T) @ X @ (B.T).T
y1 = (self.A.T @ y) @ self.B

# (B.T) @ X @ (A.T).T
y2 = (self.B.T @ y) @ self.A

# 1/2 ((A^T)X(B^T)^T + (B^T)X(A^T)^T)
y = 0.5 * (y1 + y2)

# Revert to stack of vectorized matrices
y = y.reshape(y.shape[:-2] + (-1,))

return y

def _todense_identical_factors(self) -> np.ndarray:
"""Dense representation of the symmetric Kronecker product."""
# 1/2 (A (x) B + B (x) A)
Expand Down Expand Up @@ -498,7 +449,7 @@ def _symmetrize(self) -> SymmetricKronecker:
return SymmetricKronecker(A=self.A.symmetrize(), B=self.B.symmetrize())


class IdentityKronecker(_linear_operator.LinearOperator):
class IdentityKronecker(_linear_operator.LambdaLinearOperator):
"""Block-diagonal linear operator.
Parameters
Expand Down Expand Up @@ -533,9 +484,6 @@ def __init__(self, num_blocks: int, B: LinearOperatorLike):
matmul=lambda x: _kronecker_matmul(
self.A, self.B, x
), # TODO: can be implemented more efficiently
rmatmul=lambda x: _kronecker_rmatmul(
self.A, self.B, x
), # TODO: can be implemented more efficiently
todense=lambda: np.kron(
self.A.todense(cache=False), self.B.todense(cache=False)
),
Expand Down Expand Up @@ -589,7 +537,9 @@ def _sub_idkronecker(

return NotImplemented

def _cond(self, p) -> np.inexact:
def _cond(
self, p: Optional[Union[None, int, str, np.floating]] = None
) -> np.number:
if p is None or p in (2, 1, np.inf, "fro", -2, -1, -np.inf):
return self.A.cond(p=p) * self.B.cond(p=p)

Expand Down
Loading

0 comments on commit ce4bbba

Please sign in to comment.