Skip to content

Commit

Permalink
updating wrappers to have checkpointing capabiliites
Browse files Browse the repository at this point in the history
  • Loading branch information
reginald-mclean committed Nov 8, 2024
1 parent 3f1432c commit 69d3515
Showing 1 changed file with 106 additions and 0 deletions.
106 changes: 106 additions & 0 deletions metaworld/wrappers.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from __future__ import annotations

import base64

import gymnasium as gym
import numpy as np
from gymnasium import Env
from numpy.typing import NDArray

from metaworld import SawyerXYZEnv
from metaworld.types import Task


Expand Down Expand Up @@ -62,6 +65,25 @@ def sample_tasks(self, *, seed: int | None = None, options: dict | None = None):
self._set_random_task()
return self.env.reset(seed=seed, options=options)

def get_checkpoint(self) -> dict:
return {
"tasks": self.tasks,
"current_task_idx": self.current_task_idx,
"sample_tasks_on_reset": self.sample_tasks_on_reset,
"env_rng_state": get_env_rng_checkpoint(self.unwrapped),
}

def load_checkpoint(self, ckpt: dict):
assert "tasks" in ckpt
assert "current_task_idx" in ckpt
assert "sample_tasks_on_reset" in ckpt
assert "env_rng_state" in ckpt

self.tasks = ckpt["tasks"]
self.current_task_idx = ckpt["current_task_idx"]
self.sample_tasks_on_reset = ckpt["sample_tasks_on_reset"]
set_env_rng(self.unwrapped, ckpt["env_rng_state"])


class PseudoRandomTaskSelectWrapper(gym.Wrapper):
"""A Gymnasium Wrapper to automatically reset the environment to a *pseudo*random task when explicitly called.
Expand Down Expand Up @@ -105,6 +127,25 @@ def sample_tasks(self, *, seed: int | None = None, options: dict | None = None):
self._set_pseudo_random_task()
return self.env.reset(seed=seed, options=options)

def get_checkpoint(self) -> dict:
return {
"tasks": self.tasks,
"current_task_idx": self.current_task_idx,
"sample_tasks_on_reset": self.sample_tasks_on_reset,
"env_rng_state": get_env_rng_checkpoint(self.unwrapped),
}

def load_checkpoint(self, ckpt: dict):
assert "tasks" in ckpt
assert "current_task_idx" in ckpt
assert "sample_tasks_on_reset" in ckpt
assert "env_rng_state" in ckpt

self.tasks = ckpt["tasks"]
self.current_task_idx = ckpt["current_task_idx"]
self.sample_tasks_on_reset = ckpt["sample_tasks_on_reset"]
set_env_rng(self.unwrapped, ckpt["env_rng_state"])


class AutoTerminateOnSuccessWrapper(gym.Wrapper):
"""A Gymnasium Wrapper to automatically output a termination signal when the environment's task is solved.
Expand All @@ -130,3 +171,68 @@ def step(self, action):
if self.terminate_on_success:
terminated = info["success"] == 1.0
return obs, reward, terminated, truncated, info


def get_env_rng_checkpoint(env: SawyerXYZEnv) -> dict[str, dict]:
return { # pyright: ignore [reportReturnType]
"np_random_state": env.np_random.__getstate__(),
"action_space_rng_state": env.action_space.np_random.__getstate__(),
"obs_space_rng_state": env.observation_space.np_random.__getstate__(),
"goal_space_rng_state": env.goal_space.np_random.__getstate__(), # type: ignore
}


def set_env_rng(env: SawyerXYZEnv, state: dict[str, dict]) -> None:
assert "np_random_state" in state
assert "action_space_rng_state" in state
assert "obs_space_rng_state" in state
assert "goal_space_rng_state" in state

env.np_random.__setstate__(state["np_random_state"])
env.action_space.np_random.__setstate__(state["action_space_rng_state"])
env.observation_space.np_random.__setstate__(state["obs_space_rng_state"])
env.goal_space.np_random.__setstate__(state["goal_space_rng_state"]) # type: ignore


class CheckpointWrapper(gym.Wrapper):
env_id: str

def __init__(self, env: gym.Env, env_id: str):
super().__init__(env)
assert hasattr(self.env, "get_checkpoint") and callable(self.env.get_checkpoint)
assert hasattr(self.env, "load_checkpoint") and callable(
self.env.load_checkpoint
)
self.env_id = env_id

def get_checkpoint(self) -> tuple[str, dict]:
ckpt: dict = self.env.get_checkpoint()
return (self.env_id, ckpt)

def load_checkpoint(self, ckpts: list[tuple[str, dict]]) -> None:
my_ckpt = None
for env_id, ckpt in ckpts:
if env_id == self.env_id:
my_ckpt = ckpt
break
if my_ckpt is None:
raise ValueError(
f"Could not load checkpoint, no checkpoint found with id {self.env_id}. Checkpoint IDs: ",
[env_id for env_id, _ in ckpts],
)
self.env.load_checkpoint(my_ckpt)


def _serialize_task(task: Task) -> dict:
return {
"env_name": task.env_name,
"data": base64.b64encode(task.data).decode("ascii"),
}


def _deserialize_task(task_dict: dict[str, str]) -> Task:
assert "env_name" in task_dict and "data" in task_dict

return Task(
env_name=task_dict["env_name"], data=base64.b64decode(task_dict["data"])
)

0 comments on commit 69d3515

Please sign in to comment.