Skip to content

Commit

Permalink
remove standardization in design matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Dong555 committed Oct 27, 2023
1 parent 2b16690 commit 80f84b1
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions susiepca/infer_design_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _update_beta(G: ArrayLike, params: ModelParams_Design) -> ModelParams_Design
G_op = lx.MatrixLinearOperator(G)

# Use lineax's CG solver
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
solver = lx.NormalCG(rtol=1e-3, atol=1e-3)
out = _multi_linear_solve(G_op, params.mu_z, solver)

# Updated beta
Expand All @@ -289,7 +289,8 @@ def _update_beta(G: ArrayLike, params: ModelParams_Design) -> ModelParams_Design
@dispatch
def _update_beta(G: SparseMatrix, params: ModelParams_Design) -> ModelParams_Design:
# Use lineax's CG solver
solver = lx.NormalCG(rtol=1e-6, atol=1e-6)
solver = lx.NormalCG(rtol=1e-3, atol=1e-3)

out = jax.vmap(lambda b: lx.linear_solve(G, b, solver), in_axes=1)(params.mu_z)

# Updated beta
Expand Down Expand Up @@ -325,17 +326,12 @@ def compute_elbo(
# calculation tip: tr(A @ A.T) = tr(A.T @ A) = sum(A ** 2)
# (X.T @ E[Z] @ E[W]) is p x p (big!); compute (E[W] @ X.T @ E[Z]) (k x k)
E_ll = (-0.5 * params.tau) * (
# params.ssq
jnp.sum(X**2)
- 2 * jnp.einsum("kp,np,nk->", E_W, X, params.mu_z) # tr(E[W] @ X.T @ E[Z])
# - 2 * jnp.sum((E_W @ X.T) @ params.mu_z)
+ jnp.einsum("ij,ji->", E_ZZ, E_WW) # tr(E[Z.T @ Z] @ E[W @ W.T])
) + 0.5 * n_dim * p_dim * jnp.log(
params.tau
) # -0.5 * n * p * log(1 / tau) = 0.5 * n * p * log(tau)
) + 0.5 * n_dim * p_dim * jnp.log(params.tau)

# neg-KL for Z
# negKL_z = 0.5 * (-jnp.trace(E_ZZ) + n_dim * z_dim + n_dim * _logdet(params.var_z))
Z_pred = G @ params.beta
negKL_z = -0.5 * (
jnp.trace(E_ZZ)
Expand Down Expand Up @@ -804,12 +800,8 @@ def susie_pca(
X = SparseMatrix(X, scale=standardize)
if isinstance(G, ArrayLike):
G = jnp.asarray(G)
# option to center the data
G -= jnp.mean(G, axis=0)
if standardize:
G /= jnp.std(G, axis=0)
elif isinstance(G, sparse.JAXSparse):
G = SparseMatrix(G, scale=standardize)
G = SparseMatrix(G)

# initialize PRNGkey and params
rng_key = random.PRNGKey(seed)
Expand Down

0 comments on commit 80f84b1

Please sign in to comment.