Skip to content

Commit

Permalink
add catboost and tabnet for tabular regression tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Nov 13, 2023
1 parent 3d1f80a commit 79d6f16
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 2 deletions.
3 changes: 2 additions & 1 deletion alpha_automl/resource/primitives_hierarchy.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@
"sklearn.linear_model.RidgeCV",
"sklearn.linear_model.TheilSenRegressor",
"xgboost.XGBRegressor",
"lightgbm.LGBMRegressor"
"lightgbm.LGBMRegressor",
"catboost.CatBoostRegressor"
],
"TEXT_ENCODER": [
"sklearn.feature_extraction.text.CountVectorizer",
Expand Down
36 changes: 36 additions & 0 deletions alpha_automl/wrapper_primitives/tabnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import numpy as np
import pandas as pd
from alpha_automl.base_primitive import BasePrimitive
from alpha_automl._optional_dependency import check_optional_dependency

ml_task = 'regression'
check_optional_dependency('pytorch_tabnet', ml_task)
from pytorch_tabnet.tab_model import TabNetRegressor

class PytorchTabNetRegressor(BasePrimitive):
"""
This is a pyTorch implementation of Tabnet
(Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.)
https://arxiv.org/pdf/1908.07442.pdf.
Please note that some different choices have been made overtime to improve the library
which can differ from the orginal paper.
"""

def __init__(self):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = TabNetRegressor(n_d=32, verbose=False)

def fit(self, X, y):
if isinstance(X, pd.DataFrame):
self.model.fit(X.values, y.values)
else:
self.model.fit(np.asarray(X), np.asarray(y))
return self

def predict(self, X):
# Load fasttext model
if isinstance(X, pd.DataFrame):
return self.model.predict(X.values)
else:
return self.model.predict(np.asarray(X))
4 changes: 3 additions & 1 deletion extra_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@ mxnet==1.9.1: timeseries
pmdarima==2.0.3: timeseries
fasttext-wheel: nlp
transformers: nlp, image
scikit-image: image
scikit-image: image
catboost: regression
pytorch_tabnet: regression

0 comments on commit 79d6f16

Please sign in to comment.