Skip to content

Commit

Permalink
Fix DiscreteSAC
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Nov 2, 2024
1 parent 54443e4 commit 303ed52
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions d3rlpy/algos/qlearning/torch/sac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def compute_actor_grad(
def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
# Q function should be inference mode for stability
self._modules.q_funcs.eval()
loss = self._compute_critic_grad(batch)
loss = self._compute_actor_grad(batch)
self._modules.actor_optim.step()
return {"actor_loss": float(loss["loss"].cpu().detach().numpy())}

Expand All @@ -259,7 +259,7 @@ def compute_actor_loss(

loss = {}
if self._modules.temp_optim:
loss.update(self.update_temp(batch, dist))
loss.update(self.update_temp(dist))

log_probs = dist.logits
probs = dist.probs
Expand All @@ -271,9 +271,7 @@ def compute_actor_loss(
loss["loss"] = (probs * (entropy - q_t)).sum(dim=1).mean()
return loss

def update_temp(
self, batch: TorchMiniBatch, dist: Categorical
) -> Dict[str, torch.Tensor]:
def update_temp(self, dist: Categorical) -> Dict[str, torch.Tensor]:
assert self._modules.temp_optim
assert self._modules.log_temp is not None
self._modules.temp_optim.zero_grad()
Expand Down

0 comments on commit 303ed52

Please sign in to comment.