Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
takuseno committed Oct 18, 2024
1 parent 198f9e8 commit b213bb3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
3 changes: 2 additions & 1 deletion d3rlpy/algos/transformer/decision_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,8 @@ def inner_create_impl(
transformer.named_modules(), lr=self._config.learning_rate
)
scheduler = torch.optim.lr_scheduler.LambdaLR(
optim, lambda steps: min((steps + 1) / self._config.warmup_steps, 1)
optim.optim,
lambda steps: min((steps + 1) / self._config.warmup_steps, 1),
)

# JIT compile
Expand Down
23 changes: 14 additions & 9 deletions tests/test_torch_utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from d3rlpy.dataset import TrajectoryMiniBatch, Transition, TransitionMiniBatch
from d3rlpy.models import OptimizerWrapper
from d3rlpy.torch_utility import (
GEGLU,
Checkpointer,
Expand Down Expand Up @@ -114,7 +115,8 @@ class DummyImpl:
def __init__(self) -> None:
self.fc1 = torch.nn.Linear(100, 100)
self.fc2 = torch.nn.Linear(100, 100)
self.optim = torch.optim.Adam(self.fc1.parameters())
params = list(self.fc1.parameters())
self.optim = OptimizerWrapper(params, torch.optim.Adam(params))
self.modules = DummyModules(self.fc1, self.optim)
self.device = "cpu:0"

Expand Down Expand Up @@ -143,16 +145,16 @@ def test_reset_optimizer_states() -> None:
# instantiate optimizer state
y = impl.fc1(torch.rand(100)).sum()
y.backward()
impl.optim.step()
impl.optim.step(0)

# check if state is not empty
state = copy.deepcopy(impl.optim.state)
state = copy.deepcopy(impl.optim.optim.state)
assert state

impl.modules.reset_optimizer_states()

# check if state is empty
reset_state = impl.optim.state
reset_state = impl.optim.optim.state
assert not reset_state


Expand Down Expand Up @@ -183,12 +185,13 @@ def test_get_batch_size() -> None:
@dataclasses.dataclass(frozen=True)
class DummyModules(Modules):
fc: torch.nn.Linear
optim: torch.optim.Adam
optim: OptimizerWrapper


def test_modules() -> None:
fc = torch.nn.Linear(100, 200)
optim = torch.optim.Adam(fc.parameters())
params = list(fc.parameters())
optim = OptimizerWrapper(params, torch.optim.Adam(params))
modules = DummyModules(fc, optim)

# check checkpointer
Expand Down Expand Up @@ -398,7 +401,8 @@ def test_torch_trajectory_mini_batch(
def test_checkpointer() -> None:
fc1 = torch.nn.Linear(100, 100)
fc2 = torch.nn.Linear(100, 100)
optim = torch.optim.Adam(fc1.parameters())
params = list(fc1.parameters())
optim = OptimizerWrapper(params, torch.optim.Adam(params))
checkpointer = Checkpointer(
modules={"fc1": fc1, "fc2": fc2, "optim": optim}, device="cpu:0"
)
Expand All @@ -408,7 +412,7 @@ def test_checkpointer() -> None:
states = {
"fc1": fc1.state_dict(),
"fc2": fc2.state_dict(),
"optim": optim.state_dict(),
"optim": optim.optim.state_dict(),
}
torch.save(states, ref_bytes)

Expand All @@ -419,7 +423,8 @@ def test_checkpointer() -> None:

fc1_2 = torch.nn.Linear(100, 100)
fc2_2 = torch.nn.Linear(100, 100)
optim_2 = torch.optim.Adam(fc1_2.parameters())
params_2 = list(fc1_2.parameters())
optim_2 = OptimizerWrapper(params_2, torch.optim.Adam(params_2))
checkpointer = Checkpointer(
modules={"fc1": fc1_2, "fc2": fc2_2, "optim": optim_2}, device="cpu:0"
)
Expand Down

0 comments on commit b213bb3

Please sign in to comment.