├── .DS_Store
├── README.md
├── collection
├── README.md
├── genesis_rm.py
├── mobile_runner.py
├── random_walk_aw.py
├── random_walk_web.py
├── run_mobile_runner.py
└── run_random_walk_aw.py
├── evaluation
├── .DS_Store
├── analysis
│ └── diversity_analysis.ipynb
├── android_control
│ ├── ac_eval.py
│ ├── internvl2_inference.py
│ ├── qwen2vl_inference.py
│ └── run_ac_inference.sh
├── android_world
│ └── README.md
└── eval_json_files
│ ├── ac_high_processing.jsonl
│ ├── ac_low_processing.jsonl
│ ├── android_control_test_data.jsonl
│ └── android_control_test_subsplits.json
├── faq.md
└── static
├── OS-Genesis-Badge.png
└── OS-Genesis.png
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/.DS_Store
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # OS-Genesis
2 |
3 |
4 |
5 |
6 |
7 | [](https://arxiv.org/abs/2412.19723)
8 | 
9 | [](https://huggingface.co/papers/2412.19723)
10 | [](https://mp.weixin.qq.com/s/_gu3NSCpAbAE1A8mEhGD7Q)
11 |
12 |
15 |
16 |
17 | This repository contains the code and data for the ACL 2025 paper [OS-Genesis: Automating GUI Agent Trajectory Construction via Reverse Task Synthesis](https://arxiv.org/abs/2412.19723).
18 | > We are uploading the data and checkpoints. Due to bandwidth limitations, this will take some time. Stay tuned!
19 |
20 |
21 | ## Overview
22 |
23 | We introduce OS-Genesis, an interaction-driven pipeline for synthesizing high-quality and diverse GUI agent trajectory data without human supervision or predefined tasks. By leveraging reverse task synthesis and a trajectory reward model, OS-Genesis enables effective end2end training of GUI agents.
24 |
25 |
26 |
27 |
28 |
29 |
30 | ## Training
31 |
32 | For details and operations of the training, please refer to the [InternVL2 documentation](https://internvl.readthedocs.io/en/latest/get_started/installation.html) and [Qwen2-VL](https://github.com/QwenLM/Qwen2-VL).
33 |
34 | ## Evaluation
35 | ### AndroidControl
36 | To evaluate the AndroidControl Benchmark, please follow the steps below:
37 |
38 | 1. **Clone the GitHub Repository:**
39 |
40 | ```
41 | git clone https://github.com/OS-Copilot/OS-Genesis.git
42 | ```
43 |
44 | 2. **Inference:**
45 | ```
46 | cd OS-Genesis/evaluation/android_control
47 | bash run_ac_inference.sh $dataset $checkpoint
48 | ```
49 |
50 | 3. **Evaluation:**
51 | ```
52 | pyhton ac_eval.py
53 | ```
54 |
55 | ## Mobile
56 | ### AndroidControl
57 |
58 | | Model Name | Base Model | Training Data | HF Link |
59 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: |
60 | | OS-Genesis-4B-AC | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-AC) |
61 | | OS-Genesis-7B-AC | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-AC) |
62 | | OS-Genesis-8B-AC | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-ac-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_ac_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-AC) |
63 |
64 | ### AndroidWorld
65 |
66 | | Model Name | Base Model | Training Data | HF Link |
67 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: |
68 | | OS-Genesis-4B-AW | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-AW) |
69 | | OS-Genesis-7B-AW | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-AW) |
70 | | OS-Genesis-8B-AW | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-aw-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-mobile-data/blob/main/os_genesis_aw_training_data.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-AW) |
71 |
72 |
73 | ## Web
74 |
75 | | Model Name | Base Model | Training Data | HF Link |
76 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: | :---------------------------------------------------------: |
77 | | OS-Genesis-4B-WA | [InternVL2-4B](https://huggingface.co/OpenGVLab/InternVL2-4B) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-4B-WA) |
78 | | OS-Genesis-7B-WA | [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-7B-WA) |
79 | | OS-Genesis-8B-WA | [InternVL2-8B](https://huggingface.co/OpenGVLab/InternVL2-8B) | [OS-Genesis-web-training-data](https://huggingface.co/datasets/OS-Copilot/OS-Genesis-web-data/blob/main/os_genesis_web_training.jsonl) | [🤗 link](https://huggingface.co/OS-Copilot/OS-Genesis-8B-WA) |
80 |
81 |
82 | ## More Resources
83 |
84 | ### Raw collected triples
85 |
86 | In addition to our complete trajectory data on HuggingFace, we also provide collected raw `` triples. You can use them to reproduce the process of reverse task synthesis directly, without re-collecting them from emulators yourself 😄. The screenshots and corresponding texts (with SoM info contained) are provided below:
87 |
88 | | Data Type | Screenshots | Data JSON |
89 | | :-------------: | :-------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------: |
90 | | Mobile | [Screenshots](https://drive.google.com/file/d/1ILyz_-DDOdAk32kue1lEPaV50YzQ5c4v/view?usp=sharing) | [Data JSON](https://drive.google.com/file/d/1dSxNf-co4LGh93NoiUgWKdbcf8Mo_VWG/view?usp=sharing) |
91 | | Web | [Screenshots](https://drive.google.com/file/d/1X2QktZ51OUofZ43vDGB4RuAPlXbdf5ua/view?usp=sharing) | [Data JSON](https://drive.google.com/file/d/1mDxhonGnd3wZbNQgWMVpYEkPW26_FVg8/view?usp=sharing) |
92 |
93 | Feel free to email me if you require additional data of this kind.
94 |
95 | ## FAQ ❓
96 |
97 | We have collected some questions from emails, Hugging Face, and WeChat communications. Please check the [FAQ](https://github.com/OS-Copilot/OS-Genesis/blob/main/faq.md) 🤖
98 |
99 | ## Citation 📖
100 |
101 | 🫶 If you are interested in our work or find this repository / our data helpful, please consider using the following citation format when referencing our paper:
102 |
103 | ```bibtex
104 | @article{sun2024genesis,
105 | title={OS-Genesis: Automating GUI Agent Trajectory Construction via Reverse Task Synthesis},
106 | author={Sun, Qiushi and Cheng, Kanzhi and Ding, Zichen and Jin, Chuanyang and Wang, Yian and Xu, Fangzhi and Wu, Zhenyu and Jia, Chengyou and Chen, Liheng and Liu, Zhoumianze and others},
107 | journal={arXiv preprint arXiv:2412.19723},
108 | year={2024}
109 | }
110 | ```
111 |
--------------------------------------------------------------------------------
/collection/README.md:
--------------------------------------------------------------------------------
1 | # README
2 |
3 | In addition to the collected-synthesized data, we here provide scripts collected from the environment for Reverse Task Synthesis, helping you extend OS-Genesis to more scenarios or synthesize more data as desired.
4 |
5 | # Mobile
6 |
7 | ## Random walk in the AndroidWorld Environment
8 | First, clone the [AndroidWorld](https://github.com/google-research/android_world) repository, then place `random_walk_aw.py` and `run_random_walk_aw.py` in its directory.
9 |
10 | `random_walk_aw.py` provides our implementation logic for random walking in the environment to obtain `` triples. You can use `python run_random_walk_aw.py` to collect large-scale interaction triples.
11 | ## Reverse Task Synthesis
12 |
13 | Work in progress.
14 |
15 | ## Trajectory Construction
16 |
17 | 1. Install the AndroidWorld Environment as described in: https://github.com/google-research/android_world
18 | 2. Move the scripts to the AndroidWorld directory: ``android_env/android_world``
19 | 3. Run the following command to collect the data:
20 | ```bash
21 | python mobile_runner.py
22 | ```
23 |
24 | # Desktop
25 |
26 | ## Random walk in the WebArena Environment
27 | First, configure [WebArena](https://github.com/web-arena-x/webarena) and open the specified ports to access the website, then place `random_walk_web.py` in its directory.
28 |
29 | `random_walk_web.py` provides our implementation logic for random walking in the environment to obtain triples. You can use `python random_walk_web.py` to collect large-scale interaction triples.
30 |
31 | # Reward Model
32 | We provide an example of the Reward Model we use in `genesis_rm.py`. For more information, please refer to the original paper.
--------------------------------------------------------------------------------
/collection/genesis_rm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import requests
4 | import time
5 | from PIL import Image
6 | import base64
7 | import io
8 | import json
9 | import re
10 | from tqdm import tqdm
11 | import numpy as np
12 |
13 |
14 | def encode_image(image_content):
15 | return base64.b64encode(image_content).decode('utf-8')
16 |
17 |
18 | def convert_image_to_base64(image_path):
19 | # Open the image file
20 | with open(image_path, 'rb') as f:
21 | # Load the image using PIL
22 | image_bytes = f.read()
23 | # Encode the image bytes to base64
24 | encoded_image = encode_image(image_bytes)
25 | return encoded_image
26 |
27 |
28 | def call_llm(model_name, payload):
29 | headers = {
30 | "Content-Type": "application/json",
31 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
32 | }
33 | print("Generating content with GPT model: {}".format(model_name))
34 | response = requests.post(
35 | "https://api.openai.com/v1/chat/completions",
36 | headers=headers,
37 | json=payload
38 | )
39 | if response.status_code != 200:
40 | if response.json()['error']['code'] == "context_length_exceeded":
41 | print("Context length exceeded. Retrying with a smaller context.")
42 | payload["messages"] = [payload["messages"][0]] + payload["messages"][-1:]
43 | retry_response = requests.post(
44 | "https://api.openai.com/v1/chat/completions",
45 | headers=headers,
46 | json=payload
47 | )
48 | if retry_response.status_code != 200:
49 | print(
50 | "Failed to call LLM even after attempt on shortening the history: " + retry_response.text)
51 | return ""
52 |
53 | print("Failed to call LLM: " + response.text)
54 | time.sleep(2)
55 | return ""
56 | else:
57 | return response.json()['choices'][0]['message']['content']
58 |
59 |
60 | gpt_annot_traj = json.load(open('/Users/cckevin/Desktop/gpt_annot_traj_v2.json', 'r'))
61 | imgs_dir = '/Users/cckevin/Desktop/gpt_annot_traj_v2/gpt_screenshots_v2'
62 |
63 | system_prompt = """
64 | You are an expert in evaluating Android GUI agent task trajectories. Your task is to assess the quality and effectiveness of task trajectories for GUI manipulation tasks.
65 |
66 | A trajectory consists of the following components:
67 | 1. High-level Instruction: Describes the user's intended task (e.g., "Create a new travel itinerary document in a folder").
68 | 2. Action History: Includes two key parts:
69 | - Low-level Actions & Summaries: A sequence of actions, where each step includes:
70 | - The executed action.
71 | - A summary of the action, indicating the effect after the action is executed.
72 | - GUI Screenshots: Screenshots captured when the last three actions are executed: the third-to-last, second-to-last, and final actions (if there are at least three actions; otherwise, include all actions).
73 |
74 | When evaluating a trajectory, consider these key aspects:
75 |
76 | ### Evaluation Criteria:
77 | 1. Trajectory Coherence:
78 | - Do the low-level steps and corresponding actions follow a logical sequence toward the goal?
79 | - Are the actions clearly described and specific?
80 | - Are there redundant or unnecessary actions?
81 |
82 | 2. Task Completion:
83 | - Does the trajectory successfully achieve the instructed task?
84 | - Are all necessary interactions completed?
85 | - Are error cases handled appropriately?
86 |
87 | ### Scoring Guidelines:
88 | Rate the trajectory on a scale of 1 to 5 based on the evaluation criteria:
89 |
90 | - 5: The task is perfectly completed, successfully executing multiple actions to achieve the goal. The sequence is logically clear with no noticeable redundancies.
91 | - 4: The task is mostly completed, successfully executing multiple actions. However, due to challenges or ambiguities in the instructions, the completion is not perfect, or there are inefficiencies in the process.
92 | - 3: The task is partially completed, with some successful actions executed. However, due to task or environmental constraints, the goal is not fully achieved, or the sequence ends in a loop or error.
93 | - 2: Only a few actions are executed. Although there is an attempt to complete the task, the trajectory deviates from the goal early on or demonstrates significant inefficiencies in execution and logic.
94 | - 1: The task fails completely, with no meaningful actions executed at the start. The sequence either falls into an immediate deadlock, a repetitive loop, or demonstrates no value in completing the task.
95 |
96 | Note: If the task is relatively complex, but the trajectory demonstrates valuable attempts, even if the task is not fully completed, consider adjusting the score upward. However, if the task is complex but the trajectory fails to perform actions that contribute meaningfully to task completion, no extra points should be awarded.
97 |
98 | ### Response Format:
99 | Format your response into two lines as shown below:
100 | Reason:
101 | Score:
102 | """
103 |
104 | for i, item in enumerate(gpt_annot_traj[:]):
105 |
106 | if "reward" in item:
107 | print("processed")
108 | continue
109 |
110 | instruction = item["instruction"]
111 |
112 | action_history = []
113 | for j, action in enumerate(item["steps"]):
114 | if "summary" in action:
115 | summary = action["summary"]
116 | else:
117 | summary = action["reason"]
118 | summary = f"Step {j+1}: {summary}"
119 | action_history.append(summary)
120 | action_history_text = '\n'.join(action_history)
121 |
122 | action_screenshots = []
123 | for action in item["steps"][-3:]:
124 | screenshot_path = os.path.join(imgs_dir, action["screen_before"])
125 | screenshot = convert_image_to_base64(screenshot_path)
126 | action_screenshots.append(screenshot)
127 |
128 | traj_prompt = f"Instruction :{instruction}\nAction History:\n{action_history_text}\nThe last three screenshots are provided."
129 |
130 | messages = []
131 |
132 | messages.append({
133 | "role": "system",
134 | "content": [
135 | {
136 | "type": "text",
137 | "text": system_prompt
138 | },
139 | ]
140 | })
141 |
142 | # Prediction example
143 | action_text_image = []
144 | for img in action_screenshots:
145 | action_text_image.append(
146 | {
147 | "type": "image_url",
148 | "image_url": {
149 | "url": f"data:image/png;base64,{img}",
150 | "detail": "high"
151 | }
152 | }
153 | )
154 |
155 | action_text_image.append(
156 | {
157 | "type": "text",
158 | "text": traj_prompt
159 | }
160 | )
161 |
162 | messages.append({
163 | "role": "user",
164 | "content": action_text_image
165 | })
166 |
167 | print(traj_prompt)
168 |
169 | model_name = "gpt-4o-2024-08-06"
170 | try_num = 0
171 | answer = None
172 | while try_num < 5:
173 | try_num += 1
174 | try:
175 | response = call_llm(model_name, {
176 | "model": model_name,
177 | "messages": messages,
178 | "max_tokens": 1500,
179 | "top_p": 0.9,
180 | "temperature": 0.5
181 | })
182 | except:
183 | print("error call")
184 | time.sleep(1.0)
185 | continue
186 | try:
187 | print(response)
188 | reason_match = re.search(r"Reason:\s*(.+?)\s*Score:", response, re.DOTALL)
189 | score_match = re.search(r"Score:\s*(\d+)", response)
190 | reason = reason_match.group(1).strip() if reason_match else None
191 | score = int(score_match.group(1)) if score_match else None
192 |
193 | if reason and score and 1 <= score <= 5:
194 | item["reward_reason"] = reason
195 | item["reward"] = score
196 | break # Successfully parsed, exit loop
197 | else:
198 | print("Invalid response format or score out of range, retrying...")
199 | time.sleep(1.0)
200 |
201 | except json.JSONDecodeError:
202 | # If response is not valid JSON, continue to generate
203 | print("Invalid response received, retrying...")
204 | time.sleep(1.0)
205 |
206 | num_processed = len([item for item in gpt_annot_traj if ("reward" in item)])
207 | print("Num of total: {} Num of success: {}".format(len(gpt_annot_traj), num_processed))
208 | if i % 20 == 0:
209 | json.dump(gpt_annot_traj, open('/Users/cckevin/Desktop/gpt_annot_traj_v2_reward.json', 'w'))
210 |
211 | json.dump(gpt_annot_traj, open('/Users/cckevin/Desktop/gpt_annot_traj_v2_reward.json', 'w'))
212 | print("Save")
--------------------------------------------------------------------------------
/collection/mobile_runner.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The android_world Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Build trajectories of OS-Genesis.
16 |
17 | """
18 |
19 | from collections.abc import Sequence
20 | import os
21 | import random
22 | from typing import Type
23 | import json
24 | import time
25 | import uuid
26 | from PIL import Image
27 | import numpy as np
28 | import re
29 | import pickle
30 |
31 | from absl import app
32 | from absl import flags
33 | from absl import logging
34 | from android_world import registry
35 | from android_world.agents import infer
36 | from android_world.agents import t3a, m3a
37 | # from android_world.agents import m3a_origin
38 | from android_world.agents import m3a_utils
39 | from android_world.agents.t3a import _generate_ui_elements_description_list_full
40 | from android_world.env import env_launcher, json_action
41 | from android_world.task_evals import task_eval
42 |
43 | logging.set_verbosity(logging.WARNING)
44 |
45 | os.environ['GRPC_VERBOSITY'] = 'ERROR' # Only show errors
46 | os.environ['GRPC_TRACE'] = 'none' # Disable tracing
47 |
48 |
49 | def _find_adb_directory() -> str:
50 | """Returns the directory where adb is located."""
51 | potential_paths = [
52 | os.path.expanduser('~/Library/Android/sdk/platform-tools/adb'),
53 | os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
54 | ]
55 | for path in potential_paths:
56 | if os.path.isfile(path):
57 | return path
58 | raise EnvironmentError(
59 | 'adb not found in the common Android SDK paths. Please install Android'
60 | " SDK and ensure adb is in one of the expected directories. If it's"
61 | ' already installed, point to the installed location.'
62 | )
63 |
64 |
65 | _ADB_PATH = flags.DEFINE_string(
66 | 'adb_path',
67 | _find_adb_directory(),
68 | 'Path to adb. Set if not installed through SDK.',
69 | )
70 | _EMULATOR_SETUP = flags.DEFINE_boolean(
71 | 'perform_emulator_setup',
72 | False,
73 | 'Whether to perform emulator setup. This must be done once and only once'
74 | ' before running Android World. After an emulator is setup, this flag'
75 | ' should always be False.',
76 | )
77 | _DEVICE_CONSOLE_PORT = flags.DEFINE_integer(
78 | 'console_port',
79 | 5554,
80 | 'The console port of the running Android device. This can usually be'
81 | ' retrieved by looking at the output of `adb devices`. In general, the'
82 | ' first connected device is port 5554, the second is 5556, and'
83 | ' so on.',
84 | )
85 |
86 | _TASK = flags.DEFINE_string(
87 | 'task',
88 | None,
89 | 'A specific task to run.',
90 | )
91 |
92 |
93 | def save_image(image, directory):
94 | """Same image to a file and return the file name."""
95 | unique_id = str(uuid.uuid4())
96 | image_name = f"{unique_id}.png"
97 | image_path = os.path.join(directory, image_name)
98 | if isinstance(image, np.ndarray):
99 | image = Image.fromarray(np.uint8(image))
100 | image.save(image_path)
101 | return image_name
102 |
103 |
104 | def get_state(env_state, logical_screen_size, ui_elements):
105 | element_list_text = _generate_ui_elements_description_list_full(
106 | ui_elements,
107 | logical_screen_size,
108 | )
109 | screen = env_state.pixels.copy()
110 | screen = Image.fromarray(screen.astype('uint8'))
111 | return screen, element_list_text
112 |
113 |
114 | def element_to_identifier(element):
115 | """Converts an element to a JSON-serializable identifier."""
116 | bbox = getattr(element, 'bbox_pixels', None)
117 | bbox_dict = {'x_min': bbox.x_min, 'x_max': bbox.x_max, 'y_min': bbox.y_min, 'y_max': bbox.y_max} if bbox else None
118 | identifier = {
119 | 'resource_id': getattr(element, 'resource_id', None),
120 | 'text': getattr(element, 'text', None),
121 | 'content_description': getattr(element, 'content_description', None),
122 | 'class_name': getattr(element, 'class_name', None),
123 | 'bbox_pixels': bbox_dict,
124 | 'hint_text': getattr(element, 'hint_text', None),
125 | 'is_checkable': getattr(element, 'is_checkable', None),
126 | 'is_enabled': getattr(element, 'is_enabled', None),
127 | 'is_visible': getattr(element, 'is_visible', None),
128 | 'is_clickable': getattr(element, 'is_clickable', None),
129 | 'is_editable': getattr(element, 'is_editable', None),
130 | 'is_focused': getattr(element, 'is_focused', None),
131 | 'is_focusable': getattr(element, 'is_focusable', None),
132 | 'is_long_clickable': getattr(element, 'is_long_clickable', None),
133 | 'is_scrollable': getattr(element, 'is_scrollable', None),
134 | 'is_selected': getattr(element, 'is_selected', None),
135 | 'package_name': getattr(element, 'package_name', None),
136 | 'resource_name': getattr(element, 'resource_name', None),
137 | }
138 | return identifier
139 |
140 |
141 | def _main() -> None:
142 |
143 | instruction_path = './aw_instructions.json'
144 | aw_instrcutions = json.load(open(instruction_path, 'r'))
145 |
146 | SCREEN_GPT_DIR = './screenshots_gpt_v2'
147 | if not os.path.exists(SCREEN_GPT_DIR):
148 | os.mkdir(SCREEN_GPT_DIR)
149 |
150 | """Initialize Env."""
151 | env = env_launcher.load_and_setup_env(
152 | console_port=_DEVICE_CONSOLE_PORT.value,
153 | emulator_setup=_EMULATOR_SETUP.value,
154 | adb_path=_ADB_PATH.value,
155 | )
156 | env_launcher.verify_api_level(env)
157 |
158 | for task_item in aw_instrcutions:
159 | if "task_fail" in task_item:
160 | del task_item["task_fail"]
161 |
162 | for task_item in aw_instrcutions:
163 |
164 | total_tasks = len(aw_instrcutions)
165 | annotated_tasks = len([item for item in aw_instrcutions if "gpt_traj" in item])
166 | print(f"Total task: {total_tasks} --- Annotated task: {annotated_tasks}")
167 | failed_tasks = len([item for item in aw_instrcutions if "task_fail" in item])
168 | print(f"Total task: {total_tasks} --- Failed task: {failed_tasks}")
169 | if "gpt_traj" in task_item or "task_fail" in task_item:
170 | continue
171 |
172 | try:
173 | env.reset(go_home=True)
174 | task_registry = registry.TaskRegistry()
175 | aw_registry = task_registry.get_registry(task_registry.ANDROID_WORLD_FAMILY)
176 |
177 | # Initialize based on the task sampled and open the corresponding app.
178 | app_name = task_item["app_name"]
179 | task_name = task_item["task_name"] if "task_name" in task_item else task_item["task_task"]
180 | instrcution = task_item["refine_task"]
181 |
182 | if task_name and task_name != "default":
183 | if task_name not in aw_registry:
184 | raise ValueError('Task {} not found in registry.'.format(_TASK.value))
185 | task_type: Type[task_eval.TaskEval] = aw_registry[task_name]
186 | else:
187 | task_type: Type[task_eval.TaskEval] = random.choice(
188 | list(aw_registry.values())
189 | )
190 | print("unknown task name")
191 | input()
192 | print(task_type)
193 |
194 | # load params
195 | task_id = task_item["task_id"]
196 | params_dir = './params_new'
197 | params_path = os.path.join(params_dir, task_id + "_params.pkl")
198 | with open(params_path, 'rb') as f:
199 | params = pickle.load(f)
200 | print(params)
201 | #params = task_type.generate_random_params()
202 |
203 | task = task_type(params)
204 |
205 | task.initialize_task(env)
206 | # agent = t3a.T3A(env, infer.Gpt4Wrapper('gpt-4-turbo-2024-04-09'))
207 | # agent = m3a_origin.M3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06'))
208 | agent = m3a.M3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06'))
209 |
210 |
211 | # Open the corresponding app after initializing the task.
212 | open_app = True
213 | if open_app:
214 | open_app_action = {"action_type": "open_app", "app_name": app_name}
215 | converted_action = json_action.JSONAction(**open_app_action)
216 | agent.env.execute_action(converted_action)
217 | time.sleep(3.0)
218 |
219 | print('Goal: ' + str(instrcution))
220 | is_done = False
221 | gpt_traj = []
222 | for i, _ in enumerate(range(15)):
223 |
224 | # Obtain the state of the environment before execution to synchronize with our training setup.
225 | env_state = agent.get_post_transition_state()
226 | logical_screen_size = agent.env.logical_screen_size
227 | ui_elements = env_state.ui_elements
228 | screen, element_list_text = get_state(env_state, logical_screen_size, ui_elements)
229 | screen_before = save_image(screen, SCREEN_GPT_DIR)
230 | # Note: Here, following the implementation of M3A, a state interface representation consistent with the agent’s observation is saved, ensuring that the actions generated by the model can locate the corresponding elements.
231 | ui_elements_before_identifiers = [element_to_identifier(elem) for elem in ui_elements if m3a_utils.validate_ui_element(elem, logical_screen_size)]
232 |
233 | # take one step
234 | response = agent.step(instrcution)
235 |
236 | # Extract the screen, prompt, and generated action from the response.
237 | screen_before_som = save_image(response.data["before_screenshot_with_som"], SCREEN_GPT_DIR)
238 | action_prompt = response.data["action_prompt"]
239 | action_output = response.data["action_output"]
240 | action_reason = response.data["action_reason"]
241 | summary_prompt = response.data["summary_prompt"]
242 | summary = response.data["summary"]
243 |
244 | match = re.search(r'Action:\s*(\{.*\})', action_output)
245 | action_json = match.group(1) if match else "action_not_match"
246 | # Exit if the same action is performed three times consecutively.
247 | if app_name != "Simple Calendar Pro":
248 | if i >= 2 and (action_json == gpt_traj[i-1]["action_json"] == gpt_traj[i-2]["action_json"]):
249 | break
250 |
251 | step_data = {
252 | "screen_before": screen_before,
253 | "screen_before_som": screen_before_som,
254 | "ui_elements_before_text": element_list_text,
255 | "ui_elements_before": ui_elements_before_identifiers,
256 | "action_prompt": action_prompt,
257 | "action_output": action_output,
258 | "action_json": action_json,
259 | "action_reason": action_reason,
260 | "summary_prompt": summary_prompt,
261 | "summary": summary
262 | }
263 | gpt_traj.append(step_data)
264 |
265 | if response.done:
266 | is_done = True
267 | break
268 |
269 | """
270 | agent_successful = is_done and task.is_successful(env) == 1
271 | print(
272 | f'{"Task Successful ✅" if agent_successful else "Task Failed ❌"};'
273 | f' {task.goal}'
274 | )
275 | """
276 |
277 | # env.close()
278 |
279 | # Update the annotations to the original aw_instructions file at the end of each trajectory.
280 | task_item["gpt_traj"] = gpt_traj
281 | json.dump(aw_instrcutions, open(instruction_path, 'w'))
282 |
283 | except Exception as e:
284 | print(f"An error occurred: {e}")
285 | task_item["task_fail"] = "fail"
286 | json.dump(aw_instrcutions, open(instruction_path, 'w'))
287 | time.sleep(10)
288 | break # Exit and restart after an error occurs.
289 |
290 |
291 | def main(argv: Sequence[str]) -> None:
292 | del argv
293 | _main()
294 |
295 |
296 | if __name__ == '__main__':
297 | app.run(main)
--------------------------------------------------------------------------------
/collection/random_walk_aw.py:
--------------------------------------------------------------------------------
1 | # Copyright 2024 The android_world Authors.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Random walk in AndroidWorld environment, collecting triples.
17 | Maintain an unclickable element pool: during random walk, if the action does not change the screen, add the element to the pool.
18 | """
19 |
20 | from collections.abc import Sequence
21 | import os
22 | import requests
23 | import random
24 | from typing import Type
25 | from PIL import Image
26 | import time
27 | import json
28 | import uuid
29 | import pickle
30 |
31 | from absl import app
32 | from absl import flags
33 | from absl import logging
34 | from android_world import registry
35 | from android_world.agents import infer
36 | from android_world.agents import t3a
37 | from android_world.agents import agent_utils
38 | from android_world.agents.t3a import _generate_ui_elements_description_list_full
39 | from android_world.env import env_launcher, json_action
40 | from android_world.env import representation_utils
41 | from android_world.task_evals import task_eval
42 |
43 |
44 | logging.set_verbosity(logging.WARNING)
45 |
46 | os.environ['GRPC_VERBOSITY'] = 'ERROR' # Only show errors
47 | os.environ['GRPC_TRACE'] = 'none' # Disable tracing
48 |
49 |
50 | def _find_adb_directory() -> str:
51 | """Returns the directory where adb is located."""
52 | potential_paths = [
53 | os.path.expanduser('~/Library/Android/sdk/platform-tools/adb'),
54 | os.path.expanduser('~/Android/Sdk/platform-tools/adb'),
55 | ]
56 | for path in potential_paths:
57 | if os.path.isfile(path):
58 | return path
59 | raise EnvironmentError(
60 | 'adb not found in the common Android SDK paths. Please install Android'
61 | " SDK and ensure adb is in one of the expected directories. If it's"
62 | ' already installed, point to the installed location.'
63 | )
64 |
65 |
66 | _ADB_PATH = flags.DEFINE_string(
67 | 'adb_path',
68 | _find_adb_directory(),
69 | 'Path to adb. Set if not installed through SDK.',
70 | )
71 | _EMULATOR_SETUP = flags.DEFINE_boolean(
72 | 'perform_emulator_setup',
73 | False,
74 | 'Whether to perform emulator setup. This must be done once and only once'
75 | ' before running Android World. After an emulator is setup, this flag'
76 | ' should always be False.',
77 | )
78 | _DEVICE_CONSOLE_PORT = flags.DEFINE_integer(
79 | 'console_port',
80 | 5554,
81 | 'The console port of the running Android device. This can usually be'
82 | ' retrieved by looking at the output of `adb devices`. In general, the'
83 | ' first connected device is port 5554, the second is 5556, and'
84 | ' so on.',
85 | )
86 |
87 | _TASK = flags.DEFINE_string(
88 | 'task',
89 | None,
90 | 'A specific task to run.',
91 | )
92 |
93 |
94 | def generate_text_input(element_list_text, interactive_element, max_retries=3):
95 | # Construct the prompt with specific instructions
96 | def is_valid_response(response_text):
97 | if "\n" in response_text:
98 | return False
99 | # Optionally, add more validation rules here
100 | return True
101 |
102 | def element_to_text(element):
103 | """Convert the element into a text description."""
104 | description = ""
105 | if getattr(element, 'resource_id', None):
106 | description += f"Resource ID: {element.resource_id}\n"
107 | if getattr(element, 'text', None):
108 | description += f"Text: {element.text}\n"
109 | if getattr(element, 'content_description', None):
110 | description += f"Content Description: {element.content_description}\n"
111 | if getattr(element, 'class_name', None):
112 | description += f"Class Name: {element.class_name}\n"
113 | if getattr(element, 'hint_text', None):
114 | description += f"Hint Text: {element.hint_text}\n"
115 | return description or "No additional information."
116 |
117 | prompt = f"""
118 | You are an intelligent input assistant. The current UI elements are as follows:
119 | {element_list_text}
120 | The selected editable element information is as follows:
121 | {element_to_text(interactive_element)}
122 | Based on the above information, please randomly generate a text content that a user might input into this element. The text should be contextually appropriate. For example, if it's a search box, you might generate a search query; if it's a username input field, you might generate a username.
123 | **Please return only the generated text without any additional explanation. Do not include any prefixes or suffixes.**
124 |
125 | If you understand, please provide the text input now.
126 | """
127 |
128 | headers = {
129 | "Content-Type": "application/json",
130 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
131 | }
132 |
133 | payload = {
134 | "model": "gpt-4o-mini-2024-07-18",
135 | "messages": [
136 | {"role": "user", "content": prompt}
137 | ],
138 | "max_tokens": 20,
139 | "temperature": 0.7,
140 | "n": 1,
141 | }
142 |
143 | retries = 0
144 | while retries < max_retries:
145 | try:
146 | # Send POST request to OpenAI API
147 | response = requests.post(
148 | "https://api.openai.com/v1/chat/completions",
149 | headers=headers,
150 | json=payload
151 | )
152 | response.raise_for_status()
153 | result = response.json()
154 | text_input = result['choices'][0]['message']['content'].strip()
155 |
156 | # Validate the response
157 | if is_valid_response(text_input):
158 | return text_input
159 | else:
160 | print(f"Invalid response format: '{text_input}'. Retrying...")
161 | retries += 1
162 | time.sleep(1) # Wait a bit before retrying
163 | except Exception as e:
164 | print(f"Error generating text input: {e}")
165 | retries += 1
166 | time.sleep(1)
167 |
168 | # If all retries fail, return a default text
169 | print("Failed to get valid text input after retries. Returning default text.")
170 | return "Test Input"
171 |
172 |
173 | def get_state(env_state, logical_screen_size, ui_elements):
174 | element_list_text = _generate_ui_elements_description_list_full(
175 | ui_elements,
176 | logical_screen_size,
177 | )
178 | screen = env_state.pixels.copy()
179 | screen = Image.fromarray(screen.astype('uint8'))
180 | return screen, element_list_text
181 |
182 |
183 | def element_to_identifier(element):
184 | """Converts an element to a JSON-serializable identifier."""
185 | bbox = getattr(element, 'bbox_pixels', None)
186 | bbox_dict = {'x_min': bbox.x_min, 'x_max': bbox.x_max, 'y_min': bbox.y_min, 'y_max': bbox.y_max} if bbox else None
187 | identifier = {
188 | 'resource_id': getattr(element, 'resource_id', None),
189 | 'text': getattr(element, 'text', None),
190 | 'content_description': getattr(element, 'content_description', None),
191 | 'class_name': getattr(element, 'class_name', None),
192 | 'bbox_pixels': bbox_dict,
193 | 'hint_text': getattr(element, 'hint_text', None),
194 | 'is_checkable': getattr(element, 'is_checkable', None),
195 | 'is_enabled': getattr(element, 'is_enabled', None),
196 | 'is_visible': getattr(element, 'is_visible', None),
197 | 'is_clickable': getattr(element, 'is_clickable', None),
198 | 'is_editable': getattr(element, 'is_editable', None),
199 | 'is_focused': getattr(element, 'is_focused', None),
200 | 'is_focusable': getattr(element, 'is_focusable', None),
201 | 'is_long_clickable': getattr(element, 'is_long_clickable', None),
202 | 'is_scrollable': getattr(element, 'is_scrollable', None),
203 | 'is_selected': getattr(element, 'is_selected', None),
204 | 'package_name': getattr(element, 'package_name', None),
205 | 'resource_name': getattr(element, 'resource_name', None),
206 | }
207 | return identifier
208 |
209 |
210 | def filter_interactive_elements(elements, screen_width_height_px, unc_elem_pool):
211 | interactive_elements = []
212 | screen_width, screen_height = screen_width_height_px
213 |
214 | # List of excluded package names, adding other keyboard app package names can improve filtering effect
215 | # Keyboard and system UI (e.g., battery indicator is actually not clickable)
216 | excluded_packages = {'com.google.android.inputmethod.latin', 'com.android.systemui'}
217 |
218 | for index, element in enumerate(elements):
219 |
220 | if element.package_name in excluded_packages:
221 | continue
222 |
223 | if element.is_enabled and element.is_visible:
224 | if element.bbox_pixels:
225 | x_min = element.bbox_pixels.x_min
226 | x_max = element.bbox_pixels.x_max
227 | y_min = element.bbox_pixels.y_min
228 | y_max = element.bbox_pixels.y_max
229 |
230 | # Check if bounding box is within screen bounds and coordinates are valid
231 | if not (x_min >= x_max or x_min >= screen_width or x_max <= 0 or
232 | y_min >= y_max or y_min >= screen_height or y_max <= 0):
233 |
234 | # Compute element identifier
235 | element_identifier = element_to_identifier(element)
236 | element_identifier_str = json.dumps(element_identifier, sort_keys=True)
237 | if element_identifier_str not in unc_elem_pool:
238 | interactive_elements.append([index, element])
239 |
240 | return interactive_elements
241 |
242 |
243 | def sample_action_element(element, element_list_text):
244 | index, interactive_element = element
245 | actions = []
246 | # If element is editable, prioritize text input
247 | if interactive_element.is_editable:
248 | # Use GPT to generate appropriate text input based on current UI
249 | text_input = generate_text_input(element_list_text, interactive_element)
250 | return {"action_type": "input_text", "text": text_input, "index": index}
251 | # Assume all elements are clickable, add click action to action list
252 | actions.append({"action_type": "click", "index": index})
253 | # If element can be long pressed, add long press action to action list
254 | if interactive_element.is_long_clickable:
255 | actions.append({"action_type": "long_press", "index": index})
256 | # If multiple actions are available, choose based on given probabilities
257 | if actions:
258 | if len(actions) == 1:
259 | return actions[0]
260 | else:
261 | # Choose click or long press with 90% and 10% probability (assuming both actions are available)
262 | return random.choices(actions, weights=[9, 1], k=1)[0]
263 |
264 |
265 | def get_task_app(task_name):
266 | # Choose which app to open based on task name
267 | task_2_app = {
268 | "AudioRecorder": "Audio Recorder",
269 | "Browser": "Files",
270 | "SimpleCalendar": "Simple Calendar Pro",
271 | "Camera": "Camera",
272 | "Clock": "Clock",
273 | "Contacts": "Contacts",
274 | "Expense": "Pro Expense",
275 | "ExpenseAddMultipleFromMarkor": "Pro Expense",
276 | "Files": "Files",
277 | "Markor": "Markor",
278 | "Osm": "OsmAnd",
279 | "Recipe": "Broccoli - Recipe App",
280 | "RecipeAddMultipleRecipesFromMarkor": "Broccoli - Recipe App",
281 | "RecipeAddMultipleRecipesFromMarkor2": "Broccoli - Recipe App",
282 | "Retro": "Retro Music",
283 | "SimpleDraw": "Simple Draw Pro",
284 | "SaveCopyOfReceipt": "Simple Gallery Pro",
285 | "SimpleSms": "Simple SMS Messenger",
286 | "System": "Settings",
287 | "Turn": "Settings",
288 | "Vlc": "VLC",
289 | "Tasks": "Tasks",
290 | "Notes": "Joplin",
291 | "Sports": "OpenTracks",
292 | }
293 | if task_name in task_2_app:
294 | return task_2_app[task_name]
295 | else:
296 | for task, app in task_2_app.items():
297 | if task in task_name:
298 | return app
299 | return "Home"
300 |
301 |
302 | def has_screen_changed(before_elements, after_elements):
303 | # Convert UI elements to identifiers, excluding system UI elements
304 | before_set = set(
305 | json.dumps(element_to_identifier(elem), sort_keys=True)
306 | for elem in before_elements
307 | if elem.package_name != 'com.android.systemui'
308 | )
309 | after_set = set(
310 | json.dumps(element_to_identifier(elem), sort_keys=True)
311 | for elem in after_elements
312 | if elem.package_name != 'com.android.systemui'
313 | )
314 | return before_set != after_set
315 |
316 |
317 | # Random walk pipeline:
318 | # 1. Initialize environment
319 | # 2. Randomly select an app to sample based on task config
320 | # 3. Load unclickable element pool
321 | # 4. Execute num_step steps of random walk within an app
322 | # a. For each step, get a set of currently executable actions based on current state
323 | # b. Randomly select an action
324 | # c. Execute action, if screen doesn't change it means it's an unclickable element, add to pool; otherwise record triples
325 |
326 | def _main() -> None:
327 | """Random walk in AndroidWorld for sampling"""
328 | SCREEN_DIR = './screenshots_camera'
329 | if not os.path.exists(SCREEN_DIR):
330 | os.mkdir(SCREEN_DIR)
331 |
332 | TRAJECTORY_DIR = './trajectories_camera'
333 | if not os.path.exists(TRAJECTORY_DIR):
334 | os.mkdir(TRAJECTORY_DIR)
335 |
336 | PARAMS_DIR = './params_camera'
337 | if not os.path.exists(PARAMS_DIR):
338 | os.mkdir(PARAMS_DIR)
339 |
340 | # Launch Android emulator (ADB) and return to home screen
341 | env = env_launcher.load_and_setup_env(
342 | console_port=_DEVICE_CONSOLE_PORT.value,
343 | emulator_setup=_EMULATOR_SETUP.value,
344 | adb_path=_ADB_PATH.value,
345 | )
346 | env_launcher.verify_api_level(env)
347 | env.reset(go_home=True)
348 |
349 | # Task collection
350 | task_registry = registry.TaskRegistry()
351 | aw_registry = task_registry.get_registry(task_registry.ANDROID_WORLD_FAMILY) # All AW tasks, total 116
352 | print(aw_registry)
353 | print("There are total {} tasks.".format(len(aw_registry)))
354 |
355 | aw_list = list(aw_registry.items())
356 | random.shuffle(aw_list)
357 | aw_registry = dict(aw_list)
358 |
359 | # Record unclickable elements for each app
360 | if not os.path.exists('./unclickable_elem_pool'):
361 | os.mkdir('./unclickable_elem_pool')
362 |
363 | for task_id in range(len(aw_registry)):
364 |
365 | task_uuid = str(uuid.uuid4())
366 |
367 | # Select a task
368 | task_name, task_type = list(aw_registry.items())[task_id]
369 |
370 | # Open different apps as starting state based on task
371 | app_name = get_task_app(task_name)
372 |
373 | # Initialize and return to home screen, this initialization initializes the corresponding app based on task's app snapshot
374 | params = task_type.generate_random_params()
375 |
376 | # Record random params
377 | with open(os.path.join(PARAMS_DIR, task_uuid+"_params.pkl"), "wb") as f:
378 | pickle.dump(params, f)
379 |
380 | task = task_type(params)
381 | task.initialize_task(env)
382 | env.reset(go_home=True)
383 |
384 | agent = t3a.T3A(env, infer.Gpt4Wrapper('gpt-4o-2024-08-06')) # Only used to adapt environment interface, actually no need for agent
385 |
386 | print("Open App: {}".format(app_name))
387 | print("Goal: {}".format(task.goal))
388 | if app_name != "Home":
389 | open_app_action = {"action_type": "open_app", "app_name": app_name}
390 | converted_action = json_action.JSONAction(**open_app_action)
391 | agent.env.execute_action(converted_action)
392 | time.sleep(3.0)
393 |
394 | # Load unclick_elem_pool
395 | unc_elem_pool_path = os.path.join('./unclickable_elem_pool', str(app_name)+".json")
396 | if not os.path.exists(unc_elem_pool_path):
397 | unc_elem_pool = set()
398 | else:
399 | with open(unc_elem_pool_path, 'r') as f:
400 | unc_elem_pool_list = json.load(f)
401 | unc_elem_pool = set(unc_elem_pool_list)
402 |
403 | # Random walk sampling for num_step steps
404 | trajectory = []
405 | num_step = 10
406 | for i in range(num_step):
407 | # Get current state
408 | env_state = agent.get_post_transition_state()
409 | logical_screen_size = agent.env.logical_screen_size
410 | ui_elements = env_state.ui_elements
411 | screen, element_list_text = get_state(env_state, logical_screen_size, ui_elements)
412 | print(element_list_text)
413 |
414 | # Get all executable actions on current screen: 1. Different sampling ratios & logic for different actions; 2. Keyboard elements should not be sampled, filter these elements
415 | # TODO: (questionable) Some samples seem to occasionally not match up?
416 | # Interactive elements on current screen
417 | interactive_elements = filter_interactive_elements(ui_elements, logical_screen_size, unc_elem_pool)
418 | # Other available actions on current screen
419 | addition_actions = [{"action_type": "scroll", "direction": "down"}, {"action_type": "scroll", "direction": "up"},
420 | {"action_type": "navigate_back"}]
421 |
422 | # Sample an action/element
423 | weight_interactive = 10 # Weight for interactive elements
424 | weight_addition = 1 # Weight for other actions
425 | action_element = random.choice(interactive_elements*weight_interactive+addition_actions*weight_addition)
426 | if "action_type" in action_element: # If it's a direct action
427 | action_sample = action_element
428 | else: # If it's an element, sample an action that can be executed on it based on element properties (e.g., is_clickable)
429 | action_sample = sample_action_element(action_element, element_list_text)
430 | print("Cand Elements: {}".format([index for (index, elem) in interactive_elements]))
431 | print(action_sample)
432 | if "index" in action_sample:
433 | print(ui_elements[action_sample['index']])
434 |
435 | # Execute action
436 | converted_action = json_action.JSONAction(**action_sample)
437 | agent.env.execute_action(converted_action) # TODO: (questionable) Some actions seem to not be fully executed? Theoretically execute_action waits for action completion
438 | time.sleep(2.0)
439 |
440 | # Check if action caused screen change
441 | env_state_after = agent.get_post_transition_state()
442 | logical_screen_size = agent.env.logical_screen_size
443 | ui_elements_after = env_state_after.ui_elements
444 | if not has_screen_changed(ui_elements, ui_elements_after):
445 | print("The screen not change")
446 | # If screen didn't change and this operation was on an element, add this element to Unavailable Ele Pool
447 | if "index" in action_sample:
448 | element = ui_elements[action_sample['index']]
449 | element_identifier = element_to_identifier(element)
450 | unc_elem_pool.add(json.dumps(element_identifier, sort_keys=True))
451 | else:
452 | # Record this action
453 | screen_before_uuid = str(uuid.uuid4())
454 | screen_after_uuid = str(uuid.uuid4())
455 | screen_before_filename = os.path.join(SCREEN_DIR, f'{screen_before_uuid}.png')
456 | screen.save(screen_before_filename)
457 |
458 | screen_after, element_list_text_after = get_state(env_state_after, logical_screen_size, ui_elements_after)
459 | screen_after_filename = os.path.join(SCREEN_DIR, f'{screen_after_uuid}.png')
460 | screen_after.save(screen_after_filename)
461 |
462 | ui_elements_before_identifiers = [element_to_identifier(elem) for elem in ui_elements]
463 | interactive_elements_before = [element_to_identifier(elem[1]) for elem in interactive_elements]
464 |
465 | ui_elements_after_identifiers = [element_to_identifier(elem) for elem in ui_elements_after]
466 | interactive_elements_after_full = filter_interactive_elements(ui_elements_after, logical_screen_size, unc_elem_pool)
467 | interactive_elements_after = [element_to_identifier(elem[1]) for elem in interactive_elements_after_full]
468 |
469 | if "index" in action_sample:
470 | action_element = element_to_identifier(ui_elements[action_sample['index']])
471 | else:
472 | action_element = None
473 |
474 | step_data = {
475 | 'task_uuid': task_uuid,
476 | 'task': task_name,
477 | 'app': app_name,
478 | 'screen_before': screen_before_filename,
479 | 'element_list_text_before': element_list_text,
480 | 'ui_elements_before': ui_elements_before_identifiers,
481 | 'interactive_elements_before': interactive_elements_before,
482 | 'screen_after': screen_after_filename,
483 | 'element_list_text_after': element_list_text_after,
484 | 'ui_elements_after': ui_elements_after_identifiers,
485 | 'interactive_elements_after': interactive_elements_after,
486 | 'action': action_sample,
487 | 'action_element': action_element
488 | }
489 |
490 | trajectory.append(step_data)
491 | print(f"Recorded step {i + 1} for task {task_name}")
492 |
493 | # Save the updated unc_elem_pool
494 | with open(unc_elem_pool_path, 'w') as f:
495 | json.dump(list(unc_elem_pool), f, indent=2)
496 |
497 | # Save the trajectory to a JSON file
498 | trajectory_uuid = str(uuid.uuid4())
499 | trajectory_filename = os.path.join(TRAJECTORY_DIR, f'{task_name}_{trajectory_uuid}.json')
500 | with open(trajectory_filename, 'w') as f:
501 | json.dump(trajectory, f, indent=2)
502 | print(f"Saved trajectory for task {task_name} to {trajectory_filename}")
503 |
504 | env.reset(go_home=True)
505 |
506 | env.close()
507 |
508 |
509 | def main(argv: Sequence[str]) -> None:
510 | del argv
511 | _main()
512 |
513 |
514 | if __name__ == '__main__':
515 | app.run(main)
516 |
--------------------------------------------------------------------------------
/collection/random_walk_web.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # type: ignore
3 |
4 | """
5 | Random walk in WebArena environment, collecting triples.
6 | Maintain unclickable element pool, clickable element pool, explored element pool, and implement exploration of the website through heuristic rules.
7 | """
8 |
9 | import json
10 | import os
11 | import random
12 | import re
13 | import subprocess
14 | import time
15 | import copy
16 | import numpy as np
17 | import uuid
18 | from PIL import Image
19 | import requests
20 | from tqdm import tqdm
21 | from browser_env.utils import (
22 | DetachedPage,
23 | )
24 |
25 |
26 | def get_state(env):
27 | """Get the current state of the environment"""
28 | observation = env._get_obs()
29 | observation_metadata = env._get_obs_metadata()
30 | info = {
31 | "page": DetachedPage(env.page.url, ""),
32 | "fail_error": "",
33 | "observation_metadata": observation_metadata,
34 | }
35 | # Return a deep copy of info to ensure it doesn't update with the environment
36 | return (copy.deepcopy(observation), copy.deepcopy(info))
37 |
38 |
39 | def extract_state_ele(observation_metadata):
40 | # Extract each element represented by 'union_bound' and 'text' from the current interface for comparing changes between interfaces
41 | return [(tuple(value.get('union_bound')), re.sub(r'^\[\d+\]\s*', '', value.get('text'))) for value in observation_metadata.values()]
42 |
43 |
44 | def are_screen_identical(screen_before, screen_after):
45 | return np.array_equal(screen_before, screen_after)
46 |
47 |
48 | def load_element_pool(file_path):
49 | """Load element pool"""
50 | if not os.path.exists(file_path):
51 | return set()
52 | else:
53 | with open(file_path, 'r') as f:
54 | data_list = json.load(f)
55 | element_pool = set((tuple(item[0]), item[1]) for item in data_list)
56 | return element_pool
57 |
58 |
59 | def save_element_pool(element_pool, file_path):
60 | """Save element pool"""
61 | with open(file_path, 'w') as f:
62 | json.dump(list(element_pool), f, indent=2)
63 |
64 |
65 | def save_image(image_array, directory):
66 | """Save image and return filename"""
67 | unique_id = str(uuid.uuid4())
68 | image_name = f"{unique_id}.png"
69 | image_path = os.path.join(directory, image_name)
70 | image = Image.fromarray(image_array)
71 | image.save(image_path)
72 | return image_name
73 |
74 |
75 | def select_element(state_elements, actree_obs, unclick_elem_pool, new_elements, explored_elem_pool):
76 | """Randomly select an element based on weights and return corresponding information"""
77 | elements_weights = []
78 | for action_element_id, action_element in state_elements:
79 | element_key = (tuple(action_element['union_bound']), re.sub(r'^\[\d+\]\s*', '', action_element['text']))
80 | # Check if element is in unclickable element set and if it's visible in the interface
81 | if element_key in unclick_elem_pool or (f"[{action_element_id}]" not in actree_obs) or ("statictext" in element_key[1].lower()):
82 | continue
83 | # Determine element status
84 | is_new = element_key in new_elements
85 | is_explored = element_key in explored_elem_pool
86 | # Assign weights
87 | if (not is_explored) and is_new:
88 | weight = 4 # Unexplored and newly appeared, highest weight
89 | elif is_explored and is_new:
90 | weight = 3 # Explored but newly appeared, second highest weight
91 | elif (not is_explored) and (not is_new):
92 | weight = 3 # Unexplored but not newly appeared, second highest weight
93 | else:
94 | weight = 1 # Explored and not newly appeared, lowest weight
95 | elements_weights.append((action_element_id, action_element, element_key, weight))
96 |
97 | # If no selectable elements, return None
98 | if len(elements_weights) == 0:
99 | return None
100 |
101 | print(elements_weights)
102 |
103 | # Randomly select an element based on weights
104 | elem_infos = [item[:3] for item in elements_weights]
105 | weights = [item[3] for item in elements_weights]
106 | selected_elem_info = random.choices(elem_infos, weights=weights, k=1)[0]
107 | return selected_elem_info
108 |
109 |
110 | def generate_text_input(screen_text, selected_element_text, max_retries=5):
111 | """Call GPT API to determine if element is an input field and generate input content"""
112 | def is_valid_response(response_text):
113 | try:
114 | # Clean up the response text
115 | response_text = response_text.strip()
116 |
117 | # Remove code block markers if present
118 | if response_text.startswith("```"):
119 | # Remove starting ``` or ```json
120 | response_text = re.sub(r'^```[a-zA-Z]*\n', '', response_text)
121 | # Remove ending ```
122 | response_text = re.sub(r'\n```$', '', response_text)
123 |
124 | # Extract JSON object
125 | json_match = re.search(r'\{.*\}', response_text, re.DOTALL)
126 | if json_match:
127 | json_str = json_match.group(0)
128 | # Replace single quotes with double quotes
129 | json_str = json_str.replace("'", '"')
130 | # Replace True/False with true/false
131 | json_str = json_str.replace('True', 'true').replace('False', 'false')
132 | response_json = json.loads(json_str)
133 | # Check if 'is_input' exists and is boolean
134 | if 'is_input' in response_json and isinstance(response_json['is_input'], bool):
135 | if response_json['is_input']:
136 | # Ensure 'input_content' exists and is a string
137 | return 'input_content' in response_json and isinstance(response_json['input_content'], str)
138 | else:
139 | return True
140 | else:
141 | return False
142 | else:
143 | return False
144 | except json.JSONDecodeError as e:
145 | print(f"JSONDecodeError: {e}")
146 | return False
147 |
148 | prompt = f"""
149 | You are an intelligent assistant. Based on the current UI elements and the selected element information, carefully determine whether the selected element is an input field. If it is, generate text content that a user might input into this element. The text should be contextually appropriate, for example, if it's a search box, you might generate a search query; if it's a username input field, you might generate a username; if it's a location, you might try the name of a university, museum, airport, etc. Use imagination to generate diversed and appropriate input content.
150 |
151 | Current UI elements:
152 | {screen_text}
153 |
154 | Selected element information:
155 | {selected_element_text}
156 |
157 | Please output a JSON object in the following format, without adding any extra text or comments:
158 |
159 | If the selected element is an input field:
160 |
161 | {{
162 | "is_input": true,
163 | "input_content": "the content to input"
164 | }}
165 |
166 | If the selected element is not an input field:
167 |
168 | {{
169 | "is_input": false
170 | }}
171 |
172 | Ensure that the JSON is properly formatted and parsable. Use lowercase `true` or `false` for boolean values, and double quotes for strings.
173 |
174 | If you understand, please provide the JSON object now, without adding any extra text or markers.
175 | """
176 |
177 | headers = {
178 | "Content-Type": "application/json",
179 | "Authorization": f"Bearer {os.environ['OPENAI_API_KEY']}"
180 | }
181 |
182 | payload = {
183 | "model": "gpt-4o-mini-2024-07-18",
184 | "messages": [
185 | {"role": "user", "content": prompt}
186 | ],
187 | "max_tokens": 100,
188 | "temperature": 0.9,
189 | "n": 1,
190 | }
191 |
192 | retries = 0
193 | while retries < max_retries:
194 | try:
195 | # Send POST request to OpenAI API
196 | response = requests.post(
197 | "https://api.openai.com/v1/chat/completions",
198 | headers=headers,
199 | json=payload
200 | )
201 | response.raise_for_status()
202 | result = response.json()
203 | response_text = result['choices'][0]['message']['content'].strip()
204 |
205 | # Validate and parse the response
206 | if is_valid_response(response_text):
207 | response_json = json.loads(response_text)
208 | return response_json
209 | else:
210 | print(f"Invalid response format: '{response_text}'. Retrying...")
211 | retries += 1
212 | time.sleep(1) # Wait a bit before retrying
213 | except Exception as e:
214 | print(f"Error generating text input: {e}")
215 | retries += 1
216 | time.sleep(1)
217 |
218 | # If all retries fail, return default action
219 | print("Failed to get valid response after retries. Assuming it's not an input field.")
220 | return {"is_input": False}
221 |
222 |
223 | def save_trajectory(unique_id_traj, trajectory, traj_dir, screen_dir, website_name, url):
224 | """Save trajectory data"""
225 | traj_save = []
226 | for i, item in enumerate(trajectory):
227 | if isinstance(item, dict):
228 | continue
229 | elif isinstance(item, tuple) and i+1 < len(trajectory):
230 | screen_before_name = save_image(trajectory[i-1]['observation']['image'], screen_dir)
231 | screen_after_name = save_image(trajectory[i+1]['observation']['image'], screen_dir)
232 | step_data = {
233 | "website_name": website_name,
234 | "url": url,
235 | "screen_before": screen_before_name,
236 | "a11y_before": trajectory[i-1]['observation']['text'],
237 | "state_before": trajectory[i-1]['info']['observation_metadata'],
238 | "screen_after": screen_after_name,
239 | "a11y_after": trajectory[i+1]['observation']['text'],
240 | "state_after": trajectory[i+1]['info']['observation_metadata'],
241 | "action": item
242 | }
243 | traj_save.append(step_data)
244 |
245 | traj_filename = f"{unique_id_traj}.json"
246 | traj_path = os.path.join(traj_dir, traj_filename)
247 | with open(traj_path, 'w') as f:
248 | json.dump(traj_save, f, indent=2)
249 |
250 |
251 | def random_walk_episode(config_file, traj_dir, screen_dir, element_dir, config_randomwalk):
252 | """
253 | config_file: Initialize a specific website
254 | """
255 |
256 | if not os.path.exists(traj_dir):
257 | os.mkdir(traj_dir)
258 | if not os.path.exists(screen_dir):
259 | os.mkdir(screen_dir)
260 | if not os.path.exists(element_dir):
261 | os.mkdir(element_dir)
262 | if not os.path.exists(config_randomwalk):
263 | os.mkdir(config_randomwalk)
264 |
265 | SLEEP = 3
266 | # set the URLs of each website
267 | os.environ[
268 | "SHOPPING"
269 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:7770"
270 | os.environ[
271 | "SHOPPING_ADMIN"
272 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:7780/admin"
273 | os.environ[
274 | "REDDIT"
275 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:9999"
276 | os.environ[
277 | "GITLAB"
278 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:8023"
279 | os.environ[
280 | "MAP"
281 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:3000"
282 | os.environ[
283 | "WIKIPEDIA"
284 | ] = "http://ec2-18-220-173-105.us-east-2.compute.amazonaws.com:8888/wikipedia_en_all_maxi_2022-05/A/User:The_other_Kiwix_guy/Landing"
285 | os.environ[
286 | "HOMEPAGE"
287 | ] = "PASS" # The home page is not currently hosted in the demo site
288 | print("Done setting up URLs")
289 |
290 |
291 | # Check if configuration file is correct
292 | assert os.path.exists(config_file)
293 | with open(config_file, "r") as f:
294 | config = json.load(f)
295 |
296 | # Check which URLs are available for the corresponding website and randomly select one
297 | configs_dir = "/Users/cckevin/Desktop/webarena/config_files"
298 | website_url = dict()
299 | for config_item in os.listdir(configs_dir):
300 | if config_item == "examples":
301 | continue
302 | config_path = os.path.join(configs_dir, config_item)
303 | config_item = json.load(open(config_path, 'r'))
304 | if not isinstance(config_item, dict):
305 | continue
306 | try:
307 | if len(config_item["sites"]) != 1:
308 | continue
309 | website_name = config_item["sites"][0]
310 | if website_name not in website_url:
311 | website_url[website_name] = set()
312 |
313 | start_url = config_item["start_url"]
314 | website_url[website_name].add(start_url)
315 | except Exception as e:
316 | print(f"An error occurred: {e}")
317 | print(config_item)
318 |
319 | web_urls = website_url[config["sites"][0]]
320 | random_url = random.choice(list(web_urls))
321 | print(f"Random URL: {random_url}")
322 | config["start_url"] = random_url
323 | # Save config
324 | unique_id_traj = str(uuid.uuid4())
325 | config_randomwalk_path = os.path.join(config_randomwalk, f"{unique_id_traj}.json")
326 | json.dump(config, open(config_randomwalk_path, 'w'))
327 |
328 | """
329 | # run bash prepare.sh to save all account cookies, this only needs to be done once
330 | subprocess.run(["bash", "prepare.sh"])
331 | print("Done saving account cookies")
332 | """
333 |
334 | # Init an environment
335 | from browser_env import (
336 | Action,
337 | ActionTypes,
338 | ObservationMetadata,
339 | ScriptBrowserEnv,
340 | StateInfo,
341 | Trajectory,
342 | action2str,
343 | create_id_based_action,
344 | create_stop_action,
345 | )
346 | from evaluation_harness.evaluators import evaluator_router
347 |
348 | # maintain a trajectory
349 | trajectory: Trajectory = []
350 |
351 | # Maintain element pools
352 | website_name = config["sites"][0]
353 | unclick_elem_pool_path = os.path.join(element_dir, website_name + "_unclick.json")
354 | click_elem_pool_path = os.path.join(element_dir, website_name + "_click.json")
355 | explored_elem_pool_path = os.path.join(element_dir, website_name + "_explored.json")
356 |
357 | unclick_elem_pool = load_element_pool(unclick_elem_pool_path)
358 | click_elem_pool = load_element_pool(click_elem_pool_path)
359 | explored_elem_pool = load_element_pool(explored_elem_pool_path)
360 |
361 | env = None
362 |
363 | try:
364 | # Init the environment
365 | env = ScriptBrowserEnv(
366 | headless=False,
367 | slow_mo=100,
368 | observation_type="accessibility_tree",
369 | current_viewport_only=True,
370 | viewport_size={"width": 1280, "height": 720},
371 | )
372 |
373 | # set the environment for the current example (website)
374 | env.reset(options={"config_file": config_randomwalk_path})
375 | obs, info = get_state(env)
376 | state_info: StateInfo = {"observation": obs, "info": info} # Save each interface state using StateInfo
377 | trajectory.append(state_info) # Record initial interface
378 |
379 | # Record elements from previous interface for comparing newly appeared elements
380 | prev_state_elements = set()
381 |
382 | # random walk num_step
383 | num_step = 5
384 |
385 | for i in range(num_step):
386 |
387 | # Get current interface state
388 | obs_before, info_before = get_state(env)
389 | actree_obs = obs_before["text"]
390 | print(actree_obs)
391 |
392 | # Get candidate interactive elements from current interface
393 | state_elements = list(info_before['observation_metadata']['text']['obs_nodes_info'].items())
394 |
395 | # Extract elements from current interface
396 | current_state_elements = set(extract_state_ele(info_before['observation_metadata']['text']['obs_nodes_info']))
397 | # Find newly appeared elements
398 | new_elements = current_state_elements - prev_state_elements
399 |
400 | # Select element
401 | selected_elem_info = select_element(
402 | state_elements,
403 | actree_obs,
404 | unclick_elem_pool,
405 | new_elements,
406 | explored_elem_pool
407 | )
408 |
409 | # If no selectable elements, skip this iteration
410 | if selected_elem_info is None:
411 | print("No clickable elements found.")
412 | break
413 |
414 | action_element_id, action_element, element_key = selected_elem_info
415 |
416 | # Use GPT to determine if selected element is inputable and choose corresponding action (click/type)
417 | gpt_response = generate_text_input(actree_obs, action_element['text'])
418 | if gpt_response.get('is_input'):
419 | type_content = gpt_response.get('input_content', 'Test Input')
420 | next_action_str = f"type [{action_element_id}] [{type_content}]"
421 | next_action = create_id_based_action(next_action_str)
422 | else:
423 | next_action_str = f"click [{action_element_id}]"
424 | next_action = create_id_based_action(next_action_str)
425 |
426 | print(f"Step {i}: {next_action_str}")
427 |
428 | # Execute action
429 | env.step(next_action)
430 | time.sleep(SLEEP)
431 |
432 | # Add element to explored set
433 | explored_elem_pool.add(element_key)
434 |
435 | # Get execution result
436 | obs_after, info_after = get_state(env)
437 | actree_obs = obs_after["text"]
438 |
439 | # Determine if action caused interface change through screenshot comparison
440 | if are_screen_identical(obs_before['image'], obs_after['image']):
441 | # If interface unchanged, add element to Unavailable Element Pool
442 | print("The pages are identical. Added to unclick_elem_pool.")
443 | if element_key not in click_elem_pool: # If not in clickable set
444 | unclick_elem_pool.add(element_key)
445 | else:
446 | # Interface changed, record action and subsequent interface in trajectory
447 | print("The pages have differences.")
448 | click_elem_pool.add(element_key)
449 | trajectory.append((next_action_str, action_element_id, action_element))
450 | state_info = {"observation": obs_after, "info": info_after}
451 | trajectory.append(state_info)
452 |
453 | # Update previous interface element set, regardless of whether page changed
454 | prev_state_elements = current_state_elements
455 |
456 | except Exception as e:
457 | print(f"An error occurred: {e}")
458 | time.sleep(1)
459 | env.close()
460 | time.sleep(1)
461 | finally:
462 | if env is not None:
463 | env.close()
464 | time.sleep(1)
465 |
466 | # Record element pools
467 | save_element_pool(unclick_elem_pool, unclick_elem_pool_path)
468 | save_element_pool(click_elem_pool, click_elem_pool_path)
469 | save_element_pool(explored_elem_pool, explored_elem_pool_path)
470 |
471 | # Save trajectory data
472 | save_trajectory(unique_id_traj, trajectory, traj_dir, screen_dir, website_name, random_url)
473 |
474 | return trajectory
475 |
476 |
477 | num_episode = 50
478 | config_file = "config_files/60.json" # Choose the corresponding config file based on the website to be sampled
479 | config_randomwalk = "/Users/cckevin/Desktop/config_randomwalk"
480 | traj_dir = "/Users/cckevin/Desktop/traj_results"
481 | screen_dir = "/Users/cckevin/Desktop/screen_results"
482 | element_dir = "/Users/cckevin/Desktop/click_elements"
483 |
484 | for i in tqdm(range(num_episode)):
485 | print(f"Episode {i} random walk")
486 | try:
487 | random_walk_episode(config_file, traj_dir, screen_dir, element_dir, config_randomwalk)
488 | except Exception as e:
489 | print(f"An error occurred: {e}")
490 |
--------------------------------------------------------------------------------
/collection/run_mobile_runner.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 |
5 | def run_gpt_task():
6 | os.system("mobile_runner.py")
7 |
8 | while True:
9 | run_gpt_task()
10 | time.sleep(20)
--------------------------------------------------------------------------------
/collection/run_random_walk_aw.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 |
4 |
5 | def run_random_walk():
6 | os.system("python random_walk_aw.py")
7 |
8 | while True:
9 | run_random_walk()
10 | time.sleep(20)
--------------------------------------------------------------------------------
/evaluation/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/evaluation/.DS_Store
--------------------------------------------------------------------------------
/evaluation/android_control/ac_eval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import json
3 | from collections import defaultdict
4 | import argparse
5 | import re
6 | import math
7 | import os
8 | import ast
9 | import logging
10 | import openpyxl
11 |
12 | logger = logging.getLogger(__name__)
13 | logger.setLevel(logging.INFO)
14 |
15 | TYPE="high"
16 | MODEL="qwen2vl"
17 | PRED_FILE_PATH=f"results/{MODEL}/predictions.jsonl"
18 | def get_in_domain_ids(jsonl_path):
19 | in_domain_ids = []
20 | with open(jsonl_path, 'r') as f:
21 | for line in f:
22 | data = json.loads(line)
23 | in_domain_ids.append(data["id"])
24 | return in_domain_ids
25 |
26 |
27 | def extract_json(s):
28 | match = re.search(r'\{.*?\}', s)
29 | if match:
30 | json_str = match.group(0)
31 | try:
32 | return json.loads(json_str)
33 | except json.JSONDecodeError:
34 | return json_str # 返回字符串形式
35 | return None
36 |
37 |
38 | def extract_dict_from_string(text):
39 | start_index = text.find("Accessibility tree: ")
40 | if start_index == -1:
41 | return None # 没有找到 "Accessibility tree: " 返回 None
42 |
43 | # 找到字典的起始位置
44 | dict_start = start_index + len("Accessibility tree: ")
45 |
46 | # 提取出字典部分的字符串
47 | dict_str = text[dict_start:]
48 |
49 | # 找到第一个 '{' 和匹配的 '}'
50 | open_brace_index = dict_str.find('{')
51 | if open_brace_index == -1:
52 | return None # 没有找到 '{' 返回 None
53 |
54 | stack = []
55 | for i, char in enumerate(dict_str[open_brace_index:], start=open_brace_index):
56 | if char == '{':
57 | stack.append(char)
58 | elif char == '}':
59 | stack.pop()
60 | if not stack:
61 | dict_end = i + 1
62 | break
63 | else:
64 | return None # 没有找到匹配的 '}' 返回 None
65 |
66 | # 提取出字典字符串
67 | dict_substr = dict_str[open_brace_index:dict_end]
68 |
69 | return dict_substr
70 |
71 | def get_key_by_position(text, x, y):
72 | position = f"({x}, {y})"
73 | extracted_dict = json.loads(extract_dict_from_string(text))
74 | if extracted_dict is None:
75 | return None # 如果没有找到字典或解析失败返回 None
76 |
77 | for key in extracted_dict:
78 | if extracted_dict[key] == position:
79 | return key
80 |
81 | return None # 没有找到匹配的位置返回 None
82 |
83 | def calculate_f1_score(predicted_str, ground_truth_str):
84 | predicted_tokens = set(predicted_str.lower().split())
85 | ground_truth_tokens = set(ground_truth_str.lower().split())
86 |
87 | common_tokens = predicted_tokens.intersection(ground_truth_tokens)
88 | if len(predicted_tokens) == 0:
89 | precision = 0
90 | else:
91 | precision = len(common_tokens) / len(predicted_tokens)
92 | if len(ground_truth_tokens) == 0:
93 | recall = 0
94 | else:
95 | recall = len(common_tokens) / len(ground_truth_tokens)
96 |
97 | if precision + recall == 0:
98 | f1_score = 0
99 | else:
100 | f1_score = 2 * (precision * recall) / (precision + recall)
101 | return f1_score
102 |
103 |
104 | def evaluate(args):
105 | prediction_file_path = args.prediction_file_path
106 | prediction = []
107 | with open(prediction_file_path) as file:
108 | for line in file:
109 | prediction.append(json.loads(line))
110 |
111 | ground_truth = []
112 | with open("eval_json_files/android_control_test_data.jsonl") as file:
113 | for line in file:
114 | ground_truth.append(json.loads(line))
115 |
116 | with open(f"eval_json_files/android_control_test_subsplits.json","r") as file:
117 | test_subsplits = json.load(file)
118 | print(test_subsplits.keys())
119 | print(len(ground_truth))
120 |
121 |
122 | # ======================================================================== #
123 | # Results on Low-level
124 | # ======================================================================== #
125 | mis_click_wait_num = 0
126 | step_acc_res_dict = defaultdict(int)
127 | sample_number_dict = defaultdict(int)
128 | for pred, gt in zip(prediction, ground_truth):
129 | gt_action = json.loads(gt["conversations"][1]["value"].split("actions:\n")[1])
130 | episode_id = int(pred["image_id"].split("/")[-1].split("_")[1]) # parse out the episode index
131 | subsplit_type = next((category for category, ids in test_subsplits.items() if episode_id in ids), None)
132 | gt_action_type = gt_action["action_type"]
133 | sample_number_dict[subsplit_type+"_LL"] += 1
134 | sample_number_dict["full_LL"] += 1
135 | sample_number_dict["Type_LL"] += 1
136 | sample_number_dict[gt_action_type+"_LL"] += 1
137 |
138 | if len(pred["pred"].split("action: "))==2:
139 | try:
140 | pred_action = json.loads(pred["pred"].split("action: ")[1])
141 | if len(pred_action) == 0:
142 | continue
143 | except json.JSONDecodeError as e:
144 | continue
145 | else:
146 | pred_action = extract_json(pred["pred"])
147 | if pred_action is None:
148 | continue
149 | try:
150 | pred_action_type = pred_action["action_type"]
151 | except Exception as e:
152 | continue
153 |
154 | # calculate step acc based on types
155 | if gt_action_type==pred_action_type or (gt_action_type == "type" and pred_action_type == "input_text"):
156 | step_acc_res_dict["Type_LL"] += 1
157 | step_acc_res_dict[gt_action_type+"_type_match_LL"] += 1
158 | if gt_action_type in ["click","long_press"]: # evaluate click type
159 | try:
160 | pred_x, pred_y = int(pred_action["x"]), int(pred_action["y"])
161 | except:
162 | pred_x, pred_y = -100, -100
163 | gt_x, gt_y = int(gt_action["x"]), int(gt_action["y"])
164 |
165 | if math.sqrt((pred_x - gt_x)**2 + (pred_y - gt_y)**2) <= math.sqrt((1080*0.14)**2 + (2400*0.14)**2): # set 14 % of screen size as the ratio
166 | step_acc_res_dict[subsplit_type+"_LL"] += 1
167 | step_acc_res_dict["full_LL"] += 1
168 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
169 |
170 |
171 | elif gt_action_type == "type" and pred_action_type in ["input_text", "type"]:
172 | if gt_action["text"]==pred_action["text"] or calculate_f1_score(pred_action["text"], gt_action["text"])>0.5:
173 | step_acc_res_dict[subsplit_type+"_LL"] += 1
174 | step_acc_res_dict["full_LL"] += 1
175 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
176 |
177 | elif gt_action_type == "scroll":
178 | if "Scroll up" in pred["prompt"] and pred_action["direction"] == "up":
179 | step_acc_res_dict[subsplit_type+"_LL"] += 1
180 | step_acc_res_dict["full_LL"] += 1
181 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
182 | elif "Scroll down" in pred["prompt"] and pred_action["direction"] == "down":
183 | step_acc_res_dict[subsplit_type+"_LL"] += 1
184 | step_acc_res_dict["full_LL"] += 1
185 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
186 | elif pred_action==gt_action:
187 | step_acc_res_dict[subsplit_type+"_LL"] += 1
188 | step_acc_res_dict["full_LL"] += 1
189 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
190 |
191 | elif gt_action==pred_action: # evaluate other types
192 | step_acc_res_dict[subsplit_type+"_LL"] += 1
193 | step_acc_res_dict["full_LL"] += 1
194 | step_acc_res_dict[gt_action_type+"_all_match_LL"] += 1
195 |
196 |
197 | # Print the low-level results
198 | logger.info("="*30 + " Step Acc " + "="*30)
199 | logger.info("Full-LL: %f\n" % (step_acc_res_dict["full_LL"] / sample_number_dict["full_LL"]))
200 | logger.info("Type-LL: %f\n" % (step_acc_res_dict["Type_LL"] / sample_number_dict["Type_LL"]))
201 | # 保存结果到excel中
202 | SR_acc = round((step_acc_res_dict["full_LL"] / sample_number_dict["full_LL"]) * 100, 2)
203 | Type_acc = round((step_acc_res_dict["Type_LL"] / sample_number_dict["Type_LL"]) * 100, 2)
204 | # 打开 Excel 文件
205 | file_path = "android_control_eval.xlsx"
206 | wb = openpyxl.load_workbook(file_path)
207 |
208 | # 选择工作表
209 | sheet = wb.active
210 |
211 | # 找到下一个空行
212 | next_row = sheet.max_row + 1
213 |
214 | # 写入数据
215 | sheet.cell(row=next_row, column=1, value=MODEL)
216 | sheet.cell(row=next_row, column=2, value=SR_acc)
217 | sheet.cell(row=next_row, column=3, value=Type_acc)
218 |
219 | # 保存文件
220 | wb.save(file_path)
221 |
222 | logger.info("IDD-LL: %f\n" % (step_acc_res_dict["IDD_LL"] / sample_number_dict["IDD_LL"]))
223 | logger.info("app_unseen-LL: %f\n" % (step_acc_res_dict["app_unseen_LL"] / sample_number_dict["app_unseen_LL"]))
224 | logger.info("task_unseen-LL: %f\n" % (step_acc_res_dict["task_unseen_LL"] / sample_number_dict["task_unseen_LL"]))
225 | logger.info("category_unseen-LL: %f\n" % (step_acc_res_dict["category_unseen_LL"] / sample_number_dict["category_unseen_LL"]))
226 | logger.info("="*30 + " Detailed Acc of Each Type " + "="*30)
227 | for action_type in sample_number_dict:
228 | action_type = action_type.split("_LL")[0]
229 | if action_type not in ["full","Type","IDD","app_unseen","task_unseen","category_unseen"]:
230 | logger.info(f"{action_type}_all_match-LL: %f\n" % (step_acc_res_dict[f"{action_type}_all_match_LL"] / sample_number_dict[f"{action_type}_LL"]))
231 | logger.info(f"{action_type}_type_match-LL: %f\n" % (step_acc_res_dict[f"{action_type}_type_match_LL"] / sample_number_dict[f"{action_type}_LL"]))
232 |
233 |
234 |
235 | if __name__ == '__main__':
236 | parser = argparse.ArgumentParser()
237 |
238 | parser.add_argument('--prediction_file_path', type=str, default=PRED_FILE_PATH)
239 | parser.add_argument('--datasets', type=str, default='')
240 | parser.add_argument('--output_path', type=str, default='results/score/')
241 | parser.add_argument('--eval_HH', action='store_true')
242 | parser.add_argument('--seed', type=int, default=0)
243 | args = parser.parse_args()
244 |
245 | if not os.path.exists(args.output_path):
246 | os.makedirs(args.output_path)
247 |
248 | file_handler = logging.FileHandler(args.output_path + f"ac_{TYPE}_score_{MODEL}.log", mode="w")
249 | file_handler.setLevel(logging.INFO)
250 |
251 | console_handler = logging.StreamHandler()
252 | console_handler.setLevel(logging.INFO)
253 |
254 | logger.addHandler(file_handler)
255 | logger.addHandler(console_handler)
256 |
257 | evaluate(args)
--------------------------------------------------------------------------------
/evaluation/android_control/internvl2_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 | import torch
9 | from internvl.model.internvl_chat import InternVLChatModel
10 | from internvl.train.dataset import build_transform, dynamic_preprocess
11 | from PIL import Image
12 | from tqdm import tqdm
13 | from transformers import AutoTokenizer
14 | import logging
15 |
16 | logger = logging.getLogger(__name__)
17 | logger.setLevel(logging.INFO)
18 |
19 | ds_collections = {
20 | 'ac_high': {
21 | 'root': '/path/to/images',
22 | 'annotation': 'eval_json_files/ac_high_processing.jsonl',
23 | 'max_new_tokens': 999,
24 | 'min_new_tokens': 1,
25 | },
26 | 'ac_low': {
27 | 'root': '/path/to/images',
28 | 'annotation': 'eval_json_files/ac_low_processing.jsonl',
29 | 'max_new_tokens': 999,
30 | 'min_new_tokens': 1,
31 | }
32 |
33 | }
34 |
35 |
36 | def collate_fn(batches, tokenizer):
37 | pixel_values = torch.cat([_['pixel_values'] for _ in batches], dim=0)
38 | questions = [_['question'] for _ in batches]
39 | items = [_['item'] for _ in batches]
40 | return pixel_values, questions, items
41 |
42 |
43 | class AndroidControlDataset(torch.utils.data.Dataset):
44 |
45 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False,
46 | use_thumbnail=False, max_num=6):
47 | self.root = root
48 | self.items = []
49 | f = open(annotation)
50 | data = f.readlines()
51 | for data_line in data:
52 | data_line = json.loads(data_line)
53 | self.items.append(data_line)
54 | self.input_size = input_size # input size??
55 | self.dynamic_image_size = dynamic_image_size
56 | self.use_thumbnail = use_thumbnail
57 | self.max_num = max_num
58 | self.transform = build_transform(is_train=False, input_size=input_size)
59 |
60 | def __len__(self):
61 | return len(self.items)
62 |
63 | def __getitem__(self, idx):
64 | item = self.items[idx]
65 | image_path, question = item['image'], item['conversations'][0]['value']
66 | image = Image.open(image_path).convert('RGB')
67 | if self.dynamic_image_size:
68 | images = dynamic_preprocess(image, image_size=self.input_size,
69 | use_thumbnail=self.use_thumbnail,
70 | max_num=self.max_num)
71 | else:
72 | images = [image]
73 | pixel_values = [self.transform(image) for image in images]
74 | pixel_values = torch.stack(pixel_values)
75 |
76 | return {
77 | 'question': question,
78 | 'pixel_values': pixel_values,
79 | 'item': item,
80 | }
81 |
82 |
83 |
84 | class InferenceSampler(torch.utils.data.sampler.Sampler):
85 |
86 | def __init__(self, size):
87 | self._size = int(size)
88 | assert size > 0
89 | self._rank = torch.distributed.get_rank()
90 | self._world_size = torch.distributed.get_world_size()
91 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
92 |
93 | @staticmethod
94 | def _get_local_indices(total_size, world_size, rank):
95 | shard_size = total_size // world_size
96 | left = total_size % world_size
97 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
98 |
99 | begin = sum(shard_sizes[:rank])
100 | end = min(sum(shard_sizes[:rank + 1]), total_size)
101 | return range(begin, end)
102 |
103 | def __iter__(self):
104 | yield from self._local_indices
105 |
106 | def __len__(self):
107 | return len(self._local_indices)
108 |
109 |
110 |
111 | def evaluate_chat_model():
112 | random.seed(args.seed)
113 |
114 | if args.ds_name_list == None:
115 | ds_names = ds_collections.keys()
116 | else:
117 | ds_names = args.ds_name_list
118 | for ds_name in ds_names:
119 | dataset = AndroidControlDataset(
120 | root=ds_collections[ds_name]['root'],
121 | annotation=ds_collections[ds_name]['annotation'],
122 | input_size=image_size,
123 | dynamic_image_size=args.dynamic,
124 | use_thumbnail=use_thumbnail,
125 | max_num=args.max_num
126 | )
127 | dataloader = torch.utils.data.DataLoader(
128 | dataset=dataset,
129 | sampler=InferenceSampler(len(dataset)),
130 | batch_size=args.batch_size,
131 | num_workers=args.num_workers,
132 | pin_memory=True,
133 | drop_last=False,
134 | collate_fn=partial(collate_fn, tokenizer=tokenizer),
135 | )
136 |
137 |
138 | logger.info(f'Evaluating {ds_name} ...')
139 |
140 | outputs = []
141 | for _, (pixel_values, questions, items) in tqdm(enumerate(dataloader)):
142 | pixel_values = pixel_values.to(torch.bfloat16).cuda()
143 | generation_config = dict(
144 | num_beams=args.num_beams,
145 | max_new_tokens=ds_collections[ds_name]['max_new_tokens'],
146 | min_new_tokens=ds_collections[ds_name]['min_new_tokens'],
147 | do_sample=True if args.temperature > 0 else False,
148 | temperature=args.temperature,
149 | )
150 | pred = model.chat(
151 | tokenizer=tokenizer,
152 | pixel_values=pixel_values,
153 | question=questions[0],
154 | generation_config=generation_config
155 | )
156 | preds = [pred]
157 |
158 | for question, answer, item in zip(questions, preds, items):
159 | question_id = item['id']
160 | image_id = item['image'].split("\/")[-1].replace(".png", "")
161 | text = question
162 | output = {
163 | 'question_id': question_id,
164 | 'image_id': image_id,
165 | 'prompt': text,
166 | 'pred': answer,
167 | 'model_id': model_id,
168 | 'metadata': {},
169 | "task_name": "task2action",
170 | "string_format": "CSV_String",
171 | "position_format": "related"
172 | }
173 | outputs.append(output)
174 |
175 | torch.distributed.barrier()
176 | world_size = torch.distributed.get_world_size()
177 | merged_outputs = [None for _ in range(world_size)]
178 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
179 |
180 | merged_outputs = [json.loads(_) for _ in merged_outputs]
181 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
182 |
183 | torch.distributed.barrier()
184 |
185 | if torch.distributed.get_rank() == 0:
186 | print(len(merged_outputs))
187 | results_file = f'{model_id}_{ds_name}.jsonl'
188 | results_file = os.path.join(args.out_dir, results_file)
189 | with open(results_file, 'w') as f:
190 | for output in merged_outputs:
191 | json.dump(output, f)
192 | f.write('\n')
193 | print('Results saved to {}'.format(results_file))
194 |
195 |
196 | if __name__ == '__main__':
197 | parser = argparse.ArgumentParser()
198 | parser.add_argument('--checkpoint', type=str, default='')
199 | parser.add_argument('--datasets', type=str, default='')
200 | parser.add_argument('--batch-size', type=int, default=1)
201 | parser.add_argument('--num-workers', type=int, default=1)
202 | parser.add_argument('--num-beams', type=int, default=1)
203 | parser.add_argument('--temperature', type=float, default=0.0)
204 | parser.add_argument('--out-dir', type=str, default='results')
205 | parser.add_argument('--seed', type=int, default=0)
206 | parser.add_argument('--dynamic', action='store_true')
207 | parser.add_argument('--max-num', type=int, default=6)
208 | parser.add_argument('--load-in-8bit', action='store_true')
209 | parser.add_argument('--auto', action='store_true')
210 | parser.add_argument('--ds_name_list', type=str, nargs='*', default=None, help='List of dataset names')
211 | args = parser.parse_args()
212 |
213 | args.datasets = args.datasets.split(',')
214 | logger.info('datasets:', args.datasets)
215 | assert args.batch_size == 1, 'Only batch size 1 is supported'
216 |
217 | torch.distributed.init_process_group(
218 | backend='nccl',
219 | world_size=int(os.getenv('WORLD_SIZE', '1')),
220 | rank=int(os.getenv('RANK', '0')),
221 | )
222 |
223 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
224 |
225 |
226 | if args.auto:
227 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
228 | kwargs = {'device_map': 'auto'} if args.auto else {}
229 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint, trust_remote_code=True, use_fast=False)
230 | model = InternVLChatModel.from_pretrained(
231 | args.checkpoint, low_cpu_mem_usage=True, torch_dtype=torch.bfloat16,
232 | load_in_8bit=args.load_in_8bit, **kwargs).eval()
233 | if not args.load_in_8bit and not args.auto:
234 | model = model.cuda()
235 | image_size = model.config.force_image_size or model.config.vision_config.image_size
236 | use_thumbnail = model.config.use_thumbnail
237 |
238 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
239 | if total_params > 20 or args.dynamic:
240 | args.num_beams = 1
241 | logger.info(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
242 | else:
243 | logger.info(f'[test] total_params: {total_params}B')
244 | logger.info(f'[test] image_size: {image_size}')
245 | logger.info(f'[test] template: {model.config.template}')
246 | logger.info(f'[test] dynamic_image_size: {args.dynamic}')
247 | logger.info(f'[test] use_thumbnail: {use_thumbnail}')
248 | logger.info(f'[test] max_num: {args.max_num}')
249 | if torch.distributed.get_rank() == 0:
250 | if not os.path.exists(args.out_dir):
251 | os.makedirs(args.out_dir)
252 |
253 | model_id = '_'.join(args.checkpoint.split('/')[-1:])
254 | evaluate_chat_model()
--------------------------------------------------------------------------------
/evaluation/android_control/qwen2vl_inference.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import itertools
3 | import json
4 | import os
5 | import random
6 | import time
7 | from functools import partial
8 |
9 | import torch
10 | # from internvl.model.internvl_chat import InternVLChatModel
11 | # from internvl.train.dataset import build_transform, dynamic_preprocess
12 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
13 | from qwen_vl_utils import process_vision_info
14 | from PIL import Image
15 | from tqdm import tqdm
16 | from transformers import AutoTokenizer
17 | import logging
18 |
19 | logger = logging.getLogger(__name__)
20 | logger.setLevel(logging.INFO)
21 | max_pixels = 1024 * 1024
22 |
23 | import torch.multiprocessing as mp
24 | mp.set_start_method('spawn', force=True)
25 |
26 | ds_collections = {
27 | 'ac_high': {
28 | 'root': '/path/to/images',
29 | 'annotation': 'eval_json_files/ac_high_processing.jsonl',
30 | 'max_new_tokens': 999,
31 | 'min_new_tokens': 1,
32 | },
33 | 'ac_low': {
34 | 'root': '/path/to/images',
35 | 'annotation': 'eval_json_files/ac_low_processing.jsonl',
36 | 'max_new_tokens': 999,
37 | 'min_new_tokens': 1,
38 | }
39 |
40 | }
41 |
42 |
43 | def collate_fn(batches):
44 | inputs = [_['inputs'] for _ in batches]
45 | questions = [_['question'] for _ in batches]
46 | items = [_['item'] for _ in batches]
47 | return inputs, questions, items
48 |
49 |
50 |
51 | class AndroidControlDataset(torch.utils.data.Dataset):
52 |
53 | def __init__(self, root, annotation, input_size=224, dynamic_image_size=False,
54 | use_thumbnail=False, max_num=6, device="cuda:0"):
55 | self.root = root
56 | self.items = []
57 | f = open(annotation)
58 | data = f.readlines()
59 | for data_line in data:
60 | data_line = json.loads(data_line)
61 | self.items.append(data_line)
62 | self.input_size = input_size # input size??
63 | self.dynamic_image_size = dynamic_image_size
64 | self.use_thumbnail = use_thumbnail
65 | self.max_num = max_num
66 | self.device = device
67 | max_pixels = 1024 * 1024
68 | self.processor = AutoProcessor.from_pretrained("/nas/shared/NLP_A100/wuzhenyu/LLMs/Qwen2-VL-7B-Instruct", max_pixels=max_pixels)
69 | def __len__(self):
70 | return len(self.items)
71 |
72 | def __getitem__(self, idx):
73 | item = self.items[idx]
74 | image_path, question = item['image'], item['conversations'][0]['value']
75 | messages = [
76 | {
77 | "role": "user",
78 | "content": [
79 | {"type": "text", "text": question.split("")[0]},
80 | {
81 | "type": "image",
82 | "image": image_path,
83 | },
84 | {"type": "text", "text": question.split("")[1]},
85 | ],
86 | }
87 | ]
88 | text = self.processor.apply_chat_template(
89 | messages, tokenize=False, add_generation_prompt=True
90 | )
91 | image_inputs, video_inputs = process_vision_info(messages)
92 | inputs = self.processor(
93 | text=[text],
94 | images=image_inputs,
95 | videos=video_inputs,
96 | padding=True,
97 | return_tensors="pt",
98 | )
99 | return {
100 | 'question': question,
101 | 'inputs': inputs,
102 | 'item': item,
103 | }
104 |
105 |
106 | class InferenceSampler(torch.utils.data.sampler.Sampler):
107 |
108 | def __init__(self, size):
109 | self._size = int(size)
110 | assert size > 0
111 | self._rank = torch.distributed.get_rank()
112 | self._world_size = torch.distributed.get_world_size()
113 | self._local_indices = self._get_local_indices(size, self._world_size, self._rank)
114 |
115 | @staticmethod
116 | def _get_local_indices(total_size, world_size, rank):
117 | shard_size = total_size // world_size
118 | left = total_size % world_size
119 | shard_sizes = [shard_size + int(r < left) for r in range(world_size)]
120 |
121 | begin = sum(shard_sizes[:rank])
122 | end = min(sum(shard_sizes[:rank + 1]), total_size)
123 | return range(begin, end)
124 |
125 | def __iter__(self):
126 | yield from self._local_indices
127 |
128 | def __len__(self):
129 | return len(self._local_indices)
130 |
131 | processor = AutoProcessor.from_pretrained("/nas/shared/NLP_A100/wuzhenyu/LLMs/Qwen2-VL-7B-Instruct", max_pixels=max_pixels)
132 |
133 | def evaluate_chat_model(model):
134 | random.seed(args.seed)
135 |
136 | if args.ds_name_list == None:
137 | ds_names = ds_collections.keys()
138 | else:
139 | ds_names = args.ds_name_list
140 | for ds_name in ds_names:
141 | dataset = AndroidControlDataset(
142 | root=ds_collections[ds_name]['root'],
143 | annotation=ds_collections[ds_name]['annotation'],
144 | input_size=224,
145 | dynamic_image_size=args.dynamic,
146 | use_thumbnail=False,
147 | max_num=args.max_num,
148 | device=model.device
149 | )
150 | dataloader = torch.utils.data.DataLoader(
151 | dataset=dataset,
152 | sampler=InferenceSampler(len(dataset)),
153 | batch_size=args.batch_size,
154 | num_workers=args.num_workers,
155 | pin_memory=False,
156 | drop_last=False,
157 | collate_fn=partial(collate_fn),
158 | )
159 |
160 |
161 | logger.info(f'Evaluating {ds_name} ...')
162 |
163 | outputs = []
164 | for _, (inputs, questions, items) in tqdm(enumerate(dataloader)):
165 |
166 | inputs = inputs[0]
167 | inputs = inputs.to(model.device)
168 | print(f"inputs: {inputs}")
169 | generated_ids = model.generate(
170 | **{k: v.to(model.device) for k, v in inputs.items()},
171 | max_new_tokens=ds_collections[ds_name]['max_new_tokens']
172 | )
173 | generated_ids_trimmed = [
174 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
175 | ]
176 | output_text = processor.batch_decode(
177 | generated_ids_trimmed, skip_special_tokens=False, clean_up_tokenization_spaces=False
178 | )
179 | preds = [output_text[0]]
180 | print(f"preds: {preds}")
181 | for question, answer, item in zip(questions, preds, items):
182 | question_id = item['id']
183 | image_id = item['image'].split("\/")[-1].replace(".png", "")
184 | text = question
185 | output = {
186 | 'question_id': question_id,
187 | 'image_id': image_id,
188 | 'prompt': text,
189 | 'pred': answer,
190 | 'model_id': model_id,
191 | 'metadata': {},
192 | "task_name": "task2action",
193 | "string_format": "CSV_String",
194 | "position_format": "related"
195 | }
196 | outputs.append(output)
197 |
198 |
199 | torch.distributed.barrier()
200 | world_size = torch.distributed.get_world_size()
201 | merged_outputs = [None for _ in range(world_size)]
202 | torch.distributed.all_gather_object(merged_outputs, json.dumps(outputs))
203 |
204 | merged_outputs = [json.loads(_) for _ in merged_outputs]
205 | merged_outputs = [_ for _ in itertools.chain.from_iterable(merged_outputs)]
206 |
207 | torch.distributed.barrier()
208 |
209 | if torch.distributed.get_rank() == 0:
210 | print(len(merged_outputs))
211 | results_file = f'{model_id}_{ds_name}.jsonl'
212 | results_file = os.path.join(args.out_dir, results_file)
213 | # json.dump(merged_outputs, open(results_file, 'w'))
214 | with open(results_file, 'w') as f:
215 | for output in merged_outputs:
216 | json.dump(output, f)
217 | f.write('\n')
218 | print('Results saved to {}'.format(results_file))
219 |
220 |
221 | if __name__ == '__main__':
222 | parser = argparse.ArgumentParser()
223 | parser.add_argument('--checkpoint', type=str, default='')
224 | parser.add_argument('--datasets', type=str, default='')
225 | parser.add_argument('--batch-size', type=int, default=1)
226 | parser.add_argument('--num-workers', type=int, default=1)
227 | parser.add_argument('--num-beams', type=int, default=1)
228 | parser.add_argument('--temperature', type=float, default=0.0)
229 | parser.add_argument('--out-dir', type=str, default='results')
230 | parser.add_argument('--seed', type=int, default=0)
231 | parser.add_argument('--dynamic', action='store_true')
232 | parser.add_argument('--max-num', type=int, default=6)
233 | parser.add_argument('--load-in-8bit', action='store_true')
234 | parser.add_argument('--auto', action='store_true')
235 | parser.add_argument('--ds_name_list', type=str, nargs='*', default=None, help='List of dataset names')
236 | args = parser.parse_args()
237 |
238 | args.datasets = args.datasets.split(',')
239 | logger.info('datasets:', args.datasets)
240 | assert args.batch_size == 1, 'Only batch size 1 is supported'
241 |
242 | torch.distributed.init_process_group(
243 | backend='nccl',
244 | world_size=int(os.getenv('WORLD_SIZE', '1')),
245 | rank=int(os.getenv('RANK', '0')),
246 | )
247 |
248 | torch.cuda.set_device(int(os.getenv('LOCAL_RANK', 0)))
249 |
250 |
251 | if args.auto:
252 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
253 | kwargs = {'device_map': 'auto'} if args.auto else {}
254 |
255 | model = Qwen2VLForConditionalGeneration.from_pretrained(
256 | args.checkpoint,
257 | torch_dtype=torch.bfloat16
258 | ).cuda().eval()
259 | total_params = sum(p.numel() for p in model.parameters()) / 1e9
260 | if total_params > 20 or args.dynamic:
261 | args.num_beams = 1
262 | logger.info(f'[test] total_params: {total_params}B, use num_beams: {args.num_beams}')
263 | else:
264 | logger.info(f'[test] total_params: {total_params}B')
265 | logger.info(f'[test] dynamic_image_size: {args.dynamic}')
266 | logger.info(f'[test] max_num: {args.max_num}')
267 | if torch.distributed.get_rank() == 0:
268 | if not os.path.exists(args.out_dir):
269 | os.makedirs(args.out_dir)
270 |
271 | model_id = '_'.join(args.checkpoint.split('/')[-1:])
272 | evaluate_chat_model(model)
273 |
--------------------------------------------------------------------------------
/evaluation/android_control/run_ac_inference.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | DATASET=${1}
4 | CHECKPOINT=${2}
5 | echo "DATASET: ${DATASET}"
6 | echo "CHECKPOINT: ${CHECKPOINT}"
7 |
8 | # 检查数据集名称是否正确
9 | options=("ac_high" "ac_low")
10 | if [[ " ${options[@]} " =~ " ${DATASET} " ]]; then
11 | echo "输入 '$DATASET' 是有效选项"
12 | else
13 | echo "错误: '$DATASET' 不是有效选项"
14 | echo "有效选项是: ${options[@]}"
15 | exit 1
16 | fi
17 |
18 | PATCH=24
19 |
20 | TEST_SET_NAME=${DATASET}
21 | echo "测试集key: ${TEST_SET_NAME}"
22 |
23 | MASTER_PORT=${MASTER_PORT:-63604}
24 | PORT=${PORT:-63604}
25 | PARTITION=${PARTITION:-"INTERN2"}
26 |
27 | export MASTER_PORT=${MASTER_PORT}
28 | export PORT=${PORT}
29 | echo "GPUS: ${GPUS}"
30 |
31 | NODES=$((GPUS / GPUS_PER_NODE))
32 | GPUS_PER_NODE=${GPUS_PER_NODE:-${GPUS}}
33 | QUOTA_TYPE=${QUOTA_TYPE:-"reserved"}
34 | SRUN_ARGS=${SRUN_ARGS:-""}
35 |
36 |
37 | your/conda/path/torchrun \
38 | --nnodes=1 \
39 | --node_rank=0 \
40 | --master_addr=127.0.0.1 \
41 | --nproc_per_node=${GPUS} \
42 | --master_port=${MASTER_PORT} \
43 | internvl2_inference.py --checkpoint ${CHECKPOINT_DIR} --datasets ${DATASET} \
44 | --dynamic \
45 | --max-num ${PATCH} \
46 | --ds_name_list ${TEST_SET_NAME} \
47 | --out-dir ${OUT_DIR};"
48 |
49 |
--------------------------------------------------------------------------------
/evaluation/android_world/README.md:
--------------------------------------------------------------------------------
1 | # AndroidWorld Evaluation
2 |
3 | https://github.com/google-research/android_world
--------------------------------------------------------------------------------
/evaluation/eval_json_files/android_control_test_subsplits.json:
--------------------------------------------------------------------------------
1 | {"IDD": [4790, 8333, 1589, 14287, 2860, 16124, 19349, 17863, 880, 4537, 3506, 10986, 10235, 12166, 7838, 11512, 1489, 1719, 12303, 10558, 11410, 697, 18103, 2775, 13134, 5366, 16951, 14103, 14607, 16467, 19199, 2426, 13314, 1665, 1569, 3195, 8886, 10988, 14804, 4622, 13806, 14495, 4123, 12182, 4807, 6645, 19887, 6158, 18654, 14525, 16402, 2706, 13189, 17036, 1260, 8583, 10864, 3484, 6714, 2005, 11550, 19946, 9283, 15527, 7201, 6304, 7320, 6472, 18544, 3061, 3613, 18932, 16588, 10058, 1764, 16503, 8197, 12803, 18557, 5147, 3281, 11672, 7961, 12758, 14815, 2676, 470, 6240, 4045, 18875, 27, 6424, 3240, 15892, 4550, 3833, 14723, 18491, 5718, 9961, 11673, 16022, 2139, 9711, 17522, 16526, 11198, 15865, 9823, 2010, 18231, 16674, 14928, 19477, 19315, 7195, 17598, 7609, 2666, 1542, 11128, 10925, 1727, 12926, 13274, 20116, 5602, 640, 9837, 19666, 13674, 13459, 15204, 4564, 8064, 19602, 787, 4377, 2221, 19277, 1581, 15057, 6119, 4047, 6323, 4724, 16169, 13752, 15594, 971, 6708, 1089, 8717, 8484, 9195, 631, 16072, 13529, 7646, 16004, 5704, 15070, 15540, 2748, 15785, 19510, 7903, 13428, 17280, 11726, 2777, 6346, 6443, 907, 5002, 11848, 4782, 8262, 6532, 3476, 3566, 12338, 1969, 1821, 18038, 5278, 19382, 3124, 5072, 15055, 19817, 2802, 12538, 6871, 13546, 5205, 1082, 13805, 3176, 15718, 7194, 19833, 10844, 9747, 16527, 18703, 6236, 16036, 10944, 2471, 13228, 4745, 904, 12336, 9072, 3814, 15267, 1815, 19565, 1045, 18771, 17880, 10956, 3041, 9812, 13061, 7430, 16579, 16760, 15597, 8500, 15658, 5595, 14172, 4328, 15036, 2685, 777, 17611, 11549, 8730, 1933, 9144, 20016, 10292, 14129, 4015, 5407, 19971, 4936, 2165, 10443, 938, 14297, 17102, 610, 7382, 8202, 14779, 11306, 9755, 12076, 1280, 14076, 12614, 4797, 9143, 9120, 13217, 11277, 8864, 17334, 13248, 6760, 5996, 3256, 2754, 13036, 7119, 3128, 15032, 19088, 17452, 16464, 11800, 18310, 5316, 10081, 6041, 15100, 10167, 18521, 7472, 17317, 13385, 1438, 12069, 15020, 15832, 18237, 9835, 5836, 18513, 261, 2682, 6589, 19880, 15419, 18944, 140, 12863, 1959, 16837, 12163, 5215, 1632, 5241, 2673, 5791, 13441, 17981, 14639, 7456, 12391, 19551, 13440, 9168, 3898, 12594, 8074, 1394, 13256, 7829, 17033, 6759, 2124, 15653, 8963, 423, 11843, 55, 10232, 2912, 8748, 12829, 6153, 10294, 17169, 18723, 14514, 16208, 2078, 16441, 2612, 1741, 11222, 2588, 3945, 19585, 19394, 15225, 12678, 13436, 4796, 5897, 8726, 16514, 13754, 20122, 10273, 6707, 19877, 17994, 6404, 10476, 4699, 16353, 7958, 6402, 8271, 3347, 12743, 15966, 16669, 5264, 4833, 11342, 15349, 6207, 17842, 6048, 13182, 6408, 9938, 12473, 11632, 13975, 15749, 19268, 4413, 19719, 15752, 8324, 3419, 11912, 8155, 10608, 2250, 3760, 1672, 14401, 794, 5956, 8629, 5906, 3031, 5837, 2406, 3708, 8537, 7748, 3319, 11300, 5455, 15017, 4680, 10726, 6623, 17479, 11260, 7860, 11714, 127, 10142, 12739, 229, 7866, 12390, 17422, 7399, 14343, 3292, 5574, 14212, 13554, 18181, 3000, 2630, 5902, 9783, 394, 13970, 2936, 11450, 3496, 11035, 3747, 7993, 11494, 14776, 9467, 15241, 12577, 10183, 3948, 10690, 17355, 14358, 9698, 1892, 5385, 4976, 4943, 1330, 18124, 4756, 719, 18574, 16140, 2302, 7821, 13613, 6606, 19680, 13745, 6211, 2822, 16860, 7633, 17085, 6353, 16452, 2387, 17111, 13522, 1904, 7806, 4501, 19813, 18185, 13349, 9934, 19653, 3373, 1381, 2158, 18099, 11404, 10323, 3806, 15405, 1895, 8908, 10037, 5076, 759, 14307, 10318, 5693, 12642, 916, 13468, 12508, 8216, 8282, 1099, 6689, 4090, 11040, 15350, 13429, 11217, 18991, 11682, 5563, 1027, 12187, 18938, 936, 3026, 7587, 8740, 6932, 13018, 14109, 9112, 3425, 7774, 15348, 13841, 14178, 19928, 18950, 3738, 6919, 7100, 8804, 19552, 3241, 10241, 14319, 15583, 2535, 387, 19190, 4876, 11967, 12978, 2327, 11274, 12429, 20097, 16712, 19460, 7548, 15999, 10531, 6876, 8084, 18488, 9593, 7249, 7455, 8709, 17840, 15167, 14940, 19033, 4998, 5615, 12479, 5611, 2121, 7538, 14768, 19242, 12243, 2771, 14942, 12334, 16768, 10501, 943, 8651, 16608, 17114, 4437, 14320, 4692, 15316, 10710, 2592, 3250, 8389, 84, 15908, 19339, 11520, 10699, 18679, 4603, 19537, 1306, 15778, 4902, 14953, 12364, 3338, 10265, 16154, 15537, 2862, 9128, 19676, 15857, 8899, 9231, 7765, 4159, 14157, 13733, 2743, 7239, 18973, 3317, 4190, 12289, 54, 15646, 237, 5990, 3507, 7955, 16696, 19251, 17831, 10571, 13526, 14537, 14230, 1194, 1019, 7869, 2269, 3438, 407, 20055, 3100, 4358, 17953, 2488, 2626, 10724, 6670, 15703, 5200, 9893, 6904, 2921, 2553, 18078, 7889, 2093, 11885, 16663, 4988, 13846, 12498, 15476, 12823, 11448, 6121, 20086, 11042, 9719, 7884, 13596, 300, 13697, 13800, 17237, 19170, 3191, 803, 4640, 12709, 14114, 19657, 15386, 11412, 15634, 18894, 11900, 12248, 6779, 10786, 5676, 9767, 1624, 8697, 13686], "category_unseen": [53, 62, 73, 104, 122, 136, 167, 172, 185, 190, 219, 227, 228, 231, 270, 282, 296, 318, 382, 406, 417, 495, 548, 567, 569, 575, 586, 591, 636, 660, 662, 672, 674, 723, 724, 727, 828, 886, 905, 937, 945, 1024, 1049, 1074, 1168, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1309, 1311, 1329, 1336, 1344, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1501, 1654, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2174, 2195, 2200, 2201, 2232, 2288, 2405, 2412, 2473, 2484, 2507, 2544, 2582, 2609, 2644, 2669, 2697, 2814, 2836, 2856, 2886, 2889, 2916, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3380, 3449, 3519, 3530, 3559, 3591, 3592, 3700, 3715, 3751, 3761, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4158, 4178, 4183, 4194, 4204, 4216, 4226, 4244, 4260, 4297, 4299, 4330, 4336, 4348, 4406, 4414, 4451, 4491, 4535, 4545, 4560, 4579, 4581, 4608, 4618, 4635, 4657, 4670, 4674, 4740, 4803, 4812, 4858, 4885, 4889, 4964, 4977, 4979, 5016, 5019, 5029, 5122, 5148, 5154, 5185, 5199, 5234, 5256, 5271, 5276, 5288, 5312, 5332, 5342, 5355, 5373, 5431, 5446, 5510, 5564, 5636, 5673, 5677, 5684, 5691, 5707, 5732, 5741, 5742, 5776, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5961, 5968, 5973, 6003, 6045, 6046, 6081, 6147, 6177, 6178, 6181, 6258, 6271, 6328, 6332, 6338, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6463, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6792, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6979, 6996, 7014, 7020, 7030, 7032, 7101, 7184, 7205, 7247, 7298, 7329, 7337, 7499, 7509, 7565, 7585, 7607, 7608, 7614, 7643, 7682, 7733, 7753, 7826, 7827, 7878, 7917, 7969, 7999, 8044, 8049, 8146, 8158, 8165, 8188, 8193, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8725, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8985, 8997, 8998, 9074, 9131, 9208, 9298, 9301, 9314, 9318, 9320, 9335, 9357, 9373, 9401, 9420, 9431, 9452, 9464, 9466, 9476, 9547, 9588, 9609, 9683, 9721, 9722, 9753, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10062, 10067, 10089, 10095, 10098, 10137, 10140, 10206, 10223, 10255, 10319, 10366, 10441, 10464, 10520, 10527, 10559, 10590, 10612, 10672, 10693, 10711, 10738, 10739, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11091, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11305, 11326, 11373, 11379, 11390, 11417, 11429, 11454, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11926, 11935, 11946, 11968, 12015, 12045, 12123, 12176, 12217, 12223, 12245, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12633, 12648, 12649, 12736, 12816, 12856, 12885, 12901, 12904, 12933, 12946, 12962, 12965, 13019, 13070, 13072, 13108, 13126, 13176, 13203, 13247, 13307, 13350, 13355, 13367, 13405, 13437, 13470, 13490, 13497, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13676, 13688, 13704, 13731, 13746, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14251, 14259, 14265, 14269, 14310, 14321, 14339, 14345, 14390, 14395, 14403, 14454, 14493, 14523, 14554, 14562, 14620, 14628, 14651, 14654, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14981, 14985, 14989, 14998, 15048, 15063, 15076, 15149, 15168, 15344, 15402, 15425, 15442, 15447, 15457, 15485, 15518, 15541, 15613, 15690, 15746, 15768, 15801, 15802, 15814, 15843, 15886, 15890, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16379, 16426, 16454, 16491, 16501, 16532, 16562, 16584, 16593, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16829, 16830, 16874, 16924, 17082, 17098, 17127, 17136, 17180, 17181, 17190, 17207, 17271, 17293, 17325, 17402, 17410, 17444, 17481, 17506, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17647, 17661, 17677, 17712, 17748, 17807, 17824, 17841, 17853, 17929, 17959, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18245, 18298, 18302, 18306, 18335, 18340, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18773, 18800, 18828, 18883, 18963, 18983, 19003, 19038, 19189, 19197, 19269, 19300, 19347, 19375, 19461, 19572, 19593, 19636, 19660, 19697, 19721, 19725, 19766, 19805, 19809, 19812, 19846, 19897, 19906, 19927, 19939, 19956, 19979, 20131, 20189], "app_unseen": [53, 62, 73, 104, 122, 167, 172, 185, 190, 216, 219, 227, 228, 231, 270, 282, 296, 318, 348, 382, 406, 417, 495, 548, 567, 569, 575, 586, 636, 660, 672, 723, 724, 886, 905, 937, 945, 1024, 1049, 1074, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1302, 1309, 1311, 1329, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1501, 1654, 1693, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2174, 2195, 2200, 2201, 2232, 2288, 2405, 2412, 2473, 2484, 2507, 2544, 2609, 2644, 2669, 2697, 2814, 2856, 2886, 2889, 2916, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3380, 3449, 3519, 3559, 3592, 3613, 3700, 3715, 3761, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4158, 4178, 4194, 4226, 4244, 4260, 4297, 4299, 4330, 4348, 4406, 4414, 4451, 4491, 4535, 4560, 4579, 4581, 4608, 4618, 4635, 4657, 4674, 4740, 4803, 4812, 4858, 4885, 4889, 4964, 4977, 4979, 5016, 5019, 5029, 5122, 5148, 5185, 5199, 5234, 5235, 5256, 5271, 5276, 5312, 5355, 5373, 5446, 5564, 5636, 5673, 5676, 5677, 5680, 5684, 5691, 5707, 5732, 5741, 5742, 5776, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5961, 6003, 6045, 6046, 6081, 6147, 6178, 6258, 6271, 6332, 6338, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6979, 6996, 7014, 7020, 7030, 7101, 7184, 7205, 7247, 7298, 7329, 7334, 7499, 7509, 7565, 7607, 7608, 7614, 7643, 7682, 7705, 7733, 7753, 7826, 7878, 7917, 7999, 8146, 8158, 8165, 8188, 8193, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8997, 8998, 9074, 9131, 9298, 9301, 9314, 9318, 9320, 9335, 9357, 9373, 9401, 9420, 9431, 9452, 9464, 9466, 9588, 9609, 9683, 9721, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10067, 10089, 10098, 10137, 10147, 10206, 10223, 10255, 10319, 10384, 10441, 10464, 10520, 10527, 10559, 10590, 10612, 10693, 10711, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11077, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11305, 11326, 11332, 11373, 11377, 11379, 11390, 11417, 11429, 11454, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11935, 11946, 11968, 12015, 12045, 12123, 12176, 12223, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12633, 12649, 12736, 12816, 12856, 12885, 12901, 12933, 12962, 12965, 13019, 13070, 13072, 13108, 13126, 13176, 13307, 13327, 13350, 13355, 13367, 13405, 13437, 13490, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13676, 13688, 13704, 13746, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14259, 14265, 14269, 14296, 14310, 14321, 14339, 14390, 14395, 14403, 14454, 14493, 14523, 14562, 14620, 14628, 14651, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14928, 14981, 14985, 14989, 14998, 15048, 15076, 15149, 15344, 15402, 15425, 15442, 15447, 15457, 15485, 15613, 15746, 15775, 15801, 15802, 15814, 15843, 15886, 15890, 15935, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16379, 16454, 16491, 16501, 16532, 16584, 16593, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16829, 16830, 16874, 16957, 17098, 17127, 17136, 17154, 17180, 17181, 17190, 17207, 17293, 17325, 17402, 17410, 17481, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17661, 17677, 17748, 17807, 17824, 17841, 17853, 17929, 17959, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18245, 18302, 18335, 18340, 18351, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18828, 18883, 18963, 18983, 19003, 19038, 19189, 19197, 19269, 19300, 19347, 19366, 19375, 19461, 19572, 19593, 19697, 19721, 19725, 19805, 19809, 19812, 19846, 19897, 19906, 19939, 19956, 20109, 20131, 20189], "task_unseen": [45, 53, 62, 73, 104, 122, 136, 167, 172, 185, 190, 219, 227, 228, 231, 270, 282, 296, 318, 382, 406, 417, 480, 495, 548, 567, 569, 575, 586, 591, 636, 660, 662, 672, 674, 723, 724, 727, 828, 841, 886, 892, 898, 905, 937, 945, 1024, 1049, 1074, 1168, 1170, 1193, 1198, 1203, 1206, 1293, 1300, 1309, 1311, 1329, 1336, 1344, 1359, 1360, 1364, 1378, 1458, 1464, 1474, 1480, 1482, 1501, 1512, 1586, 1654, 1704, 1736, 1809, 1811, 1819, 1820, 1836, 1842, 1853, 1908, 2060, 2079, 2086, 2098, 2122, 2146, 2158, 2174, 2195, 2200, 2201, 2232, 2288, 2313, 2405, 2412, 2442, 2473, 2484, 2507, 2527, 2544, 2582, 2609, 2644, 2669, 2697, 2757, 2814, 2836, 2856, 2886, 2889, 2916, 2947, 2961, 2991, 3007, 3010, 3023, 3034, 3099, 3113, 3148, 3152, 3159, 3198, 3224, 3238, 3242, 3297, 3315, 3380, 3449, 3493, 3519, 3530, 3559, 3591, 3592, 3700, 3715, 3751, 3761, 3770, 3792, 3793, 3812, 3845, 3849, 3875, 4093, 4114, 4158, 4178, 4183, 4194, 4195, 4204, 4216, 4226, 4244, 4247, 4260, 4297, 4299, 4330, 4336, 4348, 4406, 4414, 4422, 4451, 4491, 4516, 4535, 4545, 4560, 4579, 4581, 4602, 4608, 4618, 4635, 4657, 4662, 4670, 4674, 4740, 4803, 4806, 4812, 4858, 4881, 4883, 4885, 4889, 4916, 4964, 4977, 4979, 5016, 5019, 5025, 5029, 5122, 5148, 5154, 5185, 5199, 5234, 5251, 5256, 5271, 5276, 5288, 5312, 5332, 5342, 5355, 5373, 5431, 5446, 5459, 5497, 5510, 5564, 5636, 5664, 5673, 5677, 5684, 5691, 5707, 5732, 5741, 5742, 5773, 5776, 5779, 5794, 5806, 5850, 5852, 5899, 5918, 5926, 5932, 5958, 5961, 5968, 5973, 6003, 6045, 6046, 6081, 6147, 6177, 6178, 6181, 6239, 6258, 6271, 6312, 6328, 6332, 6338, 6385, 6405, 6413, 6418, 6430, 6445, 6446, 6449, 6463, 6471, 6505, 6541, 6555, 6562, 6574, 6578, 6638, 6640, 6666, 6681, 6683, 6710, 6714, 6768, 6792, 6809, 6844, 6888, 6903, 6917, 6930, 6934, 6944, 6979, 6996, 7014, 7020, 7030, 7032, 7045, 7101, 7184, 7204, 7205, 7247, 7271, 7298, 7329, 7337, 7356, 7379, 7499, 7509, 7565, 7585, 7607, 7608, 7614, 7643, 7644, 7682, 7733, 7753, 7767, 7826, 7827, 7878, 7917, 7935, 7969, 7999, 8044, 8049, 8141, 8146, 8158, 8165, 8188, 8193, 8242, 8248, 8261, 8307, 8336, 8455, 8488, 8489, 8501, 8506, 8532, 8579, 8658, 8703, 8725, 8738, 8770, 8780, 8795, 8826, 8871, 8885, 8898, 8914, 8917, 8939, 8982, 8985, 8997, 8998, 9062, 9074, 9131, 9208, 9298, 9301, 9307, 9314, 9318, 9320, 9335, 9357, 9373, 9376, 9394, 9401, 9420, 9431, 9452, 9464, 9466, 9476, 9547, 9588, 9609, 9628, 9634, 9683, 9721, 9722, 9753, 9788, 9793, 9804, 9806, 9921, 9935, 9963, 10033, 10053, 10062, 10067, 10089, 10095, 10098, 10137, 10140, 10206, 10223, 10255, 10319, 10366, 10441, 10464, 10520, 10527, 10554, 10559, 10590, 10612, 10672, 10693, 10711, 10738, 10739, 10807, 10809, 10841, 10911, 10912, 10928, 10966, 11014, 11028, 11091, 11132, 11135, 11159, 11185, 11187, 11218, 11234, 11239, 11242, 11294, 11296, 11304, 11305, 11326, 11373, 11379, 11390, 11417, 11429, 11454, 11482, 11488, 11490, 11501, 11551, 11552, 11559, 11571, 11595, 11622, 11623, 11695, 11705, 11728, 11777, 11816, 11829, 11831, 11840, 11847, 11852, 11870, 11890, 11896, 11926, 11935, 11946, 11968, 11972, 12015, 12045, 12057, 12123, 12176, 12217, 12223, 12227, 12245, 12321, 12356, 12434, 12483, 12499, 12514, 12546, 12571, 12581, 12624, 12625, 12633, 12648, 12649, 12653, 12736, 12754, 12816, 12856, 12885, 12901, 12904, 12933, 12946, 12962, 12965, 13019, 13070, 13072, 13089, 13108, 13126, 13176, 13203, 13247, 13307, 13327, 13350, 13355, 13367, 13390, 13405, 13437, 13470, 13490, 13497, 13498, 13502, 13512, 13560, 13574, 13605, 13619, 13651, 13659, 13668, 13669, 13676, 13688, 13704, 13731, 13735, 13746, 13821, 13838, 13885, 13886, 13887, 13934, 13992, 14027, 14107, 14154, 14219, 14220, 14251, 14259, 14265, 14269, 14310, 14321, 14339, 14345, 14390, 14395, 14403, 14454, 14493, 14523, 14554, 14556, 14562, 14620, 14628, 14651, 14654, 14670, 14681, 14735, 14764, 14774, 14787, 14843, 14902, 14981, 14985, 14989, 14998, 15048, 15063, 15076, 15149, 15168, 15244, 15252, 15283, 15344, 15402, 15409, 15425, 15438, 15442, 15447, 15457, 15485, 15518, 15541, 15591, 15613, 15625, 15690, 15746, 15768, 15801, 15802, 15814, 15843, 15886, 15890, 15968, 15975, 15979, 15981, 15993, 16002, 16035, 16101, 16128, 16178, 16212, 16223, 16228, 16259, 16275, 16321, 16324, 16379, 16426, 16454, 16491, 16501, 16532, 16562, 16584, 16593, 16614, 16656, 16668, 16671, 16679, 16703, 16705, 16787, 16805, 16810, 16829, 16830, 16853, 16874, 16924, 16979, 17037, 17065, 17082, 17098, 17127, 17136, 17180, 17181, 17190, 17207, 17258, 17271, 17293, 17325, 17402, 17410, 17444, 17457, 17481, 17506, 17510, 17517, 17533, 17538, 17563, 17588, 17621, 17628, 17629, 17647, 17661, 17677, 17712, 17748, 17807, 17824, 17841, 17853, 17900, 17929, 17959, 17969, 17977, 18046, 18066, 18068, 18087, 18141, 18173, 18182, 18245, 18298, 18302, 18306, 18335, 18340, 18360, 18374, 18438, 18444, 18577, 18598, 18619, 18643, 18645, 18661, 18673, 18741, 18773, 18800, 18828, 18883, 18963, 18983, 19003, 19012, 19019, 19038, 19061, 19189, 19197, 19269, 19300, 19347, 19375, 19453, 19461, 19572, 19593, 19636, 19660, 19697, 19721, 19725, 19766, 19805, 19809, 19812, 19846, 19897, 19906, 19927, 19939, 19956, 19979, 20033, 20131, 20189]}
--------------------------------------------------------------------------------
/faq.md:
--------------------------------------------------------------------------------
1 | # FAQ
2 |
3 | Thank you for your interest in OS-Genesis. Below are some questions we have collected from emails, Hugging Face, and WeChat communications. We hope these can be helpful to you.
4 |
5 | ## When will the checkpoints and data be available?
6 |
7 | We have already uploaded the checkpoints and evaluation code for AndroidControl. ~~The remaining checkpoints will be uploaded in the coming days~~ (done). Due to server bandwidth limitations, this may take some time. The data will also be available shortly.
8 |
9 |
10 | ## How About Desktop?
11 |
12 | Q: Why haven’t you worked on PC/Desktop data? Is there a particular reason?
13 |
14 | A:
15 | We originally intended to cover PC, mobile, and web. In fact, our high-level reverse-synthesis process can also run on PC (we used [OSWorld](https://os-world.github.io/) as the dynamic environment). However, we decided not to continue on the PC side for the following reasons:
16 | 1. Data collection on PC is too difficult for a model-based approach.
17 | For instance, in [OSWorld](https://os-world.github.io/), the success rate for GPT-4o across most scenarios is <10%, which means the proportion of high-quality trajectories would be low. Ensuring quality would require a massive amount of data and a more rigorous TRM, making costly.
18 |
19 | 2. Even after collecting trajectories, there are significant challenges in training:
20 | 1. Length of a11ytree:
21 | We use a11ytree, and on desktop the a11ytree is much longer than the mobile or web DOM. In training that involves multimodal information, it exceeds the context window of models like InternVL and Qwen.
22 | 2. Instruction-following issues:
23 | Currently, open-source VLMs face major problems with instruction-following on PC environments.
--------------------------------------------------------------------------------
/static/OS-Genesis-Badge.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/static/OS-Genesis-Badge.png
--------------------------------------------------------------------------------
/static/OS-Genesis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/OS-Copilot/OS-Genesis/9ecbe594352f254b9a9228468f9ca9a77b2388a2/static/OS-Genesis.png
--------------------------------------------------------------------------------