Skip to content

Commit

Permalink
Merge pull request #33 from mancusolab/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
zeyunlu authored Nov 20, 2023
2 parents 7d9fa02 + 93806a1 commit b64f0aa
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 38 deletions.
17 changes: 14 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@
:target: https://pyscaffold.org/


======
SuShiE
======
========
SuShiE🍣
========
SuShiE (Sum of SHared sIngle Effect) is a Python software to fine-map causal SNPs, compute prediction weights, and infer effect size correlation across multiple ancestries. **The manuscript is in progress.**

.. code:: diff
Expand Down Expand Up @@ -106,6 +106,17 @@ It can perform:

See `here <https://mancusolab.github.io/sushie/>`_ for more details on how to use SuShiE.

If you want to use in-software SuShiE inference function, you can use following code as an example:

.. code:: python
from sushie.infer import infer_sushie
# Xs is for genotype data, and it should be a list of numpy array whose length is the number of ancestry.
# ys is for phenotype data, and it should also be a list of numpy array whose length is the number of ancestry.
infer_sushie(Xs=X, ys=y)
You can play it with your own ideas!

.. _Notes:
.. |Notes| replace:: **Notes**

Expand Down
5 changes: 5 additions & 0 deletions docs/manual.rst
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,11 @@ Parameters
- False
- ``--rint``
- Indicator to perform rank inverse normalization transformation (rint) for each phenotype data. Default is False (do not transform). Specify --rint will store 'True' value. We suggest users to do this QC during data preparation.
* - ``--no-reorder``
- Boolean
- False
- ``--no-reorder``
- Indicator to re-order single effects based on Frobenius norm of alpha-weighted posterior mean square. Default is False (to re-order). Specify --no-reorder will store 'True' value.
* - ``--meta``
- Boolean
- False
Expand Down
28 changes: 19 additions & 9 deletions sushie/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ def parameter_check(
)

if n_geno > 1:
log.logger.warning(
log.logger.info(
f"Detect {n_geno} genotypes, will only use one genotypes in the order of 'plink, vcf, and bgen'"
)

Expand Down Expand Up @@ -495,11 +495,6 @@ def parameter_check(
"The number of folds in cross validation is invalid."
+ " Choose some number greater than 1.",
)
elif args.cv_num > 5:
log.logger.warning(
"The number of folds in cross validation is too large."
+ " It may increase running time.",
)

if args.maf <= 0 or args.maf > 0.5:
raise ValueError(
Expand All @@ -508,7 +503,7 @@ def parameter_check(
)

if (args.meta or args.mega) and n_pop == 1:
log.logger.warning(
log.logger.info(
"The number of ancestry is 1, but --meta or --mega is specified. Will skip meta or mega SuSiE."
)

Expand Down Expand Up @@ -636,7 +631,7 @@ def process_raw(
)

if imp_num != 0:
log.logger.warning(
log.logger.debug(
f"Ancestry {idx + 1}: Impute {imp_num} out of {old_snp_num} SNPs with NAN value based on allele"
+ " frequency."
)
Expand Down Expand Up @@ -910,6 +905,7 @@ def sushie_wrapper(
purity=args.purity,
max_select=args.max_select,
min_snps=args.min_snps,
no_reorder=args.no_reorder,
seed=args.seed,
)
pips_all.append(tmp_result.pip_all[:, jnp.newaxis])
Expand All @@ -919,6 +915,7 @@ def sushie_wrapper(
pips_all = utils.make_pip(jnp.concatenate(pips_all, axis=1).T)
pips_cs = utils.make_pip(jnp.concatenate(pips_cs, axis=1).T)
else:
# normal sushie and mega sushie can use the same wrapper function
if mega:
log.logger.info(
f"Start fine-mapping using Mega SuSiE with {args.L} effects because --mega is specified."
Expand All @@ -944,6 +941,7 @@ def sushie_wrapper(
purity=args.purity,
max_select=args.max_select,
min_snps=args.min_snps,
no_reorder=args.no_reorder,
seed=args.seed,
)
result.append(tmp_result)
Expand Down Expand Up @@ -1337,7 +1335,7 @@ def build_finemap_parser(subp):
type=int,
help=(
"The minimum number of SNPs to fine-map. Default is 100.",
" It has to be positive integer number. A smaller number may produce weird results.",
" It has to be positive integer number.",
),
)

Expand All @@ -1363,6 +1361,18 @@ def build_finemap_parser(subp):
),
)

finemap.add_argument(
"--no-reorder",
default=False,
action="store_true",
help=(
"Indicator to re-order single effects based on Frobenius norm of alpha-weighted",
" posterior mean square.",
" Default is False (to re-order).",
" Specify --no-reorder will store 'True' value.",
),
)

# I/O option
finemap.add_argument(
"--meta",
Expand Down
72 changes: 49 additions & 23 deletions sushie/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,10 @@ class SushieResult(NamedTuple):
sample_size: The sample size for each ancestry in the inference.
elbo: The final ELBO.
elbo_increase: A boolean to indicate whether ELBO increases during the optimizations.
l_order: The original order that SuShiE infers. For example, if L=3 and it is 0,2,1,
then the original SuShiE's second effect (0-based index 1) is now third,
and the original SuShiE's third effect (0-based index 2) is now second
after sorting use Frobenius norm.
"""

Expand All @@ -82,6 +86,7 @@ class SushieResult(NamedTuple):
sample_size: Array
elbo: Array
elbo_increase: bool
l_order: Array


class _PriorAdjustor(NamedTuple):
Expand Down Expand Up @@ -172,6 +177,7 @@ def infer_sushie(
purity: float = 0.5,
max_select: int = 500,
min_snps: int = 100,
no_reorder: bool = False,
seed: int = 12345,
) -> SushieResult:
"""The main inference function for running SuShiE.
Expand All @@ -198,6 +204,8 @@ def infer_sushie(
purity: The minimum pairwise correlation across SNPs to be eligible as output credible set.
max_select: The maximum number of selected SNPs to compute purity.
min_snps: The minimum number of SNPs to fine-map.
no_reorder: Do not re-order single effects based on Frobenius norm of alpha-weighted posterior mean square.
Default is to re-order.
seed: The randomization seed for selecting SNPs in the credible set to compute purity.
Returns:
Expand Down Expand Up @@ -311,7 +319,6 @@ def infer_sushie(
raise ValueError(
f"The number of common SNPs across ancestries ({n_snps}) is less than minimum common "
+ "number of SNPs specified. Please expand the genomic window."
+ "We do not recommend to set --min-snps less than 50 as it may give weird results."
)

param_effect_var = effect_var
Expand Down Expand Up @@ -490,9 +497,12 @@ def infer_sushie(
+ f" Reach maximum iteration threshold {max_iter}.",
)

l_order = jnp.arange(L)
if not no_reorder:
priors, posteriors, l_order = _reorder_l(priors, posteriors)

cs, full_alphas, pip_all, pip_cs = make_cs(
posteriors.alpha,
posteriors.weighted_sum_covar,
Xs,
ns,
threshold,
Expand All @@ -511,6 +521,7 @@ def infer_sushie(
ns,
elbo_tracker,
elbo_increase,
l_order,
)


Expand Down Expand Up @@ -737,9 +748,31 @@ def _erss(X: ArrayLike, y: ArrayLike, beta: ArrayLike, beta_sq: ArrayLike) -> Ar
return term_1 + term_2


def _reorder_l(priors: Prior, posteriors: Posterior) -> Tuple[Prior, Posterior, Array]:

frob_norm = jnp.sum(
jnp.linalg.svd(posteriors.weighted_sum_covar, compute_uv=False), axis=1
)

# we want to reorder them based on the Frobenius norm
l_order = jnp.argsort(-frob_norm)

# priors effect_covar
priors = priors._replace(effect_covar=priors.effect_covar[l_order])

posteriors = posteriors._replace(
alpha=posteriors.alpha[l_order],
post_mean=posteriors.post_mean[l_order],
post_mean_sq=posteriors.post_mean_sq[l_order],
weighted_sum_covar=posteriors.weighted_sum_covar[l_order],
kl=posteriors.kl[l_order],
)

return priors, posteriors, l_order


def make_cs(
alpha: ArrayLike,
prior_covar: ArrayLike,
Xs: ArrayLike,
ns: ArrayLike,
threshold: float = 0.9,
Expand All @@ -752,8 +785,6 @@ def make_cs(
Args:
alpha: :math:`L \\times p` matrix that contains posterior probability for SNP to be causal
(i.e., :math:`\\alpha` in :ref:`Model`).
prior_covar: :math:`L \\times k \\times k` matrix that contains prior covariance for each credible set
(i.e., :math:`C` in :ref:`Model`).
Xs: Genotype data for multiple ancestries.
ns: Sample size for each ancestry.
threshold: The credible set threshold.
Expand All @@ -773,19 +804,11 @@ def make_cs(
rng_key = random.PRNGKey(seed)
n_l, n_snps = alpha.shape
t_alpha = pd.DataFrame(alpha.T).reset_index()
frob_norm = jnp.sum(jnp.linalg.svd(prior_covar, compute_uv=False), axis=1)

# we want to reorder them based on the Frobenius norm
new_order = jnp.argsort(-frob_norm)

cs = pd.DataFrame(columns=["CSIndex", "SNPIndex", "alpha", "c_alpha"])
full_alphas = t_alpha[["index"]]

# new CS index name
new_ldx = 0
for ldx in new_order.tolist():
new_ldx += 1

for ldx in range(n_l):
# select original index and alpha
tmp_pd = (
t_alpha[["index", ldx]]
Expand All @@ -801,9 +824,10 @@ def make_cs(
else:
select_idx = jnp.arange(n_row + 1)

# output CS Index is 1-based
tmp_cs = (
tmp_pd.iloc[select_idx, :]
.assign(CSIndex=new_ldx)
.assign(CSIndex=(ldx + 1))
.rename(columns={"csum": "c_alpha", "index": "SNPIndex", ldx: "alpha"})
)

Expand All @@ -812,8 +836,8 @@ def make_cs(
# prepare alphas table's entries
tmp_pd = tmp_pd.drop(["csum"], axis=1).rename(
columns={
"in_cs": f"in_cs_l{new_ldx}",
ldx: f"alpha_l{new_ldx}",
"in_cs": f"in_cs_l{ldx + 1}",
ldx: f"alpha_l{ldx + 1}",
}
)

Expand All @@ -838,17 +862,19 @@ def make_cs(
jnp.min(jnp.abs(ld), axis=(1, 2))[:, jnp.newaxis] * ss_weight
)

full_alphas[f"purity_l{new_ldx}"] = avg_corr
full_alphas[f"purity_l{ldx + 1}"] = avg_corr

if avg_corr > purity:
cs = pd.concat([cs, tmp_cs], ignore_index=True)
full_alphas[f"kept_l{new_ldx}"] = 1
full_alphas[f"kept_l{ldx + 1}"] = 1
else:
full_alphas[f"kept_l{new_ldx}"] = 0
full_alphas[f"kept_l{ldx + 1}"] = 0

pip_all = utils.make_pip(alpha)

# CSIndex is now 1-based
pip_cs = utils.make_pip(
alpha[new_order][
alpha[
(cs.CSIndex.unique().astype(int) - 1),
]
)
Expand All @@ -858,8 +884,8 @@ def make_cs(

if len(n_snp_cs) != len(n_snp_cs_unique):
log.logger.warning(
"Same SNPs appear in different credible set."
+ " This is considered weired results. It may be due to not enough SNPs to fine-map."
"Same SNPs appear in different credible set, which is very unusual."
+ " You may want to check this gene in details."
)

cs["pip_all"] = jnp.array([pip_all[idx] for idx in cs.SNPIndex.values.astype(int)])
Expand Down
20 changes: 17 additions & 3 deletions sushie/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ def read_data(
)
else:
covar = None
# it has some warnings. It's okay to ingore them.

# it has some pycharm warnings. It's okay to ingore them.
# I couldn't think of a way to remove these warnings other than pre-specify them before for loops
# but the codes will look silly
tmp_bim = bim
Expand Down Expand Up @@ -372,11 +373,24 @@ def output_weights(
tmp_weights[cname_pip_all] = result[idx].pip_all
tmp_weights[cname_pip_cs] = result[idx].pip_cs
weights = pd.concat([weights, tmp_weights], axis=1)

df_cs = (
result[idx]
.cs[["SNPIndex", "CSIndex"]]
.groupby("SNPIndex")["CSIndex"]
.agg(lambda x: ",".join(x.astype(str)))
.reset_index()
)

# although for super rare cases, we have the same snp in more credible sets
# to record this situation in the weights file (we introduce WARNING in the inference function),
# we just concatenate the CS index with comma by creating this tmp_cs pandas data frame
tmp_cs = (
weights[["SNPIndex"]]
.merge(result[idx].cs[["SNPIndex", "CSIndex"]], on="SNPIndex", how="left")
.fillna(0)
.merge(df_cs, on="SNPIndex", how="left")
.fillna("No CS")
)

weights = weights.merge(
tmp_cs.rename(columns={"CSIndex": cname_cs}), on="SNPIndex"
)
Expand Down

0 comments on commit b64f0aa

Please sign in to comment.