diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index f0f0b8e57..8fc607107 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -9,6 +9,7 @@ from gymnasium import Space from gymnasium.core import ActType, Env +from gymnasium.logger import warn from gymnasium.vector import VectorActionWrapper, VectorEnv from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate from gymnasium.wrappers import transform_action @@ -61,18 +62,31 @@ def __init__( env: VectorEnv, func: Callable[[ActType], Any], action_space: Space | None = None, + single_action_space: Space | None = None, ): """Constructor for the lambda action wrapper. Args: env: The vector environment to wrap func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``. - action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``. + action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``. + single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``. """ super().__init__(env) - if action_space is not None: + if action_space is None: + if single_action_space is not None: + self.single_action_space = single_action_space + self.action_space = batch_space(single_action_space, self.num_envs) + else: self.action_space = action_space + if single_action_space is not None: + self.single_action_space = single_action_space + # TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below. + if self.action_space != batch_space(self.single_action_space, self.num_envs): + warn( + f"For {env}, the action space and the batched single action space don't match as expected, action_space={env.action_space}, batched single_action_space={batch_space(self.single_action_space, self.num_envs)}" + ) self.func = func diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 52b5b9a07..5ca07a1e4 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -57,18 +57,35 @@ def __init__( env: VectorEnv, func: Callable[[ObsType], Any], observation_space: Space | None = None, + single_observation_space: Space | None = None, ): """Constructor for the transform observation wrapper. Args: env: The vector environment to wrap func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. - observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. + observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``. + single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``. """ super().__init__(env) - if observation_space is not None: + if observation_space is None: + if single_observation_space is not None: + self.single_observation_space = single_observation_space + self.observation_space = batch_space( + single_observation_space, self.num_envs + ) + else: self.observation_space = observation_space + if single_observation_space is not None: + self._single_observation_space = single_observation_space + # TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below. + if self.observation_space != batch_space( + self.single_observation_space, self.num_envs + ): + warn( + f"For {env}, the observation space and the batched single observation space don't match as expected, observation_space={env.observation_space}, batched single_observation_space={batch_space(self.single_observation_space, self.num_envs)}" + ) self.func = func diff --git a/tests/wrappers/vector/test_transform_action.py b/tests/wrappers/vector/test_transform_action.py new file mode 100644 index 000000000..5b0d3987d --- /dev/null +++ b/tests/wrappers/vector/test_transform_action.py @@ -0,0 +1,73 @@ +"""Test suite for vector TransformAction wrapper.""" + +import numpy as np +import pytest + +from gymnasium import spaces, wrappers +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def create_env(): + return GenericTestEnv( + action_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + + +def test_action_space_from_single_action_space( + n_envs: int = 5, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformAction( + vec_env, + func=lambda x: x + 100, + single_action_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32) + 100, + high=np.array([10, -5, 10], dtype=np.float32) + 100, + ), + ) + + # Check action space + assert isinstance(vec_env.action_space, spaces.Box) + assert vec_env.action_space.shape == (n_envs, 3) + assert vec_env.action_space.dtype == np.float32 + assert ( + vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32) + ).all() + assert ( + vec_env.action_space.high + == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) + ).all() + + # Check single action space + assert isinstance(vec_env.single_action_space, spaces.Box) + assert vec_env.single_action_space.shape == (3,) + assert vec_env.single_action_space.dtype == np.float32 + assert ( + vec_env.single_action_space.low == np.array([100, 90, 95], dtype=np.float32) + ).all() + assert ( + vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32) + ).all() + + +def test_warning_on_mismatched_single_action_space( + n_envs: int = 2, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + # We only specify action_space without single_action_space, so single_action_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning. + with pytest.warns( + Warning, + match=r"the action space and the batched single action space don't match as expected", + ): + vec_env = wrappers.vector.TransformAction( + vec_env, + func=lambda x: x + 100, + action_space=spaces.Box( + low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, + high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, + ), + ) diff --git a/tests/wrappers/vector/test_transform_observation.py b/tests/wrappers/vector/test_transform_observation.py new file mode 100644 index 000000000..98b39b129 --- /dev/null +++ b/tests/wrappers/vector/test_transform_observation.py @@ -0,0 +1,100 @@ +"""Test suite for vector TransformObservation wrapper.""" + +import numpy as np +import pytest + +from gymnasium import spaces, wrappers +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def create_env(): + return GenericTestEnv( + observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + + +def test_transform(n_envs: int = 2): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + single_observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ), + ) + + obs, _ = vec_env.reset(seed=123) + vec_env.observation_space.seed(123) + vec_env.action_space.seed(123) + + assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all() + assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all() + + obs, *_ = vec_env.step(vec_env.action_space.sample()) + + assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all() + assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all() + + +def test_observation_space_from_single_observation_space( + n_envs: int = 5, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + single_observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32) + 100, + high=np.array([10, -5, 10], dtype=np.float32) + 100, + ), + ) + + # Check observation space + assert isinstance(vec_env.observation_space, spaces.Box) + assert vec_env.observation_space.shape == (n_envs, 3) + assert vec_env.observation_space.dtype == np.float32 + assert ( + vec_env.observation_space.low + == np.array([[100, 90, 95]] * n_envs, dtype=np.float32) + ).all() + assert ( + vec_env.observation_space.high + == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) + ).all() + + # Check single observation space + assert isinstance(vec_env.single_observation_space, spaces.Box) + assert vec_env.single_observation_space.shape == (3,) + assert vec_env.single_observation_space.dtype == np.float32 + assert ( + vec_env.single_observation_space.low + == np.array([100, 90, 95], dtype=np.float32) + ).all() + assert ( + vec_env.single_observation_space.high + == np.array([110, 95, 110], dtype=np.float32) + ).all() + + +def test_warning_on_mismatched_single_observation_space( + n_envs: int = 2, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + # We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning. + with pytest.warns( + Warning, + match=r"the observation space and the batched single observation space don't match as expected", + ): + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + observation_space=spaces.Box( + low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, + high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, + ), + )