├── coat
├── __init__.py
├── .DS_Store
├── utils.py
├── config.py
├── screen_utils.py
├── evaluate.py
├── action_utils.py
├── model.py
├── agent.py
└── config.yaml
├── assets
├── cmp.png
└── intro-total.png
├── data-example
└── GOOGLE_APPS-523638528775825151
│ ├── GOOGLE_APPS-523638528775825151_0.png
│ ├── GOOGLE_APPS-523638528775825151_1.png
│ ├── GOOGLE_APPS-523638528775825151_2.png
│ ├── GOOGLE_APPS-523638528775825151_3.png
│ └── GOOGLE_APPS-523638528775825151.json
├── README.md
└── run_coat.py
/coat/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/cmp.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/assets/cmp.png
--------------------------------------------------------------------------------
/coat/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/coat/.DS_Store
--------------------------------------------------------------------------------
/assets/intro-total.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/assets/intro-total.png
--------------------------------------------------------------------------------
/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png
--------------------------------------------------------------------------------
/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png
--------------------------------------------------------------------------------
/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png
--------------------------------------------------------------------------------
/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png
--------------------------------------------------------------------------------
/coat/utils.py:
--------------------------------------------------------------------------------
1 | from colorama import init, Fore, Style
2 |
3 |
4 | def print_with_color(text, color):
5 | colors = {
6 | 'red': Fore.RED,
7 | 'green': Fore.GREEN,
8 | 'yellow': Fore.YELLOW,
9 | 'blue': Fore.BLUE,
10 | 'magenta': Fore.MAGENTA,
11 | 'cyan': Fore.CYAN,
12 | 'white': Fore.WHITE,
13 | 'reset': Fore.RESET
14 | }
15 |
16 | color_code = colors.get(color.lower(), Fore.RESET)
17 | print(color_code + text)
18 |
19 |
--------------------------------------------------------------------------------
/coat/config.py:
--------------------------------------------------------------------------------
1 | from yacs.config import CfgNode as CN
2 |
3 | _C = CN()
4 | _C.DEMO_NAME = "" # 当前实验的name
5 | _C.DEMO_MODE = "" # demo模式
6 | _C.OUTPUT_DIR = "" # demo输出的保存路径
7 |
8 |
9 | _C.DATA = CN() # ----- 数据集相关配置 -----
10 | _C.DATA.DATA_DIR = "" # 数据集路径
11 | _C.DATA.SPLIT = "" # 进行测试的数据集划分
12 |
13 |
14 | _C.MODEL = CN() # ----- 模型API相关配置 -----
15 | _C.MODEL.NAME = ""
16 | _C.MODEL.OPENAI_API_KEY = ""
17 | _C.MODEL.OPENAI_API_URL = ""
18 | _C.MODEL.GEMINI_API_KEY = ""
19 | _C.MODEL.GEMINI_MODEL = "gemini-pro-vision"
20 | _C.MODEL.DASHSCOPE_API_KEY = ""
21 | _C.MODEL.DASHSCOPE_MODEL = "qwen-vl-max"
22 |
23 |
24 | _C.PROMPTS = CN() # ----- 提示词相关配置 -----
25 |
26 | _C.PROMPTS.SCREEN_DESC = CN()
27 | _C.PROMPTS.SCREEN_DESC.SYSTEM = ""
28 | _C.PROMPTS.SCREEN_DESC.USER = ""
29 |
30 | _C.PROMPTS.ACTION_THINK = CN()
31 | _C.PROMPTS.ACTION_THINK.SYSTEM = ""
32 | _C.PROMPTS.ACTION_THINK.USER = ""
33 |
34 | _C.PROMPTS.ACTION_DESC = CN()
35 | _C.PROMPTS.ACTION_DESC.SYSTEM = ""
36 | _C.PROMPTS.ACTION_DESC.USER = ""
37 |
38 | _C.PROMPTS.ACTION_PREDICT = CN()
39 | _C.PROMPTS.ACTION_PREDICT.SYSTEM = ""
40 | _C.PROMPTS.ACTION_PREDICT.USER = ""
41 |
42 | _C.PROMPTS.ACTION_RESULT = CN()
43 | _C.PROMPTS.ACTION_RESULT.SYSTEM = ""
44 | _C.PROMPTS.ACTION_RESULT.USER = ""
45 |
46 |
47 | def get_cfg_defaults():
48 | """Get a yacs CfgNode object with default values for my_project."""
49 | # Return a clone so that the defaults will not be altered
50 | # This is for the "local variable" use pattern
51 | return _C.clone()
52 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
Android in the Zoo:
Chain-of-Action-Thought for GUI Agents\br>
3 |
4 | Jiwen Zhang1,2 , Jihao Wu2 , Yihua Teng2 , Minghui Liao2 , Nuo Xu2 , Xiao Xiao2 , Zhongyu Wei1 , Duyu Tang2.
5 |
6 | 1Fudan University 2Huawei Inc.
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 | --------------
24 |
25 | This work presents **Chain-of-Action-Thought** (dubbed **CoAT**), which takes the description of the previous actions, the current screen, and more importantly the action thinking of what actions should be performed and the outcomes led by the chosen action. To enable an adaptive learning of CoAT process, we construct a benchmark **Android-In-The-Zoo**, which contains 18,643 screen-action pairs together with CoAT annotations.
26 |
27 |
28 |

29 |
30 |
31 |
32 | ## 📣 Update
33 |
34 | - **[2024-10-15]** Evaluation code has been released!
35 |
36 | - **[2024-09-20]** Our work has been accepted to EMNLP2024 Findings!
37 |
38 | - **[2024-07-16]** We add the demo code for using CoAT on proprietary models (GPT4V, Gemini-Pro and Qwen-VL-Max)!
39 |
40 | - **[2024-03-31]** We release the first version of our AiTZ dataset!
41 |
42 | - **[2024-03-05]** We have our paper arxived, now you can acess it by clicking [here](https://arxiv.org/abs/2403.02713) !
43 |
44 |
45 |
46 | ## Android-in-the-Zoo
47 |
48 | The data in AiTZ has 18,643 screens together with 2500+ instructions, all annotated with CoAT-driven semantic labels. The sample format for each time step is
49 |
50 | ```json
51 | {
52 | "episode_id": "523638528775825151",
53 | "episode_length": 4,
54 | "step_id": 0,
55 | "coat_screen_desc": "[observe]",
56 | "coat_action_think": "[action think]",
57 | "coat_action_desc": "[next action description]",
58 | "coat_action_result": "[action result]",
59 | ...
60 | }
61 | ```
62 |
63 | You can refer to `data-example` folder for a more specific example.
64 |
65 |
66 | ### Download
67 |
68 | Our dataset ([GoogleDrive](https://drive.google.com/file/d/12xOV2m62fBUFLhMcWIsFiC6zwV7a2RhI/view?usp=sharing) or [BaiduNetdisk](https://pan.baidu.com/s/1dHG-4L0RE1aYINzMSA4dCw?pwd=7g82)) contains both the screens (.png) and the annotations (.json), consuming about 2.6G device space.
69 |
70 |
71 | ### Statistics
72 |
73 | | Subset | Train | | Test | |
74 | | ----------- | ---------- | --------- | ---------- | --------- |
75 | | | \#Episodes | \#Screens | \#Episodes | \#Screens |
76 | | General | 323 | 2405 | 156 | 1202 |
77 | | Install | 286 | 2519 | 134 | 1108 |
78 | | GoogleApps | 166 | 1268 | 76 | 621 |
79 | | Single | 844 | 2594 | 0 | 0 |
80 | | WebShopping | 379 | 5133 | 140 | 1793 |
81 | | **Total** | **1998** | **13919** | **506** | **4724** |
82 |
83 |
84 |
85 | ## Chain-of-Action-Thought
86 |
87 | ### Comparison with other context modeling methods
88 |
89 | We validate the effectiveness of CoAT by conducting a preliminary experiment on 50 episodes randomly sampled from AITW dataset.
90 |
91 | The compared baselines are [Chain-of-Thought](https://arxiv.org/abs/2201.11903) (CoT) and [Chain-of-Actions](https://arxiv.org/abs/2309.11436) (CoA).
92 |
93 | | Prompt | Metric | QwenVL | Gemini-PV | GPT-4V |
94 | | ------ | ------ | ------ | --------- | ------ |
95 | | CoA | hit | 94.5 | 99.8 | 99.3 |
96 | | | acc | 44.4 | 47.7 | 62.8 |
97 | | CoT | hit | 95.6 | 97.5 | 97.1 |
98 | | | acc | 49.4 | 52.0 | 64.1 |
99 | | CoAT | hit | 96.3 | 96.4 | 98.2 |
100 | | | acc | 52.4 | 54.5 | 73.5 |
101 |
102 | where “hit” means format hit rate, and “acc” means action type prediction accuracy. (One can refer to Table 8 in our paper for more details.)
103 |
104 |
105 |
106 |
107 | ### CoAT demo usage
108 |
109 | Here we provide a demo code for anyone who wants to try the CoAT on GPT-4V, Qwen-VL-Max and Gemini-1.0-Pro-Vision.
110 |
111 | Firstly, go to `coat/config.yaml` and add your own api-keys and urls.
112 |
113 | Secondly, run the folloiwng code in commad line to generate sematic components of CoAT framework:
114 |
115 | ```shell
116 | python run_coat.py --task "flow" --DEMO_MODE "COAT" --MODEL.NAME "openai/gemini/qwenvl" --num-threads 3
117 | ```
118 |
119 | Then, you can obtain the action prediction results by
120 |
121 | ```shell
122 | python run_coat.py --task "predict" --DEMO_MODE "COAT" --MODEL.NAME "openai/gemini/qwenvl" --num-threads 3
123 | ```
124 |
125 |
126 |
127 |
128 |
129 | ## Citation
130 |
131 | If you find our work helpful, please consider citing our paper.
132 |
133 | ```
134 | @misc{zhang2024android,
135 | title={Android in the Zoo: Chain-of-Action-Thought for GUI Agents},
136 | author={Jiwen Zhang and Jihao Wu and Yihua Teng and Minghui Liao and Nuo Xu and Xiao Xiao and Zhongyu Wei and Duyu Tang},
137 | year={2024},
138 | eprint={2403.02713},
139 | archivePrefix={arXiv},
140 | primaryClass={cs.CL}
141 | }
142 | ```
143 |
144 |
--------------------------------------------------------------------------------
/run_coat.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import yaml
4 | import time
5 | import random
6 | import pandas as pd
7 | from collections import defaultdict
8 | from yacs.config import CfgNode as CN
9 | from tqdm import tqdm
10 |
11 | from coat.model import get_model
12 | from coat.action_utils import extract_gt_action
13 | from coat.agent import ScreenAgent
14 | import threading
15 |
16 |
17 | class AITZDataset(object):
18 | """
19 | Creating a custom dataset for reading the processed AITW data
20 | with GPT-4V labeled detail semantic annotations.
21 | """
22 | DATASET_DIR = {
23 | 'general': '{}/general',
24 | 'google_apps': '{}/google_apps',
25 | 'install': '{}/install',
26 | 'single': '{}/single',
27 | 'web_shopping': '{}/web_shopping',
28 | }
29 |
30 | def __init__(self, split="test", data_dir="android-in-the-zoo", ratio=1.0, double_sample=False) -> None:
31 | self.ratio = ratio
32 | self.double_sample = double_sample
33 | self.data_dir = os.path.join(data_dir, split)
34 |
35 | self.episode_data = self._load_data_()
36 | self.data = self._split_to_steps_(self.episode_data)
37 |
38 | def _load_data_(self):
39 | valid_paths = defaultdict(list)
40 | for subset in self.DATASET_DIR:
41 | subdata_dir = self.DATASET_DIR[subset].format(self.data_dir)
42 | if os.path.exists(subdata_dir):
43 | sequence_names = os.listdir(subdata_dir)
44 | for seq_name in sequence_names:
45 | seq_dir = os.path.join(subdata_dir, seq_name)
46 | if not os.path.isdir(seq_dir): continue
47 | episode_path = os.path.join(seq_dir, f"{seq_name}.json")
48 | valid_paths[subset].append(episode_path)
49 |
50 | sampled_paths = []
51 | for subset, v_paths in valid_paths.items():
52 | N = len(v_paths)
53 | k=int(self.ratio*N)
54 | # random.sample(v_paths, k) if self.ratio < 1.0 else v_paths
55 | sampled_paths += random.sample(v_paths, k) if self.ratio < 1.0 else v_paths
56 |
57 | ep_data = []
58 | for episode_path in sampled_paths:
59 | episode_data = json.load(open(episode_path, "r"))
60 | ep_data.append(episode_data)
61 | return ep_data
62 |
63 | def _split_to_steps_(self, episode_data):
64 | data = []
65 | for edx, episode in enumerate(episode_data):
66 | history_plain_actions, history_coat_actions = [], []
67 | for idx, step in enumerate(episode):
68 | step['subset'] = step['image_path'].split('/')[0]
69 | step['image_full_path'] = os.path.join(self.data_dir, step['image_path'])
70 | step['prev_step_id'] = episode[idx-1]['step_id'] if idx > 0 else None
71 | next_img_path = os.path.join(self.data_dir, episode[idx+1]['image_path']) \
72 | if idx + 1 < len(episode) else None
73 | step['next_image_full_path'] = next_img_path
74 | step['history_actions'] = history_plain_actions[:]
75 | step['history_coat_actions'] = history_coat_actions[:]
76 | step['result_action'] = extract_gt_action(step)[1]
77 | for ui_key in ['ui_positions', 'ui_text', 'ui_types']:
78 | step[ui_key] = json.loads(step[ui_key])
79 | data.append(step)
80 | history_plain_actions.append(step['result_action'])
81 | history_coat_actions.append(step['coat_action_desc'])
82 |
83 | return data
84 |
85 | def __len__(self, ): return len(self.data)
86 |
87 | def __getitem__(self, index): return self.data[index]
88 |
89 | # ==========================================================================================
90 |
91 | def sinle_run_flow(agent:ScreenAgent, step_data:dict, save_dir:str):
92 | agent.flow(step_data, save_dir=save_dir)
93 |
94 |
95 | def sinle_run_predict(agent:ScreenAgent, step_data:dict, save_dir:str):
96 | agent.predict(step_data, save_dir=save_dir)
97 |
98 |
99 | def collect(cfg, num_threads=2, task="flow"):
100 | # cfg.MODEL.NAME = "openai"
101 | if task == "flow": todo_task = sinle_run_flow
102 | if task == "predict": todo_task = sinle_run_predict
103 |
104 | if cfg.MODEL.NAME in ["gemini", "openai"]:
105 | print("using proxy ... ")
106 | os.environ['http_proxy'] = "your_proxy"
107 | os.environ['https_proxy'] = "your_proxy"
108 |
109 | aitz = AITZDataset(cfg.DATA.SPLIT, cfg.DATA.DATA_DIR, ratio=0.1, double_sample=False)
110 | print(len(aitz), len(aitz.episode_data))
111 |
112 | save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME)
113 | agent = ScreenAgent(config=cfg)
114 |
115 | before_count = threading.active_count() + 1
116 | threads = []
117 | last_time = time.time()
118 | pbar = tqdm(aitz.data)
119 | for idx, step_data in enumerate(pbar):
120 | thread = threading.Thread(target=todo_task, args=(agent, step_data, save_dir))
121 | time.sleep(max(0.01-(time.time()-last_time), 0))
122 | thread.start()
123 | last_time = time.time()
124 | threads.append(thread)
125 | pbar.set_description(f"Active threads [{threading.active_count()-before_count}]")
126 | while threading.active_count() - before_count >= num_threads: time.sleep(1)
127 |
128 | if len(threads) == 100:
129 | for thr in threads: thr.join()
130 | threads = []
131 |
132 | print()
133 | while threading.active_count() - before_count >0: time.sleep(1)
134 | for thr in threads: thr.join()
135 |
136 | # ==========================================================================================
137 |
138 | def try_model(cfg):
139 | cfg.MODEL.NAME = "qwenvl"
140 |
141 | if cfg.MODEL.NAME in ["gemini", "openai"]:
142 | os.environ['http_proxy'] = "your_proxy"
143 | os.environ['https_proxy'] = "your_proxy"
144 |
145 | prompt = "Describe the image."
146 | image_path = "android-in-the-wild/aitw_with_gpt/test/general/GENERAL-1359994677477286277/GENERAL-1359994677477286277_2.png"
147 | model = get_model(cfg.MODEL, seed=2024)
148 | res_json, res_state = model.get_response(image_path, prompt)
149 | print(res_json, res_state)
150 | pass
151 |
152 | # ==========================================================================================
153 |
154 | if __name__ == "__main__":
155 | import argparse
156 |
157 | parser = argparse.ArgumentParser(description="CoAT Demo")
158 | parser.add_argument("--config-file", default="coat/config.yaml", metavar="FILE", help="path to config file",)
159 | parser.add_argument("--task", default="eval", type=str, choices=['try', 'flow', 'predict', 'eval'])
160 | parser.add_argument("--num-threads", default=1, type=int, help="number of threads")
161 | parser.add_argument("--seed", default=2020, type=int, help="random seed")
162 | parser.add_argument("opts", help="Modify config options using the command-line",
163 | default=None, nargs=argparse.REMAINDER, )
164 | args = parser.parse_args()
165 | print(args)
166 | random.seed(args.seed)
167 |
168 | cfg = CN(yaml.safe_load(open(args.config_file)))
169 | cfg.merge_from_list(args.opts)
170 |
171 | if args.task == "try":
172 | try_model(cfg)
173 | if args.task in ['flow', 'predict']:
174 | collect(cfg, num_threads=args.num_threads, task=args.task)
175 |
176 |
--------------------------------------------------------------------------------
/coat/screen_utils.py:
--------------------------------------------------------------------------------
1 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator, Dict
2 | from PIL import Image, ImageDraw, ImageFont
3 | from functools import cmp_to_key
4 | import numpy as np
5 |
6 | FONT_PATH = "/System/Library/Fonts/Helvetica.ttc"
7 |
8 | # ========================================================================================
9 |
10 | def node_compare(node1, node2):
11 | if node1 is None and node2 is None:
12 | return 0
13 | if node1 is None:
14 | return 1
15 | if node2 is None:
16 | return -1
17 |
18 | l1, t1, r1, b1 = node1['bounds'] # top_left, right_bot
19 | l2, t2, r2, b2 = node2['bounds']
20 |
21 | base_height = 0
22 | if t1 > t2: # node1的上边沿 低于 node2的上边沿
23 | base_height = (t2 + b2) / 2
24 | if t1 > base_height: return 1 # 说明node1的左上角在node2中心以下 --》 node1在node2的下一层
25 | elif t1 < t2: # node1的上边沿 高于 node2的上边沿
26 | base_height = (t1 + b1) / 2
27 | if t2 > base_height: return -1 # 说明node2的左上角在node1的中心以下 --》 node2在node1的下一层
28 |
29 | width_diff = l1 - l2
30 | if width_diff > 0: return 1 # node1在node2的右边
31 | elif width_diff < 0: return -1
32 |
33 | return 0
34 |
35 |
36 | def row_col_sort(nodes):
37 | """ 节点排序:从左到右, 从上到下
38 |
39 | 从上到下先判断元素是否在一行内, 只对同一行的元素从左往右进行排序
40 | """
41 | if len(nodes) <= 1: return nodes
42 |
43 | # 首先按照y轴进行排序, 需要同时考虑 y 轴的最小值和中心点的 y 轴坐标
44 | # nodes.sort(key=lambda x: (x['bounds'][1], (x['bounds'][1] + x['bounds'][3])*0.5))
45 |
46 | # 更新排序规则 -- 参考OCR
47 | nodes.sort(key=cmp_to_key(node_compare))
48 |
49 | first_node = nodes[0]
50 | other_node = nodes[1:]
51 | sort_dots = [[first_node]]
52 |
53 | line_index = 0
54 | for node in other_node:
55 | center_node_y = (node['bounds'][1] + node['bounds'][3]) * 0.5
56 |
57 | line_nodes = sort_dots[line_index]
58 | prev_avg_center_y = sum([(x['bounds'][1] + x['bounds'][3])*0.5 for x in line_nodes]) / len(line_nodes)
59 |
60 | if (node['bounds'][1] < prev_avg_center_y < node['bounds'][3]) \
61 | and (line_nodes[-1]['bounds'][1] < center_node_y < line_nodes[-1]['bounds'][3]):
62 | # 当前行的平均 y 中心点位于当前结点的 y 轴范围内, 并且
63 | # Y轴中心大于上一个点Y轴最小值、小于上一个点Y轴最大值 => 说明在同一行
64 | sort_dots[line_index].append(node)
65 | else: # 第二行或新增一行
66 | line_index += 1
67 | sort_dots.append([node])
68 |
69 | for dot in sort_dots: # 对每一行做X轴最小点排序
70 | dot.sort(key=lambda x: x['bounds'][0])
71 |
72 | new_nodes = [dot for dots in sort_dots for dot in dots]
73 | return new_nodes
74 |
75 | # ========================================================================================
76 |
77 |
78 | def draw_bbox(image_path, bboxes:List[Tuple[float, float, float, float]],
79 | texts=None, rgba_color=(0, 0, 255, 0), thickness=1, ret_corrds=False):
80 | """ Draw the bounding boxes with corresponding texts """
81 | image = Image.open(image_path)
82 | w, h = image.size
83 |
84 | text_coords = []
85 |
86 | with image.convert('RGBA') as base:
87 | tmp = Image.new("RGBA", base.size, (0, 0, 0, 0))
88 | # get a drawing context
89 | draw = ImageDraw.Draw(tmp)
90 | for idx, bbox in enumerate(bboxes):
91 | xmin, ymin, xmax, ymax = bbox[:4]
92 | xmin = min(max(0, xmin), w)
93 | xmax = min(max(xmin, xmax), w)
94 | ymin = min(max(0, ymin), h)
95 | ymax = min(max(ymin, ymax), h)
96 | # draw the boudning box
97 | draw.rectangle((xmin, ymin, xmax, ymax), outline=rgba_color, width=thickness)
98 |
99 | # draw text if any
100 | if texts:
101 | text = texts[idx]
102 | box_height, box_width = ymax - ymin, xmax - xmin
103 | font_size = int(min(max(int(box_height * 0.7), 14), 36))
104 |
105 | font = ImageFont.truetype(FONT_PATH, font_size, encoding="utf-8")
106 | left, top, right, bot = font.getbbox(text)
107 | coords = [
108 | xmin, ymin,
109 | xmin + right*1.1, ymin,
110 | xmin + right*1.1, ymin - bot*1.1,
111 | xmin, ymin - bot*1.1
112 | ]
113 | draw.polygon(coords, fill=rgba_color)
114 | draw.text((xmin, ymin - bot*1.05), text, (255,255,255,255), font=font)
115 | text_coords.append([coords[0], coords[5], coords[2], coords[1]])
116 |
117 | out = Image.alpha_composite(base, tmp)
118 |
119 | if ret_corrds: return out, text_coords
120 | return out
121 |
122 |
123 | def enlarge_bbox(bbox_list, scale_factor=1.2)->np.ndarray:
124 | """
125 | 将每个 bounding box 放大一定倍数。
126 |
127 | :param bbox_list: bounding box 列表, 每个 bbox 是一个包含四个值的元组或列表, 表示 (xmin, ymin, xmax, ymax)
128 | :param scale_factor: 放大倍数
129 | :return: 放大后的 bounding box 列表
130 | """
131 | bbox_array = np.array(bbox_list)
132 | x_min, y_min, x_max, y_max = \
133 | bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3]
134 |
135 | # 计算每个 bounding box 的中心点
136 | x_center = (x_min + x_max) / 2
137 | y_center = (y_min + y_max) / 2
138 |
139 | # 计算每个 bounding box 的宽度和高度
140 | width = (x_max - x_min) * scale_factor
141 | height = (y_max - y_min) * scale_factor
142 |
143 | # 计算放大后的 bounding box 的新的坐标
144 | new_x_min = x_center - width / 2
145 | new_y_min = y_center - height / 2
146 | new_x_max = x_center + width / 2
147 | new_y_max = y_center + height / 2
148 |
149 | # 将新的坐标组合成 bounding box 列表
150 | enlarged_bbox_list = np.vstack((new_x_min, new_y_min, new_x_max, new_y_max)).T
151 |
152 | return enlarged_bbox_list
153 |
154 |
155 |
156 | def check_inside(x, y, bbox_array):
157 | """
158 | 判断一个坐标 (x, y) 是否在一个 bounding box 列表里面, 使用 NumPy 以提高效率。
159 | 同时返回所在的所有 bounding box 的坐标。
160 |
161 | :param x: 坐标的 x 值
162 | :param y: 坐标的 y 值
163 | :param bbox_array: bounding box 列表, 每个 bbox 是一个包含四个值的元组或列表, 表示 (xmin, ymin, xmax, ymax)
164 | :return: 一个元组, 第一个元素是布尔值, 如果坐标在任意一个 bounding box 内为 True, 否则为 False;
165 | 第二个元素是包含所有所在 bounding box 坐标的列表
166 | """
167 | x_min, y_min, x_max, y_max = bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3]
168 |
169 | # 检查 (x, y) 是否在任意一个 bounding box 内
170 | within_x = (x_min <= x) & (x <= x_max)
171 | within_y = (y_min <= y) & (y <= y_max)
172 | within_bbox = within_x & within_y
173 |
174 | if np.any(within_bbox):
175 | within_bbox_coords = bbox_array[within_bbox]
176 | return True, within_bbox_coords
177 | else:
178 | return False, None
179 |
180 |
181 | def intersect_iou(can_bbox, ref_bboxes):
182 | """
183 | 计算一个边界框和一组边界框的IoU。
184 |
185 | 参数:
186 | - can_bbox: NumPy数组或列表list, 形状为[4,], 表示一个边界框[x_min, y_min, x_max, y_max]
187 | - ref_bboxes: NumPy数组, 形状为[N, 4], 表示N个边界框
188 |
189 | 返回:
190 | - ious: NumPy数组, 形状为[N,], 表示输入边界框和每个边界框的IoU
191 | """
192 | # 计算交集的坐标
193 | inter_xmin = np.maximum(can_bbox[0], ref_bboxes[:, 0])
194 | inter_ymin = np.maximum(can_bbox[1], ref_bboxes[:, 1])
195 | inter_xmax = np.minimum(can_bbox[2], ref_bboxes[:, 2])
196 | inter_ymax = np.minimum(can_bbox[3], ref_bboxes[:, 3])
197 |
198 | # 计算交集的面积
199 | inter_area = np.maximum(0, inter_xmax - inter_xmin) * \
200 | np.maximum(0, inter_ymax - inter_ymin)
201 |
202 | # 计算候选边界框的面积
203 | can_bbox_area = np.maximum((can_bbox[2] - can_bbox[0]) * \
204 | (can_bbox[3] - can_bbox[1]), 1)
205 | # 计算参考边界框的面积
206 | ref_bboxes_area = np.maximum((ref_bboxes[:, 2] - ref_bboxes[:, 0]) * \
207 | (ref_bboxes[:, 3] - ref_bboxes[:, 1]), 1)
208 |
209 | # 计算并集的面积
210 | union_area = can_bbox_area + ref_bboxes_area - inter_area
211 |
212 | # 计算IoU
213 | ious = inter_area / union_area
214 | return ious
--------------------------------------------------------------------------------
/coat/evaluate.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import imagesize
5 | import numpy as np
6 | import pandas as pd
7 | import Levenshtein
8 | from collections import defaultdict
9 | from .screen_utils import row_col_sort, enlarge_bbox, check_inside, intersect_iou
10 |
11 |
12 | class ActionEvaluator(object):
13 | BBOX_PATTERN = re.compile(r'\[ *(\d+) *, *(\d+) *, *(\d+) *, *(\d+) *\]')
14 |
15 | def __init__(self, demo_mode, screen_mode) -> None:
16 | self.demo_mode = demo_mode
17 | self.screen_mode = screen_mode
18 |
19 | self._all_action_types_ = [
20 | "click",
21 | "scroll",
22 | "type",
23 | "press",
24 | "stop"
25 | ]
26 |
27 | def action_map(self, action_api:str):
28 | if not action_api: return None
29 | action_api = action_api.lower()
30 | if action_api == "input": return "type"
31 | for act_type in self._all_action_types_:
32 | if act_type in action_api: return act_type
33 | return None
34 |
35 | def _parse_action_(self, pred, w, h):
36 | pr = pred.get('action_predict', {})
37 | if self.demo_mode not in pr: return (None,) * 6
38 |
39 | action = pr[self.demo_mode].get(self.screen_mode, {})
40 | if not action: return (None,) * 6
41 |
42 | pd_action_type = self.action_map(action.get('ACTION', None))
43 | # if pd_action_type is None: print(action)
44 |
45 | pd_action_args = action.get('ARGS', {})
46 | if not isinstance(pd_action_args, dict): pd_action_args = {}
47 | pd_action_bbox = pd_action_args.get('bbox', None)
48 | if pd_action_bbox is not None:
49 | xmin, ymin, xmax, ymax = pd_action_bbox[:4]
50 | xmin = round(xmin/1000 * w)
51 | ymin = round(ymin/1000 * h)
52 | xmax = round(xmax/1000 * w)
53 | ymax = round(ymax/1000 * h)
54 | pd_action_bbox = [xmin, ymin, xmax, ymax]
55 |
56 | pd_action_idx = pd_action_args.get('idx', None)
57 | if pd_action_idx:
58 | try: pd_action_idx = int(pd_action_idx)
59 | except: pd_action_idx = None
60 | pd_action_direction = pd_action_args.get('direction', None)
61 | pd_action_text = pd_action_args.get('text', "")
62 | pd_action_button = None if pd_action_type != "press" else \
63 | action['ACTION'].split("_")[1].lower()
64 |
65 | return pd_action_type, pd_action_bbox, pd_action_idx, \
66 | pd_action_direction, pd_action_text, pd_action_button
67 |
68 | def _parse_answer_(self, gt):
69 | gt_words = gt['coat_action_desc'].split(' ')
70 |
71 | gt_action_type = self.action_map(gt_words[0])
72 | if gt_action_type is None: print(gt['subset'], gt['episode_id'])
73 | gt_action_text = gt['result_action_text']
74 | gt_action_direction = "" if gt_action_type != "scroll" else gt_words[1].strip()
75 | gt_action_button = ""
76 | if gt_action_type == "press":
77 | for button in ['enter', 'back', 'home']:
78 | if button in gt['coat_action_desc']:
79 | gt_action_button = button
80 | break
81 |
82 | w, h = imagesize.get(gt['image_full_path'])
83 | gt_action_xy = [0, 0]
84 | if gt_action_type == "scroll":
85 | rel_y, rel_x = json.loads(gt['result_touch_yx'])
86 | abs_y, abs_x = int(rel_y*h), int(rel_x*w)
87 | gt_action_xy = [abs_x, abs_y]
88 |
89 | gt_cand_nodes = []
90 | for org_bbox, txt, ui_class in zip(gt['ui_positions'], gt['ui_text'], gt['ui_types']):
91 | ymin, xmin, h, w = org_bbox
92 | bbox = [xmin, ymin, xmin+w, ymin+h]
93 | gt_cand_nodes.append({"bounds": bbox, "text": txt, "type": ui_class})
94 | gt_cand_nodes = row_col_sort(gt_cand_nodes)
95 |
96 | return gt_action_type, gt_action_xy, gt_cand_nodes, \
97 | gt_action_text, gt_action_button, gt_action_direction
98 |
99 | def _check_click_(self, pred_bbox, gt_xy, gt_nodes):
100 | # gt_xy is within pred_bbox
101 | if not pred_bbox: return False
102 | pred_bbox = enlarge_bbox([pred_bbox])[0]
103 | xmin, ymin, xmax, ymax = pred_bbox
104 | gt_x, gt_y = gt_xy
105 | is_correct = (xmin <= gt_x <= xmax and ymin <= gt_y <= ymax)
106 | if is_correct: return True
107 |
108 | # gt_xy is within any bbox
109 | bbox_array = enlarge_bbox([x['bounds'] for x in gt_nodes], scale_factor=1.2)
110 | is_inside, bbox_inside = check_inside(gt_x, gt_y, bbox_array)
111 | if is_inside:
112 | ious = intersect_iou(pred_bbox, bbox_inside)
113 | if np.any(ious > 0.5): return True
114 |
115 | return False
116 |
117 | def __call__(self, gt, pred):
118 | """ eval_single_step """
119 | subset, episode_id, step_id = gt['subset'], gt['episode_id'], gt['step_id']
120 | w, h = imagesize.get(gt['image_full_path'])
121 |
122 | # get ground truth information
123 | gt_action_type, gt_action_xy, gt_cand_nodes, \
124 | gt_action_text, gt_action_button, gt_action_direction = self._parse_answer_(gt)
125 | if not gt_action_type: print(gt['coat_action_desc'])
126 | gt_action_detail = {
127 | "click": gt_action_xy, "scroll": gt_action_direction,
128 | "type": gt_action_text, "press": gt_action_button, "stop": "stop"
129 | }.get(gt_action_type, None)
130 |
131 | # get predict action information
132 | pd_action_type, pd_action_bbox, pd_action_idx, \
133 | pd_action_direction, pd_action_text, pd_action_button = self._parse_action_(pred, w, h)
134 |
135 | # compute metrics
136 | hit_format = True if pd_action_type is not None else False
137 | type_match = (pd_action_type is not None and gt_action_type == pd_action_type)
138 |
139 | exact_match = False
140 | pd_action_detail = None
141 | text_dist = None
142 | if type_match and pd_action_type == "click":
143 | if self.screen_mode == "tag" and pd_action_idx: # transform idx into bbox
144 | if 0 <= pd_action_idx < len(gt_cand_nodes):
145 | pd_action_bbox = gt_cand_nodes[pd_action_idx]['bounds']
146 | pd_action_detail = pd_action_bbox
147 | exact_match = self._check_click_(pd_action_bbox, gt_action_xy, gt_cand_nodes)
148 |
149 | if type_match and pd_action_type == "scroll":
150 | pd_action_detail = pd_action_direction
151 | exact_match = (pd_action_direction == gt_action_direction)
152 |
153 | if type_match and pd_action_type == "type":
154 | pd_action_detail = pd_action_text
155 | text_dist = Levenshtein.ratio(pd_action_text, gt_action_text)
156 | exact_match = (pd_action_text in gt_action_text or \
157 | gt_action_text in pd_action_text or \
158 | text_dist > 0.8)
159 |
160 | if type_match and pd_action_type == "press":
161 | pd_action_detail = pd_action_button
162 | exact_match = (pd_action_button == gt_action_button)
163 |
164 | if type_match and pd_action_type == "stop":
165 | pd_action_detail = "stop"
166 | exact_match = True
167 |
168 | return {
169 | "subset": subset, "episode_id": episode_id, "step_id": step_id,
170 | "answer": {"action_type": gt_action_type, "action_detail": gt_action_detail},
171 | "pred": {"action_type": pd_action_type, "action_detail": pd_action_detail},
172 | "type_match": type_match, "exact_match": exact_match,
173 | "text_dist": text_dist, "format_hit": hit_format
174 | }
175 |
176 | def compute_episode_metrics(self, episode_results):
177 | success, progress = [], []
178 | for __, eplist in episode_results.items():
179 | ep_success, ep_progress = True, 0
180 | for ex in eplist:
181 | if ex['exact_match'] is True: ep_progress += 1
182 | else: ep_success = False
183 | if not ep_success: break
184 | success.append(ep_success)
185 | progress.append(ep_progress/len(eplist)*1.0)
186 |
187 | return {"success_rate": round(sum(success) / len(success), 4),
188 | "goal_progress": round(sum(progress) / len(progress), 4)}
189 |
190 | def compute_atomic_metrics(self, step_results):
191 | recorder = {
192 | 'total': {'count':0, 'type_match':0, 'exact_match':0, "hit": 0},
193 | # -------------------------------------------
194 | 'CLICK': {'count':0, 'type_match':0, 'exact_match':0},
195 | 'TYPE': {'count':0, 'type_match':0, 'exact_match':0, 'text_dist': []},
196 | 'SCROLL': {'count':0, 'type_match':0, 'exact_match':0},
197 | 'PRESS': {'count':0, 'type_match':0, 'exact_match':0},
198 | 'STOP': {'count':0, 'type_match':0, 'exact_match':0},
199 | }
200 |
201 | for step in step_results:
202 | recorder['total']['count'] += 1
203 | recorder['total']['hit'] += step['format_hit']
204 |
205 | action_type = step['answer']['action_type'].upper()
206 | recorder[action_type]['count'] += 1
207 | recorder[action_type]['type_match'] += step['type_match']
208 | recorder['total']['type_match'] += step['type_match']
209 | recorder[action_type]['exact_match'] += step['exact_match']
210 | recorder['total']['exact_match'] += step['exact_match']
211 | if 'text_dist' in recorder[action_type] and step['text_dist'] is not None:
212 | recorder[action_type]['text_dist'].append(step['text_dist'])
213 |
214 | scores = {metric_key:{} for metric_key in ['total', 'CLICK', 'SCROLL', 'PRESS', 'STOP', 'TYPE']}
215 | scores['total']['hit_rate'] = round(recorder['total']['hit']/recorder['total']['count'], 4)
216 | for metric_key in ['total', 'CLICK', 'SCROLL', 'PRESS', 'STOP', "TYPE"]:
217 | scores[metric_key]['type_acc'] = round(recorder[metric_key]['type_match']/recorder[metric_key]['count'], 4)
218 | scores[metric_key]['exact_acc'] = round(recorder[metric_key]['exact_match']/recorder[metric_key]['count'], 4)
219 | if recorder['TYPE']['text_dist']:
220 | scores['TYPE']['text_dist'] = round(sum(recorder['TYPE']['text_dist'])/len(recorder['TYPE']['text_dist']), 4)
221 | return scores
222 |
223 |
--------------------------------------------------------------------------------
/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151.json:
--------------------------------------------------------------------------------
1 | [
2 | {
3 | "episode_id": "523638528775825151",
4 | "episode_length": 4,
5 | "step_id": 0,
6 | "instruction": "open app \"Clock\" (install if not already installed)",
7 | "ui_positions": "[[54, 17, 8, 12], [82, 15, 8, 17], [82, 37, 10, 46], [120, 15, 20, 10], [128, 43, 8, 24], [160, 15, 8, 13], [162, 43, 7, 96], [189, 17, 14, 7], [194, 43, 5, 22], [225, 15, 10, 12], [227, 43, 8, 91], [256, 15, 15, 9], [261, 43, 8, 20], [574, 51, 15, 6], [574, 207, 16, 7]]",
8 | "ui_text": "[\"M\", \"Set\", \"up email\", \"G\", \"Coogle\", \"o\", \"Outlook, Homail, and Live\", \"\", \"Yatpe\", \"O\", \"Exchange and Ofioe 365\", \"\", \"Uthe\", \"\", \"\"]",
9 | "ui_types": "[\"TEXT\", \"TEXT\", \"TEXT\", \"ICON_GOOGLE\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_ENVELOPE\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_ENVELOPE\", \"TEXT\", \"ICON_V_BACKWARD\", \"ICON_NAV_BAR_RECT\"]",
10 | "result_action_type": 6,
11 | "result_action_text": "",
12 | "result_touch_yx": "[-1.0, -1.0]",
13 | "result_lift_yx": "[-1.0, -1.0]",
14 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png",
15 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png",
16 | "coat_screen_desc": "This screenshot shows the \"Set up email\" screen on a smartphone, likely within an email application setup process. The screen offers a selection of popular email service providers including Google, Outlook (Hotmail and Live), Yahoo, Exchange and Office 365, as well as an \"Other\" option for email services not listed. Users can presumably tap on one of these options to begin the process of adding their email account to the device. The top of the screen displays standard status icons such as signal strength, battery, and time, indicating it's 9:35. The bottom of the screen shows navigation buttons for back, home, and recent applications, common to Android operating systems.",
17 | "coat_action_think": "The screen does not display the \"Clock\" app or any relevant features to access it; it is focused on email setup. Possible actions are to exit the email setup screen by tapping the home or back button, and then proceed to open the \"Clock\" app from the home screen or app drawer, or install it if it is not present.",
18 | "coat_action_desc": "press the home button",
19 | "coat_action_result": "By doing so, the home screen is displayed with app icons visible. This allows access to the app drawer or search function, where the Clock app can be located and opened."
20 | },
21 | {
22 | "episode_id": "523638528775825151",
23 | "episode_length": 4,
24 | "step_id": 1,
25 | "instruction": "open app \"Clock\" (install if not already installed)",
26 | "ui_positions": "[[56, 24, 8, 52], [366, 33, 28, 15], [368, 86, 30, 31], [373, 223, 15, 7], [408, 155, 5, 22], [409, 24, 5, 31], [409, 94, 5, 18], [409, 215, 5, 28], [464, 33, 23, 13], [465, 94, 21, 15], [526, 34, 21, 10], [529, 24, 23, 21], [573, 207, 16, 7], [575, 51, 14, 6]]",
27 | "ui_text": "[\"Man, Aug 8\", \"\", \"M\", \"\", \"Fhotca\", \"Pey Stoe\", \"imal\", \"auTuba\", \"\", \"\", \"\", \"G\", \"\", \"\"]",
28 | "ui_types": "[\"TEXT\", \"ICON_PLAY\", \"TEXT\", \"ICON_PLAY\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_CALL\", \"ICON_CHAT\", \"ICON_GOOGLE\", \"TEXT\", \"ICON_NAV_BAR_RECT\", \"ICON_V_BACKWARD\"]",
29 | "result_action_type": 4,
30 | "result_action_text": "",
31 | "result_touch_yx": "[0.541063666343689, 0.5073748230934143]",
32 | "result_lift_yx": "[0.001115699764341116, 0.5788536071777344]",
33 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png",
34 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png",
35 | "coat_screen_desc": "This is a screenshot of a smartphone home screen displaying a clean and simple layout. The top of the screen shows the time as \"5:35 AM\" and the date \"Mon, Aug 8\". There are five app icons visible: Play Store, Gmail, Photos, YouTube, along with a bottom row featuring the Phone, Chrome, Messages, and Camera apps, suggesting these are likely the user's most frequently used applications. A Google search bar is at the bottom, emphasizing easy access to web searches. The overall appearance is typical of an Android interface.",
36 | "coat_action_think": "Since the \"Clock\" app is not among the visible icons on the home screen, the next step would be to search for it by opening the app drawer or using the search feature. Possible actions are to swipe up on the home screen to access the app drawer or tap the Google search bar to type in \"Clock\" and search for the app.",
37 | "coat_action_desc": "scroll up",
38 | "coat_action_result": "By doing so, the screen now displays a list of applications including the Clock app which was not visible on the home screen before the action. This action enables the user to access and open the Clock app directly from the app drawer where it is located."
39 | },
40 | {
41 | "episode_id": "523638528775825151",
42 | "episode_length": 4,
43 | "step_id": 2,
44 | "instruction": "open app \"Clock\" (install if not already installed)",
45 | "ui_positions": "[[34, 236, 16, 5], [40, 24, 8, 138], [82, 36, 30, 15], [88, 160, 16, 9], [88, 212, 18, 24], [124, 27, 5, 33], [124, 93, 5, 22], [124, 150, 5, 28], [124, 216, 5, 18], [160, 121, 5, 7], [161, 133, 5, 15], [195, 216, 23, 17], [205, 91, 7, 24], [211, 34, 7, 18], [235, 36, 5, 17], [235, 83, 7, 43], [235, 150, 5, 28], [235, 210, 5, 21], [281, 34, 27, 15], [283, 218, 20, 10], [321, 37, 5, 15], [321, 91, 5, 15], [321, 156, 5, 18], [321, 206, 5, 7], [405, 138, 8, 45], [406, 217, 7, 15], [407, 33, 5, 21], [407, 96, 5, 13], [449, 94, 32, 18], [449, 191, 37, 43], [451, 217, 25, 15], [456, 31, 17, 24], [491, 149, 7, 31], [492, 34, 5, 17], [492, 93, 7, 22], [492, 209, 5, 33], [536, 37, 28, 12], [540, 94, 21, 15], [577, 87, 5, 34], [577, 215, 7, 21], [578, 36, 5, 24], [578, 155, 5, 20]]",
46 | "ui_text": "[\"\", \"Search your phone and more\", \"\", \"\", \"M\", \"Plaw S1ae\", \"Photcs\", \"YouTube\", \"Cmall\", \"All\", \"apps\", \"\", \"Bocking\", \"oirtel\", \"Airtel\", \"Hrokinc cor\", \"Laendar\", \"Cema\", \"\", \"\", \"Chal\", \"Chru\", \"Cleck\", \"C\", \"File Mener\", \"Fles\", \"Dagher\", \"Drve\", \"G\", \"O\", \"\", \"M\", \"HRO MAX\", \"mal\", \"Coagla\", \"Irataamy\", \"\", \"\", \"Masscces\", \"hotes\", \"Maps\", \"Phanc\"]",
47 | "ui_types": "[\"ICON_THREE_DOTS\", \"TEXT\", \"ICON_PLAY\", \"ICON_PLAY\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_TAKE_PHOTO\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_CHAT\", \"ICON_PERSON\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_GOOGLE\", \"TEXT\", \"ICON_TAKE_PHOTO\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_LOCATION\", \"ICON_CHAT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\"]",
48 | "result_action_type": 4,
49 | "result_action_text": "",
50 | "result_touch_yx": "[0.49836206436157227, 0.6069772839546204]",
51 | "result_lift_yx": "[0.49669790267944336, 0.6069772839546204]",
52 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png",
53 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png",
54 | "coat_screen_desc": "This is a screenshot of a smartphone home screen displaying various application icons, suggesting a user interface likely from an Android device. The top of the screen shows the notification bar with the time, battery status, and connectivity icons. Below that is a search bar, followed by frequently used apps such as Play Store, Photos, YouTube, and Gmail. Further down, there's a categorized section labeled \"All apps\" showcasing a grid of additional applications including Airtel, Booking.com, Calendar, and others, indicating the user has access to a range of services and tools for productivity, communication, and entertainment.",
55 | "coat_action_think": "The Clock app is displayed in the list of apps under the \"All apps\" section on the home screen. Possible actions are to tap the Clock app icon to open it.",
56 | "coat_action_desc": "click on the Clock app located at the upper middle right side of the screen.",
57 | "coat_action_result": "By doing so, the Clock app has been successfully opened, providing access to features such as alarms, world clocks, and timers. This action enables the user to manage time-related functions, which is the likely intention behind the query."
58 | },
59 | {
60 | "episode_id": "523638528775825151",
61 | "episode_length": 4,
62 | "step_id": 3,
63 | "instruction": "open app \"Clock\" (install if not already installed)",
64 | "ui_positions": "[[20, 86, 16, 9], [23, 254, 15, 3], [43, 21, 5, 20], [43, 83, 5, 18], [43, 144, 7, 20], [43, 197, 7, 37], [82, 84, 28, 96], [123, 111, 8, 45], [573, 130, 16, 7], [574, 51, 15, 6], [574, 207, 16, 7]]",
65 | "ui_text": "[\"\", \"\", \"larm\", \"Clock\", \"Tirmer\", \"Stopwatch\", \"5:35AM\", \"Mon, Aug 8\", \"\", \"\", \"\"]",
66 | "ui_types": "[\"ICON_TIME\", \"ICON_THREE_DOTS\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_SUN\", \"ICON_V_BACKWARD\", \"ICON_NAV_BAR_RECT\"]",
67 | "result_action_type": 10,
68 | "result_action_text": "",
69 | "result_touch_yx": "[-1.0, -1.0]",
70 | "result_lift_yx": "[-1.0, -1.0]",
71 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png",
72 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png",
73 | "coat_screen_desc": "The screenshot shows the clock application of a smartphone displayed at 5:35 AM, Mon, Aug 8. It offers four functionalities: Alarm, Clock, Timer, and Stopwatch, indicated by icons at the top of the screen. The main display is digital, with large numbers for the hours and minutes. At the bottom, there's a button with a clock icon, presumably for setting or accessing alarm functions. The status bar at the very top displays the current battery level, signal strength, and indicates alarm and night mode are active.",
74 | "coat_action_think": "The screen is displaying the Clock application with its four main functions visible at the top: Alarm, Clock, Timer, and Stopwatch. The task appears to be in relation to accessing the Clock app, which is already on the screen. Possible actions are to stop and set the query as completed.",
75 | "coat_action_desc": "stop and set the query as completed",
76 | "coat_action_result": "By opening the \"Clock\" app, the user can interact with features such as setting alarms, checking the time, using a timer, or a stopwatch. The reason for opening this app is to provide the user with access to time-related functions that they have requested to use."
77 | }
78 | ]
--------------------------------------------------------------------------------
/coat/action_utils.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2023 The Google Research Authors.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | import json
17 | import enum
18 | import imagesize
19 | # import jax
20 | import jax.numpy as jnp
21 | import numpy as np
22 |
23 |
24 | def extract_gt_action(example):
25 | ex_action_type = example['result_action_type']
26 |
27 | if ex_action_type == ActionType.DUAL_POINT:
28 | lift_yx = json.loads(example['result_lift_yx'])
29 | touch_yx = json.loads(example['result_touch_yx'])
30 | if is_tap_action(np.array(touch_yx), np.array(lift_yx)):
31 | action_type = 'CLICK'
32 | w, h = imagesize.get(example['image_full_path'])
33 | click_y, click_x = round(lift_yx[0] * h), round(lift_yx[1] * w)
34 | action = (click_y, click_x)
35 | else:
36 | action_type = 'SCROLL'
37 | v_change = abs(touch_yx[0] - lift_yx[0])
38 | h_change = abs(lift_yx[1] - touch_yx[1])
39 | is_scroll_up = lift_yx[0] < touch_yx[0] # touch is lower
40 | is_scroll_left = lift_yx[1] < touch_yx[1] # touch is bigger
41 | if v_change >= 0.9*h_change: # vertical
42 | action = "scroll up" if is_scroll_up else "scroll down"
43 | else: # horizonal
44 | action = "scroll left" if is_scroll_left else "scroll right"
45 | elif ex_action_type in (
46 | ActionType.PRESS_BACK,
47 | ActionType.PRESS_HOME,
48 | ):
49 | button = ActionType(ex_action_type).name.split('_')[1].lower()
50 | action = f'press the {button} button'
51 | action_type = f'PRESS {button}'.upper()
52 | elif ex_action_type == ActionType.PRESS_ENTER:
53 | action = "press enter"
54 | action_type = 'PRESS ENTER'
55 | elif ex_action_type == ActionType.TYPE:
56 | action_text = example['result_action_text']
57 | action = f'input text "{action_text}"'
58 | action_type = 'INPUT'
59 | elif ex_action_type == ActionType.STATUS_TASK_COMPLETE:
60 | action = 'stop and set the query as completed'
61 | action_type = 'STOP'
62 | elif ex_action_type == ActionType.STATUS_TASK_IMPOSSIBLE:
63 | action = 'stop and set the query as impossible'
64 | action_type = 'STOP'
65 | else:
66 | raise NotImplementedError
67 |
68 | return action, action_type
69 |
70 |
71 | '''======================================
72 | Global Args
73 | ======================================'''
74 |
75 |
76 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen
77 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4
78 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4
79 |
80 | # Interval determining if an action is a tap or a swipe.
81 | _SWIPE_DISTANCE_THRESHOLD = 0.04
82 |
83 |
84 | def is_tap_action(normalized_start_yx, normalized_end_yx):
85 | distance = jnp.linalg.norm(
86 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
87 | return distance <= _SWIPE_DISTANCE_THRESHOLD
88 |
89 |
90 | def _is_non_dual_point_action(action_type):
91 | return jnp.not_equal(action_type, ActionType.DUAL_POINT)
92 |
93 |
94 | '''======================================
95 | === AndroidInTheWild action types. ===
96 | ======================================'''
97 |
98 |
99 | class ActionType(enum.IntEnum):
100 | """Integer values for each supported action type in AndroidInTheWild."""
101 |
102 | # Placeholders for unused enum values
103 | UNUSED_0 = 0
104 | UNUSED_1 = 1
105 | UNUSED_2 = 2
106 | UNUSED_8 = 8
107 | UNUSED_9 = 9
108 |
109 | ########### Agent actions ###########
110 |
111 | # A type action that sends text to the emulator. Note that this simply sends
112 | # text and does not perform any clicks for element focus or enter presses for
113 | # submitting text.
114 | TYPE = 3
115 |
116 | # The dual point action used to represent all gestures.
117 | DUAL_POINT = 4
118 |
119 | # These actions differentiate pressing the home and back button from touches.
120 | # They represent explicit presses of back and home performed using ADB.
121 | PRESS_BACK = 5
122 | PRESS_HOME = 6
123 |
124 | # An action representing that ADB command for hitting enter was performed.
125 | PRESS_ENTER = 7
126 |
127 | ########### Episode status actions ###########
128 |
129 | # An action used to indicate the desired task has been completed and resets
130 | # the environment. This action should also be used in the case that the task
131 | # has already been completed and there is nothing to do.
132 | # e.g. The task is to turn on the Wi-Fi when it is already on
133 | STATUS_TASK_COMPLETE = 10
134 |
135 | # An action used to indicate that desired task is impossible to complete and
136 | # resets the environment. This can be a result of many different things
137 | # including UI changes, Android version differences, etc.
138 | STATUS_TASK_IMPOSSIBLE = 11
139 |
140 |
141 | '''===============================================================================
142 | === Utilites for performing action matching on AndroidInTheWild data. ===
143 | === Note: this code is implemented using JAX so it can be 'vmap'ed and ===
144 | === efficiently appied over a batch of data. ===
145 | ==============================================================================='''
146 |
147 |
148 | def _yx_in_bounding_boxes(
149 | yx, bounding_boxes
150 | ):
151 | """Check if the (y,x) point is contained in each bounding box.
152 |
153 | Args:
154 | yx: The (y, x) coordinate in pixels of the point.
155 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row
156 | represents a bounding box: (y_top_left, x_top_left, box_height,
157 | box_width). Note: containment is inclusive of the bounding box edges.
158 |
159 | Returns:
160 | is_inside: A 1D bool array where each element specifies if the point is
161 | contained within the respective box.
162 | """
163 | y, x = yx
164 |
165 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the
166 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension.
167 | top, left, height, width = [
168 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1)
169 | ]
170 |
171 | # The y-axis is inverted for AndroidEnv, so bottom = top + height.
172 | bottom, right = top + height, left + width
173 |
174 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(x >= left, x <= right)
175 |
176 |
177 | def _resize_annotation_bounding_boxes(
178 | annotation_positions, annotation_width_augment_fraction,
179 | annotation_height_augment_fraction):
180 | """Resize the bounding boxes by the given fractions.
181 |
182 | Args:
183 | annotation_positions: Array of shape (N, 4), where each row represents the
184 | (y, x, height, width) of the bounding boxes.
185 | annotation_width_augment_fraction: The fraction to augment the box widths,
186 | E.g., 1.4 == 240% total increase.
187 | annotation_height_augment_fraction: Same as described for width, but for box
188 | height.
189 |
190 | Returns:
191 | Resized bounding box.
192 |
193 | """
194 | height_change = (
195 | annotation_height_augment_fraction * annotation_positions[:, 2])
196 | width_change = (
197 | annotation_width_augment_fraction * annotation_positions[:, 3])
198 |
199 | # Limit bounding box positions to the screen.
200 | resized_annotations = jnp.stack([
201 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)),
202 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)),
203 | jnp.minimum(1, annotation_positions[:, 2] + height_change),
204 | jnp.minimum(1, annotation_positions[:, 3] + width_change),
205 | ],axis=1)
206 |
207 | return resized_annotations
208 |
209 |
210 | def is_tap_action(normalized_start_yx, normalized_end_yx):
211 | distance = jnp.linalg.norm(
212 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx))
213 | return distance <= _SWIPE_DISTANCE_THRESHOLD
214 |
215 |
216 | def _is_non_dual_point_action(action_type):
217 | return jnp.not_equal(action_type, ActionType.DUAL_POINT)
218 |
219 |
220 | def _check_tap_actions_match(
221 | tap_1_yx,
222 | tap_2_yx,
223 | annotation_positions,
224 | matching_tap_distance_threshold_screen_percentage,
225 | annotation_width_augment_fraction,
226 | annotation_height_augment_fraction,
227 | ):
228 | """Determines if two tap actions are the same."""
229 | resized_annotation_positions = _resize_annotation_bounding_boxes(
230 | annotation_positions,
231 | annotation_width_augment_fraction,
232 | annotation_height_augment_fraction,
233 | )
234 |
235 | # Check if the ground truth tap action falls in an annotation's bounding box.
236 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions)
237 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions)
238 | both_in_box = jnp.max(tap1_in_box & tap2_in_box)
239 |
240 | # If the ground-truth tap action falls outside any of the annotation
241 | # bounding boxes or one of the actions is inside a bounding box and the other
242 | # is outside bounding box or vice versa, compare the points using Euclidean
243 | # distance.
244 | within_threshold = (
245 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx))
246 | <= matching_tap_distance_threshold_screen_percentage
247 | )
248 | return jnp.logical_or(both_in_box, within_threshold)
249 |
250 |
251 | def _check_drag_actions_match(
252 | drag_1_touch_yx,
253 | drag_1_lift_yx,
254 | drag_2_touch_yx,
255 | drag_2_lift_yx,
256 | ):
257 | """Determines if two drag actions are the same."""
258 | # Store drag deltas (the change in the y and x coordinates from touch to
259 | # lift), magnitudes, and the index of the main axis, which is the axis with
260 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and
261 | # ending at (0.3, 0.5) has a main axis index of 1).
262 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx
263 | drag_1_magnitudes = jnp.abs(drag_1_deltas)
264 | drag_1_main_axis = np.argmax(drag_1_magnitudes)
265 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx
266 | drag_2_magnitudes = jnp.abs(drag_2_deltas)
267 | drag_2_main_axis = np.argmax(drag_2_magnitudes)
268 |
269 | return jnp.equal(drag_1_main_axis, drag_2_main_axis)
270 |
271 |
272 | def check_actions_match(
273 | action_1_touch_yx,
274 | action_1_lift_yx,
275 | action_1_action_type,
276 | action_2_touch_yx,
277 | action_2_lift_yx,
278 | action_2_action_type,
279 | annotation_positions,
280 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD,
281 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION,
282 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION,
283 | ):
284 | """Determines if two actions are considered to be the same.
285 |
286 | Two actions being "the same" is defined here as two actions that would result
287 | in a similar screen state.
288 |
289 | Args:
290 | action_1_touch_yx: The (y, x) coordinates of the first action's touch.
291 | action_1_lift_yx: The (y, x) coordinates of the first action's lift.
292 | action_1_action_type: The action type of the first action.
293 | action_2_touch_yx: The (y, x) coordinates of the second action's touch.
294 | action_2_lift_yx: The (y, x) coordinates of the second action's lift.
295 | action_2_action_type: The action type of the second action.
296 | annotation_positions: The positions of the UI annotations for the screen. It
297 | is A 2D int array of shape (num_bboxes, 4), where each row represents a
298 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that
299 | containment is inclusive of the bounding box edges.
300 | tap_distance_threshold: The threshold that determines if two taps result in
301 | a matching screen state if they don't fall the same bounding boxes.
302 | annotation_width_augment_fraction: The fraction to increase the width of the
303 | bounding box by.
304 | annotation_height_augment_fraction: The fraction to increase the height of
305 | of the bounding box by.
306 |
307 | Returns:
308 | A boolean representing whether the two given actions are the same or not.
309 | """
310 | action_1_touch_yx = jnp.asarray(action_1_touch_yx)
311 | action_1_lift_yx = jnp.asarray(action_1_lift_yx)
312 | action_2_touch_yx = jnp.asarray(action_2_touch_yx)
313 | action_2_lift_yx = jnp.asarray(action_2_lift_yx)
314 |
315 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT),
316 | # because if that is the case, only the actions' types need to be compared.
317 | has_non_dual_point_action = jnp.logical_or(
318 | _is_non_dual_point_action(action_1_action_type),
319 | _is_non_dual_point_action(action_2_action_type),
320 | )
321 |
322 | different_dual_point_types = jnp.logical_xor(
323 | is_tap_action(action_1_touch_yx, action_1_lift_yx),
324 | is_tap_action(action_2_touch_yx, action_2_lift_yx),
325 | )
326 |
327 | is_tap = jnp.logical_and(
328 | is_tap_action(action_1_touch_yx, action_1_lift_yx),
329 | is_tap_action(action_2_touch_yx, action_2_lift_yx),
330 | )
331 |
332 | taps_match = _check_tap_actions_match(
333 | action_1_touch_yx,
334 | action_2_touch_yx,
335 | annotation_positions,
336 | tap_distance_threshold,
337 | annotation_width_augment_fraction,
338 | annotation_height_augment_fraction,
339 | )
340 |
341 | taps_match = jnp.logical_and(is_tap, taps_match)
342 |
343 | drags_match = _check_drag_actions_match(
344 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx
345 | )
346 | drags_match = jnp.where(is_tap, False, drags_match)
347 |
348 | return jnp.where(
349 | has_non_dual_point_action,
350 | jnp.equal(action_1_action_type, action_2_action_type),
351 | jnp.where(
352 | different_dual_point_types,
353 | False,
354 | jnp.logical_or(taps_match, drags_match),
355 | ),
356 | )
357 |
358 |
--------------------------------------------------------------------------------
/coat/model.py:
--------------------------------------------------------------------------------
1 | from typing import Union, List, Dict, Optional, Any
2 | from abc import abstractmethod
3 | import os
4 | import io
5 | import json
6 | import requests
7 | import random
8 | import base64
9 | import tempfile
10 | import numpy as np
11 | from PIL import Image
12 | from functools import partial
13 | from http import HTTPStatus
14 |
15 |
16 | class BaseModel(object):
17 |
18 | def __init__(self, **kwargs):
19 | pass
20 |
21 | @abstractmethod
22 | def format_query(self, *args: Any, **kwargs: Any):
23 | pass
24 |
25 | @abstractmethod
26 | def get_response(self, *args: Any, **kwargs: Any):
27 | pass
28 |
29 | @abstractmethod
30 | def chat(self, *args: Any, **kwargs: Any):
31 | pass
32 |
33 | def __call__(self, *args: Any, **kwargs: Any) -> Any:
34 | return self.get_response(*args, **kwargs)
35 |
36 |
37 | class OpenAIModel(BaseModel):
38 |
39 | def __init__(self, configs, seed=2024, **kwargs) -> None:
40 | super().__init__()
41 |
42 | self.api_key = configs['OPENAI_API_KEY']
43 | self.api_url = configs['OPENAI_API_URL']
44 | self.seed = seed
45 |
46 | def encode_image(self, image=None, image_path=None):
47 | assert image is not None or image_path is not None, \
48 | "Please specify at least an image array or a path to image."
49 | if image is not None:# convert image to bytes
50 | with io.BytesIO() as output_bytes:
51 | if isinstance(image, Image.Image): PIL_image = image
52 | elif isinstance(image, np.ndarray): PIL_image = Image.fromarray(image)
53 | else: raise TypeError(f"Unsupported image type {type(image)}")
54 | PIL_image.save(output_bytes, 'PNG')
55 | bytes_data = output_bytes.getvalue()
56 | return base64.b64encode(bytes_data).decode('utf-8')
57 | elif image_path is not None:
58 | with open(image_path, "rb") as image_file:
59 | return base64.b64encode(image_file.read()).decode('utf-8')
60 | else: raise NotImplementedError
61 |
62 | def format_query(self, image_paths:List[str], prompt:Union[str, dict, tuple], detail="high"):
63 | if isinstance(image_paths, str): image_paths = [image_paths]
64 |
65 | if isinstance(prompt, str):
66 | message = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt},]}]
67 | if len(image_paths) != 0:
68 | for img_path in image_paths:
69 | base64_image = self.encode_image(image_path=img_path)
70 | message[0]['content'].append(
71 | {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}}
72 | )
73 |
74 | if isinstance(prompt, list):
75 | message = []
76 | for role, msg in prompt:
77 | if isinstance(msg, dict) and "img_index" in msg:
78 | img_index = msg.pop('img_index')
79 | base64_image = self.encode_image(image_path=image_paths[img_index])
80 | message.append({'role': role, 'content': [
81 | {'type': 'text', 'text': msg['text']},
82 | {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}}
83 | ]})
84 | else:
85 | if isinstance(msg, dict):
86 | message.append({'role': role, 'content': [{'type': 'text', 'text': msg['text']},]})
87 | else:
88 | message.append({'role': role, 'content': [{'type': 'text', 'text': msg},]})
89 |
90 | if isinstance(prompt, dict):
91 | message = []
92 | for role in ['system', 'user', 'assistant']:
93 | if role in prompt:
94 | msg = [{'role': role, 'content': []},]
95 | if isinstance(prompt[role], str):
96 | msg['content'] = prompt[role]
97 | else:
98 | msg['content'].append({'type': 'text', 'text': prompt[role]['text']})
99 | if "img_index" in prompt[role]:
100 | img_index = prompt[role].pop('img_index')
101 | base64_image = self.encode_image(image_path=image_paths[img_index])
102 | msg['content'].append({'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}})
103 | message.append(msg)
104 | # message += [
105 | # {'role': 'user', 'content': [
106 | # {'type': 'text', 'text': user_prompt},
107 | # {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}}
108 | # ]}
109 | # ]
110 | return message
111 |
112 | def get_response(self, image_path:str, prompt:Union[str, dict], temperature=1, max_new_tokens=1024, **kwargs: Any):
113 | response_str, response_state, history = self.chat(
114 | image_path, prompt, history=[], temperature=temperature, max_new_tokens=max_new_tokens, **kwargs)
115 |
116 | return response_str, response_state
117 |
118 | def chat(self, image_path:str, prompt:Union[str, dict], history:list=[], temperature=1, max_new_tokens=1024, **kwargs: Any):
119 | message = history + self.format_query(image_path, prompt)
120 |
121 | headers = {
122 | 'Content-Type': 'application/json',
123 | 'Authorization': f'Bearer {self.api_key}'
124 | }
125 | payload = {
126 | # 'model': 'gpt-4-turbo',
127 | 'model': 'gpt-4-vision-preview',
128 | 'messages': message,
129 | 'max_tokens': max_new_tokens,
130 | 'seed': self.seed,
131 | }
132 |
133 | response = requests.post(self.api_url, headers=headers, json=payload)
134 | response_state = response.status_code
135 |
136 | try:
137 | response_json = response.json()
138 | response_str = response_json['choices'][0]['message']['content']
139 | except:
140 | response_json, response_str = {}, ""
141 |
142 | if response_state == 200:
143 | history += message
144 | history += [{'role': 'assistant', 'content': response_str}]
145 |
146 | return response_str, response_state, history
147 |
148 |
149 | class GeminiModel(BaseModel):
150 |
151 | def __init__(self, configs, **kwargs) -> None:
152 | super().__init__()
153 | import google.generativeai as genai
154 |
155 | genai.configure(api_key=configs['GEMINI_API_KEY'], transport='rest')
156 | self.model = genai.GenerativeModel(configs['GEMINI_MODEL'])
157 |
158 | def format_query(self, image_paths:str, prompt:Union[str, list, dict]):
159 | if isinstance(prompt, str):
160 | image = Image.open(image_paths)
161 | return [prompt, image]
162 |
163 | if isinstance(image_paths, str): image_paths = [image_paths]
164 |
165 | if isinstance(prompt, list):
166 | message = []
167 | for role, msg in prompt:
168 | if isinstance(msg, dict) and "img_index" in msg:
169 | img_index = msg.pop('img_index')
170 | image = Image.open(image_paths[img_index])
171 | message.extend([msg['text'], image])
172 | else:
173 | if isinstance(msg, dict): message.append(msg['text'])
174 | else: message.append(msg)
175 | return message
176 |
177 | def get_response(self, image_path:str, prompt:str, temperature=1, max_new_tokens=1024, **kwargs: Any):
178 | """
179 | Return:
180 | - response_str: (string) the decoded response
181 | - response_state: (int)
182 | STOP (1): Natural stop point of the model or provided stop sequence.
183 | MAX_TOKENS (2): The maximum number of tokens as specified in the request was reached.
184 | SAFETY (3): The candidate content was flagged for safety reasons.
185 | RECITATION (4): The candidate content was flagged for recitation reasons.
186 | OTHER (5): Unknown reason.
187 | """
188 | import google.generativeai as genai
189 |
190 | message = self.format_query(image_path, prompt)
191 | response = self.model.generate_content(message,
192 | generation_config=genai.types.GenerationConfig(
193 | candidate_count=1,
194 | temperature=temperature,
195 | max_output_tokens=max_new_tokens
196 | )
197 | )
198 | response_state = response.candidates[0].finish_reason
199 |
200 | try:
201 | response_str = response.candidates[0].content.parts[0].text
202 | if response_state == 1: response_state = HTTPStatus.OK
203 | except: response_str = ""
204 |
205 | return response_str.strip(), response_state
206 |
207 |
208 |
209 | qwenvl_special_str = """
210 | Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
211 | {
212 | "THINK": ,
213 | "NEXT": ,
214 | "ACTION": ,
215 | "ARGS":
216 | }
217 | You should output only one next single-step action.
218 |
219 | :
220 | """
221 |
222 | class QwenVLModel(BaseModel):
223 |
224 | def __init__(self, configs, **kwargs):
225 | super().__init__()
226 | from dashscope import MultiModalConversation
227 |
228 | self.api_key = configs['DASHSCOPE_API_KEY']
229 | self.api_version = configs['DASHSCOPE_MODEL']
230 | self.model = partial(
231 | MultiModalConversation.call,
232 | model=self.api_version, api_key=self.api_key
233 | )
234 |
235 | def encode_image(self, image=None, image_path=None):
236 | assert image is not None or image_path is not None, \
237 | "Please specify at least an image array or a path to image."
238 | if image is not None:# convert image to bytes
239 | if isinstance(image, Image.Image): PIL_image = image
240 | elif isinstance(image, np.ndarray): PIL_image = Image.fromarray(image)
241 | else: raise TypeError(f"Unsupported image type {type(image)}")
242 | temp_directory = tempfile.gettempdir()
243 | unique_suffix = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5))
244 | filename = f"image{unique_suffix}.png"
245 | temp_image_path = os.path.join(temp_directory, filename)
246 | temp_image_url = f"file:///{temp_image_path}"
247 | temp_image_url = temp_image_url.replace('\\', '/')
248 | PIL_image.save(temp_image_path)
249 | image_url = temp_image_url
250 | elif image_path is not None: image_url = f'file://{os.path.abspath(image_path)}'
251 | else: raise NotImplementedError
252 | return image_url
253 |
254 | def format_query(self, image_paths:str, prompt:Union[str, dict], detail="high"):
255 | if isinstance(image_paths, str): image_paths = [image_paths]
256 |
257 | if isinstance(prompt, str):
258 | message = [{'role': 'user', 'content': [{'text': prompt},]}]
259 | if len(image_paths) != 0:
260 | for img_path in image_paths:
261 | image_url = self.encode_image(image_path=img_path)
262 | message[0]['content'].append({'image': image_url})
263 |
264 | if isinstance(prompt, list):
265 | message = []
266 | for role, msg in prompt:
267 | if isinstance(msg, dict) and "img_index" in msg:
268 | img_index = msg.pop('img_index')
269 | image_url = self.encode_image(image_path=image_paths[img_index])
270 | message.append({'role': role, 'content': [
271 | {'text': msg['text']}, {'image': image_url}]})
272 | else:
273 | if isinstance(msg, dict):
274 | message.append({'role': role, 'content': [{'text': msg['text']},]})
275 | else:
276 | message.append({'role': role, 'content': [{'text': msg},]})
277 | if role == "assistant":
278 | message.append({"role": "user", "content": [{"text": qwenvl_special_str}]})
279 |
280 | if isinstance(prompt, dict):
281 | message = []
282 | for role in ['system', 'user', 'assistant']:
283 | if role in prompt:
284 | msg = [{'role': role, 'content': []},]
285 | if isinstance(prompt[role], str):
286 | msg['content'].append({'text': prompt[role]})
287 | else:
288 | msg['content'].append({'text': prompt[role]['text']})
289 | if "img_index" in prompt[role]:
290 | img_index = prompt[role].pop('img_index')
291 | image_url = self.encode_image(image_path=image_paths[img_index])
292 | msg['content'].append({'image': image_url})
293 | message.append(msg)
294 |
295 | return message
296 |
297 | def chat(self, image_path:str, prompt:Union[str, dict], history:list=[], top_p=0.001, **kwargs: Any):
298 | message = history + self.format_query(image_path, prompt)
299 |
300 | response = self.model(messages=message, top_p=top_p)
301 | response_state = response.status_code
302 |
303 | try:
304 | response_str = response.output.choices[0].message.content[0]['text']
305 | if response_state == HTTPStatus.OK: #如果调用成功
306 | history += message
307 | history += [{
308 | 'role': 'assistant',
309 | 'content': [{'text': response_str}]
310 | }]
311 | except: response_str = ""
312 |
313 | return response_str.strip(), response_state, history
314 |
315 | def get_response(self, image_path:str, prompt:Union[str, dict], history:list=[], top_p=0.1, **kwargs: Any):
316 | response_str, response_state, history = self.chat(
317 | image_path, prompt, history=[], top_p=top_p, **kwargs)
318 |
319 | return response_str, response_state
320 |
321 |
322 | def get_model(model_config, seed=2024) -> BaseModel:
323 | if model_config['NAME'].lower() == "openai": return OpenAIModel(model_config, seed)
324 | if model_config['NAME'].lower() == "gemini": return GeminiModel(model_config)
325 | if model_config['NAME'].lower() == "qwenvl": return QwenVLModel(model_config)
326 | raise NotImplementedError
327 |
328 |
--------------------------------------------------------------------------------
/coat/agent.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import json
4 | import traceback
5 | import imagesize
6 | import numpy as np
7 | from PIL import Image
8 | from http import HTTPStatus
9 | from .model import get_model
10 | from .screen_utils import row_col_sort, draw_bbox
11 |
12 |
13 | def json_parser(json_string:str):
14 | if json_string.startswith("```json"):
15 | json_string = json_string[7:]
16 | if json_string.startswith("```JSON"):
17 | json_string = json_string[7:]
18 | if json_string.endswith("```"):
19 | json_string = json_string[:-3]
20 |
21 | if "```json" in json_string:
22 | json_string = json_string.split("```json")[1]
23 | if "```JSON" in json_string:
24 | json_string = json_string.split("```JSON")[1]
25 | if "```" in json_string:
26 | json_string = json_string.split("```")[0]
27 |
28 | return json.loads(json_string)
29 |
30 |
31 |
32 | class BaseAgent(object):
33 |
34 | def __init__(self, config, *args, **kwargs) -> None:
35 | self.cfg = config
36 | self.prompts = config['PROMPTS']
37 | self.vlm = get_model(config['MODEL'])
38 |
39 | def observe(self, user_request, screen_path, **kwargs):
40 | sys_rompt = self.prompts['SCREEN_DESC']['SYSTEM']
41 | usr_prompt = self.prompts['SCREEN_DESC']['USER']
42 | usr_prompt = [x.strip() for x in usr_prompt.split("{screenshot}")]
43 |
44 | prompt = [
45 | ("system", sys_rompt),
46 | ("user", {"text": usr_prompt[0], "img_index": 0})
47 | ]
48 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt}))
49 |
50 | res, state = self.vlm.get_response(screen_path, prompt)
51 | return res, state
52 |
53 | def think_action(self, user_request, screen_path, screen_desc, history_actions, **kwargs):
54 | sys_rompt = self.prompts['ACTION_THINK_DESC']['SYSTEM']
55 |
56 | if history_actions: history_actions = ", ".join(history_actions)
57 | else: history_actions = "None"
58 | history_actions = history_actions + "."
59 |
60 | usr_prompt = self.prompts['ACTION_THINK_DESC']['USER']
61 | usr_prompt = usr_prompt.replace("{screen_desc}", screen_desc)
62 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions)
63 | usr_prompt = usr_prompt.replace("{user_request}", user_request)
64 | usr_prompt = [x.strip() for x in usr_prompt.split("{screenshot}")]
65 |
66 | prompt = [
67 | ("system", sys_rompt),
68 | ("user", {"text": usr_prompt[0], "img_index": 0})
69 | ]
70 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt}))
71 |
72 | res, state = self.vlm.get_response(screen_path, prompt)
73 | return res, state
74 |
75 | def reflect_result(self, user_request, screen_path, next_screen_path, last_action, **kwargs):
76 | sys_rompt = self.prompts['ACTION_RESULT']['SYSTEM']
77 |
78 | usr_prompt = self.prompts['ACTION_RESULT']['USER']
79 | usr_prompt = usr_prompt.replace("{last_action}", last_action)
80 | usr_prompt = usr_prompt.replace("{user_request}", user_request)
81 |
82 | usr_prompt = [x.strip() for x in usr_prompt.split("{before_screenshot}")]
83 | usr_prompt = [usr_prompt[0]] + [x.strip() for x in usr_prompt[1].split("{after_screenshot}")]
84 |
85 | prompt = [
86 | ("system", sys_rompt),
87 | ("user", {"text": usr_prompt[0], "img_index": 0}),
88 | ("user", {"text": usr_prompt[1], "img_index": 1})
89 | ]
90 | for txt in usr_prompt[2:]: prompt.append(("user", {"text": txt}))
91 |
92 | res, state = self.vlm.get_response([screen_path, next_screen_path], prompt)
93 | return res, state
94 |
95 | def flow(self, step_data, save_dir, max_trials=5):
96 | subset, episode_id, step_id = step_data['subset'], step_data['episode_id'], step_data['step_id']
97 | user_request, image_path = step_data['instruction'], step_data['image_full_path']
98 |
99 | save_dir = os.path.join(save_dir, f"{subset}-{episode_id}")
100 | if not os.path.exists(save_dir): os.makedirs(save_dir)
101 | save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{step_id}.json")
102 | if not os.path.exists(save_path): json.dump({}, open(save_path, "w", encoding="utf-8"))
103 |
104 | prev_anno = json.load(open(save_path, "r"))
105 | # if 'action_result' not in prev_anno: return
106 | # del prev_anno['action_result']
107 | # json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4)
108 | # return
109 |
110 | update = False
111 | if 'screen_desc' not in prev_anno:
112 | print(f"Genearting [screen_desc] -> img_path {image_path}")
113 | try_num = 0
114 | while try_num < max_trials:
115 | res, state = self.observe(user_request, image_path)
116 | if state == HTTPStatus.OK:
117 | prev_anno['screen_desc'], update = res.strip(), True
118 | break
119 | try_num += 1
120 | time.sleep(5)
121 | if update: print(f"Updating [screen_desc] {save_path} ...")
122 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4)
123 |
124 | update = False
125 | if 'action_think' not in prev_anno:
126 | print(f"Genearting [action_think] -> img_path {image_path}")
127 | try_num = 0
128 | while try_num < max_trials:
129 | res, state = self.think_action(user_request, image_path,
130 | screen_desc=prev_anno['screen_desc'],
131 | history_actions=step_data['history_coat_actions'])
132 | if state == HTTPStatus.OK:
133 | try:
134 | response = json_parser(res)
135 | action_think = response['Thought']
136 | action_plan = response['Future Action Plan']
137 | action_desc = response['Next Single Step Action']
138 | prev_anno['action_think'] = action_think
139 | prev_anno['action_plan'] = action_plan
140 | prev_anno['action_desc'] = action_desc
141 | update = True
142 | break
143 | except Exception as e:
144 | if not isinstance(e, json.decoder.JSONDecodeError): print(traceback.format_exc())
145 | else: print(res,"\n", save_path, "\n", sep="")
146 | break
147 | print(f"Trial {try_num} failured! -- Status {state}")
148 | try_num += 1
149 | time.sleep(5)
150 | if update: print(f"Updating [action_think] {save_path} ...")
151 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4)
152 |
153 | update = False
154 | if 'action_result' not in prev_anno:
155 | print(f"Genearting [action_result] -> img_path {image_path}")
156 | if step_data['result_action_type'] not in [10, 11]:
157 | try_num = 0
158 | cur_image_path, next_image_path = image_path, step_data['next_image_full_path']
159 | while try_num < max_trials:
160 | cur_image_height = imagesize.get(cur_image_path)[1]
161 | next_image_height = imagesize.get(next_image_path)[1]
162 | res, state = self.reflect_result(user_request,
163 | screen_path=cur_image_path,
164 | next_screen_path=next_image_path,
165 | last_action=step_data['coat_action_desc'])
166 | if state == HTTPStatus.OK:
167 | prev_anno['action_result'], update = res, True
168 | break
169 | elif state == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
170 | cur_image_path = self.compress_image(image_path, max_height=int(0.8*cur_image_height))
171 | next_image_path = self.compress_image(step_data['next_image_full_path'], max_height=int(0.8*next_image_height))
172 | else: print(res)
173 | print(f"Trial {try_num} failured! -- Status {state}")
174 | try_num += 1
175 | time.sleep(5)
176 | else:
177 | prev_anno['action_result'] = "The execution of user request is stopped."
178 | update = True
179 | if update: print(f"Updating [action_result] {save_path} ...")
180 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4)
181 |
182 | def compress_image(self, image_path, max_height=1024, mode="scale"):
183 | org_img_path, ext = os.path.splitext(image_path)
184 |
185 | screen_img = Image.open(image_path)
186 | image_width, image_height = screen_img.size
187 |
188 | if image_height > max_height:
189 | scale_percent = max_height / image_height
190 | height = int(scale_percent * image_height)
191 | width = int(scale_percent * image_width)
192 | screen_img = screen_img.resize((width, height))
193 |
194 | if mode == "scale":
195 | save_img_path = org_img_path+"_tmp"+ext
196 | screen_img.save(save_img_path)
197 | elif mode == "jpg":
198 | screen_img = screen_img.convert("RGB")
199 | save_img_path = org_img_path+"_tmp"+".jpeg"
200 | screen_img.save(save_img_path)
201 | else:
202 | save_img_path = image_path
203 |
204 | return save_img_path
205 |
206 |
207 | class ScreenAgent(BaseAgent):
208 |
209 | def __init__(self, config) -> None:
210 | super().__init__(config)
211 |
212 | self.use_screen_tag = self.cfg.DATA.USE_SCREEN_TAG
213 | self.use_screen_txt = (not self.use_screen_tag)
214 | self.screen_mode = "txt" if self.use_screen_txt else "tag"
215 | pass
216 |
217 | def format_bbox(self, bbox_xyxy, width, height):
218 | coord_sys, rep_type = self.cfg.DATA.BBOX_FORMAT.split("_")
219 |
220 | if coord_sys == "relative":
221 | bbox = np.array(bbox_xyxy) / np.array([width, height, width, height])
222 | if rep_type == "int":
223 | bbox = (bbox*1000).astype(np.int32).tolist()
224 | else:
225 | bbox = [round(x, 3) for x in bbox.tolist()]
226 | elif coord_sys == "absolute": bbox = bbox_xyxy
227 | else: raise NotImplementedError
228 | box_str = json.dumps(bbox)
229 | return box_str
230 |
231 | def add_screen_tag(self, step_data, save_dir, update=False):
232 | """ Set of Mark Prompting """
233 | src_image_path = step_data['image_full_path']
234 | base_name, ext = os.path.splitext(src_image_path)
235 | dst_image_path = base_name + "_tag" + ext
236 | # image_name = os.path.basename(src_image_path)
237 | # dst_image_path = os.path.join(save_dir, image_name)
238 |
239 | if not os.path.exists(dst_image_path) or update:
240 | ui_bboxes = []
241 | for ui_node in step_data['ui_elements']:
242 | ui_bboxes.append(ui_node['bounds'])
243 |
244 | tag_image = draw_bbox(src_image_path, bboxes=ui_bboxes,
245 | texts=list(map(str, range(len(step_data['ui_positions'])))),
246 | rgba_color=(255, 0, 0, 180), thickness=5)
247 | tag_image.save(dst_image_path)
248 | return dst_image_path
249 |
250 | def add_screen_txt(self, step_data, indent=2):
251 | """ Textual Representation of Screen Elements """
252 | image_path = step_data['image_full_path']
253 | w, h = imagesize.get(image_path)
254 |
255 | screen_str = []
256 | for ui_node in step_data['ui_elements']:
257 | ui_str = " "*indent + ui_node['type'].upper()
258 | if ui_node['text']: ui_str += " " + ui_node['text']
259 | bbox_str = self.format_bbox(ui_node['bounds'], w, h)
260 | ui_str += " " + bbox_str
261 | screen_str.append(ui_str)
262 |
263 | return "\n".join(screen_str)
264 |
265 | def make_action(self, step_data, usr_prompt, save_dir, asst_prompt=None):
266 | action_mode = self.cfg.DEMO_MODE.upper()
267 | image_path = step_data['image_full_path']
268 |
269 | ui_elements = []
270 | for org_bbox, txt, ui_class in zip(
271 | step_data['ui_positions'], step_data['ui_text'], step_data['ui_types']):
272 | ymin, xmin, h, w = org_bbox
273 | bbox = [xmin, ymin, xmin+w, ymin+h]
274 | ui_elements.append({"bounds": bbox, "text": txt, "type": ui_class})
275 | ui_elements = row_col_sort(ui_elements)
276 | step_data['ui_elements'] = ui_elements
277 |
278 | if self.use_screen_tag:
279 | sys_prompt = self.prompts['ACTION_PREDICT'][action_mode]['SYSTEM']['SCREEN_TAG']
280 | action_space = self.prompts['ACTION_PREDICT']['ACTION_SPACE']['SCREEN_TAG']
281 | sys_prompt = sys_prompt.replace("{action_space}", action_space)
282 | tag_image_path = self.add_screen_tag(step_data, save_dir=save_dir)
283 |
284 | if self.cfg.MODEL.NAME == "openai":
285 | image_path = self.compress_image(image_path, max_height=1440, mode="jpg")
286 | tag_image_path = self.compress_image(tag_image_path, max_height=1440, mode="scale")
287 |
288 | prompt = [
289 | ("system", sys_prompt),
290 | ("user", {"text": usr_prompt[0], "img_index": 0}),
291 | ("user", {"text": "", "img_index": 1})
292 | ]
293 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt}))
294 | if asst_prompt: prompt.append(("assistant", {"text": asst_prompt}))
295 | asst_res, state = self.vlm.get_response([image_path, tag_image_path], prompt)
296 |
297 | if self.use_screen_txt:
298 | sys_prompt = self.prompts['ACTION_PREDICT'][action_mode]['SYSTEM']['SCREEN_TXT']
299 | action_space = self.prompts['ACTION_PREDICT']['ACTION_SPACE']['SCREEN_TXT']
300 | sys_prompt = sys_prompt.replace("{action_space}", action_space)
301 | screen_txt = self.add_screen_txt(step_data)
302 |
303 | prompt = [
304 | ("system", sys_prompt),
305 | ("user", {"text": usr_prompt[0], "img_index": 0}),
306 | ("user", {"text": f": {screen_txt}"})
307 | ]
308 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt}))
309 | if asst_prompt: prompt.append(("assistant", {"text": asst_prompt}))
310 | asst_res, state = self.vlm.get_response([image_path], prompt)
311 |
312 | asst_head = asst_prompt if asst_prompt else ""
313 | return asst_head, asst_res, state
314 |
315 | def coa_action(self, step_data, *args, **kwargs):
316 | """ Chain of Action """
317 | user_request, image_path = step_data['instruction'], step_data['image_full_path']
318 | history_actions = step_data['history_actions']
319 |
320 | if history_actions: history_actions = ", ".join(history_actions)
321 | else: history_actions = "None"
322 | history_actions = history_actions + "."
323 |
324 | usr_prompt = self.prompts['ACTION_PREDICT']['COA']['USER']
325 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions)
326 | usr_prompt = usr_prompt.replace("{user_request}", user_request)
327 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")]
328 |
329 | return self.make_action(step_data, usr_prompt, kwargs['save_dir'])
330 |
331 | def cot_action(self, step_data, *args, **kwargs):
332 | """ Chain of Thought """
333 | user_request, image_path = step_data['instruction'], step_data['image_full_path']
334 | acton_think = kwargs['action_think']
335 |
336 | usr_prompt = self.prompts['ACTION_PREDICT']['COT']['USER']
337 | usr_prompt = usr_prompt.replace("{user_request}", user_request)
338 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")]
339 |
340 | if self.cfg.MODEL.NAME == "qwenvl":
341 | asst_prompt = json.dumps({"THINK": acton_think}, indent=4)
342 | else:
343 | asst_prompt = self.prompts['ACTION_PREDICT']['COT']['ASST']
344 | asst_prompt = asst_prompt.replace("{action_thought}", json.dumps(acton_think))
345 |
346 | return self.make_action(step_data, usr_prompt, kwargs['save_dir'], asst_prompt=asst_prompt)
347 |
348 | def coat_action(self, step_data, *args, **kwargs):
349 | """ Chain of Action Thought """
350 | user_request, image_path = step_data['instruction'], step_data['image_full_path']
351 | history_actions = step_data['history_coat_actions']
352 |
353 | screen_desc = kwargs['screen_desc']
354 | prev_action_result = kwargs.get('prev_action_result', None)
355 | acton_think, next_action_desc = kwargs['action_think'], kwargs['action_desc']
356 |
357 | usr_prompt = self.prompts['ACTION_PREDICT']['COAT']['USER']
358 | usr_prompt = usr_prompt.split("\n")
359 | if not history_actions: usr_prompt = [x for x in usr_prompt if 'history_actions' not in x]
360 | if not prev_action_result: usr_prompt = [x for x in usr_prompt if 'prev_action_result' not in x]
361 | usr_prompt = "\n".join(usr_prompt)
362 |
363 | usr_prompt = usr_prompt.replace("{screen_desc}", screen_desc)
364 | if history_actions: history_actions = ", ".join(history_actions)
365 | else: history_actions = "None"
366 | history_actions = history_actions + "."
367 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions)
368 | if prev_action_result:
369 | usr_prompt = usr_prompt.replace("{prev_action_result}", prev_action_result)
370 | usr_prompt = usr_prompt.replace("{user_request}", user_request)
371 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")]
372 |
373 | if self.cfg.MODEL.NAME == "qwenvl":
374 | asst_prompt = json.dumps({"THINK": acton_think, "NEXT": next_action_desc,}, indent=4)
375 | else:
376 | asst_prompt = self.prompts['ACTION_PREDICT']['COAT']['ASST']
377 | asst_prompt = asst_prompt.replace("{action_thought}", json.dumps(acton_think))
378 | asst_prompt = asst_prompt.replace("{next_single_action}", json.dumps(next_action_desc))
379 |
380 | return self.make_action(step_data, usr_prompt, kwargs['save_dir'], asst_prompt=asst_prompt)
381 |
382 | def predict(self, step_data, save_dir, max_trials=5):
383 | subset, episode_id = step_data['subset'], step_data['episode_id']
384 | prev_step_id, step_id = step_data['prev_step_id'], step_data['step_id']
385 |
386 | save_dir = os.path.join(save_dir, f"{subset}-{episode_id}")
387 |
388 | cur_save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{step_id}.json")
389 | cur_anno = json.load(open(cur_save_path, "r"))
390 | if prev_step_id:
391 | prev_save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{prev_step_id}.json")
392 | prev_anno = json.load(open(prev_save_path, "r"))
393 | cur_anno['prev_action_result'] = prev_anno['action_result']
394 |
395 | action_mode = self.cfg.DEMO_MODE.upper()
396 | if action_mode == "COA": func_handler = self.coa_action
397 | elif action_mode == "COT": func_handler = self.cot_action
398 | elif action_mode == "COAT": func_handler = self.coat_action
399 | else: raise NotImplementedError
400 |
401 | if 'action_predict' not in cur_anno: cur_anno['action_predict'] = {}
402 | if action_mode not in cur_anno['action_predict']: cur_anno['action_predict'][action_mode] = {}
403 | # if action_mode in cur_anno['action_predict']:
404 | # anno = cur_anno['action_predict'][action_mode]
405 | # cur_anno['action_predict'][action_mode] = {}
406 | # cur_anno['action_predict'][action_mode][self.screen_mode] = anno
407 | # json.dump(cur_anno, open(cur_save_path, "w", encoding="utf-8"), indent=4)
408 |
409 | if self.screen_mode not in cur_anno['action_predict'][action_mode]:
410 | print(f"[Mode {action_mode}][{self.screen_mode}] Gnerating action ... ({cur_save_path})")
411 | try_num = 0
412 | while try_num < max_trials:
413 | asst_head, asst_res, state = func_handler(step_data, **cur_anno, save_dir=save_dir)
414 | if state == HTTPStatus.OK:
415 | try:
416 | try: response = json_parser(asst_res)
417 | except json.decoder.JSONDecodeError as e:
418 | try: response = json_parser(asst_head + asst_res)
419 | except:
420 | try:response = json_parser(asst_res + "}")
421 | except: response = json_parser(asst_head + asst_res + "}")
422 | cur_anno['action_predict'][action_mode][self.screen_mode] = {
423 | "ACTION": response["ACTION"], "ARGS": response["ARGS"]}
424 |
425 | print(f"Updating [{action_mode}][{self.screen_mode}] {cur_save_path} ...")
426 | json.dump(cur_anno, open(cur_save_path, "w", encoding="utf-8"), indent=4)
427 | except json.decoder.JSONDecodeError as e:
428 | print('-'*50 + "\n", asst_head, sep=" ")
429 | print('-'*50 + "\n", asst_res, sep=" ")
430 | break
431 | elif state == HTTPStatus.REQUEST_ENTITY_TOO_LARGE:
432 | print(f"Trial {try_num} failured! -- Status {state}")
433 | break
434 | print(f"Trial {try_num} failured! -- Status {state}")
435 | try_num += 1
436 | time.sleep(10)
437 |
438 | return None
439 |
--------------------------------------------------------------------------------
/coat/config.yaml:
--------------------------------------------------------------------------------
1 | DEMO_NAME: "CoATAgent"
2 | DEMO_MODE: "COA"
3 | OUTPUT_DIR: "android-in-the-zoo/api"
4 |
5 | DATA:
6 | DATA_DIR: "android-in-the-zoo"
7 | SPLIT: "test"
8 | USE_SCREEN_TAG: True
9 | USE_SCREEN_TXT: False
10 | BBOX_FORMAT: "relative_int"
11 |
12 | MODEL:
13 | NAME: "command_line"
14 | OPENAI_API_URL: ""
15 | OPENAI_API_KEY: ""
16 | GEMINI_API_KEY: ""
17 | GEMINI_MODEL: gemini-pro-vision
18 | DASHSCOPE_API_KEY: ""
19 | DASHSCOPE_MODEL: qwen-vl-max
20 |
21 | PROMPTS:
22 | SCREEN_DESC:
23 | SYSTEM: |-
24 | You are a smart and helpful visual assistant that is well trained to describe smartphone screenshots.
25 | - You are provided with a screenshot of the current mobile phone.
26 | - You are required to describe this screen about its main content and its functionality. The output must be less than five sentences.
27 | - You are required to keep the description as concise and brief as possible.
28 | USER: |-
29 | : {screenshot}
30 | :
31 |
32 | ACTION_THINK_DESC:
33 | SYSTEM: |-
34 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
35 | Your task is to navigate and take action on the current screen step-by-step to complete the user request.
36 | - You are provided with a screenshot of the current mobile phone, together with the textual screen description.
37 | - You are provided with your history actions to decide on your next action. You can backtrack to revise the previous actions when neessary.
38 | - You are required to analyze the task status and detail a reasonable future action plan to accomplish the user request.
39 | - You are required to select the next single-step action based on your analysis and action plan.
40 |
41 | ## Analysis Guidelines
42 | - You should check whether the history actions have accomplish the user request.
43 | - You should check the apps, icons, and buttons that are visible on the current screen and might pertain to the user request.
44 | - You should combine the above information and describe your future action plan. But DO NOT include any additional actions beyond the completion of the user request.
45 |
46 | ## Output Format
47 | - You are required to response in a JSON format, consisting of 3 distinct parts with the following keys and corresponding content:
48 | {
49 | "Thought": ,
50 | "Future Action Plan": ,
51 | "Next Single Step Action":
52 | }
53 | - You can not output anything except for the above JSON.
54 |
55 | ## Output Example
56 | {
57 | "Thought": "...",
58 | "Future Action Plan": ["...", "..."],
59 | "Next Single Step Action": "..."
60 | }
61 |
62 | USER: |-
63 | : {screenshot}
64 | : {screen_desc}
65 | : {history_actions}
66 | : {user_request}
67 | :
68 |
69 | ACTION_PREDICT:
70 | COA:
71 | SYSTEM:
72 | SCREEN_TXT: |-
73 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
74 | Your task is to navigate on the current screen to complete the user request.
75 | - You are provided with a screenshot of the current mobile phone.
76 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height.
77 | - You are provided with history actions trying to accompolish the user request.
78 | - You are required to decide on the next single-step valid action to conduct on the current screen.
79 |
80 | ## Valid Actions on the screen
81 | {action_space}
82 |
83 | ## Output Format
84 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
85 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
86 | {
87 | "ACTION": ,
88 | "ARGS":
89 | }
90 |
91 | ## Output Example 1:
92 | {
93 | "ACTION": "click_element",
94 | "ARGS": {"bbox": [100, 345, 219, 826]}
95 | }
96 |
97 | ## Output Example 2:
98 | {
99 | "ACTION": "scroll",
100 | "ARGS": {"direction": "down"}
101 | }
102 |
103 | ## Output Example 3:
104 | {
105 | "ACTION": "press_home",
106 | "ARGS": {}
107 | }
108 |
109 | SCREEN_TAG: |-
110 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
111 | Your task is to navigate on the current screen to complete the user request.
112 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second).
113 | - You are provided with history actions trying to accompolish the user request.
114 | - You are required to decide on the next single-step valid action to conduct on the current screen.
115 |
116 | ## Valid Actions on the screen
117 | {action_space}
118 |
119 | ## Output Format
120 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
121 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
122 | {
123 | "ACTION": ,
124 | "ARGS":
125 | }
126 |
127 | ## Output Example 1:
128 | {
129 | "ACTION": "click_element",
130 | "ARGS": {"idx": 6}
131 | }
132 |
133 | ## Output Example 2:
134 | {
135 | "ACTION": "scroll",
136 | "ARGS": {"direction": "down"}
137 | }
138 |
139 | ## Output Example 3:
140 | {
141 | "ACTION": "press_home",
142 | "ARGS": {}
143 | }
144 |
145 | USER: |-
146 | : {screenshot}
147 | : {history_actions}
148 | : {user_request}
149 | :
150 |
151 | HINT:
152 | SYSTEM: |-
153 |
154 | USER: |-
155 | : {screenshot}
156 | : {next_single_action}
157 | : {user_request}
158 | :
159 |
160 | COT:
161 | SYSTEM:
162 | SCREEN_TXT: |-
163 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
164 | Your task is to navigate on the current screen to complete the user request.
165 | - You are provided with a screenshot of the current mobile phone.
166 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height.
167 | - You are required to decide on the next single-step valid action to conduct on the current screen.
168 |
169 | ## Valid Actions on the screen
170 | {action_space}
171 |
172 | ## Output Format
173 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
174 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
175 | {
176 | "THINK": ,
177 | "ACTION": ,
178 | "ARGS":
179 | }
180 |
181 | ## Output Example 1:
182 | {
183 | "THINK": "...",
184 | "ACTION": "click_element",
185 | "ARGS": {"bbox": [100, 345, 219, 826]}
186 | }
187 |
188 | ## Output Example 2:
189 | {
190 | "THINK": "...",
191 | "ACTION": "scroll",
192 | "ARGS": {"direction": "down"}
193 | }
194 |
195 | ## Output Example 3:
196 | {
197 | "THINK": "...",
198 | "ACTION": "press_home",
199 | "ARGS": {}
200 | }
201 |
202 | SCREEN_TAG: |-
203 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
204 | Your task is to navigate on the current screen to complete the user request.
205 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second).
206 | - You are required to decide on the next single-step valid action to conduct on the current screen.
207 |
208 | ## Valid Actions on the screen
209 | {action_space}
210 |
211 | ## Output Format
212 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
213 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
214 | {
215 | "THINK": ,
216 | "ACTION": ,
217 | "ARGS":
218 | }
219 |
220 | ## Output Example 1:
221 | {
222 | "THINK": "...",
223 | "ACTION": "click_element",
224 | "ARGS": {"idx": 6}
225 | }
226 |
227 | ## Output Example 2:
228 | {
229 | "THINK": "...",
230 | "ACTION": "scroll",
231 | "ARGS": {"direction": "down"}
232 | }
233 |
234 | ## Output Example 3:
235 | {
236 | "THINK": "...",
237 | "ACTION": "press_home",
238 | "ARGS": {}
239 | }
240 |
241 | USER: |-
242 | : {screenshot}
243 | : {user_request}
244 | :
245 |
246 | ASST: |-
247 | {
248 | "THINK": {action_thought},
249 | "ACTION":
250 |
251 | COAT:
252 | SYSTEM:
253 | SCREEN_TXT: |-
254 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
255 | Your task is to navigate on the current screen to complete the user request.
256 | - You are provided with a screenshot of the current mobile phone.
257 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height.
258 | - You are provided with a breif summarization of the screen content.
259 | - You are provided with history actions trying to accompolish the user request, together with the previous action result that indicates how current screenshot is obtained.
260 | - You are required to decide on the next single-step valid action to be conducted on the current screen so as to fulfill the user request.
261 |
262 | ## Valid Actions on the screen
263 | {action_space}
264 |
265 | ## Output Format
266 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
267 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
268 | {
269 | "THINK": ,
270 | "NEXT": ,
271 | "ACTION": ,
272 | "ARGS":
273 | }
274 |
275 | ## Output Example 1:
276 | {
277 | "THINK": "...",
278 | "NEXT": "...",
279 | "ACTION": "click_element",
280 | "ARGS": {"bbox": [100, 345, 219, 826]}
281 | }
282 |
283 | ## Output Example 2:
284 | {
285 | "THINK": "...",
286 | "NEXT": "...",
287 | "ACTION": "scroll",
288 | "ARGS": {"direction": "down"}
289 | }
290 |
291 | ## Output Example 3:
292 | {
293 | "THINK": "...",
294 | "NEXT": "...",
295 | "ACTION": "press_home",
296 | "ARGS": {}
297 | }
298 |
299 | SCREEN_TAG: |-
300 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
301 | Your task is to navigate on the current screen to complete the user request.
302 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second).
303 | - You are provided with a breif summarization of the screen content.
304 | - You are provided with history actions trying to accompolish the user request, together with the previous action result that indicates how current screenshot is obtained.
305 | - You are required to decide on the next single-step valid action to be conducted on the current screen so as to fulfill the user request.
306 |
307 | ## Valid Actions on the screen
308 | {action_space}
309 |
310 | ## Output Format
311 | - You must choose one of the valid apis provided above and response in the corresponding API call format.
312 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content:
313 | {
314 | "THINK": ,
315 | "NEXT": ,
316 | "ACTION": ,
317 | "ARGS":
318 | }
319 |
320 | ## Output Example 1:
321 | {
322 | "THINK": "...",
323 | "NEXT": "...",
324 | "ACTION": "click_element",
325 | "ARGS": {"idx": 6}
326 | }
327 |
328 | ## Output Example 2:
329 | {
330 | "THINK": "...",
331 | "NEXT": "...",
332 | "ACTION": "scroll",
333 | "ARGS": {"direction": "down"}
334 | }
335 |
336 | ## Output Example 3:
337 | {
338 | "THINK": "...",
339 | "NEXT": "...",
340 | "ACTION": "press_home",
341 | "ARGS": {}
342 | }
343 |
344 | USER: |-
345 | : {screenshot}
346 | : {screen_desc}
347 | : {history_actions}
348 | : {prev_action_result}
349 | : {user_request}
350 | :
351 |
352 | ASST: |-
353 | {
354 | "THINK": {action_thought},
355 | "NEXT": {next_single_action},
356 | "ACTION":
357 |
358 | ACTION_SPACE:
359 |
360 | SCREEN_TXT: |-
361 | CLICK_ELEMENT:
362 | summary: click on the visible element on the screen
363 | usage:
364 | [1] API call: click_element(bbox=)
365 | [2] Args:
366 | - bbox: 'The bounding box of the element to be clicked, formed as [x1, y1, x2, y2].'
367 | [3] Example: click_element(bbox=[10, 20, 30, 40])
368 | [4] Return: None
369 |
370 | SCROLL:
371 | summary: move the scrollable content, or open the app drawer
372 | explanation:
373 | - Scrolling down typically means moving scrollable content to see what's further down. If content is moving up, you are scrolling down.
374 | - Scrolling up could either open the app drawer, or move back to view the previous content on the same screen.
375 | usage:
376 | [1] API call: scroll(direction=)
377 | [2] Args:
378 | - direction: 'Scroll direction. One of ''up'', ''down'', ''left'', ''right''.'
379 | [3] Example: scroll(direction="up")
380 | [4] Return: None
381 |
382 | INPUT:
383 | summary: input text to an editable input area
384 | usage:
385 | [1] API call: input(text="")
386 | [2] Args:
387 | - text: 'The text input to the editable input area.'
388 | [3] Example: input(text="Hello World")
389 | [4] Return: None
390 |
391 | PRESS_ENTER:
392 | summary: confirm the input, or submit the input, or start a new line of text
393 | usage:
394 | [1] API call: press_enter()
395 | [2] Args: None
396 | [3] Example: press_enter()
397 | [4] Return: None
398 |
399 | PRESS_HOME:
400 | summary: directly move back to the home screen
401 | usage:
402 | [1] API call: press_home()
403 | [2] Args: None
404 | [3] Example: press_home()
405 | [4] Return: None
406 |
407 | PRESS_BACK:
408 | summary: return to the most recently visited screen or interface
409 | usage:
410 | [1] API call: press_back()
411 | [2] Args: None
412 | [3] Example: press_back()
413 | [4] Return: None
414 |
415 | STOP:
416 | summary: stop and set the state of the task
417 | usage:
418 | [1] API call: stop(task_status=)
419 | [2] Args:
420 | - task_status: 'Decide on whether the task is a success or failure. Choose one of ''success'', ''failure''.'
421 | [3] Example: stop(task_status="success")
422 | [4] Return: None
423 |
424 | SCREEN_TAG: |-
425 | CLICK_ELEMENT:
426 | summary: click on the visible ui element on the screen
427 | usage:
428 | [1] API call: click_element(idx=)
429 | [2] Args:
430 | - idx: 'The index the element to be clicked, as shown on the screen.'
431 | [3] Example: click_element(idx=10)
432 | [4] Return: None
433 |
434 | SCROLL:
435 | summary: move the scrollable content, or open the app drawer
436 | explanation:
437 | - Scrolling down typically means moving scrollable content to see what's further down. If content is moving up, you are scrolling down.
438 | - Scrolling up could either open the app drawer, or move back to view the previous content on the same screen.
439 | usage:
440 | [1] API call: scroll(direction=)
441 | [2] Args:
442 | - direction: 'Scroll direction. One of ''up'', ''down'', ''left'', ''right''.'
443 | [3] Example: scroll(direction="up")
444 | [4] Return: None
445 |
446 | INPUT:
447 | summary: input text to an editable input area
448 | usage:
449 | [1] API call: input(text="")
450 | [2] Args:
451 | - text: 'The text input to the editable input area.'
452 | [3] Example: input(text="Hello World")
453 | [4] Return: None
454 |
455 | PRESS_ENTER:
456 | summary: confirm the input, or submit the input, or start a new line of text
457 | usage:
458 | [1] API call: press_enter()
459 | [2] Args: None
460 | [3] Example: press_enter()
461 | [4] Return: None
462 |
463 | PRESS_HOME:
464 | summary: directly move back to the home screen
465 | usage:
466 | [1] API call: press_home()
467 | [2] Args: None
468 | [3] Example: press_home()
469 | [4] Return: None
470 |
471 | PRESS_BACK:
472 | summary: return to the most recently visited screen or interface
473 | usage:
474 | [1] API call: press_back()
475 | [2] Args: None
476 | [3] Example: press_back()
477 | [4] Return: None
478 |
479 | STOP:
480 | summary: stop and set the state of the task
481 | usage:
482 | [1] API call: stop(task_status=)
483 | [2] Args:
484 | - task_status: 'Decide on whether the task is a success or failure. Choose one of ''success'', ''failure''.'
485 | [3] Example: stop(task_status="success")
486 | [4] Return: None
487 |
488 | ACTION_RESULT:
489 | SYSTEM: |-
490 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones.
491 | Your task is to explain the results made by last action, and the its influece towards the completion of user request.
492 | - You are provided with two screenshots of the moblie phone, one from the last step and one from the current screen.
493 | - You are provided with the details of the last action you performed on the last screenshot.
494 | - You are required to describe the direct consequences of the last action by comparing the two screenshots.
495 | - You are required to judge on whether this action has made progress towards the user request.
496 |
497 | ## Example 1: "By clicking on the search bar, now I can enter the query into the input area. This helps to complete the user query, because one has to use search engine to find the cheapest chair in walmart."
498 | ## Example 2: "By pressing the home button, I have moved back to the home screen. This is necessary because to install the user required app, we must fine the App store first."
499 |
500 | USER: |-
501 | : {before_screenshot}
502 | : {after_screenshot}
503 | : {last_action}
504 | : {user_request}
505 | :
506 |
--------------------------------------------------------------------------------