diff --git a/2.4_Categorical-DQN_C51/Categorical_DQN.py b/2.4_Categorical-DQN_C51/Categorical_DQN.py index 4ea59f6..fa41fd5 100644 --- a/2.4_Categorical-DQN_C51/Categorical_DQN.py +++ b/2.4_Categorical-DQN_C51/Categorical_DQN.py @@ -114,8 +114,8 @@ def save(self,algo,EnvName,steps): torch.save(self.q_net.state_dict(), "./model/{}_{}_{}k.pth".format(algo,EnvName,steps)) def load(self,algo,EnvName,steps): - self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps))) - self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps))) + self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps), map_location=self.dvc)) + self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps), map_location=self.dvc)) class ReplayBuffer(object):