Skip to content

Commit

Permalink
Merge branch 'main' into features/1707-Batched_QR
Browse files Browse the repository at this point in the history
  • Loading branch information
ClaudiaComito authored Dec 2, 2024
2 parents 960e472 + c23428e commit c8702a8
Show file tree
Hide file tree
Showing 9 changed files with 587 additions and 37 deletions.
17 changes: 17 additions & 0 deletions benchmarks/cb/decomposition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# flake8: noqa
import heat as ht
from mpi4py import MPI
from perun import monitor
from heat.decomposition import IncrementalPCA


@monitor()
def incremental_pca_split0(list_of_X, n_components):
ipca = IncrementalPCA(n_components=n_components)
for X in list_of_X:
ipca.partial_fit(X)


def run_decomposition_benchmarks():
list_of_X = [ht.random.rand(50000, 500, split=0) for _ in range(10)]
incremental_pca_split0(list_of_X, 50)
2 changes: 2 additions & 0 deletions benchmarks/cb/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
from cluster import run_cluster_benchmarks
from manipulations import run_manipulation_benchmarks
from preprocessing import run_preprocessing_benchmarks
from decomposition import run_decomposition_benchmarks

run_linalg_benchmarks()
run_cluster_benchmarks()
run_manipulation_benchmarks()
run_preprocessing_benchmarks()
run_decomposition_benchmarks()
4 changes: 2 additions & 2 deletions heat/core/linalg/qr.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def qr(
for i in range(last_row_reached + 1):
# this loop goes through all the column-blocks (i.e. local arrays) of the matrix
# this corresponds to the loop over all columns in classical Gram-Schmidt

if i < nprocs - 1:
k_loc_i = min(A.shape[-2], A.lshape_map[i, -1])
Q_buf = torch.zeros(
Expand All @@ -163,8 +164,7 @@ def qr(

if i < nprocs - 1:
# broadcast the orthogonalized block of columns to all other processes
req = A.comm.Ibcast(Q_buf, root=i)
req.Wait()
A.comm.Bcast(Q_buf, root=i)

if A.comm.rank > i:
# subtract the contribution of the current block of columns from the remaining columns
Expand Down
240 changes: 206 additions & 34 deletions heat/core/linalg/svdtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,18 @@
from ..linalg import matmul, vector_norm, qr, svd
from ..indexing import where
from ..random import randn

from ..sanitation import sanitize_in_nd_realfloating
from ..manipulations import vstack, hstack, diag, balance

from .. import statistics
from math import log, ceil, floor, sqrt


__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd", "rsvd"]


def _check_SVD_input(A):
if not isinstance(A, DNDarray):
raise TypeError(f"Argument needs to be a DNDarray but is {type(A)}.")
if not A.ndim == 2:
raise ValueError("A needs to be a 2D matrix")
if not types.heat_type_is_realfloating(A.dtype):
raise TypeError(
"Argument needs to be a DNDarray with datatype float32 or float64, but data type is {}.".format(
A.dtype
)
)
return None
__all__ = ["hsvd_rank", "hsvd_rtol", "hsvd", "rsvd", "isvd"]


#######################################################################################
# user-friendly versions of hSVD
# hierachical SVD "hSVD"
#######################################################################################


Expand Down Expand Up @@ -99,7 +85,7 @@ def hsvd_rank(
[1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016.
[2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018.
"""
_check_SVD_input(A) # check if A is suitable input
sanitize_in_nd_realfloating(A, "A", [2])
A_local_size = max(A.lshape_map[:, 1])

if maxmergedim is not None and maxmergedim < 2 * (maxrank + safetyshift) + 1:
Expand Down Expand Up @@ -202,7 +188,7 @@ def hsvd_rtol(
[1] Iwen, Ong. A distributed and incremental SVD algorithm for agglomerative data analysis on large networks. SIAM J. Matrix Anal. Appl., 37(4), 2016.
[2] Himpe, Leibner, Rave. Hierarchical approximate proper orthogonal decomposition. SIAM J. Sci. Comput., 40 (5), 2018.
"""
_check_SVD_input(A) # check if A is suitable input
sanitize_in_nd_realfloating(A, "A", [2])
A_local_size = max(A.lshape_map[:, 1])

if maxmergedim is not None and maxrank is None:
Expand Down Expand Up @@ -248,11 +234,6 @@ def hsvd_rtol(
)


################################################################################################
# hSVD - "full" routine for the experts
################################################################################################


def hsvd(
A: DNDarray,
maxrank: Optional[int] = None,
Expand Down Expand Up @@ -334,7 +315,7 @@ def hsvd(
"\t\t".join(["%d" % an for an in active_nodes]),
)

U_loc, sigma_loc, err_squared_loc = compute_local_truncated_svd(
U_loc, sigma_loc, err_squared_loc = _compute_local_truncated_svd(
level, A.comm.rank, A.larray, maxrank, loc_atol, safetyshift
)
U_loc = torch.matmul(U_loc, torch.diag(sigma_loc))
Expand Down Expand Up @@ -412,7 +393,7 @@ def hsvd(

if len(future_nodes) == 1:
safetyshift = 0
U_loc, sigma_loc, err_squared_loc_new = compute_local_truncated_svd(
U_loc, sigma_loc, err_squared_loc_new = _compute_local_truncated_svd(
level, A.comm.rank, U_loc, maxrank, loc_atol, safetyshift
)

Expand Down Expand Up @@ -466,12 +447,7 @@ def hsvd(
return U, rel_error_estimate


##############################################################################################
# AUXILIARY ROUTINES
##############################################################################################


def compute_local_truncated_svd(
def _compute_local_truncated_svd(
level: int,
proc_id: int,
U_loc: torch.Tensor,
Expand Down Expand Up @@ -528,7 +504,7 @@ def compute_local_truncated_svd(


##############################################################################################
# Randomized SVD
# Randomized SVD "rSVD"
##############################################################################################


Expand Down Expand Up @@ -568,7 +544,7 @@ def rsvd(
-----------
[1] Halko, N., Martinsson, P. G., & Tropp, J. A. (2011). Finding structure with randomness: Probabilistic algorithms for constructing approximate matrix decompositions. SIAM review, 53(2), 217-288.
"""
_check_SVD_input(A) # check if A is suitable input
sanitize_in_nd_realfloating(A, "A", [2])
if not isinstance(rank, int):
raise TypeError(f"rank must be an integer, but is {type(rank)}.")
if rank < 1:
Expand Down Expand Up @@ -614,3 +590,199 @@ def rsvd(
V = V[:, :rank]
V.balance_()
return U, S, V


##############################################################################################
# Incremental SVD "iSVD"
##############################################################################################


def _isvd(
new_data: DNDarray,
U_old: DNDarray,
S_old: DNDarray,
V_old: Optional[DNDarray] = None,
maxrank: Optional[int] = None,
old_matrix_size: Optional[int] = None,
old_rowwise_mean: Optional[DNDarray] = None,
) -> Union[Tuple[DNDarray, DNDarray, DNDarray], Tuple[DNDarray, DNDarray, DNDarray, DNDarray]]:
"""
Helper function for iSVD and iPCA; follows roughly the "incremental PCA with mean update", Fig.1 in:
David A. Ross, Jongwoo Lim, Ruei-Sung Lin, Ming-Hsuan Yang. Incremental Learning for Robust Visual Tracking. IJCV, 2008.
Either incremental SVD / PCA or incremental SVD / PCA with mean subtraction is performed.
Parameters
-----------
new_data: DNDarray
new data as DNDarray
U_old, S_old, V_old: DNDarrays
"old" SVD-factors
if no V_old is provided, only U and S are computed (PCA)
maxrank: int, optional
rank to which new SVD should be truncated
old_matrix_size: int, optional
size of the old matrix; this does not need to be identical to V_old.shape[0] as "old" SVD might have been truncated
old_rowwise_mean: int, optional
row-wise mean of the old matrix; if not provided, no mean subtraction is performed
"""
# old SVD is SVD of a matrix of dimension m x n as has rank r
# new data have shape m x d
d = new_data.shape[1]
n = V_old.shape[0] if V_old is not None else old_matrix_size
r = S_old.shape[0]
if maxrank is None:
maxrank = min(n + d, U_old.shape[0])
else:
maxrank = min(maxrank, min(n + d, U_old.shape[0]))

if old_rowwise_mean is not None:
new_data_rowwise_mean = statistics.mean(new_data, axis=1)
new_rowwise_mean = (old_matrix_size * old_rowwise_mean + d * new_data_rowwise_mean) / (
old_matrix_size + d
)
new_data -= new_data_rowwise_mean.reshape(-1, 1)
new_data = hstack(
[
new_data,
(new_data_rowwise_mean - old_rowwise_mean)
* (d * old_matrix_size / (d + old_matrix_size)) ** 0.5,
]
)
d += 1

# orthogonalize and decompose new_data
UtC = U_old.T @ new_data
if U_old.split is not None:
new_data = new_data.resplit_(U_old.split) - U_old @ UtC
else:
new_data = new_data - (U_old @ UtC).resplit_(new_data.split)
P, Rc = qr(new_data)

# prepare one component of "new" V-factor
if V_old is not None:
V_new = vstack(
[
V_old,
factories.zeros(
(d, r),
device=V_old.device,
dtype=V_old.dtype,
split=V_old.split,
comm=V_old.comm,
),
]
)
helper = vstack(
[
factories.zeros(
(n, d),
device=V_old.device,
dtype=V_old.dtype,
split=V_old.split,
comm=V_old.comm,
),
factories.eye(
d, device=V_old.device, dtype=V_old.dtype, split=V_old.split, comm=V_old.comm
),
]
)
V_new = hstack([V_new, helper])
del helper

# prepare one component of "new" U-factor
U_new = hstack([U_old, P])

# prepare "inner" matrix that needs to be decomposed, decompose it
helper1 = vstack(
[
diag(S_old),
factories.zeros(
(Rc.shape[0] + UtC.shape[0] - r, r),
device=S_old.device,
dtype=S_old.dtype,
split=S_old.split,
comm=S_old.comm,
),
]
)
if r > d:
Rc = Rc.resplit_(UtC.split)
else:
UtC = UtC.resplit_(Rc.split)
helper2 = vstack([UtC, Rc])
innermat = hstack([helper1, helper2])
del (helper1, helper2)
# as innermat is small enough to fit into memory of a single process, we can use torch svd
u, s, v = svd.svd(innermat.resplit_(None))
del innermat

# truncate if desired
if maxrank < s.shape[0]:
u = u[:, :maxrank]
s = s[:maxrank]
v = v[:, :maxrank]

U_new = U_new @ u
if V_old is not None:
V_new = V_new @ v

if V_old is not None: # use-case: SVD
return U_new, s, V_new
if old_rowwise_mean is not None: # use-case PCA
return U_new, s, new_rowwise_mean


def isvd(
new_data: DNDarray,
U_old: DNDarray,
S_old: DNDarray,
V_old: DNDarray,
maxrank: Optional[int] = None,
) -> Tuple[DNDarray, DNDarray, DNDarray]:
r"""Incremental SVD (iSVD) for the addition of new data to an existing SVD.
Given the the SVD of an "old" matrix, :math:`X_\textnormal{old} = `U_\textnormal{old} \cdot S_\textnormal{old} \cdot V_\textnormal{old}^T`, and additional columns :math:`N` (\"`new_data`\"), this routine computes
(a possibly approximate) SVD of the extended matrix :math:`X_\textnormal{new} = [ X_\textnormal{old} | N]`.
Parameters
----------
new_data : DNDarray
2D-array (float32/64) of columns that are added to the "old" SVD. It must hold `new_data.split != 1` if `U_old.split = 0`.
U_old : DNDarray
U-factor of the SVD of the "old" matrix, 2D-array (float32/64). It must hold `U_old.split != 0` if `new_data.split = 1`.
S_old : DNDarray
Sigma-factor of the SVD of the "old" matrix, 1D-array (float32/64)
V_old : DNDarray
V-factor of the SVD of the "old" matrix, 2D-array (float32/64)
maxrank : int, optional
truncation rank of the SVD of the extended matrix. The default is None, i.e., no bound on the maximal rank is imposed.
Notes
-----------
Inexactness may arise due to truncation to maximal rank `maxrank` if rank of the data to be processed exceeds this rank.
If you set `maxrank` to a high number (or None) in order to avoid inexactness, you may encounter memory issues.
The implementation follows the approach described in Ref. [1], Sect. 2.
References
------------
[1] Brand, M. (2006). Fast low-rank modifications of the thin singular value decomposition. Linear algebra and its applications, 415(1), 20-30.
"""
# check if new_data, U_old, V_old are 2D DNDarrays and float32/64
sanitize_in_nd_realfloating(new_data, "new_data", [2])
sanitize_in_nd_realfloating(U_old, "U_old", [2])
sanitize_in_nd_realfloating(S_old, "S_old", [1])
sanitize_in_nd_realfloating(V_old, "V_old", [2])
# check if number of columns of U_old and V_old match the number of elements in S_old
if U_old.shape[1] != S_old.shape[0]:
raise ValueError(
"The number of columns of U_old must match the number of elements in S_old."
)
if V_old.shape[1] != S_old.shape[0]:
raise ValueError(
"The number of columns of V_old must match the number of elements in S_old."
)
# check if the number of columns of new_data matches the number of rows of U_old and V_old
if new_data.shape[0] != U_old.shape[0]:
raise ValueError("The number of rows of new_data must match the number of rows of U_old.")

return _isvd(new_data, U_old, S_old, V_old, maxrank)
Loading

0 comments on commit c8702a8

Please sign in to comment.