Skip to content

Commit

Permalink
TabPFN: Add Regression and option to have no labels for test data (#1567
Browse files Browse the repository at this point in the history
)

* add regression and plots

* add filter

* fix linting

* update help

* fix linting

* move boolean param down
  • Loading branch information
anuprulez authored Jan 17, 2025
1 parent 0544107 commit e87b82b
Show file tree
Hide file tree
Showing 10 changed files with 563 additions and 37 deletions.
111 changes: 91 additions & 20 deletions tools/tabpfn/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,15 @@
import time

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, average_precision_score, precision_recall_curve
from tabpfn import TabPFNClassifier
from sklearn.metrics import (
average_precision_score,
precision_recall_curve,
r2_score,
root_mean_squared_error
)
from tabpfn import TabPFNClassifier, TabPFNRegressor


def separate_features_labels(data):
Expand All @@ -17,38 +23,103 @@ def separate_features_labels(data):
return features, labels


def classification_plot(xval, yval, leg_label, title, xlabel, ylabel):
plt.figure(figsize=(8, 6))
plt.plot(xval, yval, label=leg_label)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend(loc="lower left")
plt.grid(True)
plt.savefig("output_plot.png")


def regression_plot(xval, yval, title, xlabel, ylabel):
plt.figure(figsize=(8, 6))
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.legend(loc="lower left")
plt.grid(True)
plt.scatter(xval, yval, alpha=0.8)
xticks = np.arange(len(xval))
plt.plot(xticks, xticks, color="red", linestyle="--", label="y = x")
plt.savefig("output_plot.png")


def train_evaluate(args):
"""
Train TabPFN
Train TabPFN and predict
"""
# prepare train data
tr_features, tr_labels = separate_features_labels(args["train_data"])
te_features, te_labels = separate_features_labels(args["test_data"])
classifier = TabPFNClassifier(device='cpu')
# prepare test data
if args["testhaslabels"] == "haslabels":
te_features, te_labels = separate_features_labels(args["test_data"])
else:
te_features = pd.read_csv(args["test_data"], sep="\t")
te_labels = []
s_time = time.time()
classifier.fit(tr_features, tr_labels)
if args["selected_task"] == "Classification":
classifier = TabPFNClassifier(device="cpu")
classifier.fit(tr_features, tr_labels)
y_eval = classifier.predict(te_features)
pred_probas_test = classifier.predict_proba(te_features)
if len(te_labels) > 0:
precision, recall, thresholds = precision_recall_curve(
te_labels, pred_probas_test[:, 1]
)
average_precision = average_precision_score(
te_labels, pred_probas_test[:, 1]
)
classification_plot(
recall,
precision,
f"Precision-Recall Curve (AP={average_precision:.2f})",
"Precision-Recall Curve",
"Recall",
"Precision",
)
else:
regressor = TabPFNRegressor(device="cpu")
regressor.fit(tr_features, tr_labels)
y_eval = regressor.predict(te_features)
if len(te_labels) > 0:
score = root_mean_squared_error(te_labels, y_eval)
r2_metric_score = r2_score(te_labels, y_eval)
regression_plot(
te_labels,
y_eval,
f"Scatter plot: True vs predicted values. RMSE={score:.2f}, R2={r2_metric_score:.2f}",
"True values",
"Predicted values",
)
e_time = time.time()
print("Time taken by TabPFN for training: {} seconds".format(e_time - s_time))
y_eval = classifier.predict(te_features)
print('Accuracy', accuracy_score(te_labels, y_eval))
pred_probas_test = classifier.predict_proba(te_features)
print(
"Time taken by TabPFN for training and prediction: {} seconds".format(
e_time - s_time
)
)
te_features["predicted_labels"] = y_eval
te_features.to_csv("output_predicted_data", sep="\t", index=None)
precision, recall, thresholds = precision_recall_curve(te_labels, pred_probas_test[:, 1])
average_precision = average_precision_score(te_labels, pred_probas_test[:, 1])
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, label=f'Precision-Recall Curve (AP={average_precision:.2f})')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc='lower left')
plt.grid(True)
plt.savefig("output_prec_recall_curve.png")


if __name__ == "__main__":
arg_parser = argparse.ArgumentParser()
arg_parser.add_argument("-trdata", "--train_data", required=True, help="Train data")
arg_parser.add_argument("-tedata", "--test_data", required=True, help="Test data")
arg_parser.add_argument(
"-testhaslabels",
"--testhaslabels",
required=True,
help="if test data contain labels",
)
arg_parser.add_argument(
"-selectedtask",
"--selected_task",
required=True,
help="Type of machine learning task",
)
# get argument values
args = vars(arg_parser.parse_args())
train_evaluate(args)
65 changes: 49 additions & 16 deletions tools/tabpfn/tabpfn.xml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<description>with PyTorch</description>
<macros>
<token name="@TOOL_VERSION@">2.0.3</token>
<token name="@VERSION_SUFFIX@">0</token>
<token name="@VERSION_SUFFIX@">1</token>
</macros>
<creator>
<organization name="European Galaxy Team" url="https://galaxyproject.org/eu/" />
Expand All @@ -17,48 +17,81 @@
<version_command>echo "@VERSION@"</version_command>
<command detect_errors="aggressive">
<![CDATA[
python '$__tool_directory__/main.py'
python '$__tool_directory__/main.py'
--selected_task '$selected_task'
--train_data '$train_data'
--testhaslabels '$testhaslabels'
--test_data '$test_data'
]]>
</command>
<inputs>
<param name="train_data" type="data" format="tabular" label="Train data" help="Please provide training data for training model."/>
<param name="test_data" type="data" format="tabular" label="Test data" help="Please provide test data for evaluating model."/>
<param name="selected_task" type="select" label="Select a machine learning task">
<option value="Classification" selected="true"></option>
<option value="Regression" selected="false"></option>
</param>
<param name="train_data" type="data" format="tabular" label="Train data" help="Please provide training data for training model. It should contain labels/class/target in the last column" />
<param name="test_data" type="data" format="tabular" label="Test data" help="Please provide test data for evaluating model. It may or may not contain labels/class/target in the last column" />
<param name="testhaslabels" type="boolean" truevalue="haslabels" falsevalue="" checked="false" label="Does test data contain labels?" help="Set this parameter when test data contains labels" />
</inputs>
<outputs>
<data format="tabular" name="output_predicted_data" from_work_dir="output_predicted_data" label="Predicted data"></data>
<data format="png" name="output_prec_recall_curve" from_work_dir="output_prec_recall_curve.png" label="Precision-recall curve"></data>
<data format="png" name="output_plot" from_work_dir="output_plot.png" label="Prediction plot on test data">
<filter>testhaslabels is True</filter>
</data>
</outputs>
<tests>
<test>
<param name="train_data" value="local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="local_test_rows.tabular" ftype="tabular" />
<test expect_num_outputs="1">
<param name="selected_task" value="Classification" />
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="classification_local_test_rows.tabular" ftype="tabular" />
<param name="testhaslabels" value="" />
<output name="output_predicted_data">
<assert_contents>
<has_n_columns n="42" />
<has_n_lines n="3" />
</assert_contents>
</output>
</test>
<test>
<param name="train_data" value="local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="local_test_rows.tabular" ftype="tabular" />
<output name="output_prec_recall_curve" file="pr_curve.png" compare="sim_size" />
<test expect_num_outputs="2">
<param name="selected_task" value="Classification" />
<param name="train_data" value="classification_local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="classification_local_test_rows_labels.tabular" ftype="tabular" />
<param name="testhaslabels" value="haslabels" />
<output name="output_plot" file="pr_curve.png" compare="sim_size" />
</test>
<test expect_num_outputs="2">
<param name="selected_task" value="Regression" />
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="regression_local_test_rows_labels.tabular" ftype="tabular" />
<param name="testhaslabels" value="haslabels" />
<output name="output_plot" file="r2_curve.png" compare="sim_size" />
</test>
<test expect_num_outputs="1">
<param name="selected_task" value="Regression" />
<param name="train_data" value="regression_local_train_rows.tabular" ftype="tabular" />
<param name="test_data" value="regression_local_test_rows.tabular" ftype="tabular" />
<param name="testhaslabels" value="" />
<output name="output_predicted_data">
<assert_contents>
<has_n_columns n="14" />
<has_n_lines n="105" />
</assert_contents>
</output>
</test>
</tests>
<help>
<![CDATA[
**What it does**
Classification on tabular data by TabPFN
Classification and Regression on tabular data by TabPFN
**Input files**
- Training data: the training data should contain features and the last column should be the class labels. It could either be tabular or in CSV format.
- Test data: the test data should also contain the same features as the training data and the last column should be the class labels. It could either be tabular or in CSV format.
- Training data: the training data should contain features and the last column should be the class labels. It should be in tabular format.
- Test data: the test data should also contain the same features as the training data and the last column should be the class labels if labels are avaialble. It should be in tabular format. It is not required for the test data to have labels.
**Output files**
- Predicted data along with predicted labels
- Predicted data along with predicted labels.
- Prediction plot (when test data has labels available).
]]>
</help>
<citations>
Expand Down
3 changes: 3 additions & 0 deletions tools/tabpfn/test-data/classification_local_test_rows.tabular
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
SpMax_L J_Dz(e) nHM F01[N-N] F04[C-N] NssssC nCb- C% nCp nO F03[C-N] SdssC HyWi_B(m) LOC SM6_L F03[C-O] Me Mi nN-N nArNO2 nCRX3 SpPosA_B(p) nCIR B01[C-Br] B03[C-Cl] N-073 SpMax_A Psi_i_1d B04[C-Br] SdO TI2_L nCrt C-026 F02[C-N] nHDon SpMax_B(m) Psi_i_A nN SM6_B(m) nArCOOR nX
3.919 2.6909 0 0 0 0 0 31.4 2 0 0 0 3.106 2.55 9.002 0 0.96 1.142 0 0 0 1.201 0 0 0 0 1.932 0.011 0 0 4.489 0 0 0 0 2.949 1.591 0 7.253 0 0
4.17 2.1144 0 0 0 0 0 30.8 1 1 0 0 2.461 1.393 8.723 1 0.989 1.144 0 0 0 1.104 1 0 0 0 2.214 -0.204 0 0 1.542 0 0 0 0 3.315 1.967 0 7.257 0 0
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
SpMax_L J_Dz(e) nHM F01[N-N] F04[C-N] NssssC nCb- C% nCp nO F03[C-N] SdssC HyWi_B(m) LOC SM6_L F03[C-O] Me Mi nN-N nArNO2 nCRX3 SpPosA_B(p) nCIR B01[C-Br] B03[C-Cl] N-073 SpMax_A Psi_i_1d B04[C-Br] SdO TI2_L nCrt C-026 F02[C-N] nHDon SpMax_B(m) Psi_i_A nN SM6_B(m) nArCOOR nX predicted_labels
SpMax_L J_Dz(e) nHM F01[N-N] F04[C-N] NssssC nCb- C% nCp nO F03[C-N] SdssC HyWi_B(m) LOC SM6_L F03[C-O] Me Mi nN-N nArNO2 nCRX3 SpPosA_B(p) nCIR B01[C-Br] B03[C-Cl] N-073 SpMax_A Psi_i_1d B04[C-Br] SdO TI2_L nCrt C-026 F02[C-N] nHDon SpMax_B(m) Psi_i_A nN SM6_B(m) nArCOOR nX Class
3.919 2.6909 0 0 0 0 0 31.4 2 0 0 0 3.106 2.55 9.002 0 0.96 1.142 0 0 0 1.201 0 0 0 0 1.932 0.011 0 0 4.489 0 0 0 0 2.949 1.591 0 7.253 0 0 1
4.17 2.1144 0 0 0 0 0 30.8 1 1 0 0 2.461 1.393 8.723 1 0.989 1.144 0 0 0 1.104 1 0 0 0 2.214 -0.204 0 0 1.542 0 0 0 0 3.315 1.967 0 7.257 0 0 1
Binary file modified tools/tabpfn/test-data/pr_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added tools/tabpfn/test-data/r2_curve.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit e87b82b

Please sign in to comment.