Skip to content

Commit

Permalink
implementation of zolo pd
Browse files Browse the repository at this point in the history
  • Loading branch information
Hoppe committed Jan 20, 2025
1 parent 5e19ccd commit 105a48d
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 54 deletions.
137 changes: 98 additions & 39 deletions heat/core/linalg/pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch
from typing import Type, Callable, Dict, Any, TypeVar, Union, Tuple

from ..communication import MPICommunication
from ..communication import MPICommunication, MPI
from ..dndarray import DNDarray
from .. import factories
from .. import types
Expand All @@ -19,7 +19,6 @@
from ..manipulations import vstack, hstack, concatenate, diag, balance
from ..exponential import sqrt
from .. import statistics
from mpi4py import MPI

from scipy.special import ellipj
from scipy.special import ellipkm1
Expand Down Expand Up @@ -131,7 +130,7 @@ def _in_place_qr_with_q_only(A: DNDarray, procs_to_merge: int = 2) -> None:

def pd(
A: DNDarray,
r: int = 8,
r: int = None,
calcH: bool = True,
condition_estimate: float = 0.0,
silent: bool = True,
Expand All @@ -145,9 +144,10 @@ def pd(
A : ht.DNDarray,
The input matrix for which the polar decomposition is computed;
must be two-dimensional, of data type float32 or float64, and must have at least as many rows as columns.
r : int, optional, default: 8
The parameter r used in the Zolotarev-PD algorithm; must be an integer between 1 and 8.
r : int, optional, default: None
The parameter r used in the Zolotarev-PD algorithm; if provided, must be an integer between 1 and 8 that divides the number of MPI processes.
Higher values of r lead to faster convergence, but memory consumption is proportional to r.
If not provided, the largest 1 <= r <= 8 that divides the number of MPI processes is chosen.
calcH : bool, optional, default: True
If True, the function returns the symmetric, positive definite matrix H. If False, only the orthogonal matrix U is returned.
condition_estimate : float, optional, default: 0.
Expand Down Expand Up @@ -175,6 +175,7 @@ def pd(
raise ValueError(
f"Input ``A`` must have at least as many rows as columns, but has shape {A.shape}."
)

# check if A is a real floating point matrix and choose tolerances tol accordingly
if A.dtype == types.float32:
tol = 1.19e-7
Expand All @@ -186,17 +187,42 @@ def pd(
)

# check if input for r is reasonable
if not isinstance(r, int) or r < 1 or r > 8:
raise ValueError(
f"If specified, input ``r`` must be an integer between 1 and 8, but is {r} of data type {type(r)}."
)
if r is not None:
if not isinstance(r, int) or r < 1 or r > 8:
raise ValueError(
f"If specified, input ``r`` must be an integer between 1 and 8, but is {r} of data type {type(r)}."
)
if A.is_distributed() and (A.comm.size % r != 0 or A.comm.size == r):
raise ValueError(
f"If specified, input ``r`` must be a non-trivial divisor of the number MPI processes , but r={r} and A.comm.size={A.comm.size}."
)
else:
for i in range(8, 0, -1):
if A.comm.size % i == 0 and A.comm.size // i > 1:
r = i
break
if not silent:
if A.comm.rank == 0:
print(f"Automatically chosen r={r}.")

# check if input for condition_estimate is reasonable
if not isinstance(condition_estimate, float):
raise TypeError(
f"If specified, input ``condition_estimate`` must be a float but is {type(condition_estimate)}."
)

# early out for the non-distributed case
if not A.is_distributed():
U, s, vh = torch.linalg.svd(A.larray, full_matrices=False)
U @= vh
H = vh.T @ torch.diag(s) @ vh
if calcH:
return factories.array(U, is_split=None, comm=A.comm), factories.array(
H, is_split=None, comm=A.comm
)
else:
return factories.array(U, is_split=None, comm=A.comm)

alpha = _estimate_largest_singularvalue(A).item()

if condition_estimate <= 1.0:
Expand All @@ -212,6 +238,30 @@ def pd(
# initialize X for the iteration: input ``A``, normalized by largest singular value
X = A / alpha

# each of these communicators has size r, along these communicators we parallelize the r many QR decompositions that are performed in parallel
horizontal_comm = A.comm.Split(A.comm.rank // r, A.comm.rank)

# each of these communicators has size MPI_WORLD.size / r and will carray a full copy of X for QR decomposition
vertical_comm = A.comm.Split(A.comm.rank % r, A.comm.rank)

# in each horizontal communicator, collect the local array of X from all processes
local_shapes = horizontal_comm.allgather(A.lshape[A.split])
new_local_shape = (
(sum(local_shapes), A.shape[1]) if A.split == 0 else (A.shape[0], sum(local_shapes))
)
counts = tuple(local_shapes)
displacements = tuple(np.cumsum([0] + list(local_shapes))[:-1])
X_collected_local = torch.zeros(
new_local_shape, dtype=A.dtype.torch_type(), device=A.device.torch_device
)
horizontal_comm.Allgatherv(
X.larray, (X_collected_local, counts, displacements), recv_axis=A.split
)
del X

X = factories.array(X_collected_local, is_split=A.split, comm=vertical_comm)
X.balance_()

# iteration counter and maximum number of iterations
it = 0
itmax = _zolopd_n_iterations(r, kappa)
Expand All @@ -220,6 +270,7 @@ def pd(
ell = 1.0 / kappa
c, a, Mhat = _compute_zolotarev_coefficients(r, ell, A.device, dtype=A.dtype)

itmax = _zolopd_n_iterations(r, kappa)
while it < itmax:
it += 1
if not silent:
Expand All @@ -228,39 +279,23 @@ def pd(
# remember current X for later convergence check
X_old = X.copy()

# repeat X r-times and create (repeated) identity matrix
# this allows to compute the r-many QR decomposition and matrix multiplications in batch-parallel manor
X = factories.array(
X.larray.repeat(r, 1, 1),
is_split=X.split + 1 if X.split is not None else None,
comm=A.comm,
)
cId = factories.eye(A.shape[1], dtype=A.dtype, comm=A.comm, split=A.split, device=A.device)
cId = factories.array(
cId.larray.repeat(r, 1, 1),
is_split=cId.split + 1 if cId.split is not None else None,
comm=A.comm,
)
cId *= c[0::2].reshape(-1, 1, 1) ** 0.5
X = concatenate([X, cId], axis=1)
cId = factories.eye(X.shape[1], dtype=X.dtype, comm=X.comm, split=X.split, device=X.device)
cId *= c[2 * horizontal_comm.rank].item() ** 0.5
X = concatenate([X, cId], axis=0)
del cId
# if A.split == 0:
Q, R = qr(X)
del R
Q1 = Q[:, : A.shape[0], :].balance()
Q2 = Q[:, A.shape[0] :, :].transpose([0, 2, 1]).balance()
del Q
# elif A.split == 1:
# _in_place_qr_with_q_only(X)
# Q1 = X[:, : A.shape[0], : ]
# Q2 = X[:, A.shape[0] :, : ].transpose([0, 2, 1])
Q1 = Q[: A.shape[0], :].balance()
Q2 = Q[A.shape[0] :, :].transpose().balance()
Q1Q2 = matmul(Q1, Q2)
del Q1, Q2
X = X[:, : A.shape[0], :].balance() / r
X = Mhat * (X + a.reshape(-1, 1, 1) / c[0::2].reshape(-1, 1, 1) ** 0.5 * Q1Q2)
X = X[: A.shape[0], :].balance()
X /= r
X += a[horizontal_comm.rank].item() / c[2 * horizontal_comm.rank].item() ** 0.5 * Q1Q2
del Q1Q2
# finally, sum over the batch-dimension to get back the result of the iteration
X = X.sum(axis=0)
X *= Mhat.item()
# finally, sum over the horizontal communicators
horizontal_comm.Allreduce(MPI.IN_PLACE, X.larray, op=MPI.SUM)

# check for convergence and break if tolerance is reached
if it > 1 and matrix_norm(X - X_old, ord="fro") / matrix_norm(X, ord="fro") <= tol ** (
Expand All @@ -284,10 +319,34 @@ def pd(
print(
f"Zolotarev-PD iteration did not reach the convergence criterion after {itmax} iterations, which is most likely due to limited numerical accuracy and/or poor estimation of the condition number. The result may still be useful, but should be handeled with care!"
)

# as every process has much more data than required, we need to split the result into the parts that are actually
counts = [
X.lshape[X.split] // horizontal_comm.size + (r < X.lshape[X.split] % horizontal_comm.size)
for r in range(horizontal_comm.size)
]
displacements = [sum(counts[:r]) for r in range(horizontal_comm.size)]

if A.split == 1:
U_local = X.larray[
:,
displacements[horizontal_comm.rank] : displacements[horizontal_comm.rank]
+ counts[horizontal_comm.rank],
]
else:
U_local = X.larray[
displacements[horizontal_comm.rank] : displacements[horizontal_comm.rank]
+ counts[horizontal_comm.rank],
:,
]
U = factories.array(U_local, is_split=A.split, comm=A.comm)
del X
U.balance_()

# postprocessing: compute H if requested
if calcH:
H = matmul(X.T, A)
H = matmul(U.T, A)
H = 0.5 * (H + H.T.resplit(H.split))
return X, H.resplit(A.split)
return U, H.resplit(A.split)
else:
return X
return U
43 changes: 28 additions & 15 deletions heat/core/linalg/tests/test_pd.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,24 @@ def test_catch_wrong_inputs(self):

def test_pd_split0(self):
# split=0, float32, no condition estimate provided, silent mode
ht.random.seed(18112024)
for r in range(1, 9):
A = ht.random.randn(100, 10 * r, split=0, dtype=ht.float32)
U, H = ht.pd(A, r=r)
dtypetol = 1e-4

self._check_pd(A, U, H, dtypetol)
with self.subTest(r=r):
ht.random.seed(18112024)
A = ht.random.randn(100, 10 * r, split=0, dtype=ht.float32)
if (
ht.MPI_WORLD.size % r == 0 and ht.MPI_WORLD.size != r
) or ht.MPI_WORLD.size == 1:
U, H = ht.pd(A, r=r)
dtypetol = 1e-4
self._check_pd(A, U, H, dtypetol)
else:
with self.assertRaises(ValueError):
U, H = ht.pd(A, r=r)

# cases not covered so far
A = ht.random.randn(100, 100, split=0, dtype=ht.float64)
U, H = ht.pd(A, condition_estimate=1.0e16, silent=False)
dtypetol = 1e-8
dtypetol = 1e-7

self._check_pd(A, U, H, dtypetol)

Expand All @@ -80,25 +86,32 @@ def test_pd_split0(self):

def test_pd_split1(self):
# split=1, float64, condition estimate provided, non-silent mode
ht.random.seed(623)
for r in range(1, 9):
A = ht.random.randn(100, 99, split=1, dtype=ht.float64)
U, H = ht.pd(A, r=r, silent=False, condition_estimate=1.0e16)
dtypetol = 1e-8
with self.subTest(r=r):
ht.random.seed(623)
A = ht.random.randn(100, 99, split=1, dtype=ht.float64)
if (
ht.MPI_WORLD.size % r == 0 and ht.MPI_WORLD.size != r
) or ht.MPI_WORLD.size == 1:
U, H = ht.pd(A, r=r, silent=False, condition_estimate=1.0e16)
dtypetol = 1e-7

self._check_pd(A, U, H, dtypetol)
self._check_pd(A, U, H, dtypetol)
else:
with self.assertRaises(ValueError):
U, H = ht.pd(A, r=r)

# cases not covered so far
A = ht.random.randn(100, 99, split=1, dtype=ht.float32)
U, H = ht.pd(A)
U, H = ht.pd(A, silent=False, condition_estimate=1.0e16)
dtypetol = 1e-4
self._check_pd(A, U, H, dtypetol)

# case without calculating H
A = ht.random.randn(100, 100, split=1, dtype=ht.float64)
U = ht.pd(A, calcH=False)
U = ht.pd(A, calcH=False, condition_estimate=1.0e16)
U_np = U.numpy()
self.assertTrue(np.allclose(U_np.T @ U_np, np.eye(U_np.shape[1]), atol=1e-8, rtol=1e-8))
self.assertTrue(np.allclose(U_np.T @ U_np, np.eye(U_np.shape[1]), atol=1e-7, rtol=1e-7))
H_np = U_np.T @ A.numpy()
self.assertTrue(np.allclose(H_np.T, H_np, atol=1e-8, rtol=1e-8))
self.assertTrue((np.linalg.eigvalsh(H_np) > 0).all())

0 comments on commit 105a48d

Please sign in to comment.