Skip to content

Commit

Permalink
Merge branch 'main' into do_patch
Browse files Browse the repository at this point in the history
  • Loading branch information
DSuveges authored Dec 12, 2023
2 parents 6a112a6 + 243d8f6 commit 7f94aed
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 21 deletions.
88 changes: 67 additions & 21 deletions src/otg/method/l2g/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.tuning import ParamGridBuilder
from wandb.wandb_run import Run
from xgboost.spark.core import SparkXGBClassifierModel

from otg.dataset.l2g_feature_matrix import L2GFeatureMatrix
from otg.method.l2g.evaluator import WandbEvaluator
Expand All @@ -31,6 +32,7 @@ class LocusToGeneModel:
estimator: Any = None
pipeline: Pipeline = Pipeline(stages=[])
model: PipelineModel | None = None
wandb_l2g_project_name: str = "otg_l2g"

def __post_init__(self: LocusToGeneModel) -> None:
"""Post init that adds the model to the ML pipeline."""
Expand Down Expand Up @@ -98,29 +100,39 @@ def features_vector_assembler(features_cols: list[str]) -> VectorAssembler:
.setOutputCol("features")
)

@staticmethod
def log_to_wandb(
self: LocusToGeneModel,
results: DataFrame,
binary_evaluator: BinaryClassificationEvaluator,
multi_evaluator: MulticlassClassificationEvaluator,
training_data: L2GFeatureMatrix,
evaluators: list[
BinaryClassificationEvaluator | MulticlassClassificationEvaluator
],
wandb_run: Run,
) -> None:
"""Perform evaluation of the model by applying it to a test set and tracking the results with W&B.
"""Log evaluation results and feature importance to W&B.
Args:
results (DataFrame): Dataframe containing the predictions
binary_evaluator (BinaryClassificationEvaluator): Binary evaluator
multi_evaluator (MulticlassClassificationEvaluator): Multiclass evaluator
training_data (L2GFeatureMatrix): Training data used for the model. If provided, the table and the number of positive and negative labels will be logged to W&B
evaluators (list[BinaryClassificationEvaluator | MulticlassClassificationEvaluator]): List of Spark ML evaluators to use for evaluation
wandb_run (Run): W&B run to log the results to
"""
binary_wandb_evaluator = WandbEvaluator(
spark_ml_evaluator=binary_evaluator, wandb_run=wandb_run
)
binary_wandb_evaluator.evaluate(results)
multi_wandb_evaluator = WandbEvaluator(
spark_ml_evaluator=multi_evaluator, wandb_run=wandb_run
)
multi_wandb_evaluator.evaluate(results)
## Track evaluation metrics
for evaluator in evaluators:
wandb_evaluator = WandbEvaluator(
spark_ml_evaluator=evaluator, wandb_run=wandb_run
)
wandb_evaluator.evaluate(results)
## Track feature importance
wandb_run.log({"importances": self.get_feature_importance()})
## Track training set metadata
gs_counts_dict = {
"goldStandard" + row["goldStandardSet"].capitalize(): row["count"]
for row in training_data.df.groupBy("goldStandardSet").count().collect()
}
wandb_run.log(gs_counts_dict)
training_table = wandb.Table(dataframe=training_data.df.toPandas())
wandb_run.log({"trainingSet": wandb.Table(dataframe=training_table)})

@classmethod
def load_from_disk(
Expand Down Expand Up @@ -189,13 +201,15 @@ def evaluate(
results: DataFrame,
hyperparameters: dict[str, Any],
wandb_run_name: str | None,
training_data: L2GFeatureMatrix | None = None,
) -> None:
"""Perform evaluation of the model predictions for the test set and track the results with W&B.
Args:
results (DataFrame): Dataframe containing the predictions
hyperparameters (dict[str, Any]): Hyperparameters used for the model
wandb_run_name (str | None): Descriptive name for the run to be tracked with W&B
training_data (L2GFeatureMatrix | None): Training data used for the model. If provided, the ratio of positive to negative labels will be logged to W&B
"""
binary_evaluator = BinaryClassificationEvaluator(
rawPredictionCol="rawPrediction", labelCol="label"
Expand Down Expand Up @@ -226,20 +240,52 @@ def evaluate(
multi_evaluator.evaluate(results, {multi_evaluator.metricName: "f1"}),
)

if wandb_run_name:
if wandb_run_name and training_data:
print("Logging to W&B...")
run = wandb.init(
project="otg_l2g", config=hyperparameters, name=wandb_run_name
project=self.wandb_l2g_project_name,
config=hyperparameters,
name=wandb_run_name,
)
if isinstance(run, Run):
LocusToGeneModel.log_to_wandb(
results, binary_evaluator, multi_evaluator, run
self.log_to_wandb(
results, training_data, [binary_evaluator, multi_evaluator], run
)
run.finish()

def plot_importance(self: LocusToGeneModel) -> None:
"""Plot the feature importance of the model."""
# xgb_plot_importance(self) # FIXME: What is the attribute that stores the model?
@property
def feature_name_map(self: LocusToGeneModel) -> dict[str, str]:
"""Return a dictionary mapping encoded feature names to the original names.
Returns:
dict[str, str]: Feature name map of the model
Raises:
ValueError: If the model has not been fitted yet
"""
if not self.model:
raise ValueError("Model not fitted yet. `fit()` has to be called first.")
elif isinstance(self.model.stages[1], VectorAssembler):
feature_names = self.model.stages[1].getInputCols()
return {f"f{i}": feature_name for i, feature_name in enumerate(feature_names)}

def get_feature_importance(self: LocusToGeneModel) -> dict[str, float]:
"""Return dictionary with relative importances of every feature in the model. Feature names are encoded and have to be mapped back to their original names.
Returns:
dict[str, float]: Dictionary mapping feature names to their importance
Raises:
ValueError: If the model has not been fitted yet or is not an XGBoost model
"""
if not self.model or not isinstance(
self.model.stages[-1], SparkXGBClassifierModel
):
raise ValueError(
f"Model type {type(self.model)} not supported for feature importance."
)
importance_map = self.model.stages[-1].get_feature_importances()
return {self.feature_name_map[k]: v for k, v in importance_map.items()}

def fit(
self: LocusToGeneModel,
Expand Down
1 change: 1 addition & 0 deletions src/otg/method/l2g/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def train(
results=model.predict(test),
hyperparameters=hyperparams,
wandb_run_name=wandb_run_name,
training_data=train,
)
if model_path:
l2g_model.save(model_path)
Expand Down

0 comments on commit 7f94aed

Please sign in to comment.