├── .gitignore
├── LICENSE
├── README.md
├── app.py
├── assets
├── Teaser.gif
├── examples
│ ├── init_states
│ │ ├── amazon.png
│ │ ├── booking.png
│ │ ├── honkai_star_rail.png
│ │ ├── honkai_star_rail_showui.png
│ │ ├── ign.png
│ │ ├── powerpoint.png
│ │ └── powerpoint_homepage.png
│ └── ootb_examples.json
├── gradio_interface.png
├── ootb_icon.png
├── ootb_logo.png
└── wechat_3.jpg
├── computer_use_demo
├── __init__.py
├── executor
│ ├── anthropic_executor.py
│ └── showui_executor.py
├── gui_agent
│ ├── actor
│ │ ├── showui_agent.py
│ │ └── uitars_agent.py
│ ├── llm_utils
│ │ ├── llm_utils.py
│ │ ├── oai.py
│ │ ├── qwen.py
│ │ └── run_llm.py
│ └── planner
│ │ ├── anthropic_agent.py
│ │ ├── api_vlm_planner.py
│ │ └── local_vlm_planner.py
├── loop.py
├── remote_inference.py
└── tools
│ ├── __init__.py
│ ├── base.py
│ ├── bash.py
│ ├── collection.py
│ ├── colorful_text.py
│ ├── computer.py
│ ├── edit.py
│ ├── logger.py
│ ├── run.py
│ └── screen_capture.py
├── docs
└── README_cn.md
├── install_tools
├── install_showui-awq-4bit.py
├── install_showui.py
├── install_uitars-2b-sft.py
└── test_ui-tars_server.py
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | .venv
2 | .ruff_cache
3 | __pycache__
4 | .pytest_cache
5 | .cache
6 | .ipynb_checkpoints
7 | .ipynb
8 | .DS_Store
9 | /tmp
10 | /.gradio
11 | /.zed
12 | /showui*
13 | /ui-tars*
14 | /demo
15 | /Qwen*
16 | /install_tools/install_qwen*
17 | /dev_tools*
18 | test.ipynb
19 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | Copyright [2024] [Show Lab Computer-Use-OOTB Team]
179 |
180 | Licensed under the Apache License, Version 2.0 (the "License");
181 | you may not use this file except in compliance with the License.
182 | You may obtain a copy of the License at
183 |
184 | http://www.apache.org/licenses/LICENSE-2.0
185 |
186 | Unless required by applicable law or agreed to in writing, software
187 | distributed under the License is distributed on an "AS IS" BASIS,
188 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
189 | See the License for the specific language governing permissions and
190 | limitations under the License.
191 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
17 |
18 | ## Overview
19 | **Computer Use OOTB** is an out-of-the-box (OOTB) solution for Desktop GUI Agent, including API-based (**Claude 3.5 Computer Use**) and locally-running models (**ShowUI**, **UI-TARS**).
20 |
21 | **No Docker** is required, and it supports both **Windows** and **macOS**. OOTB provides a user-friendly interface based on Gradio.🎨
22 |
23 | Visit our study on GUI Agent of Claude 3.5 Computer Use [[project page]](https://computer-use-ootb.github.io). 🌐
24 |
25 | ## Update
26 | - **[2025/02/08]** We've added the support for [**UI-TARS**](https://github.com/bytedance/UI-TARS). Follow [Cloud Deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#cloud-deployment) or [VLLM deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#local-deployment-vllm) to implement UI-TARS and run it locally in OOTB.
27 | - **Major Update! [2024/12/04]** **Local Run🔥** is now live! Say hello to [**ShowUI**](https://github.com/showlab/ShowUI), an open-source 2B vision-language-action (VLA) model for GUI Agent. Now compatible with `"gpt-4o + ShowUI" (~200x cheaper)`* & `"Qwen2-VL + ShowUI" (~30x cheaper)`* for only few cents for each task💰! *compared to Claude Computer Use.
28 | - **[2024/11/20]** We've added some examples to help you get hands-on experience with Claude 3.5 Computer Use.
29 | - **[2024/11/19]** Forget about the single-display limit set by Anthropic - you can now use **multiple displays** 🎉!
30 | - **[2024/11/18]** We've released a deep analysis of Claude 3.5 Computer Use: [https://arxiv.org/abs/2411.10323](https://arxiv.org/abs/2411.10323).
31 | - **[2024/11/11]** Forget about the low-resolution display limit set by Anthropic — you can now use *any resolution you like* and still keep the **screenshot token cost low** 🎉!
32 | - **[2024/11/11]** Now both **Windows** and **macOS** platforms are supported 🎉!
33 | - **[2024/10/25]** Now you can **Remotely Control** your computer 💻 through your mobile device 📱 — **No Mobile App Installation** required! Give it a try and have fun 🎉.
34 |
35 |
36 | ## Demo Video
37 |
38 | https://github.com/user-attachments/assets/f50b7611-2350-4712-af9e-3d31e30020ee
39 |
40 |
48 |
49 |
50 | ## 🚀 Getting Started
51 |
52 | ### 0. Prerequisites
53 | - Instal Miniconda on your system through this [link](https://www.anaconda.com/download?utm_source=anacondadocs&utm_medium=documentation&utm_campaign=download&utm_content=topnavalldocs). (**Python Version: >= 3.12**).
54 | - Hardware Requirements (optional, for ShowUI local-run):
55 | - **Windows (CUDA-enabled):** A compatible NVIDIA GPU with CUDA support, >=6GB GPU memory
56 | - **macOS (Apple Silicon):** M1 chip (or newer), >=16GB unified RAM
57 |
58 |
59 | ### 1. Clone the Repository 📂
60 | Open the Conda Terminal. (After installation Of Miniconda, it will appear in the Start menu.)
61 | Run the following command on **Conda Terminal**.
62 | ```bash
63 | git clone https://github.com/showlab/computer_use_ootb.git
64 | cd computer_use_ootb
65 | ```
66 |
67 | ### 2.1 Install Dependencies 🔧
68 | ```bash
69 | pip install -r requirements.txt
70 | ```
71 |
72 | ### 2.2 (Optional) Get Prepared for **ShowUI** Local-Run
73 |
74 | 1. Download all files of the ShowUI-2B model via the following command. Ensure the `ShowUI-2B` folder is under the `computer_use_ootb` folder.
75 |
76 | ```python
77 | python install_tools/install_showui.py
78 | ```
79 |
80 | 2. Make sure to install the correct GPU version of PyTorch (CUDA, MPS, etc.) on your machine. See [install guide and verification](https://pytorch.org/get-started/locally/).
81 |
82 | 3. Get API Keys for [GPT-4o](https://platform.openai.com/docs/quickstart) or [Qwen-VL](https://help.aliyun.com/zh/dashscope/developer-reference/acquisition-and-configuration-of-api-key). For mainland China users, Qwen API free trial for first 1 mil tokens is [available](https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api).
83 |
84 | ### 2.3 (Optional) Get Prepared for **UI-TARS** Local-Run
85 |
86 | 1. Follow [Cloud Deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#cloud-deployment) or [VLLM deployment](https://github.com/bytedance/UI-TARS?tab=readme-ov-file#local-deployment-vllm) guides to deploy your UI-TARS server.
87 |
88 | 2. Test your UI-TARS sever with the script `.\install_tools\test_ui-tars_server.py`.
89 |
90 | ### 2.4 (Optional) If you want to deploy Qwen model as planner on ssh server
91 | 1. git clone this project on your ssh server
92 |
93 | 2. python computer_use_demo/remote_inference.py
94 | ### 3. Start the Interface ▶️
95 |
96 | **Start the OOTB interface:**
97 | ```bash
98 | python app.py
99 | ```
100 | If you successfully start the interface, you will see two URLs in the terminal:
101 | ```bash
102 | * Running on local URL: http://127.0.0.1:7860
103 | * Running on public URL: https://xxxxxxxxxxxxxxxx.gradio.live (Do not share this link with others, or they will be able to control your computer.)
104 | ```
105 |
106 |
107 | > For convenience, we recommend running one or more of the following command to set API keys to the environment variables before starting the interface. Then you don’t need to manually pass the keys each run. On Windows Powershell (via the `set` command if on cmd):
108 | > ```bash
109 | > $env:ANTHROPIC_API_KEY="sk-xxxxx" (Replace with your own key)
110 | > $env:QWEN_API_KEY="sk-xxxxx"
111 | > $env:OPENAI_API_KEY="sk-xxxxx"
112 | > ```
113 | > On macOS/Linux, replace `$env:ANTHROPIC_API_KEY` with `export ANTHROPIC_API_KEY` in the above command.
114 |
115 |
116 | ### 4. Control Your Computer with Any Device can Access the Internet
117 | - **Computer to be controlled**: The one installed software.
118 | - **Device Send Command**: The one opens the website.
119 |
120 | Open the website at http://localhost:7860/ (if you're controlling the computer itself) or https://xxxxxxxxxxxxxxxxx.gradio.live in your mobile browser for remote control.
121 |
122 | Enter the Anthropic API key (you can obtain it through this [website](https://console.anthropic.com/settings/keys)), then give commands to let the AI perform your tasks.
123 |
124 | ### ShowUI Advanced Settings
125 |
126 | We provide a 4-bit quantized ShowUI-2B model for cost-efficient inference (currently **only support CUDA devices**). To download the 4-bit quantized ShowUI-2B model:
127 | ```
128 | python install_tools/install_showui-awq-4bit.py
129 | ```
130 | Then, enable the quantized setting in the 'ShowUI Advanced Settings' dropdown menu.
131 |
132 | Besides, we also provide a slider to quickly adjust the `max_pixel` parameter in the ShowUI model. This controls the visual input size of the model and greatly affects the memory and inference speed.
133 |
134 | ## 📊 GUI Agent Model Zoo
135 |
136 | Now, OOTB supports customizing the GUI Agent via the following models:
137 |
138 | - **Unified Model**: Unified planner & actor, can both make the high-level planning and take the low-level control.
139 | - **Planner**: General-purpose LLMs, for handling the high-level planning and decision-making.
140 | - **Actor**: Vision-language-action models, for handling the low-level control and action command generation.
141 |
142 |
143 |
203 |
204 | > where [API] models are based on API calling the LLMs that can inference remotely,
205 | and [Local] models can use your own device that inferences locally with no API costs.
206 |
207 |
208 |
209 | ## 🖥️ Supported Systems
210 | - **Windows** (Claude ✅, ShowUI ✅)
211 | - **macOS** (Claude ✅, ShowUI ✅)
212 |
213 | ## 👓 OOTB Iterface
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 | ## ⚠️ Risks
222 | - **Potential Dangerous Operations by the Model**: The models' performance is still limited and may generate unintended or potentially harmful outputs. Recommend continuously monitoring the AI's actions.
223 | - **Cost Control**: Each task may cost a few dollars for Claude 3.5 Computer Use.💸
224 |
225 | ## 📅 Roadmap
226 | - [ ] **Explore available features**
227 | - [ ] The Claude API seems to be unstable when solving tasks. We are investigating the reasons: resolutions, types of actions required, os platforms, or planning mechanisms. Welcome any thoughts or comments on it.
228 | - [ ] **Interface Design**
229 | - [x] **Support for Gradio** ✨
230 | - [ ] **Simpler Installation**
231 | - [ ] **More Features**... 🚀
232 | - [ ] **Platform**
233 | - [x] **Windows**
234 | - [x] **macOS**
235 | - [x] **Mobile** (Send command)
236 | - [ ] **Mobile** (Be controlled)
237 | - [ ] **Support for More MLLMs**
238 | - [x] **Claude 3.5 Sonnet** 🎵
239 | - [x] **GPT-4o**
240 | - [x] **Qwen2-VL**
241 | - [ ] **Local MLLMs**
242 | - [ ] ...
243 | - [ ] **Improved Prompting Strategy**
244 | - [ ] Optimize prompts for cost-efficiency. 💡
245 | - [x] **Improved Inference Speed**
246 | - [x] Support int4 Quantization.
247 |
248 | ## Join Discussion
249 | Welcome to discuss with us and continuously improve the user experience of Computer Use - OOTB. Reach us using this [**Discord Channel**](https://discord.gg/vMMJTSew37) or the WeChat QR code below!
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
--------------------------------------------------------------------------------
/app.py:
--------------------------------------------------------------------------------
1 | """
2 | Entrypoint for Gradio, see https://gradio.app/
3 | """
4 |
5 | import platform
6 | import asyncio
7 | import base64
8 | import os
9 | import io
10 | import json
11 | from datetime import datetime
12 | from enum import StrEnum
13 | from functools import partial
14 | from pathlib import Path
15 | from typing import cast, Dict
16 | from PIL import Image
17 |
18 | import gradio as gr
19 | from anthropic import APIResponse
20 | from anthropic.types import TextBlock
21 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
22 | from anthropic.types.tool_use_block import ToolUseBlock
23 |
24 | from screeninfo import get_monitors
25 | from computer_use_demo.tools.logger import logger, truncate_string
26 |
27 | logger.info("Starting the gradio app")
28 |
29 | screens = get_monitors()
30 | logger.info(f"Found {len(screens)} screens")
31 |
32 | from computer_use_demo.loop import APIProvider, sampling_loop_sync
33 |
34 | from computer_use_demo.tools import ToolResult
35 | from computer_use_demo.tools.computer import get_screen_details
36 | SCREEN_NAMES, SELECTED_SCREEN_INDEX = get_screen_details()
37 |
38 | API_KEY_FILE = "./api_keys.json"
39 |
40 | WARNING_TEXT = "⚠️ Security Alert: Do not provide access to sensitive accounts or data, as malicious web content can hijack Agent's behavior. Keep monitor on the Agent's actions."
41 |
42 |
43 | def setup_state(state):
44 |
45 | if "messages" not in state:
46 | state["messages"] = []
47 | # -------------------------------
48 | if "planner_model" not in state:
49 | state["planner_model"] = "gpt-4o" # default
50 | if "actor_model" not in state:
51 | state["actor_model"] = "ShowUI" # default
52 | if "planner_provider" not in state:
53 | state["planner_provider"] = "openai" # default
54 | if "actor_provider" not in state:
55 | state["actor_provider"] = "local" # default
56 |
57 | # Fetch API keys from environment variables
58 | if "openai_api_key" not in state:
59 | state["openai_api_key"] = os.getenv("OPENAI_API_KEY", "")
60 | if "anthropic_api_key" not in state:
61 | state["anthropic_api_key"] = os.getenv("ANTHROPIC_API_KEY", "")
62 | if "qwen_api_key" not in state:
63 | state["qwen_api_key"] = os.getenv("QWEN_API_KEY", "")
64 | if "ui_tars_url" not in state:
65 | state["ui_tars_url"] = ""
66 |
67 | # Set the initial api_key based on the provider
68 | if "planner_api_key" not in state:
69 | if state["planner_provider"] == "openai":
70 | state["planner_api_key"] = state["openai_api_key"]
71 | elif state["planner_provider"] == "anthropic":
72 | state["planner_api_key"] = state["anthropic_api_key"]
73 | elif state["planner_provider"] == "qwen":
74 | state["planner_api_key"] = state["qwen_api_key"]
75 | else:
76 | state["planner_api_key"] = ""
77 |
78 | logger.info(f"loaded initial api_key for {state['planner_provider']}: {state['planner_api_key']}")
79 |
80 | if not state["planner_api_key"]:
81 | logger.warning("Planner API key not found. Please set it in the environment or paste in textbox.")
82 |
83 |
84 | if "selected_screen" not in state:
85 | state['selected_screen'] = SELECTED_SCREEN_INDEX if SCREEN_NAMES else 0
86 |
87 | if "auth_validated" not in state:
88 | state["auth_validated"] = False
89 | if "responses" not in state:
90 | state["responses"] = {}
91 | if "tools" not in state:
92 | state["tools"] = {}
93 | if "only_n_most_recent_images" not in state:
94 | state["only_n_most_recent_images"] = 10 # 10
95 | if "custom_system_prompt" not in state:
96 | state["custom_system_prompt"] = ""
97 | # remove if want to use default system prompt
98 | device_os_name = "Windows" if platform.system() == "Windows" else "Mac" if platform.system() == "Darwin" else "Linux"
99 | state["custom_system_prompt"] += f"\n\nNOTE: you are operating a {device_os_name} machine"
100 | if "hide_images" not in state:
101 | state["hide_images"] = False
102 | if 'chatbot_messages' not in state:
103 | state['chatbot_messages'] = []
104 |
105 | if "showui_config" not in state:
106 | state["showui_config"] = "Default"
107 | if "max_pixels" not in state:
108 | state["max_pixels"] = 1344
109 | if "awq_4bit" not in state:
110 | state["awq_4bit"] = False
111 |
112 |
113 | async def main(state):
114 | """Render loop for Gradio"""
115 | setup_state(state)
116 | return "Setup completed"
117 |
118 |
119 | def validate_auth(provider: APIProvider, api_key: str | None):
120 | if provider == APIProvider.ANTHROPIC:
121 | if not api_key:
122 | return "Enter your Anthropic API key to continue."
123 | if provider == APIProvider.BEDROCK:
124 | import boto3
125 |
126 | if not boto3.Session().get_credentials():
127 | return "You must have AWS credentials set up to use the Bedrock API."
128 | if provider == APIProvider.VERTEX:
129 | import google.auth
130 | from google.auth.exceptions import DefaultCredentialsError
131 |
132 | if not os.environ.get("CLOUD_ML_REGION"):
133 | return "Set the CLOUD_ML_REGION environment variable to use the Vertex API."
134 | try:
135 | google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])
136 | except DefaultCredentialsError:
137 | return "Your google cloud credentials are not set up correctly."
138 |
139 |
140 | def _api_response_callback(response: APIResponse[BetaMessage], response_state: dict):
141 | response_id = datetime.now().isoformat()
142 | response_state[response_id] = response
143 |
144 |
145 | def _tool_output_callback(tool_output: ToolResult, tool_id: str, tool_state: dict):
146 | tool_state[tool_id] = tool_output
147 |
148 |
149 | def chatbot_output_callback(message, chatbot_state, hide_images=False, sender="bot"):
150 |
151 | def _render_message(message: str | BetaTextBlock | BetaToolUseBlock | ToolResult, hide_images=False):
152 |
153 | logger.info(f"_render_message: {str(message)[:100]}")
154 |
155 | if isinstance(message, str):
156 | return message
157 |
158 | is_tool_result = not isinstance(message, str) and (
159 | isinstance(message, ToolResult)
160 | or message.__class__.__name__ == "ToolResult"
161 | or message.__class__.__name__ == "CLIResult"
162 | )
163 | if not message or (
164 | is_tool_result
165 | and hide_images
166 | and not hasattr(message, "error")
167 | and not hasattr(message, "output")
168 | ): # return None if hide_images is True
169 | return
170 | # render tool result
171 | if is_tool_result:
172 | message = cast(ToolResult, message)
173 | if message.output:
174 | return message.output
175 | if message.error:
176 | return f"Error: {message.error}"
177 | if message.base64_image and not hide_images:
178 | # somehow can't display via gr.Image
179 | # image_data = base64.b64decode(message.base64_image)
180 | # return gr.Image(value=Image.open(io.BytesIO(image_data)))
181 | return f''
182 |
183 | elif isinstance(message, BetaTextBlock) or isinstance(message, TextBlock):
184 | return message.text
185 | elif isinstance(message, BetaToolUseBlock) or isinstance(message, ToolUseBlock):
186 | return f"Tool Use: {message.name}\nInput: {message.input}"
187 | else:
188 | return message
189 |
190 |
191 | # processing Anthropic messages
192 | message = _render_message(message, hide_images)
193 |
194 | if sender == "bot":
195 | chatbot_state.append((None, message))
196 | else:
197 | chatbot_state.append((message, None))
198 |
199 | # Create a concise version of the chatbot state for logging
200 | concise_state = [(truncate_string(user_msg), truncate_string(bot_msg)) for user_msg, bot_msg in chatbot_state]
201 | logger.info(f"chatbot_output_callback chatbot_state: {concise_state} (truncated)")
202 |
203 |
204 | def process_input(user_input, state):
205 |
206 | setup_state(state)
207 |
208 | # Append the user message to state["messages"]
209 | state["messages"].append(
210 | {
211 | "role": "user",
212 | "content": [TextBlock(type="text", text=user_input)],
213 | }
214 | )
215 |
216 | # Append the user's message to chatbot_messages with None for the assistant's reply
217 | state['chatbot_messages'].append((user_input, None))
218 | yield state['chatbot_messages'] # Yield to update the chatbot UI with the user's message
219 |
220 | # Run sampling_loop_sync with the chatbot_output_callback
221 | for loop_msg in sampling_loop_sync(
222 | system_prompt_suffix=state["custom_system_prompt"],
223 | planner_model=state["planner_model"],
224 | planner_provider=state["planner_provider"],
225 | actor_model=state["actor_model"],
226 | actor_provider=state["actor_provider"],
227 | messages=state["messages"],
228 | output_callback=partial(chatbot_output_callback, chatbot_state=state['chatbot_messages'], hide_images=state["hide_images"]),
229 | tool_output_callback=partial(_tool_output_callback, tool_state=state["tools"]),
230 | api_response_callback=partial(_api_response_callback, response_state=state["responses"]),
231 | api_key=state["planner_api_key"],
232 | only_n_most_recent_images=state["only_n_most_recent_images"],
233 | selected_screen=state['selected_screen'],
234 | showui_max_pixels=state['max_pixels'],
235 | showui_awq_4bit=state['awq_4bit']
236 | ):
237 | if loop_msg is None:
238 | yield state['chatbot_messages']
239 | logger.info("End of task. Close the loop.")
240 | break
241 |
242 |
243 | yield state['chatbot_messages'] # Yield the updated chatbot_messages to update the chatbot UI
244 |
245 |
246 | with gr.Blocks(theme=gr.themes.Soft()) as demo:
247 |
248 | state = gr.State({}) # Use Gradio's state management
249 | setup_state(state.value) # Initialize the state
250 |
251 | # Retrieve screen details
252 | gr.Markdown("# Computer Use OOTB")
253 |
254 | if not os.getenv("HIDE_WARNING", False):
255 | gr.Markdown(WARNING_TEXT)
256 |
257 | with gr.Accordion("Settings", open=True):
258 | with gr.Row():
259 | with gr.Column():
260 | # --------------------------
261 | # Planner
262 | planner_model = gr.Dropdown(
263 | label="Planner Model",
264 | choices=["gpt-4o",
265 | "gpt-4o-mini",
266 | "qwen2-vl-max",
267 | "qwen2-vl-2b (local)",
268 | "qwen2-vl-7b (local)",
269 | "qwen2-vl-2b (ssh)",
270 | "qwen2-vl-7b (ssh)",
271 | "qwen2.5-vl-7b (ssh)",
272 | "claude-3-5-sonnet-20241022"],
273 | value="gpt-4o",
274 | interactive=True,
275 | )
276 | with gr.Column():
277 | planner_api_provider = gr.Dropdown(
278 | label="API Provider",
279 | choices=[option.value for option in APIProvider],
280 | value="openai",
281 | interactive=False,
282 | )
283 | with gr.Column():
284 | planner_api_key = gr.Textbox(
285 | label="Planner API Key",
286 | type="password",
287 | value=state.value.get("planner_api_key", ""),
288 | placeholder="Paste your planner model API key",
289 | interactive=True,
290 | )
291 |
292 | with gr.Column():
293 | actor_model = gr.Dropdown(
294 | label="Actor Model",
295 | choices=["ShowUI", "UI-TARS"],
296 | value="ShowUI",
297 | interactive=True,
298 | )
299 |
300 | with gr.Column():
301 | custom_prompt = gr.Textbox(
302 | label="System Prompt Suffix",
303 | value="",
304 | interactive=True,
305 | )
306 | with gr.Column():
307 | screen_options, primary_index = get_screen_details()
308 | SCREEN_NAMES = screen_options
309 | SELECTED_SCREEN_INDEX = primary_index
310 | screen_selector = gr.Dropdown(
311 | label="Select Screen",
312 | choices=screen_options,
313 | value=screen_options[primary_index] if screen_options else None,
314 | interactive=True,
315 | )
316 | with gr.Column():
317 | only_n_images = gr.Slider(
318 | label="N most recent screenshots",
319 | minimum=0,
320 | maximum=10,
321 | step=1,
322 | value=2,
323 | interactive=True,
324 | )
325 |
326 | with gr.Accordion("ShowUI Advanced Settings", open=False):
327 |
328 | gr.Markdown("""
329 | **Note:** Adjust these settings to fine-tune the resource (**memory** and **infer time**) and performance trade-offs of ShowUI. \\
330 | Quantization model requires additional download. Please refer to [Computer Use OOTB - #ShowUI Advanced Settings guide](https://github.com/showlab/computer_use_ootb?tab=readme-ov-file#showui-advanced-settings) for preparation for this feature.
331 | """)
332 |
333 | # New configuration for ShowUI
334 | with gr.Row():
335 | with gr.Column():
336 | showui_config = gr.Dropdown(
337 | label="ShowUI Preset Configuration",
338 | choices=["Default (Maximum)", "Medium", "Minimal", "Custom"],
339 | value="Default (Maximum)",
340 | interactive=True,
341 | )
342 | with gr.Column():
343 | max_pixels = gr.Slider(
344 | label="Max Visual Tokens",
345 | minimum=720,
346 | maximum=1344,
347 | step=16,
348 | value=1344,
349 | interactive=False,
350 | )
351 | with gr.Column():
352 | awq_4bit = gr.Checkbox(
353 | label="Enable AWQ-4bit Model",
354 | value=False,
355 | interactive=False
356 | )
357 |
358 | # Define the merged dictionary with task mappings
359 | merged_dict = json.load(open("assets/examples/ootb_examples.json", "r"))
360 |
361 | def update_only_n_images(only_n_images_value, state):
362 | state["only_n_most_recent_images"] = only_n_images_value
363 |
364 | # Callback to update the second dropdown based on the first selection
365 | def update_second_menu(selected_category):
366 | return gr.update(choices=list(merged_dict.get(selected_category, {}).keys()))
367 |
368 | # Callback to update the third dropdown based on the second selection
369 | def update_third_menu(selected_category, selected_option):
370 | return gr.update(choices=list(merged_dict.get(selected_category, {}).get(selected_option, {}).keys()))
371 |
372 | # Callback to update the textbox based on the third selection
373 | def update_textbox(selected_category, selected_option, selected_task):
374 | task_data = merged_dict.get(selected_category, {}).get(selected_option, {}).get(selected_task, {})
375 | prompt = task_data.get("prompt", "")
376 | preview_image = task_data.get("initial_state", "")
377 | task_hint = "Task Hint: " + task_data.get("hint", "")
378 | return prompt, preview_image, task_hint
379 |
380 | # Function to update the global variable when the dropdown changes
381 | def update_selected_screen(selected_screen_name, state):
382 | global SCREEN_NAMES
383 | global SELECTED_SCREEN_INDEX
384 | SELECTED_SCREEN_INDEX = SCREEN_NAMES.index(selected_screen_name)
385 | logger.info(f"Selected screen updated to: {SELECTED_SCREEN_INDEX}")
386 | state['selected_screen'] = SELECTED_SCREEN_INDEX
387 |
388 |
389 | def update_planner_model(model_selection, state):
390 | state["model"] = model_selection
391 | # Update planner_model
392 | state["planner_model"] = model_selection
393 | logger.info(f"Model updated to: {state['planner_model']}")
394 |
395 | if model_selection == "qwen2-vl-max":
396 | provider_choices = ["qwen"]
397 | provider_value = "qwen"
398 | provider_interactive = False
399 | api_key_interactive = True
400 | api_key_placeholder = "qwen API key"
401 | actor_model_choices = ["ShowUI", "UI-TARS"]
402 | actor_model_value = "ShowUI"
403 | actor_model_interactive = True
404 | api_key_type = "password" # Display API key in password form
405 |
406 | elif model_selection in ["qwen2-vl-2b (local)", "qwen2-vl-7b (local)"]:
407 | # Set provider to "openai", make it unchangeable
408 | provider_choices = ["local"]
409 | provider_value = "local"
410 | provider_interactive = False
411 | api_key_interactive = False
412 | api_key_placeholder = "not required"
413 | actor_model_choices = ["ShowUI", "UI-TARS"]
414 | actor_model_value = "ShowUI"
415 | actor_model_interactive = True
416 | api_key_type = "password" # Maintain consistency
417 |
418 | elif "ssh" in model_selection:
419 | provider_choices = ["ssh"]
420 | provider_value = "ssh"
421 | provider_interactive = False
422 | api_key_interactive = True
423 | api_key_placeholder = "ssh host and port (e.g. localhost:8000)"
424 | actor_model_choices = ["ShowUI", "UI-TARS"]
425 | actor_model_value = "ShowUI"
426 | actor_model_interactive = True
427 | api_key_type = "text" # Display SSH connection info in plain text
428 | # If SSH connection info already exists, keep it
429 | if "planner_api_key" in state and state["planner_api_key"]:
430 | state["api_key"] = state["planner_api_key"]
431 | else:
432 | state["api_key"] = ""
433 |
434 | elif model_selection == "gpt-4o" or model_selection == "gpt-4o-mini":
435 | # Set provider to "openai", make it unchangeable
436 | provider_choices = ["openai"]
437 | provider_value = "openai"
438 | provider_interactive = False
439 | api_key_interactive = True
440 | api_key_type = "password" # Display API key in password form
441 |
442 | api_key_placeholder = "openai API key"
443 | actor_model_choices = ["ShowUI", "UI-TARS"]
444 | actor_model_value = "ShowUI"
445 | actor_model_interactive = True
446 |
447 | elif model_selection == "claude-3-5-sonnet-20241022":
448 | # Provider can be any of the current choices except 'openai'
449 | provider_choices = [option.value for option in APIProvider if option.value != "openai"]
450 | provider_value = "anthropic" # Set default to 'anthropic'
451 | state['actor_provider'] = "anthropic"
452 | provider_interactive = True
453 | api_key_interactive = True
454 | api_key_placeholder = "claude API key"
455 | actor_model_choices = ["claude-3-5-sonnet-20241022"]
456 | actor_model_value = "claude-3-5-sonnet-20241022"
457 | actor_model_interactive = False
458 | api_key_type = "password" # Display API key in password form
459 |
460 | else:
461 | raise ValueError(f"Model {model_selection} not supported")
462 |
463 | # Update the provider in state
464 | state["planner_api_provider"] = provider_value
465 |
466 | # Update api_key in state based on the provider
467 | if provider_value == "openai":
468 | state["api_key"] = state.get("openai_api_key", "")
469 | elif provider_value == "anthropic":
470 | state["api_key"] = state.get("anthropic_api_key", "")
471 | elif provider_value == "qwen":
472 | state["api_key"] = state.get("qwen_api_key", "")
473 | elif provider_value == "local":
474 | state["api_key"] = ""
475 | # SSH的情况已经在上面处理过了,这里不需要重复处理
476 |
477 | provider_update = gr.update(
478 | choices=provider_choices,
479 | value=provider_value,
480 | interactive=provider_interactive
481 | )
482 |
483 | # Update the API Key textbox
484 | api_key_update = gr.update(
485 | placeholder=api_key_placeholder,
486 | value=state["api_key"],
487 | interactive=api_key_interactive,
488 | type=api_key_type # 添加 type 参数的更新
489 | )
490 |
491 | actor_model_update = gr.update(
492 | choices=actor_model_choices,
493 | value=actor_model_value,
494 | interactive=actor_model_interactive
495 | )
496 |
497 | logger.info(f"Updated state: model={state['planner_model']}, provider={state['planner_api_provider']}, api_key={state['api_key']}")
498 | return provider_update, api_key_update, actor_model_update
499 |
500 | def update_actor_model(actor_model_selection, state):
501 | state["actor_model"] = actor_model_selection
502 | logger.info(f"Actor model updated to: {state['actor_model']}")
503 |
504 | def update_api_key_placeholder(provider_value, model_selection):
505 | if model_selection == "claude-3-5-sonnet-20241022":
506 | if provider_value == "anthropic":
507 | return gr.update(placeholder="anthropic API key")
508 | elif provider_value == "bedrock":
509 | return gr.update(placeholder="bedrock API key")
510 | elif provider_value == "vertex":
511 | return gr.update(placeholder="vertex API key")
512 | else:
513 | return gr.update(placeholder="")
514 | elif model_selection == "gpt-4o + ShowUI":
515 | return gr.update(placeholder="openai API key")
516 | else:
517 | return gr.update(placeholder="")
518 |
519 | def update_system_prompt_suffix(system_prompt_suffix, state):
520 | state["custom_system_prompt"] = system_prompt_suffix
521 |
522 | # When showui_config changes, we set the max_pixels and awq_4bit accordingly.
523 | def handle_showui_config_change(showui_config_val, state):
524 | if showui_config_val == "Default (Maximum)":
525 | state["max_pixels"] = 1344
526 | state["awq_4bit"] = False
527 | return (
528 | gr.update(value=1344, interactive=False),
529 | gr.update(value=False, interactive=False)
530 | )
531 | elif showui_config_val == "Medium":
532 | state["max_pixels"] = 1024
533 | state["awq_4bit"] = False
534 | return (
535 | gr.update(value=1024, interactive=False),
536 | gr.update(value=False, interactive=False)
537 | )
538 | elif showui_config_val == "Minimal":
539 | state["max_pixels"] = 1024
540 | state["awq_4bit"] = True
541 | return (
542 | gr.update(value=1024, interactive=False),
543 | gr.update(value=True, interactive=False)
544 | )
545 | elif showui_config_val == "Custom":
546 | # Do not overwrite the current user values, just make them interactive
547 | return (
548 | gr.update(interactive=True),
549 | gr.update(interactive=True)
550 | )
551 |
552 | def update_api_key(api_key_value, state):
553 | """Handle API key updates"""
554 | state["planner_api_key"] = api_key_value
555 | if state["planner_provider"] == "ssh":
556 | state["api_key"] = api_key_value
557 | logger.info(f"API key updated: provider={state['planner_provider']}, api_key={state['api_key']}")
558 |
559 | with gr.Accordion("Quick Start Prompt", open=False): # open=False 表示默认收
560 | # Initialize Gradio interface with the dropdowns
561 | with gr.Row():
562 | # Set initial values
563 | initial_category = "Game Play"
564 | initial_second_options = list(merged_dict[initial_category].keys())
565 | initial_third_options = list(merged_dict[initial_category][initial_second_options[0]].keys())
566 | initial_text_value = merged_dict[initial_category][initial_second_options[0]][initial_third_options[0]]
567 |
568 | with gr.Column(scale=2):
569 | # First dropdown for Task Category
570 | first_menu = gr.Dropdown(
571 | choices=list(merged_dict.keys()), label="Task Category", interactive=True, value=initial_category
572 | )
573 |
574 | # Second dropdown for Software
575 | second_menu = gr.Dropdown(
576 | choices=initial_second_options, label="Software", interactive=True, value=initial_second_options[0]
577 | )
578 |
579 | # Third dropdown for Task
580 | third_menu = gr.Dropdown(
581 | choices=initial_third_options, label="Task", interactive=True, value=initial_third_options[0]
582 | # choices=["Please select a task"]+initial_third_options, label="Task", interactive=True, value="Please select a task"
583 | )
584 |
585 | with gr.Column(scale=1):
586 | initial_image_value = "./assets/examples/init_states/honkai_star_rail_showui.png" # default image path
587 | image_preview = gr.Image(value=initial_image_value, label="Reference Initial State", height=260-(318.75-280))
588 | hintbox = gr.Markdown("Task Hint: Selected options will appear here.")
589 |
590 | # Textbox for displaying the mapped value
591 | # textbox = gr.Textbox(value=initial_text_value, label="Action")
592 |
593 | # api_key.change(fn=lambda key: save_to_storage(API_KEY_FILE, key), inputs=api_key)
594 |
595 | with gr.Row():
596 | # submit_button = gr.Button("Submit") # Add submit button
597 | with gr.Column(scale=8):
598 | chat_input = gr.Textbox(show_label=False, placeholder="Type a message to send to Computer Use OOTB...", container=False)
599 | with gr.Column(scale=1, min_width=50):
600 | submit_button = gr.Button(value="Send", variant="primary")
601 |
602 | chatbot = gr.Chatbot(label="Chatbot History", type="tuples", autoscroll=True, height=580, group_consecutive_messages=False)
603 |
604 | planner_model.change(fn=update_planner_model, inputs=[planner_model, state], outputs=[planner_api_provider, planner_api_key, actor_model])
605 | planner_api_provider.change(fn=update_api_key_placeholder, inputs=[planner_api_provider, planner_model], outputs=planner_api_key)
606 | actor_model.change(fn=update_actor_model, inputs=[actor_model, state], outputs=None)
607 |
608 | screen_selector.change(fn=update_selected_screen, inputs=[screen_selector, state], outputs=None)
609 | only_n_images.change(fn=update_only_n_images, inputs=[only_n_images, state], outputs=None)
610 |
611 | # When showui_config changes, we update max_pixels and awq_4bit automatically.
612 | showui_config.change(fn=handle_showui_config_change,
613 | inputs=[showui_config, state],
614 | outputs=[max_pixels, awq_4bit])
615 |
616 | # Link callbacks to update dropdowns based on selections
617 | first_menu.change(fn=update_second_menu, inputs=first_menu, outputs=second_menu)
618 | second_menu.change(fn=update_third_menu, inputs=[first_menu, second_menu], outputs=third_menu)
619 | third_menu.change(fn=update_textbox, inputs=[first_menu, second_menu, third_menu], outputs=[chat_input, image_preview, hintbox])
620 |
621 | # chat_input.submit(process_input, [chat_input, state], chatbot)
622 | submit_button.click(process_input, [chat_input, state], chatbot)
623 |
624 | planner_api_key.change(
625 | fn=update_api_key,
626 | inputs=[planner_api_key, state],
627 | outputs=None
628 | )
629 |
630 | demo.launch(share=False,
631 | allowed_paths=["./"],
632 | server_port=7888) # TODO: allowed_paths
--------------------------------------------------------------------------------
/assets/Teaser.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/Teaser.gif
--------------------------------------------------------------------------------
/assets/examples/init_states/amazon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/amazon.png
--------------------------------------------------------------------------------
/assets/examples/init_states/booking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/booking.png
--------------------------------------------------------------------------------
/assets/examples/init_states/honkai_star_rail.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/honkai_star_rail.png
--------------------------------------------------------------------------------
/assets/examples/init_states/honkai_star_rail_showui.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/honkai_star_rail_showui.png
--------------------------------------------------------------------------------
/assets/examples/init_states/ign.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/ign.png
--------------------------------------------------------------------------------
/assets/examples/init_states/powerpoint.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/powerpoint.png
--------------------------------------------------------------------------------
/assets/examples/init_states/powerpoint_homepage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/examples/init_states/powerpoint_homepage.png
--------------------------------------------------------------------------------
/assets/examples/ootb_examples.json:
--------------------------------------------------------------------------------
1 | {
2 | "Web Navigation": {
3 | "Shopping": {
4 | "Search Gift Card": {
5 | "hint": "Search for 'You are Amazing' congrats gift card",
6 | "prompt": "Search for 'You are Amazing' congrats gift card",
7 | "initial_state": ".\\assets\\examples\\init_states\\amazon.png"
8 | },
9 | "Add Headphones": {
10 | "hint": "Add a set of wireless headphones to your cart",
11 | "prompt": "Add a set of wireless headphones to your cart",
12 | "initial_state": ".\\assets\\examples\\init_states\\amazon.png"
13 | }
14 | },
15 | "Accommodation": {
16 | "Find Private Room": {
17 | "hint": "Find a private room in New York",
18 | "prompt": "Find a private room in New York",
19 | "initial_state": ".\\assets\\examples\\init_states\\booking.png"
20 | }
21 | },
22 | "Gaming": {
23 | "Walk-through Guide": {
24 | "hint": "Find a walk-through guide for the game 'Black Myth: Wukong'",
25 | "prompt": "Find a walk-through guide for the game 'Black Myth: Wukong'",
26 | "initial_state": ".\\assets\\examples\\init_states\\ign.png"
27 | }
28 | }
29 | },
30 | "Productivity": {
31 | "Presentations": {
32 | "Create Presentation": {
33 | "hint": "Create a new presentation and set the title to 'Hail Computer Use OOTB!'",
34 | "prompt": "Create a new presentation and edit the title to 'Hail Computer Use OOTB!'",
35 | "initial_state": ".\\assets\\examples\\init_states\\powerpoint_homepage.png"
36 | },
37 | "Duplicate First Slide": {
38 | "hint": "Duplicate the first slide in PowerPoint",
39 | "prompt": "Duplicate the first slide in PowerPoint",
40 | "initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
41 | },
42 | "Insert Picture": {
43 | "hint": "Insert a picture from my device into the current slide, selecting the first image in the photo browser",
44 | "prompt": "Insert a picture from my device into the current slide",
45 | "initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
46 | },
47 | "Apply Morph Transition": {
48 | "hint": "Apply the Morph transition to all slides",
49 | "prompt": "Apply the Morph transition to all slides",
50 | "initial_state": ".\\assets\\examples\\init_states\\powerpoint.png"
51 | }
52 | }
53 | },
54 | "Game Play": {
55 | "Honkai: Star Rail": {
56 | "Daily Task (ShowUI)": {
57 | "hint": "Complete the daily task",
58 | "prompt": "1. Escape on the keyboard to open the menu. 2. Click 'Interastral Guide'. 3. Then click 'calyx golden for exp' entry. 4. Then click on the 'Teleport of Buds of MEMORIES'. 5. Press the 'bottom plus + button, the one below'. 6. Then click Challenge 7. Then click Start Challenge. 8. Then click on exit when the battle is completed.",
59 | "initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail_showui.png"
60 | },
61 | "Daily Task (Claude 3.5 Computer Use)": {
62 | "hint": "Complete the daily task",
63 | "prompt": "You are currently playing Honkai: Star Rail, your objective is to finish a daily game task for me. Press escape on the keyboard to open the menu, then click interastral guide, then click 'calyx golden for exp' entry on the left side of the popped up game window. Only then click on the teleport button on the same line of the first entry named 'buds of MEMORIES' (you need to carefully check the name), then click 'plus +' button 5 times to increase attempts to 6, then click challenge, then click start challenge. Then click the auto-battle button at the right-up corner - carefully count from the right to the left, it should be the second icon, it is near the 'pause' icon, it looks like an 'infinite' symbol. Then click on exit when the battle is completed.",
64 | "initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail.png"
65 | },
66 | "Warp": {
67 | "hint": "Perform a warp (gacha pull)",
68 | "prompt": "You are currently playing Honkai: Star Rail, your objective is to perform a 10-warp pull for me. Press escape on the keyboard to open the menu, then click warp. It should open the warp page, and the first entry on the left side would be 'Words of Yore', this would be the destination pool. Then click on 'warp x10' to perform a 10-warp pull, then click at the blank space at the right-up corner to reveal the arrow at the right-up corner, then click on the arrow to skip the animation. Always click on the arrow to continue skipping the animation if there is an arrow at the right-up corner. Only when all animations are skipped by clicking on the arrows, the pull summary page will appear and there would be a cross there, click on the cross to finish the pull. Good luck!",
69 | "initial_state": ".\\assets\\examples\\init_states\\honkai_star_rail.png"
70 | }
71 | }
72 | }
73 | }
74 |
--------------------------------------------------------------------------------
/assets/gradio_interface.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/gradio_interface.png
--------------------------------------------------------------------------------
/assets/ootb_icon.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/ootb_icon.png
--------------------------------------------------------------------------------
/assets/ootb_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/ootb_logo.png
--------------------------------------------------------------------------------
/assets/wechat_3.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/assets/wechat_3.jpg
--------------------------------------------------------------------------------
/computer_use_demo/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/showlab/computer_use_ootb/21b05ae06de89700b998dc85d8ef64454f0d0291/computer_use_demo/__init__.py
--------------------------------------------------------------------------------
/computer_use_demo/executor/anthropic_executor.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | from typing import Any, Dict, cast
3 | from collections.abc import Callable
4 | from anthropic.types.beta import (
5 | BetaContentBlock,
6 | BetaContentBlockParam,
7 | BetaImageBlockParam,
8 | BetaMessage,
9 | BetaMessageParam,
10 | BetaTextBlockParam,
11 | BetaToolResultBlockParam,
12 | )
13 | from anthropic.types import TextBlock
14 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
15 | from ..tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
16 |
17 |
18 | class AnthropicExecutor:
19 | def __init__(
20 | self,
21 | output_callback: Callable[[BetaContentBlockParam], None],
22 | tool_output_callback: Callable[[Any, str], None],
23 | selected_screen: int = 0
24 | ):
25 | self.tool_collection = ToolCollection(
26 | ComputerTool(selected_screen=selected_screen),
27 | BashTool(),
28 | EditTool(),
29 | )
30 | self.output_callback = output_callback
31 | self.tool_output_callback = tool_output_callback
32 |
33 | def __call__(self, response: BetaMessage, messages: list[BetaMessageParam]):
34 | new_message = {
35 | "role": "assistant",
36 | "content": cast(list[BetaContentBlockParam], response.content),
37 | }
38 | if new_message not in messages:
39 | messages.append(new_message)
40 | else:
41 | print("new_message already in messages, there are duplicates.")
42 |
43 | tool_result_content: list[BetaToolResultBlockParam] = []
44 | for content_block in cast(list[BetaContentBlock], response.content):
45 |
46 | self.output_callback(content_block, sender="bot")
47 | # Execute the tool
48 | if content_block.type == "tool_use":
49 | # Run the asynchronous tool execution in a synchronous context
50 | result = asyncio.run(self.tool_collection.run(
51 | name=content_block.name,
52 | tool_input=cast(dict[str, Any], content_block.input),
53 | ))
54 |
55 | self.output_callback(result, sender="bot")
56 |
57 | tool_result_content.append(
58 | _make_api_tool_result(result, content_block.id)
59 | )
60 | self.tool_output_callback(result, content_block.id)
61 |
62 | # Craft messages based on the content_block
63 | # Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
64 |
65 | display_messages = _message_display_callback(messages)
66 | # display_messages = []
67 |
68 | # Send the messages to the gradio
69 | for user_msg, bot_msg in display_messages:
70 | yield [user_msg, bot_msg], tool_result_content
71 |
72 | if not tool_result_content:
73 | return messages
74 |
75 | return tool_result_content
76 |
77 | def _message_display_callback(messages):
78 | display_messages = []
79 | for msg in messages:
80 | try:
81 | if isinstance(msg["content"][0], TextBlock):
82 | display_messages.append((msg["content"][0].text, None)) # User message
83 | elif isinstance(msg["content"][0], BetaTextBlock):
84 | display_messages.append((None, msg["content"][0].text)) # Bot message
85 | elif isinstance(msg["content"][0], BetaToolUseBlock):
86 | display_messages.append((None, f"Tool Use: {msg['content'][0].name}\nInput: {msg['content'][0].input}")) # Bot message
87 | elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
88 | display_messages.append((None, f'')) # Bot message
89 | else:
90 | print(msg["content"][0])
91 | except Exception as e:
92 | print("error", e)
93 | pass
94 | return display_messages
95 |
96 | def _make_api_tool_result(
97 | result: ToolResult, tool_use_id: str
98 | ) -> BetaToolResultBlockParam:
99 | """Convert an agent ToolResult to an API ToolResultBlockParam."""
100 | tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
101 | is_error = False
102 | if result.error:
103 | is_error = True
104 | tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
105 | else:
106 | if result.output:
107 | tool_result_content.append(
108 | {
109 | "type": "text",
110 | "text": _maybe_prepend_system_tool_result(result, result.output),
111 | }
112 | )
113 | if result.base64_image:
114 | tool_result_content.append(
115 | {
116 | "type": "image",
117 | "source": {
118 | "type": "base64",
119 | "media_type": "image/png",
120 | "data": result.base64_image,
121 | },
122 | }
123 | )
124 | return {
125 | "type": "tool_result",
126 | "content": tool_result_content,
127 | "tool_use_id": tool_use_id,
128 | "is_error": is_error,
129 | }
130 |
131 |
132 | def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
133 | if result.system:
134 | result_text = f"{result.system}\n{result_text}"
135 | return result_text
--------------------------------------------------------------------------------
/computer_use_demo/executor/showui_executor.py:
--------------------------------------------------------------------------------
1 | import ast
2 | import asyncio
3 | from typing import Any, Dict, cast, List, Union
4 | from collections.abc import Callable
5 | import uuid
6 | from anthropic.types.beta import (
7 | BetaContentBlock,
8 | BetaContentBlockParam,
9 | BetaImageBlockParam,
10 | BetaMessage,
11 | BetaMessageParam,
12 | BetaTextBlockParam,
13 | BetaToolResultBlockParam,
14 | )
15 | from anthropic.types import TextBlock
16 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
17 | from computer_use_demo.tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
18 | from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
19 |
20 |
21 | class ShowUIExecutor:
22 | def __init__(
23 | self,
24 | output_callback: Callable[[BetaContentBlockParam], None],
25 | tool_output_callback: Callable[[Any, str], None],
26 | selected_screen: int = 0
27 | ):
28 | self.output_callback = output_callback
29 | self.tool_output_callback = tool_output_callback
30 | self.selected_screen = selected_screen
31 | self.screen_bbox = self._get_screen_resolution()
32 | print("Screen BBox:", self.screen_bbox)
33 |
34 | self.tool_collection = ToolCollection(
35 | ComputerTool(selected_screen=selected_screen, is_scaling=False)
36 | )
37 |
38 | self.supported_action_type={
39 | # "showui_action": "anthropic_tool_action"
40 | "CLICK": 'key', # TBD
41 | "INPUT": "key",
42 | "ENTER": "key", # TBD
43 | "ESC": "key",
44 | "ESCAPE": "key",
45 | "PRESS": "key",
46 | }
47 |
48 | def __call__(self, response: str, messages: list[BetaMessageParam]):
49 | # response is expected to be :
50 | # {'content': "{'action': 'CLICK', 'value': None, 'position': [0.83, 0.15]}, ...", 'role': 'assistant'},
51 |
52 | action_dict = self._format_actor_output(response) # str -> dict
53 |
54 | actions = action_dict["content"]
55 | role = action_dict["role"]
56 |
57 | # Parse the actions from showui
58 | action_list = self._parse_showui_output(actions)
59 | print("Parsed Action List:", action_list)
60 |
61 | tool_result_content = None
62 |
63 | if action_list is not None and len(action_list) > 0:
64 |
65 | for action in action_list: # Execute the tool (adapting the code from anthropic_executor.py)
66 |
67 | tool_result_content: list[BetaToolResultBlockParam] = []
68 |
69 | self.output_callback(f"{colorful_text_showui}:\n{action}", sender="bot")
70 | print("Converted Action:", action)
71 |
72 | sim_content_block = BetaToolUseBlock(id=f'toolu_{uuid.uuid4()}',
73 | input={'action': action["action"], 'text': action["text"], 'coordinate': action["coordinate"]},
74 | name='computer', type='tool_use')
75 |
76 | # update messages
77 | new_message = {
78 | "role": "assistant",
79 | "content": cast(list[BetaContentBlockParam], [sim_content_block]),
80 | }
81 | if new_message not in messages:
82 | messages.append(new_message)
83 |
84 | # Run the asynchronous tool execution in a synchronous context
85 | result = self.tool_collection.sync_call(
86 | name=sim_content_block.name,
87 | tool_input=cast(dict[str, Any], sim_content_block.input),
88 | )
89 |
90 | tool_result_content.append(
91 | _make_api_tool_result(result, sim_content_block.id)
92 | )
93 | # print(f"executor: tool_result_content: {tool_result_content}")
94 | self.tool_output_callback(result, sim_content_block.id)
95 |
96 | # Craft messages based on the content_block
97 | # Note: to display the messages in the gradio, you should organize the messages in the following way (user message, bot message)
98 | display_messages = _message_display_callback(messages)
99 | # Send the messages to the gradio
100 | for user_msg, bot_msg in display_messages:
101 | yield [user_msg, bot_msg], tool_result_content
102 |
103 | return tool_result_content
104 |
105 |
106 | def _format_actor_output(self, action_output: str|dict) -> Dict[str, Any]:
107 | if type(action_output) == dict:
108 | return action_output
109 | else:
110 | try:
111 | action_output.replace("'", "\"")
112 | action_dict = ast.literal_eval(action_output)
113 | return action_dict
114 | except Exception as e:
115 | print(f"Error parsing action output: {e}")
116 | return None
117 |
118 |
119 | def _parse_showui_output(self, output_text: str) -> Union[List[Dict[str, Any]], None]:
120 | try:
121 | output_text = output_text.strip()
122 |
123 | # process single dictionary
124 | if output_text.startswith("{") and output_text.endswith("}"):
125 | output_text = f"[{output_text}]"
126 |
127 | # Validate if the output resembles a list of dictionaries
128 | if not (output_text.startswith("[") and output_text.endswith("]")):
129 | raise ValueError("Output does not look like a valid list or dictionary.")
130 |
131 | print("Output Text:", output_text)
132 |
133 | parsed_output = ast.literal_eval(output_text)
134 |
135 | print("Parsed Output:", parsed_output)
136 |
137 | if isinstance(parsed_output, dict):
138 | parsed_output = [parsed_output]
139 | elif not isinstance(parsed_output, list):
140 | raise ValueError("Parsed output is neither a dictionary nor a list.")
141 |
142 | if not all(isinstance(item, dict) for item in parsed_output):
143 | raise ValueError("Not all items in the parsed output are dictionaries.")
144 |
145 | # refine key: value pairs, mapping to the Anthropic's format
146 | refined_output = []
147 |
148 | for action_item in parsed_output:
149 |
150 | print("Action Item:", action_item)
151 | # sometime showui returns lower case action names
152 | action_item["action"] = action_item["action"].upper()
153 |
154 | if action_item["action"] not in self.supported_action_type:
155 | raise ValueError(f"Action {action_item['action']} not supported. Check the output from ShowUI: {output_text}")
156 | # continue
157 |
158 | elif action_item["action"] == "CLICK": # 1. click -> mouse_move + left_click
159 | x, y = action_item["position"]
160 | action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
161 | int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
162 | refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
163 | refined_output.append({"action": "left_click", "text": None, "coordinate": None})
164 |
165 | elif action_item["action"] == "INPUT": # 2. input -> type
166 | refined_output.append({"action": "type", "text": action_item["value"], "coordinate": None})
167 |
168 | elif action_item["action"] == "ENTER": # 3. enter -> key, enter
169 | refined_output.append({"action": "key", "text": "Enter", "coordinate": None})
170 |
171 | elif action_item["action"] == "ESC" or action_item["action"] == "ESCAPE": # 4. enter -> key, enter
172 | refined_output.append({"action": "key", "text": "Escape", "coordinate": None})
173 |
174 | elif action_item["action"] == "HOVER": # 5. hover -> mouse_move
175 | x, y = action_item["position"]
176 | action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
177 | int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
178 | refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
179 |
180 | elif action_item["action"] == "SCROLL": # 6. scroll -> key: pagedown
181 | if action_item["value"] == "up":
182 | refined_output.append({"action": "key", "text": "pageup", "coordinate": None})
183 | elif action_item["value"] == "down":
184 | refined_output.append({"action": "key", "text": "pagedown", "coordinate": None})
185 | else:
186 | raise ValueError(f"Scroll direction {action_item['value']} not supported.")
187 |
188 | elif action_item["action"] == "PRESS": # 7. press
189 | x, y = action_item["position"]
190 | action_item["position"] = (int(x * (self.screen_bbox[2] - self.screen_bbox[0])),
191 | int(y * (self.screen_bbox[3] - self.screen_bbox[1])))
192 | refined_output.append({"action": "mouse_move", "text": None, "coordinate": tuple(action_item["position"])})
193 | refined_output.append({"action": "left_press", "text": None, "coordinate": None})
194 |
195 | return refined_output
196 |
197 | except Exception as e:
198 | print(f"Error parsing output: {e}")
199 | return None
200 |
201 |
202 | def _get_screen_resolution(self):
203 | from screeninfo import get_monitors
204 | import platform
205 | if platform.system() == "Darwin":
206 | import Quartz # uncomment this line if you are on macOS
207 | import subprocess
208 |
209 | # Detect platform
210 | system = platform.system()
211 |
212 | if system == "Windows":
213 | # Windows: Use screeninfo to get monitor details
214 | screens = get_monitors()
215 |
216 | # Sort screens by x position to arrange from left to right
217 | sorted_screens = sorted(screens, key=lambda s: s.x)
218 |
219 | if self.selected_screen < 0 or self.selected_screen >= len(screens):
220 | raise IndexError("Invalid screen index.")
221 |
222 | screen = sorted_screens[self.selected_screen]
223 | bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
224 |
225 | elif system == "Darwin": # macOS
226 | # macOS: Use Quartz to get monitor details
227 | max_displays = 32 # Maximum number of displays to handle
228 | active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
229 |
230 | # Get the display bounds (resolution) for each active display
231 | screens = []
232 | for display_id in active_displays:
233 | bounds = Quartz.CGDisplayBounds(display_id)
234 | screens.append({
235 | 'id': display_id,
236 | 'x': int(bounds.origin.x),
237 | 'y': int(bounds.origin.y),
238 | 'width': int(bounds.size.width),
239 | 'height': int(bounds.size.height),
240 | 'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
241 | })
242 |
243 | # Sort screens by x position to arrange from left to right
244 | sorted_screens = sorted(screens, key=lambda s: s['x'])
245 |
246 | if self.selected_screen < 0 or self.selected_screen >= len(screens):
247 | raise IndexError("Invalid screen index.")
248 |
249 | screen = sorted_screens[self.selected_screen]
250 | bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
251 |
252 | else: # Linux or other OS
253 | cmd = "xrandr | grep ' primary' | awk '{print $4}'"
254 | try:
255 | output = subprocess.check_output(cmd, shell=True).decode()
256 | resolution = output.strip()
257 | # Parse the resolution format like "1920x1080+1920+0"
258 | # The format is "WIDTHxHEIGHT+X+Y"
259 | parts = resolution.split('+')[0] # Get just the "1920x1080" part
260 | width, height = map(int, parts.split('x'))
261 | # Get the X, Y offset if needed
262 | x_offset = int(resolution.split('+')[1]) if len(resolution.split('+')) > 1 else 0
263 | y_offset = int(resolution.split('+')[2]) if len(resolution.split('+')) > 2 else 0
264 | bbox = (x_offset, y_offset, x_offset + width, y_offset + height)
265 | except subprocess.CalledProcessError:
266 | raise RuntimeError("Failed to get screen resolution on Linux.")
267 |
268 | return bbox
269 |
270 |
271 |
272 | def _message_display_callback(messages):
273 | display_messages = []
274 | for msg in messages:
275 | try:
276 | if isinstance(msg["content"][0], TextBlock):
277 | display_messages.append((msg["content"][0].text, None)) # User message
278 | elif isinstance(msg["content"][0], BetaTextBlock):
279 | display_messages.append((None, msg["content"][0].text)) # Bot message
280 | elif isinstance(msg["content"][0], BetaToolUseBlock):
281 | display_messages.append((None, f"Tool Use: {msg['content'][0].name}\nInput: {msg['content'][0].input}")) # Bot message
282 | elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
283 | display_messages.append((None, f'')) # Bot message
284 | else:
285 | pass
286 | # print(msg["content"][0])
287 | except Exception as e:
288 | print("error", e)
289 | pass
290 | return display_messages
291 |
292 |
293 | def _make_api_tool_result(
294 | result: ToolResult, tool_use_id: str
295 | ) -> BetaToolResultBlockParam:
296 | """Convert an agent ToolResult to an API ToolResultBlockParam."""
297 | tool_result_content: list[BetaTextBlockParam | BetaImageBlockParam] | str = []
298 | is_error = False
299 | if result.error:
300 | is_error = True
301 | tool_result_content = _maybe_prepend_system_tool_result(result, result.error)
302 | else:
303 | if result.output:
304 | tool_result_content.append(
305 | {
306 | "type": "text",
307 | "text": _maybe_prepend_system_tool_result(result, result.output),
308 | }
309 | )
310 | if result.base64_image:
311 | tool_result_content.append(
312 | {
313 | "type": "image",
314 | "source": {
315 | "type": "base64",
316 | "media_type": "image/png",
317 | "data": result.base64_image,
318 | },
319 | }
320 | )
321 | return {
322 | "type": "tool_result",
323 | "content": tool_result_content,
324 | "tool_use_id": tool_use_id,
325 | "is_error": is_error,
326 | }
327 |
328 |
329 | def _maybe_prepend_system_tool_result(result: ToolResult, result_text: str):
330 | if result.system:
331 | result_text = f"{result.system}\n{result_text}"
332 | return result_text
333 |
334 |
335 |
336 | # Testing main function
337 | if __name__ == "__main__":
338 | def output_callback(content_block):
339 | # print("Output Callback:", content_block)
340 | pass
341 |
342 | def tool_output_callback(result, action):
343 | print("[showui_executor] Tool Output Callback:", result, action)
344 | pass
345 |
346 | # Instantiate the executor
347 | executor = ShowUIExecutor(
348 | output_callback=output_callback,
349 | tool_output_callback=tool_output_callback,
350 | selected_screen=0
351 | )
352 |
353 | # test inputs
354 | response_content = "{'content': \"{'action': 'CLICK', 'value': None, 'position': [0.49, 0.18]}\", 'role': 'assistant'}"
355 | # response_content = {'content': "{'action': 'CLICK', 'value': None, 'position': [0.49, 0.39]}", 'role': 'assistant'}
356 | # response_content = "{'content': \"{'action': 'CLICK', 'value': None, 'position': [0.49, 0.42]}, {'action': 'INPUT', 'value': 'weather for New York city', 'position': [0.49, 0.42]}, {'action': 'ENTER', 'value': None, 'position': None}\", 'role': 'assistant'}"
357 |
358 | # Initialize messages
359 | messages = []
360 |
361 | # Call the executor
362 | print("Testing ShowUIExecutor with response content:", response_content)
363 | for message, tool_result_content in executor(response_content, messages):
364 | print("Message:", message)
365 | print("Tool Result Content:", tool_result_content)
366 |
367 | # Display final messages
368 | print("\nFinal messages:")
369 | for msg in messages:
370 | print(msg)
371 |
372 |
373 |
374 | [
375 | {'role': 'user', 'content': ['open a new tab and go to amazon.com', 'tmp/outputs/screenshot_b4a1b7e60a5c47359bedbd8707573966.png']},
376 | {'role': 'assistant', 'content': ["History Action: {'action': 'mouse_move', 'text': None, 'coordinate': (1216, 88)}"]},
377 | {'role': 'assistant', 'content': ["History Action: {'action': 'left_click', 'text': None, 'coordinate': None}"]},
378 | {'content': [
379 | {'type': 'tool_result', 'content': [{'type': 'text', 'text': 'Moved mouse to (1216, 88)'}], 'tool_use_id': 'toolu_ae4f2886-366c-4789-9fa6-ec13461cef12', 'is_error': False},
380 | {'type': 'tool_result', 'content': [{'type': 'text', 'text': 'Performed left_click'}], 'tool_use_id': 'toolu_a7377954-e1b7-4746-9757-b2eb4dcddc82', 'is_error': False}
381 | ], 'role': 'user'}
382 | ]
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/actor/showui_agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import ast
3 | import base64
4 | from io import BytesIO
5 | from pathlib import Path
6 | from uuid import uuid4
7 |
8 | import pyautogui
9 | import requests
10 | import torch
11 | from PIL import Image, ImageDraw
12 | from qwen_vl_utils import process_vision_info
13 | from transformers import AutoProcessor, Qwen2VLForConditionalGeneration
14 |
15 | from computer_use_demo.gui_agent.llm_utils.oai import encode_image
16 | from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
17 | from computer_use_demo.tools.screen_capture import get_screenshot
18 |
19 | os.environ["TOKENIZERS_PARALLELISM"] = "false"
20 |
21 |
22 | class ShowUIActor:
23 | _NAV_SYSTEM = """
24 | You are an assistant trained to navigate the {_APP} screen.
25 | Given a task instruction, a screen observation, and an action history sequence,
26 | output the next action and wait for the next observation.
27 | Here is the action space:
28 | {_ACTION_SPACE}
29 | """
30 |
31 | _NAV_FORMAT = """
32 | Format the action as a dictionary with the following keys:
33 | {'action': 'ACTION_TYPE', 'value': 'element', 'position': [x,y]}
34 |
35 | If value or position is not applicable, set it as None.
36 | Position might be [[x1,y1], [x2,y2]] if the action requires a start and end position.
37 | Position represents the relative coordinates on the screenshot and should be scaled to a range of 0-1.
38 | """
39 |
40 | action_map = {
41 | 'desktop': """
42 | 1. CLICK: Click on an element, value is not applicable and the position [x,y] is required.
43 | 2. INPUT: Type a string into an element, value is a string to type and the position [x,y] is required.
44 | 3. HOVER: Hover on an element, value is not applicable and the position [x,y] is required.
45 | 4. ENTER: Enter operation, value and position are not applicable.
46 | 5. SCROLL: Scroll the screen, value is the direction to scroll and the position is not applicable.
47 | 6. ESC: ESCAPE operation, value and position are not applicable.
48 | 7. PRESS: Long click on an element, value is not applicable and the position [x,y] is required.
49 | """,
50 | 'phone': """
51 | 1. INPUT: Type a string into an element, value is not applicable and the position [x,y] is required.
52 | 2. SWIPE: Swipe the screen, value is not applicable and the position [[x1,y1], [x2,y2]] is the start and end position of the swipe operation.
53 | 3. TAP: Tap on an element, value is not applicable and the position [x,y] is required.
54 | 4. ANSWER: Answer the question, value is the status (e.g., 'task complete') and the position is not applicable.
55 | 5. ENTER: Enter operation, value and position are not applicable.
56 | """
57 | }
58 |
59 | def __init__(self, model_path, output_callback, device=torch.device("cpu"), split='desktop', selected_screen=0,
60 | max_pixels=1344, awq_4bit=False):
61 | self.device = device
62 | self.split = split
63 | self.selected_screen = selected_screen
64 | self.output_callback = output_callback
65 |
66 | if not model_path or not os.path.exists(model_path) or not os.listdir(model_path):
67 | if awq_4bit:
68 | model_path = "showlab/ShowUI-2B-AWQ-4bit"
69 | else:
70 | model_path = "showlab/ShowUI-2B"
71 |
72 | self.model = Qwen2VLForConditionalGeneration.from_pretrained(
73 | model_path,
74 | torch_dtype=torch.float16,
75 | device_map="cpu"
76 | ).to(self.device)
77 | self.model.eval()
78 |
79 | self.min_pixels = 256 * 28 * 28
80 | self.max_pixels = max_pixels * 28 * 28
81 | # self.max_pixels = 1344 * 28 * 28
82 |
83 | self.processor = AutoProcessor.from_pretrained(
84 | "Qwen/Qwen2-VL-2B-Instruct",
85 | # "./Qwen2-VL-2B-Instruct",
86 | min_pixels=self.min_pixels,
87 | max_pixels=self.max_pixels
88 | )
89 | self.system_prompt = self._NAV_SYSTEM.format(
90 | _APP=split,
91 | _ACTION_SPACE=self.action_map[split]
92 | )
93 | self.action_history = '' # Initialize action history
94 |
95 | def __call__(self, messages):
96 |
97 | task = messages
98 |
99 | # screenshot
100 | screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen, resize=True, target_width=1920, target_height=1080)
101 | screenshot_path = str(screenshot_path)
102 | image_base64 = encode_image(screenshot_path)
103 | self.output_callback(f'Screenshot for {colorful_text_showui}:\n', sender="bot")
104 |
105 | # Use system prompt, task, and action history to build the messages
106 | if len(self.action_history) == 0:
107 | messages_for_processor = [
108 | {
109 | "role": "user",
110 | "content": [
111 | {"type": "text", "text": self.system_prompt},
112 | {"type": "image", "image": screenshot_path, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
113 | {"type": "text", "text": f"Task: {task}"}
114 | ],
115 | }
116 | ]
117 | else:
118 | # https://github.com/showlab/ShowUI/issues/5
119 | messages_for_processor = [
120 | {
121 | "role": "user",
122 | "content": [
123 | {"type": "text", "text": self.system_prompt},
124 | {"type": "image", "image": screenshot_path, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
125 | {"type": "text", "text": f"Task: {task}"},
126 | {"type": "text", "text": self.action_history},
127 | ],
128 | }
129 | ]
130 |
131 | text = self.processor.apply_chat_template(
132 | messages_for_processor, tokenize=False, add_generation_prompt=True,
133 | )
134 | image_inputs, video_inputs = process_vision_info(messages_for_processor)
135 | inputs = self.processor(
136 | text=[text],
137 | images=image_inputs,
138 | videos=video_inputs,
139 | padding=True,
140 | return_tensors="pt",
141 | )
142 | inputs = inputs.to(self.device)
143 |
144 | with torch.no_grad():
145 | generated_ids = self.model.generate(**inputs, max_new_tokens=128)
146 |
147 | generated_ids_trimmed = [
148 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
149 | ]
150 | output_text = self.processor.batch_decode(
151 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
152 | )[0]
153 |
154 | # dummy output test
155 | # output_text = "{'action': 'CLICK', 'value': None, 'position': [0.49, 0.42]}"
156 |
157 | # Update action history
158 | self.action_history += output_text + '\n'
159 |
160 | # Return response in expected format
161 | response = {'content': output_text, 'role': 'assistant'}
162 | return response
163 |
164 |
165 | def parse_showui_output(self, output_text):
166 | try:
167 | # Ensure the output is stripped of any extra spaces
168 | output_text = output_text.strip()
169 |
170 | # Wrap the input in brackets if it looks like a single dictionary
171 | if output_text.startswith("{") and output_text.endswith("}"):
172 | output_text = f"[{output_text}]"
173 |
174 | # Validate if the output resembles a list of dictionaries
175 | if not (output_text.startswith("[") and output_text.endswith("]")):
176 | raise ValueError("Output does not look like a valid list or dictionary.")
177 |
178 | # Parse the output using ast.literal_eval
179 | parsed_output = ast.literal_eval(output_text)
180 |
181 | # Ensure the result is a list
182 | if isinstance(parsed_output, dict):
183 | parsed_output = [parsed_output]
184 | elif not isinstance(parsed_output, list):
185 | raise ValueError("Parsed output is neither a dictionary nor a list.")
186 |
187 | # Ensure all elements in the list are dictionaries
188 | if not all(isinstance(item, dict) for item in parsed_output):
189 | raise ValueError("Not all items in the parsed output are dictionaries.")
190 |
191 | return parsed_output
192 |
193 | except Exception as e:
194 | print(f"Error parsing output: {e}")
195 | return None
196 |
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/actor/uitars_agent.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from openai import OpenAI
4 |
5 | from computer_use_demo.gui_agent.llm_utils.oai import encode_image
6 | from computer_use_demo.tools.screen_capture import get_screenshot
7 | from computer_use_demo.tools.logger import logger, truncate_string
8 |
9 |
10 | class UITARS_Actor:
11 | """
12 | In OOTB, we use the default grounding system prompt form UI_TARS repo, and then convert its action to our action format.
13 | """
14 |
15 | _NAV_SYSTEM_GROUNDING = """
16 | You are a GUI agent. You are given a task and your action history, with screenshots. You need to perform the next action to complete the task.
17 |
18 | ## Output Format
19 | ```Action: ...```
20 |
21 | ## Action Space
22 | click(start_box='<|box_start|>(x1,y1)<|box_end|>')
23 | hotkey(key='')
24 | type(content='') #If you want to submit your input, use \"\" at the end of `content`.
25 | scroll(start_box='<|box_start|>(x1,y1)<|box_end|>', direction='down or up or right or left')
26 | wait() #Sleep for 5s and take a screenshot to check for any changes.
27 | finished()
28 | call_user() # Submit the task and call the user when the task is unsolvable, or when you need the user's help.
29 |
30 | ## Note
31 | - Do not generate any other text.
32 | """
33 |
34 | def __init__(self, ui_tars_url, output_callback, api_key="", selected_screen=0):
35 |
36 | self.ui_tars_url = ui_tars_url
37 | self.ui_tars_client = OpenAI(base_url=self.ui_tars_url, api_key=api_key)
38 | self.selected_screen = selected_screen
39 | self.output_callback = output_callback
40 |
41 | self.grounding_system_prompt = self._NAV_SYSTEM_GROUNDING.format()
42 |
43 |
44 | def __call__(self, messages):
45 |
46 | task = messages
47 |
48 | # take screenshot
49 | screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen, resize=True, target_width=1920, target_height=1080)
50 | screenshot_path = str(screenshot_path)
51 | screenshot_base64 = encode_image(screenshot_path)
52 |
53 | logger.info(f"Sending messages to UI-TARS on {self.ui_tars_url}: {task}, screenshot: {screenshot_path}")
54 |
55 | response = self.ui_tars_client.chat.completions.create(
56 | model="ui-tars",
57 | messages=[
58 | {"role": "system", "content": self.grounding_system_prompt},
59 | {"role": "user", "content": [
60 | {"type": "text", "text": task},
61 | {"type": "image_url", "image_url": {"url": f"data:image/png;base64,{screenshot_base64}"}}
62 | ]
63 | },
64 | ],
65 | max_tokens=256,
66 | temperature=0
67 | )
68 |
69 | ui_tars_action = response.choices[0].message.content
70 | converted_action = convert_ui_tars_action_to_json(ui_tars_action)
71 | response = str(converted_action)
72 |
73 | response = {'content': response, 'role': 'assistant'}
74 | return response
75 |
76 |
77 |
78 | def convert_ui_tars_action_to_json(action_str: str) -> str:
79 | """
80 | Converts an action line such as:
81 | Action: click(start_box='(153,97)')
82 | into a JSON string of the form:
83 | {
84 | "action": "CLICK",
85 | "value": null,
86 | "position": [153, 97]
87 | }
88 | """
89 |
90 | # Strip leading/trailing whitespace and remove "Action: " prefix if present
91 | action_str = action_str.strip()
92 | if action_str.startswith("Action:"):
93 | action_str = action_str[len("Action:"):].strip()
94 |
95 | # Mappings from old action names to the new action schema
96 | ACTION_MAP = {
97 | "click": "CLICK",
98 | "type": "INPUT",
99 | "scroll": "SCROLL",
100 | "wait": "STOP", # TODO: deal with "wait()"
101 | "finished": "STOP",
102 | "call_user": "STOP",
103 | "hotkey": "HOTKEY", # We break down the actual key below (Enter, Esc, etc.)
104 | }
105 |
106 | # Prepare a structure for the final JSON
107 | # Default to no position and null value
108 | output_dict = {
109 | "action": None,
110 | "value": None,
111 | "position": None
112 | }
113 |
114 | # 1) CLICK(...) e.g. click(start_box='(153,97)')
115 | match_click = re.match(r"^click\(start_box='\(?(\d+),\s*(\d+)\)?'\)$", action_str)
116 | if match_click:
117 | x, y = match_click.groups()
118 | output_dict["action"] = ACTION_MAP["click"]
119 | output_dict["position"] = [int(x), int(y)]
120 | return json.dumps(output_dict)
121 |
122 | # 2) HOTKEY(...) e.g. hotkey(key='Enter')
123 | match_hotkey = re.match(r"^hotkey\(key='([^']+)'\)$", action_str)
124 | if match_hotkey:
125 | key = match_hotkey.group(1).lower()
126 | if key == "enter":
127 | output_dict["action"] = "ENTER"
128 | elif key == "esc":
129 | output_dict["action"] = "ESC"
130 | else:
131 | # Otherwise treat it as some generic hotkey
132 | output_dict["action"] = ACTION_MAP["hotkey"]
133 | output_dict["value"] = key
134 | return json.dumps(output_dict)
135 |
136 | # 3) TYPE(...) e.g. type(content='some text')
137 | match_type = re.match(r"^type\(content='([^']*)'\)$", action_str)
138 | if match_type:
139 | typed_content = match_type.group(1)
140 | output_dict["action"] = ACTION_MAP["type"]
141 | output_dict["value"] = typed_content
142 | # If you want a position (x,y) you need it in your string. Otherwise it's omitted.
143 | return json.dumps(output_dict)
144 |
145 | # 4) SCROLL(...) e.g. scroll(start_box='(153,97)', direction='down')
146 | # or scroll(start_box='...', direction='down')
147 | match_scroll = re.match(
148 | r"^scroll\(start_box='[^']*'\s*,\s*direction='(down|up|left|right)'\)$",
149 | action_str
150 | )
151 | if match_scroll:
152 | direction = match_scroll.group(1)
153 | output_dict["action"] = ACTION_MAP["scroll"]
154 | output_dict["value"] = direction
155 | return json.dumps(output_dict)
156 |
157 | # 5) WAIT() or FINISHED() or CALL_USER() etc.
158 | if action_str in ["wait()", "finished()", "call_user()"]:
159 | base_action = action_str.replace("()", "")
160 | if base_action in ACTION_MAP:
161 | output_dict["action"] = ACTION_MAP[base_action]
162 | else:
163 | output_dict["action"] = "STOP"
164 | return json.dumps(output_dict)
165 |
166 | # If none of the above patterns matched, you can decide how to handle
167 | # unknown or unexpected action lines:
168 | output_dict["action"] = "STOP"
169 | return json.dumps(output_dict)
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/llm_utils/llm_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import ast
4 | import base64
5 |
6 |
7 | def is_image_path(text):
8 | # Checking if the input text ends with typical image file extensions
9 | image_extensions = (".jpg", ".jpeg", ".png", ".gif", ".bmp", ".tiff", ".tif")
10 | if text.endswith(image_extensions):
11 | return True
12 | else:
13 | return False
14 |
15 |
16 | def encode_image(image_path):
17 | """Encode image file to base64."""
18 | with open(image_path, "rb") as image_file:
19 | return base64.b64encode(image_file.read()).decode("utf-8")
20 |
21 |
22 | def is_url_or_filepath(input_string):
23 | # Check if input_string is a URL
24 | url_pattern = re.compile(
25 | r"http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*\\(\\),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+"
26 | )
27 | if url_pattern.match(input_string):
28 | return "URL"
29 |
30 | # Check if input_string is a file path
31 | file_path = os.path.abspath(input_string)
32 | if os.path.exists(file_path):
33 | return "File path"
34 |
35 | return "Invalid"
36 |
37 |
38 | def extract_data(input_string, data_type):
39 | # Regular expression to extract content starting from '```python' until the end if there are no closing backticks
40 | pattern = f"```{data_type}" + r"(.*?)(```|$)"
41 | # Extract content
42 | # re.DOTALL allows '.' to match newlines as well
43 | matches = re.findall(pattern, input_string, re.DOTALL)
44 | # Return the first match if exists, trimming whitespace and ignoring potential closing backticks
45 | return matches[0][0].strip() if matches else input_string
46 |
47 |
48 | def parse_input(code):
49 | """Use AST to parse the input string and extract the function name, arguments, and keyword arguments."""
50 |
51 | def get_target_names(target):
52 | """Recursively get all variable names from the assignment target."""
53 | if isinstance(target, ast.Name):
54 | return [target.id]
55 | elif isinstance(target, ast.Tuple):
56 | names = []
57 | for elt in target.elts:
58 | names.extend(get_target_names(elt))
59 | return names
60 | return []
61 |
62 | def extract_value(node):
63 | """提取 AST 节点的实际值"""
64 | if isinstance(node, ast.Constant):
65 | return node.value
66 | elif isinstance(node, ast.Name):
67 | # TODO: a better way to handle variables
68 | raise ValueError(
69 | f"Arguments should be a Constant, got a variable {node.id} instead."
70 | )
71 | # 添加其他需要处理的 AST 节点类型
72 | return None
73 |
74 | try:
75 | tree = ast.parse(code)
76 | for node in ast.walk(tree):
77 | if isinstance(node, ast.Assign):
78 | targets = []
79 | for t in node.targets:
80 | targets.extend(get_target_names(t))
81 | if isinstance(node.value, ast.Call):
82 | func_name = node.value.func.id
83 | args = [ast.dump(arg) for arg in node.value.args]
84 | kwargs = {
85 | kw.arg: extract_value(kw.value) for kw in node.value.keywords
86 | }
87 | print(f"Input: {code.strip()}")
88 | print(f"Output Variables: {targets}")
89 | print(f"Function Name: {func_name}")
90 | print(f"Arguments: {args}")
91 | print(f"Keyword Arguments: {kwargs}")
92 | elif isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
93 | targets = []
94 | func_name = extract_value(node.value.func)
95 | args = [extract_value(arg) for arg in node.value.args]
96 | kwargs = {kw.arg: extract_value(kw.value) for kw in node.value.keywords}
97 |
98 | except SyntaxError:
99 | print(f"Input: {code.strip()}")
100 | print("No match found")
101 |
102 | return targets, func_name, args, kwargs
103 |
104 |
105 | if __name__ == "__main__":
106 | import json
107 | s='{"Thinking": "The Docker icon has been successfully clicked, and the Docker application should now be opening. No further actions are required.", "Next Action": None}'
108 | json_str = json.loads(s)
109 | print(json_str)
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/llm_utils/oai.py:
--------------------------------------------------------------------------------
1 | import os
2 | import logging
3 | import base64
4 | import requests
5 | from computer_use_demo.gui_agent.llm_utils.llm_utils import is_image_path, encode_image
6 |
7 |
8 |
9 | def run_oai_interleaved(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
10 |
11 | api_key = api_key or os.environ.get("OPENAI_API_KEY")
12 | if not api_key:
13 | raise ValueError("OPENAI_API_KEY is not set")
14 |
15 | headers = {"Content-Type": "application/json",
16 | "Authorization": f"Bearer {api_key}"}
17 |
18 | final_messages = [{"role": "system", "content": system}]
19 |
20 | # image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
21 | if type(messages) == list:
22 | for item in messages:
23 | print(f"item: {item}")
24 | contents = []
25 | if isinstance(item, dict):
26 | for cnt in item["content"]:
27 | if isinstance(cnt, str):
28 | if is_image_path(cnt):
29 | base64_image = encode_image(cnt)
30 | content = {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
31 | else:
32 | content = {"type": "text", "text": cnt}
33 |
34 | # if isinstance(cnt, list):
35 |
36 | contents.append(content)
37 | message = {"role": item["role"], "content": contents}
38 |
39 | elif isinstance(item, str):
40 | if is_image_path(item):
41 | base64_image = encode_image(item)
42 | contents.append({"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}})
43 | message = {"role": "user", "content": contents}
44 | else:
45 | contents.append({"type": "text", "text": item})
46 | message = {"role": "user", "content": contents}
47 |
48 | else: # str
49 | contents.append({"type": "text", "text": item})
50 | message = {"role": "user", "content": contents}
51 |
52 | final_messages.append(message)
53 |
54 |
55 | elif isinstance(messages, str):
56 | final_messages.append({"role": "user", "content": messages})
57 |
58 | print("[oai] sending messages:", [f"{k}: {v}, {k}" for k, v in final_messages])
59 |
60 | payload = {
61 | "model": llm,
62 | "messages": final_messages,
63 | "max_tokens": max_tokens,
64 | "temperature": temperature,
65 | # "stop": stop,
66 | }
67 |
68 | # from IPython.core.debugger import Pdb; Pdb().set_trace()
69 |
70 | response = requests.post(
71 | "https://api.openai.com/v1/chat/completions", headers=headers, json=payload
72 | )
73 |
74 | try:
75 | text = response.json()['choices'][0]['message']['content']
76 | token_usage = int(response.json()['usage']['total_tokens'])
77 | return text, token_usage
78 |
79 | # return error message if the response is not successful
80 | except Exception as e:
81 | print(f"Error in interleaved openAI: {e}. This may due to your invalid OPENAI_API_KEY. Please check the response: {response.json()} ")
82 | return response.json()
83 |
84 | def run_ssh_llm_interleaved(messages: list, system: str, llm: str, ssh_host: str, ssh_port: int, max_tokens=256, temperature=0.7, do_sample=True):
85 | """Send chat completion request to SSH remote server"""
86 | from PIL import Image
87 | from io import BytesIO
88 | def encode_image(image_path: str, max_size=1024) -> str:
89 | """Convert image to base64 encoding with preprocessing"""
90 | try:
91 | with Image.open(image_path) as img:
92 | # Convert to RGB format
93 | img = img.convert('RGB')
94 |
95 | # Scale down if image is too large
96 | if max(img.size) > max_size:
97 | ratio = max_size / max(img.size)
98 | new_size = tuple(int(dim * ratio) for dim in img.size)
99 | img = img.resize(new_size, Image.LANCZOS)
100 |
101 | # Convert processed image to base64
102 | buffered = BytesIO()
103 | img.save(buffered, format="JPEG", quality=85)
104 | img_str = base64.b64encode(buffered.getvalue()).decode()
105 | return img_str
106 | except Exception as e:
107 | print(f"Image processing failed: {str(e)}")
108 | raise
109 |
110 |
111 | try:
112 | # Verify SSH connection info
113 | if not ssh_host or not ssh_port:
114 | raise ValueError("SSH_HOST and SSH_PORT are not set")
115 |
116 | # Build API URL
117 | api_url = f"http://{ssh_host}:{ssh_port}"
118 |
119 | # Prepare message list
120 | final_messages = []
121 |
122 | # Add system message
123 | if system:
124 | final_messages.append({
125 | "role": "system",
126 | "content": system
127 | })
128 |
129 | # Process user messages
130 | if type(messages) == list:
131 | for item in messages:
132 | contents = []
133 | if isinstance(item, dict):
134 | for cnt in item["content"]:
135 | if isinstance(cnt, str):
136 | if is_image_path(cnt):
137 | base64_image = encode_image(cnt)
138 | content = {
139 | "type": "image_url",
140 | "image_url": {
141 | "url": f"data:image/jpeg;base64,{base64_image}"
142 | }
143 | }
144 | else:
145 | content = {
146 | "type": "text",
147 | "text": cnt
148 | }
149 | contents.append(content)
150 | message = {"role": item["role"], "content": contents}
151 | else: # str
152 | contents.append({"type": "text", "text": item})
153 | message = {"role": "user", "content": contents}
154 | final_messages.append(message)
155 | elif isinstance(messages, str):
156 | final_messages.append({
157 | "role": "user",
158 | "content": messages
159 | })
160 |
161 | # Prepare request data
162 | data = {
163 | "model": llm,
164 | "messages": final_messages,
165 | "temperature": temperature,
166 | "max_tokens": max_tokens,
167 | "do_sample": do_sample
168 | }
169 |
170 | print(f"[ssh] Sending chat completion request to model: {llm}")
171 | print(f"[ssh] sending messages:", final_messages)
172 |
173 | # Send request
174 | response = requests.post(
175 | f"{api_url}/v1/chat/completions",
176 | json=data,
177 | headers={"Content-Type": "application/json"},
178 | timeout=30
179 | )
180 |
181 | result = response.json()
182 |
183 | if response.status_code == 200:
184 | content = result['choices'][0]['message']['content']
185 | token_usage = int(result['usage']['total_tokens'])
186 | print(f"[ssh] Generation successful: {content}")
187 | return content, token_usage
188 | else:
189 | print(f"[ssh] Request failed: {result}")
190 | raise Exception(f"API request failed: {result}")
191 |
192 | except Exception as e:
193 | print(f"[ssh] Chat completion request failed: {str(e)}")
194 | raise
195 |
196 |
197 |
198 | if __name__ == "__main__":
199 |
200 | api_key = os.environ.get("OPENAI_API_KEY")
201 | if not api_key:
202 | raise ValueError("OPENAI_API_KEY is not set")
203 |
204 | # text, token_usage = run_oai_interleaved(
205 | # messages= [{"content": [
206 | # "What is in the screenshot?",
207 | # "./tmp/outputs/screenshot_0b04acbb783d4706bc93873d17ba8c05.png"],
208 | # "role": "user"
209 | # }],
210 | # llm="gpt-4o-mini",
211 | # system="You are a helpful assistant",
212 | # api_key=api_key,
213 | # max_tokens=256,
214 | # temperature=0)
215 |
216 | # print(text, token_usage)
217 | text, token_usage = run_ssh_llm_interleaved(
218 | messages= [{"content": [
219 | "What is in the screenshot?",
220 | "tmp/outputs/screenshot_5a26d36c59e84272ab58c1b34493d40d.png"],
221 | "role": "user"
222 | }],
223 | llm="Qwen2.5-VL-7B-Instruct",
224 | ssh_host="10.245.92.68",
225 | ssh_port=9192,
226 | max_tokens=256,
227 | temperature=0.7
228 | )
229 | print(text, token_usage)
230 | # There is an introduction describing the Calyx... 36986
231 |
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/llm_utils/qwen.py:
--------------------------------------------------------------------------------
1 |
2 | import os
3 | import logging
4 | import base64
5 | import requests
6 |
7 | import dashscope
8 | # from computer_use_demo.gui_agent.llm_utils import is_image_path, encode_image
9 |
10 | def is_image_path(text):
11 | return False
12 |
13 | def encode_image(image_path):
14 | return ""
15 |
16 |
17 | def run_qwen(messages: list, system: str, llm: str, api_key: str, max_tokens=256, temperature=0):
18 |
19 | api_key = api_key or os.environ.get("QWEN_API_KEY")
20 | if not api_key:
21 | raise ValueError("QWEN_API_KEY is not set")
22 |
23 | dashscope.api_key = api_key
24 |
25 | # from IPython.core.debugger import Pdb; Pdb().set_trace()
26 |
27 | final_messages = [{"role": "system", "content": [{"text": system}]}]
28 | # image_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
29 | if type(messages) == list:
30 | for item in messages:
31 | contents = []
32 | if isinstance(item, dict):
33 | for cnt in item["content"]:
34 | if isinstance(cnt, str):
35 | if is_image_path(cnt):
36 | # base64_image = encode_image(cnt)
37 | content = [{"image": cnt}]
38 | # content = {"type": "image_url", "image_url": {"url": image_url}}
39 | else:
40 | content = {"text": cnt}
41 | contents.append(content)
42 |
43 | message = {"role": item["role"], "content": contents}
44 | else: # str
45 | contents.append({"text": item})
46 | message = {"role": "user", "content": contents}
47 |
48 | final_messages.append(message)
49 |
50 | print("[qwen-vl] sending messages:", final_messages)
51 |
52 | response = dashscope.MultiModalConversation.call(
53 | model='qwen-vl-max-latest',
54 | # model='qwen-vl-max-0809',
55 | messages=final_messages
56 | )
57 |
58 | # from IPython.core.debugger import Pdb; Pdb().set_trace()
59 |
60 | try:
61 | text = response.output.choices[0].message.content[0]['text']
62 | usage = response.usage
63 |
64 | if "total_tokens" not in usage:
65 | token_usage = int(usage["input_tokens"] + usage["output_tokens"])
66 | else:
67 | token_usage = int(usage["total_tokens"])
68 |
69 | return text, token_usage
70 | # return response.json()['choices'][0]['message']['content']
71 | # return error message if the response is not successful
72 | except Exception as e:
73 | print(f"Error in interleaved openAI: {e}. This may due to your invalid OPENAI_API_KEY. Please check the response: {response.json()} ")
74 | return response.json()
75 |
76 |
77 |
78 | if __name__ == "__main__":
79 | api_key = os.environ.get("QWEN_API_KEY")
80 | if not api_key:
81 | raise ValueError("QWEN_API_KEY is not set")
82 |
83 | dashscope.api_key = api_key
84 |
85 | final_messages = [{"role": "user",
86 | "content": [
87 | {"text": "What is in the screenshot?"},
88 | {"image": "./tmp/outputs/screenshot_0b04acbb783d4706bc93873d17ba8c05.png"}
89 | ]
90 | }
91 | ]
92 | response = dashscope.MultiModalConversation.call(model='qwen-vl-max-0809', messages=final_messages)
93 |
94 | print(response)
95 |
96 | text = response.output.choices[0].message.content[0]['text']
97 | usage = response.usage
98 |
99 | if "total_tokens" not in usage:
100 | if "image_tokens" in usage:
101 | token_usage = usage["input_tokens"] + usage["output_tokens"] + usage["image_tokens"]
102 | else:
103 | token_usage = usage["input_tokens"] + usage["output_tokens"]
104 | else:
105 | token_usage = usage["total_tokens"]
106 |
107 | print(text, token_usage)
108 | # The screenshot is from a video game... 1387
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/llm_utils/run_llm.py:
--------------------------------------------------------------------------------
1 | import base64
2 | import logging
3 | from .oai import run_oai_interleaved
4 | from .gemini import run_gemini_interleaved
5 |
6 | def run_llm(prompt, llm="gpt-4o-mini", max_tokens=256, temperature=0, stop=None):
7 | log_prompt(prompt)
8 |
9 | # turn string prompt into list
10 | if isinstance(prompt, str):
11 | prompt = [prompt]
12 | elif isinstance(prompt, list):
13 | pass
14 | else:
15 | raise ValueError(f"Invalid prompt type: {type(prompt)}")
16 |
17 | if llm.startswith("gpt"): # gpt series
18 | out = run_oai_interleaved(
19 | prompt,
20 | llm,
21 | max_tokens,
22 | temperature,
23 | stop
24 | )
25 | elif llm.startswith("gemini"): # gemini series
26 | out = run_gemini_interleaved(
27 | prompt,
28 | llm,
29 | max_tokens,
30 | temperature,
31 | stop
32 | )
33 | else:
34 | raise ValueError(f"Invalid llm: {llm}")
35 | logging.info(
36 | f"========Output for {llm}=======\n{out}\n============================")
37 | return out
38 |
39 | def log_prompt(prompt):
40 | prompt_display = [prompt] if isinstance(prompt, str) else prompt
41 | prompt_display = "\n\n".join(prompt_display)
42 | logging.info(
43 | f"========Prompt=======\n{prompt_display}\n============================")
44 |
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/planner/anthropic_agent.py:
--------------------------------------------------------------------------------
1 | """
2 | Agentic sampling loop that calls the Anthropic API and local implementation of anthropic-defined computer use tools.
3 | """
4 | import asyncio
5 | import platform
6 | from collections.abc import Callable
7 | from datetime import datetime
8 | from enum import StrEnum
9 | from typing import Any, cast
10 |
11 | from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
12 | from anthropic.types import (
13 | ToolResultBlockParam,
14 | )
15 | from anthropic.types.beta import (
16 | BetaContentBlock,
17 | BetaContentBlockParam,
18 | BetaImageBlockParam,
19 | BetaMessage,
20 | BetaMessageParam,
21 | BetaTextBlockParam,
22 | BetaToolResultBlockParam,
23 | )
24 | from anthropic.types import TextBlock
25 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock
26 |
27 | from computer_use_demo.tools import BashTool, ComputerTool, EditTool, ToolCollection, ToolResult
28 |
29 | from PIL import Image
30 | from io import BytesIO
31 | import gradio as gr
32 | from typing import Dict
33 |
34 |
35 | BETA_FLAG = "computer-use-2024-10-22"
36 |
37 |
38 | class APIProvider(StrEnum):
39 | ANTHROPIC = "anthropic"
40 | BEDROCK = "bedrock"
41 | VERTEX = "vertex"
42 |
43 |
44 | PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
45 | APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
46 | APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
47 | APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
48 | }
49 |
50 |
51 | # Check OS
52 | SYSTEM_PROMPT = f"""
53 | * You are utilizing a Windows system with internet access.
54 | * The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
55 |
56 | """
57 |
58 |
59 | class AnthropicActor:
60 | def __init__(
61 | self,
62 | model: str,
63 | provider: APIProvider,
64 | system_prompt_suffix: str,
65 | api_key: str,
66 | api_response_callback: Callable[[APIResponse[BetaMessage]], None],
67 | max_tokens: int = 4096,
68 | only_n_most_recent_images: int | None = None,
69 | selected_screen: int = 0,
70 | print_usage: bool = True,
71 | ):
72 | self.model = model
73 | self.provider = provider
74 | self.system_prompt_suffix = system_prompt_suffix
75 | self.api_key = api_key
76 | self.api_response_callback = api_response_callback
77 | self.max_tokens = max_tokens
78 | self.only_n_most_recent_images = only_n_most_recent_images
79 | self.selected_screen = selected_screen
80 |
81 | self.tool_collection = ToolCollection(
82 | ComputerTool(selected_screen=selected_screen),
83 | BashTool(),
84 | EditTool(),
85 | )
86 |
87 | self.system = (
88 | f"{SYSTEM_PROMPT}{' ' + system_prompt_suffix if system_prompt_suffix else ''}"
89 | )
90 |
91 | self.total_token_usage = 0
92 | self.total_cost = 0
93 | self.print_usage = print_usage
94 |
95 | # Instantiate the appropriate API client based on the provider
96 | print("provider:", provider)
97 | if provider == APIProvider.ANTHROPIC:
98 | self.client = Anthropic(api_key=api_key)
99 | elif provider == APIProvider.VERTEX:
100 | self.client = AnthropicVertex()
101 | elif provider == APIProvider.BEDROCK:
102 | self.client = AnthropicBedrock()
103 | else:
104 | raise ValueError(f"Provider {provider} not supported")
105 |
106 | def __call__(
107 | self,
108 | *,
109 | messages: list[BetaMessageParam]
110 | ):
111 | """
112 | Generate a response given history messages.
113 | """
114 | if self.only_n_most_recent_images:
115 | _maybe_filter_to_n_most_recent_images(messages, self.only_n_most_recent_images)
116 |
117 | # Call the API synchronously
118 | raw_response = self.client.beta.messages.with_raw_response.create(
119 | max_tokens=self.max_tokens,
120 | messages=messages,
121 | model=self.model,
122 | system=self.system,
123 | tools=self.tool_collection.to_params(),
124 | betas=["computer-use-2024-10-22"],
125 | )
126 |
127 | self.api_response_callback(cast(APIResponse[BetaMessage], raw_response))
128 |
129 | response = raw_response.parse()
130 | print(f"AnthropicActor response: {response}")
131 |
132 | self.total_token_usage += response.usage.input_tokens + response.usage.output_tokens
133 | self.total_cost += (response.usage.input_tokens * 3 / 1000000 + response.usage.output_tokens * 15 / 1000000)
134 |
135 | if self.print_usage:
136 | print(f"Claude total token usage so far: {self.total_token_usage}, total cost so far: $USD{self.total_cost}")
137 |
138 | return response
139 |
140 |
141 | def _maybe_filter_to_n_most_recent_images(
142 | messages: list[BetaMessageParam],
143 | images_to_keep: int,
144 | min_removal_threshold: int = 10,
145 | ):
146 | """
147 | With the assumption that images are screenshots that are of diminishing value as
148 | the conversation progresses, remove all but the final `images_to_keep` tool_result
149 | images in place, with a chunk of min_removal_threshold to reduce the amount we
150 | break the implicit prompt cache.
151 | """
152 | if images_to_keep is None:
153 | return messages
154 |
155 | tool_result_blocks = cast(
156 | list[ToolResultBlockParam],
157 | [
158 | item
159 | for message in messages
160 | for item in (
161 | message["content"] if isinstance(message["content"], list) else []
162 | )
163 | if isinstance(item, dict) and item.get("type") == "tool_result"
164 | ],
165 | )
166 |
167 | total_images = sum(
168 | 1
169 | for tool_result in tool_result_blocks
170 | for content in tool_result.get("content", [])
171 | if isinstance(content, dict) and content.get("type") == "image"
172 | )
173 |
174 | images_to_remove = total_images - images_to_keep
175 | # for better cache behavior, we want to remove in chunks
176 | images_to_remove -= images_to_remove % min_removal_threshold
177 |
178 | for tool_result in tool_result_blocks:
179 | if isinstance(tool_result.get("content"), list):
180 | new_content = []
181 | for content in tool_result.get("content", []):
182 | if isinstance(content, dict) and content.get("type") == "image":
183 | if images_to_remove > 0:
184 | images_to_remove -= 1
185 | continue
186 | new_content.append(content)
187 | tool_result["content"] = new_content
188 |
189 |
190 |
191 | if __name__ == "__main__":
192 | pass
193 | # client = Anthropic(api_key="")
194 | # response = client.beta.messages.with_raw_response.create(
195 | # max_tokens=4096,
196 | # model="claude-3-5-sonnet-20241022",
197 | # system=SYSTEM_PROMPT,
198 | # # tools=ToolCollection(
199 | # # ComputerTool(selected_screen=0),
200 | # # BashTool(),
201 | # # EditTool(),
202 | # # ).to_params(),
203 | # betas=["computer-use-2024-10-22"],
204 | # messages=[
205 | # {"role": "user", "content": "click on (199, 199)."}
206 | # ],
207 | # )
208 |
209 | # print(f"AnthropicActor response: {response.parse().usage.input_tokens+response.parse().usage.output_tokens}")
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/planner/api_vlm_planner.py:
--------------------------------------------------------------------------------
1 | import json
2 | import asyncio
3 | import platform
4 | from collections.abc import Callable
5 | from datetime import datetime
6 | from enum import StrEnum
7 | from typing import Any, cast, Dict, Callable
8 |
9 | from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
10 | from anthropic.types import TextBlock, ToolResultBlockParam
11 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam
12 |
13 | from computer_use_demo.tools.screen_capture import get_screenshot
14 | from computer_use_demo.gui_agent.llm_utils.oai import run_oai_interleaved, run_ssh_llm_interleaved
15 | from computer_use_demo.gui_agent.llm_utils.qwen import run_qwen
16 | from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data, encode_image
17 | from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
18 |
19 |
20 | class APIVLMPlanner:
21 | def __init__(
22 | self,
23 | model: str,
24 | provider: str,
25 | system_prompt_suffix: str,
26 | api_key: str,
27 | output_callback: Callable,
28 | api_response_callback: Callable,
29 | max_tokens: int = 4096,
30 | only_n_most_recent_images: int | None = None,
31 | selected_screen: int = 0,
32 | print_usage: bool = True,
33 | ):
34 | if model == "gpt-4o":
35 | self.model = "gpt-4o-2024-11-20"
36 | elif model == "gpt-4o-mini":
37 | self.model = "gpt-4o-mini" # "gpt-4o-mini"
38 | elif model == "qwen2-vl-max":
39 | self.model = "qwen2-vl-max"
40 | elif model == "qwen2-vl-2b (ssh)":
41 | self.model = "Qwen2-VL-2B-Instruct"
42 | elif model == "qwen2-vl-7b (ssh)":
43 | self.model = "Qwen2-VL-7B-Instruct"
44 | elif model == "qwen2.5-vl-7b (ssh)":
45 | self.model = "Qwen2.5-VL-7B-Instruct"
46 | else:
47 | raise ValueError(f"Model {model} not supported")
48 |
49 | self.provider = provider
50 | self.system_prompt_suffix = system_prompt_suffix
51 | self.api_key = api_key
52 | self.api_response_callback = api_response_callback
53 | self.max_tokens = max_tokens
54 | self.only_n_most_recent_images = only_n_most_recent_images
55 | self.selected_screen = selected_screen
56 | self.output_callback = output_callback
57 | self.system_prompt = self._get_system_prompt() + self.system_prompt_suffix
58 |
59 |
60 | self.print_usage = print_usage
61 | self.total_token_usage = 0
62 | self.total_cost = 0
63 |
64 |
65 | def __call__(self, messages: list):
66 |
67 | # drop looping actions msg, byte image etc
68 | planner_messages = _message_filter_callback(messages)
69 | print(f"filtered_messages: {planner_messages}")
70 |
71 | if self.only_n_most_recent_images:
72 | _maybe_filter_to_n_most_recent_images(planner_messages, self.only_n_most_recent_images)
73 |
74 | # Take a screenshot
75 | screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen)
76 | screenshot_path = str(screenshot_path)
77 | image_base64 = encode_image(screenshot_path)
78 | self.output_callback(f'Screenshot for {colorful_text_vlm}:\n',
79 | sender="bot")
80 |
81 | # if isinstance(planner_messages[-1], dict):
82 | # if not isinstance(planner_messages[-1]["content"], list):
83 | # planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
84 | # planner_messages[-1]["content"].append(screenshot_path)
85 | # elif isinstance(planner_messages[-1], str):
86 | # planner_messages[-1] = {"role": "user", "content": [{"type": "text", "text": planner_messages[-1]}]}
87 |
88 | # append screenshot
89 | # planner_messages.append({"role": "user", "content": [{"type": "image", "image": screenshot_path}]})
90 |
91 | planner_messages.append(screenshot_path)
92 |
93 | print(f"Sending messages to VLMPlanner: {planner_messages}")
94 |
95 | if self.model == "gpt-4o-2024-11-20":
96 | vlm_response, token_usage = run_oai_interleaved(
97 | messages=planner_messages,
98 | system=self.system_prompt,
99 | llm=self.model,
100 | api_key=self.api_key,
101 | max_tokens=self.max_tokens,
102 | temperature=0,
103 | )
104 | print(f"oai token usage: {token_usage}")
105 | self.total_token_usage += token_usage
106 | self.total_cost += (token_usage * 0.15 / 1000000) # https://openai.com/api/pricing/
107 |
108 | elif self.model == "qwen2-vl-max":
109 | vlm_response, token_usage = run_qwen(
110 | messages=planner_messages,
111 | system=self.system_prompt,
112 | llm=self.model,
113 | api_key=self.api_key,
114 | max_tokens=self.max_tokens,
115 | temperature=0,
116 | )
117 | print(f"qwen token usage: {token_usage}")
118 | self.total_token_usage += token_usage
119 | self.total_cost += (token_usage * 0.02 / 7.25 / 1000) # 1USD=7.25CNY, https://help.aliyun.com/zh/dashscope/developer-reference/tongyi-qianwen-vl-plus-api
120 | elif "Qwen" in self.model:
121 | # 从api_key中解析host和port
122 | try:
123 | ssh_host, ssh_port = self.api_key.split(":")
124 | ssh_port = int(ssh_port)
125 | except ValueError:
126 | raise ValueError("Invalid SSH connection string. Expected format: host:port")
127 |
128 | vlm_response, token_usage = run_ssh_llm_interleaved(
129 | messages=planner_messages,
130 | system=self.system_prompt,
131 | llm=self.model,
132 | ssh_host=ssh_host,
133 | ssh_port=ssh_port,
134 | max_tokens=self.max_tokens,
135 | )
136 | else:
137 | raise ValueError(f"Model {self.model} not supported")
138 |
139 | print(f"VLMPlanner response: {vlm_response}")
140 |
141 | if self.print_usage:
142 | print(f"VLMPlanner total token usage so far: {self.total_token_usage}. Total cost so far: $USD{self.total_cost:.5f}")
143 |
144 | vlm_response_json = extract_data(vlm_response, "json")
145 |
146 | # vlm_plan_str = '\n'.join([f'{key}: {value}' for key, value in json.loads(response).items()])
147 | vlm_plan_str = ""
148 | for key, value in json.loads(vlm_response_json).items():
149 | if key == "Thinking":
150 | vlm_plan_str += f'{value}'
151 | else:
152 | vlm_plan_str += f'\n{key}: {value}'
153 |
154 | self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
155 |
156 | return vlm_response_json
157 |
158 |
159 | def _api_response_callback(self, response: APIResponse):
160 | self.api_response_callback(response)
161 |
162 |
163 | def reformat_messages(self, messages: list):
164 | pass
165 |
166 | def _get_system_prompt(self):
167 | os_name = platform.system()
168 | return f"""
169 | You are using an {os_name} device.
170 | You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
171 | You can only interact with the desktop GUI (no terminal or application menu access).
172 |
173 | You may be given some history plan and actions, this is the response from the previous loop.
174 | You should carefully consider your plan base on the task, screenshot, and history actions.
175 |
176 | Your available "Next Action" only include:
177 | - ENTER: Press an enter key.
178 | - ESCAPE: Press an ESCAPE key.
179 | - INPUT: Input a string of text.
180 | - CLICK: Describe the ui element to be clicked.
181 | - HOVER: Describe the ui element to be hovered.
182 | - SCROLL: Scroll the screen, you must specify up or down.
183 | - PRESS: Describe the ui element to be pressed.
184 |
185 |
186 | Output format:
187 | ```json
188 | {{
189 | "Thinking": str, # describe your thoughts on how to achieve the task, choose one action from available actions at a time.
190 | "Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
191 | }}
192 | ```
193 |
194 | One Example:
195 | ```json
196 | {{
197 | "Thinking": "I need to search and navigate to amazon.com.",
198 | "Next Action": "CLICK 'Search Google or type a URL'."
199 | }}
200 | ```
201 |
202 | IMPORTANT NOTES:
203 | 1. Carefully observe the screenshot to understand the current state and read history actions.
204 | 2. You should only give a single action at a time. for example, INPUT text, and ENTER can't be in one Next Action.
205 | 3. Attach the text to Next Action, if there is text or any description for the button.
206 | 4. You should not include other actions, such as keyboard shortcuts.
207 | 5. When the task is completed, you should say "Next Action": "None" in the json field.
208 | """
209 |
210 |
211 |
212 | def _maybe_filter_to_n_most_recent_images(
213 | messages: list[BetaMessageParam],
214 | images_to_keep: int,
215 | min_removal_threshold: int = 10,
216 | ):
217 | """
218 | With the assumption that images are screenshots that are of diminishing value as
219 | the conversation progresses, remove all but the final `images_to_keep` tool_result
220 | images in place, with a chunk of min_removal_threshold to reduce the amount we
221 | break the implicit prompt cache.
222 | """
223 | if images_to_keep is None:
224 | return messages
225 |
226 | tool_result_blocks = cast(
227 | list[ToolResultBlockParam],
228 | [
229 | item
230 | for message in messages
231 | for item in (
232 | message["content"] if isinstance(message["content"], list) else []
233 | )
234 | if isinstance(item, dict) and item.get("type") == "tool_result"
235 | ],
236 | )
237 |
238 | total_images = sum(
239 | 1
240 | for tool_result in tool_result_blocks
241 | for content in tool_result.get("content", [])
242 | if isinstance(content, dict) and content.get("type") == "image"
243 | )
244 |
245 | images_to_remove = total_images - images_to_keep
246 | # for better cache behavior, we want to remove in chunks
247 | images_to_remove -= images_to_remove % min_removal_threshold
248 |
249 | for tool_result in tool_result_blocks:
250 | if isinstance(tool_result.get("content"), list):
251 | new_content = []
252 | for content in tool_result.get("content", []):
253 | if isinstance(content, dict) and content.get("type") == "image":
254 | if images_to_remove > 0:
255 | images_to_remove -= 1
256 | continue
257 | new_content.append(content)
258 | tool_result["content"] = new_content
259 |
260 |
261 | def _message_filter_callback(messages):
262 | filtered_list = []
263 | try:
264 | for msg in messages:
265 | if msg.get('role') in ['user']:
266 | if not isinstance(msg["content"], list):
267 | msg["content"] = [msg["content"]]
268 | if isinstance(msg["content"][0], TextBlock):
269 | filtered_list.append(str(msg["content"][0].text)) # User message
270 | elif isinstance(msg["content"][0], str):
271 | filtered_list.append(msg["content"][0]) # User message
272 | else:
273 | print("[_message_filter_callback]: drop message", msg)
274 | continue
275 |
276 | # elif msg.get('role') in ['assistant']:
277 | # if isinstance(msg["content"][0], TextBlock):
278 | # msg["content"][0] = str(msg["content"][0].text)
279 | # elif isinstance(msg["content"][0], BetaTextBlock):
280 | # msg["content"][0] = str(msg["content"][0].text)
281 | # elif isinstance(msg["content"][0], BetaToolUseBlock):
282 | # msg["content"][0] = str(msg['content'][0].input)
283 | # elif isinstance(msg["content"][0], Dict) and msg["content"][0]["content"][-1]["type"] == "image":
284 | # msg["content"][0] = f''
285 | # else:
286 | # print("[_message_filter_callback]: drop message", msg)
287 | # continue
288 | # filtered_list.append(msg["content"][0]) # User message
289 |
290 | else:
291 | print("[_message_filter_callback]: drop message", msg)
292 | continue
293 |
294 | except Exception as e:
295 | print("[_message_filter_callback]: error", e)
296 |
297 | return filtered_list
--------------------------------------------------------------------------------
/computer_use_demo/gui_agent/planner/local_vlm_planner.py:
--------------------------------------------------------------------------------
1 | import json
2 | import asyncio
3 | import platform
4 | from collections.abc import Callable
5 | from datetime import datetime
6 | from enum import StrEnum
7 | from typing import Any, cast, Dict, Callable
8 |
9 | from anthropic import Anthropic, AnthropicBedrock, AnthropicVertex, APIResponse
10 | from anthropic.types import TextBlock, ToolResultBlockParam
11 | from anthropic.types.beta import BetaMessage, BetaTextBlock, BetaToolUseBlock, BetaMessageParam
12 |
13 | from computer_use_demo.tools.screen_capture import get_screenshot
14 | from computer_use_demo.gui_agent.llm_utils.llm_utils import extract_data, encode_image
15 | from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
16 |
17 | import torch
18 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
19 | from qwen_vl_utils import process_vision_info
20 |
21 | SYSTEM_PROMPT = f"""
22 | * You are utilizing a Windows system with internet access.
23 | * The current date is {datetime.today().strftime('%A, %B %d, %Y')}.
24 |
25 | """
26 |
27 | MODEL_TO_HF_PATH = {
28 | "qwen-vl-7b-instruct": "Qwen/Qwen2-VL-7B-Instruct",
29 | "qwen2-vl-2b-instruct": "Qwen/Qwen2-VL-2B-Instruct",
30 | "qwen2.5-vl-3b-instruct": "Qwen/Qwen2.5-VL-3B-Instruct",
31 | "qwen2.5-vl-7b-instruct": "Qwen/Qwen2.5-VL-7B-Instruct",
32 | }
33 |
34 |
35 | class LocalVLMPlanner:
36 | def __init__(
37 | self,
38 | model: str,
39 | provider: str,
40 | system_prompt_suffix: str,
41 | output_callback: Callable,
42 | api_response_callback: Callable,
43 | max_tokens: int = 4096,
44 | only_n_most_recent_images: int | None = None,
45 | selected_screen: int = 0,
46 | print_usage: bool = True,
47 | device: torch.device = torch.device("cpu"),
48 | ):
49 | self.device = device
50 | self.min_pixels = 256 * 28 * 28
51 | self.max_pixels = 1344 * 28 * 28
52 | self.model_name = model
53 | if model in MODEL_TO_HF_PATH:
54 | self.hf_path = MODEL_TO_HF_PATH[model]
55 | else:
56 | raise ValueError(f"Model {model} not supported for local VLM planner")
57 |
58 | self.model = Qwen2VLForConditionalGeneration.from_pretrained(
59 | self.hf_path,
60 | torch_dtype=torch.float16,
61 | device_map="cpu"
62 | ).to(self.device)
63 | self.processor = AutoProcessor.from_pretrained(
64 | self.hf_path,
65 | min_pixels=self.min_pixels,
66 | max_pixels=self.max_pixels
67 | )
68 |
69 | self.provider = provider
70 | self.system_prompt_suffix = system_prompt_suffix
71 | self.api_response_callback = api_response_callback
72 | self.max_tokens = max_tokens
73 | self.only_n_most_recent_images = only_n_most_recent_images
74 | self.selected_screen = selected_screen
75 | self.output_callback = output_callback
76 | self.system_prompt = self._get_system_prompt() + self.system_prompt_suffix
77 |
78 | self.print_usage = print_usage
79 | self.total_token_usage = 0
80 | self.total_cost = 0
81 |
82 |
83 | def __call__(self, messages: list):
84 |
85 | # drop looping actions msg, byte image etc
86 | planner_messages = _message_filter_callback(messages)
87 | print(f"filtered_messages: {planner_messages}")
88 |
89 | # Take a screenshot
90 | screenshot, screenshot_path = get_screenshot(selected_screen=self.selected_screen)
91 | screenshot_path = str(screenshot_path)
92 | image_base64 = encode_image(screenshot_path)
93 | self.output_callback(f'Screenshot for {colorful_text_vlm}:\n',
94 | sender="bot")
95 |
96 | if isinstance(planner_messages[-1], dict):
97 | if not isinstance(planner_messages[-1]["content"], list):
98 | planner_messages[-1]["content"] = [planner_messages[-1]["content"]]
99 | planner_messages[-1]["content"].append(screenshot_path)
100 |
101 | print(f"Sending messages to VLMPlanner: {planner_messages}")
102 |
103 | messages_for_processor = [
104 | {
105 | "role": "system",
106 | "content": [{"type": "text", "text": self.system_prompt}]
107 | },
108 | {
109 | "role": "user",
110 | "content": [
111 | {"type": "image", "image": screenshot_path, "min_pixels": self.min_pixels, "max_pixels": self.max_pixels},
112 | {"type": "text", "text": f"Task: {''.join(planner_messages)}"}
113 | ],
114 | }]
115 |
116 | text = self.processor.apply_chat_template(
117 | messages_for_processor, tokenize=False, add_generation_prompt=True
118 | )
119 | image_inputs, video_inputs = process_vision_info(messages_for_processor)
120 |
121 | inputs = self.processor(
122 | text=[text],
123 | images=image_inputs,
124 | videos=video_inputs,
125 | padding=True,
126 | return_tensors="pt",
127 | )
128 | inputs = inputs.to(self.device)
129 |
130 | generated_ids = self.model.generate(**inputs, max_new_tokens=128)
131 | generated_ids_trimmed = [
132 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
133 | ]
134 | vlm_response = self.processor.batch_decode(
135 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
136 | )[0]
137 |
138 | print(f"VLMPlanner response: {vlm_response}")
139 |
140 | vlm_response_json = extract_data(vlm_response, "json")
141 |
142 | # vlm_plan_str = '\n'.join([f'{key}: {value}' for key, value in json.loads(response).items()])
143 | vlm_plan_str = ""
144 | for key, value in json.loads(vlm_response_json).items():
145 | if key == "Thinking":
146 | vlm_plan_str += f'{value}'
147 | else:
148 | vlm_plan_str += f'\n{key}: {value}'
149 |
150 | self.output_callback(f"{colorful_text_vlm}:\n{vlm_plan_str}", sender="bot")
151 |
152 | return vlm_response_json
153 |
154 |
155 | def _api_response_callback(self, response: APIResponse):
156 | self.api_response_callback(response)
157 |
158 |
159 | def reformat_messages(self, messages: list):
160 | pass
161 |
162 | def _get_system_prompt(self):
163 | os_name = platform.system()
164 | return f"""
165 | You are using an {os_name} device.
166 | You are able to use a mouse and keyboard to interact with the computer based on the given task and screenshot.
167 | You can only interact with the desktop GUI (no terminal or application menu access).
168 |
169 | You may be given some history plan and actions, this is the response from the previous loop.
170 | You should carefully consider your plan base on the task, screenshot, and history actions.
171 |
172 | Your available "Next Action" only include:
173 | - ENTER: Press an enter key.
174 | - ESCAPE: Press an ESCAPE key.
175 | - INPUT: Input a string of text.
176 | - CLICK: Describe the ui element to be clicked.
177 | - HOVER: Describe the ui element to be hovered.
178 | - SCROLL: Scroll the screen, you must specify up or down.
179 | - PRESS: Describe the ui element to be pressed.
180 |
181 | Output format:
182 | ```json
183 | {{
184 | "Thinking": str, # describe your thoughts on how to achieve the task, choose one action from available actions at a time.
185 | "Next Action": "action_type, action description" | "None" # one action at a time, describe it in short and precisely.
186 | }}
187 | ```
188 |
189 | One Example:
190 | ```json
191 | {{
192 | "Thinking": "I need to search and navigate to amazon.com.",
193 | "Next Action": "CLICK 'Search Google or type a URL'."
194 | }}
195 | ```
196 |
197 | IMPORTANT NOTES:
198 | 1. Carefully observe the screenshot to understand the current state and read history actions.
199 | 2. You should only give a single action at a time. for example, INPUT text, and ENTER can't be in one Next Action.
200 | 3. Attach the text to Next Action, if there is text or any description for the button.
201 | 4. You should not include other actions, such as keyboard shortcuts.
202 | 5. When the task is completed, you should say "Next Action": "None" in the json field.
203 | """
204 |
205 | def _message_filter_callback(messages):
206 | filtered_list = []
207 | try:
208 | for msg in messages:
209 | if msg.get('role') in ['user']:
210 | if not isinstance(msg["content"], list):
211 | msg["content"] = [msg["content"]]
212 | if isinstance(msg["content"][0], TextBlock):
213 | filtered_list.append(str(msg["content"][0].text)) # User message
214 | elif isinstance(msg["content"][0], str):
215 | filtered_list.append(msg["content"][0]) # User message
216 | else:
217 | print("[_message_filter_callback]: drop message", msg)
218 | continue
219 |
220 | else:
221 | print("[_message_filter_callback]: drop message", msg)
222 | continue
223 |
224 | except Exception as e:
225 | print("[_message_filter_callback]: error", e)
226 |
227 | return filtered_list
--------------------------------------------------------------------------------
/computer_use_demo/loop.py:
--------------------------------------------------------------------------------
1 | """
2 | Agentic sampling loop that calls the Anthropic API and local implementation of computer use tools.
3 | """
4 | import time
5 | import json
6 | from collections.abc import Callable
7 | from enum import StrEnum
8 |
9 | from anthropic import APIResponse
10 | from anthropic.types.beta import BetaContentBlock, BetaMessage, BetaMessageParam
11 | from computer_use_demo.tools import ToolResult
12 |
13 | from computer_use_demo.tools.colorful_text import colorful_text_showui, colorful_text_vlm
14 | from computer_use_demo.tools.screen_capture import get_screenshot
15 | from computer_use_demo.gui_agent.llm_utils.oai import encode_image
16 | from computer_use_demo.tools.logger import logger
17 |
18 |
19 |
20 | class APIProvider(StrEnum):
21 | ANTHROPIC = "anthropic"
22 | BEDROCK = "bedrock"
23 | VERTEX = "vertex"
24 | OPENAI = "openai"
25 | QWEN = "qwen"
26 | SSH = "ssh"
27 |
28 |
29 | PROVIDER_TO_DEFAULT_MODEL_NAME: dict[APIProvider, str] = {
30 | APIProvider.ANTHROPIC: "claude-3-5-sonnet-20241022",
31 | APIProvider.BEDROCK: "anthropic.claude-3-5-sonnet-20241022-v2:0",
32 | APIProvider.VERTEX: "claude-3-5-sonnet-v2@20241022",
33 | APIProvider.OPENAI: "gpt-4o",
34 | APIProvider.QWEN: "qwen2vl",
35 | APIProvider.SSH: "qwen2-vl-2b",
36 | }
37 |
38 | PLANNER_MODEL_CHOICES_MAPPING = {
39 | "claude-3-5-sonnet-20241022": "claude-3-5-sonnet-20241022",
40 | "gpt-4o": "gpt-4o",
41 | "gpt-4o-mini": "gpt-4o-mini",
42 | "qwen2-vl-max": "qwen2-vl-max",
43 | "qwen2-vl-2b (local)": "qwen2-vl-2b-instruct",
44 | "qwen2-vl-7b (local)": "qwen2-vl-7b-instruct",
45 | "qwen2.5-vl-3b (local)": "qwen2.5-vl-3b-instruct",
46 | "qwen2.5-vl-7b (local)": "qwen2.5-vl-7b-instruct",
47 | "qwen2-vl-2b (ssh)": "qwen2-vl-2b (ssh)",
48 | "qwen2-vl-7b (ssh)": "qwen2-vl-7b (ssh)",
49 | }
50 |
51 |
52 | def sampling_loop_sync(
53 | *,
54 | planner_model: str,
55 | planner_provider: APIProvider | None,
56 | actor_model: str,
57 | actor_provider: APIProvider | None,
58 | system_prompt_suffix: str,
59 | messages: list[BetaMessageParam],
60 | output_callback: Callable[[BetaContentBlock], None],
61 | tool_output_callback: Callable[[ToolResult, str], None],
62 | api_response_callback: Callable[[APIResponse[BetaMessage]], None],
63 | api_key: str,
64 | only_n_most_recent_images: int | None = None,
65 | max_tokens: int = 4096,
66 | selected_screen: int = 0,
67 | showui_max_pixels: int = 1344,
68 | showui_awq_4bit: bool = False,
69 | ui_tars_url: str = ""
70 | ):
71 | """
72 | Synchronous agentic sampling loop for the assistant/tool interaction of computer use.
73 | """
74 |
75 | # ---------------------------
76 | # Initialize Planner
77 | # ---------------------------
78 |
79 | if planner_model in PLANNER_MODEL_CHOICES_MAPPING:
80 | planner_model = PLANNER_MODEL_CHOICES_MAPPING[planner_model]
81 | else:
82 | raise ValueError(f"Planner Model {planner_model} not supported")
83 |
84 | if planner_model == "claude-3-5-sonnet-20241022":
85 |
86 | from computer_use_demo.gui_agent.planner.anthropic_agent import AnthropicActor
87 | from computer_use_demo.executor.anthropic_executor import AnthropicExecutor
88 |
89 | # Register Actor and Executor
90 | actor = AnthropicActor(
91 | model=planner_model,
92 | provider=actor_provider,
93 | system_prompt_suffix=system_prompt_suffix,
94 | api_key=api_key,
95 | api_response_callback=api_response_callback,
96 | max_tokens=max_tokens,
97 | only_n_most_recent_images=only_n_most_recent_images,
98 | selected_screen=selected_screen
99 | )
100 |
101 | executor = AnthropicExecutor(
102 | output_callback=output_callback,
103 | tool_output_callback=tool_output_callback,
104 | selected_screen=selected_screen
105 | )
106 |
107 | loop_mode = "unified"
108 |
109 | elif planner_model in ["gpt-4o", "gpt-4o-mini", "qwen2-vl-max"]:
110 |
111 | from computer_use_demo.gui_agent.planner.api_vlm_planner import APIVLMPlanner
112 |
113 | planner = APIVLMPlanner(
114 | model=planner_model,
115 | provider=planner_provider,
116 | system_prompt_suffix=system_prompt_suffix,
117 | api_key=api_key,
118 | api_response_callback=api_response_callback,
119 | selected_screen=selected_screen,
120 | output_callback=output_callback,
121 | )
122 | loop_mode = "planner + actor"
123 |
124 | elif planner_model in ["qwen2-vl-2b-instruct", "qwen2-vl-7b-instruct"]:
125 |
126 | import torch
127 | from computer_use_demo.gui_agent.planner.local_vlm_planner import LocalVLMPlanner
128 | if torch.cuda.is_available(): device = torch.device("cuda")
129 | elif torch.backends.mps.is_available(): device = torch.device("mps")
130 | else: device = torch.device("cpu") # support: 'cpu', 'mps', 'cuda'
131 | logger.info(f"Planner model {planner_model} inited on device: {device}.")
132 |
133 | planner = LocalVLMPlanner(
134 | model=planner_model,
135 | provider=planner_provider,
136 | system_prompt_suffix=system_prompt_suffix,
137 | api_key=api_key,
138 | api_response_callback=api_response_callback,
139 | selected_screen=selected_screen,
140 | output_callback=output_callback,
141 | device=device
142 | )
143 | loop_mode = "planner + actor"
144 |
145 | elif "ssh" in planner_model:
146 | planner = APIVLMPlanner(
147 | model=planner_model,
148 | provider=planner_provider,
149 | system_prompt_suffix=system_prompt_suffix,
150 | api_key=api_key,
151 | api_response_callback=api_response_callback,
152 | selected_screen=selected_screen,
153 | output_callback=output_callback,
154 | )
155 | loop_mode = "planner + actor"
156 | else:
157 | logger.error(f"Planner Model {planner_model} not supported")
158 | raise ValueError(f"Planner Model {planner_model} not supported")
159 |
160 |
161 | # ---------------------------
162 | # Initialize Actor, Executor
163 | # ---------------------------
164 | if actor_model == "ShowUI":
165 |
166 | from computer_use_demo.executor.showui_executor import ShowUIExecutor
167 | from computer_use_demo.gui_agent.actor.showui_agent import ShowUIActor
168 | if showui_awq_4bit:
169 | showui_model_path = "./showui-2b-awq-4bit/"
170 | else:
171 | showui_model_path = "./showui-2b/"
172 |
173 | import torch
174 | if torch.cuda.is_available(): device = torch.device("cuda")
175 | elif torch.backends.mps.is_available(): device = torch.device("mps")
176 | else: device = torch.device("cpu") # support: 'cpu', 'mps', 'cuda'
177 | logger.info(f"Actor model {actor_model} inited on device: {device}.")
178 |
179 | actor = ShowUIActor(
180 | model_path=showui_model_path,
181 | device=device,
182 | split='desktop', # 'desktop' or 'phone'
183 | selected_screen=selected_screen,
184 | output_callback=output_callback,
185 | max_pixels=showui_max_pixels,
186 | awq_4bit=showui_awq_4bit
187 | )
188 |
189 | executor = ShowUIExecutor(
190 | output_callback=output_callback,
191 | tool_output_callback=tool_output_callback,
192 | selected_screen=selected_screen
193 | )
194 |
195 | elif actor_model == "UI-TARS":
196 |
197 | from computer_use_demo.executor.showui_executor import ShowUIExecutor
198 | from computer_use_demo.gui_agent.actor.uitars_agent import UITARS_Actor
199 |
200 | actor = UITARS_Actor(
201 | ui_tars_url=ui_tars_url,
202 | output_callback=output_callback,
203 | selected_screen=selected_screen
204 | )
205 |
206 | executor = ShowUIExecutor(
207 | output_callback=output_callback,
208 | tool_output_callback=tool_output_callback,
209 | selected_screen=selected_screen
210 | )
211 |
212 | elif actor_model == "claude-3-5-sonnet-20241022":
213 | loop_mode = "unified"
214 |
215 | else:
216 | raise ValueError(f"Actor Model {actor_model} not supported")
217 |
218 |
219 | tool_result_content = None
220 | showui_loop_count = 0
221 |
222 | logger.info(f"Start the message loop. User messages: {messages}")
223 |
224 | if loop_mode == "unified":
225 | # ------------------------------
226 | # Unified loop:
227 | # 1) repeatedly call actor -> executor -> check tool_result -> maybe end
228 | # ------------------------------
229 | while True:
230 | # Call the actor with current messages
231 | response = actor(messages=messages)
232 |
233 | # Let the executor process that response, yielding any intermediate messages
234 | for message, tool_result_content in executor(response, messages):
235 | yield message
236 |
237 | # If executor didn't produce further content, we're done
238 | if not tool_result_content:
239 | return messages
240 |
241 | # If there is more tool content, treat that as user input
242 | messages.append({
243 | "content": tool_result_content,
244 | "role": "user"
245 | })
246 |
247 | elif loop_mode == "planner + actor":
248 | # ------------------------------------------------------
249 | # Planner + actor loop:
250 | # 1) planner => get next_action
251 | # 2) If no next_action -> end
252 | # 3) Otherwise actor => executor
253 | # 4) repeat
254 | # ------------------------------------------------------
255 | while True:
256 | # Step 1: Planner (VLM) response
257 | vlm_response = planner(messages=messages)
258 |
259 | # Step 2: Extract the "Next Action" from the planner output
260 | next_action = json.loads(vlm_response).get("Next Action")
261 |
262 | # Yield the next_action string, in case the UI or logs want to show it
263 | yield next_action
264 |
265 | # Step 3: Check if there are no further actions
266 | if not next_action or next_action in ("None", ""):
267 | final_sc, final_sc_path = get_screenshot(selected_screen=selected_screen)
268 | final_image_b64 = encode_image(str(final_sc_path))
269 |
270 | output_callback(
271 | (
272 | f"No more actions from {colorful_text_vlm}. End of task. Final State:\n"
273 | f''
274 | ),
275 | sender="bot"
276 | )
277 | yield None
278 | break
279 |
280 | # Step 4: Output an action message
281 | output_callback(
282 | f"{colorful_text_vlm} sending action to {colorful_text_showui}:\n{next_action}",
283 | sender="bot"
284 | )
285 |
286 | # Step 5: Actor response
287 | actor_response = actor(messages=next_action)
288 | yield actor_response
289 |
290 | # Step 6: Execute the actor response
291 | for message, tool_result_content in executor(actor_response, messages):
292 | time.sleep(0.5) # optional small delay
293 | yield message
294 |
295 | # Step 7: Update conversation with embedding history of plan and actions
296 | messages.append({
297 | "role": "user",
298 | "content": [
299 | "History plan:" + str(json.loads(vlm_response)),
300 | "History actions:" + str(actor_response["content"])
301 | ]
302 | })
303 |
304 | logger.info(
305 | f"End of loop. Total cost: $USD{planner.total_cost:.5f}"
306 | )
307 |
308 |
309 | # Increment loop counter
310 | showui_loop_count += 1
311 |
--------------------------------------------------------------------------------
/computer_use_demo/remote_inference.py:
--------------------------------------------------------------------------------
1 | from contextlib import asynccontextmanager
2 | from fastapi import FastAPI, HTTPException
3 | from fastapi.middleware.cors import CORSMiddleware
4 | from pydantic import BaseModel, field_validator
5 | from typing import Optional, List, Union, Dict, Any
6 | import torch
7 | from transformers import (
8 | Qwen2_5_VLForConditionalGeneration,
9 | Qwen2VLForConditionalGeneration,
10 | AutoProcessor,
11 | BitsAndBytesConfig
12 | )
13 | from qwen_vl_utils import process_vision_info
14 | import uvicorn
15 | import json
16 | from datetime import datetime
17 | import logging
18 | import time
19 | import psutil
20 | import GPUtil
21 | import base64
22 | from PIL import Image
23 | import io
24 | import os
25 | import threading
26 |
27 | # Set environment variables to disable compilation cache and avoid CUDA kernel issues
28 | os.environ["CUDA_LAUNCH_BLOCKING"] = "0"
29 | os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0" # Compatible with A5000
30 |
31 | # Model configuration
32 | MODELS = {
33 | "Qwen2.5-VL-7B-Instruct": {
34 | "path": "Qwen/Qwen2.5-VL-7B-Instruct",
35 | "model_class": Qwen2_5_VLForConditionalGeneration,
36 | },
37 | "Qwen2-VL-7B-Instruct": {
38 | "path": "Qwen/Qwen2-VL-7B-Instruct",
39 | "model_class": Qwen2VLForConditionalGeneration,
40 | },
41 | "Qwen2-VL-2B-Instruct": {
42 | "path": "Qwen/Qwen2-VL-2B-Instruct",
43 | "model_class": Qwen2VLForConditionalGeneration,
44 | }
45 | }
46 |
47 | # Configure logging
48 | logging.basicConfig(
49 | level=logging.INFO,
50 | format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
51 | )
52 | logger = logging.getLogger(__name__)
53 |
54 | # Global variables
55 | models = {}
56 | processors = {}
57 | model_locks = {} # Thread locks for model loading
58 | last_used = {} # Record last use time of models
59 |
60 | # Set default CUDA device
61 | if torch.cuda.is_available():
62 | # Get GPU information and select the device with maximum memory
63 | gpus = GPUtil.getGPUs()
64 | if gpus:
65 | max_memory_gpu = max(gpus, key=lambda g: g.memoryTotal)
66 | selected_device = max_memory_gpu.id
67 | torch.cuda.set_device(selected_device)
68 | device = torch.device(f"cuda:{selected_device}")
69 | logger.info(f"Selected GPU {selected_device} ({max_memory_gpu.name}) with {max_memory_gpu.memoryTotal}MB memory")
70 | else:
71 | device = torch.device("cuda:0")
72 | else:
73 | device = torch.device("cpu")
74 | logger.info(f"Using device: {device}")
75 |
76 | class ImageURL(BaseModel):
77 | url: str
78 |
79 | class MessageContent(BaseModel):
80 | type: str
81 | text: Optional[str] = None
82 | image_url: Optional[Dict[str, str]] = None
83 |
84 | @field_validator('type')
85 | @classmethod
86 | def validate_type(cls, v: str) -> str:
87 | if v not in ['text', 'image_url']:
88 | raise ValueError(f"Invalid content type: {v}")
89 | return v
90 |
91 | class ChatMessage(BaseModel):
92 | role: str
93 | content: Union[str, List[MessageContent]]
94 |
95 | @field_validator('role')
96 | @classmethod
97 | def validate_role(cls, v: str) -> str:
98 | if v not in ['system', 'user', 'assistant']:
99 | raise ValueError(f"Invalid role: {v}")
100 | return v
101 |
102 | @field_validator('content')
103 | @classmethod
104 | def validate_content(cls, v: Union[str, List[Any]]) -> Union[str, List[MessageContent]]:
105 | if isinstance(v, str):
106 | return v
107 | if isinstance(v, list):
108 | return [MessageContent(**item) if isinstance(item, dict) else item for item in v]
109 | raise ValueError("Content must be either a string or a list of content items")
110 |
111 | class ChatCompletionRequest(BaseModel):
112 | model: str
113 | messages: List[ChatMessage]
114 | temperature: Optional[float] = 0.7
115 | top_p: Optional[float] = 0.95
116 | max_tokens: Optional[int] = 2048
117 | stream: Optional[bool] = False
118 | response_format: Optional[Dict[str, str]] = None
119 |
120 | class ChatCompletionResponse(BaseModel):
121 | id: str
122 | object: str
123 | created: int
124 | model: str
125 | choices: List[Dict[str, Any]]
126 | usage: Dict[str, int]
127 |
128 | class ModelCard(BaseModel):
129 | id: str
130 | created: int
131 | owned_by: str
132 | permission: List[Dict[str, Any]] = []
133 | root: Optional[str] = None
134 | parent: Optional[str] = None
135 | capabilities: Optional[Dict[str, bool]] = None
136 | context_window: Optional[int] = None
137 | max_tokens: Optional[int] = None
138 |
139 | class ModelList(BaseModel):
140 | object: str = "list"
141 | data: List[ModelCard]
142 |
143 | def process_base64_image(base64_string: str) -> Image.Image:
144 | """Process base64 image data and return PIL Image"""
145 | try:
146 | # Remove data URL prefix if present
147 | if 'base64,' in base64_string:
148 | base64_string = base64_string.split('base64,')[1]
149 |
150 | image_data = base64.b64decode(base64_string)
151 | image = Image.open(io.BytesIO(image_data))
152 |
153 | # Convert to RGB if necessary
154 | if image.mode not in ('RGB', 'L'):
155 | image = image.convert('RGB')
156 |
157 | return image
158 | except Exception as e:
159 | logger.error(f"Error processing base64 image: {str(e)}")
160 | raise ValueError(f"Invalid base64 image data: {str(e)}")
161 |
162 | def log_system_info():
163 | """Log system resource information"""
164 | try:
165 | cpu_percent = psutil.cpu_percent(interval=1)
166 | memory = psutil.virtual_memory()
167 | gpu_info = []
168 | if torch.cuda.is_available():
169 | for gpu in GPUtil.getGPUs():
170 | gpu_info.append({
171 | 'id': gpu.id,
172 | 'name': gpu.name,
173 | 'load': f"{gpu.load*100}%",
174 | 'memory_used': f"{gpu.memoryUsed}MB/{gpu.memoryTotal}MB",
175 | 'temperature': f"{gpu.temperature}°C"
176 | })
177 | logger.info(f"System Info - CPU: {cpu_percent}%, RAM: {memory.percent}%, "
178 | f"Available RAM: {memory.available/1024/1024/1024:.1f}GB")
179 | if gpu_info:
180 | logger.info(f"GPU Info: {gpu_info}")
181 | except Exception as e:
182 | logger.warning(f"Failed to log system info: {str(e)}")
183 |
184 | def get_or_initialize_model(model_name: str):
185 | """Get or initialize a model if not already loaded"""
186 | global models, processors, model_locks, last_used
187 |
188 | if model_name not in MODELS:
189 | available_models = list(MODELS.keys())
190 | raise ValueError(f"Unsupported model: {model_name}\nAvailable models: {available_models}")
191 |
192 | # Initialize lock for the model (if not already done)
193 | if model_name not in model_locks:
194 | model_locks[model_name] = threading.Lock()
195 |
196 | with model_locks[model_name]:
197 | if model_name not in models or model_name not in processors:
198 | try:
199 | start_time = time.time()
200 | logger.info(f"Starting {model_name} initialization...")
201 | log_system_info()
202 |
203 | model_config = MODELS[model_name]
204 |
205 | # Configure 8-bit quantization
206 | quantization_config = BitsAndBytesConfig(
207 | load_in_8bit=True,
208 | bnb_4bit_compute_dtype=torch.float16,
209 | bnb_4bit_use_double_quant=False,
210 | bnb_4bit_quant_type="nf4",
211 | )
212 |
213 | logger.info(f"Loading {model_name} with 8-bit quantization...")
214 | model = model_config["model_class"].from_pretrained(
215 | model_config["path"],
216 | quantization_config=quantization_config,
217 | device_map={"": device.index if device.type == "cuda" else "cpu"},
218 | local_files_only=False
219 | ).eval()
220 |
221 | processor = AutoProcessor.from_pretrained(
222 | model_config["path"],
223 | local_files_only=False
224 | )
225 |
226 | models[model_name] = model
227 | processors[model_name] = processor
228 |
229 | end_time = time.time()
230 | logger.info(f"Model {model_name} initialized in {end_time - start_time:.2f} seconds")
231 | log_system_info()
232 |
233 | except Exception as e:
234 | logger.error(f"Model initialization error for {model_name}: {str(e)}", exc_info=True)
235 | raise RuntimeError(f"Failed to initialize model {model_name}: {str(e)}")
236 |
237 | # Update last use time
238 | last_used[model_name] = time.time()
239 |
240 | return models[model_name], processors[model_name]
241 |
242 | @asynccontextmanager
243 | async def lifespan(app: FastAPI):
244 | logger.info("Starting application initialization...")
245 | try:
246 | yield
247 | finally:
248 | logger.info("Shutting down application...")
249 | global models, processors
250 | for model_name, model in models.items():
251 | try:
252 | del model
253 | logger.info(f"Model {model_name} unloaded")
254 | except Exception as e:
255 | logger.error(f"Error during cleanup of {model_name}: {str(e)}")
256 |
257 | if torch.cuda.is_available():
258 | torch.cuda.empty_cache()
259 | logger.info("CUDA cache cleared")
260 |
261 | models = {}
262 | processors = {}
263 | logger.info("Shutdown complete")
264 |
265 | app = FastAPI(
266 | title="Qwen2.5-VL API",
267 | description="OpenAI-compatible API for Qwen2.5-VL vision-language model",
268 | version="1.0.0",
269 | lifespan=lifespan
270 | )
271 |
272 | app.add_middleware(
273 | CORSMiddleware,
274 | allow_origins=["*"],
275 | allow_credentials=True,
276 | allow_methods=["*"],
277 | allow_headers=["*"],
278 | )
279 |
280 | @app.get("/v1/models", response_model=ModelList)
281 | async def list_models():
282 | """List available models"""
283 | model_cards = []
284 | for model_name in MODELS.keys():
285 | model_cards.append(
286 | ModelCard(
287 | id=model_name,
288 | created=1709251200,
289 | owned_by="Qwen",
290 | permission=[{
291 | "id": f"modelperm-{model_name}",
292 | "created": 1709251200,
293 | "allow_create_engine": False,
294 | "allow_sampling": True,
295 | "allow_logprobs": True,
296 | "allow_search_indices": False,
297 | "allow_view": True,
298 | "allow_fine_tuning": False,
299 | "organization": "*",
300 | "group": None,
301 | "is_blocking": False
302 | }],
303 | capabilities={
304 | "vision": True,
305 | "chat": True,
306 | "embeddings": False,
307 | "text_completion": True
308 | },
309 | context_window=4096,
310 | max_tokens=2048
311 | )
312 | )
313 | return ModelList(data=model_cards)
314 |
315 | @app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
316 | async def chat_completions(request: ChatCompletionRequest):
317 | """Handle chat completion requests with vision support"""
318 | try:
319 | # Get or initialize requested model
320 | model, processor = get_or_initialize_model(request.model)
321 |
322 | request_start_time = time.time()
323 | logger.info(f"Received chat completion request for model: {request.model}")
324 | logger.info(f"Request content: {request.model_dump_json()}")
325 |
326 | messages = []
327 | for msg in request.messages:
328 | if isinstance(msg.content, str):
329 | messages.append({"role": msg.role, "content": msg.content})
330 | else:
331 | processed_content = []
332 | for content_item in msg.content:
333 | if content_item.type == "text":
334 | processed_content.append({
335 | "type": "text",
336 | "text": content_item.text
337 | })
338 | elif content_item.type == "image_url":
339 | if "url" in content_item.image_url:
340 | if content_item.image_url["url"].startswith("data:image"):
341 | processed_content.append({
342 | "type": "image",
343 | "image": process_base64_image(content_item.image_url["url"])
344 | })
345 | messages.append({"role": msg.role, "content": processed_content})
346 |
347 | text = processor.apply_chat_template(
348 | messages,
349 | tokenize=False,
350 | add_generation_prompt=True
351 | )
352 |
353 | image_inputs, video_inputs = process_vision_info(messages)
354 |
355 | # Ensure input data is on the correct device
356 | inputs = processor(
357 | text=[text],
358 | images=image_inputs,
359 | videos=video_inputs,
360 | padding=True,
361 | return_tensors="pt"
362 | )
363 |
364 | # Move all tensors to specified device
365 | input_tensors = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
366 |
367 | with torch.inference_mode():
368 | generated_ids = model.generate(
369 | **input_tensors,
370 | max_new_tokens=request.max_tokens,
371 | temperature=request.temperature,
372 | top_p=request.top_p,
373 | pad_token_id=processor.tokenizer.pad_token_id,
374 | eos_token_id=processor.tokenizer.eos_token_id
375 | )
376 |
377 | # Get input length and trim generated IDs
378 | input_length = input_tensors['input_ids'].shape[1]
379 | generated_ids_trimmed = generated_ids[:, input_length:]
380 |
381 | response = processor.batch_decode(
382 | generated_ids_trimmed,
383 | skip_special_tokens=True,
384 | clean_up_tokenization_spaces=False
385 | )[0]
386 |
387 | if request.response_format and request.response_format.get("type") == "json_object":
388 | try:
389 | if response.startswith('```'):
390 | response = '\n'.join(response.split('\n')[1:-1])
391 | if response.startswith('json'):
392 | response = response[4:].lstrip()
393 | content = json.loads(response)
394 | response = json.dumps(content)
395 | except json.JSONDecodeError as e:
396 | logger.error(f"JSON parsing error: {str(e)}")
397 | raise HTTPException(status_code=400, detail=f"Invalid JSON response: {str(e)}")
398 |
399 | total_time = time.time() - request_start_time
400 | logger.info(f"Request completed in {total_time:.2f} seconds")
401 |
402 | return ChatCompletionResponse(
403 | id=f"chatcmpl-{datetime.now().strftime('%Y%m%d%H%M%S')}",
404 | object="chat.completion",
405 | created=int(datetime.now().timestamp()),
406 | model=request.model,
407 | choices=[{
408 | "index": 0,
409 | "message": {
410 | "role": "assistant",
411 | "content": response
412 | },
413 | "finish_reason": "stop"
414 | }],
415 | usage={
416 | "prompt_tokens": input_length,
417 | "completion_tokens": len(generated_ids_trimmed[0]),
418 | "total_tokens": input_length + len(generated_ids_trimmed[0])
419 | }
420 | )
421 | except Exception as e:
422 | logger.error(f"Request error: {str(e)}", exc_info=True)
423 | if isinstance(e, HTTPException):
424 | raise
425 | raise HTTPException(status_code=500, detail=str(e))
426 |
427 | @app.get("/health")
428 | async def health_check():
429 | """Health check endpoint"""
430 | log_system_info()
431 | return {
432 | "status": "healthy",
433 | "loaded_models": list(models.keys()),
434 | "device": str(device),
435 | "cuda_available": torch.cuda.is_available(),
436 | "cuda_device_count": torch.cuda.device_count() if torch.cuda.is_available() else 0,
437 | "timestamp": datetime.now().isoformat()
438 | }
439 |
440 | @app.get("/model_status")
441 | async def model_status():
442 | """Get the status of all models"""
443 | status = {}
444 | for model_name in MODELS:
445 | status[model_name] = {
446 | "loaded": model_name in models,
447 | "last_used": last_used.get(model_name, None),
448 | "available": model_name in MODELS
449 | }
450 | return status
451 |
452 | if __name__ == "__main__":
453 | uvicorn.run(app, host="0.0.0.0", port=9192)
--------------------------------------------------------------------------------
/computer_use_demo/tools/__init__.py:
--------------------------------------------------------------------------------
1 | from .base import CLIResult, ToolResult
2 | from .bash import BashTool
3 | from .collection import ToolCollection
4 | from .computer import ComputerTool
5 | from .edit import EditTool
6 | from .screen_capture import get_screenshot
7 |
8 | __ALL__ = [
9 | BashTool,
10 | CLIResult,
11 | ComputerTool,
12 | EditTool,
13 | ToolCollection,
14 | ToolResult,
15 | get_screenshot,
16 | ]
17 |
--------------------------------------------------------------------------------
/computer_use_demo/tools/base.py:
--------------------------------------------------------------------------------
1 | from abc import ABCMeta, abstractmethod
2 | from dataclasses import dataclass, fields, replace
3 | from typing import Any
4 |
5 | from anthropic.types.beta import BetaToolUnionParam
6 |
7 |
8 | class BaseAnthropicTool(metaclass=ABCMeta):
9 | """Abstract base class for Anthropic-defined tools."""
10 |
11 | @abstractmethod
12 | def __call__(self, **kwargs) -> Any:
13 | """Executes the tool with the given arguments."""
14 | ...
15 |
16 | @abstractmethod
17 | def to_params(
18 | self,
19 | ) -> BetaToolUnionParam:
20 | raise NotImplementedError
21 |
22 |
23 | @dataclass(kw_only=True, frozen=True)
24 | class ToolResult:
25 | """Represents the result of a tool execution."""
26 |
27 | output: str | None = None
28 | error: str | None = None
29 | base64_image: str | None = None
30 | system: str | None = None
31 |
32 | def __bool__(self):
33 | return any(getattr(self, field.name) for field in fields(self))
34 |
35 | def __add__(self, other: "ToolResult"):
36 | def combine_fields(
37 | field: str | None, other_field: str | None, concatenate: bool = True
38 | ):
39 | if field and other_field:
40 | if concatenate:
41 | return field + other_field
42 | raise ValueError("Cannot combine tool results")
43 | return field or other_field
44 |
45 | return ToolResult(
46 | output=combine_fields(self.output, other.output),
47 | error=combine_fields(self.error, other.error),
48 | base64_image=combine_fields(self.base64_image, other.base64_image, False),
49 | system=combine_fields(self.system, other.system),
50 | )
51 |
52 | def replace(self, **kwargs):
53 | """Returns a new ToolResult with the given fields replaced."""
54 | return replace(self, **kwargs)
55 |
56 |
57 | class CLIResult(ToolResult):
58 | """A ToolResult that can be rendered as a CLI output."""
59 |
60 |
61 | class ToolFailure(ToolResult):
62 | """A ToolResult that represents a failure."""
63 |
64 |
65 | class ToolError(Exception):
66 | """Raised when a tool encounters an error."""
67 |
68 | def __init__(self, message):
69 | self.message = message
70 |
--------------------------------------------------------------------------------
/computer_use_demo/tools/bash.py:
--------------------------------------------------------------------------------
1 | import asyncio
2 | import os
3 | from typing import ClassVar, Literal
4 |
5 | from anthropic.types.beta import BetaToolBash20241022Param
6 |
7 | from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
8 |
9 |
10 | class _BashSession:
11 | """A session of a bash shell."""
12 |
13 | _started: bool
14 | _process: asyncio.subprocess.Process
15 |
16 | command: str = "/bin/bash"
17 | _output_delay: float = 0.2 # seconds
18 | _timeout: float = 120.0 # seconds
19 | _sentinel: str = "<>"
20 |
21 | def __init__(self):
22 | self._started = False
23 | self._timed_out = False
24 |
25 | async def start(self):
26 | if self._started:
27 | return
28 |
29 | self._process = await asyncio.create_subprocess_shell(
30 | self.command,
31 | shell=False,
32 | stdin=asyncio.subprocess.PIPE,
33 | stdout=asyncio.subprocess.PIPE,
34 | stderr=asyncio.subprocess.PIPE,
35 | )
36 |
37 | self._started = True
38 |
39 | def stop(self):
40 | """Terminate the bash shell."""
41 | if not self._started:
42 | raise ToolError("Session has not started.")
43 | if self._process.returncode is not None:
44 | return
45 | self._process.terminate()
46 |
47 | async def run(self, command: str):
48 | """Execute a command in the bash shell."""
49 | if not self._started:
50 | raise ToolError("Session has not started.")
51 | if self._process.returncode is not None:
52 | return ToolResult(
53 | system="tool must be restarted",
54 | error=f"bash has exited with returncode {self._process.returncode}",
55 | )
56 | if self._timed_out:
57 | raise ToolError(
58 | f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
59 | )
60 |
61 | # we know these are not None because we created the process with PIPEs
62 | assert self._process.stdin
63 | assert self._process.stdout
64 | assert self._process.stderr
65 |
66 | # send command to the process
67 | self._process.stdin.write(
68 | command.encode() + f"; echo '{self._sentinel}'\n".encode()
69 | )
70 | await self._process.stdin.drain()
71 |
72 | # read output from the process, until the sentinel is found
73 | output = ""
74 | try:
75 | async with asyncio.timeout(self._timeout):
76 | while True:
77 | await asyncio.sleep(self._output_delay)
78 | data = await self._process.stdout.readline()
79 | if not data:
80 | break
81 | line = data.decode()
82 | output += line
83 | if self._sentinel in line:
84 | output = output.replace(self._sentinel, "")
85 | break
86 | except asyncio.TimeoutError:
87 | self._timed_out = True
88 | raise ToolError(
89 | f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
90 | ) from None
91 |
92 | error = await self._process.stderr.read()
93 | error = error.decode()
94 |
95 | return CLIResult(output=output.strip(), error=error.strip())
96 |
97 |
98 | class BashTool(BaseAnthropicTool):
99 | """
100 | A tool that allows the agent to run bash commands.
101 | The tool parameters are defined by Anthropic and are not editable.
102 | """
103 |
104 | _session: _BashSession | None
105 | name: ClassVar[Literal["bash"]] = "bash"
106 | api_type: ClassVar[Literal["bash_20241022"]] = "bash_20241022"
107 |
108 | def __init__(self):
109 | self._session = None
110 | super().__init__()
111 |
112 | async def __call__(
113 | self, command: str | None = None, restart: bool = False, **kwargs
114 | ):
115 | if restart:
116 | if self._session:
117 | self._session.stop()
118 | self._session = _BashSession()
119 | await self._session.start()
120 |
121 | return ToolResult(system="tool has been restarted.")
122 |
123 | if self._session is None:
124 | self._session = _BashSession()
125 | await self._session.start()
126 |
127 | if command is not None:
128 | return await self._session.run(command)
129 |
130 | raise ToolError("no command provided.")
131 |
132 | def to_params(self) -> BetaToolBash20241022Param:
133 | return {
134 | "type": self.api_type,
135 | "name": self.name,
136 | }
--------------------------------------------------------------------------------
/computer_use_demo/tools/collection.py:
--------------------------------------------------------------------------------
1 | """Collection classes for managing multiple tools."""
2 |
3 | from typing import Any
4 |
5 | from anthropic.types.beta import BetaToolUnionParam
6 |
7 | from .base import (
8 | BaseAnthropicTool,
9 | ToolError,
10 | ToolFailure,
11 | ToolResult,
12 | )
13 |
14 |
15 | class ToolCollection:
16 | """A collection of anthropic-defined tools."""
17 |
18 | def __init__(self, *tools: BaseAnthropicTool):
19 | self.tools = tools
20 | self.tool_map = {tool.to_params()["name"]: tool for tool in tools}
21 |
22 | def to_params(
23 | self,
24 | ) -> list[BetaToolUnionParam]:
25 | return [tool.to_params() for tool in self.tools]
26 |
27 | async def run(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
28 | tool = self.tool_map.get(name)
29 | if not tool:
30 | return ToolFailure(error=f"Tool {name} is invalid")
31 | try:
32 | return await tool(**tool_input)
33 | except ToolError as e:
34 | return ToolFailure(error=e.message)
35 |
36 | def sync_call(self, *, name: str, tool_input: dict[str, Any]) -> ToolResult:
37 | print(f"sync_call: {name} {tool_input}")
38 | tool = self.tool_map.get(name)
39 | if not tool:
40 | return ToolFailure(error=f"Tool {name} is invalid")
41 | return tool.sync_call(**tool_input)
42 |
--------------------------------------------------------------------------------
/computer_use_demo/tools/colorful_text.py:
--------------------------------------------------------------------------------
1 | """
2 | Define some colorful stuffs for better visualization in the chat.
3 | """
4 |
5 | # Define the RGB colors for each letter
6 | colors = {
7 | 'S': 'rgb(106, 158, 210)',
8 | 'h': 'rgb(111, 163, 82)',
9 | 'o': 'rgb(209, 100, 94)',
10 | 'w': 'rgb(238, 171, 106)',
11 | 'U': 'rgb(0, 0, 0)',
12 | 'I': 'rgb(0, 0, 0)',
13 | }
14 |
15 | # Construct the colorful "ShowUI" word
16 | colorful_text_showui = "**"+''.join(
17 | f'{letter}'
18 | for letter in "ShowUI"
19 | )+"**"
20 |
21 |
22 | colorful_text_vlm = "**VLMPlanner**"
23 |
24 | colorful_text_user = "**User**"
25 |
26 | # print(f"colorful_text_showui: {colorful_text_showui}")
27 | # **ShowUI**
--------------------------------------------------------------------------------
/computer_use_demo/tools/edit.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from pathlib import Path
3 | from typing import Literal, get_args
4 |
5 | from anthropic.types.beta import BetaToolTextEditor20241022Param
6 |
7 | from .base import BaseAnthropicTool, CLIResult, ToolError, ToolResult
8 | from .run import maybe_truncate, run
9 |
10 | Command = Literal[
11 | "view",
12 | "create",
13 | "str_replace",
14 | "insert",
15 | "undo_edit",
16 | ]
17 | SNIPPET_LINES: int = 4
18 |
19 |
20 | class EditTool(BaseAnthropicTool):
21 | """
22 | An filesystem editor tool that allows the agent to view, create, and edit files.
23 | The tool parameters are defined by Anthropic and are not editable.
24 | """
25 |
26 | api_type: Literal["text_editor_20241022"] = "text_editor_20241022"
27 | name: Literal["str_replace_editor"] = "str_replace_editor"
28 |
29 | _file_history: dict[Path, list[str]]
30 |
31 | def __init__(self):
32 | self._file_history = defaultdict(list)
33 | super().__init__()
34 |
35 | def to_params(self) -> BetaToolTextEditor20241022Param:
36 | return {
37 | "name": self.name,
38 | "type": self.api_type,
39 | }
40 |
41 | async def __call__(
42 | self,
43 | *,
44 | command: Command,
45 | path: str,
46 | file_text: str | None = None,
47 | view_range: list[int] | None = None,
48 | old_str: str | None = None,
49 | new_str: str | None = None,
50 | insert_line: int | None = None,
51 | **kwargs,
52 | ):
53 | _path = Path(path)
54 | self.validate_path(command, _path)
55 | if command == "view":
56 | return await self.view(_path, view_range)
57 | elif command == "create":
58 | if not file_text:
59 | raise ToolError("Parameter `file_text` is required for command: create")
60 | self.write_file(_path, file_text)
61 | self._file_history[_path].append(file_text)
62 | return ToolResult(output=f"File created successfully at: {_path}")
63 | elif command == "str_replace":
64 | if not old_str:
65 | raise ToolError(
66 | "Parameter `old_str` is required for command: str_replace"
67 | )
68 | return self.str_replace(_path, old_str, new_str)
69 | elif command == "insert":
70 | if insert_line is None:
71 | raise ToolError(
72 | "Parameter `insert_line` is required for command: insert"
73 | )
74 | if not new_str:
75 | raise ToolError("Parameter `new_str` is required for command: insert")
76 | return self.insert(_path, insert_line, new_str)
77 | elif command == "undo_edit":
78 | return self.undo_edit(_path)
79 | raise ToolError(
80 | f'Unrecognized command {command}. The allowed commands for the {self.name} tool are: {", ".join(get_args(Command))}'
81 | )
82 |
83 | def validate_path(self, command: str, path: Path):
84 | """
85 | Check that the path/command combination is valid.
86 | """
87 | # Check if its an absolute path
88 | if not path.is_absolute():
89 | suggested_path = Path("") / path
90 | raise ToolError(
91 | f"The path {path} is not an absolute path, it should start with `/`. Maybe you meant {suggested_path}?"
92 | )
93 | # Check if path exists
94 | if not path.exists() and command != "create":
95 | raise ToolError(
96 | f"The path {path} does not exist. Please provide a valid path."
97 | )
98 | if path.exists() and command == "create":
99 | raise ToolError(
100 | f"File already exists at: {path}. Cannot overwrite files using command `create`."
101 | )
102 | # Check if the path points to a directory
103 | if path.is_dir():
104 | if command != "view":
105 | raise ToolError(
106 | f"The path {path} is a directory and only the `view` command can be used on directories"
107 | )
108 |
109 | async def view(self, path: Path, view_range: list[int] | None = None):
110 | """Implement the view command"""
111 | if path.is_dir():
112 | if view_range:
113 | raise ToolError(
114 | "The `view_range` parameter is not allowed when `path` points to a directory."
115 | )
116 |
117 | _, stdout, stderr = await run(
118 | rf"find {path} -maxdepth 2 -not -path '*/\.*'"
119 | )
120 | if not stderr:
121 | stdout = f"Here's the files and directories up to 2 levels deep in {path}, excluding hidden items:\n{stdout}\n"
122 | return CLIResult(output=stdout, error=stderr)
123 |
124 | file_content = self.read_file(path)
125 | init_line = 1
126 | if view_range:
127 | if len(view_range) != 2 or not all(isinstance(i, int) for i in view_range):
128 | raise ToolError(
129 | "Invalid `view_range`. It should be a list of two integers."
130 | )
131 | file_lines = file_content.split("\n")
132 | n_lines_file = len(file_lines)
133 | init_line, final_line = view_range
134 | if init_line < 1 or init_line > n_lines_file:
135 | raise ToolError(
136 | f"Invalid `view_range`: {view_range}. It's first element `{init_line}` should be within the range of lines of the file: {[1, n_lines_file]}"
137 | )
138 | if final_line > n_lines_file:
139 | raise ToolError(
140 | f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be smaller than the number of lines in the file: `{n_lines_file}`"
141 | )
142 | if final_line != -1 and final_line < init_line:
143 | raise ToolError(
144 | f"Invalid `view_range`: {view_range}. It's second element `{final_line}` should be larger or equal than its first `{init_line}`"
145 | )
146 |
147 | if final_line == -1:
148 | file_content = "\n".join(file_lines[init_line - 1 :])
149 | else:
150 | file_content = "\n".join(file_lines[init_line - 1 : final_line])
151 |
152 | return CLIResult(
153 | output=self._make_output(file_content, str(path), init_line=init_line)
154 | )
155 |
156 | def str_replace(self, path: Path, old_str: str, new_str: str | None):
157 | """Implement the str_replace command, which replaces old_str with new_str in the file content"""
158 | # Read the file content
159 | file_content = self.read_file(path).expandtabs()
160 | old_str = old_str.expandtabs()
161 | new_str = new_str.expandtabs() if new_str is not None else ""
162 |
163 | # Check if old_str is unique in the file
164 | occurrences = file_content.count(old_str)
165 | if occurrences == 0:
166 | raise ToolError(
167 | f"No replacement was performed, old_str `{old_str}` did not appear verbatim in {path}."
168 | )
169 | elif occurrences > 1:
170 | file_content_lines = file_content.split("\n")
171 | lines = [
172 | idx + 1
173 | for idx, line in enumerate(file_content_lines)
174 | if old_str in line
175 | ]
176 | raise ToolError(
177 | f"No replacement was performed. Multiple occurrences of old_str `{old_str}` in lines {lines}. Please ensure it is unique"
178 | )
179 |
180 | # Replace old_str with new_str
181 | new_file_content = file_content.replace(old_str, new_str)
182 |
183 | # Write the new content to the file
184 | self.write_file(path, new_file_content)
185 |
186 | # Save the content to history
187 | self._file_history[path].append(file_content)
188 |
189 | # Create a snippet of the edited section
190 | replacement_line = file_content.split(old_str)[0].count("\n")
191 | start_line = max(0, replacement_line - SNIPPET_LINES)
192 | end_line = replacement_line + SNIPPET_LINES + new_str.count("\n")
193 | snippet = "\n".join(new_file_content.split("\n")[start_line : end_line + 1])
194 |
195 | # Prepare the success message
196 | success_msg = f"The file {path} has been edited. "
197 | success_msg += self._make_output(
198 | snippet, f"a snippet of {path}", start_line + 1
199 | )
200 | success_msg += "Review the changes and make sure they are as expected. Edit the file again if necessary."
201 |
202 | return CLIResult(output=success_msg)
203 |
204 | def insert(self, path: Path, insert_line: int, new_str: str):
205 | """Implement the insert command, which inserts new_str at the specified line in the file content."""
206 | file_text = self.read_file(path).expandtabs()
207 | new_str = new_str.expandtabs()
208 | file_text_lines = file_text.split("\n")
209 | n_lines_file = len(file_text_lines)
210 |
211 | if insert_line < 0 or insert_line > n_lines_file:
212 | raise ToolError(
213 | f"Invalid `insert_line` parameter: {insert_line}. It should be within the range of lines of the file: {[0, n_lines_file]}"
214 | )
215 |
216 | new_str_lines = new_str.split("\n")
217 | new_file_text_lines = (
218 | file_text_lines[:insert_line]
219 | + new_str_lines
220 | + file_text_lines[insert_line:]
221 | )
222 | snippet_lines = (
223 | file_text_lines[max(0, insert_line - SNIPPET_LINES) : insert_line]
224 | + new_str_lines
225 | + file_text_lines[insert_line : insert_line + SNIPPET_LINES]
226 | )
227 |
228 | new_file_text = "\n".join(new_file_text_lines)
229 | snippet = "\n".join(snippet_lines)
230 |
231 | self.write_file(path, new_file_text)
232 | self._file_history[path].append(file_text)
233 |
234 | success_msg = f"The file {path} has been edited. "
235 | success_msg += self._make_output(
236 | snippet,
237 | "a snippet of the edited file",
238 | max(1, insert_line - SNIPPET_LINES + 1),
239 | )
240 | success_msg += "Review the changes and make sure they are as expected (correct indentation, no duplicate lines, etc). Edit the file again if necessary."
241 | return CLIResult(output=success_msg)
242 |
243 | def undo_edit(self, path: Path):
244 | """Implement the undo_edit command."""
245 | if not self._file_history[path]:
246 | raise ToolError(f"No edit history found for {path}.")
247 |
248 | old_text = self._file_history[path].pop()
249 | self.write_file(path, old_text)
250 |
251 | return CLIResult(
252 | output=f"Last edit to {path} undone successfully. {self._make_output(old_text, str(path))}"
253 | )
254 |
255 | def read_file(self, path: Path):
256 | """Read the content of a file from a given path; raise a ToolError if an error occurs."""
257 | try:
258 | return path.read_text()
259 | except Exception as e:
260 | raise ToolError(f"Ran into {e} while trying to read {path}") from None
261 |
262 | def write_file(self, path: Path, file: str):
263 | """Write the content of a file to a given path; raise a ToolError if an error occurs."""
264 | try:
265 | path.write_text(file)
266 | except Exception as e:
267 | raise ToolError(f"Ran into {e} while trying to write to {path}") from None
268 |
269 | def _make_output(
270 | self,
271 | file_content: str,
272 | file_descriptor: str,
273 | init_line: int = 1,
274 | expand_tabs: bool = True,
275 | ):
276 | """Generate output for the CLI based on the content of a file."""
277 | file_content = maybe_truncate(file_content)
278 | if expand_tabs:
279 | file_content = file_content.expandtabs()
280 | file_content = "\n".join(
281 | [
282 | f"{i + init_line:6}\t{line}"
283 | for i, line in enumerate(file_content.split("\n"))
284 | ]
285 | )
286 | return (
287 | f"Here's the result of running `cat -n` on {file_descriptor}:\n"
288 | + file_content
289 | + "\n"
290 | )
291 |
--------------------------------------------------------------------------------
/computer_use_demo/tools/logger.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 |
4 | def truncate_string(s, max_length=500):
5 | """Truncate long strings for concise printing."""
6 | if isinstance(s, str) and len(s) > max_length:
7 | return s[:max_length] + "..."
8 | return s
9 |
10 | # Configure logger
11 | logger = logging.getLogger(__name__)
12 | logger.setLevel(logging.INFO) # Choose your default level (INFO, DEBUG, etc.)
13 |
14 |
15 | # Optionally add a console handler if you don't have one already
16 | if not logger.handlers:
17 | console_handler = logging.StreamHandler()
18 | console_handler.setLevel(logging.INFO)
19 | formatter = logging.Formatter("[%(levelname)s] %(name)s - %(message)s")
20 | console_handler.setFormatter(formatter)
21 | logger.addHandler(console_handler)
--------------------------------------------------------------------------------
/computer_use_demo/tools/run.py:
--------------------------------------------------------------------------------
1 | """Utility to run shell commands asynchronously with a timeout."""
2 |
3 | import asyncio
4 |
5 | TRUNCATED_MESSAGE: str = "To save on context only part of this file has been shown to you. You should retry this tool after you have searched inside the file with `grep -n` in order to find the line numbers of what you are looking for."
6 | MAX_RESPONSE_LEN: int = 16000
7 |
8 |
9 | def maybe_truncate(content: str, truncate_after: int | None = MAX_RESPONSE_LEN):
10 | """Truncate content and append a notice if content exceeds the specified length."""
11 | return (
12 | content
13 | if not truncate_after or len(content) <= truncate_after
14 | else content[:truncate_after] + TRUNCATED_MESSAGE
15 | )
16 |
17 |
18 | async def run(
19 | cmd: str,
20 | timeout: float | None = 120.0, # seconds
21 | truncate_after: int | None = MAX_RESPONSE_LEN,
22 | ):
23 | """Run a shell command asynchronously with a timeout."""
24 | process = await asyncio.create_subprocess_shell(
25 | cmd, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
26 | )
27 |
28 | try:
29 | stdout, stderr = await asyncio.wait_for(process.communicate(), timeout=timeout)
30 | return (
31 | process.returncode or 0,
32 | maybe_truncate(stdout.decode(), truncate_after=truncate_after),
33 | maybe_truncate(stderr.decode(), truncate_after=truncate_after),
34 | )
35 | except asyncio.TimeoutError as exc:
36 | try:
37 | process.kill()
38 | except ProcessLookupError:
39 | pass
40 | raise TimeoutError(
41 | f"Command '{cmd}' timed out after {timeout} seconds"
42 | ) from exc
43 |
--------------------------------------------------------------------------------
/computer_use_demo/tools/screen_capture.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import base64
3 | from pathlib import Path
4 | from PIL import ImageGrab
5 | from uuid import uuid4
6 | from screeninfo import get_monitors
7 | import platform
8 | if platform.system() == "Darwin":
9 | import Quartz # uncomment this line if you are on macOS
10 |
11 | from PIL import ImageGrab
12 | from functools import partial
13 | from .base import BaseAnthropicTool, ToolError, ToolResult
14 |
15 |
16 | OUTPUT_DIR = "./tmp/outputs"
17 |
18 | def get_screenshot(selected_screen: int = 0, resize: bool = True, target_width: int = 1920, target_height: int = 1080):
19 | # print(f"get_screenshot selected_screen: {selected_screen}")
20 |
21 | # Get screen width and height using Windows command
22 | display_num = None
23 | offset_x = 0
24 | offset_y = 0
25 | selected_screen = selected_screen
26 | width, height = _get_screen_size()
27 |
28 | """Take a screenshot of the current screen and return a ToolResult with the base64 encoded image."""
29 | output_dir = Path(OUTPUT_DIR)
30 | output_dir.mkdir(parents=True, exist_ok=True)
31 | path = output_dir / f"screenshot_{uuid4().hex}.png"
32 |
33 | ImageGrab.grab = partial(ImageGrab.grab, all_screens=True)
34 |
35 | # Detect platform
36 | system = platform.system()
37 |
38 | if system == "Windows":
39 | # Windows: Use screeninfo to get monitor details
40 | screens = get_monitors()
41 |
42 | # Sort screens by x position to arrange from left to right
43 | sorted_screens = sorted(screens, key=lambda s: s.x)
44 |
45 | if selected_screen < 0 or selected_screen >= len(screens):
46 | raise IndexError("Invalid screen index.")
47 |
48 | screen = sorted_screens[selected_screen]
49 | bbox = (screen.x, screen.y, screen.x + screen.width, screen.y + screen.height)
50 |
51 | elif system == "Darwin": # macOS
52 | # macOS: Use Quartz to get monitor details
53 | max_displays = 32 # Maximum number of displays to handle
54 | active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
55 |
56 | # Get the display bounds (resolution) for each active display
57 | screens = []
58 | for display_id in active_displays:
59 | bounds = Quartz.CGDisplayBounds(display_id)
60 | screens.append({
61 | 'id': display_id,
62 | 'x': int(bounds.origin.x),
63 | 'y': int(bounds.origin.y),
64 | 'width': int(bounds.size.width),
65 | 'height': int(bounds.size.height),
66 | 'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
67 | })
68 |
69 | # Sort screens by x position to arrange from left to right
70 | sorted_screens = sorted(screens, key=lambda s: s['x'])
71 | # print(f"Darwin sorted_screens: {sorted_screens}")
72 |
73 | if selected_screen < 0 or selected_screen >= len(screens):
74 | raise IndexError("Invalid screen index.")
75 |
76 | screen = sorted_screens[selected_screen]
77 |
78 | bbox = (screen['x'], screen['y'], screen['x'] + screen['width'], screen['y'] + screen['height'])
79 |
80 | else: # Linux or other OS
81 | cmd = "xrandr | grep ' primary' | awk '{print $4}'"
82 | try:
83 | output = subprocess.check_output(cmd, shell=True).decode()
84 | resolution = output.strip()
85 | # Parse the resolution format like "1920x1080+1920+0"
86 | parts = resolution.split('+')[0] # Get just the "1920x1080" part
87 | width, height = map(int, parts.split('x'))
88 |
89 | # Create a screen object/dictionary similar to what's used in other platforms
90 | # Extract the offsets from the resolution string
91 | x_offset = int(resolution.split('+')[1]) if len(resolution.split('+')) > 1 else 0
92 | y_offset = int(resolution.split('+')[2]) if len(resolution.split('+')) > 2 else 0
93 |
94 | # Create a screen object with attributes similar to the screeninfo.Monitor object
95 | class LinuxScreen:
96 | def __init__(self, x, y, width, height):
97 | self.x = x
98 | self.y = y
99 | self.width = width
100 | self.height = height
101 |
102 | screen = LinuxScreen(x_offset, y_offset, width, height)
103 | bbox = (x_offset, y_offset, x_offset + width, y_offset + height)
104 | except subprocess.CalledProcessError:
105 | raise RuntimeError("Failed to get screen resolution on Linux.")
106 |
107 | # Take screenshot using the bounding box
108 | screenshot = ImageGrab.grab(bbox=bbox)
109 |
110 | # Set offsets (for potential future use)
111 | offset_x = screen['x'] if system == "Darwin" else screen.x
112 | offset_y = screen['y'] if system == "Darwin" else screen.y
113 |
114 | # # Resize if
115 | if resize:
116 | screenshot = screenshot.resize((target_width, target_height))
117 |
118 | # Save the screenshot
119 | screenshot.save(str(path))
120 |
121 | if path.exists():
122 | # Return a ToolResult instance instead of a dictionary
123 | return screenshot, path
124 |
125 | raise ToolError(f"Failed to take screenshot: {path} does not exist.")
126 |
127 |
128 |
129 |
130 | def _get_screen_size(selected_screen: int = 0):
131 | if platform.system() == "Windows":
132 | # Use screeninfo to get primary monitor on Windows
133 | screens = get_monitors()
134 |
135 | # Sort screens by x position to arrange from left to right
136 | sorted_screens = sorted(screens, key=lambda s: s.x)
137 | if selected_screen is None:
138 | primary_monitor = next((m for m in get_monitors() if m.is_primary), None)
139 | return primary_monitor.width, primary_monitor.height
140 | elif selected_screen < 0 or selected_screen >= len(screens):
141 | raise IndexError("Invalid screen index.")
142 | else:
143 | screen = sorted_screens[selected_screen]
144 | return screen.width, screen.height
145 | elif platform.system() == "Darwin":
146 | # macOS part using Quartz to get screen information
147 | max_displays = 32 # Maximum number of displays to handle
148 | active_displays = Quartz.CGGetActiveDisplayList(max_displays, None, None)[1]
149 |
150 | # Get the display bounds (resolution) for each active display
151 | screens = []
152 | for display_id in active_displays:
153 | bounds = Quartz.CGDisplayBounds(display_id)
154 | screens.append({
155 | 'id': display_id,
156 | 'x': int(bounds.origin.x),
157 | 'y': int(bounds.origin.y),
158 | 'width': int(bounds.size.width),
159 | 'height': int(bounds.size.height),
160 | 'is_primary': Quartz.CGDisplayIsMain(display_id) # Check if this is the primary display
161 | })
162 |
163 | # Sort screens by x position to arrange from left to right
164 | sorted_screens = sorted(screens, key=lambda s: s['x'])
165 |
166 | if selected_screen is None:
167 | # Find the primary monitor
168 | primary_monitor = next((screen for screen in screens if screen['is_primary']), None)
169 | if primary_monitor:
170 | return primary_monitor['width'], primary_monitor['height']
171 | else:
172 | raise RuntimeError("No primary monitor found.")
173 | elif selected_screen < 0 or selected_screen >= len(screens):
174 | raise IndexError("Invalid screen index.")
175 | else:
176 | # Return the resolution of the selected screen
177 | screen = sorted_screens[selected_screen]
178 | return screen['width'], screen['height']
179 |
180 | else: # Linux or other OS
181 | cmd = "xrandr | grep ' primary' | awk '{print $4}'"
182 | try:
183 | output = subprocess.check_output(cmd, shell=True).decode()
184 | resolution = output.strip().split()[0]
185 | parts = resolution.split('+')[0] # Get just the "1920x1080" part
186 | width, height = map(int, parts.split('x'))
187 | return width, height
188 | except subprocess.CalledProcessError:
189 | raise RuntimeError("Failed to get screen resolution on Linux.")
190 |
--------------------------------------------------------------------------------
/docs/README_cn.md:
--------------------------------------------------------------------------------
1 |