Skip to content

Commit

Permalink
added rewrite for eig when input matrix is identity
Browse files Browse the repository at this point in the history
  • Loading branch information
tanish1729 committed Dec 18, 2024
1 parent fc29a91 commit bdd1679
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 3 deletions.
36 changes: 35 additions & 1 deletion pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,12 +1016,46 @@ def slogdet_specialization(fgraph, node):
return replacements


@register_canonicalize
@register_stabilize
@node_rewriter([eig])
def rewrite_eig_eye(fgraph, node):
"""
This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis.
Parameters
----------
fgraph: FunctionGraph
Function graph being optimized
node: Apply
Node of the function graph to be optimized
Returns
-------
list of Variable, optional
List of optimized variables, or None if no optimization was performed
"""
# Check whether input to Eig is Eye and the 1's are on main diagonal
potential_eye = node.inputs[0]
if not (
potential_eye.owner
and isinstance(potential_eye.owner.op, Eye)
and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0
):
return None

eigval_rewritten = pt.ones(potential_eye.shape[-1])
eigvec_rewritten = pt.eye(potential_eye.shape[-1])

return [eigval_rewritten, eigvec_rewritten]


@register_canonicalize
@register_stabilize
@node_rewriter([eig])
def rewrite_eig_diag(fgraph, node):
"""
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix.
This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis.
The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices
that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to
Expand Down
30 changes: 28 additions & 2 deletions tests/tensor/rewriting/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,14 +1047,40 @@ def test_eig_diag_from_eye_mul(shape):
)


def test_eig_diag_from_diag():
def test_eig_eye():
x = pt.eye(10)
eigval, eigvec = pt.linalg.eig(x)

# REWRITE TEST
f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN")
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Eig) for node in nodes)

# NUMERIC VALUE TEST
x_test = np.eye(10)
eigval, eigvec = np.linalg.eig(x_test)
rewritten_eigval, rewritten_eigvec = f_rewritten()
assert_allclose(
eigval,
rewritten_eigval,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)
assert_allclose(
eigvec,
rewritten_eigvec,
atol=1e-3 if config.floatX == "float32" else 1e-8,
rtol=1e-3 if config.floatX == "float32" else 1e-8,
)


def test_eig_diag():
x = pt.tensor("x", shape=(None,))
x_diag = pt.diag(x)
eigval, eigvec = pt.linalg.eig(x_diag)

# REWRITE TEST
f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN")
f_rewritten.dprint()
nodes = f_rewritten.maker.fgraph.apply_nodes
assert not any(isinstance(node.op, Eig) for node in nodes)

Expand Down

0 comments on commit bdd1679

Please sign in to comment.