├── .DS_Store ├── README.md ├── collection ├── README.md ├── genesis_rm.py ├── mobile_runner.py ├── random_walk_aw.py ├── random_walk_web.py ├── run_mobile_runner.py └── run_random_walk_aw.py ├── evaluation ├── .DS_Store ├── analysis │ └── diversity_analysis.ipynb ├── android_control │ ├── ac_eval.py │ ├── internvl2_inference.py │ ├── qwen2vl_inference.py │ └── run_ac_inference.sh ├── android_world │ └── README.md └── eval_json_files │ ├── ac_high_processing.jsonl │ ├── ac_low_processing.jsonl │ ├── android_control_test_data.jsonl │ └── android_control_test_subsplits.json ├── faq.md └── static ├── OS-Genesis-Badge.png └── OS-Genesis.png /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/.DS_Store -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # OS-Genesis 2 | 3 | 4 | overview 5 | 6 | 7 | [![arXiv](https://img.shields.io/badge/arXiv-2412.19723-b31b1b.svg)](https://arxiv.org/abs/2412.19723) 8 | ![License](https://img.shields.io/badge/License-MIT-blue) 9 | [![Paper page](https://huggingface.co/datasets/huggingface/badges/resolve/main/paper-page-sm.svg)](https://huggingface.co/papers/2412.19723) 10 | [![Generic badge](https://img.shields.io/badge/WeChat-机器之心-green.svg?logo=wechat)](https://mp.weixin.qq.com/s/_gu3NSCpAbAE1A8mEhGD7Q) 11 | 12 | 15 | 16 | 17 | This repository contains the code and data for the ACL 2025 paper [OS-Genesis: Automating GUI Agent Trajectory Construction via Reverse Task Synthesis](https://arxiv.org/abs/2412.19723). 18 | > We are uploading the data and checkpoints. Due to bandwidth limitations, this will take some time. Stay tuned! 19 | 20 | 21 | ## Overview 22 | 23 | We introduce OS-Genesis, an interaction-driven pipeline for synthesizing high-quality and diverse GUI agent trajectory data without human supervision or predefined tasks. By leveraging reverse task synthesis and a trajectory reward model, OS-Genesis enables effective end2end training of GUI agents. 24 | 25 | 26 | 27 | overview 28 | 29 | 30 | ## Training 31 | 32 | For details and operations of the training, please refer to the [InternVL2 documentation](https://internvl.readthedocs.io/en/latest/get_started/installation.html) and [Qwen2-VL](https://github.com/QwenLM/Qwen2-VL). 33 | 34 | ## Evaluation 35 | ### AndroidControl 36 | To evaluate the AndroidControl Benchmark, please follow the steps below: 37 | 38 | 1. **Clone the GitHub Repository:** 39 | 40 | ``` 41 | git clone https://github.com/OS-Copilot/OS-Genesis.git 42 | ``` 43 | 44 | 2. **Inference:** 45 | ``` 46 | cd OS-Genesis/evaluation/android_control 47 | bash run_ac_inference.sh $dataset $checkpoint 48 | ``` 49 | 50 | 3. **Evaluation:** 51 | ``` 52 | pyhton ac_eval.py 53 | ``` 54 | 55 | ## Mobile 56 | ### AndroidControl 57 | 58 | | Model Name | Base Model | Training Data | HF Link | 59 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: | 60 | | OS-Genesis-4B-AC | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-AC) | 61 | | OS-Genesis-7B-AC | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-AC) | 62 | | OS-Genesis-8B-AC | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-AC) | 63 | 64 | ### AndroidWorld 65 | 66 | | Model Name | Base Model | Training Data | HF Link | 67 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: | 68 | | OS-Genesis-4B-AW | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-AW) | 69 | | OS-Genesis-7B-AW | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-AW) | 70 | | OS-Genesis-8B-AW | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-AW) | 71 | 72 | 73 | ## Web 74 | 75 | | Model Name | Base Model | Training Data | HF Link | 76 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: | 77 | | OS-Genesis-4B-WA | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-WA) | 78 | | OS-Genesis-7B-WA | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-WA) | 79 | | OS-Genesis-8B-WA | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-WA) | 80 | 81 | 82 | ## More Resources 83 | 84 | ### Raw collected triples 85 | 86 | In addition to our complete trajectory data on HuggingFace, we also provide collected raw `` triples. You can use them to reproduce the process of reverse task synthesis directly, without re-collecting them from emulators yourself 😄. The screenshots and corresponding texts (with SoM info contained) are provided below: 87 | 88 | | Data Type | Screenshots | Data JSON | 89 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | 90 | | Mobile | [Screenshots](https://drive.google.com/file/d/1ILyz_-DDOdAk32kue1lEPaV50YzQ5c4v/view?usp=sharing) | [Data JSON](https://drive.google.com/file/d/1dSxNf-co4LGh93NoiUgWKdbcf8Mo_VWG/view?usp=sharing) | 91 | | Web | [Screenshots](https://drive.google.com/file/d/1X2QktZ51OUofZ43vDGB4RuAPlXbdf5ua/view?usp=sharing) | [Data JSON](https://drive.google.com/file/d/1mDxhonGnd3wZbNQgWMVpYEkPW26_FVg8/view?usp=sharing) | 92 | 93 | Feel free to email me if you require additional data of this kind. 94 | 95 | ## FAQ ❓ 96 | 97 | We have collected some questions from emails, Hugging Face, and WeChat communications. Please check the [FAQ](https://github.com/OS-Copilot/OS-Genesis/blob/main/faq.md) 🤖 98 | 99 | ## Citation 📖 100 | 101 | 🫶 If you are interested in our work or find this repository / our data helpful, please consider using the following citation format when referencing our paper: 102 | 103 | ```bibtex 104 | @article{sun2024genesis, 105 | title={OS-Genesis: Automating GUI Agent Trajectory Construction via Reverse Task Synthesis}, 106 | author={Sun, Qiushi and Cheng, Kanzhi and Ding, Zichen and Jin, Chuanyang and Wang, Yian and Xu, Fangzhi and Wu, Zhenyu and Jia, Chengyou and Chen, Liheng and Liu, Zhoumianze and others}, 107 | journal={arXiv preprint arXiv:2412.19723}, 108 | year={2024} 109 | } 110 | ``` 111 | -------------------------------------------------------------------------------- /collection/README.md: -------------------------------------------------------------------------------- 1 | # README 2 | 3 | In addition to the collected-synthesized data, we here provide scripts collected from the environment for Reverse Task Synthesis, helping you extend OS-Genesis to more scenarios or synthesize more data as desired. 4 | 5 | # Mobile 6 | 7 | ## Random walk in the AndroidWorld Environment 8 | First, clone the [AndroidWorld](https://github.com/google-research/android_world) repository, then place `random_walk_aw.py` and `run_random_walk_aw.py` in its directory. 9 | 10 | `random_walk_aw.py` provides our implementation logic for random walking in the environment to obtain `` triples. You can use `python run_random_walk_aw.py` to collect large-scale interaction triples. 11 | ## Reverse Task Synthesis 12 | 13 | Work in progress. 14 | 15 | ## Trajectory Construction 16 | 17 | 1. Install the AndroidWorld Environment as described in: https://github.com/google-research/android_world 18 | 2. Move the scripts to the AndroidWorld directory: ``android_env/android_world`` 19 | 3. Run the following command to collect the data: 20 | ```bash 21 | python mobile_runner.py 22 | ``` 23 | 24 | # Desktop 25 | 26 | ## Random walk in the WebArena Environment 27 | First, configure [WebArena](https://github.com/web-arena-x/webarena) and open the specified ports to access the website, then place `random_walk_web.py` in its directory. 28 | 29 | `random_walk_web.py` provides our implementation logic for random walking in the environment to obtain triples. You can use `python random_walk_web.py` to collect large-scale interaction triples. 30 | 31 | # Reward Model 32 | We provide an example of the Reward Model we use in `genesis_rm.py`. For more information, please refer to the original paper. -------------------------------------------------------------------------------- /collection/genesis_rm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import requests 4 | import time 5 | from PIL import Image 6 | import base64 7 | import io 8 | import json 9 | import re 10 | from tqdm import tqdm 11 | import numpy as np 12 | 13 | 14 | def encode_image(image_content): 15 | return base64.b64encode(image_content).decode('utf-8') 16 | 17 | 18 | def convert_image_to_base64(image_path): 19 | # Open the image file 20 | with open(image_path, 'rb') as f: 21 | # Load the image using PIL 22 | image_bytes = f.read() 23 | # Encode the image bytes to base64 24 | encoded_image = encode_image(image_bytes) 25 | return encoded_image 26 | 27 | 28 | def call_llm(model_name, payload): 29 | headers = { 30 | "Content-Type": "application/json", 31 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" 32 | } 33 | print("Generating content with GPT model: {}".format(model_name)) 34 | response = requests.post( 35 | "https://api.openai.com/v1/chat/completions", 36 | headers=headers, 37 | json=payload 38 | ) 39 | if response.status_code != 200: 40 | if response.json()['error']['code'] == "context_length_exceeded": 41 | print("Context length exceeded. Retrying with a smaller context.") 42 | payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:] 43 | retry_response = requests.post( 44 | "https://api.openai.com/v1/chat/completions", 45 | headers=headers, 46 | json=payload 47 | ) 48 | if retry_response.status_code != 200: 49 | print( 50 | "Failed to call LLM even after attempt on shortening the history: " + retry_response.text) 51 | return "" 52 | 53 | print("Failed to call LLM: " + response.text) 54 | time.sleep(2) 55 | return "" 56 | else: 57 | return response.json()['choices'][0]['message']['content'] 58 | 59 | 60 | gpt_annot_traj = json.load(open('/Users/cckevin/Desktop/gpt_annot_traj_v2.json', 'r')) 61 | imgs_dir = '/Users/cckevin/Desktop/gpt_annot_traj_v2/gpt_screenshots_v2' 62 | 63 | system_prompt = """ 64 | You are an expert in evaluating Android GUI agent task trajectories. Your task is to assess the quality and effectiveness of task trajectories for GUI manipulation tasks. 65 | 66 | A trajectory consists of the following components: 67 | 1. High-level Instruction: Describes the user's intended task (e.g., "Create a new travel itinerary document in a folder"). 68 | 2. Action History: Includes two key parts: 69 | - Low-level Actions & Summaries: A sequence of actions, where each step includes: 70 | - The executed action. 71 | - A summary of the action, indicating the effect after the action is executed. 72 | - GUI Screenshots: Screenshots captured when the last three actions are executed: the third-to-last, second-to-last, and final actions (if there are at least three actions; otherwise, include all actions). 73 | 74 | When evaluating a trajectory, consider these key aspects: 75 | 76 | ### Evaluation Criteria: 77 | 1. Trajectory Coherence: 78 | - Do the low-level steps and corresponding actions follow a logical sequence toward the goal? 79 | - Are the actions clearly described and specific? 80 | - Are there redundant or unnecessary actions? 81 | 82 | 2. Task Completion: 83 | - Does the trajectory successfully achieve the instructed task? 84 | - Are all necessary interactions completed? 85 | - Are error cases handled appropriately? 86 | 87 | ### Scoring Guidelines: 88 | Rate the trajectory on a scale of 1 to 5 based on the evaluation criteria: 89 | 90 | - 5: The task is perfectly completed, successfully executing multiple actions to achieve the goal. The sequence is logically clear with no noticeable redundancies. 91 | - 4: The task is mostly completed, successfully executing multiple actions. However, due to challenges or ambiguities in the instructions, the completion is not perfect, or there are inefficiencies in the process. 92 | - 3: The task is partially completed, with some successful actions executed. However, due to task or environmental constraints, the goal is not fully achieved, or the sequence ends in a loop or error. 93 | - 2: Only a few actions are executed. Although there is an attempt to complete the task, the trajectory deviates from the goal early on or demonstrates significant inefficiencies in execution and logic. 94 | - 1: The task fails completely, with no meaningful actions executed at the start. The sequence either falls into an immediate deadlock, a repetitive loop, or demonstrates no value in completing the task. 95 | 96 | Note: If the task is relatively complex, but the trajectory demonstrates valuable attempts, even if the task is not fully completed, consider adjusting the score upward. However, if the task is complex but the trajectory fails to perform actions that contribute meaningfully to task completion, no extra points should be awarded. 97 | 98 | ### Response Format: 99 | Format your response into two lines as shown below: 100 | Reason: 101 | Score: 102 | """ 103 | 104 | for i, item in enumerate(gpt_annot_traj[:]): 105 | 106 | if "reward" in item: 107 | print("processed") 108 | continue 109 | 110 | instruction = item["instruction"] 111 | 112 | action_history = [] 113 | for j, action in enumerate(item["steps"]): 114 | if "summary" in action: 115 | summary = action["summary"] 116 | else: 117 | summary = action["reason"] 118 | summary = f"Step {j+1}: {summary}" 119 | action_history.append(summary) 120 | action_history_text = '\n'.join(action_history) 121 | 122 | action_screenshots = [] 123 | for action in item["steps"][-3:]: 124 | screenshot_path = os.path.join(imgs_dir, action["screen_before"]) 125 | screenshot = convert_image_to_base64(screenshot_path) 126 | action_screenshots.append(screenshot) 127 | 128 | traj_prompt = f"Instruction :{instruction}\nAction History:\n{action_history_text}\nThe last three screenshots are provided." 129 | 130 | messages = [] 131 | 132 | messages.append({ 133 | "role": "system", 134 | "content": [ 135 | { 136 | "type": "text", 137 | "text": system_prompt 138 | }, 139 | ] 140 | }) 141 | 142 | # Prediction example 143 | action_text_image = [] 144 | for img in action_screenshots: 145 | action_text_image.append( 146 | { 147 | "type": "image_url", 148 | "image_url": { 149 | "url": f"data:image/png;base64,{img}", 150 | "detail": "high" 151 | } 152 | } 153 | ) 154 | 155 | action_text_image.append( 156 | { 157 | "type": "text", 158 | "text": traj_prompt 159 | } 160 | ) 161 | 162 | messages.append({ 163 | "role": "user", 164 | "content": action_text_image 165 | }) 166 | 167 | print(traj_prompt) 168 | 169 | model_name = "gpt-4o-2024-08-06" 170 | try_num = 0 171 | answer = None 172 | while try_num < 5: 173 | try_num += 1 174 | try: 175 | response = call_llm(model_name, { 176 | "model": model_name, 177 | "messages": messages, 178 | "max_tokens": 1500, 179 | "top_p": 0.9, 180 | "temperature": 0.5 181 | }) 182 | except: 183 | print("error call") 184 | time.sleep(1.0) 185 | continue 186 | try: 187 | print(response) 188 | reason_match = re.search(r"Reason:\s*(.+?)\s*Score:", response, re.DOTALL) 189 | score_match = re.search(r"Score:\s*(\d+)", response) 190 | reason = reason_match.group(1).strip() if reason_match else None 191 | score = int(score_match.group(1)) if score_match else None 192 | 193 | if reason and score and 1 <= score <= 5: 194 | item["reward_reason"] = reason 195 | item["reward"] = score 196 | break # Successfully parsed, exit loop 197 | else: 198 | print("Invalid response format or score out of range, retrying...") 199 | time.sleep(1.0) 200 | 201 | except json.JSONDecodeError: 202 | # If response is not valid JSON, continue to generate 203 | print("Invalid response received, retrying...") 204 | time.sleep(1.0) 205 | 206 | num_processed = len([item for item in gpt_annot_traj if ("reward" in item)]) 207 | print("Num of total: {} Num of success: {}".format(len(gpt_annot_traj), num_processed)) 208 | if i % 20 == 0: 209 | json.dump(gpt_annot_traj, open('/Users/cckevin/Desktop/gpt_annot_traj_v2_reward.json', 'w')) 210 | 211 | json.dump(gpt_annot_traj, open('/Users/cckevin/Desktop/gpt_annot_traj_v2_reward.json', 'w')) 212 | print("Save") -------------------------------------------------------------------------------- /collection/mobile_runner.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The android_world Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Build trajectories of OS-Genesis. 16 | 17 | """ 18 | 19 | from collections.abc import Sequence 20 | import os 21 | import random 22 | from typing import Type 23 | import json 24 | import time 25 | import uuid 26 | from PIL import Image 27 | import numpy as np 28 | import re 29 | import pickle 30 | 31 | from absl import app 32 | from absl import flags 33 | from absl import logging 34 | from android_world import registry 35 | from android_world.agents import infer 36 | from android_world.agents import t3a, m3a 37 | # from android_world.agents import m3a_origin 38 | from android_world.agents import m3a_utils 39 | from android_world.agents.t3a import _generate_ui_elements_description_list_full 40 | from android_world.env import env_launcher, json_action 41 | from android_world.task_evals import task_eval 42 | 43 | logging.set_verbosity(logging.WARNING) 44 | 45 | os.environ['GRPC_VERBOSITY'] = 'ERROR' # Only show errors 46 | os.environ['GRPC_TRACE'] = 'none' # Disable tracing 47 | 48 | 49 | def _find_adb_directory() -> str: 50 | """Returns the directory where adb is located.""" 51 | potential_paths = [ 52 | os.path.expanduser('~/Library/Android/sdk/platform-tools/adb'), 53 | os.path.expanduser('~/Android/Sdk/platform-tools/adb'), 54 | ] 55 | for path in potential_paths: 56 | if os.path.isfile(path): 57 | return path 58 | raise EnvironmentError( 59 | 'adb not found in the common Android SDK paths. Please install Android' 60 | " SDK and ensure adb is in one of the expected directories. If it's" 61 | ' already installed, point to the installed location.' 62 | ) 63 | 64 | 65 | _ADB_PATH = flags.DEFINE_string( 66 | 'adb_path', 67 | _find_adb_directory(), 68 | 'Path to adb. Set if not installed through SDK.', 69 | ) 70 | _EMULATOR_SETUP = flags.DEFINE_boolean( 71 | 'perform_emulator_setup', 72 | False, 73 | 'Whether to perform emulator setup. This must be done once and only once' 74 | ' before running Android World. After an emulator is setup, this flag' 75 | ' should always be False.', 76 | ) 77 | _DEVICE_CONSOLE_PORT = flags.DEFINE_integer( 78 | 'console_port', 79 | 5554, 80 | 'The console port of the running Android device. This can usually be' 81 | ' retrieved by looking at the output of `adb devices`. In general, the' 82 | ' first connected device is port 5554, the second is 5556, and' 83 | ' so on.', 84 | ) 85 | 86 | _TASK = flags.DEFINE_string( 87 | 'task', 88 | None, 89 | 'A specific task to run.', 90 | ) 91 | 92 | 93 | def save_image(image, directory): 94 | """Same image to a file and return the file name.""" 95 | unique_id = str(uuid.uuid4()) 96 | image_name = f"{unique_id}.png" 97 | image_path = os.path.join(directory, image_name) 98 | if isinstance(image, np.ndarray): 99 | image = Image.fromarray(np.uint8(image)) 100 | image.save(image_path) 101 | return image_name 102 | 103 | 104 | def get_state(env_state, logical_screen_size, ui_elements): 105 | element_list_text = _generate_ui_elements_description_list_full( 106 | ui_elements, 107 | logical_screen_size, 108 | ) 109 | screen = env_state.pixels.copy() 110 | screen = Image.fromarray(screen.astype('uint8')) 111 | return screen, element_list_text 112 | 113 | 114 | def element_to_identifier(element): 115 | """Converts an element to a JSON-serializable identifier.""" 116 | bbox = getattr(element, 'bbox_pixels', None) 117 | bbox_dict = {'x_min': bbox.x_min, 'x_max': bbox.x_max, 'y_min': bbox.y_min, 'y_max': bbox.y_max} if bbox else None 118 | identifier = { 119 | 'resource_id': getattr(element, 'resource_id', None), 120 | 'text': getattr(element, 'text', None), 121 | 'content_description': getattr(element, 'content_description', None), 122 | 'class_name': getattr(element, 'class_name', None), 123 | 'bbox_pixels': bbox_dict, 124 | 'hint_text': getattr(element, 'hint_text', None), 125 | 'is_checkable': getattr(element, 'is_checkable', None), 126 | 'is_enabled': getattr(element, 'is_enabled', None), 127 | 'is_visible': getattr(element, 'is_visible', None), 128 | 'is_clickable': getattr(element, 'is_clickable', None), 129 | 'is_editable': getattr(element, 'is_editable', None), 130 | 'is_focused': getattr(element, 'is_focused', None), 131 | 'is_focusable': getattr(element, 'is_focusable', None), 132 | 'is_long_clickable': getattr(element, 'is_long_clickable', None), 133 | 'is_scrollable': getattr(element, 'is_scrollable', None), 134 | 'is_selected': getattr(element, 'is_selected', None), 135 | 'package_name': getattr(element, 'package_name', None), 136 | 'resource_name': getattr(element, 'resource_name', None), 137 | } 138 | return identifier 139 | 140 | 141 | def _main() -> None: 142 | 143 | instruction_path = './aw_instructions.json' 144 | aw_instrcutions = json.load(open(instruction_path, 'r')) 145 | 146 | SCREEN_GPT_DIR = './screenshots_gpt_v2' 147 | if not os.path.exists(SCREEN_GPT_DIR): 148 | os.mkdir(SCREEN_GPT_DIR) 149 | 150 | """Initialize Env.""" 151 | env = env_launcher.load_and_setup_env( 152 | console_port=_DEVICE_CONSOLE_PORT.value, 153 | emulator_setup=_EMULATOR_SETUP.value, 154 | adb_path=_ADB_PATH.value, 155 | ) 156 | env_launcher.verify_api_level(env) 157 | 158 | for task_item in aw_instrcutions: 159 | if "task_fail" in task_item: 160 | del task_item["task_fail"] 161 | 162 | for task_item in aw_instrcutions: 163 | 164 | total_tasks = len(aw_instrcutions) 165 | annotated_tasks = len([item for item in aw_instrcutions if "gpt_traj" in item]) 166 | print(f"Total task: {total_tasks} --- Annotated task: {annotated_tasks}") 167 | failed_tasks = len([item for item in aw_instrcutions if "task_fail" in item]) 168 | print(f"Total task: {total_tasks} --- Failed task: {failed_tasks}") 169 | if "gpt_traj" in task_item or "task_fail" in task_item: 170 | continue 171 | 172 | try: 173 | env.reset(go_home=True) 174 | task_registry = registry.TaskRegistry() 175 | aw_registry = task_registry.get_registry(task_registry.ANDROID_WORLD_FAMILY) 176 | 177 | # Initialize based on the task sampled and open the corresponding app. 178 | app_name = task_item["app_name"] 179 | task_name = task_item["task_name"] if "task_name" in task_item else task_item["task_task"] 180 | instrcution = task_item["refine_task"] 181 | 182 | if task_name and task_name != "default": 183 | if task_name not in aw_registry: 184 | raise ValueError('Task {} not found in registry.'.format(_TASK.value)) 185 | task_type: Type[task_eval.TaskEval] = aw_registry[task_name] 186 | else: 187 | task_type: Type[task_eval.TaskEval] = random.choice( 188 | list(aw_registry.values()) 189 | ) 190 | print("unknown task name") 191 | input() 192 | print(task_type) 193 | 194 | # load params 195 | task_id = task_item["task_id"] 196 | params_dir = './params_new' 197 | params_path = os.path.join(params_dir, task_id + "_params.pkl") 198 | with open(params_path, 'rb') as f: 199 | params = pickle.load(f) 200 | print(params) 201 | #params = task_type.generate_random_params() 202 | 203 | task = task_type(params) 204 | 205 | task.initialize_task(env) 206 | # agent = t3a.T3A(env, infer.Gpt4Wrapper('gpt-4-turbo-2024-04-09')) 207 | # agent = m3a_origin.M3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06')) 208 | agent = m3a.M3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06')) 209 | 210 | 211 | # Open the corresponding app after initializing the task. 212 | open_app = True 213 | if open_app: 214 | open_app_action = {"action_type": "open_app", "app_name": app_name} 215 | converted_action = json_action.JSONAction(**open_app_action) 216 | agent.env.execute_action(converted_action) 217 | time.sleep(3.0) 218 | 219 | print('Goal: ' + str(instrcution)) 220 | is_done = False 221 | gpt_traj = [] 222 | for i, _ in enumerate(range(15)): 223 | 224 | # Obtain the state of the environment before execution to synchronize with our training setup. 225 | env_state = agent.get_post_transition_state() 226 | logical_screen_size = agent.env.logical_screen_size 227 | ui_elements = env_state.ui_elements 228 | screen, element_list_text = get_state(env_state, logical_screen_size, ui_elements) 229 | screen_before = save_image(screen, SCREEN_GPT_DIR) 230 | # Note: Here, following the implementation of M3A, a state interface representation consistent with the agent’s observation is saved, ensuring that the actions generated by the model can locate the corresponding elements. 231 | ui_elements_before_identifiers = [element_to_identifier(elem) for elem in ui_elements if m3a_utils.validate_ui_element(elem, logical_screen_size)] 232 | 233 | # take one step 234 | response = agent.step(instrcution) 235 | 236 | # Extract the screen, prompt, and generated action from the response. 237 | screen_before_som = save_image(response.data["before_screenshot_with_som"], SCREEN_GPT_DIR) 238 | action_prompt = response.data["action_prompt"] 239 | action_output = response.data["action_output"] 240 | action_reason = response.data["action_reason"] 241 | summary_prompt = response.data["summary_prompt"] 242 | summary = response.data["summary"] 243 | 244 | match = re.search(r'Action:\s*(\{.*\})', action_output) 245 | action_json = match.group(1) if match else "action_not_match" 246 | # Exit if the same action is performed three times consecutively. 247 | if app_name != "Simple Calendar Pro": 248 | if i >= 2 and (action_json == gpt_traj[i-1]["action_json"] == gpt_traj[i-2]["action_json"]): 249 | break 250 | 251 | step_data = { 252 | "screen_before": screen_before, 253 | "screen_before_som": screen_before_som, 254 | "ui_elements_before_text": element_list_text, 255 | "ui_elements_before": ui_elements_before_identifiers, 256 | "action_prompt": action_prompt, 257 | "action_output": action_output, 258 | "action_json": action_json, 259 | "action_reason": action_reason, 260 | "summary_prompt": summary_prompt, 261 | "summary": summary 262 | } 263 | gpt_traj.append(step_data) 264 | 265 | if response.done: 266 | is_done = True 267 | break 268 | 269 | """ 270 | agent_successful = is_done and task.is_successful(env) == 1 271 | print( 272 | f'{"Task Successful ✅" if agent_successful else "Task Failed ❌"};' 273 | f' {task.goal}' 274 | ) 275 | """ 276 | 277 | # env.close() 278 | 279 | # Update the annotations to the original aw_instructions file at the end of each trajectory. 280 | task_item["gpt_traj"] = gpt_traj 281 | json.dump(aw_instrcutions, open(instruction_path, 'w')) 282 | 283 | except Exception as e: 284 | print(f"An error occurred: {e}") 285 | task_item["task_fail"] = "fail" 286 | json.dump(aw_instrcutions, open(instruction_path, 'w')) 287 | time.sleep(10) 288 | break # Exit and restart after an error occurs. 289 | 290 | 291 | def main(argv: Sequence[str]) -> None: 292 | del argv 293 | _main() 294 | 295 | 296 | if __name__ == '__main__': 297 | app.run(main) -------------------------------------------------------------------------------- /collection/random_walk_aw.py: -------------------------------------------------------------------------------- 1 | # Copyright 2024 The android_world Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Random walk in AndroidWorld environment, collecting triples. 17 | Maintain an unclickable element pool: during random walk, if the action does not change the screen, add the element to the pool. 18 | """ 19 | 20 | from collections.abc import Sequence 21 | import os 22 | import requests 23 | import random 24 | from typing import Type 25 | from PIL import Image 26 | import time 27 | import json 28 | import uuid 29 | import pickle 30 | 31 | from absl import app 32 | from absl import flags 33 | from absl import logging 34 | from android_world import registry 35 | from android_world.agents import infer 36 | from android_world.agents import t3a 37 | from android_world.agents import agent_utils 38 | from android_world.agents.t3a import _generate_ui_elements_description_list_full 39 | from android_world.env import env_launcher, json_action 40 | from android_world.env import representation_utils 41 | from android_world.task_evals import task_eval 42 | 43 | 44 | logging.set_verbosity(logging.WARNING) 45 | 46 | os.environ['GRPC_VERBOSITY'] = 'ERROR' # Only show errors 47 | os.environ['GRPC_TRACE'] = 'none' # Disable tracing 48 | 49 | 50 | def _find_adb_directory() -> str: 51 | """Returns the directory where adb is located.""" 52 | potential_paths = [ 53 | os.path.expanduser('~/Library/Android/sdk/platform-tools/adb'), 54 | os.path.expanduser('~/Android/Sdk/platform-tools/adb'), 55 | ] 56 | for path in potential_paths: 57 | if os.path.isfile(path): 58 | return path 59 | raise EnvironmentError( 60 | 'adb not found in the common Android SDK paths. Please install Android' 61 | " SDK and ensure adb is in one of the expected directories. If it's" 62 | ' already installed, point to the installed location.' 63 | ) 64 | 65 | 66 | _ADB_PATH = flags.DEFINE_string( 67 | 'adb_path', 68 | _find_adb_directory(), 69 | 'Path to adb. Set if not installed through SDK.', 70 | ) 71 | _EMULATOR_SETUP = flags.DEFINE_boolean( 72 | 'perform_emulator_setup', 73 | False, 74 | 'Whether to perform emulator setup. This must be done once and only once' 75 | ' before running Android World. After an emulator is setup, this flag' 76 | ' should always be False.', 77 | ) 78 | _DEVICE_CONSOLE_PORT = flags.DEFINE_integer( 79 | 'console_port', 80 | 5554, 81 | 'The console port of the running Android device. This can usually be' 82 | ' retrieved by looking at the output of `adb devices`. In general, the' 83 | ' first connected device is port 5554, the second is 5556, and' 84 | ' so on.', 85 | ) 86 | 87 | _TASK = flags.DEFINE_string( 88 | 'task', 89 | None, 90 | 'A specific task to run.', 91 | ) 92 | 93 | 94 | def generate_text_input(element_list_text, interactive_element, max_retries=3): 95 | # Construct the prompt with specific instructions 96 | def is_valid_response(response_text): 97 | if "\n" in response_text: 98 | return False 99 | # Optionally, add more validation rules here 100 | return True 101 | 102 | def element_to_text(element): 103 | """Convert the element into a text description.""" 104 | description = "" 105 | if getattr(element, 'resource_id', None): 106 | description += f"Resource ID: {element.resource_id}\n" 107 | if getattr(element, 'text', None): 108 | description += f"Text: {element.text}\n" 109 | if getattr(element, 'content_description', None): 110 | description += f"Content Description: {element.content_description}\n" 111 | if getattr(element, 'class_name', None): 112 | description += f"Class Name: {element.class_name}\n" 113 | if getattr(element, 'hint_text', None): 114 | description += f"Hint Text: {element.hint_text}\n" 115 | return description or "No additional information." 116 | 117 | prompt = f""" 118 | You are an intelligent input assistant. The current UI elements are as follows: 119 | {element_list_text} 120 | The selected editable element information is as follows: 121 | {element_to_text(interactive_element)} 122 | Based on the above information, please randomly generate a text content that a user might input into this element. The text should be contextually appropriate. For example, if it's a search box, you might generate a search query; if it's a username input field, you might generate a username. 123 | **Please return only the generated text without any additional explanation. Do not include any prefixes or suffixes.** 124 | 125 | If you understand, please provide the text input now. 126 | """ 127 | 128 | headers = { 129 | "Content-Type": "application/json", 130 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" 131 | } 132 | 133 | payload = { 134 | "model": "gpt-4o-mini-2024-07-18", 135 | "messages": [ 136 | {"role": "user", "content": prompt} 137 | ], 138 | "max_tokens": 20, 139 | "temperature": 0.7, 140 | "n": 1, 141 | } 142 | 143 | retries = 0 144 | while retries < max_retries: 145 | try: 146 | # Send POST request to OpenAI API 147 | response = requests.post( 148 | "https://api.openai.com/v1/chat/completions", 149 | headers=headers, 150 | json=payload 151 | ) 152 | response.raise_for_status() 153 | result = response.json() 154 | text_input = result['choices'][0]['message']['content'].strip() 155 | 156 | # Validate the response 157 | if is_valid_response(text_input): 158 | return text_input 159 | else: 160 | print(f"Invalid response format: '{text_input}'. Retrying...") 161 | retries += 1 162 | time.sleep(1) # Wait a bit before retrying 163 | except Exception as e: 164 | print(f"Error generating text input: {e}") 165 | retries += 1 166 | time.sleep(1) 167 | 168 | # If all retries fail, return a default text 169 | print("Failed to get valid text input after retries. Returning default text.") 170 | return "Test Input" 171 | 172 | 173 | def get_state(env_state, logical_screen_size, ui_elements): 174 | element_list_text = _generate_ui_elements_description_list_full( 175 | ui_elements, 176 | logical_screen_size, 177 | ) 178 | screen = env_state.pixels.copy() 179 | screen = Image.fromarray(screen.astype('uint8')) 180 | return screen, element_list_text 181 | 182 | 183 | def element_to_identifier(element): 184 | """Converts an element to a JSON-serializable identifier.""" 185 | bbox = getattr(element, 'bbox_pixels', None) 186 | bbox_dict = {'x_min': bbox.x_min, 'x_max': bbox.x_max, 'y_min': bbox.y_min, 'y_max': bbox.y_max} if bbox else None 187 | identifier = { 188 | 'resource_id': getattr(element, 'resource_id', None), 189 | 'text': getattr(element, 'text', None), 190 | 'content_description': getattr(element, 'content_description', None), 191 | 'class_name': getattr(element, 'class_name', None), 192 | 'bbox_pixels': bbox_dict, 193 | 'hint_text': getattr(element, 'hint_text', None), 194 | 'is_checkable': getattr(element, 'is_checkable', None), 195 | 'is_enabled': getattr(element, 'is_enabled', None), 196 | 'is_visible': getattr(element, 'is_visible', None), 197 | 'is_clickable': getattr(element, 'is_clickable', None), 198 | 'is_editable': getattr(element, 'is_editable', None), 199 | 'is_focused': getattr(element, 'is_focused', None), 200 | 'is_focusable': getattr(element, 'is_focusable', None), 201 | 'is_long_clickable': getattr(element, 'is_long_clickable', None), 202 | 'is_scrollable': getattr(element, 'is_scrollable', None), 203 | 'is_selected': getattr(element, 'is_selected', None), 204 | 'package_name': getattr(element, 'package_name', None), 205 | 'resource_name': getattr(element, 'resource_name', None), 206 | } 207 | return identifier 208 | 209 | 210 | def filter_interactive_elements(elements, screen_width_height_px, unc_elem_pool): 211 | interactive_elements = [] 212 | screen_width, screen_height = screen_width_height_px 213 | 214 | # List of excluded package names, adding other keyboard app package names can improve filtering effect 215 | # Keyboard and system UI (e.g., battery indicator is actually not clickable) 216 | excluded_packages = {'com.google.android.inputmethod.latin', 'com.android.systemui'} 217 | 218 | for index, element in enumerate(elements): 219 | 220 | if element.package_name in excluded_packages: 221 | continue 222 | 223 | if element.is_enabled and element.is_visible: 224 | if element.bbox_pixels: 225 | x_min = element.bbox_pixels.x_min 226 | x_max = element.bbox_pixels.x_max 227 | y_min = element.bbox_pixels.y_min 228 | y_max = element.bbox_pixels.y_max 229 | 230 | # Check if bounding box is within screen bounds and coordinates are valid 231 | if not (x_min >= x_max or x_min >= screen_width or x_max <= 0 or 232 | y_min >= y_max or y_min >= screen_height or y_max <= 0): 233 | 234 | # Compute element identifier 235 | element_identifier = element_to_identifier(element) 236 | element_identifier_str = json.dumps(element_identifier, sort_keys=True) 237 | if element_identifier_str not in unc_elem_pool: 238 | interactive_elements.append([index, element]) 239 | 240 | return interactive_elements 241 | 242 | 243 | def sample_action_element(element, element_list_text): 244 | index, interactive_element = element 245 | actions = [] 246 | # If element is editable, prioritize text input 247 | if interactive_element.is_editable: 248 | # Use GPT to generate appropriate text input based on current UI 249 | text_input = generate_text_input(element_list_text, interactive_element) 250 | return {"action_type": "input_text", "text": text_input, "index": index} 251 | # Assume all elements are clickable, add click action to action list 252 | actions.append({"action_type": "click", "index": index}) 253 | # If element can be long pressed, add long press action to action list 254 | if interactive_element.is_long_clickable: 255 | actions.append({"action_type": "long_press", "index": index}) 256 | # If multiple actions are available, choose based on given probabilities 257 | if actions: 258 | if len(actions) == 1: 259 | return actions[0] 260 | else: 261 | # Choose click or long press with 90% and 10% probability (assuming both actions are available) 262 | return random.choices(actions, weights=[9, 1], k=1)[0] 263 | 264 | 265 | def get_task_app(task_name): 266 | # Choose which app to open based on task name 267 | task_2_app = { 268 | "AudioRecorder": "Audio Recorder", 269 | "Browser": "Files", 270 | "SimpleCalendar": "Simple Calendar Pro", 271 | "Camera": "Camera", 272 | "Clock": "Clock", 273 | "Contacts": "Contacts", 274 | "Expense": "Pro Expense", 275 | "ExpenseAddMultipleFromMarkor": "Pro Expense", 276 | "Files": "Files", 277 | "Markor": "Markor", 278 | "Osm": "OsmAnd", 279 | "Recipe": "Broccoli - Recipe App", 280 | "RecipeAddMultipleRecipesFromMarkor": "Broccoli - Recipe App", 281 | "RecipeAddMultipleRecipesFromMarkor2": "Broccoli - Recipe App", 282 | "Retro": "Retro Music", 283 | "SimpleDraw": "Simple Draw Pro", 284 | "SaveCopyOfReceipt": "Simple Gallery Pro", 285 | "SimpleSms": "Simple SMS Messenger", 286 | "System": "Settings", 287 | "Turn": "Settings", 288 | "Vlc": "VLC", 289 | "Tasks": "Tasks", 290 | "Notes": "Joplin", 291 | "Sports": "OpenTracks", 292 | } 293 | if task_name in task_2_app: 294 | return task_2_app[task_name] 295 | else: 296 | for task, app in task_2_app.items(): 297 | if task in task_name: 298 | return app 299 | return "Home" 300 | 301 | 302 | def has_screen_changed(before_elements, after_elements): 303 | # Convert UI elements to identifiers, excluding system UI elements 304 | before_set = set( 305 | json.dumps(element_to_identifier(elem), sort_keys=True) 306 | for elem in before_elements 307 | if elem.package_name != 'com.android.systemui' 308 | ) 309 | after_set = set( 310 | json.dumps(element_to_identifier(elem), sort_keys=True) 311 | for elem in after_elements 312 | if elem.package_name != 'com.android.systemui' 313 | ) 314 | return before_set != after_set 315 | 316 | 317 | # Random walk pipeline: 318 | # 1. Initialize environment 319 | # 2. Randomly select an app to sample based on task config 320 | # 3. Load unclickable element pool 321 | # 4. Execute num_step steps of random walk within an app 322 | # a. For each step, get a set of currently executable actions based on current state 323 | # b. Randomly select an action 324 | # c. Execute action, if screen doesn't change it means it's an unclickable element, add to pool; otherwise record triples 325 | 326 | def _main() -> None: 327 | """Random walk in AndroidWorld for sampling""" 328 | SCREEN_DIR = './screenshots_camera' 329 | if not os.path.exists(SCREEN_DIR): 330 | os.mkdir(SCREEN_DIR) 331 | 332 | TRAJECTORY_DIR = './trajectories_camera' 333 | if not os.path.exists(TRAJECTORY_DIR): 334 | os.mkdir(TRAJECTORY_DIR) 335 | 336 | PARAMS_DIR = './params_camera' 337 | if not os.path.exists(PARAMS_DIR): 338 | os.mkdir(PARAMS_DIR) 339 | 340 | # Launch Android emulator (ADB) and return to home screen 341 | env = env_launcher.load_and_setup_env( 342 | console_port=_DEVICE_CONSOLE_PORT.value, 343 | emulator_setup=_EMULATOR_SETUP.value, 344 | adb_path=_ADB_PATH.value, 345 | ) 346 | env_launcher.verify_api_level(env) 347 | env.reset(go_home=True) 348 | 349 | # Task collection 350 | task_registry = registry.TaskRegistry() 351 | aw_registry = task_registry.get_registry(task_registry.ANDROID_WORLD_FAMILY) # All AW tasks, total 116 352 | print(aw_registry) 353 | print("There are total {} tasks.".format(len(aw_registry))) 354 | 355 | aw_list = list(aw_registry.items()) 356 | random.shuffle(aw_list) 357 | aw_registry = dict(aw_list) 358 | 359 | # Record unclickable elements for each app 360 | if not os.path.exists('./unclickable_elem_pool'): 361 | os.mkdir('./unclickable_elem_pool') 362 | 363 | for task_id in range(len(aw_registry)): 364 | 365 | task_uuid = str(uuid.uuid4()) 366 | 367 | # Select a task 368 | task_name, task_type = list(aw_registry.items())[task_id] 369 | 370 | # Open different apps as starting state based on task 371 | app_name = get_task_app(task_name) 372 | 373 | # Initialize and return to home screen, this initialization initializes the corresponding app based on task's app snapshot 374 | params = task_type.generate_random_params() 375 | 376 | # Record random params 377 | with open(os.path.join(PARAMS_DIR, task_uuid+"_params.pkl"), "wb") as f: 378 | pickle.dump(params, f) 379 | 380 | task = task_type(params) 381 | task.initialize_task(env) 382 | env.reset(go_home=True) 383 | 384 | agent = t3a.T3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06')) # Only used to adapt environment interface, actually no need for agent 385 | 386 | print("Open App: {}".format(app_name)) 387 | print("Goal: {}".format(task.goal)) 388 | if app_name != "Home": 389 | open_app_action = {"action_type": "open_app", "app_name": app_name} 390 | converted_action = json_action.JSONAction(**open_app_action) 391 | agent.env.execute_action(converted_action) 392 | time.sleep(3.0) 393 | 394 | # Load unclick_elem_pool 395 | unc_elem_pool_path = os.path.join('./unclickable_elem_pool', str(app_name)+".json") 396 | if not os.path.exists(unc_elem_pool_path): 397 | unc_elem_pool = set() 398 | else: 399 | with open(unc_elem_pool_path, 'r') as f: 400 | unc_elem_pool_list = json.load(f) 401 | unc_elem_pool = set(unc_elem_pool_list) 402 | 403 | # Random walk sampling for num_step steps 404 | trajectory = [] 405 | num_step = 10 406 | for i in range(num_step): 407 | # Get current state 408 | env_state = agent.get_post_transition_state() 409 | logical_screen_size = agent.env.logical_screen_size 410 | ui_elements = env_state.ui_elements 411 | screen, element_list_text = get_state(env_state, logical_screen_size, ui_elements) 412 | print(element_list_text) 413 | 414 | # Get all executable actions on current screen: 1. Different sampling ratios & logic for different actions; 2. Keyboard elements should not be sampled, filter these elements 415 | # TODO: (questionable) Some samples seem to occasionally not match up? 416 | # Interactive elements on current screen 417 | interactive_elements = filter_interactive_elements(ui_elements, logical_screen_size, unc_elem_pool) 418 | # Other available actions on current screen 419 | addition_actions = [{"action_type": "scroll", "direction": "down"}, {"action_type": "scroll", "direction": "up"}, 420 | {"action_type": "navigate_back"}] 421 | 422 | # Sample an action/element 423 | weight_interactive = 10 # Weight for interactive elements 424 | weight_addition = 1 # Weight for other actions 425 | action_element = random.choice(interactive_elements*weight_interactive+addition_actions*weight_addition) 426 | if "action_type" in action_element: # If it's a direct action 427 | action_sample = action_element 428 | else: # If it's an element, sample an action that can be executed on it based on element properties (e.g., is_clickable) 429 | action_sample = sample_action_element(action_element, element_list_text) 430 | print("Cand Elements: {}".format([index for (index, elem) in interactive_elements])) 431 | print(action_sample) 432 | if "index" in action_sample: 433 | print(ui_elements[action_sample['index']]) 434 | 435 | # Execute action 436 | converted_action = json_action.JSONAction(**action_sample) 437 | agent.env.execute_action(converted_action) # TODO: (questionable) Some actions seem to not be fully executed? Theoretically execute_action waits for action completion 438 | time.sleep(2.0) 439 | 440 | # Check if action caused screen change 441 | env_state_after = agent.get_post_transition_state() 442 | logical_screen_size = agent.env.logical_screen_size 443 | ui_elements_after = env_state_after.ui_elements 444 | if not has_screen_changed(ui_elements, ui_elements_after): 445 | print("The screen not change") 446 | # If screen didn't change and this operation was on an element, add this element to Unavailable Ele Pool 447 | if "index" in action_sample: 448 | element = ui_elements[action_sample['index']] 449 | element_identifier = element_to_identifier(element) 450 | unc_elem_pool.add(json.dumps(element_identifier, sort_keys=True)) 451 | else: 452 | # Record this action 453 | screen_before_uuid = str(uuid.uuid4()) 454 | screen_after_uuid = str(uuid.uuid4()) 455 | screen_before_filename = os.path.join(SCREEN_DIR, f'{screen_before_uuid}.png') 456 | screen.save(screen_before_filename) 457 | 458 | screen_after, element_list_text_after = get_state(env_state_after, logical_screen_size, ui_elements_after) 459 | screen_after_filename = os.path.join(SCREEN_DIR, f'{screen_after_uuid}.png') 460 | screen_after.save(screen_after_filename) 461 | 462 | ui_elements_before_identifiers = [element_to_identifier(elem) for elem in ui_elements] 463 | interactive_elements_before = [element_to_identifier(elem[1]) for elem in interactive_elements] 464 | 465 | ui_elements_after_identifiers = [element_to_identifier(elem) for elem in ui_elements_after] 466 | interactive_elements_after_full = filter_interactive_elements(ui_elements_after, logical_screen_size, unc_elem_pool) 467 | interactive_elements_after = [element_to_identifier(elem[1]) for elem in interactive_elements_after_full] 468 | 469 | if "index" in action_sample: 470 | action_element = element_to_identifier(ui_elements[action_sample['index']]) 471 | else: 472 | action_element = None 473 | 474 | step_data = { 475 | 'task_uuid': task_uuid, 476 | 'task': task_name, 477 | 'app': app_name, 478 | 'screen_before': screen_before_filename, 479 | 'element_list_text_before': element_list_text, 480 | 'ui_elements_before': ui_elements_before_identifiers, 481 | 'interactive_elements_before': interactive_elements_before, 482 | 'screen_after': screen_after_filename, 483 | 'element_list_text_after': element_list_text_after, 484 | 'ui_elements_after': ui_elements_after_identifiers, 485 | 'interactive_elements_after': interactive_elements_after, 486 | 'action': action_sample, 487 | 'action_element': action_element 488 | } 489 | 490 | trajectory.append(step_data) 491 | print(f"Recorded step {i + 1} for task {task_name}") 492 | 493 | # Save the updated unc_elem_pool 494 | with open(unc_elem_pool_path, 'w') as f: 495 | json.dump(list(unc_elem_pool), f, indent=2) 496 | 497 | # Save the trajectory to a JSON file 498 | trajectory_uuid = str(uuid.uuid4()) 499 | trajectory_filename = os.path.join(TRAJECTORY_DIR, f'{task_name}_{trajectory_uuid}.json') 500 | with open(trajectory_filename, 'w') as f: 501 | json.dump(trajectory, f, indent=2) 502 | print(f"Saved trajectory for task {task_name} to {trajectory_filename}") 503 | 504 | env.reset(go_home=True) 505 | 506 | env.close() 507 | 508 | 509 | def main(argv: Sequence[str]) -> None: 510 | del argv 511 | _main() 512 | 513 | 514 | if __name__ == '__main__': 515 | app.run(main) 516 | -------------------------------------------------------------------------------- /collection/random_walk_web.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # type: ignore 3 | 4 | """ 5 | Random walk in WebArena environment, collecting triples. 6 | Maintain unclickable element pool, clickable element pool, explored element pool, and implement exploration of the website through heuristic rules. 7 | """ 8 | 9 | import json 10 | import os 11 | import random 12 | import re 13 | import subprocess 14 | import time 15 | import copy 16 | import numpy as np 17 | import uuid 18 | from PIL import Image 19 | import requests 20 | from tqdm import tqdm 21 | from browser_env.utils import ( 22 | DetachedPage, 23 | ) 24 | 25 | 26 | def get_state(env): 27 | """Get the current state of the environment""" 28 | observation = env._get_obs() 29 | observation_metadata = env._get_obs_metadata() 30 | info = { 31 | "page": DetachedPage(env.page.url, ""), 32 | "fail_error": "", 33 | "observation_metadata": observation_metadata, 34 | } 35 | # Return a deep copy of info to ensure it doesn't update with the environment 36 | return (copy.deepcopy(observation), copy.deepcopy(info)) 37 | 38 | 39 | def extract_state_ele(observation_metadata): 40 | # Extract each element represented by 'union_bound' and 'text' from the current interface for comparing changes between interfaces 41 | return [(tuple(value.get('union_bound')), re.sub(r'^\[\d+\]\s*', '', value.get('text'))) for value in observation_metadata.values()] 42 | 43 | 44 | def are_screen_identical(screen_before, screen_after): 45 | return np.array_equal(screen_before, screen_after) 46 | 47 | 48 | def load_element_pool(file_path): 49 | """Load element pool""" 50 | if not os.path.exists(file_path): 51 | return set() 52 | else: 53 | with open(file_path, 'r') as f: 54 | data_list = json.load(f) 55 | element_pool = set((tuple(item[0]), item[1]) for item in data_list) 56 | return element_pool 57 | 58 | 59 | def save_element_pool(element_pool, file_path): 60 | """Save element pool""" 61 | with open(file_path, 'w') as f: 62 | json.dump(list(element_pool), f, indent=2) 63 | 64 | 65 | def save_image(image_array, directory): 66 | """Save image and return filename""" 67 | unique_id = str(uuid.uuid4()) 68 | image_name = f"{unique_id}.png" 69 | image_path = os.path.join(directory, image_name) 70 | image = Image.fromarray(image_array) 71 | image.save(image_path) 72 | return image_name 73 | 74 | 75 | def select_element(state_elements, actree_obs, unclick_elem_pool, new_elements, explored_elem_pool): 76 | """Randomly select an element based on weights and return corresponding information""" 77 | elements_weights = [] 78 | for action_element_id, action_element in state_elements: 79 | element_key = (tuple(action_element['union_bound']), re.sub(r'^\[\d+\]\s*', '', action_element['text'])) 80 | # Check if element is in unclickable element set and if it's visible in the interface 81 | if element_key in unclick_elem_pool or (f"[{action_element_id}]" not in actree_obs) or ("statictext" in element_key[1].lower()): 82 | continue 83 | # Determine element status 84 | is_new = element_key in new_elements 85 | is_explored = element_key in explored_elem_pool 86 | # Assign weights 87 | if (not is_explored) and is_new: 88 | weight = 4 # Unexplored and newly appeared, highest weight 89 | elif is_explored and is_new: 90 | weight = 3 # Explored but newly appeared, second highest weight 91 | elif (not is_explored) and (not is_new): 92 | weight = 3 # Unexplored but not newly appeared, second highest weight 93 | else: 94 | weight = 1 # Explored and not newly appeared, lowest weight 95 | elements_weights.append((action_element_id, action_element, element_key, weight)) 96 | 97 | # If no selectable elements, return None 98 | if len(elements_weights) == 0: 99 | return None 100 | 101 | print(elements_weights) 102 | 103 | # Randomly select an element based on weights 104 | elem_infos = [item[:3] for item in elements_weights] 105 | weights = [item[3] for item in elements_weights] 106 | selected_elem_info = random.choices(elem_infos, weights=weights, k=1)[0] 107 | return selected_elem_info 108 | 109 | 110 | def generate_text_input(screen_text, selected_element_text, max_retries=5): 111 | """Call GPT API to determine if element is an input field and generate input content""" 112 | def is_valid_response(response_text): 113 | try: 114 | # Clean up the response text 115 | response_text = response_text.strip() 116 | 117 | # Remove code block markers if present 118 | if response_text.startswith("```"): 119 | # Remove starting ``` or ```json 120 | response_text = re.sub(r'^```[a-zA-Z]*\n', '', response_text) 121 | # Remove ending ``` 122 | response_text = re.sub(r'\n```$', '', response_text) 123 | 124 | # Extract JSON object 125 | json_match = re.search(r'\{.*\}', response_text, re.DOTALL) 126 | if json_match: 127 | json_str = json_match.group(0) 128 | # Replace single quotes with double quotes 129 | json_str = json_str.replace("'", '"') 130 | # Replace True/False with true/false 131 | json_str = json_str.replace('True', 'true').replace('False', 'false') 132 | response_json = json.loads(json_str) 133 | # Check if 'is_input' exists and is boolean 134 | if 'is_input' in response_json and isinstance(response_json['is_input'], bool): 135 | if response_json['is_input']: 136 | # Ensure 'input_content' exists and is a string 137 | return 'input_content' in response_json and isinstance(response_json['input_content'], str) 138 | else: 139 | return True 140 | else: 141 | return False 142 | else: 143 | return False 144 | except json.JSONDecodeError as e: 145 | print(f"JSONDecodeError: {e}") 146 | return False 147 | 148 | prompt = f""" 149 | You are an intelligent assistant. Based on the current UI elements and the selected element information, carefully determine whether the selected element is an input field. If it is, generate text content that a user might input into this element. The text should be contextually appropriate, for example, if it's a search box, you might generate a search query; if it's a username input field, you might generate a username; if it's a location, you might try the name of a university, museum, airport, etc. Use imagination to generate diversed and appropriate input content. 150 | 151 | Current UI elements: 152 | {screen_text} 153 | 154 | Selected element information: 155 | {selected_element_text} 156 | 157 | Please output a JSON object in the following format, without adding any extra text or comments: 158 | 159 | If the selected element is an input field: 160 | 161 | {{ 162 | "is_input": true, 163 | "input_content": "the content to input" 164 | }} 165 | 166 | If the selected element is not an input field: 167 | 168 | {{ 169 | "is_input": false 170 | }} 171 | 172 | Ensure that the JSON is properly formatted and parsable. Use lowercase `true` or `false` for boolean values, and double quotes for strings. 173 | 174 | If you understand, please provide the JSON object now, without adding any extra text or markers. 175 | """ 176 | 177 | headers = { 178 | "Content-Type": "application/json", 179 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}" 180 | } 181 | 182 | payload = { 183 | "model": "gpt-4o-mini-2024-07-18", 184 | "messages": [ 185 | {"role": "user", "content": prompt} 186 | ], 187 | "max_tokens": 100, 188 | "temperature": 0.9, 189 | "n": 1, 190 | } 191 | 192 | retries = 0 193 | while retries < max_retries: 194 | try: 195 | # Send POST request to OpenAI API 196 | response = requests.post( 197 | "https://api.openai.com/v1/chat/completions", 198 | headers=headers, 199 | json=payload 200 | ) 201 | response.raise_for_status() 202 | result = response.json() 203 | response_text = result['choices'][0]['message']['content'].strip() 204 | 205 | # Validate and parse the response 206 | if is_valid_response(response_text): 207 | response_json = json.loads(response_text) 208 | return response_json 209 | else: 210 | print(f"Invalid response format: '{response_text}'. Retrying...") 211 | retries += 1 212 | time.sleep(1) # Wait a bit before retrying 213 | except Exception as e: 214 | print(f"Error generating text input: {e}") 215 | retries += 1 216 | time.sleep(1) 217 | 218 | # If all retries fail, return default action 219 | print("Failed to get valid response after retries. Assuming it's not an input field.") 220 | return {"is_input": False} 221 | 222 | 223 | def save_trajectory(unique_id_traj, trajectory, traj_dir, screen_dir, website_name, url): 224 | """Save trajectory data""" 225 | traj_save = [] 226 | for i, item in enumerate(trajectory): 227 | if isinstance(item, dict): 228 | continue 229 | elif isinstance(item, tuple) and i+1 < len(trajectory): 230 | screen_before_name = save_image(trajectory[i-1]['observation']['image'], screen_dir) 231 | screen_after_name = save_image(trajectory[i+1]['observation']['image'], screen_dir) 232 | step_data = { 233 | "website_name": website_name, 234 | "url": url, 235 | "screen_before": screen_before_name, 236 | "a11y_before": trajectory[i-1]['observation']['text'], 237 | "state_before": trajectory[i-1]['info']['observation_metadata'], 238 | "screen_after": screen_after_name, 239 | "a11y_after": trajectory[i+1]['observation']['text'], 240 | "state_after": trajectory[i+1]['info']['observation_metadata'], 241 | "action": item 242 | } 243 | traj_save.append(step_data) 244 | 245 | traj_filename = f"{unique_id_traj}.json" 246 | traj_path = os.path.join(traj_dir, traj_filename) 247 | with open(traj_path, 'w') as f: 248 | json.dump(traj_save, f, indent=2) 249 | 250 | 251 | def random_walk_episode(config_file, traj_dir, screen_dir, element_dir, config_randomwalk): 252 | """ 253 | config_file: Initialize a specific website 254 | """ 255 | 256 | if not os.path.exists(traj_dir): 257 | os.mkdir(traj_dir) 258 | if not os.path.exists(screen_dir): 259 | os.mkdir(screen_dir) 260 | if not os.path.exists(element_dir): 261 | os.mkdir(element_dir) 262 | if not os.path.exists(config_randomwalk): 263 | os.mkdir(config_randomwalk) 264 | 265 | SLEEP = 3 266 | # set the URLs of each website 267 | os.environ[ 268 | "SHOPPING" 269 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:7770" 270 | os.environ[ 271 | "SHOPPING_ADMIN" 272 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:7780/admin" 273 | os.environ[ 274 | "REDDIT" 275 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:9999" 276 | os.environ[ 277 | "GITLAB" 278 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:8023" 279 | os.environ[ 280 | "MAP" 281 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:3000" 282 | os.environ[ 283 | "WIKIPEDIA" 284 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing" 285 | os.environ[ 286 | "HOMEPAGE" 287 | ] = "PASS" # The home page is not currently hosted in the demo site 288 | print("Done setting up URLs") 289 | 290 | 291 | # Check if configuration file is correct 292 | assert os.path.exists(config_file) 293 | with open(config_file, "r") as f: 294 | config = json.load(f) 295 | 296 | # Check which URLs are available for the corresponding website and randomly select one 297 | configs_dir = "/Users/cckevin/Desktop/webarena/config_files" 298 | website_url = dict() 299 | for config_item in os.listdir(configs_dir): 300 | if config_item == "examples": 301 | continue 302 | config_path = os.path.join(configs_dir, config_item) 303 | config_item = json.load(open(config_path, 'r')) 304 | if not isinstance(config_item, dict): 305 | continue 306 | try: 307 | if len(config_item["sites"]) != 1: 308 | continue 309 | website_name = config_item["sites"][0] 310 | if website_name not in website_url: 311 | website_url[website_name] = set() 312 | 313 | start_url = config_item["start_url"] 314 | website_url[website_name].add(start_url) 315 | except Exception as e: 316 | print(f"An error occurred: {e}") 317 | print(config_item) 318 | 319 | web_urls = website_url[config["sites"][0]] 320 | random_url = random.choice(list(web_urls)) 321 | print(f"Random URL: {random_url}") 322 | config["start_url"] = random_url 323 | # Save config 324 | unique_id_traj = str(uuid.uuid4()) 325 | config_randomwalk_path = os.path.join(config_randomwalk, f"{unique_id_traj}.json") 326 | json.dump(config, open(config_randomwalk_path, 'w')) 327 | 328 | """ 329 | # run bash prepare.sh to save all account cookies, this only needs to be done once 330 | subprocess.run(["bash", "prepare.sh"]) 331 | print("Done saving account cookies") 332 | """ 333 | 334 | # Init an environment 335 | from browser_env import ( 336 | Action, 337 | ActionTypes, 338 | ObservationMetadata, 339 | ScriptBrowserEnv, 340 | StateInfo, 341 | Trajectory, 342 | action2str, 343 | create_id_based_action, 344 | create_stop_action, 345 | ) 346 | from evaluation_harness.evaluators import evaluator_router 347 | 348 | # maintain a trajectory 349 | trajectory: Trajectory = [] 350 | 351 | # Maintain element pools 352 | website_name = config["sites"][0] 353 | unclick_elem_pool_path = os.path.join(element_dir, website_name + "_unclick.json") 354 | click_elem_pool_path = os.path.join(element_dir, website_name + "_click.json") 355 | explored_elem_pool_path = os.path.join(element_dir, website_name + "_explored.json") 356 | 357 | unclick_elem_pool = load_element_pool(unclick_elem_pool_path) 358 | click_elem_pool = load_element_pool(click_elem_pool_path) 359 | explored_elem_pool = load_element_pool(explored_elem_pool_path) 360 | 361 | env = None 362 | 363 | try: 364 | # Init the environment 365 | env = ScriptBrowserEnv( 366 | headless=False, 367 | slow_mo=100, 368 | observation_type="accessibility_tree", 369 | current_viewport_only=True, 370 | viewport_size={"width": 1280, "height": 720}, 371 | ) 372 | 373 | # set the environment for the current example (website) 374 | env.reset(options={"config_file": config_randomwalk_path}) 375 | obs, info = get_state(env) 376 | state_info: StateInfo = {"observation": obs, "info": info} # Save each interface state using StateInfo 377 | trajectory.append(state_info) # Record initial interface 378 | 379 | # Record elements from previous interface for comparing newly appeared elements 380 | prev_state_elements = set() 381 | 382 | # random walk num_step 383 | num_step = 5 384 | 385 | for i in range(num_step): 386 | 387 | # Get current interface state 388 | obs_before, info_before = get_state(env) 389 | actree_obs = obs_before["text"] 390 | print(actree_obs) 391 | 392 | # Get candidate interactive elements from current interface 393 | state_elements = list(info_before['observation_metadata']['text']['obs_nodes_info'].items()) 394 | 395 | # Extract elements from current interface 396 | current_state_elements = set(extract_state_ele(info_before['observation_metadata']['text']['obs_nodes_info'])) 397 | # Find newly appeared elements 398 | new_elements = current_state_elements - prev_state_elements 399 | 400 | # Select element 401 | selected_elem_info = select_element( 402 | state_elements, 403 | actree_obs, 404 | unclick_elem_pool, 405 | new_elements, 406 | explored_elem_pool 407 | ) 408 | 409 | # If no selectable elements, skip this iteration 410 | if selected_elem_info is None: 411 | print("No clickable elements found.") 412 | break 413 | 414 | action_element_id, action_element, element_key = selected_elem_info 415 | 416 | # Use GPT to determine if selected element is inputable and choose corresponding action (click/type) 417 | gpt_response = generate_text_input(actree_obs, action_element['text']) 418 | if gpt_response.get('is_input'): 419 | type_content = gpt_response.get('input_content', 'Test Input') 420 | next_action_str = f"type [{action_element_id}] [{type_content}]" 421 | next_action = create_id_based_action(next_action_str) 422 | else: 423 | next_action_str = f"click [{action_element_id}]" 424 | next_action = create_id_based_action(next_action_str) 425 | 426 | print(f"Step {i}: {next_action_str}") 427 | 428 | # Execute action 429 | env.step(next_action) 430 | time.sleep(SLEEP) 431 | 432 | # Add element to explored set 433 | explored_elem_pool.add(element_key) 434 | 435 | # Get execution result 436 | obs_after, info_after = get_state(env) 437 | actree_obs = obs_after["text"] 438 | 439 | # Determine if action caused interface change through screenshot comparison 440 | if are_screen_identical(obs_before['image'], obs_after['image']): 441 | # If interface unchanged, add element to Unavailable Element Pool 442 | print("The pages are identical. Added to unclick_elem_pool.") 443 | if element_key not in click_elem_pool: # If not in clickable set 444 | unclick_elem_pool.add(element_key) 445 | else: 446 | # Interface changed, record action and subsequent interface in trajectory 447 | print("The pages have differences.") 448 | click_elem_pool.add(element_key) 449 | trajectory.append((next_action_str, action_element_id, action_element)) 450 | state_info = {"observation": obs_after, "info": info_after} 451 | trajectory.append(state_info) 452 | 453 | # Update previous interface element set, regardless of whether page changed 454 | prev_state_elements = current_state_elements 455 | 456 | except Exception as e: 457 | print(f"An error occurred: {e}") 458 | time.sleep(1) 459 | env.close() 460 | time.sleep(1) 461 | finally: 462 | if env is not None: 463 | env.close() 464 | time.sleep(1) 465 | 466 | # Record element pools 467 | save_element_pool(unclick_elem_pool, unclick_elem_pool_path) 468 | save_element_pool(click_elem_pool, click_elem_pool_path) 469 | save_element_pool(explored_elem_pool, explored_elem_pool_path) 470 | 471 | # Save trajectory data 472 | save_trajectory(unique_id_traj, trajectory, traj_dir, screen_dir, website_name, random_url) 473 | 474 | return trajectory 475 | 476 | 477 | num_episode = 50 478 | config_file = "config_files/60.json" # Choose the corresponding config file based on the website to be sampled 479 | config_randomwalk = "/Users/cckevin/Desktop/config_randomwalk" 480 | traj_dir = "/Users/cckevin/Desktop/traj_results" 481 | screen_dir = "/Users/cckevin/Desktop/screen_results" 482 | element_dir = "/Users/cckevin/Desktop/click_elements" 483 | 484 | for i in tqdm(range(num_episode)): 485 | print(f"Episode {i} random walk") 486 | try: 487 | random_walk_episode(config_file, traj_dir, screen_dir, element_dir, config_randomwalk) 488 | except Exception as e: 489 | print(f"An error occurred: {e}") 490 | -------------------------------------------------------------------------------- /collection/run_mobile_runner.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def run_gpt_task(): 6 | os.system("mobile_runner.py") 7 | 8 | while True: 9 | run_gpt_task() 10 | time.sleep(20) -------------------------------------------------------------------------------- /collection/run_random_walk_aw.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | 4 | 5 | def run_random_walk(): 6 | os.system("python random_walk_aw.py") 7 | 8 | while True: 9 | run_random_walk() 10 | time.sleep(20) -------------------------------------------------------------------------------- /evaluation/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/evaluation/.DS_Store -------------------------------------------------------------------------------- /evaluation/android_control/ac_eval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json 3 | from collections import defaultdict 4 | import argparse 5 | import re 6 | import math 7 | import os 8 | import ast 9 | import logging 10 | import openpyxl 11 | 12 | logger = logging.getLogger(__name__) 13 | logger.setLevel(logging.INFO) 14 | 15 | TYPE="high" 16 | MODEL="qwen2vl" 17 | PRED_FILE_PATH=f"results/{MODEL}/predictions.jsonl" 18 | def get_in_domain_ids(jsonl_path): 19 | in_domain_ids = [] 20 | with open(jsonl_path, 'r') as f: 21 | for line in f: 22 | data = json.loads(line) 23 | in_domain_ids.append(data["id"]) 24 | return in_domain_ids 25 | 26 | 27 | def extract_json(s): 28 | match = re.search(r'\{.*?\}', s) 29 | if match: 30 | json_str = match.group(0) 31 | try: 32 | return json.loads(json_str) 33 | except json.JSONDecodeError: 34 | return json_str # 返回字符串形式 35 | return None 36 | 37 | 38 | def extract_dict_from_string(text): 39 | start_index = text.find("Accessibility tree: ") 40 | if start_index == -1: 41 | return None # 没有找到 "Accessibility tree: " 返回 None 42 | 43 | # 找到字典的起始位置 44 | dict_start = start_index + len("Accessibility tree: ") 45 | 46 | # 提取出字典部分的字符串 47 | dict_str = text[dict_start:] 48 | 49 | # 找到第一个 '{' 和匹配的 '}' 50 | open_brace_index = dict_str.find('{') 51 | if open_brace_index == -1: 52 | return None # 没有找到 '{' 返回 None 53 | 54 | stack = [] 55 | for i, char in enumerate(dict_str[open_brace_index:], start=open_brace_index): 56 | if char == '{': 57 | stack.append(char) 58 | elif char == '}': 59 | stack.pop() 60 | if not stack: 61 | dict_end = i + 1 62 | break 63 | else: 64 | return None # 没有找到匹配的 '}' 返回 None 65 | 66 | # 提取出字典字符串 67 | dict_substr = dict_str[open_brace_index:dict_end] 68 | 69 | return dict_substr 70 | 71 | def get_key_by_position(text, x, y): 72 | position = f"({x}, {y})" 73 | extracted_dict = json.loads(extract_dict_from_string(text)) 74 | if extracted_dict is None: 75 | return None # 如果没有找到字典或解析失败返回 None 76 | 77 | for key in extracted_dict: 78 | if extracted_dict[key] == position: 79 | return key 80 | 81 | return None # 没有找到匹配的位置返回 None 82 | 83 | def calculate_f1_score(predicted_str, ground_truth_str): 84 | predicted_tokens = set(predicted_str.lower().split()) 85 | ground_truth_tokens = set(ground_truth_str.lower().split()) 86 | 87 | common_tokens = predicted_tokens.intersection(ground_truth_tokens) 88 | if len(predicted_tokens) == 0: 89 | precision = 0 90 | else: 91 | precision = len(common_tokens) / len(predicted_tokens) 92 | if len(ground_truth_tokens) == 0: 93 | recall = 0 94 | else: 95 | recall = len(common_tokens) / len(ground_truth_tokens) 96 | 97 | if precision + recall == 0: 98 | f1_score = 0 99 | else: 100 | f1_score = 2 * (precision * recall) / (precision + recall) 101 | return f1_score 102 | 103 | 104 | def evaluate(args): 105 | prediction_file_path = args.prediction_file_path 106 | prediction = [] 107 | with open(prediction_file_path) as file: 108 | for line in file: 109 | prediction.append(json.loads(line)) 110 | 111 | ground_truth = [] 112 | with open("eval_json_files/android_control_test_data.jsonl") as file: 113 | for line in file: 114 | ground_truth.append(json.loads(line)) 115 | 116 | with open(f"eval_json_files/android_control_test_subsplits.json","r") as file: 117 | test_subsplits = json.load(file) 118 | print(test_subsplits.keys()) 119 | print(len(ground_truth)) 120 | 121 | 122 | # ======================================================================== # 123 | # Results on Low-level 124 | # ======================================================================== # 125 | mis_click_wait_num = 0 126 | step_acc_res_dict = defaultdict(int) 127 | sample_number_dict = defaultdict(int) 128 | for pred, gt in zip(prediction, ground_truth): 129 | gt_action = json.loads(gt["conversations"][1]["value"].split("actions:\n")[1]) 130 | episode_id = int(pred["image_id"].split("/")[-1].split("_")[1]) # parse out the episode index 131 | subsplit_type = next((category for category, ids in test_subsplits.items() if episode_id in ids), None) 132 | gt_action_type = gt_action["action_type"] 133 | sample_number_dict[subsplit_type+"_LL"] += 1 134 | sample_number_dict["full_LL"] += 1 135 | sample_number_dict["Type_LL"] += 1 136 | sample_number_dict[gt_action_type+"_LL"] += 1 137 | 138 | if len(pred["pred"].split("action: "))==2: 139 | try: 140 | pred_action = json.loads(pred["pred"].split("action: ")[1]) 141 | if len(pred_action) == 0: 142 | continue 143 | except json.JSONDecodeError as e: 144 | continue 145 | else: 146 | pred_action = extract_json(pred["pred"]) 147 | if pred_action is None: 148 | continue 149 | try: 150 | pred_action_type = pred_action["action_type"] 151 | except Exception as e: 152 | continue 153 | 154 | # calculate step acc based on types 155 | if gt_action_type==pred_action_type or (gt_action_type == "type" and pred_action_type == "input_text"): 156 | step_acc_res_dict["Type_LL"] += 1 157 | step_acc_res_dict[gt_action_type+"_type_match_LL"] += 1 158 | if gt_action_type in ["click","long_press"]: # evaluate click type 159 | try: 160 | pred_x, pred_y = int(pred_action["x"]), int(pred_action["y"]) 161 | except: 162 | pred_x, pred_y = -100, -100 163 | gt_x, gt_y = int(gt_action["x"]), int(gt_action["y"]) 164 | 165 | if math.sqrt((pred_x - gt_x)**2 + (pred_y - gt_y)**2) <= math.sqrt((1080*0.14)**2 + (2400*0.14)**2): # set 14 % of screen size as the ratio 166 | step_acc_res_dict[subsplit_type+"_LL"] += 1 167 | step_acc_res_dict["full_LL"] += 1 168 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 169 | 170 | 171 | elif gt_action_type == "type" and pred_action_type in ["input_text", "type"]: 172 | if gt_action["text"]==pred_action["text"] or calculate_f1_score(pred_action["text"], gt_action["text"])>0.5: 173 | step_acc_res_dict[subsplit_type+"_LL"] += 1 174 | step_acc_res_dict["full_LL"] += 1 175 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 176 | 177 | elif gt_action_type == "scroll": 178 | if "Scroll up" in pred["prompt"] and pred_action["direction"] == "up": 179 | step_acc_res_dict[subsplit_type+"_LL"] += 1 180 | step_acc_res_dict["full_LL"] += 1 181 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 182 | elif "Scroll down" in pred["prompt"] and pred_action["direction"] == "down": 183 | step_acc_res_dict[subsplit_type+"_LL"] += 1 184 | step_acc_res_dict["full_LL"] += 1 185 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 186 | elif pred_action==gt_action: 187 | step_acc_res_dict[subsplit_type+"_LL"] += 1 188 | step_acc_res_dict["full_LL"] += 1 189 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 190 | 191 | elif gt_action==pred_action: # evaluate other types 192 | step_acc_res_dict[subsplit_type+"_LL"] += 1 193 | step_acc_res_dict["full_LL"] += 1 194 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1 195 | 196 | 197 | # Print the low-level results 198 | logger.info("="*30 + " Step Acc " + "="*30) 199 | logger.info("Full-LL: %f\n" % (step_acc_res_dict["full_LL"] / sample_number_dict["full_LL"])) 200 | logger.info("Type-LL: %f\n" % (step_acc_res_dict["Type_LL"] / sample_number_dict["Type_LL"])) 201 | # 保存结果到excel中 202 | SR_acc = round((step_acc_res_dict["full_LL"] / sample_number_dict["full_LL"]) * 100, 2) 203 | Type_acc = round((step_acc_res_dict["Type_LL"] / sample_number_dict["Type_LL"]) * 100, 2) 204 | # 打开 Excel 文件 205 | file_path = "android_control_eval.xlsx" 206 | wb = openpyxl.load_workbook(file_path) 207 | 208 | # 选择工作表 209 | sheet = wb.active 210 | 211 | # 找到下一个空行 212 | next_row = sheet.max_row + 1 213 | 214 | # 写入数据 215 | sheet.cell(row=next_row, column=1, value=MODEL) 216 | sheet.cell(row=next_row, column=2, value=SR_acc) 217 | sheet.cell(row=next_row, column=3, value=Type_acc) 218 | 219 | # 保存文件 220 | wb.save(file_path) 221 | 222 | logger.info("IDD-LL: %f\n" % (step_acc_res_dict["IDD_LL"] / sample_number_dict["IDD_LL"])) 223 | logger.info("app_unseen-LL: %f\n" % (step_acc_res_dict["app_unseen_LL"] / sample_number_dict["app_unseen_LL"])) 224 | logger.info("task_unseen-LL: %f\n" % (step_acc_res_dict["task_unseen_LL"] / sample_number_dict["task_unseen_LL"])) 225 | logger.info("category_unseen-LL: %f\n" % (step_acc_res_dict["category_unseen_LL"] / sample_number_dict["category_unseen_LL"])) 226 | logger.info("="*30 + " Detailed Acc of Each Type " + "="*30) 227 | for action_type in sample_number_dict: 228 | action_type = action_type.split("_LL")[0] 229 | if action_type not in ["full","Type","IDD","app_unseen","task_unseen","category_unseen"]: 230 | logger.info(f"{action_type}_all_match-LL: %f\n" % (step_acc_res_dict[f"{action_type}_all_match_LL"] / sample_number_dict[f"{action_type}_LL"])) 231 | logger.info(f"{action_type}_type_match-LL: %f\n" % (step_acc_res_dict[f"{action_type}_type_match_LL"] / sample_number_dict[f"{action_type}_LL"])) 232 | 233 | 234 | 235 | if __name__ == '__main__': 236 | parser = argparse.ArgumentParser() 237 | 238 | parser.add_argument('--prediction_file_path', type=str, default=PRED_FILE_PATH) 239 | parser.add_argument('--datasets', type=str, default='') 240 | parser.add_argument('--output_path', type=str, default='results/score/') 241 | parser.add_argument('--eval_HH', action='store_true') 242 | parser.add_argument('--seed', type=int, default=0) 243 | args = parser.parse_args() 244 | 245 | if not os.path.exists(args.output_path): 246 | os.makedirs(args.output_path) 247 | 248 | file_handler = logging.FileHandler(args.output_path + f"ac_{TYPE}_score_{MODEL}.log", mode="w") 249 | file_handler.setLevel(logging.INFO) 250 | 251 | console_handler = logging.StreamHandler() 252 | console_handler.setLevel(logging.INFO) 253 | 254 | logger.addHandler(file_handler) 255 | logger.addHandler(console_handler) 256 | 257 | evaluate(args) -------------------------------------------------------------------------------- /evaluation/android_control/internvl2_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | import torch 9 | from internvl.model.internvl_chat import InternVLChatModel 10 | from internvl.train.dataset import build_transform, dynamic_preprocess 11 | from PIL import Image 12 | from tqdm import tqdm 13 | from transformers import AutoTokenizer 14 | import logging 15 | 16 | logger = logging.getLogger(__name__) 17 | logger.setLevel(logging.INFO) 18 | 19 | ds_collections = { 20 | 'ac_high': { 21 | 'root': '/path/to/images', 22 | 'annotation': 'eval_json_files/ac_high_processing.jsonl', 23 | 'max_new_tokens': 999, 24 | 'min_new_tokens': 1, 25 | }, 26 | 'ac_low': { 27 | 'root': '/path/to/images', 28 | 'annotation': 'eval_json_files/ac_low_processing.jsonl', 29 | 'max_new_tokens': 999, 30 | 'min_new_tokens': 1, 31 | } 32 | 33 | } 34 | 35 | 36 | def collate_fn(batches, tokenizer): 37 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0) 38 | questions = [_['question'] for _ in batches] 39 | items = [_['item'] for _ in batches] 40 | return pixel_values, questions, items 41 | 42 | 43 | class AndroidControlDataset(torch.utils.data.Dataset): 44 | 45 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False, 46 | use_thumbnail=False, max_num=6): 47 | self.root = root 48 | self.items = [] 49 | f = open(annotation) 50 | data = f.readlines() 51 | for data_line in data: 52 | data_line = json.loads(data_line) 53 | self.items.append(data_line) 54 | self.input_size = input_size # input size?? 55 | self.dynamic_image_size = dynamic_image_size 56 | self.use_thumbnail = use_thumbnail 57 | self.max_num = max_num 58 | self.transform = build_transform(is_train=False, input_size=input_size) 59 | 60 | def __len__(self): 61 | return len(self.items) 62 | 63 | def __getitem__(self, idx): 64 | item = self.items[idx] 65 | image_path, question = item['image'], item['conversations'][0]['value'] 66 | image = Image.open(image_path).convert('RGB') 67 | if self.dynamic_image_size: 68 | images = dynamic_preprocess(image, image_size=self.input_size, 69 | use_thumbnail=self.use_thumbnail, 70 | max_num=self.max_num) 71 | else: 72 | images = [image] 73 | pixel_values = [self.transform(image) for image in images] 74 | pixel_values = torch.stack(pixel_values) 75 | 76 | return { 77 | 'question': question, 78 | 'pixel_values': pixel_values, 79 | 'item': item, 80 | } 81 | 82 | 83 | 84 | class InferenceSampler(torch.utils.data.sampler.Sampler): 85 | 86 | def __init__(self, size): 87 | self._size = int(size) 88 | assert size > 0 89 | self._rank = torch.distributed.get_rank() 90 | self._world_size = torch.distributed.get_world_size() 91 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 92 | 93 | @staticmethod 94 | def _get_local_indices(total_size, world_size, rank): 95 | shard_size = total_size // world_size 96 | left = total_size % world_size 97 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 98 | 99 | begin = sum(shard_sizes[:rank]) 100 | end = min(sum(shard_sizes[:rank + 1]), total_size) 101 | return range(begin, end) 102 | 103 | def __iter__(self): 104 | yield from self._local_indices 105 | 106 | def __len__(self): 107 | return len(self._local_indices) 108 | 109 | 110 | 111 | def evaluate_chat_model(): 112 | random.seed(args.seed) 113 | 114 | if args.ds_name_list == None: 115 | ds_names = ds_collections.keys() 116 | else: 117 | ds_names = args.ds_name_list 118 | for ds_name in ds_names: 119 | dataset = AndroidControlDataset( 120 | root=ds_collections[ds_name]['root'], 121 | annotation=ds_collections[ds_name]['annotation'], 122 | input_size=image_size, 123 | dynamic_image_size=args.dynamic, 124 | use_thumbnail=use_thumbnail, 125 | max_num=args.max_num 126 | ) 127 | dataloader = torch.utils.data.DataLoader( 128 | dataset=dataset, 129 | sampler=InferenceSampler(len(dataset)), 130 | batch_size=args.batch_size, 131 | num_workers=args.num_workers, 132 | pin_memory=True, 133 | drop_last=False, 134 | collate_fn=partial(collate_fn, tokenizer=tokenizer), 135 | ) 136 | 137 | 138 | logger.info(f'Evaluating {ds_name} ...') 139 | 140 | outputs = [] 141 | for _, (pixel_values, questions, items) in tqdm(enumerate(dataloader)): 142 | pixel_values = pixel_values.to(torch.bfloat16).cuda() 143 | generation_config = dict( 144 | num_beams=args.num_beams, 145 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'], 146 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'], 147 | do_sample=True if args.temperature > 0 else False, 148 | temperature=args.temperature, 149 | ) 150 | pred = model.chat( 151 | tokenizer=tokenizer, 152 | pixel_values=pixel_values, 153 | question=questions[0], 154 | generation_config=generation_config 155 | ) 156 | preds = [pred] 157 | 158 | for question, answer, item in zip(questions, preds, items): 159 | question_id = item['id'] 160 | image_id = item['image'].split("\/")[-1].replace(".png", "") 161 | text = question 162 | output = { 163 | 'question_id': question_id, 164 | 'image_id': image_id, 165 | 'prompt': text, 166 | 'pred': answer, 167 | 'model_id': model_id, 168 | 'metadata': {}, 169 | "task_name": "task2action", 170 | "string_format": "CSV_String", 171 | "position_format": "related" 172 | } 173 | outputs.append(output) 174 | 175 | torch.distributed.barrier() 176 | world_size = torch.distributed.get_world_size() 177 | merged_outputs = [None for _ in range(world_size)] 178 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 179 | 180 | merged_outputs = [json.loads(_) for _ in merged_outputs] 181 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 182 | 183 | torch.distributed.barrier() 184 | 185 | if torch.distributed.get_rank() == 0: 186 | print(len(merged_outputs)) 187 | results_file = f'{model_id}_{ds_name}.jsonl' 188 | results_file = os.path.join(args.out_dir, results_file) 189 | with open(results_file, 'w') as f: 190 | for output in merged_outputs: 191 | json.dump(output, f) 192 | f.write('\n') 193 | print('Results saved to {}'.format(results_file)) 194 | 195 | 196 | if __name__ == '__main__': 197 | parser = argparse.ArgumentParser() 198 | parser.add_argument('--checkpoint', type=str, default='') 199 | parser.add_argument('--datasets', type=str, default='') 200 | parser.add_argument('--batch-size', type=int, default=1) 201 | parser.add_argument('--num-workers', type=int, default=1) 202 | parser.add_argument('--num-beams', type=int, default=1) 203 | parser.add_argument('--temperature', type=float, default=0.0) 204 | parser.add_argument('--out-dir', type=str, default='results') 205 | parser.add_argument('--seed', type=int, default=0) 206 | parser.add_argument('--dynamic', action='store_true') 207 | parser.add_argument('--max-num', type=int, default=6) 208 | parser.add_argument('--load-in-8bit', action='store_true') 209 | parser.add_argument('--auto', action='store_true') 210 | parser.add_argument('--ds_name_list', type=str, nargs='*', default=None, help='List of dataset names') 211 | args = parser.parse_args() 212 | 213 | args.datasets = args.datasets.split(',') 214 | logger.info('datasets:', args.datasets) 215 | assert args.batch_size == 1, 'Only batch size 1 is supported' 216 | 217 | torch.distributed.init_process_group( 218 | backend='nccl', 219 | world_size=int(os.getenv('WORLD_SIZE', '1')), 220 | rank=int(os.getenv('RANK', '0')), 221 | ) 222 | 223 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 224 | 225 | 226 | if args.auto: 227 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 228 | kwargs = {'device_map': 'auto'} if args.auto else {} 229 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False) 230 | model = InternVLChatModel.from_pretrained( 231 | args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16, 232 | load_in_8bit=args.load_in_8bit, **kwargs).eval() 233 | if not args.load_in_8bit and not args.auto: 234 | model = model.cuda() 235 | image_size = model.config.force_image_size or model.config.vision_config.image_size 236 | use_thumbnail = model.config.use_thumbnail 237 | 238 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 239 | if total_params > 20 or args.dynamic: 240 | args.num_beams = 1 241 | logger.info(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 242 | else: 243 | logger.info(f'[test] total_params: {total_params}B') 244 | logger.info(f'[test] image_size: {image_size}') 245 | logger.info(f'[test] template: {model.config.template}') 246 | logger.info(f'[test] dynamic_image_size: {args.dynamic}') 247 | logger.info(f'[test] use_thumbnail: {use_thumbnail}') 248 | logger.info(f'[test] max_num: {args.max_num}') 249 | if torch.distributed.get_rank() == 0: 250 | if not os.path.exists(args.out_dir): 251 | os.makedirs(args.out_dir) 252 | 253 | model_id = '_'.join(args.checkpoint.split('/')[-1:]) 254 | evaluate_chat_model() -------------------------------------------------------------------------------- /evaluation/android_control/qwen2vl_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import itertools 3 | import json 4 | import os 5 | import random 6 | import time 7 | from functools import partial 8 | 9 | import torch 10 | # from internvl.model.internvl_chat import InternVLChatModel 11 | # from internvl.train.dataset import build_transform, dynamic_preprocess 12 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 13 | from qwen_vl_utils import process_vision_info 14 | from PIL import Image 15 | from tqdm import tqdm 16 | from transformers import AutoTokenizer 17 | import logging 18 | 19 | logger = logging.getLogger(__name__) 20 | logger.setLevel(logging.INFO) 21 | max_pixels = 1024 * 1024 22 | 23 | import torch.multiprocessing as mp 24 | mp.set_start_method('spawn', force=True) 25 | 26 | ds_collections = { 27 | 'ac_high': { 28 | 'root': '/path/to/images', 29 | 'annotation': 'eval_json_files/ac_high_processing.jsonl', 30 | 'max_new_tokens': 999, 31 | 'min_new_tokens': 1, 32 | }, 33 | 'ac_low': { 34 | 'root': '/path/to/images', 35 | 'annotation': 'eval_json_files/ac_low_processing.jsonl', 36 | 'max_new_tokens': 999, 37 | 'min_new_tokens': 1, 38 | } 39 | 40 | } 41 | 42 | 43 | def collate_fn(batches): 44 | inputs = [_['inputs'] for _ in batches] 45 | questions = [_['question'] for _ in batches] 46 | items = [_['item'] for _ in batches] 47 | return inputs, questions, items 48 | 49 | 50 | 51 | class AndroidControlDataset(torch.utils.data.Dataset): 52 | 53 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False, 54 | use_thumbnail=False, max_num=6, device="cuda:0"): 55 | self.root = root 56 | self.items = [] 57 | f = open(annotation) 58 | data = f.readlines() 59 | for data_line in data: 60 | data_line = json.loads(data_line) 61 | self.items.append(data_line) 62 | self.input_size = input_size # input size?? 63 | self.dynamic_image_size = dynamic_image_size 64 | self.use_thumbnail = use_thumbnail 65 | self.max_num = max_num 66 | self.device = device 67 | max_pixels = 1024 * 1024 68 | self.processor = AutoProcessor.from_pretrained("/nas/shared/NLP_A100/wuzhenyu/LLMs/Qwen2-VL-7B-Instruct", max_pixels=max_pixels) 69 | def __len__(self): 70 | return len(self.items) 71 | 72 | def __getitem__(self, idx): 73 | item = self.items[idx] 74 | image_path, question = item['image'], item['conversations'][0]['value'] 75 | messages = [ 76 | { 77 | "role": "user", 78 | "content": [ 79 | {"type": "text", "text": question.split("")[0]}, 80 | { 81 | "type": "image", 82 | "image": image_path, 83 | }, 84 | {"type": "text", "text": question.split("")[1]}, 85 | ], 86 | } 87 | ] 88 | text = self.processor.apply_chat_template( 89 | messages, tokenize=False, add_generation_prompt=True 90 | ) 91 | image_inputs, video_inputs = process_vision_info(messages) 92 | inputs = self.processor( 93 | text=[text], 94 | images=image_inputs, 95 | videos=video_inputs, 96 | padding=True, 97 | return_tensors="pt", 98 | ) 99 | return { 100 | 'question': question, 101 | 'inputs': inputs, 102 | 'item': item, 103 | } 104 | 105 | 106 | class InferenceSampler(torch.utils.data.sampler.Sampler): 107 | 108 | def __init__(self, size): 109 | self._size = int(size) 110 | assert size > 0 111 | self._rank = torch.distributed.get_rank() 112 | self._world_size = torch.distributed.get_world_size() 113 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank) 114 | 115 | @staticmethod 116 | def _get_local_indices(total_size, world_size, rank): 117 | shard_size = total_size // world_size 118 | left = total_size % world_size 119 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)] 120 | 121 | begin = sum(shard_sizes[:rank]) 122 | end = min(sum(shard_sizes[:rank + 1]), total_size) 123 | return range(begin, end) 124 | 125 | def __iter__(self): 126 | yield from self._local_indices 127 | 128 | def __len__(self): 129 | return len(self._local_indices) 130 | 131 | processor = AutoProcessor.from_pretrained("/nas/shared/NLP_A100/wuzhenyu/LLMs/Qwen2-VL-7B-Instruct", max_pixels=max_pixels) 132 | 133 | def evaluate_chat_model(model): 134 | random.seed(args.seed) 135 | 136 | if args.ds_name_list == None: 137 | ds_names = ds_collections.keys() 138 | else: 139 | ds_names = args.ds_name_list 140 | for ds_name in ds_names: 141 | dataset = AndroidControlDataset( 142 | root=ds_collections[ds_name]['root'], 143 | annotation=ds_collections[ds_name]['annotation'], 144 | input_size=224, 145 | dynamic_image_size=args.dynamic, 146 | use_thumbnail=False, 147 | max_num=args.max_num, 148 | device=model.device 149 | ) 150 | dataloader = torch.utils.data.DataLoader( 151 | dataset=dataset, 152 | sampler=InferenceSampler(len(dataset)), 153 | batch_size=args.batch_size, 154 | num_workers=args.num_workers, 155 | pin_memory=False, 156 | drop_last=False, 157 | collate_fn=partial(collate_fn), 158 | ) 159 | 160 | 161 | logger.info(f'Evaluating {ds_name} ...') 162 | 163 | outputs = [] 164 | for _, (inputs, questions, items) in tqdm(enumerate(dataloader)): 165 | 166 | inputs = inputs[0] 167 | inputs = inputs.to(model.device) 168 | print(f"inputs: {inputs}") 169 | generated_ids = model.generate( 170 | **{k: v.to(model.device) for k, v in inputs.items()}, 171 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'] 172 | ) 173 | generated_ids_trimmed = [ 174 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 175 | ] 176 | output_text = processor.batch_decode( 177 | generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False 178 | ) 179 | preds = [output_text[0]] 180 | print(f"preds: {preds}") 181 | for question, answer, item in zip(questions, preds, items): 182 | question_id = item['id'] 183 | image_id = item['image'].split("\/")[-1].replace(".png", "") 184 | text = question 185 | output = { 186 | 'question_id': question_id, 187 | 'image_id': image_id, 188 | 'prompt': text, 189 | 'pred': answer, 190 | 'model_id': model_id, 191 | 'metadata': {}, 192 | "task_name": "task2action", 193 | "string_format": "CSV_String", 194 | "position_format": "related" 195 | } 196 | outputs.append(output) 197 | 198 | 199 | torch.distributed.barrier() 200 | world_size = torch.distributed.get_world_size() 201 | merged_outputs = [None for _ in range(world_size)] 202 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs)) 203 | 204 | merged_outputs = [json.loads(_) for _ in merged_outputs] 205 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)] 206 | 207 | torch.distributed.barrier() 208 | 209 | if torch.distributed.get_rank() == 0: 210 | print(len(merged_outputs)) 211 | results_file = f'{model_id}_{ds_name}.jsonl' 212 | results_file = os.path.join(args.out_dir, results_file) 213 | # json.dump(merged_outputs, open(results_file, 'w')) 214 | with open(results_file, 'w') as f: 215 | for output in merged_outputs: 216 | json.dump(output, f) 217 | f.write('\n') 218 | print('Results saved to {}'.format(results_file)) 219 | 220 | 221 | if __name__ == '__main__': 222 | parser = argparse.ArgumentParser() 223 | parser.add_argument('--checkpoint', type=str, default='') 224 | parser.add_argument('--datasets', type=str, default='') 225 | parser.add_argument('--batch-size', type=int, default=1) 226 | parser.add_argument('--num-workers', type=int, default=1) 227 | parser.add_argument('--num-beams', type=int, default=1) 228 | parser.add_argument('--temperature', type=float, default=0.0) 229 | parser.add_argument('--out-dir', type=str, default='results') 230 | parser.add_argument('--seed', type=int, default=0) 231 | parser.add_argument('--dynamic', action='store_true') 232 | parser.add_argument('--max-num', type=int, default=6) 233 | parser.add_argument('--load-in-8bit', action='store_true') 234 | parser.add_argument('--auto', action='store_true') 235 | parser.add_argument('--ds_name_list', type=str, nargs='*', default=None, help='List of dataset names') 236 | args = parser.parse_args() 237 | 238 | args.datasets = args.datasets.split(',') 239 | logger.info('datasets:', args.datasets) 240 | assert args.batch_size == 1, 'Only batch size 1 is supported' 241 | 242 | torch.distributed.init_process_group( 243 | backend='nccl', 244 | world_size=int(os.getenv('WORLD_SIZE', '1')), 245 | rank=int(os.getenv('RANK', '0')), 246 | ) 247 | 248 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0))) 249 | 250 | 251 | if args.auto: 252 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 253 | kwargs = {'device_map': 'auto'} if args.auto else {} 254 | 255 | model = Qwen2VLForConditionalGeneration.from_pretrained( 256 | args.checkpoint, 257 | torch_dtype=torch.bfloat16 258 | ).cuda().eval() 259 | total_params = sum(p.numel() for p in model.parameters()) / 1e9 260 | if total_params > 20 or args.dynamic: 261 | args.num_beams = 1 262 | logger.info(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}') 263 | else: 264 | logger.info(f'[test] total_params: {total_params}B') 265 | logger.info(f'[test] dynamic_image_size: {args.dynamic}') 266 | logger.info(f'[test] max_num: {args.max_num}') 267 | if torch.distributed.get_rank() == 0: 268 | if not os.path.exists(args.out_dir): 269 | os.makedirs(args.out_dir) 270 | 271 | model_id = '_'.join(args.checkpoint.split('/')[-1:]) 272 | evaluate_chat_model(model) 273 | -------------------------------------------------------------------------------- /evaluation/android_control/run_ac_inference.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | DATASET=${1} 4 | CHECKPOINT=${2} 5 | echo "DATASET: ${DATASET}" 6 | echo "CHECKPOINT: ${CHECKPOINT}" 7 | 8 | # 检查数据集名称是否正确 9 | options=("ac_high" "ac_low") 10 | if [[ " ${options[@]} " =~ " ${DATASET} " ]]; then 11 | echo "输入 '$DATASET' 是有效选项" 12 | else 13 | echo "错误: '$DATASET' 不是有效选项" 14 | echo "有效选项是: ${options[@]}" 15 | exit 1 16 | fi 17 | 18 | PATCH=24 19 | 20 | TEST_SET_NAME=${DATASET} 21 | echo "测试集key: ${TEST_SET_NAME}" 22 | 23 | MASTER_PORT=${MASTER_PORT:-63604} 24 | PORT=${PORT:-63604} 25 | PARTITION=${PARTITION:-"INTERN2"} 26 | 27 | export MASTER_PORT=${MASTER_PORT} 28 | export PORT=${PORT} 29 | echo "GPUS: ${GPUS}" 30 | 31 | NODES=$((GPUS / GPUS_PER_NODE)) 32 | GPUS_PER_NODE=${GPUS_PER_NODE:-${GPUS}} 33 | QUOTA_TYPE=${QUOTA_TYPE:-"reserved"} 34 | SRUN_ARGS=${SRUN_ARGS:-""} 35 | 36 | 37 | your/conda/path/torchrun \ 38 | --nnodes=1 \ 39 | --node_rank=0 \ 40 | --master_addr=127.0.0.1 \ 41 | --nproc_per_node=${GPUS} \ 42 | --master_port=${MASTER_PORT} \ 43 | internvl2_inference.py --checkpoint ${CHECKPOINT_DIR} --datasets ${DATASET} \ 44 | --dynamic \ 45 | --max-num ${PATCH} \ 46 | --ds_name_list ${TEST_SET_NAME} \ 47 | --out-dir ${OUT_DIR};" 48 | 49 | -------------------------------------------------------------------------------- /evaluation/android_world/README.md: -------------------------------------------------------------------------------- 1 | # AndroidWorld Evaluation 2 | 3 | https://github.com/google-research/android_world -------------------------------------------------------------------------------- /evaluation/eval_json_files/android_control_test_subsplits.json: -------------------------------------------------------------------------------- 1 | {"IDD": [4790, 8333, 1589, 14287, 2860, 16124, 19349, 17863, 880, 4537, 3506, 10986, 10235, 12166, 7838, 11512, 1489, 1719, 12303, 10558, 11410, 697, 18103, 2775, 13134, 5366, 16951, 14103, 14607, 16467, 19199, 2426, 13314, 1665, 1569, 3195, 8886, 10988, 14804, 4622, 13806, 14495, 4123, 12182, 4807, 6645, 19887, 6158, 18654, 14525, 16402, 2706, 13189, 17036, 1260, 8583, 10864, 3484, 6714, 2005, 11550, 19946, 9283, 15527, 7201, 6304, 7320, 6472, 18544, 3061, 3613, 18932, 16588, 10058, 1764, 16503, 8197, 12803, 18557, 5147, 3281, 11672, 7961, 12758, 14815, 2676, 470, 6240, 4045, 18875, 27, 6424, 3240, 15892, 4550, 3833, 14723, 18491, 5718, 9961, 11673, 16022, 2139, 9711, 17522, 16526, 11198, 15865, 9823, 2010, 18231, 16674, 14928, 19477, 19315, 7195, 17598, 7609, 2666, 1542, 11128, 10925, 1727, 12926, 13274, 20116, 5602, 640, 9837, 19666, 13674, 13459, 15204, 4564, 8064, 19602, 787, 4377, 2221, 19277, 1581, 15057, 6119, 4047, 6323, 4724, 16169, 13752, 15594, 971, 6708, 1089, 8717, 8484, 9195, 631, 16072, 13529, 7646, 16004, 5704, 15070, 15540, 2748, 15785, 19510, 7903, 13428, 17280, 11726, 2777, 6346, 6443, 907, 5002, 11848, 4782, 8262, 6532, 3476, 3566, 12338, 1969, 1821, 18038, 5278, 19382, 3124, 5072, 15055, 19817, 2802, 12538, 6871, 13546, 5205, 1082, 13805, 3176, 15718, 7194, 19833, 10844, 9747, 16527, 18703, 6236, 16036, 10944, 2471, 13228, 4745, 904, 12336, 9072, 3814, 15267, 1815, 19565, 1045, 18771, 17880, 10956, 3041, 9812, 13061, 7430, 16579, 16760, 15597, 8500, 15658, 5595, 14172, 4328, 15036, 2685, 777, 17611, 11549, 8730, 1933, 9144, 20016, 10292, 14129, 4015, 5407, 19971, 4936, 2165, 10443, 938, 14297, 17102, 610, 7382, 8202, 14779, 11306, 9755, 12076, 1280, 14076, 12614, 4797, 9143, 9120, 13217, 11277, 8864, 17334, 13248, 6760, 5996, 3256, 2754, 13036, 7119, 3128, 15032, 19088, 17452, 16464, 11800, 18310, 5316, 10081, 6041, 15100, 10167, 18521, 7472, 17317, 13385, 1438, 12069, 15020, 15832, 18237, 9835, 5836, 18513, 261, 2682, 6589, 19880, 15419, 18944, 140, 12863, 1959, 16837, 12163, 5215, 1632, 5241, 2673, 5791, 13441, 17981, 14639, 7456, 12391, 19551, 13440, 9168, 3898, 12594, 8074, 1394, 13256, 7829, 17033, 6759, 2124, 15653, 8963, 423, 11843, 55, 10232, 2912, 8748, 12829, 6153, 10294, 17169, 18723, 14514, 16208, 2078, 16441, 2612, 1741, 11222, 2588, 3945, 19585, 19394, 15225, 12678, 13436, 4796, 5897, 8726, 16514, 13754, 20122, 10273, 6707, 19877, 17994, 6404, 10476, 4699, 16353, 7958, 6402, 8271, 3347, 12743, 15966, 16669, 5264, 4833, 11342, 15349, 6207, 17842, 6048, 13182, 6408, 9938, 12473, 11632, 13975, 15749, 19268, 4413, 19719, 15752, 8324, 3419, 11912, 8155, 10608, 2250, 3760, 1672, 14401, 794, 5956, 8629, 5906, 3031, 5837, 2406, 3708, 8537, 7748, 3319, 11300, 5455, 15017, 4680, 10726, 6623, 17479, 11260, 7860, 11714, 127, 10142, 12739, 229, 7866, 12390, 17422, 7399, 14343, 3292, 5574, 14212, 13554, 18181, 3000, 2630, 5902, 9783, 394, 13970, 2936, 11450, 3496, 11035, 3747, 7993, 11494, 14776, 9467, 15241, 12577, 10183, 3948, 10690, 17355, 14358, 9698, 1892, 5385, 4976, 4943, 1330, 18124, 4756, 719, 18574, 16140, 2302, 7821, 13613, 6606, 19680, 13745, 6211, 2822, 16860, 7633, 17085, 6353, 16452, 2387, 17111, 13522, 1904, 7806, 4501, 19813, 18185, 13349, 9934, 19653, 3373, 1381, 2158, 18099, 11404, 10323, 3806, 15405, 1895, 8908, 10037, 5076, 759, 14307, 10318, 5693, 12642, 916, 13468, 12508, 8216, 8282, 1099, 6689, 4090, 11040, 15350, 13429, 11217, 18991, 11682, 5563, 1027, 12187, 18938, 936, 3026, 7587, 8740, 6932, 13018, 14109, 9112, 3425, 7774, 15348, 13841, 14178, 19928, 18950, 3738, 6919, 7100, 8804, 19552, 3241, 10241, 14319, 15583, 2535, 387, 19190, 4876, 11967, 12978, 2327, 11274, 12429, 20097, 16712, 19460, 7548, 15999, 10531, 6876, 8084, 18488, 9593, 7249, 7455, 8709, 17840, 15167, 14940, 19033, 4998, 5615, 12479, 5611, 2121, 7538, 14768, 19242, 12243, 2771, 14942, 12334, 16768, 10501, 943, 8651, 16608, 17114, 4437, 14320, 4692, 15316, 10710, 2592, 3250, 8389, 84, 15908, 19339, 11520, 10699, 18679, 4603, 19537, 1306, 15778, 4902, 14953, 12364, 3338, 10265, 16154, 15537, 2862, 9128, 19676, 15857, 8899, 9231, 7765, 4159, 14157, 13733, 2743, 7239, 18973, 3317, 4190, 12289, 54, 15646, 237, 5990, 3507, 7955, 16696, 19251, 17831, 10571, 13526, 14537, 14230, 1194, 1019, 7869, 2269, 3438, 407, 20055, 3100, 4358, 17953, 2488, 2626, 10724, 6670, 15703, 5200, 9893, 6904, 2921, 2553, 18078, 7889, 2093, 11885, 16663, 4988, 13846, 12498, 15476, 12823, 11448, 6121, 20086, 11042, 9719, 7884, 13596, 300, 13697, 13800, 17237, 19170, 3191, 803, 4640, 12709, 14114, 19657, 15386, 11412, 15634, 18894, 11900, 12248, 6779, 10786, 5676, 9767, 1624, 8697, 13686], "category_unseen": [53, 62, 73, 104, 122, 136, 167, 172, 185, 190, 219, 227, 228, 231, 270, 282, 296, 318, 382, 406, 417, 495, 548, 567, 569, 575, 586, 591, 636, 660, 662, 672, 674, 723, 724, 727, 828, 886, 905, 937, 945, 1024, 1049, 1074, 1168, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1309, 1311, 1329, 1336, 1344, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1501, 1654, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2174, 2195, 2200, 2201, 2232, 2288, 2405, 2412, 2473, 2484, 2507, 2544, 2582, 2609, 2644, 2669, 2697, 2814, 2836, 2856, 2886, 2889, 2916, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3380, 3449, 3519, 3530, 3559, 3591, 3592, 3700, 3715, 3751, 3761, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4158, 4178, 4183, 4194, 4204, 4216, 4226, 4244, 4260, 4297, 4299, 4330, 4336, 4348, 4406, 4414, 4451, 4491, 4535, 4545, 4560, 4579, 4581, 4608, 4618, 4635, 4657, 4670, 4674, 4740, 4803, 4812, 4858, 4885, 4889, 4964, 4977, 4979, 5016, 5019, 5029, 5122, 5148, 5154, 5185, 5199, 5234, 5256, 5271, 5276, 5288, 5312, 5332, 5342, 5355, 5373, 5431, 5446, 5510, 5564, 5636, 5673, 5677, 5684, 5691, 5707, 5732, 5741, 5742, 5776, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5961, 5968, 5973, 6003, 6045, 6046, 6081, 6147, 6177, 6178, 6181, 6258, 6271, 6328, 6332, 6338, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6463, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6792, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6979, 6996, 7014, 7020, 7030, 7032, 7101, 7184, 7205, 7247, 7298, 7329, 7337, 7499, 7509, 7565, 7585, 7607, 7608, 7614, 7643, 7682, 7733, 7753, 7826, 7827, 7878, 7917, 7969, 7999, 8044, 8049, 8146, 8158, 8165, 8188, 8193, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8725, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8985, 8997, 8998, 9074, 9131, 9208, 9298, 9301, 9314, 9318, 9320, 9335, 9357, 9373, 9401, 9420, 9431, 9452, 9464, 9466, 9476, 9547, 9588, 9609, 9683, 9721, 9722, 9753, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10062, 10067, 10089, 10095, 10098, 10137, 10140, 10206, 10223, 10255, 10319, 10366, 10441, 10464, 10520, 10527, 10559, 10590, 10612, 10672, 10693, 10711, 10738, 10739, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11091, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11305, 11326, 11373, 11379, 11390, 11417, 11429, 11454, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11926, 11935, 11946, 11968, 12015, 12045, 12123, 12176, 12217, 12223, 12245, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12633, 12648, 12649, 12736, 12816, 12856, 12885, 12901, 12904, 12933, 12946, 12962, 12965, 13019, 13070, 13072, 13108, 13126, 13176, 13203, 13247, 13307, 13350, 13355, 13367, 13405, 13437, 13470, 13490, 13497, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13676, 13688, 13704, 13731, 13746, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14251, 14259, 14265, 14269, 14310, 14321, 14339, 14345, 14390, 14395, 14403, 14454, 14493, 14523, 14554, 14562, 14620, 14628, 14651, 14654, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14981, 14985, 14989, 14998, 15048, 15063, 15076, 15149, 15168, 15344, 15402, 15425, 15442, 15447, 15457, 15485, 15518, 15541, 15613, 15690, 15746, 15768, 15801, 15802, 15814, 15843, 15886, 15890, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16379, 16426, 16454, 16491, 16501, 16532, 16562, 16584, 16593, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16829, 16830, 16874, 16924, 17082, 17098, 17127, 17136, 17180, 17181, 17190, 17207, 17271, 17293, 17325, 17402, 17410, 17444, 17481, 17506, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17647, 17661, 17677, 17712, 17748, 17807, 17824, 17841, 17853, 17929, 17959, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18245, 18298, 18302, 18306, 18335, 18340, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18773, 18800, 18828, 18883, 18963, 18983, 19003, 19038, 19189, 19197, 19269, 19300, 19347, 19375, 19461, 19572, 19593, 19636, 19660, 19697, 19721, 19725, 19766, 19805, 19809, 19812, 19846, 19897, 19906, 19927, 19939, 19956, 19979, 20131, 20189], "app_unseen": [53, 62, 73, 104, 122, 167, 172, 185, 190, 216, 219, 227, 228, 231, 270, 282, 296, 318, 348, 382, 406, 417, 495, 548, 567, 569, 575, 586, 636, 660, 672, 723, 724, 886, 905, 937, 945, 1024, 1049, 1074, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1302, 1309, 1311, 1329, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1501, 1654, 1693, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2174, 2195, 2200, 2201, 2232, 2288, 2405, 2412, 2473, 2484, 2507, 2544, 2609, 2644, 2669, 2697, 2814, 2856, 2886, 2889, 2916, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3380, 3449, 3519, 3559, 3592, 3613, 3700, 3715, 3761, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4158, 4178, 4194, 4226, 4244, 4260, 4297, 4299, 4330, 4348, 4406, 4414, 4451, 4491, 4535, 4560, 4579, 4581, 4608, 4618, 4635, 4657, 4674, 4740, 4803, 4812, 4858, 4885, 4889, 4964, 4977, 4979, 5016, 5019, 5029, 5122, 5148, 5185, 5199, 5234, 5235, 5256, 5271, 5276, 5312, 5355, 5373, 5446, 5564, 5636, 5673, 5676, 5677, 5680, 5684, 5691, 5707, 5732, 5741, 5742, 5776, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5961, 6003, 6045, 6046, 6081, 6147, 6178, 6258, 6271, 6332, 6338, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6979, 6996, 7014, 7020, 7030, 7101, 7184, 7205, 7247, 7298, 7329, 7334, 7499, 7509, 7565, 7607, 7608, 7614, 7643, 7682, 7705, 7733, 7753, 7826, 7878, 7917, 7999, 8146, 8158, 8165, 8188, 8193, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8997, 8998, 9074, 9131, 9298, 9301, 9314, 9318, 9320, 9335, 9357, 9373, 9401, 9420, 9431, 9452, 9464, 9466, 9588, 9609, 9683, 9721, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10067, 10089, 10098, 10137, 10147, 10206, 10223, 10255, 10319, 10384, 10441, 10464, 10520, 10527, 10559, 10590, 10612, 10693, 10711, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11077, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11305, 11326, 11332, 11373, 11377, 11379, 11390, 11417, 11429, 11454, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11935, 11946, 11968, 12015, 12045, 12123, 12176, 12223, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12633, 12649, 12736, 12816, 12856, 12885, 12901, 12933, 12962, 12965, 13019, 13070, 13072, 13108, 13126, 13176, 13307, 13327, 13350, 13355, 13367, 13405, 13437, 13490, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13676, 13688, 13704, 13746, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14259, 14265, 14269, 14296, 14310, 14321, 14339, 14390, 14395, 14403, 14454, 14493, 14523, 14562, 14620, 14628, 14651, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14928, 14981, 14985, 14989, 14998, 15048, 15076, 15149, 15344, 15402, 15425, 15442, 15447, 15457, 15485, 15613, 15746, 15775, 15801, 15802, 15814, 15843, 15886, 15890, 15935, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16379, 16454, 16491, 16501, 16532, 16584, 16593, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16829, 16830, 16874, 16957, 17098, 17127, 17136, 17154, 17180, 17181, 17190, 17207, 17293, 17325, 17402, 17410, 17481, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17661, 17677, 17748, 17807, 17824, 17841, 17853, 17929, 17959, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18245, 18302, 18335, 18340, 18351, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18828, 18883, 18963, 18983, 19003, 19038, 19189, 19197, 19269, 19300, 19347, 19366, 19375, 19461, 19572, 19593, 19697, 19721, 19725, 19805, 19809, 19812, 19846, 19897, 19906, 19939, 19956, 20109, 20131, 20189], "task_unseen": [45, 53, 62, 73, 104, 122, 136, 167, 172, 185, 190, 219, 227, 228, 231, 270, 282, 296, 318, 382, 406, 417, 480, 495, 548, 567, 569, 575, 586, 591, 636, 660, 662, 672, 674, 723, 724, 727, 828, 841, 886, 892, 898, 905, 937, 945, 1024, 1049, 1074, 1168, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1309, 1311, 1329, 1336, 1344, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1482, 1501, 1512, 1586, 1654, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1853, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2158, 2174, 2195, 2200, 2201, 2232, 2288, 2313, 2405, 2412, 2442, 2473, 2484, 2507, 2527, 2544, 2582, 2609, 2644, 2669, 2697, 2757, 2814, 2836, 2856, 2886, 2889, 2916, 2947, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3315, 3380, 3449, 3493, 3519, 3530, 3559, 3591, 3592, 3700, 3715, 3751, 3761, 3770, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4114, 4158, 4178, 4183, 4194, 4195, 4204, 4216, 4226, 4244, 4247, 4260, 4297, 4299, 4330, 4336, 4348, 4406, 4414, 4422, 4451, 4491, 4516, 4535, 4545, 4560, 4579, 4581, 4602, 4608, 4618, 4635, 4657, 4662, 4670, 4674, 4740, 4803, 4806, 4812, 4858, 4881, 4883, 4885, 4889, 4916, 4964, 4977, 4979, 5016, 5019, 5025, 5029, 5122, 5148, 5154, 5185, 5199, 5234, 5251, 5256, 5271, 5276, 5288, 5312, 5332, 5342, 5355, 5373, 5431, 5446, 5459, 5497, 5510, 5564, 5636, 5664, 5673, 5677, 5684, 5691, 5707, 5732, 5741, 5742, 5773, 5776, 5779, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5958, 5961, 5968, 5973, 6003, 6045, 6046, 6081, 6147, 6177, 6178, 6181, 6239, 6258, 6271, 6312, 6328, 6332, 6338, 6385, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6463, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6714, 6768, 6792, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6944, 6979, 6996, 7014, 7020, 7030, 7032, 7045, 7101, 7184, 7204, 7205, 7247, 7271, 7298, 7329, 7337, 7356, 7379, 7499, 7509, 7565, 7585, 7607, 7608, 7614, 7643, 7644, 7682, 7733, 7753, 7767, 7826, 7827, 7878, 7917, 7935, 7969, 7999, 8044, 8049, 8141, 8146, 8158, 8165, 8188, 8193, 8242, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8725, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8985, 8997, 8998, 9062, 9074, 9131, 9208, 9298, 9301, 9307, 9314, 9318, 9320, 9335, 9357, 9373, 9376, 9394, 9401, 9420, 9431, 9452, 9464, 9466, 9476, 9547, 9588, 9609, 9628, 9634, 9683, 9721, 9722, 9753, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10062, 10067, 10089, 10095, 10098, 10137, 10140, 10206, 10223, 10255, 10319, 10366, 10441, 10464, 10520, 10527, 10554, 10559, 10590, 10612, 10672, 10693, 10711, 10738, 10739, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11091, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11304, 11305, 11326, 11373, 11379, 11390, 11417, 11429, 11454, 11482, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11896, 11926, 11935, 11946, 11968, 11972, 12015, 12045, 12057, 12123, 12176, 12217, 12223, 12227, 12245, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12625, 12633, 12648, 12649, 12653, 12736, 12754, 12816, 12856, 12885, 12901, 12904, 12933, 12946, 12962, 12965, 13019, 13070, 13072, 13089, 13108, 13126, 13176, 13203, 13247, 13307, 13327, 13350, 13355, 13367, 13390, 13405, 13437, 13470, 13490, 13497, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13669, 13676, 13688, 13704, 13731, 13735, 13746, 13821, 13838, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14251, 14259, 14265, 14269, 14310, 14321, 14339, 14345, 14390, 14395, 14403, 14454, 14493, 14523, 14554, 14556, 14562, 14620, 14628, 14651, 14654, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14981, 14985, 14989, 14998, 15048, 15063, 15076, 15149, 15168, 15244, 15252, 15283, 15344, 15402, 15409, 15425, 15438, 15442, 15447, 15457, 15485, 15518, 15541, 15591, 15613, 15625, 15690, 15746, 15768, 15801, 15802, 15814, 15843, 15886, 15890, 15968, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16324, 16379, 16426, 16454, 16491, 16501, 16532, 16562, 16584, 16593, 16614, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16810, 16829, 16830, 16853, 16874, 16924, 16979, 17037, 17065, 17082, 17098, 17127, 17136, 17180, 17181, 17190, 17207, 17258, 17271, 17293, 17325, 17402, 17410, 17444, 17457, 17481, 17506, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17647, 17661, 17677, 17712, 17748, 17807, 17824, 17841, 17853, 17900, 17929, 17959, 17969, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18182, 18245, 18298, 18302, 18306, 18335, 18340, 18360, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18773, 18800, 18828, 18883, 18963, 18983, 19003, 19012, 19019, 19038, 19061, 19189, 19197, 19269, 19300, 19347, 19375, 19453, 19461, 19572, 19593, 19636, 19660, 19697, 19721, 19725, 19766, 19805, 19809, 19812, 19846, 19897, 19906, 19927, 19939, 19956, 19979, 20033, 20131, 20189]} -------------------------------------------------------------------------------- /faq.md: -------------------------------------------------------------------------------- 1 | # FAQ 2 | 3 | Thank you for your interest in OS-Genesis. Below are some questions we have collected from emails, Hugging Face, and WeChat communications. We hope these can be helpful to you. 4 | 5 | ## When will the checkpoints and data be available? 6 | 7 | We have already uploaded the checkpoints and evaluation code for AndroidControl. ~~The remaining checkpoints will be uploaded in the coming days~~ (done). Due to server bandwidth limitations, this may take some time. The data will also be available shortly. 8 | 9 | 10 | ## How About Desktop? 11 | 12 | Q: Why haven’t you worked on PC/Desktop data? Is there a particular reason? 13 | 14 | A: 15 | We originally intended to cover PC, mobile, and web. In fact, our high-level reverse-synthesis process can also run on PC (we used [OSWorld](https://os-world.github.io/) as the dynamic environment). However, we decided not to continue on the PC side for the following reasons: 16 | 1. Data collection on PC is too difficult for a model-based approach. 17 | For instance, in [OSWorld](https://os-world.github.io/), the success rate for GPT-4o across most scenarios is <10%, which means the proportion of high-quality trajectories would be low. Ensuring quality would require a massive amount of data and a more rigorous TRM, making costly. 18 | 19 | 2. Even after collecting trajectories, there are significant challenges in training: 20 | 1. Length of a11ytree: 21 | We use a11ytree, and on desktop the a11ytree is much longer than the mobile or web DOM. In training that involves multimodal information, it exceeds the context window of models like InternVL and Qwen. 22 | 2. Instruction-following issues: 23 | Currently, open-source VLMs face major problems with instruction-following on PC environments. -------------------------------------------------------------------------------- /static/OS-Genesis-Badge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/static/OS-Genesis-Badge.png -------------------------------------------------------------------------------- /static/OS-Genesis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/static/OS-Genesis.png --------------------------------------------------------------------------------