-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add catboost and tabnet for tabular regression tasks
- Loading branch information
1 parent
3d1f80a
commit 79d6f16
Showing
3 changed files
with
41 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
import numpy as np | ||
import pandas as pd | ||
from alpha_automl.base_primitive import BasePrimitive | ||
from alpha_automl._optional_dependency import check_optional_dependency | ||
|
||
ml_task = 'regression' | ||
check_optional_dependency('pytorch_tabnet', ml_task) | ||
from pytorch_tabnet.tab_model import TabNetRegressor | ||
|
||
class PytorchTabNetRegressor(BasePrimitive): | ||
""" | ||
This is a pyTorch implementation of Tabnet | ||
(Arik, S. O., & Pfister, T. (2019). TabNet: Attentive Interpretable Tabular Learning. arXiv preprint arXiv:1908.07442.) | ||
https://arxiv.org/pdf/1908.07442.pdf. | ||
Please note that some different choices have been made overtime to improve the library | ||
which can differ from the orginal paper. | ||
""" | ||
|
||
def __init__(self): | ||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | ||
self.model = TabNetRegressor(n_d=32, verbose=False) | ||
|
||
def fit(self, X, y): | ||
if isinstance(X, pd.DataFrame): | ||
self.model.fit(X.values, y.values) | ||
else: | ||
self.model.fit(np.asarray(X), np.asarray(y)) | ||
return self | ||
|
||
def predict(self, X): | ||
# Load fasttext model | ||
if isinstance(X, pd.DataFrame): | ||
return self.model.predict(X.values) | ||
else: | ||
return self.model.predict(np.asarray(X)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters