diff --git a/example.py b/example.py index b98d7d1..edf4f6d 100644 --- a/example.py +++ b/example.py @@ -32,8 +32,6 @@ from mdp_playground.envs import RLToyEnv import numpy as np -display_images = True - def display_image(obs, mode="RGB"): # Display the image observation associated with the next state from PIL import Image @@ -411,6 +409,8 @@ def atari_wrapper_example(): from mdp_playground.envs import GymEnvWrapper import gymnasium as gym + import ale_py + gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit ae = gym.make("QbertNoFrameskip-v4") env = GymEnvWrapper(ae, **config) diff --git a/mdp_playground/envs/gym_env_wrapper.py b/mdp_playground/envs/gym_env_wrapper.py index d940f1e..36d9735 100644 --- a/mdp_playground/envs/gym_env_wrapper.py +++ b/mdp_playground/envs/gym_env_wrapper.py @@ -11,6 +11,11 @@ from PIL.Image import FLIP_LEFT_RIGHT, FLIP_TOP_BOTTOM import logging +# Needed from Gymnasium v1.0.0 onwards +import ale_py +gym.register_envs(ale_py) # optional, helpful for IDEs or pre-commit + + # def get_gym_wrapper(base_class): diff --git a/mdp_playground/envs/rl_toy_env.py b/mdp_playground/envs/rl_toy_env.py index 25b23cf..bc0f375 100644 --- a/mdp_playground/envs/rl_toy_env.py +++ b/mdp_playground/envs/rl_toy_env.py @@ -24,6 +24,8 @@ class RLToyEnv(gym.Env): + metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} + """ The base toy environment in MDP Playground. It is parameterised by a config dict and can be instantiated to be an MDP with any of the possible dimensions from the accompanying research paper. The class extends OpenAI Gym's environment gym.Env. @@ -428,61 +430,65 @@ def __init__(self, **config): self.image_representations = False else: self.image_representations = config["image_representations"] - if "image_transforms" in config: - assert config["state_space_type"] == "discrete", ( - "Image " "transforms are only applicable to discrete envs." - ) - self.image_transforms = config["image_transforms"] - else: - self.image_transforms = "none" - if "image_width" in config: - self.image_width = config["image_width"] - else: - self.image_width = 100 + # Moved these out of the image_representations block when adding render() + # because they are needed for the render() method even if image_representations + # is False. + if "image_transforms" in config: + assert config["state_space_type"] == "discrete", ( + "Image " "transforms are only applicable to discrete envs." + ) + self.image_transforms = config["image_transforms"] + else: + self.image_transforms = "none" - if "image_height" in config: - self.image_height = config["image_height"] - else: - self.image_height = 100 + if "image_width" in config: + self.image_width = config["image_width"] + else: + self.image_width = 100 - # The following transforms are only applicable in discrete envs: - if config["state_space_type"] == "discrete": - if "image_sh_quant" not in config: - if "shift" in self.image_transforms: - warnings.warn( - "Setting image shift quantisation to the \ - default of 1, since no config value was provided for it." - ) - self.image_sh_quant = 1 - else: - self.image_sh_quant = None + if "image_height" in config: + self.image_height = config["image_height"] + else: + self.image_height = 100 + + # The following transforms are only applicable in discrete envs: + if config["state_space_type"] == "discrete": + if "image_sh_quant" not in config: + if "shift" in self.image_transforms: + warnings.warn( + "Setting image shift quantisation to the \ + default of 1, since no config value was provided for it." + ) + self.image_sh_quant = 1 else: - self.image_sh_quant = config["image_sh_quant"] + self.image_sh_quant = None + else: + self.image_sh_quant = config["image_sh_quant"] - if "image_ro_quant" not in config: - if "rotate" in self.image_transforms: - warnings.warn( - "Setting image rotate quantisation to the \ - default of 1, since no config value was provided for it." - ) - self.image_ro_quant = 1 - else: - self.image_ro_quant = None + if "image_ro_quant" not in config: + if "rotate" in self.image_transforms: + warnings.warn( + "Setting image rotate quantisation to the \ + default of 1, since no config value was provided for it." + ) + self.image_ro_quant = 1 else: - self.image_ro_quant = config["image_ro_quant"] + self.image_ro_quant = None + else: + self.image_ro_quant = config["image_ro_quant"] - if "image_scale_range" not in config: - if "scale" in self.image_transforms: - warnings.warn( - "Setting image scale range to the default \ - of (0.5, 1.5), since no config value was provided for it." - ) - self.image_scale_range = (0.5, 1.5) - else: - self.image_scale_range = None + if "image_scale_range" not in config: + if "scale" in self.image_transforms: + warnings.warn( + "Setting image scale range to the default \ + of (0.5, 1.5), since no config value was provided for it." + ) + self.image_scale_range = (0.5, 1.5) else: - self.image_scale_range = config["image_scale_range"] + self.image_scale_range = None + else: + self.image_scale_range = config["image_scale_range"] # Defaults for the individual environment types: if config["state_space_type"] == "discrete": @@ -827,6 +833,15 @@ def __init__(self, **config): + ", " + str(len(self.augmented_state)) ) + + # Needed for rendering with pygame for use with Gymnasium.Env's render() method: + render_mode = config.get("render_mode", None) + assert render_mode is None or render_mode in self.metadata["render_modes"] + self.render_mode = render_mode + + self.window = None + self.clock = None + self.logger.debug( "MDP Playground toy env instantiated with config: " + str(self.config) ) @@ -1639,7 +1654,8 @@ def transition_function(self, state, action): / factorial_array[j] ) # print('self.state_derivatives:', self.state_derivatives) - next_state = self.state_derivatives[0] + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state + next_state = self.state_derivatives[0].copy() else: # if action is from outside allowed action_space next_state = state @@ -1684,7 +1700,8 @@ def transition_function(self, state, action): self.state_derivatives = [ zero_state.copy() for i in range(self.dynamics_order + 1) ] - self.state_derivatives[0] = next_state + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state + self.state_derivatives[0] = next_state.copy() if self.config["reward_function"] == "move_to_a_point": next_state_rel = np.array(next_state, dtype=self.dtype_s)[ @@ -2126,7 +2143,7 @@ def get_augmented_state(self): return augmented_state_dict - def reset(self, seed=None): + def reset(self, seed=None, options=None): """Resets the environment for the beginning of an episode and samples a start state from rho_0. For discrete environments uses the defined rho_0 directly. For continuous environments, samples a state and resamples until a non-terminal state is sampled. Returns @@ -2225,7 +2242,8 @@ def reset(self, seed=None): zero_state.copy() for i in range(self.dynamics_order + 1) ] # #####IMP to have copy() # otherwise it's the same array (in memory) at every position in the list - self.state_derivatives[0] = self.curr_state + # copy to avoid modifying the original state which may be used by external code, e.g. to print the state + self.state_derivatives[0] = self.curr_state.copy() self.augmented_state = [ [np.nan] * self.state_space_dim @@ -2316,6 +2334,82 @@ def seed(self, seed=None): ) return self.seed_ + def render(self,): + ''' + Renders the environment using pygame if render_mode is "human" and returns the rendered + image if render_mode is "rgb_array". + + Based on https://gymnasium.farama.org/tutorials/gymnasium_basics/environment_creation/ + ''' + + import pygame + + # Init stuff on first call. For non-image_representations based envs, it makes sense + # to only instantiate the render_space here and not in __init__ because it's only needed + # if render() is called. + if self.window is None: + if self.image_representations: + self.render_space = self.observation_space + else: + if self.config["state_space_type"] == "discrete": + self.render_space = ImageMultiDiscrete( + self.state_space_size, + width=self.image_width, + height=self.image_height, + transforms=self.image_transforms, + sh_quant=self.image_sh_quant, + scale_range=self.image_scale_range, + ro_quant=self.image_ro_quant, + circle_radius=20, + seed=self.seed_dict["image_representations"], + ) # #seed + elif self.config["state_space_type"] == "continuous": + self.render_space = ImageContinuous( + self.feature_space, + width=self.image_width, + height=self.image_height, + term_spaces=self.term_spaces, + target_point=self.target_point, + circle_radius=5, + seed=self.seed_dict["image_representations"], + ) # #seed + elif self.config["state_space_type"] == "grid": + target_pt = list_to_float_np_array(self.target_point) + self.render_space = ImageContinuous( + self.feature_space, + width=self.image_width, + height=self.image_height, + term_spaces=self.term_spaces, + target_point=target_pt, + circle_radius=5, + grid_shape=self.grid_shape, + seed=self.seed_dict["image_representations"], + ) # #seed + + + if self.window is None and self.render_mode == "human": + pygame.init() + pygame.display.init() + self.window = pygame.display.set_mode( + (self.image_width, self.image_height) + ) + if self.clock is None and self.render_mode == "human": + self.clock = pygame.time.Clock() + + # ##TODO There are repeated calculations here in calling get_concatenated_image + # that can be taken from storing variables in step() or reset(). + if self.render_mode == "human": + rgb_array = self.render_space.get_concatenated_image(self.curr_state) + pygame_surface = pygame.surfarray.make_surface(rgb_array) + self.window.blit(pygame_surface, pygame_surface.get_rect()) + pygame.event.pump() + pygame.display.update() + + # We need to ensure that human-rendering occurs at the predefined framerate. + # The following line will automatically add a delay to keep the framerate stable. + self.clock.tick(self.metadata["render_fps"]) + elif self.render_mode == "rgb_array": + return self.render_space.get_concatenated_image(self.curr_state) def dist_of_pt_from_line(pt, ptA, ptB): """Returns shortest distance of a point from a line defined by 2 points - ptA and ptB. diff --git a/mdp_playground/spaces/image_continuous.py b/mdp_playground/spaces/image_continuous.py index 52e96cf..70740ab 100644 --- a/mdp_playground/spaces/image_continuous.py +++ b/mdp_playground/spaces/image_continuous.py @@ -212,7 +212,9 @@ def get_concatenated_image(self, obs): # image to have >=3 dims def convert_to_pixel(self, position): - """ """ + """ + Convert a continuous position to a pixel position in the image + """ # It's implicit that both relevant and irrelevant sub-spaces have the # same max and min here: max = self.feature_space.high[self.relevant_indices] diff --git a/tests/test_gym_env_wrapper.py b/tests/test_gym_env_wrapper.py index 1159f11..a55ebff 100644 --- a/tests/test_gym_env_wrapper.py +++ b/tests/test_gym_env_wrapper.py @@ -44,7 +44,7 @@ def test_r_delay(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("observation_space.shape:", ob.shape) # print(ob) total_reward = 0.0 @@ -83,7 +83,7 @@ def test_r_shift(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("observation_space.shape:", ob.shape) # print(ob) total_reward = 0.0 @@ -123,7 +123,7 @@ def test_r_scale(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("observation_space.shape:", ob.shape) # print(ob) total_reward = 0.0 @@ -164,7 +164,7 @@ def test_r_scale(self): # ae = gym.make("BeamRiderNoFrameskip-v4") # aew = GymEnvWrapper(ae, **config) - # ob = aew.reset() + # ob, _ = aew.reset() # print("observation_space.shape:", ob.shape) # # print(ob) # total_reward = 0.0 @@ -211,7 +211,7 @@ def test_r_scale(self): # game = "".join([g.capitalize() for g in game.split("_")]) # ae = gym.make("{}NoFrameskip-v4".format(game)) # aew = GymEnvWrapper(ae, **config) - # ob = aew.reset() + # ob, _ = aew.reset() # print("observation_space.shape:", ob.shape) # # print(ob) # total_reward = 0.0 @@ -253,7 +253,7 @@ def test_r_delay_p_noise_r_noise(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("observation_space.shape:", ob.shape) # print(ob) total_reward = 0.0 @@ -316,7 +316,7 @@ def test_discrete_irr_features(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("type(observation_space):", type(ob)) # print(ob) total_reward = 0.0 @@ -364,7 +364,7 @@ def test_image_transforms(self): ae = gym.make("BeamRiderNoFrameskip-v4") aew = GymEnvWrapper(ae, **config) - ob = aew.reset() + ob, _ = aew.reset() print("observation_space.shape:", ob.shape) assert ob.shape == (100, 100, 3), "Observation shape of the env was unexpected." # print(ob) @@ -420,7 +420,7 @@ def test_cont_irr_features(self): # register_env("HalfCheetahWrapper-v3", lambda config: HalfCheetahWrapperV3(**config)) hc3w = GymEnvWrapper(hc3, **config) - ob = hc3w.reset() + ob, _ = hc3w.reset() print("obs shape, type(observation_space):", ob.shape, type(ob)) print("initial obs: ", ob) assert (