Skip to content

Commit

Permalink
add functions
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Jun 6, 2024
1 parent 8540387 commit 899df27
Showing 1 changed file with 31 additions and 1 deletion.
32 changes: 31 additions & 1 deletion alpha_automl/automl_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def __init__(self, time_bound=15, metric=None, split_strategy='holdout', time_bo
self.output_folder = setup_output_folder(output_folder)
self.pipelines = {}
self.new_primitives = {}
self.include_primitives = {}
self.exclude_primitives = {}
self.X = None
self.y = None
self.leaderboard = None
Expand All @@ -79,7 +81,11 @@ def fit(self, X, y):
"""
self.X = X
self.y = y
automl_hyperparams = {'new_primitives': self.new_primitives}
automl_hyperparams = {
'new_primitives': self.new_primitives,
'include_primitives': self.include_primitives,
'exclude_primitives': self.exclude_primitives
}
pipelines = []
start_time = datetime.datetime.utcnow()

Expand Down Expand Up @@ -206,6 +212,30 @@ def add_primitives(self, new_primitives):
self.new_primitives[primitive_name] = {'primitive_object': primitive_object,
'primitive_type': primitive_type}

def whitelist_primitives(self, include_primitives):
"""
Whitelist primitives to the search space.
include_primitives: [('FEATURE_GENERATOR', 'sklearn.preprocessing.PolynomialFeatures'), ...]
"""

for primitive_type, primitive_name in include_primitives:
if primitive_type not in self.include_primitives:
self.include_primitives[primitive_type] = [primitive_name]
else:
self.include_primitives[primitive_type].append(primitive_name)

def blacklist_primitives(self, exclude_primitives):
"""
Blacklist primitives to the search space.
exclude_primitives: [('FEATURE_GENERATOR', 'sklearn.preprocessing.PolynomialFeatures'), ...]
"""

for primitive_type, primitive_name in exclude_primitives:
if primitive_type not in self.exclude_primitives:
self.exclude_primitives[primitive_type] = [primitive_name]
else:
self.exclude_primitives[primitive_type].append(primitive_name)

def get_leaderboard(self):
"""
Return the leaderboard.
Expand Down

0 comments on commit 899df27

Please sign in to comment.