From 5d7de3ecbb7599aa9d5500d3273cd15d96e81a56 Mon Sep 17 00:00:00 2001 From: Alexis DUBURCQ Date: Sat, 18 Jan 2025 10:39:02 +0000 Subject: [PATCH] Improve 'set_wrapper_attr'. --- gymnasium/core.py | 38 +++++++++++++++++++------------------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/gymnasium/core.py b/gymnasium/core.py index aaf9476f7..c0c5ee68c 100644 --- a/gymnasium/core.py +++ b/gymnasium/core.py @@ -273,9 +273,12 @@ def get_wrapper_attr(self, name: str) -> Any: """Gets the attribute `name` from the environment.""" return getattr(self, name) - def set_wrapper_attr(self, name: str, value: Any): + def set_wrapper_attr(self, name: str, value: Any, *, force: bool = False): """Sets the attribute `name` on the environment with `value`.""" - setattr(self, name, value) + if force or hasattr(self, name): + setattr(self, name, value) + else: + raise AttributeError(f"{self} has no attribute {name!r}") WrapperObsType = TypeVar("WrapperObsType") @@ -425,30 +428,27 @@ def get_wrapper_attr(self, name: str) -> Any: f"wrapper {self.class_name()} has no attribute {name!r}" ) from e - def set_wrapper_attr(self, name: str, value: Any): + def set_wrapper_attr(self, name: str, value: Any, *, force: bool = True): """Sets an attribute on this wrapper or lower environment if `name` is already defined. Args: name: The variable name value: The new variable value + force: Whether to create the attribute on this wrapper if it does not exists on the + lower environment instead of raising an exception """ - sub_env = self - - # loop through all the wrappers, checking if it has the variable name then setting it - # otherwise stripping the wrapper to check the next. - # end when the core env is reached - while isinstance(sub_env, Wrapper): - if hasattr(sub_env, name): - setattr(sub_env, name, value) - return - - sub_env = sub_env.env - - # check if the base environment has the wrapper, otherwise, we set it on the top (this) wrapper - if hasattr(sub_env, name): - setattr(sub_env, name, value) - else: + if hasattr(self, name): setattr(self, name, value) + else: + try: + self.env.set_wrapper_attr(name, value, force=False) + except AttributeError as e: + if force: + setattr(self, name, value) + else: + raise AttributeError( + f"wrapper {self.class_name()} has no attribute {name!r}" + ) from e def __str__(self): """Returns the wrapper name and the :attr:`env` representation string."""