Skip to content

Commit

Permalink
Update algorithms
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 18, 2024
1 parent 741d40f commit 198f9e8
Show file tree
Hide file tree
Showing 25 changed files with 193 additions and 154 deletions.
2 changes: 1 addition & 1 deletion d3rlpy/algos/qlearning/torch/awac_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
self._n_action_samples = n_action_samples

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int
) -> SACActorLoss:
# compute log probability
dist = build_gaussian_distribution(action)
Expand Down
20 changes: 13 additions & 7 deletions d3rlpy/algos/qlearning/torch/bc_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from torch.optim import Optimizer

from ....dataclass_utils import asdict_as_float
from ....models import OptimizerWrapper
from ....models.torch import (
CategoricalPolicy,
DeterministicPolicy,
DiscreteImitationLoss,
ImitationLoss,
NormalPolicy,
Policy,
Expand All @@ -25,7 +27,7 @@

@dataclasses.dataclass(frozen=True)
class BCBaseModules(Modules):
optim: Optimizer
optim: OptimizerWrapper


class BCBaseImpl(QLearningAlgoImplBase, metaclass=ABCMeta):
Expand All @@ -45,13 +47,15 @@ def __init__(
device=device,
)

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
self._modules.optim.zero_grad()

loss = self.compute_loss(batch.observations, batch.actions)

loss.loss.backward()
self._modules.optim.step()
self._modules.optim.step(grad_step)

return asdict_as_float(loss)

Expand All @@ -72,7 +76,7 @@ def inner_predict_value(
def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
return self.update_imitator(batch)
return self.update_imitator(batch, grad_step)


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -105,12 +109,14 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:

def compute_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
) -> torch.Tensor:
) -> ImitationLoss:
if self._policy_type == "deterministic":
assert isinstance(self._modules.imitator, DeterministicPolicy)
return compute_deterministic_imitation_loss(
self._modules.imitator, obs_t, act_t
)
elif self._policy_type == "stochastic":
assert isinstance(self._modules.imitator, NormalPolicy)
return compute_stochastic_imitation_loss(
self._modules.imitator, obs_t, act_t
)
Expand All @@ -123,7 +129,7 @@ def policy(self) -> Policy:

@property
def policy_optim(self) -> Optimizer:
return self._modules.optim
return self._modules.optim.optim


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -156,7 +162,7 @@ def inner_predict_best_action(self, x: TorchObservation) -> torch.Tensor:

def compute_loss(
self, obs_t: TorchObservation, act_t: torch.Tensor
) -> torch.Tensor:
) -> DiscreteImitationLoss:
return compute_discrete_imitation_loss(
policy=self._modules.imitator,
x=obs_t,
Expand Down
18 changes: 10 additions & 8 deletions d3rlpy/algos/qlearning/torch/bcq_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torch.nn.functional as F
from torch.optim import Optimizer

from ....models import OptimizerWrapper
from ....models.torch import (
ActionOutput,
CategoricalPolicy,
Expand Down Expand Up @@ -44,7 +44,7 @@ class BCQModules(DDPGBaseModules):
targ_policy: DeterministicResidualPolicy
vae_encoder: VAEEncoder
vae_decoder: VAEDecoder
vae_optim: Optimizer
vae_optim: OptimizerWrapper


class BCQImpl(DDPGBaseImpl):
Expand Down Expand Up @@ -88,14 +88,16 @@ def __init__(
self._rl_start_step = rl_start_step

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int
) -> DDPGBaseActorLoss:
value = self._q_func_forwarder.compute_expected_q(
batch.observations, action.squashed_mu, "none"
)
return DDPGBaseActorLoss(-value[0].mean())

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
self._modules.vae_optim.zero_grad()
loss = compute_vae_error(
vae_encoder=self._modules.vae_encoder,
Expand All @@ -105,7 +107,7 @@ def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
beta=self._beta,
)
loss.backward()
self._modules.vae_optim.step()
self._modules.vae_optim.step(grad_step)
return {"vae_loss": float(loss.cpu().detach().numpy())}

def _repeat_observation(self, x: TorchObservation) -> TorchObservation:
Expand Down Expand Up @@ -184,7 +186,7 @@ def inner_update(
) -> Dict[str, float]:
metrics = {}

metrics.update(self.update_imitator(batch))
metrics.update(self.update_imitator(batch, grad_step))
if grad_step < self._rl_start_step:
return metrics

Expand All @@ -201,8 +203,8 @@ def inner_update(
action = self._modules.policy(batch.observations, sampled_action)

# update models
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch, action))
metrics.update(self.update_critic(batch, grad_step))
metrics.update(self.update_actor(batch, action, grad_step))
self.update_critic_target()
self.update_actor_target()
return metrics
Expand Down
36 changes: 20 additions & 16 deletions d3rlpy/algos/qlearning/torch/bear_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Dict, Optional

import torch
from torch.optim import Optimizer

from ....models import OptimizerWrapper
from ....models.torch import (
ActionOutput,
ContinuousEnsembleQFunctionForwarder,
Expand Down Expand Up @@ -47,8 +47,8 @@ class BEARModules(SACModules):
vae_encoder: VAEEncoder
vae_decoder: VAEDecoder
log_alpha: Parameter
vae_optim: Optimizer
alpha_optim: Optional[Optimizer]
vae_optim: OptimizerWrapper
alpha_optim: Optional[OptimizerWrapper]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -110,12 +110,12 @@ def __init__(
self._warmup_steps = warmup_steps

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int
) -> BEARActorLoss:
loss = super().compute_actor_loss(batch, action)
loss = super().compute_actor_loss(batch, action, grad_step)
mmd_loss = self._compute_mmd_loss(batch.observations)
if self._modules.alpha_optim:
self.update_alpha(mmd_loss)
self.update_alpha(mmd_loss, grad_step)
return BEARActorLoss(
actor_loss=loss.actor_loss + mmd_loss,
temp_loss=loss.temp_loss,
Expand All @@ -124,23 +124,27 @@ def compute_actor_loss(
alpha=get_parameter(self._modules.log_alpha).exp()[0][0],
)

def warmup_actor(self, batch: TorchMiniBatch) -> Dict[str, float]:
def warmup_actor(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
self._modules.actor_optim.zero_grad()
loss = self._compute_mmd_loss(batch.observations)
loss.backward()
self._modules.actor_optim.step()
self._modules.actor_optim.step(grad_step)
return {"actor_loss": float(loss.cpu().detach().numpy())}

def _compute_mmd_loss(self, obs_t: TorchObservation) -> torch.Tensor:
mmd = self._compute_mmd(obs_t)
alpha = get_parameter(self._modules.log_alpha).exp()
return (alpha * (mmd - self._alpha_threshold)).mean()

def update_imitator(self, batch: TorchMiniBatch) -> Dict[str, float]:
def update_imitator(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
self._modules.vae_optim.zero_grad()
loss = self.compute_imitator_loss(batch)
loss.backward()
self._modules.vae_optim.step()
self._modules.vae_optim.step(grad_step)
return {"imitator_loss": float(loss.cpu().detach().numpy())}

def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
Expand All @@ -152,12 +156,12 @@ def compute_imitator_loss(self, batch: TorchMiniBatch) -> torch.Tensor:
beta=self._vae_kl_weight,
)

def update_alpha(self, mmd_loss: torch.Tensor) -> None:
def update_alpha(self, mmd_loss: torch.Tensor, grad_step: int) -> None:
assert self._modules.alpha_optim
self._modules.alpha_optim.zero_grad()
loss = -mmd_loss
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()
self._modules.alpha_optim.step(grad_step)
# clip for stability
get_parameter(self._modules.log_alpha).data.clamp_(-5.0, 10.0)

Expand Down Expand Up @@ -274,13 +278,13 @@ def inner_update(
self, batch: TorchMiniBatch, grad_step: int
) -> Dict[str, float]:
metrics = {}
metrics.update(self.update_imitator(batch))
metrics.update(self.update_critic(batch))
metrics.update(self.update_imitator(batch, grad_step))
metrics.update(self.update_critic(batch, grad_step))
if grad_step < self._warmup_steps:
actor_loss = self.warmup_actor(batch)
actor_loss = self.warmup_actor(batch, grad_step)
else:
action = self._modules.policy(batch.observations)
actor_loss = self.update_actor(batch, action)
actor_loss = self.update_actor(batch, action, grad_step)
metrics.update(actor_loss)
self.update_critic_target()
return metrics
16 changes: 9 additions & 7 deletions d3rlpy/algos/qlearning/torch/cql_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch
import torch.nn.functional as F
from torch.optim import Optimizer

from ....models import OptimizerWrapper
from ....models.torch import (
ContinuousEnsembleQFunctionForwarder,
DiscreteEnsembleQFunctionForwarder,
Expand All @@ -29,7 +29,7 @@
@dataclasses.dataclass(frozen=True)
class CQLModules(SACModules):
log_alpha: Parameter
alpha_optim: Optional[Optimizer]
alpha_optim: Optional[OptimizerWrapper]


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -79,30 +79,32 @@ def __init__(
self._max_q_backup = max_q_backup

def compute_critic_loss(
self, batch: TorchMiniBatch, q_tpn: torch.Tensor
self, batch: TorchMiniBatch, q_tpn: torch.Tensor, grad_step: int
) -> CQLCriticLoss:
loss = super().compute_critic_loss(batch, q_tpn)
loss = super().compute_critic_loss(batch, q_tpn, grad_step)
conservative_loss = self._compute_conservative_loss(
obs_t=batch.observations,
act_t=batch.actions,
obs_tp1=batch.next_observations,
returns_to_go=batch.returns_to_go,
)
if self._modules.alpha_optim:
self.update_alpha(conservative_loss)
self.update_alpha(conservative_loss, grad_step)
return CQLCriticLoss(
critic_loss=loss.critic_loss + conservative_loss.sum(),
conservative_loss=conservative_loss.sum(),
alpha=get_parameter(self._modules.log_alpha).exp()[0][0],
)

def update_alpha(self, conservative_loss: torch.Tensor) -> None:
def update_alpha(
self, conservative_loss: torch.Tensor, grad_step: int
) -> None:
assert self._modules.alpha_optim
self._modules.alpha_optim.zero_grad()
# the original implementation does scale the loss value
loss = -conservative_loss.mean()
loss.backward(retain_graph=True)
self._modules.alpha_optim.step()
self._modules.alpha_optim.step(grad_step)

def _compute_policy_is_values(
self,
Expand Down
6 changes: 3 additions & 3 deletions d3rlpy/algos/qlearning/torch/crr_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
self._target_update_interval = target_update_interval

def compute_actor_loss(
self, batch: TorchMiniBatch, action: ActionOutput
self, batch: TorchMiniBatch, action: ActionOutput, grad_step: int
) -> DDPGBaseActorLoss:
# compute log probability
dist = build_gaussian_distribution(action)
Expand Down Expand Up @@ -187,8 +187,8 @@ def inner_update(
) -> Dict[str, float]:
metrics = {}
action = self._modules.policy(batch.observations)
metrics.update(self.update_critic(batch))
metrics.update(self.update_actor(batch, action))
metrics.update(self.update_critic(batch, grad_step))
metrics.update(self.update_actor(batch, action, grad_step))

if self._target_update_type == "hard":
if grad_step % self._target_update_interval == 0:
Expand Down
Loading

0 comments on commit 198f9e8

Please sign in to comment.