Skip to content

Commit

Permalink
[WIP] Define framework
Browse files Browse the repository at this point in the history
  • Loading branch information
dandansamax committed Nov 13, 2024
1 parent 6d1be33 commit 4ae656b
Show file tree
Hide file tree
Showing 9 changed files with 704 additions and 0 deletions.
32 changes: 32 additions & 0 deletions refactor_demo/core/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Iterable

from .environment import Environment
from .evaluator import Evaluator


class Benchmark(ABC):
@abstractmethod
def get_corresponding_env(self) -> Environment:
pass

@abstractmethod
def get_task_by_id(self, id: str) -> str:
pass

@abstractmethod
def tasks(self) -> Iterable[tuple[str, Evaluator]]:
pass
61 changes: 61 additions & 0 deletions refactor_demo/core/environment.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from typing import Any

import gymnasium as gym
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam


class Environment(gym.Env, ABC):
"""The base environment class for agents to interact with in the CRAB framework.
Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class
for all environments in the CRAB. Your must implement two functions
`get_action_schema` and `convert_tool_call_to_action` to make the environment
compatible with OpenAI tool use API.
"""

@abstractmethod
def get_description(self) -> str:
"""Get the description of the environment, which can be used as a part of the
agent prompt.
Returns:
A string description of the environment.
"""

@abstractmethod
def get_action_schema(self) -> list[ChatCompletionToolParam]:
"""Get the tool schema for the action space of the environment.
The schema provides detailed descriptions of the whole actions space and their
parameters that represent all the possible actions in the tool calling format,
which can be directly used in the OpenAI API. It should be comprehensive and do
not produce any misunderstanding for a human user.
Returns:
A list of tool schema.
"""
...

@abstractmethod
def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:
"""Convert a tool call to the actual action space in the environment.
Args:
tool_name: The name of the tool.
parameters: The parameters of the tool call.
"""
...
20 changes: 20 additions & 0 deletions refactor_demo/core/evaluator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod


class Evaluator(ABC):
@abstractmethod
def step(self, environment: Environment, task: Task) -> Any:
pass
13 changes: 13 additions & 0 deletions refactor_demo/core/policy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
27 changes: 27 additions & 0 deletions refactor_demo/core/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from typing import Any

from .environment import Environment


class Task(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
id: str
description: str
evaluator: Evaluator
setup: setup = []
extra_action: list[Action] = []
76 changes: 76 additions & 0 deletions refactor_demo/core/task_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from typing import Any

import gymnasium as gym
from gymnasium import Wrapper
from gymnasium.core import ActType, ObsType, WrapperObsType
from gymnasium.spaces import Dict, Space, Text, Tuple


class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):
def __init__(
self,
env: gym.Env[ObsType, ActType],
task: Task,
*,
dict_task_key: str = "task",
):
super().__init__(env)
self.env = env
self.task = task

task_space = Text(500)

# Observation space in different situations
if isinstance(env.observation_space, Dict):
assert dict_task_key not in env.observation_space.keys()
observation_space = Dict(
{dict_task_key: task_space, **env.observation_space.spaces}
)
self._append_data_func = lambda obs, task: {dict_task_key: task, **obs}
elif isinstance(env.observation_space, Tuple):
observation_space = Tuple(env.observation_space.spaces + (task_space,))
self._append_data_func = lambda obs, task: obs + (task,)
else:
observation_space = Dict(obs=env.observation_space, task=task_space)
self._append_data_func = lambda obs, task: {"obs": obs, "task": task}

self.observation_space: gym.Space[WrapperObsType] = observation_space

def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[Dict, dict[str, Any]]:
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified
observation using :meth:`self.observation`."""
obs, info = self.env.reset(seed=seed, options=options)
return self.observation(obs), info

def step(
self, action: ActType
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:
observation, reward, terminal, truncated, info = self.step(action)
reward = self.task.evaluate(self.env)
return self.observation(observation), reward, terminal, truncated, info

def observation(self, observation: ObsType):
"""Returns a modified observation.
Args:
observation: The :attr:`env` observation
Returns:
The modified observation
"""
return self._append_data_func(observation, self.task.description)
20 changes: 20 additions & 0 deletions refactor_demo/core/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from abc import ABC, abstractmethod


class Workflow(ABC):
@abstractmethod
def step(self):
pass
106 changes: 106 additions & 0 deletions refactor_demo/envs/multi_env.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
# Licensed under the Apache License, Version 2.0 (the “License”);
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an “AS IS” BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
import gymnasium as gym
import numpy as np
from gymnasium import spaces


class MultiEnv(gym.Env):
def __init__(self, envs):
"""
Initialize the MultiEnv environment.
Args:
envs (list): A list of gymnasium environments to integrate.
"""
super().__init__()

# Store the environments
self.envs = envs

# Create action space using OneOf with the action spaces of each environment
self.action_space = spaces.OneOf([env.action_space for env in envs])

# Create observation space as a Dict space containing each environment's observation space
self.observation_space = spaces.Dict(
{f"env_{i}": env.observation_space for i, env in enumerate(envs)}
)

def reset(self):
"""
Reset all environments and return initial observations.
Returns:
dict: A dictionary with initial observations from each environment.
"""
observations = {}
for i, env in enumerate(self.envs):
observations[f"env_{i}"], _ = env.reset()
return observations

def step(self, action):
"""
Take a step in the selected environment based on the action.
Args:
action (int): The index of the environment to take a step in.
Returns:
tuple: A tuple containing the observations, rewards, done flags, and info.
"""
assert 0 <= action < len(self.envs), "Invalid action for environment selection."

# Initialize dictionaries to store results
observations = {}
rewards = {}
dones = {}
infos = {}

# Perform a step in the selected environment
obs, reward, done, truncated, info = self.envs[action].step(action)

# Populate results for the selected environment
observations[f"env_{action}"] = obs
rewards[f"env_{action}"] = reward
dones[f"env_{action}"] = done
infos[f"env_{action}"] = info

# For other environments, simply pass their previous observations
for i, env in enumerate(self.envs):
if i != action:
observations[f"env_{i}"] = (
None # No new observation for non-acting environments
)
rewards[f"env_{i}"] = 0
dones[f"env_{i}"] = False
infos[f"env_{i}"] = {}

# Set done if all environments are done
all_done = all(dones.values())

return observations, rewards, all_done, infos

def render(self, mode="human"):
"""
Render all environments (optional implementation).
"""
for i, env in enumerate(self.envs):
env.render(mode=mode)

def close(self):
"""
Close all environments.
"""
for env in self.envs:
env.close()
Loading

0 comments on commit 4ae656b

Please sign in to comment.