Skip to content

Commit

Permalink
Fix: updated finemapping methods
Browse files Browse the repository at this point in the history
  • Loading branch information
hlnicholls committed Oct 17, 2023
1 parent 20aeeea commit 2d15f5a
Show file tree
Hide file tree
Showing 8 changed files with 775 additions and 215 deletions.
31 changes: 23 additions & 8 deletions src/otg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,20 +316,34 @@ class FMDataExtractionStepConfig:


@dataclass
class FinemappingStepConfig:
"""Fine-mapping step requirements.
class DentistStepConfig:
"""DENTIST outlier detection step requirements.
Attributes:
fm_filtered_LDMatrix_path (str): Path for extracted and filtered LD matrix for locus.
fm_filtered_StudyLocus_path (str): Path for extracted and filtered summary statistics for locus.
finemapped_locus_out (str): Output path for fine-mapping results for locus.
fm_filtered_StudyLocus_out (str): Output path for fine-mapping results for locus.
"""

_target_: str = "otg.methods.finemapping.FinemappingStep"
fm_filtered_LDMatrix_path: str = MISSING
_target_: str = "otg.methods.dentist.DentistStep"
fm_filtered_StudyLocus_path: str = MISSING
finemapped_locus_out: str = MISSING
fm_filtered_StudyLocus_out: str = MISSING


@dataclass
class SuSiEStepConfig:
"""SuSiE fine-mapping step requirements.
Attributes:
fm_filtered_StudyLocus_path (str): Path for extracted and filtered summary statistics for locus.
fm_filtered_StudyLocus_path (str): Path for extracted and filtered LD matrix for locus.
"""

_target_: str = "otg.methods.susie.SuSiEStep"
fm_filtered_StudyLocus_path: str = MISSING
fm_filtered_LDMatrix_path: str = MISSING
n_sample: int = MISSING


# Register all configs
Expand All @@ -354,4 +368,5 @@ def register_configs() -> None:
)
cs.store(name="study_locus_overlap", group="step", node=StudyLocusOverlapStepConfig)
cs.store(name="fm_data_extraction", group="step", node=FMDataExtractionStepConfig)
cs.store(name="finemapping", group="step", node=FinemappingStepConfig)
cs.store(name="dentist", group="step", node=DentistStepConfig)
cs.store(name="susie", group="step", node=SuSiEStepConfig)
27 changes: 27 additions & 0 deletions src/otg/dataset/fm_data_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,30 @@ def get_matching_snps(
filtered_LDMatrix = filtered_LDMatrix.filter(filtered_LDMatrix['variantIdCol'].cast('string').isin(unique_variant_ids))
fm_filtered_LDMatrix = filtered_LDMatrix.drop("variantIdCol")
return fm_filtered_LDMatrix, fm_filtered_StudyLocus

def allele_flip_check(
self: FMDataExtraction,
fm_filtered_StudyLocus: DataFrame,
SNP_ids_38: list,
) -> DataFrame:
"""Check alleles match between LD matrix and summary statistics.
Function unused as currently not needed with gnomad data"""
df = self.session.createDataFrame(SNP_ids_38, StringType()).toDF("ID")

# Split the 'ID' column to extract 'ref' and 'alt' columns
df = df.withColumn("ref_LD", split(col("ID"), "_")[2])
df = df.withColumn("alt_LD", split(col("ID"), "_")[3])

# Extract alleles using PySpark string functions
allele_df = concordance_test.withColumn('allele_parts', F.split('SNP', '[:,_]'))
concordance_test = allele_df.withColumn('allele1_LD', allele_df['allele_parts'].getItem(1)).\
withColumn('allele2_LD', allele_df['allele_parts'].getItem(2))

# Join sumstat_filtered and concordance_test to align them
joint_df = sumstat_filtered.join(concordance_test, 'ID', 'inner')

# Flip z-scores if alleles are discordant
condition = (joint_df['ref'] != joint_df['ref_LD']) | (joint_df['alt'] != joint_df['alt_LD'])
sumstat_filtered = joint_df.withColumn('z', F.when(condition, -joint_df['z']).otherwise(joint_df['z']))
return fm_filtered_StudyLocus
33 changes: 33 additions & 0 deletions src/otg/dentist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
"""Step to run study locus fine-mapping."""

from __future__ import annotations

from dataclasses import dataclass

from otg.common.session import Session
from otg.config import DentistStepConfig
from otg.method.dentist import Dentist


@dataclass
class DentistStep(DentistStepConfig):
"""DENTIST outlier detection for an input locus"""

session: Session = Session()

def run(self: DentistStep) -> None:
"""Run DENTIST outlier detection step."""
self.session.logger.info(self.fm_filtered_StudyLocus_path)
self.session.logger.info(self.fm_filtered_StudyLocus_out)

StudyLocus_file_paths = [f"{self.fm_filtered_StudyLocus_path}*.snappy.parquet"]
fm_filtered_StudyLocus = self.session.read.parquet(*StudyLocus_file_paths)

fm_filtered_Locus_outliers = Dentist.calculate_dentist(
fm_filtered_StudyLocus,
)

# Write the output.
fm_filtered_Locus_outliers.df.write.mode(self.session.write_mode).parquet(
self.fm_filtered_StudyLocus_out
)
47 changes: 0 additions & 47 deletions src/otg/finemapping.py

This file was deleted.

91 changes: 91 additions & 0 deletions src/otg/method/dentist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
"""Step to run study locus fine-mapping."""

from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from otg.common.session import Session

if TYPE_CHECKING:
from pyspark.sql import DataFrame

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql.types import DoubleType
from scipy import stats

from otg.common.session import Session


@dataclass
class Dentist:
"""Dentist outlier detection
untested as it needs study locus with R2 column (LD for all variants with lead SNP)
"""

session: Session = Session()

@staticmethod
def calculate_dentist(
filtered_StudyLocus: Dataframe,
n_sample: int,
r2_threshold: float,
lead_snp_ID: str,
nlog10p_dentist_s_threshold: float,
) -> DataFrame:
"""Performs outlier detection using DENTIST."""
# need study locus summary statistics with columns: r2 with lead snp, beta, se, z
# Calculate 'r'
filtered_StudyLocus = filtered_StudyLocus.withColumn(
"r", (f.sum("R2") * n_sample) / (f.count("R2") * n_sample)
)

lead_idx_snp_row = filtered_StudyLocus.filter(
filtered_StudyLocus.ID == lead_snp_ID
).collect()[0]
lead_z = lead_idx_snp_row.beta / lead_idx_snp_row.se

# 2. Calculate 't_dentist_s' and 'dentist_outlier'
filtered_StudyLocus = filtered_StudyLocus.withColumn(
"t_dentist_s",
(
(
filtered_StudyLocus.beta / filtered_StudyLocus.se
- filtered_StudyLocus.r * lead_z
)
** 2
)
/ (1 - filtered_StudyLocus.r**2),
)
filtered_StudyLocus = filtered_StudyLocus.withColumn(
"t_dentist_s",
f.when(filtered_StudyLocus["t_dentist_s"] < 0, float("inf")).otherwise(
filtered_StudyLocus["t_dentist_s"]
),
)

def calc_nlog10p_dentist_s(t_dentist_s):
return stats.chi2.logsf(t_dentist_s, df=1) / -np.log(10)

calc_nlog10p_dentist_s_udf = f.udf(calc_nlog10p_dentist_s, DoubleType())
filtered_StudyLocus = filtered_StudyLocus.withColumn(
"nlog10p_dentist_s", calc_nlog10p_dentist_s_udf("t_dentist_s")
)

# Count the number of DENTIST outliers and creating new column
n_dentist_s_outlier = filtered_StudyLocus.filter(
(filtered_StudyLocus.R2 > r2_threshold)
& (filtered_StudyLocus.nlog10p_dentist_s > nlog10p_dentist_s_threshold)
).count()
print(f"Number of DENTIST outliers detected: {n_dentist_s_outlier}")
filtered_StudyLocus = filtered_StudyLocus.withColumn(
"dentist_outlier",
f.when(
(filtered_StudyLocus.R2 > r2_threshold)
& (filtered_StudyLocus.nlog10p_dentist_s > nlog10p_dentist_s_threshold),
1,
).otherwise(0),
)
return filtered_StudyLocus
Loading

0 comments on commit 2d15f5a

Please sign in to comment.