Skip to content

Commit

Permalink
debug
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 12, 2024
1 parent a706691 commit 439dbfe
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 12 deletions.
64 changes: 60 additions & 4 deletions alpha_automl/pipeline_search/AlphaAutoMLEnv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging

import random
import gymnasium as gym
import numpy as np
from gymnasium.spaces import Box, Dict, Discrete
Expand All @@ -25,22 +25,24 @@ class AlphaAutoMLEnv(gym.Env):
def __init__(self, config: EnvContext):
self.game = config["game"]
self.board = self.game.getInitBoard()
# self.step_stack = ["S"]
self.metadata = self.board[: self.game.m]
self.observation_space = Dict(
{
"board": Box(
0, 80, shape=(self.game.p + self.game.m,), dtype=np.uint8
0, 85, shape=(self.game.p + self.game.m,), dtype=np.uint8
), # board
}
)
self.action_space = Discrete(80) # primitives to choose from
self.action_space = Discrete(85) # primitives to choose from

self.cur_player = 1 # NEVER USED - ONLY ONE PLAYER

def reset(self, *, seed=None, options=None):
# init number of steps
self.num_steps = 0

# self.step_stack = ["S"]
self.board = self.game.getInitBoard()
self.metadata = self.board[: self.game.m]
self.found = set()
Expand All @@ -49,6 +51,7 @@ def reset(self, *, seed=None, options=None):
return {"board": np.array(self.board).astype(np.uint8)}, {}

def step(self, action):
# Check the action is illegal
valid_moves = self.game.getValidMoves(self.board, self.cur_player)
if action >= len(valid_moves) or valid_moves[action] != 1:
return (
Expand All @@ -59,12 +62,27 @@ def step(self, action):
{},
)

# Check the action is out of order
# move_type, non_terminals_moves = self.extract_action_details(action)
# if move_type != self.step_stack[-1]:
# return (
# {"board": np.array(self.board).astype(np.uint8)},
# -100,
# True,
# False,
# {},
# )


# update number of steps
self.num_steps += 1

# update board with new action
# print(f"action: {action}\n board: {self.board}")
self.board, _ = self.game.getNextState(self.board, self.cur_player, action)

if self.num_steps > 9:
logger.info(f"[YFW]================={self.board[self.game.m:]}")
# reward: win(1) - pipeline score, not end(0) - 0, bad(2) - 0
reward = 0
game_end = self.game.getGameEnded(self.board, self.cur_player)
Expand All @@ -83,8 +101,38 @@ def step(self, action):
logger.critical(f"[PIPELINE FOUND] Error happened")
elif game_end == 2: # finished but invalid
reward = 10
else:
elif action == 0:
reward = 1

else:
# popped_step = self.step_stack.pop()
# if move_type == "ENSEMBLER":
# reward = 1
# if non_terminals_moves[0] != "E" and non_terminals_moves[0].upper() == non_terminals_moves[0]:
# self.step_stack.extend(non_terminals_moves)
# if move_type == "S":
# reward = 1
# elif move_type == "ENCODERS":
# reward = 1
# else:
# if move_type == "IMPUTER" or move_type == "CATEGORICAL_ENCODER":
# reward = 1
reward = random.uniform(0, 1)
# if move_string.upper() != move_string:
# reward = random.uniform(0, 1)
# else:
# split_move = move_string.split("->")
# non_terminals_moves = move_string.split("->")[1].strip().split(" ")

# if split_move[0].strip() == "ENSEMBLER":
# if "E" in non_terminals_moves:
# rewards = 5
# else:
# rewards = 5 - len(non_terminals_moves)
# else:
# rewards = random.uniform(0, 1)



# done & truncated
truncated = self.num_steps >= 200
Expand All @@ -97,3 +145,11 @@ def step(self, action):
truncated,
{},
)

def extract_action_details(self, action):
rules = self.game.grammar["RULES"]
move_string = list(rules.keys())[list(rules.values()).index(action + 1)]
split_move = move_string.split("->")
move_type = split_move[0].strip()
non_terminals_moves = split_move[1].strip().split(" ")
return move_type, non_terminals_moves
12 changes: 6 additions & 6 deletions alpha_automl/pipeline_search/RlLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def pipeline_search_rllib(game, time_bound, save_checkpoint=False):

# load checkpoint or create a new one
algo = load_rllib_checkpoint(game, num_rollout_workers=3)
# algo = load_rllib_checkpoint_dqn(game, num_rollout_workers=3)
logger.debug("[RlLib] Create Algo object done")

# train model
Expand All @@ -55,13 +56,13 @@ def load_rllib_checkpoint(game, num_rollout_workers):
.rollouts(num_rollout_workers=num_rollout_workers)
.training(
gamma=0.99,
clip_param=0.2,
kl_coeff=0.2,
clip_param=0.3,
kl_coeff=0.3,
entropy_coeff=0.01,
train_batch_size=5000,
train_batch_size=10000,
)
)
config.lr = 1e-4
config.lr = 1e-5
config.simple_optimizer = True
logger.debug("[RlLib] Create Config done")

Expand All @@ -79,7 +80,6 @@ def load_rllib_checkpoint(game, num_rollout_workers):
# checkpoint_info = get_checkpoint_info(PATH_TO_CHECKPOINT)
return algo


def train_rllib_model(algo, time_bound, save_checkpoint=False):
timeout = time.time() + time_bound
result = algo.train()
Expand All @@ -95,7 +95,7 @@ def train_rllib_model(algo, time_bound, save_checkpoint=False):
logger.info(f"[RlLib] Train Timeout")
break

if [f for f in os.listdir(PATH_TO_CHECKPOINT) if not f.startswith(".")] != []:
if save_checkpoint and [f for f in os.listdir(PATH_TO_CHECKPOINT) if not f.startswith(".")] != []:
weights = load_rllib_policy_weights()
algo.set_weights(weights)
result = algo.train()
Expand Down
6 changes: 4 additions & 2 deletions alpha_automl/pipeline_synthesis/setup_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
'SEMISUPERVISED': 6,
},
'DATA_TYPES': {'TABULAR': 1, 'GRAPH': 2, 'IMAGE': 3},
'PIPELINE_SIZE': 8,
'PIPELINE_SIZE': 10,
'ARGS': {
'numIters': 25,
'numEps': 5,
Expand Down Expand Up @@ -100,14 +100,16 @@ def search_pipelines(
) # Hide logs here too, since multiprocessing has some issues with loggers

builder = BaseBuilder(metadata, automl_hyperparams)
all_primitives = builder.all_primitives
ensemble_pipelines_hash = set()

task_start = datetime.now()

def evaluate_pipeline(primitives, origin):
has_repeated_classifiers = check_repeated_classifiers(primitives, all_primitives, ensemble_pipelines_hash)

if has_repeated_classifiers:
logger.debug('Repeated classifiers detected in ensembles, ignoring pipeline')
logger.info('Repeated classifiers detected in ensembles, ignoring pipeline')
return None

pipeline = builder.make_pipeline(primitives)
Expand Down

0 comments on commit 439dbfe

Please sign in to comment.