From 16bda9dfe71d3b55335bc5f4da94b0bb16ff44df Mon Sep 17 00:00:00 2001 From: Roque Lopez Date: Fri, 16 Aug 2024 15:18:05 -0400 Subject: [PATCH] Apply MathFeatures within ColumnTransformer --- alpha_automl/pipeline_synthesis/pipeline_builder.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/alpha_automl/pipeline_synthesis/pipeline_builder.py b/alpha_automl/pipeline_synthesis/pipeline_builder.py index 0fd9461..7b7477d 100644 --- a/alpha_automl/pipeline_synthesis/pipeline_builder.py +++ b/alpha_automl/pipeline_synthesis/pipeline_builder.py @@ -120,7 +120,7 @@ def make_primitive_objects(self, primitives): primitive_object = create_object(primitive_name, {'estimators': estimators}) elif "feature_engine.creation" in primitive_name: primitive_name_type = primitive_name.split('-')[1] - primitive_object = create_math_features(primitive_name_type, numeric_columns) + primitive_object = create_math_features(primitive_name_type, list(range(len(numeric_columns)))) elif self.all_primitives[primitive_name]['origin'] == NATIVE_PRIMITIVE: # It's an installed primitive primitive_object = create_object(primitive_name, EXTRA_PARAMS.get(primitive_name, None)) else: @@ -130,6 +130,8 @@ def make_primitive_objects(self, primitives): if primitive_type in nonnumeric_columns: # Create a new transformer and add it to the list transformers += self.create_transformers(primitive_object, primitive_name, primitive_type) + elif primitive_type == 'FEATURE_GENERATOR': + transformers += self.create_transformers(primitive_object, primitive_name, primitive_type) else: if len(transformers) > 0: # Add previous transformers to the pipeline if len(useless_columns) > 0: @@ -145,6 +147,7 @@ def make_primitive_objects(self, primitives): def create_transformers(self, primitive_object, primitive_name, primitive_type): column_transformers = [] nonnumeric_columns = self.metadata['nonnumeric_columns'] + numeric_columns = self.metadata['numeric_columns'] if primitive_type == 'TEXT_ENCODER': column_transformers = [(f'{primitive_name}-{col_name}', primitive_object, col_index) for @@ -152,6 +155,9 @@ def create_transformers(self, primitive_object, primitive_name, primitive_type): elif primitive_type == 'CATEGORICAL_ENCODER' or primitive_type == 'DATETIME_ENCODER' or primitive_type == 'IMAGE_ENCODER': column_transformers = [(primitive_name, primitive_object, [col_index for col_index, _ in nonnumeric_columns[primitive_type]])] + elif primitive_type == 'FEATURE_GENERATOR': + column_transformers = [(primitive_name, primitive_object, [col_index for col_index, _ + in numeric_columns])] return column_transformers