Skip to content

Commit

Permalink
update game env and configs, limit results to 20
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 17, 2024
1 parent 439dbfe commit 162bd45
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 41 deletions.
84 changes: 52 additions & 32 deletions alpha_automl/pipeline_search/AlphaAutoMLEnv.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class AlphaAutoMLEnv(gym.Env):
def __init__(self, config: EnvContext):
self.game = config["game"]
self.board = self.game.getInitBoard()
# self.step_stack = ["S"]
self.step_stack = ["S"]
self.metadata = self.board[: self.game.m]
self.observation_space = Dict(
{
Expand All @@ -34,52 +34,62 @@ def __init__(self, config: EnvContext):
), # board
}
)
self.action_space = Discrete(85) # primitives to choose from
# self.action_space = Discrete(85) # primitives to choose from
self.max_actions = 24
self.action_spaces = self.generate_action_spaces()
self.action_offsets = self.generate_action_offsets()
self.action_space = Discrete(self.max_actions)


self.cur_player = 1 # NEVER USED - ONLY ONE PLAYER

def reset(self, *, seed=None, options=None):
# init number of steps
self.num_steps = 0

# self.step_stack = ["S"]
self.step_stack = ["S"]
self.board = self.game.getInitBoard()
self.metadata = self.board[: self.game.m]
self.found = set()

# print(f"metadata: {self.metadata}\n board: {self.board}")
return {"board": np.array(self.board).astype(np.uint8)}, {}

def step(self, action):
curr_step = self.step_stack.pop()
offseted_action = self.action_offsets[curr_step]+action
valid_action_size = self.action_spaces[curr_step]
# Check the action is illegal
valid_moves = self.game.getValidMoves(self.board, self.cur_player)
if action >= len(valid_moves) or valid_moves[action] != 1:
if action >= valid_action_size or valid_moves[offseted_action-1] != 1:
return (
{"board": np.array(self.board).astype(np.uint8)},
-100,
-1,
True,
False,
{},
)

# Check the action is out of order
# move_type, non_terminals_moves = self.extract_action_details(action)
# if move_type != self.step_stack[-1]:
# return (
# {"board": np.array(self.board).astype(np.uint8)},
# -100,
# True,
# False,
# {},
# )
move_type, non_terminals_moves = self.extract_action_details(offseted_action)
# logger.critical(f"offseted_action: {offseted_action} ===> curr_step: {curr_step}")
if move_type != curr_step:
return (
{"board": np.array(self.board).astype(np.uint8)},
-100,
True,
False,
{},
)
if non_terminals_moves[0] != "E" and non_terminals_moves[0].upper() == non_terminals_moves[0]:
self.step_stack.extend(non_terminals_moves[::-1])


# update number of steps
self.num_steps += 1

# update board with new action
# print(f"action: {action}\n board: {self.board}")
self.board, _ = self.game.getNextState(self.board, self.cur_player, action)
self.board, _ = self.game.getNextState(self.board, self.cur_player, offseted_action-1)

if self.num_steps > 9:
logger.info(f"[YFW]================={self.board[self.game.m:]}")
Expand All @@ -92,32 +102,19 @@ def step(self, action):
reward = 10 + (100 / self.game.getEvaluation(self.board))
else:
reward = 10 + (self.game.getEvaluation(self.board)) ** 2 * 100
if tuple(self.board[self.game.m :]) not in self.found:
self.found.add(tuple(self.board[self.game.m :]))
logger.debug(
f"[PIPELINE FOUND] {self.board[self.game.m:]} -> {reward}"
)
except:
logger.critical(f"[PIPELINE FOUND] Error happened")
elif game_end == 2: # finished but invalid
reward = 10
elif action == 0:
reward = 1

else:
# popped_step = self.step_stack.pop()
# if move_type == "ENSEMBLER":
# reward = 1
# if non_terminals_moves[0] != "E" and non_terminals_moves[0].upper() == non_terminals_moves[0]:
# self.step_stack.extend(non_terminals_moves)
# if move_type == "S":
# reward = 1
# elif move_type == "ENCODERS":
# reward = 1
# else:
# if move_type == "IMPUTER" or move_type == "CATEGORICAL_ENCODER":
# reward = 1
reward = random.uniform(0, 1)
reward = 1
# if move_string.upper() != move_string:
# reward = random.uniform(0, 1)
# else:
Expand All @@ -135,7 +132,7 @@ def step(self, action):


# done & truncated
truncated = self.num_steps >= 200
truncated = self.num_steps >= 20
done = game_end or truncated

return (
Expand All @@ -148,8 +145,31 @@ def step(self, action):

def extract_action_details(self, action):
rules = self.game.grammar["RULES"]
move_string = list(rules.keys())[list(rules.values()).index(action + 1)]
move_string = list(rules.keys())[list(rules.values()).index(action)]
split_move = move_string.split("->")
move_type = split_move[0].strip()
non_terminals_moves = split_move[1].strip().split(" ")
return move_type, non_terminals_moves

def generate_action_spaces(self):
action_spaces = {}
for action in self.game.grammar["RULES"].values():
move_type, non_terminals_moves = self.extract_action_details(action)

if move_type not in action_spaces:
action_spaces[move_type] = 1
else:
action_spaces[move_type] += 1

return action_spaces

def generate_action_offsets(self):
action_offsets = {}
for action in self.game.grammar["RULES"].values():
move_type, non_terminals_moves = self.extract_action_details(action)

if move_type not in action_offsets:
action_offsets[move_type] = action

return action_offsets

15 changes: 8 additions & 7 deletions alpha_automl/pipeline_search/RlLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ray.rllib.utils.checkpoints import get_checkpoint_info
from ray.tune.logger import pretty_print
from ray.tune.registry import get_trainable_cls
from ray import tune

from alpha_automl.pipeline_search.AlphaAutoMLEnv import AlphaAutoMLEnv

Expand All @@ -23,13 +24,12 @@ def pipeline_search_rllib(game, time_bound, save_checkpoint=False):
"""
Search for pipelines using Rllib
"""
ray.init(local_mode=True, num_cpus=4)
ray.init(local_mode=True, num_cpus=8)
num_cpus = int(ray.available_resources()["CPU"])
logger.debug("[RlLib] Ready")

# load checkpoint or create a new one
algo = load_rllib_checkpoint(game, num_rollout_workers=3)
# algo = load_rllib_checkpoint_dqn(game, num_rollout_workers=3)
algo = load_rllib_checkpoint(game, num_rollout_workers=7)
logger.debug("[RlLib] Create Algo object done")

# train model
Expand Down Expand Up @@ -58,7 +58,7 @@ def load_rllib_checkpoint(game, num_rollout_workers):
gamma=0.99,
clip_param=0.3,
kl_coeff=0.3,
entropy_coeff=0.01,
entropy_coeff=0.05,
train_batch_size=10000,
)
)
Expand All @@ -80,6 +80,7 @@ def load_rllib_checkpoint(game, num_rollout_workers):
# checkpoint_info = get_checkpoint_info(PATH_TO_CHECKPOINT)
return algo


def train_rllib_model(algo, time_bound, save_checkpoint=False):
timeout = time.time() + time_bound
result = algo.train()
Expand Down Expand Up @@ -129,7 +130,7 @@ def save_rllib_checkpoint(algo):
)


def dump_result_to_json(primitives, task_start, output_folder=None):
def dump_result_to_json(primitives, task_start, score, output_folder=None):
output_path = generate_json_path(output_folder)
# Read JSON data from input file
if not os.path.exists(output_path) or os.path.getsize(output_path) == 0:
Expand All @@ -145,7 +146,7 @@ def dump_result_to_json(primitives, task_start, output_folder=None):
# Check for duplicate elements
if primitives in data.values():
return
data[timestamp] = primitives
data[score] = primitives

# Write unique elements to output file
with open(output_path, "w") as f:
Expand All @@ -166,7 +167,7 @@ def read_result_to_pipeline(builder, output_folder=None):
data = json.load(f)

# Check for duplicate elements
for primitives in data.values():
for score, primitives in sorted(data.items(), reverse=True):
pipeline = builder.make_pipeline(primitives)
if pipeline:
pipelines.append(pipeline)
Expand Down
4 changes: 2 additions & 2 deletions alpha_automl/pipeline_synthesis/setup_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def evaluate_pipeline(primitives, origin):
if alphaautoml_pipeline is not None:
score = alphaautoml_pipeline.get_score()
if score is not None:
dump_result_to_json(primitives, task_start, output_folder)
dump_result_to_json(primitives, task_start, score, output_folder)
return score

if task_name is None:
Expand Down Expand Up @@ -168,7 +168,7 @@ def evaluate_pipeline(primitives, origin):

logger.debug('Search completed')

return read_result_to_pipeline(builder, output_folder)
return read_result_to_pipeline(builder, output_folder)[:20]
# queue.put('DONE')


Expand Down

0 comments on commit 162bd45

Please sign in to comment.