├── .github
└── workflows
│ └── codeql.yml
├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── assets
└── images
│ ├── main_figure.png
│ └── title.png
├── data
└── data_config.yaml
├── eval
├── screenSpot.py
├── screenSpot_pro.py
└── screenSpot_v2.py
├── pyproject.toml
├── scripts
├── train.sh
├── warmup.sh
└── zero3.json
├── src
└── gui_actor
│ ├── __init__.py
│ ├── constants.py
│ ├── dataset.py
│ ├── inference.py
│ ├── modeling.py
│ ├── modeling_qwen25vl.py
│ ├── trainer.py
│ └── utils.py
├── train.py
└── verifier
├── README.md
├── ScreenSpot-v2-new
├── screenspot_desktop_v2.json
├── screenspot_mobile_v2.json
└── screenspot_web_v2.json
├── eval_ss_with_verifier.py
├── run_ss_pro.sh
├── run_ss_v1.sh
├── run_ss_v2.sh
├── verifier_data_generation.py
└── verifier_model.py
/.github/workflows/codeql.yml:
--------------------------------------------------------------------------------
1 | # For most projects, this workflow file will not need changing; you simply need
2 | # to commit it to your repository.
3 | #
4 | # You may wish to alter this file to override the set of languages analyzed,
5 | # or to provide custom queries or build logic.
6 | #
7 | # ******** NOTE ********
8 | # We have attempted to detect the languages in your repository. Please check
9 | # the `language` matrix defined below to confirm you have the correct set of
10 | # supported CodeQL languages.
11 | #
12 | name: "CodeQL Advanced"
13 |
14 | on:
15 | push:
16 | branches: [ "main" ]
17 | pull_request:
18 | branches: [ "main" ]
19 | schedule:
20 | - cron: '35 12 * * 3'
21 |
22 | jobs:
23 | analyze:
24 | name: Analyze (${{ matrix.language }})
25 | # Runner size impacts CodeQL analysis time. To learn more, please see:
26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql
27 | # - https://gh.io/supported-runners-and-hardware-resources
28 | # - https://gh.io/using-larger-runners (GitHub.com only)
29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements.
30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }}
31 | permissions:
32 | # required for all workflows
33 | security-events: write
34 |
35 | # required to fetch internal or private CodeQL packs
36 | packages: read
37 |
38 | # only required for workflows in private repositories
39 | actions: read
40 | contents: read
41 |
42 | strategy:
43 | fail-fast: false
44 | matrix:
45 | include:
46 | - language: python
47 | build-mode: none
48 | # CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift'
49 | # Use `c-cpp` to analyze code written in C, C++ or both
50 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both
51 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both
52 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis,
53 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning.
54 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how
55 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages
56 | steps:
57 | - name: Checkout repository
58 | uses: actions/checkout@v4
59 |
60 | # Add any setup steps before running the `github/codeql-action/init` action.
61 | # This includes steps like installing compilers or runtimes (`actions/setup-node`
62 | # or others). This is typically only required for manual builds.
63 | # - name: Setup runtime (example)
64 | # uses: actions/setup-example@v1
65 |
66 | # Initializes the CodeQL tools for scanning.
67 | - name: Initialize CodeQL
68 | uses: github/codeql-action/init@v3
69 | with:
70 | languages: ${{ matrix.language }}
71 | build-mode: ${{ matrix.build-mode }}
72 | # If you wish to specify custom queries, you can do so here or in a config file.
73 | # By default, queries listed here will override any specified in a config file.
74 | # Prefix the list here with "+" to use these queries and those in the config file.
75 |
76 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs
77 | # queries: security-extended,security-and-quality
78 |
79 | # If the analyze step fails for one of the languages you are analyzing with
80 | # "We were unable to automatically build your code", modify the matrix above
81 | # to set the build mode to "manual" for that language. Then modify this step
82 | # to build your code.
83 | # ℹ️ Command-line programs to run using the OS shell.
84 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun
85 | - if: matrix.build-mode == 'manual'
86 | shell: bash
87 | run: |
88 | echo 'If you are using a "manual" build mode for one or more of the' \
89 | 'languages you are analyzing, replace this with the commands to build' \
90 | 'your code, for example:'
91 | echo ' make bootstrap'
92 | echo ' make release'
93 | exit 1
94 |
95 | - name: Perform CodeQL Analysis
96 | uses: github/codeql-action/analyze@v3
97 | with:
98 | category: "/language:${{matrix.language}}"
99 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,linux,python,visualstudiocode
2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,linux,python,visualstudiocode
3 |
4 | ### Linux ###
5 | *~
6 |
7 | # temporary files which can be created if a process still has a handle open of a deleted file
8 | .fuse_hidden*
9 |
10 | # KDE directory preferences
11 | .directory
12 |
13 | # Linux trash folder which might appear on any partition or disk
14 | .Trash-*
15 |
16 | # .nfs files are created when an open file is removed but is still being accessed
17 | .nfs*
18 |
19 | ### macOS ###
20 | # General
21 | .DS_Store
22 | .AppleDouble
23 | .LSOverride
24 |
25 | # Icon must end with two \r
26 | Icon
27 |
28 |
29 | # Thumbnails
30 | ._*
31 |
32 | # Files that might appear in the root of a volume
33 | .DocumentRevisions-V100
34 | .fseventsd
35 | .Spotlight-V100
36 | .TemporaryItems
37 | .Trashes
38 | .VolumeIcon.icns
39 | .com.apple.timemachine.donotpresent
40 |
41 | # Directories potentially created on remote AFP share
42 | .AppleDB
43 | .AppleDesktop
44 | Network Trash Folder
45 | Temporary Items
46 | .apdisk
47 |
48 | ### macOS Patch ###
49 | # iCloud generated files
50 | *.icloud
51 |
52 | ### Python ###
53 | # Byte-compiled / optimized / DLL files
54 | __pycache__/
55 | *.py[cod]
56 | *$py.class
57 |
58 | # C extensions
59 | *.so
60 |
61 | # Distribution / packaging
62 | .Python
63 | build/
64 | develop-eggs/
65 | dist/
66 | downloads/
67 | eggs/
68 | .eggs/
69 | lib/
70 | lib64/
71 | parts/
72 | sdist/
73 | var/
74 | wheels/
75 | share/python-wheels/
76 | *.egg-info/
77 | .installed.cfg
78 | *.egg
79 | MANIFEST
80 |
81 | # PyInstaller
82 | # Usually these files are written by a python script from a template
83 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
84 | *.manifest
85 | *.spec
86 |
87 | # Installer logs
88 | pip-log.txt
89 | pip-delete-this-directory.txt
90 |
91 | # Unit test / coverage reports
92 | htmlcov/
93 | .tox/
94 | .nox/
95 | .coverage
96 | .coverage.*
97 | .cache
98 | nosetests.xml
99 | coverage.xml
100 | *.cover
101 | *.py,cover
102 | .hypothesis/
103 | .pytest_cache/
104 | cover/
105 |
106 | # Translations
107 | *.mo
108 | *.pot
109 |
110 | # Django stuff:
111 | *.log
112 | local_settings.py
113 | db.sqlite3
114 | db.sqlite3-journal
115 |
116 | # Flask stuff:
117 | instance/
118 | .webassets-cache
119 |
120 | # Scrapy stuff:
121 | .scrapy
122 |
123 | # Sphinx documentation
124 | docs/_build/
125 |
126 | # PyBuilder
127 | .pybuilder/
128 | target/
129 |
130 | # Jupyter Notebook
131 | .ipynb_checkpoints
132 |
133 | # IPython
134 | profile_default/
135 | ipython_config.py
136 |
137 | # pyenv
138 | # For a library or package, you might want to ignore these files since the code is
139 | # intended to run in multiple environments; otherwise, check them in:
140 | # .python-version
141 |
142 | # pipenv
143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
146 | # install all needed dependencies.
147 | #Pipfile.lock
148 |
149 | # poetry
150 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
151 | # This is especially recommended for binary packages to ensure reproducibility, and is more
152 | # commonly ignored for libraries.
153 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
154 | #poetry.lock
155 |
156 | # pdm
157 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
158 | #pdm.lock
159 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
160 | # in version control.
161 | # https://pdm.fming.dev/#use-with-ide
162 | .pdm.toml
163 |
164 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
165 | __pypackages__/
166 |
167 | # Celery stuff
168 | celerybeat-schedule
169 | celerybeat.pid
170 |
171 | # SageMath parsed files
172 | *.sage.py
173 |
174 | # Environments
175 | .env
176 | .venv
177 | env/
178 | venv/
179 | ENV/
180 | env.bak/
181 | venv.bak/
182 |
183 | # Spyder project settings
184 | .spyderproject
185 | .spyproject
186 |
187 | # Rope project settings
188 | .ropeproject
189 |
190 | # mkdocs documentation
191 | /site
192 |
193 | # mypy
194 | .mypy_cache/
195 | .dmypy.json
196 | dmypy.json
197 |
198 | # Pyre type checker
199 | .pyre/
200 |
201 | # pytype static type analyzer
202 | .pytype/
203 |
204 | # Cython debug symbols
205 | cython_debug/
206 |
207 | # PyCharm
208 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
209 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
210 | # and can be added to the global gitignore or merged into this file. For a more nuclear
211 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
212 | #.idea/
213 |
214 | ### Python Patch ###
215 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
216 | poetry.toml
217 |
218 | # ruff
219 | .ruff_cache/
220 |
221 | # LSP config files
222 | pyrightconfig.json
223 |
224 | ### VisualStudioCode ###
225 | .vscode/*
226 | !.vscode/settings.json
227 | !.vscode/tasks.json
228 | !.vscode/launch.json
229 | !.vscode/extensions.json
230 | !.vscode/*.code-snippets
231 |
232 | # Local History for Visual Studio Code
233 | .history/
234 |
235 | # Built Visual Studio Code Extensions
236 | *.vsix
237 |
238 | ### VisualStudioCode Patch ###
239 | # Ignore all local history of files
240 | .history
241 | .ionide
242 |
243 | # End of https://www.toptal.com/developers/gitignore/api/macos,linux,python,visualstudiocode
244 |
245 | wandb/
246 | results/*
247 | data/data_config_test.yaml
248 | checkpoints
249 | eval/inference.ipynb
250 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 Microsoft
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |

5 |
6 |
7 | [Qianhui Wu](https://qianhuiwu.github.io/)
*1
8 | [Kanzhi Cheng](https://scholar.google.com/citations?user=S2IPVnwAAAAJ&hl=en&oi=ao/)
*2
9 | [Rui Yang](https://yangrui2015.github.io/)
*3
10 | [Chaoyun Zhang](https://vyokky.github.io/)
1
11 | [Jianwei Yang](https://jwyang.github.io/)
1
12 | [Huiqiang Jiang](https://hqjiang.com/)
1
13 | [Jian Mu]()
2
14 | [Baolin Peng](https://scholar.google.com/citations?user=u1CNjgwAAAAJ&hl=zh-CN)
1
15 | [Bo Qiao](https://scholar.google.com/citations?user=_6ugrdYAAAAJ&hl=en)
1
16 | [Reuben Tan](https://cs-people.bu.edu/rxtan/)
1
17 | [Si Qin](https://sqin860.github.io/)
1
18 | [Lars Liden](https://sites.google.com/site/larsliden)
1
19 | [Qingwei Lin](https://scholar.google.com/citations?user=W9fdsxMAAAAJ&hl=zh-CN)
1
20 | [Huan Zhang](https://huan-zhang.com/)
3
21 | [Tong Zhang](https://tongzhang-ml.org/)
3
22 | [Jianbing Zhang](https://cs.nju.edu.cn/zhangjb/index.htm)
2
23 | [Dongmei Zhang](https://scholar.google.com/citations?user=jLlBBl4AAAAJ&hl=en)
1
24 | [Jianfeng Gao](https://scholar.google.com/citations?user=CQ1cqKkAAAAJ&hl=en)
1†
25 |
26 |
1 Microsoft Research
2 Nanjing University
3 University of Illinois Urbana-Champaign
27 |
* Equal Contribution
† Leadership
28 |
29 |
34 |
35 |
36 |
37 |
38 |

39 |
40 |
41 | Figure 1. **Left**: Model performance vs. training data scale on the ScreenSpot-Pro benchmark. Higher and more left is better; larger points indicate models with more parameters. We only show GUI-Actor models built upon Qwen2-VL here for fair comparison. With Qwen2.5-VL as the backbone, GUI-Actor-3B/7B reaches scores up to 42.2/44.6 (without Verifier). **Right**: Illustration of action attention. GUI-Actor grounds target elements by attending to the most relevant visual regions.
42 |
43 | ## :sparkles: Highlights
44 | 🤔 **We identify several limitations in coordinate-generation based methods** (_i.e._, output screen positions as text tokens x=…, y=…) for GUI grounding, including (1) weak spatial-semantic alignment, (2) ambiguous supervision signals, and (3) granularity mismatch between vision and action space.
45 |
46 | 💡 **Rethink how humans interact with digital interfaces**: humans do NOT calculate precise screen coordinates before acting—they perceive the target element and interact with it directly.
47 |
48 | 🚀 **We propose _GUI-Actor_, a VLM enhanced by an action head, to mitigate the above limitations.** The attention-based action head not only enables GUI-Actor to peform coordinate-free GUI grounding that more closely aligns with human behavior, but also can generate multiple candidate regions in a single forward pass, offering flexibility for downstream modules such as search strategies.
49 |
50 | ➕ **We design a _grounding verifier_ to evaluate and select the most plausible action region** among the candidates proposed from the action attention map. We show that this verifier can be easily integrated with other grounding methods for a further performance boost.
51 |
52 | 🎯 **GUI-Actor achieves state-of-the-art performance on multiple GUI action grounding benchmarks** with the same Qwen2-VL backbone, demonstrating its effectiveness and generalization to unseen screen resolutions and layouts. Notably, GUI-Actor-7B even surpasses UI-TARS-72B (38.1) on **ScreenSpot-Pro**, achieving scores of **40.7** with Qwen2-VL and **44.6** with Qwen2.5-VL as backbones.
53 |
54 |
57 |
58 | ## :bookmark_tabs: Todos
59 | We will be releasing all the following contents:
60 | - [x] Model training and evaluation based on Qwen2-VL (2025.06.03)
61 | - [x] Model checkpoint (2025.06.03)
62 | - [x] Code for grounding verifier (2025.06.06)
63 | - [ ] Support for Qwen2.5-VL
64 | - [ ] Processed training data
65 | - [ ] Demo
66 |
67 | ## :bar_chart: Main Results
68 | Table 1. Main results on ScreenSpot-Pro, ScreenSpot, and ScreenSpot-v2 with **Qwen2-VL** as the backbone. † indicates scores obtained from our own evaluation of the official models on Huggingface.
69 | | Method | Backbone VLM | ScreenSpot-Pro | ScreenSpot | ScreenSpot-v2 |
70 | |------------------|--------------|----------------|------------|----------------|
71 | | **_72B models:_**
72 | | AGUVIS-72B | Qwen2-VL | - | 89.2 | - |
73 | | UGround-V1-72B | Qwen2-VL | 34.5 | **89.4** | - |
74 | | UI-TARS-72B | Qwen2-VL | **38.1** | 88.4 | **90.3** |
75 | | **_7B models:_**
76 | | OS-Atlas-7B | Qwen2-VL | 18.9 | 82.5 | 84.1 |
77 | | AGUVIS-7B | Qwen2-VL | 22.9 | 84.4 | 86.0† |
78 | | UGround-V1-7B | Qwen2-VL | 31.1 | 86.3 | 87.6† |
79 | | UI-TARS-7B | Qwen2-VL | 35.7 | **89.5** | **91.6** |
80 | | GUI-Actor-7B | Qwen2-VL | **40.7** | 88.3 | 89.5 |
81 | | GUI-Actor-7B + Verifier | Qwen2-VL | 44.2 | 89.7 | 90.9 |
82 | | **_2B models:_**
83 | | UGround-V1-2B | Qwen2-VL | 26.6 | 77.1 | - |
84 | | UI-TARS-2B | Qwen2-VL | 27.7 | 82.3 | 84.7 |
85 | | GUI-Actor-2B | Qwen2-VL | **36.7** | **86.5** | **88.6** |
86 | | GUI-Actor-2B + Verifier | Qwen2-VL | 41.8 | 86.9 | 89.3 |
87 |
88 | Table 2. Main results on the ScreenSpot-Pro and ScreenSpot-v2 with **Qwen2.5-VL** as the backbone.
89 | | Method | Backbone VLM | ScreenSpot-Pro | ScreenSpot-v2 |
90 | |----------------|---------------|----------------|----------------|
91 | | **_7B models:_**
92 | | Qwen2.5-VL-7B | Qwen2.5-VL | 27.6 | 88.8 |
93 | | Jedi-7B | Qwen2.5-VL | 39.5 | 91.7 |
94 | | GUI-Actor-7B | Qwen2.5-VL | **44.6** | **92.1** |
95 | | GUI-Actor-7B + Verifier | Qwen2.5-VL | 47.7 | 92.5 |
96 | | **_3B models:_**
97 | | Qwen2.5-VL-3B | Qwen2.5-VL | 25.9 | 80.9 |
98 | | Jedi-3B | Qwen2.5-VL | 36.1 | 88.6 |
99 | | GUI-Actor-3B | Qwen2.5-VL | **42.2** | **91.0** |
100 | | GUI-Actor-3B + Verifier | Qwen2.5-VL | 45.9 | 92.4 |
101 |
102 | ## :rescue_worker_helmet: Installation
103 | 1. Clone this repo to your local machine:
104 | ```bash
105 | git clone https://github.com/microsoft/GUI-Actor.git
106 | cd GUI-Actor
107 | ```
108 | 2. Create a conda environment and install the dependencies:
109 | ```bash
110 | conda create -n gui_actor python=3.10
111 | conda activate gui_actor
112 | conda install pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia
113 | pip install -e .
114 | ```
115 | ## :minidisc: Data Preparation
116 | 1. Download the processed data from [here (coming soon)]().
117 | 2. Modify the paths in the [data_config.yaml](./data/data_config.yaml) file to point to the downloaded data.
118 |
119 | ## :building_construction: Model Training
120 | 1. Warmup stage:
121 | ```bash
122 | bash scripts/warmup.sh
123 | ```
124 | 2. Full-parameter training stage:
125 | ```bash
126 | bash scripts/train.sh
127 | ```
128 |
129 | ## :checkered_flag: Evaluation on GUI Grounding Benchmarks
130 | For evaluation on ScreenSpot and ScreenSpot-v2, you can directly run the scripts under the `scripts/` folder like `python eval/screenSpot.py` or `python eval/screenSpot_v2.py`.
131 |
132 | For evaluation on ScreenSpot-Pro, you first need to download the data from [here](https://huggingface.co/datasets/likaixin/ScreenSpot-Pro), then run the following command:
133 | ```bash
134 | python eval/screenSpot_pro.py --save_path --data_path
135 | ```
136 |
137 | Example usage:
138 | ```python
139 | import torch
140 |
141 | from qwen_vl_utils import process_vision_info
142 | from datasets import load_dataset
143 | from transformers import AutoProcessor
144 | from gui_actor.constants import chat_template
145 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
146 | from gui_actor.inference import inference
147 |
148 |
149 | # load model
150 | model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2-VL"
151 | data_processor = AutoProcessor.from_pretrained(model_name_or_path)
152 | tokenizer = data_processor.tokenizer
153 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
154 | model_name_or_path,
155 | torch_dtype=torch.bfloat16,
156 | device_map="cuda:0",
157 | attn_implementation="flash_attention_2"
158 | ).eval()
159 |
160 | # prepare example
161 | dataset = load_dataset("rootsautomation/ScreenSpot")["test"]
162 | example = dataset[0]
163 | print(f"Intruction: {example['instruction']}")
164 | print(f"ground-truth action region (x1, y1, x2, y2): {[round(i, 2) for i in example['bbox']]}")
165 |
166 | conversation = [
167 | {
168 | "role": "system",
169 | "content": [
170 | {
171 | "type": "text",
172 | "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.",
173 | }
174 | ]
175 | },
176 | {
177 | "role": "user",
178 | "content": [
179 | {
180 | "type": "image",
181 | "image": example["image"], # PIL.Image.Image or str to path
182 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
183 | },
184 | {
185 | "type": "text",
186 | "text": example["instruction"]
187 | },
188 | ],
189 | },
190 | ]
191 |
192 | # inference
193 | pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3)
194 | px, py = pred["topk_points"][0]
195 | print(f"Predicted click point: [{round(px, 4)}, {round(py, 4)}]")
196 |
197 | # >> Model Response
198 | # Intruction: close this window
199 | # ground-truth action region (x1, y1, x2, y2): [0.9479, 0.1444, 0.9938, 0.2074]
200 | # Predicted click point: [0.9709, 0.1548]
201 | ```
202 |
203 | ## :+1: Acknowledgements
204 |
205 | This project is built upon the following projects. Thanks for their great work!
206 | - [Transformers](https://github.com/huggingface/transformers)
207 | - [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL)
208 | - [AGUVIS](https://github.com/xlang-ai/aguvis)
209 |
210 | We also thank the authors of the following projects for their insightful work, as well as for providing datasets and engaging in valuable discussions.
211 | - [AGUVIS](https://github.com/xlang-ai/aguvis)
212 | - [UGround](https://github.com/OSU-NLP-Group/UGround)
213 | - [OS-Atlas](https://github.com/OS-Copilot/OS-Atlas)
214 | - [SeeClick](https://github.com/njucckevin/SeeClick)
215 |
216 | ## :memo: Citation
217 | If you find this work useful in your research, please consider citing:
218 | ```bibtex
219 | @misc{wu2025guiactor,
220 | title={GUI-Actor: Coordinate-Free Visual Grounding for GUI Agents},
221 | author={Qianhui Wu and Kanzhi Cheng and Rui Yang and Chaoyun Zhang and Jianwei Yang and Huiqiang Jiang and Jian Mu and Baolin Peng and Bo Qiao and Reuben Tan and Si Qin and Lars Liden and Qingwei Lin and Huan Zhang and Tong Zhang and Jianbing Zhang and Dongmei Zhang and Jianfeng Gao},
222 | year={2025},
223 | eprint={2506.03143},
224 | archivePrefix={arXiv},
225 | primaryClass={cs.CV},
226 | url={https://arxiv.org/abs/2506.03143},
227 | }
228 | ```
229 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | ## Security
4 |
5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin).
6 |
7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below.
8 |
9 | ## Reporting Security Issues
10 |
11 | **Please do not report security vulnerabilities through public GitHub issues.**
12 |
13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report).
14 |
15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp).
16 |
17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc).
18 |
19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue:
20 |
21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.)
22 | * Full paths of source file(s) related to the manifestation of the issue
23 | * The location of the affected source code (tag/branch/commit or direct URL)
24 | * Any special configuration required to reproduce the issue
25 | * Step-by-step instructions to reproduce the issue
26 | * Proof-of-concept or exploit code (if possible)
27 | * Impact of the issue, including how an attacker might exploit the issue
28 |
29 | This information will help us triage your report more quickly.
30 |
31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs.
32 |
33 | ## Preferred Languages
34 |
35 | We prefer all communications to be in English.
36 |
37 | ## Policy
38 |
39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd).
40 |
41 |
42 |
--------------------------------------------------------------------------------
/assets/images/main_figure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/assets/images/main_figure.png
--------------------------------------------------------------------------------
/assets/images/title.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/assets/images/title.png
--------------------------------------------------------------------------------
/data/data_config.yaml:
--------------------------------------------------------------------------------
1 | datasets:
2 | - json_path: /mnt/datasets/Uground/uground_aguvis_bbox_filter.json
3 | images_folder: /mnt/datasets/Uground/images/
4 | sampling_strategy: "all"
5 | - json_path: /mnt/datasets/GUIEnv/guienv_aguvis_bbox.json
6 | images_folder: /mnt/datasets/GUIEnv/guienvs/images/
7 | sampling_strategy: "all"
8 | - json_path: /mnt/datasets/GUIAct/guiact_aguvis_bbox.json
9 | images_folder: /mnt/datasets/GUIAct/web_imgs/
10 | sampling_strategy: "all"
11 | - json_path: /mnt/datasets/AMEX/amex_aguvis_bbox.json
12 | images_folder: /mnt/datasets/AMEX/screenshots/
13 | sampling_strategy: "all"
14 | - json_path: /mnt/datasets/AndroidControl/androidcontrol_aguvis_bbox.json
15 | images_folder: /mnt/datasets/AndroidControl/tfrecord/images/
16 | sampling_strategy: "all"
17 | - json_path: /mnt/datasets/Wave-UI/wave_ui_aguvis_bbox_fixed.json
18 | images_folder: /mnt/datasets/Wave-UI/images_fixed/
19 | sampling_strategy: "all"
--------------------------------------------------------------------------------
/eval/screenSpot.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import json
4 | import argparse
5 |
6 | from tqdm import tqdm
7 | from datasets import load_dataset
8 | from transformers import AutoProcessor
9 |
10 | from gui_actor.constants import chat_template
11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
14 | from gui_actor.utils import do_boxes_overlap
15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
16 |
17 | IMAGE_PATCH_SIZE =14
18 |
19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it
21 | x1, y1, x2, y2 = bbox_x1y1x2y2
22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1):
23 | return bbox_x1y1x2y2
24 | else:
25 | x1 = x1 / img_width
26 | y1 = y1 / img_height
27 | x2 = x2 / img_width
28 | y2 = y2 / img_height
29 | return x1, y1, x2, y2
30 |
31 | def evaluate(model_name_or_path, model_type, use_placeholder, topk):
32 | # initialize model
33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path)
34 | tokenizer = data_processor.tokenizer
35 | for k, v in tokenizer.added_tokens_encoder.items():
36 | print(v, k)
37 |
38 | if model_type == "qwen2vl":
39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41 | model_name_or_path,
42 | torch_dtype=torch.bfloat16,
43 | device_map="cuda:0",
44 | attn_implementation="flash_attention_2"
45 | ).eval()
46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47 | elif model_type == "qwen25vl":
48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50 | model_name_or_path,
51 | torch_dtype=torch.bfloat16,
52 | device_map="cuda:0",
53 | attn_implementation="flash_attention_2"
54 | ).eval()
55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()."
56 | else:
57 | raise ValueError(f"Invalid model type: {model_type}")
58 | print(f"Loaded model from {model_name_or_path}")
59 |
60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor(
61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
62 | forced_sequence=[
63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
64 | ]
65 | )
66 |
67 | dataset = load_dataset("rootsautomation/ScreenSpot")["test"]
68 | domain_dict = {
69 | "windows": "desktop",
70 | "macos": "desktop",
71 | "ios": "mobile",
72 | "android": "mobile",
73 | "tool": "web",
74 | "shop": "web",
75 | "gitlab": "web",
76 | "forum": "web"
77 | }
78 |
79 | results = []
80 | for i, example in tqdm(enumerate(dataset), total=len(dataset)):
81 | ele = {
82 | "file_name": example["file_name"],
83 | "data_type": example["data_type"],
84 | "domain": domain_dict[example["data_source"]],
85 | "instruction": example["instruction"],
86 | "img_size": example["image"].size,
87 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["image"].size[0], example["image"].size[1]),
88 | "hit_top1": 0,
89 | "overlap_top1": 0,
90 | "hit_topk": 0,
91 | "overlap_topk": 0,
92 | }
93 |
94 | conversation = [
95 | {
96 | "role": "system",
97 | "content": [
98 | {
99 | "type": "text",
100 | "text": grounding_system_message,
101 | }
102 | ]
103 | },
104 | {
105 | "role": "user",
106 | "content": [
107 | {
108 | "type": "image",
109 | "image": example["image"], # PIL.Image.Image or str to path
110 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
111 | },
112 | {
113 | "type": "text",
114 | "text": example["instruction"]
115 | },
116 | ],
117 | },
118 | ]
119 |
120 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3)
121 | topk_points = pred["topk_points"]
122 | gt_bbox = ele["bbox_x1y1x2y2"]
123 |
124 | # compute the metrics
125 | px, py = topk_points[0]
126 | x1, y1, x2, y2 = gt_bbox
127 |
128 | if (x1 <= px <= x2) and (y1 <= py <= y2):
129 | ele["hit_top1"] = 1
130 | ele["hit_topk"] = 1
131 |
132 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
133 | if do_boxes_overlap(pred_bbox, gt_bbox):
134 | ele["overlap_top1"] = 1
135 | ele["overlap_topk"] = 1
136 |
137 | for px, py in topk_points[1:]:
138 | if (x1 <= px <= x2) and (y1 <= py <= y2):
139 | ele["hit_topk"] = 1
140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
141 | if do_boxes_overlap(pred_bbox, gt_bbox):
142 | ele["overlap_topk"] = 1
143 |
144 | results.append(ele)
145 |
146 | return results
147 |
148 |
149 | def get_metric(list_of_examples,
150 | domains=["mobile", "desktop", "web"],
151 | data_types=["text", "icon"]):
152 | """
153 | Computes metrics over a list of examples and prints/plots a table.
154 |
155 | Each element in list_of_examples is a dict containing:
156 | - "domain": Domain name (e.g., "web", "mobile", "desktop")
157 | - "data_type": Data type (e.g., "text", "icon")
158 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1)
159 |
160 | The final table has columns for each domain broken down by UI type (plus a domain-average)
161 | and overall columns ("All-text", "All-icon", "All-average").
162 |
163 | The rows of the table are:
164 | - hit_top1
165 | - overlap_top1
166 | - hit_topk
167 | - overlap_topk
168 | """
169 |
170 | # List of metric keys to compute.
171 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"]
172 |
173 | # Helper function to compute the mean of a given key from a list of examples.
174 | def compute_mean(examples, key):
175 | if not examples:
176 | return None
177 | return sum(example.get(key, 0) for example in examples) / len(examples)
178 |
179 | # Prepare results dictionary: structure {metric: {column_name: value}}.
180 | results = {metric: {} for metric in metrics}
181 |
182 | # Compute metrics for each group broken down by UI type.
183 | for domain in domains:
184 | # Filter examples for the current group.
185 | domain_examples = [ex for ex in list_of_examples if ex.get("domain") == domain]
186 | for data_type in data_types:
187 | # Filter further for the specific UI type.
188 | domain_data_type_examples = [ex for ex in domain_examples if ex.get("data_type") == data_type]
189 | col_name = f"{domain}-{data_type}"
190 | for metric in metrics:
191 | results[metric][col_name] = compute_mean(domain_data_type_examples, metric)
192 |
193 | # Compute domain-average (all UI types for this domain).
194 | col_name_avg = f"{domain}-avg"
195 | for metric in metrics:
196 | results[metric][col_name_avg] = compute_mean(domain_examples, metric)
197 |
198 | # Compute overall metrics for each UI type across all domains.
199 | for data_type in data_types:
200 | data_type_examples = [ex for ex in list_of_examples if ex.get("data_type") == data_type]
201 | col_name = f"All-{data_type}"
202 | for metric in metrics:
203 | results[metric][col_name] = compute_mean(data_type_examples, metric)
204 |
205 | # Compute overall average across all examples.
206 | overall_key = "All-avg"
207 | for metric in metrics:
208 | results[metric][overall_key] = compute_mean(list_of_examples, metric)
209 |
210 | # Define the order of columns.
211 | columns_order = []
212 | for domain in domains:
213 | for data_type in data_types:
214 | columns_order.append(f"{domain}-{data_type}")
215 | columns_order.append(f"{domain}-avg")
216 | for data_type in data_types:
217 | columns_order.append(f"All-{data_type}")
218 | columns_order.append("All-avg")
219 |
220 | # ------------- Print Table to Console -------------
221 | # Prepare header row.
222 | header = [""] + columns_order
223 | # Calculate column widths for console printing.
224 | col_widths = [max(len(col), 12) for col in header]
225 |
226 | def format_cell(cell):
227 | if isinstance(cell, float):
228 | return f"{cell*100:.2f}"
229 | elif cell is None:
230 | return "N/A"
231 | return str(cell)
232 |
233 | # Print header.
234 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths))
235 | separator_line = "-+-".join("-" * width for width in col_widths)
236 | print(header_line)
237 | print(separator_line)
238 |
239 | for metric in metrics:
240 | row = [metric]
241 | for col in columns_order:
242 | val = results[metric].get(col)
243 | row.append(format_cell(val))
244 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths))
245 | print(row_line)
246 |
247 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) -------------
248 | metric_info = "Tab-delimited Table for Excel:\n"
249 | # Header row.
250 | header_tab = "\t".join([""] + columns_order)
251 | metric_info += (header_tab + "\n")
252 | # Each row.
253 | for metric in metrics:
254 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order]
255 | metric_info += ("\t".join(row) + "\n")
256 | print(metric_info)
257 | return metric_info
258 |
259 |
260 | """
261 | # cd to project root directory
262 | python eval/screenSpot.py --save_path
263 | """
264 | if __name__ == "__main__":
265 | parser = argparse.ArgumentParser()
266 | parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"])
267 | parser.add_argument("--model_name_or_path", type=str, default="qianhuiwu/GUI-Actor-3B-Qwen-2.5-VL")
268 | parser.add_argument("--save_path", type=str, default="./")
269 | parser.add_argument('--topk', type=int, default=3, help='Topk')
270 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder')
271 | parser.set_defaults(use_placeholder=True)
272 |
273 | args = parser.parse_args()
274 |
275 | save_path = args.save_path
276 | if not os.path.exists(save_path):
277 | os.makedirs(save_path, exist_ok=True)
278 | pred_path = f"{save_path}/screenspot_all_preds.json"
279 | metric_path = f"{save_path}/screenspot_all_metrics.txt"
280 |
281 | if os.path.exists(metric_path):
282 | exit()
283 |
284 | if os.path.exists(pred_path):
285 | print(f"Loading predictions from {pred_path}")
286 | with open(pred_path, "r") as f:
287 | results = json.load(f)
288 | else:
289 | print(f"Evaluating {args.model_name_or_path}...")
290 | results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk)
291 | with open(pred_path, "w") as f:
292 | json.dump(results, f)
293 | print(f"Saved {len(results)} predictions to {pred_path}")
294 |
295 | if not os.path.exists(metric_path):
296 | metric_info = get_metric(results)
297 | with open(metric_path, "w") as f:
298 | f.write(metric_info)
299 | print(f"Saved metric to {metric_path}")
300 |
--------------------------------------------------------------------------------
/eval/screenSpot_pro.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import json
4 | import argparse
5 |
6 | from tqdm import tqdm
7 | from datasets import load_dataset
8 | from transformers import AutoProcessor
9 | from PIL import Image
10 | from gui_actor.constants import chat_template
11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
14 | from gui_actor.utils import do_boxes_overlap
15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
16 |
17 | IMAGE_PATCH_SIZE =14
18 |
19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it
21 | x1, y1, x2, y2 = bbox_x1y1x2y2
22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1):
23 | return bbox_x1y1x2y2
24 | else:
25 | x1 = x1 / img_width
26 | y1 = y1 / img_height
27 | x2 = x2 / img_width
28 | y2 = y2 / img_height
29 | return x1, y1, x2, y2
30 |
31 | def evaluate(model_name_or_path, model_type, data_fn, image_dir, use_placeholder, topk, resize_to_pixels=None):
32 | # initialize model
33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path)
34 | tokenizer = data_processor.tokenizer
35 | for k, v in tokenizer.added_tokens_encoder.items():
36 | print(v, k)
37 |
38 | if model_type == "qwen2vl":
39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41 | model_name_or_path,
42 | torch_dtype=torch.bfloat16,
43 | device_map="cuda:0",
44 | attn_implementation="flash_attention_2"
45 | ).eval()
46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47 | elif model_type == "qwen25vl":
48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50 | model_name_or_path,
51 | torch_dtype=torch.bfloat16,
52 | device_map="cuda:0",
53 | attn_implementation="flash_attention_2"
54 | ).eval()
55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()."
56 | else:
57 | raise ValueError(f"Invalid model type: {model_type}")
58 | print(f"Loaded model from {model_name_or_path}")
59 |
60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor(
61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
62 | forced_sequence=[
63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
64 | ]
65 | )
66 |
67 | # load data
68 | with open(data_fn, "r") as f:
69 | data = json.load(f)
70 | print(f"Loaded {len(data)} examples from {data_fn}")
71 |
72 | results = []
73 | for i, example in tqdm(enumerate(data), total=len(data)):
74 | ele = {
75 | "file_name": example["img_filename"],
76 | "ui_type": example["ui_type"],
77 | "group": example["group"],
78 | "platform": example["platform"],
79 | "application": example["application"],
80 | "id": example["id"],
81 | "instruction": example["instruction"],
82 | "img_size": example["img_size"],
83 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["img_size"][0], example["img_size"][1]),
84 | "hit_top1": 0,
85 | "overlap_top1": 0,
86 | "hit_topk": 0,
87 | "overlap_topk": 0,
88 | }
89 |
90 | image_path = os.path.join(image_dir, example["img_filename"])
91 | image = Image.open(image_path)
92 | # resize the image if needed
93 | image_width, image_height = example["img_size"]
94 | if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels):
95 | resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5
96 | image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio)
97 | image = image.resize((image_width_resized, image_height_resized))
98 | ele["img_size_resized"] = [image_width_resized, image_height_resized]
99 | else:
100 | ele["img_size_resized"] = None
101 |
102 | conversation = [
103 | {
104 | "role": "system",
105 | "content": [
106 | {
107 | "type": "text",
108 | "text": grounding_system_message,
109 | }
110 | ]
111 | },
112 | {
113 | "role": "user",
114 | "content": [
115 | {
116 | "type": "image",
117 | "image": image, # PIL.Image.Image or str to path
118 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
119 | },
120 | {
121 | "type": "text",
122 | "text": example["instruction"]
123 | },
124 | ],
125 | },
126 | ]
127 |
128 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3)
129 | topk_points = pred["topk_points"]
130 | gt_bbox = ele["bbox_x1y1x2y2"]
131 |
132 | # compute the metrics
133 | px, py = topk_points[0]
134 | x1, y1, x2, y2 = gt_bbox
135 |
136 | if (x1 <= px <= x2) and (y1 <= py <= y2):
137 | ele["hit_top1"] = 1
138 | ele["hit_topk"] = 1
139 |
140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
141 | if do_boxes_overlap(pred_bbox, gt_bbox):
142 | ele["overlap_top1"] = 1
143 | ele["overlap_topk"] = 1
144 |
145 | for px, py in topk_points[1:]:
146 | if (x1 <= px <= x2) and (y1 <= py <= y2):
147 | ele["hit_topk"] = 1
148 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
149 | if do_boxes_overlap(pred_bbox, gt_bbox):
150 | ele["overlap_topk"] = 1
151 |
152 | results.append(ele)
153 |
154 | return results
155 |
156 |
157 | def get_metric(list_of_examples,
158 | groups=["Dev", "Creative", "CAD", "Scientific", "Office", "OS"],
159 | ui_types=["text", "icon"]):
160 | """
161 | Computes metrics over a list of examples and prints/plots a table.
162 |
163 | Each element in list_of_examples is a dict containing:
164 | - "group": Group name (e.g., "Dev", "Creative", etc.)
165 | - "ui_type": UI type (e.g., "text", "icon")
166 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1)
167 |
168 | The final table has columns for each group broken down by UI type (plus a group-average)
169 | and overall columns ("All-text", "All-icon", "All-average").
170 |
171 | The rows of the table are:
172 | - hit_top1
173 | - overlap_top1
174 | - hit_topk
175 | - overlap_topk
176 | """
177 |
178 | # List of metric keys to compute.
179 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"]
180 |
181 | # Helper function to compute the mean of a given key from a list of examples.
182 | def compute_mean(examples, key):
183 | if not examples:
184 | return None
185 | return sum(example.get(key, 0) for example in examples) / len(examples)
186 |
187 | # Prepare results dictionary: structure {metric: {column_name: value}}.
188 | results = {metric: {} for metric in metrics}
189 |
190 | # Compute metrics for each group broken down by UI type.
191 | for group in groups:
192 | # Filter examples for the current group.
193 | group_examples = [ex for ex in list_of_examples if ex.get("group") == group]
194 | for ui in ui_types:
195 | # Filter further for the specific UI type.
196 | group_ui_examples = [ex for ex in group_examples if ex.get("ui_type") == ui]
197 | col_name = f"{group}-{ui}"
198 | for metric in metrics:
199 | results[metric][col_name] = compute_mean(group_ui_examples, metric)
200 |
201 | # Compute group-average (all UI types for this group).
202 | col_name_avg = f"{group}-avg"
203 | for metric in metrics:
204 | results[metric][col_name_avg] = compute_mean(group_examples, metric)
205 |
206 | # Compute overall metrics for each UI type across all groups.
207 | for ui in ui_types:
208 | ui_examples = [ex for ex in list_of_examples if ex.get("ui_type") == ui]
209 | col_name = f"All-{ui}"
210 | for metric in metrics:
211 | results[metric][col_name] = compute_mean(ui_examples, metric)
212 |
213 | # Compute overall average across all examples.
214 | overall_key = "All-avg"
215 | for metric in metrics:
216 | results[metric][overall_key] = compute_mean(list_of_examples, metric)
217 |
218 | # Define the order of columns.
219 | columns_order = []
220 | for group in groups:
221 | for ui in ui_types:
222 | columns_order.append(f"{group}-{ui}")
223 | columns_order.append(f"{group}-avg")
224 | for ui in ui_types:
225 | columns_order.append(f"All-{ui}")
226 | columns_order.append("All-avg")
227 |
228 | # ------------- Print Table to Console -------------
229 | # Prepare header row.
230 | header = [""] + columns_order
231 | # Calculate column widths for console printing.
232 | col_widths = [max(len(col), 12) for col in header]
233 |
234 | def format_cell(cell):
235 | if isinstance(cell, float):
236 | return f"{cell*100:.2f}"
237 | elif cell is None:
238 | return "N/A"
239 | return str(cell)
240 |
241 | # Print header.
242 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths))
243 | separator_line = "-+-".join("-" * width for width in col_widths)
244 | print(header_line)
245 | print(separator_line)
246 |
247 | for metric in metrics:
248 | row = [metric]
249 | for col in columns_order:
250 | val = results[metric].get(col)
251 | row.append(format_cell(val))
252 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths))
253 | print(row_line)
254 |
255 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) -------------
256 | metric_info = "Tab-delimited Table for Excel:\n"
257 | # Header row.
258 | header_tab = "\t".join([""] + columns_order)
259 | metric_info += header_tab + "\n"
260 | # Each row.
261 | for metric in metrics:
262 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order]
263 | metric_info += ("\t".join(row) + "\n")
264 | print(metric_info)
265 | return metric_info
266 |
267 |
268 | """
269 | # cd to project root directory
270 | python eval/screenSpot_pro.py --save_path --data_path
271 | """
272 | if __name__ == "__main__":
273 | parser = argparse.ArgumentParser()
274 | parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"])
275 | parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-7B-Qwen2.5-VL")
276 | parser.add_argument("--save_path", type=str, default="./")
277 | parser.add_argument("--data_path", type=str, default="/mnt/data/ScreenSpot-Pro")
278 | parser.add_argument("--resize_to_pixels", type=int, default=3200*1800, help="If set to <0, will not resize the image.")
279 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder')
280 | parser.add_argument('--topk', type=int, default=3, help='Topk')
281 | parser.set_defaults(use_placeholder=True)
282 |
283 | args = parser.parse_args()
284 |
285 | resize_to_pixels = args.resize_to_pixels if args.resize_to_pixels > 0 else None
286 | image_dir = os.path.join(args.data_path, "images")
287 | data_fn = os.path.join(args.data_path, "annotations/all.json")
288 | save_path = args.save_path
289 | if not os.path.exists(save_path):
290 | os.makedirs(save_path, exist_ok=True)
291 | pred_path = f"{save_path}/screenspot-Pro_all_preds_StandardResize.json"
292 | metric_path = f"{save_path}/screenspot-Pro_all_preds_StandardResize.txt"
293 |
294 | if os.path.exists(metric_path):
295 | exit()
296 |
297 | if os.path.exists(pred_path):
298 | print(f"Loading predictions from {pred_path}")
299 | with open(pred_path, "r") as f:
300 | results = json.load(f)
301 | else:
302 | print(f"Evaluating {args.model_name_or_path}...")
303 | results = evaluate(args.model_name_or_path, args.model_type, data_fn, image_dir, args.use_placeholder, args.topk, resize_to_pixels)
304 | with open(pred_path, "w") as f:
305 | json.dump(results, f)
306 | print(f"Saved {len(results)} predictions to {pred_path}")
307 |
308 | if not os.path.exists(metric_path):
309 | metric_info = get_metric(results)
310 | with open(metric_path, "w") as f:
311 | f.write(metric_info)
312 | print(f"Saved metric to {metric_path}")
313 |
--------------------------------------------------------------------------------
/eval/screenSpot_v2.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 | import json
4 | import argparse
5 |
6 | from tqdm import tqdm
7 | from datasets import load_dataset
8 | from transformers import AutoProcessor
9 |
10 | from gui_actor.constants import chat_template
11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor
14 | from gui_actor.utils import do_boxes_overlap
15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN
16 |
17 | IMAGE_PATCH_SIZE =14
18 |
19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height):
20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it
21 | x1, y1, x2, y2 = bbox_x1y1x2y2
22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1):
23 | return bbox_x1y1x2y2
24 | else:
25 | x1 = x1 / img_width
26 | y1 = y1 / img_height
27 | x2 = x2 / img_width
28 | y2 = y2 / img_height
29 | return x1, y1, x2, y2
30 |
31 | def evaluate(model_name_or_path, model_type, use_placeholder, topk):
32 | # initialize model
33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path)
34 | tokenizer = data_processor.tokenizer
35 | for k, v in tokenizer.added_tokens_encoder.items():
36 | print(v, k)
37 |
38 | if model_type == "qwen2vl":
39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}")
40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
41 | model_name_or_path,
42 | torch_dtype=torch.bfloat16,
43 | device_map="cuda:0",
44 | attn_implementation="flash_attention_2"
45 | ).eval()
46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
47 | elif model_type == "qwen25vl":
48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}")
49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
50 | model_name_or_path,
51 | torch_dtype=torch.bfloat16,
52 | device_map="cuda:0",
53 | attn_implementation="flash_attention_2"
54 | ).eval()
55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()."
56 | else:
57 | raise ValueError(f"Invalid model type: {model_type}")
58 | print(f"Loaded model from {model_name_or_path}")
59 |
60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor(
61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
62 | forced_sequence=[
63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
64 | ]
65 | )
66 |
67 | dataset = load_dataset("HongxinLi/ScreenSpot_v2")["test"]
68 | domain_dict = {
69 | "windows": "desktop",
70 | "macos": "desktop",
71 | "ios": "mobile",
72 | "android": "mobile",
73 | "tool": "web",
74 | "shop": "web",
75 | "gitlab": "web",
76 | "forum": "web"
77 | }
78 |
79 | results = []
80 | for i, example in tqdm(enumerate(dataset), total=len(dataset)):
81 | ele = {
82 | "file_name": example["file_name"],
83 | "data_type": example["data_type"],
84 | "domain": domain_dict[example["data_source"]],
85 | "instruction": example["instruction"],
86 | "img_size": example["image"].size,
87 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["image"].size[0], example["image"].size[1]),
88 | "hit_top1": 0,
89 | "overlap_top1": 0,
90 | "hit_topk": 0,
91 | "overlap_topk": 0,
92 | }
93 |
94 | conversation = [
95 | {
96 | "role": "system",
97 | "content": [
98 | {
99 | "type": "text",
100 | "text": grounding_system_message,
101 | }
102 | ]
103 | },
104 | {
105 | "role": "user",
106 | "content": [
107 | {
108 | "type": "image",
109 | "image": example["image"], # PIL.Image.Image or str to path
110 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
111 | },
112 | {
113 | "type": "text",
114 | "text": example["instruction"]
115 | },
116 | ],
117 | },
118 | ]
119 |
120 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3)
121 | topk_points = pred["topk_points"]
122 | gt_bbox = ele["bbox_x1y1x2y2"]
123 |
124 | # compute the metrics
125 | px, py = topk_points[0]
126 | x1, y1, x2, y2 = gt_bbox
127 |
128 | if (x1 <= px <= x2) and (y1 <= py <= y2):
129 | ele["hit_top1"] = 1
130 | ele["hit_topk"] = 1
131 |
132 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
133 | if do_boxes_overlap(pred_bbox, gt_bbox):
134 | ele["overlap_top1"] = 1
135 | ele["overlap_topk"] = 1
136 |
137 | for px, py in topk_points[1:]:
138 | if (x1 <= px <= x2) and (y1 <= py <= y2):
139 | ele["hit_topk"] = 1
140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE]
141 | if do_boxes_overlap(pred_bbox, gt_bbox):
142 | ele["overlap_topk"] = 1
143 |
144 | results.append(ele)
145 |
146 | return results
147 |
148 |
149 | def get_metric(list_of_examples,
150 | domains=["mobile", "desktop", "web"],
151 | data_types=["text", "icon"]):
152 | """
153 | Computes metrics over a list of examples and prints/plots a table.
154 |
155 | Each element in list_of_examples is a dict containing:
156 | - "domain": Domain name (e.g., "web", "mobile", "desktop")
157 | - "data_type": Data type (e.g., "text", "icon")
158 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1)
159 |
160 | The final table has columns for each domain broken down by UI type (plus a domain-average)
161 | and overall columns ("All-text", "All-icon", "All-average").
162 |
163 | The rows of the table are:
164 | - hit_top1
165 | - overlap_top1
166 | - hit_topk
167 | - overlap_topk
168 | """
169 |
170 | # List of metric keys to compute.
171 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"]
172 |
173 | # Helper function to compute the mean of a given key from a list of examples.
174 | def compute_mean(examples, key):
175 | if not examples:
176 | return None
177 | return sum(example.get(key, 0) for example in examples) / len(examples)
178 |
179 | # Prepare results dictionary: structure {metric: {column_name: value}}.
180 | results = {metric: {} for metric in metrics}
181 |
182 | # Compute metrics for each group broken down by UI type.
183 | for domain in domains:
184 | # Filter examples for the current group.
185 | domain_examples = [ex for ex in list_of_examples if ex.get("domain") == domain]
186 | for data_type in data_types:
187 | # Filter further for the specific UI type.
188 | domain_data_type_examples = [ex for ex in domain_examples if ex.get("data_type") == data_type]
189 | col_name = f"{domain}-{data_type}"
190 | for metric in metrics:
191 | results[metric][col_name] = compute_mean(domain_data_type_examples, metric)
192 |
193 | # Compute domain-average (all UI types for this domain).
194 | col_name_avg = f"{domain}-avg"
195 | for metric in metrics:
196 | results[metric][col_name_avg] = compute_mean(domain_examples, metric)
197 |
198 | # Compute overall metrics for each UI type across all domains.
199 | for data_type in data_types:
200 | data_type_examples = [ex for ex in list_of_examples if ex.get("data_type") == data_type]
201 | col_name = f"All-{data_type}"
202 | for metric in metrics:
203 | results[metric][col_name] = compute_mean(data_type_examples, metric)
204 |
205 | # Compute overall average across all examples.
206 | overall_key = "All-avg"
207 | for metric in metrics:
208 | results[metric][overall_key] = compute_mean(list_of_examples, metric)
209 |
210 | # Define the order of columns.
211 | columns_order = []
212 | for domain in domains:
213 | for data_type in data_types:
214 | columns_order.append(f"{domain}-{data_type}")
215 | columns_order.append(f"{domain}-avg")
216 | for data_type in data_types:
217 | columns_order.append(f"All-{data_type}")
218 | columns_order.append("All-avg")
219 |
220 | # ------------- Print Table to Console -------------
221 | # Prepare header row.
222 | header = [""] + columns_order
223 | # Calculate column widths for console printing.
224 | col_widths = [max(len(col), 12) for col in header]
225 |
226 | def format_cell(cell):
227 | if isinstance(cell, float):
228 | return f"{cell*100:.2f}"
229 | elif cell is None:
230 | return "N/A"
231 | return str(cell)
232 |
233 | # Print header.
234 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths))
235 | separator_line = "-+-".join("-" * width for width in col_widths)
236 | print(header_line)
237 | print(separator_line)
238 |
239 | for metric in metrics:
240 | row = [metric]
241 | for col in columns_order:
242 | val = results[metric].get(col)
243 | row.append(format_cell(val))
244 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths))
245 | print(row_line)
246 |
247 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) -------------
248 | metric_info = "Tab-delimited Table for Excel:\n"
249 | # Header row.
250 | header_tab = "\t".join([""] + columns_order)
251 | metric_info += (header_tab + "\n")
252 | # Each row.
253 | for metric in metrics:
254 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order]
255 | metric_info += ("\t".join(row) + "\n")
256 | print(metric_info)
257 | return metric_info
258 |
259 |
260 | """
261 | # cd to project root directory
262 | python eval/screenSpot_v2.py --save_path
263 | """
264 | if __name__ == "__main__":
265 | parser = argparse.ArgumentParser()
266 | parser.add_argument("--model_type", type=str, default="qwen2vl", choices=["qwen2vl", "qwen25vl"])
267 | parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-2B-Qwen2-VL")
268 | parser.add_argument("--save_path", type=str, default="./")
269 | parser.add_argument('--topk', type=int, default=3, help='Topk')
270 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder')
271 | parser.set_defaults(use_placeholder=True)
272 |
273 | args = parser.parse_args()
274 |
275 | save_path = args.save_path
276 | if not os.path.exists(save_path):
277 | os.makedirs(save_path, exist_ok=True)
278 | pred_path = f"{save_path}/screenspot_v2_all_preds.json"
279 | metric_path = f"{save_path}/screenspot_v2_all_metrics.txt"
280 |
281 | if os.path.exists(metric_path):
282 | exit()
283 |
284 | if os.path.exists(pred_path):
285 | print(f"Loading predictions from {pred_path}")
286 | with open(pred_path, "r") as f:
287 | results = json.load(f)
288 | else:
289 | print(f"Evaluating {args.model_name_or_path}...")
290 | results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk)
291 | with open(pred_path, "w") as f:
292 | json.dump(results, f)
293 | print(f"Saved {len(results)} predictions to {pred_path}")
294 |
295 | if not os.path.exists(metric_path):
296 | metric_info = get_metric(results)
297 | with open(metric_path, "w") as f:
298 | f.write(metric_info)
299 | print(f"Saved metric to {metric_path}")
300 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [project]
2 | name = "gui-actor"
3 | version = "0.1.0"
4 | description = "Coordinate-Free Visual Grounding for GUI Agents"
5 | authors = [
6 | {name = "GUI-Actor team"},
7 | ]
8 | dependencies = [
9 | "pre-commit>=3.7.1",
10 | "pip>=24.1.1",
11 | "Pillow>=10.4.0",
12 | "liger-kernel==0.5.2",
13 | "opencv-python-headless>=4.10.0.84",
14 | "accelerate==1.1.1",
15 | "qwen-vl-utils==0.0.8",
16 | "deepspeed==0.16.0",
17 | "transformers==4.51.3",
18 | "flash-attn",
19 | "wandb==0.18.3",
20 | "datasets>=2.18.0"
21 | ]
22 | requires-python = ">=3.10,<3.13"
23 | readme = "README.md"
24 | license = {text = "MIT"}
25 |
26 |
27 | [tool.pdm]
28 | distribution = false
29 |
30 | [tool.pdm.dev-dependencies]
31 | test = [
32 | "pytest>=8.2.2",
33 | ]
34 | [tool.ruff]
35 | target-version = 'py38'
36 | line-length = 120 # Must agree with Black
37 |
38 | [tool.ruff.lint]
39 | select = [
40 | "B", # flake8-bugbear
41 | "C4", # flake8-comprehensions
42 | "D", # pydocstyle
43 | "E", # Error
44 | "F", # pyflakes
45 | "I", # isort
46 | "ISC", # flake8-implicit-str-concat
47 | "N", # pep8-naming
48 | "PGH", # pygrep-hooks
49 | # "PTH", # flake8-use-pathlib
50 | "Q", # flake8-quotes
51 | "SIM", # flake8-simplify
52 | "TRY", # tryceratops
53 | "UP", # pyupgrade
54 | "W", # Warning
55 | "YTT", # flake8-2020
56 | ]
57 |
58 | exclude = [
59 | "migrations",
60 | "__pycache__",
61 | "manage.py",
62 | "settings.py",
63 | "env",
64 | ".env",
65 | "venv",
66 | ".venv",
67 | ]
68 |
69 | ignore = [
70 | "B905", # zip strict=True; remove once python <3.10 support is dropped.
71 | "D100",
72 | "D101",
73 | "D102",
74 | "D103",
75 | "D104",
76 | "D105",
77 | "D106",
78 | "D107",
79 | "D200",
80 | "D401",
81 | "E402",
82 | "E501",
83 | "TRY003", # Avoid specifying messages outside exception class; overly strict, especially for ValueError
84 | "N812",
85 | ]
86 |
87 | [tool.ruff.lint.flake8-bugbear]
88 | extend-immutable-calls = [
89 | "chr",
90 | "typer.Argument",
91 | "typer.Option",
92 | ]
93 |
94 | [tool.ruff.lint.pydocstyle]
95 | convention = "numpy"
96 |
97 | [tool.ruff.lint.per-file-ignores]
98 | "tests/*.py" = [
99 | "D100",
100 | "D101",
101 | "D102",
102 | "D103",
103 | "D104",
104 | "D105",
105 | "D106",
106 | "D107",
107 | "S101", # use of "assert"
108 | "S102", # use of "exec"
109 | "S106", # possible hardcoded password.
110 | "PGH001", # use of "eval"
111 | ]
112 |
113 | [tool.ruff.lint.pep8-naming]
114 | staticmethod-decorators = [
115 | "pydantic.validator",
116 | "pydantic.root_validator",
117 | ]
118 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # model_type: qwen2vl or qwen25vl
3 | model_type="qwen2vl"
4 | llm_model="./checkpoints/${model_type}_warmup"
5 | output_dir="./checkpoints/${model_type}_sft"
6 |
7 | # === Training Command ===
8 | torchrun --nproc_per_node=4 train.py \
9 | --deepspeed ./scripts/zero3.json \
10 | --data_path data/data_config.yaml \
11 | --image_folder "" \
12 | --model_type ${model_type} \
13 | --model_name_or_path ${llm_model} \
14 | --group_by_modality_length True \
15 | --bf16 True \
16 | --output_dir ${output_dir} \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 1 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --eval_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 2000 \
24 | --learning_rate 1e-4 \
25 | --weight_decay 0. \
26 | --warmup_ratio 0.03 \
27 | --lr_scheduler_type "cosine" \
28 | --logging_steps 10 \
29 | --tf32 True \
30 | --model_max_length 24576 \
31 | --gradient_checkpointing True \
32 | --dataloader_num_workers 8 \
33 | --max_pixels 5720064 \
34 | --unfreeze_all_parameters True \
35 | --unfreeze_pointer_head False \
36 | --unfreeze_lm_head False \
37 | --unfreeze_base_model False \
38 | --unfreeze_last_n_layers -1 \
39 | --unfreeze_new_tokens False \
40 | --unfreeze_visual False \
41 | --pointer_loss_weight 1.0 \
42 | --lm_loss_weight 1.0
43 |
--------------------------------------------------------------------------------
/scripts/warmup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | # model_type: qwen2vl or qwen25vl
3 | model_type="qwen25vl"
4 | llm_model="Qwen/Qwen2.5-VL-3B-Instruct"
5 | output_dir="./checkpoints/${model_type}_warmup"
6 |
7 | # === Training Command ===
8 | torchrun --nproc_per_node=4 train.py \
9 | --deepspeed ./scripts/zero3.json \
10 | --data_path data/data_config.yaml \
11 | --image_folder "" \
12 | --model_type ${model_type} \
13 | --model_name_or_path ${llm_model} \
14 | --group_by_modality_length True \
15 | --bf16 True \
16 | --output_dir ${output_dir} \
17 | --num_train_epochs 1 \
18 | --per_device_train_batch_size 1 \
19 | --per_device_eval_batch_size 4 \
20 | --gradient_accumulation_steps 1 \
21 | --eval_strategy "no" \
22 | --save_strategy "steps" \
23 | --save_steps 2000 \
24 | --learning_rate 1e-4 \
25 | --weight_decay 0. \
26 | --warmup_ratio 0.03 \
27 | --lr_scheduler_type "cosine" \
28 | --logging_steps 10 \
29 | --tf32 True \
30 | --model_max_length 24576 \
31 | --gradient_checkpointing True \
32 | --dataloader_num_workers 8 \
33 | --max_pixels 5720064 \
34 | --unfreeze_all_parameters False \
35 | --unfreeze_pointer_head True \
36 | --unfreeze_lm_head False \
37 | --unfreeze_base_model False \
38 | --unfreeze_last_n_layers -1 \
39 | --unfreeze_new_tokens True \
40 | --unfreeze_visual False \
41 | --pointer_loss_weight 1.0 \
42 | --lm_loss_weight -1.0
43 |
--------------------------------------------------------------------------------
/scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "zero_optimization": {
14 | "stage": 3,
15 | "offload_optimizer": {
16 | "device": "none",
17 | "pin_memory": false
18 | },
19 | "offload_param": {
20 | "device": "none",
21 | "pin_memory": false
22 | },
23 | "overlap_comm": false,
24 | "contiguous_gradients": true,
25 | "sub_group_size": 1000000000.0,
26 | "reduce_bucket_size": "auto",
27 | "stage3_prefetch_bucket_size": "auto",
28 | "stage3_param_persistence_threshold": "auto",
29 | "stage3_max_live_parameters": 1000000000.0,
30 | "stage3_max_reuse_distance": 1000000000.0,
31 | "stage3_gather_16bit_weights_on_model_save": true
32 | },
33 | "gradient_accumulation_steps": "auto",
34 | "gradient_clipping": "auto",
35 | "steps_per_print": 100,
36 | "train_batch_size": "auto",
37 | "train_micro_batch_size_per_gpu": "auto",
38 | "wall_clock_breakdown": false
39 | }
40 |
--------------------------------------------------------------------------------
/src/gui_actor/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/src/gui_actor/__init__.py
--------------------------------------------------------------------------------
/src/gui_actor/constants.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | CONTROLLER_HEART_BEAT_EXPIRATION = 30
4 | WORKER_HEART_BEAT_INTERVAL = 15
5 |
6 | LOGDIR = "."
7 |
8 | # Model Constants
9 | IGNORE_INDEX = -100
10 | DEFAULT_IMAGE_TOKEN = ""
11 | DEFAULT_POINTER_START_TOKEN = "<|pointer_start|>"
12 | DEFAULT_POINTER_END_TOKEN = "<|pointer_end|>"
13 | DEFAULT_POINTER_PAD_TOKEN = "<|pointer_pad|>"
14 |
15 | # UNMASK_TOKEN_IDS = [198, 151644, 151645]
16 |
17 | # System Message
18 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()."
19 |
20 | # Chat Template
21 | chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}"
22 |
23 | assistant_template = "{% for message in messages %}{{'<|im_start|>' + message['role']}}{% if 'recipient' in message %}<|recipient|>{{ message['recipient'] }}{% endif %}{{'\n' + message['content'][0]['text']}}{% if 'end_turn' in message and message['end_turn'] %}{{'<|diff_marker|>\n'}}{% else %}{{'<|im_end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|recipient|>' }}{% endif %}"
24 |
25 | # Special Tokens
26 | ADDITIONAL_SPECIAL_TOKENS = [
27 | "<|recipient|>",
28 | "<|diff_marker|>",
29 | DEFAULT_POINTER_START_TOKEN,
30 | DEFAULT_POINTER_END_TOKEN,
31 | DEFAULT_POINTER_PAD_TOKEN,
32 | ]
33 |
34 | # Action Patterns to be replaced with special tokens
35 | ACTION_PATTENS_XY = [
36 | r"x=([0-9.]+), y=([0-9.]+)",
37 | r"from_coord=\[([0-9.]+), ([0-9.]+)\], to_coord=\[([0-9.]+), ([0-9.]+)\]",
38 | ]
39 |
40 | until = ["<|diff_marker|>"]
41 |
--------------------------------------------------------------------------------
/src/gui_actor/inference.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import json
3 | import re
4 | import os
5 | from qwen_vl_utils import process_vision_info
6 | from transformers import (
7 | Qwen2VLForConditionalGeneration,
8 | LogitsProcessor,
9 | LogitsProcessorList,
10 | AutoModelForCausalLM,
11 | AutoTokenizer
12 | )
13 | from gui_actor.constants import (
14 | DEFAULT_POINTER_END_TOKEN,
15 | DEFAULT_POINTER_PAD_TOKEN,
16 | chat_template
17 | )
18 |
19 | class ForceFollowTokensLogitsProcessor(LogitsProcessor):
20 | """
21 | Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token).
22 | Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]).
23 | As long as forced tokens remain in the queue, force them in the output.
24 | """
25 | def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]):
26 | super().__init__()
27 | self.token_a_id = token_a_id
28 | self.forced_sequence = forced_sequence # list of token IDs, e.g. [B_id, C_id]
29 | self.force_queue = [] # holds the tokens we still need to force
30 |
31 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
32 | """
33 | Called at each decoding step to modify `scores`.
34 |
35 | Args:
36 | input_ids: shape (batch_size, seq_len). The already-decoded tokens.
37 | scores: shape (batch_size, vocab_size). Model logits for the next token.
38 | """
39 | batch_size = input_ids.shape[0]
40 | if batch_size > 1:
41 | raise NotImplementedError("Batch size must be 1 for this logits processor.")
42 |
43 | # We assume batch_size=1 for simplicity; if you have multiple sequences,
44 | # you'll need to adapt the logic to handle each item in the batch.
45 | last_token_id = input_ids[0, -1].item()
46 |
47 | # If the last token was A, enqueue B and C
48 | if last_token_id == self.token_a_id:
49 | self.force_queue.extend(self.forced_sequence)
50 |
51 | # If we have forced tokens waiting in the queue, override the distribution
52 | if len(self.force_queue) > 0:
53 | forced_token = self.force_queue.pop(0) # next token to force
54 | # Create a mask of -inf for all tokens except the forced one
55 | new_scores = torch.full_like(scores, float('-inf'))
56 | new_scores[0, forced_token] = 0.0 # log prob = 0 => prob = 1
57 | return new_scores
58 |
59 | # Otherwise, return scores unmodified
60 | return scores
61 |
62 |
63 | def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False):
64 | """
65 | 1. Select activated patches
66 | 2. Divide connected patches into different regions
67 | 3. Calculate the average activation value for each region
68 | 4. Select the region with the highest average activation value
69 | 5. Return the center point of that region as the final prediction point
70 | """
71 |
72 | # Get patches with activation values greater than a certain proportion of the maximum activation value as activated patches
73 | # Get the highest activation value and threshold
74 | max_score = attn_scores[0].max().item()
75 | threshold = max_score * activation_threshold
76 | # Select all patches above the threshold
77 | mask = attn_scores[0] > threshold
78 | valid_indices = torch.nonzero(mask).squeeze(-1)
79 | topk_values = attn_scores[0][valid_indices]
80 | topk_indices = valid_indices
81 |
82 | # Convert indices to 2D coordinates
83 | topk_coords = []
84 | for idx in topk_indices.tolist():
85 | y = idx // n_width
86 | x = idx % n_width
87 | topk_coords.append((y, x, idx))
88 |
89 | # Divide into connected regions
90 | regions = []
91 | visited = set()
92 | for i, (y, x, idx) in enumerate(topk_coords):
93 | if idx in visited:
94 | continue
95 |
96 | # Start a new region
97 | region = [(y, x, idx, topk_values[i].item())]
98 | visited.add(idx)
99 | queue = [(y, x, idx, topk_values[i].item())]
100 |
101 | # BFS to find connected points
102 | while queue:
103 | cy, cx, c_idx, c_val = queue.pop(0)
104 |
105 | # Check 4 adjacent directions
106 | for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
107 | ny, nx = cy + dy, cx + dx
108 | n_idx = ny * n_width + nx
109 |
110 | # Check if this adjacent point is in the topk list
111 | for j, (ty, tx, t_idx) in enumerate(topk_coords):
112 | if ty == ny and tx == nx and t_idx not in visited:
113 | visited.add(t_idx)
114 | region.append((ny, nx, t_idx, topk_values[j].item()))
115 | queue.append((ny, nx, t_idx, topk_values[j].item()))
116 |
117 | regions.append(region)
118 |
119 | # Calculate the average activation value for each region
120 | region_scores = []
121 | region_centers = []
122 | region_points = []
123 |
124 | for region in regions:
125 | # Calculate average score for the region
126 | avg_score = sum(item[3] for item in region) / len(region)
127 | region_scores.append(avg_score)
128 |
129 | # Calculate normalized center coordinates for each patch, then take the average
130 | normalized_centers = []
131 | weights = []
132 | y_coords = set()
133 | x_coords = set()
134 |
135 | for y, x, _, score in region:
136 | # Normalized coordinates of the center point for each patch
137 | center_y = (y + 0.5) / n_height
138 | center_x = (x + 0.5) / n_width
139 | normalized_centers.append((center_x, center_y))
140 | weights.append(score)
141 |
142 | y_coords.add(center_y)
143 | x_coords.add(center_x)
144 |
145 | region_points.append(normalized_centers)
146 |
147 | # Calculate the average of normalized coordinates as the region center
148 | if not rect_center:
149 | # Weighted average
150 | total_weight = sum(weights)
151 | weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight
152 | weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight
153 | avg_center_x, avg_center_y = weighted_x, weighted_y
154 | # # Simple average
155 | # avg_center_x = sum(nc[0] for nc in normalized_centers) / len(normalized_centers)
156 | # avg_center_y = sum(nc[1] for nc in normalized_centers) / len(normalized_centers)
157 | else:
158 | avg_center_x = sum(x_coords) / len(x_coords)
159 | avg_center_y = sum(y_coords) / len(y_coords)
160 | region_centers.append((avg_center_x, avg_center_y))
161 |
162 | # Select the region with the highest average activation value
163 | sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True)
164 | sorted_scores = [region_scores[i] for i in sorted_indices]
165 | sorted_centers = [region_centers[i] for i in sorted_indices]
166 | sorted_points = [region_points[i] for i in sorted_indices]
167 | best_point = sorted_centers[0]
168 |
169 | if return_all_regions:
170 | # Outputs:
171 | # 1. best_point: the center point of the region with the highest average activation value
172 | # 2. sorted_centers: the center points of all regions, sorted by the average activation value in descending order
173 | # 3. sorted_scores: the average activation values of all regions, sorted in descending order
174 | # 4. sorted_points: the normalized center coordinates of all patches, sorted by the average activation value in descending order
175 | return best_point, sorted_centers, sorted_scores, sorted_points
176 | else:
177 | return best_point
178 |
179 |
180 | def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5):
181 | """
182 | conversation = [
183 | {
184 | "role": "system",
185 | "content": [
186 | {
187 | "type": "text",
188 | "text": grounding_system_message,
189 | }
190 | ]
191 | },
192 | {
193 | "role": "user",
194 | "content": [
195 | {
196 | "type": "image",
197 | "image": example["image"], # PIL.Image.Image or str to path
198 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "", will be split by "base64,"
199 | },
200 | {
201 | "type": "text",
202 | "text": example["instruction"]
203 | },
204 | ],
205 | },
206 | ]
207 | """
208 | if logits_processor is None:
209 | logits_processor = ForceFollowTokensLogitsProcessor(
210 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0],
211 | forced_sequence=[
212 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
213 | ]
214 | )
215 |
216 | assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)"
217 |
218 | pred = {
219 | "output_text": None, # generated text
220 | "n_width": None, # number of patch_tokens in width dimension
221 | "n_height": None, # number of patch_tokens in height dimension
222 | "attn_scores": None, # attention scores over the image patches
223 | "topk_points": None, # topk points
224 | "topk_values": None, # topk values
225 | "topk_points_all": None, # all points
226 | }
227 |
228 | # prepare text
229 | text = data_processor.apply_chat_template(conversation,
230 | tokenize=False,
231 | add_generation_prompt=False,
232 | chat_template=chat_template
233 | )
234 | text += assiatant_starter
235 |
236 | # prepare inputs
237 | image_inputs, video_inputs = process_vision_info(conversation)
238 | inputs = data_processor(text=[text],
239 | images=image_inputs,
240 | videos=video_inputs,
241 | padding=True,
242 | return_tensors="pt"
243 | )
244 | inputs = inputs.to(model.device)
245 |
246 | # generate
247 | results = model.generate(**inputs,
248 | max_new_tokens=2048 if not use_placeholder else 1,
249 | logits_processor=LogitsProcessorList([logits_processor]),
250 | return_dict_in_generate=True,
251 | output_hidden_states=True
252 | )
253 |
254 |
255 | # decode the generated ids
256 | input_ids = inputs["input_ids"][0]
257 | generated_ids = results.sequences[0][len(input_ids):]
258 | output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False)
259 | pred["output_text"] = output_text
260 |
261 | # check if there are is inside the input_ids or generated_ids
262 | if use_placeholder:
263 | pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) # n_all_input_tokens
264 | else:
265 | pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) # seq_len_generated_ids-1
266 |
267 | # if there are no in the input_ids or generated_ids, return the pred
268 | if len(pointer_pad_mask) == 0:
269 | return pred
270 |
271 | # otherwise, get the coordinate from the action head
272 | if use_placeholder:
273 | decoder_hidden_states = results.hidden_states[0][-1][0] # n_all_input_tokens, hidden_size
274 | else:
275 | decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]]
276 | decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) # seq_len_generated_ids-1, hidden_size
277 | decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] # n_pointer_pad_tokens, hidden_size
278 |
279 | # get the image embeddings as encoder vectors
280 | image_embeds = model.visual(inputs["pixel_values"], grid_thw=inputs["image_grid_thw"]) # n_image_tokens, hidden_size
281 |
282 | attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states)
283 | pred["attn_scores"] = attn_scores.tolist()
284 |
285 | _, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist()
286 | pred["n_width"] = n_width
287 | pred["n_height"] = n_height
288 |
289 | # get the topk points according to the attention scores
290 | best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False)
291 | topk_points = region_points[:topk] if len(region_points) > topk else region_points
292 | topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores
293 | topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all
294 | pred["topk_points"] = topk_points
295 | pred["topk_values"] = topk_values
296 | pred["topk_points_all"] = topk_points_all
297 |
298 | return pred
--------------------------------------------------------------------------------
/src/gui_actor/modeling.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration
6 | from gui_actor.constants import IGNORE_INDEX
7 | from typing import List, Tuple, Union, Optional
8 | from gui_actor.trainer import rank0_print
9 |
10 | class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast):
11 | """
12 | Output class for Qwen2VL with pointer head, extending the base output class.
13 |
14 | Args:
15 | lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
16 | Language modeling loss.
17 | pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
18 | Vision pointer network loss.
19 | pointer_scores (`List[torch.FloatTensor]`, *optional*):
20 | Attention scores from the pointer network, one tensor per batch item.
21 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
22 | Combined loss (weighted sum of lm_loss and pointer_loss).
23 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
24 | Prediction scores from the language modeling head.
25 | past_key_values, hidden_states, attentions, rope_deltas:
26 | Same as parent class.
27 | """
28 | def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
29 | super().__init__(*args, **kwargs)
30 | self.lm_loss = lm_loss
31 | self.pointer_loss = pointer_loss
32 | self.pointer_scores = pointer_scores
33 |
34 |
35 | class VisionHead_MultiPatch(nn.Module):
36 | def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
37 | super().__init__()
38 | self.d_model = d_model
39 |
40 | # Note: We omit additional normalization here because Qwen2VL
41 | # already normalizes hidden states using RMSNorm.
42 | self.projection_enc = nn.Sequential(
43 | nn.Linear(d_model, projection_dim),
44 | nn.GELU(),
45 | nn.Linear(projection_dim, d_model)
46 | )
47 | self.projection_dec = nn.Sequential(
48 | nn.Linear(d_model, projection_dim),
49 | nn.GELU(),
50 | nn.Linear(projection_dim, d_model)
51 | )
52 |
53 | # Add self-attention layer for visual features
54 | self.self_attention = nn.MultiheadAttention(
55 | embed_dim=d_model,
56 | num_heads=num_attention_heads,
57 | dropout=dropout_rate,
58 | batch_first=True
59 | )
60 |
61 | # Layer normalization and residual connection
62 | self.layer_norm = nn.LayerNorm(d_model)
63 | self.dropout = nn.Dropout(dropout_rate)
64 |
65 | def forward(self,
66 | hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
67 | hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
68 | labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
69 | do_single_patch: bool = False,
70 | ):
71 |
72 | enc_input = hidden_state_enc.unsqueeze(0)
73 | attn_output, _ = self.self_attention(
74 | query=enc_input,
75 | key=enc_input,
76 | value=enc_input,
77 | # attn_mask=attention_mask,
78 | need_weights=False
79 | )
80 | # Residual connection and layer normalization
81 | hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
82 | # Remove batch dimension
83 | hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
84 |
85 | # Apply the projection networks.
86 | proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
87 | proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
88 |
89 | # Compute scaled dot-product attention scores.
90 | # Scaling by sqrt(d_model) is critical regardless of variable n_enc.
91 | scaling = self.d_model ** 0.5
92 | patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
93 |
94 | # Softmax normalization is applied along the encoder dimension.
95 | attn_weights = F.softmax(patch_logits, dim=-1)
96 |
97 | loss = None
98 | if (labels is not None) and (not do_single_patch):
99 | epsilon = 1e-8
100 | labels_float = labels.float()
101 | # Normalize each row to get target probability distribution
102 | target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
103 |
104 | # Apply log_softmax to logits
105 | pred_log_probs = F.log_softmax(patch_logits, dim=-1)
106 | # Use KL divergence as loss
107 | loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
108 |
109 | if do_single_patch and (labels is not None):
110 | loss = F.cross_entropy(attn_scores, labels)
111 |
112 | return attn_weights, loss
113 |
114 |
115 | class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration):
116 | def __init__(self, *args, **kwargs):
117 | super().__init__(*args, **kwargs)
118 | self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
119 | self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
120 | self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
121 | self.post_init()
122 |
123 | def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
124 | self.pointer_loss_weight = pointer_loss_weight
125 | self.lm_loss_weight = lm_loss_weight
126 |
127 | def forward(self,
128 | input_ids: torch.LongTensor = None, # (batch_size, seq_len)
129 | attention_mask: Optional[torch.Tensor] = None,
130 | position_ids: Optional[torch.LongTensor] = None,
131 | past_key_values: Optional[List[torch.FloatTensor]] = None,
132 | inputs_embeds: Optional[torch.FloatTensor] = None,
133 | labels: Optional[torch.LongTensor] = None,
134 | use_cache: Optional[bool] = None,
135 | output_attentions: Optional[bool] = None,
136 | output_hidden_states: Optional[bool] = None,
137 | return_dict: Optional[bool] = None,
138 | pixel_values: Optional[torch.Tensor] = None,
139 | pixel_values_videos: Optional[torch.FloatTensor] = None,
140 | image_grid_thw: Optional[torch.LongTensor] = None,
141 | video_grid_thw: Optional[torch.LongTensor] = None,
142 | rope_deltas: Optional[torch.LongTensor] = None,
143 | cache_position: Optional[torch.LongTensor] = None,
144 | # Grounding
145 | visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
146 | multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
147 | if_multi_patch: bool = True,
148 | coordinates: Optional[List[Tuple[float, float]]] = None,
149 | verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
150 |
151 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
152 | output_hidden_states = (
153 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
154 | )
155 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
156 |
157 | if verbose:
158 | rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
159 | rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
160 | rank0_print(f"pixel_values: {pixel_values.shape}")
161 | rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
162 | rank0_print(f"coordinates: {coordinates}")
163 | rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
164 | rank0_print(f"return_dict: {return_dict}")
165 |
166 | if inputs_embeds is None:
167 | inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
168 | if pixel_values is not None:
169 | pixel_values = pixel_values.type(self.visual.dtype)
170 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
171 | n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
172 | n_image_features = image_embeds.shape[0]
173 | if n_image_tokens != n_image_features:
174 | raise ValueError(
175 | f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
176 | )
177 | image_mask = (
178 | (input_ids == self.config.image_token_id)
179 | .unsqueeze(-1)
180 | .expand_as(inputs_embeds)
181 | .to(inputs_embeds.device)
182 | )
183 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
184 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
185 |
186 | if pixel_values_videos is not None:
187 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
188 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
189 | n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
190 | n_video_features = video_embeds.shape[0]
191 | if n_video_tokens != n_video_features:
192 | raise ValueError(
193 | f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
194 | )
195 | video_mask = (
196 | (input_ids == self.config.video_token_id)
197 | .unsqueeze(-1)
198 | .expand_as(inputs_embeds)
199 | .to(inputs_embeds.device)
200 | )
201 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
202 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
203 |
204 | if attention_mask is not None:
205 | attention_mask = attention_mask.to(inputs_embeds.device)
206 |
207 | # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
208 | if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
209 | # calculate RoPE index once per generation in the pre-fill stage only
210 | if (
211 | (cache_position is not None and cache_position[0] == 0)
212 | or self.rope_deltas is None
213 | or (past_key_values is None or past_key_values.get_seq_length() == 0)
214 | ):
215 | position_ids, rope_deltas = self.get_rope_index(
216 | input_ids, image_grid_thw, video_grid_thw, attention_mask
217 | )
218 | self.rope_deltas = rope_deltas
219 | # then use the prev pre-calculated rope-deltas to get the correct position ids
220 | else:
221 | batch_size, seq_length, _ = inputs_embeds.shape
222 | delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
223 | position_ids = torch.arange(seq_length, device=inputs_embeds.device)
224 | position_ids = position_ids.view(1, -1).expand(batch_size, -1)
225 | if cache_position is not None: # otherwise `deltas` is an int `0`
226 | delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
227 | delta = delta.to(position_ids.device)
228 | position_ids = position_ids.add(delta)
229 | position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
230 |
231 | outputs = self.model(
232 | input_ids=None,
233 | position_ids=position_ids,
234 | attention_mask=attention_mask,
235 | past_key_values=past_key_values,
236 | inputs_embeds=inputs_embeds,
237 | use_cache=use_cache,
238 | output_attentions=output_attentions,
239 | output_hidden_states=output_hidden_states,
240 | return_dict=return_dict,
241 | cache_position=cache_position,
242 | )
243 |
244 | hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
245 | logits = self.lm_head(hidden_states)
246 |
247 | lm_loss = None
248 | if labels is not None and self.lm_loss_weight > 0:
249 | # Upcast to float if we need to compute the loss to avoid potential precision issues
250 | logits = logits.float()
251 | # Shift so that tokens < n predict n
252 | shift_logits = logits[..., :-1, :].contiguous()
253 | shift_labels = labels[..., 1:].contiguous()
254 | # Flatten the tokens
255 | loss_fct = nn.CrossEntropyLoss()
256 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
257 | shift_labels = shift_labels.view(-1)
258 | # Enable model parallelism
259 | shift_labels = shift_labels.to(shift_logits.device)
260 | lm_loss = loss_fct(shift_logits, shift_labels)
261 |
262 |
263 | # If vision supervision is requested, process the action head.
264 | pointer_loss = None
265 | pointer_scores = []
266 | if visual_token_indices_of_coordinates is not None:
267 | batch_size = input_ids.shape[0]
268 | pointer_losses = []
269 |
270 | # Process each sample individually because the number of visual and target tokens may vary.
271 | for i in range(batch_size):
272 | dummy_target = False
273 |
274 | # Get the token ids and corresponding hidden states for sample i.
275 | token_ids = input_ids[i] # shape: (seq_length,)
276 | hs = hidden_states[i] # shape: (seq_length, d_model)
277 |
278 | # Identify visual tokens indices.
279 | visual_mask = (token_ids == self.config.image_token_id)
280 | visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
281 |
282 | # Identify target tokens (the ones that should attend to visual features).
283 | target_mask = (token_ids == self.config.pointer_pad_token_id)
284 | target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
285 |
286 | # If either visual or target tokens are missing, skip this sample.
287 | if visual_indices.numel() == 0:
288 | raise ValueError(f"No visual or target tokens found for sample {i}.")
289 | if target_indices.numel() == 0:
290 | target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
291 | gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
292 | if if_multi_patch: # task the first 4 visual tokens as the ground truth
293 | sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
294 | sample_labels[0][:4] = 1
295 | dummy_target = True
296 | else:
297 | # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
298 | # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
299 | gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
300 | if if_multi_patch:
301 | sample_labels = multi_patch_labels[i]
302 |
303 | # Gather the corresponding hidden state representations.
304 | # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
305 | visual_embeds = inputs_embeds[i][visual_indices]
306 | target_hidden = hs[target_indices] # shape: (n_target, d_model)
307 |
308 | # Calculate loss for multi-patch mode
309 | if if_multi_patch:
310 | # Ensure the number of targets matches between sample and labels
311 | if sample_labels.shape[0] != target_indices.shape[0]:
312 | raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
313 |
314 | # Process using VisionHead_MultiPatch
315 | attn_scores, loss_v = self.multi_patch_pointer_head(
316 | visual_embeds,
317 | target_hidden,
318 | labels=sample_labels
319 | )
320 |
321 | else:
322 | # Deprecated branch - single patch mode is no longer used
323 | # Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
324 | attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
325 |
326 | pointer_scores.append(attn_scores.detach().cpu())
327 |
328 | pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
329 |
330 | pointer_loss = torch.stack(pointer_losses).mean()
331 |
332 | # Combine the LM loss and vision loss using the provided loss weights.
333 |
334 | if lm_loss is None:
335 | total_loss = pointer_loss
336 | elif pointer_loss is None:
337 | total_loss = lm_loss
338 | else:
339 | total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
340 |
341 | if return_dict:
342 | return QwenVLwithVisionHeadOutputWithPast(
343 | lm_loss=lm_loss,
344 | pointer_loss=pointer_loss,
345 | pointer_scores=pointer_scores,
346 | loss=total_loss,
347 | logits=logits,
348 | past_key_values=outputs.past_key_values,
349 | hidden_states=outputs.hidden_states,
350 | attentions=outputs.attentions,
351 | rope_deltas=self.rope_deltas,
352 | )
353 | else:
354 | # When labels are provided, parent's forward returns a tuple with loss as the first element.
355 | if labels is not None:
356 | # Replace the LM loss with the combined loss.
357 | output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
358 | print(f"returning: total_loss, logits, pointer_scores, ...")
359 | return (total_loss,) + output if total_loss is not None else output
360 | else:
361 | return outputs
--------------------------------------------------------------------------------
/src/gui_actor/modeling_qwen25vl.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast, Qwen2_5_VLForConditionalGeneration
6 | from gui_actor.constants import IGNORE_INDEX
7 | from typing import List, Tuple, Union, Optional
8 | from gui_actor.trainer import rank0_print
9 |
10 | class QwenVLwithVisionHeadOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast):
11 | """
12 | Output class for Qwen2_5_VL with pointer head, extending the base output class.
13 |
14 | Args:
15 | lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
16 | Language modeling loss.
17 | pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
18 | Vision pointer network loss.
19 | pointer_scores (`List[torch.FloatTensor]`, *optional*):
20 | Attention scores from the pointer network, one tensor per batch item.
21 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
22 | Combined loss (weighted sum of lm_loss and pointer_loss).
23 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
24 | Prediction scores from the language modeling head.
25 | past_key_values, hidden_states, attentions, rope_deltas:
26 | Same as parent class.
27 | """
28 | def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs):
29 | super().__init__(*args, **kwargs)
30 | self.lm_loss = lm_loss
31 | self.pointer_loss = pointer_loss
32 | self.pointer_scores = pointer_scores
33 |
34 |
35 | class VisionHead_MultiPatch(nn.Module):
36 | def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1):
37 | super().__init__()
38 | self.d_model = d_model
39 |
40 | # Note: We omit additional normalization here because Qwen2VL
41 | # already normalizes hidden states using RMSNorm.
42 | self.projection_enc = nn.Sequential(
43 | nn.Linear(d_model, projection_dim),
44 | nn.GELU(),
45 | nn.Linear(projection_dim, d_model)
46 | )
47 | self.projection_dec = nn.Sequential(
48 | nn.Linear(d_model, projection_dim),
49 | nn.GELU(),
50 | nn.Linear(projection_dim, d_model)
51 | )
52 |
53 | # Add self-attention layer for visual features
54 | self.self_attention = nn.MultiheadAttention(
55 | embed_dim=d_model,
56 | num_heads=num_attention_heads,
57 | dropout=dropout_rate,
58 | batch_first=True
59 | )
60 |
61 | # Layer normalization and residual connection
62 | self.layer_norm = nn.LayerNorm(d_model)
63 | self.dropout = nn.Dropout(dropout_rate)
64 |
65 | def forward(self,
66 | hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size
67 | hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample
68 | labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox
69 | do_single_patch: bool = False,
70 | ):
71 |
72 | enc_input = hidden_state_enc.unsqueeze(0)
73 | attn_output, _ = self.self_attention(
74 | query=enc_input,
75 | key=enc_input,
76 | value=enc_input,
77 | # attn_mask=attention_mask,
78 | need_weights=False
79 | )
80 | # Residual connection and layer normalization
81 | hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output))
82 | # Remove batch dimension
83 | hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model]
84 |
85 | # Apply the projection networks.
86 | proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model]
87 | proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model]
88 |
89 | # Compute scaled dot-product attention scores.
90 | # Scaling by sqrt(d_model) is critical regardless of variable n_enc.
91 | scaling = self.d_model ** 0.5
92 | patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc]
93 |
94 | # Softmax normalization is applied along the encoder dimension.
95 | attn_weights = F.softmax(patch_logits, dim=-1)
96 |
97 | loss = None
98 | if (labels is not None) and (not do_single_patch):
99 | epsilon = 1e-8
100 | labels_float = labels.float()
101 | # Normalize each row to get target probability distribution
102 | target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon)
103 |
104 | # Apply log_softmax to logits
105 | pred_log_probs = F.log_softmax(patch_logits, dim=-1)
106 | # Use KL divergence as loss
107 | loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean')
108 |
109 | if do_single_patch and (labels is not None):
110 | loss = F.cross_entropy(attn_scores, labels)
111 |
112 | return attn_weights, loss
113 |
114 |
115 | class Qwen2_5_VLForConditionalGenerationWithPointer(Qwen2_5_VLForConditionalGeneration):
116 | def __init__(self, *args, **kwargs):
117 | super().__init__(*args, **kwargs)
118 | self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size)
119 | self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0)
120 | self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0)
121 | self.post_init()
122 |
123 | def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight):
124 | self.pointer_loss_weight = pointer_loss_weight
125 | self.lm_loss_weight = lm_loss_weight
126 |
127 | def forward(self,
128 | input_ids: torch.LongTensor = None, # (batch_size, seq_len)
129 | attention_mask: Optional[torch.Tensor] = None,
130 | position_ids: Optional[torch.LongTensor] = None,
131 | past_key_values: Optional[List[torch.FloatTensor]] = None,
132 | inputs_embeds: Optional[torch.FloatTensor] = None,
133 | labels: Optional[torch.LongTensor] = None,
134 | use_cache: Optional[bool] = None,
135 | output_attentions: Optional[bool] = None,
136 | output_hidden_states: Optional[bool] = None,
137 | return_dict: Optional[bool] = None,
138 | pixel_values: Optional[torch.Tensor] = None,
139 | pixel_values_videos: Optional[torch.FloatTensor] = None,
140 | image_grid_thw: Optional[torch.LongTensor] = None,
141 | video_grid_thw: Optional[torch.LongTensor] = None,
142 | rope_deltas: Optional[torch.LongTensor] = None,
143 | cache_position: Optional[torch.LongTensor] = None,
144 | second_per_grid_ts: Optional[torch.Tensor] = None,
145 | # Grounding
146 | visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token
147 | multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox
148 | if_multi_patch: bool = True,
149 | coordinates: Optional[List[Tuple[float, float]]] = None,
150 | verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]:
151 |
152 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
153 | output_hidden_states = (
154 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
155 | )
156 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict
157 |
158 | if verbose:
159 | rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...")
160 | rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...")
161 | rank0_print(f"pixel_values: {pixel_values.shape}")
162 | rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}")
163 | rank0_print(f"coordinates: {coordinates}")
164 | rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}")
165 | rank0_print(f"return_dict: {return_dict}")
166 |
167 | if inputs_embeds is None:
168 | inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model)
169 | if pixel_values is not None:
170 | pixel_values = pixel_values.type(self.visual.dtype)
171 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
172 | n_image_tokens = (input_ids == self.config.image_token_id).sum().item()
173 | n_image_features = image_embeds.shape[0]
174 | if n_image_tokens != n_image_features:
175 | raise ValueError(
176 | f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
177 | )
178 | image_mask = (
179 | (input_ids == self.config.image_token_id)
180 | .unsqueeze(-1)
181 | .expand_as(inputs_embeds)
182 | .to(inputs_embeds.device)
183 | )
184 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
185 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
186 |
187 | if pixel_values_videos is not None:
188 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
189 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
190 | n_video_tokens = (input_ids == self.config.video_token_id).sum().item()
191 | n_video_features = video_embeds.shape[0]
192 | if n_video_tokens != n_video_features:
193 | raise ValueError(
194 | f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}"
195 | )
196 | video_mask = (
197 | (input_ids == self.config.video_token_id)
198 | .unsqueeze(-1)
199 | .expand_as(inputs_embeds)
200 | .to(inputs_embeds.device)
201 | )
202 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
203 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
204 |
205 | if attention_mask is not None:
206 | attention_mask = attention_mask.to(inputs_embeds.device)
207 |
208 | # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme
209 | if position_ids is None and (attention_mask is None or attention_mask.ndim == 2):
210 | # calculate RoPE index once per generation in the pre-fill stage only
211 | if (
212 | (cache_position is not None and cache_position[0] == 0)
213 | or self.rope_deltas is None
214 | or (past_key_values is None or past_key_values.get_seq_length() == 0)
215 | ):
216 | position_ids, rope_deltas = self.get_rope_index(
217 | input_ids, image_grid_thw, video_grid_thw, attention_mask
218 | )
219 | self.rope_deltas = rope_deltas
220 | # then use the prev pre-calculated rope-deltas to get the correct position ids
221 | else:
222 | batch_size, seq_length, _ = inputs_embeds.shape
223 | delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
224 | position_ids = torch.arange(seq_length, device=inputs_embeds.device)
225 | position_ids = position_ids.view(1, -1).expand(batch_size, -1)
226 | if cache_position is not None: # otherwise `deltas` is an int `0`
227 | delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
228 | delta = delta.to(position_ids.device)
229 | position_ids = position_ids.add(delta)
230 | position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
231 |
232 | outputs = self.model(
233 | input_ids=None,
234 | position_ids=position_ids,
235 | attention_mask=attention_mask,
236 | past_key_values=past_key_values,
237 | inputs_embeds=inputs_embeds,
238 | use_cache=use_cache,
239 | output_attentions=output_attentions,
240 | output_hidden_states=output_hidden_states,
241 | return_dict=return_dict,
242 | cache_position=cache_position,
243 | )
244 |
245 | hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model)
246 | logits = self.lm_head(hidden_states)
247 |
248 | lm_loss = None
249 | if labels is not None and self.lm_loss_weight > 0:
250 | # Upcast to float if we need to compute the loss to avoid potential precision issues
251 | logits = logits.float()
252 | # Shift so that tokens < n predict n
253 | shift_logits = logits[..., :-1, :].contiguous()
254 | shift_labels = labels[..., 1:].contiguous()
255 | # Flatten the tokens
256 | loss_fct = nn.CrossEntropyLoss()
257 | shift_logits = shift_logits.view(-1, self.config.vocab_size)
258 | shift_labels = shift_labels.view(-1)
259 | # Enable model parallelism
260 | shift_labels = shift_labels.to(shift_logits.device)
261 | lm_loss = loss_fct(shift_logits, shift_labels)
262 |
263 |
264 | # If vision supervision is requested, process the action head.
265 | pointer_loss = None
266 | pointer_scores = []
267 | if visual_token_indices_of_coordinates is not None:
268 | batch_size = input_ids.shape[0]
269 | pointer_losses = []
270 |
271 | # Process each sample individually because the number of visual and target tokens may vary.
272 | for i in range(batch_size):
273 | dummy_target = False
274 |
275 | # Get the token ids and corresponding hidden states for sample i.
276 | token_ids = input_ids[i] # shape: (seq_length,)
277 | hs = hidden_states[i] # shape: (seq_length, d_model)
278 |
279 | # Identify visual tokens indices.
280 | visual_mask = (token_ids == self.config.image_token_id)
281 | visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,)
282 |
283 | # Identify target tokens (the ones that should attend to visual features).
284 | target_mask = (token_ids == self.config.pointer_pad_token_id)
285 | target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1)
286 |
287 | # If either visual or target tokens are missing, skip this sample.
288 | if visual_indices.numel() == 0:
289 | raise ValueError(f"No visual or target tokens found for sample {i}.")
290 | if target_indices.numel() == 0:
291 | target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token
292 | gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth
293 | if if_multi_patch: # task the first 4 visual tokens as the ground truth
294 | sample_labels = torch.zeros_like(visual_indices).unsqueeze(0)
295 | sample_labels[0][:4] = 1
296 | # n_t = target_indices.size(0) # 目标 token 个数
297 | # n_v = visual_indices.size(0)
298 | # sample_labels = torch.zeros(
299 | # (n_t, n_v), device=hs.device, dtype=torch.float
300 | # )
301 | # sample_labels[:, :min(4, n_v)] = 1
302 | dummy_target = True
303 | else:
304 | # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,)
305 | # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token.
306 | gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,)
307 | if if_multi_patch:
308 | sample_labels = multi_patch_labels[i]
309 | # if sample_labels is None:
310 | # n_t = target_indices.size(0) # 目标 token 个数
311 | # n_v = visual_indices.size(0)
312 | # sample_labels = torch.zeros(
313 | # (n_t, n_v), device=hs.device, dtype=torch.float
314 | # )
315 | # sample_labels[:, :min(4, n_v)] = 1
316 | # dummy_target = True
317 |
318 | # Gather the corresponding hidden state representations.
319 | # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model)
320 | visual_embeds = inputs_embeds[i][visual_indices]
321 | target_hidden = hs[target_indices] # shape: (n_target, d_model)
322 |
323 | # Calculate loss for multi-patch mode
324 | if if_multi_patch:
325 | # Ensure the number of targets matches between sample and labels
326 | if sample_labels.shape[0] != target_indices.shape[0]:
327 | raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens")
328 |
329 | # Process using VisionHead_MultiPatch
330 | attn_scores, loss_v = self.multi_patch_pointer_head(
331 | visual_embeds,
332 | target_hidden,
333 | labels=sample_labels
334 | )
335 |
336 | else:
337 | # Deprecated branch - single patch mode is no longer used
338 | # Run the action head to compute the attention (from target tokens to visual tokens) and its loss.
339 | attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt)
340 |
341 | pointer_scores.append(attn_scores.detach().cpu())
342 |
343 | pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v)
344 |
345 | pointer_loss = torch.stack(pointer_losses).mean()
346 |
347 | # Combine the LM loss and vision loss using the provided loss weights.
348 |
349 | if lm_loss is None:
350 | total_loss = pointer_loss
351 | elif pointer_loss is None:
352 | total_loss = lm_loss
353 | else:
354 | total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss
355 |
356 | if return_dict:
357 | return QwenVLwithVisionHeadOutputWithPast(
358 | lm_loss=lm_loss,
359 | pointer_loss=pointer_loss,
360 | pointer_scores=pointer_scores,
361 | loss=total_loss,
362 | logits=logits,
363 | past_key_values=outputs.past_key_values,
364 | hidden_states=outputs.hidden_states,
365 | attentions=outputs.attentions,
366 | rope_deltas=self.rope_deltas,
367 | )
368 | else:
369 | # When labels are provided, parent's forward returns a tuple with loss as the first element.
370 | if labels is not None:
371 | # Replace the LM loss with the combined loss.
372 | output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:]
373 | print(f"returning: total_loss, logits, pointer_scores, ...")
374 | return (total_loss,) + output if total_loss is not None else output
375 | else:
376 | return outputs
--------------------------------------------------------------------------------
/src/gui_actor/trainer.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from functools import wraps
3 | from typing import Optional
4 |
5 | import torch
6 | import torch.distributed as dist
7 | import transformers
8 | from accelerate import Accelerator, DataLoaderConfiguration
9 | from accelerate.utils import GradientAccumulationPlugin, InitProcessGroupKwargs
10 | from torch.utils.data import DataLoader, RandomSampler
11 | from transformers import Trainer
12 | from transformers.trainer import (
13 | ALL_LAYERNORM_LAYERS,
14 | get_parameter_names,
15 | has_length,
16 | is_accelerate_available,
17 | is_datasets_available,
18 | is_sagemaker_mp_enabled,
19 | )
20 | from transformers.trainer_pt_utils import LengthGroupedSampler as HFLengthGroupedSampler
21 | from transformers.trainer_utils import seed_worker
22 | from transformers.utils import logging
23 |
24 | if is_datasets_available():
25 | import datasets
26 |
27 |
28 | def rank0_print(*args):
29 | if dist.is_initialized():
30 | if dist.get_rank() == 0:
31 | print(f"Rank {dist.get_rank()}: ", *args)
32 | else:
33 | print(*args)
34 |
35 |
36 | def maybe_zero_3(param, ignore_status=False, name=None):
37 | from deepspeed import zero
38 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
39 |
40 | if hasattr(param, "ds_id"):
41 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE and not ignore_status:
42 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}")
43 | with zero.GatheredParameters([param]):
44 | param = param.data.detach().cpu().clone()
45 | else:
46 | param = param.detach().cpu().clone()
47 | return param
48 |
49 |
50 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match):
51 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)}
52 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()}
53 | return to_return
54 |
55 |
56 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str):
57 | """Collects the state dict and dump to disk."""
58 | trainer.accelerator.wait_for_everyone()
59 | torch.cuda.synchronize()
60 |
61 | if trainer.deepspeed:
62 | trainer.save_model(output_dir)
63 | return
64 |
65 | state_dict = trainer.model.state_dict()
66 | if trainer.args.should_save:
67 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
68 | del state_dict
69 | trainer._save(output_dir, state_dict=cpu_state_dict)
70 |
71 |
72 | class AGUVISTrainer(Trainer):
73 |
74 | def __init__(self, *args, **kwargs):
75 | super().__init__(*args, **kwargs)
76 |
77 | original_save = self._save
78 | original_save_model = self.save_model
79 |
80 | def modify_eos_token(func):
81 | @wraps(func)
82 | def wrapper(*args, **kwargs):
83 | tokenizer = self.processing_class.tokenizer
84 | old_config_id = self.model.config.eos_token_id
85 | old_eos_token = tokenizer.eos_token
86 | old_generation_config_eos_token_id = (
87 | self.model.generation_config.eos_token_id if hasattr(self.model, "generation_config") else None
88 | )
89 |
90 | try:
91 | new_eos_token_id = tokenizer.convert_tokens_to_ids("<|diff_marker|>")
92 | self.model.config.eos_token_id = [new_eos_token_id]
93 | tokenizer.eos_token = "<|diff_marker|>"
94 | if hasattr(self.model, "generation_config"):
95 | self.model.generation_config.eos_token_id = [new_eos_token_id]
96 |
97 | print("Set eos token id to", new_eos_token_id)
98 | print("Set eos token to", "<|diff_marker|>")
99 | print("Set generation config eos token id to", [new_eos_token_id])
100 |
101 | result = func(*args, **kwargs)
102 | return result
103 | finally:
104 | self.model.config.eos_token_id = old_config_id
105 | tokenizer.eos_token = old_eos_token
106 | if hasattr(self.model, "generation_config") and old_generation_config_eos_token_id is not None:
107 | self.model.generation_config.eos_token_id = old_generation_config_eos_token_id
108 |
109 | print("Set eos token id back to", old_config_id)
110 | print("Set eos token back to", old_eos_token)
111 | if old_generation_config_eos_token_id is not None:
112 | print("Set generation config eos token id back to", old_generation_config_eos_token_id)
113 |
114 | return wrapper
115 |
116 | self._save = modify_eos_token(original_save)
117 | self.save_model = modify_eos_token(original_save_model)
118 |
119 | def create_accelerator_and_postprocess(self):
120 | grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps}
121 | grad_acc_kwargs["sync_with_dataloader"] = False
122 | gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs)
123 |
124 | accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
125 |
126 | # create accelerator object
127 | dispatch_batches = getattr(self.args, "dispatch_batches", None)
128 | split_batches = getattr(self.args, "split_batches", None)
129 | self.dataloader_config = DataLoaderConfiguration(
130 | dispatch_batches=dispatch_batches,
131 | split_batches=split_batches,
132 | )
133 | self.accelerator = Accelerator(
134 | dataloader_config=self.dataloader_config,
135 | deepspeed_plugin=self.args.deepspeed_plugin,
136 | gradient_accumulation_plugin=gradient_accumulation_plugin,
137 | kwargs_handlers=[accelerator_kwargs],
138 | )
139 | # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag
140 | self.gather_function = self.accelerator.gather_for_metrics
141 |
142 | # deepspeed and accelerate flags covering both trainer args and accelerate launcher
143 | self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None
144 | self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None
145 |
146 | # post accelerator creation setup
147 | if self.is_fsdp_enabled:
148 | fsdp_plugin = self.accelerator.state.fsdp_plugin
149 | fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
150 | "limit_all_gathers", fsdp_plugin.limit_all_gathers
151 | )
152 | if is_accelerate_available("0.23.0"):
153 | fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get(
154 | "activation_checkpointing", fsdp_plugin.activation_checkpointing
155 | )
156 | if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
157 | raise ValueError(
158 | "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "
159 | "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic "
160 | "when using FSDP."
161 | )
162 |
163 | if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None:
164 | self.propagate_args_to_deepspeed()
165 |
166 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
167 | if self.train_dataset is None or not has_length(self.train_dataset):
168 | return None
169 |
170 | if self.args.group_by_length:
171 | lengths = self.train_dataset.lengths
172 | return HFLengthGroupedSampler(
173 | self.args.train_batch_size * self.args.gradient_accumulation_steps,
174 | dataset=self.train_dataset,
175 | lengths=lengths,
176 | )
177 | elif self.args.group_by_modality_length:
178 | lengths = self.train_dataset.modality_lengths
179 | return HFLengthGroupedSampler(
180 | self.args.train_batch_size * self.args.gradient_accumulation_steps,
181 | dataset=self.train_dataset,
182 | lengths=lengths,
183 | )
184 | else:
185 | return RandomSampler(self.train_dataset)
186 |
187 | def get_train_dataloader(self) -> DataLoader:
188 | """
189 | Returns the training [`~torch.utils.data.DataLoader`].
190 |
191 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed
192 | training if necessary) otherwise.
193 |
194 | Subclass and override this method if you want to inject some custom behavior.
195 | """
196 | if self.train_dataset is None:
197 | raise ValueError("Trainer: training requires a train_dataset.")
198 |
199 | train_dataset = self.train_dataset
200 | data_collator = self.data_collator
201 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset):
202 | train_dataset = self._remove_unused_columns(train_dataset, description="training")
203 | else:
204 | data_collator = self._get_collator_with_removed_columns(data_collator, description="training")
205 |
206 | dataloader_params = {
207 | "batch_size": self._train_batch_size,
208 | "collate_fn": data_collator,
209 | "num_workers": self.args.dataloader_num_workers,
210 | "pin_memory": self.args.dataloader_pin_memory,
211 | "persistent_workers": self.args.dataloader_persistent_workers,
212 | }
213 |
214 | if not isinstance(train_dataset, torch.utils.data.IterableDataset):
215 | dataloader_params["sampler"] = self._get_train_sampler()
216 | dataloader_params["drop_last"] = self.args.dataloader_drop_last
217 | dataloader_params["worker_init_fn"] = seed_worker
218 | dataloader_params["prefetch_factor"] = (
219 | self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None
220 | )
221 |
222 | dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
223 |
224 | return dataloader
225 |
226 | def create_optimizer(self):
227 | """
228 | Setup the optimizer.
229 |
230 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
231 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
232 | """
233 | if is_sagemaker_mp_enabled():
234 | return super().create_optimizer()
235 |
236 | opt_model = self.model
237 |
238 | if self.optimizer is None:
239 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
240 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
241 | optimizer_grouped_parameters = [
242 | {
243 | "params": [
244 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad)
245 | ],
246 | "weight_decay": self.args.weight_decay,
247 | },
248 | {
249 | "params": [
250 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad)
251 | ],
252 | "weight_decay": 0.0,
253 | },
254 | ]
255 |
256 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args)
257 |
258 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
259 |
260 | return self.optimizer
261 |
262 | def create_optimizer_with_different_learning_rates(self):
263 | """
264 | Setup the optimizer.
265 |
266 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the
267 | Trainer's init through `optimizers`, or subclass and override this method in a subclass.
268 | """
269 | if is_sagemaker_mp_enabled():
270 | raise NotImplementedError("Sagemaker MP is not supported for separate learning rate yet")
271 | return super().create_optimizer()
272 |
273 | opt_model = self.model
274 |
275 | if self.optimizer is None:
276 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
277 | decay_parameters = [name for name in decay_parameters if "bias" not in name]
278 |
279 | new_parameters = []
280 | for name, param in opt_model.named_parameters():
281 | if ("pointer_head" in name) or ("embed_tokens" in name):
282 | new_parameters.append(name)
283 | rank0_print(f"new_parameters: {len(new_parameters)}")
284 |
285 | optimizer_grouped_parameters = [
286 | {
287 | "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
288 | "weight_decay": self.args.weight_decay,
289 | "lr": self.args.learning_rate,
290 | },
291 | {
292 | "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n not in new_parameters) and p.requires_grad)],
293 | "weight_decay": 0.0,
294 | "lr": self.args.learning_rate,
295 | },
296 | {
297 | "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n in new_parameters) and p.requires_grad)],
298 | "weight_decay": self.args.weight_decay,
299 | "lr": self.args.learning_rate_new_params,
300 | },
301 | {
302 | "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n in new_parameters) and p.requires_grad)],
303 | "weight_decay": 0.0,
304 | "lr": self.args.learning_rate_new_params,
305 | },
306 | ]
307 |
308 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) # {'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08}
309 | optimizer_kwargs.pop("lr")
310 |
311 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
312 |
313 | return self.optimizer
--------------------------------------------------------------------------------
/src/gui_actor/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image, ImageDraw, ImageColor
2 | import json
3 | import os
4 |
5 | def dump_args_to_json(model_config, data_processor, model_args, data_args, training_args, output_dir):
6 | def is_json_serializable(v):
7 | try:
8 | json.dumps(v)
9 | return True
10 | except:
11 | return False
12 |
13 | save_path = f"{output_dir}/args.json"
14 | if not os.path.exists(save_path):
15 | with open(save_path, "w") as f:
16 | json.dump({
17 | "model_config": {k: v for k, v in model_config.__dict__.items() if is_json_serializable(v)},
18 | "data_processor_config": {k: v for k, v in data_processor.__dict__.items() if is_json_serializable(v)},
19 | "image_processor_config": {k: v for k, v in data_processor.image_processor.__dict__.items() if is_json_serializable(v)},
20 | "model_args": {k: v for k, v in model_args.__dict__.items() if is_json_serializable(v)},
21 | "data_args": {k: v for k, v in data_args.__dict__.items() if is_json_serializable(v)},
22 | "training_args": {k: v for k, v in training_args.__dict__.items() if is_json_serializable(v)},
23 | }, f, indent=4)
24 |
25 | def draw_point(image: Image.Image, point: list, color=None):
26 | if isinstance(color, str):
27 | try:
28 | color = ImageColor.getrgb(color)
29 | color = color + (128,)
30 | except ValueError:
31 | color = (255, 0, 0, 128)
32 | else:
33 | color = (255, 0, 0, 128)
34 |
35 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
36 | overlay_draw = ImageDraw.Draw(overlay)
37 | radius = 14
38 | x, y = point
39 |
40 | overlay_draw.rectangle(
41 | [x - radius, y - radius, x + radius, y + radius],
42 | fill=color
43 | )
44 |
45 | center_radius = radius * 0.1
46 | overlay_draw.ellipse(
47 | [(x - center_radius, y - center_radius),
48 | (x + center_radius, y + center_radius)],
49 | fill=(0, 255, 0, 255)
50 | )
51 |
52 | image = image.convert('RGBA')
53 | combined = Image.alpha_composite(image, overlay)
54 |
55 | return combined.convert('RGB')
56 |
57 | def draw_bbox(image: Image.Image, bbox: list, color=None):
58 | """bbox is in the format of [x1, y1, x2, y2]"""
59 | if isinstance(color, str):
60 | try:
61 | color = ImageColor.getrgb(color)
62 | color = color + (128,)
63 | except ValueError:
64 | color = (255, 0, 0, 128)
65 | else:
66 | color = (255, 0, 0, 128)
67 |
68 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0))
69 | overlay_draw = ImageDraw.Draw(overlay)
70 | overlay_draw.rectangle(bbox, fill=color)
71 | return Image.alpha_composite(image, overlay).convert('RGB')
72 |
73 | def do_boxes_overlap(box1, box2):
74 | """
75 | Check if two boxes overlap.
76 |
77 | Each box is represented as a tuple: (x1, y1, x2, y2)
78 | Where (x1, y1) is the top-left and (x2, y2) is the bottom-right corner.
79 | """
80 | # Unpack the coordinates
81 | x1_min, y1_min, x1_max, y1_max = box1
82 | x2_min, y2_min, x2_max, y2_max = box2
83 |
84 | # Check for no overlap
85 | if x1_max < x2_min or x2_max < x1_min:
86 | return False
87 | if y1_max < y2_min or y2_max < y1_min:
88 | return False
89 |
90 | return True
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import pathlib
2 | from dataclasses import dataclass, field
3 | from typing import Dict, Optional, Sequence
4 |
5 | import os
6 | import json
7 | import torch
8 | import transformers
9 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl
10 | from PIL import ImageFile
11 | from transformers import (
12 | Qwen2VLForConditionalGeneration,
13 | AutoProcessor,
14 | )
15 |
16 | from gui_actor.dataset import LazySupervisedDataset
17 | from gui_actor.trainer import AGUVISTrainer, rank0_print, safe_save_model_for_hf_trainer
18 | from gui_actor.utils import dump_args_to_json
19 |
20 | from gui_actor.constants import (
21 | IGNORE_INDEX,
22 | ADDITIONAL_SPECIAL_TOKENS,
23 | DEFAULT_POINTER_START_TOKEN,
24 | DEFAULT_POINTER_END_TOKEN,
25 | DEFAULT_POINTER_PAD_TOKEN,
26 | )
27 |
28 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer
29 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer
30 |
31 | apply_liger_kernel_to_qwen2_vl()
32 |
33 | torch.multiprocessing.set_sharing_strategy("file_system")
34 |
35 | ImageFile.LOAD_TRUNCATED_IMAGES = True
36 | local_rank = None
37 |
38 |
39 | @dataclass
40 | class ModelArguments:
41 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
42 | flash_attn_2_enabled: bool = field(default=True)
43 | model_type: str = field(default="qwen2vl", metadata={"help": "model type: qwen2vl or qwen25vl"})
44 |
45 | @dataclass
46 | class DataArguments:
47 | data_path: str = field(default=None)
48 | early_mix_text: bool = False
49 | image_folder: Optional[str] = field(default=None)
50 | min_pixels: Optional[int] = field(default=3136) # 2 * 2 * 28 * 28 = 56 * 56
51 | max_pixels: Optional[int] = field(default=5720064) # 5720064 = 114 * 64 * 28 * 28 = 3192 * 1792, 12845056 = 128 * 128 * 28 * 28
52 | max_conv_turns: Optional[int] = field(default=10) # 30 => 20 => 10
53 |
54 |
55 | @dataclass
56 | class TrainingArguments(transformers.TrainingArguments):
57 | cache_dir: Optional[str] = field(default=None)
58 | optim: str = field(default="adamw_torch")
59 | model_max_length: int = field(
60 | default=8192,
61 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."},
62 | )
63 | group_by_modality_length: bool = field(default=False)
64 | gradient_checkpointing: bool = field(default=True)
65 | verbose_logging: bool = field(default=False)
66 |
67 | unfreeze_all_parameters: bool = field(default=False)
68 | unfreeze_pointer_head: bool = field(default=True)
69 | unfreeze_lm_head: bool = field(default=False)
70 | unfreeze_base_model: bool = field(default=False)
71 | unfreeze_last_n_layers: int = field(default=-1)
72 | unfreeze_new_tokens: bool = field(default=True)
73 | unfreeze_visual: bool = field(default=False)
74 | pointer_loss_weight: float = field(default=0.1)
75 | lm_loss_weight: float = field(default=-1.0)
76 |
77 | # def mask_embedding_grad(grad):
78 | # n_new_tokens = len(ADDITIONAL_SPECIAL_TOKENS)
79 | # mask = torch.zeros_like(grad)
80 | # mask[-n_new_tokens:] = 1.0
81 | # return grad * mask
82 |
83 | def smart_tokenizer_and_embedding_resize(
84 | special_tokens_dict: Dict,
85 | tokenizer: transformers.PreTrainedTokenizer,
86 | model: transformers.PreTrainedModel,
87 | ):
88 | """Resize tokenizer and embedding.
89 |
90 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
91 | """
92 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
93 | model.resize_token_embeddings(len(tokenizer))
94 |
95 | new_vocab_size = len(tokenizer)
96 | # Update base model and current model config
97 | if hasattr(model.config, "text_config"):
98 | model.config.text_config.vocab_size = new_vocab_size
99 | else:
100 | model.config.vocab_size = new_vocab_size
101 | model.vocab_size = new_vocab_size
102 |
103 | if num_new_tokens > 0:
104 | input_embeddings = model.get_input_embeddings().weight.data
105 | output_embeddings = model.get_output_embeddings().weight.data
106 |
107 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
108 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
109 |
110 | input_embeddings[-num_new_tokens:] = input_embeddings_avg
111 | output_embeddings[-num_new_tokens:] = output_embeddings_avg
112 |
113 |
114 | def update_pointer_token_ids(model_config: transformers.PretrainedConfig, tokenizer: transformers.PreTrainedTokenizer):
115 | model_config.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0]
116 | model_config.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0]
117 | model_config.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0]
118 | rank0_print(f"Updated pointer token ids: {model_config.pointer_pad_token_id}, {model_config.pointer_start_token_id}, {model_config.pointer_end_token_id}")
119 |
120 | def setup_params_to_update(model: transformers.PreTrainedModel, training_args: TrainingArguments):
121 | if training_args.unfreeze_all_parameters:
122 | rank0_print(f"Unfreezing all model parameters...")
123 | for p in model.parameters():
124 | p.requires_grad = True
125 | else:
126 | rank0_print(f"Freezing all model parameters...")
127 | for p in model.parameters():
128 | p.requires_grad = False
129 |
130 | if training_args.unfreeze_pointer_head:
131 | rank0_print(f"Unfreezing pointer head parameters...")
132 | # for p in model.pointer_head.parameters():
133 | # p.requires_grad = True
134 | for p in model.multi_patch_pointer_head.parameters():
135 | p.requires_grad = True
136 |
137 | if training_args.unfreeze_lm_head:
138 | rank0_print(f"Unfreezing lm head parameters...")
139 | for p in model.lm_head.parameters():
140 | p.requires_grad = True
141 |
142 | if training_args.unfreeze_base_model: # including text tokens
143 | rank0_print(f"Unfreezing base model parameters...")
144 | for p in model.model.parameters():
145 | p.requires_grad = True
146 |
147 | if training_args.unfreeze_last_n_layers > 0:
148 | rank0_print(f"Unfreezing last {training_args.unfreeze_last_n_layers} layers of base model parameters...")
149 | for p in model.model.layers[-training_args.unfreeze_last_n_layers:].parameters():
150 | p.requires_grad = True
151 |
152 | if training_args.unfreeze_new_tokens:
153 | rank0_print(f"Unfreezing new tokens parameters via embedding hook...")
154 | model.model.embed_tokens.weight.requires_grad = True
155 | # Registering hook before Trainer initialization is invalid, so it is disabled
156 | # model.model.embed_tokens.weight.register_hook(mask_embedding_grad)
157 |
158 | if training_args.unfreeze_visual:
159 | rank0_print(f"Unfreezing visual parameters...")
160 | for p in model.visual.parameters():
161 | p.requires_grad = True
162 |
163 | @dataclass
164 | class DataCollatorForSupervisedDataset:
165 | """Collate examples for supervised fine-tuning."""
166 |
167 | tokenizer: transformers.PreTrainedTokenizer
168 |
169 | def pad_sequence(self, input_ids, batch_first, padding_value):
170 | if self.tokenizer.padding_side == "left":
171 | input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids]
172 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value)
173 | if self.tokenizer.padding_side == "left":
174 | input_ids = torch.flip(input_ids, [1])
175 | return input_ids
176 |
177 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
178 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
179 | input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids]
180 | labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels]
181 | if self.tokenizer.pad_token_id is None:
182 | self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why.
183 | input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id)
184 | labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
185 | batch = {
186 | "input_ids": input_ids,
187 | "labels": labels.long() if labels.dtype == torch.int32 else labels,
188 | "attention_mask": input_ids.ne(self.tokenizer.pad_token_id),
189 | }
190 |
191 | if "pixel_values" in instances[0]:
192 | batch["pixel_values"] = torch.concat([instance["pixel_values"] for instance in instances], dim=0)
193 | batch["image_grid_thw"] = torch.concat([instance["image_grid_thw"] for instance in instances], dim=0)
194 |
195 | if "coordinates" in instances[0]:
196 | batch["coordinates"] = [instance["coordinates"] for instance in instances]
197 | batch["visual_token_indices_of_coordinates"] = [instance["visual_token_indices_of_coordinates"] for instance in instances]
198 |
199 | if "multi_patch_labels" in instances[0]:
200 | batch["multi_patch_labels"] = [instance["multi_patch_labels"] for instance in instances]
201 |
202 | return batch
203 |
204 |
205 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
206 | processor: transformers.ProcessorMixin,
207 | data_args: DataArguments,
208 | training_args: TrainingArguments) -> Dict:
209 | """Make dataset and collator for supervised fine-tuning."""
210 | train_dataset = LazySupervisedDataset(
211 | tokenizer=tokenizer, processor=processor, data_path=data_args.data_path, data_args=data_args
212 | )
213 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
214 | return {"train_dataset": train_dataset, "eval_dataset": None, "data_collator": data_collator}
215 |
216 |
217 | def train():
218 | global local_rank
219 |
220 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments))
221 | model_args, data_args, training_args = parser.parse_args_into_dataclasses()
222 | local_rank = training_args.local_rank
223 |
224 | if training_args.verbose_logging:
225 | rank0_print("Inspecting experiment hyperparameters:\n")
226 | rank0_print(f"model_args = {vars(model_args)}\n\n")
227 | rank0_print(f"data_args = {vars(data_args)}\n\n")
228 | rank0_print(f"training_args = {vars(training_args)}\n\n")
229 | # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n")
230 |
231 | # set up model
232 | if model_args.model_type == "qwen2vl":
233 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained(
234 | model_args.model_name_or_path,
235 | cache_dir=training_args.cache_dir,
236 | attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None,
237 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
238 | low_cpu_mem_usage=False,
239 | )
240 | elif model_args.model_type == "qwen25vl":
241 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained(
242 | model_args.model_name_or_path,
243 | cache_dir=training_args.cache_dir,
244 | attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None,
245 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None),
246 | low_cpu_mem_usage=False,
247 | )
248 | else:
249 | raise ValueError(f"Invalid model type: {model_args.model_type}")
250 | model.config.use_cache = False
251 | model.reset_loss_weights(pointer_loss_weight=training_args.pointer_loss_weight, lm_loss_weight=training_args.lm_loss_weight)
252 |
253 | if training_args.gradient_checkpointing:
254 | if hasattr(model, "enable_input_require_grads"):
255 | model.enable_input_require_grads()
256 | else:
257 | def make_inputs_require_grad(module, input, output):
258 | output.requires_grad_(True)
259 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
260 |
261 | setup_params_to_update(model, training_args)
262 |
263 | # set up tokenizer
264 | tokenizer = transformers.AutoTokenizer.from_pretrained(
265 | model_args.model_name_or_path,
266 | cache_dir=training_args.cache_dir,
267 | model_max_length=training_args.model_max_length,
268 | padding_side="right",
269 | )
270 |
271 | smart_tokenizer_and_embedding_resize(
272 | special_tokens_dict={"additional_special_tokens": ADDITIONAL_SPECIAL_TOKENS},
273 | tokenizer=tokenizer,
274 | model=model,
275 | )
276 | update_pointer_token_ids(model.config, tokenizer)
277 |
278 | data_args.processor = AutoProcessor.from_pretrained(
279 | model_args.model_name_or_path, min_pixels=data_args.min_pixels, max_pixels=data_args.max_pixels
280 | )
281 | data_args.processor.tokenizer = tokenizer
282 |
283 | if not os.path.exists(training_args.output_dir):
284 | os.makedirs(training_args.output_dir, exist_ok=True)
285 |
286 | if training_args.local_rank == 0 or training_args.local_rank == -1:
287 | dump_args_to_json(model.config, data_args.processor, model_args, data_args, training_args, training_args.output_dir)
288 |
289 | data_module = make_supervised_data_module(tokenizer=tokenizer, processor=data_args.processor, data_args=data_args, training_args=training_args)
290 |
291 | trainer = AGUVISTrainer(
292 | model=model,
293 | processing_class=data_args.processor,
294 | args=training_args,
295 | **data_module,
296 | )
297 |
298 | # When LiteTrain, only update the gradient of the new tokens
299 | if training_args.unfreeze_new_tokens:
300 | emb_param = None
301 | for n, p in trainer.model.named_parameters():
302 | if n.endswith("model.embed_tokens.weight"):
303 | emb_param = p; break
304 | if emb_param is None:
305 | raise ValueError("embed_tokens.weight not found")
306 |
307 | n_new_tokens = len(ADDITIONAL_SPECIAL_TOKENS)
308 | def mask_grad(grad):
309 | grad[:-n_new_tokens] = 0.0
310 | return grad
311 | emb_param.register_hook(mask_grad)
312 |
313 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")):
314 | trainer.train(resume_from_checkpoint=True)
315 | else:
316 | trainer.train()
317 | trainer.save_state()
318 |
319 | model.config.use_cache = True
320 |
321 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir)
322 |
323 | rank0_print(f"Model saved to {training_args.output_dir}")
324 |
325 |
326 | if __name__ == "__main__":
327 | train()
328 |
--------------------------------------------------------------------------------
/verifier/README.md:
--------------------------------------------------------------------------------
1 | # Grounding Verifier for GUI-Actor
2 |
3 | We developed a grounding verifier to assess whether a selected action position aligns with a given language instruction. This model is particularly effective for GUI-Actor, as GUI-Actor's attention map produces diverse candidate positions from a single inference. With the verifier, we can efficiently evaluate actions **in hindsight**—after identifying the chosen position on the image—and make more informed decisions.
4 |
5 |
6 |
7 | ## Training
8 |
9 | The verifier is trained to take a language instruction and an image (with a red circle marking the candidate position) as input, and predict whether the position is correct—outputting "True" or "False."
10 |
11 | We use the [OS-Atlas dataset](https://huggingface.co/datasets/OS-Copilot/OS-Atlas-data) and process it using `verifier_data_generation.py` to curate training data. The model is fine-tuned via supervised training (SFT) starting from the UITARS-SFT-2B checkpoint, providing strong performance with a relatively small model size.
12 |
13 | ### Data Preparation
14 |
15 | To prepare the dataset:
16 |
17 | 1. Download and unzip the [OS-Atlas dataset](https://huggingface.co/datasets/OS-Copilot/OS-Atlas-data) following the instructions on the Hugging Face page.
18 | 2. Organize the images into the following directory structure:
19 |
20 | ```python
21 | image_folder_dict = {
22 | 'windows_splited': f'{root_path}/desktop_domain/windows_images',
23 | 'linux_splited': f'{root_path}/desktop_domain/linux_images',
24 | 'macos_splited': f'{root_path}/desktop_domain/macos_images',
25 | 'widget_captioning': f'{root_path}/mobile_domain/combined',
26 | 'uibert_raw': f'{root_path}/mobile_domain/UIBert',
27 | 'ricosca': f'{root_path}/mobile_domain/combined',
28 | 'amex_raw': f'{root_path}/mobile_domain/amex_images',
29 | 'seeclick_web': f'{root_path}/web_domain/seeclick_web_imgs',
30 | 'fineweb_3m': f'{root_path}/web_domain/fineweb'
31 | }
32 | ```
33 |
34 | Each training sample includes a positive and one or more negative examples:
35 |
36 | * **Positive samples**: taken directly from the original dataset with a red circle marking the correct target.
37 | * **Negative samples**: created by either (a) selecting another meaningful UI element or (b) randomly sampling a point, which may not correspond to any actionable item.
38 |
39 | To generate the dataset, run the following commands (since the dataset is very large, you can ):
40 |
41 | ```bash
42 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key desktop_domain --selected_size 30000
43 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key mobile_domain --selected_size 30000
44 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key web_domain --selected_size 30000
45 | ```
46 |
47 |
48 | ### SFT
49 |
50 | We use the official code from [Aguvis](https://github.com/xlang-ai/aguvis) to perform SFT training. Make sure to set the file path correctly in the `stage1.yaml` configuration. For training, we use [**UITARS-2B-SFT**](https://huggingface.co/ByteDance-Seed/UI-TARS-2B-SFT) as the base model with a learning rate of $2 \times 10^{-5}$, running for one epoch.
51 |
52 |
53 |
54 | ## Evaluation
55 |
56 | We evaluate our method using the attention weights generated by GUI-Actor and the grounding verifier, saved in a JSON file (e.g., `screenspot_all_preds_Original.json`). Before running the evaluation scripts, please update the file paths in `run_ss_v1.sh`, `run_ss_v2.sh`, and `run_ss_pro.sh` accordingly.
57 |
58 | Make sure to download the ScreenSpot datasets and ensure their paths exactly match those specified in the shell scripts. Specifically, download **ScreenSpot** and **ScreenSpot-Pro** from [ss-v1](https://huggingface.co/datasets/rootsautomation/ScreenSpot) and [ss-pro](https://huggingface.co/datasets/likaixin/ScreenSpot-Pro), respectively.
59 | For **ScreenSpot-v2**, we provide a converted version (`ScreenSpot-v2-new`) that aligns with the format used by the other datasets. However, you still need to download the original images from [ss-v2](https://huggingface.co/datasets/OS-Copilot/ScreenSpot-v2).
60 |
61 |
62 | Once everything is set up, run the following commands:
63 |
64 | ```bash
65 | bash run_ss_v1.sh
66 | bash run_ss_v2.sh
67 | bash run_ss_pro.sh
68 | ```
69 |
70 |
71 |
72 |
73 |
74 |
--------------------------------------------------------------------------------
/verifier/eval_ss_with_verifier.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import itertools
3 | import torch
4 | import json
5 | import re
6 | import argparse
7 | import os
8 | from PIL import Image, ImageDraw
9 | import logging
10 | from tqdm import tqdm
11 |
12 |
13 |
14 | def draw_annotations(img, point_in_pixel, bbox, output_path='test.png'):
15 | draw = ImageDraw.Draw(img)
16 |
17 | # Draw the ground truth bounding box in green
18 | if bbox:
19 | # Assuming bbox format is [x1, y1, x2, y2]
20 | draw.rectangle(bbox, outline="yellow", width=4)
21 |
22 | # Draw a small rectangle around the predicted point in red
23 | if point_in_pixel:
24 | # Create a small rectangle around the point (5 pixels in each direction)
25 | radius = 8
26 | circle_bbox = [
27 | point_in_pixel[0] - radius, # x1
28 | point_in_pixel[1] - radius, # y1
29 | point_in_pixel[0] + radius, # x2
30 | point_in_pixel[1] + radius # y2
31 | ]
32 | draw.ellipse(circle_bbox, outline="red", width=4)
33 |
34 | img.save(output_path)
35 | print(f"Annotated image saved to {output_path}")
36 | return img
37 |
38 |
39 |
40 |
41 | logging.basicConfig(level=logging.INFO)
42 | torch.manual_seed(114514)
43 |
44 |
45 | GT_TYPES = ['positive', 'negative']
46 | INSTRUCTION_STYLES = ['instruction', 'action', 'description']
47 | LANGUAGES = ['en', 'cn']
48 |
49 |
50 | def parse_args():
51 | parser = argparse.ArgumentParser()
52 | parser.add_argument('--model_path', type=str, required=False)
53 | parser.add_argument('--screenspot_imgs', type=str, required=True)
54 | parser.add_argument('--screenspot_test', type=str, required=True)
55 | parser.add_argument('--task', type=str, required=True)
56 | parser.add_argument('--inst_style', type=str, required=True, choices=INSTRUCTION_STYLES + ['all'], help="Instruction style to use.")
57 | parser.add_argument('--language', type=str, required=True, choices=LANGUAGES + ['all'], default='en', help="Language to use.")
58 | parser.add_argument('--gt_type', type=str, required=True, choices=GT_TYPES + ['all'], help="Ground truth type: 'positive' or 'negative'.")
59 | parser.add_argument('--log_path', type=str, required=True)
60 | parser.add_argument('--json_prediction', type=str, required=False)
61 | parser.add_argument('--verifier_path', type=str, required=True)
62 | parser.add_argument('--verifier_method', type=str, required=True)
63 |
64 | args = parser.parse_args()
65 | return args
66 |
67 |
68 | def build_model(args):
69 | from verifier_model import GroundingVerifier
70 | model = GroundingVerifier(model_name_or_path=args.model_path, json_prediction=args.json_prediction, method=args.verifier_method)
71 | model.load_model(args.verifier_path)
72 | return model
73 |
74 |
75 | def collect_results_to_eval(results, platform=None, group=None, application=None, language=None, gt_type=None, instruction_style=None, ui_type=None):
76 | """
77 | Filters the results based on provided values. None means include all (ignore filtering this attribute).
78 |
79 |
80 | Parameters:
81 | results (list): A list of dictionaries containing sample results.
82 |
83 | Returns:
84 | list: A filtered list of dictionaries based on the given criteria.
85 | """
86 | filtered_results = []
87 |
88 |
89 | for sample in results:
90 | # Check each filter condition; if None, consider it as passed
91 | if (platform is None or sample.get("platform") == platform) and \
92 | (group is None or sample.get("group") == group) and \
93 | (application is None or sample.get("application") == application) and \
94 | (language is None or sample.get("language") == language) and \
95 | (gt_type is None or sample.get("gt_type") == gt_type) and \
96 | (instruction_style is None or sample.get("instruction_style") == instruction_style) and \
97 | (ui_type is None or sample.get("ui_type") == ui_type):
98 | filtered_results.append(sample)
99 |
100 |
101 | return filtered_results
102 |
103 |
104 |
105 |
106 | def make_combinations(results, platform=False, group=None, application=False, language=False, gt_type=False, instruction_style=False, ui_type=False):
107 | """
108 | Returns a list of combinations of values for attributes where the corresponding parameter is set to True.
109 | """
110 | # Initialize a dictionary to store unique values for each attribute
111 | unique_values = {
112 | "platform": set(),
113 | "group": set(),
114 | "application": set(),
115 | "language": set(),
116 | "gt_type": set(),
117 | "instruction_style": set(),
118 | "ui_type": set(),
119 | }
120 |
121 |
122 | # Collect unique values from the results
123 | for sample in results:
124 | if platform:
125 | unique_values["platform"].add(sample.get("platform"))
126 | if group:
127 | unique_values["group"].add(sample.get("group"))
128 | if application:
129 | unique_values["application"].add(sample.get("application"))
130 | if language:
131 | unique_values["language"].add(sample.get("language"))
132 | if gt_type:
133 | unique_values["gt_type"].add(sample.get("gt_type"))
134 | if instruction_style:
135 | unique_values["instruction_style"].add(sample.get("instruction_style"))
136 | if ui_type:
137 | unique_values["ui_type"].add(sample.get("ui_type"))
138 |
139 |
140 | # Filter out the attributes that are set to False (no need for combinations)
141 | filtered_values = {key: list(value) for key, value in unique_values.items() if value}
142 | if not filtered_values:
143 | return []
144 |
145 |
146 | # Generate all combinations of the selected attributes using itertools.product
147 | attribute_combinations = list(itertools.product(*filtered_values.values()))
148 |
149 |
150 | # Convert combinations into dictionaries with corresponding attribute names
151 | combinations = []
152 | for combination in attribute_combinations:
153 | combinations.append(dict(zip(filtered_values.keys(), combination)))
154 |
155 |
156 | return combinations
157 |
158 |
159 |
160 |
161 | def calc_metric_for_result_list(results):
162 | """Calculates the metrics for a simple result list."""
163 | num_total = len(results)
164 | correct_num = sum(1 for res in results if res["correctness"] == "correct")
165 | wrong_format_num = sum(1 for res in results if res["correctness"] == "wrong_format")
166 |
167 |
168 | # Calculate text and icon specific metrics using collect_results_to_eval
169 | text_results = collect_results_to_eval(results, ui_type="text")
170 | icon_results = collect_results_to_eval(results, ui_type="icon")
171 |
172 |
173 | text_correct = sum(1 for res in text_results if res["correctness"] == "correct")
174 | text_total = len(text_results)
175 | icon_correct = sum(1 for res in icon_results if res["correctness"] == "correct")
176 | icon_total = len(icon_results)
177 | metrics = {
178 | "num_correct_action": correct_num,
179 | "num_total": num_total,
180 | "wrong_format_num": wrong_format_num,
181 | "action_acc": correct_num / num_total if num_total > 0 else 0,
182 | "text_acc": text_correct / text_total if text_total > 0 else 0,
183 | "icon_acc": icon_correct / icon_total if icon_total > 0 else 0
184 | }
185 | return metrics
186 |
187 |
188 |
189 |
190 | def eval_sample_positive_gt(sample, response):
191 | bbox = sample["bbox"]
192 | bbox = [bbox[0], bbox[1], bbox[2], bbox[3]] # x1, y1, x2, y2
193 | # bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] # x1, y1, w, h
194 | img_size = sample["img_size"]
195 | bbox = [bbox[0] / img_size[0], bbox[1] / img_size[1], bbox[2] / img_size[0], bbox[3] / img_size[1]]
196 |
197 | click_point = response["point"] # may be none
198 | print(click_point, bbox)
199 | # import pdb;pdb.set_trace()
200 | if click_point is None:
201 | return "wrong_format"
202 | # Check if the predicted point falls in the ground truth box
203 | if (bbox[0] <= click_point[0] <= bbox[2]) and (bbox[1] <= click_point[1] <= bbox[3]):
204 | return "correct"
205 | else:
206 | return "wrong"
207 |
208 | def eval_sample_negative_gt(sample, response):
209 | if response["result"] == "negative":
210 | return "correct"
211 | elif response["result"] == "positive":
212 | return "wrong"
213 | else: ## response["result"] == wrong_format
214 | return "wrong_format"
215 |
216 |
217 | def evaluate_fine_grained(results):
218 | # Generate all combinations of platform, instruction_style, and gt_type
219 | combinations = make_combinations(
220 | results,
221 | platform=True,
222 | application=True,
223 | instruction_style=True,
224 | gt_type=True
225 | )
226 |
227 |
228 | evaluation_result = {}
229 |
230 |
231 | # Iterate through each combination
232 | for combo in combinations:
233 | platform = combo.get("platform")
234 | application = combo.get("application")
235 | inst_style = combo.get("instruction_style")
236 | gt_type = combo.get("gt_type")
237 |
238 | # Filter results for the current combination
239 | filtered_results = collect_results_to_eval(
240 | results=results,
241 | platform=platform,
242 | application=application,
243 | instruction_style=inst_style,
244 | gt_type=gt_type
245 | )
246 |
247 | # Calculate metrics using the calc_metric_for_result_list function
248 | metrics = calc_metric_for_result_list(filtered_results)
249 | if metrics['num_total'] == 0:
250 | continue
251 |
252 | # Construct a unique key based on the combination
253 | key = f"plat:{platform} app:{application} inst_style:{inst_style} gt_type:{gt_type}"
254 | evaluation_result[key] = metrics
255 |
256 |
257 | return evaluation_result
258 |
259 |
260 | def evaluate_fine_grained_v2(results):
261 | # Generate all combinations of platform, instruction_style, and gt_type
262 | combinations = make_combinations(
263 | results,
264 | group=True,
265 | )
266 |
267 |
268 | evaluation_result = {}
269 |
270 |
271 | # Iterate through each combination
272 | for combo in combinations:
273 | group = combo.get("group")
274 |
275 |
276 | # Filter results for the current combination
277 | filtered_results = collect_results_to_eval(
278 | results=results,
279 | group=group,
280 | )
281 |
282 | # Calculate metrics using the calc_metric_for_result_list function
283 | metrics = calc_metric_for_result_list(filtered_results)
284 | if metrics['num_total'] == 0:
285 | continue
286 |
287 | # Construct a unique key based on the combination
288 | key = f"group:{group}"
289 | evaluation_result[key] = metrics
290 |
291 |
292 | return evaluation_result
293 |
294 |
295 | def evaluate_seeclick_paper_style(results):
296 | # Generate all combinations of platform, instruction_style, and gt_type
297 | combinations = make_combinations(
298 | results,
299 | platform=True,
300 | instruction_style=True,
301 | gt_type=True
302 | )
303 |
304 |
305 | evaluation_result = {}
306 |
307 |
308 | # Iterate through each combination
309 | for combo in combinations:
310 | platform = combo.get("platform")
311 | inst_style = combo.get("instruction_style")
312 | gt_type = combo.get("gt_type")
313 |
314 | # Filter results for the current combination
315 | filtered_results = collect_results_to_eval(
316 | results=results,
317 | platform=platform,
318 | instruction_style=inst_style,
319 | gt_type=gt_type
320 | )
321 |
322 | # Calculate metrics using the calc_metric_for_result_list function
323 | metrics = calc_metric_for_result_list(filtered_results)
324 | if metrics['num_total'] == 0:
325 | continue
326 |
327 | # Construct a unique key based on the combination
328 | key = f"plat:{platform} inst_style:{inst_style} gt_type:{gt_type}"
329 | evaluation_result[key] = metrics
330 |
331 |
332 | return evaluation_result
333 |
334 |
335 | def evaluate_leaderboard_detailed_style(results):
336 | # Generate all combinations of platform, instruction_style, and gt_type
337 | combinations = make_combinations(
338 | results,
339 | application=True,
340 | )
341 |
342 |
343 | evaluation_result = {}
344 |
345 |
346 | # Iterate through each combination
347 | for combo in combinations:
348 | application = combo.get("application")
349 |
350 | # Filter results for the current combination
351 | filtered_results = collect_results_to_eval(
352 | results=results,
353 | application=application,
354 | )
355 |
356 | # Calculate metrics using the calc_metric_for_result_list function
357 | metrics = calc_metric_for_result_list(filtered_results)
358 | if metrics['num_total'] == 0:
359 | continue
360 |
361 | # Construct a unique key based on the combination
362 | key = f"app:{application}"
363 | evaluation_result[key] = metrics
364 |
365 |
366 | return evaluation_result
367 |
368 |
369 | def evaluate_leaderboard_simple_style(results):
370 | # Generate all combinations of platform, instruction_style, and gt_type
371 | combinations = make_combinations(
372 | results,
373 | group=True,
374 | )
375 |
376 |
377 | evaluation_result = {}
378 |
379 |
380 | # Iterate through each combination
381 | for combo in combinations:
382 | group = combo.get("group")
383 |
384 | # Filter results for the current combination
385 | filtered_results = collect_results_to_eval(
386 | results=results,
387 | group=group,
388 | )
389 |
390 | # Calculate metrics using the calc_metric_for_result_list function
391 | metrics = calc_metric_for_result_list(filtered_results)
392 | if metrics['num_total'] == 0:
393 | continue
394 |
395 | # Construct a unique key based on the combination
396 | key = f"group:{group}"
397 | evaluation_result[key] = metrics
398 |
399 |
400 | return evaluation_result
401 |
402 |
403 | def evaluate_overall(results):
404 | """
405 | Evaluates the overall metrics for all results without any filtering.
406 |
407 | Parameters:
408 | results (list): A list of dictionaries containing sample results.
409 |
410 | Returns:
411 | dict: A dictionary containing the overall metrics.
412 | """
413 | # Calculate metrics for the entire result set
414 | metrics = calc_metric_for_result_list(results)
415 |
416 | return metrics
417 |
418 |
419 |
420 |
421 | def evaluate(results):
422 | """Collect results and calculate metrics. You can comment out function calls or add new ones based on your need.
423 | """
424 | result_report = {
425 | "details": [], # Store detailed information for each sample
426 | "metrics": {}
427 | }
428 |
429 |
430 | # # TODO: comment out function calls based on your need
431 | result_report["metrics"]["fine_grained"] = evaluate_fine_grained_v2(results)
432 | # result_report["metrics"]["seeclick_style"] = evaluate_seeclick_paper_style(results)
433 | # result_report["metrics"]["leaderboard_simple_style"] = evaluate_leaderboard_simple_style(results)
434 | # result_report["metrics"]["leaderboard_detailed_style"] = evaluate_leaderboard_detailed_style(results)
435 | result_report["metrics"]["overall"] = evaluate_overall(results)
436 |
437 |
438 | # Save detailed results
439 | result_report["details"] = results
440 |
441 |
442 | return result_report
443 |
444 |
445 | def main(args):
446 | model = build_model(args)
447 | print("Load model success")
448 |
449 |
450 | if args.task == "all":
451 | task_filenames = [
452 | os.path.splitext(f)[0]
453 | for f in os.listdir(args.screenspot_test)
454 | if f.endswith(".json")
455 | ]
456 | else:
457 | task_filenames = args.task.split(",")
458 |
459 |
460 | if args.inst_style == "all":
461 | inst_styles = INSTRUCTION_STYLES
462 | else:
463 | inst_styles = args.inst_style.split(",")
464 |
465 |
466 | if args.language == "all":
467 | languages = LANGUAGES
468 | else:
469 | languages = args.language.split(",")
470 |
471 |
472 | if args.gt_type == "all":
473 | gt_types = GT_TYPES
474 | else:
475 | gt_types = args.gt_type.split(",")
476 |
477 |
478 | tasks_to_run = []
479 | for task_filename in task_filenames:
480 | dataset = task_filename + ".json"
481 | with open(os.path.join(args.screenspot_test, dataset), 'r') as f:
482 | task_data = json.load(f)
483 |
484 |
485 | # Create the list of tasks to run, one item as an instance. Tasks may be reused.
486 | for inst_style in inst_styles: # Expand tasks based on user configurations
487 | for gt_type in gt_types:
488 | for lang in languages:
489 | for task_instance in task_data: # [30:]
490 | task_instance = copy.deepcopy(task_instance)
491 | task_instance["task_filename"] = task_filename
492 | task_instance["gt_type"] = gt_type
493 | task_instance["instruction_style"] = inst_style
494 | task_instance["language"] = lang
495 | if lang == "cn":
496 | if inst_style!= 'instruction' or gt_type != 'positive':
497 | # TODO: Translate the data
498 | raise AttributeError("Only positive samples and 'instruction' style are supported for Chinese instructions.")
499 | task_instance["prompt_to_evaluate"] = task_instance["instruction_cn"]
500 | elif lang == "en":
501 | task_instance["prompt_to_evaluate"] = task_instance["instruction"]
502 |
503 |
504 | tasks_to_run.append(task_instance)
505 | print(f"Num of sample in {task_filename}: {len(task_data)} * {len(inst_styles)} * {len(gt_types)} * {len(languages)} = {len(task_data) * len(inst_styles) * len(gt_types) * len(languages)}")
506 | print(f"Total tasks: {len(tasks_to_run)}")
507 |
508 |
509 | results = []
510 | for sample in tqdm(tasks_to_run[:]):
511 | filename = sample["img_filename"]
512 | img_path = os.path.join(args.screenspot_imgs, filename)
513 |
514 | if task_instance["gt_type"] == "positive":
515 | response = model.ground_only_positive(instruction=sample["prompt_to_evaluate"], image=img_path, target_point=sample['bbox'])
516 |
517 |
518 | elif task_instance["gt_type"] == "negative":
519 | response = model.ground_allow_negative(instruction=sample["prompt_to_evaluate"], image=img_path)
520 | # print(response)
521 | point = response["point"]
522 | img_size = sample["img_size"]
523 | point_in_pixel = [point[0] * img_size[0], point[1] * img_size[1]] if point else None
524 |
525 | sample_result = {
526 | "img_path": img_path,
527 | "group": sample["group"] if "group" in sample else None,
528 | "platform": sample["platform"],
529 | "application": sample["application"] if 'application' in sample else None,
530 | "lang": sample["language"],
531 | "instruction_style": sample["instruction_style"] if 'instruction_style' in sample else None,
532 | "prompt_to_evaluate": sample["prompt_to_evaluate"],
533 | "gt_type": sample["gt_type"],
534 | "ui_type": sample["ui_type"],
535 | "task_filename": sample["task_filename"],
536 | "pred": point_in_pixel,
537 | "raw_response": response["raw_response"]
538 | }
539 |
540 | if sample["gt_type"] == "positive":
541 | correctness = eval_sample_positive_gt(sample, response)
542 | sample_result.update({
543 | "bbox": sample["bbox"],
544 | })
545 | print(correctness)
546 | elif sample["gt_type"] == "negative":
547 | correctness = eval_sample_negative_gt(sample, response)
548 | else:
549 | raise ValueError("Wrong instruction type")
550 |
551 |
552 |
553 | sample_result.update({
554 | "correctness": correctness,
555 | })
556 | results.append(sample_result)
557 |
558 | result_report = evaluate(results)
559 | # Save to file
560 | os.makedirs(os.path.dirname(args.log_path), exist_ok=True)
561 | with open(args.log_path, 'w') as f:
562 | json.dump(result_report, f, indent=4)
563 | logging.info("Evaluation of ScreenSpot finished.")
564 |
565 |
566 |
567 |
568 | if __name__ == "__main__":
569 | main(parse_args())
570 |
571 |
572 |
573 |
--------------------------------------------------------------------------------
/verifier/run_ss_pro.sh:
--------------------------------------------------------------------------------
1 | # !/bin/bash
2 | set -e
3 |
4 | json_path='Screenspot_eval/json_data/7B_full_qwen2vl/final_eval/screenspot-Pro_all_preds_StandardResize.json'
5 | exp_name='Actor-7b-fixprompt-bon_score_verifier'
6 | verifier_path='microsoft/GUI-Actor-Verifier-2B'
7 | screenspot_dataset_path='/datadisk/data/ss-eval/ScreenSpot-Pro'
8 | logdir='results_pro'
9 |
10 | verifier_method='score'
11 | # verifier_method='best_one'
12 | export CUDA_VISIBLE_DEVICES=0
13 |
14 |
15 | python eval_ss_with_verifier.py \
16 | --screenspot_imgs ${screenspot_dataset_path}'/images' \
17 | --screenspot_test ${screenspot_dataset_path}'/annotations' \
18 | --task "all" \
19 | --language "en" \
20 | --gt_type "positive" \
21 | --log_path "${logdir}/${exp_name}_${checkpoint}_sspro.json" \
22 | --inst_style "instruction" \
23 | --verifier_method ${verifier_method} \
24 | --verifier_path ${verifier_path} \
25 | --json_prediction ${json_path}
26 |
27 |
--------------------------------------------------------------------------------
/verifier/run_ss_v1.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | json_path='/home/t-yangrui/code/Screenspot_eval/json_data/new_prompt_7B/screenspot_all_preds_Original.json'
5 | exp_name='Actor-7b-warmup-fixprompt-bon_score_verifierultracpt6000_crop500'
6 | verifier_path='microsoft/GUI-Actor-Verifier-2B'
7 | screenspot_dataset_path="data/ss-eval/ScreenSpot"
8 | logdir='results_v1'
9 |
10 | verifier_method='score'
11 | # verifier_method='best_one'
12 | export CUDA_VISIBLE_DEVICES=0
13 |
14 |
15 | python eval_ss_with_verifier.py \
16 | --screenspot_imgs ${screenspot_dataset_path}'/images' \
17 | --screenspot_test ${screenspot_dataset_path} \
18 | --task "all" \
19 | --language "en" \
20 | --gt_type "positive" \
21 | --log_path "${logdir}/${exp_name}_${checkpoint}_ssv1.json" \
22 | --inst_style "instruction" \
23 | --verifier_method ${verifier_method} \
24 | --verifier_path ${verifier_path} \
25 | --json_prediction ${json_path}
26 |
27 |
--------------------------------------------------------------------------------
/verifier/run_ss_v2.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | models=("aguvis-with-verifier")
5 |
6 |
7 | json_path='/home/t-yangrui/code/Screenspot_eval/json_data/new_prompt_3B/screenspot_v2_all_preds_Original.json'
8 | exp_name='Actor-7b-warmup-fixprompt-bon_score_verifierultracpt6000_crop500'
9 | verifier_path='microsoft/GUI-Actor-Verifier-2B'
10 | screenspot_dataset_path='ss-eval/ScreenSpot-v2'
11 | logdir='results_v2'
12 |
13 |
14 | verifier_method='score'
15 | # verifier_method='best_one'
16 | export CUDA_VISIBLE_DEVICES=0
17 |
18 |
19 | python eval_ss_with_verifier.py \
20 | --screenspot_imgs "${screenspot_dataset_pat}/screenspotv2_image" \
21 | --screenspot_test "ScreenSpot-v2-new" \
22 | --task "all" \
23 | --language "en" \
24 | --gt_type "positive" \
25 | --log_path "${logdir}/${exp_name}_${checkpoint}_ssv2.json" \
26 | --inst_style "instruction" \
27 | --verifier_method ${verifier_method} \
28 | --verifier_path ${verifier_path} \
29 | --json_prediction ${json_path}
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/verifier/verifier_data_generation.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import re
4 | import numpy as np
5 | import random
6 | from PIL import Image, ImageDraw
7 | import argparse
8 |
9 |
10 | dic = {
11 | "from": "gpt",
12 | "value": "True",
13 | "recipient": "os",
14 | "end_turn": True
15 | }
16 | neg_dic = {
17 | "from": "gpt",
18 | "value": "False",
19 | "recipient": "os",
20 | "end_turn": True
21 | }
22 |
23 |
24 |
25 | def sample_point(bbox):
26 | x0, y0, x1, y1 = bbox
27 | t = 0
28 | while t <= 50:
29 | xx, yy = np.random.random(2)
30 | t += 1
31 | if not ((x0 < xx < x1) and (y0 < yy < y1)):
32 | break
33 | if t > 50:
34 | return
35 | return xx, yy
36 |
37 |
38 | def load_json_file(file_path):
39 | """Load and parse JSON data from a file."""
40 | try:
41 | with open(file_path, 'r') as file:
42 | data = json.load(file)
43 | return data
44 | except FileNotFoundError:
45 | print(f"Error: File '{file_path}' not found.")
46 | return None
47 | except json.JSONDecodeError:
48 | print(f"Error: '{file_path}' contains invalid JSON.")
49 | return None
50 |
51 |
52 | def draw_annotations(img, point_in_pixel, bbox, output_path='test.png', color='red', size=1):
53 | draw = ImageDraw.Draw(img)
54 |
55 | # Draw the ground truth bounding box in green
56 | if bbox:
57 | # Assuming bbox format is [x1, y1, x2, y2]
58 | draw.rectangle(bbox, outline="yellow", width=4)
59 | # Draw a small rectangle around the predicted point in red
60 | if point_in_pixel:
61 | # Create a small rectangle around the point (5 pixels in each direction)
62 | radius = np.ceil(8 * size).astype(int)
63 | circle_bbox = [
64 | point_in_pixel[0] - radius, # x1
65 | point_in_pixel[1] - radius, # y1
66 | point_in_pixel[0] + radius, # x2
67 | point_in_pixel[1] + radius # y2
68 | ]
69 | draw.ellipse(circle_bbox, outline=color, width=np.ceil(4 * size).astype(int))
70 |
71 | img.save(output_path)
72 | print(f"Annotated image saved to {output_path}")
73 | return img
74 |
75 |
76 | def transform_to_conversation_format(data, file, image_folder_dict, new_directory):
77 | """
78 | Transform the input data to the specified conversation format.
79 | Args:
80 | data: List of dictionaries containing webpage elements data
81 |
82 | Returns:
83 | List of dictionaries in the conversation format
84 | """
85 | image_folder = image_folder_dict[file]
86 | result = []
87 | for i, item in enumerate(data):
88 | print(i / len(data))
89 | img_filename = item['img_filename']
90 |
91 | prompt = 'Please observe the screenshot and exame whether the hollow red circle accurately placed on the intended position in the image:'
92 |
93 | if 'elements' in item:
94 | # sample n//2 element
95 | n = len(item['elements'])
96 | ind_list = []
97 | if n <= 1:
98 | ind_list = [0]
99 | else:
100 | ind_list = random.sample(range(n), min(n//2, 3))
101 |
102 | for ind in ind_list:
103 | conversations = []
104 | instruction = item['elements'][ind]['instruction']
105 | bbox = item['elements'][ind]['bbox']
106 | if (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) >= 0.8:
107 | continue
108 |
109 | conversations.append({
110 | "from": "human",
111 | "value": f"\n{prompt} " + f"'{instruction}'. Answer True or False."
112 | })
113 |
114 | # Calculate the center point of the bounding box
115 | x_center = (bbox[0] + bbox[2]) / 2
116 | y_center = (bbox[1] + bbox[3]) / 2
117 |
118 | if n >= 2:
119 | neg_ind = random.choice([k for k in range(n) if k != ind])
120 | neg_bbox = item['elements'][neg_ind]['bbox']
121 | x_neg, y_neg = (neg_bbox[0] + neg_bbox[2]) / 2, (neg_bbox[1] + neg_bbox[3]) / 2
122 | if (x_center - x_neg) ** 2 + (y_center - y_neg) ** 2 < 0.05:
123 | x_neg, y_neg = sample_point(bbox)
124 | else:
125 | x_neg, y_neg = sample_point(bbox)
126 |
127 | # draw image
128 | try:
129 | img = Image.open(os.path.join(image_folder, img_filename))
130 | except:
131 | continue
132 |
133 |
134 | prefix, suffix = img_filename.split('.')
135 | width, height = img.size
136 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos{ind}.' + suffix)
137 | while os.path.exists(save_path):
138 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos{ind}_{random.randint(0, 1000)}.' + suffix)
139 |
140 | try:
141 | draw_annotations(img, [x_center * width, y_center* height], None, output_path=save_path, size=height/1000 * 1.2)
142 | except:
143 | continue
144 | img = Image.open(os.path.join(image_folder, img_filename))
145 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg{ind}.' + suffix)
146 | while os.path.exists(neg_save_path):
147 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg{ind}_{random.randint(0, 1000)}.' + suffix)
148 |
149 | draw_annotations(img, [x_neg * width, y_neg* height], None, output_path=neg_save_path, size=height/1000 * 1.2)
150 |
151 |
152 | # Create the conversation item
153 | result.append({
154 | "image":save_path.replace(new_directory, ''),
155 | "conversations": conversations + [dic]
156 | })
157 | result.append({
158 | "image":neg_save_path.replace(new_directory, ''),
159 | "conversations": conversations + [neg_dic]
160 | })
161 | else:
162 | conversations = []
163 | instruction = item['instruction']
164 | bbox = item['bbox']
165 | conversations.append({
166 | "from": "human",
167 | "value": f"\n{prompt} " + f"'{instruction}'. Answer True or False."
168 | })
169 |
170 |
171 | if (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) >= 0.8:
172 | continue
173 |
174 | x_center = (bbox[0] + bbox[2]) / 2
175 | y_center = (bbox[1] + bbox[3]) / 2
176 | x_neg, y_neg = sample_point(bbox)
177 |
178 | # draw image
179 | try:
180 | img = Image.open(os.path.join(image_folder, img_filename))
181 | except:
182 | continue
183 | prefix, suffix = img_filename.split('.')
184 | width, height = img.size
185 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + '_pos.' + suffix)
186 | while os.path.exists(save_path):
187 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos_{random.randint(0, 1000)}.' + suffix)
188 |
189 | draw_annotations(img, [x_center * width, y_center* height], None, output_path=save_path, size=height/1000 * 1.2)
190 |
191 | img = Image.open(os.path.join(image_folder, img_filename))
192 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + '_neg.' + suffix)
193 | while os.path.exists(neg_save_path):
194 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg_{random.randint(0, 1000)}.' + suffix)
195 |
196 | draw_annotations(img, [x_neg * width, y_neg* height], None, output_path=neg_save_path, size=height/1000 * 1.2)
197 |
198 | # Create the conversation item
199 | result.append({
200 | "image":save_path.replace(new_directory, ''),
201 | "conversations": conversations + [dic]
202 | })
203 | result.append({
204 | "image":neg_save_path.replace(new_directory, ''),
205 | "conversations": conversations + [neg_dic]
206 | })
207 |
208 | return result
209 |
210 |
211 |
212 |
213 | if __name__ == "__main__":
214 | parser = argparse.ArgumentParser(description="Generate verifier data")
215 | parser.add_argument('--root_path', type=str, required=True, help='Root path to OS-Atlas-data')
216 | parser.add_argument('--new_directory', type=str, default='./verifier_data', help='Directory to save the new verifier data')
217 | parser.add_argument('--file_dict_key', type=str, default='', help='Key for the file dictionary to process')
218 | parser.add_argument('--save_suffix', type=str, default='verifier', help='Suffix for the saved files')
219 | parser.add_argument('--selected_size', type=int, default=10000, help='Number of samples to select from each file')
220 | args = parser.parse_args()
221 |
222 |
223 | root_path = args.root_path
224 | new_directory = args.new_directory
225 | save_suffix = args.save_suffix
226 | selected_size = args.selected_size
227 |
228 | if not os.path.exists(new_directory):
229 | os.makedirs(new_directory)
230 |
231 |
232 | image_folder_dict = {
233 | 'windows_splited': f'{root_path}/desktop_domain/windows_images',
234 | 'linux_splited': f'{root_path}/desktop_domain/linux_images',
235 | 'macos_splited': f'{root_path}/desktop_domain/macos_images',
236 | 'widget_captioning': f'{root_path}/mobile_domain/combined',
237 | 'uibert_raw': f'{root_path}/mobile_domain/UIBert',
238 | 'ricosca': f'{root_path}/mobile_domain/combined',
239 | 'amex_raw': f'{root_path}/mobile_domain/amex_images',
240 | 'seeclick_web': f'{root_path}/web_domain/seeclick_web_imgs',
241 | 'fineweb_3m': f'{root_path}/web_domain/fineweb'
242 | }
243 |
244 |
245 | file_dict = {
246 | 'desktop_domain': ['linux_splited', 'windows_splited', 'macos_splited'],
247 | 'mobile_domain': ['uibert_raw', 'ricosca', 'amex_raw', 'widget_captioning'],
248 | 'web_domain': ['fineweb_3m', 'seeclick_web'],
249 | }
250 |
251 |
252 | def process_files(directory):
253 | files = file_dict[directory]
254 | for file in files:
255 | file_path = os.path.join(root_path, directory, file + '.json')
256 | # Load the JSON data
257 | data = load_json_file(file_path)
258 | data = random.sample(data, selected_size) if len(data) >= selected_size else data
259 | print(directory, file, len(data))
260 |
261 | # Extract coordinates
262 | new_data = transform_to_conversation_format(data, file, image_folder_dict, new_directory)
263 |
264 |
265 | print(directory, file, len(data))
266 | with open(file_path.replace('.json', f'_{save_suffix}.json'), "w", encoding="utf-8") as f:
267 | json.dump(new_data, f)
268 |
269 |
270 | if len(args.file_dict_key) == 0:
271 | for directory in file_dict.keys():
272 | process_files(directory)
273 | else:
274 | key = args.file_dict_key
275 | assert key in file_dict.keys(), f"Key {key} not found in file_dict"
276 | process_files(key)
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
--------------------------------------------------------------------------------
/verifier/verifier_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
3 | from transformers.generation import GenerationConfig
4 | import json
5 | import re
6 | import os
7 | import tempfile
8 | from PIL import Image, ImageDraw
9 | from qwen_vl_utils import process_vision_info
10 | from typing import List, Literal, Optional
11 | import numpy as np
12 | import random
13 |
14 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task."
15 |
16 |
17 | def image_to_temp_filename(image):
18 | temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
19 | image.save(temp_file.name)
20 | print(f"Image saved to temporary file: {temp_file.name}")
21 | return temp_file.name
22 |
23 |
24 | def draw_point_list(img, points, color='red', size=1, crop=True, sample_crop=False, crop_size=500):
25 | draw = ImageDraw.Draw(img)
26 | radius = np.ceil(7 * size).astype(int)
27 | for point in points:
28 | circle_bbox = [
29 | point[0] - radius, # x1
30 | point[1] - radius, # y1
31 | point[0] + radius, # x2
32 | point[1] + radius # y2
33 | ]
34 | draw.ellipse(circle_bbox, outline=color, width=np.ceil(3 * size).astype(int))
35 |
36 |
37 | if crop:
38 | x, y = points[0]
39 | width, height = img.size
40 | crop_half_size = crop_size
41 | left = max(0, x - crop_half_size)
42 | right = min(width-1, x + crop_half_size)
43 | top = max(0, y - crop_half_size)
44 | bottom = min(height-1, y + crop_half_size)
45 | try:
46 | img = img.crop((left, top, right, bottom))
47 | except Exception as e:
48 | print(f"Error cropping image: {e}")
49 | # If cropping fails, return the original image
50 | return img
51 | return img
52 |
53 |
54 |
55 | class GroundingVerifier():
56 | def __init__(self,
57 | model_name_or_path="microsoft/GUI-Actor-Verifier-2B",
58 | json_prediction=None,
59 | method='score' # 'best_one', 'comparison', 'score'
60 | ):
61 | self.method = method
62 | self.model_name_or_path = model_name_or_path
63 | self.system_message = {
64 | "role": "system",
65 | "content": grounding_system_message,
66 | }
67 | self.json_prediction_path = json_prediction
68 | # load json prediction
69 | assert os.path.exists(json_prediction) and os.path.isfile(json_prediction), "Invalid json prediction path."
70 | with open(json_prediction, 'r') as f:
71 | self.json_prediction = json.load(f)
72 |
73 | self.verifier_crop_size = 500 # half of the true crop size
74 | # use 0.95 for ss-pro
75 | if '-pro' in self.json_prediction_path.lower():
76 | self.threshold = 0.95
77 | else: # use 0.8 for ss and ss-v2
78 | self.threshold = 0.8
79 |
80 | self.json_index_dict = {}
81 | for i, item in enumerate(self.json_prediction):
82 | key = 'img_filename' if 'img_filename' in item else 'file_name'
83 | json_key = item[key] + item['instruction'] if 'instruction' in item else ''
84 | self.json_index_dict[json_key] = i
85 |
86 |
87 |
88 | def load_model(self, verifier_path):
89 | if self.method == 'best_one':
90 | return
91 | else:
92 | verifier_model_name_or_path = verifier_path
93 |
94 | self.verifier = Qwen2VLForConditionalGeneration.from_pretrained(
95 | verifier_model_name_or_path,
96 | device_map="cuda:0",
97 | trust_remote_code=True,
98 | torch_dtype=torch.bfloat16,
99 | attn_implementation="flash_attention_2"
100 | ).eval()
101 | self.verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_model_name_or_path, trust_remote_code=True)
102 | self.verifier_processor = AutoProcessor.from_pretrained(verifier_model_name_or_path)
103 | self.verifier_processor.tokenizer.pad_token = self.verifier_processor.tokenizer.eos_token
104 |
105 |
106 |
107 | def set_generation_config(self, **kwargs):
108 | pass
109 |
110 |
111 | def verify(self, instruction, image):
112 | verifier_prompt = "Please observe the screenshot and exame whether the hollow red circle accurately placed on the intended position in the image: '{}'. Answer True or False."
113 | full_prompt = verifier_prompt.format(instruction)
114 | messages = [
115 | {
116 | "role": "user",
117 | "content": [
118 | {
119 | "type": "image",
120 | "image": image,
121 | },
122 | {"type": "text", "text": full_prompt},
123 | ],
124 | }
125 | ]
126 | text_input = self.verifier_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
127 | image_inputs, video_inputs = process_vision_info(messages)
128 | inputs = self.verifier_processor(
129 | text=[text_input],
130 | images=image_inputs,
131 | videos=video_inputs,
132 | padding=True,
133 | return_tensors="pt",
134 | )
135 | inputs = inputs.to("cuda:0")
136 |
137 |
138 | # get the token probability of True and False using the verifier
139 | # Forward pass to get logits
140 | with torch.no_grad():
141 | outputs = self.verifier(**inputs)
142 | logits = outputs.logits # shape: (batch_size, seq_len, vocab_size)
143 |
144 | # Get the last token's logits
145 | last_token_logits = logits[:, -1, :] # (batch_size, vocab_size)
146 |
147 |
148 | # Get vocab IDs for "True" and "False"
149 | true_id = self.verifier_processor.tokenizer.encode("True", add_special_tokens=False)[0]
150 | false_id = self.verifier_processor.tokenizer.encode("False", add_special_tokens=False)[0]
151 |
152 | # Get probabilities using softmax
153 | probs = torch.softmax(last_token_logits, dim=-1)
154 | true_prob = probs[0, true_id].item()
155 | false_prob = probs[0, false_id].item()
156 | score = true_prob / (true_prob + false_prob)
157 | return score
158 |
159 |
160 |
161 | def verifier_score(self, instruction, image, box):
162 | box = [box]
163 | img_copy = image.copy()
164 | img_new = draw_point_list(img_copy, box, crop_size=self.verifier_crop_size)
165 | score = self.verify(instruction, img_new)
166 | return score
167 |
168 |
169 | def get_prediction_region_point(self, attn_scores, n_width, n_height, top_n=20, return_all_regions=True, rect_center=False, no_groups=False):
170 | attn_scores = np.array(attn_scores)
171 | max_score = attn_scores.max()
172 | threshold = max_score * 0.2
173 | # select patches with activation scores above the threshold
174 | mask = attn_scores > threshold
175 | valid_indices = np.where(mask)
176 | # keep only top_n patches
177 | if len(valid_indices[1]) > top_n:
178 | valid_scores = attn_scores[valid_indices]
179 | sorted_idx = np.argsort(valid_scores)[::-1][:top_n]
180 | valid_indices = valid_indices[1][sorted_idx]
181 | topk_values = valid_scores[sorted_idx]
182 | topk_indices = valid_indices
183 | else:
184 | topk_values = attn_scores[valid_indices].tolist()
185 | topk_indices = valid_indices[1]
186 |
187 | # topk_values, topk_indices = attn_scores.topk(top_n, dim=-1)
188 | if n_width * n_height != attn_scores.shape[1]:
189 | n_width = n_width // 2
190 | n_height = n_height // 2
191 |
192 |
193 | # transform the topk_indices into coordinates
194 | topk_coords = []
195 | for idx in topk_indices:
196 | x = idx % n_width
197 | y = idx // n_width
198 | topk_coords.append((int(y), int(x), int(idx)))
199 |
200 | # divide the topk_coords into regions based on connectivity
201 | regions = []
202 | visited = set()
203 |
204 | for i, (y, x, idx) in enumerate(topk_coords):
205 | if idx in visited:
206 | continue
207 |
208 | region = [(y, x, idx, topk_values[i])]
209 | visited.add(idx)
210 | queue = [(y, x, idx, topk_values[i])]
211 |
212 | # BFS
213 | while queue:
214 | cy, cx, c_idx, c_val = queue.pop(0)
215 |
216 | # check four directions
217 | for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]:
218 | ny, nx = cy + dy, cx + dx
219 | n_idx = ny * n_width + nx
220 |
221 | # check whether the new coordinates are within bounds
222 | for j, (ty, tx, t_idx) in enumerate(topk_coords):
223 | if ty == ny and tx == nx and t_idx not in visited:
224 | visited.add(t_idx)
225 | region.append((ny, nx, t_idx, topk_values[j]))
226 | queue.append((ny, nx, t_idx, topk_values[j]))
227 |
228 | regions.append(region)
229 |
230 | region_scores = []
231 | region_centers = []
232 | region_points = []
233 |
234 | for region in regions:
235 | # calculate the average score of the region
236 | avg_score = sum(item[3] for item in region) / len(region)
237 | region_scores.append(avg_score)
238 |
239 |
240 | # calculate the normalized center of the region
241 | normalized_centers = []
242 | weights = []
243 | y_coords = set()
244 | x_coords = set()
245 |
246 | for y, x, _, score in region:
247 | center_y = (y + 0.5) / n_height
248 | center_x = (x + 0.5) / n_width
249 | normalized_centers.append((center_x, center_y))
250 | weights.append(score)
251 |
252 |
253 | y_coords.add(center_y)
254 | x_coords.add(center_x)
255 |
256 |
257 | region_points.append(normalized_centers)
258 |
259 |
260 | # calculate the average center of the region
261 | if not rect_center:
262 | # weighted average
263 | total_weight = sum(weights)
264 | weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight
265 | weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight
266 | avg_center_x, avg_center_y = weighted_x, weighted_y
267 | else:
268 | avg_center_x = sum(x_coords) / len(x_coords)
269 | avg_center_y = sum(y_coords) / len(y_coords)
270 | region_centers.append((avg_center_x, avg_center_y))
271 |
272 | # select top regions based on scores
273 | sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True)
274 | sorted_scores = [region_scores[i] for i in sorted_indices]
275 | sorted_centers = [region_centers[i] for i in sorted_indices]
276 | sorted_points = [region_points[i] for i in sorted_indices]
277 | best_point = sorted_centers[0]
278 |
279 |
280 | if no_groups:
281 | if return_all_regions:
282 | return sorted_centers + [[(x[1] + 0.5) / n_width, (x[0] + 0.5) /n_height] for x in topk_coords]
283 | else:
284 | return sorted_centers + [(topk_coords[0][1]+ 0.5) / n_width, (topk_coords[0][0]+ 0.5) / n_height]
285 |
286 | if return_all_regions:
287 | return best_point, sorted_centers, sorted_scores, sorted_points
288 | else:
289 | return best_point
290 |
291 |
292 |
293 |
294 | def ground_only_positive(self, instruction, image, target_point):
295 | if isinstance(image, str):
296 | image_path = image
297 | assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path."
298 | image = Image.open(image_path).convert('RGB')
299 | else:
300 | assert isinstance(image, Image.Image)
301 | image_path = image_to_temp_filename(image)
302 |
303 | width, height = image.size
304 |
305 | print(image_path)
306 | if 'v2' in image_path:
307 | key = image_path.split('/')[-1]
308 | elif 'Pro' in image_path:
309 | key = '/'.join(image_path.split('/')[-2:])
310 | else:
311 | key = image_path.split('/')[-1]
312 | key += instruction
313 | index = self.json_index_dict[key]
314 |
315 |
316 | if self.method == 'best_one':
317 | predictions = self.json_prediction[index]['topk_points']
318 | predictions = [predictions[0]] # only the first one
319 | else:
320 | attn_scores = self.json_prediction[index]['attn_scores']
321 | if 'n_width' in self.json_prediction[index]:
322 | n_width, n_height = self.json_prediction[index]['n_width'], self.json_prediction[index]['n_height']
323 | elif 'img_size_crop' in self.json_prediction[index]:
324 | n_width, n_height = self.json_prediction[index]['img_size_crop']
325 | else:
326 | raise ValueError("Invalid json prediction format. 'n_width' or 'img_size_crop' not found.")
327 | predictions = self.get_prediction_region_point(attn_scores, n_width, n_height, top_n=20, return_all_regions=True, rect_center=False, no_groups=True)
328 |
329 | pred_points_list = [[pred[0] * image.size[0], pred[1] * image.size[1]] for pred in predictions]
330 | score_list = []
331 |
332 |
333 | print(predictions, len(predictions))
334 | if len(predictions) > 1:
335 | if self.method == 'score':
336 | for point in pred_points_list[len(score_list):]:
337 | score = self.verifier_score(instruction, image, point)
338 | score_list.append(score)
339 | if score >= self.threshold:
340 | break
341 | # get the max score
342 | print(score_list, len(score_list))
343 | point = predictions[score_list.index(max(score_list))]
344 | else:
345 | point = predictions[0]
346 |
347 |
348 | result_dict = {
349 | "result": "positive",
350 | "format": "x1y1x2y2",
351 | "raw_response": pred_points_list,
352 | "bbox": None,
353 | "point": point,
354 | }
355 | return result_dict
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
--------------------------------------------------------------------------------