Skip to content

Commit

Permalink
Merge pull request #303 from monarch-initiative/302-add-roc-curves-an…
Browse files Browse the repository at this point in the history
…d-precision-recall-curves-to-benchmarking-output

302 add roc curves and precision recall curves to benchmarking output
  • Loading branch information
julesjacobsen authored Mar 18, 2024
2 parents 384e308 + 7f6753b commit 08a2333
Show file tree
Hide file tree
Showing 4 changed files with 204 additions and 17 deletions.
30 changes: 29 additions & 1 deletion src/pheval/analyse/binary_classification_stats.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from dataclasses import dataclass
from dataclasses import dataclass, field
from math import sqrt
from typing import List, Union

Expand Down Expand Up @@ -29,6 +29,8 @@ class BinaryClassificationStats:
true_negatives: int = 0
false_positives: int = 0
false_negatives: int = 0
labels: List = field(default_factory=list)
scores: List = field(default_factory=list)

@staticmethod
def remove_relevant_ranks(
Expand Down Expand Up @@ -84,6 +86,31 @@ def add_classification_for_other_entities(self, ranks: List[int]) -> None:
elif rank != 1:
self.true_negatives += 1

def add_labels_and_scores(
self,
pheval_results: Union[
List[RankedPhEvalGeneResult],
List[RankedPhEvalVariantResult],
List[RankedPhEvalDiseaseResult],
],
relevant_ranks: List[int],
):
"""
Adds scores and labels from the PhEval results.
Args:
pheval_results (Union[List[RankedPhEvalGeneResult], List[RankedPhEvalVariantResult],
List[RankedPhEvalDiseaseResult]]):
List of all PhEval results
relevant_ranks (List[int]): A list of the ranks associated with the known entities.
"""
relevant_ranks_copy = relevant_ranks.copy()
for result in pheval_results:
self.scores.append(result.score)
label = 1 if result.rank in relevant_ranks_copy else 0
self.labels.append(label)
relevant_ranks_copy.remove(result.rank) if label == 1 else None

def add_classification(
self,
pheval_results: Union[
Expand All @@ -105,6 +132,7 @@ def add_classification(
self.add_classification_for_other_entities(
self.remove_relevant_ranks(pheval_results, relevant_ranks)
)
self.add_labels_and_scores(pheval_results, relevant_ranks)

def sensitivity(self) -> float:
"""
Expand Down
79 changes: 79 additions & 0 deletions src/pheval/analyse/generate_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn.metrics import auc, precision_recall_curve, roc_curve

from pheval.analyse.benchmark_generator import (
BenchmarkRunOutputGenerator,
Expand Down Expand Up @@ -357,6 +358,82 @@ def _generate_non_cumulative_bar_plot_data(
]
)

def generate_roc_curve(
self,
benchmarking_results: List[BenchmarkRunResults],
benchmark_generator: BenchmarkRunOutputGenerator,
):
"""
Generate and plot Receiver Operating Characteristic (ROC) curves for binary classification benchmark results.
Args:
benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
"""
for i, benchmark_result in enumerate(benchmarking_results):
fpr, tpr, thresh = roc_curve(
benchmark_result.binary_classification_stats.labels,
benchmark_result.binary_classification_stats.scores,
pos_label=1,
)
roc_auc = auc(fpr, tpr)

plt.plot(
fpr,
tpr,
label=f"{self.return_benchmark_name(benchmark_result)} ROC Curve (AUC = {roc_auc:.2f})",
color=self.palette_hex_codes[i],
)

plt.plot(linestyle="--", color="gray")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("Receiver Operating Characteristic (ROC) Curve")
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
plt.savefig(
f"{benchmark_generator.prioritisation_type_file_prefix}_roc_curve.svg",
format="svg",
bbox_inches="tight",
)

def generate_precision_recall(
self,
benchmarking_results: List[BenchmarkRunResults],
benchmark_generator: BenchmarkRunOutputGenerator,
):
"""
Generate and plot Precision-Recall curves for binary classification benchmark results.
Args:
benchmarking_results (List[BenchmarkRunResults]): List of benchmarking results for multiple runs.
benchmark_generator (BenchmarkRunOutputGenerator): Object containing benchmarking output generation details.
"""
plt.figure()
for i, benchmark_result in enumerate(benchmarking_results):
precision, recall, thresh = precision_recall_curve(
benchmark_result.binary_classification_stats.labels,
benchmark_result.binary_classification_stats.scores,
)
precision_recall_auc = auc(recall, precision)
plt.plot(
recall,
precision,
label=f"{self.return_benchmark_name(benchmark_result)} Precision-Recall Curve "
f"(AUC = {precision_recall_auc:.2f})",
color=self.palette_hex_codes[i],
)

plt.plot(linestyle="--", color="gray")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend(loc="upper center", bbox_to_anchor=(0.5, -0.15))
plt.savefig(
f"{benchmark_generator.prioritisation_type_file_prefix}_precision_recall_curve.svg",
format="svg",
bbox_inches="tight",
)

def generate_non_cumulative_bar(
self,
benchmarking_results: List[BenchmarkRunResults],
Expand Down Expand Up @@ -418,6 +495,8 @@ def generate_plots(
title (str, optional): Title for the generated plot. Defaults to None.
"""
plot_generator = PlotGenerator()
plot_generator.generate_roc_curve(benchmarking_results, benchmark_generator)
plot_generator.generate_precision_recall(benchmarking_results, benchmark_generator)
if plot_type == "bar_stacked":
plot_generator.generate_stacked_bar_plot(benchmarking_results, benchmark_generator, title)
elif plot_type == "bar_cumulative":
Expand Down
105 changes: 90 additions & 15 deletions tests/test_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,12 @@ def test_assess_gene_prioritisation_no_threshold(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=3, false_positives=0, false_negatives=0
true_positives=1,
true_negatives=3,
false_positives=0,
false_negatives=0,
labels=[1, 0, 0, 0],
scores=[0.8764, 0.5777, 0.5777, 0.3765],
),
)

Expand Down Expand Up @@ -620,7 +625,12 @@ def test_assess_gene_prioritisation_threshold_fails_ascending_order_cutoff(self)
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=3, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=3,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 0],
scores=[0.3765, 0.5777, 0.5777, 0.8764],
),
)

Expand Down Expand Up @@ -661,7 +671,12 @@ def test_assess_gene_prioritisation_threshold_meets_ascending_order_cutoff(self)
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=2, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=2,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 1],
scores=[0.3765, 0.5777, 0.5777, 0.8764],
),
)

Expand Down Expand Up @@ -702,7 +717,12 @@ def test_assess_gene_prioritisation_threshold_fails_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=3, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=3,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 0],
scores=[0.8764, 0.5777, 0.5777, 0.3765],
),
)

Expand Down Expand Up @@ -743,7 +763,12 @@ def test_assess_gene_prioritisation_threshold_meets_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=3, false_positives=0, false_negatives=0
true_positives=1,
true_negatives=3,
false_positives=0,
false_negatives=0,
labels=[1, 0, 0, 0],
scores=[0.8764, 0.5777, 0.5777, 0.3765],
),
)

Expand Down Expand Up @@ -962,7 +987,12 @@ def test_assess_variant_prioritisation_no_threshold(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=0, false_positives=2, false_negatives=0
true_positives=1,
true_negatives=0,
false_positives=2,
false_negatives=0,
labels=[1, 0, 0],
scores=[0.0484, 0.0484, 0.0484],
),
)

Expand Down Expand Up @@ -1003,7 +1033,12 @@ def test_assess_variant_prioritisation_fails_ascending_order_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=0, false_positives=3, false_negatives=1
true_positives=0,
true_negatives=0,
false_positives=3,
false_negatives=1,
labels=[0, 0, 0],
scores=[0.0484, 0.0484, 0.0484],
),
)

Expand Down Expand Up @@ -1044,7 +1079,12 @@ def test_assess_variant_prioritisation_meets_ascending_order_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=0, false_positives=2, false_negatives=0
true_positives=1,
true_negatives=0,
false_positives=2,
false_negatives=0,
labels=[1, 0, 0],
scores=[0.0484, 0.0484, 0.0484],
),
)

Expand Down Expand Up @@ -1085,7 +1125,12 @@ def test_assess_variant_prioritisation_fails_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=0, false_positives=3, false_negatives=1
true_positives=0,
true_negatives=0,
false_positives=3,
false_negatives=1,
labels=[0, 0, 0],
scores=[0.0484, 0.0484, 0.0484],
),
)

Expand Down Expand Up @@ -1126,7 +1171,12 @@ def test_assess_variant_prioritisation_meets_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=0, false_positives=2, false_negatives=0
true_positives=1,
true_negatives=0,
false_positives=2,
false_negatives=0,
labels=[1, 0, 0],
scores=[0.0484, 0.0484, 0.0484],
),
)

Expand Down Expand Up @@ -1368,7 +1418,12 @@ def test_assess_disease_prioritisation_no_threshold(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=3, false_positives=0, false_negatives=0
true_positives=1,
true_negatives=3,
false_positives=0,
false_negatives=0,
labels=[1, 0, 0, 0],
scores=[1.0, 0.5, 0.5, 0.3],
),
)

Expand Down Expand Up @@ -1404,7 +1459,12 @@ def test_assess_disease_prioritisation_threshold_fails_ascending_order_cutoff(se
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=3, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=3,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 0],
scores=[0.3765, 0.5777, 0.5777, 0.8764],
),
)

Expand Down Expand Up @@ -1440,7 +1500,12 @@ def test_assess_disease_prioritisation_threshold_meets_ascending_order_cutoff(se
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=2, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=2,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 1],
scores=[0.3765, 0.5777, 0.5777, 0.8764],
),
)

Expand Down Expand Up @@ -1476,7 +1541,12 @@ def test_assess_disease_prioritisation_threshold_fails_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=0, true_negatives=3, false_positives=1, false_negatives=1
true_positives=0,
true_negatives=3,
false_positives=1,
false_negatives=1,
labels=[0, 0, 0, 0],
scores=[1.0, 0.5, 0.5, 0.3],
),
)

Expand Down Expand Up @@ -1512,6 +1582,11 @@ def test_assess_disease_prioritisation_threshold_meets_cutoff(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=3, false_positives=0, false_negatives=0
true_positives=1,
true_negatives=3,
false_positives=0,
false_negatives=0,
labels=[1, 0, 0, 0],
scores=[1.0, 0.5, 0.5, 0.3],
),
)
7 changes: 6 additions & 1 deletion tests/test_binary_classification_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,12 @@ def test_add_classification(self):
self.assertEqual(
self.binary_classification_stats,
BinaryClassificationStats(
true_positives=1, true_negatives=2, false_positives=0, false_negatives=1
true_positives=1,
true_negatives=2,
false_positives=0,
false_negatives=1,
labels=[1, 1, 0, 0],
scores=[1.0, 0.5, 0.5, 0.3],
),
)

Expand Down

0 comments on commit 08a2333

Please sign in to comment.