Skip to content

Commit

Permalink
Merge pull request #44 from Farama-Foundation/feature/ezpickle
Browse files Browse the repository at this point in the history
EzPickle for Environments
  • Loading branch information
ffelten authored Apr 30, 2024
2 parents 9d264c5 + fececf8 commit f38e4f3
Show file tree
Hide file tree
Showing 17 changed files with 178 additions and 23 deletions.
16 changes: 13 additions & 3 deletions momaland/envs/beach/beach.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,12 @@

import functools
import random

# from gymnasium.utils import EzPickle
from typing_extensions import override

import numpy as np
from gymnasium.logger import warn
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle
from pettingzoo.utils import wrappers

from momaland.utils.conversions import mo_parallel_to_aec
Expand Down Expand Up @@ -52,7 +51,7 @@ def raw_env(**kwargs):
return MOBeachDomain(**kwargs)


class MOBeachDomain(MOParallelEnv):
class MOBeachDomain(MOParallelEnv, EzPickle):
"""A `Parallel` 2-objective environment of the Beach problem domain.
## Observation Space
Expand Down Expand Up @@ -124,6 +123,17 @@ def __init__(
render_mode: render mode
reward_scheme: the reward scheme to use ('local', or 'global'). Default: local
"""
EzPickle.__init__(
self,
num_timesteps,
num_agents,
reward_scheme,
sections,
capacity,
type_distribution,
position_distribution,
render_mode,
)
self.reward_scheme = reward_scheme
self.sections = sections
# TODO Extend to distinct capacities per section?
Expand Down
10 changes: 9 additions & 1 deletion momaland/envs/breakthrough/breakthrough.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
from gymnasium import spaces
from gymnasium.logger import warn
from gymnasium.utils import EzPickle
from pettingzoo.utils import agent_selector, wrappers

from momaland.utils.env import MOAECEnv
Expand Down Expand Up @@ -54,7 +55,7 @@ def raw_env(**kwargs):
return MOBreakthrough(**kwargs)


class MOBreakthrough(MOAECEnv):
class MOBreakthrough(MOAECEnv, EzPickle):
"""Multi-objective Breakthrough.
MO-Breakthrough is a multi-objective variant of the two-player, single-objective turn-based board game Breakthrough.
Expand Down Expand Up @@ -125,6 +126,13 @@ def __init__(self, board_width: int = 8, board_height: int = 8, num_objectives:
num_objectives: The number of objectives (from 1 to 4)
render_mode: The render mode.
"""
EzPickle.__init__(
self,
board_width,
board_height,
num_objectives,
render_mode,
)
if not (3 <= board_width <= 20):
raise ValueError("Config parameter board_width must be between 3 and 20.")

Expand Down
12 changes: 11 additions & 1 deletion momaland/envs/congestion/congestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import numpy as np
from gymnasium.logger import warn
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle
from pettingzoo.utils import wrappers
from sympy import diff, lambdify, sympify

Expand Down Expand Up @@ -49,7 +50,7 @@ def raw_env(**kwargs):
return MOCongestion(**kwargs)


class MOCongestion(MOParallelEnv):
class MOCongestion(MOParallelEnv, EzPickle):
"""A `Parallel` environment where drivers learn to travel from a source to a destination while avoiding congestion.
Multi-objective version of Braess' Paradox where drivers have two objectives: travel time and monetary cost.
Expand Down Expand Up @@ -108,6 +109,15 @@ def __init__(
num_timesteps: number of timesteps (stateless, therefore always 1 timestep)
render_mode: render mode
"""
EzPickle.__init__(
self,
problem_name,
num_agents,
toll_mode,
random_toll_percentage,
num_timesteps,
render_mode,
)
# Read in the problem from the corresponding .json file in the networks directory
self.graph, self.od, self.routes, self._max_route_length = self._read_problem(problem_name)
# Keep track of the current flow on each link the network
Expand Down
9 changes: 8 additions & 1 deletion momaland/envs/connect4/connect4.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,14 @@ def __init__(
board_height: The height of the board (from 4 to 20)
column_objectives: Whether to use column objectives or not (without them, there are 2 objectives. With them, there are 2+board_width objectives)
"""
EzPickle.__init__(self, render_mode, screen_scaling)
EzPickle.__init__(
self,
render_mode,
screen_scaling,
board_width,
board_height,
column_objectives,
)
self.env = super().__init__()

if not (4 <= board_width <= 20):
Expand Down
5 changes: 3 additions & 2 deletions momaland/envs/crazyrl/catch/catch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import override

import numpy as np
from gymnasium.utils import EzPickle
from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper

from momaland.envs.crazyrl.crazyRL_base import FPS, CrazyRLBaseParallelEnv
Expand Down Expand Up @@ -49,7 +50,7 @@ def raw_env(*args, **kwargs):
return Catch(*args, **kwargs)


class Catch(CrazyRLBaseParallelEnv):
class Catch(CrazyRLBaseParallelEnv, EzPickle):
"""A `Parallel` environment where drones learn how to surround a moving target trying to escape.
## Observation Space
Expand Down Expand Up @@ -108,8 +109,8 @@ def __init__(self, *args, target_speed=0.1, **kwargs):
init_target_location (nparray, optional): Array of the initial position of the moving target
target_speed (float, optional): Distance traveled by the target at each timestep
"""

super().__init__(*args, **kwargs)
EzPickle.__init__(self, *args, target_speed=0.1, **kwargs)
self.target_speed = target_speed

def _move_target(self):
Expand Down
9 changes: 6 additions & 3 deletions momaland/envs/crazyrl/escort/escort.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing_extensions import override

import numpy as np
from gymnasium.utils import EzPickle
from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper

from momaland.envs.crazyrl.crazyRL_base import FPS, CrazyRLBaseParallelEnv
Expand Down Expand Up @@ -49,7 +50,7 @@ def raw_env(*args, **kwargs):
return Escort(*args, **kwargs)


class Escort(CrazyRLBaseParallelEnv):
class Escort(CrazyRLBaseParallelEnv, EzPickle):
"""A `Parallel` environment where drones learn how to escort a moving target.
## Observation Space
Expand Down Expand Up @@ -110,9 +111,11 @@ def __init__(self, *args, num_intermediate_points: int = 50, final_target_locati
final_target_location (nparray[float], optional): Array of the final position of the moving target
num_intermediate_points (int, optional): Number of intermediate points in the target trajectory
"""
self.final_target_location = final_target_location

EzPickle.__init__(
self, *args, num_intermediate_points=num_intermediate_points, final_target_location=final_target_location, **kwargs
)
super().__init__(*args, **kwargs)
self.final_target_location = final_target_location

# There are two more ref points than intermediate points, one for the initial and final target locations
self.num_ref_points = num_intermediate_points + 2
Expand Down
8 changes: 7 additions & 1 deletion momaland/envs/crazyrl/surround/surround.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing_extensions import override

import numpy as np
from gymnasium.utils import EzPickle
from pettingzoo.utils.wrappers import AssertOutOfBoundsWrapper

from momaland.envs.crazyrl.crazyRL_base import FPS, CrazyRLBaseParallelEnv
Expand Down Expand Up @@ -48,7 +49,7 @@ def raw_env(*args, **kwargs):
return Surround(*args, **kwargs)


class Surround(CrazyRLBaseParallelEnv):
class Surround(CrazyRLBaseParallelEnv, EzPickle):
"""A `Parallel` environment where drones learn how to surround a static target point.
## Observation Space
Expand Down Expand Up @@ -99,6 +100,11 @@ class Surround(CrazyRLBaseParallelEnv):
@override
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
EzPickle.__init__(
self,
*args,
**kwargs,
)

@override
def _transition_state(self, actions):
Expand Down
20 changes: 19 additions & 1 deletion momaland/envs/gem_mining/gem_mining.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import numpy as np
from gymnasium.logger import warn
from gymnasium.spaces import Box, Discrete
from gymnasium.utils import EzPickle
from pettingzoo.utils import wrappers

from momaland.utils.conversions import mo_parallel_to_aec
Expand Down Expand Up @@ -49,7 +50,7 @@ def raw_env(**kwargs):
return MOGemMining(**kwargs)


class MOGemMining(MOParallelEnv):
class MOGemMining(MOParallelEnv, EzPickle):
"""Environment for MO-GemMining domain.
## Observation Space
Expand Down Expand Up @@ -134,6 +135,23 @@ def __init__(
render_mode: render mode
seed: This environment is generated randomly using the provided seed. Defaults to 42.
"""
EzPickle.__init__(
self,
num_agents,
num_objectives,
min_connectivity,
max_connectivity,
min_workers,
max_workers,
min_prob,
max_prob,
trunc_probability,
w_bonus,
correlated_objectives,
num_timesteps,
render_mode,
seed,
)
self.num_timesteps = num_timesteps
self.episode_num = 0
self.render_mode = render_mode
Expand Down
15 changes: 12 additions & 3 deletions momaland/envs/ingenious/ingenious.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,12 @@

import functools
import random

# from gymnasium.utils import EzPickle
from typing_extensions import override

import numpy as np
from gymnasium.logger import warn
from gymnasium.spaces import Box, Dict, Discrete
from gymnasium.utils import EzPickle
from pettingzoo.utils import wrappers

from momaland.envs.ingenious.ingenious_base import ALL_COLORS, IngeniousBase
Expand Down Expand Up @@ -99,7 +98,7 @@ def raw_env(**kwargs):
return Ingenious(**kwargs)


class Ingenious(MOAECEnv):
class Ingenious(MOAECEnv, EzPickle):
"""Environment for the Ingenious board game."""

metadata = {"render_modes": ["human"], "name": "moingenious_v0", "is_parallelizable": False}
Expand Down Expand Up @@ -128,6 +127,16 @@ def __init__(
fully_obs (bool): Fully observable game mode, i.e. the racks of all players are visible. Default is False.
render_mode (str): The rendering mode. Default: None
"""
EzPickle.__init__(
self,
num_agents,
rack_size,
num_colors,
board_size,
reward_mode,
fully_obs,
render_mode,
)
self.num_colors = num_colors
self.init_draw = rack_size
self.max_score = 18 # max score in score board for one certain color.
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
20 changes: 14 additions & 6 deletions momaland/envs/item_gathering/item_gathering.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,16 @@
import random
from copy import deepcopy
from os import path

# from gymnasium.utils import EzPickle
from typing_extensions import override

import numpy as np
import pygame
from gymnasium.logger import warn
from gymnasium.spaces import Box, Discrete, Tuple
from gymnasium.utils import EzPickle
from pettingzoo.utils import wrappers

from momaland.envs.item_gathering.asset_utils import del_colored, get_colored
from momaland.envs.item_gathering.asset_utils import get_colored
from momaland.envs.item_gathering.map_utils import DEFAULT_MAP, randomise_map
from momaland.utils.conversions import mo_parallel_to_aec
from momaland.utils.env import MOParallelEnv
Expand Down Expand Up @@ -72,7 +71,7 @@ def raw_env(**kwargs):
return MOItemGathering(**kwargs)


class MOItemGathering(MOParallelEnv):
class MOItemGathering(MOParallelEnv, EzPickle):
"""A `Parallel` multi-objective environment of the Item Gathering problem.
## Observation Space
Expand Down Expand Up @@ -111,6 +110,7 @@ class MOItemGathering(MOParallelEnv):
"name": "moitem_gathering_v0",
"is_parallelizable": True,
"central_observation": True,
"render_fps": 30,
}

def __init__(
Expand All @@ -128,6 +128,13 @@ def __init__(
randomise: whether to randomise the map, at each episode
render_mode: render mode for the environment
"""
EzPickle.__init__(
self,
num_timesteps,
initial_map,
randomise,
render_mode,
)
self.num_timesteps = num_timesteps
self.current_timestep = 0
self.render_mode = render_mode
Expand Down Expand Up @@ -297,9 +304,10 @@ def render(self):

@override
def close(self):
if self.render_mode is not None:
del_colored()
pass
# This breaks the pickle tests
# if self.render_mode is not None:
# del_colored()

@override
def reset(self, seed=None, options=None):
Expand Down
13 changes: 12 additions & 1 deletion momaland/envs/samegame/same_game.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import numpy as np
from gymnasium import spaces
from gymnasium.logger import warn
from gymnasium.utils import EzPickle
from pettingzoo.utils import agent_selector, wrappers

from momaland.utils.env import MOAECEnv
Expand Down Expand Up @@ -70,7 +71,7 @@ def raw_env(**kwargs):
return MOSameGame(**kwargs)


class MOSameGame(MOAECEnv):
class MOSameGame(MOAECEnv, EzPickle):
"""Multi-objective Multi-agent SameGame.
MO-SameGame is a multi-objective, multi-agent variant of the single-player, single-objective turn-based puzzle
Expand Down Expand Up @@ -152,6 +153,16 @@ def __init__(
color_rewards: True = agents get separate rewards for each color, False = agents get a single reward accumulating all colors
render_mode: The render mode
"""
EzPickle.__init__(
self,
board_width,
board_height,
num_colors,
num_agents,
team_rewards,
color_rewards,
render_mode,
)
self.env = super().__init__()

self.rng = np.random.default_rng()
Expand Down
Binary file added momaland/videos/walker_mid.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added momaland/videos/walker_stable.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit f38e4f3

Please sign in to comment.