Skip to content

Commit

Permalink
fix clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
hasan-yaman committed Oct 17, 2024
1 parent 7a57ab3 commit f69e2ac
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions d3rlpy/algos/qlearning/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,14 +160,11 @@ def reset_optimizer_states(self) -> None:

def clip_gradients(self) -> None:
if self.clip_gradient_norm is not None:
parameters = [
param.parameters()
for param in self.modules.get_torch_modules().values()
]
torch.nn.utils.clip_grad_norm_(
*parameters,
torch.tensor(data=self.clip_gradient_norm, device=self.device),
)
for module in self.modules.get_torch_modules().values():
torch.nn.utils.clip_grad_norm_(
parameters=module.parameters(),
max_norm=self.clip_gradient_norm,
)


TQLearningImpl = TypeVar("TQLearningImpl", bound=QLearningAlgoImplBase)
Expand Down

0 comments on commit f69e2ac

Please sign in to comment.