diff --git a/d3rlpy/algos/qlearning/torch/awac_impl.py b/d3rlpy/algos/qlearning/torch/awac_impl.py index 09eed606..b489126c 100644 --- a/d3rlpy/algos/qlearning/torch/awac_impl.py +++ b/d3rlpy/algos/qlearning/torch/awac_impl.py @@ -49,7 +49,7 @@ def __init__( self._n_action_samples = n_action_samples def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> SACActorLoss: # compute log probability dist = build_gaussian_distribution(action) diff --git a/d3rlpy/algos/qlearning/torch/bc_impl.py b/d3rlpy/algos/qlearning/torch/bc_impl.py index 022fd2fd..cbea8f33 100644 --- a/d3rlpy/algos/qlearning/torch/bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/bc_impl.py @@ -6,9 +6,11 @@ from torch.optim import Optimizer from ....dataclass_utils import asdict_as_float +from ....models import OptimizerWrapper from ....models.torch import ( CategoricalPolicy, DeterministicPolicy, + DiscreteImitationLoss, ImitationLoss, NormalPolicy, Policy, @@ -25,7 +27,7 @@ @dataclasses.dataclass(frozen=True) class BCBaseModules(Modules): - optim: Optimizer + optim: OptimizerWrapper class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta): @@ -45,13 +47,15 @@ def __init__( device=device, ) - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.optim.zero_grad() loss = self.compute_loss(batch.observations, batch.actions) loss.loss.backward() - self._modules.optim.step() + self._modules.optim.step(grad_step) return asdict_as_float(loss) @@ -72,7 +76,7 @@ def inner_predict_value( def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: - return self.update_imitator(batch) + return self.update_imitator(batch, grad_step) @dataclasses.dataclass(frozen=True) @@ -105,12 +109,14 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: def compute_loss( self, obs_t: TorchObservation, act_t: torch.Tensor - ) -> torch.Tensor: + ) -> ImitationLoss: if self._policy_type == "deterministic": + assert isinstance(self._modules.imitator, DeterministicPolicy) return compute_deterministic_imitation_loss( self._modules.imitator, obs_t, act_t ) elif self._policy_type == "stochastic": + assert isinstance(self._modules.imitator, NormalPolicy) return compute_stochastic_imitation_loss( self._modules.imitator, obs_t, act_t ) @@ -123,7 +129,7 @@ def policy(self) -> Policy: @property def policy_optim(self) -> Optimizer: - return self._modules.optim + return self._modules.optim.optim @dataclasses.dataclass(frozen=True) @@ -156,7 +162,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor: def compute_loss( self, obs_t: TorchObservation, act_t: torch.Tensor - ) -> torch.Tensor: + ) -> DiscreteImitationLoss: return compute_discrete_imitation_loss( policy=self._modules.imitator, x=obs_t, diff --git a/d3rlpy/algos/qlearning/torch/bcq_impl.py b/d3rlpy/algos/qlearning/torch/bcq_impl.py index ce1226ab..3ee3672b 100644 --- a/d3rlpy/algos/qlearning/torch/bcq_impl.py +++ b/d3rlpy/algos/qlearning/torch/bcq_impl.py @@ -4,8 +4,8 @@ import torch import torch.nn.functional as F -from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ActionOutput, CategoricalPolicy, @@ -44,7 +44,7 @@ class BCQModules(DDPGBaseModules): targ_policy: DeterministicResidualPolicy vae_encoder: VAEEncoder vae_decoder: VAEDecoder - vae_optim: Optimizer + vae_optim: OptimizerWrapper class BCQImpl(DDPGBaseImpl): @@ -88,14 +88,16 @@ def __init__( self._rl_start_step = rl_start_step def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: value = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" ) return DDPGBaseActorLoss(-value[0].mean()) - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -105,7 +107,7 @@ def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: beta=self._beta, ) loss.backward() - self._modules.vae_optim.step() + self._modules.vae_optim.step(grad_step) return {"vae_loss": float(loss.cpu().detach().numpy())} def _repeat_observation(self, x: TorchObservation) -> TorchObservation: @@ -184,7 +186,7 @@ def inner_update( ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_imitator(batch)) + metrics.update(self.update_imitator(batch, grad_step)) if grad_step < self._rl_start_step: return metrics @@ -201,8 +203,8 @@ def inner_update( action = self._modules.policy(batch.observations, sampled_action) # update models - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch, action)) + metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_actor(batch, action, grad_step)) self.update_critic_target() self.update_actor_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/bear_impl.py b/d3rlpy/algos/qlearning/torch/bear_impl.py index 49e16784..e6d7a70d 100644 --- a/d3rlpy/algos/qlearning/torch/bear_impl.py +++ b/d3rlpy/algos/qlearning/torch/bear_impl.py @@ -2,8 +2,8 @@ from typing import Dict, Optional import torch -from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, @@ -47,8 +47,8 @@ class BEARModules(SACModules): vae_encoder: VAEEncoder vae_decoder: VAEDecoder log_alpha: Parameter - vae_optim: Optimizer - alpha_optim: Optional[Optimizer] + vae_optim: OptimizerWrapper + alpha_optim: Optional[OptimizerWrapper] @dataclasses.dataclass(frozen=True) @@ -110,12 +110,12 @@ def __init__( self._warmup_steps = warmup_steps def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> BEARActorLoss: - loss = super().compute_actor_loss(batch, action) + loss = super().compute_actor_loss(batch, action, grad_step) mmd_loss = self._compute_mmd_loss(batch.observations) if self._modules.alpha_optim: - self.update_alpha(mmd_loss) + self.update_alpha(mmd_loss, grad_step) return BEARActorLoss( actor_loss=loss.actor_loss + mmd_loss, temp_loss=loss.temp_loss, @@ -124,11 +124,13 @@ def compute_actor_loss( alpha=get_parameter(self._modules.log_alpha).exp()[0][0], ) - def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def warmup_actor( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.actor_optim.zero_grad() loss = self._compute_mmd_loss(batch.observations) loss.backward() - self._modules.actor_optim.step() + self._modules.actor_optim.step(grad_step) return {"actor_loss": float(loss.cpu().detach().numpy())} def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: @@ -136,11 +138,13 @@ def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor: alpha = get_parameter(self._modules.log_alpha).exp() return (alpha * (mmd - self._alpha_threshold)).mean() - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.vae_optim.zero_grad() loss = self.compute_imitator_loss(batch) loss.backward() - self._modules.vae_optim.step() + self._modules.vae_optim.step(grad_step) return {"imitator_loss": float(loss.cpu().detach().numpy())} def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: @@ -152,12 +156,12 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor: beta=self._vae_kl_weight, ) - def update_alpha(self, mmd_loss: torch.Tensor) -> None: + def update_alpha(self, mmd_loss: torch.Tensor, grad_step: int) -> None: assert self._modules.alpha_optim self._modules.alpha_optim.zero_grad() loss = -mmd_loss loss.backward(retain_graph=True) - self._modules.alpha_optim.step() + self._modules.alpha_optim.step(grad_step) # clip for stability get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0) @@ -274,13 +278,13 @@ def inner_update( self, batch: TorchMiniBatch, grad_step: int ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_imitator(batch)) - metrics.update(self.update_critic(batch)) + metrics.update(self.update_imitator(batch, grad_step)) + metrics.update(self.update_critic(batch, grad_step)) if grad_step < self._warmup_steps: - actor_loss = self.warmup_actor(batch) + actor_loss = self.warmup_actor(batch, grad_step) else: action = self._modules.policy(batch.observations) - actor_loss = self.update_actor(batch, action) + actor_loss = self.update_actor(batch, action, grad_step) metrics.update(actor_loss) self.update_critic_target() return metrics diff --git a/d3rlpy/algos/qlearning/torch/cql_impl.py b/d3rlpy/algos/qlearning/torch/cql_impl.py index 5dfd059a..ce29f385 100644 --- a/d3rlpy/algos/qlearning/torch/cql_impl.py +++ b/d3rlpy/algos/qlearning/torch/cql_impl.py @@ -4,8 +4,8 @@ import torch import torch.nn.functional as F -from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, @@ -29,7 +29,7 @@ @dataclasses.dataclass(frozen=True) class CQLModules(SACModules): log_alpha: Parameter - alpha_optim: Optional[Optimizer] + alpha_optim: Optional[OptimizerWrapper] @dataclasses.dataclass(frozen=True) @@ -79,9 +79,9 @@ def __init__( self._max_q_backup = max_q_backup def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor + self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int ) -> CQLCriticLoss: - loss = super().compute_critic_loss(batch, q_tpn) + loss = super().compute_critic_loss(batch, q_tpn, grad_step) conservative_loss = self._compute_conservative_loss( obs_t=batch.observations, act_t=batch.actions, @@ -89,20 +89,22 @@ def compute_critic_loss( returns_to_go=batch.returns_to_go, ) if self._modules.alpha_optim: - self.update_alpha(conservative_loss) + self.update_alpha(conservative_loss, grad_step) return CQLCriticLoss( critic_loss=loss.critic_loss + conservative_loss.sum(), conservative_loss=conservative_loss.sum(), alpha=get_parameter(self._modules.log_alpha).exp()[0][0], ) - def update_alpha(self, conservative_loss: torch.Tensor) -> None: + def update_alpha( + self, conservative_loss: torch.Tensor, grad_step: int + ) -> None: assert self._modules.alpha_optim self._modules.alpha_optim.zero_grad() # the original implementation does scale the loss value loss = -conservative_loss.mean() loss.backward(retain_graph=True) - self._modules.alpha_optim.step() + self._modules.alpha_optim.step(grad_step) def _compute_policy_is_values( self, diff --git a/d3rlpy/algos/qlearning/torch/crr_impl.py b/d3rlpy/algos/qlearning/torch/crr_impl.py index 2e398a1b..b726c9ff 100644 --- a/d3rlpy/algos/qlearning/torch/crr_impl.py +++ b/d3rlpy/algos/qlearning/torch/crr_impl.py @@ -76,7 +76,7 @@ def __init__( self._target_update_interval = target_update_interval def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: # compute log probability dist = build_gaussian_distribution(action) @@ -187,8 +187,8 @@ def inner_update( ) -> Dict[str, float]: metrics = {} action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch, action)) + metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_actor(batch, action, grad_step)) if self._target_update_type == "hard": if grad_step % self._target_update_interval == 0: diff --git a/d3rlpy/algos/qlearning/torch/ddpg_impl.py b/d3rlpy/algos/qlearning/torch/ddpg_impl.py index 1ce7181f..70dac10c 100644 --- a/d3rlpy/algos/qlearning/torch/ddpg_impl.py +++ b/d3rlpy/algos/qlearning/torch/ddpg_impl.py @@ -6,6 +6,8 @@ from torch import nn from torch.optim import Optimizer +from d3rlpy.models.optimizers import OptimizerWrapper + from ....dataclass_utils import asdict_as_float from ....models.torch import ( ActionOutput, @@ -32,8 +34,8 @@ class DDPGBaseModules(Modules): policy: Policy q_funcs: nn.ModuleList targ_q_funcs: nn.ModuleList - actor_optim: Optimizer - critic_optim: Optimizer + actor_optim: OptimizerWrapper + critic_optim: OptimizerWrapper @dataclasses.dataclass(frozen=True) @@ -78,16 +80,18 @@ def __init__( self._targ_q_func_forwarder = targ_q_func_forwarder hard_sync(self._modules.targ_q_funcs, self._modules.q_funcs) - def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_critic( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) - loss = self.compute_critic_loss(batch, q_tpn) + loss = self.compute_critic_loss(batch, q_tpn, grad_step) loss.critic_loss.backward() - self._modules.critic_optim.step() + self._modules.critic_optim.step(grad_step) return asdict_as_float(loss) def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor + self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int ) -> DDPGBaseCriticLoss: loss = self._q_func_forwarder.compute_error( observations=batch.observations, @@ -100,14 +104,14 @@ def compute_critic_loss( return DDPGBaseCriticLoss(loss) def update_actor( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() self._modules.actor_optim.zero_grad() - loss = self.compute_actor_loss(batch, action) + loss = self.compute_actor_loss(batch, action, grad_step) loss.actor_loss.backward() - self._modules.actor_optim.step() + self._modules.actor_optim.step(grad_step) return asdict_as_float(loss) def inner_update( @@ -115,14 +119,14 @@ def inner_update( ) -> Dict[str, float]: metrics = {} action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch, action)) + metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_actor(batch, action, grad_step)) self.update_critic_target() return metrics @abstractmethod def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: pass @@ -146,7 +150,7 @@ def policy(self) -> Policy: @property def policy_optim(self) -> Optimizer: - return self._modules.actor_optim + return self._modules.actor_optim.optim @property def q_function(self) -> nn.ModuleList: @@ -154,7 +158,7 @@ def q_function(self) -> nn.ModuleList: @property def q_function_optim(self) -> Optimizer: - return self._modules.critic_optim + return self._modules.critic_optim.optim @dataclasses.dataclass(frozen=True) @@ -189,7 +193,7 @@ def __init__( hard_sync(self._modules.targ_policy, self._modules.policy) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" diff --git a/d3rlpy/algos/qlearning/torch/dqn_impl.py b/d3rlpy/algos/qlearning/torch/dqn_impl.py index e1e2d7fa..86279cb4 100644 --- a/d3rlpy/algos/qlearning/torch/dqn_impl.py +++ b/d3rlpy/algos/qlearning/torch/dqn_impl.py @@ -6,6 +6,7 @@ from torch.optim import Optimizer from ....dataclass_utils import asdict_as_float +from ....models.optimizers import OptimizerWrapper from ....models.torch import DiscreteEnsembleQFunctionForwarder from ....torch_utility import Modules, TorchMiniBatch, hard_sync from ....types import Shape, TorchObservation @@ -19,7 +20,7 @@ class DQNModules(Modules): q_funcs: nn.ModuleList targ_q_funcs: nn.ModuleList - optim: Optimizer + optim: OptimizerWrapper @dataclasses.dataclass(frozen=True) @@ -67,7 +68,7 @@ def inner_update( loss = self.compute_loss(batch, q_tpn) loss.loss.backward() - self._modules.optim.step() + self._modules.optim.step(grad_step) if grad_step % self._target_update_interval == 0: self.update_target() @@ -116,7 +117,7 @@ def q_function(self) -> nn.ModuleList: @property def q_function_optim(self) -> Optimizer: - return self._modules.optim + return self._modules.optim.optim class DoubleDQNImpl(DQNImpl): diff --git a/d3rlpy/algos/qlearning/torch/iql_impl.py b/d3rlpy/algos/qlearning/torch/iql_impl.py index 1dfd29d0..fa1bb445 100644 --- a/d3rlpy/algos/qlearning/torch/iql_impl.py +++ b/d3rlpy/algos/qlearning/torch/iql_impl.py @@ -68,7 +68,7 @@ def __init__( self._max_weight = max_weight def compute_critic_loss( - self, batch: TorchMiniBatch, q_tpn: torch.Tensor + self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int ) -> IQLCriticLoss: q_loss = self._q_func_forwarder.compute_error( observations=batch.observations, @@ -90,7 +90,7 @@ def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: return self._modules.value_func(batch.next_observations) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: # compute log probability dist = build_gaussian_distribution(action) diff --git a/d3rlpy/algos/qlearning/torch/plas_impl.py b/d3rlpy/algos/qlearning/torch/plas_impl.py index 8d61f08c..66733201 100644 --- a/d3rlpy/algos/qlearning/torch/plas_impl.py +++ b/d3rlpy/algos/qlearning/torch/plas_impl.py @@ -2,8 +2,8 @@ from typing import Dict import torch -from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ActionOutput, ContinuousEnsembleQFunctionForwarder, @@ -31,7 +31,7 @@ class PLASModules(DDPGBaseModules): targ_policy: DeterministicPolicy vae_encoder: VAEEncoder vae_decoder: VAEDecoder - vae_optim: Optimizer + vae_optim: OptimizerWrapper class PLASImpl(DDPGBaseImpl): @@ -68,7 +68,9 @@ def __init__( self._beta = beta self._warmup_steps = warmup_steps - def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_imitator( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.vae_optim.zero_grad() loss = compute_vae_error( vae_encoder=self._modules.vae_encoder, @@ -78,11 +80,11 @@ def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]: beta=self._beta, ) loss.backward() - self._modules.vae_optim.step() + self._modules.vae_optim.step(grad_step) return {"vae_loss": float(loss.cpu().detach().numpy())} def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu actions = self._modules.vae_decoder(batch.observations, latent_actions) @@ -123,11 +125,11 @@ def inner_update( metrics = {} if grad_step < self._warmup_steps: - metrics.update(self.update_imitator(batch)) + metrics.update(self.update_imitator(batch, grad_step)) else: action = self._modules.policy(batch.observations) - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch, action)) + metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_actor(batch, action, grad_step)) self.update_actor_target() self.update_critic_target() @@ -172,7 +174,7 @@ def __init__( ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> DDPGBaseActorLoss: latent_actions = 2.0 * action.squashed_mu actions = self._modules.vae_decoder(batch.observations, latent_actions) diff --git a/d3rlpy/algos/qlearning/torch/rebrac_impl.py b/d3rlpy/algos/qlearning/torch/rebrac_impl.py index ae8262ac..b3480508 100644 --- a/d3rlpy/algos/qlearning/torch/rebrac_impl.py +++ b/d3rlpy/algos/qlearning/torch/rebrac_impl.py @@ -48,7 +48,7 @@ def __init__( self._critic_beta = critic_beta def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> TD3PlusBCActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, diff --git a/d3rlpy/algos/qlearning/torch/sac_impl.py b/d3rlpy/algos/qlearning/torch/sac_impl.py index 551549af..c6ac7be8 100644 --- a/d3rlpy/algos/qlearning/torch/sac_impl.py +++ b/d3rlpy/algos/qlearning/torch/sac_impl.py @@ -7,6 +7,7 @@ from torch import nn from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ActionOutput, CategoricalPolicy, @@ -37,7 +38,7 @@ class SACModules(DDPGBaseModules): policy: NormalPolicy log_temp: Parameter - temp_optim: Optional[Optimizer] + temp_optim: Optional[OptimizerWrapper] @dataclasses.dataclass(frozen=True) @@ -72,13 +73,13 @@ def __init__( ) def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> SACActorLoss: dist = build_squashed_gaussian_distribution(action) sampled_action, log_prob = dist.sample_with_log_prob() if self._modules.temp_optim: - temp_loss = self.update_temp(log_prob) + temp_loss = self.update_temp(log_prob, grad_step) else: temp_loss = torch.tensor( 0.0, dtype=torch.float32, device=sampled_action.device @@ -94,14 +95,16 @@ def compute_actor_loss( temp=get_parameter(self._modules.log_temp).exp()[0][0], ) - def update_temp(self, log_prob: torch.Tensor) -> torch.Tensor: + def update_temp( + self, log_prob: torch.Tensor, grad_step: int + ) -> torch.Tensor: assert self._modules.temp_optim self._modules.temp_optim.zero_grad() with torch.no_grad(): targ_temp = log_prob - self._action_size loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() - self._modules.temp_optim.step() + self._modules.temp_optim.step(grad_step) return loss def compute_target(self, batch: TorchMiniBatch) -> torch.Tensor: @@ -129,9 +132,9 @@ class DiscreteSACModules(Modules): q_funcs: nn.ModuleList targ_q_funcs: nn.ModuleList log_temp: Optional[Parameter] - actor_optim: Optimizer - critic_optim: Optimizer - temp_optim: Optional[Optimizer] + actor_optim: OptimizerWrapper + critic_optim: OptimizerWrapper + temp_optim: Optional[OptimizerWrapper] class DiscreteSACImpl(DiscreteQFunctionMixin, QLearningAlgoImplBase): @@ -163,14 +166,16 @@ def __init__( self._target_update_interval = target_update_interval hard_sync(modules.targ_q_funcs, modules.q_funcs) - def update_critic(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_critic( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: self._modules.critic_optim.zero_grad() q_tpn = self.compute_target(batch) loss = self.compute_critic_loss(batch, q_tpn) loss.backward() - self._modules.critic_optim.step() + self._modules.critic_optim.step(grad_step) return {"critic_loss": float(loss.cpu().detach().numpy())} @@ -208,7 +213,9 @@ def compute_critic_loss( gamma=self._gamma**batch.intervals, ) - def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_actor( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: # Q function should be inference mode for stability self._modules.q_funcs.eval() @@ -217,7 +224,7 @@ def update_actor(self, batch: TorchMiniBatch) -> Dict[str, float]: loss = self.compute_actor_loss(batch) loss.backward() - self._modules.actor_optim.step() + self._modules.actor_optim.step(grad_step) return {"actor_loss": float(loss.cpu().detach().numpy())} @@ -236,7 +243,9 @@ def compute_actor_loss(self, batch: TorchMiniBatch) -> torch.Tensor: entropy = temp * log_probs return (probs * (entropy - q_t)).sum(dim=1).mean() - def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: + def update_temp( + self, batch: TorchMiniBatch, grad_step: int + ) -> Dict[str, float]: assert self._modules.temp_optim assert self._modules.log_temp is not None self._modules.temp_optim.zero_grad() @@ -252,7 +261,7 @@ def update_temp(self, batch: TorchMiniBatch) -> Dict[str, float]: loss = -(get_parameter(self._modules.log_temp).exp() * targ_temp).mean() loss.backward() - self._modules.temp_optim.step() + self._modules.temp_optim.step(grad_step) # current temperature value log_temp = get_parameter(self._modules.log_temp) @@ -270,9 +279,9 @@ def inner_update( # lagrangian parameter update for SAC temeprature if self._modules.temp_optim: - metrics.update(self.update_temp(batch)) - metrics.update(self.update_critic(batch)) - metrics.update(self.update_actor(batch)) + metrics.update(self.update_temp(batch, grad_step)) + metrics.update(self.update_critic(batch, grad_step)) + metrics.update(self.update_actor(batch, grad_step)) if grad_step % self._target_update_interval == 0: self.update_target() @@ -296,7 +305,7 @@ def policy(self) -> Policy: @property def policy_optim(self) -> Optimizer: - return self._modules.actor_optim + return self._modules.actor_optim.optim @property def q_function(self) -> nn.ModuleList: @@ -304,4 +313,4 @@ def q_function(self) -> nn.ModuleList: @property def q_function_optim(self) -> Optimizer: - return self._modules.critic_optim + return self._modules.critic_optim.optim diff --git a/d3rlpy/algos/qlearning/torch/td3_impl.py b/d3rlpy/algos/qlearning/torch/td3_impl.py index 33fdd235..c739f026 100644 --- a/d3rlpy/algos/qlearning/torch/td3_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_impl.py @@ -65,12 +65,12 @@ def inner_update( ) -> Dict[str, float]: metrics = {} - metrics.update(self.update_critic(batch)) + metrics.update(self.update_critic(batch, grad_step)) # delayed policy update if grad_step % self._update_actor_interval == 0: action = self._modules.policy(batch.observations) - metrics.update(self.update_actor(batch, action)) + metrics.update(self.update_actor(batch, action, grad_step)) self.update_critic_target() self.update_actor_target() diff --git a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py index 6f73b0de..c6103e7b 100644 --- a/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py +++ b/d3rlpy/algos/qlearning/torch/td3_plus_bc_impl.py @@ -51,7 +51,7 @@ def __init__( self._alpha = alpha def compute_actor_loss( - self, batch: TorchMiniBatch, action: ActionOutput + self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int ) -> TD3PlusBCActorLoss: q_t = self._q_func_forwarder.compute_expected_q( batch.observations, action.squashed_mu, "none" diff --git a/d3rlpy/algos/transformer/decision_transformer.py b/d3rlpy/algos/transformer/decision_transformer.py index e7e51f01..0a29b3fa 100644 --- a/d3rlpy/algos/transformer/decision_transformer.py +++ b/d3rlpy/algos/transformer/decision_transformer.py @@ -64,7 +64,6 @@ class DecisionTransformerConfig(TransformerConfig): position_encoding_type (d3rlpy.PositionEncodingType): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_steps (int): Warmup steps for learning rate scheduler. - clip_grad_norm (float): Norm of gradient clipping. compile (bool): (experimental) Flag to enable JIT compilation. """ @@ -80,7 +79,6 @@ class DecisionTransformerConfig(TransformerConfig): activation_type: str = "relu" position_encoding_type: PositionEncodingType = PositionEncodingType.SIMPLE warmup_steps: int = 10000 - clip_grad_norm: float = 0.25 compile: bool = False def create( @@ -136,7 +134,6 @@ def inner_create_impl( action_size=action_size, modules=modules, scheduler=scheduler, - clip_grad_norm=self._config.clip_grad_norm, device=self._device, ) @@ -179,7 +176,6 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): Type of positional encoding (``SIMPLE`` or ``GLOBAL``). warmup_tokens (int): Number of tokens to warmup learning rate scheduler. final_tokens (int): Final number of tokens for learning rate scheduler. - clip_grad_norm (float): Norm of gradient clipping. compile (bool): (experimental) Flag to enable JIT compilation. """ @@ -197,7 +193,6 @@ class DiscreteDecisionTransformerConfig(TransformerConfig): position_encoding_type: PositionEncodingType = PositionEncodingType.GLOBAL warmup_tokens: int = 10240 final_tokens: int = 30000000 - clip_grad_norm: float = 1.0 compile: bool = False def create( @@ -251,7 +246,6 @@ def inner_create_impl( observation_shape=observation_shape, action_size=action_size, modules=modules, - clip_grad_norm=self._config.clip_grad_norm, warmup_tokens=self._config.warmup_tokens, final_tokens=self._config.final_tokens, initial_learning_rate=self._config.learning_rate, diff --git a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py index b9b01e5b..23b7b5c0 100644 --- a/d3rlpy/algos/transformer/torch/decision_transformer_impl.py +++ b/d3rlpy/algos/transformer/torch/decision_transformer_impl.py @@ -4,8 +4,8 @@ import torch import torch.nn.functional as F -from torch.optim import Optimizer +from ....models import OptimizerWrapper from ....models.torch import ( ContinuousDecisionTransformer, DiscreteDecisionTransformer, @@ -26,13 +26,12 @@ @dataclasses.dataclass(frozen=True) class DecisionTransformerModules(Modules): transformer: ContinuousDecisionTransformer - optim: Optimizer + optim: OptimizerWrapper class DecisionTransformerImpl(TransformerAlgoImplBase): _modules: DecisionTransformerModules _scheduler: torch.optim.lr_scheduler.LRScheduler - _clip_grad_norm: float def __init__( self, @@ -40,7 +39,6 @@ def __init__( action_size: int, modules: DecisionTransformerModules, scheduler: torch.optim.lr_scheduler.LRScheduler, - clip_grad_norm: float, device: str, ): super().__init__( @@ -50,7 +48,6 @@ def __init__( device=device, ) self._scheduler = scheduler - self._clip_grad_norm = clip_grad_norm def inner_predict(self, inpt: TorchTransformerInput) -> torch.Tensor: # (1, T, A) @@ -64,14 +61,9 @@ def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int ) -> Dict[str, float]: self._modules.optim.zero_grad() - loss = self.compute_loss(batch) - loss.backward() - torch.nn.utils.clip_grad_norm_( - self._modules.transformer.parameters(), self._clip_grad_norm - ) - self._modules.optim.step() + self._modules.optim.step(grad_step) self._scheduler.step() return {"loss": float(loss.cpu().detach().numpy())} @@ -91,12 +83,11 @@ def compute_loss(self, batch: TorchTrajectoryMiniBatch) -> torch.Tensor: @dataclasses.dataclass(frozen=True) class DiscreteDecisionTransformerModules(Modules): transformer: DiscreteDecisionTransformer - optim: Optimizer + optim: OptimizerWrapper class DiscreteDecisionTransformerImpl(TransformerAlgoImplBase): _modules: DiscreteDecisionTransformerModules - _clip_grad_norm: float _warmup_tokens: int _final_tokens: int _initial_learning_rate: float @@ -107,7 +98,6 @@ def __init__( observation_shape: Shape, action_size: int, modules: DiscreteDecisionTransformerModules, - clip_grad_norm: float, warmup_tokens: int, final_tokens: int, initial_learning_rate: float, @@ -119,7 +109,6 @@ def __init__( modules=modules, device=device, ) - self._clip_grad_norm = clip_grad_norm self._warmup_tokens = warmup_tokens self._final_tokens = final_tokens self._initial_learning_rate = initial_learning_rate @@ -138,14 +127,9 @@ def inner_update( self, batch: TorchTrajectoryMiniBatch, grad_step: int ) -> Dict[str, float]: self._modules.optim.zero_grad() - loss = self.compute_loss(batch) - loss.backward() - torch.nn.utils.clip_grad_norm_( - self._modules.transformer.parameters(), self._clip_grad_norm - ) - self._modules.optim.step() + self._modules.optim.step(grad_step) # schedule learning rate self._tokens += int(batch.masks.sum().cpu().detach().numpy()) @@ -159,7 +143,7 @@ def inner_update( ) lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) new_learning_rate = lr_mult * self._initial_learning_rate - for param_group in self._modules.optim.param_groups: + for param_group in self._modules.optim.optim.param_groups: param_group["lr"] = new_learning_rate return { diff --git a/d3rlpy/models/optimizers.py b/d3rlpy/models/optimizers.py index f8cf5d09..a020bb39 100644 --- a/d3rlpy/models/optimizers.py +++ b/d3rlpy/models/optimizers.py @@ -1,5 +1,5 @@ import dataclasses -from typing import Iterable, Sequence, Tuple, Optional +from typing import Iterable, Optional, Sequence, Tuple from torch import nn from torch.optim import SGD, Adam, AdamW, Optimizer, RMSprop @@ -32,7 +32,7 @@ def _get_parameters_from_named_modules( class OptimizerWrapper: - """OptimizerWrapper class + """OptimizerWrapper class. This class wraps PyTorch optimizer to add additional steps such as gradient clipping. @@ -42,11 +42,17 @@ class OptimizerWrapper: optim: PyTorch optimizer. clip_grad_norm: Maximum norm value of gradients to clip. """ + _params: Sequence[nn.Parameter] _optim: Optimizer _clip_grad_norm: Optional[float] - def __init__(self, params: Sequence[nn.Parameter], optim: Optimizer, clip_grad_norm: Optional[float] = None): + def __init__( + self, + params: Sequence[nn.Parameter], + optim: Optimizer, + clip_grad_norm: Optional[float] = None, + ): self._params = params self._optim = optim self._clip_grad_norm = clip_grad_norm @@ -62,7 +68,9 @@ def step(self, grad_step: int) -> None: schedulers. """ if self._clip_grad_norm: - nn.utils.clip_grad_norm_(self._params, max_norm=self._clip_grad_norm) + nn.utils.clip_grad_norm_( + self._params, max_norm=self._clip_grad_norm + ) self._optim.step() @property @@ -76,6 +84,7 @@ class OptimizerFactory(DynamicConfig): The optimizers in algorithms can be configured through this factory class. """ + clip_grad_norm: Optional[float] = None def create( @@ -90,6 +99,7 @@ def create( Returns: Updater: Updater object. """ + named_modules = list(named_modules) params = _get_parameters_from_named_modules(named_modules) optim = self.create_optimizer(named_modules, lr) return OptimizerWrapper( @@ -98,7 +108,9 @@ def create( clip_grad_norm=self.clip_grad_norm, ) - def create_optimizer(self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float) -> Optimizer: + def create_optimizer( + self, named_modules: Iterable[Tuple[str, nn.Module]], lr: float + ) -> Optimizer: raise NotImplementedError diff --git a/d3rlpy/models/torch/imitators.py b/d3rlpy/models/torch/imitators.py index b46cd6dd..ba91dde8 100644 --- a/d3rlpy/models/torch/imitators.py +++ b/d3rlpy/models/torch/imitators.py @@ -27,6 +27,7 @@ "compute_deterministic_imitation_loss", "compute_stochastic_imitation_loss", "ImitationLoss", + "DiscreteImitationLoss", ] diff --git a/d3rlpy/ope/torch/fqe_impl.py b/d3rlpy/ope/torch/fqe_impl.py index 9e89d01a..648924ea 100644 --- a/d3rlpy/ope/torch/fqe_impl.py +++ b/d3rlpy/ope/torch/fqe_impl.py @@ -3,13 +3,13 @@ import torch from torch import nn -from torch.optim import Optimizer from ...algos.qlearning.base import QLearningAlgoImplBase from ...algos.qlearning.torch.utility import ( ContinuousQFunctionMixin, DiscreteQFunctionMixin, ) +from ...models import OptimizerWrapper from ...models.torch import ( ContinuousEnsembleQFunctionForwarder, DiscreteEnsembleQFunctionForwarder, @@ -24,7 +24,7 @@ class FQEBaseModules(Modules): q_funcs: nn.ModuleList targ_q_funcs: nn.ModuleList - optim: Optimizer + optim: OptimizerWrapper class FQEBaseImpl(QLearningAlgoImplBase): @@ -111,7 +111,7 @@ def inner_update( self._modules.optim.zero_grad() loss.backward() - self._modules.optim.step() + self._modules.optim.step(grad_step) if grad_step % self._target_update_interval == 0: self.update_target() diff --git a/d3rlpy/torch_utility.py b/d3rlpy/torch_utility.py index 61fdf871..b16bf26e 100644 --- a/d3rlpy/torch_utility.py +++ b/d3rlpy/torch_utility.py @@ -23,8 +23,12 @@ from .dataclass_utils import asdict_without_copy from .dataset import TrajectoryMiniBatch, TransitionMiniBatch from .preprocessing import ActionScaler, ObservationScaler, RewardScaler -from .types import Float32NDArray, NDArray, TorchObservation -from .models import OptimizerWrapper +from .types import ( + Float32NDArray, + NDArray, + OptimizerWrapperProto, + TorchObservation, +) __all__ = [ "soft_sync", @@ -326,11 +330,13 @@ def unwrap_ddp_model(model: _TModule) -> _TModule: class Checkpointer: - _modules: Dict[str, Union[nn.Module, OptimizerWrapper]] + _modules: Dict[str, Union[nn.Module, OptimizerWrapperProto]] _device: str def __init__( - self, modules: Dict[str, Union[nn.Module, OptimizerWrapper]], device: str + self, + modules: Dict[str, Union[nn.Module, OptimizerWrapperProto]], + device: str, ): self._modules = modules self._device = device @@ -353,7 +359,7 @@ def load(self, f: BinaryIO) -> None: v.optim.load_state_dict(chkpt[k]) @property - def modules(self) -> Dict[str, Union[nn.Module, OptimizerWrapper]]: + def modules(self) -> Dict[str, Union[nn.Module, OptimizerWrapperProto]]: return self._modules @@ -363,7 +369,7 @@ def create_checkpointer(self, device: str) -> Checkpointer: modules = { k: v for k, v in asdict_without_copy(self).items() - if isinstance(v, (nn.Module, OptimizerWrapper)) + if isinstance(v, (nn.Module, OptimizerWrapperProto)) } return Checkpointer(modules=modules, device=device) @@ -391,7 +397,7 @@ def set_train(self) -> None: def reset_optimizer_states(self) -> None: for v in asdict_without_copy(self).values(): - if isinstance(v, OptimizerWrapper): + if isinstance(v, OptimizerWrapperProto): v.optim.state = collections.defaultdict(dict) def get_torch_modules(self) -> Dict[str, nn.Module]: diff --git a/d3rlpy/types.py b/d3rlpy/types.py index 293f328b..2d532c34 100644 --- a/d3rlpy/types.py +++ b/d3rlpy/types.py @@ -5,6 +5,8 @@ import numpy as np import numpy.typing as npt import torch +from torch.optim import Optimizer +from typing_extensions import Protocol, runtime_checkable __all__ = [ "NDArray", @@ -17,6 +19,7 @@ "Shape", "TorchObservation", "GymEnv", + "OptimizerWrapperProto", ] @@ -32,3 +35,10 @@ TorchObservation = Union[torch.Tensor, Sequence[torch.Tensor]] GymEnv = Union[gym.Env[Any, Any], gymnasium.Env[Any, Any]] + + +@runtime_checkable +class OptimizerWrapperProto(Protocol): + @property + def optim(self) -> Optimizer: + raise NotImplementedError diff --git a/reproductions/finetuning/iql_finetune.py b/reproductions/finetuning/iql_finetune.py index e1048b6b..caa70e09 100644 --- a/reproductions/finetuning/iql_finetune.py +++ b/reproductions/finetuning/iql_finetune.py @@ -37,7 +37,7 @@ def main() -> None: iql.build_with_dataset(dataset) assert iql.impl scheduler = CosineAnnealingLR( - iql.impl._modules.actor_optim, # pylint: disable=protected-access + iql.impl._modules.actor_optim.optim, # pylint: disable=protected-access 1000000, ) @@ -56,7 +56,7 @@ def callback(algo: d3rlpy.algos.IQL, epoch: int, total_step: int) -> None: ) # reset learning rate - for g in iql.impl._modules.actor_optim.param_groups: + for g in iql.impl._modules.actor_optim.optim.param_groups: g["lr"] = iql.config.actor_learning_rate # prepare FIFO buffer filled with dataset episodes diff --git a/reproductions/offline/decision_transformer.py b/reproductions/offline/decision_transformer.py index 9452dd7e..d9b014ae 100644 --- a/reproductions/offline/decision_transformer.py +++ b/reproductions/offline/decision_transformer.py @@ -28,7 +28,9 @@ def main() -> None: dt = d3rlpy.algos.DecisionTransformerConfig( batch_size=64, learning_rate=1e-4, - optim_factory=d3rlpy.models.AdamWFactory(weight_decay=1e-4), + optim_factory=d3rlpy.models.AdamWFactory( + weight_decay=1e-4, clip_grad_norm=0.25 + ), encoder_factory=d3rlpy.models.VectorEncoderFactory( [128], exclude_last_activation=True, diff --git a/reproductions/offline/discrete_decision_transformer.py b/reproductions/offline/discrete_decision_transformer.py index 8e015af8..0ccebb4a 100644 --- a/reproductions/offline/discrete_decision_transformer.py +++ b/reproductions/offline/discrete_decision_transformer.py @@ -63,8 +63,8 @@ def main() -> None: optim_factory=d3rlpy.models.GPTAdamWFactory( betas=(0.9, 0.95), weight_decay=0.1, + clip_grad_norm=1.0, ), - clip_grad_norm=1.0, warmup_tokens=512 * 20, final_tokens=2 * 500000 * context_size * 3, observation_scaler=d3rlpy.preprocessing.PixelObservationScaler(), diff --git a/tests/models/test_optimizers.py b/tests/models/test_optimizers.py index 198a2b6d..2b7b040c 100644 --- a/tests/models/test_optimizers.py +++ b/tests/models/test_optimizers.py @@ -19,8 +19,8 @@ def test_sgd_factory(lr: float, module: torch.nn.Module) -> None: optim = factory.create(module.named_modules(), lr) - assert isinstance(optim, SGD) - assert optim.defaults["lr"] == lr + assert isinstance(optim.optim, SGD) + assert optim.optim.defaults["lr"] == lr # check serialization and deserialization SGDFactory.deserialize(factory.serialize()) @@ -33,8 +33,8 @@ def test_adam_factory(lr: float, module: torch.nn.Module) -> None: optim = factory.create(module.named_modules(), lr) - assert isinstance(optim, Adam) - assert optim.defaults["lr"] == lr + assert isinstance(optim.optim, Adam) + assert optim.optim.defaults["lr"] == lr # check serialization and deserialization AdamFactory.deserialize(factory.serialize()) @@ -47,8 +47,8 @@ def test_adam_w_factory(lr: float, module: torch.nn.Module) -> None: optim = factory.create(module.named_modules(), lr) - assert isinstance(optim, AdamW) - assert optim.defaults["lr"] == lr + assert isinstance(optim.optim, AdamW) + assert optim.optim.defaults["lr"] == lr # check serialization and deserialization AdamWFactory.deserialize(factory.serialize()) @@ -61,8 +61,8 @@ def test_rmsprop_factory(lr: float, module: torch.nn.Module) -> None: optim = factory.create(module.named_modules(), lr) - assert isinstance(optim, RMSprop) - assert optim.defaults["lr"] == lr + assert isinstance(optim.optim, RMSprop) + assert optim.optim.defaults["lr"] == lr # check serialization and deserialization RMSpropFactory.deserialize(factory.serialize()) @@ -83,11 +83,11 @@ def __init__(self) -> None: optim = factory.create(module.named_modules(), lr) - assert isinstance(optim, AdamW) - assert optim.defaults["lr"] == lr - assert len(optim.param_groups) == 2 - assert optim.param_groups[0]["weight_decay"] == weight_decay - assert optim.param_groups[1]["weight_decay"] == 0.0 + assert isinstance(optim.optim, AdamW) + assert optim.optim.defaults["lr"] == lr + assert len(optim.optim.param_groups) == 2 + assert optim.optim.param_groups[0]["weight_decay"] == weight_decay + assert optim.optim.param_groups[1]["weight_decay"] == 0.0 # check serialization and deserialization GPTAdamWFactory.deserialize(factory.serialize())