Skip to content

Commit

Permalink
test: adding tests for trans qtl flagging
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges committed Jan 15, 2025
1 parent 45ff1be commit fc16fee
Show file tree
Hide file tree
Showing 5 changed files with 196 additions and 3 deletions.
6 changes: 6 additions & 0 deletions src/gentropy/assets/schemas/study_locus.json
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,12 @@
"name": "confidence",
"nullable": true,
"type": "string"
},
{
"metadata": {},
"name": "isTransQtl",
"nullable": true,
"type": "boolean"
}
],
"type": "struct"
Expand Down
2 changes: 2 additions & 0 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,9 +667,11 @@ class StudyLocusValidationStepConfig(StepConfig):

study_index_path: str = MISSING
study_locus_path: list[str] = MISSING
target_index_path: str = MISSING
valid_study_locus_path: str = MISSING
invalid_study_locus_path: str = MISSING
invalid_qc_reasons: list[str] = MISSING
trans_qtl_threshold: int = MISSING
_target_: str = "gentropy.study_locus_validation.StudyLocusValidationStep"


Expand Down
96 changes: 95 additions & 1 deletion src/gentropy/dataset/study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import numpy as np
import pyspark.sql.functions as f
from pyspark.sql.types import ArrayType, FloatType, StringType
from pyspark.sql.types import ArrayType, FloatType, LongType, StringType

from gentropy.common.genomic_region import GenomicRegion, KnownGenomicRegions
from gentropy.common.schemas import parse_spark_schema
Expand Down Expand Up @@ -681,6 +681,100 @@ def get_QC_mappings(cls: type[StudyLocus]) -> dict[str, str]:
"""
return {member.name: member.value for member in StudyLocusQualityCheck}

def flag_trans_qtls(
self: StudyLocus,
study_index: StudyIndex,
target_index: DataFrame,
trans_threshold: int = 5_000_000,
) -> StudyLocus:
"""Flagging transQTL credible sets based on genomic location of the measured gene.
Process:
1. Enrich study-locus dataset with geneId based on study metadata. (only QTL studies are considered)
2. Enrich with transcription start site and chromosome of the studied gegne.
3. Flagging any tagging variant of QTL credible sets, if chromosome is different from the gene or distance is above the threshold.
4. Propagate flags to credible sets where any tags are considered as trans.
5. Return study locus object with annotation stored in 'isTransQtl` boolean column, where gwas credible sets will be `null`
Args:
study_index (StudyIndex): study index to extract identifier of the measured gene
target_index (DataFrame): target index dataframe
trans_threshold (int): Distance above which the QTL is considered trans. Default: 5_000_000bp
Returns:
StudyLocus: new column added indicating if the QTL credibles sets are trans.
"""
# As the `geneId` column in the study index is optional, we have to test for that:
if "geneId" not in study_index.df.columns:
return self

# Process study index:
processed_studies = (
study_index.df
# Dropping gwas studies. This ensures that only QTLs will have "isTrans" annotation:
.filter(f.col("studyType") != "gwas").select(
"studyId", "geneId", "projectId"
)
)

# Process study locus:
processed_credible_set = (
self.df
# Exploding locus to test all tag variants:
.withColumn("locus", f.explode("locus")).select(
"studyLocusId",
"studyId",
f.split("locus.variantId", "_")[0].alias("chromosome"),
f.split("locus.variantId", "_")[1].cast(LongType()).alias("position"),
)
)

# Process target index:
processed_targets = target_index.select(
f.col("id").alias("geneId"),
# Depending on the orientation of the transcript the transcription start site is either the start or end position:
f.when(
f.col("canonicalTranscript.strand") == "+",
f.col("canonicalTranscript.start"),
)
.when(
f.col("canonicalTranscript.strand") == "-",
f.col("canonicalTranscript.end"),
)
.alias("tss"),
f.col("canonicalTranscript.chromosome").alias("geneChromosome"),
)

# Pool datasets:
joined_data = (
processed_credible_set
# Join processed studies:
.join(processed_studies, on="studyId", how="inner")
# Join processed targets:
.join(processed_targets, on="geneId", how="left")
# Assign True/False for QTL studies:
.withColumn(
"isTagTrans",
# The QTL signal is considered trans if the locus is on a different chromosome than the measured gene.
# OR the distance from the gene's transcription start site is > threshold.
f.when(
(f.col("chromosome") != f.col("geneChromosome"))
| (f.abs(f.col("tss") - f.col("position")) > trans_threshold),
f.lit(True),
).otherwise(f.lit(False)),
)
.groupby("studyLocusId")
.agg(
# If any of the tags of a locus is in trans position, the QTL is considered trans:
f.array_contains(f.collect_set("isTagTrans"), True).alias("isTransQtl")
)
)
# Adding new column, where the value is null for gwas loci:
return StudyLocus(
_df=self.df.join(joined_data, on="studyLocusId", how="left"),
_schema=self.get_schema(),
)

def filter_credible_set(
self: StudyLocus,
credible_interval: CredibleInterval,
Expand Down
11 changes: 9 additions & 2 deletions src/gentropy/study_locus_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,29 @@ class StudyLocusValidationStep:
def __init__(
self,
session: Session,
study_index_path: str,
study_locus_path: list[str],
study_index_path: str,
target_index_path: str,
valid_study_locus_path: str,
invalid_study_locus_path: str,
trans_qtl_threshold: int,
invalid_qc_reasons: list[str] = [],
) -> None:
"""Initialize step.
Args:
session (Session): Session object.
study_index_path (str): Path to study index file.
study_locus_path (list[str]): Path to study locus dataset.
study_index_path (str): Path to study index file.
target_index_path (str): path to the target index.
valid_study_locus_path (str): Path to write the valid records.
invalid_study_locus_path (str): Path to write the output file.
trans_qtl_threshold (int): genomic distance above which a QTL is considered trans.
invalid_qc_reasons (list[str]): List of invalid quality check reason names from `StudyLocusQualityCheck` (e.g. ['SUBSIGNIFICANT_FLAG']).
"""
# Reading datasets:
study_index = StudyIndex.from_parquet(session, study_index_path)
target_index = session.spark.read.parquet(target_index_path)

# Running validation then writing output:
study_locus_with_qc = (
Expand All @@ -54,6 +59,8 @@ def __init__(
)
# Annotate credible set confidence:
.assign_confidence()
# Flagging trans qtls:
.flag_trans_qtls(study_index, target_index, trans_qtl_threshold)
).persist() # we will need this for 2 types of outputs

# Valid study locus partitioned to simplify the finding of overlaps
Expand Down
84 changes: 84 additions & 0 deletions tests/gentropy/dataset/test_study_locus.py
Original file line number Diff line number Diff line change
Expand Up @@ -1189,3 +1189,87 @@ def test_duplication_flag_correctness(
assert self.validated.df.filter(f.size("qualityControls") == 0).count() == 2

assert self.validated.df.filter(f.size("qualityControls") > 0).count() == 2


class TestTransQtlFlagging:
"""Test flagging trans qtl credible sets."""

THRESHOLD = 30
STUDY_LOCUS_DATA = [
# QTL in cis position -> flag: False
("sl1", "c1_50", "s1"),
# QTL in trans position (by distance) -> flag: True
("sl2", "c1_100", "s1"),
# QTL in trans position (by chromosome) -> flag: True
("sl3", "c2_50", "s1"),
# Not qtl -> flag: Null
("sl4", "c1_50", "s2"),
]

STUDY_LOCUS_COLUMNS = ["studyLocusId", "variantId", "studyId"]

STUDY_DATA = [
("s1", "p1", "qtl", "g1"),
("s2", "p2", "gwas", None),
]

STUDY_COLUMNS = ["studyId", "projectId", "studyType", "geneId"]

GENE_DATA = [("g1", "-", "10", "30", "c1")]
GENE_COLUMNS = ["id", "strand", "start", "end", "chromosome"]

@pytest.fixture(autouse=True)
def _setup(self: TestTransQtlFlagging, spark: SparkSession) -> None:
"""Setup study locus for testing."""
self.study_locus = StudyLocus(
_df=(
spark.createDataFrame(
self.STUDY_LOCUS_DATA, self.STUDY_LOCUS_COLUMNS
).withColumn("locus", f.array(f.struct("variantId")))
)
)
self.study_index = StudyIndex(
_df=spark.createDataFrame(self.STUDY_DATA, self.STUDY_COLUMNS)
)

self.target_index = spark.createDataFrame(
self.GENE_DATA, self.GENE_COLUMNS
).withColumn(
"canonicalTranscript", f.struct("strand", "start", "end", "chromosome")
)

self.qtl_flagged = self.study_locus.flag_trans_qtls(
self.study_index, self.target_index, self.THRESHOLD
)

def test_return_type(self: TestTransQtlFlagging) -> None:
"""Test duplication flagging return type."""
assert isinstance(self.qtl_flagged, StudyLocus)

def test_number_of_rows(self: TestTransQtlFlagging) -> None:
"""Test duplication flagging no data loss."""
assert self.qtl_flagged.df.count() == self.study_locus.df.count()

def test_column_added(self: TestTransQtlFlagging) -> None:
"""Test duplication flagging no data loss."""
assert "isTransQtl" in self.qtl_flagged.df.columns

def test_correctness_no_gwas_flagged(self: TestTransQtlFlagging) -> None:
"""Make sure the flag is null for gwas credible sets."""
gwas_studies = self.study_index.df.filter(f.col("studyId") == "s2")

assert (
self.qtl_flagged.df.join(gwas_studies, on="studyId", how="inner")
.filter(f.col("isTransQtl").isNotNull())
.count()
) == 0

def test_correctness_all_qlts_are_flagged(self: TestTransQtlFlagging) -> None:
"""Make sure all qtls have non-null flags."""
assert self.qtl_flagged.df.filter(f.col("isTransQtl").isNotNull()).count() == 3

def test_correctness_found_trans(self: TestTransQtlFlagging) -> None:
"""Make sure trans qtls are flagged."""
assert (
self.qtl_flagged.df.filter(f.col("isTransQtl")).count() == 2
), "Expected number of rows differ from observed."

0 comments on commit fc16fee

Please sign in to comment.