Skip to content

Commit

Permalink
Apply MathFeatures within ColumnTransformer
Browse files Browse the repository at this point in the history
  • Loading branch information
roquelopez committed Aug 16, 2024
1 parent b8c07ab commit 16bda9d
Showing 1 changed file with 7 additions and 1 deletion.
8 changes: 7 additions & 1 deletion alpha_automl/pipeline_synthesis/pipeline_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def make_primitive_objects(self, primitives):
primitive_object = create_object(primitive_name, {'estimators': estimators})
elif "feature_engine.creation" in primitive_name:
primitive_name_type = primitive_name.split('-')[1]
primitive_object = create_math_features(primitive_name_type, numeric_columns)
primitive_object = create_math_features(primitive_name_type, list(range(len(numeric_columns))))
elif self.all_primitives[primitive_name]['origin'] == NATIVE_PRIMITIVE: # It's an installed primitive
primitive_object = create_object(primitive_name, EXTRA_PARAMS.get(primitive_name, None))
else:
Expand All @@ -130,6 +130,8 @@ def make_primitive_objects(self, primitives):

if primitive_type in nonnumeric_columns: # Create a new transformer and add it to the list
transformers += self.create_transformers(primitive_object, primitive_name, primitive_type)
elif primitive_type == 'FEATURE_GENERATOR':
transformers += self.create_transformers(primitive_object, primitive_name, primitive_type)
else:
if len(transformers) > 0: # Add previous transformers to the pipeline
if len(useless_columns) > 0:
Expand All @@ -145,13 +147,17 @@ def make_primitive_objects(self, primitives):
def create_transformers(self, primitive_object, primitive_name, primitive_type):
column_transformers = []
nonnumeric_columns = self.metadata['nonnumeric_columns']
numeric_columns = self.metadata['numeric_columns']

if primitive_type == 'TEXT_ENCODER':
column_transformers = [(f'{primitive_name}-{col_name}', primitive_object, col_index) for
col_index, col_name in nonnumeric_columns[primitive_type]]
elif primitive_type == 'CATEGORICAL_ENCODER' or primitive_type == 'DATETIME_ENCODER' or primitive_type == 'IMAGE_ENCODER':
column_transformers = [(primitive_name, primitive_object, [col_index for col_index, _
in nonnumeric_columns[primitive_type]])]
elif primitive_type == 'FEATURE_GENERATOR':
column_transformers = [(primitive_name, primitive_object, [col_index for col_index, _
in numeric_columns])]

return column_transformers

0 comments on commit 16bda9d

Please sign in to comment.