Skip to content

Commit

Permalink
Merge pull request #24 from OpenBioML/fix-checkpointing
Browse files Browse the repository at this point in the history
Fix checkpointing in fig 3 and 4
  • Loading branch information
zcqsntr authored May 12, 2023
2 parents 480cab5 + 8990fe6 commit 50cd5a9
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 75 deletions.
91 changes: 56 additions & 35 deletions RED/agents/continuous_agents/rt3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,42 +734,63 @@ def Q_update(self, recurrent=True, monte_carlo=False, policy=True, verbose=False
self.update_target_network(source=self.Q2_network, target=self.Q2_target, tau=self.polyak)
self.update_target_network(source=self.policy_network, target=self.policy_target, tau=self.polyak)

def save_network(self, save_path):
'''
Saves networks to directory specified by save_path
:param save_path: directory to save networks to
'''

torch.save(self.policy_network, os.path.join(save_path, "policy_network.pth"))
torch.save(self.Q1_network, os.path.join(save_path, "Q1_network.pth"))
torch.save(self.Q2_network, os.path.join(save_path, "Q2_network.pth"))

torch.save(self.policy_target, os.path.join(save_path, "policy_target.pth"))
torch.save(self.Q1_target, os.path.join(save_path, "Q1_target.pth"))
torch.save(self.Q2_target, os.path.join(save_path, "Q2_target.pth"))

def load_network(self, load_path, load_target_networks=False):
'''
Loads netoworks from directory specified by load_path.
:param load_path: directory to load networks from
:param load_target_networks: whether to load target networks
'''

self.policy_network = torch.load(os.path.join(load_path, "policy_network.pth"))
self.policy_network_opt = Adam(self.policy_network.parameters(), lr=self.pol_learning_rate)

self.Q1_network = torch.load(os.path.join(load_path, "Q1_network.pth"))
self.Q1_network_opt = Adam(self.Q1_network.parameters(), lr=self.val_learning_rate)

self.Q2_network = torch.load(os.path.join(load_path, "Q2_network.pth"))
self.Q2_etwork_opt = Adam(self.Q2_network.parameters(), lr=self.val_learning_rate)

def save_ckpt(self, save_path, additional_info=None):
'''
Creates a full checkpoint (networks, optimizers, memory buffers) and saves it to the specified path.
:param save_path: path to save the checkpoint to
:param additional_info: additional information to save (Python dictionary)
'''
ckpt = {
"policy_network": self.policy_network.state_dict(),
"Q1_network": self.Q1_network.state_dict(),
"Q2_network": self.Q2_network.state_dict(),
"policy_target": self.policy_target.state_dict(),
"Q1_target": self.Q1_target.state_dict(),
"Q2_target": self.Q2_target.state_dict(),
"policy_network_opt": self.policy_network_opt.state_dict(),
"Q1_network_opt": self.Q1_network_opt.state_dict(),
"Q2_network_opt": self.Q2_network_opt.state_dict(),
"additional_info": additional_info if additional_info is not None else {},
}

### save buffers
for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones",
"sequences", "next_sequences", "all_returns"):
ckpt[buffer] = getattr(self, buffer)

### save the checkpoint
torch.save(ckpt, save_path)

def load_ckpt(self, load_path, load_target_networks=True):
'''
Loads a full checkpoint (networks, optimizers, memory buffers) from the specified path.
:param load_path: path to load the checkpoint from
:param load_target_networks: whether to load the target networks as well
'''
ckpt = torch.load(load_path)

### load networks
self.policy_network.load_state_dict(ckpt["policy_network"])
self.Q1_network.load_state_dict(ckpt["Q1_network"])
self.Q2_network.load_state_dict(ckpt["Q2_network"])

### load target networks
if load_target_networks:
self.policy_target = torch.load(os.path.join(load_path, "policy_target.pth"))
self.Q1_target = torch.load(os.path.join(load_path, "Q1_target.pth"))
self.Q2_target = torch.load(os.path.join(load_path, "Q2_target.pth"))
else:
print("[WARNING] Not loading target networks")
self.policy_target.load_state_dict(ckpt["policy_target"])
self.Q1_target.load_state_dict(ckpt["Q1_target"])
self.Q2_target.load_state_dict(ckpt["Q2_target"])

### load optimizers
self.policy_network_opt.load_state_dict(ckpt["policy_network_opt"])
self.Q1_network_opt.load_state_dict(ckpt["Q1_network_opt"])
self.Q2_network_opt.load_state_dict(ckpt["Q2_network_opt"])

### load buffers
for buffer in ("memory", "values", "states", "next_states", "actions", "rewards", "dones",
"sequences", "next_sequences", "all_returns"):
setattr(self, buffer, ckpt[buffer])

return ckpt

def reset_weights(self, policy=True):
'''
Expand Down
1 change: 1 addition & 0 deletions RED/configs/example/Figure_3_RT3D_chemostat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ explore_rate_mul: 1
test_episode: False
save_path: ${hydra:run.dir}
ckpt_freq: 50
load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training

model:
batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'}
Expand Down
1 change: 1 addition & 0 deletions RED/configs/example/Figure_4_RT3D_chemostat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ explore_rate_mul: 1
test_episode: False
save_path: ${hydra:run.dir}
ckpt_freq: 50
load_ckpt_dir_path: null # directory containing agent's checkpoint to load ("agent.pt") + optionally "history.json" from which to resume training

model:
batch_size: ${eval:'${example.environment.N_control_intervals} * ${example.environment.n_parallel_experiments}'}
Expand Down
68 changes: 48 additions & 20 deletions examples/Figure_3_RT3D_chemostat/train_RT3D.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import json
import math
import os
import sys
Expand Down Expand Up @@ -48,12 +49,35 @@ def train_RT3D(cfg : DictConfig):
env, n_params = setup_env(cfg)
total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments
skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments

history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate"]}
update_count = 0
starting_episode = 0

history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]}

### load ckpt
if cfg.load_ckpt_dir_path is not None:
print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}")
# load the agent
agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt")
print(f"Loading agent from: {agent_path}")
additional_info = agent.load_ckpt(
load_path=agent_path,
load_target_networks=True,
)["additional_info"]
# load history
history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json")
if os.path.exists(history_path):
print(f"Loading history from: {history_path}")
with open(history_path, "r") as f:
history = json.load(f)
# load explore rate
if "explore_rate" in history and len(history["explore_rate"]) > 0:
explore_rate = history["explore_rate"][-1]
# load starting episode
if "episode" in additional_info:
starting_episode = additional_info["episode"] + 1

### training loop
for episode in range(total_episodes):
for episode in range(starting_episode, total_episodes):
actual_params = np.random.uniform(
low=cfg.environment.actual_params,
high=cfg.environment.actual_params,
Expand Down Expand Up @@ -108,11 +132,10 @@ def train_RT3D(cfg : DictConfig):
sequences[i].append(np.concatenate((state, action)))

### log episode data
e_us[i].append(u)
e_us[i].append(u.tolist())
next_states.append(next_state)
if reward != -1: # dont include the unstable trajectories as they override the true return
e_rewards[i].append(reward)
e_returns[i] += reward
e_rewards[i].append(reward)
e_returns[i] += reward
states = next_states

### do not memorize the test trajectory (the last one)
Expand All @@ -129,9 +152,11 @@ def train_RT3D(cfg : DictConfig):
### train agent
if episode > skip_first_n_episodes:
for _ in range(cfg.environment.n_parallel_experiments):
update_count += 1
update_policy = update_count % cfg.policy_delay == 0
history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1)
update_policy = history["update_count"][-1] % cfg.policy_delay == 0
agent.Q_update(policy=update_policy, recurrent=True)
else:
history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0)

### update explore rate
explore_rate = cfg.explore_rate_mul * agent.get_rate(
Expand All @@ -143,7 +168,7 @@ def train_RT3D(cfg : DictConfig):

### log results
history["returns"].extend(e_returns)
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2))
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist())
history["rewards"].extend(e_rewards)
history["us"].extend(e_us)
history["explore_rate"].append(explore_rate)
Expand All @@ -164,17 +189,20 @@ def train_RT3D(cfg : DictConfig):
)

### checkpoint
if cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0:
if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \
or episode == total_episodes - 1:
ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}")
os.makedirs(ckpt_dir, exist_ok=True)
agent.save_network(ckpt_dir)
for k in history.keys():
np.save(os.path.join(ckpt_dir, f"{k}.npy"), np.array(history[k]))

### save results and plot
agent.save_network(cfg.save_path)
for k in history.keys():
np.save(os.path.join(cfg.save_path, f"{k}.npy"), np.array(history[k]))
agent.save_ckpt(
save_path=os.path.join(ckpt_dir, "agent.pt"),
additional_info={
"episode": episode,
}
)
with open(os.path.join(ckpt_dir, "history.json"), "w") as f:
json.dump(history, f)

### plot
plot_returns(
returns=history["returns"],
explore_rates=history["explore_rate"],
Expand Down
68 changes: 48 additions & 20 deletions examples/Figure_4_RT3D_chemostat/train_RT3D.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@

import json
import math
import os
import sys
Expand Down Expand Up @@ -47,12 +48,35 @@ def train_RT3D(cfg : DictConfig):
env, n_params = setup_env(cfg)
total_episodes = cfg.environment.n_episodes // cfg.environment.n_parallel_experiments
skip_first_n_episodes = cfg.environment.skip_first_n_experiments // cfg.environment.n_parallel_experiments

history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate"]}
update_count = 0
starting_episode = 0

history = {k: [] for k in ["returns", "actions", "rewards", "us", "explore_rate", "update_count"]}

### load ckpt
if cfg.load_ckpt_dir_path is not None:
print(f"Loading checkpoint from: {cfg.load_ckpt_dir_path}")
# load the agent
agent_path = os.path.join(cfg.load_ckpt_dir_path, "agent.pt")
print(f"Loading agent from: {agent_path}")
additional_info = agent.load_ckpt(
load_path=agent_path,
load_target_networks=True,
)["additional_info"]
# load history
history_path = os.path.join(cfg.load_ckpt_dir_path, "history.json")
if os.path.exists(history_path):
print(f"Loading history from: {history_path}")
with open(history_path, "r") as f:
history = json.load(f)
# load explore rate
if "explore_rate" in history and len(history["explore_rate"]) > 0:
explore_rate = history["explore_rate"][-1]
# load starting episode
if "episode" in additional_info:
starting_episode = additional_info["episode"] + 1

### training loop
for episode in range(total_episodes):
for episode in range(starting_episode, total_episodes):
# sample params from uniform distribution
actual_params = np.random.uniform(
low=cfg.environment.lb,
Expand Down Expand Up @@ -108,11 +132,10 @@ def train_RT3D(cfg : DictConfig):
sequences[i].append(np.concatenate((state, action)))

### log episode data
e_us[i].append(u)
e_us[i].append(u.tolist())
next_states.append(next_state)
if reward != -1: # dont include the unstable trajectories as they override the true return
e_rewards[i].append(reward)
e_returns[i] += reward
e_rewards[i].append(reward)
e_returns[i] += reward
states = next_states

### do not memorize the test trajectory (the last one)
Expand All @@ -129,9 +152,11 @@ def train_RT3D(cfg : DictConfig):
### train agent
if episode > skip_first_n_episodes:
for _ in range(cfg.environment.n_parallel_experiments):
update_count += 1
update_policy = update_count % cfg.policy_delay == 0
history["update_count"].append(history["update_count"][-1] + 1 if len(history["update_count"]) > 0 else 1)
update_policy = history["update_count"][-1] % cfg.policy_delay == 0
agent.Q_update(policy=update_policy, recurrent=True)
else:
history["update_count"].append(history["update_count"][-1] if len(history["update_count"]) > 0 else 0)

### update explore rate
explore_rate = cfg.explore_rate_mul * agent.get_rate(
Expand All @@ -143,7 +168,7 @@ def train_RT3D(cfg : DictConfig):

### log results
history["returns"].extend(e_returns)
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2))
history["actions"].extend(np.array(e_actions).transpose(1, 0, 2).tolist())
history["rewards"].extend(e_rewards)
history["us"].extend(e_us)
history["explore_rate"].append(explore_rate)
Expand All @@ -164,17 +189,20 @@ def train_RT3D(cfg : DictConfig):
)

### checkpoint
if cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0:
if (cfg.ckpt_freq is not None and episode % cfg.ckpt_freq == 0) \
or episode == total_episodes - 1:
ckpt_dir = os.path.join(cfg.save_path, f"ckpt_{episode}")
os.makedirs(ckpt_dir, exist_ok=True)
agent.save_network(ckpt_dir)
for k in history.keys():
np.save(os.path.join(ckpt_dir, f"{k}.npy"), np.array(history[k]))

### save results and plot
agent.save_network(cfg.save_path)
for k in history.keys():
np.save(os.path.join(cfg.save_path, f"{k}.npy"), np.array(history[k]))
agent.save_ckpt(
save_path=os.path.join(ckpt_dir, "agent.pt"),
additional_info={
"episode": episode,
}
)
with open(os.path.join(ckpt_dir, "history.json"), "w") as f:
json.dump(history, f)

### plot
plot_returns(
returns=history["returns"],
explore_rates=history["explore_rate"],
Expand Down

0 comments on commit 50cd5a9

Please sign in to comment.