From 5b8ecdb4b3602e9d0387c6df67b20e7b383b5904 Mon Sep 17 00:00:00 2001 From: Hoppe Date: Tue, 13 Aug 2024 13:08:25 +0200 Subject: [PATCH] coverage for last line --- heat/decomposition/pca.py | 10 +++------- heat/decomposition/tests/test_pca.py | 3 +++ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/heat/decomposition/pca.py b/heat/decomposition/pca.py index a807bfacc..5b84098d0 100644 --- a/heat/decomposition/pca.py +++ b/heat/decomposition/pca.py @@ -99,7 +99,7 @@ def __init__( raise NotImplementedError("Whitening is not yet supported. Please set whiten=False.") if not (svd_solver == "full" or svd_solver == "hierarchical" or svd_solver == "randomized"): raise ValueError( - "At the moment, only svd_solver='full' (for tall-skinny or short-fat data) and svd_solver='hierarchical' are supported. \n An implementation of the 'full' option for arbitrarily shaped data as well as the option 'randomized' are already planned." + "At the moment, only svd_solver='full' (for tall-skinny or short-fat data), svd_solver='hierarchical', and svd_solver='randomized' are supported. \n An implementation of the 'full' option for arbitrarily shaped data is already planned." ) if not isinstance(iterated_power, int): raise TypeError( @@ -222,7 +222,8 @@ def fit(self, X: ht.DNDarray, y=None) -> Self: self.components_ = V.T self.total_explained_variance_ratio_ = 1 - info.larray.item() ** 2 - elif self.svd_solver == "randomized": + else: + # compute SVD via "randomized" SVD _, S, V = ht.linalg.rsvd( X_centered, self.n_components_, @@ -231,11 +232,6 @@ def fit(self, X: ht.DNDarray, y=None) -> Self: ) self.components_ = V.T self.n_components_ = V.shape[1] - else: - # here one could add other computational backends - raise NotImplementedError( - f"The chosen svd_solver {self.svd_solver} is not yet implemented." - ) self.n_samples_ = X.shape[0] self.noise_variance_ = None # not yet implemented diff --git a/heat/decomposition/tests/test_pca.py b/heat/decomposition/tests/test_pca.py index 5521c0d89..58fc361ce 100644 --- a/heat/decomposition/tests/test_pca.py +++ b/heat/decomposition/tests/test_pca.py @@ -207,3 +207,6 @@ def test_pca_randomized(self): self.assertEqual(pca.explained_variance_, None) self.assertEqual(pca.explained_variance_ratio_, None) self.assertEqual(pca.singular_values_, None) + + pca = ht.decomposition.PCA(n_components=None, svd_solver="randomized", random_state=1234) + self.assertEqual(ht.random.get_state()[1], 1234)