From 1d28d0b220cf4a875cdee1c280dfb68e76ce97cf Mon Sep 17 00:00:00 2001 From: takuseno Date: Mon, 4 Nov 2024 11:28:09 +0900 Subject: [PATCH] Add SimBaEncoderFactory --- d3rlpy/models/encoders.py | 53 ++++++++++++++++++++++ d3rlpy/models/torch/encoders.py | 68 ++++++++++++++++++++++++++++ tests/models/test_encoders.py | 33 ++++++++++++++ tests/models/torch/test_encoders.py | 70 +++++++++++++++++++++++++++++ 4 files changed, 224 insertions(+) diff --git a/d3rlpy/models/encoders.py b/d3rlpy/models/encoders.py index b26e549e..3230b458 100644 --- a/d3rlpy/models/encoders.py +++ b/d3rlpy/models/encoders.py @@ -12,6 +12,7 @@ VectorEncoder, VectorEncoderWithAction, ) +from .torch.encoders import SimBaEncoder, SimBaEncoderWithAction from .utility import create_activation __all__ = [ @@ -19,6 +20,7 @@ "PixelEncoderFactory", "VectorEncoderFactory", "DefaultEncoderFactory", + "SimBaEncoderFactory", "register_encoder_factory", "make_encoder_field", ] @@ -263,6 +265,56 @@ def get_type() -> str: return "default" +@dataclass() +class SimBaEncoderFactory(EncoderFactory): + """SimBa encoder factory class. + + This class implements SimBa encoder architecture. + + References: + * `Lee et al., SimBa: Simplicity Bias for Scaling Up Parameters in Deep + Reinforcement Learning, `_ + + Args: + feature_size (int): Feature unit size. + hidden_size (int): HIdden expansion layer unit size. + n_blocks (int): Number of SimBa blocks. + """ + + feature_size: int = 256 + hidden_size: int = 1024 + n_blocks: int = 1 + + def create(self, observation_shape: Shape) -> SimBaEncoder: + assert len(observation_shape) == 1 + return SimBaEncoder( + observation_shape=cast_flat_shape(observation_shape), + hidden_size=self.hidden_size, + output_size=self.feature_size, + n_blocks=self.n_blocks, + ) + + def create_with_action( + self, + observation_shape: Shape, + action_size: int, + discrete_action: bool = False, + ) -> SimBaEncoderWithAction: + assert len(observation_shape) == 1 + return SimBaEncoderWithAction( + observation_shape=cast_flat_shape(observation_shape), + action_size=action_size, + hidden_size=self.hidden_size, + output_size=self.feature_size, + n_blocks=self.n_blocks, + discrete_action=discrete_action, + ) + + @staticmethod + def get_type() -> str: + return "simba" + + register_encoder_factory, make_encoder_field = generate_config_registration( EncoderFactory, lambda: DefaultEncoderFactory() ) @@ -271,3 +323,4 @@ def get_type() -> str: register_encoder_factory(VectorEncoderFactory) register_encoder_factory(PixelEncoderFactory) register_encoder_factory(DefaultEncoderFactory) +register_encoder_factory(SimBaEncoderFactory) diff --git a/d3rlpy/models/torch/encoders.py b/d3rlpy/models/torch/encoders.py index e57be61e..2fc9caaf 100644 --- a/d3rlpy/models/torch/encoders.py +++ b/d3rlpy/models/torch/encoders.py @@ -15,6 +15,8 @@ "PixelEncoderWithAction", "VectorEncoder", "VectorEncoderWithAction", + "SimBaEncoder", + "SimBaEncoderWithAction", "compute_output_size", ] @@ -290,6 +292,72 @@ def forward( return self._layers(x) +class SimBaBlock(nn.Module): # type: ignore + def __init__(self, input_size: int, hidden_size: int, out_size: int): + super().__init__() + layers = [ + nn.LayerNorm(input_size), + nn.Linear(input_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, out_size) + ] + self._layers = nn.Sequential(*layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self._layers(x) + + +class SimBaEncoder(Encoder): + def __init__( + self, + observation_shape: Sequence[int], + hidden_size: int, + output_size: int, + n_blocks: int, + ): + super().__init__() + layers = [ + nn.Linear(observation_shape[0], output_size), + *[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)], + nn.LayerNorm(output_size), + ] + self._layers = nn.Sequential(*layers) + + def forward(self, x: TorchObservation) -> torch.Tensor: + assert isinstance(x, torch.Tensor) + return self._layers(x) + + +class SimBaEncoderWithAction(EncoderWithAction): + def __init__( + self, + observation_shape: Sequence[int], + action_size: int, + hidden_size: int, + output_size: int, + n_blocks: int, + discrete_action: bool, + ): + super().__init__() + layers = [ + nn.Linear(observation_shape[0] + action_size, output_size), + *[SimBaBlock(output_size, hidden_size, output_size) for _ in range(n_blocks)], + nn.LayerNorm(output_size), + ] + self._layers = nn.Sequential(*layers) + self._action_size = action_size + self._discrete_action = discrete_action + + def forward(self, x: TorchObservation, action: torch.Tensor) -> torch.Tensor: + assert isinstance(x, torch.Tensor) + if self._discrete_action: + action = F.one_hot( + action.view(-1).long(), num_classes=self._action_size + ).float() + h = torch.cat([x, action], dim=1) + return self._layers(h) + + def compute_output_size( input_shapes: Sequence[Shape], encoder: nn.Module ) -> int: diff --git a/tests/models/test_encoders.py b/tests/models/test_encoders.py index e931b56e..e19417c3 100644 --- a/tests/models/test_encoders.py +++ b/tests/models/test_encoders.py @@ -6,11 +6,14 @@ from d3rlpy.models.encoders import ( DefaultEncoderFactory, PixelEncoderFactory, + SimBaEncoderFactory, VectorEncoderFactory, ) from d3rlpy.models.torch.encoders import ( PixelEncoder, PixelEncoderWithAction, + SimBaEncoder, + SimBaEncoderWithAction, VectorEncoder, VectorEncoderWithAction, ) @@ -104,3 +107,33 @@ def test_default_encoder_factory( # check serization and deserialization DefaultEncoderFactory.deserialize(factory.serialize()) + + +@pytest.mark.parametrize("observation_shape", [(100,)]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("discrete_action", [False, True]) +def test_simba_encoder_factory( + observation_shape: Sequence[int], + action_size: int, + discrete_action: bool, +) -> None: + factory = SimBaEncoderFactory() + + # test state encoder + encoder = factory.create(observation_shape) + assert isinstance(encoder, SimBaEncoder) + + # test state-action encoder + encoder = factory.create_with_action( + observation_shape, action_size, discrete_action + ) + assert isinstance(encoder, SimBaEncoderWithAction) + assert encoder._discrete_action == discrete_action + + assert factory.get_type() == "simba" + + # check serization and deserialization + new_factory = SimBaEncoderFactory.deserialize(factory.serialize()) + assert new_factory.hidden_size == factory.hidden_size + assert new_factory.feature_size == factory.feature_size + assert new_factory.n_blocks == factory.n_blocks diff --git a/tests/models/torch/test_encoders.py b/tests/models/torch/test_encoders.py index 92c9af9e..ac9ab5dc 100644 --- a/tests/models/torch/test_encoders.py +++ b/tests/models/torch/test_encoders.py @@ -7,6 +7,8 @@ from d3rlpy.models.torch.encoders import ( PixelEncoder, PixelEncoderWithAction, + SimBaEncoder, + SimBaEncoderWithAction, VectorEncoder, VectorEncoderWithAction, ) @@ -212,3 +214,71 @@ def test_vector_encoder_with_action( # check layer connection check_parameter_updates(encoder, (x, action)) + + +@pytest.mark.parametrize("observation_shape", [(100,)]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("output_size", [256]) +@pytest.mark.parametrize("n_blocks", [2]) +@pytest.mark.parametrize("batch_size", [32]) +def test_simba_encoder( + observation_shape: Sequence[int], + hidden_size: int, + output_size: int, + n_blocks: int, + batch_size: int +) -> None: + encoder = SimBaEncoder( + observation_shape=observation_shape, + hidden_size=hidden_size, + output_size=output_size, + n_blocks=n_blocks, + ) + + x = torch.rand((batch_size, *observation_shape)) + y = encoder(x) + + # check output shape + assert y.shape == (batch_size, output_size) + + # check layer connection + check_parameter_updates(encoder, (x,)) + + +@pytest.mark.parametrize("observation_shape", [(100,)]) +@pytest.mark.parametrize("action_size", [2]) +@pytest.mark.parametrize("hidden_size", [128]) +@pytest.mark.parametrize("output_size", [256]) +@pytest.mark.parametrize("n_blocks", [2]) +@pytest.mark.parametrize("batch_size", [32]) +@pytest.mark.parametrize("discrete_action", [False, True]) +def test_simba_encoder_with_action( + observation_shape: Sequence[int], + action_size: int, + hidden_size: int, + output_size: int, + n_blocks: int, + batch_size: int, + discrete_action: bool, +) -> None: + encoder = SimBaEncoderWithAction( + observation_shape=observation_shape, + action_size=action_size, + hidden_size=hidden_size, + output_size=output_size, + n_blocks=n_blocks, + discrete_action=discrete_action, + ) + + x = torch.rand((batch_size, *observation_shape)) + if discrete_action: + action = torch.randint(0, action_size, size=(batch_size, 1)) + else: + action = torch.rand(batch_size, action_size) + y = encoder(x, action) + + # check output shape + assert y.shape == (batch_size, output_size) + + # check layer connection + check_parameter_updates(encoder, (x, action))