Skip to content

Commit

Permalink
coverage for last line
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Aug 13, 2024
1 parent e5dc9f4 commit 5b8ecdb
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
10 changes: 3 additions & 7 deletions heat/decomposition/pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
raise NotImplementedError("Whitening is not yet supported. Please set whiten=False.")
if not (svd_solver == "full" or svd_solver == "hierarchical" or svd_solver == "randomized"):
raise ValueError(
"At the moment, only svd_solver='full' (for tall-skinny or short-fat data) and svd_solver='hierarchical' are supported. \n An implementation of the 'full' option for arbitrarily shaped data as well as the option 'randomized' are already planned."
"At the moment, only svd_solver='full' (for tall-skinny or short-fat data), svd_solver='hierarchical', and svd_solver='randomized' are supported. \n An implementation of the 'full' option for arbitrarily shaped data is already planned."
)
if not isinstance(iterated_power, int):
raise TypeError(
Expand Down Expand Up @@ -222,7 +222,8 @@ def fit(self, X: ht.DNDarray, y=None) -> Self:
self.components_ = V.T
self.total_explained_variance_ratio_ = 1 - info.larray.item() ** 2

elif self.svd_solver == "randomized":
else:
# compute SVD via "randomized" SVD
_, S, V = ht.linalg.rsvd(
X_centered,
self.n_components_,
Expand All @@ -231,11 +232,6 @@ def fit(self, X: ht.DNDarray, y=None) -> Self:
)
self.components_ = V.T
self.n_components_ = V.shape[1]
else:
# here one could add other computational backends
raise NotImplementedError(
f"The chosen svd_solver {self.svd_solver} is not yet implemented."
)

self.n_samples_ = X.shape[0]
self.noise_variance_ = None # not yet implemented
Expand Down
3 changes: 3 additions & 0 deletions heat/decomposition/tests/test_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,6 @@ def test_pca_randomized(self):
self.assertEqual(pca.explained_variance_, None)
self.assertEqual(pca.explained_variance_ratio_, None)
self.assertEqual(pca.singular_values_, None)

pca = ht.decomposition.PCA(n_components=None, svd_solver="randomized", random_state=1234)
self.assertEqual(ht.random.get_state()[1], 1234)

0 comments on commit 5b8ecdb

Please sign in to comment.