Skip to content

Commit

Permalink
refactor(agent): clean agent part code (#40)
Browse files Browse the repository at this point in the history
Co-authored-by: Isaac Jin <[email protected]>
  • Loading branch information
dandansamax and WHALEEYE authored Oct 29, 2024
1 parent 71e95fb commit 48f2452
Show file tree
Hide file tree
Showing 38 changed files with 1,997 additions and 1,157 deletions.
4 changes: 4 additions & 0 deletions crab-benchmark-v0/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,7 @@ After setting up the environment, you can start the experiment. A brief overview
2. Start the CRAB server in the Ubuntu environment and get its IP address and port. Let's say they are `192.168.122.72` and `8000`.
3. Choose a task. As an example, we take the task with ID `a3476778-e512-40ca-b1c0-d7aab0c7f18b` from [handmade_tasks](./dataset/handmade_tasks.py). The task is: "Open the 'Tasks' app on Android, check the first incomplete task, then perform the task according to its description."
4. Run [main.py](./main.py) with the command `poetry run python -m crab-benchmark-v0.main --model gpt4o --policy single --remote-url http://192.168.122.72:8000 --task-id a3476778-e512-40ca-b1c0-d7aab0c7f18b`. In this command, `--model gpt4o` and `--policy single` determine the agent system, `--remote-url` specifies the Ubuntu environment interface, and `--task-id` indicates the task to be performed.

#### Model

For open source models, we use [VLLM](https://github.com/vllm-project/vllm) to host Pixtral model, check [here](https://docs.vllm.ai/en/latest/models/vlm.html#online-inference) for the setup commands; [SGLang](https://github.com/sgl-project/sglang) to host LLaVa-OneVision model, check [here](https://github.com/sgl-project/sglang?tab=readme-ov-file#supported-models) for the setup commands.
3 changes: 2 additions & 1 deletion crab-benchmark-v0/android_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from crab import EnvironmentConfig
from crab.actions.android_actions import (
key_press,
long_tap,
open_app_drawer,
screenshot,
setup,
Expand All @@ -24,7 +25,7 @@

ANDROID_ENV = EnvironmentConfig(
name="android",
action_space=[tap, key_press, write_text, swipe, open_app_drawer],
action_space=[tap, key_press, long_tap, write_text, swipe, open_app_drawer],
observation_space=[screenshot],
description="""A Google Pixel smartphone runs on the Android operating system. \
The interface displays a current screenshot at each step and primarily \
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"description": "In Android, Using Google Map app, Find the city name of corresponding post code \"1010021\" in the country \"Japan\".",
"tasks": [
{
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",
"attribute": {
"country": "Japan",
"number": "101-0021"
},
"output": "Tokyo"
}
],
"adjlist": "0",
"id": "4190c90c-b28c-4bb3-ab5c-af3c4fde0a3d"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
{
"description": "In the Android system, use the calendar app to find the title of an event on the date \"16 July 2024,\".",
"tasks": [
{
"task": "2394b768-2ca7-45e9-b41e-2aa4e9573192",
"attribute": {
"date": "16 July 2024"
},
"output": "Japan"
}
],
"adjlist": "0",
"id": "4893a9b0-6477-495d-a73c-32503326e24a"
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
{
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postcode \"110151\" in Colombia.",
"tasks": [
{
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",
"attribute": {
"number": "110151",
"country": "Columbia"
},
"output": "Bogota"
}
],
"adjlist": "0",
"id": "e55d7a39-7b6b-4852-8711-844cebc88cb8"
}
2 changes: 2 additions & 0 deletions crab-benchmark-v0/dataset/android_subtasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,8 @@ def check_event(date: str, env) -> bool:
event_nodes = root.xpath('//node[@class="android.support.v7.widget.RecyclerView"]')
if event_nodes is None:
return False
if not event_nodes:
return False
for node in event_nodes[0]:
text = node.get("content-desc")
if date in text:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
{
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postal code \"1010021\" in Japan, then paste the name into LibreOffice Writer on an Ubuntu system and save it as an ODT file at \"/home/crab/Desktop\".",
"description": "In Android, use the \"Google Map\" app to find the city name corresponding to the postal code \"1010021\" in Japan, then paste the name into LibreOffice Writer on an Ubuntu system and save it as an ODT file at \"/home/crab/Desktop/target.opt\".",
"tasks": [
{
"task": "51b2463c-9904-4a32-81ba-507bfb89d61f",
Expand Down
224 changes: 200 additions & 24 deletions crab-benchmark-v0/dataset/handmade_tasks.py

Large diffs are not rendered by default.

94 changes: 81 additions & 13 deletions crab-benchmark-v0/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
TaskGenerator,
create_benchmark,
)
from crab.actions.crab_actions import complete
from crab.actions.crab_actions import complete, wait
from crab.actions.visual_prompt_actions import (
get_elements_prompt,
groundingdino_easyocr,
)
from crab.agents.backend_models import ClaudeModel, GeminiModel, OpenAIModel
from crab.agents.backend_models import BackendModelConfig
from crab.agents.policies import (
MultiAgentByEnvPolicy,
MultiAgentByFuncPolicy,
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_benchmark(env: str, ubuntu_url: str):
tasks=[],
environments=[ubuntu_env],
prompting_tools=prompting_tools,
root_action_space=[complete],
root_action_space=[complete, wait],
multienv=True,
)
elif env == "android":
Expand All @@ -106,7 +106,7 @@ def get_benchmark(env: str, ubuntu_url: str):
tasks=[],
environments=[ANDROID_ENV],
prompting_tools=prompting_tools,
root_action_space=[complete],
root_action_space=[complete, wait],
multienv=True,
)
elif env == "cross":
Expand All @@ -119,7 +119,7 @@ def get_benchmark(env: str, ubuntu_url: str):
tasks=[],
environments=[ubuntu_env, ANDROID_ENV],
prompting_tools=prompting_tools,
root_action_space=[complete],
root_action_space=[complete, wait],
multienv=True,
)
else:
Expand All @@ -137,7 +137,7 @@ def get_benchmark(env: str, ubuntu_url: str):
# Load from handmade tasks
benchmark_config.tasks.extend(handmade_tasks)

benchmark_config.step_limit = 15
benchmark_config.step_limit = 20
return create_benchmark(benchmark_config)


Expand All @@ -158,7 +158,7 @@ def get_benchmark(env: str, ubuntu_url: str):
default="single",
)
parser.add_argument(
"--remote-url",
"--ubuntu-url",
type=str,
help="remote url of Ubunutu environment",
default="http://127.0.0.1:8000",
Expand All @@ -170,29 +170,97 @@ def get_benchmark(env: str, ubuntu_url: str):
default="cross",
)
parser.add_argument("--task-id", type=str, help="task id")
parser.add_argument(
"--model-base-url",
type=str,
help="URL of the model API",
default="http://127.0.0.1:8000/v1",
)
parser.add_argument(
"--model-api-key",
type=str,
help="API key of the model API",
default="EMPTY",
)
parser.add_argument(
"--loglevel",
type=str,
help="logger level, debug, info, warning, or error",
default="warning",
)
parser.add_argument(
"--history-messages-len",
type=int,
help="The number of rounds of chat history to provide to the model",
default=2,
)
args = parser.parse_args()
loglevel = args.loglevel
numeric_level = getattr(logging, loglevel.upper(), None)
if not isinstance(numeric_level, int):
raise ValueError("Invalid log level: %s" % loglevel)
logging.basicConfig(level=numeric_level)

benchmark = get_benchmark(args.env, args.remote_url)
benchmark = get_benchmark(args.env, args.ubuntu_url)

if args.model == "human":
expeirment = CrabBenchmarkV0(
benchmark=benchmark,
task_id=args.task_id,
agent_policy="human",
)
expeirment.start_benchmark()
exit()

if args.model == "gpt4o":
model = OpenAIModel(model="gpt-4o", history_messages_len=2)
model = BackendModelConfig(
model_class="openai",
model_name="gpt-4o",
history_messages_len=args.history_messages_len,
)
elif args.model == "gpt4turbo":
model = OpenAIModel(model="gpt-4-turbo", history_messages_len=2)
model = BackendModelConfig(
model_class="openai",
model_name="gpt-4-turbo",
history_messages_len=args.history_messages_len,
)
elif args.model == "gemini":
model = GeminiModel(model="gemini-1.5-pro-latest", history_messages_len=2)
model = BackendModelConfig(
model_class="gemini",
model_name="gemini-1.5-pro-latest",
history_messages_len=args.history_messages_len,
)
elif args.model == "claude":
model = ClaudeModel(model="claude-3-opus-20240229", history_messages_len=2)
model = BackendModelConfig(
model_class="claude",
model_name="claude-3-opus-20240229",
history_messages_len=args.history_messages_len,
)
elif args.model == "pixtral":
model = BackendModelConfig(
model_class="openai",
model_name="mistralai/Pixtral-12B-2409",
json_structre_output=True,
history_messages_len=args.history_messages_len,
base_url=args.model_base_url,
api_key=args.model_api_key,
)
elif args.model == "gpt4o-wofc":
model = BackendModelConfig(
model_class="openai",
model_name="gpt-4o",
json_structre_output=True,
history_messages_len=args.history_messages_len,
)
elif args.model == "llava-ov72b":
model = BackendModelConfig(
model_class="sglang",
model_name="lmms-lab/llava-onevision-qwen2-72b-ov-chat",
json_structre_output=True,
history_messages_len=args.history_messages_len,
base_url=args.model_base_url,
api_key=args.model_api_key,
)
else:
print("Unsupported model: ", args.model)
exit()
Expand All @@ -211,7 +279,7 @@ def get_benchmark(env: str, ubuntu_url: str):
print("Unsupported policy: ", args.policy)
exit()

log_dir = (Path(__file__).parent / "logs").resolve()
log_dir = (Path(__file__).parent / "tianqi_logs").resolve()
expeirment = CrabBenchmarkV0(
benchmark=benchmark,
task_id=args.task_id,
Expand Down
2 changes: 2 additions & 0 deletions crab-benchmark-v0/ubuntu_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from crab.actions.desktop_actions import (
click,
double_click,
key_press,
press_hotkey,
right_click,
Expand All @@ -31,6 +32,7 @@
press_hotkey,
search_application,
right_click,
double_click,
],
observation_space=[screenshot],
description="""An Ubuntu 22.04 Linux desktop operating system. The interface \
Expand Down
10 changes: 10 additions & 0 deletions crab/actions/crab_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# =========== Copyright 2024 @ CAMEL-AI.org. All Rights Reserved. ===========
from time import sleep

from crab import action, evaluator


Expand Down Expand Up @@ -42,6 +44,14 @@ def complete() -> bool:
pass


@action(env_name="root")
def wait() -> bool:
"""If the environment is still processing your action and you have nothing to do in
this step, you can use wait().
"""
sleep(5)


def get_element_position(element_id, env):
"""Get element position provided by function `zs_object_detection`"""
box = env.element_position_map[element_id]
Expand Down
30 changes: 29 additions & 1 deletion crab/actions/desktop_actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def right_click(element: int, env) -> None:
"""
Right-click an UI element shown on the desktop screen using the mouse, which is
usually used for opening the menu of the element. A simple use case can be
rght_click(5), which right-clicks the UI element labeled with the number 5 to open
right_click(5), which right-clicks the UI element labeled with the number 5 to open
up menu on it.
Args:
Expand All @@ -80,6 +80,34 @@ def right_click(element: int, env) -> None:
time.sleep(DELAY)


@action
def double_click_position(x: int, y: int) -> None:
"""
Double-click on the current desktop screen.
Args:
x: The X coordinate, as a floating-point number in the range [0.0, 1.0].
y: The Y coordinate, as a floating-point number in the range [0.0, 1.0].
"""
pyautogui.click(x, y, duration=DURATION, clicks=2, interval=0.2)


@action(local=True)
def double_click(element: int, env) -> None:
"""
Double-click an UI element shown on the desktop screen using the mouse, which is
usually used for opening a folder or a file. A simple use case can be
double_click(5), which double-clicks the UI element labeled with the number 5 to
open it.
Args:
element: A numeric tag assigned to an UI element shown on the screenshot.
"""
x, y = get_element_position(element, env)
env._action_endpoint(double_click_position, {"x": x, "y": y})
time.sleep(DELAY)


@action
def mouse_scroll(click: int = 1) -> None:
"""
Expand Down
Loading

0 comments on commit 48f2452

Please sign in to comment.