Skip to content

Commit

Permalink
Update Categorical_DQN.py
Browse files Browse the repository at this point in the history
fix map_location
  • Loading branch information
XinJingHao authored Jun 8, 2024
1 parent 987e83b commit a464e0a
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions 2.4_Categorical-DQN_C51/Categorical_DQN.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,8 @@ def save(self,algo,EnvName,steps):
torch.save(self.q_net.state_dict(), "./model/{}_{}_{}k.pth".format(algo,EnvName,steps))

def load(self,algo,EnvName,steps):
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps)))
self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps)))
self.q_net.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps), map_location=self.dvc))
self.q_target.load_state_dict(torch.load("./model/{}_{}_{}k.pth".format(algo,EnvName,steps), map_location=self.dvc))


class ReplayBuffer(object):
Expand Down

0 comments on commit a464e0a

Please sign in to comment.