├── .gitignore ├── AndroidWorld ├── README.md ├── agents │ ├── aria_ui_utils.py │ └── m3a_aria_ui.py ├── env │ └── json_action.py └── run_aria_ui.py ├── README.md ├── aria_ui_hf.py ├── aria_ui_vllm.py ├── assets ├── aria_ui_framework_v4.pdf ├── aria_ui_logo.png ├── logo_long.png ├── overall.png ├── performance_spider.pdf └── seo.png ├── examples └── aria.png └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # UV 100 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | #uv.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 118 | .pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | # env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | # PyPI configuration file 173 | .pypirc 174 | -------------------------------------------------------------------------------- /AndroidWorld/README.md: -------------------------------------------------------------------------------- 1 | # 🚀 Welcome to the M3A Agent Powered by Aria-UI for AndroidWorld! 2 | 3 | We’re thrilled to release the **M3A Agent powered by Aria-UI**! 🎉 This integration enhances task success rates and brings seamless grounding instruction understanding to **AndroidWorld**. Follow the steps below to get started. 🚀 4 | 5 | ## 🛠️ How to Use M3A Agent with AndroidWorld 6 | 7 | ### 1️⃣ Clone AndroidWorld and Set Up the Environment 8 | 9 | First, clone the latest version of **AndroidWorld** and install the required dependencies: 10 | 11 | ```bash 12 | git clone https://github.com/google-research/android_world.git 13 | cd android_world 14 | ``` 15 | For more details, check out the official [AndroidWorld GitHub repository](https://github.com/google-research/android_world?tab=readme-ov-file). 16 | 17 | ### 2️⃣ Merge M3A Agent Files with AndroidWorld 18 | 19 | We’ve organized our files to match AndroidWorld’s directory structure. To integrate: 20 | 21 | - Merge the provided `agents/` and `env/` directories with the respective directories in AndroidWorld. 22 | - Place `run_aria_ui.py` in the root directory of AndroidWorld. 23 | 24 | After merging, your directory structure should look like this: 25 | ``` 26 | AndroidWorld/ 27 | ├── agents/ 28 | │ └── aria_ui_utils.py 29 | │ └── m3a_aria_ui.py 30 | ├── env/ 31 | │ └── json_action.py 32 | ├── run_aria_ui.py 33 | ... 34 | ``` 35 | 36 | ### 3️⃣ Deploy Aria-UI API 🌐 37 | Under AndroidWorld/agents/aria_ui_utils.py, you’ll find how we connect to Aria-UI using an API. Follow these steps to set it up: 38 | - Deploy Aria-UI with vLLM to serve an OpenAI-style API. 39 | Reference the [vLLM OpenAI-Compatible Server](https://docs.vllm.ai/en/latest/serving/openai_compatible_server.html). 40 | - Update the API key and base URL in `aria_ui_utils.py` to match your deployment. Replace the placeholders with your values: 41 | ```python 42 | ariaui_api_key = "ariaui_api_key" 43 | ariaui_api_base = "ariaui_api_base" 44 | ``` 45 | 46 | **Note**: You may also set up a simple API server with the vLLM inference script and [FastAPI](https://github.com/fastapi/fastapi). 47 | ### 4️⃣ Run the M3A Agent 🎉 48 | Once everything is set up, run the agent with the following command: 49 | ```python 50 | python run_aria_ui.py 51 | ``` 52 | Sit back and watch **Aria-UI** in action! 🚀 53 | 54 | ### 💡 Additional Notes 55 | For troubleshooting or further customization, explore `aria_ui_utils.py` and `m3a_aria_ui.py` to understand how Aria-UI is integrated. 56 | 57 | Enjoy using the M3A Agent powered by **Aria-UI**! 🎉 -------------------------------------------------------------------------------- /AndroidWorld/agents/aria_ui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Aria-UI Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import base64 16 | from io import BytesIO 17 | import requests 18 | from PIL import Image 19 | from openai import OpenAI 20 | import numpy as np 21 | import cv2 22 | from android_world.agents.m3a_utils import _logical_to_physical 23 | 24 | """ 25 | Deploy Aria-UI with vLLM, then get the api_key and api_base from the deployment for directly API call. 26 | """ 27 | ariaui_api_key = "ariaui_api_key" 28 | ariaui_api_base = "ariaui_api_base" 29 | 30 | client = OpenAI( 31 | api_key=ariaui_api_key, 32 | base_url=ariaui_api_base, 33 | ) 34 | 35 | models = client.models.list() 36 | model = models.data[0].id 37 | 38 | def encode_image_to_base64(image_path): 39 | pil_image = Image.open(image_path).convert('RGB') 40 | buffered = BytesIO() 41 | pil_image.save(buffered, format="JPEG") 42 | base64_str = base64.b64encode(buffered.getvalue()).decode('utf-8') 43 | return base64_str 44 | 45 | def encode_numpy_image_to_base64(image: np.ndarray) -> str: 46 | """Converts a numpy array image to base64 string. 47 | 48 | Args: 49 | image: Numpy array representing an image (height, width, channels) 50 | 51 | Returns: 52 | Base64 encoded string of the image 53 | """ 54 | # Convert numpy array to bytes 55 | success, buffer = cv2.imencode('.jpg', image) 56 | if not success: 57 | raise ValueError("Failed to encode image to jpg format") 58 | 59 | # Convert bytes to base64 string 60 | image_bytes = buffer.tobytes() 61 | base64_string = base64.b64encode(image_bytes).decode('utf-8') 62 | 63 | return base64_string 64 | 65 | def request_aria_ui(image: np.ndarray, prompt: str) -> str: 66 | image_base64 = encode_numpy_image_to_base64(image) 67 | chat_completion_from_url = client.chat.completions.create( 68 | messages=[{ 69 | "role": 70 | "user", 71 | "content": [ 72 | { 73 | "type": "text", 74 | "text": prompt 75 | }, 76 | { 77 | "type": "image_url", 78 | "image_url": { 79 | "url": f"data:image/jpeg;base64,{image_base64}" 80 | }, 81 | }, 82 | ], 83 | }], 84 | model=model, 85 | max_tokens=512, 86 | stop=["<|im_end|>"], 87 | extra_body= { 88 | "split_image": True, 89 | "image_max_size": 980 90 | } 91 | ) 92 | 93 | result = chat_completion_from_url.choices[0].message.content 94 | print(f"Chat completion output:{result}") 95 | return result 96 | 97 | 98 | def add_ui_element_mark_coords( 99 | screenshot: np.ndarray, 100 | coords: tuple[int, int], # Normalized coordinates in [0, 1000] 101 | logical_screen_size: tuple[int, int], 102 | physical_frame_boundary: tuple[int, int, int, int], 103 | orientation: int, 104 | ): 105 | """Add a red circle marker at the specified normalized coordinates. 106 | 107 | Args: 108 | screenshot: The screenshot as a numpy ndarray. 109 | coords: Normalized coordinates (x, y) in range [0, 1000]. 110 | logical_screen_size: The logical screen size. 111 | physical_frame_boundary: The physical coordinates in portrait orientation 112 | for the upper left and lower right corner for the frame. 113 | orientation: The current screen orientation. 114 | """ 115 | # Convert normalized coordinates to logical coordinates 116 | logical_point = ( 117 | coords[0] * logical_screen_size[0] // 1000, 118 | coords[1] * logical_screen_size[1] // 1000 119 | ) 120 | 121 | # Convert to physical coordinates 122 | physical_point = _logical_to_physical( 123 | logical_point, 124 | logical_screen_size, 125 | physical_frame_boundary, 126 | orientation, 127 | ) 128 | 129 | # Draw a large red circle 130 | radius = 30 # Adjust size as needed 131 | cv2.circle( 132 | screenshot, 133 | physical_point, 134 | radius, 135 | color=(0, 0, 255), # BGR format - Red 136 | thickness=3 137 | ) 138 | 139 | def convert_coords_to_physical(coords: tuple[int, int], logical_screen_size: tuple[int, int], physical_frame_boundary: tuple[int, int, int, int], orientation: int) -> tuple[int, int]: 140 | logical_point = ( 141 | coords[0] * logical_screen_size[0] // 1000, 142 | coords[1] * logical_screen_size[1] // 1000 143 | ) 144 | physical_point = _logical_to_physical( 145 | logical_point, 146 | logical_screen_size, 147 | physical_frame_boundary, 148 | orientation, 149 | ) 150 | return physical_point -------------------------------------------------------------------------------- /AndroidWorld/agents/m3a_aria_ui.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Aria-UI Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2024 The android_world Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """A Multimodal Autonomous Agent for Android (Aria-UI under M3A).""" 30 | 31 | import time 32 | from android_world.agents import agent_utils 33 | from android_world.agents import base_agent 34 | from android_world.agents import infer 35 | from android_world.agents import m3a_utils 36 | from android_world.env import interface 37 | from android_world.env import json_action 38 | from android_world.env import representation_utils 39 | from android_world.agents import aria_ui_utils 40 | import ast 41 | from tenacity import retry, stop_after_attempt, wait_fixed, retry_if_exception_type 42 | 43 | 44 | PROMPT_PREFIX = """ 45 | You are an agent who can operate an Android phone on behalf of a user. Based on the user's goal/request, you may: 46 | 47 | - Answer back if the request/goal is a question (or a chat message), like user asks "What is my schedule for today?". 48 | - Complete some tasks described in the requests/goals by performing actions (step by step) on the phone. 49 | 50 | When given a user request, you will try to complete it step by step. At each step, you will be given the current screenshot and a history of what you have done (in text). Based on these pieces of information and the goal, you must choose to perform one of the actions in the following list (action description followed by the JSON format) by outputting the action in the correct JSON format: 51 | 52 | - If you think the task has been completed, finish the task by using the status action with complete as goal_status: 53 | {{ 54 | "action_type": "status", 55 | "goal_status": "complete" 56 | }} 57 | 58 | - If you think the task is not feasible (including cases like you don't have enough information or cannot perform some necessary actions), finish by using the `status` action with infeasible as goal_status: 59 | {{ 60 | "action_type": "status", 61 | "goal_status": "infeasible" 62 | }} 63 | 64 | - Answer user's question: 65 | {{ 66 | "action_type": "answer", 67 | "text": "answer_text" 68 | }} 69 | 70 | - Click/tap on an element on the screen. Please describe the element you want to click using natural language: 71 | {{ 72 | "action_type": "click", 73 | "instruction": the step-wise instruction in short, 74 | "target": target_element_description 75 | }} 76 | 77 | - Long press on an element on the screen, similar to the click action above, use the semantic description to indicate the element: 78 | {{ 79 | "action_type": "long_press", 80 | "instruction": the step-wise instruction in short, 81 | "target": target_element_description 82 | }} 83 | 84 | - Type text into a text field (this action contains clicking the text field, typing in the text, and pressing Enter, so no need to click on the target field to start). Use the semantic description to indicate the target text field: 85 | {{ 86 | "action_type": "input_text", 87 | "text": text_input, 88 | "instruction": the step-wise instruction in short, 89 | "target": target_element_description 90 | }} 91 | 92 | - Press the Enter key: 93 | {{ 94 | "action_type": "keyboard_enter" 95 | }} 96 | 97 | - Navigate to the home screen: 98 | {{ 99 | "action_type": "navigate_home" 100 | }} 101 | 102 | - Navigate back: 103 | {{ 104 | "action_type": "navigate_back" 105 | }} 106 | 107 | - Scroll the screen or a scrollable UI element in one of the four directions, use the same semantic description as above if you want to scroll a specific UI element, leave it empty when scrolling the whole screen: 108 | {{ 109 | "action_type": "scroll", 110 | "direction": "up/down/left/right", 111 | "instruction": the step-wise instruction in short, 112 | "element": optional_target_element_description 113 | }} 114 | 115 | - Open an app (nothing will happen if the app is not installed): 116 | {{ 117 | "action_type": "open_app", 118 | "app_name": name 119 | }} 120 | 121 | - Wait for the screen to update: 122 | {{ 123 | "action_type": "wait" 124 | }} 125 | """ 126 | 127 | GUIDANCE = """ 128 | Here are some useful guidelines you need to follow: 129 | 130 | General: 131 | - Usually there will be multiple ways to complete a task, pick the easiest one. Also when something does not work as expected (due to various reasons), sometimes a simple retry can solve the problem, but if it doesn’t (you can see that from the history), SWITCH to other solutions. 132 | - Sometimes you may need to navigate the phone to gather information needed to complete the task, for example if user asks "what is my schedule tomorrow", then you may want to open the calendar app (using the ‘open_app‘ action), look up information there, answer user’s question (using the ‘answer‘ action) and finish (using the ‘status‘ action with complete as goal_status). 133 | - For requests that are questions (or chat messages), remember to use the ‘answer‘ action to reply to user explicitly before finish! Merely displaying the answer on the screen is NOT sufficient (unless the goal is something like "show me ..."). 134 | - If the desired state is already achieved (e.g., enabling Wi-Fi when it’s already on), you can just complete the task. 135 | 136 | Action Related: 137 | - Use the ‘open_app‘ action whenever you want to open an app (nothing will happen if the app is not installed), do not use the app drawer to open an app unless all other ways have failed. 138 | - Use the ‘input_text‘ action whenever you want to type something (including passwords) instead of clicking characters on the keyboard one by one. Sometimes there is some default text in the text field you want to type in, remember to delete them before typing. 139 | - For ‘click‘, ‘long_press‘ and ‘input_text‘, the target_element.description parameter you choose must be based on a VISIBLE element in the screenshot. 140 | - Consider exploring the screen by using the ‘scroll‘ action with different directions to reveal additional content. 141 | - The direction parameter for the ‘scroll‘ action can be confusing sometimes as it’s opposite to swipe, for example, to view content at the bottom, the ‘scroll‘ direction should be set to "down". It has been observed that you have difficulties in choosing the correct direction, so if one does not work, try the opposite as well. 142 | 143 | Text Related Operations: 144 | - Normally to select certain text on the screen: (i) Enter text selection mode by long pressing the area where the text is, then some of the words near the long press point will be selected (highlighted with two pointers indicating the range) and usually a text selection bar will also appear with options like "copy", "paste", "select all", etc. (ii) Select the exact text you need. Usually the text selected from the previous step is NOT the one you want, you need to adjust the range by dragging the two pointers. If you want to select all text in the text field, simply click the "select all" button in the bar. 145 | - At this point, you don’t have the ability to drag something around the screen, so in general you can not select arbitrary text. 146 | - To delete some text: the most traditional way is to place the cursor at the right place and use the backspace button in the keyboard to delete the characters one by one (can long press the backspace to accelerate if there are many to delete). Another approach is to first select the text you want to delete, then click the backspace button in the keyboard. 147 | - To copy some text: first select the exact text you want to copy, which usually also brings up the text selection bar, then click the `copy` button in bar. 148 | - To paste text into a text box, first long press the text box, then usually the text selection bar will appear with a `paste` button in it. 149 | - When typing into a text field, sometimes an auto-complete dropdown list will appear. This usually indicating this is a enum field and you should try to select the best match by clicking the corresponding one in the list. 150 | """ 151 | 152 | 153 | ACTION_SELECTION_PROMPT_TEMPLATE = ( 154 | PROMPT_PREFIX + "\nThe current user goal/request is: {goal}\n\n" 155 | "Here is a history of what you have done so far:\n{history}\n\n" 156 | "The current screenshot is also given to you.\n" 157 | + GUIDANCE 158 | + "{additional_guidelines}" 159 | + "\nNow output an action from the above list in the correct JSON format," 160 | " following the reason why you do that. Your answer should look like:\n" 161 | 'Reason: ...\nAction: {{"action_type":...}}\n\n' 162 | "Your Answer:\n" 163 | ) 164 | 165 | 166 | SUMMARY_PROMPT_TEMPLATE = ( 167 | PROMPT_PREFIX + "\nThe (overall) user goal/request is: {goal}\n" 168 | "Now I want you to summerize the latest step.\n" 169 | "You will be given the screenshot before you performed the action (which" 170 | ' has a text label "before" on the bottom right), the action you chose' 171 | " (together with the reason) and the screenshot after the action was" 172 | ' performed (which has a text label "after" on the bottom right).\n' 173 | "Also here is the list of detailed information for some UI elements" 174 | " in the before screenshot:\n{before_elements}\n" 175 | "Here is the list for the after screenshot:\n{after_elements}\n" 176 | "This is the action you picked: {action}\n" 177 | "Based on the reason: {reason}\n\n" 178 | "By comparing the two screenshots (plus the UI element lists) and the" 179 | " action performed, give a brief summary of this step. This summary" 180 | " will be added to action history and used in future action selection," 181 | " so try to include essential information you think that will be most" 182 | " useful for future action selections like what you" 183 | " intended to do, why, if it worked as expected, if not" 184 | " what might be the reason (be critical, the action/reason might be" 185 | " wrong), what should/should not be done next and so on. Some more" 186 | " rules/tips you should follow:\n" 187 | "- Keep it short (better less than 50 words) and in a single line\n" 188 | "- Some actions (like `answer`, `wait`) don't involve screen change," 189 | " you can just assume they work as expected.\n" 190 | "- Given this summary will be added into action history, it can be used as" 191 | " memory to include information that needs to be remembered, or shared" 192 | " between different apps.\n\n" 193 | "Summary of this step: " 194 | ) 195 | 196 | 197 | ARIA_UI_PROMPT_TEMPLATE = """The agent is performing the ultimate task: {ultimate_task}. 198 | History of the agent's steps:\n{history_list}. 199 | Step {step_idx}. Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: {instruction}""" 200 | 201 | ARIA_UI_PROMPT_TEMPLATE_MINIWOB = """Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: {instruction}""" 202 | 203 | 204 | def _generate_ui_element_description( 205 | ui_element: representation_utils.UIElement, index: int 206 | ) -> str: 207 | """Generate a description for a given UI element with important information. 208 | 209 | Args: 210 | ui_element: UI elements for the current screen. 211 | index: The numeric index for the UI element. 212 | 213 | Returns: 214 | The description for the UI element. 215 | """ 216 | element_description = f'UI element {index}: {{"index": {index}, ' 217 | if ui_element.text: 218 | element_description += f'"text": "{ui_element.text}", ' 219 | if ui_element.content_description: 220 | element_description += ( 221 | f'"content_description": "{ui_element.content_description}", ' 222 | ) 223 | if ui_element.hint_text: 224 | element_description += f'"hint_text": "{ui_element.hint_text}", ' 225 | if ui_element.tooltip: 226 | element_description += f'"tooltip": "{ui_element.tooltip}", ' 227 | element_description += ( 228 | f'"is_clickable": {"True" if ui_element.is_clickable else "False"}, ' 229 | ) 230 | element_description += ( 231 | '"is_long_clickable":' 232 | f' {"True" if ui_element.is_long_clickable else "False"}, ' 233 | ) 234 | element_description += ( 235 | f'"is_editable": {"True" if ui_element.is_editable else "False"}, ' 236 | ) 237 | if ui_element.is_scrollable: 238 | element_description += '"is_scrollable": True, ' 239 | if ui_element.is_focusable: 240 | element_description += '"is_focusable": True, ' 241 | element_description += ( 242 | f'"is_selected": {"True" if ui_element.is_selected else "False"}, ' 243 | ) 244 | element_description += ( 245 | f'"is_checked": {"True" if ui_element.is_checked else "False"}, ' 246 | ) 247 | return element_description[:-2] + "}" 248 | 249 | 250 | def _generate_ui_elements_description_list( 251 | ui_elements: list[representation_utils.UIElement], 252 | screen_width_height_px: tuple[int, int], 253 | ) -> str: 254 | """Generate concise information for a list of UIElement. 255 | 256 | Args: 257 | ui_elements: UI elements for the current screen. 258 | screen_width_height_px: The height and width of the screen in pixels. 259 | 260 | Returns: 261 | Concise information for each UIElement. 262 | """ 263 | tree_info = "" 264 | for index, ui_element in enumerate(ui_elements): 265 | if m3a_utils.validate_ui_element(ui_element, screen_width_height_px): 266 | tree_info += _generate_ui_element_description(ui_element, index) + "\n" 267 | return tree_info 268 | 269 | 270 | def _pvision_action_selection_prompt( 271 | goal: str, 272 | history: list[str], 273 | additional_guidelines: list[str] | None = None, 274 | ) -> str: 275 | """Generate the prompt for the action selection. 276 | 277 | Args: 278 | goal: The current goal. 279 | history: Summaries for previous steps. 280 | ui_elements: A list of descriptions for the UI elements. 281 | additional_guidelines: Task specific guidelines. 282 | 283 | Returns: 284 | The text prompt for action selection that will be sent to gpt4v. 285 | """ 286 | if history: 287 | history = "\n".join(history) 288 | else: 289 | history = "You just started, no action has been performed yet." 290 | 291 | extra_guidelines = "" 292 | if additional_guidelines: 293 | extra_guidelines = "For The Current Task:\n" 294 | for guideline in additional_guidelines: 295 | extra_guidelines += f"- {guideline}\n" 296 | 297 | return ACTION_SELECTION_PROMPT_TEMPLATE.format( 298 | goal=goal, 299 | history=history, 300 | additional_guidelines=extra_guidelines, 301 | ) 302 | 303 | 304 | def _summarize_prompt( 305 | action: str, 306 | reason: str, 307 | goal: str, 308 | before_elements: str, 309 | after_elements: str, 310 | ) -> str: 311 | """Generate the prompt for the summarization step. 312 | 313 | Args: 314 | action: Action picked. 315 | reason: The reason to pick the action. 316 | goal: The overall goal. 317 | before_elements: Information for UI elements on the before screenshot. 318 | after_elements: Information for UI elements on the after screenshot. 319 | 320 | Returns: 321 | The text prompt for summarization that will be sent to gpt4v. 322 | """ 323 | return SUMMARY_PROMPT_TEMPLATE.format( 324 | goal=goal, 325 | before_elements=before_elements, 326 | after_elements=after_elements, 327 | action=action, 328 | reason=reason, 329 | ) 330 | 331 | 332 | def _extract_coords_from_response(response: str) -> tuple[int, int]: 333 | """Extract coordinate tuple from LLM response string. 334 | 335 | Args: 336 | response: String containing coordinates like "(892,925)" or "[892, 925]" 337 | 338 | Returns: 339 | Tuple of (x,y) coordinates as integers 340 | 341 | Raises: 342 | ValueError: If exactly 2 numbers are not found in the response 343 | """ 344 | # Clean up the response string 345 | resp = response.replace("```", "").strip() 346 | 347 | # Extract numbers using regex 348 | import re 349 | 350 | numbers = re.findall(r"\d+", resp) 351 | if len(numbers) != 2: 352 | raise ValueError( 353 | f"Expected exactly 2 coordinates, found {len(numbers)} numbers in response: {response}" 354 | ) 355 | 356 | return (int(numbers[0]), int(numbers[1])) 357 | 358 | 359 | @retry( 360 | stop=stop_after_attempt(3), 361 | wait=wait_fixed(5), 362 | retry=retry_if_exception_type(ValueError), 363 | reraise=True, 364 | ) 365 | def call_grounding_llm(screenshot, goal, history, elem_description, elem_instruction): 366 | history_list = "\n".join( 367 | [f"\t{j+1}. " + step_info["summary"] for j, step_info in enumerate(history)] 368 | ) 369 | """ 370 | AndroidWorld 371 | """ 372 | prompt = ARIA_UI_PROMPT_TEMPLATE.format( 373 | ultimate_task=goal, 374 | history_list=history_list, 375 | step_idx=len(history) + 1, 376 | instruction=f"description: {elem_description}; instruction: {elem_instruction}", 377 | ) 378 | """ 379 | MiniWob++ 380 | """ 381 | # prompt = ARIA_UI_PROMPT_TEMPLATE_MINIWOB.format( 382 | # instruction=f"{elem_description}", 383 | # ) 384 | 385 | response = aria_ui_utils.request_aria_ui(screenshot, prompt) 386 | 387 | coords = _extract_coords_from_response(response) 388 | return coords 389 | 390 | 391 | class M3A(base_agent.EnvironmentInteractingAgent): 392 | """M3A which stands for Multimodal Autonomous Agent for Android.""" 393 | 394 | def __init__( 395 | self, 396 | env: interface.AsyncEnv, 397 | llm: infer.MultimodalLlmWrapper, 398 | name: str = "M3A", 399 | wait_after_action_seconds: float = 2.0, 400 | ): 401 | """Initializes a M3A Agent. 402 | 403 | Args: 404 | env: The environment. 405 | llm: The multimodal LLM wrapper. 406 | name: The agent name. 407 | wait_after_action_seconds: Seconds to wait for the screen to stablize 408 | after executing an action 409 | """ 410 | super().__init__(env, name) 411 | self.llm = llm 412 | self.history = [] 413 | self.additional_guidelines = None 414 | self.wait_after_action_seconds = wait_after_action_seconds 415 | 416 | def set_task_guidelines(self, task_guidelines: list[str]) -> None: 417 | self.additional_guidelines = task_guidelines 418 | 419 | def reset(self, go_home_on_reset: bool = False): 420 | super().reset(go_home_on_reset) 421 | # Hide the coordinates on screen which might affect the vision model. 422 | self.env.hide_automation_ui() 423 | self.history = [] 424 | 425 | def step(self, goal: str) -> base_agent.AgentInteractionResult: 426 | step_data = { 427 | "raw_screenshot": None, 428 | "before_screenshot_with_som": None, 429 | "before_ui_elements": [], 430 | "after_screenshot_with_som": None, 431 | "action_prompt": None, 432 | "action_output": None, 433 | "action_output_json": None, 434 | "action_reason": None, 435 | "action_raw_response": None, 436 | "summary_prompt": None, 437 | "summary": None, 438 | "summary_raw_response": None, 439 | } 440 | print("----------step " + str(len(self.history) + 1)) 441 | 442 | state = self.get_post_transition_state() 443 | logical_screen_size = self.env.logical_screen_size 444 | orientation = self.env.orientation 445 | physical_frame_boundary = self.env.physical_frame_boundary 446 | 447 | before_ui_elements = state.ui_elements 448 | step_data["before_ui_elements"] = before_ui_elements 449 | before_ui_elements_list = _generate_ui_elements_description_list( 450 | before_ui_elements, logical_screen_size 451 | ) 452 | step_data["raw_screenshot"] = state.pixels.copy() 453 | before_screenshot = state.pixels.copy() 454 | for index, ui_element in enumerate(before_ui_elements): 455 | if m3a_utils.validate_ui_element(ui_element, logical_screen_size): 456 | m3a_utils.add_ui_element_mark( 457 | before_screenshot, 458 | ui_element, 459 | index, 460 | logical_screen_size, 461 | physical_frame_boundary, 462 | orientation, 463 | ) 464 | step_data["before_screenshot_with_som"] = before_screenshot.copy() 465 | 466 | action_prompt = _pvision_action_selection_prompt( 467 | goal, 468 | [ 469 | "Step " + str(i + 1) + "- " + step_info["summary"] 470 | for i, step_info in enumerate(self.history) 471 | ], 472 | self.additional_guidelines, 473 | ) 474 | step_data["action_prompt"] = action_prompt 475 | action_output, is_safe, raw_response = self.llm.predict_mm( 476 | action_prompt, 477 | [ 478 | step_data["raw_screenshot"], 479 | ], 480 | ) 481 | 482 | if is_safe == False: # pylint: disable=singleton-comparison 483 | # is_safe could be None 484 | action_output = f"""Reason: {m3a_utils.TRIGGER_SAFETY_CLASSIFIER} 485 | Action: {{"action_type": "status", "goal_status": "infeasible"}}""" 486 | 487 | if not raw_response: 488 | raise RuntimeError("Error calling LLM in action selection phase.") 489 | step_data["action_output"] = action_output 490 | step_data["action_raw_response"] = raw_response 491 | 492 | reason, action = m3a_utils.parse_reason_action_output(action_output) 493 | 494 | # If the output is not in the right format, add it to step summary which 495 | # will be passed to next step and return. 496 | if (not reason) or (not action): 497 | print("Action prompt output is not in the correct format.") 498 | step_data["summary"] = ( 499 | "Output for action selection is not in the correct format, so no" 500 | " action is performed." 501 | ) 502 | self.history.append(step_data) 503 | 504 | return base_agent.AgentInteractionResult( 505 | False, 506 | step_data, 507 | ) 508 | 509 | print("Action: " + action) 510 | print("Reason: " + reason) 511 | step_data["action_reason"] = reason 512 | 513 | try: 514 | action_json = agent_utils.extract_json(action) 515 | converted_action = json_action.JSONAction( 516 | **action_json, 517 | ) 518 | step_data["action_output_json"] = converted_action 519 | except Exception as e: # pylint: disable=broad-exception-caught 520 | print("Failed to convert the output to a valid action.") 521 | print(str(e)) 522 | step_data["summary"] = ( 523 | "Can not parse the output to a valid action. Please make sure to pick" 524 | " the action from the list with required parameters (if any) in the" 525 | " correct JSON format!" 526 | ) 527 | self.history.append(step_data) 528 | 529 | return base_agent.AgentInteractionResult( 530 | False, 531 | step_data, 532 | ) 533 | 534 | action_index = converted_action.index 535 | elem_description = converted_action.target 536 | elem_instruction = converted_action.instruction 537 | 538 | num_ui_elements = len(before_ui_elements) 539 | if ( 540 | converted_action.action_type 541 | in ["click", "long_press", "input_text", "scroll"] 542 | and elem_description is not None 543 | ): 544 | aria_ui_coords = call_grounding_llm( 545 | step_data["raw_screenshot"], 546 | goal, 547 | self.history, 548 | elem_description, 549 | elem_instruction, 550 | ) 551 | physical_coords = aria_ui_utils.convert_coords_to_physical( 552 | aria_ui_coords, 553 | logical_screen_size, 554 | physical_frame_boundary, 555 | orientation, 556 | ) 557 | 558 | converted_action.x = physical_coords[0] 559 | converted_action.y = physical_coords[1] 560 | 561 | # Add mark to the target element. 562 | aria_ui_utils.add_ui_element_mark_coords( 563 | step_data["raw_screenshot"], 564 | aria_ui_coords, 565 | logical_screen_size, 566 | physical_frame_boundary, 567 | orientation, 568 | ) 569 | 570 | if converted_action.action_type == "status": 571 | if converted_action.goal_status == "infeasible": 572 | print("Agent stopped since it thinks mission impossible.") 573 | step_data["summary"] = "Agent thinks the request has been completed." 574 | self.history.append(step_data) 575 | return base_agent.AgentInteractionResult( 576 | True, 577 | step_data, 578 | ) 579 | 580 | if converted_action.action_type == "answer": 581 | print("Agent answered with: " + converted_action.text) 582 | 583 | try: 584 | self.env.execute_action(converted_action) 585 | except Exception as e: # pylint: disable=broad-exception-caught 586 | print("Failed to execute action.") 587 | print(str(e)) 588 | step_data["summary"] = ( 589 | "Can not execute the action, make sure to select the action with" 590 | " the required parameters (if any) in the correct JSON format!" 591 | ) 592 | return base_agent.AgentInteractionResult( 593 | False, 594 | step_data, 595 | ) 596 | 597 | time.sleep(self.wait_after_action_seconds) 598 | 599 | state = self.env.get_state(wait_to_stabilize=False) 600 | logical_screen_size = self.env.logical_screen_size 601 | orientation = self.env.orientation 602 | physical_frame_boundary = self.env.physical_frame_boundary 603 | after_ui_elements = state.ui_elements 604 | after_ui_elements_list = _generate_ui_elements_description_list( 605 | after_ui_elements, logical_screen_size 606 | ) 607 | after_screenshot = state.pixels.copy() 608 | for index, ui_element in enumerate(after_ui_elements): 609 | if m3a_utils.validate_ui_element(ui_element, logical_screen_size): 610 | m3a_utils.add_ui_element_mark( 611 | after_screenshot, 612 | ui_element, 613 | index, 614 | logical_screen_size, 615 | physical_frame_boundary, 616 | orientation, 617 | ) 618 | 619 | m3a_utils.add_screenshot_label( 620 | step_data["before_screenshot_with_som"], "before" 621 | ) 622 | m3a_utils.add_screenshot_label(after_screenshot, "after") 623 | step_data["after_screenshot_with_som"] = after_screenshot.copy() 624 | 625 | summary_prompt = _summarize_prompt( 626 | action, 627 | reason, 628 | goal, 629 | before_ui_elements_list, 630 | after_ui_elements_list, 631 | ) 632 | summary, is_safe, raw_response = self.llm.predict_mm( 633 | summary_prompt, 634 | [ 635 | before_screenshot, 636 | after_screenshot, 637 | ], 638 | ) 639 | 640 | if is_safe == False: # pylint: disable=singleton-comparison 641 | # is_safe could be None 642 | summary = """Summary triggered LLM safety classifier.""" 643 | 644 | if not raw_response: 645 | print( 646 | "Error calling LLM in summarization phase. This should not happen: " 647 | f"{summary}" 648 | ) 649 | step_data["summary"] = ( 650 | "Some error occurred calling LLM during summarization phase: %s" 651 | % summary 652 | ) 653 | self.history.append(step_data) 654 | return base_agent.AgentInteractionResult( 655 | False, 656 | step_data, 657 | ) 658 | 659 | step_data["summary_prompt"] = summary_prompt 660 | step_data["summary"] = f"Action selected: {action}. {summary}" 661 | print("Summary: " + summary) 662 | step_data["summary_raw_response"] = raw_response 663 | 664 | self.history.append(step_data) 665 | return base_agent.AgentInteractionResult( 666 | False, 667 | step_data, 668 | ) 669 | -------------------------------------------------------------------------------- /AndroidWorld/env/json_action.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The android_world Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Represents an action for Android interaction, parsed from a JSON format.""" 16 | 17 | import dataclasses 18 | import json 19 | from typing import Optional 20 | 21 | 22 | _JSON_SEPARATORS = (',', ':') 23 | 24 | ANSWER = 'answer' 25 | CLICK = 'click' 26 | DOUBLE_TAP = 'double_tap' 27 | INPUT_TEXT = 'input_text' 28 | KEYBOARD_ENTER = 'keyboard_enter' 29 | LONG_PRESS = 'long_press' 30 | NAVIGATE_BACK = 'navigate_back' 31 | NAVIGATE_HOME = 'navigate_home' 32 | OPEN_APP = 'open_app' 33 | SCROLL = 'scroll' 34 | STATUS = 'status' 35 | SWIPE = 'swipe' 36 | UNKNOWN = 'unknown' 37 | WAIT = 'wait' 38 | 39 | _ACTION_TYPES = ( 40 | CLICK, 41 | DOUBLE_TAP, 42 | SCROLL, 43 | SWIPE, 44 | INPUT_TEXT, 45 | NAVIGATE_HOME, 46 | NAVIGATE_BACK, 47 | KEYBOARD_ENTER, 48 | OPEN_APP, 49 | STATUS, 50 | WAIT, 51 | LONG_PRESS, 52 | ANSWER, 53 | UNKNOWN, 54 | ) 55 | 56 | _SCROLL_DIRECTIONS = ('left', 'right', 'down', 'up') 57 | 58 | # Keys of JSON action. 59 | ACTION_TYPE = 'action_type' 60 | INDEX = 'index' 61 | X = 'x' 62 | Y = 'y' 63 | TEXT = 'text' 64 | DIRECTION = 'direction' 65 | APP_NAME = 'app_name' 66 | GOAL_STATUS = 'goal_status' 67 | 68 | 69 | @dataclasses.dataclass() 70 | class JSONAction: 71 | """Represents a parsed JSON action. 72 | 73 | # Example 74 | result_json = {'action_type': 'click', 'x': %d, 'y': %d} 75 | action = JSONAction(**result_json) 76 | 77 | Attributes: 78 | action_type: The action type. 79 | index: The index to click, if action is a click. Either an index or a 80 | should be provided. See x, y attributes below. 81 | x: The x position to click, if the action is a click. 82 | y: The y position to click, if the action is a click. 83 | text: The text to type, if action is type. 84 | direction: The direction to scroll, if action is scroll. 85 | goal_status: If the status is a 'status' type, indicates the status of the 86 | goal. 87 | app_name: The app name to launch, if the action type is 'open_app'. 88 | keycode: Keycode actions are necessary for an agent to interact with complex 89 | UI elements (like large textareas) that can't be accessed or controlled by 90 | simply taping, ensuring precise control over navigation and selection in 91 | the interface. 92 | """ 93 | 94 | action_type: Optional[str] = None 95 | index: Optional[str | int] = None 96 | x: Optional[int] = None 97 | y: Optional[int] = None 98 | text: Optional[str] = None 99 | direction: Optional[str] = None 100 | goal_status: Optional[str] = None 101 | app_name: Optional[str] = None 102 | keycode: Optional[str] = None 103 | 104 | # for aria-UI 105 | target: Optional[str] = None 106 | instruction: Optional[str] = None 107 | coords: Optional[tuple[int, int]] = None 108 | 109 | def __post_init__(self): 110 | if self.action_type not in _ACTION_TYPES: 111 | raise ValueError(f'Invalid action type: {self.action_type}') 112 | if self.index is not None: 113 | self.index = int(self.index) 114 | if self.x is not None or self.y is not None: 115 | raise ValueError('Either an index or a should be provided.') 116 | if self.direction and self.direction not in _SCROLL_DIRECTIONS: 117 | raise ValueError(f'Invalid scroll direction: {self.direction}') 118 | if self.text is not None and not isinstance(self.text, str): 119 | self.text = str(self.text) 120 | if self.keycode is not None and not self.keycode.startswith('KEYCODE_'): 121 | raise ValueError(f'Invalid keycode: {self.keycode}') 122 | 123 | def __repr__(self) -> str: 124 | properties = [] 125 | for key, value in self.__dict__.items(): 126 | if value is not None: 127 | if isinstance(value, float): 128 | value = f'{value:.3f}' 129 | properties.append(f'{key}={value!r}') 130 | return f"JSONAction({', '.join(properties)})" 131 | 132 | def __eq__(self, other): 133 | if isinstance(other, JSONAction): 134 | return _compare_actions(self, other) 135 | return False 136 | 137 | def __ne__(self, other): 138 | return not self.__eq__(other) 139 | 140 | def json_str(self) -> str: 141 | non_null = {} 142 | for key, value in self.__dict__.items(): 143 | if value is not None: 144 | non_null[key] = value 145 | return json.dumps(non_null, separators=_JSON_SEPARATORS) 146 | 147 | 148 | def _compare_actions(a: JSONAction, b: JSONAction) -> bool: 149 | """Compares two JSONActions. 150 | 151 | Args: 152 | a: The first action. 153 | b: The second action. 154 | 155 | Returns: 156 | If the actions are equal. 157 | """ 158 | # Ignore cases. 159 | if a.app_name is not None and b.app_name is not None: 160 | app_name_match = a.app_name.lower() == b.app_name.lower() 161 | else: 162 | app_name_match = a.app_name == b.app_name 163 | 164 | if a.text is not None and b.text is not None: 165 | text_match = a.text.lower() == b.text.lower() 166 | else: 167 | text_match = a.text == b.text 168 | 169 | # Compare the non-metadata fields. 170 | return ( 171 | app_name_match 172 | and text_match 173 | and a.action_type == b.action_type 174 | and a.index == b.index 175 | and a.x == b.x 176 | and a.y == b.y 177 | and a.keycode == b.keycode 178 | and a.direction == b.direction 179 | and a.goal_status == b.goal_status 180 | ) 181 | -------------------------------------------------------------------------------- /AndroidWorld/run_aria_ui.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The Aria-UI Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Copyright 2024 The android_world Authors. 16 | # 17 | # Licensed under the Apache License, Version 2.0 (the "License"); 18 | # you may not use this file except in compliance with the License. 19 | # You may obtain a copy of the License at 20 | # 21 | # http://www.apache.org/licenses/LICENSE-2.0 22 | # 23 | # Unless required by applicable law or agreed to in writing, software 24 | # distributed under the License is distributed on an "AS IS" BASIS, 25 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 26 | # See the License for the specific language governing permissions and 27 | # limitations under the License. 28 | 29 | """Run eval suite. 30 | 31 | The run.py module is used to run a suite of tasks, with configurable task 32 | combinations, environment setups, and agent configurations. You can run specific 33 | tasks or all tasks in the suite and customize various settings using the 34 | command-line flags. 35 | """ 36 | 37 | from collections.abc import Sequence 38 | import os 39 | 40 | from absl import app 41 | from absl import flags 42 | from absl import logging 43 | from android_world import checkpointer as checkpointer_lib 44 | from android_world import registry 45 | from android_world import suite_utils 46 | from android_world.agents import base_agent 47 | from android_world.agents import human_agent 48 | from android_world.agents import infer 49 | from android_world.agents import m3a 50 | from android_world.agents import random_agent 51 | from android_world.agents import seeact 52 | from android_world.agents import t3a 53 | from android_world.env import env_launcher 54 | from android_world.env import interface 55 | 56 | from android_world.agents import m3a_aria_ui 57 | 58 | logging.set_verbosity(logging.WARNING) 59 | 60 | os.environ["GRPC_VERBOSITY"] = "ERROR" # Only show errors 61 | os.environ["GRPC_TRACE"] = "none" # Disable tracing 62 | 63 | 64 | def _find_adb_directory() -> str: 65 | """Returns the directory where adb is located.""" 66 | potential_paths = [ 67 | os.path.expanduser("~/Library/Android/sdk/platform-tools/adb"), 68 | os.path.expanduser("~/Android/Sdk/platform-tools/adb"), 69 | ] 70 | for path in potential_paths: 71 | if os.path.isfile(path): 72 | return path 73 | raise EnvironmentError( 74 | "adb not found in the common Android SDK paths. Please install Android" 75 | " SDK and ensure adb is in one of the expected directories. If it's" 76 | " already installed, point to the installed location." 77 | ) 78 | 79 | 80 | _ADB_PATH = flags.DEFINE_string( 81 | "adb_path", 82 | _find_adb_directory(), 83 | "Path to adb. Set if not installed through SDK.", 84 | ) 85 | _EMULATOR_SETUP = flags.DEFINE_boolean( 86 | "perform_emulator_setup", 87 | False, 88 | "Whether to perform emulator setup. This must be done once and only once" 89 | " before running Android World. After an emulator is setup, this flag" 90 | " should always be False.", 91 | ) 92 | _DEVICE_CONSOLE_PORT = flags.DEFINE_integer( 93 | "console_port", 94 | 5554, 95 | "The console port of the running Android device. This can usually be" 96 | " retrieved by looking at the output of `adb devices`. In general, the" 97 | " first connected device is port 5554, the second is 5556, and" 98 | " so on.", 99 | ) 100 | 101 | _SUITE_FAMILY = flags.DEFINE_enum( 102 | "suite_family", 103 | registry.TaskRegistry.ANDROID_WORLD_FAMILY, 104 | [ 105 | # Families from the paper. 106 | registry.TaskRegistry.ANDROID_WORLD_FAMILY, 107 | registry.TaskRegistry.MINIWOB_FAMILY_SUBSET, 108 | # Other families for more testing. 109 | registry.TaskRegistry.MINIWOB_FAMILY, 110 | registry.TaskRegistry.ANDROID_FAMILY, 111 | registry.TaskRegistry.INFORMATION_RETRIEVAL_FAMILY, 112 | ], 113 | "Suite family to run. See registry.py for more information.", 114 | ) 115 | _TASK_RANDOM_SEED = flags.DEFINE_integer( 116 | "task_random_seed", 30, "Random seed for task randomness." 117 | ) 118 | 119 | _TASKS = flags.DEFINE_list( 120 | "tasks", 121 | None, 122 | "List of specific tasks to run in the given suite family. If None, run all" 123 | " tasks in the suite family.", 124 | ) 125 | _N_TASK_COMBINATIONS = flags.DEFINE_integer( 126 | "n_task_combinations", 127 | 1, 128 | "Number of task instances to run for each task template.", 129 | ) 130 | 131 | _CHECKPOINT_DIR = flags.DEFINE_string( 132 | "checkpoint_dir", 133 | "", 134 | "The directory to save checkpoints and resume evaluation from. If the" 135 | " directory contains existing checkpoint files, evaluation will resume from" 136 | " the latest checkpoint. If the directory is empty or does not exist, a new" 137 | " directory will be created.", 138 | ) 139 | _OUTPUT_PATH = flags.DEFINE_string( 140 | "output_path", 141 | os.path.expanduser("~/android_world/runs"), 142 | "The path to save results to if not resuming from a checkpoint is not" " provided.", 143 | ) 144 | 145 | # Agent specific. 146 | _AGENT_NAME = flags.DEFINE_string("agent_name", "m3a_aria_ui", help="Agent name.") 147 | 148 | _FIXED_TASK_SEED = flags.DEFINE_boolean( 149 | "fixed_task_seed", 150 | False, 151 | "Whether to use the same task seed when running multiple task combinations" 152 | " (n_task_combinations > 1).", 153 | ) 154 | 155 | 156 | # MiniWoB is very lightweight and new screens/View Hierarchy load quickly. 157 | _MINIWOB_TRANSITION_PAUSE = 0.2 158 | 159 | # Additional guidelines for the MiniWob tasks. 160 | _MINIWOB_ADDITIONAL_GUIDELINES = [ 161 | ( 162 | "This task is running in a mock app, you must stay in this app and" 163 | " DO NOT use the `navigate_home` action." 164 | ), 165 | ] 166 | 167 | 168 | def _get_agent( 169 | env: interface.AsyncEnv, 170 | family: str | None = None, 171 | ) -> base_agent.EnvironmentInteractingAgent: 172 | """Gets agent.""" 173 | print("Initializing agent...") 174 | agent = None 175 | 176 | agent = m3a_aria_ui.M3A(env, infer.Gpt4Wrapper("gpt-4o")) 177 | 178 | if ( 179 | agent.name in ["M3A", "T3A", "SeeAct"] 180 | and family 181 | and family.startswith("miniwob") 182 | and hasattr(agent, "set_task_guidelines") 183 | ): 184 | agent.set_task_guidelines(_MINIWOB_ADDITIONAL_GUIDELINES) 185 | agent.name = _AGENT_NAME.value 186 | 187 | return agent 188 | 189 | 190 | def _main() -> None: 191 | """Runs eval suite and gets rewards back.""" 192 | env = env_launcher.load_and_setup_env( 193 | console_port=_DEVICE_CONSOLE_PORT.value, 194 | emulator_setup=_EMULATOR_SETUP.value, 195 | adb_path=_ADB_PATH.value, 196 | ) 197 | 198 | n_task_combinations = _N_TASK_COMBINATIONS.value 199 | task_registry = registry.TaskRegistry() 200 | suite = suite_utils.create_suite( 201 | task_registry.get_registry(family=_SUITE_FAMILY.value), 202 | n_task_combinations=n_task_combinations, 203 | seed=_TASK_RANDOM_SEED.value, 204 | tasks=_TASKS.value, 205 | use_identical_params=_FIXED_TASK_SEED.value, 206 | ) 207 | suite.suite_family = _SUITE_FAMILY.value 208 | 209 | agent = _get_agent(env, _SUITE_FAMILY.value) 210 | 211 | if _SUITE_FAMILY.value.startswith("miniwob"): 212 | # MiniWoB pages change quickly, don't need to wait for screen to stabilize. 213 | agent.transition_pause = _MINIWOB_TRANSITION_PAUSE 214 | else: 215 | agent.transition_pause = None 216 | 217 | if _CHECKPOINT_DIR.value: 218 | checkpoint_dir = _CHECKPOINT_DIR.value 219 | else: 220 | checkpoint_dir = checkpointer_lib.create_run_directory(_OUTPUT_PATH.value) 221 | 222 | print( 223 | f"Starting eval with agent {_AGENT_NAME.value} and writing to" 224 | f" {checkpoint_dir}" 225 | ) 226 | suite_utils.run( 227 | suite, 228 | agent, 229 | checkpointer=checkpointer_lib.IncrementalCheckpointer(checkpoint_dir), 230 | demo_mode=False, 231 | ) 232 | print( 233 | f"Finished running agent {_AGENT_NAME.value} on {_SUITE_FAMILY.value}" 234 | f" family. Wrote to {checkpoint_dir}." 235 | ) 236 | env.close() 237 | 238 | 239 | def main(argv: Sequence[str]) -> None: 240 | del argv 241 | _main() 242 | 243 | 244 | if __name__ == "__main__": 245 | app.run(main) 246 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | Project Logo 3 |
4 | 5 |
6 | 7 | [🤗 Aria-UI Demo (Try it out!)](https://huggingface.co/spaces/Aria-UI/Aria-UI) • 8 | [🤗 Aria-UI Models](https://huggingface.co/Aria-UI/Aria-UI-base) • 9 | [🤗 Aria-UI Context-aware Models](https://huggingface.co/Aria-UI/Aria-UI-context-aware) • 10 | [🤗 Aria-UI Datasets](https://huggingface.co/datasets/Aria-UI/Aria-UI_Data) • 11 | 12 | [🤗 Aria-UI Context-aware Datasets](https://huggingface.co/datasets/Aria-UI/Aria-UI_Context-aware_Data) • 13 | [🌐 Project Page](https://ariaui.github.io) • 14 | [📝 Paper](https://arxiv.org/abs/2412.16256) • 15 | [🗃️ Aria-UI at ModelScope](https://modelscope.cn/models/AriaUI/Aria-UI/) 16 |
17 | 18 | --- 19 | 20 | 21 | ## 📰 News 22 | - **[2025-02-08]** We released all context-aware episode training data of Aria-UI! It has around **992K** instruction-output pairs. Try it in your exicting projects at [🤗 Aria-UI Context-aware Datasets](https://huggingface.co/datasets/Aria-UI/Aria-UI_Context-aware_Data). 23 | 24 | - **[2025-01-23]** We released the context-aware version of Aria-UI! Check it at [🤗 Aria-UI Context-aware Models](https://huggingface.co/Aria-UI/Aria-UI-context-aware). It typically brings stronger performances under dynamic agent tasks like `AndroidWorld` and `OSWorld`. 25 | 26 | - **[2025-01-10]** We are excited to release the **M3A Agent powered by Aria-UI**, for **AndroidWorld**! Experience enhanced task success rates and seamless integration with the latest in grounding instruction understanding. Check it out under `AndroidWorld/`. 27 | 28 | 29 | https://github.com/user-attachments/assets/48c61813-7f63-4985-a3c9-0e325ca764fe 30 | 31 | ## 🌇 Overview 32 | 33 | ✨ **Versatile Grounding Instruction Understanding:** 34 | Aria-UI handles diverse grounding instructions, excelling in interpreting varied formats, ensuring robust adaptability across dynamic scenarios or when paired with diverse planning agents. 35 | 36 | 📝 **Context-aware Grounding:** 37 | Aria-UI effectively leverages historical input, whether in pure text or text-image-interleaved formats, to improve grounding accuracy. 38 | 39 | ⚡ **Lightweight and Fast:** 40 | Aria-UI is a mixture-of-expert model with 3.9B activated parameters per token. It efficiently encodes GUI input of variable sizes and aspect ratios, with ultra-resolution support. 41 | 42 | 🎉 **Superior Performances:** 43 | Aria-UI sets new state-of-the-art results on offline and online agent benchmarks. 44 | 🏆 **1st place** on **AndroidWorld** with **44.8%** task success rate and 45 | 🥉 **3rd place** on **OSWorld** with **15.2%** task success rate (Dec. 2024). 46 | 47 |
48 | Aria-UI Overview 49 |
50 | 51 | ## 🚀 Quick Start 52 | 53 | ### Installation 54 | ```bash 55 | pip install transformers==4.45.0 accelerate==0.34.1 sentencepiece==0.2.0 torchvision requests torch Pillow 56 | pip install flash-attn --no-build-isolation 57 | # For better inference performance, you can install grouped-gemm, which may take 3-5 minutes to install 58 | pip install grouped_gemm==0.1.6 59 | ``` 60 | 61 | ### Inference with vllm (strongly recommended) 62 | First, make sure you install the appropriate version (for example, `vllm==0.6.6.dev3+g866fa455`) of vLLM so that it supports Aria-UI 63 | ```bash 64 | export VLLM_COMMIT=866fa4550d572f4ff3521ccf503e0df2e76591a1 # use full commit hash from the main branch 65 | pip install https://wheels.vllm.ai/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl 66 | ``` 67 | 71 | 72 | Here is a code snippet for Aria-UI with vllm. 73 | ```python 74 | from PIL import Image, ImageDraw 75 | from transformers import AutoTokenizer 76 | from vllm import LLM, SamplingParams 77 | import ast 78 | model_path = "Aria-UI/Aria-UI-base" 79 | def main(): 80 | llm = LLM( 81 | model=model_path, 82 | tokenizer_mode="slow", 83 | dtype="bfloat16", 84 | trust_remote_code=True, 85 | ) 86 | tokenizer = AutoTokenizer.from_pretrained( 87 | model_path, trust_remote_code=True, use_fast=False 88 | ) 89 | instruction = "Try Aria." 90 | messages = [ 91 | { 92 | "role": "user", 93 | "content": [ 94 | {"type": "image"}, 95 | { 96 | "type": "text", 97 | "text": "Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + instruction, 98 | } 99 | ], 100 | } 101 | ] 102 | message = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 103 | outputs = llm.generate( 104 | { 105 | "prompt_token_ids": message, 106 | "multi_modal_data": { 107 | "image": [ 108 | Image.open("examples/aria.png"), 109 | ], 110 | "max_image_size": 980, # [Optional] The max image patch size, default `980` 111 | "split_image": True, # [Optional] whether to split the images, default `True` 112 | }, 113 | }, 114 | sampling_params=SamplingParams(max_tokens=50, top_k=1, stop=["<|im_end|>"]), 115 | ) 116 | for o in outputs: 117 | generated_tokens = o.outputs[0].token_ids 118 | response = tokenizer.decode(generated_tokens, skip_special_tokens=True) 119 | print(response) 120 | coords = ast.literal_eval(response.replace("<|im_end|>", "").replace("```", "").replace(" ", "").strip()) 121 | return coords 122 | if __name__ == "__main__": 123 | main() 124 | ``` 125 | ### Inference with Transfomrers (not recommended) 126 | You can also use the original `transformers` API for Aria-UI. For instance: 127 | ```python 128 | import argparse 129 | import torch 130 | import os 131 | import json 132 | from tqdm import tqdm 133 | import time 134 | from PIL import Image, ImageDraw 135 | from transformers import AutoModelForCausalLM, AutoProcessor 136 | import ast 137 | 138 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 139 | 140 | model_path = "Aria-UI/Aria-UI-base" 141 | model = AutoModelForCausalLM.from_pretrained( 142 | model_path, 143 | device_map="auto", 144 | torch_dtype=torch.bfloat16, 145 | trust_remote_code=True, 146 | ) 147 | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 148 | image_file = "./examples/aria.png" 149 | instruction = "Try Aria." 150 | image = Image.open(image_file).convert("RGB") 151 | 152 | messages = [ 153 | { 154 | "role": "user", 155 | "content": [ 156 | {"text": None, "type": "image"}, 157 | {"text": instruction, "type": "text"}, 158 | ], 159 | } 160 | ] 161 | text = processor.apply_chat_template(messages, add_generation_prompt=True) 162 | inputs = processor(text=text, images=image, return_tensors="pt") 163 | inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) 164 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 165 | with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16): 166 | output = model.generate( 167 | **inputs, 168 | max_new_tokens=50, 169 | stop_strings=["<|im_end|>"], 170 | tokenizer=processor.tokenizer, 171 | # do_sample=True, 172 | # temperature=0.9, 173 | ) 174 | output_ids = output[0][inputs["input_ids"].shape[1] :] 175 | response = processor.decode(output_ids, skip_special_tokens=True) 176 | print(response) 177 | 178 | coords = ast.literal_eval(response.replace("<|im_end|>", "").replace("```", "").replace(" ", "").strip()) 179 | ``` 180 | 181 | ## Citation 182 | 183 | If you find our work helpful, please consider citing: 184 | 185 | ```bibtex 186 | @article{ariaui, 187 | title={Aria-UI: Visual Grounding for GUI Instructions}, 188 | author={Yuhao Yang and Yue Wang and Dongxu Li and Ziyang Luo and Bei Chen and Chao Huang and Junnan Li}, 189 | year={2024}, 190 | journal={arXiv preprint arXiv:2412.16256}, 191 | } 192 | ``` 193 | 194 | ## Acknowledgments 195 | 196 | We thank [Tianbao Xie](https://tianbaoxie.com), [Yiheng Xu](https://yihengxu.com) for their valuable discussion and suggestions. 197 | 198 | ## More demos 199 | 200 | 201 | https://github.com/user-attachments/assets/cf7f26bf-d5ad-4146-9334-bb64a3ab48a6 202 | 203 | https://github.com/user-attachments/assets/1a5bfd18-0a1d-49d7-99a6-597329b0812d 204 | -------------------------------------------------------------------------------- /aria_ui_hf.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | from PIL import Image 4 | from transformers import AutoModelForCausalLM, AutoProcessor 5 | import ast 6 | from utils import draw_coord, resize_image 7 | 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | 11 | model_path = "Aria-UI/Aria-UI-base" 12 | 13 | model = AutoModelForCausalLM.from_pretrained( 14 | model_path, 15 | device_map="auto", 16 | torch_dtype=torch.bfloat16, 17 | trust_remote_code=True, 18 | ) 19 | processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) 20 | 21 | image_file = "./examples/aria.png" 22 | instruction = "Try Aria." 23 | 24 | image = Image.open(image_file).convert("RGB") 25 | 26 | # NOTE: using huggingface on a single 80GB GPU, we resize the image to 1920px on the long side to prevent OOM. this is unnecessary with vllm. 27 | image = resize_image(image, long_size=1920) 28 | 29 | messages = [ 30 | { 31 | "role": "user", 32 | "content": [ 33 | {"text": None, "type": "image"}, 34 | {"text": instruction, "type": "text"}, 35 | ], 36 | } 37 | ] 38 | 39 | text = processor.apply_chat_template(messages, add_generation_prompt=True) 40 | inputs = processor(text=text, images=image, return_tensors="pt") 41 | inputs["pixel_values"] = inputs["pixel_values"].to(model.dtype) 42 | inputs = {k: v.to(model.device) for k, v in inputs.items()} 43 | 44 | with torch.inference_mode(), torch.amp.autocast("cuda", dtype=torch.bfloat16): 45 | output = model.generate( 46 | **inputs, 47 | max_new_tokens=50, 48 | stop_strings=["<|im_end|>"], 49 | tokenizer=processor.tokenizer, 50 | # do_sample=True, 51 | # temperature=0.9, 52 | ) 53 | 54 | output_ids = output[0][inputs["input_ids"].shape[1] :] 55 | response = processor.decode(output_ids, skip_special_tokens=True) 56 | print(response) 57 | 58 | coords = ast.literal_eval(response.replace("<|im_end|>", "").replace("```", "").replace(" ", "").strip()) 59 | # Save the image with the predicted coordinates 60 | image = draw_coord(image, coords) 61 | image.save("output.png") -------------------------------------------------------------------------------- /aria_ui_vllm.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from transformers import AutoTokenizer 3 | from vllm import LLM, SamplingParams 4 | import ast 5 | from utils import draw_coord 6 | 7 | 8 | model_path = "Aria-UI/Aria-UI-base" 9 | 10 | def main(): 11 | llm = LLM( 12 | model=model_path, 13 | tokenizer_mode="slow", 14 | dtype="bfloat16", 15 | trust_remote_code=True, 16 | ) 17 | 18 | tokenizer = AutoTokenizer.from_pretrained( 19 | model_path, trust_remote_code=True, use_fast=False 20 | ) 21 | 22 | instruction = "Try Aria." 23 | image_path = "examples/aria.png" 24 | 25 | messages = [ 26 | { 27 | "role": "user", 28 | "content": [ 29 | {"type": "image"}, 30 | { 31 | "type": "text", 32 | "text": "Given a GUI image, what are the relative (0-1000) pixel point coordinates for the element corresponding to the following instruction or description: " + instruction, 33 | } 34 | ], 35 | } 36 | ] 37 | 38 | message = tokenizer.apply_chat_template(messages, add_generation_prompt=True) 39 | 40 | outputs = llm.generate( 41 | { 42 | "prompt_token_ids": message, 43 | "multi_modal_data": { 44 | "image": [ 45 | Image.open(image_path), 46 | ], 47 | "max_image_size": 980, # [Optional] The max image patch size, default `980`, maximum `980`, the image size for splitted blocks 48 | "split_image": True, # [Optional] whether to split the images, default `True` 49 | }, 50 | }, 51 | sampling_params=SamplingParams(max_tokens=50, top_k=1, stop=["<|im_end|>"]), 52 | ) 53 | 54 | for o in outputs: 55 | generated_tokens = o.outputs[0].token_ids 56 | response = tokenizer.decode(generated_tokens, skip_special_tokens=True) 57 | print(response) 58 | coords = ast.literal_eval(response.replace("<|im_end|>", "").replace("```", "").replace(" ", "").strip()) 59 | image = draw_coord(Image.open("examples/aria.png"), coords) 60 | image.save("output.png") 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /assets/aria_ui_framework_v4.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/aria_ui_framework_v4.pdf -------------------------------------------------------------------------------- /assets/aria_ui_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/aria_ui_logo.png -------------------------------------------------------------------------------- /assets/logo_long.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/logo_long.png -------------------------------------------------------------------------------- /assets/overall.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/overall.png -------------------------------------------------------------------------------- /assets/performance_spider.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/performance_spider.pdf -------------------------------------------------------------------------------- /assets/seo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/assets/seo.png -------------------------------------------------------------------------------- /examples/aria.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/AriaUI/Aria-UI/2f87eb7a9b8b910c9f709ac2fe7725afd6799fad/examples/aria.png -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | 3 | def draw_coord(image, coords): 4 | width, height = image.size 5 | draw = ImageDraw.Draw(image) 6 | 7 | abs_coords = [coords[0]/1000*width, coords[1]/1000*height] 8 | 9 | draw.circle(abs_coords, 10, fill="red") 10 | return image 11 | 12 | def resize_image(image, long_size=1920): 13 | width, height = image.size 14 | if width > height: 15 | new_width = long_size 16 | new_height = int(height * long_size / width) 17 | else: 18 | new_height = long_size 19 | new_width = int(width * long_size / height) 20 | return image.resize((new_width, new_height)) --------------------------------------------------------------------------------