From 509ebc5e6ced4e08e02d7f1ecfad475209830436 Mon Sep 17 00:00:00 2001 From: Howard <174055+howardh@users.noreply.github.com> Date: Tue, 7 Jan 2025 14:53:57 -0500 Subject: [PATCH 1/5] Transform[Observation/Action] single_[observation/action]_space fix (Farama-Foundation/Gymnasium#1287) --- gymnasium/wrappers/vector/vectorize_action.py | 36 +++++++- .../wrappers/vector/vectorize_observation.py | 25 +++++- .../wrappers/vector/test_transform_action.py | 41 +++++++++ .../vector/test_transform_observation.py | 89 +++++++++++++++++++ 4 files changed, 187 insertions(+), 4 deletions(-) create mode 100644 tests/wrappers/vector/test_transform_action.py create mode 100644 tests/wrappers/vector/test_transform_observation.py diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index f0f0b8e57..7baeda0d9 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -61,18 +61,43 @@ def __init__( env: VectorEnv, func: Callable[[ActType], Any], action_space: Space | None = None, + single_action_space: Space | None = None, ): """Constructor for the lambda action wrapper. Args: env: The vector environment to wrap func: A function that will transform an action. If this transformed action is outside the action space of ``env.action_space`` then provide an ``action_space``. - action_space: The action spaces of the wrapper, if None, then it is assumed the same as ``env.action_space``. + action_space: The action spaces of the wrapper. If None, then it is computed from ``single_action_space``. If ``single_action_space`` is not provided either, then it is assumed to be the same as ``env.action_space``. + single_action_space: The action space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_action_space``. """ super().__init__(env) - if action_space is not None: + """ + self._single_observation_space_error = None + self._single_observation_space = self.env.single_observation_space + if observation_space is None: + if single_observation_space is not None: + self.observation_space = batch_space(single_observation_space, self.num_envs) + else: + self.observation_space = observation_space + if single_observation_space is None: + # TODO: We could compute this from the observation_space. + self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space." + else: + self._single_observation_space = single_observation_space + """ + self._single_action_space_error = None + self._single_action_space = self.env.single_action_space + if action_space is None: + if single_action_space is not None: + self.action_space = batch_space(single_action_space, self.num_envs) + else: self.action_space = action_space + if single_action_space is None: + self._single_action_space_error = "`single_action_space` not defined. A new action space was provided to the TransformAction wrapper, but not the single action space." + else: + self._single_action_space = single_action_space self.func = func @@ -80,6 +105,13 @@ def actions(self, actions: ActType) -> ActType: """Applies the :attr:`func` to the actions.""" return self.func(actions) + @property + def single_action_space(self) -> Space: + """The single observation space of the environment.""" + if self._single_action_space_error is not None: + raise AttributeError(self._single_action_space_error) + return self._single_action_space + class VectorizeTransformAction(VectorActionWrapper): """Vectorizes a single-agent transform action wrapper for vector environments. diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 52b5b9a07..c23a9d257 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -57,18 +57,32 @@ def __init__( env: VectorEnv, func: Callable[[ObsType], Any], observation_space: Space | None = None, + single_observation_space: Space | None = None, ): """Constructor for the transform observation wrapper. Args: env: The vector environment to wrap func: A function that will transform the vector observation. If this transformed observation is outside the observation space of ``env.observation_space`` then provide an ``observation_space``. - observation_space: The observation spaces of the wrapper, if None, then it is assumed the same as ``env.observation_space``. + observation_space: The observation spaces of the wrapper. If None, then it is computed from ``single_observation_space``. If ``single_observation_space`` is not provided either, then it is assumed to be the same as ``env.observation_space``. + single_observation_space: The observation space of the non-vectorized environment. If None, then it is assumed the same as ``env.single_observation_space``. """ super().__init__(env) - if observation_space is not None: + self._single_observation_space_error = None + self._single_observation_space = self.env.single_observation_space + if observation_space is None: + if single_observation_space is not None: + self.observation_space = batch_space( + single_observation_space, self.num_envs + ) + else: self.observation_space = observation_space + if single_observation_space is None: + # TODO: We could compute this from the observation_space. + self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space." + else: + self._single_observation_space = single_observation_space self.func = func @@ -76,6 +90,13 @@ def observations(self, observations: ObsType) -> ObsType: """Apply function to the vector observation.""" return self.func(observations) + @property + def single_observation_space(self) -> Space: + """Returns the single observation space.""" + if self._single_observation_space_error is not None: + raise AttributeError(self._single_observation_space_error) + return self._single_observation_space + class VectorizeTransformObservation(VectorObservationWrapper): """Vectorizes a single-agent transform observation wrapper for vector environments. diff --git a/tests/wrappers/vector/test_transform_action.py b/tests/wrappers/vector/test_transform_action.py new file mode 100644 index 000000000..a66eff178 --- /dev/null +++ b/tests/wrappers/vector/test_transform_action.py @@ -0,0 +1,41 @@ +"""Test suite for vector TransformObservation wrapper.""" + +import numpy as np + +from gymnasium import spaces, wrappers +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def create_env(): + return GenericTestEnv( + action_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + + +def test_observation_space_from_single_observation_space( + n_envs: int = 5, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformAction( + vec_env, + func=lambda x: x + 100, + single_action_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32) + 100, + high=np.array([10, -5, 10], dtype=np.float32) + 100, + ), + ) + + assert isinstance(vec_env.action_space, spaces.Box) + assert vec_env.action_space.shape == (n_envs, 3) + assert vec_env.action_space.dtype == np.float32 + assert ( + vec_env.action_space.low == np.array([[100, 90, 95]] * n_envs, dtype=np.float32) + ).all() + assert ( + vec_env.action_space.high + == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) + ).all() diff --git a/tests/wrappers/vector/test_transform_observation.py b/tests/wrappers/vector/test_transform_observation.py new file mode 100644 index 000000000..7915fb768 --- /dev/null +++ b/tests/wrappers/vector/test_transform_observation.py @@ -0,0 +1,89 @@ +"""Test suite for vector TransformObservation wrapper.""" + +import numpy as np +import pytest + +from gymnasium import spaces, wrappers +from gymnasium.vector import SyncVectorEnv +from tests.testing_env import GenericTestEnv + + +def create_env(): + return GenericTestEnv( + observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ) + ) + + +def test_transform(n_envs: int = 2): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + single_observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32), + high=np.array([10, -5, 10], dtype=np.float32), + ), + ) + + obs, _ = vec_env.reset(seed=123) + vec_env.observation_space.seed(123) + vec_env.action_space.seed(123) + + assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all() + assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all() + + obs, *_ = vec_env.step(vec_env.action_space.sample()) + + assert (obs >= np.array([100, 90, 95], dtype=np.float32)).all() + assert (obs <= np.array([110, 95, 110], dtype=np.float32)).all() + + +def test_observation_space_from_single_observation_space( + n_envs: int = 5, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + single_observation_space=spaces.Box( + low=np.array([0, -10, -5], dtype=np.float32) + 100, + high=np.array([10, -5, 10], dtype=np.float32) + 100, + ), + ) + + assert isinstance(vec_env.observation_space, spaces.Box) + assert vec_env.observation_space.shape == (n_envs, 3) + assert vec_env.observation_space.dtype == np.float32 + assert ( + vec_env.observation_space.low + == np.array([[100, 90, 95]] * n_envs, dtype=np.float32) + ).all() + assert ( + vec_env.observation_space.high + == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) + ).all() + + +def test_error_on_unspecified_single_observation_space( + n_envs: int = 5, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + observation_space=spaces.Box( + low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, + high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, + ), + ) + + # Environment should still work normally + obs, _ = vec_env.reset() + obs, *_ = vec_env.step(vec_env.action_space.sample()) + + # But if we try to access the single_observation_space, it should error + with pytest.raises(AttributeError): + vec_env.single_observation_space From d410f04a6aee991a0197b49bdf938eeba0e21db4 Mon Sep 17 00:00:00 2001 From: Howard <174055+howardh@users.noreply.github.com> Date: Tue, 7 Jan 2025 15:19:13 -0500 Subject: [PATCH 2/5] Fix --- gymnasium/wrappers/vector/vectorize_action.py | 1 + gymnasium/wrappers/vector/vectorize_observation.py | 1 + tests/wrappers/vector/test_transform_action.py | 12 ++++++++++++ .../wrappers/vector/test_transform_observation.py | 14 ++++++++++++++ 4 files changed, 28 insertions(+) diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index 7baeda0d9..15865f63e 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -92,6 +92,7 @@ def __init__( if action_space is None: if single_action_space is not None: self.action_space = batch_space(single_action_space, self.num_envs) + self._single_action_space = single_action_space else: self.action_space = action_space if single_action_space is None: diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index c23a9d257..7dfebb973 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -76,6 +76,7 @@ def __init__( self.observation_space = batch_space( single_observation_space, self.num_envs ) + self._single_observation_space = single_observation_space else: self.observation_space = observation_space if single_observation_space is None: diff --git a/tests/wrappers/vector/test_transform_action.py b/tests/wrappers/vector/test_transform_action.py index a66eff178..f8837c2cc 100644 --- a/tests/wrappers/vector/test_transform_action.py +++ b/tests/wrappers/vector/test_transform_action.py @@ -29,6 +29,7 @@ def test_observation_space_from_single_observation_space( ), ) + # Check action space assert isinstance(vec_env.action_space, spaces.Box) assert vec_env.action_space.shape == (n_envs, 3) assert vec_env.action_space.dtype == np.float32 @@ -39,3 +40,14 @@ def test_observation_space_from_single_observation_space( vec_env.action_space.high == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) ).all() + + # Check single action space + assert isinstance(vec_env.single_action_space, spaces.Box) + assert vec_env.single_action_space.shape == (3,) + assert vec_env.single_action_space.dtype == np.float32 + assert ( + vec_env.single_action_space.low == np.array([100, 90, 95], dtype=np.float32) + ).all() + assert ( + vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32) + ).all() diff --git a/tests/wrappers/vector/test_transform_observation.py b/tests/wrappers/vector/test_transform_observation.py index 7915fb768..36b3ecc36 100644 --- a/tests/wrappers/vector/test_transform_observation.py +++ b/tests/wrappers/vector/test_transform_observation.py @@ -54,6 +54,7 @@ def test_observation_space_from_single_observation_space( ), ) + # Check observation space assert isinstance(vec_env.observation_space, spaces.Box) assert vec_env.observation_space.shape == (n_envs, 3) assert vec_env.observation_space.dtype == np.float32 @@ -66,6 +67,19 @@ def test_observation_space_from_single_observation_space( == np.array([[110, 95, 110]] * n_envs, dtype=np.float32) ).all() + # Check single observation space + assert isinstance(vec_env.single_observation_space, spaces.Box) + assert vec_env.single_observation_space.shape == (3,) + assert vec_env.single_observation_space.dtype == np.float32 + assert ( + vec_env.single_observation_space.low + == np.array([100, 90, 95], dtype=np.float32) + ).all() + assert ( + vec_env.single_observation_space.high + == np.array([110, 95, 110], dtype=np.float32) + ).all() + def test_error_on_unspecified_single_observation_space( n_envs: int = 5, From 702b21da17143daca42d655bd30596b10088f718 Mon Sep 17 00:00:00 2001 From: Howard <174055+howardh@users.noreply.github.com> Date: Wed, 8 Jan 2025 00:25:21 -0500 Subject: [PATCH 3/5] fix --- gymnasium/wrappers/vector/vectorize_action.py | 16 +--------------- tests/wrappers/vector/test_transform_action.py | 2 +- 2 files changed, 2 insertions(+), 16 deletions(-) diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index 15865f63e..7590bb023 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -73,20 +73,6 @@ def __init__( """ super().__init__(env) - """ - self._single_observation_space_error = None - self._single_observation_space = self.env.single_observation_space - if observation_space is None: - if single_observation_space is not None: - self.observation_space = batch_space(single_observation_space, self.num_envs) - else: - self.observation_space = observation_space - if single_observation_space is None: - # TODO: We could compute this from the observation_space. - self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space." - else: - self._single_observation_space = single_observation_space - """ self._single_action_space_error = None self._single_action_space = self.env.single_action_space if action_space is None: @@ -108,7 +94,7 @@ def actions(self, actions: ActType) -> ActType: @property def single_action_space(self) -> Space: - """The single observation space of the environment.""" + """The single action space of the environment.""" if self._single_action_space_error is not None: raise AttributeError(self._single_action_space_error) return self._single_action_space diff --git a/tests/wrappers/vector/test_transform_action.py b/tests/wrappers/vector/test_transform_action.py index f8837c2cc..2d765e4ec 100644 --- a/tests/wrappers/vector/test_transform_action.py +++ b/tests/wrappers/vector/test_transform_action.py @@ -1,4 +1,4 @@ -"""Test suite for vector TransformObservation wrapper.""" +"""Test suite for vector TransformAction wrapper.""" import numpy as np From 2d98fc8bcc4a8b256e8595a9496e2970b641586a Mon Sep 17 00:00:00 2001 From: Howard <174055+howardh@users.noreply.github.com> Date: Wed, 8 Jan 2025 19:25:04 -0500 Subject: [PATCH 4/5] Fix --- gymnasium/wrappers/vector/vectorize_action.py | 23 ++++++--------- .../wrappers/vector/vectorize_observation.py | 23 ++++++--------- .../vector/test_transform_observation.py | 28 ++++++++----------- 3 files changed, 29 insertions(+), 45 deletions(-) diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index 7590bb023..a3e951d8e 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -9,6 +9,7 @@ from gymnasium import Space from gymnasium.core import ActType, Env +from gymnasium.logger import warn from gymnasium.vector import VectorActionWrapper, VectorEnv from gymnasium.vector.utils import batch_space, concatenate, create_empty_array, iterate from gymnasium.wrappers import transform_action @@ -73,18 +74,19 @@ def __init__( """ super().__init__(env) - self._single_action_space_error = None - self._single_action_space = self.env.single_action_space if action_space is None: if single_action_space is not None: + self.single_action_space = single_action_space self.action_space = batch_space(single_action_space, self.num_envs) - self._single_action_space = single_action_space else: self.action_space = action_space - if single_action_space is None: - self._single_action_space_error = "`single_action_space` not defined. A new action space was provided to the TransformAction wrapper, but not the single action space." - else: - self._single_action_space = single_action_space + if single_action_space is not None: + self.single_action_space = single_action_space + # TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below. + if self.action_space != batch_space(self.single_action_space, self.num_envs): + warn( + "The action space and the batched single action space don't match as expected." + ) self.func = func @@ -92,13 +94,6 @@ def actions(self, actions: ActType) -> ActType: """Applies the :attr:`func` to the actions.""" return self.func(actions) - @property - def single_action_space(self) -> Space: - """The single action space of the environment.""" - if self._single_action_space_error is not None: - raise AttributeError(self._single_action_space_error) - return self._single_action_space - class VectorizeTransformAction(VectorActionWrapper): """Vectorizes a single-agent transform action wrapper for vector environments. diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 7dfebb973..55fe10bb7 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -69,21 +69,23 @@ def __init__( """ super().__init__(env) - self._single_observation_space_error = None - self._single_observation_space = self.env.single_observation_space if observation_space is None: if single_observation_space is not None: + self.single_observation_space = single_observation_space self.observation_space = batch_space( single_observation_space, self.num_envs ) - self._single_observation_space = single_observation_space else: self.observation_space = observation_space - if single_observation_space is None: - # TODO: We could compute this from the observation_space. - self._single_observation_space_error = "`single_observation_space` not defined. A new observation space was provided to the TransformObservation wrapper, but not the single observation space." - else: + if single_observation_space is not None: self._single_observation_space = single_observation_space + # TODO: We could compute single_observation_space from the observation_space if only the latter is provided and avoid the warning below. + if self.observation_space != batch_space( + self.single_observation_space, self.num_envs + ): + warn( + "The observation space and the batched single observation space don't match as expected." + ) self.func = func @@ -91,13 +93,6 @@ def observations(self, observations: ObsType) -> ObsType: """Apply function to the vector observation.""" return self.func(observations) - @property - def single_observation_space(self) -> Space: - """Returns the single observation space.""" - if self._single_observation_space_error is not None: - raise AttributeError(self._single_observation_space_error) - return self._single_observation_space - class VectorizeTransformObservation(VectorObservationWrapper): """Vectorizes a single-agent transform observation wrapper for vector environments. diff --git a/tests/wrappers/vector/test_transform_observation.py b/tests/wrappers/vector/test_transform_observation.py index 36b3ecc36..c8698850e 100644 --- a/tests/wrappers/vector/test_transform_observation.py +++ b/tests/wrappers/vector/test_transform_observation.py @@ -81,23 +81,17 @@ def test_observation_space_from_single_observation_space( ).all() -def test_error_on_unspecified_single_observation_space( +def test_warning_on_mismatched_single_observation_space( n_envs: int = 5, ): vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) - vec_env = wrappers.vector.TransformObservation( - vec_env, - func=lambda x: x + 100, - observation_space=spaces.Box( - low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, - high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, - ), - ) - - # Environment should still work normally - obs, _ = vec_env.reset() - obs, *_ = vec_env.step(vec_env.action_space.sample()) - - # But if we try to access the single_observation_space, it should error - with pytest.raises(AttributeError): - vec_env.single_observation_space + # We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning. + with pytest.warns(Warning): + vec_env = wrappers.vector.TransformObservation( + vec_env, + func=lambda x: x + 100, + observation_space=spaces.Box( + low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, + high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, + ), + ) From afdbecbc0f925fb47bcde2707ed92ff0802e69ca Mon Sep 17 00:00:00 2001 From: Howard <174055+howardh@users.noreply.github.com> Date: Fri, 10 Jan 2025 14:29:32 -0500 Subject: [PATCH 5/5] Fix --- gymnasium/wrappers/vector/vectorize_action.py | 2 +- .../wrappers/vector/vectorize_observation.py | 2 +- .../wrappers/vector/test_transform_action.py | 22 ++++++++++++++++++- .../vector/test_transform_observation.py | 7 ++++-- 4 files changed, 28 insertions(+), 5 deletions(-) diff --git a/gymnasium/wrappers/vector/vectorize_action.py b/gymnasium/wrappers/vector/vectorize_action.py index a3e951d8e..8fc607107 100644 --- a/gymnasium/wrappers/vector/vectorize_action.py +++ b/gymnasium/wrappers/vector/vectorize_action.py @@ -85,7 +85,7 @@ def __init__( # TODO: We could compute single_action_space from the action_space if only the latter is provided and avoid the warning below. if self.action_space != batch_space(self.single_action_space, self.num_envs): warn( - "The action space and the batched single action space don't match as expected." + f"For {env}, the action space and the batched single action space don't match as expected, action_space={env.action_space}, batched single_action_space={batch_space(self.single_action_space, self.num_envs)}" ) self.func = func diff --git a/gymnasium/wrappers/vector/vectorize_observation.py b/gymnasium/wrappers/vector/vectorize_observation.py index 55fe10bb7..5ca07a1e4 100644 --- a/gymnasium/wrappers/vector/vectorize_observation.py +++ b/gymnasium/wrappers/vector/vectorize_observation.py @@ -84,7 +84,7 @@ def __init__( self.single_observation_space, self.num_envs ): warn( - "The observation space and the batched single observation space don't match as expected." + f"For {env}, the observation space and the batched single observation space don't match as expected, observation_space={env.observation_space}, batched single_observation_space={batch_space(self.single_observation_space, self.num_envs)}" ) self.func = func diff --git a/tests/wrappers/vector/test_transform_action.py b/tests/wrappers/vector/test_transform_action.py index 2d765e4ec..5b0d3987d 100644 --- a/tests/wrappers/vector/test_transform_action.py +++ b/tests/wrappers/vector/test_transform_action.py @@ -1,6 +1,7 @@ """Test suite for vector TransformAction wrapper.""" import numpy as np +import pytest from gymnasium import spaces, wrappers from gymnasium.vector import SyncVectorEnv @@ -16,7 +17,7 @@ def create_env(): ) -def test_observation_space_from_single_observation_space( +def test_action_space_from_single_action_space( n_envs: int = 5, ): vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) @@ -51,3 +52,22 @@ def test_observation_space_from_single_observation_space( assert ( vec_env.single_action_space.high == np.array([110, 95, 110], dtype=np.float32) ).all() + + +def test_warning_on_mismatched_single_action_space( + n_envs: int = 2, +): + vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) + # We only specify action_space without single_action_space, so single_action_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning. + with pytest.warns( + Warning, + match=r"the action space and the batched single action space don't match as expected", + ): + vec_env = wrappers.vector.TransformAction( + vec_env, + func=lambda x: x + 100, + action_space=spaces.Box( + low=np.array([[0, -10, -5]] * n_envs, dtype=np.float32) + 100, + high=np.array([[10, -5, 10]] * n_envs, dtype=np.float32) + 100, + ), + ) diff --git a/tests/wrappers/vector/test_transform_observation.py b/tests/wrappers/vector/test_transform_observation.py index c8698850e..98b39b129 100644 --- a/tests/wrappers/vector/test_transform_observation.py +++ b/tests/wrappers/vector/test_transform_observation.py @@ -82,11 +82,14 @@ def test_observation_space_from_single_observation_space( def test_warning_on_mismatched_single_observation_space( - n_envs: int = 5, + n_envs: int = 2, ): vec_env = SyncVectorEnv([create_env for _ in range(n_envs)]) # We only specify observation_space without single_observation_space, so single_observation_space inherits its value from the wrapped env which would not match. This mismatch should give us a warning. - with pytest.warns(Warning): + with pytest.warns( + Warning, + match=r"the observation space and the batched single observation space don't match as expected", + ): vec_env = wrappers.vector.TransformObservation( vec_env, func=lambda x: x + 100,