-
Notifications
You must be signed in to change notification settings - Fork 29
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6d1be33
commit 4ae656b
Showing
9 changed files
with
704 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from abc import ABC, abstractmethod | ||
from typing import Iterable | ||
|
||
from .environment import Environment | ||
from .evaluator import Evaluator | ||
|
||
|
||
class Benchmark(ABC): | ||
@abstractmethod | ||
def get_corresponding_env(self) -> Environment: | ||
pass | ||
|
||
@abstractmethod | ||
def get_task_by_id(self, id: str) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def tasks(self) -> Iterable[tuple[str, Evaluator]]: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from abc import ABC, abstractmethod | ||
from typing import Any | ||
|
||
import gymnasium as gym | ||
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam | ||
|
||
|
||
class Environment(gym.Env, ABC): | ||
"""The base environment class for agents to interact with in the CRAB framework. | ||
Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class | ||
for all environments in the CRAB. Your must implement two functions | ||
`get_action_schema` and `convert_tool_call_to_action` to make the environment | ||
compatible with OpenAI tool use API. | ||
""" | ||
|
||
@abstractmethod | ||
def get_description(self) -> str: | ||
"""Get the description of the environment, which can be used as a part of the | ||
agent prompt. | ||
Returns: | ||
A string description of the environment. | ||
""" | ||
|
||
@abstractmethod | ||
def get_action_schema(self) -> list[ChatCompletionToolParam]: | ||
"""Get the tool schema for the action space of the environment. | ||
The schema provides detailed descriptions of the whole actions space and their | ||
parameters that represent all the possible actions in the tool calling format, | ||
which can be directly used in the OpenAI API. It should be comprehensive and do | ||
not produce any misunderstanding for a human user. | ||
Returns: | ||
A list of tool schema. | ||
""" | ||
... | ||
|
||
@abstractmethod | ||
def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any: | ||
"""Convert a tool call to the actual action space in the environment. | ||
Args: | ||
tool_name: The name of the tool. | ||
parameters: The parameters of the tool call. | ||
""" | ||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Evaluator(ABC): | ||
@abstractmethod | ||
def step(self, environment: Environment, task: Task) -> Any: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from abc import ABC, abstractmethod | ||
from pydantic import BaseModel, ConfigDict | ||
from typing import Any | ||
|
||
from .environment import Environment | ||
|
||
|
||
class Task(BaseModel): | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
id: str | ||
description: str | ||
evaluator: Evaluator | ||
setup: setup = [] | ||
extra_action: list[Action] = [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from typing import Any | ||
|
||
import gymnasium as gym | ||
from gymnasium import Wrapper | ||
from gymnasium.core import ActType, ObsType, WrapperObsType | ||
from gymnasium.spaces import Dict, Space, Text, Tuple | ||
|
||
|
||
class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): | ||
def __init__( | ||
self, | ||
env: gym.Env[ObsType, ActType], | ||
task: Task, | ||
*, | ||
dict_task_key: str = "task", | ||
): | ||
super().__init__(env) | ||
self.env = env | ||
self.task = task | ||
|
||
task_space = Text(500) | ||
|
||
# Observation space in different situations | ||
if isinstance(env.observation_space, Dict): | ||
assert dict_task_key not in env.observation_space.keys() | ||
observation_space = Dict( | ||
{dict_task_key: task_space, **env.observation_space.spaces} | ||
) | ||
self._append_data_func = lambda obs, task: {dict_task_key: task, **obs} | ||
elif isinstance(env.observation_space, Tuple): | ||
observation_space = Tuple(env.observation_space.spaces + (task_space,)) | ||
self._append_data_func = lambda obs, task: obs + (task,) | ||
else: | ||
observation_space = Dict(obs=env.observation_space, task=task_space) | ||
self._append_data_func = lambda obs, task: {"obs": obs, "task": task} | ||
|
||
self.observation_space: gym.Space[WrapperObsType] = observation_space | ||
|
||
def reset( | ||
self, *, seed: int | None = None, options: dict[str, Any] | None = None | ||
) -> tuple[Dict, dict[str, Any]]: | ||
"""Modifies the :attr:`env` after calling :meth:`reset`, returning a modified | ||
observation using :meth:`self.observation`.""" | ||
obs, info = self.env.reset(seed=seed, options=options) | ||
return self.observation(obs), info | ||
|
||
def step( | ||
self, action: ActType | ||
) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: | ||
observation, reward, terminal, truncated, info = self.step(action) | ||
reward = self.task.evaluate(self.env) | ||
return self.observation(observation), reward, terminal, truncated, info | ||
|
||
def observation(self, observation: ObsType): | ||
"""Returns a modified observation. | ||
Args: | ||
observation: The :attr:`env` observation | ||
Returns: | ||
The modified observation | ||
""" | ||
return self._append_data_func(observation, self.task.description) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
from abc import ABC, abstractmethod | ||
|
||
|
||
class Workflow(ABC): | ||
@abstractmethod | ||
def step(self): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
# Licensed under the Apache License, Version 2.0 (the “License”); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an “AS IS” BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== | ||
import gymnasium as gym | ||
import numpy as np | ||
from gymnasium import spaces | ||
|
||
|
||
class MultiEnv(gym.Env): | ||
def __init__(self, envs): | ||
""" | ||
Initialize the MultiEnv environment. | ||
Args: | ||
envs (list): A list of gymnasium environments to integrate. | ||
""" | ||
super().__init__() | ||
|
||
# Store the environments | ||
self.envs = envs | ||
|
||
# Create action space using OneOf with the action spaces of each environment | ||
self.action_space = spaces.OneOf([env.action_space for env in envs]) | ||
|
||
# Create observation space as a Dict space containing each environment's observation space | ||
self.observation_space = spaces.Dict( | ||
{f"env_{i}": env.observation_space for i, env in enumerate(envs)} | ||
) | ||
|
||
def reset(self): | ||
""" | ||
Reset all environments and return initial observations. | ||
Returns: | ||
dict: A dictionary with initial observations from each environment. | ||
""" | ||
observations = {} | ||
for i, env in enumerate(self.envs): | ||
observations[f"env_{i}"], _ = env.reset() | ||
return observations | ||
|
||
def step(self, action): | ||
""" | ||
Take a step in the selected environment based on the action. | ||
Args: | ||
action (int): The index of the environment to take a step in. | ||
Returns: | ||
tuple: A tuple containing the observations, rewards, done flags, and info. | ||
""" | ||
assert 0 <= action < len(self.envs), "Invalid action for environment selection." | ||
|
||
# Initialize dictionaries to store results | ||
observations = {} | ||
rewards = {} | ||
dones = {} | ||
infos = {} | ||
|
||
# Perform a step in the selected environment | ||
obs, reward, done, truncated, info = self.envs[action].step(action) | ||
|
||
# Populate results for the selected environment | ||
observations[f"env_{action}"] = obs | ||
rewards[f"env_{action}"] = reward | ||
dones[f"env_{action}"] = done | ||
infos[f"env_{action}"] = info | ||
|
||
# For other environments, simply pass their previous observations | ||
for i, env in enumerate(self.envs): | ||
if i != action: | ||
observations[f"env_{i}"] = ( | ||
None # No new observation for non-acting environments | ||
) | ||
rewards[f"env_{i}"] = 0 | ||
dones[f"env_{i}"] = False | ||
infos[f"env_{i}"] = {} | ||
|
||
# Set done if all environments are done | ||
all_done = all(dones.values()) | ||
|
||
return observations, rewards, all_done, infos | ||
|
||
def render(self, mode="human"): | ||
""" | ||
Render all environments (optional implementation). | ||
""" | ||
for i, env in enumerate(self.envs): | ||
env.render(mode=mode) | ||
|
||
def close(self): | ||
""" | ||
Close all environments. | ||
""" | ||
for env in self.envs: | ||
env.close() |
Oops, something went wrong.