Skip to content

Commit

Permalink
Avoid repeated classifiers
Browse files Browse the repository at this point in the history
  • Loading branch information
roquelopez committed Apr 5, 2024
1 parent 75faf92 commit 90adcaf
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 6 deletions.
9 changes: 6 additions & 3 deletions alpha_automl/pipeline_search/pipeline/PipelineLogic.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,17 +95,20 @@ def has_legal_moves(self):

def next_state(self, action):
s = self.valid_moves[action]
nt = self.non_terminals[s[:s.index('-')].strip()]
nt = self.non_terminals[s[:s.index('->')].strip()]
r = [self.non_terminals[p] if p in self.non_terminals.keys() else
self.terminals[p] for p in s[s.index('-')+2:].strip().split(' ')]
self.terminals[p] for p in s[s.index('->')+2:].strip().split(' ')]
r = [x for x in r if x != 0]
s = []
not_used = True

for p in self.pieces_p:
if p == 0:
continue

if p == nt:
if p == nt and not_used: # Chose one primitive at the time
s += r
not_used = False
else:
s.append(p)

Expand Down
2 changes: 1 addition & 1 deletion alpha_automl/pipeline_synthesis/pipeline_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def make_primitive_objects(self, primitives):

change_default_hyperparams(primitive_object)

if primitive_type in nonnumeric_columns: # Create a new transformer and add it to the list
if primitive_type in nonnumeric_columns: # Create a new transformer and add it to the list
transformers += self.create_transformers(primitive_object, primitive_name, primitive_type)
else:
if len(transformers) > 0: # Add previous transformers to the pipeline
Expand Down
43 changes: 43 additions & 0 deletions alpha_automl/pipeline_synthesis/setup_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,57 @@ def signal_handler(queue, signum):
sys.exit(0)


def check_repeated_classifiers(pipeline_primitives, all_primitives, ensemble_pipelines_hash):
# Verify if the classifiers are repeated in the ensembles (regardless of the order)
classifiers = []
pipeline_hash = ''
has_ensemble_primitive = False
has_repeated_classifiers = False

for primitive_name in pipeline_primitives:
primitive_type = all_primitives[primitive_name]['type']

if primitive_type == 'CLASSIFIER':
classifiers.append(primitive_name)
elif primitive_type == 'MULTI_ENSEMBLER':
has_ensemble_primitive = True
pipeline_hash += primitive_name
if len(classifiers) != len(set(classifiers)): # All classifiers should be different
has_repeated_classifiers = True
else:
pipeline_hash += primitive_name

if not has_ensemble_primitive:
return False

if has_repeated_classifiers:
return True

pipeline_hash += ''.join(sorted(classifiers))

if pipeline_hash in ensemble_pipelines_hash:
return True
else:
ensemble_pipelines_hash.add(pipeline_hash)
return False


def search_pipelines(X, y, scoring, splitting_strategy, task_name, automl_hyperparams, metadata, output_folder, verbose,
queue):
signal.signal(signal.SIGTERM, lambda signum, frame: signal_handler(queue, signum))
hide_logs(verbose) # Hide logs here too, since multiprocessing has some issues with loggers

builder = BaseBuilder(metadata, automl_hyperparams)
all_primitives = builder.all_primitives
ensemble_pipelines_hash = set()

def evaluate_pipeline(primitives, origin):
has_repeated_classifiers = check_repeated_classifiers(primitives, all_primitives, ensemble_pipelines_hash)

if has_repeated_classifiers:
logger.debug('Repeated classifiers detected in ensembles, ignoring pipeline')
return None

pipeline = builder.make_pipeline(primitives)
score = None

Expand Down
4 changes: 2 additions & 2 deletions alpha_automl/resource/base_grammar.bnf
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ SEMISUPERVISED_TASK -> IMPUTER ENCODERS FEATURE_SCALER SEMISUPERVISED_CLASSIFIER
NA_TASK -> CLASSIFICATION_TASK | REGRESSION_TASK | SEMISUPERVISED_TASK
ENCODERS -> TEXT_ENCODER DATETIME_ENCODER CATEGORICAL_ENCODER IMAGE_ENCODER
ENSEMBLER -> SINGLE_ENSEMBLER | CLASSIFIER CLASSIFIER MULTI_ENSEMBLER | E
SEMISUPERVISED_CLASSIFIER -> CLASSIFIER SEMISUPERVISED_SELFTRAINER| SEMISUPERVISED_LABELPROPAGATOR
SEMISUPERVISED_CLASSIFIER -> CLASSIFIER SEMISUPERVISED_SELFTRAINER | SEMISUPERVISED_LABELPROPAGATOR
IMPUTER -> 'primitive_terminal'
FEATURE_SCALER -> 'primitive_terminal' | 'E'
FEATURE_SELECTOR -> 'primitive_terminal' | 'E'
Expand All @@ -22,4 +22,4 @@ REGRESSOR -> 'primitive_terminal'
CLUSTERER -> 'primitive_terminal'
TIME_SERIES_FORECASTER -> 'primitive_terminal'
SEMISUPERVISED_SELFTRAINER -> 'primitive_terminal'
SEMISUPERVISED_LABELPROPAGATOR -> 'primitive_terminal'
SEMISUPERVISED_LABELPROPAGATOR -> 'primitive_terminal'

0 comments on commit 90adcaf

Please sign in to comment.