Skip to content

Commit

Permalink
Update SACD.py
Browse files Browse the repository at this point in the history
simplify trainning of actor
  • Loading branch information
XinJingHao authored Oct 23, 2024
1 parent d3e6408 commit 677e12b
Showing 1 changed file with 1 addition and 8 deletions.
9 changes: 1 addition & 8 deletions 5.1 SAC-Discrete/SACD.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,25 +60,18 @@ def train(self):
self.q_critic_optimizer.step()

#------------------------------------------ Train Actor ----------------------------------------#
for params in self.q_critic.parameters():
#Freeze Q net, so you don't waste time on computing its gradient while updating Actor.
params.requires_grad = False

probs = self.actor(s) #[b,a_dim]
log_probs = torch.log(probs + 1e-8) #[b,a_dim]
with torch.no_grad():
q1_all, q2_all = self.q_critic(s) #[b,a_dim]
min_q_all = torch.min(q1_all, q2_all)

a_loss = torch.sum(probs * (self.alpha*log_probs - min_q_all), dim=1, keepdim=True) #[b,1]
a_loss = torch.sum(probs * (self.alpha*log_probs - min_q_all), dim=1, keepdim=False) #[b,]

self.actor_optimizer.zero_grad()
a_loss.mean().backward()
self.actor_optimizer.step()

for params in self.q_critic.parameters():
params.requires_grad = True

#------------------------------------------ Train Alpha ----------------------------------------#
if self.adaptive_alpha:
with torch.no_grad():
Expand Down

0 comments on commit 677e12b

Please sign in to comment.