Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: remove gene_index step #946

Draft
wants to merge 4 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions docs/python_api/datasets/gene_index.md

This file was deleted.

9 changes: 9 additions & 0 deletions docs/python_api/datasets/target_index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
---
title: Target Index
---

::: gentropy.dataset.target_index.TargetIndex

## Schema

--8<-- "assets/schemas/target_index.md"
5 changes: 0 additions & 5 deletions docs/python_api/steps/gene_index.md

This file was deleted.

5 changes: 5 additions & 0 deletions docs/python_api/steps/target_index.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
title: target_index
---

::: gentropy.target_index.TargetIndexStep
2 changes: 1 addition & 1 deletion src/gentropy/biosample_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
cell_ontology_input_path (str): Input cell ontology dataset path.
uberon_input_path (str): Input uberon dataset path.
efo_input_path (str): Input efo dataset path.
biosample_index_path (str): Output gene index dataset path.
biosample_index_path (str): Output biosample index dataset path.
"""
cell_ontology_index = extract_ontology_from_json(
cell_ontology_input_path, session.spark
Expand Down
12 changes: 6 additions & 6 deletions src/gentropy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,12 @@ class ColocalisationConfig(StepConfig):


@dataclass
class GeneIndexConfig(StepConfig):
"""Gene index step configuration."""
class TargetIndexConfig(StepConfig):
"""Target index step configuration."""

target_path: str = MISSING
gene_index_path: str = MISSING
_target_: str = "gentropy.gene_index.GeneIndexStep"
target_index_path: str = MISSING
_target_: str = "gentropy.target_index.TargetIndexStep"


@dataclass
Expand Down Expand Up @@ -305,7 +305,7 @@ class LocusToGeneFeatureMatrixConfig(StepConfig):
variant_index_path: str | None = None
colocalisation_path: str | None = None
study_index_path: str | None = None
gene_index_path: str | None = None
target_index_path: str | None = None
feature_matrix_path: str = MISSING
features_list: list[str] = field(
default_factory=lambda: [
Expand Down Expand Up @@ -694,7 +694,7 @@ def register_config() -> None:
cs.store(group="step/session", name="base_session", node=SessionConfig)
cs.store(group="step", name="colocalisation", node=ColocalisationConfig)
cs.store(group="step", name="eqtl_catalogue", node=EqtlCatalogueConfig)
cs.store(group="step", name="gene_index", node=GeneIndexConfig)
cs.store(group="step", name="target_index", node=TargetIndexConfig)
cs.store(group="step", name="biosample_index", node=BiosampleIndexConfig)
cs.store(
group="step",
Expand Down
8 changes: 4 additions & 4 deletions src/gentropy/dataset/intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gentropy.common.Liftover import LiftOverSpark
from gentropy.common.schemas import parse_spark_schema
from gentropy.dataset.dataset import Dataset
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.target_index import TargetIndex

if TYPE_CHECKING:
from pyspark.sql import SparkSession
Expand All @@ -35,7 +35,7 @@ def from_source(
spark: SparkSession,
source_name: str,
source_path: str,
gene_index: GeneIndex,
target_index: TargetIndex,
lift: LiftOverSpark,
) -> Intervals:
"""Collect interval data for a particular source.
Expand All @@ -44,7 +44,7 @@ def from_source(
spark (SparkSession): Spark session
source_name (str): Name of the interval source
source_path (str): Path to the interval source file
gene_index (GeneIndex): Gene index
target_index (TargetIndex): Target index
lift (LiftOverSpark): LiftOverSpark instance to convert coordinats from hg37 to hg38

Returns:
Expand All @@ -70,4 +70,4 @@ def from_source(

source_class = source_to_class[source_name]
data = source_class.read(spark, source_path) # type: ignore
return source_class.parse(data, gene_index, lift) # type: ignore
return source_class.parse(data, target_index, lift) # type: ignore
28 changes: 14 additions & 14 deletions src/gentropy/dataset/l2g_features/colocalisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from gentropy.common.spark_helpers import convert_from_wide_to_long
from gentropy.dataset.colocalisation import Colocalisation
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.l2g_features.l2g_feature import L2GFeature
from gentropy.dataset.l2g_gold_standard import L2GGoldStandard
from gentropy.dataset.study_index import StudyIndex
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.target_index import TargetIndex
from gentropy.dataset.variant_index import VariantIndex

if TYPE_CHECKING:
Expand Down Expand Up @@ -74,7 +74,7 @@ def extend_missing_colocalisation_to_neighbourhood_genes(
feature_name: str,
local_features: DataFrame,
variant_index: VariantIndex,
gene_index: GeneIndex,
target_index: TargetIndex,
study_locus: StudyLocus,
) -> DataFrame:
"""This function creates an artificial dataset of features that represents the missing colocalisation to the neighbourhood genes.
Expand All @@ -83,7 +83,7 @@ def extend_missing_colocalisation_to_neighbourhood_genes(
feature_name (str): The name of the feature to extend
local_features (DataFrame): The dataframe of features to extend
variant_index (VariantIndex): Variant index containing all variant/gene relationships
gene_index (GeneIndex): Gene index to fetch the gene information
target_index (TargetIndex): Target index to fetch the gene information
study_locus (StudyLocus): Study locus to traverse between colocalisation and variant index

Returns:
Expand All @@ -94,7 +94,7 @@ def extend_missing_colocalisation_to_neighbourhood_genes(
"variantId", f.explode("transcriptConsequences").alias("tc")
)
.select(f.col("tc.targetId").alias("geneId"), "variantId")
.join(gene_index.df.select("geneId", "biotype"), "geneId", "left")
.join(target_index.df.select("geneId", "biotype"), "geneId", "left")
.filter(f.col("biotype") == "protein_coding")
.drop("biotype")
.distinct()
Expand Down Expand Up @@ -127,7 +127,7 @@ def common_neighbourhood_colocalisation_feature_logic(
*,
colocalisation: Colocalisation,
study_index: StudyIndex,
gene_index: GeneIndex,
target_index: TargetIndex,
study_locus: StudyLocus,
variant_index: VariantIndex,
) -> DataFrame:
Expand All @@ -141,7 +141,7 @@ def common_neighbourhood_colocalisation_feature_logic(
qtl_types (list[str] | str): The types of QTL to filter the data by
colocalisation (Colocalisation): Dataset with the colocalisation results
study_index (StudyIndex): Study index to fetch study type and gene
gene_index (GeneIndex): Gene index to add gene type
target_index (TargetIndex): Target index to add gene type
study_locus (StudyLocus): Study locus to traverse between colocalisation and study index
variant_index (VariantIndex): Variant index to annotate all overlapping genes

Expand All @@ -165,15 +165,15 @@ def common_neighbourhood_colocalisation_feature_logic(
local_feature_name,
local_max,
variant_index,
gene_index,
target_index,
study_locus,
)
)
return (
extended_local_max.join(
# Compute average score in the vicinity (feature will be the same for any gene associated with a studyLocus)
# (non protein coding genes in the vicinity are excluded see #3552)
gene_index.df.filter(f.col("biotype") == "protein_coding").select("geneId"),
target_index.df.filter(f.col("biotype") == "protein_coding").select("geneId"),
"geneId",
"inner",
)
Expand Down Expand Up @@ -242,7 +242,7 @@ class EQtlColocClppMaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down Expand Up @@ -333,7 +333,7 @@ class PQtlColocClppMaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down Expand Up @@ -423,7 +423,7 @@ class SQtlColocClppMaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down Expand Up @@ -513,7 +513,7 @@ class EQtlColocH4MaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down Expand Up @@ -603,7 +603,7 @@ class PQtlColocH4MaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down Expand Up @@ -693,7 +693,7 @@ class SQtlColocH4MaximumNeighbourhoodFeature(L2GFeature):
feature_dependency_type = [
Colocalisation,
StudyIndex,
GeneIndex,
TargetIndex,
StudyLocus,
VariantIndex,
]
Expand Down
16 changes: 8 additions & 8 deletions src/gentropy/dataset/l2g_features/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
from pyspark.sql import Window

from gentropy.common.spark_helpers import convert_from_wide_to_long
from gentropy.dataset.gene_index import GeneIndex
from gentropy.dataset.l2g_features.l2g_feature import L2GFeature
from gentropy.dataset.l2g_gold_standard import L2GGoldStandard
from gentropy.dataset.study_locus import StudyLocus
from gentropy.dataset.target_index import TargetIndex
from gentropy.dataset.variant_index import VariantIndex

if TYPE_CHECKING:
Expand Down Expand Up @@ -85,7 +85,7 @@ def common_neighbourhood_distance_feature_logic(
variant_index: VariantIndex,
feature_name: str,
distance_type: str,
gene_index: GeneIndex,
target_index: TargetIndex,
genomic_window: int = 500_000,
) -> DataFrame:
"""Calculate the distance feature that correlates any variant in a credible set with any protein coding gene nearby the locus. The distance is weighted by the posterior probability of the variant to factor in its contribution to the trait.
Expand All @@ -95,7 +95,7 @@ def common_neighbourhood_distance_feature_logic(
variant_index (VariantIndex): The dataset containing distance to gene information
feature_name (str): The name of the feature
distance_type (str): The type of distance to gene
gene_index (GeneIndex): The dataset containing gene information
target_index (TargetIndex): The dataset containing gene information
genomic_window (int): The maximum window size to consider

Returns:
Expand All @@ -113,7 +113,7 @@ def common_neighbourhood_distance_feature_logic(
return (
# Then compute mean distance in the vicinity (feature will be the same for any gene associated with a studyLocus)
local_metric.join(
gene_index.df.filter(f.col("biotype") == "protein_coding").select("geneId"),
target_index.df.filter(f.col("biotype") == "protein_coding").select("geneId"),
"geneId",
"inner",
)
Expand Down Expand Up @@ -185,7 +185,7 @@ def compute(
class DistanceTssMeanNeighbourhoodFeature(L2GFeature):
"""Minimum mean distance to TSS for all genes in the vicinity of a studyLocus."""

feature_dependency_type = [VariantIndex, GeneIndex]
feature_dependency_type = [VariantIndex, TargetIndex]
feature_name = "distanceTssMeanNeighbourhood"

@classmethod
Expand Down Expand Up @@ -261,7 +261,7 @@ def compute(
class DistanceSentinelTssNeighbourhoodFeature(L2GFeature):
"""Distance between the sentinel variant and a gene TSS as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability."""

feature_dependency_type = [VariantIndex, GeneIndex]
feature_dependency_type = [VariantIndex, TargetIndex]
feature_name = "distanceSentinelTssNeighbourhood"

@classmethod
Expand Down Expand Up @@ -342,7 +342,7 @@ def compute(
class DistanceFootprintMeanNeighbourhoodFeature(L2GFeature):
"""Minimum mean distance to footprint for all genes in the vicinity of a studyLocus."""

feature_dependency_type = [VariantIndex, GeneIndex]
feature_dependency_type = [VariantIndex, TargetIndex]
feature_name = "distanceFootprintMeanNeighbourhood"

@classmethod
Expand Down Expand Up @@ -418,7 +418,7 @@ def compute(
class DistanceSentinelFootprintNeighbourhoodFeature(L2GFeature):
"""Distance between the sentinel variant and a gene footprint as a relation of the distnace with all the genes in the vicinity of a studyLocus. This is not weighted by the causal probability."""

feature_dependency_type = [VariantIndex, GeneIndex]
feature_dependency_type = [VariantIndex, TargetIndex]
feature_name = "distanceSentinelFootprintNeighbourhood"

@classmethod
Expand Down
Loading
Loading