From 4ae656bfdbc8b8371c3cdfbcc76fbe528ede953a Mon Sep 17 00:00:00 2001 From: Tianqi Xu Date: Wed, 13 Nov 2024 21:16:27 +0300 Subject: [PATCH] [WIP] Define framework --- refactor_demo/core/benchmark.py | 32 +++ refactor_demo/core/environment.py | 61 +++++ refactor_demo/core/evaluator.py | 20 ++ refactor_demo/core/policy.py | 13 + refactor_demo/core/task.py | 27 ++ refactor_demo/core/task_wrapper.py | 76 ++++++ refactor_demo/core/workflow.py | 20 ++ refactor_demo/envs/multi_env.py | 106 +++++++ refactor_demo/envs/multi_env_test.ipynb | 349 ++++++++++++++++++++++++ 9 files changed, 704 insertions(+) create mode 100644 refactor_demo/core/benchmark.py create mode 100644 refactor_demo/core/environment.py create mode 100644 refactor_demo/core/evaluator.py create mode 100644 refactor_demo/core/policy.py create mode 100644 refactor_demo/core/task.py create mode 100644 refactor_demo/core/task_wrapper.py create mode 100644 refactor_demo/core/workflow.py create mode 100644 refactor_demo/envs/multi_env.py create mode 100644 refactor_demo/envs/multi_env_test.ipynb diff --git a/refactor_demo/core/benchmark.py b/refactor_demo/core/benchmark.py new file mode 100644 index 0000000..88aa52d --- /dev/null +++ b/refactor_demo/core/benchmark.py @@ -0,0 +1,32 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from abc import ABC, abstractmethod +from typing import Iterable + +from .environment import Environment +from .evaluator import Evaluator + + +class Benchmark(ABC): + @abstractmethod + def get_corresponding_env(self) -> Environment: + pass + + @abstractmethod + def get_task_by_id(self, id: str) -> str: + pass + + @abstractmethod + def tasks(self) -> Iterable[tuple[str, Evaluator]]: + pass diff --git a/refactor_demo/core/environment.py b/refactor_demo/core/environment.py new file mode 100644 index 0000000..b5c33b8 --- /dev/null +++ b/refactor_demo/core/environment.py @@ -0,0 +1,61 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from abc import ABC, abstractmethod +from typing import Any + +import gymnasium as gym +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam + + +class Environment(gym.Env, ABC): + """The base environment class for agents to interact with in the CRAB framework. + + Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class + for all environments in the CRAB. Your must implement two functions + `get_action_schema` and `convert_tool_call_to_action` to make the environment + compatible with OpenAI tool use API. + """ + + @abstractmethod + def get_description(self) -> str: + """Get the description of the environment, which can be used as a part of the + agent prompt. + + Returns: + A string description of the environment. + """ + + @abstractmethod + def get_action_schema(self) -> list[ChatCompletionToolParam]: + """Get the tool schema for the action space of the environment. + + The schema provides detailed descriptions of the whole actions space and their + parameters that represent all the possible actions in the tool calling format, + which can be directly used in the OpenAI API. It should be comprehensive and do + not produce any misunderstanding for a human user. + + Returns: + A list of tool schema. + """ + ... + + @abstractmethod + def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any: + """Convert a tool call to the actual action space in the environment. + + Args: + tool_name: The name of the tool. + parameters: The parameters of the tool call. + """ + ... diff --git a/refactor_demo/core/evaluator.py b/refactor_demo/core/evaluator.py new file mode 100644 index 0000000..82cb0f9 --- /dev/null +++ b/refactor_demo/core/evaluator.py @@ -0,0 +1,20 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from abc import ABC, abstractmethod + + +class Evaluator(ABC): + @abstractmethod + def step(self, environment: Environment, task: Task) -> Any: + pass diff --git a/refactor_demo/core/policy.py b/refactor_demo/core/policy.py new file mode 100644 index 0000000..66e0731 --- /dev/null +++ b/refactor_demo/core/policy.py @@ -0,0 +1,13 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== diff --git a/refactor_demo/core/task.py b/refactor_demo/core/task.py new file mode 100644 index 0000000..a8f2f62 --- /dev/null +++ b/refactor_demo/core/task.py @@ -0,0 +1,27 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from abc import ABC, abstractmethod +from pydantic import BaseModel, ConfigDict +from typing import Any + +from .environment import Environment + + +class Task(BaseModel): + model_config = ConfigDict(arbitrary_types_allowed=True) + id: str + description: str + evaluator: Evaluator + setup: setup = [] + extra_action: list[Action] = [] diff --git a/refactor_demo/core/task_wrapper.py b/refactor_demo/core/task_wrapper.py new file mode 100644 index 0000000..2ae112a --- /dev/null +++ b/refactor_demo/core/task_wrapper.py @@ -0,0 +1,76 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from typing import Any + +import gymnasium as gym +from gymnasium import Wrapper +from gymnasium.core import ActType, ObsType, WrapperObsType +from gymnasium.spaces import Dict, Space, Text, Tuple + + +class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]): + def __init__( + self, + env: gym.Env[ObsType, ActType], + task: Task, + *, + dict_task_key: str = "task", + ): + super().__init__(env) + self.env = env + self.task = task + + task_space = Text(500) + + # Observation space in different situations + if isinstance(env.observation_space, Dict): + assert dict_task_key not in env.observation_space.keys() + observation_space = Dict( + {dict_task_key: task_space, **env.observation_space.spaces} + ) + self._append_data_func = lambda obs, task: {dict_task_key: task, **obs} + elif isinstance(env.observation_space, Tuple): + observation_space = Tuple(env.observation_space.spaces + (task_space,)) + self._append_data_func = lambda obs, task: obs + (task,) + else: + observation_space = Dict(obs=env.observation_space, task=task_space) + self._append_data_func = lambda obs, task: {"obs": obs, "task": task} + + self.observation_space: gym.Space[WrapperObsType] = observation_space + + def reset( + self, *, seed: int | None = None, options: dict[str, Any] | None = None + ) -> tuple[Dict, dict[str, Any]]: + """Modifies the :attr:`env` after calling :meth:`reset`, returning a modified + observation using :meth:`self.observation`.""" + obs, info = self.env.reset(seed=seed, options=options) + return self.observation(obs), info + + def step( + self, action: ActType + ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]: + observation, reward, terminal, truncated, info = self.step(action) + reward = self.task.evaluate(self.env) + return self.observation(observation), reward, terminal, truncated, info + + def observation(self, observation: ObsType): + """Returns a modified observation. + + Args: + observation: The :attr:`env` observation + + Returns: + The modified observation + """ + return self._append_data_func(observation, self.task.description) diff --git a/refactor_demo/core/workflow.py b/refactor_demo/core/workflow.py new file mode 100644 index 0000000..68e170d --- /dev/null +++ b/refactor_demo/core/workflow.py @@ -0,0 +1,20 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +from abc import ABC, abstractmethod + + +class Workflow(ABC): + @abstractmethod + def step(self): + pass diff --git a/refactor_demo/envs/multi_env.py b/refactor_demo/envs/multi_env.py new file mode 100644 index 0000000..0d13895 --- /dev/null +++ b/refactor_demo/envs/multi_env.py @@ -0,0 +1,106 @@ +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +# Licensed under the Apache License, Version 2.0 (the “License”); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an “AS IS” BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. =========== +import gymnasium as gym +import numpy as np +from gymnasium import spaces + + +class MultiEnv(gym.Env): + def __init__(self, envs): + """ + Initialize the MultiEnv environment. + + Args: + envs (list): A list of gymnasium environments to integrate. + """ + super().__init__() + + # Store the environments + self.envs = envs + + # Create action space using OneOf with the action spaces of each environment + self.action_space = spaces.OneOf([env.action_space for env in envs]) + + # Create observation space as a Dict space containing each environment's observation space + self.observation_space = spaces.Dict( + {f"env_{i}": env.observation_space for i, env in enumerate(envs)} + ) + + def reset(self): + """ + Reset all environments and return initial observations. + + Returns: + dict: A dictionary with initial observations from each environment. + """ + observations = {} + for i, env in enumerate(self.envs): + observations[f"env_{i}"], _ = env.reset() + return observations + + def step(self, action): + """ + Take a step in the selected environment based on the action. + + Args: + action (int): The index of the environment to take a step in. + + Returns: + tuple: A tuple containing the observations, rewards, done flags, and info. + """ + assert 0 <= action < len(self.envs), "Invalid action for environment selection." + + # Initialize dictionaries to store results + observations = {} + rewards = {} + dones = {} + infos = {} + + # Perform a step in the selected environment + obs, reward, done, truncated, info = self.envs[action].step(action) + + # Populate results for the selected environment + observations[f"env_{action}"] = obs + rewards[f"env_{action}"] = reward + dones[f"env_{action}"] = done + infos[f"env_{action}"] = info + + # For other environments, simply pass their previous observations + for i, env in enumerate(self.envs): + if i != action: + observations[f"env_{i}"] = ( + None # No new observation for non-acting environments + ) + rewards[f"env_{i}"] = 0 + dones[f"env_{i}"] = False + infos[f"env_{i}"] = {} + + # Set done if all environments are done + all_done = all(dones.values()) + + return observations, rewards, all_done, infos + + def render(self, mode="human"): + """ + Render all environments (optional implementation). + """ + for i, env in enumerate(self.envs): + env.render(mode=mode) + + def close(self): + """ + Close all environments. + """ + for env in self.envs: + env.close() diff --git a/refactor_demo/envs/multi_env_test.ipynb b/refactor_demo/envs/multi_env_test.ipynb new file mode 100644 index 0000000..c818e95 --- /dev/null +++ b/refactor_demo/envs/multi_env_test.ipynb @@ -0,0 +1,349 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from abc import ABC, abstractmethod\n", + "from typing import Any\n", + "\n", + "import gymnasium as gym\n", + "from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam\n", + "\n", + "\n", + "class Environment(gym.Env, ABC):\n", + " \"\"\"The base environment class for agents to interact with in the CRAB framework.\n", + "\n", + " Crab Environment is a subclass of `gymnasium.Env` and is designed to be a base class\n", + " for all environments in the CRAB. Your must implement two functions\n", + " `get_action_schema` and `convert_tool_call_to_action` to make the environment\n", + " compatible with OpenAI tool use API.\n", + " \"\"\"\n", + "\n", + " @abstractmethod\n", + " def get_description(self) -> str:\n", + " \"\"\"Get the description of the environment, which can be used as a part of the\n", + " agent prompt.\n", + "\n", + " Returns:\n", + " A string description of the environment.\n", + " \"\"\"\n", + "\n", + " @abstractmethod\n", + " def get_action_schema(self) -> list[ChatCompletionToolParam]:\n", + " \"\"\"Get the tool schema for the action space of the environment.\n", + "\n", + " The schema provides detailed descriptions of the whole actions space and their\n", + " parameters that represent all the possible actions in the tool calling format,\n", + " which can be directly used in the OpenAI API. It should be comprehensive and do\n", + " not produce any misunderstanding for a human user.\n", + "\n", + " Returns:\n", + " A list of tool schema.\n", + " \"\"\"\n", + " ...\n", + "\n", + " @abstractmethod\n", + " def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:\n", + " \"\"\"Convert a tool call to the actual action space in the environment.\n", + "\n", + " Args:\n", + " tool_name: The name of the tool.\n", + " parameters: The parameters of the tool call.\n", + " \"\"\"\n", + " ..." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import gymnasium as gym\n", + "from gymnasium.envs.classic_control.acrobot import AcrobotEnv\n", + "\n", + "from crab.core.decorators import action\n", + "from refactor_demo.envs.multi_env import MultiEnv\n", + "\n", + "\n", + "@action\n", + "def left():\n", + " \"\"\"apply -1 torque to the actuated joint\"\"\"\n", + "\n", + "\n", + "@action\n", + "def right():\n", + " \"\"\"apply +1 torque to the actuated joint\"\"\"\n", + "\n", + "\n", + "@action\n", + "def no_torque():\n", + " \"\"\"apply 0 torque to the actuated joint\"\"\"\n", + "\n", + "\n", + "class CrabAcrobotEnv(AcrobotEnv, Environment):\n", + " def get_description(self) -> str:\n", + " \"\"\"Get the description of the environment, which can be used as a part of the\n", + " agent prompt.\n", + "\n", + " Returns:\n", + " A string description of the environment.\n", + " \"\"\"\n", + " return \"\"\"The system consists of two links connected linearly to form a chain, with one end of \\\n", + "the chain fixed. The joint between the two links is actuated. The goal is to apply \\\n", + "torques on the actuated joint to swing the free end of the linear chain above a \\\n", + "given height while starting from the initial state of hanging downwards.\n", + "\n", + " ## Observation Space\n", + "\n", + " The observation is a `ndarray` with shape `(6,)` that provides information about the\n", + " two rotational joint angles as well as their angular velocities:\n", + "\n", + " | Num | Observation | Min | Max |\n", + " |-----|------------------------------|---------------------|-------------------|\n", + " | 0 | Cosine of `theta1` | -1 | 1 |\n", + " | 1 | Sine of `theta1` | -1 | 1 |\n", + " | 2 | Cosine of `theta2` | -1 | 1 |\n", + " | 3 | Sine of `theta2` | -1 | 1 |\n", + " | 4 | Angular velocity of `theta1` | ~ -12.567 (-4 * pi) | ~ 12.567 (4 * pi) |\n", + " | 5 | Angular velocity of `theta2` | ~ -28.274 (-9 * pi) | ~ 28.274 (9 * pi) |\n", + "\"\"\"\n", + "\n", + " def get_action_schema(self) -> list[ChatCompletionToolParam]:\n", + " \"\"\"Get the tool schema for the action space of the environment.\n", + "\n", + " The schema provides detailed descriptions of the whole actions space and their\n", + " parameters that represent all the possible actions in the tool calling format,\n", + " which can be directly used in the OpenAI API. It should be comprehensive and do\n", + " not produce any misunderstanding for a human user.\n", + "\n", + " Returns:\n", + " A list of tool schema.\n", + " \"\"\"\n", + " result = []\n", + " result.append(left.to_openai_json_schema())\n", + " result.append(right.to_openai_json_schema())\n", + " result.append(no_torque.to_openai_json_schema())\n", + " return result\n", + "\n", + " MAP = {\"left\": 0, \"no_torque\": 1, \"right\": 2}\n", + "\n", + " def convert_tool_call_to_action(self, tool_name: str, parameters: dict) -> Any:\n", + " \"\"\"Convert a tool call to the actual action space in the environment.\n", + "\n", + " Args:\n", + " tool_name: The name of the tool.\n", + " parameters: The parameters of the tool call.\n", + " \"\"\"\n", + " return self.MAP[tool_name]\n", + "\n", + "\n", + "env = CrabAcrobotEnv()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from dataclasses import dataclass\n", + "\n", + "\n", + "@dataclass\n", + "class Task:\n", + " description: str\n", + " evaluate: callable\n", + "\n", + "\n", + "task = Task(\n", + " description=\"apply torques on the actuated joint to swing the free end of the linear chain above a given height while starting from the initial state of hanging downwards.\",\n", + " evaluate=lambda env: True,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Generic\n", + "\n", + "from gymnasium import Wrapper\n", + "from gymnasium.core import ActType, ObsType, WrapperObsType\n", + "from gymnasium.spaces import Dict, Space, Text, Tuple\n", + "\n", + "\n", + "class TaskWrapper(Wrapper[WrapperObsType, ActType, ObsType, ActType]):\n", + " def __init__(\n", + " self,\n", + " env: gym.Env[ObsType, ActType],\n", + " task: Task,\n", + " *,\n", + " dict_task_key: str = \"task\",\n", + " ):\n", + " super().__init__(env)\n", + " self.env = env\n", + " self.task = task\n", + "\n", + " task_space = Text(500)\n", + "\n", + " # Observation space in different situations\n", + " if isinstance(env.observation_space, Dict):\n", + " assert dict_task_key not in env.observation_space.keys()\n", + " observation_space = Dict(\n", + " {dict_task_key: task_space, **env.observation_space.spaces}\n", + " )\n", + " self._append_data_func = lambda obs, task: {dict_task_key: task, **obs}\n", + " elif isinstance(env.observation_space, Tuple):\n", + " observation_space = Tuple(env.observation_space.spaces + (task_space,))\n", + " self._append_data_func = lambda obs, task: obs + (task,)\n", + " else:\n", + " observation_space = Dict(obs=env.observation_space, task=task_space)\n", + " self._append_data_func = lambda obs, task: {\"obs\": obs, \"task\": task}\n", + "\n", + " self.observation_space: gym.Space[WrapperObsType] = observation_space\n", + "\n", + " def reset(\n", + " self, *, seed: int | None = None, options: dict[str, Any] | None = None\n", + " ) -> tuple[Dict, dict[str, Any]]:\n", + " \"\"\"Modifies the :attr:`env` after calling :meth:`reset`, returning a modified\n", + " observation using :meth:`self.observation`.\"\"\"\n", + " obs, info = self.env.reset(seed=seed, options=options)\n", + " return self.observation(obs), info\n", + "\n", + " def step(\n", + " self, action: ActType\n", + " ) -> tuple[WrapperObsType, float, bool, bool, dict[str, Any]]:\n", + " observation, reward, terminal, truncated, info = self.step(action)\n", + " reward = self.task.evaluate(self.env)\n", + " return self.observation(observation), reward, terminal, truncated, info\n", + "\n", + " def observation(self, observation: ObsType):\n", + " \"\"\"Returns a modified observation.\n", + "\n", + " Args:\n", + " observation: The :attr:`env` observation\n", + "\n", + " Returns:\n", + " The modified observation\n", + " \"\"\"\n", + " return self._append_data_func(observation, self.task.description)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "task_env = TaskWrapper(env, task)\n", + "o, i = task_env.reset()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'obs': array([ 9.9810892e-01, 6.1470319e-02, 1.0000000e+00, -2.1458303e-05,\n", + " -9.0955026e-02, -7.1539722e-02], dtype=float32),\n", + " 'task': 'apply torques on the actuated joint to swing the free end of the linear chain above a given height while starting from the initial state of hanging downwards.'}" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "ename": "", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[1;31mThe Kernel crashed while executing code in the current cell or a previous cell. \n", + "\u001b[1;31mPlease review the code in the cell(s) to identify a possible cause of the failure. \n", + "\u001b[1;31mClick here for more info. \n", + "\u001b[1;31mView Jupyter log for further details." + ] + } + ], + "source": [ + "import openai\n", + "\n", + "client = openai.Client()\n", + "o, _ = env.reset()\n", + "\n", + "\n", + "result = client.chat.completions.create(\n", + " model=\"gpt-4-0613\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": env.get_description()},\n", + " {\"role\": \"user\", \"content\": str(o) + \"Tell me next step\"},\n", + " ],\n", + " tools=[{\"function\": tool, \"type\": \"function\"} for tool in env.get_action_schema()],\n", + " tool_choice=\"required\",\n", + ")\n", + "print(result.choices[0].message.tool_calls)" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([ 0.999577 , 0.02908438, 0.9999982 , -0.00189753, 0.08006953,\n", + " 0.06967726], dtype=float32)" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "o" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "crab-framework-MZtbDDSz-py3.10", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}