From 677e12b4ad45afa1819e863561eb12dc5eec0cce Mon Sep 17 00:00:00 2001 From: XinJingHao <75819608+XinJingHao@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:41:57 +0800 Subject: [PATCH] Update SACD.py simplify trainning of actor --- 5.1 SAC-Discrete/SACD.py | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/5.1 SAC-Discrete/SACD.py b/5.1 SAC-Discrete/SACD.py index 34ef353..b36f625 100644 --- a/5.1 SAC-Discrete/SACD.py +++ b/5.1 SAC-Discrete/SACD.py @@ -60,25 +60,18 @@ def train(self): self.q_critic_optimizer.step() #------------------------------------------ Train Actor ----------------------------------------# - for params in self.q_critic.parameters(): - #Freeze Q net, so you don't waste time on computing its gradient while updating Actor. - params.requires_grad = False - probs = self.actor(s) #[b,a_dim] log_probs = torch.log(probs + 1e-8) #[b,a_dim] with torch.no_grad(): q1_all, q2_all = self.q_critic(s) #[b,a_dim] min_q_all = torch.min(q1_all, q2_all) - a_loss = torch.sum(probs * (self.alpha*log_probs - min_q_all), dim=1, keepdim=True) #[b,1] + a_loss = torch.sum(probs * (self.alpha*log_probs - min_q_all), dim=1, keepdim=False) #[b,] self.actor_optimizer.zero_grad() a_loss.mean().backward() self.actor_optimizer.step() - for params in self.q_critic.parameters(): - params.requires_grad = True - #------------------------------------------ Train Alpha ----------------------------------------# if self.adaptive_alpha: with torch.no_grad():