diff --git a/4.2 TD3/TD3.py b/4.2 TD3/TD3.py index 06293fe..5498fb8 100644 --- a/4.2 TD3/TD3.py +++ b/4.2 TD3/TD3.py @@ -80,8 +80,8 @@ def save(self,EnvName, timestep): torch.save(self.q_critic.state_dict(), "./model/{}_q_critic{}.pth".format(EnvName,timestep)) def load(self,EnvName, timestep): - self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep))) - self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep))) + self.actor.load_state_dict(torch.load("./model/{}_actor{}.pth".format(EnvName, timestep), map_location=self.dvc)) + self.q_critic.load_state_dict(torch.load("./model/{}_q_critic{}.pth".format(EnvName, timestep), map_location=self.dvc)) class ReplayBuffer():