Skip to content

Commit

Permalink
tmp
Browse files Browse the repository at this point in the history
  • Loading branch information
EdenWuyifan committed Apr 5, 2024
1 parent 460c44a commit c9250eb
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions alpha_automl/pipeline_search/RlLib.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from datetime import datetime

import ray
from ray.rllib.policy import Policy
from ray.rllib.utils.checkpoints import get_checkpoint_info
from ray.tune.logger import pretty_print
from ray.tune.registry import get_trainable_cls
Expand Down Expand Up @@ -70,9 +71,11 @@ def load_rllib_checkpoint(game, num_rollout_workers):
return config.build()
else:
algo = config.build()

weights = load_rllib_policy_weights()

algo.set_weights(weights)
# Restore the old (checkpointed) state.
algo.restore(PATH_TO_CHECKPOINT)
# algo.restore(PATH_TO_CHECKPOINT)
# checkpoint_info = get_checkpoint_info(PATH_TO_CHECKPOINT)
return algo

Expand All @@ -91,6 +94,8 @@ def train_rllib_model(algo, time_bound, save_checkpoint=False):
):
logger.info(f"[RlLib] Train Timeout")
break
weights = load_rllib_policy_weights()
algo.set_weights(weights)
result = algo.train()
logger.info(pretty_print(result))
# stop training of the target train steps or reward are reached
Expand All @@ -104,6 +109,15 @@ def train_rllib_model(algo, time_bound, save_checkpoint=False):
algo.stop()


def load_rllib_policy_weights():
logger.info(f"[RlLib] Synchronizing model weights...")
policy = Policy.from_checkpoint(PATH_TO_CHECKPOINT)
policy = policy['default_policy']
weights = policy.get_weights()

weights = {'default_policy': weights}
return weights

def save_rllib_checkpoint(algo):
save_result = algo.save(checkpoint_dir=PATH_TO_CHECKPOINT)
path_to_checkpoint = save_result.checkpoint.path
Expand Down

0 comments on commit c9250eb

Please sign in to comment.