From c9250eba5c60a0bd02ba234af27b47f655582e5e Mon Sep 17 00:00:00 2001 From: EdenWuyifan Date: Fri, 5 Apr 2024 13:45:45 -0400 Subject: [PATCH] tmp --- alpha_automl/pipeline_search/RlLib.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/alpha_automl/pipeline_search/RlLib.py b/alpha_automl/pipeline_search/RlLib.py index 9b70f7df..0aa77e68 100644 --- a/alpha_automl/pipeline_search/RlLib.py +++ b/alpha_automl/pipeline_search/RlLib.py @@ -6,6 +6,7 @@ from datetime import datetime import ray +from ray.rllib.policy import Policy from ray.rllib.utils.checkpoints import get_checkpoint_info from ray.tune.logger import pretty_print from ray.tune.registry import get_trainable_cls @@ -70,9 +71,11 @@ def load_rllib_checkpoint(game, num_rollout_workers): return config.build() else: algo = config.build() - + weights = load_rllib_policy_weights() + + algo.set_weights(weights) # Restore the old (checkpointed) state. - algo.restore(PATH_TO_CHECKPOINT) + # algo.restore(PATH_TO_CHECKPOINT) # checkpoint_info = get_checkpoint_info(PATH_TO_CHECKPOINT) return algo @@ -91,6 +94,8 @@ def train_rllib_model(algo, time_bound, save_checkpoint=False): ): logger.info(f"[RlLib] Train Timeout") break + weights = load_rllib_policy_weights() + algo.set_weights(weights) result = algo.train() logger.info(pretty_print(result)) # stop training of the target train steps or reward are reached @@ -104,6 +109,15 @@ def train_rllib_model(algo, time_bound, save_checkpoint=False): algo.stop() +def load_rllib_policy_weights(): + logger.info(f"[RlLib] Synchronizing model weights...") + policy = Policy.from_checkpoint(PATH_TO_CHECKPOINT) + policy = policy['default_policy'] + weights = policy.get_weights() + + weights = {'default_policy': weights} + return weights + def save_rllib_checkpoint(algo): save_result = algo.save(checkpoint_dir=PATH_TO_CHECKPOINT) path_to_checkpoint = save_result.checkpoint.path