diff --git a/alpha_automl/pipeline_search/agent_lab.py b/alpha_automl/pipeline_search/agent_lab.py index 1efd9297..a0661f2d 100644 --- a/alpha_automl/pipeline_search/agent_lab.py +++ b/alpha_automl/pipeline_search/agent_lab.py @@ -21,11 +21,11 @@ def pipeline_search_rllib(game, time_bound, checkpoint_load_folder, checkpoint_s ray.init(local_mode=True, logging_level=logging.CRITICAL) num_cpus = int(ray.available_resources()["CPU"]) - # load checkpoint or create a new one + # Load checkpoint or create a new one algo = load_rllib_checkpoint(game, checkpoint_load_folder, num_rollout_workers=1) logger.debug("Create Algo object done") - # train model + # Train model train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_folder) logger.debug("Training done") ray.shutdown() @@ -83,12 +83,12 @@ def train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_ if ( time.time() > timeout or (best_unchanged_iter >= 10 and result["episode_reward_mean"] >= 0) - # or result["episode_reward_mean"] >= 70 ): logger.debug(f"Training timeout reached") break if contain_checkpoints(checkpoint_save_folder): + # Load the most recent weights weights = load_rllib_policy_weights(checkpoint_save_folder) algo.set_weights(weights) elif contain_checkpoints(checkpoint_load_folder): @@ -96,7 +96,8 @@ def train_rllib_model(algo, time_bound, checkpoint_load_folder, checkpoint_save_ algo.set_weights(weights) result = algo.train() logger.debug(pretty_print(result)) - # stop training of the target train steps or reward are reached + + # Stop training of the target train steps or reward are reached if result["episode_reward_mean"] > last_best: last_best = result["episode_reward_mean"] best_unchanged_iter = 1 diff --git a/alpha_automl/pipeline_synthesis/setup_search.py b/alpha_automl/pipeline_synthesis/setup_search.py index d4911105..de19a054 100644 --- a/alpha_automl/pipeline_synthesis/setup_search.py +++ b/alpha_automl/pipeline_synthesis/setup_search.py @@ -92,7 +92,7 @@ def evaluate_pipeline(primitives): checkpoint_save_folder = ( checkpoints_folder if checkpoints_folder is not None - else DEFAULT_CHECKPOINT_PATH + else output_folder ) game = PipelineGame(config_updated, evaluate_pipeline) pipeline_search_rllib(