├── coat ├── __init__.py ├── .DS_Store ├── utils.py ├── config.py ├── screen_utils.py ├── evaluate.py ├── action_utils.py ├── model.py ├── agent.py └── config.yaml ├── assets ├── cmp.png └── intro-total.png ├── data-example └── GOOGLE_APPS-523638528775825151 │ ├── GOOGLE_APPS-523638528775825151_0.png │ ├── GOOGLE_APPS-523638528775825151_1.png │ ├── GOOGLE_APPS-523638528775825151_2.png │ ├── GOOGLE_APPS-523638528775825151_3.png │ └── GOOGLE_APPS-523638528775825151.json ├── README.md └── run_coat.py /coat/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/cmp.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/assets/cmp.png -------------------------------------------------------------------------------- /coat/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/coat/.DS_Store -------------------------------------------------------------------------------- /assets/intro-total.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/assets/intro-total.png -------------------------------------------------------------------------------- /data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png -------------------------------------------------------------------------------- /data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png -------------------------------------------------------------------------------- /data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png -------------------------------------------------------------------------------- /data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IMNearth/CoAT/HEAD/data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png -------------------------------------------------------------------------------- /coat/utils.py: -------------------------------------------------------------------------------- 1 | from colorama import init, Fore, Style 2 | 3 | 4 | def print_with_color(text, color): 5 | colors = { 6 | 'red': Fore.RED, 7 | 'green': Fore.GREEN, 8 | 'yellow': Fore.YELLOW, 9 | 'blue': Fore.BLUE, 10 | 'magenta': Fore.MAGENTA, 11 | 'cyan': Fore.CYAN, 12 | 'white': Fore.WHITE, 13 | 'reset': Fore.RESET 14 | } 15 | 16 | color_code = colors.get(color.lower(), Fore.RESET) 17 | print(color_code + text) 18 | 19 | -------------------------------------------------------------------------------- /coat/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | _C = CN() 4 | _C.DEMO_NAME = "" # 当前实验的name 5 | _C.DEMO_MODE = "" # demo模式 6 | _C.OUTPUT_DIR = "" # demo输出的保存路径 7 | 8 | 9 | _C.DATA = CN() # ----- 数据集相关配置 ----- 10 | _C.DATA.DATA_DIR = "" # 数据集路径 11 | _C.DATA.SPLIT = "" # 进行测试的数据集划分 12 | 13 | 14 | _C.MODEL = CN() # ----- 模型API相关配置 ----- 15 | _C.MODEL.NAME = "" 16 | _C.MODEL.OPENAI_API_KEY = "" 17 | _C.MODEL.OPENAI_API_URL = "" 18 | _C.MODEL.GEMINI_API_KEY = "" 19 | _C.MODEL.GEMINI_MODEL = "gemini-pro-vision" 20 | _C.MODEL.DASHSCOPE_API_KEY = "" 21 | _C.MODEL.DASHSCOPE_MODEL = "qwen-vl-max" 22 | 23 | 24 | _C.PROMPTS = CN() # ----- 提示词相关配置 ----- 25 | 26 | _C.PROMPTS.SCREEN_DESC = CN() 27 | _C.PROMPTS.SCREEN_DESC.SYSTEM = "" 28 | _C.PROMPTS.SCREEN_DESC.USER = "" 29 | 30 | _C.PROMPTS.ACTION_THINK = CN() 31 | _C.PROMPTS.ACTION_THINK.SYSTEM = "" 32 | _C.PROMPTS.ACTION_THINK.USER = "" 33 | 34 | _C.PROMPTS.ACTION_DESC = CN() 35 | _C.PROMPTS.ACTION_DESC.SYSTEM = "" 36 | _C.PROMPTS.ACTION_DESC.USER = "" 37 | 38 | _C.PROMPTS.ACTION_PREDICT = CN() 39 | _C.PROMPTS.ACTION_PREDICT.SYSTEM = "" 40 | _C.PROMPTS.ACTION_PREDICT.USER = "" 41 | 42 | _C.PROMPTS.ACTION_RESULT = CN() 43 | _C.PROMPTS.ACTION_RESULT.SYSTEM = "" 44 | _C.PROMPTS.ACTION_RESULT.USER = "" 45 | 46 | 47 | def get_cfg_defaults(): 48 | """Get a yacs CfgNode object with default values for my_project.""" 49 | # Return a clone so that the defaults will not be altered 50 | # This is for the "local variable" use pattern 51 | return _C.clone() 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

Android in the Zoo:
Chain-of-Action-Thought for GUI Agents

3 |
4 |

Jiwen Zhang1,2 , Jihao Wu2 , Yihua Teng2 , Minghui Liao2 , Nuo Xu2 , Xiao Xiao2 , Zhongyu Wei1 , Duyu Tang2. 5 |

6 |

1Fudan University 2Huawei Inc.

7 |

8 | 9 | 10 | 11 | 12 |

13 |

14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 |

22 | 23 | -------------- 24 | 25 | This work presents **Chain-of-Action-Thought** (dubbed **CoAT**), which takes the description of the previous actions, the current screen, and more importantly the action thinking of what actions should be performed and the outcomes led by the chosen action. To enable an adaptive learning of CoAT process, we construct a benchmark **Android-In-The-Zoo**, which contains 18,643 screen-action pairs together with CoAT annotations. 26 | 27 |
28 | 29 |
30 | 31 | 32 | ## 📣 Update 33 | 34 | - **[2024-10-15]** Evaluation code has been released! 35 | 36 | - **[2024-09-20]** Our work has been accepted to EMNLP2024 Findings! 37 | 38 | - **[2024-07-16]** We add the demo code for using CoAT on proprietary models (GPT4V, Gemini-Pro and Qwen-VL-Max)! 39 | 40 | - **[2024-03-31]** We release the first version of our AiTZ dataset! 41 | 42 | - **[2024-03-05]** We have our paper arxived, now you can acess it by clicking [here](https://arxiv.org/abs/2403.02713) ! 43 | 44 | 45 | 46 | ## Android-in-the-Zoo 47 | 48 | The data in AiTZ has 18,643 screens together with 2500+ instructions, all annotated with CoAT-driven semantic labels. The sample format for each time step is 49 | 50 | ```json 51 | { 52 | "episode_id": "523638528775825151", 53 | "episode_length": 4, 54 | "step_id": 0, 55 | "coat_screen_desc": "[observe]", 56 | "coat_action_think": "[action think]", 57 | "coat_action_desc": "[next action description]", 58 | "coat_action_result": "[action result]", 59 | ... 60 | } 61 | ``` 62 | 63 | You can refer to `data-example` folder for a more specific example. 64 | 65 | 66 | ### Download 67 | 68 | Our dataset ([GoogleDrive](https://drive.google.com/file/d/12xOV2m62fBUFLhMcWIsFiC6zwV7a2RhI/view?usp=sharing) or [BaiduNetdisk](https://pan.baidu.com/s/1dHG-4L0RE1aYINzMSA4dCw?pwd=7g82)) contains both the screens (.png) and the annotations (.json), consuming about 2.6G device space. 69 | 70 | 71 | ### Statistics 72 | 73 | | Subset | Train | | Test | | 74 | | ----------- | ---------- | --------- | ---------- | --------- | 75 | | | \#Episodes | \#Screens | \#Episodes | \#Screens | 76 | | General | 323 | 2405 | 156 | 1202 | 77 | | Install | 286 | 2519 | 134 | 1108 | 78 | | GoogleApps | 166 | 1268 | 76 | 621 | 79 | | Single | 844 | 2594 | 0 | 0 | 80 | | WebShopping | 379 | 5133 | 140 | 1793 | 81 | | **Total** | **1998** | **13919** | **506** | **4724** | 82 | 83 | 84 | 85 | ## Chain-of-Action-Thought 86 | 87 | ### Comparison with other context modeling methods 88 | 89 | We validate the effectiveness of CoAT by conducting a preliminary experiment on 50 episodes randomly sampled from AITW dataset. 90 | 91 | The compared baselines are [Chain-of-Thought](https://arxiv.org/abs/2201.11903) (CoT) and [Chain-of-Actions](https://arxiv.org/abs/2309.11436) (CoA). 92 | 93 | | Prompt | Metric | QwenVL | Gemini-PV | GPT-4V | 94 | | ------ | ------ | ------ | --------- | ------ | 95 | | CoA | hit | 94.5 | 99.8 | 99.3 | 96 | | | acc | 44.4 | 47.7 | 62.8 | 97 | | CoT | hit | 95.6 | 97.5 | 97.1 | 98 | | | acc | 49.4 | 52.0 | 64.1 | 99 | | CoAT | hit | 96.3 | 96.4 | 98.2 | 100 | | | acc | 52.4 | 54.5 | 73.5 | 101 | 102 | where “hit” means format hit rate, and “acc” means action type prediction accuracy. (One can refer to Table 8 in our paper for more details.) 103 | 104 | 105 | 106 | 107 | ### CoAT demo usage 108 | 109 | Here we provide a demo code for anyone who wants to try the CoAT on GPT-4V, Qwen-VL-Max and Gemini-1.0-Pro-Vision. 110 | 111 | Firstly, go to `coat/config.yaml` and add your own api-keys and urls. 112 | 113 | Secondly, run the folloiwng code in commad line to generate sematic components of CoAT framework: 114 | 115 | ```shell 116 | python run_coat.py --task "flow" --DEMO_MODE "COAT" --MODEL.NAME "openai/gemini/qwenvl" --num-threads 3 117 | ``` 118 | 119 | Then, you can obtain the action prediction results by 120 | 121 | ```shell 122 | python run_coat.py --task "predict" --DEMO_MODE "COAT" --MODEL.NAME "openai/gemini/qwenvl" --num-threads 3 123 | ``` 124 | 125 | 126 | 127 | 128 | 129 | ## Citation 130 | 131 | If you find our work helpful, please consider citing our paper. 132 | 133 | ``` 134 | @misc{zhang2024android, 135 | title={Android in the Zoo: Chain-of-Action-Thought for GUI Agents}, 136 | author={Jiwen Zhang and Jihao Wu and Yihua Teng and Minghui Liao and Nuo Xu and Xiao Xiao and Zhongyu Wei and Duyu Tang}, 137 | year={2024}, 138 | eprint={2403.02713}, 139 | archivePrefix={arXiv}, 140 | primaryClass={cs.CL} 141 | } 142 | ``` 143 | 144 | -------------------------------------------------------------------------------- /run_coat.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import yaml 4 | import time 5 | import random 6 | import pandas as pd 7 | from collections import defaultdict 8 | from yacs.config import CfgNode as CN 9 | from tqdm import tqdm 10 | 11 | from coat.model import get_model 12 | from coat.action_utils import extract_gt_action 13 | from coat.agent import ScreenAgent 14 | import threading 15 | 16 | 17 | class AITZDataset(object): 18 | """ 19 | Creating a custom dataset for reading the processed AITW data 20 | with GPT-4V labeled detail semantic annotations. 21 | """ 22 | DATASET_DIR = { 23 | 'general': '{}/general', 24 | 'google_apps': '{}/google_apps', 25 | 'install': '{}/install', 26 | 'single': '{}/single', 27 | 'web_shopping': '{}/web_shopping', 28 | } 29 | 30 | def __init__(self, split="test", data_dir="android-in-the-zoo", ratio=1.0, double_sample=False) -> None: 31 | self.ratio = ratio 32 | self.double_sample = double_sample 33 | self.data_dir = os.path.join(data_dir, split) 34 | 35 | self.episode_data = self._load_data_() 36 | self.data = self._split_to_steps_(self.episode_data) 37 | 38 | def _load_data_(self): 39 | valid_paths = defaultdict(list) 40 | for subset in self.DATASET_DIR: 41 | subdata_dir = self.DATASET_DIR[subset].format(self.data_dir) 42 | if os.path.exists(subdata_dir): 43 | sequence_names = os.listdir(subdata_dir) 44 | for seq_name in sequence_names: 45 | seq_dir = os.path.join(subdata_dir, seq_name) 46 | if not os.path.isdir(seq_dir): continue 47 | episode_path = os.path.join(seq_dir, f"{seq_name}.json") 48 | valid_paths[subset].append(episode_path) 49 | 50 | sampled_paths = [] 51 | for subset, v_paths in valid_paths.items(): 52 | N = len(v_paths) 53 | k=int(self.ratio*N) 54 | # random.sample(v_paths, k) if self.ratio < 1.0 else v_paths 55 | sampled_paths += random.sample(v_paths, k) if self.ratio < 1.0 else v_paths 56 | 57 | ep_data = [] 58 | for episode_path in sampled_paths: 59 | episode_data = json.load(open(episode_path, "r")) 60 | ep_data.append(episode_data) 61 | return ep_data 62 | 63 | def _split_to_steps_(self, episode_data): 64 | data = [] 65 | for edx, episode in enumerate(episode_data): 66 | history_plain_actions, history_coat_actions = [], [] 67 | for idx, step in enumerate(episode): 68 | step['subset'] = step['image_path'].split('/')[0] 69 | step['image_full_path'] = os.path.join(self.data_dir, step['image_path']) 70 | step['prev_step_id'] = episode[idx-1]['step_id'] if idx > 0 else None 71 | next_img_path = os.path.join(self.data_dir, episode[idx+1]['image_path']) \ 72 | if idx + 1 < len(episode) else None 73 | step['next_image_full_path'] = next_img_path 74 | step['history_actions'] = history_plain_actions[:] 75 | step['history_coat_actions'] = history_coat_actions[:] 76 | step['result_action'] = extract_gt_action(step)[1] 77 | for ui_key in ['ui_positions', 'ui_text', 'ui_types']: 78 | step[ui_key] = json.loads(step[ui_key]) 79 | data.append(step) 80 | history_plain_actions.append(step['result_action']) 81 | history_coat_actions.append(step['coat_action_desc']) 82 | 83 | return data 84 | 85 | def __len__(self, ): return len(self.data) 86 | 87 | def __getitem__(self, index): return self.data[index] 88 | 89 | # ========================================================================================== 90 | 91 | def sinle_run_flow(agent:ScreenAgent, step_data:dict, save_dir:str): 92 | agent.flow(step_data, save_dir=save_dir) 93 | 94 | 95 | def sinle_run_predict(agent:ScreenAgent, step_data:dict, save_dir:str): 96 | agent.predict(step_data, save_dir=save_dir) 97 | 98 | 99 | def collect(cfg, num_threads=2, task="flow"): 100 | # cfg.MODEL.NAME = "openai" 101 | if task == "flow": todo_task = sinle_run_flow 102 | if task == "predict": todo_task = sinle_run_predict 103 | 104 | if cfg.MODEL.NAME in ["gemini", "openai"]: 105 | print("using proxy ... ") 106 | os.environ['http_proxy'] = "your_proxy" 107 | os.environ['https_proxy'] = "your_proxy" 108 | 109 | aitz = AITZDataset(cfg.DATA.SPLIT, cfg.DATA.DATA_DIR, ratio=0.1, double_sample=False) 110 | print(len(aitz), len(aitz.episode_data)) 111 | 112 | save_dir = os.path.join(cfg.OUTPUT_DIR, cfg.MODEL.NAME) 113 | agent = ScreenAgent(config=cfg) 114 | 115 | before_count = threading.active_count() + 1 116 | threads = [] 117 | last_time = time.time() 118 | pbar = tqdm(aitz.data) 119 | for idx, step_data in enumerate(pbar): 120 | thread = threading.Thread(target=todo_task, args=(agent, step_data, save_dir)) 121 | time.sleep(max(0.01-(time.time()-last_time), 0)) 122 | thread.start() 123 | last_time = time.time() 124 | threads.append(thread) 125 | pbar.set_description(f"Active threads [{threading.active_count()-before_count}]") 126 | while threading.active_count() - before_count >= num_threads: time.sleep(1) 127 | 128 | if len(threads) == 100: 129 | for thr in threads: thr.join() 130 | threads = [] 131 | 132 | print() 133 | while threading.active_count() - before_count >0: time.sleep(1) 134 | for thr in threads: thr.join() 135 | 136 | # ========================================================================================== 137 | 138 | def try_model(cfg): 139 | cfg.MODEL.NAME = "qwenvl" 140 | 141 | if cfg.MODEL.NAME in ["gemini", "openai"]: 142 | os.environ['http_proxy'] = "your_proxy" 143 | os.environ['https_proxy'] = "your_proxy" 144 | 145 | prompt = "Describe the image." 146 | image_path = "android-in-the-wild/aitw_with_gpt/test/general/GENERAL-1359994677477286277/GENERAL-1359994677477286277_2.png" 147 | model = get_model(cfg.MODEL, seed=2024) 148 | res_json, res_state = model.get_response(image_path, prompt) 149 | print(res_json, res_state) 150 | pass 151 | 152 | # ========================================================================================== 153 | 154 | if __name__ == "__main__": 155 | import argparse 156 | 157 | parser = argparse.ArgumentParser(description="CoAT Demo") 158 | parser.add_argument("--config-file", default="coat/config.yaml", metavar="FILE", help="path to config file",) 159 | parser.add_argument("--task", default="eval", type=str, choices=['try', 'flow', 'predict', 'eval']) 160 | parser.add_argument("--num-threads", default=1, type=int, help="number of threads") 161 | parser.add_argument("--seed", default=2020, type=int, help="random seed") 162 | parser.add_argument("opts", help="Modify config options using the command-line", 163 | default=None, nargs=argparse.REMAINDER, ) 164 | args = parser.parse_args() 165 | print(args) 166 | random.seed(args.seed) 167 | 168 | cfg = CN(yaml.safe_load(open(args.config_file))) 169 | cfg.merge_from_list(args.opts) 170 | 171 | if args.task == "try": 172 | try_model(cfg) 173 | if args.task in ['flow', 'predict']: 174 | collect(cfg, num_threads=args.num_threads, task=args.task) 175 | 176 | -------------------------------------------------------------------------------- /coat/screen_utils.py: -------------------------------------------------------------------------------- 1 | from typing import TYPE_CHECKING, Optional, Tuple, Union, Callable, List, Any, Generator, Dict 2 | from PIL import Image, ImageDraw, ImageFont 3 | from functools import cmp_to_key 4 | import numpy as np 5 | 6 | FONT_PATH = "/System/Library/Fonts/Helvetica.ttc" 7 | 8 | # ======================================================================================== 9 | 10 | def node_compare(node1, node2): 11 | if node1 is None and node2 is None: 12 | return 0 13 | if node1 is None: 14 | return 1 15 | if node2 is None: 16 | return -1 17 | 18 | l1, t1, r1, b1 = node1['bounds'] # top_left, right_bot 19 | l2, t2, r2, b2 = node2['bounds'] 20 | 21 | base_height = 0 22 | if t1 > t2: # node1的上边沿 低于 node2的上边沿 23 | base_height = (t2 + b2) / 2 24 | if t1 > base_height: return 1 # 说明node1的左上角在node2中心以下 --》 node1在node2的下一层 25 | elif t1 < t2: # node1的上边沿 高于 node2的上边沿 26 | base_height = (t1 + b1) / 2 27 | if t2 > base_height: return -1 # 说明node2的左上角在node1的中心以下 --》 node2在node1的下一层 28 | 29 | width_diff = l1 - l2 30 | if width_diff > 0: return 1 # node1在node2的右边 31 | elif width_diff < 0: return -1 32 | 33 | return 0 34 | 35 | 36 | def row_col_sort(nodes): 37 | """ 节点排序:从左到右, 从上到下 38 | 39 | 从上到下先判断元素是否在一行内, 只对同一行的元素从左往右进行排序 40 | """ 41 | if len(nodes) <= 1: return nodes 42 | 43 | # 首先按照y轴进行排序, 需要同时考虑 y 轴的最小值和中心点的 y 轴坐标 44 | # nodes.sort(key=lambda x: (x['bounds'][1], (x['bounds'][1] + x['bounds'][3])*0.5)) 45 | 46 | # 更新排序规则 -- 参考OCR 47 | nodes.sort(key=cmp_to_key(node_compare)) 48 | 49 | first_node = nodes[0] 50 | other_node = nodes[1:] 51 | sort_dots = [[first_node]] 52 | 53 | line_index = 0 54 | for node in other_node: 55 | center_node_y = (node['bounds'][1] + node['bounds'][3]) * 0.5 56 | 57 | line_nodes = sort_dots[line_index] 58 | prev_avg_center_y = sum([(x['bounds'][1] + x['bounds'][3])*0.5 for x in line_nodes]) / len(line_nodes) 59 | 60 | if (node['bounds'][1] < prev_avg_center_y < node['bounds'][3]) \ 61 | and (line_nodes[-1]['bounds'][1] < center_node_y < line_nodes[-1]['bounds'][3]): 62 | # 当前行的平均 y 中心点位于当前结点的 y 轴范围内, 并且 63 | # Y轴中心大于上一个点Y轴最小值、小于上一个点Y轴最大值 => 说明在同一行 64 | sort_dots[line_index].append(node) 65 | else: # 第二行或新增一行 66 | line_index += 1 67 | sort_dots.append([node]) 68 | 69 | for dot in sort_dots: # 对每一行做X轴最小点排序 70 | dot.sort(key=lambda x: x['bounds'][0]) 71 | 72 | new_nodes = [dot for dots in sort_dots for dot in dots] 73 | return new_nodes 74 | 75 | # ======================================================================================== 76 | 77 | 78 | def draw_bbox(image_path, bboxes:List[Tuple[float, float, float, float]], 79 | texts=None, rgba_color=(0, 0, 255, 0), thickness=1, ret_corrds=False): 80 | """ Draw the bounding boxes with corresponding texts """ 81 | image = Image.open(image_path) 82 | w, h = image.size 83 | 84 | text_coords = [] 85 | 86 | with image.convert('RGBA') as base: 87 | tmp = Image.new("RGBA", base.size, (0, 0, 0, 0)) 88 | # get a drawing context 89 | draw = ImageDraw.Draw(tmp) 90 | for idx, bbox in enumerate(bboxes): 91 | xmin, ymin, xmax, ymax = bbox[:4] 92 | xmin = min(max(0, xmin), w) 93 | xmax = min(max(xmin, xmax), w) 94 | ymin = min(max(0, ymin), h) 95 | ymax = min(max(ymin, ymax), h) 96 | # draw the boudning box 97 | draw.rectangle((xmin, ymin, xmax, ymax), outline=rgba_color, width=thickness) 98 | 99 | # draw text if any 100 | if texts: 101 | text = texts[idx] 102 | box_height, box_width = ymax - ymin, xmax - xmin 103 | font_size = int(min(max(int(box_height * 0.7), 14), 36)) 104 | 105 | font = ImageFont.truetype(FONT_PATH, font_size, encoding="utf-8") 106 | left, top, right, bot = font.getbbox(text) 107 | coords = [ 108 | xmin, ymin, 109 | xmin + right*1.1, ymin, 110 | xmin + right*1.1, ymin - bot*1.1, 111 | xmin, ymin - bot*1.1 112 | ] 113 | draw.polygon(coords, fill=rgba_color) 114 | draw.text((xmin, ymin - bot*1.05), text, (255,255,255,255), font=font) 115 | text_coords.append([coords[0], coords[5], coords[2], coords[1]]) 116 | 117 | out = Image.alpha_composite(base, tmp) 118 | 119 | if ret_corrds: return out, text_coords 120 | return out 121 | 122 | 123 | def enlarge_bbox(bbox_list, scale_factor=1.2)->np.ndarray: 124 | """ 125 | 将每个 bounding box 放大一定倍数。 126 | 127 | :param bbox_list: bounding box 列表, 每个 bbox 是一个包含四个值的元组或列表, 表示 (xmin, ymin, xmax, ymax) 128 | :param scale_factor: 放大倍数 129 | :return: 放大后的 bounding box 列表 130 | """ 131 | bbox_array = np.array(bbox_list) 132 | x_min, y_min, x_max, y_max = \ 133 | bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3] 134 | 135 | # 计算每个 bounding box 的中心点 136 | x_center = (x_min + x_max) / 2 137 | y_center = (y_min + y_max) / 2 138 | 139 | # 计算每个 bounding box 的宽度和高度 140 | width = (x_max - x_min) * scale_factor 141 | height = (y_max - y_min) * scale_factor 142 | 143 | # 计算放大后的 bounding box 的新的坐标 144 | new_x_min = x_center - width / 2 145 | new_y_min = y_center - height / 2 146 | new_x_max = x_center + width / 2 147 | new_y_max = y_center + height / 2 148 | 149 | # 将新的坐标组合成 bounding box 列表 150 | enlarged_bbox_list = np.vstack((new_x_min, new_y_min, new_x_max, new_y_max)).T 151 | 152 | return enlarged_bbox_list 153 | 154 | 155 | 156 | def check_inside(x, y, bbox_array): 157 | """ 158 | 判断一个坐标 (x, y) 是否在一个 bounding box 列表里面, 使用 NumPy 以提高效率。 159 | 同时返回所在的所有 bounding box 的坐标。 160 | 161 | :param x: 坐标的 x 值 162 | :param y: 坐标的 y 值 163 | :param bbox_array: bounding box 列表, 每个 bbox 是一个包含四个值的元组或列表, 表示 (xmin, ymin, xmax, ymax) 164 | :return: 一个元组, 第一个元素是布尔值, 如果坐标在任意一个 bounding box 内为 True, 否则为 False; 165 | 第二个元素是包含所有所在 bounding box 坐标的列表 166 | """ 167 | x_min, y_min, x_max, y_max = bbox_array[:, 0], bbox_array[:, 1], bbox_array[:, 2], bbox_array[:, 3] 168 | 169 | # 检查 (x, y) 是否在任意一个 bounding box 内 170 | within_x = (x_min <= x) & (x <= x_max) 171 | within_y = (y_min <= y) & (y <= y_max) 172 | within_bbox = within_x & within_y 173 | 174 | if np.any(within_bbox): 175 | within_bbox_coords = bbox_array[within_bbox] 176 | return True, within_bbox_coords 177 | else: 178 | return False, None 179 | 180 | 181 | def intersect_iou(can_bbox, ref_bboxes): 182 | """ 183 | 计算一个边界框和一组边界框的IoU。 184 | 185 | 参数: 186 | - can_bbox: NumPy数组或列表list, 形状为[4,], 表示一个边界框[x_min, y_min, x_max, y_max] 187 | - ref_bboxes: NumPy数组, 形状为[N, 4], 表示N个边界框 188 | 189 | 返回: 190 | - ious: NumPy数组, 形状为[N,], 表示输入边界框和每个边界框的IoU 191 | """ 192 | # 计算交集的坐标 193 | inter_xmin = np.maximum(can_bbox[0], ref_bboxes[:, 0]) 194 | inter_ymin = np.maximum(can_bbox[1], ref_bboxes[:, 1]) 195 | inter_xmax = np.minimum(can_bbox[2], ref_bboxes[:, 2]) 196 | inter_ymax = np.minimum(can_bbox[3], ref_bboxes[:, 3]) 197 | 198 | # 计算交集的面积 199 | inter_area = np.maximum(0, inter_xmax - inter_xmin) * \ 200 | np.maximum(0, inter_ymax - inter_ymin) 201 | 202 | # 计算候选边界框的面积 203 | can_bbox_area = np.maximum((can_bbox[2] - can_bbox[0]) * \ 204 | (can_bbox[3] - can_bbox[1]), 1) 205 | # 计算参考边界框的面积 206 | ref_bboxes_area = np.maximum((ref_bboxes[:, 2] - ref_bboxes[:, 0]) * \ 207 | (ref_bboxes[:, 3] - ref_bboxes[:, 1]), 1) 208 | 209 | # 计算并集的面积 210 | union_area = can_bbox_area + ref_bboxes_area - inter_area 211 | 212 | # 计算IoU 213 | ious = inter_area / union_area 214 | return ious -------------------------------------------------------------------------------- /coat/evaluate.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import imagesize 5 | import numpy as np 6 | import pandas as pd 7 | import Levenshtein 8 | from collections import defaultdict 9 | from .screen_utils import row_col_sort, enlarge_bbox, check_inside, intersect_iou 10 | 11 | 12 | class ActionEvaluator(object): 13 | BBOX_PATTERN = re.compile(r'\[ *(\d+) *, *(\d+) *, *(\d+) *, *(\d+) *\]') 14 | 15 | def __init__(self, demo_mode, screen_mode) -> None: 16 | self.demo_mode = demo_mode 17 | self.screen_mode = screen_mode 18 | 19 | self._all_action_types_ = [ 20 | "click", 21 | "scroll", 22 | "type", 23 | "press", 24 | "stop" 25 | ] 26 | 27 | def action_map(self, action_api:str): 28 | if not action_api: return None 29 | action_api = action_api.lower() 30 | if action_api == "input": return "type" 31 | for act_type in self._all_action_types_: 32 | if act_type in action_api: return act_type 33 | return None 34 | 35 | def _parse_action_(self, pred, w, h): 36 | pr = pred.get('action_predict', {}) 37 | if self.demo_mode not in pr: return (None,) * 6 38 | 39 | action = pr[self.demo_mode].get(self.screen_mode, {}) 40 | if not action: return (None,) * 6 41 | 42 | pd_action_type = self.action_map(action.get('ACTION', None)) 43 | # if pd_action_type is None: print(action) 44 | 45 | pd_action_args = action.get('ARGS', {}) 46 | if not isinstance(pd_action_args, dict): pd_action_args = {} 47 | pd_action_bbox = pd_action_args.get('bbox', None) 48 | if pd_action_bbox is not None: 49 | xmin, ymin, xmax, ymax = pd_action_bbox[:4] 50 | xmin = round(xmin/1000 * w) 51 | ymin = round(ymin/1000 * h) 52 | xmax = round(xmax/1000 * w) 53 | ymax = round(ymax/1000 * h) 54 | pd_action_bbox = [xmin, ymin, xmax, ymax] 55 | 56 | pd_action_idx = pd_action_args.get('idx', None) 57 | if pd_action_idx: 58 | try: pd_action_idx = int(pd_action_idx) 59 | except: pd_action_idx = None 60 | pd_action_direction = pd_action_args.get('direction', None) 61 | pd_action_text = pd_action_args.get('text', "") 62 | pd_action_button = None if pd_action_type != "press" else \ 63 | action['ACTION'].split("_")[1].lower() 64 | 65 | return pd_action_type, pd_action_bbox, pd_action_idx, \ 66 | pd_action_direction, pd_action_text, pd_action_button 67 | 68 | def _parse_answer_(self, gt): 69 | gt_words = gt['coat_action_desc'].split(' ') 70 | 71 | gt_action_type = self.action_map(gt_words[0]) 72 | if gt_action_type is None: print(gt['subset'], gt['episode_id']) 73 | gt_action_text = gt['result_action_text'] 74 | gt_action_direction = "" if gt_action_type != "scroll" else gt_words[1].strip() 75 | gt_action_button = "" 76 | if gt_action_type == "press": 77 | for button in ['enter', 'back', 'home']: 78 | if button in gt['coat_action_desc']: 79 | gt_action_button = button 80 | break 81 | 82 | w, h = imagesize.get(gt['image_full_path']) 83 | gt_action_xy = [0, 0] 84 | if gt_action_type == "scroll": 85 | rel_y, rel_x = json.loads(gt['result_touch_yx']) 86 | abs_y, abs_x = int(rel_y*h), int(rel_x*w) 87 | gt_action_xy = [abs_x, abs_y] 88 | 89 | gt_cand_nodes = [] 90 | for org_bbox, txt, ui_class in zip(gt['ui_positions'], gt['ui_text'], gt['ui_types']): 91 | ymin, xmin, h, w = org_bbox 92 | bbox = [xmin, ymin, xmin+w, ymin+h] 93 | gt_cand_nodes.append({"bounds": bbox, "text": txt, "type": ui_class}) 94 | gt_cand_nodes = row_col_sort(gt_cand_nodes) 95 | 96 | return gt_action_type, gt_action_xy, gt_cand_nodes, \ 97 | gt_action_text, gt_action_button, gt_action_direction 98 | 99 | def _check_click_(self, pred_bbox, gt_xy, gt_nodes): 100 | # gt_xy is within pred_bbox 101 | if not pred_bbox: return False 102 | pred_bbox = enlarge_bbox([pred_bbox])[0] 103 | xmin, ymin, xmax, ymax = pred_bbox 104 | gt_x, gt_y = gt_xy 105 | is_correct = (xmin <= gt_x <= xmax and ymin <= gt_y <= ymax) 106 | if is_correct: return True 107 | 108 | # gt_xy is within any bbox 109 | bbox_array = enlarge_bbox([x['bounds'] for x in gt_nodes], scale_factor=1.2) 110 | is_inside, bbox_inside = check_inside(gt_x, gt_y, bbox_array) 111 | if is_inside: 112 | ious = intersect_iou(pred_bbox, bbox_inside) 113 | if np.any(ious > 0.5): return True 114 | 115 | return False 116 | 117 | def __call__(self, gt, pred): 118 | """ eval_single_step """ 119 | subset, episode_id, step_id = gt['subset'], gt['episode_id'], gt['step_id'] 120 | w, h = imagesize.get(gt['image_full_path']) 121 | 122 | # get ground truth information 123 | gt_action_type, gt_action_xy, gt_cand_nodes, \ 124 | gt_action_text, gt_action_button, gt_action_direction = self._parse_answer_(gt) 125 | if not gt_action_type: print(gt['coat_action_desc']) 126 | gt_action_detail = { 127 | "click": gt_action_xy, "scroll": gt_action_direction, 128 | "type": gt_action_text, "press": gt_action_button, "stop": "stop" 129 | }.get(gt_action_type, None) 130 | 131 | # get predict action information 132 | pd_action_type, pd_action_bbox, pd_action_idx, \ 133 | pd_action_direction, pd_action_text, pd_action_button = self._parse_action_(pred, w, h) 134 | 135 | # compute metrics 136 | hit_format = True if pd_action_type is not None else False 137 | type_match = (pd_action_type is not None and gt_action_type == pd_action_type) 138 | 139 | exact_match = False 140 | pd_action_detail = None 141 | text_dist = None 142 | if type_match and pd_action_type == "click": 143 | if self.screen_mode == "tag" and pd_action_idx: # transform idx into bbox 144 | if 0 <= pd_action_idx < len(gt_cand_nodes): 145 | pd_action_bbox = gt_cand_nodes[pd_action_idx]['bounds'] 146 | pd_action_detail = pd_action_bbox 147 | exact_match = self._check_click_(pd_action_bbox, gt_action_xy, gt_cand_nodes) 148 | 149 | if type_match and pd_action_type == "scroll": 150 | pd_action_detail = pd_action_direction 151 | exact_match = (pd_action_direction == gt_action_direction) 152 | 153 | if type_match and pd_action_type == "type": 154 | pd_action_detail = pd_action_text 155 | text_dist = Levenshtein.ratio(pd_action_text, gt_action_text) 156 | exact_match = (pd_action_text in gt_action_text or \ 157 | gt_action_text in pd_action_text or \ 158 | text_dist > 0.8) 159 | 160 | if type_match and pd_action_type == "press": 161 | pd_action_detail = pd_action_button 162 | exact_match = (pd_action_button == gt_action_button) 163 | 164 | if type_match and pd_action_type == "stop": 165 | pd_action_detail = "stop" 166 | exact_match = True 167 | 168 | return { 169 | "subset": subset, "episode_id": episode_id, "step_id": step_id, 170 | "answer": {"action_type": gt_action_type, "action_detail": gt_action_detail}, 171 | "pred": {"action_type": pd_action_type, "action_detail": pd_action_detail}, 172 | "type_match": type_match, "exact_match": exact_match, 173 | "text_dist": text_dist, "format_hit": hit_format 174 | } 175 | 176 | def compute_episode_metrics(self, episode_results): 177 | success, progress = [], [] 178 | for __, eplist in episode_results.items(): 179 | ep_success, ep_progress = True, 0 180 | for ex in eplist: 181 | if ex['exact_match'] is True: ep_progress += 1 182 | else: ep_success = False 183 | if not ep_success: break 184 | success.append(ep_success) 185 | progress.append(ep_progress/len(eplist)*1.0) 186 | 187 | return {"success_rate": round(sum(success) / len(success), 4), 188 | "goal_progress": round(sum(progress) / len(progress), 4)} 189 | 190 | def compute_atomic_metrics(self, step_results): 191 | recorder = { 192 | 'total': {'count':0, 'type_match':0, 'exact_match':0, "hit": 0}, 193 | # ------------------------------------------- 194 | 'CLICK': {'count':0, 'type_match':0, 'exact_match':0}, 195 | 'TYPE': {'count':0, 'type_match':0, 'exact_match':0, 'text_dist': []}, 196 | 'SCROLL': {'count':0, 'type_match':0, 'exact_match':0}, 197 | 'PRESS': {'count':0, 'type_match':0, 'exact_match':0}, 198 | 'STOP': {'count':0, 'type_match':0, 'exact_match':0}, 199 | } 200 | 201 | for step in step_results: 202 | recorder['total']['count'] += 1 203 | recorder['total']['hit'] += step['format_hit'] 204 | 205 | action_type = step['answer']['action_type'].upper() 206 | recorder[action_type]['count'] += 1 207 | recorder[action_type]['type_match'] += step['type_match'] 208 | recorder['total']['type_match'] += step['type_match'] 209 | recorder[action_type]['exact_match'] += step['exact_match'] 210 | recorder['total']['exact_match'] += step['exact_match'] 211 | if 'text_dist' in recorder[action_type] and step['text_dist'] is not None: 212 | recorder[action_type]['text_dist'].append(step['text_dist']) 213 | 214 | scores = {metric_key:{} for metric_key in ['total', 'CLICK', 'SCROLL', 'PRESS', 'STOP', 'TYPE']} 215 | scores['total']['hit_rate'] = round(recorder['total']['hit']/recorder['total']['count'], 4) 216 | for metric_key in ['total', 'CLICK', 'SCROLL', 'PRESS', 'STOP', "TYPE"]: 217 | scores[metric_key]['type_acc'] = round(recorder[metric_key]['type_match']/recorder[metric_key]['count'], 4) 218 | scores[metric_key]['exact_acc'] = round(recorder[metric_key]['exact_match']/recorder[metric_key]['count'], 4) 219 | if recorder['TYPE']['text_dist']: 220 | scores['TYPE']['text_dist'] = round(sum(recorder['TYPE']['text_dist'])/len(recorder['TYPE']['text_dist']), 4) 221 | return scores 222 | 223 | -------------------------------------------------------------------------------- /data-example/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "episode_id": "523638528775825151", 4 | "episode_length": 4, 5 | "step_id": 0, 6 | "instruction": "open app \"Clock\" (install if not already installed)", 7 | "ui_positions": "[[54, 17, 8, 12], [82, 15, 8, 17], [82, 37, 10, 46], [120, 15, 20, 10], [128, 43, 8, 24], [160, 15, 8, 13], [162, 43, 7, 96], [189, 17, 14, 7], [194, 43, 5, 22], [225, 15, 10, 12], [227, 43, 8, 91], [256, 15, 15, 9], [261, 43, 8, 20], [574, 51, 15, 6], [574, 207, 16, 7]]", 8 | "ui_text": "[\"M\", \"Set\", \"up email\", \"G\", \"Coogle\", \"o\", \"Outlook, Homail, and Live\", \"\", \"Yatpe\", \"O\", \"Exchange and Ofioe 365\", \"\", \"Uthe\", \"\", \"\"]", 9 | "ui_types": "[\"TEXT\", \"TEXT\", \"TEXT\", \"ICON_GOOGLE\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_ENVELOPE\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_ENVELOPE\", \"TEXT\", \"ICON_V_BACKWARD\", \"ICON_NAV_BAR_RECT\"]", 10 | "result_action_type": 6, 11 | "result_action_text": "", 12 | "result_touch_yx": "[-1.0, -1.0]", 13 | "result_lift_yx": "[-1.0, -1.0]", 14 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png", 15 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_0.png", 16 | "coat_screen_desc": "This screenshot shows the \"Set up email\" screen on a smartphone, likely within an email application setup process. The screen offers a selection of popular email service providers including Google, Outlook (Hotmail and Live), Yahoo, Exchange and Office 365, as well as an \"Other\" option for email services not listed. Users can presumably tap on one of these options to begin the process of adding their email account to the device. The top of the screen displays standard status icons such as signal strength, battery, and time, indicating it's 9:35. The bottom of the screen shows navigation buttons for back, home, and recent applications, common to Android operating systems.", 17 | "coat_action_think": "The screen does not display the \"Clock\" app or any relevant features to access it; it is focused on email setup. Possible actions are to exit the email setup screen by tapping the home or back button, and then proceed to open the \"Clock\" app from the home screen or app drawer, or install it if it is not present.", 18 | "coat_action_desc": "press the home button", 19 | "coat_action_result": "By doing so, the home screen is displayed with app icons visible. This allows access to the app drawer or search function, where the Clock app can be located and opened." 20 | }, 21 | { 22 | "episode_id": "523638528775825151", 23 | "episode_length": 4, 24 | "step_id": 1, 25 | "instruction": "open app \"Clock\" (install if not already installed)", 26 | "ui_positions": "[[56, 24, 8, 52], [366, 33, 28, 15], [368, 86, 30, 31], [373, 223, 15, 7], [408, 155, 5, 22], [409, 24, 5, 31], [409, 94, 5, 18], [409, 215, 5, 28], [464, 33, 23, 13], [465, 94, 21, 15], [526, 34, 21, 10], [529, 24, 23, 21], [573, 207, 16, 7], [575, 51, 14, 6]]", 27 | "ui_text": "[\"Man, Aug 8\", \"\", \"M\", \"\", \"Fhotca\", \"Pey Stoe\", \"imal\", \"auTuba\", \"\", \"\", \"\", \"G\", \"\", \"\"]", 28 | "ui_types": "[\"TEXT\", \"ICON_PLAY\", \"TEXT\", \"ICON_PLAY\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_CALL\", \"ICON_CHAT\", \"ICON_GOOGLE\", \"TEXT\", \"ICON_NAV_BAR_RECT\", \"ICON_V_BACKWARD\"]", 29 | "result_action_type": 4, 30 | "result_action_text": "", 31 | "result_touch_yx": "[0.541063666343689, 0.5073748230934143]", 32 | "result_lift_yx": "[0.001115699764341116, 0.5788536071777344]", 33 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png", 34 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_1.png", 35 | "coat_screen_desc": "This is a screenshot of a smartphone home screen displaying a clean and simple layout. The top of the screen shows the time as \"5:35 AM\" and the date \"Mon, Aug 8\". There are five app icons visible: Play Store, Gmail, Photos, YouTube, along with a bottom row featuring the Phone, Chrome, Messages, and Camera apps, suggesting these are likely the user's most frequently used applications. A Google search bar is at the bottom, emphasizing easy access to web searches. The overall appearance is typical of an Android interface.", 36 | "coat_action_think": "Since the \"Clock\" app is not among the visible icons on the home screen, the next step would be to search for it by opening the app drawer or using the search feature. Possible actions are to swipe up on the home screen to access the app drawer or tap the Google search bar to type in \"Clock\" and search for the app.", 37 | "coat_action_desc": "scroll up", 38 | "coat_action_result": "By doing so, the screen now displays a list of applications including the Clock app which was not visible on the home screen before the action. This action enables the user to access and open the Clock app directly from the app drawer where it is located." 39 | }, 40 | { 41 | "episode_id": "523638528775825151", 42 | "episode_length": 4, 43 | "step_id": 2, 44 | "instruction": "open app \"Clock\" (install if not already installed)", 45 | "ui_positions": "[[34, 236, 16, 5], [40, 24, 8, 138], [82, 36, 30, 15], [88, 160, 16, 9], [88, 212, 18, 24], [124, 27, 5, 33], [124, 93, 5, 22], [124, 150, 5, 28], [124, 216, 5, 18], [160, 121, 5, 7], [161, 133, 5, 15], [195, 216, 23, 17], [205, 91, 7, 24], [211, 34, 7, 18], [235, 36, 5, 17], [235, 83, 7, 43], [235, 150, 5, 28], [235, 210, 5, 21], [281, 34, 27, 15], [283, 218, 20, 10], [321, 37, 5, 15], [321, 91, 5, 15], [321, 156, 5, 18], [321, 206, 5, 7], [405, 138, 8, 45], [406, 217, 7, 15], [407, 33, 5, 21], [407, 96, 5, 13], [449, 94, 32, 18], [449, 191, 37, 43], [451, 217, 25, 15], [456, 31, 17, 24], [491, 149, 7, 31], [492, 34, 5, 17], [492, 93, 7, 22], [492, 209, 5, 33], [536, 37, 28, 12], [540, 94, 21, 15], [577, 87, 5, 34], [577, 215, 7, 21], [578, 36, 5, 24], [578, 155, 5, 20]]", 46 | "ui_text": "[\"\", \"Search your phone and more\", \"\", \"\", \"M\", \"Plaw S1ae\", \"Photcs\", \"YouTube\", \"Cmall\", \"All\", \"apps\", \"\", \"Bocking\", \"oirtel\", \"Airtel\", \"Hrokinc cor\", \"Laendar\", \"Cema\", \"\", \"\", \"Chal\", \"Chru\", \"Cleck\", \"C\", \"File Mener\", \"Fles\", \"Dagher\", \"Drve\", \"G\", \"O\", \"\", \"M\", \"HRO MAX\", \"mal\", \"Coagla\", \"Irataamy\", \"\", \"\", \"Masscces\", \"hotes\", \"Maps\", \"Phanc\"]", 47 | "ui_types": "[\"ICON_THREE_DOTS\", \"TEXT\", \"ICON_PLAY\", \"ICON_PLAY\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_TAKE_PHOTO\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_CHAT\", \"ICON_PERSON\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_GOOGLE\", \"TEXT\", \"ICON_TAKE_PHOTO\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_LOCATION\", \"ICON_CHAT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\"]", 48 | "result_action_type": 4, 49 | "result_action_text": "", 50 | "result_touch_yx": "[0.49836206436157227, 0.6069772839546204]", 51 | "result_lift_yx": "[0.49669790267944336, 0.6069772839546204]", 52 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png", 53 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_2.png", 54 | "coat_screen_desc": "This is a screenshot of a smartphone home screen displaying various application icons, suggesting a user interface likely from an Android device. The top of the screen shows the notification bar with the time, battery status, and connectivity icons. Below that is a search bar, followed by frequently used apps such as Play Store, Photos, YouTube, and Gmail. Further down, there's a categorized section labeled \"All apps\" showcasing a grid of additional applications including Airtel, Booking.com, Calendar, and others, indicating the user has access to a range of services and tools for productivity, communication, and entertainment.", 55 | "coat_action_think": "The Clock app is displayed in the list of apps under the \"All apps\" section on the home screen. Possible actions are to tap the Clock app icon to open it.", 56 | "coat_action_desc": "click on the Clock app located at the upper middle right side of the screen.", 57 | "coat_action_result": "By doing so, the Clock app has been successfully opened, providing access to features such as alarms, world clocks, and timers. This action enables the user to manage time-related functions, which is the likely intention behind the query." 58 | }, 59 | { 60 | "episode_id": "523638528775825151", 61 | "episode_length": 4, 62 | "step_id": 3, 63 | "instruction": "open app \"Clock\" (install if not already installed)", 64 | "ui_positions": "[[20, 86, 16, 9], [23, 254, 15, 3], [43, 21, 5, 20], [43, 83, 5, 18], [43, 144, 7, 20], [43, 197, 7, 37], [82, 84, 28, 96], [123, 111, 8, 45], [573, 130, 16, 7], [574, 51, 15, 6], [574, 207, 16, 7]]", 65 | "ui_text": "[\"\", \"\", \"larm\", \"Clock\", \"Tirmer\", \"Stopwatch\", \"5:35AM\", \"Mon, Aug 8\", \"\", \"\", \"\"]", 66 | "ui_types": "[\"ICON_TIME\", \"ICON_THREE_DOTS\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"TEXT\", \"ICON_SUN\", \"ICON_V_BACKWARD\", \"ICON_NAV_BAR_RECT\"]", 67 | "result_action_type": 10, 68 | "result_action_text": "", 69 | "result_touch_yx": "[-1.0, -1.0]", 70 | "result_lift_yx": "[-1.0, -1.0]", 71 | "image_path": "google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png", 72 | "image_full_path": "android-in-the-wild/aitw_with_gpt/train/google_apps/GOOGLE_APPS-523638528775825151/GOOGLE_APPS-523638528775825151_3.png", 73 | "coat_screen_desc": "The screenshot shows the clock application of a smartphone displayed at 5:35 AM, Mon, Aug 8. It offers four functionalities: Alarm, Clock, Timer, and Stopwatch, indicated by icons at the top of the screen. The main display is digital, with large numbers for the hours and minutes. At the bottom, there's a button with a clock icon, presumably for setting or accessing alarm functions. The status bar at the very top displays the current battery level, signal strength, and indicates alarm and night mode are active.", 74 | "coat_action_think": "The screen is displaying the Clock application with its four main functions visible at the top: Alarm, Clock, Timer, and Stopwatch. The task appears to be in relation to accessing the Clock app, which is already on the screen. Possible actions are to stop and set the query as completed.", 75 | "coat_action_desc": "stop and set the query as completed", 76 | "coat_action_result": "By opening the \"Clock\" app, the user can interact with features such as setting alarms, checking the time, using a timer, or a stopwatch. The reason for opening this app is to provide the user with access to time-related functions that they have requested to use." 77 | } 78 | ] -------------------------------------------------------------------------------- /coat/action_utils.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | import json 17 | import enum 18 | import imagesize 19 | # import jax 20 | import jax.numpy as jnp 21 | import numpy as np 22 | 23 | 24 | def extract_gt_action(example): 25 | ex_action_type = example['result_action_type'] 26 | 27 | if ex_action_type == ActionType.DUAL_POINT: 28 | lift_yx = json.loads(example['result_lift_yx']) 29 | touch_yx = json.loads(example['result_touch_yx']) 30 | if is_tap_action(np.array(touch_yx), np.array(lift_yx)): 31 | action_type = 'CLICK' 32 | w, h = imagesize.get(example['image_full_path']) 33 | click_y, click_x = round(lift_yx[0] * h), round(lift_yx[1] * w) 34 | action = (click_y, click_x) 35 | else: 36 | action_type = 'SCROLL' 37 | v_change = abs(touch_yx[0] - lift_yx[0]) 38 | h_change = abs(lift_yx[1] - touch_yx[1]) 39 | is_scroll_up = lift_yx[0] < touch_yx[0] # touch is lower 40 | is_scroll_left = lift_yx[1] < touch_yx[1] # touch is bigger 41 | if v_change >= 0.9*h_change: # vertical 42 | action = "scroll up" if is_scroll_up else "scroll down" 43 | else: # horizonal 44 | action = "scroll left" if is_scroll_left else "scroll right" 45 | elif ex_action_type in ( 46 | ActionType.PRESS_BACK, 47 | ActionType.PRESS_HOME, 48 | ): 49 | button = ActionType(ex_action_type).name.split('_')[1].lower() 50 | action = f'press the {button} button' 51 | action_type = f'PRESS {button}'.upper() 52 | elif ex_action_type == ActionType.PRESS_ENTER: 53 | action = "press enter" 54 | action_type = 'PRESS ENTER' 55 | elif ex_action_type == ActionType.TYPE: 56 | action_text = example['result_action_text'] 57 | action = f'input text "{action_text}"' 58 | action_type = 'INPUT' 59 | elif ex_action_type == ActionType.STATUS_TASK_COMPLETE: 60 | action = 'stop and set the query as completed' 61 | action_type = 'STOP' 62 | elif ex_action_type == ActionType.STATUS_TASK_IMPOSSIBLE: 63 | action = 'stop and set the query as impossible' 64 | action_type = 'STOP' 65 | else: 66 | raise NotImplementedError 67 | 68 | return action, action_type 69 | 70 | 71 | '''====================================== 72 | Global Args 73 | ======================================''' 74 | 75 | 76 | _TAP_DISTANCE_THRESHOLD = 0.14 # Fraction of the screen 77 | ANNOTATION_WIDTH_AUGMENT_FRACTION = 1.4 78 | ANNOTATION_HEIGHT_AUGMENT_FRACTION = 1.4 79 | 80 | # Interval determining if an action is a tap or a swipe. 81 | _SWIPE_DISTANCE_THRESHOLD = 0.04 82 | 83 | 84 | def is_tap_action(normalized_start_yx, normalized_end_yx): 85 | distance = jnp.linalg.norm( 86 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 87 | return distance <= _SWIPE_DISTANCE_THRESHOLD 88 | 89 | 90 | def _is_non_dual_point_action(action_type): 91 | return jnp.not_equal(action_type, ActionType.DUAL_POINT) 92 | 93 | 94 | '''====================================== 95 | === AndroidInTheWild action types. === 96 | ======================================''' 97 | 98 | 99 | class ActionType(enum.IntEnum): 100 | """Integer values for each supported action type in AndroidInTheWild.""" 101 | 102 | # Placeholders for unused enum values 103 | UNUSED_0 = 0 104 | UNUSED_1 = 1 105 | UNUSED_2 = 2 106 | UNUSED_8 = 8 107 | UNUSED_9 = 9 108 | 109 | ########### Agent actions ########### 110 | 111 | # A type action that sends text to the emulator. Note that this simply sends 112 | # text and does not perform any clicks for element focus or enter presses for 113 | # submitting text. 114 | TYPE = 3 115 | 116 | # The dual point action used to represent all gestures. 117 | DUAL_POINT = 4 118 | 119 | # These actions differentiate pressing the home and back button from touches. 120 | # They represent explicit presses of back and home performed using ADB. 121 | PRESS_BACK = 5 122 | PRESS_HOME = 6 123 | 124 | # An action representing that ADB command for hitting enter was performed. 125 | PRESS_ENTER = 7 126 | 127 | ########### Episode status actions ########### 128 | 129 | # An action used to indicate the desired task has been completed and resets 130 | # the environment. This action should also be used in the case that the task 131 | # has already been completed and there is nothing to do. 132 | # e.g. The task is to turn on the Wi-Fi when it is already on 133 | STATUS_TASK_COMPLETE = 10 134 | 135 | # An action used to indicate that desired task is impossible to complete and 136 | # resets the environment. This can be a result of many different things 137 | # including UI changes, Android version differences, etc. 138 | STATUS_TASK_IMPOSSIBLE = 11 139 | 140 | 141 | '''=============================================================================== 142 | === Utilites for performing action matching on AndroidInTheWild data. === 143 | === Note: this code is implemented using JAX so it can be 'vmap'ed and === 144 | === efficiently appied over a batch of data. === 145 | ===============================================================================''' 146 | 147 | 148 | def _yx_in_bounding_boxes( 149 | yx, bounding_boxes 150 | ): 151 | """Check if the (y,x) point is contained in each bounding box. 152 | 153 | Args: 154 | yx: The (y, x) coordinate in pixels of the point. 155 | bounding_boxes: A 2D int array of shape (num_bboxes, 4), where each row 156 | represents a bounding box: (y_top_left, x_top_left, box_height, 157 | box_width). Note: containment is inclusive of the bounding box edges. 158 | 159 | Returns: 160 | is_inside: A 1D bool array where each element specifies if the point is 161 | contained within the respective box. 162 | """ 163 | y, x = yx 164 | 165 | # `bounding_boxes` has shape (n_elements, 4); we extract each array along the 166 | # last axis into shape (n_elements, 1), then squeeze unneeded dimension. 167 | top, left, height, width = [ 168 | jnp.squeeze(v, axis=-1) for v in jnp.split(bounding_boxes, 4, axis=-1) 169 | ] 170 | 171 | # The y-axis is inverted for AndroidEnv, so bottom = top + height. 172 | bottom, right = top + height, left + width 173 | 174 | return jnp.logical_and(y >= top, y <= bottom) & jnp.logical_and(x >= left, x <= right) 175 | 176 | 177 | def _resize_annotation_bounding_boxes( 178 | annotation_positions, annotation_width_augment_fraction, 179 | annotation_height_augment_fraction): 180 | """Resize the bounding boxes by the given fractions. 181 | 182 | Args: 183 | annotation_positions: Array of shape (N, 4), where each row represents the 184 | (y, x, height, width) of the bounding boxes. 185 | annotation_width_augment_fraction: The fraction to augment the box widths, 186 | E.g., 1.4 == 240% total increase. 187 | annotation_height_augment_fraction: Same as described for width, but for box 188 | height. 189 | 190 | Returns: 191 | Resized bounding box. 192 | 193 | """ 194 | height_change = ( 195 | annotation_height_augment_fraction * annotation_positions[:, 2]) 196 | width_change = ( 197 | annotation_width_augment_fraction * annotation_positions[:, 3]) 198 | 199 | # Limit bounding box positions to the screen. 200 | resized_annotations = jnp.stack([ 201 | jnp.maximum(0, annotation_positions[:, 0] - (height_change / 2)), 202 | jnp.maximum(0, annotation_positions[:, 1] - (width_change / 2)), 203 | jnp.minimum(1, annotation_positions[:, 2] + height_change), 204 | jnp.minimum(1, annotation_positions[:, 3] + width_change), 205 | ],axis=1) 206 | 207 | return resized_annotations 208 | 209 | 210 | def is_tap_action(normalized_start_yx, normalized_end_yx): 211 | distance = jnp.linalg.norm( 212 | jnp.array(normalized_start_yx) - jnp.array(normalized_end_yx)) 213 | return distance <= _SWIPE_DISTANCE_THRESHOLD 214 | 215 | 216 | def _is_non_dual_point_action(action_type): 217 | return jnp.not_equal(action_type, ActionType.DUAL_POINT) 218 | 219 | 220 | def _check_tap_actions_match( 221 | tap_1_yx, 222 | tap_2_yx, 223 | annotation_positions, 224 | matching_tap_distance_threshold_screen_percentage, 225 | annotation_width_augment_fraction, 226 | annotation_height_augment_fraction, 227 | ): 228 | """Determines if two tap actions are the same.""" 229 | resized_annotation_positions = _resize_annotation_bounding_boxes( 230 | annotation_positions, 231 | annotation_width_augment_fraction, 232 | annotation_height_augment_fraction, 233 | ) 234 | 235 | # Check if the ground truth tap action falls in an annotation's bounding box. 236 | tap1_in_box = _yx_in_bounding_boxes(tap_1_yx, resized_annotation_positions) 237 | tap2_in_box = _yx_in_bounding_boxes(tap_2_yx, resized_annotation_positions) 238 | both_in_box = jnp.max(tap1_in_box & tap2_in_box) 239 | 240 | # If the ground-truth tap action falls outside any of the annotation 241 | # bounding boxes or one of the actions is inside a bounding box and the other 242 | # is outside bounding box or vice versa, compare the points using Euclidean 243 | # distance. 244 | within_threshold = ( 245 | jnp.linalg.norm(jnp.array(tap_1_yx) - jnp.array(tap_2_yx)) 246 | <= matching_tap_distance_threshold_screen_percentage 247 | ) 248 | return jnp.logical_or(both_in_box, within_threshold) 249 | 250 | 251 | def _check_drag_actions_match( 252 | drag_1_touch_yx, 253 | drag_1_lift_yx, 254 | drag_2_touch_yx, 255 | drag_2_lift_yx, 256 | ): 257 | """Determines if two drag actions are the same.""" 258 | # Store drag deltas (the change in the y and x coordinates from touch to 259 | # lift), magnitudes, and the index of the main axis, which is the axis with 260 | # the greatest change in coordinate value (e.g. a drag starting at (0, 0) and 261 | # ending at (0.3, 0.5) has a main axis index of 1). 262 | drag_1_deltas = drag_1_lift_yx - drag_1_touch_yx 263 | drag_1_magnitudes = jnp.abs(drag_1_deltas) 264 | drag_1_main_axis = np.argmax(drag_1_magnitudes) 265 | drag_2_deltas = drag_2_lift_yx - drag_2_touch_yx 266 | drag_2_magnitudes = jnp.abs(drag_2_deltas) 267 | drag_2_main_axis = np.argmax(drag_2_magnitudes) 268 | 269 | return jnp.equal(drag_1_main_axis, drag_2_main_axis) 270 | 271 | 272 | def check_actions_match( 273 | action_1_touch_yx, 274 | action_1_lift_yx, 275 | action_1_action_type, 276 | action_2_touch_yx, 277 | action_2_lift_yx, 278 | action_2_action_type, 279 | annotation_positions, 280 | tap_distance_threshold = _TAP_DISTANCE_THRESHOLD, 281 | annotation_width_augment_fraction = ANNOTATION_WIDTH_AUGMENT_FRACTION, 282 | annotation_height_augment_fraction = ANNOTATION_HEIGHT_AUGMENT_FRACTION, 283 | ): 284 | """Determines if two actions are considered to be the same. 285 | 286 | Two actions being "the same" is defined here as two actions that would result 287 | in a similar screen state. 288 | 289 | Args: 290 | action_1_touch_yx: The (y, x) coordinates of the first action's touch. 291 | action_1_lift_yx: The (y, x) coordinates of the first action's lift. 292 | action_1_action_type: The action type of the first action. 293 | action_2_touch_yx: The (y, x) coordinates of the second action's touch. 294 | action_2_lift_yx: The (y, x) coordinates of the second action's lift. 295 | action_2_action_type: The action type of the second action. 296 | annotation_positions: The positions of the UI annotations for the screen. It 297 | is A 2D int array of shape (num_bboxes, 4), where each row represents a 298 | bounding box: (y_top_left, x_top_left, box_height, box_width). Note that 299 | containment is inclusive of the bounding box edges. 300 | tap_distance_threshold: The threshold that determines if two taps result in 301 | a matching screen state if they don't fall the same bounding boxes. 302 | annotation_width_augment_fraction: The fraction to increase the width of the 303 | bounding box by. 304 | annotation_height_augment_fraction: The fraction to increase the height of 305 | of the bounding box by. 306 | 307 | Returns: 308 | A boolean representing whether the two given actions are the same or not. 309 | """ 310 | action_1_touch_yx = jnp.asarray(action_1_touch_yx) 311 | action_1_lift_yx = jnp.asarray(action_1_lift_yx) 312 | action_2_touch_yx = jnp.asarray(action_2_touch_yx) 313 | action_2_lift_yx = jnp.asarray(action_2_lift_yx) 314 | 315 | # Checks if at least one of the actions is global (i.e. not DUAL_POINT), 316 | # because if that is the case, only the actions' types need to be compared. 317 | has_non_dual_point_action = jnp.logical_or( 318 | _is_non_dual_point_action(action_1_action_type), 319 | _is_non_dual_point_action(action_2_action_type), 320 | ) 321 | 322 | different_dual_point_types = jnp.logical_xor( 323 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 324 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 325 | ) 326 | 327 | is_tap = jnp.logical_and( 328 | is_tap_action(action_1_touch_yx, action_1_lift_yx), 329 | is_tap_action(action_2_touch_yx, action_2_lift_yx), 330 | ) 331 | 332 | taps_match = _check_tap_actions_match( 333 | action_1_touch_yx, 334 | action_2_touch_yx, 335 | annotation_positions, 336 | tap_distance_threshold, 337 | annotation_width_augment_fraction, 338 | annotation_height_augment_fraction, 339 | ) 340 | 341 | taps_match = jnp.logical_and(is_tap, taps_match) 342 | 343 | drags_match = _check_drag_actions_match( 344 | action_1_touch_yx, action_1_lift_yx, action_2_touch_yx, action_2_lift_yx 345 | ) 346 | drags_match = jnp.where(is_tap, False, drags_match) 347 | 348 | return jnp.where( 349 | has_non_dual_point_action, 350 | jnp.equal(action_1_action_type, action_2_action_type), 351 | jnp.where( 352 | different_dual_point_types, 353 | False, 354 | jnp.logical_or(taps_match, drags_match), 355 | ), 356 | ) 357 | 358 | -------------------------------------------------------------------------------- /coat/model.py: -------------------------------------------------------------------------------- 1 | from typing import Union, List, Dict, Optional, Any 2 | from abc import abstractmethod 3 | import os 4 | import io 5 | import json 6 | import requests 7 | import random 8 | import base64 9 | import tempfile 10 | import numpy as np 11 | from PIL import Image 12 | from functools import partial 13 | from http import HTTPStatus 14 | 15 | 16 | class BaseModel(object): 17 | 18 | def __init__(self, **kwargs): 19 | pass 20 | 21 | @abstractmethod 22 | def format_query(self, *args: Any, **kwargs: Any): 23 | pass 24 | 25 | @abstractmethod 26 | def get_response(self, *args: Any, **kwargs: Any): 27 | pass 28 | 29 | @abstractmethod 30 | def chat(self, *args: Any, **kwargs: Any): 31 | pass 32 | 33 | def __call__(self, *args: Any, **kwargs: Any) -> Any: 34 | return self.get_response(*args, **kwargs) 35 | 36 | 37 | class OpenAIModel(BaseModel): 38 | 39 | def __init__(self, configs, seed=2024, **kwargs) -> None: 40 | super().__init__() 41 | 42 | self.api_key = configs['OPENAI_API_KEY'] 43 | self.api_url = configs['OPENAI_API_URL'] 44 | self.seed = seed 45 | 46 | def encode_image(self, image=None, image_path=None): 47 | assert image is not None or image_path is not None, \ 48 | "Please specify at least an image array or a path to image." 49 | if image is not None:# convert image to bytes 50 | with io.BytesIO() as output_bytes: 51 | if isinstance(image, Image.Image): PIL_image = image 52 | elif isinstance(image, np.ndarray): PIL_image = Image.fromarray(image) 53 | else: raise TypeError(f"Unsupported image type {type(image)}") 54 | PIL_image.save(output_bytes, 'PNG') 55 | bytes_data = output_bytes.getvalue() 56 | return base64.b64encode(bytes_data).decode('utf-8') 57 | elif image_path is not None: 58 | with open(image_path, "rb") as image_file: 59 | return base64.b64encode(image_file.read()).decode('utf-8') 60 | else: raise NotImplementedError 61 | 62 | def format_query(self, image_paths:List[str], prompt:Union[str, dict, tuple], detail="high"): 63 | if isinstance(image_paths, str): image_paths = [image_paths] 64 | 65 | if isinstance(prompt, str): 66 | message = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt},]}] 67 | if len(image_paths) != 0: 68 | for img_path in image_paths: 69 | base64_image = self.encode_image(image_path=img_path) 70 | message[0]['content'].append( 71 | {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}} 72 | ) 73 | 74 | if isinstance(prompt, list): 75 | message = [] 76 | for role, msg in prompt: 77 | if isinstance(msg, dict) and "img_index" in msg: 78 | img_index = msg.pop('img_index') 79 | base64_image = self.encode_image(image_path=image_paths[img_index]) 80 | message.append({'role': role, 'content': [ 81 | {'type': 'text', 'text': msg['text']}, 82 | {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}} 83 | ]}) 84 | else: 85 | if isinstance(msg, dict): 86 | message.append({'role': role, 'content': [{'type': 'text', 'text': msg['text']},]}) 87 | else: 88 | message.append({'role': role, 'content': [{'type': 'text', 'text': msg},]}) 89 | 90 | if isinstance(prompt, dict): 91 | message = [] 92 | for role in ['system', 'user', 'assistant']: 93 | if role in prompt: 94 | msg = [{'role': role, 'content': []},] 95 | if isinstance(prompt[role], str): 96 | msg['content'] = prompt[role] 97 | else: 98 | msg['content'].append({'type': 'text', 'text': prompt[role]['text']}) 99 | if "img_index" in prompt[role]: 100 | img_index = prompt[role].pop('img_index') 101 | base64_image = self.encode_image(image_path=image_paths[img_index]) 102 | msg['content'].append({'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}}) 103 | message.append(msg) 104 | # message += [ 105 | # {'role': 'user', 'content': [ 106 | # {'type': 'text', 'text': user_prompt}, 107 | # {'type': 'image_url', 'image_url': {"url": f"data:image/jepg;base64,{base64_image}", "detail": detail}} 108 | # ]} 109 | # ] 110 | return message 111 | 112 | def get_response(self, image_path:str, prompt:Union[str, dict], temperature=1, max_new_tokens=1024, **kwargs: Any): 113 | response_str, response_state, history = self.chat( 114 | image_path, prompt, history=[], temperature=temperature, max_new_tokens=max_new_tokens, **kwargs) 115 | 116 | return response_str, response_state 117 | 118 | def chat(self, image_path:str, prompt:Union[str, dict], history:list=[], temperature=1, max_new_tokens=1024, **kwargs: Any): 119 | message = history + self.format_query(image_path, prompt) 120 | 121 | headers = { 122 | 'Content-Type': 'application/json', 123 | 'Authorization': f'Bearer {self.api_key}' 124 | } 125 | payload = { 126 | # 'model': 'gpt-4-turbo', 127 | 'model': 'gpt-4-vision-preview', 128 | 'messages': message, 129 | 'max_tokens': max_new_tokens, 130 | 'seed': self.seed, 131 | } 132 | 133 | response = requests.post(self.api_url, headers=headers, json=payload) 134 | response_state = response.status_code 135 | 136 | try: 137 | response_json = response.json() 138 | response_str = response_json['choices'][0]['message']['content'] 139 | except: 140 | response_json, response_str = {}, "" 141 | 142 | if response_state == 200: 143 | history += message 144 | history += [{'role': 'assistant', 'content': response_str}] 145 | 146 | return response_str, response_state, history 147 | 148 | 149 | class GeminiModel(BaseModel): 150 | 151 | def __init__(self, configs, **kwargs) -> None: 152 | super().__init__() 153 | import google.generativeai as genai 154 | 155 | genai.configure(api_key=configs['GEMINI_API_KEY'], transport='rest') 156 | self.model = genai.GenerativeModel(configs['GEMINI_MODEL']) 157 | 158 | def format_query(self, image_paths:str, prompt:Union[str, list, dict]): 159 | if isinstance(prompt, str): 160 | image = Image.open(image_paths) 161 | return [prompt, image] 162 | 163 | if isinstance(image_paths, str): image_paths = [image_paths] 164 | 165 | if isinstance(prompt, list): 166 | message = [] 167 | for role, msg in prompt: 168 | if isinstance(msg, dict) and "img_index" in msg: 169 | img_index = msg.pop('img_index') 170 | image = Image.open(image_paths[img_index]) 171 | message.extend([msg['text'], image]) 172 | else: 173 | if isinstance(msg, dict): message.append(msg['text']) 174 | else: message.append(msg) 175 | return message 176 | 177 | def get_response(self, image_path:str, prompt:str, temperature=1, max_new_tokens=1024, **kwargs: Any): 178 | """ 179 | Return: 180 | - response_str: (string) the decoded response 181 | - response_state: (int) 182 | STOP (1): Natural stop point of the model or provided stop sequence. 183 | MAX_TOKENS (2): The maximum number of tokens as specified in the request was reached. 184 | SAFETY (3): The candidate content was flagged for safety reasons. 185 | RECITATION (4): The candidate content was flagged for recitation reasons. 186 | OTHER (5): Unknown reason. 187 | """ 188 | import google.generativeai as genai 189 | 190 | message = self.format_query(image_path, prompt) 191 | response = self.model.generate_content(message, 192 | generation_config=genai.types.GenerationConfig( 193 | candidate_count=1, 194 | temperature=temperature, 195 | max_output_tokens=max_new_tokens 196 | ) 197 | ) 198 | response_state = response.candidates[0].finish_reason 199 | 200 | try: 201 | response_str = response.candidates[0].content.parts[0].text 202 | if response_state == 1: response_state = HTTPStatus.OK 203 | except: response_str = "" 204 | 205 | return response_str.strip(), response_state 206 | 207 | 208 | 209 | qwenvl_special_str = """ 210 | Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 211 | { 212 | "THINK": , 213 | "NEXT": , 214 | "ACTION": , 215 | "ARGS": 216 | } 217 | You should output only one next single-step action. 218 | 219 | : 220 | """ 221 | 222 | class QwenVLModel(BaseModel): 223 | 224 | def __init__(self, configs, **kwargs): 225 | super().__init__() 226 | from dashscope import MultiModalConversation 227 | 228 | self.api_key = configs['DASHSCOPE_API_KEY'] 229 | self.api_version = configs['DASHSCOPE_MODEL'] 230 | self.model = partial( 231 | MultiModalConversation.call, 232 | model=self.api_version, api_key=self.api_key 233 | ) 234 | 235 | def encode_image(self, image=None, image_path=None): 236 | assert image is not None or image_path is not None, \ 237 | "Please specify at least an image array or a path to image." 238 | if image is not None:# convert image to bytes 239 | if isinstance(image, Image.Image): PIL_image = image 240 | elif isinstance(image, np.ndarray): PIL_image = Image.fromarray(image) 241 | else: raise TypeError(f"Unsupported image type {type(image)}") 242 | temp_directory = tempfile.gettempdir() 243 | unique_suffix = "_temp_" + ''.join(random.choice("abcdefghijklmnopqrstuvwxyz") for _ in range(5)) 244 | filename = f"image{unique_suffix}.png" 245 | temp_image_path = os.path.join(temp_directory, filename) 246 | temp_image_url = f"file:///{temp_image_path}" 247 | temp_image_url = temp_image_url.replace('\\', '/') 248 | PIL_image.save(temp_image_path) 249 | image_url = temp_image_url 250 | elif image_path is not None: image_url = f'file://{os.path.abspath(image_path)}' 251 | else: raise NotImplementedError 252 | return image_url 253 | 254 | def format_query(self, image_paths:str, prompt:Union[str, dict], detail="high"): 255 | if isinstance(image_paths, str): image_paths = [image_paths] 256 | 257 | if isinstance(prompt, str): 258 | message = [{'role': 'user', 'content': [{'text': prompt},]}] 259 | if len(image_paths) != 0: 260 | for img_path in image_paths: 261 | image_url = self.encode_image(image_path=img_path) 262 | message[0]['content'].append({'image': image_url}) 263 | 264 | if isinstance(prompt, list): 265 | message = [] 266 | for role, msg in prompt: 267 | if isinstance(msg, dict) and "img_index" in msg: 268 | img_index = msg.pop('img_index') 269 | image_url = self.encode_image(image_path=image_paths[img_index]) 270 | message.append({'role': role, 'content': [ 271 | {'text': msg['text']}, {'image': image_url}]}) 272 | else: 273 | if isinstance(msg, dict): 274 | message.append({'role': role, 'content': [{'text': msg['text']},]}) 275 | else: 276 | message.append({'role': role, 'content': [{'text': msg},]}) 277 | if role == "assistant": 278 | message.append({"role": "user", "content": [{"text": qwenvl_special_str}]}) 279 | 280 | if isinstance(prompt, dict): 281 | message = [] 282 | for role in ['system', 'user', 'assistant']: 283 | if role in prompt: 284 | msg = [{'role': role, 'content': []},] 285 | if isinstance(prompt[role], str): 286 | msg['content'].append({'text': prompt[role]}) 287 | else: 288 | msg['content'].append({'text': prompt[role]['text']}) 289 | if "img_index" in prompt[role]: 290 | img_index = prompt[role].pop('img_index') 291 | image_url = self.encode_image(image_path=image_paths[img_index]) 292 | msg['content'].append({'image': image_url}) 293 | message.append(msg) 294 | 295 | return message 296 | 297 | def chat(self, image_path:str, prompt:Union[str, dict], history:list=[], top_p=0.001, **kwargs: Any): 298 | message = history + self.format_query(image_path, prompt) 299 | 300 | response = self.model(messages=message, top_p=top_p) 301 | response_state = response.status_code 302 | 303 | try: 304 | response_str = response.output.choices[0].message.content[0]['text'] 305 | if response_state == HTTPStatus.OK: #如果调用成功 306 | history += message 307 | history += [{ 308 | 'role': 'assistant', 309 | 'content': [{'text': response_str}] 310 | }] 311 | except: response_str = "" 312 | 313 | return response_str.strip(), response_state, history 314 | 315 | def get_response(self, image_path:str, prompt:Union[str, dict], history:list=[], top_p=0.1, **kwargs: Any): 316 | response_str, response_state, history = self.chat( 317 | image_path, prompt, history=[], top_p=top_p, **kwargs) 318 | 319 | return response_str, response_state 320 | 321 | 322 | def get_model(model_config, seed=2024) -> BaseModel: 323 | if model_config['NAME'].lower() == "openai": return OpenAIModel(model_config, seed) 324 | if model_config['NAME'].lower() == "gemini": return GeminiModel(model_config) 325 | if model_config['NAME'].lower() == "qwenvl": return QwenVLModel(model_config) 326 | raise NotImplementedError 327 | 328 | -------------------------------------------------------------------------------- /coat/agent.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import json 4 | import traceback 5 | import imagesize 6 | import numpy as np 7 | from PIL import Image 8 | from http import HTTPStatus 9 | from .model import get_model 10 | from .screen_utils import row_col_sort, draw_bbox 11 | 12 | 13 | def json_parser(json_string:str): 14 | if json_string.startswith("```json"): 15 | json_string = json_string[7:] 16 | if json_string.startswith("```JSON"): 17 | json_string = json_string[7:] 18 | if json_string.endswith("```"): 19 | json_string = json_string[:-3] 20 | 21 | if "```json" in json_string: 22 | json_string = json_string.split("```json")[1] 23 | if "```JSON" in json_string: 24 | json_string = json_string.split("```JSON")[1] 25 | if "```" in json_string: 26 | json_string = json_string.split("```")[0] 27 | 28 | return json.loads(json_string) 29 | 30 | 31 | 32 | class BaseAgent(object): 33 | 34 | def __init__(self, config, *args, **kwargs) -> None: 35 | self.cfg = config 36 | self.prompts = config['PROMPTS'] 37 | self.vlm = get_model(config['MODEL']) 38 | 39 | def observe(self, user_request, screen_path, **kwargs): 40 | sys_rompt = self.prompts['SCREEN_DESC']['SYSTEM'] 41 | usr_prompt = self.prompts['SCREEN_DESC']['USER'] 42 | usr_prompt = [x.strip() for x in usr_prompt.split("{screenshot}")] 43 | 44 | prompt = [ 45 | ("system", sys_rompt), 46 | ("user", {"text": usr_prompt[0], "img_index": 0}) 47 | ] 48 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt})) 49 | 50 | res, state = self.vlm.get_response(screen_path, prompt) 51 | return res, state 52 | 53 | def think_action(self, user_request, screen_path, screen_desc, history_actions, **kwargs): 54 | sys_rompt = self.prompts['ACTION_THINK_DESC']['SYSTEM'] 55 | 56 | if history_actions: history_actions = ", ".join(history_actions) 57 | else: history_actions = "None" 58 | history_actions = history_actions + "." 59 | 60 | usr_prompt = self.prompts['ACTION_THINK_DESC']['USER'] 61 | usr_prompt = usr_prompt.replace("{screen_desc}", screen_desc) 62 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions) 63 | usr_prompt = usr_prompt.replace("{user_request}", user_request) 64 | usr_prompt = [x.strip() for x in usr_prompt.split("{screenshot}")] 65 | 66 | prompt = [ 67 | ("system", sys_rompt), 68 | ("user", {"text": usr_prompt[0], "img_index": 0}) 69 | ] 70 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt})) 71 | 72 | res, state = self.vlm.get_response(screen_path, prompt) 73 | return res, state 74 | 75 | def reflect_result(self, user_request, screen_path, next_screen_path, last_action, **kwargs): 76 | sys_rompt = self.prompts['ACTION_RESULT']['SYSTEM'] 77 | 78 | usr_prompt = self.prompts['ACTION_RESULT']['USER'] 79 | usr_prompt = usr_prompt.replace("{last_action}", last_action) 80 | usr_prompt = usr_prompt.replace("{user_request}", user_request) 81 | 82 | usr_prompt = [x.strip() for x in usr_prompt.split("{before_screenshot}")] 83 | usr_prompt = [usr_prompt[0]] + [x.strip() for x in usr_prompt[1].split("{after_screenshot}")] 84 | 85 | prompt = [ 86 | ("system", sys_rompt), 87 | ("user", {"text": usr_prompt[0], "img_index": 0}), 88 | ("user", {"text": usr_prompt[1], "img_index": 1}) 89 | ] 90 | for txt in usr_prompt[2:]: prompt.append(("user", {"text": txt})) 91 | 92 | res, state = self.vlm.get_response([screen_path, next_screen_path], prompt) 93 | return res, state 94 | 95 | def flow(self, step_data, save_dir, max_trials=5): 96 | subset, episode_id, step_id = step_data['subset'], step_data['episode_id'], step_data['step_id'] 97 | user_request, image_path = step_data['instruction'], step_data['image_full_path'] 98 | 99 | save_dir = os.path.join(save_dir, f"{subset}-{episode_id}") 100 | if not os.path.exists(save_dir): os.makedirs(save_dir) 101 | save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{step_id}.json") 102 | if not os.path.exists(save_path): json.dump({}, open(save_path, "w", encoding="utf-8")) 103 | 104 | prev_anno = json.load(open(save_path, "r")) 105 | # if 'action_result' not in prev_anno: return 106 | # del prev_anno['action_result'] 107 | # json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4) 108 | # return 109 | 110 | update = False 111 | if 'screen_desc' not in prev_anno: 112 | print(f"Genearting [screen_desc] -> img_path {image_path}") 113 | try_num = 0 114 | while try_num < max_trials: 115 | res, state = self.observe(user_request, image_path) 116 | if state == HTTPStatus.OK: 117 | prev_anno['screen_desc'], update = res.strip(), True 118 | break 119 | try_num += 1 120 | time.sleep(5) 121 | if update: print(f"Updating [screen_desc] {save_path} ...") 122 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4) 123 | 124 | update = False 125 | if 'action_think' not in prev_anno: 126 | print(f"Genearting [action_think] -> img_path {image_path}") 127 | try_num = 0 128 | while try_num < max_trials: 129 | res, state = self.think_action(user_request, image_path, 130 | screen_desc=prev_anno['screen_desc'], 131 | history_actions=step_data['history_coat_actions']) 132 | if state == HTTPStatus.OK: 133 | try: 134 | response = json_parser(res) 135 | action_think = response['Thought'] 136 | action_plan = response['Future Action Plan'] 137 | action_desc = response['Next Single Step Action'] 138 | prev_anno['action_think'] = action_think 139 | prev_anno['action_plan'] = action_plan 140 | prev_anno['action_desc'] = action_desc 141 | update = True 142 | break 143 | except Exception as e: 144 | if not isinstance(e, json.decoder.JSONDecodeError): print(traceback.format_exc()) 145 | else: print(res,"\n", save_path, "\n", sep="") 146 | break 147 | print(f"Trial {try_num} failured! -- Status {state}") 148 | try_num += 1 149 | time.sleep(5) 150 | if update: print(f"Updating [action_think] {save_path} ...") 151 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4) 152 | 153 | update = False 154 | if 'action_result' not in prev_anno: 155 | print(f"Genearting [action_result] -> img_path {image_path}") 156 | if step_data['result_action_type'] not in [10, 11]: 157 | try_num = 0 158 | cur_image_path, next_image_path = image_path, step_data['next_image_full_path'] 159 | while try_num < max_trials: 160 | cur_image_height = imagesize.get(cur_image_path)[1] 161 | next_image_height = imagesize.get(next_image_path)[1] 162 | res, state = self.reflect_result(user_request, 163 | screen_path=cur_image_path, 164 | next_screen_path=next_image_path, 165 | last_action=step_data['coat_action_desc']) 166 | if state == HTTPStatus.OK: 167 | prev_anno['action_result'], update = res, True 168 | break 169 | elif state == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: 170 | cur_image_path = self.compress_image(image_path, max_height=int(0.8*cur_image_height)) 171 | next_image_path = self.compress_image(step_data['next_image_full_path'], max_height=int(0.8*next_image_height)) 172 | else: print(res) 173 | print(f"Trial {try_num} failured! -- Status {state}") 174 | try_num += 1 175 | time.sleep(5) 176 | else: 177 | prev_anno['action_result'] = "The execution of user request is stopped." 178 | update = True 179 | if update: print(f"Updating [action_result] {save_path} ...") 180 | if update: json.dump(prev_anno, open(save_path, "w", encoding="utf-8"), indent=4) 181 | 182 | def compress_image(self, image_path, max_height=1024, mode="scale"): 183 | org_img_path, ext = os.path.splitext(image_path) 184 | 185 | screen_img = Image.open(image_path) 186 | image_width, image_height = screen_img.size 187 | 188 | if image_height > max_height: 189 | scale_percent = max_height / image_height 190 | height = int(scale_percent * image_height) 191 | width = int(scale_percent * image_width) 192 | screen_img = screen_img.resize((width, height)) 193 | 194 | if mode == "scale": 195 | save_img_path = org_img_path+"_tmp"+ext 196 | screen_img.save(save_img_path) 197 | elif mode == "jpg": 198 | screen_img = screen_img.convert("RGB") 199 | save_img_path = org_img_path+"_tmp"+".jpeg" 200 | screen_img.save(save_img_path) 201 | else: 202 | save_img_path = image_path 203 | 204 | return save_img_path 205 | 206 | 207 | class ScreenAgent(BaseAgent): 208 | 209 | def __init__(self, config) -> None: 210 | super().__init__(config) 211 | 212 | self.use_screen_tag = self.cfg.DATA.USE_SCREEN_TAG 213 | self.use_screen_txt = (not self.use_screen_tag) 214 | self.screen_mode = "txt" if self.use_screen_txt else "tag" 215 | pass 216 | 217 | def format_bbox(self, bbox_xyxy, width, height): 218 | coord_sys, rep_type = self.cfg.DATA.BBOX_FORMAT.split("_") 219 | 220 | if coord_sys == "relative": 221 | bbox = np.array(bbox_xyxy) / np.array([width, height, width, height]) 222 | if rep_type == "int": 223 | bbox = (bbox*1000).astype(np.int32).tolist() 224 | else: 225 | bbox = [round(x, 3) for x in bbox.tolist()] 226 | elif coord_sys == "absolute": bbox = bbox_xyxy 227 | else: raise NotImplementedError 228 | box_str = json.dumps(bbox) 229 | return box_str 230 | 231 | def add_screen_tag(self, step_data, save_dir, update=False): 232 | """ Set of Mark Prompting """ 233 | src_image_path = step_data['image_full_path'] 234 | base_name, ext = os.path.splitext(src_image_path) 235 | dst_image_path = base_name + "_tag" + ext 236 | # image_name = os.path.basename(src_image_path) 237 | # dst_image_path = os.path.join(save_dir, image_name) 238 | 239 | if not os.path.exists(dst_image_path) or update: 240 | ui_bboxes = [] 241 | for ui_node in step_data['ui_elements']: 242 | ui_bboxes.append(ui_node['bounds']) 243 | 244 | tag_image = draw_bbox(src_image_path, bboxes=ui_bboxes, 245 | texts=list(map(str, range(len(step_data['ui_positions'])))), 246 | rgba_color=(255, 0, 0, 180), thickness=5) 247 | tag_image.save(dst_image_path) 248 | return dst_image_path 249 | 250 | def add_screen_txt(self, step_data, indent=2): 251 | """ Textual Representation of Screen Elements """ 252 | image_path = step_data['image_full_path'] 253 | w, h = imagesize.get(image_path) 254 | 255 | screen_str = [] 256 | for ui_node in step_data['ui_elements']: 257 | ui_str = " "*indent + ui_node['type'].upper() 258 | if ui_node['text']: ui_str += " " + ui_node['text'] 259 | bbox_str = self.format_bbox(ui_node['bounds'], w, h) 260 | ui_str += " " + bbox_str 261 | screen_str.append(ui_str) 262 | 263 | return "\n".join(screen_str) 264 | 265 | def make_action(self, step_data, usr_prompt, save_dir, asst_prompt=None): 266 | action_mode = self.cfg.DEMO_MODE.upper() 267 | image_path = step_data['image_full_path'] 268 | 269 | ui_elements = [] 270 | for org_bbox, txt, ui_class in zip( 271 | step_data['ui_positions'], step_data['ui_text'], step_data['ui_types']): 272 | ymin, xmin, h, w = org_bbox 273 | bbox = [xmin, ymin, xmin+w, ymin+h] 274 | ui_elements.append({"bounds": bbox, "text": txt, "type": ui_class}) 275 | ui_elements = row_col_sort(ui_elements) 276 | step_data['ui_elements'] = ui_elements 277 | 278 | if self.use_screen_tag: 279 | sys_prompt = self.prompts['ACTION_PREDICT'][action_mode]['SYSTEM']['SCREEN_TAG'] 280 | action_space = self.prompts['ACTION_PREDICT']['ACTION_SPACE']['SCREEN_TAG'] 281 | sys_prompt = sys_prompt.replace("{action_space}", action_space) 282 | tag_image_path = self.add_screen_tag(step_data, save_dir=save_dir) 283 | 284 | if self.cfg.MODEL.NAME == "openai": 285 | image_path = self.compress_image(image_path, max_height=1440, mode="jpg") 286 | tag_image_path = self.compress_image(tag_image_path, max_height=1440, mode="scale") 287 | 288 | prompt = [ 289 | ("system", sys_prompt), 290 | ("user", {"text": usr_prompt[0], "img_index": 0}), 291 | ("user", {"text": "", "img_index": 1}) 292 | ] 293 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt})) 294 | if asst_prompt: prompt.append(("assistant", {"text": asst_prompt})) 295 | asst_res, state = self.vlm.get_response([image_path, tag_image_path], prompt) 296 | 297 | if self.use_screen_txt: 298 | sys_prompt = self.prompts['ACTION_PREDICT'][action_mode]['SYSTEM']['SCREEN_TXT'] 299 | action_space = self.prompts['ACTION_PREDICT']['ACTION_SPACE']['SCREEN_TXT'] 300 | sys_prompt = sys_prompt.replace("{action_space}", action_space) 301 | screen_txt = self.add_screen_txt(step_data) 302 | 303 | prompt = [ 304 | ("system", sys_prompt), 305 | ("user", {"text": usr_prompt[0], "img_index": 0}), 306 | ("user", {"text": f": {screen_txt}"}) 307 | ] 308 | for txt in usr_prompt[1:]: prompt.append(("user", {"text": txt})) 309 | if asst_prompt: prompt.append(("assistant", {"text": asst_prompt})) 310 | asst_res, state = self.vlm.get_response([image_path], prompt) 311 | 312 | asst_head = asst_prompt if asst_prompt else "" 313 | return asst_head, asst_res, state 314 | 315 | def coa_action(self, step_data, *args, **kwargs): 316 | """ Chain of Action """ 317 | user_request, image_path = step_data['instruction'], step_data['image_full_path'] 318 | history_actions = step_data['history_actions'] 319 | 320 | if history_actions: history_actions = ", ".join(history_actions) 321 | else: history_actions = "None" 322 | history_actions = history_actions + "." 323 | 324 | usr_prompt = self.prompts['ACTION_PREDICT']['COA']['USER'] 325 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions) 326 | usr_prompt = usr_prompt.replace("{user_request}", user_request) 327 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")] 328 | 329 | return self.make_action(step_data, usr_prompt, kwargs['save_dir']) 330 | 331 | def cot_action(self, step_data, *args, **kwargs): 332 | """ Chain of Thought """ 333 | user_request, image_path = step_data['instruction'], step_data['image_full_path'] 334 | acton_think = kwargs['action_think'] 335 | 336 | usr_prompt = self.prompts['ACTION_PREDICT']['COT']['USER'] 337 | usr_prompt = usr_prompt.replace("{user_request}", user_request) 338 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")] 339 | 340 | if self.cfg.MODEL.NAME == "qwenvl": 341 | asst_prompt = json.dumps({"THINK": acton_think}, indent=4) 342 | else: 343 | asst_prompt = self.prompts['ACTION_PREDICT']['COT']['ASST'] 344 | asst_prompt = asst_prompt.replace("{action_thought}", json.dumps(acton_think)) 345 | 346 | return self.make_action(step_data, usr_prompt, kwargs['save_dir'], asst_prompt=asst_prompt) 347 | 348 | def coat_action(self, step_data, *args, **kwargs): 349 | """ Chain of Action Thought """ 350 | user_request, image_path = step_data['instruction'], step_data['image_full_path'] 351 | history_actions = step_data['history_coat_actions'] 352 | 353 | screen_desc = kwargs['screen_desc'] 354 | prev_action_result = kwargs.get('prev_action_result', None) 355 | acton_think, next_action_desc = kwargs['action_think'], kwargs['action_desc'] 356 | 357 | usr_prompt = self.prompts['ACTION_PREDICT']['COAT']['USER'] 358 | usr_prompt = usr_prompt.split("\n") 359 | if not history_actions: usr_prompt = [x for x in usr_prompt if 'history_actions' not in x] 360 | if not prev_action_result: usr_prompt = [x for x in usr_prompt if 'prev_action_result' not in x] 361 | usr_prompt = "\n".join(usr_prompt) 362 | 363 | usr_prompt = usr_prompt.replace("{screen_desc}", screen_desc) 364 | if history_actions: history_actions = ", ".join(history_actions) 365 | else: history_actions = "None" 366 | history_actions = history_actions + "." 367 | usr_prompt = usr_prompt.replace("{history_actions}", history_actions) 368 | if prev_action_result: 369 | usr_prompt = usr_prompt.replace("{prev_action_result}", prev_action_result) 370 | usr_prompt = usr_prompt.replace("{user_request}", user_request) 371 | usr_prompt = [x for x in usr_prompt.split("{screenshot}")] 372 | 373 | if self.cfg.MODEL.NAME == "qwenvl": 374 | asst_prompt = json.dumps({"THINK": acton_think, "NEXT": next_action_desc,}, indent=4) 375 | else: 376 | asst_prompt = self.prompts['ACTION_PREDICT']['COAT']['ASST'] 377 | asst_prompt = asst_prompt.replace("{action_thought}", json.dumps(acton_think)) 378 | asst_prompt = asst_prompt.replace("{next_single_action}", json.dumps(next_action_desc)) 379 | 380 | return self.make_action(step_data, usr_prompt, kwargs['save_dir'], asst_prompt=asst_prompt) 381 | 382 | def predict(self, step_data, save_dir, max_trials=5): 383 | subset, episode_id = step_data['subset'], step_data['episode_id'] 384 | prev_step_id, step_id = step_data['prev_step_id'], step_data['step_id'] 385 | 386 | save_dir = os.path.join(save_dir, f"{subset}-{episode_id}") 387 | 388 | cur_save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{step_id}.json") 389 | cur_anno = json.load(open(cur_save_path, "r")) 390 | if prev_step_id: 391 | prev_save_path = os.path.join(save_dir, f"{subset}-{episode_id}_{prev_step_id}.json") 392 | prev_anno = json.load(open(prev_save_path, "r")) 393 | cur_anno['prev_action_result'] = prev_anno['action_result'] 394 | 395 | action_mode = self.cfg.DEMO_MODE.upper() 396 | if action_mode == "COA": func_handler = self.coa_action 397 | elif action_mode == "COT": func_handler = self.cot_action 398 | elif action_mode == "COAT": func_handler = self.coat_action 399 | else: raise NotImplementedError 400 | 401 | if 'action_predict' not in cur_anno: cur_anno['action_predict'] = {} 402 | if action_mode not in cur_anno['action_predict']: cur_anno['action_predict'][action_mode] = {} 403 | # if action_mode in cur_anno['action_predict']: 404 | # anno = cur_anno['action_predict'][action_mode] 405 | # cur_anno['action_predict'][action_mode] = {} 406 | # cur_anno['action_predict'][action_mode][self.screen_mode] = anno 407 | # json.dump(cur_anno, open(cur_save_path, "w", encoding="utf-8"), indent=4) 408 | 409 | if self.screen_mode not in cur_anno['action_predict'][action_mode]: 410 | print(f"[Mode {action_mode}][{self.screen_mode}] Gnerating action ... ({cur_save_path})") 411 | try_num = 0 412 | while try_num < max_trials: 413 | asst_head, asst_res, state = func_handler(step_data, **cur_anno, save_dir=save_dir) 414 | if state == HTTPStatus.OK: 415 | try: 416 | try: response = json_parser(asst_res) 417 | except json.decoder.JSONDecodeError as e: 418 | try: response = json_parser(asst_head + asst_res) 419 | except: 420 | try:response = json_parser(asst_res + "}") 421 | except: response = json_parser(asst_head + asst_res + "}") 422 | cur_anno['action_predict'][action_mode][self.screen_mode] = { 423 | "ACTION": response["ACTION"], "ARGS": response["ARGS"]} 424 | 425 | print(f"Updating [{action_mode}][{self.screen_mode}] {cur_save_path} ...") 426 | json.dump(cur_anno, open(cur_save_path, "w", encoding="utf-8"), indent=4) 427 | except json.decoder.JSONDecodeError as e: 428 | print('-'*50 + "\n", asst_head, sep=" ") 429 | print('-'*50 + "\n", asst_res, sep=" ") 430 | break 431 | elif state == HTTPStatus.REQUEST_ENTITY_TOO_LARGE: 432 | print(f"Trial {try_num} failured! -- Status {state}") 433 | break 434 | print(f"Trial {try_num} failured! -- Status {state}") 435 | try_num += 1 436 | time.sleep(10) 437 | 438 | return None 439 | -------------------------------------------------------------------------------- /coat/config.yaml: -------------------------------------------------------------------------------- 1 | DEMO_NAME: "CoATAgent" 2 | DEMO_MODE: "COA" 3 | OUTPUT_DIR: "android-in-the-zoo/api" 4 | 5 | DATA: 6 | DATA_DIR: "android-in-the-zoo" 7 | SPLIT: "test" 8 | USE_SCREEN_TAG: True 9 | USE_SCREEN_TXT: False 10 | BBOX_FORMAT: "relative_int" 11 | 12 | MODEL: 13 | NAME: "command_line" 14 | OPENAI_API_URL: "" 15 | OPENAI_API_KEY: "" 16 | GEMINI_API_KEY: "" 17 | GEMINI_MODEL: gemini-pro-vision 18 | DASHSCOPE_API_KEY: "" 19 | DASHSCOPE_MODEL: qwen-vl-max 20 | 21 | PROMPTS: 22 | SCREEN_DESC: 23 | SYSTEM: |- 24 | You are a smart and helpful visual assistant that is well trained to describe smartphone screenshots. 25 | - You are provided with a screenshot of the current mobile phone. 26 | - You are required to describe this screen about its main content and its functionality. The output must be less than five sentences. 27 | - You are required to keep the description as concise and brief as possible. 28 | USER: |- 29 | : {screenshot} 30 | : 31 | 32 | ACTION_THINK_DESC: 33 | SYSTEM: |- 34 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 35 | Your task is to navigate and take action on the current screen step-by-step to complete the user request. 36 | - You are provided with a screenshot of the current mobile phone, together with the textual screen description. 37 | - You are provided with your history actions to decide on your next action. You can backtrack to revise the previous actions when neessary. 38 | - You are required to analyze the task status and detail a reasonable future action plan to accomplish the user request. 39 | - You are required to select the next single-step action based on your analysis and action plan. 40 | 41 | ## Analysis Guidelines 42 | - You should check whether the history actions have accomplish the user request. 43 | - You should check the apps, icons, and buttons that are visible on the current screen and might pertain to the user request. 44 | - You should combine the above information and describe your future action plan. But DO NOT include any additional actions beyond the completion of the user request. 45 | 46 | ## Output Format 47 | - You are required to response in a JSON format, consisting of 3 distinct parts with the following keys and corresponding content: 48 | { 49 | "Thought": , 50 | "Future Action Plan": , 51 | "Next Single Step Action": 52 | } 53 | - You can not output anything except for the above JSON. 54 | 55 | ## Output Example 56 | { 57 | "Thought": "...", 58 | "Future Action Plan": ["...", "..."], 59 | "Next Single Step Action": "..." 60 | } 61 | 62 | USER: |- 63 | : {screenshot} 64 | : {screen_desc} 65 | : {history_actions} 66 | : {user_request} 67 | : 68 | 69 | ACTION_PREDICT: 70 | COA: 71 | SYSTEM: 72 | SCREEN_TXT: |- 73 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 74 | Your task is to navigate on the current screen to complete the user request. 75 | - You are provided with a screenshot of the current mobile phone. 76 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height. 77 | - You are provided with history actions trying to accompolish the user request. 78 | - You are required to decide on the next single-step valid action to conduct on the current screen. 79 | 80 | ## Valid Actions on the screen 81 | {action_space} 82 | 83 | ## Output Format 84 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 85 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 86 | { 87 | "ACTION": , 88 | "ARGS": 89 | } 90 | 91 | ## Output Example 1: 92 | { 93 | "ACTION": "click_element", 94 | "ARGS": {"bbox": [100, 345, 219, 826]} 95 | } 96 | 97 | ## Output Example 2: 98 | { 99 | "ACTION": "scroll", 100 | "ARGS": {"direction": "down"} 101 | } 102 | 103 | ## Output Example 3: 104 | { 105 | "ACTION": "press_home", 106 | "ARGS": {} 107 | } 108 | 109 | SCREEN_TAG: |- 110 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 111 | Your task is to navigate on the current screen to complete the user request. 112 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second). 113 | - You are provided with history actions trying to accompolish the user request. 114 | - You are required to decide on the next single-step valid action to conduct on the current screen. 115 | 116 | ## Valid Actions on the screen 117 | {action_space} 118 | 119 | ## Output Format 120 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 121 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 122 | { 123 | "ACTION": , 124 | "ARGS": 125 | } 126 | 127 | ## Output Example 1: 128 | { 129 | "ACTION": "click_element", 130 | "ARGS": {"idx": 6} 131 | } 132 | 133 | ## Output Example 2: 134 | { 135 | "ACTION": "scroll", 136 | "ARGS": {"direction": "down"} 137 | } 138 | 139 | ## Output Example 3: 140 | { 141 | "ACTION": "press_home", 142 | "ARGS": {} 143 | } 144 | 145 | USER: |- 146 | : {screenshot} 147 | : {history_actions} 148 | : {user_request} 149 | : 150 | 151 | HINT: 152 | SYSTEM: |- 153 | 154 | USER: |- 155 | : {screenshot} 156 | : {next_single_action} 157 | : {user_request} 158 | : 159 | 160 | COT: 161 | SYSTEM: 162 | SCREEN_TXT: |- 163 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 164 | Your task is to navigate on the current screen to complete the user request. 165 | - You are provided with a screenshot of the current mobile phone. 166 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height. 167 | - You are required to decide on the next single-step valid action to conduct on the current screen. 168 | 169 | ## Valid Actions on the screen 170 | {action_space} 171 | 172 | ## Output Format 173 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 174 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 175 | { 176 | "THINK": , 177 | "ACTION": , 178 | "ARGS": 179 | } 180 | 181 | ## Output Example 1: 182 | { 183 | "THINK": "...", 184 | "ACTION": "click_element", 185 | "ARGS": {"bbox": [100, 345, 219, 826]} 186 | } 187 | 188 | ## Output Example 2: 189 | { 190 | "THINK": "...", 191 | "ACTION": "scroll", 192 | "ARGS": {"direction": "down"} 193 | } 194 | 195 | ## Output Example 3: 196 | { 197 | "THINK": "...", 198 | "ACTION": "press_home", 199 | "ARGS": {} 200 | } 201 | 202 | SCREEN_TAG: |- 203 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 204 | Your task is to navigate on the current screen to complete the user request. 205 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second). 206 | - You are required to decide on the next single-step valid action to conduct on the current screen. 207 | 208 | ## Valid Actions on the screen 209 | {action_space} 210 | 211 | ## Output Format 212 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 213 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 214 | { 215 | "THINK": , 216 | "ACTION": , 217 | "ARGS": 218 | } 219 | 220 | ## Output Example 1: 221 | { 222 | "THINK": "...", 223 | "ACTION": "click_element", 224 | "ARGS": {"idx": 6} 225 | } 226 | 227 | ## Output Example 2: 228 | { 229 | "THINK": "...", 230 | "ACTION": "scroll", 231 | "ARGS": {"direction": "down"} 232 | } 233 | 234 | ## Output Example 3: 235 | { 236 | "THINK": "...", 237 | "ACTION": "press_home", 238 | "ARGS": {} 239 | } 240 | 241 | USER: |- 242 | : {screenshot} 243 | : {user_request} 244 | : 245 | 246 | ASST: |- 247 | { 248 | "THINK": {action_thought}, 249 | "ACTION": 250 | 251 | COAT: 252 | SYSTEM: 253 | SCREEN_TXT: |- 254 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 255 | Your task is to navigate on the current screen to complete the user request. 256 | - You are provided with a screenshot of the current mobile phone. 257 | - You are provided with a textual description of elements on the screen. Each element is assigned a category represented in uppercase letters. For "TEXT" element, the corresponding text on the screen is also included. Each element also has a bounding box, composed of four coordinates [xmin, ymin, xmax, ymax] ranging from 0 to 999, representing the proportion in width or height. 258 | - You are provided with a breif summarization of the screen content. 259 | - You are provided with history actions trying to accompolish the user request, together with the previous action result that indicates how current screenshot is obtained. 260 | - You are required to decide on the next single-step valid action to be conducted on the current screen so as to fulfill the user request. 261 | 262 | ## Valid Actions on the screen 263 | {action_space} 264 | 265 | ## Output Format 266 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 267 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 268 | { 269 | "THINK": , 270 | "NEXT": , 271 | "ACTION": , 272 | "ARGS": 273 | } 274 | 275 | ## Output Example 1: 276 | { 277 | "THINK": "...", 278 | "NEXT": "...", 279 | "ACTION": "click_element", 280 | "ARGS": {"bbox": [100, 345, 219, 826]} 281 | } 282 | 283 | ## Output Example 2: 284 | { 285 | "THINK": "...", 286 | "NEXT": "...", 287 | "ACTION": "scroll", 288 | "ARGS": {"direction": "down"} 289 | } 290 | 291 | ## Output Example 3: 292 | { 293 | "THINK": "...", 294 | "NEXT": "...", 295 | "ACTION": "press_home", 296 | "ARGS": {} 297 | } 298 | 299 | SCREEN_TAG: |- 300 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 301 | Your task is to navigate on the current screen to complete the user request. 302 | - You are provided with two screenshot of the current mobile phone, one without the annotation (first) and one with ui elements annotated (second). 303 | - You are provided with a breif summarization of the screen content. 304 | - You are provided with history actions trying to accompolish the user request, together with the previous action result that indicates how current screenshot is obtained. 305 | - You are required to decide on the next single-step valid action to be conducted on the current screen so as to fulfill the user request. 306 | 307 | ## Valid Actions on the screen 308 | {action_space} 309 | 310 | ## Output Format 311 | - You must choose one of the valid apis provided above and response in the corresponding API call format. 312 | - Your response should be strictly structured in JSON format, consisting of the following keys and corresponding content: 313 | { 314 | "THINK": , 315 | "NEXT": , 316 | "ACTION": , 317 | "ARGS": 318 | } 319 | 320 | ## Output Example 1: 321 | { 322 | "THINK": "...", 323 | "NEXT": "...", 324 | "ACTION": "click_element", 325 | "ARGS": {"idx": 6} 326 | } 327 | 328 | ## Output Example 2: 329 | { 330 | "THINK": "...", 331 | "NEXT": "...", 332 | "ACTION": "scroll", 333 | "ARGS": {"direction": "down"} 334 | } 335 | 336 | ## Output Example 3: 337 | { 338 | "THINK": "...", 339 | "NEXT": "...", 340 | "ACTION": "press_home", 341 | "ARGS": {} 342 | } 343 | 344 | USER: |- 345 | : {screenshot} 346 | : {screen_desc} 347 | : {history_actions} 348 | : {prev_action_result} 349 | : {user_request} 350 | : 351 | 352 | ASST: |- 353 | { 354 | "THINK": {action_thought}, 355 | "NEXT": {next_single_action}, 356 | "ACTION": 357 | 358 | ACTION_SPACE: 359 | 360 | SCREEN_TXT: |- 361 | CLICK_ELEMENT: 362 | summary: click on the visible element on the screen 363 | usage: 364 | [1] API call: click_element(bbox=) 365 | [2] Args: 366 | - bbox: 'The bounding box of the element to be clicked, formed as [x1, y1, x2, y2].' 367 | [3] Example: click_element(bbox=[10, 20, 30, 40]) 368 | [4] Return: None 369 | 370 | SCROLL: 371 | summary: move the scrollable content, or open the app drawer 372 | explanation: 373 | - Scrolling down typically means moving scrollable content to see what's further down. If content is moving up, you are scrolling down. 374 | - Scrolling up could either open the app drawer, or move back to view the previous content on the same screen. 375 | usage: 376 | [1] API call: scroll(direction=) 377 | [2] Args: 378 | - direction: 'Scroll direction. One of ''up'', ''down'', ''left'', ''right''.' 379 | [3] Example: scroll(direction="up") 380 | [4] Return: None 381 | 382 | INPUT: 383 | summary: input text to an editable input area 384 | usage: 385 | [1] API call: input(text="") 386 | [2] Args: 387 | - text: 'The text input to the editable input area.' 388 | [3] Example: input(text="Hello World") 389 | [4] Return: None 390 | 391 | PRESS_ENTER: 392 | summary: confirm the input, or submit the input, or start a new line of text 393 | usage: 394 | [1] API call: press_enter() 395 | [2] Args: None 396 | [3] Example: press_enter() 397 | [4] Return: None 398 | 399 | PRESS_HOME: 400 | summary: directly move back to the home screen 401 | usage: 402 | [1] API call: press_home() 403 | [2] Args: None 404 | [3] Example: press_home() 405 | [4] Return: None 406 | 407 | PRESS_BACK: 408 | summary: return to the most recently visited screen or interface 409 | usage: 410 | [1] API call: press_back() 411 | [2] Args: None 412 | [3] Example: press_back() 413 | [4] Return: None 414 | 415 | STOP: 416 | summary: stop and set the state of the task 417 | usage: 418 | [1] API call: stop(task_status=) 419 | [2] Args: 420 | - task_status: 'Decide on whether the task is a success or failure. Choose one of ''success'', ''failure''.' 421 | [3] Example: stop(task_status="success") 422 | [4] Return: None 423 | 424 | SCREEN_TAG: |- 425 | CLICK_ELEMENT: 426 | summary: click on the visible ui element on the screen 427 | usage: 428 | [1] API call: click_element(idx=) 429 | [2] Args: 430 | - idx: 'The index the element to be clicked, as shown on the screen.' 431 | [3] Example: click_element(idx=10) 432 | [4] Return: None 433 | 434 | SCROLL: 435 | summary: move the scrollable content, or open the app drawer 436 | explanation: 437 | - Scrolling down typically means moving scrollable content to see what's further down. If content is moving up, you are scrolling down. 438 | - Scrolling up could either open the app drawer, or move back to view the previous content on the same screen. 439 | usage: 440 | [1] API call: scroll(direction=) 441 | [2] Args: 442 | - direction: 'Scroll direction. One of ''up'', ''down'', ''left'', ''right''.' 443 | [3] Example: scroll(direction="up") 444 | [4] Return: None 445 | 446 | INPUT: 447 | summary: input text to an editable input area 448 | usage: 449 | [1] API call: input(text="") 450 | [2] Args: 451 | - text: 'The text input to the editable input area.' 452 | [3] Example: input(text="Hello World") 453 | [4] Return: None 454 | 455 | PRESS_ENTER: 456 | summary: confirm the input, or submit the input, or start a new line of text 457 | usage: 458 | [1] API call: press_enter() 459 | [2] Args: None 460 | [3] Example: press_enter() 461 | [4] Return: None 462 | 463 | PRESS_HOME: 464 | summary: directly move back to the home screen 465 | usage: 466 | [1] API call: press_home() 467 | [2] Args: None 468 | [3] Example: press_home() 469 | [4] Return: None 470 | 471 | PRESS_BACK: 472 | summary: return to the most recently visited screen or interface 473 | usage: 474 | [1] API call: press_back() 475 | [2] Args: None 476 | [3] Example: press_back() 477 | [4] Return: None 478 | 479 | STOP: 480 | summary: stop and set the state of the task 481 | usage: 482 | [1] API call: stop(task_status=) 483 | [2] Args: 484 | - task_status: 'Decide on whether the task is a success or failure. Choose one of ''success'', ''failure''.' 485 | [3] Example: stop(task_status="success") 486 | [4] Return: None 487 | 488 | ACTION_RESULT: 489 | SYSTEM: |- 490 | You are a smart and helpful visual assistant that is well trained to manipulate mobile phones. 491 | Your task is to explain the results made by last action, and the its influece towards the completion of user request. 492 | - You are provided with two screenshots of the moblie phone, one from the last step and one from the current screen. 493 | - You are provided with the details of the last action you performed on the last screenshot. 494 | - You are required to describe the direct consequences of the last action by comparing the two screenshots. 495 | - You are required to judge on whether this action has made progress towards the user request. 496 | 497 | ## Example 1: "By clicking on the search bar, now I can enter the query into the input area. This helps to complete the user query, because one has to use search engine to find the cheapest chair in walmart." 498 | ## Example 2: "By pressing the home button, I have moved back to the home screen. This is necessary because to install the user required app, we must fine the App store first." 499 | 500 | USER: |- 501 | : {before_screenshot} 502 | : {after_screenshot} 503 | : {last_action} 504 | : {user_request} 505 | : 506 | --------------------------------------------------------------------------------