├── assets
├── overview.png
└── first_image.png
├── scripts
└── server.sh
├── .gitmodules
├── train
└── sft.yaml
├── LICENSE
├── deploy
├── main.py
├── prompt.py
├── utils.py
├── env.py
└── agent.py
├── README.md
└── postprocess
├── prepare.py
├── prompt.py
├── boost.py
├── refinement.py
└── utils.py
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GAIR-NLP/PC-Agent-E/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/assets/first_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GAIR-NLP/PC-Agent-E/HEAD/assets/first_image.png
--------------------------------------------------------------------------------
/scripts/server.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | vllm serve "henryhe0123/PC-Agent-E" --tensor-parallel-size 4 --port 8030
4 |
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "train/LLaMA-Factory"]
2 | path = train/LLaMA-Factory
3 | url = git@github.com:hiyouga/LLaMA-Factory.git
4 |
--------------------------------------------------------------------------------
/train/sft.yaml:
--------------------------------------------------------------------------------
1 | ### model
2 | model_name_or_path: Qwen/Qwen2.5-VL-72B-Instruct
3 | image_max_pixels: 1000000
4 |
5 | ### method
6 | stage: sft
7 | do_train: true
8 | finetuning_type: full
9 | freeze_vision_tower: true
10 | deepspeed: examples/deepspeed/ds_z3_config.json
11 |
12 | ### dataset
13 | dataset: pc-agent-e
14 | template: qwen2_vl
15 | cutoff_len: 8192
16 | overwrite_cache: true
17 | preprocessing_num_workers: 64
18 |
19 | ### output
20 | output_dir: saves/pc-agent-e/Qwen2.5-VL-72B-sft
21 | logging_steps: 1
22 | save_steps: 100
23 | plot_loss: true
24 | overwrite_output_dir: true
25 |
26 | ### train
27 | per_device_train_batch_size: 2
28 | gradient_accumulation_steps: 2
29 | learning_rate: 2.0e-6
30 | num_train_epochs: 2
31 | lr_scheduler_type: cosine
32 | warmup_ratio: 0.05
33 | bf16: true
34 | ddp_timeout: 180000000
35 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 SII - Generative Artificial Intelligence Research Lab (GAIR)
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/deploy/main.py:
--------------------------------------------------------------------------------
1 | # main.py
2 |
3 | from openai import OpenAI
4 | from agent import PCAgentE
5 | from env import PCEnv
6 |
7 | client = OpenAI(
8 | api_key="EMPTY",
9 | base_url="http://localhost:8030/v1",
10 | )
11 | model = "henryhe0123/PC-Agent-E"
12 |
13 |
14 | def run(task_description, max_steps=30):
15 | # Initialize agent and environment
16 | agent = PCAgentE(client, model, max_steps)
17 | env = PCEnv()
18 |
19 | # Reset environment to get initial observation
20 | obs = env.reset()
21 |
22 | # Run interaction loop
23 | while True:
24 | # Agent predicts next action based on current observation
25 | actions, logs = agent.predict(task_description, obs)
26 | if not actions:
27 | print("Agent failed to generate valid actions, terminating execution")
28 | return
29 |
30 | # Execute each action
31 | for action in actions:
32 | print(f"Executing action: {action}")
33 | obs, done = env.step(action)
34 | if done:
35 | return
36 |
37 |
38 | if __name__ == "__main__":
39 | task_description = input("Please enter task description: ")
40 | run(task_description)
41 |
--------------------------------------------------------------------------------
/deploy/prompt.py:
--------------------------------------------------------------------------------
1 | # prompt.py
2 |
3 | AGENT_PROMPT = """You are a helpful assistant who can help users complete computer tasks, with **full permission** to make any operations on the user's computer.
4 | Based on the provided current state, you need to suggest the next action to complete the task. Do not try to complete the entire task in one step. Break it down into smaller steps, and at each step you will get a new state to interact with.
5 |
6 | IMPORTANT: You must strictly adhere to the following rules:
7 | 1. Choose ONLY ONE action from the list below for each response, DO NOT perform more than one action per step.
8 | 2. Follow the exact syntax format for the selected action, DO NOT create or use any actions other than those listed.
9 | 3. Once the task is completed, output action finish.
10 |
11 | Valid actions:
12 |
13 | 1. click (x, y)
14 | click the element at the position (x, y) on the screen
15 |
16 | 2. right click (x, y)
17 | right click the element at the position (x, y) on the screen
18 |
19 | 3. double click (x, y)
20 | double click the element at the position (x, y) on the screen
21 |
22 | 4. drag from (x1, y1) to (x2, y2)
23 | drag the element from position (x1, y1) to (x2, y2).
24 |
25 | 5. scroll (x)
26 | scroll the screen vertically with pixel offset x. Positive values of x: scroll up, negative values of x: scroll down.
27 |
28 | 6. press key: key_content
29 | press the key key_content on the keyboard.
30 |
31 | 7. hotkey (key1, key2)
32 | press the hotkey composed of key1 and key2.
33 |
34 | 8. hotkey (key1, key2, key3)
35 | press the hotkey composed of key1, key2, and key3.
36 |
37 | 9. type text: text_content
38 | type content text_content on the keyboard.
39 |
40 | 10. wait
41 | wait for some time, usually for the system to respond, screen to refresh, advertisement to finish.
42 |
43 | 11. finish
44 | indicating that the task has been completed.
45 |
46 | 12. fail
47 | indicating that the task has failed, of this task is infeasible because not enough information is provided.
48 |
49 | Response Format: {Your thought process}\n\nAction: {The specific action you choose to take}
50 |
51 | --------------------------------------------
52 |
53 | """
54 |
--------------------------------------------------------------------------------
/deploy/utils.py:
--------------------------------------------------------------------------------
1 | # utils.py
2 |
3 | import io
4 | import base64
5 |
6 |
7 | KEYBOARD_KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright']
8 |
9 |
10 | def encode_image(image):
11 | # encode image to base64 string
12 | buffered = io.BytesIO()
13 | image.save(buffered, format="PNG")
14 | return base64.b64encode(buffered.getvalue()).decode('utf-8')
15 |
16 |
17 | def save_screenshot(screenshot, path):
18 | screenshot.save(path, format="PNG")
19 |
20 |
21 | def get_mllm_messages(instruction, base64_image):
22 | messages = [
23 | {
24 | "role": "user",
25 | "content": [
26 | {
27 | "type": "text",
28 | "text": instruction
29 | },
30 | {
31 | "type": "image_url",
32 | "image_url": {
33 | "url": f"data:image/png;base64,{base64_image}"
34 | },
35 | },
36 | ],
37 | },
38 | ]
39 | return messages
40 |
--------------------------------------------------------------------------------
/deploy/env.py:
--------------------------------------------------------------------------------
1 | # env.py
2 |
3 | import time
4 | import pyautogui
5 | from io import BytesIO
6 | from PIL import ImageGrab
7 |
8 | class PCEnv:
9 | """
10 | PC Environment class, encapsulates the local computer environment,
11 | supports executing pyautogui code and capturing screenshots
12 | """
13 | def __init__(self, screenshot_size=(1280, 720)):
14 | """
15 | Initialize the environment
16 | Args:
17 | screenshot_size: Screenshot dimensions
18 | """
19 | self.screenshot_size = screenshot_size
20 | # Ensure pyautogui has failsafe measures
21 | pyautogui.FAILSAFE = True
22 | print("Initializing PC Environment...")
23 |
24 | def step(self, action):
25 | """
26 | Execute an action and return new observation
27 | Args:
28 | action: Action to execute (pyautogui code string)
29 | Returns:
30 | obs: Observation containing new screenshot
31 | done: Whether the task is completed
32 | """
33 | done = False
34 |
35 | # Handle special actions
36 | if action == "WAIT":
37 | time.sleep(3)
38 | elif action == "DONE":
39 | print("Task completed, terminating execution")
40 | done = True
41 | return {"screenshot": self.get_screenshot()}, done
42 | elif action == "FAIL":
43 | print("Task failed, terminating execution")
44 | done = True
45 | return {"screenshot": self.get_screenshot()}, done
46 | else:
47 | # Execute pyautogui code
48 | try:
49 | # Since we've imported pyautogui at the module level,
50 | # exec can directly execute strings like "pyautogui.click(1, 1)"
51 | # The pyautogui module is available in the exec's namespace
52 | exec(action)
53 | # Wait briefly to let UI respond
54 | time.sleep(1)
55 | except Exception as e:
56 | print(f"Action execution failed: {e}")
57 |
58 | # Return new observation (screenshot)
59 | return {"screenshot": self.get_screenshot()}, done
60 |
61 | def get_screenshot(self):
62 | """
63 | Capture current screen screenshot
64 | Returns:
65 | screenshot: Binary data of the screenshot
66 | """
67 | # Take screenshot
68 | screenshot = ImageGrab.grab()
69 |
70 | # Warning if size is not as expected
71 | if screenshot.size != self.screenshot_size:
72 | print(f"Warning: Screenshot size is not as expected. Expected {self.screenshot_size}, got {screenshot.size}")
73 |
74 | # Convert to binary
75 | buffer = BytesIO()
76 | screenshot.save(buffer, format='PNG')
77 | return buffer.getvalue()
78 |
79 | def reset(self):
80 | """
81 | Reset the environment
82 | Returns:
83 | obs: Observation containing new screenshot
84 | """
85 | # Reset only needs to return current screenshot as initial observation
86 | return {"screenshot": self.get_screenshot()}
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Efficient Agent Training for Computer Use
2 |
3 |
4 | 📄 Paper |
5 | 🌐 Website |
6 | 🤖 Model |
7 | 🤗 Dataset |
8 | ⚔️ WindowsAgentArena-V2
9 |
10 |
11 |
12 |
13 |
14 |
15 | ## Demo
16 |
17 | Check out our demo of PC Agent-E autonomously controlling a computer to complete tasks on Windows and Linux systems!
18 |
19 | https://github.com/user-attachments/assets/9540d8cb-630d-41e2-a108-a96ca3fcb32e
20 |
21 | https://github.com/user-attachments/assets/18b436e7-733f-49a5-8716-25c29a990766
22 |
23 | ## Introduction
24 |
25 | We introduce **PC Agent-E**, an efficient agent training framework that elicits strong computer use capabilities with remarkable **data efficiency**.
26 | This framework is implemented with four key components:
27 | 1. **Trajectory Collection**, gathering a small set of task trajectories from human annotators with [PC Tracker](https://github.com/GAIR-NLP/PC-Agent?tab=readme-ov-file#pc-tracker);
28 | 2. **Thought Completion**, reconstructing the latent human thought process before each action;
29 | 3. **Trajectory Boost**, synthesizing diverse alternative action decisions;
30 | 4. **Agent Training**, training native agent model with augmented trajectories.
31 |
32 | 
33 |
34 | ## Main Results
35 |
36 | Table: Results of successful rate (%) for different models on [WindowsAgentArena-V2](https://github.com/GAIR-NLP/WindowsAgentArena-V2), an improved benchmark we also released.
37 |
38 | | Models | LibreOffice | Chrome | Edge | System | VS Code | VLC | Utils | Total |
39 | |--------------------------|-------------|--------|-------|--------|---------|------|--------|-------|
40 | | **Number of Tasks** | 42 | 17 | 13 | 24 | 19 | 14 | 12 | 141 |
41 | | Qwen2.5-VL-72B | 0.0 | 34.7 | 15.4 | 20.8 | 26.3 | 7.6 | 16.7 | 14.9 |
42 | | UI-TARS-1.5-7B | **7.1** | 34.7 | 23.1 | 45.8 | 21.1 | 7.6 | 16.7 | 21.3 |
43 | | UI-TARS-72B-DPO | 0.0 | 40.6 | 38.5 | 58.3 | 36.8 | 7.6 | 25.0 | 26.2 |
44 | | Claude 3.7 Sonnet | 2.4 | 46.5 | **61.5** | 54.2 | 52.6 | 29.0 | 16.7 | 32.6 |
45 | | Claude 3.7 Sonnet (thinking) | 2.4 | **64.1** | 46.2 | **66.7** | 52.6 | 21.9 | 25.0 | 35.4 |
46 | | **PC Agent-E (Ours)** | 4.8 | **64.1** | 46.2 | 50.0 | **57.9**| **35.7** | **33.3** | **36.0** |
47 |
48 | ## Quick Start
49 |
50 | ### Trajectory Collection
51 |
52 | Collect raw human trajectory with PC Tracker. See usage [here](https://github.com/GAIR-NLP/PC-Agent?tab=readme-ov-file#pc-tracker).
53 |
54 | ### Post Processing
55 |
56 | To convert raw human trajectory into high-quality trajectories for training, follow these steps:
57 | 1. Place recorded in the `data/` directory.
58 | 2. Run post processing pipeline:
59 | ```bash
60 | # Data refinement
61 | python postprocess/refinement.py
62 |
63 | # Thought completion and Trajectory Boost
64 | python postprocess/boost.py
65 | ```
66 |
67 | Note: You need to prepare your API key in advance.
68 |
69 | ### Agent Training
70 |
71 | You can use [our dataset](https://huggingface.co/datasets/henryhe0123/PC-Agent-E) or build data set with above steps on your own. To prepare data for agent training, put the dataset in the `data/` directory, and run:
72 | ```bash
73 | python postprocess/prepare.py
74 | ```
75 |
76 | We recommend using [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory) for agent training. To launch distributed training across multiple nodes, you can run:
77 |
78 | ```bash
79 | FORCE_TORCHRUN=1 NNODES=4 NODE_RANK=${PET_NODE_RANK} MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train train/sft.yaml
80 | ```
81 |
82 | Replace PET_NODE_RANK with the rank of the current node (from 0 to 3).
83 |
84 | ### Agent Deployment
85 |
86 | We provide a reference implementation of our PC Agent-E scaffold in the `deploy/` directory. To deploy our agent on your computer, run:
87 |
88 | ```bash
89 | python deploy/main.py
90 | ```
91 |
92 | Reference scripts for model deployment can be found in `scripts/server.sh`.
93 |
94 | ## Acknowledgments
95 |
96 | We would like to express our sincere gratitude to Shijie Xia for his meticulous review and constructive
97 | suggestions, which significantly improved the quality of this paper. This project is supported by SJTU SEIEE - ByteDance Large Language Model Joint Laboratory, SII.
98 |
99 | ## Citation
100 |
101 | If you find this work helpful, please consider citing:
102 |
103 | ```
104 | @misc{he2025efficientagenttrainingcomputer,
105 | title={Efficient Agent Training for Computer Use},
106 | author={Yanheng He and Jiahe Jin and Pengfei Liu},
107 | year={2025},
108 | eprint={2505.13909},
109 | archivePrefix={arXiv},
110 | primaryClass={cs.AI},
111 | url={https://arxiv.org/abs/2505.13909},
112 | }
113 | ```
114 |
--------------------------------------------------------------------------------
/postprocess/prepare.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import copy
4 | import traceback
5 | from prompt import AGENT_PROMPT
6 | from utils import refine_response, refine_thought, combine_thought_action_to_response, get_history_str
7 |
8 |
9 | output_file = f"train/LLaMA-Factory/data/pc-agent-e.json"
10 | all_data = []
11 | BOOST = True
12 | HUMAN = True
13 | REMOVE_NO_FISISH = True
14 | BOOST_CNT = 9
15 |
16 |
17 | def get_instruction(task_description, action_history):
18 | prompt = AGENT_PROMPT + f"Your task is: {task_description}\n\n"
19 | prompt += f"History of the previous actions and thoughts you have done to reach the current screen: {action_history}\n\n"
20 | prompt += "--------------------------------------------\n\n"
21 | prompt += f"Given the screenshot. What's the next step that you will do to help with the task?"
22 | return prompt
23 |
24 |
25 | def check_boost_response(boost_response, action):
26 | if boost_response is None:
27 | return False
28 |
29 | if REMOVE_NO_FISISH and action == "finish" and not "finish" in boost_response:
30 | # print(f"last action for boost is not finish, remove it!")
31 | return False
32 | if "(x, y)" in boost_response or "(x,y)" in boost_response:
33 | return False
34 |
35 | return True
36 |
37 |
38 | def process_task_jsonl_file(file_path, dir_path, task_description):
39 | with open(file_path, "r", encoding="utf-8") as file:
40 | lines = file.readlines()
41 |
42 | response_history = [] # for action history in natural language
43 |
44 | for idx, line in enumerate(lines):
45 | formatted_task = {
46 | "messages": [],
47 | "images": "",
48 | }
49 |
50 | entry = json.loads(line)
51 | action = entry["action"]
52 |
53 | # Reorganize press key action
54 | if action.startswith("press key"):
55 | action = action.replace("press key", "press key:")
56 |
57 | screenshot_path = entry["screenshot"]
58 | screenshot_path = f"{dir_path}/{screenshot_path}"
59 | # Add image path
60 | formatted_task["images"] = [screenshot_path]
61 |
62 | # Add user message
63 | action_history = get_history_str(response_history)
64 |
65 | query = get_instruction(task_description, action_history)
66 | formatted_task["messages"].append({"role": "user", "content": query})
67 |
68 | # Add boost response
69 | if BOOST and "boost_responses" in entry:
70 | for id, boost_response in enumerate(entry["boost_responses"][:BOOST_CNT]):
71 | if not check_boost_response(boost_response, action):
72 | continue
73 | boost_response_cleaned = refine_response(boost_response)
74 | if boost_response_cleaned is None:
75 | continue
76 | formatted_task_copy = copy.deepcopy(formatted_task)
77 | formatted_task_copy["messages"].append({"role": "assistant", "content": boost_response_cleaned})
78 | all_data.append(formatted_task_copy)
79 |
80 | # Add assistant message
81 | thought = refine_thought(entry['thought'])
82 | if thought is not None:
83 | response = combine_thought_action_to_response(thought, action)
84 | formatted_task["messages"].append({"role": "assistant", "content": response})
85 | response_history.append(response)
86 | if HUMAN:
87 | all_data.append(formatted_task)
88 |
89 |
90 | def process_events_directories():
91 | # Get the parent directory of the current script
92 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
93 |
94 | # Build the path to the data folder
95 | data_dir = os.path.join(root_dir, 'data')
96 | if not os.path.exists(data_dir):
97 | print(f"error: {data_dir} directory does not exist")
98 | exit()
99 |
100 | # Events folder prefix
101 | events_prefix = "events"
102 |
103 | # Process each subdirectory under /data
104 | for item in os.listdir(data_dir):
105 | item_path = os.path.join(data_dir, item)
106 |
107 | # Check if it's a directory and starts with specified name
108 | if os.path.isdir(item_path) and item.startswith(events_prefix):
109 | for filename in os.listdir(item_path):
110 | # Process each jsonl file under the directory
111 | if filename.endswith(".jsonl") and "task" in filename:
112 | file_path = os.path.join(item_path, filename)
113 | md_path = os.path.join(item_path, filename.replace(".jsonl", ".md"))
114 | with open(md_path, "r", encoding="utf-8") as file:
115 | lines = file.readlines()
116 | try:
117 | task_description = lines[1].replace("**Description:** ", "").strip()
118 | except:
119 | print(f"Error: Unable to extract task description from {md_path}")
120 | continue
121 | try:
122 | process_task_jsonl_file(file_path, item_path, task_description)
123 | except Exception as e:
124 | error_traceback = traceback.format_exc()
125 | print(f"{file_path} encountered error: {e}")
126 | print(f"{error_traceback}")
127 |
128 |
129 | if __name__ == "__main__":
130 | process_events_directories()
131 | print(f"entries: {len(all_data)}")
132 | with open(output_file, "w", encoding="utf-8") as file:
133 | json.dump(all_data, file, indent=2, ensure_ascii=False)
134 |
--------------------------------------------------------------------------------
/postprocess/prompt.py:
--------------------------------------------------------------------------------
1 | THOUGHT_COMPLETION_PROMPT = """You are a helpful computer use agent designed to complete tasks on a computer. Your goal is to recreate your thought process behind a specific action.
2 |
3 | You will be provided with:
4 |
5 | 1. The task you are attempting to complete.
6 | 2. A history of the steps you have already performed (up to 50, if any; none if it was the first action).
7 | 3. The specific action you chose to take.
8 | 4. The name of the element you clicked (if you clicked on an element). It might be too general or vague, you have to decied what to click based on the screenshot.
9 | 5. A screenshot of the computer screen at the moment you decided to take the action.
10 | 6. The red marks on the screenshot indicate the position of the click or drag action.
11 |
12 |
13 | To formulate your thought process, consider:
14 |
15 | 1. What do you observe on the screen? Consider your task and previous action when you analyzing current screenshot.
16 | 2. Evaluate your previous action (if applicable):
17 | - Did it achieve the intended effect? If not, identify possible reasons (e.g., misclick, inactive element).
18 | Some typical examples for ineffective action:
19 | - misclick in an empty space
20 | - ineffective opening some elements without double click
21 | - ineffective type text/ press key because of inactivated input box
22 | - Did the result align with your previous plan, or did something unexpected happen?
23 | Some typical examples for ineffective action:
24 | - misclick in a wrong element
25 | - forget to clear existing text in input bar
26 | 3. Based on your action history, assess your progress toward completing the overall task.
27 | 4. Consider if you're exploring how to finish the task because of failed attempts in history steps.
28 |
29 |
30 | Present your thought process as a clear, natural first-person narrative that explains your reasoning at that moment.
31 |
32 | Important requirements:
33 | 1. **DO NOT** mention the red marks in your response. These marks were **added after the fact** to indicate the position of your click or drag actions, and they were not on the screen when you made the decision. **DO NOT** mention "red box", "red square", "red circle", or "red arrow" in your response.
34 | 2. Write as if you are thinking in real-time before taking the action. Do not include post-action evaluation or hindsight.
35 |
36 | --------------------------------------------
37 | """
38 |
39 |
40 | TRAJECTORY_BOOST_PROMPT = """
41 | You are a helpful assistant who can help users complete computer tasks, with **full permission** to make any operations on the user's computer. The operating system is windows.
42 | Based on the provided current state, you need to suggest the next action to complete the task. Do not try to complete the entire task in one step. Break it down into smaller steps, and at each step you will get a new state to interact with.
43 |
44 | IMPORTANT: You must strictly adhere to the following rules:
45 |
46 | 1. Choose ONLY ONE action from the list below for each response, DO NOT perform more than one action per step.
47 | 2. Follow the exact syntax format for the selected action, DO NOT create or use any actions other than those listed.
48 | 3. Once the task is completed, output action finish.
49 |
50 | Valid actions:
51 |
52 | 1. click (x, y)
53 | click the element at the position (x, y) on the screen
54 |
55 | 2. right click (x, y)
56 | right click the element at the position (x, y) on the screen
57 |
58 | 3. double click (x, y)
59 | double click the element at the position (x, y) on the screen
60 |
61 | 4. drag from (x1, y1) to (x2, y2)
62 | drag the element from position (x1, y1) to (x2, y2).
63 |
64 | 5. scroll (x)
65 | scroll the screen vertically with pixel offset x. Positive values of x: scroll up, negative values of x: scroll down.
66 |
67 | 6. press key: key_content
68 | press the key key_content on the keyboard.
69 |
70 | 7. hotkey (key1, key2)
71 | press the hotkey composed of key1 and key2.
72 |
73 | 8. hotkey (key1, key2, key3)
74 | press the hotkey composed of key1, key2, and key3.
75 |
76 | 9. type text: text_content
77 | type content text_content on the keyboard.
78 | Note that before typing text, you need to ensure the text box or input field is active/focused first. If the text box is not yet activated, you should first click on it to activate it, and then use type text in a separate step.
79 |
80 | 10. wait
81 | wait for some time, usually for the system to respond, screen to refresh, advertisement to finish.
82 |
83 | 11. finish
84 | indicating that the task has been completed.
85 |
86 | 12. fail
87 | indicating that the task has failed, of this task is infeasible because not enough information is provided.
88 |
89 |
90 | Before deciding your next action, you should think carefully about the current state of the screen and your history steps. Contain the following points in your thought process:
91 |
92 | 1. What do you observe on the screen? Consider your task and previous action when you analyzing current screenshot.
93 | 2. What's your previous plan and action (if applicable)? Evaluate your previous plan and action in three conditions:
94 | 1. It didn't make any effect. You should dentify possible reasons (e.g., misclick, inactive element) and adjust your plan in this step.
95 | Some typical examples for ineffective action:
96 | - misclick in an empty space
97 | - ineffective opening some elements without double click
98 | - ineffective type text/ press key because of inactivated input box
99 | 2. It made some effect, but the result does not align with previous plan. You should dentify possible reasons (e.g., misclick, inactive element) and correct it in this step.
100 | Some typical examples for ineffective action:
101 | - misclick in a wrong element
102 | - forget to clear existing text in input bar
103 | 3. It made some effect, and it successfully align with previous plan. You should progress to the next step based on the current state.
104 | 3. Based on your action history, assess your progress toward completing the overall task.
105 | 4. Exploring new ways to finish the task if there are already failed attempts in history steps. **DO NOT repeat** the history actions.
106 |
107 |
108 | Response Format: Your thought process\n\nAction: The specific action you choose to take
109 | """
110 |
111 |
112 | AGENT_PROMPT = """You are a helpful assistant who can help users complete computer tasks, with **full permission** to make any operations on the user's computer. The operating system is windows.
113 | Based on the provided current state, you need to suggest the next action to complete the task. Do not try to complete the entire task in one step. Break it down into smaller steps, and at each step you will get a new state to interact with.
114 |
115 | IMPORTANT: You must strictly adhere to the following rules:
116 | 1. Choose ONLY ONE action from the list below for each response, DO NOT perform more than one action per step.
117 | 2. Follow the exact syntax format for the selected action, DO NOT create or use any actions other than those listed.
118 | 3. Once the task is completed, output action finish.
119 |
120 | Valid actions:
121 |
122 | 1. click (x, y)
123 | click the element at the position (x, y) on the screen
124 |
125 | 2. right click (x, y)
126 | right click the element at the position (x, y) on the screen
127 |
128 | 3. double click (x, y)
129 | double click the element at the position (x, y) on the screen
130 |
131 | 4. drag from (x1, y1) to (x2, y2)
132 | drag the element from position (x1, y1) to (x2, y2).
133 |
134 | 5. scroll (x)
135 | scroll the screen vertically with pixel offset x. Positive values of x: scroll up, negative values of x: scroll down.
136 |
137 | 6. press key: key_content
138 | press the key key_content on the keyboard.
139 |
140 | 7. hotkey (key1, key2)
141 | press the hotkey composed of key1 and key2.
142 |
143 | 8. hotkey (key1, key2, key3)
144 | press the hotkey composed of key1, key2, and key3.
145 |
146 | 9. type text: text_content
147 | type content text_content on the keyboard.
148 |
149 | 10. wait
150 | wait for some time, usually for the system to respond, screen to refresh, advertisement to finish.
151 |
152 | 11. finish
153 | indicating that the task has been completed.
154 |
155 | 12. fail
156 | indicating that the task has failed, of this task is infeasible because not enough information is provided.
157 |
158 | Response Format: {Your thought process}\n\nAction: {The specific action you choose to take}
159 |
160 | --------------------------------------------
161 |
162 | """
163 |
164 |
--------------------------------------------------------------------------------
/deploy/agent.py:
--------------------------------------------------------------------------------
1 | # agent.py
2 |
3 | import re
4 | import time
5 | from typing import Dict, List
6 | from PIL import Image
7 | from io import BytesIO
8 | from utils import *
9 | from prompt import *
10 |
11 |
12 | class PCAgentE:
13 | def __init__(
14 | self, client, model, max_steps=30, screenshot_size=(1280, 720), prompt=AGENT_PROMPT
15 | ):
16 | self.retry_click_elements = []
17 | self.history = []
18 | self.history_cut_off = 10
19 | self.client = client
20 | self.model = model
21 | self.max_steps = max_steps
22 | self.screenshot_size = screenshot_size
23 | self.prompt = prompt
24 | self.steps = 0
25 | print(f"Using model: {model}")
26 |
27 | def predict(self, instruction: str, obs: Dict):
28 | """
29 | Predict the next action based on the current observation
30 | Args:
31 | instruction: the task description
32 | obs: the current observation (obs['screenshot'])
33 | Returns:
34 | actions: the code of next action
35 | logs: the logs of next action
36 | """
37 | logs = {}
38 | self.task_description = instruction
39 |
40 | # get and process the screenshot
41 | image_file = BytesIO(obs['screenshot'])
42 | view_image = Image.open(image_file)
43 |
44 | # call the visual language model for planning
45 | self.screenshot_size = view_image.size
46 | try_time = 5
47 | feedback = ""
48 | while try_time > 0:
49 | plan, action = self.get_plan(view_image, self.task_description, feedback)
50 | action_code = self.get_action_code(action)
51 | if action_code is None:
52 | print(f"Invalid action: {action}, Try again.")
53 | feedback = f"\n\nNote: You have provided an invalid action before: {action}, please try again."
54 | try_time -= 1
55 | if try_time == 0:
56 | raise ValueError(f"Fail to get valid action after 5 try: {action}")
57 | else:
58 | self.add_to_history(f"Plan: {plan}\n\nAction: {action}")
59 | break
60 |
61 | # check if the steps is greater than the max steps
62 | self.steps += 1
63 | if self.steps >= self.max_steps and action_code != "DONE":
64 | logs['plan_result'] = "Max steps reached"
65 | actions = ["FAIL"]
66 | else:
67 | logs['plan_result'] = f"Plan: {plan}\n\nAction: {action}"
68 | actions = [action_code]
69 |
70 | return actions, logs
71 |
72 | def reset(self):
73 | """Reset the agent state"""
74 | self.history.clear()
75 | pass
76 |
77 | def get_plan(self, screenshot, task_description, feedback=""):
78 | """
79 | get the next plan
80 | Args:
81 | screenshot: the screenshot
82 | task_description: task description
83 | retry_click_elements: the list of elements that failed to click before
84 | Returns:
85 | plan_str: plan description
86 | action_str: specific action
87 | """
88 | base64_image = encode_image(screenshot)
89 | try_time = 5
90 | while try_time > 0:
91 | try:
92 | instruction = self.get_plan_instruction(task_description, feedback)
93 | messages = get_mllm_messages(instruction, base64_image)
94 |
95 | completion = self.client.chat.completions.create(
96 | model=self.model,
97 | messages=messages,
98 | max_tokens=512,
99 | temperature=0.8
100 | )
101 | output_text = completion.choices[0].message.content
102 | print(f"Output from agent: {output_text}")
103 |
104 | if not "Action" in output_text:
105 | feedback = f"\n\nNote: You should provide an action after 'Action:' in the response."
106 |
107 | return self.split_output(output_text)
108 |
109 | except Exception as e:
110 | print(f"Failed to get the plan: {e}, try again.")
111 | time.sleep(1)
112 | if try_time == 1:
113 | raise Exception(f"Failed to get the plan: {e}")
114 |
115 | try_time -= 1
116 |
117 | def add_to_history(self, output):
118 | """
119 | add the output to the history
120 | """
121 | self.history.append(output)
122 |
123 | def get_action_history(self):
124 | if len(self.history) > self.history_cut_off:
125 | history_str = "\n\n".join(f"[{i+1}] {item}" for i, item in enumerate(self.history[-self.history_cut_off:]))
126 | else:
127 | history_str = "\n\n".join(f"[{i+1}] {item}" for i, item in enumerate(self.history))
128 |
129 | if history_str == '':
130 | history_str = "None"
131 |
132 | return history_str
133 |
134 | def get_plan_instruction(self, task_description, feedback=""):
135 | """
136 | generate the planning instruction
137 | """
138 | prompt = self.prompt + f"Your task is: {task_description}\n\n"
139 | history_str = self.get_action_history()
140 | prompt += f"History of the previous actions and thoughts you have done to reach the current screen: {history_str}\n\n"
141 | prompt += "--------------------------------------------\n\n"
142 | prompt += f"Given the screenshot. What's the next step that you will do to help with the task?"
143 | prompt += feedback
144 | return prompt
145 |
146 | def split_output(self, output):
147 | """
148 | split the output into plan and action
149 | """
150 | plan_str = output.split("Action:")[0].strip().strip('{}')
151 | action_str = output.split("Action:")[1].strip().strip('{}')
152 | return plan_str, action_str
153 |
154 | def get_action_code(self, action) -> str:
155 | screen_width, screen_height = self.screenshot_size
156 | # click
157 | match = re.match(r"click \((-?\d+), (-?\d+)\)", action)
158 | if match:
159 | x = int(match.group(1))
160 | y = int(match.group(2))
161 | if 0 <= x < screen_width and 0 <= y < screen_height:
162 | return f"pyautogui.click({x}, {y})"
163 | else:
164 | return None
165 |
166 | # right click
167 | match = re.match(r"right click \((-?\d+), (-?\d+)\)", action)
168 | if match:
169 | x = int(match.group(1))
170 | y = int(match.group(2))
171 | if 0 <= x < screen_width and 0 <= y < screen_height:
172 | return f"pyautogui.rightClick({x}, {y})"
173 | else:
174 | return None
175 |
176 | # double click
177 | match = re.match(r"double click \((-?\d+), (-?\d+)\)", action)
178 | if match:
179 | x = int(match.group(1))
180 | y = int(match.group(2))
181 | if 0 <= x < screen_width and 0 <= y < screen_height:
182 | return f"pyautogui.doubleClick({x}, {y})"
183 | else:
184 | return None
185 |
186 | # drag
187 | match = re.match(r"drag from \((-?\d+), (-?\d+)\) to \((-?\d+), (-?\d+)\)", action)
188 | if match:
189 | x1 = int(match.group(1)) # start x coordinate
190 | y1 = int(match.group(2)) # start y coordinate
191 | x2 = int(match.group(3)) # target x coordinate
192 | y2 = int(match.group(4)) # target y coordinate
193 | if 0 <= x1 < screen_width and 0 <= y1 < screen_height and 0 <= x2 < screen_width and 0 <= y2 < screen_height:
194 | return f"pyautogui.mouseDown({x1}, {y1})\npyautogui.dragTo({x2}, {y2}, duration=0.5)"
195 | else:
196 | return None
197 |
198 | # scroll
199 | match = re.match(r"scroll \((-?\d+)\)", action)
200 | if match:
201 | y = int(match.group(1)) # vertical scroll distance
202 | return f"pyautogui.scroll({y})" # positive: scroll up, negative: scroll down
203 |
204 | # press key
205 | match = re.match(r"press key: (.+)", action)
206 | if match:
207 | key_content = match.group(1).lower()
208 | # Format error
209 | if 'key' in key_content:
210 | return None
211 | # if key is not in the valid keyboard keys list
212 | if key_content not in KEYBOARD_KEYS:
213 | return None
214 | return f"pyautogui.press('{key_content}')"
215 |
216 | # hotkey
217 | match = re.match(r"hotkey \((.+), (.+), (.+)\)", action)
218 | if match:
219 | key1 = match.group(1).strip("'").lower()
220 | key2 = match.group(2).strip("'").lower()
221 | key3 = match.group(3).strip("'").lower()
222 | # Format error
223 | if 'key' in key1 or 'key' in key2 or 'key' in key3:
224 | return None
225 | return f"pyautogui.hotkey('{key1}', '{key2}', '{key3}')"
226 |
227 | match = re.match(r"hotkey \((.+), (.+)\)", action)
228 | if match:
229 | key1 = match.group(1).strip("'").lower()
230 | key2 = match.group(2).strip("'").lower()
231 | # Format error
232 | if 'key' in key1 or 'key' in key2:
233 | return None
234 | return f"pyautogui.hotkey('{key1}', '{key2}')"
235 |
236 | # type text
237 | match = re.match(r"type text: (.+)", action)
238 | if match:
239 | text_content = match.group(1).strip("'").strip("\"")
240 | text_content = text_content.replace("\"", "\\\"")
241 | text_content = text_content.replace("\'", "\\\'")
242 | # Format error
243 | if "text_content" in text_content:
244 | return None
245 | return f"pyautogui.write(\"{text_content}\")"
246 |
247 | # wait
248 | if action == "wait":
249 | return "WAIT"
250 |
251 | # finish
252 | if action == "finish":
253 | return "DONE"
254 |
255 | # fail
256 | if action == "fail":
257 | return "FAIL"
258 |
259 | return None
260 |
--------------------------------------------------------------------------------
/postprocess/boost.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import sys
4 | import random
5 | import concurrent.futures
6 | import argparse
7 | import traceback
8 | import time
9 | from datetime import datetime
10 | from openai import OpenAI
11 | from concurrent.futures import ThreadPoolExecutor
12 | from prompt import *
13 | from utils import *
14 |
15 | THOUGHT = True
16 | BOOST = True
17 | CONCURRENT_NUM = 18
18 | RE_GENERATE = False
19 | MAX_CONTEXT_ENTRIES = 30
20 | DETAILED_OUTPUT = True
21 | BOOST_COUNT = 9
22 |
23 |
24 | client = OpenAI()
25 | model = "claude-3-7-sonnet-20250219"
26 | print(f"Using model: {model}")
27 |
28 |
29 | def call_model(query, base64_image=None):
30 | messages = [
31 | {
32 | "role": "user",
33 | "content": [
34 | {
35 | "type": "image_url",
36 | "image_url": {
37 | "url": f"data:image/jpeg;base64,{base64_image}"
38 | }
39 | } if base64_image else None,
40 | {
41 | "type": "text",
42 | "text": query
43 | },
44 | ],
45 | },
46 | ]
47 |
48 | retry_time = 10
49 | while retry_time > 0:
50 | try:
51 | completion = client.chat.completions.create(
52 | model=model,
53 | messages=messages,
54 | max_tokens=1000
55 | )
56 | reply = completion.choices[0].message.content
57 |
58 | return reply
59 | except Exception as e:
60 | if retry_time == 1:
61 | raise e
62 | else:
63 | pass
64 | retry_time -= 1
65 |
66 |
67 | def process_concurrently(data_dir, events_prefix, function):
68 | tasks = []
69 |
70 | for item in os.listdir(data_dir):
71 | item_path = os.path.join(data_dir, item)
72 |
73 | if os.path.isdir(item_path) and item.startswith(events_prefix):
74 | print(f'Processing directory: {item_path}')
75 | for filename in os.listdir(item_path):
76 | if filename.endswith('.jsonl') and 'task' in filename:
77 | file_path = os.path.join(item_path, filename)
78 | md_path = os.path.join(item_path, filename.replace('.jsonl', '.md'))
79 | try:
80 | with open(md_path, 'r', encoding='utf-8') as file:
81 | lines = file.readlines()
82 | task_description = lines[1].replace('**Description:** ', '').strip()
83 | tasks.append((file_path, task_description))
84 | except Exception as e:
85 | print(f"error: failed to extract task description from {md_path}: {e}")
86 |
87 | random.shuffle(tasks)
88 | with ThreadPoolExecutor(max_workers=CONCURRENT_NUM) as executor:
89 | futures = [executor.submit(function, file_path, task_description)
90 | for file_path, task_description in tasks]
91 | concurrent.futures.wait(futures)
92 |
93 |
94 | def get_history_str_for_boost(history_steps):
95 | """
96 | no context limit, extra
97 | """
98 | history_str = ""
99 | for i, step in enumerate(history_steps):
100 | step_id, step_content = step
101 | if i == len(history_steps) - 1:
102 | history_str += f"**Your Previous Step**: Step {step_id}: {step_content}"
103 | else:
104 | history_str += f"Step {step_id}: {step_content}\n\n"
105 | return history_str
106 |
107 |
108 | def get_thought(task_description, entry, history_steps, marked_screenshot_path):
109 | """
110 | Generate thought for the action.
111 | """
112 | base64_image = encode_image(marked_screenshot_path)
113 | action = entry["action"]
114 | element_description = entry["element"]
115 | history_str = get_history_str_for_boost(history_steps)
116 |
117 | query = THOUGHT_COMPLETION_PROMPT \
118 | + f"The task you are attempting to complete: {task_description}\n\n" \
119 | + f"Your performing history:\n{history_str}\n\n" \
120 | + f"The specific action you chose to perform: {action}\n\n"
121 |
122 | if element_description and element_description != "Unknown":
123 | query += f"The element you clicked: {element_description}\n\n"
124 |
125 | while True:
126 | thought = call_model(query, base64_image)
127 | thought = refine_thought(thought)
128 | if thought is not None:
129 | return thought
130 |
131 |
132 | def get_boost_responses(task_description, entry, history_steps, screenshot_path, num):
133 | """
134 | Generate boost responses
135 | """
136 | base64_image = encode_image(screenshot_path)
137 | history_str = get_history_str_for_boost(history_steps)
138 |
139 | query = TRAJECTORY_BOOST_PROMPT \
140 | + f"The task you are attempting to complete: {task_description}\n\n" \
141 | + f"Your performing history:\n{history_str}\n\n" \
142 | + f"Given the screenshot as below. What's the next step that you will do to help with the task?"
143 |
144 | responses = []
145 |
146 | # Add more boost responses one by one
147 | for i in range(num-len(responses)):
148 | try_time = 5
149 | while try_time > 0:
150 | response = call_model(query, base64_image)
151 | response = refine_response(response)
152 | if response is not None:
153 | responses.append(response)
154 | break
155 | try_time -= 1
156 | responses.append(None)
157 |
158 | return responses
159 |
160 |
161 | def add_entry_for_file(file_path, task_description):
162 | print(f"begin add entry for {file_path}")
163 | entries = []
164 |
165 | try:
166 | with open(file_path, 'r', encoding='utf-8') as file:
167 | entries = [json.loads(line) for line in file]
168 | except Exception as e:
169 | print(f"error: failed to read file {file_path}: {e}")
170 | return
171 |
172 | try:
173 | for id, entry in enumerate(entries):
174 | # check marked screenshot available
175 | if 'marked_screenshot' not in entry:
176 | print(f"error: marked_screenshot field not found: {file_path}")
177 | continue
178 |
179 | marked_screenshot_path = os.path.join(os.path.dirname(file_path), entry['marked_screenshot'])
180 | screenshot_path = os.path.join(os.path.dirname(file_path), entry['screenshot'])
181 | if not os.path.isfile(marked_screenshot_path):
182 | print(f"error: screenshot file not found: {marked_screenshot_path}")
183 | continue
184 |
185 | # build history steps
186 | history_steps = []
187 | start_idx = max(0, id - MAX_CONTEXT_ENTRIES)
188 | for hist_id in range(start_idx, id):
189 | hist_entry = entries[hist_id]
190 | if 'thought' in hist_entry:
191 | history_steps.append((hist_id+1, combine_thought_action_to_response(hist_entry['thought'], hist_entry['action'])))
192 |
193 | # get thought completion
194 | if THOUGHT:
195 | try:
196 | field = "thought"
197 | if field in entry:
198 | if RE_GENERATE:
199 | entry[field] = get_thought(task_description, entry, history_steps, marked_screenshot_path)
200 | else:
201 | # try refine thought
202 | thought = refine_thought(entry[field])
203 | # re-generate if not qualified
204 | if thought is None:
205 | entry[field] = get_thought(task_description, entry, history_steps, marked_screenshot_path)
206 | else:
207 | entry[field] = get_thought(task_description, entry, history_steps, marked_screenshot_path)
208 | except Exception as e:
209 | print(f"error: failed to add thought file {file_path}: {e}")
210 |
211 | # get boost responses
212 | if BOOST:
213 | try:
214 | field = "boost_responses"
215 | if field in entry:
216 | if RE_GENERATE:
217 | entry[field] = get_boost_responses(task_description, entry, history_steps, screenshot_path, BOOST_COUNT)
218 | else:
219 | responses = []
220 | # append existing reponse after refinement
221 | for response in entry[field]:
222 | # remove empty response
223 | if response is None:
224 | continue
225 | response = refine_response(response)
226 | responses.append(response)
227 | # add new reponses if not enough
228 | if len(responses) < BOOST_COUNT:
229 | print(f"append new boost response\n")
230 | new_reponses = get_boost_responses(task_description, entry, history_steps, screenshot_path, BOOST_COUNT - len(responses))
231 | responses.extend(new_reponses)
232 |
233 | entry[field] = responses
234 | else:
235 | entry[field] = get_boost_responses(task_description, entry, history_steps, screenshot_path, BOOST_COUNT)
236 | except Exception as e:
237 | print(f"error: failed to boost file {file_path}: {e}")
238 | raise
239 |
240 | if DETAILED_OUTPUT:
241 | print(f"boost finished for entry {id} in file {file_path}")
242 |
243 | with open(file_path, 'w', encoding='utf-8') as file:
244 | for entry in entries:
245 | json.dump(entry, file, ensure_ascii=False)
246 | file.write('\n')
247 |
248 | rewrite_markdown_file_by_jsonl(file_path)
249 | print(f"finished adding thought for {file_path}")
250 |
251 | except Exception as e:
252 | traceback.print_exc()
253 | print(f"error: failed to process file {file_path}: {e}")
254 | if "Expecting" in str(e) or "Invalid control character" in str(e):
255 | print(f"file {file_path} is corrupted, deleting...")
256 | try:
257 | os.remove(file_path)
258 | print(f"deleted corrupted file: {file_path}")
259 | except OSError as delete_error:
260 | print(f"error: failed to delete corrupted file: {delete_error}")
261 |
262 |
263 | if __name__ == "__main__":
264 | parser = argparse.ArgumentParser(description="Choose which model to use.")
265 | parser.add_argument(
266 | "--specific_data_dir",
267 | type=str,
268 | default=None,
269 | help="Optional path to a specific data directory.",
270 | )
271 | parser.add_argument(
272 | "--events_prefix",
273 | type=str,
274 | default=None,
275 | help="Events prefix",
276 | )
277 | parser.add_argument(
278 | "--boost_count",
279 | type=int,
280 | default=None,
281 | help="Optional number of items to boost. If None, boost all."
282 | )
283 |
284 | args = parser.parse_args()
285 |
286 | start_time = datetime.now()
287 | print(f"start time: {start_time}")
288 |
289 | # Get parent directory of current script
290 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
291 |
292 | # Build total data folder path
293 | data_dir = os.path.join(root_dir, 'data')
294 | if not os.path.exists(data_dir):
295 | print(f"error: {data_dir} directory does not exist")
296 | exit()
297 |
298 | # Events folder prefix
299 | events_prefix = "events" if args.events_prefix is None else args.events_prefix
300 |
301 | process_concurrently(data_dir, events_prefix, add_entry_for_file)
302 |
303 | print("process events finished!")
304 |
305 | end_time = datetime.now()
306 | print(f"end time: {end_time}")
307 | print(f"Total time: {end_time - start_time}")
308 |
--------------------------------------------------------------------------------
/postprocess/refinement.py:
--------------------------------------------------------------------------------
1 | # multi-function script for data refinement
2 | # 1. rewrite screenshot path
3 | # 2. clean fail and error record
4 | # 3. check last action finish
5 | # 4. merge press and drag
6 | # 5. remove redundant actions
7 | # 6. remove meaningless actions
8 | # 7. rewrite scroll
9 | # 8. resize screenshot and coordinates to 1080p -> 720p
10 | # 9. clean tracker interface
11 | # 10. mark screenshot with red rect and point
12 | # 11. rewrite markdown file
13 | # 12. statistics
14 | # support interrupt
15 |
16 |
17 | import os
18 | import json
19 | import sys
20 | import numpy as np
21 | from PIL import Image
22 | from utils import *
23 |
24 | OVERWRITE_MARKED = True
25 | REMOVE_FAIL_RECORD = True
26 | DETAIL_OUTPUT = False
27 |
28 |
29 | def screenshot_of_tracker(screenshot_path, sample_size=100):
30 | """
31 | check if the screenshot is a Tracker interface.
32 | """
33 | if get_file_size_kb(screenshot_path) > 83: # magic number
34 | return False
35 |
36 | bg_color = "#f0f0f0"
37 | bg_threshold = 0.8
38 | top_offset = 40 # top area offset
39 | bottom_offset = 80 # bottom area offset
40 |
41 | with Image.open(screenshot_path) as img:
42 | width, height = img.size
43 |
44 | # define the sampling regions
45 | sample_regions = [
46 | (0, top_offset, sample_size, sample_size + top_offset), # top left corner
47 | (width - sample_size, top_offset, width, sample_size + top_offset), # top right corner
48 | (0, height - sample_size - bottom_offset, sample_size, height - bottom_offset), # bottom left corner
49 | (width - sample_size, height - sample_size - bottom_offset, width, height - bottom_offset) # bottom right corner
50 | ]
51 |
52 | # convert the background color to numpy array
53 | bg_color_rgb = np.array([int(bg_color[i:i + 2], 16) for i in (1, 3, 5)])
54 |
55 | # check the four regions
56 | for region in sample_regions:
57 | sample_region = img.crop(region)
58 | sample_array = np.array(sample_region)[:, :, :3]
59 | matches = np.all(sample_array == bg_color_rgb, axis=2)
60 | bg_ratio = np.sum(matches) / matches.size
61 |
62 | if bg_ratio < bg_threshold:
63 | return False
64 |
65 | return True
66 |
67 |
68 | def clean_tracker_interface(file_path):
69 | """
70 | clean the action records of the Tracker interface.
71 |
72 | return the number of actions after cleaning, -1 means the file is deleted
73 | """
74 | if DETAIL_OUTPUT:
75 | print(f"Clean tracker interface: {file_path}")
76 | screenshot_paths = []
77 | entries = []
78 |
79 | with open(file_path, 'r', encoding='utf-8') as file:
80 | for line in file:
81 | entry = json.loads(line)
82 | full_path = os.path.join(os.path.dirname(file_path), entry['screenshot'])
83 | screenshot_paths.append(full_path)
84 | entries.append(entry)
85 |
86 | last_entry_action = entries[-1].get('action')
87 | markdown_path = file_path.replace('.jsonl', '.md')
88 |
89 | # scan and identify the action of the Tracker interface
90 | begin = -1
91 | interval_list = [] # [begin, end)
92 | for index, screenshot_path in enumerate(screenshot_paths):
93 | # find the screenshot of the Tracker interface
94 | if screenshot_of_tracker(screenshot_path):
95 | if begin == -1:
96 | begin = index
97 | else:
98 | # back to the screenshot of non-Tracker interface, end the interval
99 | if begin != -1:
100 | interval_list.append((begin, index))
101 | begin = -1
102 |
103 | interval_list.append((begin, len(screenshot_paths))) # the last interval (begin maybe -1)
104 |
105 | # delete the last interval (finish/fail)
106 | begin, end = interval_list.pop()
107 | if begin != -1:
108 | entries = entries[:begin]
109 | try:
110 | entries[-1]['action'] = last_entry_action
111 | entries[-1]['element'] = None
112 | entries[-1]['rect'] = None
113 | except Exception as e: # empty data
114 | print(f"[ERROR] delete related records (probably empty): {e}")
115 | # delete the JSONL file
116 | os.remove(file_path)
117 | # delete the Markdown file
118 | os.remove(markdown_path)
119 | # delete the screenshot files
120 | for screenshot_path in screenshot_paths:
121 | remove_screenshot(screenshot_path)
122 | return -1
123 |
124 | for i in range(begin, end):
125 | remove_screenshot(screenshot_paths[i])
126 |
127 | # delete other intervals
128 | to_remove_entry_set = set()
129 | for begin, end in interval_list:
130 | for i in range(begin - 1, end):
131 | remove_screenshot(screenshot_paths[i])
132 | to_remove_entry_set.add(i)
133 |
134 | entries = [entry for i, entry in enumerate(entries) if i not in to_remove_entry_set]
135 |
136 | # save the updated JSONL file
137 | with open(file_path, 'w', encoding='utf-8') as file:
138 | for entry in entries:
139 | json.dump(entry, file, ensure_ascii=False)
140 | file.write('\n')
141 |
142 | return len(entries)
143 |
144 |
145 | def clean_fail_and_error(file_path):
146 | """
147 | clean the records without corresponding Markdown files or the last action is 'fail' or there is None in action.
148 |
149 | return True if the file is deleted, False otherwise.
150 | """
151 | markdown_path = file_path.replace('.jsonl', '.md')
152 | if DETAIL_OUTPUT:
153 | print(f"Clean fail: {file_path}")
154 | try:
155 | with open(file_path, 'r', encoding='utf-8') as infile:
156 | entries = [json.loads(line) for line in infile]
157 | except Exception as e:
158 | print(f"[ERROR] Failed to read file {file_path}: {e}")
159 | return False
160 |
161 | screenshot_paths = [os.path.join(os.path.dirname(file_path), entry['screenshot']) for entry in entries]
162 | last_entry_action = entries[-1]['action'] if entries else ''
163 |
164 | # delete the records without corresponding Markdown files
165 | if not os.path.exists(markdown_path):
166 | if DETAIL_OUTPUT:
167 | print(f"File {file_path} has no corresponding Markdown file")
168 | print("Delete related records...")
169 | # delete the JSONL file
170 | os.remove(file_path)
171 | # delete the screenshot files
172 | for screenshot_path in screenshot_paths:
173 | remove_screenshot(screenshot_path)
174 | return True
175 |
176 | # clean the fail records (optional)
177 | if REMOVE_FAIL_RECORD and last_entry_action == 'fail':
178 | if DETAIL_OUTPUT:
179 | print(f"File {file_path} ends with fail action")
180 | print("Delete related records...")
181 | # delete the JSONL file
182 | os.remove(file_path)
183 | # delete the Markdown file
184 | os.remove(markdown_path)
185 | # delete the screenshot files
186 | for screenshot_path in screenshot_paths:
187 | remove_screenshot(screenshot_path)
188 | return True
189 |
190 | # check if there is None in action
191 | for entry in entries:
192 | if entry['action'] is None or "None" in entry['action']:
193 | if DETAIL_OUTPUT:
194 | print(f"File {file_path} has None in action")
195 | print("Delete related records...")
196 | # delete the JSONL file
197 | os.remove(file_path)
198 | # delete the Markdown file
199 | os.remove(markdown_path)
200 | # delete the screenshot files
201 | for screenshot_path in screenshot_paths:
202 | remove_screenshot(screenshot_path)
203 | return True
204 |
205 | return False
206 |
207 |
208 | def resize(file_path):
209 | if DETAIL_OUTPUT:
210 | print(f"Resize file: {file_path}")
211 |
212 | # get the directory of the file
213 | task_dir = os.path.dirname(file_path)
214 |
215 | # read the screenshot path of the last entry
216 | try:
217 | with open(file_path, 'r', encoding='utf-8') as infile:
218 | lines = infile.readlines()
219 | last_line = lines[-1]
220 | last_entry = json.loads(last_line)
221 | screenshot_path = os.path.join(task_dir, last_entry['screenshot'])
222 | except Exception as e:
223 | print(f"[ERROR] Failed to read the screenshot path of the last entry: {e}")
224 | return
225 |
226 | if not os.path.exists(screenshot_path):
227 | print(f"[ERROR] The screenshot file does not exist: {screenshot_path}")
228 | return
229 |
230 | # get the resolution of the screenshot
231 | try:
232 | with Image.open(screenshot_path) as img:
233 | original_width, original_height = img.size
234 | if DETAIL_OUTPUT:
235 | print(f"Original resolution: {original_width}x{original_height}")
236 | except Exception as e:
237 | print(f"[ERROR] Failed to open the screenshot file {screenshot_path}: {e}")
238 | return
239 |
240 | # original_width, original_height = 2560, 1440
241 |
242 | # target resolution
243 | target_width, target_height = 1280, 720
244 | if original_width == target_width and original_height == target_height:
245 | if DETAIL_OUTPUT:
246 | print(f"The screenshot resolution is the same as the target resolution, no need to resize")
247 | return
248 |
249 | scale_x = target_width / original_width
250 | scale_y = target_height / original_height
251 | if DETAIL_OUTPUT:
252 | print(f"Resize ratio - X: {scale_x:.4f}, Y: {scale_y:.4f}")
253 |
254 | # process the JSONL file
255 | modified_lines = []
256 | for line in lines:
257 | try:
258 | data = json.loads(line)
259 |
260 | # process the screenshot
261 | screenshot_path = os.path.join(task_dir, data['screenshot'])
262 | assert resize_to_720p(screenshot_path), "Error occured!"
263 |
264 | # process the action
265 | data['action'] = resize_action(data['action'], scale_x, scale_y)
266 |
267 | # process the rect
268 | if 'rect' in data and isinstance(data['rect'], dict):
269 | rect = data['rect']
270 | rect['left'] = round(rect['left'] * scale_x)
271 | rect['top'] = round(rect['top'] * scale_y)
272 | rect['right'] = round(rect['right'] * scale_x)
273 | rect['bottom'] = round(rect['bottom'] * scale_y)
274 | if DETAIL_OUTPUT:
275 | print(f"Resize rect: {rect}")
276 |
277 | modified_lines.append(json.dumps(data, ensure_ascii=False) + '\n')
278 | except Exception as e:
279 | print(f"[WARNING] Error when processing the line: {line.strip()} - {e}")
280 | modified_lines.append(line)
281 |
282 | # directly write the modified content, overwrite the original file
283 | try:
284 | with open(file_path, 'w', encoding='utf-8') as outfile:
285 | outfile.writelines(modified_lines)
286 | if DETAIL_OUTPUT:
287 | print(f"Saved the modified file: {file_path}")
288 | except Exception as e:
289 | print(f"[ERROR] Failed to write the file {file_path}: {e}")
290 |
291 |
292 | def mark(file_path):
293 | if DETAIL_OUTPUT:
294 | print(f"Mark file: {file_path}")
295 |
296 | # get the directory of the file
297 | task_dir = os.path.dirname(file_path)
298 |
299 | # process the JSONL file
300 | modified_lines = []
301 | with open(file_path, 'r', encoding='utf-8') as infile:
302 | for line in infile:
303 | entry = json.loads(line)
304 |
305 | if not OVERWRITE_MARKED and 'marked_screenshot' in entry:
306 | if DETAIL_OUTPUT:
307 | print(f"Already marked: {entry['marked_screenshot']}")
308 | modified_lines.append(line)
309 | continue
310 |
311 | screenshot = os.path.join(task_dir, entry.get('screenshot'))
312 | action = entry.get('action')
313 | rect = entry.get('rect')
314 |
315 | if rect is not None and action != "finish": # click or drag
316 | click_action_name, coordinates = parse_click_action(action)
317 | if click_action_name != None: # click related action
318 | x, y = coordinates
319 | marked_screenshot = mark_image(is_click_action=True, image_path=screenshot, rect=rect, point1={'x': x, 'y': y})
320 | entry['marked_screenshot'] = marked_screenshot
321 | else: # drag related action
322 | (x1, y1), (x2, y2) = parse_drag_action(action)
323 | marked_screenshot = mark_image(is_click_action=False, image_path=screenshot, rect=rect, point1={'x': x1, 'y': y1}, point2={'x': x2, 'y': y2})
324 | entry['marked_screenshot'] = marked_screenshot
325 | else:
326 | # rect is None, copy the original screenshot path
327 | entry['marked_screenshot'] = screenshot
328 |
329 | # remove the task_dir prefix of marked_screenshot
330 | entry['marked_screenshot'] = entry['marked_screenshot'].replace(
331 | task_dir + '/', '')
332 |
333 | modified_lines.append(json.dumps(entry, ensure_ascii=False) + '\n')
334 |
335 | # write the modified content, overwrite the original file
336 | with open(file_path, 'w', encoding='utf-8') as outfile:
337 | outfile.writelines(modified_lines)
338 |
339 |
340 | def rewrite_screenshot_path(file_path):
341 | if DETAIL_OUTPUT:
342 | print(f"Rewrite screenshot path: {file_path}")
343 |
344 | modified_lines = []
345 | with open(file_path, 'r', encoding='utf-8') as file:
346 | for line in file:
347 | entry = json.loads(line)
348 |
349 | # process the screenshot field, remove the possible prefix 'events\\'
350 | if entry['screenshot'].startswith('events\\'):
351 | entry['screenshot'] = entry['screenshot'][7:] # remove the 'events\\' prefix
352 |
353 | # replace the backslash with the forward slash (Linux format)
354 | if "\\" in entry['screenshot']:
355 | entry['screenshot'] = entry['screenshot'].replace("\\", "/")
356 |
357 | modified_lines.append(json.dumps(entry, ensure_ascii=False) + '\n')
358 |
359 | with open(file_path, 'w', encoding='utf-8') as outfile:
360 | outfile.writelines(modified_lines)
361 |
362 |
363 | duplicate_clicks = 0
364 | adjacent_clicks = 0
365 |
366 |
367 | def remove_redundant_actions(file_path):
368 | if DETAIL_OUTPUT:
369 | print(f"Remove redundant actions: {file_path}")
370 | ctrl_cnt = 0
371 | shift_cnt = 0
372 | wait_cnt = 0
373 | caps_lock_cnt = 0
374 | all_entries = []
375 | kept_entries = []
376 | screenshot_paths = []
377 | continuous_wait_at_begin = False
378 |
379 | with open(file_path, 'r', encoding='utf-8') as file:
380 | for line in file:
381 | entry = json.loads(line)
382 | all_entries.append(entry)
383 |
384 | total_cnt = len(all_entries)
385 | skip = False
386 | for id, entry in enumerate(all_entries):
387 | if skip:
388 | skip = False
389 | continue
390 | # check the continuous adjacent clicks
391 | screenshot_path = os.path.join(os.path.dirname(file_path), entry['screenshot'])
392 | if entry != all_entries[-1] and 'click' in entry['action'] and 'click' in all_entries[id+1]['action']:
393 | _, (x1, y1) = parse_click_action(entry['action'])
394 | _, (x2, y2) = parse_click_action(all_entries[id+1]['action'])
395 | global adjacent_clicks
396 | global duplicate_clicks
397 | if entry['action'] == all_entries[id+1]['action']:
398 | duplicate_clicks += 1;
399 | elif abs(x1-x2) + abs(y1-y2) < 5:
400 | adjacent_clicks += 1;
401 |
402 | # delete the continuous wait at the beginning
403 | if entry['action'] != 'wait':
404 | continuous_wait_at_begin = False
405 | if entry['action'] == 'wait' and (id == 0 or continuous_wait_at_begin):
406 | wait_cnt += 1
407 | screenshot_paths.append(screenshot_path)
408 | continuous_wait_at_begin = True
409 | # delete the redundant ctrl and shift
410 | elif entry['action'] == 'press key ctrl' and (entry == all_entries[-1] or all_entries[id+1]['action'] == 'press key ctrl' or all_entries[id+1]['action'].startswith("hotkey (Ctrl,")):
411 | ctrl_cnt += 1
412 | screenshot_paths.append(screenshot_path)
413 | elif entry['action'] == 'press key shift' and (entry == all_entries[-1] or all_entries[id+1]['action'] == 'press key shift' or all_entries[id+1]['action'].startswith('type')):
414 | shift_cnt += 1
415 | screenshot_paths.append(screenshot_path)
416 | elif entry['action'] == 'press key ctrl' and all_entries[id+1]['action'] == 'press key shift':
417 | # this action and the next action should be deleted
418 | ctrl_cnt += 1
419 | shift_cnt += 1
420 | screenshot_paths.append(screenshot_path)
421 | screenshot_paths.append(os.path.join(os.path.dirname(file_path), all_entries[id+1]['screenshot']))
422 | if DETAIL_OUTPUT:
423 | print(f"remove ctrl + shift in {file_path} action {id}")
424 | skip = True
425 | elif entry['action'] == 'press key caps_lock':
426 | caps_lock_cnt += 1
427 | screenshot_paths.append(screenshot_path)
428 | else:
429 | kept_entries.append(entry)
430 |
431 | with open(file_path, 'w', encoding='utf-8') as file:
432 | for entry in kept_entries:
433 | json.dump(entry, file, ensure_ascii=False)
434 | file.write('\n')
435 |
436 | if len(kept_entries) == len(all_entries):
437 | if DETAIL_OUTPUT:
438 | print(f"File {file_path} has no redundant actions")
439 | return
440 | if DETAIL_OUTPUT:
441 | if wait_cnt != 0:
442 | print(f"File {file_path} has {wait_cnt}/{total_cnt} redundant wait, removed")
443 | if ctrl_cnt != 0:
444 | print(f"File {file_path} has {ctrl_cnt}/{total_cnt} redundant ctrl, removed")
445 | if shift_cnt != 0:
446 | print(f"File {file_path} has {shift_cnt}/{total_cnt} redundant shift, removed")
447 | if caps_lock_cnt != 0:
448 | print(f"File {file_path} has {caps_lock_cnt}/{total_cnt} redundant caps_lock, removed")
449 |
450 | # delete the screenshot files
451 | for screenshot_path in screenshot_paths:
452 | os.remove(screenshot_path)
453 |
454 |
455 | def remove_meaningless_actions(file_path):
456 | if DETAIL_OUTPUT:
457 | print(f"Remove meaningless actions: {file_path}")
458 | all_entries = []
459 | kept_entries = []
460 | screenshot_paths = []
461 |
462 | with open(file_path, 'r', encoding='utf-8') as file:
463 | for line in file:
464 | entry = json.loads(line)
465 | all_entries.append(entry)
466 |
467 | for id, entry in enumerate(all_entries):
468 | # check the similarity of two continuous screenshots
469 | if entry != all_entries[-1] and (entry['action'] == 'wait' or 'click' in entry['action']):
470 | screenshot_path1 = os.path.join(os.path.dirname(file_path), entry['screenshot'])
471 | screenshot_path2 = os.path.join(os.path.dirname(file_path), all_entries[id+1]['screenshot'])
472 | if are_screenshots_identical(screenshot_path1, screenshot_path2):
473 | screenshot_paths.append(screenshot_path1)
474 | if DETAIL_OUTPUT:
475 | print(f"action {id}: {entry['action']} in {file_path} is a meaningless action, it has been removed")
476 | else:
477 | kept_entries.append(entry)
478 | else:
479 | kept_entries.append(entry)
480 |
481 | if len(kept_entries) == len(all_entries):
482 | if DETAIL_OUTPUT:
483 | print(f"File {file_path} has no meaningless actions")
484 | return
485 |
486 | # rewrite the JSON file
487 | with open(file_path, 'w', encoding='utf-8') as file:
488 | for entry in kept_entries:
489 | json.dump(entry, file, ensure_ascii=False)
490 | file.write('\n')
491 |
492 | # delete the screenshot files
493 | for screenshot_path in screenshot_paths:
494 | os.remove(screenshot_path)
495 |
496 |
497 | def merge_press_drag(file_path):
498 | if DETAIL_OUTPUT:
499 | print(f"Merge press and drag: {file_path}")
500 |
501 | all_entries = []
502 | kept_entries = []
503 | screenshot_paths = []
504 |
505 | with open(file_path, 'r', encoding='utf-8') as file:
506 | for line in file:
507 | entry = json.loads(line)
508 | all_entries.append(entry)
509 |
510 | id = 0
511 | while id < len(all_entries):
512 | # check the press action
513 | if id != len(all_entries) - 1 and all_entries[id]['action'].startswith("press ("):
514 | if id == len(all_entries) - 2 and all_entries[id+1]['action'] == "finish":
515 | # delete this press action
516 | id += 1
517 | continue
518 | # the next action must be drag to
519 | assert all_entries[id+1]['action'].startswith("drag"), f"Error: In file {file_path}, action {id+1} should start with 'drag', but it's {all_entries[id+1]['action']}"
520 | x1, y1 = extract_coordinates(all_entries[id]['action'])
521 | x2, y2 = extract_coordinates(all_entries[id+1]['action'])
522 | if abs(x1-x2) + abs(y1-y2) <= 10:
523 | if DETAIL_OUTPUT:
524 | print(f"delta: {abs(x1-x2) + abs(y1-y2)} in {file_path} action {id} is too small, it's merged into a single click")
525 | all_entries[id]['action'] = f"click ({x2}, {y2})"
526 | else:
527 | if DETAIL_OUTPUT:
528 | print(f"action {id}: {all_entries[id]['action']} in {file_path} has been merged with action {id+1}: {all_entries[id+1]['action']}")
529 | all_entries[id]['action'] = f"drag from ({x1}, {y1}) to ({x2}, {y2})"
530 | screenshot_paths.append(os.path.join(os.path.dirname(file_path), all_entries[id+1]['screenshot']))
531 | kept_entries.append(all_entries[id])
532 | id += 1 # skip the next action
533 | else:
534 | kept_entries.append(all_entries[id])
535 |
536 | id += 1
537 |
538 | if len(kept_entries) == len(all_entries):
539 | if DETAIL_OUTPUT:
540 | print(f"File {file_path} has no press and drag to be merged")
541 | return
542 |
543 | # rewrite the JSON file
544 | with open(file_path, 'w', encoding='utf-8') as file:
545 | for entry in kept_entries:
546 | json.dump(entry, file, ensure_ascii=False)
547 | file.write('\n')
548 |
549 | # delete the screenshot files
550 | for screenshot_path in screenshot_paths:
551 | os.remove(screenshot_path)
552 |
553 | def rewrite_scroll(file_path):
554 | if DETAIL_OUTPUT:
555 | print(f"Rewrite Scroll: {file_path}")
556 |
557 | all_entries = []
558 | kept_entries = []
559 |
560 | with open(file_path, 'r', encoding='utf-8') as file:
561 | for line in file:
562 | entry = json.loads(line)
563 | all_entries.append(entry)
564 |
565 | for entry in all_entries:
566 | if entry['action'].startswith("scroll"):
567 | new_pattern = r'^scroll \(-?\d+\)$'
568 | match = re.match(new_pattern, entry['action'])
569 | if match: # already in the new pattern
570 | kept_entries.append(entry)
571 | else:
572 | dx, dy = extract_coordinates(entry['action'])
573 | if dy != 0:
574 | action = f"scroll ({dy})"
575 | entry['action'] = action
576 | kept_entries.append(entry)
577 | else:
578 | kept_entries.append(entry)
579 |
580 |
581 | # rewrite the JSON file
582 | with open(file_path, 'w', encoding='utf-8') as file:
583 | for entry in kept_entries:
584 | json.dump(entry, file, ensure_ascii=False)
585 | file.write('\n')
586 |
587 |
588 | def check_finish(file_path):
589 | if DETAIL_OUTPUT:
590 | print(f"Check finish: {file_path}")
591 |
592 | # read all lines
593 | try:
594 | with open(file_path, 'r', encoding='utf-8') as infile:
595 | lines = infile.readlines()
596 | last_line = lines[-1]
597 | last_entry = json.loads(last_line)
598 | except Exception as e:
599 | print(f"[ERROR] Failed to read the file content: {e}")
600 | return
601 |
602 | # replace the last action with finish
603 | if last_entry.get('action') == 'finish':
604 | if DETAIL_OUTPUT:
605 | print("The last entry is already 'finish'")
606 | return
607 | else:
608 | if DETAIL_OUTPUT:
609 | print("The last entry is ", last_entry.get('action'))
610 | print("Modify the last entry to 'finish'")
611 | last_entry['action'] = 'finish'
612 |
613 | # update the last line
614 | lines[-1] = json.dumps(last_entry, ensure_ascii=False) + '\n'
615 |
616 | # write back to file
617 | try:
618 | with open(file_path, 'w', encoding='utf-8') as outfile:
619 | outfile.writelines(lines)
620 | if DETAIL_OUTPUT:
621 | print(f"Saved the modified file: {file_path}")
622 | except Exception as e:
623 | print(f"[ERROR] Failed to write the file {file_path}: {e}")
624 |
625 |
626 | def process_task_jsonl_file(file_path):
627 | if DETAIL_OUTPUT:
628 | print(f"Process task jsonl file: {file_path}")
629 | rewrite_screenshot_path(file_path)
630 | if clean_fail_and_error(file_path):
631 | return -1 # the file is deleted
632 | check_finish(file_path)
633 | merge_press_drag(file_path)
634 | remove_redundant_actions(file_path)
635 | remove_meaningless_actions(file_path)
636 | rewrite_scroll(file_path)
637 | resize(file_path)
638 | cnt = clean_tracker_interface(file_path)
639 | if cnt != -1:
640 | mark(file_path)
641 | rewrite_markdown_file_by_jsonl(file_path)
642 | return cnt
643 |
644 |
645 | def process_events_directories():
646 | # Get parent directory of current script
647 | root_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
648 |
649 | # Build the path to the data folder
650 | data_dir = os.path.join(root_dir, 'data')
651 | if not os.path.exists(data_dir):
652 | print(f"error: {data_dir} directory does not exist")
653 | exit()
654 |
655 | # Events folder prefix
656 | events_prefix = "events_"
657 |
658 | total_action_cnt = 0
659 | total_record_cnt = 0
660 | max_action_cnt = 0
661 |
662 | # traverse all subdirectories of the data folder
663 | for item in os.listdir(data_dir):
664 | item_path = os.path.join(data_dir, item)
665 |
666 | # check if it's a directory and starts with the specified name
667 | if os.path.isdir(item_path) and item.startswith(events_prefix):
668 | print(f'Processing directory: {item_path}')
669 | for filename in os.listdir(item_path):
670 | # task jsonl file
671 | if filename.endswith('.jsonl') and 'task' in filename:
672 | file_path = os.path.join(item_path, filename)
673 | cnt = process_task_jsonl_file(file_path)
674 | if cnt != -1:
675 | total_action_cnt += cnt
676 | total_record_cnt += 1
677 | max_action_cnt = max(max_action_cnt, cnt)
678 |
679 | average_action_cnt = total_action_cnt / total_record_cnt
680 | print(f"Total records: {total_record_cnt}")
681 | print(f"Average actions per record: {average_action_cnt:.2f}")
682 | print(f"Maximum actions: {max_action_cnt}")
683 |
684 |
685 | if __name__ == "__main__":
686 | process_events_directories()
--------------------------------------------------------------------------------
/postprocess/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import re
3 | import json
4 | import base64
5 | import cv2
6 | import numpy as np
7 | from PIL import Image, ImageDraw
8 |
9 | POINT_RADIUS = 2
10 | CIRCLE_RADIUS = 18
11 | CIRCLE_WIDTH = 2
12 | RECT_WIDTH = 2
13 |
14 | KEYBOARD_KEYS = ['\t', '\n', '\r', ' ', '!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '{', '|', '}', '~', 'accept', 'add', 'alt', 'altleft', 'altright', 'apps', 'backspace', 'browserback', 'browserfavorites', 'browserforward', 'browserhome', 'browserrefresh', 'browsersearch', 'browserstop', 'capslock', 'clear', 'convert', 'ctrl', 'ctrlleft', 'ctrlright', 'decimal', 'del', 'delete', 'divide', 'down', 'end', 'enter', 'esc', 'escape', 'execute', 'f1', 'f10', 'f11', 'f12', 'f13', 'f14', 'f15', 'f16', 'f17', 'f18', 'f19', 'f2', 'f20', 'f21', 'f22', 'f23', 'f24', 'f3', 'f4', 'f5', 'f6', 'f7', 'f8', 'f9', 'final', 'fn', 'hanguel', 'hangul', 'hanja', 'help', 'home', 'insert', 'junja', 'kana', 'kanji', 'launchapp1', 'launchapp2', 'launchmail', 'launchmediaselect', 'left', 'modechange', 'multiply', 'nexttrack', 'nonconvert', 'num0', 'num1', 'num2', 'num3', 'num4', 'num5', 'num6', 'num7', 'num8', 'num9', 'numlock', 'pagedown', 'pageup', 'pause', 'pgdn', 'pgup', 'playpause', 'prevtrack', 'print', 'printscreen', 'prntscrn', 'prtsc', 'prtscr', 'return', 'right', 'scrolllock', 'select', 'separator', 'shift', 'shiftleft', 'shiftright', 'sleep', 'space', 'stop', 'subtract', 'tab', 'up', 'volumedown', 'volumemute', 'volumeup', 'win', 'winleft', 'winright', 'yen', 'command', 'option', 'optionleft', 'optionright']
15 |
16 |
17 | def refine_response(response):
18 | # Returns: refined response or None
19 | if response is None:
20 | return None
21 | response = response.replace("**Action:**", "Action:").strip()
22 | response = response.replace("### Action:\nAction:", "Action:").strip()
23 | thought, action = parse_thought_action_from_response(response)
24 | thought = refine_thought(thought)
25 | action = refine_action(action)
26 | if thought is None or action is None:
27 | return None
28 |
29 | return combine_thought_action_to_response(thought, action)
30 |
31 |
32 | def refine_action(action):
33 | # Returns: refined action or None
34 | if action is None:
35 | return None
36 |
37 | action = remove_comments_from_action(action)
38 |
39 | # check if valid
40 | if get_action_code(action) is None:
41 | return None
42 |
43 | return action
44 |
45 |
46 | def remove_comments_from_action(action):
47 | if action is None:
48 | return None
49 | # Find '#'
50 | pos_hash = action.find('#')
51 | if pos_hash != -1:
52 | action = action[:pos_hash]
53 | # Find '//'
54 | pos_slash = action.find('//')
55 | if pos_slash != -1:
56 | action = action[:pos_slash]
57 |
58 | return action.strip()
59 |
60 |
61 | def refine_thought(thought):
62 | # Returns: refined thought or None
63 |
64 | # rule 0: check None
65 | if thought is None:
66 | return None
67 |
68 | # rule 1: check 'I can't assist'
69 | if "sorry, I can\'t assist" in thought:
70 | return None
71 |
72 | thought = thought.replace("**Thought Process:**", "").strip()
73 |
74 | # rule 2: check 'Action:' in thought
75 | if "Action:" in thought:
76 | thought = thought.split("Action:")[0].strip()
77 | if "*Action*:" in thought:
78 | thought = thought.split("*Action*:")[0].strip()
79 | if "**Action:**" in thought:
80 | thought = thought.split("**Action:**")[0].strip()
81 |
82 | # rule 3: check useless title with #
83 | if thought.startswith("# Thought Process") or thought.startswith("# My Thought Process"):
84 | newline_index = thought.find("\n")
85 | if newline_index != -1:
86 | thought = thought[newline_index+1:].strip()
87 | else:
88 | return None
89 |
90 | # rule 4: check if thought is enclosed in {}
91 | if thought.startswith("{") and thought.endswith("}"):
92 | thought = thought[1:-1].strip() # remove the outer {}
93 |
94 | # rule 5: check if start with Your thought process
95 | unwanted_contents = ["{Your thought process}", "Your thought process", "## Thought Process", "# Thought process", "#*# Thought Process", "Thought process:", "Thought process", "My thought process:", "My thought process", "#\n", "#:\n", ":\n"]
96 | for unwanted in unwanted_contents:
97 | if unwanted in thought:
98 | thought = thought.replace(unwanted, "").strip()
99 |
100 | # rule 6: check too short thought
101 | if len(thought)< 15:
102 | return None
103 |
104 | return thought
105 |
106 |
107 | def rewrite_markdown_file_by_jsonl(jsonl_path):
108 | """
109 | rewrite markdown file by jsonl file
110 | """
111 | with open(jsonl_path, 'r', encoding='utf-8') as file:
112 | lines = file.readlines()
113 |
114 | entries = [json.loads(line) for line in lines]
115 | markdown_path = jsonl_path.replace('.jsonl', '.md')
116 | rewrite_markdown_file(markdown_path, entries)
117 |
118 |
119 | def rewrite_markdown_file(markdown_path, entries):
120 | """
121 | rewrite markdown file by entries, use marked_screenshot if exists
122 | """
123 | prompt = '''Given the screenshot as below. What's the next step that you will do to help with the task?'''
124 | with open(markdown_path, 'r', encoding='utf-8') as file:
125 | lines = file.readlines()
126 |
127 | # keep the first 5 lines
128 | kept_lines = lines[:5]
129 |
130 | # add new lines after the kept lines
131 | for index, entry in enumerate(entries):
132 | action = get_full_action(entry)
133 | screenshot_path = entry['marked_screenshot'] if 'marked_screenshot' in entry else entry['screenshot']
134 | thought = entry['thought'] if 'thought' in entry else None
135 | # boost_responses = entry['boost_responses'] if 'boost_responses' in entry else []
136 |
137 | kept_lines.append(f'### Step {index+1}\n')
138 | kept_lines.append(f'**Input:** \n\n{prompt}\n\n')
139 | kept_lines.append(
140 | f'
\n\n')
141 |
142 | if thought:
143 | kept_lines.append(f'**Thought:** \n\n{thought}\n\n')
144 |
145 | kept_lines.append(f'**Output:** \n\n{action}\n\n')
146 |
147 | # rewrite the file
148 | with open(markdown_path, 'w', encoding='utf-8') as file:
149 | file.writelines(kept_lines)
150 |
151 |
152 | def remove_screenshot(screenshot_path):
153 | """
154 | remove the screenshot file and the possible _marked file
155 | """
156 | if os.path.exists(screenshot_path):
157 | os.remove(screenshot_path)
158 |
159 | # remove the possible _marked file
160 | marked_screenshot_path = screenshot_path.replace('.png', '_marked.png')
161 | if os.path.exists(marked_screenshot_path):
162 | os.remove(marked_screenshot_path)
163 |
164 |
165 | def get_full_action(entry):
166 | """
167 | get the full action string from entry
168 | """
169 | action = entry['action']
170 | element = entry['element']
171 | if element:
172 | target = 'click'
173 | index = action.find(target)
174 | if index != -1:
175 | # find the end position of 'click'
176 | insert_position = index + len(target)
177 | # insert ':' after 'click'
178 | action = action[:insert_position] + \
179 | f' element {element} at' + action[insert_position:]
180 | return action
181 |
182 |
183 | def encode_image(image_path):
184 | """
185 | encode image to base64
186 | """
187 | with open(image_path, "rb") as image_file:
188 | return base64.b64encode(image_file.read()).decode('utf-8')
189 |
190 |
191 | def get_file_size_kb(file_path):
192 | file_size_bytes = os.path.getsize(file_path)
193 | file_size_kb = file_size_bytes / 1024 # convert to KB
194 | return round(file_size_kb, 1) # keep 1 decimal place
195 |
196 |
197 | def mark_image(is_click_action, image_path, rect, point1, point2=None):
198 | """
199 | mark the image and save as a new file, return the new file path
200 | """
201 | # open the image
202 | with Image.open(image_path) as image:
203 | if is_click_action:
204 | # create a drawable object
205 | draw = ImageDraw.Draw(image)
206 |
207 | # draw a rectangle
208 | draw.rectangle(
209 | [(rect["left"], rect["top"]), (rect["right"], rect["bottom"])],
210 | outline="red",
211 | width=RECT_WIDTH # line width
212 | )
213 |
214 | # draw a point
215 | draw_point(point1["x"], point1["y"], draw)
216 |
217 | # draw a circle
218 | draw_circle(point1["x"], point1["y"], draw)
219 |
220 | # draw a short arrow
221 | draw_short_arrow(point1["x"], point1["y"], draw)
222 |
223 | else:
224 | draw = ImageDraw.Draw(image)
225 |
226 | # draw a point
227 | draw_point(point1["x"], point1["y"], draw)
228 | draw_point(point2["x"], point2["y"], draw)
229 |
230 | if (abs(point1["x"] - point2["x"]) + abs(point1["y"] - point2["y"])) > 15:
231 | # draw a circle
232 | draw_circle(point1["x"], point1["y"], draw)
233 | draw_circle(point2["x"], point2["y"], draw)
234 | else:
235 | print(f"the distance between point1 and point2 in image {image_path} is too small, skip drawing circles")
236 |
237 | # draw a long arrow
238 | draw_long_arrow(point1["x"], point1["y"], point2["x"], point2["y"], draw)
239 |
240 | # generate the output path, add "_marked" to the original file name
241 | base, ext = os.path.splitext(image_path)
242 | output_path = f"{base}_marked{ext}"
243 |
244 | # save the marked image
245 | image.save(output_path)
246 | # print(f"marked image saved to: {output_path}")
247 | return output_path
248 |
249 |
250 | def mark_image_for_boost(is_click_action, image_path, boost_idx, point1, point2=None):
251 | """
252 | mark the image and save as a new file, return the new file path
253 | """
254 | # open the image
255 | with Image.open(image_path) as image:
256 | if is_click_action:
257 | # create a drawable object
258 | draw = ImageDraw.Draw(image)
259 |
260 | # draw a point
261 | draw_point(point1["x"], point1["y"], draw)
262 |
263 | # draw a circle
264 | draw_circle(point1["x"], point1["y"], draw)
265 |
266 | # draw a short arrow
267 | draw_short_arrow(point1["x"], point1["y"], draw)
268 |
269 | else:
270 | draw = ImageDraw.Draw(image)
271 |
272 | # draw a point
273 | draw_point(point1["x"], point1["y"], draw)
274 | draw_point(point2["x"], point2["y"], draw)
275 |
276 | if (abs(point1["x"] - point2["x"]) + abs(point1["y"] - point2["y"])) > 15:
277 | # draw a circle
278 | draw_circle(point1["x"], point1["y"], draw)
279 | draw_circle(point2["x"], point2["y"], draw)
280 | else:
281 | print(f"the distance between point1 and point2 in image {image_path} is too small, skip drawing circles")
282 |
283 | # draw a long arrow
284 | draw_long_arrow(point1["x"], point1["y"], point2["x"], point2["y"], draw)
285 |
286 | # generate the output path, add "_marked" to the original file name
287 | base, ext = os.path.splitext(image_path)
288 | output_path = f"{base}_marked_boost_{boost_idx}{ext}"
289 |
290 | # save the marked image
291 | image.save(output_path)
292 | # print(f"marked image saved to: {output_path}")
293 | return output_path
294 |
295 |
296 | def resize_to_720p(image_path):
297 | """
298 | check and resize the image to fixed 1280x720 resolution, return whether success
299 | """
300 | try:
301 | with Image.open(image_path) as img:
302 | img.verify() # verify the image integrity
303 | except:
304 | print(f"[ERROR] image corrupted: {image_path}")
305 | return False
306 |
307 | # open the image
308 | with Image.open(image_path) as img:
309 | if img.size == (1280, 720):
310 | print(f"image is already 720p, no need to resize: {image_path}")
311 | return True
312 |
313 | try:
314 | resized_img = img.resize((1280, 720), Image.LANCZOS)
315 | except:
316 | print(f"[ERROR] cannot resize image: {image_path}")
317 | return False
318 |
319 | # save the resized image, overwrite the original file
320 | resized_img.save(image_path, optimize=True)
321 | print(f"image resized to 720p and saved: {image_path}")
322 | return True
323 |
324 |
325 | def resize_to_1080p(image_path):
326 | """
327 | check and resize the image to fixed 1920x1080 resolution, return whether success
328 | """
329 | try:
330 | with Image.open(image_path) as img:
331 | img.verify() # verify the image integrity
332 | except:
333 | print(f"[ERROR] image corrupted: {image_path}")
334 | return False
335 |
336 | # open the image
337 | with Image.open(image_path) as img:
338 | # check if the image is already 1080p
339 | if img.size == (1920, 1080):
340 | print(f"image is already 1080p, no need to resize: {image_path}")
341 | return True
342 |
343 | # resize the image to fixed 1920x1080 resolution
344 | try:
345 | resized_img = img.resize((1920, 1080), Image.LANCZOS)
346 | except:
347 | print(f"[ERROR] cannot resize image: {image_path}")
348 | return False
349 |
350 | # save the resized image, overwrite the original file
351 | resized_img.save(image_path, optimize=True)
352 | print(f"image resized and saved: {image_path}")
353 | return True
354 |
355 |
356 | def resize_action(action_str, scale_x, scale_y):
357 | """
358 | extract coordinates from the action string, scale them, and replace the coordinate part in the original string.
359 | supports both single-point actions (e.g. "double click (1415, 741)") and
360 | drag actions (e.g. "drag from (1230, 26) to (1209, 26)").
361 |
362 | :param action_str: action string
363 | :param scale_x: X axis scale factor
364 | :param scale_y: Y axis scale factor
365 | :return: the scaled action string
366 | """
367 | # use regex to match coordinate pairs
368 | pattern = r'\((\d+),\s*(\d+)\)'
369 |
370 | def scale_coords(match):
371 | original_x = float(match.group(1))
372 | original_y = float(match.group(2))
373 | scaled_x = round(original_x * scale_x)
374 | scaled_y = round(original_y * scale_y)
375 | print(f"scale coordinates: ({original_x}, {original_y}) -> ({scaled_x}, {scaled_y})")
376 | return f"({scaled_x}, {scaled_y})"
377 |
378 | # replace all coordinate pairs using the callback function
379 | new_action_str = re.sub(pattern, scale_coords, action_str)
380 | return new_action_str
381 |
382 |
383 | def are_screenshots_identical(screenshot_path1, screenshot_path2):
384 | """
385 | check if two screenshots are identical
386 | """
387 | # read the images
388 | img1 = cv2.imread(screenshot_path1)
389 | img2 = cv2.imread(screenshot_path2)
390 |
391 | # check if the images are successfully read
392 | if img1 is None or img2 is None:
393 | print(f"cannot read image: {screenshot_path1} or {screenshot_path2}")
394 | return False
395 |
396 | # check if the images have the same size
397 | if img1.shape != img2.shape:
398 | return False
399 |
400 | # check if the images are identical
401 | difference = cv2.subtract(img1, img2)
402 | return not np.any(difference)
403 |
404 |
405 | def parse_click_action(action):
406 | pattern = r'((?:double |right )?click)\s*\((\d+),\s*(\d+)\)'
407 | match = re.match(pattern, action)
408 |
409 | if match:
410 | action = match.group(1) # extract the action name
411 | x = int(match.group(2)) # extract x coordinate and convert to integer
412 | y = int(match.group(3)) # extract y coordinate and convert to integer
413 | return action, (x, y)
414 | else:
415 | return None, None
416 |
417 |
418 | def parse_drag_action(action):
419 | assert action.startswith('drag from'), f"error: action '{action}' is not a drag action"
420 | start1 = action.find('from (') + 6
421 | end1 = action.find(') to (')
422 | start2 = action.find('to (') + 4
423 | end2 = len(action) - 1
424 |
425 | # extract two sets of coordinates
426 | coord1 = action[start1:end1]
427 | coord2 = action[start2:end2]
428 |
429 | # split and convert to integers
430 | x1, y1 = map(int, coord1.split(', '))
431 | x2, y2 = map(int, coord2.split(', '))
432 |
433 | return (x1, y1), (x2, y2)
434 |
435 |
436 | def extract_coordinates(text):
437 | # Pattern for drag/press/scroll coordinates
438 | coord_pattern_1 = r'(?:drag to|press|scroll) \((\-?\d+), (\-?\d+)\)'
439 | coord_match = re.search(coord_pattern_1, text)
440 | if coord_match:
441 | x, y = map(int, coord_match.groups())
442 | return x, y
443 |
444 | # Pattern for scroll with dx and dy
445 | coord_pattern_2 = r'scroll dx\s*=\s*(\-?\d+),\s*dy\s*=\s*(\-?\d+)'
446 | coord_match = re.search(coord_pattern_2, text)
447 | if coord_match:
448 | dx, dy = map(int, coord_match.groups())
449 | return dx, dy
450 |
451 | # If no match is found, return None
452 | return None
453 |
454 |
455 | def draw_point(x, y, draw):
456 | radius = POINT_RADIUS
457 | left = x - radius
458 | top = y - radius
459 | right = x + radius
460 | bottom = y + radius
461 |
462 | draw.ellipse(
463 | [(left, top), (right, bottom)],
464 | fill="red"
465 | )
466 |
467 |
468 | def draw_circle(x, y, draw):
469 | radius = CIRCLE_RADIUS
470 | left = x - radius
471 | top = y - radius
472 | right = x + radius
473 | bottom = y + radius
474 |
475 | draw.ellipse(
476 | [(left, top), (right, bottom)],
477 | outline="red",
478 | width=CIRCLE_WIDTH
479 | )
480 |
481 |
482 | def draw_short_arrow(x, y, draw):
483 | arrow_length = 50 # arrow length
484 | arrow_gap = CIRCLE_RADIUS + 2 # arrow gap
485 | arrow_width = 10 # arrow width
486 | angle = np.radians(30) # arrow angle
487 | cos_angle = np.cos(angle)
488 | sin_angle = np.sin(angle)
489 |
490 | # draw the arrow body
491 | start_x = x - arrow_length * cos_angle
492 | start_y = y - arrow_length * sin_angle
493 | end_x = x - arrow_gap * cos_angle
494 | end_y = y - arrow_gap * sin_angle
495 | draw.line([(start_x, start_y), (end_x, end_y)],
496 | fill="red", width=3)
497 |
498 | # draw the arrow head
499 | arrow_point1 = (
500 | int(end_x - arrow_width),
501 | int(end_y)
502 | )
503 | arrow_point2 = (
504 | int(end_x - arrow_width * sin_angle),
505 | int(end_y - arrow_width * cos_angle)
506 | )
507 |
508 | draw.polygon([
509 | (end_x, end_y),
510 | arrow_point1,
511 | arrow_point2
512 | ], fill="red")
513 |
514 |
515 | def draw_long_arrow(x1, y1, x2, y2, draw):
516 | head_length = 18 # arrow head length
517 | head_angle = np.radians(30) # arrow head angle
518 |
519 | # calculate the midpoint of the line
520 | mid_x = (x1 + x2) / 2
521 | mid_y = (y1 + y2) / 2
522 |
523 | # draw the arrow body
524 | draw.line([(x1, y1), (x2, y2)], fill="red", width=3)
525 |
526 | # arrow head direction vector
527 | vector_x = x2 - x1
528 | vector_y = y2 - y1
529 | length = np.hypot(vector_x, vector_y)
530 | unit_vector_x = vector_x / length
531 | unit_vector_y = vector_y / length
532 |
533 | # calculate the positions of the two points of the arrow head (now based on the midpoint)
534 | left_x = mid_x - head_length * \
535 | (unit_vector_x * np.cos(head_angle) +
536 | unit_vector_y * np.sin(head_angle))
537 | left_y = mid_y - head_length * \
538 | (unit_vector_y * np.cos(head_angle) -
539 | unit_vector_x * np.sin(head_angle))
540 |
541 | right_x = mid_x - head_length * \
542 | (unit_vector_x * np.cos(head_angle) -
543 | unit_vector_y * np.sin(head_angle))
544 | right_y = mid_y - head_length * \
545 | (unit_vector_y * np.cos(head_angle) +
546 | unit_vector_x * np.sin(head_angle))
547 |
548 | # use the midpoint as the vertex of the arrow head
549 | draw.polygon([(mid_x, mid_y), (left_x, left_y),
550 | (right_x, right_y)], fill="red")
551 |
552 |
553 | def parse_thought_action_from_response(response):
554 | """
555 | Parse thought, action from response by finding the last occurrence of 'Action:'.
556 | """
557 | if response is None:
558 | return None, None
559 |
560 | # Find the last occurrence of 'Action:'
561 | index = response.rfind("Action:")
562 | if index == -1:
563 | return response.strip(), None
564 |
565 | # Split the response into thought and action
566 | thought = response[:index].strip()
567 | action_start = index + len("Action:")
568 | action = response[action_start:].strip()
569 |
570 | return thought, action
571 |
572 |
573 | def combine_thought_action_to_response(thought, action):
574 | return f"{thought}\n\nAction: {action}"
575 |
576 |
577 | def get_mllm_messages(instruction, base64_image=None):
578 | messages = [
579 | {
580 | "role": "user",
581 | "content": [
582 | {
583 | "type": "image_url",
584 | "image_url": {
585 | "url": f"data:image/jpeg;base64,{base64_image}"
586 | }
587 | },
588 | {
589 | "type": "text",
590 | "text": instruction
591 | },
592 | ] if base64_image else [
593 | {
594 | "type": "text",
595 | "text": instruction
596 | }
597 | ]
598 | },
599 | ]
600 | return messages
601 |
602 |
603 | def match(action, gt_entry):
604 | """
605 | Determine if the predicted action is equivalent to the ground truth entry
606 |
607 | Args:
608 | action (str): Predicted action string
609 | gt_entry (dict): Dictionary containing ground truth action and related information
610 |
611 | Returns:
612 | bool: Returns True if actions are equivalent, False otherwise
613 | """
614 | # Handle edge cases first
615 | if action is None or gt_entry is None or "action" not in gt_entry:
616 | return False
617 |
618 | gt_action = gt_entry["action"]
619 |
620 | # Handle all click-type actions (click, right click, double click)
621 | click_types = ["click", "right click", "double click"]
622 |
623 | for click_type in click_types:
624 | if action.startswith(click_type) and gt_action.startswith(click_type):
625 | # After confirming click type match, check coordinates
626 | try:
627 | # Extract coordinates from predicted action
628 | import re
629 | coord_match = re.search(r'\((\d+),\s*(\d+)\)', action)
630 | if not coord_match:
631 | return False
632 |
633 | x, y = int(coord_match.group(1)), int(coord_match.group(2))
634 |
635 | # Check if coordinates are within gt_entry's rect range
636 | if "rect" in gt_entry:
637 | rect = gt_entry["rect"]
638 | # Check rect format, usually [x1, y1, x2, y2] representing top-left and bottom-right coordinates
639 | if isinstance(rect, list) and len(rect) == 4:
640 | x1, y1, x2, y2 = rect
641 | return x1 <= x <= x2 and y1 <= y <= y2
642 | except Exception as e:
643 | print(f"Error in matching click coordinates: {e}")
644 | return False
645 |
646 | # For all other action types, directly compare if strings are identical
647 | return action == gt_action
648 |
649 |
650 | def get_action_code(action) -> str:
651 | screen_width, screen_height = 1280, 720
652 | # click
653 | match = re.match(r"click \((-?\d+), (-?\d+)\)", action)
654 | if match:
655 | x = int(match.group(1))
656 | y = int(match.group(2))
657 | if 0 <= x < screen_width and 0 <= y < screen_height:
658 | return f"pyautogui.click({x}, {y})"
659 | else:
660 | return None
661 |
662 | # right click
663 | match = re.match(r"right click \((-?\d+), (-?\d+)\)", action)
664 | if match:
665 | x = int(match.group(1))
666 | y = int(match.group(2))
667 | if 0 <= x < screen_width and 0 <= y < screen_height:
668 | return f"pyautogui.rightClick({x}, {y})"
669 | else:
670 | return None
671 |
672 | # double click
673 | match = re.match(r"double click \((-?\d+), (-?\d+)\)", action)
674 | if match:
675 | x = int(match.group(1))
676 | y = int(match.group(2))
677 | if 0 <= x < screen_width and 0 <= y < screen_height:
678 | return f"pyautogui.doubleClick({x}, {y})"
679 | else:
680 | return None
681 |
682 | # drag
683 | match = re.match(r"drag from \((-?\d+), (-?\d+)\) to \((-?\d+), (-?\d+)\)", action)
684 | if match:
685 | x1 = int(match.group(1)) # start x coordinate
686 | y1 = int(match.group(2)) # start y coordinate
687 | x2 = int(match.group(3)) # target x coordinate
688 | y2 = int(match.group(4)) # target y coordinate
689 | if 0 <= x1 < screen_width and 0 <= y1 < screen_height and 0 <= x2 < screen_width and 0 <= y2 < screen_height:
690 | return f"pyautogui.mouseDown({x1}, {y1})\npyautogui.dragTo({x2}, {y2}, duration=0.5)"
691 | else:
692 | return None
693 |
694 | # scroll
695 | match = re.match(r"scroll \((-?\d+)\)", action)
696 | if match:
697 | y = int(match.group(1)) # vertical scroll distance
698 | return f"pyautogui.scroll({y})" # positive: scroll up, negative: scroll down
699 |
700 | # press key
701 | match = re.match(r"press key: (.+)", action)
702 | if match:
703 | key_content = match.group(1).lower()
704 | # Format error
705 | if 'key' in key_content:
706 | return None
707 | # If key is not in the valid key list
708 | if key_content not in KEYBOARD_KEYS:
709 | return None
710 | return f"pyautogui.press('{key_content}')"
711 |
712 | # hotkey
713 | match = re.match(r"hotkey \((.+), (.+), (.+)\)", action)
714 | if match:
715 | key1 = match.group(1).strip("'").lower()
716 | key2 = match.group(2).strip("'").lower()
717 | key3 = match.group(3).strip("'").lower()
718 | # Format error
719 | if 'key' in key1 or 'key' in key2 or 'key' in key3:
720 | return None
721 | return f"pyautogui.hotkey('{key1}', '{key2}', '{key3}')"
722 |
723 | match = re.match(r"hotkey \((.+), (.+)\)", action)
724 | if match:
725 | key1 = match.group(1).strip("'").lower()
726 | key2 = match.group(2).strip("'").lower()
727 | # Format error
728 | if 'key' in key1 or 'key' in key2:
729 | return None
730 | return f"pyautogui.hotkey('{key1}', '{key2}')"
731 |
732 | # type text
733 | match = re.match(r"type text: (.+)", action)
734 | if match:
735 | text_content = match.group(1).strip("'").strip("\"")
736 | text_content = text_content.replace("\"", "\\\"")
737 | text_content = text_content.replace("\'", "\\\'")
738 | # Format error
739 | if "text_content" in text_content:
740 | return None
741 | return f"pyautogui.write(\"{text_content}\")"
742 |
743 | # wait
744 | if action == "wait":
745 | return "WAIT"
746 |
747 | # finish
748 | if action == "finish":
749 | return "DONE"
750 |
751 | # fail
752 | if action == "fail":
753 | return "FAIL"
754 |
755 | return None
756 |
757 |
758 | def get_history_str(history: list):
759 | history_cut_off = 10
760 | if len(history) > history_cut_off:
761 | history_str = "\n\n".join(f"[{i+1}] {item}" for i, item in enumerate(history[-history_cut_off:]))
762 | else:
763 | history_str = "\n\n".join(f"[{i+1}] {item}" for i, item in enumerate(history))
764 |
765 | if history_str == '':
766 | history_str = "None"
767 |
768 | return history_str
--------------------------------------------------------------------------------