diff --git a/alpha_automl/resource/primitives_hierarchy.json b/alpha_automl/resource/primitives_hierarchy.json index bcbccc1a..8d452c04 100644 --- a/alpha_automl/resource/primitives_hierarchy.json +++ b/alpha_automl/resource/primitives_hierarchy.json @@ -69,7 +69,8 @@ "sklearn.linear_model.RidgeCV", "sklearn.linear_model.TheilSenRegressor", "xgboost.XGBRegressor", - "lightgbm.LGBMRegressor" + "lightgbm.LGBMRegressor", + "catboost.CatBoostRegressor" ], "TEXT_ENCODER": [ "sklearn.feature_extraction.text.CountVectorizer", diff --git a/alpha_automl/wrapper_primitives/tabnet.py b/alpha_automl/wrapper_primitives/tabnet.py new file mode 100644 index 00000000..1c498854 --- /dev/null +++ b/alpha_automl/wrapper_primitives/tabnet.py @@ -0,0 +1,36 @@ +import torch +import numpy as np +import pandas as pd +from alpha_automl.base_primitive import BasePrimitive +from alpha_automl._optional_dependency import check_optional_dependency + +ml_task = 'regression' +check_optional_dependency('pytorch_tabnet', ml_task) +from pytorch_tabnet.tab_model import TabNetRegressor + +class PytorchTabNetRegressor(BasePrimitive): + """ + This is a pyTorch implementation of Tabnet + (Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) + https://arxiv.org/pdf/1908.07442.pdf. + Please note that some different choices have been made overtime to improve the library + which can differ from the orginal paper. + """ + + def __init__(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.model = TabNetRegressor(n_d=32, verbose=False) + + def fit(self, X, y): + if isinstance(X, pd.DataFrame): + self.model.fit(X.values, y.values) + else: + self.model.fit(np.asarray(X), np.asarray(y)) + return self + + def predict(self, X): + # Load fasttext model + if isinstance(X, pd.DataFrame): + return self.model.predict(X.values) + else: + return self.model.predict(np.asarray(X)) \ No newline at end of file diff --git a/extra_requirements.txt b/extra_requirements.txt index af879346..09d11b54 100644 --- a/extra_requirements.txt +++ b/extra_requirements.txt @@ -6,4 +6,6 @@ mxnet==1.9.1: timeseries pmdarima==2.0.3: timeseries fasttext-wheel: nlp transformers: nlp, image -scikit-image: image \ No newline at end of file +scikit-image: image +catboost: regression +pytorch_tabnet: regression \ No newline at end of file