From bdd167912395974f87c2e6d6f67c6c46bdebcaae Mon Sep 17 00:00:00 2001 From: Tanish Date: Wed, 18 Dec 2024 21:53:12 +0530 Subject: [PATCH] added rewrite for eig when input matrix is identity --- pytensor/tensor/rewriting/linalg.py | 36 ++++++++++++++++++++++++++- tests/tensor/rewriting/test_linalg.py | 30 ++++++++++++++++++++-- 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/pytensor/tensor/rewriting/linalg.py b/pytensor/tensor/rewriting/linalg.py index c6a094a8a2..da16acf1a4 100644 --- a/pytensor/tensor/rewriting/linalg.py +++ b/pytensor/tensor/rewriting/linalg.py @@ -1016,12 +1016,46 @@ def slogdet_specialization(fgraph, node): return replacements +@register_canonicalize +@register_stabilize +@node_rewriter([eig]) +def rewrite_eig_eye(fgraph, node): + """ + This rewrite takes advantage of the fact that for any identity matrix, all the eigenvalues are 1 and the eigenvectors are the standard basis. + + Parameters + ---------- + fgraph: FunctionGraph + Function graph being optimized + node: Apply + Node of the function graph to be optimized + + Returns + ------- + list of Variable, optional + List of optimized variables, or None if no optimization was performed + """ + # Check whether input to Eig is Eye and the 1's are on main diagonal + potential_eye = node.inputs[0] + if not ( + potential_eye.owner + and isinstance(potential_eye.owner.op, Eye) + and getattr(potential_eye.owner.inputs[-1], "data", -1).item() == 0 + ): + return None + + eigval_rewritten = pt.ones(potential_eye.shape[-1]) + eigvec_rewritten = pt.eye(potential_eye.shape[-1]) + + return [eigval_rewritten, eigvec_rewritten] + + @register_canonicalize @register_stabilize @node_rewriter([eig]) def rewrite_eig_diag(fgraph, node): """ - This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the identity matrix. + This rewrite takes advantage of the fact that for a diagonal matrix, the eigenvalues are simply the diagonal elements and the eigenvectors are the standard basis. The presence of a diagonal matrix is detected by inspecting the graph. This rewrite can identify diagonal matrices that arise as the result of elementwise multiplication with an identity matrix. Specialized computation is used to diff --git a/tests/tensor/rewriting/test_linalg.py b/tests/tensor/rewriting/test_linalg.py index a9afa5a0e5..fa9c5f84e6 100644 --- a/tests/tensor/rewriting/test_linalg.py +++ b/tests/tensor/rewriting/test_linalg.py @@ -1047,14 +1047,40 @@ def test_eig_diag_from_eye_mul(shape): ) -def test_eig_diag_from_diag(): +def test_eig_eye(): + x = pt.eye(10) + eigval, eigvec = pt.linalg.eig(x) + + # REWRITE TEST + f_rewritten = function([], [eigval, eigvec], mode="FAST_RUN") + nodes = f_rewritten.maker.fgraph.apply_nodes + assert not any(isinstance(node.op, Eig) for node in nodes) + + # NUMERIC VALUE TEST + x_test = np.eye(10) + eigval, eigvec = np.linalg.eig(x_test) + rewritten_eigval, rewritten_eigvec = f_rewritten() + assert_allclose( + eigval, + rewritten_eigval, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + assert_allclose( + eigvec, + rewritten_eigvec, + atol=1e-3 if config.floatX == "float32" else 1e-8, + rtol=1e-3 if config.floatX == "float32" else 1e-8, + ) + + +def test_eig_diag(): x = pt.tensor("x", shape=(None,)) x_diag = pt.diag(x) eigval, eigvec = pt.linalg.eig(x_diag) # REWRITE TEST f_rewritten = function([x], [eigval, eigvec], mode="FAST_RUN") - f_rewritten.dprint() nodes = f_rewritten.maker.fgraph.apply_nodes assert not any(isinstance(node.op, Eig) for node in nodes)