diff --git a/alpha_automl/utils.py b/alpha_automl/utils.py index 66eedc10..73313108 100644 --- a/alpha_automl/utils.py +++ b/alpha_automl/utils.py @@ -8,6 +8,7 @@ import numpy as np import pandas as pd import torch +from enum import Enum from sklearn.compose import ColumnTransformer from sklearn.preprocessing import LabelEncoder from sklearn.model_selection import ShuffleSplit, train_test_split @@ -24,6 +25,25 @@ RANDOM_SEED = 0 +class PrimitiveType(Enum): + IMPUTER = 'IMPUTER' + CATEGORICAL_ENCODER = 'CATEGORICAL_ENCODER' + DATETIME_ENCODER = 'DATETIME_ENCODER' + TEXT_ENCODER = 'TEXT_ENCODER' + IMAGE_ENCODER = 'IMAGE_ENCODER' + FEATURE_SCALER = 'FEATURE_SCALER' + FEATURE_SELECTOR = 'FEATURE_SELECTOR' + COLUMN_TRANSFORMER = 'COLUMN_TRANSFORMER' + TIME_SERIES_FORECASTER = 'TIME_SERIES_FORECASTER' + CLASSIFIER = 'CLASSIFIER' + REGRESSOR = 'REGRESSOR' + CLUSTERER = 'CLUSTERER' + SINGLE_ENSEMBLER = 'SINGLE_ENSEMBLER' + MULTI_ENSEMBLER = 'MULTI_ENSEMBLER' + SEMISUPERVISED_SELFTRAINER = 'SEMISUPERVISED_SELFTRAINER' + SEMISUPERVISED_LABELPROPAGATOR = 'SEMISUPERVISED_LABELPROPAGATOR' + + def create_object(import_path, class_params=None): if class_params is None: class_params = {}