Skip to content

Commit

Permalink
resample when evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 22, 2024
1 parent 70fccd4 commit 64deb84
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
3 changes: 2 additions & 1 deletion alpha_automl/automl_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,11 @@ def _search_pipelines(self, automl_hyperparams):
found_pipelines = 0

pipeline_threshold = 20
X, y, _ = sample_dataset(self.X, self.y, SAMPLE_SIZE, self.task)
while pipelines and found_pipelines < pipeline_threshold:
pipeline = pipelines.pop()
try:
alphaautoml_pipeline = score_pipeline(pipeline, self.X, self.y, self.scoring,
alphaautoml_pipeline = score_pipeline(pipeline, X, y, self.scoring,
self.splitting_strategy, self.task,
self.verbose)

Expand Down
5 changes: 3 additions & 2 deletions alpha_automl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import pandas as pd
import torch
from datetime import datetime
from enum import Enum
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import LabelEncoder
Expand Down Expand Up @@ -66,10 +67,10 @@ def sample_dataset(X, y, sample_size, task):
if original_size > sample_size:
ratio = sample_size / original_size
try:
_, X_test, _, y_test = train_test_split(X, y, random_state=RANDOM_SEED, test_size=ratio, stratify=y, shuffle=shuffle)
_, X_test, _, y_test = train_test_split(X, y, random_state=datetime.now().timestamp(), test_size=ratio, stratify=y, shuffle=shuffle)
except Exception:
# Not using stratified sampling when the minority class has few instances, not enough for all the folds
_, X_test, _, y_test = train_test_split(X, y, random_state=RANDOM_SEED, test_size=ratio, shuffle=shuffle)
_, X_test, _, y_test = train_test_split(X, y, random_state=datetime.now().timestamp(), test_size=ratio, shuffle=shuffle)
logger.debug(f'Sampling down data from {original_size} to {len(X_test)}')
if isinstance(X_test, pd.DataFrame):
X_test = X_test.reset_index(drop=True)
Expand Down

0 comments on commit 64deb84

Please sign in to comment.