├── .github └── workflows │ └── codeql.yml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── assets └── images │ ├── main_figure.png │ └── title.png ├── data └── data_config.yaml ├── eval ├── screenSpot.py ├── screenSpot_pro.py └── screenSpot_v2.py ├── pyproject.toml ├── scripts ├── train.sh ├── warmup.sh └── zero3.json ├── src └── gui_actor │ ├── __init__.py │ ├── constants.py │ ├── dataset.py │ ├── inference.py │ ├── modeling.py │ ├── modeling_qwen25vl.py │ ├── trainer.py │ └── utils.py ├── train.py └── verifier ├── README.md ├── ScreenSpot-v2-new ├── screenspot_desktop_v2.json ├── screenspot_mobile_v2.json └── screenspot_web_v2.json ├── eval_ss_with_verifier.py ├── run_ss_pro.sh ├── run_ss_v1.sh ├── run_ss_v2.sh ├── verifier_data_generation.py └── verifier_model.py /.github/workflows/codeql.yml: -------------------------------------------------------------------------------- 1 | # For most projects, this workflow file will not need changing; you simply need 2 | # to commit it to your repository. 3 | # 4 | # You may wish to alter this file to override the set of languages analyzed, 5 | # or to provide custom queries or build logic. 6 | # 7 | # ******** NOTE ******** 8 | # We have attempted to detect the languages in your repository. Please check 9 | # the `language` matrix defined below to confirm you have the correct set of 10 | # supported CodeQL languages. 11 | # 12 | name: "CodeQL Advanced" 13 | 14 | on: 15 | push: 16 | branches: [ "main" ] 17 | pull_request: 18 | branches: [ "main" ] 19 | schedule: 20 | - cron: '35 12 * * 3' 21 | 22 | jobs: 23 | analyze: 24 | name: Analyze (${{ matrix.language }}) 25 | # Runner size impacts CodeQL analysis time. To learn more, please see: 26 | # - https://gh.io/recommended-hardware-resources-for-running-codeql 27 | # - https://gh.io/supported-runners-and-hardware-resources 28 | # - https://gh.io/using-larger-runners (GitHub.com only) 29 | # Consider using larger runners or machines with greater resources for possible analysis time improvements. 30 | runs-on: ${{ (matrix.language == 'swift' && 'macos-latest') || 'ubuntu-latest' }} 31 | permissions: 32 | # required for all workflows 33 | security-events: write 34 | 35 | # required to fetch internal or private CodeQL packs 36 | packages: read 37 | 38 | # only required for workflows in private repositories 39 | actions: read 40 | contents: read 41 | 42 | strategy: 43 | fail-fast: false 44 | matrix: 45 | include: 46 | - language: python 47 | build-mode: none 48 | # CodeQL supports the following values keywords for 'language': 'actions', 'c-cpp', 'csharp', 'go', 'java-kotlin', 'javascript-typescript', 'python', 'ruby', 'swift' 49 | # Use `c-cpp` to analyze code written in C, C++ or both 50 | # Use 'java-kotlin' to analyze code written in Java, Kotlin or both 51 | # Use 'javascript-typescript' to analyze code written in JavaScript, TypeScript or both 52 | # To learn more about changing the languages that are analyzed or customizing the build mode for your analysis, 53 | # see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/customizing-your-advanced-setup-for-code-scanning. 54 | # If you are analyzing a compiled language, you can modify the 'build-mode' for that language to customize how 55 | # your codebase is analyzed, see https://docs.github.com/en/code-security/code-scanning/creating-an-advanced-setup-for-code-scanning/codeql-code-scanning-for-compiled-languages 56 | steps: 57 | - name: Checkout repository 58 | uses: actions/checkout@v4 59 | 60 | # Add any setup steps before running the `github/codeql-action/init` action. 61 | # This includes steps like installing compilers or runtimes (`actions/setup-node` 62 | # or others). This is typically only required for manual builds. 63 | # - name: Setup runtime (example) 64 | # uses: actions/setup-example@v1 65 | 66 | # Initializes the CodeQL tools for scanning. 67 | - name: Initialize CodeQL 68 | uses: github/codeql-action/init@v3 69 | with: 70 | languages: ${{ matrix.language }} 71 | build-mode: ${{ matrix.build-mode }} 72 | # If you wish to specify custom queries, you can do so here or in a config file. 73 | # By default, queries listed here will override any specified in a config file. 74 | # Prefix the list here with "+" to use these queries and those in the config file. 75 | 76 | # For more details on CodeQL's query packs, refer to: https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs 77 | # queries: security-extended,security-and-quality 78 | 79 | # If the analyze step fails for one of the languages you are analyzing with 80 | # "We were unable to automatically build your code", modify the matrix above 81 | # to set the build mode to "manual" for that language. Then modify this step 82 | # to build your code. 83 | # ℹ️ Command-line programs to run using the OS shell. 84 | # 📚 See https://docs.github.com/en/actions/using-workflows/workflow-syntax-for-github-actions#jobsjob_idstepsrun 85 | - if: matrix.build-mode == 'manual' 86 | shell: bash 87 | run: | 88 | echo 'If you are using a "manual" build mode for one or more of the' \ 89 | 'languages you are analyzing, replace this with the commands to build' \ 90 | 'your code, for example:' 91 | echo ' make bootstrap' 92 | echo ' make release' 93 | exit 1 94 | 95 | - name: Perform CodeQL Analysis 96 | uses: github/codeql-action/analyze@v3 97 | with: 98 | category: "/language:${{matrix.language}}" 99 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/macos,linux,python,visualstudiocode 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=macos,linux,python,visualstudiocode 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### macOS Patch ### 49 | # iCloud generated files 50 | *.icloud 51 | 52 | ### Python ### 53 | # Byte-compiled / optimized / DLL files 54 | __pycache__/ 55 | *.py[cod] 56 | *$py.class 57 | 58 | # C extensions 59 | *.so 60 | 61 | # Distribution / packaging 62 | .Python 63 | build/ 64 | develop-eggs/ 65 | dist/ 66 | downloads/ 67 | eggs/ 68 | .eggs/ 69 | lib/ 70 | lib64/ 71 | parts/ 72 | sdist/ 73 | var/ 74 | wheels/ 75 | share/python-wheels/ 76 | *.egg-info/ 77 | .installed.cfg 78 | *.egg 79 | MANIFEST 80 | 81 | # PyInstaller 82 | # Usually these files are written by a python script from a template 83 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 84 | *.manifest 85 | *.spec 86 | 87 | # Installer logs 88 | pip-log.txt 89 | pip-delete-this-directory.txt 90 | 91 | # Unit test / coverage reports 92 | htmlcov/ 93 | .tox/ 94 | .nox/ 95 | .coverage 96 | .coverage.* 97 | .cache 98 | nosetests.xml 99 | coverage.xml 100 | *.cover 101 | *.py,cover 102 | .hypothesis/ 103 | .pytest_cache/ 104 | cover/ 105 | 106 | # Translations 107 | *.mo 108 | *.pot 109 | 110 | # Django stuff: 111 | *.log 112 | local_settings.py 113 | db.sqlite3 114 | db.sqlite3-journal 115 | 116 | # Flask stuff: 117 | instance/ 118 | .webassets-cache 119 | 120 | # Scrapy stuff: 121 | .scrapy 122 | 123 | # Sphinx documentation 124 | docs/_build/ 125 | 126 | # PyBuilder 127 | .pybuilder/ 128 | target/ 129 | 130 | # Jupyter Notebook 131 | .ipynb_checkpoints 132 | 133 | # IPython 134 | profile_default/ 135 | ipython_config.py 136 | 137 | # pyenv 138 | # For a library or package, you might want to ignore these files since the code is 139 | # intended to run in multiple environments; otherwise, check them in: 140 | # .python-version 141 | 142 | # pipenv 143 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 144 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 145 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 146 | # install all needed dependencies. 147 | #Pipfile.lock 148 | 149 | # poetry 150 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 151 | # This is especially recommended for binary packages to ensure reproducibility, and is more 152 | # commonly ignored for libraries. 153 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 154 | #poetry.lock 155 | 156 | # pdm 157 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 158 | #pdm.lock 159 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 160 | # in version control. 161 | # https://pdm.fming.dev/#use-with-ide 162 | .pdm.toml 163 | 164 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 165 | __pypackages__/ 166 | 167 | # Celery stuff 168 | celerybeat-schedule 169 | celerybeat.pid 170 | 171 | # SageMath parsed files 172 | *.sage.py 173 | 174 | # Environments 175 | .env 176 | .venv 177 | env/ 178 | venv/ 179 | ENV/ 180 | env.bak/ 181 | venv.bak/ 182 | 183 | # Spyder project settings 184 | .spyderproject 185 | .spyproject 186 | 187 | # Rope project settings 188 | .ropeproject 189 | 190 | # mkdocs documentation 191 | /site 192 | 193 | # mypy 194 | .mypy_cache/ 195 | .dmypy.json 196 | dmypy.json 197 | 198 | # Pyre type checker 199 | .pyre/ 200 | 201 | # pytype static type analyzer 202 | .pytype/ 203 | 204 | # Cython debug symbols 205 | cython_debug/ 206 | 207 | # PyCharm 208 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 209 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 210 | # and can be added to the global gitignore or merged into this file. For a more nuclear 211 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 212 | #.idea/ 213 | 214 | ### Python Patch ### 215 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 216 | poetry.toml 217 | 218 | # ruff 219 | .ruff_cache/ 220 | 221 | # LSP config files 222 | pyrightconfig.json 223 | 224 | ### VisualStudioCode ### 225 | .vscode/* 226 | !.vscode/settings.json 227 | !.vscode/tasks.json 228 | !.vscode/launch.json 229 | !.vscode/extensions.json 230 | !.vscode/*.code-snippets 231 | 232 | # Local History for Visual Studio Code 233 | .history/ 234 | 235 | # Built Visual Studio Code Extensions 236 | *.vsix 237 | 238 | ### VisualStudioCode Patch ### 239 | # Ignore all local history of files 240 | .history 241 | .ionide 242 | 243 | # End of https://www.toptal.com/developers/gitignore/api/macos,linux,python,visualstudiocode 244 | 245 | wandb/ 246 | results/* 247 | data/data_config_test.yaml 248 | checkpoints 249 | eval/inference.ipynb 250 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Microsoft 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 |
6 | 7 | [Qianhui Wu](https://qianhuiwu.github.io/)*1  8 | [Kanzhi Cheng](https://scholar.google.com/citations?user=S2IPVnwAAAAJ&hl=en&oi=ao/)*2  9 | [Rui Yang](https://yangrui2015.github.io/)*3  10 | [Chaoyun Zhang](https://vyokky.github.io/)1  11 | [Jianwei Yang](https://jwyang.github.io/)1  12 | [Huiqiang Jiang](https://hqjiang.com/)1
13 | [Jian Mu]()2  14 | [Baolin Peng](https://scholar.google.com/citations?user=u1CNjgwAAAAJ&hl=zh-CN)1  15 | [Bo Qiao](https://scholar.google.com/citations?user=_6ugrdYAAAAJ&hl=en)1  16 | [Reuben Tan](https://cs-people.bu.edu/rxtan/)1  17 | [Si Qin](https://sqin860.github.io/)1  18 | [Lars Liden](https://sites.google.com/site/larsliden)1
19 | [Qingwei Lin](https://scholar.google.com/citations?user=W9fdsxMAAAAJ&hl=zh-CN)1  20 | [Huan Zhang](https://huan-zhang.com/)3  21 | [Tong Zhang](https://tongzhang-ml.org/)3  22 | [Jianbing Zhang](https://cs.nju.edu.cn/zhangjb/index.htm)2  23 | [Dongmei Zhang](https://scholar.google.com/citations?user=jLlBBl4AAAAJ&hl=en)1  24 | [Jianfeng Gao](https://scholar.google.com/citations?user=CQ1cqKkAAAAJ&hl=en)1 25 | 26 | 1 Microsoft Research  2 Nanjing University  3 University of Illinois Urbana-Champaign
27 | * Equal Contribution     Leadership 28 | 29 |

30 | 📄 arXiv Paper   31 | 🌐 Project Page   32 | 🤗 Hugging Face Models 33 |

34 | 35 |
36 | 37 |
38 | 39 |
40 | 41 | Figure 1. **Left**: Model performance vs. training data scale on the ScreenSpot-Pro benchmark. Higher and more left is better; larger points indicate models with more parameters. We only show GUI-Actor models built upon Qwen2-VL here for fair comparison. With Qwen2.5-VL as the backbone, GUI-Actor-3B/7B reaches scores up to 42.2/44.6 (without Verifier). **Right**: Illustration of action attention. GUI-Actor grounds target elements by attending to the most relevant visual regions. 42 | 43 | ## :sparkles: Highlights 44 | 🤔 **We identify several limitations in coordinate-generation based methods** (_i.e._, output screen positions as text tokens x=…, y=…) for GUI grounding, including (1) weak spatial-semantic alignment, (2) ambiguous supervision signals, and (3) granularity mismatch between vision and action space. 45 | 46 | 💡 **Rethink how humans interact with digital interfaces**: humans do NOT calculate precise screen coordinates before acting—they perceive the target element and interact with it directly. 47 | 48 | 🚀 **We propose _GUI-Actor_, a VLM enhanced by an action head, to mitigate the above limitations.** The attention-based action head not only enables GUI-Actor to peform coordinate-free GUI grounding that more closely aligns with human behavior, but also can generate multiple candidate regions in a single forward pass, offering flexibility for downstream modules such as search strategies. 49 | 50 | ➕ **We design a _grounding verifier_ to evaluate and select the most plausible action region** among the candidates proposed from the action attention map. We show that this verifier can be easily integrated with other grounding methods for a further performance boost. 51 | 52 | 🎯 **GUI-Actor achieves state-of-the-art performance on multiple GUI action grounding benchmarks** with the same Qwen2-VL backbone, demonstrating its effectiveness and generalization to unseen screen resolutions and layouts. Notably, GUI-Actor-7B even surpasses UI-TARS-72B (38.1) on **ScreenSpot-Pro**, achieving scores of **40.7** with Qwen2-VL and **44.6** with Qwen2.5-VL as backbones. 53 | 54 | 57 | 58 | ## :bookmark_tabs: Todos 59 | We will be releasing all the following contents: 60 | - [x] Model training and evaluation based on Qwen2-VL (2025.06.03) 61 | - [x] Model checkpoint (2025.06.03) 62 | - [x] Code for grounding verifier (2025.06.06) 63 | - [ ] Support for Qwen2.5-VL 64 | - [ ] Processed training data 65 | - [ ] Demo 66 | 67 | ## :bar_chart: Main Results 68 | Table 1. Main results on ScreenSpot-Pro, ScreenSpot, and ScreenSpot-v2 with **Qwen2-VL** as the backbone. † indicates scores obtained from our own evaluation of the official models on Huggingface. 69 | | Method | Backbone VLM | ScreenSpot-Pro | ScreenSpot | ScreenSpot-v2 | 70 | |------------------|--------------|----------------|------------|----------------| 71 | | **_72B models:_** 72 | | AGUVIS-72B | Qwen2-VL | - | 89.2 | - | 73 | | UGround-V1-72B | Qwen2-VL | 34.5 | **89.4** | - | 74 | | UI-TARS-72B | Qwen2-VL | **38.1** | 88.4 | **90.3** | 75 | | **_7B models:_** 76 | | OS-Atlas-7B | Qwen2-VL | 18.9 | 82.5 | 84.1 | 77 | | AGUVIS-7B | Qwen2-VL | 22.9 | 84.4 | 86.0† | 78 | | UGround-V1-7B | Qwen2-VL | 31.1 | 86.3 | 87.6† | 79 | | UI-TARS-7B | Qwen2-VL | 35.7 | **89.5** | **91.6** | 80 | | GUI-Actor-7B | Qwen2-VL | **40.7** | 88.3 | 89.5 | 81 | | GUI-Actor-7B + Verifier | Qwen2-VL | 44.2 | 89.7 | 90.9 | 82 | | **_2B models:_** 83 | | UGround-V1-2B | Qwen2-VL | 26.6 | 77.1 | - | 84 | | UI-TARS-2B | Qwen2-VL | 27.7 | 82.3 | 84.7 | 85 | | GUI-Actor-2B | Qwen2-VL | **36.7** | **86.5** | **88.6** | 86 | | GUI-Actor-2B + Verifier | Qwen2-VL | 41.8 | 86.9 | 89.3 | 87 | 88 | Table 2. Main results on the ScreenSpot-Pro and ScreenSpot-v2 with **Qwen2.5-VL** as the backbone. 89 | | Method | Backbone VLM | ScreenSpot-Pro | ScreenSpot-v2 | 90 | |----------------|---------------|----------------|----------------| 91 | | **_7B models:_** 92 | | Qwen2.5-VL-7B | Qwen2.5-VL | 27.6 | 88.8 | 93 | | Jedi-7B | Qwen2.5-VL | 39.5 | 91.7 | 94 | | GUI-Actor-7B | Qwen2.5-VL | **44.6** | **92.1** | 95 | | GUI-Actor-7B + Verifier | Qwen2.5-VL | 47.7 | 92.5 | 96 | | **_3B models:_** 97 | | Qwen2.5-VL-3B | Qwen2.5-VL | 25.9 | 80.9 | 98 | | Jedi-3B | Qwen2.5-VL | 36.1 | 88.6 | 99 | | GUI-Actor-3B | Qwen2.5-VL | **42.2** | **91.0** | 100 | | GUI-Actor-3B + Verifier | Qwen2.5-VL | 45.9 | 92.4 | 101 | 102 | ## :rescue_worker_helmet: Installation 103 | 1. Clone this repo to your local machine: 104 | ```bash 105 | git clone https://github.com/microsoft/GUI-Actor.git 106 | cd GUI-Actor 107 | ``` 108 | 2. Create a conda environment and install the dependencies: 109 | ```bash 110 | conda create -n gui_actor python=3.10 111 | conda activate gui_actor 112 | conda install pytorch torchvision torchaudio pytorch-cuda -c pytorch -c nvidia 113 | pip install -e . 114 | ``` 115 | ## :minidisc: Data Preparation 116 | 1. Download the processed data from [here (coming soon)](). 117 | 2. Modify the paths in the [data_config.yaml](./data/data_config.yaml) file to point to the downloaded data. 118 | 119 | ## :building_construction: Model Training 120 | 1. Warmup stage: 121 | ```bash 122 | bash scripts/warmup.sh 123 | ``` 124 | 2. Full-parameter training stage: 125 | ```bash 126 | bash scripts/train.sh 127 | ``` 128 | 129 | ## :checkered_flag: Evaluation on GUI Grounding Benchmarks 130 | For evaluation on ScreenSpot and ScreenSpot-v2, you can directly run the scripts under the `scripts/` folder like `python eval/screenSpot.py` or `python eval/screenSpot_v2.py`. 131 | 132 | For evaluation on ScreenSpot-Pro, you first need to download the data from [here](https://huggingface.co/datasets/likaixin/ScreenSpot-Pro), then run the following command: 133 | ```bash 134 | python eval/screenSpot_pro.py --save_path --data_path 135 | ``` 136 | 137 | Example usage: 138 | ```python 139 | import torch 140 | 141 | from qwen_vl_utils import process_vision_info 142 | from datasets import load_dataset 143 | from transformers import AutoProcessor 144 | from gui_actor.constants import chat_template 145 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer 146 | from gui_actor.inference import inference 147 | 148 | 149 | # load model 150 | model_name_or_path = "microsoft/GUI-Actor-7B-Qwen2-VL" 151 | data_processor = AutoProcessor.from_pretrained(model_name_or_path) 152 | tokenizer = data_processor.tokenizer 153 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained( 154 | model_name_or_path, 155 | torch_dtype=torch.bfloat16, 156 | device_map="cuda:0", 157 | attn_implementation="flash_attention_2" 158 | ).eval() 159 | 160 | # prepare example 161 | dataset = load_dataset("rootsautomation/ScreenSpot")["test"] 162 | example = dataset[0] 163 | print(f"Intruction: {example['instruction']}") 164 | print(f"ground-truth action region (x1, y1, x2, y2): {[round(i, 2) for i in example['bbox']]}") 165 | 166 | conversation = [ 167 | { 168 | "role": "system", 169 | "content": [ 170 | { 171 | "type": "text", 172 | "text": "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task.", 173 | } 174 | ] 175 | }, 176 | { 177 | "role": "user", 178 | "content": [ 179 | { 180 | "type": "image", 181 | "image": example["image"], # PIL.Image.Image or str to path 182 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," 183 | }, 184 | { 185 | "type": "text", 186 | "text": example["instruction"] 187 | }, 188 | ], 189 | }, 190 | ] 191 | 192 | # inference 193 | pred = inference(conversation, model, tokenizer, data_processor, use_placeholder=True, topk=3) 194 | px, py = pred["topk_points"][0] 195 | print(f"Predicted click point: [{round(px, 4)}, {round(py, 4)}]") 196 | 197 | # >> Model Response 198 | # Intruction: close this window 199 | # ground-truth action region (x1, y1, x2, y2): [0.9479, 0.1444, 0.9938, 0.2074] 200 | # Predicted click point: [0.9709, 0.1548] 201 | ``` 202 | 203 | ## :+1: Acknowledgements 204 | 205 | This project is built upon the following projects. Thanks for their great work! 206 | - [Transformers](https://github.com/huggingface/transformers) 207 | - [Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL) 208 | - [AGUVIS](https://github.com/xlang-ai/aguvis) 209 | 210 | We also thank the authors of the following projects for their insightful work, as well as for providing datasets and engaging in valuable discussions. 211 | - [AGUVIS](https://github.com/xlang-ai/aguvis) 212 | - [UGround](https://github.com/OSU-NLP-Group/UGround) 213 | - [OS-Atlas](https://github.com/OS-Copilot/OS-Atlas) 214 | - [SeeClick](https://github.com/njucckevin/SeeClick) 215 | 216 | ## :memo: Citation 217 | If you find this work useful in your research, please consider citing: 218 | ```bibtex 219 | @misc{wu2025guiactor, 220 | title={GUI-Actor: Coordinate-Free Visual Grounding for GUI Agents}, 221 | author={Qianhui Wu and Kanzhi Cheng and Rui Yang and Chaoyun Zhang and Jianwei Yang and Huiqiang Jiang and Jian Mu and Baolin Peng and Bo Qiao and Reuben Tan and Si Qin and Lars Liden and Qingwei Lin and Huan Zhang and Tong Zhang and Jianbing Zhang and Dongmei Zhang and Jianfeng Gao}, 222 | year={2025}, 223 | eprint={2506.03143}, 224 | archivePrefix={arXiv}, 225 | primaryClass={cs.CV}, 226 | url={https://arxiv.org/abs/2506.03143}, 227 | } 228 | ``` 229 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/Microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet) and [Xamarin](https://github.com/xamarin). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/security.md/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/security.md/msrc/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/security.md/msrc/pgp). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://www.microsoft.com/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/security.md/msrc/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/security.md/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /assets/images/main_figure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/assets/images/main_figure.png -------------------------------------------------------------------------------- /assets/images/title.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/assets/images/title.png -------------------------------------------------------------------------------- /data/data_config.yaml: -------------------------------------------------------------------------------- 1 | datasets: 2 | - json_path: /mnt/datasets/Uground/uground_aguvis_bbox_filter.json 3 | images_folder: /mnt/datasets/Uground/images/ 4 | sampling_strategy: "all" 5 | - json_path: /mnt/datasets/GUIEnv/guienv_aguvis_bbox.json 6 | images_folder: /mnt/datasets/GUIEnv/guienvs/images/ 7 | sampling_strategy: "all" 8 | - json_path: /mnt/datasets/GUIAct/guiact_aguvis_bbox.json 9 | images_folder: /mnt/datasets/GUIAct/web_imgs/ 10 | sampling_strategy: "all" 11 | - json_path: /mnt/datasets/AMEX/amex_aguvis_bbox.json 12 | images_folder: /mnt/datasets/AMEX/screenshots/ 13 | sampling_strategy: "all" 14 | - json_path: /mnt/datasets/AndroidControl/androidcontrol_aguvis_bbox.json 15 | images_folder: /mnt/datasets/AndroidControl/tfrecord/images/ 16 | sampling_strategy: "all" 17 | - json_path: /mnt/datasets/Wave-UI/wave_ui_aguvis_bbox_fixed.json 18 | images_folder: /mnt/datasets/Wave-UI/images_fixed/ 19 | sampling_strategy: "all" -------------------------------------------------------------------------------- /eval/screenSpot.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | from transformers import AutoProcessor 9 | 10 | from gui_actor.constants import chat_template 11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer 12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer 13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor 14 | from gui_actor.utils import do_boxes_overlap 15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN 16 | 17 | IMAGE_PATCH_SIZE =14 18 | 19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height): 20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it 21 | x1, y1, x2, y2 = bbox_x1y1x2y2 22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1): 23 | return bbox_x1y1x2y2 24 | else: 25 | x1 = x1 / img_width 26 | y1 = y1 / img_height 27 | x2 = x2 / img_width 28 | y2 = y2 / img_height 29 | return x1, y1, x2, y2 30 | 31 | def evaluate(model_name_or_path, model_type, use_placeholder, topk): 32 | # initialize model 33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path) 34 | tokenizer = data_processor.tokenizer 35 | for k, v in tokenizer.added_tokens_encoder.items(): 36 | print(v, k) 37 | 38 | if model_type == "qwen2vl": 39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}") 40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained( 41 | model_name_or_path, 42 | torch_dtype=torch.bfloat16, 43 | device_map="cuda:0", 44 | attn_implementation="flash_attention_2" 45 | ).eval() 46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task." 47 | elif model_type == "qwen25vl": 48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}") 49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained( 50 | model_name_or_path, 51 | torch_dtype=torch.bfloat16, 52 | device_map="cuda:0", 53 | attn_implementation="flash_attention_2" 54 | ).eval() 55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()." 56 | else: 57 | raise ValueError(f"Invalid model type: {model_type}") 58 | print(f"Loaded model from {model_name_or_path}") 59 | 60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor( 61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0], 62 | forced_sequence=[ 63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] 64 | ] 65 | ) 66 | 67 | dataset = load_dataset("rootsautomation/ScreenSpot")["test"] 68 | domain_dict = { 69 | "windows": "desktop", 70 | "macos": "desktop", 71 | "ios": "mobile", 72 | "android": "mobile", 73 | "tool": "web", 74 | "shop": "web", 75 | "gitlab": "web", 76 | "forum": "web" 77 | } 78 | 79 | results = [] 80 | for i, example in tqdm(enumerate(dataset), total=len(dataset)): 81 | ele = { 82 | "file_name": example["file_name"], 83 | "data_type": example["data_type"], 84 | "domain": domain_dict[example["data_source"]], 85 | "instruction": example["instruction"], 86 | "img_size": example["image"].size, 87 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["image"].size[0], example["image"].size[1]), 88 | "hit_top1": 0, 89 | "overlap_top1": 0, 90 | "hit_topk": 0, 91 | "overlap_topk": 0, 92 | } 93 | 94 | conversation = [ 95 | { 96 | "role": "system", 97 | "content": [ 98 | { 99 | "type": "text", 100 | "text": grounding_system_message, 101 | } 102 | ] 103 | }, 104 | { 105 | "role": "user", 106 | "content": [ 107 | { 108 | "type": "image", 109 | "image": example["image"], # PIL.Image.Image or str to path 110 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," 111 | }, 112 | { 113 | "type": "text", 114 | "text": example["instruction"] 115 | }, 116 | ], 117 | }, 118 | ] 119 | 120 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3) 121 | topk_points = pred["topk_points"] 122 | gt_bbox = ele["bbox_x1y1x2y2"] 123 | 124 | # compute the metrics 125 | px, py = topk_points[0] 126 | x1, y1, x2, y2 = gt_bbox 127 | 128 | if (x1 <= px <= x2) and (y1 <= py <= y2): 129 | ele["hit_top1"] = 1 130 | ele["hit_topk"] = 1 131 | 132 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 133 | if do_boxes_overlap(pred_bbox, gt_bbox): 134 | ele["overlap_top1"] = 1 135 | ele["overlap_topk"] = 1 136 | 137 | for px, py in topk_points[1:]: 138 | if (x1 <= px <= x2) and (y1 <= py <= y2): 139 | ele["hit_topk"] = 1 140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 141 | if do_boxes_overlap(pred_bbox, gt_bbox): 142 | ele["overlap_topk"] = 1 143 | 144 | results.append(ele) 145 | 146 | return results 147 | 148 | 149 | def get_metric(list_of_examples, 150 | domains=["mobile", "desktop", "web"], 151 | data_types=["text", "icon"]): 152 | """ 153 | Computes metrics over a list of examples and prints/plots a table. 154 | 155 | Each element in list_of_examples is a dict containing: 156 | - "domain": Domain name (e.g., "web", "mobile", "desktop") 157 | - "data_type": Data type (e.g., "text", "icon") 158 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1) 159 | 160 | The final table has columns for each domain broken down by UI type (plus a domain-average) 161 | and overall columns ("All-text", "All-icon", "All-average"). 162 | 163 | The rows of the table are: 164 | - hit_top1 165 | - overlap_top1 166 | - hit_topk 167 | - overlap_topk 168 | """ 169 | 170 | # List of metric keys to compute. 171 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"] 172 | 173 | # Helper function to compute the mean of a given key from a list of examples. 174 | def compute_mean(examples, key): 175 | if not examples: 176 | return None 177 | return sum(example.get(key, 0) for example in examples) / len(examples) 178 | 179 | # Prepare results dictionary: structure {metric: {column_name: value}}. 180 | results = {metric: {} for metric in metrics} 181 | 182 | # Compute metrics for each group broken down by UI type. 183 | for domain in domains: 184 | # Filter examples for the current group. 185 | domain_examples = [ex for ex in list_of_examples if ex.get("domain") == domain] 186 | for data_type in data_types: 187 | # Filter further for the specific UI type. 188 | domain_data_type_examples = [ex for ex in domain_examples if ex.get("data_type") == data_type] 189 | col_name = f"{domain}-{data_type}" 190 | for metric in metrics: 191 | results[metric][col_name] = compute_mean(domain_data_type_examples, metric) 192 | 193 | # Compute domain-average (all UI types for this domain). 194 | col_name_avg = f"{domain}-avg" 195 | for metric in metrics: 196 | results[metric][col_name_avg] = compute_mean(domain_examples, metric) 197 | 198 | # Compute overall metrics for each UI type across all domains. 199 | for data_type in data_types: 200 | data_type_examples = [ex for ex in list_of_examples if ex.get("data_type") == data_type] 201 | col_name = f"All-{data_type}" 202 | for metric in metrics: 203 | results[metric][col_name] = compute_mean(data_type_examples, metric) 204 | 205 | # Compute overall average across all examples. 206 | overall_key = "All-avg" 207 | for metric in metrics: 208 | results[metric][overall_key] = compute_mean(list_of_examples, metric) 209 | 210 | # Define the order of columns. 211 | columns_order = [] 212 | for domain in domains: 213 | for data_type in data_types: 214 | columns_order.append(f"{domain}-{data_type}") 215 | columns_order.append(f"{domain}-avg") 216 | for data_type in data_types: 217 | columns_order.append(f"All-{data_type}") 218 | columns_order.append("All-avg") 219 | 220 | # ------------- Print Table to Console ------------- 221 | # Prepare header row. 222 | header = [""] + columns_order 223 | # Calculate column widths for console printing. 224 | col_widths = [max(len(col), 12) for col in header] 225 | 226 | def format_cell(cell): 227 | if isinstance(cell, float): 228 | return f"{cell*100:.2f}" 229 | elif cell is None: 230 | return "N/A" 231 | return str(cell) 232 | 233 | # Print header. 234 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths)) 235 | separator_line = "-+-".join("-" * width for width in col_widths) 236 | print(header_line) 237 | print(separator_line) 238 | 239 | for metric in metrics: 240 | row = [metric] 241 | for col in columns_order: 242 | val = results[metric].get(col) 243 | row.append(format_cell(val)) 244 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths)) 245 | print(row_line) 246 | 247 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) ------------- 248 | metric_info = "Tab-delimited Table for Excel:\n" 249 | # Header row. 250 | header_tab = "\t".join([""] + columns_order) 251 | metric_info += (header_tab + "\n") 252 | # Each row. 253 | for metric in metrics: 254 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order] 255 | metric_info += ("\t".join(row) + "\n") 256 | print(metric_info) 257 | return metric_info 258 | 259 | 260 | """ 261 | # cd to project root directory 262 | python eval/screenSpot.py --save_path 263 | """ 264 | if __name__ == "__main__": 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"]) 267 | parser.add_argument("--model_name_or_path", type=str, default="qianhuiwu/GUI-Actor-3B-Qwen-2.5-VL") 268 | parser.add_argument("--save_path", type=str, default="./") 269 | parser.add_argument('--topk', type=int, default=3, help='Topk') 270 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder') 271 | parser.set_defaults(use_placeholder=True) 272 | 273 | args = parser.parse_args() 274 | 275 | save_path = args.save_path 276 | if not os.path.exists(save_path): 277 | os.makedirs(save_path, exist_ok=True) 278 | pred_path = f"{save_path}/screenspot_all_preds.json" 279 | metric_path = f"{save_path}/screenspot_all_metrics.txt" 280 | 281 | if os.path.exists(metric_path): 282 | exit() 283 | 284 | if os.path.exists(pred_path): 285 | print(f"Loading predictions from {pred_path}") 286 | with open(pred_path, "r") as f: 287 | results = json.load(f) 288 | else: 289 | print(f"Evaluating {args.model_name_or_path}...") 290 | results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk) 291 | with open(pred_path, "w") as f: 292 | json.dump(results, f) 293 | print(f"Saved {len(results)} predictions to {pred_path}") 294 | 295 | if not os.path.exists(metric_path): 296 | metric_info = get_metric(results) 297 | with open(metric_path, "w") as f: 298 | f.write(metric_info) 299 | print(f"Saved metric to {metric_path}") 300 | -------------------------------------------------------------------------------- /eval/screenSpot_pro.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | from transformers import AutoProcessor 9 | from PIL import Image 10 | from gui_actor.constants import chat_template 11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer 12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer 13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor 14 | from gui_actor.utils import do_boxes_overlap 15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN 16 | 17 | IMAGE_PATCH_SIZE =14 18 | 19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height): 20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it 21 | x1, y1, x2, y2 = bbox_x1y1x2y2 22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1): 23 | return bbox_x1y1x2y2 24 | else: 25 | x1 = x1 / img_width 26 | y1 = y1 / img_height 27 | x2 = x2 / img_width 28 | y2 = y2 / img_height 29 | return x1, y1, x2, y2 30 | 31 | def evaluate(model_name_or_path, model_type, data_fn, image_dir, use_placeholder, topk, resize_to_pixels=None): 32 | # initialize model 33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path) 34 | tokenizer = data_processor.tokenizer 35 | for k, v in tokenizer.added_tokens_encoder.items(): 36 | print(v, k) 37 | 38 | if model_type == "qwen2vl": 39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}") 40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained( 41 | model_name_or_path, 42 | torch_dtype=torch.bfloat16, 43 | device_map="cuda:0", 44 | attn_implementation="flash_attention_2" 45 | ).eval() 46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task." 47 | elif model_type == "qwen25vl": 48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}") 49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained( 50 | model_name_or_path, 51 | torch_dtype=torch.bfloat16, 52 | device_map="cuda:0", 53 | attn_implementation="flash_attention_2" 54 | ).eval() 55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()." 56 | else: 57 | raise ValueError(f"Invalid model type: {model_type}") 58 | print(f"Loaded model from {model_name_or_path}") 59 | 60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor( 61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0], 62 | forced_sequence=[ 63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] 64 | ] 65 | ) 66 | 67 | # load data 68 | with open(data_fn, "r") as f: 69 | data = json.load(f) 70 | print(f"Loaded {len(data)} examples from {data_fn}") 71 | 72 | results = [] 73 | for i, example in tqdm(enumerate(data), total=len(data)): 74 | ele = { 75 | "file_name": example["img_filename"], 76 | "ui_type": example["ui_type"], 77 | "group": example["group"], 78 | "platform": example["platform"], 79 | "application": example["application"], 80 | "id": example["id"], 81 | "instruction": example["instruction"], 82 | "img_size": example["img_size"], 83 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["img_size"][0], example["img_size"][1]), 84 | "hit_top1": 0, 85 | "overlap_top1": 0, 86 | "hit_topk": 0, 87 | "overlap_topk": 0, 88 | } 89 | 90 | image_path = os.path.join(image_dir, example["img_filename"]) 91 | image = Image.open(image_path) 92 | # resize the image if needed 93 | image_width, image_height = example["img_size"] 94 | if (resize_to_pixels is not None) and ((image_width * image_height) != resize_to_pixels): 95 | resize_ratio = (resize_to_pixels / (image_width * image_height)) ** 0.5 96 | image_width_resized, image_height_resized = int(image_width * resize_ratio), int(image_height * resize_ratio) 97 | image = image.resize((image_width_resized, image_height_resized)) 98 | ele["img_size_resized"] = [image_width_resized, image_height_resized] 99 | else: 100 | ele["img_size_resized"] = None 101 | 102 | conversation = [ 103 | { 104 | "role": "system", 105 | "content": [ 106 | { 107 | "type": "text", 108 | "text": grounding_system_message, 109 | } 110 | ] 111 | }, 112 | { 113 | "role": "user", 114 | "content": [ 115 | { 116 | "type": "image", 117 | "image": image, # PIL.Image.Image or str to path 118 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," 119 | }, 120 | { 121 | "type": "text", 122 | "text": example["instruction"] 123 | }, 124 | ], 125 | }, 126 | ] 127 | 128 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3) 129 | topk_points = pred["topk_points"] 130 | gt_bbox = ele["bbox_x1y1x2y2"] 131 | 132 | # compute the metrics 133 | px, py = topk_points[0] 134 | x1, y1, x2, y2 = gt_bbox 135 | 136 | if (x1 <= px <= x2) and (y1 <= py <= y2): 137 | ele["hit_top1"] = 1 138 | ele["hit_topk"] = 1 139 | 140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 141 | if do_boxes_overlap(pred_bbox, gt_bbox): 142 | ele["overlap_top1"] = 1 143 | ele["overlap_topk"] = 1 144 | 145 | for px, py in topk_points[1:]: 146 | if (x1 <= px <= x2) and (y1 <= py <= y2): 147 | ele["hit_topk"] = 1 148 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 149 | if do_boxes_overlap(pred_bbox, gt_bbox): 150 | ele["overlap_topk"] = 1 151 | 152 | results.append(ele) 153 | 154 | return results 155 | 156 | 157 | def get_metric(list_of_examples, 158 | groups=["Dev", "Creative", "CAD", "Scientific", "Office", "OS"], 159 | ui_types=["text", "icon"]): 160 | """ 161 | Computes metrics over a list of examples and prints/plots a table. 162 | 163 | Each element in list_of_examples is a dict containing: 164 | - "group": Group name (e.g., "Dev", "Creative", etc.) 165 | - "ui_type": UI type (e.g., "text", "icon") 166 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1) 167 | 168 | The final table has columns for each group broken down by UI type (plus a group-average) 169 | and overall columns ("All-text", "All-icon", "All-average"). 170 | 171 | The rows of the table are: 172 | - hit_top1 173 | - overlap_top1 174 | - hit_topk 175 | - overlap_topk 176 | """ 177 | 178 | # List of metric keys to compute. 179 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"] 180 | 181 | # Helper function to compute the mean of a given key from a list of examples. 182 | def compute_mean(examples, key): 183 | if not examples: 184 | return None 185 | return sum(example.get(key, 0) for example in examples) / len(examples) 186 | 187 | # Prepare results dictionary: structure {metric: {column_name: value}}. 188 | results = {metric: {} for metric in metrics} 189 | 190 | # Compute metrics for each group broken down by UI type. 191 | for group in groups: 192 | # Filter examples for the current group. 193 | group_examples = [ex for ex in list_of_examples if ex.get("group") == group] 194 | for ui in ui_types: 195 | # Filter further for the specific UI type. 196 | group_ui_examples = [ex for ex in group_examples if ex.get("ui_type") == ui] 197 | col_name = f"{group}-{ui}" 198 | for metric in metrics: 199 | results[metric][col_name] = compute_mean(group_ui_examples, metric) 200 | 201 | # Compute group-average (all UI types for this group). 202 | col_name_avg = f"{group}-avg" 203 | for metric in metrics: 204 | results[metric][col_name_avg] = compute_mean(group_examples, metric) 205 | 206 | # Compute overall metrics for each UI type across all groups. 207 | for ui in ui_types: 208 | ui_examples = [ex for ex in list_of_examples if ex.get("ui_type") == ui] 209 | col_name = f"All-{ui}" 210 | for metric in metrics: 211 | results[metric][col_name] = compute_mean(ui_examples, metric) 212 | 213 | # Compute overall average across all examples. 214 | overall_key = "All-avg" 215 | for metric in metrics: 216 | results[metric][overall_key] = compute_mean(list_of_examples, metric) 217 | 218 | # Define the order of columns. 219 | columns_order = [] 220 | for group in groups: 221 | for ui in ui_types: 222 | columns_order.append(f"{group}-{ui}") 223 | columns_order.append(f"{group}-avg") 224 | for ui in ui_types: 225 | columns_order.append(f"All-{ui}") 226 | columns_order.append("All-avg") 227 | 228 | # ------------- Print Table to Console ------------- 229 | # Prepare header row. 230 | header = [""] + columns_order 231 | # Calculate column widths for console printing. 232 | col_widths = [max(len(col), 12) for col in header] 233 | 234 | def format_cell(cell): 235 | if isinstance(cell, float): 236 | return f"{cell*100:.2f}" 237 | elif cell is None: 238 | return "N/A" 239 | return str(cell) 240 | 241 | # Print header. 242 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths)) 243 | separator_line = "-+-".join("-" * width for width in col_widths) 244 | print(header_line) 245 | print(separator_line) 246 | 247 | for metric in metrics: 248 | row = [metric] 249 | for col in columns_order: 250 | val = results[metric].get(col) 251 | row.append(format_cell(val)) 252 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths)) 253 | print(row_line) 254 | 255 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) ------------- 256 | metric_info = "Tab-delimited Table for Excel:\n" 257 | # Header row. 258 | header_tab = "\t".join([""] + columns_order) 259 | metric_info += header_tab + "\n" 260 | # Each row. 261 | for metric in metrics: 262 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order] 263 | metric_info += ("\t".join(row) + "\n") 264 | print(metric_info) 265 | return metric_info 266 | 267 | 268 | """ 269 | # cd to project root directory 270 | python eval/screenSpot_pro.py --save_path --data_path 271 | """ 272 | if __name__ == "__main__": 273 | parser = argparse.ArgumentParser() 274 | parser.add_argument("--model_type", type=str, default="qwen25vl", choices=["qwen2vl", "qwen25vl"]) 275 | parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-7B-Qwen2.5-VL") 276 | parser.add_argument("--save_path", type=str, default="./") 277 | parser.add_argument("--data_path", type=str, default="/mnt/data/ScreenSpot-Pro") 278 | parser.add_argument("--resize_to_pixels", type=int, default=3200*1800, help="If set to <0, will not resize the image.") 279 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder') 280 | parser.add_argument('--topk', type=int, default=3, help='Topk') 281 | parser.set_defaults(use_placeholder=True) 282 | 283 | args = parser.parse_args() 284 | 285 | resize_to_pixels = args.resize_to_pixels if args.resize_to_pixels > 0 else None 286 | image_dir = os.path.join(args.data_path, "images") 287 | data_fn = os.path.join(args.data_path, "annotations/all.json") 288 | save_path = args.save_path 289 | if not os.path.exists(save_path): 290 | os.makedirs(save_path, exist_ok=True) 291 | pred_path = f"{save_path}/screenspot-Pro_all_preds_StandardResize.json" 292 | metric_path = f"{save_path}/screenspot-Pro_all_preds_StandardResize.txt" 293 | 294 | if os.path.exists(metric_path): 295 | exit() 296 | 297 | if os.path.exists(pred_path): 298 | print(f"Loading predictions from {pred_path}") 299 | with open(pred_path, "r") as f: 300 | results = json.load(f) 301 | else: 302 | print(f"Evaluating {args.model_name_or_path}...") 303 | results = evaluate(args.model_name_or_path, args.model_type, data_fn, image_dir, args.use_placeholder, args.topk, resize_to_pixels) 304 | with open(pred_path, "w") as f: 305 | json.dump(results, f) 306 | print(f"Saved {len(results)} predictions to {pred_path}") 307 | 308 | if not os.path.exists(metric_path): 309 | metric_info = get_metric(results) 310 | with open(metric_path, "w") as f: 311 | f.write(metric_info) 312 | print(f"Saved metric to {metric_path}") 313 | -------------------------------------------------------------------------------- /eval/screenSpot_v2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import json 4 | import argparse 5 | 6 | from tqdm import tqdm 7 | from datasets import load_dataset 8 | from transformers import AutoProcessor 9 | 10 | from gui_actor.constants import chat_template 11 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer 12 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer 13 | from gui_actor.inference import inference, ForceFollowTokensLogitsProcessor 14 | from gui_actor.utils import do_boxes_overlap 15 | from gui_actor.constants import DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN 16 | 17 | IMAGE_PATCH_SIZE =14 18 | 19 | def normalize_bbox(bbox_x1y1x2y2, img_width, img_height): 20 | # if bbox_x1y1x2y2 is not normalized to [0, 1], normalize it 21 | x1, y1, x2, y2 = bbox_x1y1x2y2 22 | if (0 <= x1 <= 1) and (0 <= y1 <= 1) and (0 <= x2 <= 1) and (0 <= y2 <= 1): 23 | return bbox_x1y1x2y2 24 | else: 25 | x1 = x1 / img_width 26 | y1 = y1 / img_height 27 | x2 = x2 / img_width 28 | y2 = y2 / img_height 29 | return x1, y1, x2, y2 30 | 31 | def evaluate(model_name_or_path, model_type, use_placeholder, topk): 32 | # initialize model 33 | data_processor = AutoProcessor.from_pretrained(model_name_or_path) 34 | tokenizer = data_processor.tokenizer 35 | for k, v in tokenizer.added_tokens_encoder.items(): 36 | print(v, k) 37 | 38 | if model_type == "qwen2vl": 39 | print(f"Loading model with Qwen2-VL backbone from {model_name_or_path}") 40 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained( 41 | model_name_or_path, 42 | torch_dtype=torch.bfloat16, 43 | device_map="cuda:0", 44 | attn_implementation="flash_attention_2" 45 | ).eval() 46 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task." 47 | elif model_type == "qwen25vl": 48 | print(f"Loading model with Qwen2.5-VL backbone from {model_name_or_path}") 49 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained( 50 | model_name_or_path, 51 | torch_dtype=torch.bfloat16, 52 | device_map="cuda:0", 53 | attn_implementation="flash_attention_2" 54 | ).eval() 55 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()." 56 | else: 57 | raise ValueError(f"Invalid model type: {model_type}") 58 | print(f"Loaded model from {model_name_or_path}") 59 | 60 | logits_processor_pointer = ForceFollowTokensLogitsProcessor( 61 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0], 62 | forced_sequence=[ 63 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] 64 | ] 65 | ) 66 | 67 | dataset = load_dataset("HongxinLi/ScreenSpot_v2")["test"] 68 | domain_dict = { 69 | "windows": "desktop", 70 | "macos": "desktop", 71 | "ios": "mobile", 72 | "android": "mobile", 73 | "tool": "web", 74 | "shop": "web", 75 | "gitlab": "web", 76 | "forum": "web" 77 | } 78 | 79 | results = [] 80 | for i, example in tqdm(enumerate(dataset), total=len(dataset)): 81 | ele = { 82 | "file_name": example["file_name"], 83 | "data_type": example["data_type"], 84 | "domain": domain_dict[example["data_source"]], 85 | "instruction": example["instruction"], 86 | "img_size": example["image"].size, 87 | "bbox_x1y1x2y2": normalize_bbox(example["bbox"], example["image"].size[0], example["image"].size[1]), 88 | "hit_top1": 0, 89 | "overlap_top1": 0, 90 | "hit_topk": 0, 91 | "overlap_topk": 0, 92 | } 93 | 94 | conversation = [ 95 | { 96 | "role": "system", 97 | "content": [ 98 | { 99 | "type": "text", 100 | "text": grounding_system_message, 101 | } 102 | ] 103 | }, 104 | { 105 | "role": "user", 106 | "content": [ 107 | { 108 | "type": "image", 109 | "image": example["image"], # PIL.Image.Image or str to path 110 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," 111 | }, 112 | { 113 | "type": "text", 114 | "text": example["instruction"] 115 | }, 116 | ], 117 | }, 118 | ] 119 | 120 | pred = inference(conversation, model, tokenizer, data_processor, logits_processor=logits_processor_pointer, use_placeholder=use_placeholder, topk=3) 121 | topk_points = pred["topk_points"] 122 | gt_bbox = ele["bbox_x1y1x2y2"] 123 | 124 | # compute the metrics 125 | px, py = topk_points[0] 126 | x1, y1, x2, y2 = gt_bbox 127 | 128 | if (x1 <= px <= x2) and (y1 <= py <= y2): 129 | ele["hit_top1"] = 1 130 | ele["hit_topk"] = 1 131 | 132 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 133 | if do_boxes_overlap(pred_bbox, gt_bbox): 134 | ele["overlap_top1"] = 1 135 | ele["overlap_topk"] = 1 136 | 137 | for px, py in topk_points[1:]: 138 | if (x1 <= px <= x2) and (y1 <= py <= y2): 139 | ele["hit_topk"] = 1 140 | pred_bbox = [px - IMAGE_PATCH_SIZE, py - IMAGE_PATCH_SIZE, px + IMAGE_PATCH_SIZE, py + IMAGE_PATCH_SIZE] 141 | if do_boxes_overlap(pred_bbox, gt_bbox): 142 | ele["overlap_topk"] = 1 143 | 144 | results.append(ele) 145 | 146 | return results 147 | 148 | 149 | def get_metric(list_of_examples, 150 | domains=["mobile", "desktop", "web"], 151 | data_types=["text", "icon"]): 152 | """ 153 | Computes metrics over a list of examples and prints/plots a table. 154 | 155 | Each element in list_of_examples is a dict containing: 156 | - "domain": Domain name (e.g., "web", "mobile", "desktop") 157 | - "data_type": Data type (e.g., "text", "icon") 158 | - "hit_top1", "overlap_top1", "hit_topk", "overlap_topk": binary (0 or 1) 159 | 160 | The final table has columns for each domain broken down by UI type (plus a domain-average) 161 | and overall columns ("All-text", "All-icon", "All-average"). 162 | 163 | The rows of the table are: 164 | - hit_top1 165 | - overlap_top1 166 | - hit_topk 167 | - overlap_topk 168 | """ 169 | 170 | # List of metric keys to compute. 171 | metrics = ["hit_top1", "overlap_top1", "hit_topk", "overlap_topk"] 172 | 173 | # Helper function to compute the mean of a given key from a list of examples. 174 | def compute_mean(examples, key): 175 | if not examples: 176 | return None 177 | return sum(example.get(key, 0) for example in examples) / len(examples) 178 | 179 | # Prepare results dictionary: structure {metric: {column_name: value}}. 180 | results = {metric: {} for metric in metrics} 181 | 182 | # Compute metrics for each group broken down by UI type. 183 | for domain in domains: 184 | # Filter examples for the current group. 185 | domain_examples = [ex for ex in list_of_examples if ex.get("domain") == domain] 186 | for data_type in data_types: 187 | # Filter further for the specific UI type. 188 | domain_data_type_examples = [ex for ex in domain_examples if ex.get("data_type") == data_type] 189 | col_name = f"{domain}-{data_type}" 190 | for metric in metrics: 191 | results[metric][col_name] = compute_mean(domain_data_type_examples, metric) 192 | 193 | # Compute domain-average (all UI types for this domain). 194 | col_name_avg = f"{domain}-avg" 195 | for metric in metrics: 196 | results[metric][col_name_avg] = compute_mean(domain_examples, metric) 197 | 198 | # Compute overall metrics for each UI type across all domains. 199 | for data_type in data_types: 200 | data_type_examples = [ex for ex in list_of_examples if ex.get("data_type") == data_type] 201 | col_name = f"All-{data_type}" 202 | for metric in metrics: 203 | results[metric][col_name] = compute_mean(data_type_examples, metric) 204 | 205 | # Compute overall average across all examples. 206 | overall_key = "All-avg" 207 | for metric in metrics: 208 | results[metric][overall_key] = compute_mean(list_of_examples, metric) 209 | 210 | # Define the order of columns. 211 | columns_order = [] 212 | for domain in domains: 213 | for data_type in data_types: 214 | columns_order.append(f"{domain}-{data_type}") 215 | columns_order.append(f"{domain}-avg") 216 | for data_type in data_types: 217 | columns_order.append(f"All-{data_type}") 218 | columns_order.append("All-avg") 219 | 220 | # ------------- Print Table to Console ------------- 221 | # Prepare header row. 222 | header = [""] + columns_order 223 | # Calculate column widths for console printing. 224 | col_widths = [max(len(col), 12) for col in header] 225 | 226 | def format_cell(cell): 227 | if isinstance(cell, float): 228 | return f"{cell*100:.2f}" 229 | elif cell is None: 230 | return "N/A" 231 | return str(cell) 232 | 233 | # Print header. 234 | header_line = " | ".join(word.ljust(width) for word, width in zip(header, col_widths)) 235 | separator_line = "-+-".join("-" * width for width in col_widths) 236 | print(header_line) 237 | print(separator_line) 238 | 239 | for metric in metrics: 240 | row = [metric] 241 | for col in columns_order: 242 | val = results[metric].get(col) 243 | row.append(format_cell(val)) 244 | row_line = " | ".join(word.ljust(width) for word, width in zip(row, col_widths)) 245 | print(row_line) 246 | 247 | # ------------- Print Tab-delimited Version (for Excel Copy-Paste) ------------- 248 | metric_info = "Tab-delimited Table for Excel:\n" 249 | # Header row. 250 | header_tab = "\t".join([""] + columns_order) 251 | metric_info += (header_tab + "\n") 252 | # Each row. 253 | for metric in metrics: 254 | row = [metric] + [format_cell(results[metric].get(col)) for col in columns_order] 255 | metric_info += ("\t".join(row) + "\n") 256 | print(metric_info) 257 | return metric_info 258 | 259 | 260 | """ 261 | # cd to project root directory 262 | python eval/screenSpot_v2.py --save_path 263 | """ 264 | if __name__ == "__main__": 265 | parser = argparse.ArgumentParser() 266 | parser.add_argument("--model_type", type=str, default="qwen2vl", choices=["qwen2vl", "qwen25vl"]) 267 | parser.add_argument("--model_name_or_path", type=str, default="microsoft/GUI-Actor-2B-Qwen2-VL") 268 | parser.add_argument("--save_path", type=str, default="./") 269 | parser.add_argument('--topk', type=int, default=3, help='Topk') 270 | parser.add_argument('--no-placeholder', dest='use_placeholder', action='store_false', help='Disable the placeholder') 271 | parser.set_defaults(use_placeholder=True) 272 | 273 | args = parser.parse_args() 274 | 275 | save_path = args.save_path 276 | if not os.path.exists(save_path): 277 | os.makedirs(save_path, exist_ok=True) 278 | pred_path = f"{save_path}/screenspot_v2_all_preds.json" 279 | metric_path = f"{save_path}/screenspot_v2_all_metrics.txt" 280 | 281 | if os.path.exists(metric_path): 282 | exit() 283 | 284 | if os.path.exists(pred_path): 285 | print(f"Loading predictions from {pred_path}") 286 | with open(pred_path, "r") as f: 287 | results = json.load(f) 288 | else: 289 | print(f"Evaluating {args.model_name_or_path}...") 290 | results = evaluate(args.model_name_or_path, args.model_type, args.use_placeholder, args.topk) 291 | with open(pred_path, "w") as f: 292 | json.dump(results, f) 293 | print(f"Saved {len(results)} predictions to {pred_path}") 294 | 295 | if not os.path.exists(metric_path): 296 | metric_info = get_metric(results) 297 | with open(metric_path, "w") as f: 298 | f.write(metric_info) 299 | print(f"Saved metric to {metric_path}") 300 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "gui-actor" 3 | version = "0.1.0" 4 | description = "Coordinate-Free Visual Grounding for GUI Agents" 5 | authors = [ 6 | {name = "GUI-Actor team"}, 7 | ] 8 | dependencies = [ 9 | "pre-commit>=3.7.1", 10 | "pip>=24.1.1", 11 | "Pillow>=10.4.0", 12 | "liger-kernel==0.5.2", 13 | "opencv-python-headless>=4.10.0.84", 14 | "accelerate==1.1.1", 15 | "qwen-vl-utils==0.0.8", 16 | "deepspeed==0.16.0", 17 | "transformers==4.51.3", 18 | "flash-attn", 19 | "wandb==0.18.3", 20 | "datasets>=2.18.0" 21 | ] 22 | requires-python = ">=3.10,<3.13" 23 | readme = "README.md" 24 | license = {text = "MIT"} 25 | 26 | 27 | [tool.pdm] 28 | distribution = false 29 | 30 | [tool.pdm.dev-dependencies] 31 | test = [ 32 | "pytest>=8.2.2", 33 | ] 34 | [tool.ruff] 35 | target-version = 'py38' 36 | line-length = 120 # Must agree with Black 37 | 38 | [tool.ruff.lint] 39 | select = [ 40 | "B", # flake8-bugbear 41 | "C4", # flake8-comprehensions 42 | "D", # pydocstyle 43 | "E", # Error 44 | "F", # pyflakes 45 | "I", # isort 46 | "ISC", # flake8-implicit-str-concat 47 | "N", # pep8-naming 48 | "PGH", # pygrep-hooks 49 | # "PTH", # flake8-use-pathlib 50 | "Q", # flake8-quotes 51 | "SIM", # flake8-simplify 52 | "TRY", # tryceratops 53 | "UP", # pyupgrade 54 | "W", # Warning 55 | "YTT", # flake8-2020 56 | ] 57 | 58 | exclude = [ 59 | "migrations", 60 | "__pycache__", 61 | "manage.py", 62 | "settings.py", 63 | "env", 64 | ".env", 65 | "venv", 66 | ".venv", 67 | ] 68 | 69 | ignore = [ 70 | "B905", # zip strict=True; remove once python <3.10 support is dropped. 71 | "D100", 72 | "D101", 73 | "D102", 74 | "D103", 75 | "D104", 76 | "D105", 77 | "D106", 78 | "D107", 79 | "D200", 80 | "D401", 81 | "E402", 82 | "E501", 83 | "TRY003", # Avoid specifying messages outside exception class; overly strict, especially for ValueError 84 | "N812", 85 | ] 86 | 87 | [tool.ruff.lint.flake8-bugbear] 88 | extend-immutable-calls = [ 89 | "chr", 90 | "typer.Argument", 91 | "typer.Option", 92 | ] 93 | 94 | [tool.ruff.lint.pydocstyle] 95 | convention = "numpy" 96 | 97 | [tool.ruff.lint.per-file-ignores] 98 | "tests/*.py" = [ 99 | "D100", 100 | "D101", 101 | "D102", 102 | "D103", 103 | "D104", 104 | "D105", 105 | "D106", 106 | "D107", 107 | "S101", # use of "assert" 108 | "S102", # use of "exec" 109 | "S106", # possible hardcoded password. 110 | "PGH001", # use of "eval" 111 | ] 112 | 113 | [tool.ruff.lint.pep8-naming] 114 | staticmethod-decorators = [ 115 | "pydantic.validator", 116 | "pydantic.root_validator", 117 | ] 118 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # model_type: qwen2vl or qwen25vl 3 | model_type="qwen2vl" 4 | llm_model="./checkpoints/${model_type}_warmup" 5 | output_dir="./checkpoints/${model_type}_sft" 6 | 7 | # === Training Command === 8 | torchrun --nproc_per_node=4 train.py \ 9 | --deepspeed ./scripts/zero3.json \ 10 | --data_path data/data_config.yaml \ 11 | --image_folder "" \ 12 | --model_type ${model_type} \ 13 | --model_name_or_path ${llm_model} \ 14 | --group_by_modality_length True \ 15 | --bf16 True \ 16 | --output_dir ${output_dir} \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 1 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --eval_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 2000 \ 24 | --learning_rate 1e-4 \ 25 | --weight_decay 0. \ 26 | --warmup_ratio 0.03 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 10 \ 29 | --tf32 True \ 30 | --model_max_length 24576 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 8 \ 33 | --max_pixels 5720064 \ 34 | --unfreeze_all_parameters True \ 35 | --unfreeze_pointer_head False \ 36 | --unfreeze_lm_head False \ 37 | --unfreeze_base_model False \ 38 | --unfreeze_last_n_layers -1 \ 39 | --unfreeze_new_tokens False \ 40 | --unfreeze_visual False \ 41 | --pointer_loss_weight 1.0 \ 42 | --lm_loss_weight 1.0 43 | -------------------------------------------------------------------------------- /scripts/warmup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # model_type: qwen2vl or qwen25vl 3 | model_type="qwen25vl" 4 | llm_model="Qwen/Qwen2.5-VL-3B-Instruct" 5 | output_dir="./checkpoints/${model_type}_warmup" 6 | 7 | # === Training Command === 8 | torchrun --nproc_per_node=4 train.py \ 9 | --deepspeed ./scripts/zero3.json \ 10 | --data_path data/data_config.yaml \ 11 | --image_folder "" \ 12 | --model_type ${model_type} \ 13 | --model_name_or_path ${llm_model} \ 14 | --group_by_modality_length True \ 15 | --bf16 True \ 16 | --output_dir ${output_dir} \ 17 | --num_train_epochs 1 \ 18 | --per_device_train_batch_size 1 \ 19 | --per_device_eval_batch_size 4 \ 20 | --gradient_accumulation_steps 1 \ 21 | --eval_strategy "no" \ 22 | --save_strategy "steps" \ 23 | --save_steps 2000 \ 24 | --learning_rate 1e-4 \ 25 | --weight_decay 0. \ 26 | --warmup_ratio 0.03 \ 27 | --lr_scheduler_type "cosine" \ 28 | --logging_steps 10 \ 29 | --tf32 True \ 30 | --model_max_length 24576 \ 31 | --gradient_checkpointing True \ 32 | --dataloader_num_workers 8 \ 33 | --max_pixels 5720064 \ 34 | --unfreeze_all_parameters False \ 35 | --unfreeze_pointer_head True \ 36 | --unfreeze_lm_head False \ 37 | --unfreeze_base_model False \ 38 | --unfreeze_last_n_layers -1 \ 39 | --unfreeze_new_tokens True \ 40 | --unfreeze_visual False \ 41 | --pointer_loss_weight 1.0 \ 42 | --lm_loss_weight -1.0 43 | -------------------------------------------------------------------------------- /scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "none", 17 | "pin_memory": false 18 | }, 19 | "offload_param": { 20 | "device": "none", 21 | "pin_memory": false 22 | }, 23 | "overlap_comm": false, 24 | "contiguous_gradients": true, 25 | "sub_group_size": 1000000000.0, 26 | "reduce_bucket_size": "auto", 27 | "stage3_prefetch_bucket_size": "auto", 28 | "stage3_param_persistence_threshold": "auto", 29 | "stage3_max_live_parameters": 1000000000.0, 30 | "stage3_max_reuse_distance": 1000000000.0, 31 | "stage3_gather_16bit_weights_on_model_save": true 32 | }, 33 | "gradient_accumulation_steps": "auto", 34 | "gradient_clipping": "auto", 35 | "steps_per_print": 100, 36 | "train_batch_size": "auto", 37 | "train_micro_batch_size_per_gpu": "auto", 38 | "wall_clock_breakdown": false 39 | } 40 | -------------------------------------------------------------------------------- /src/gui_actor/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/GUI-Actor/30e148da6d719117444f9d05a944561f31e362f1/src/gui_actor/__init__.py -------------------------------------------------------------------------------- /src/gui_actor/constants.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | CONTROLLER_HEART_BEAT_EXPIRATION = 30 4 | WORKER_HEART_BEAT_INTERVAL = 15 5 | 6 | LOGDIR = "." 7 | 8 | # Model Constants 9 | IGNORE_INDEX = -100 10 | DEFAULT_IMAGE_TOKEN = "" 11 | DEFAULT_POINTER_START_TOKEN = "<|pointer_start|>" 12 | DEFAULT_POINTER_END_TOKEN = "<|pointer_end|>" 13 | DEFAULT_POINTER_PAD_TOKEN = "<|pointer_pad|>" 14 | 15 | # UNMASK_TOKEN_IDS = [198, 151644, 151645] 16 | 17 | # System Message 18 | grounding_system_message = "You are a GUI agent. Given a screenshot of the current GUI and a human instruction, your task is to locate the screen element that corresponds to the instruction. You should output a PyAutoGUI action that performs a click on the correct position. To indicate the click location, we will use some special tokens, which is used to refer to a visual patch later. For example, you can output: pyautogui.click()." 19 | 20 | # Chat Template 21 | chat_template = "{% set image_count = namespace(value=0) %}{% set video_count = namespace(value=0) %}{% for message in messages %}<|im_start|>{{ message['role'] }}\n{% if message['content'] is string %}{{ message['content'] }}<|im_end|>\n{% else %}{% for content in message['content'] %}{% if content['type'] == 'image' or 'image' in content or 'image_url' in content %}{% set image_count.value = image_count.value + 1 %}{% if add_vision_id %}Picture {{ image_count.value }}: {% endif %}<|vision_start|><|image_pad|><|vision_end|>{% elif content['type'] == 'video' or 'video' in content %}{% set video_count.value = video_count.value + 1 %}{% if add_vision_id %}Video {{ video_count.value }}: {% endif %}<|vision_start|><|video_pad|><|vision_end|>{% elif 'text' in content %}{{ content['text'] }}{% endif %}{% endfor %}<|im_end|>\n{% endif %}{% endfor %}{% if add_generation_prompt %}<|im_start|>assistant\n{% endif %}" 22 | 23 | assistant_template = "{% for message in messages %}{{'<|im_start|>' + message['role']}}{% if 'recipient' in message %}<|recipient|>{{ message['recipient'] }}{% endif %}{{'\n' + message['content'][0]['text']}}{% if 'end_turn' in message and message['end_turn'] %}{{'<|diff_marker|>\n'}}{% else %}{{'<|im_end|>\n'}}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant<|recipient|>' }}{% endif %}" 24 | 25 | # Special Tokens 26 | ADDITIONAL_SPECIAL_TOKENS = [ 27 | "<|recipient|>", 28 | "<|diff_marker|>", 29 | DEFAULT_POINTER_START_TOKEN, 30 | DEFAULT_POINTER_END_TOKEN, 31 | DEFAULT_POINTER_PAD_TOKEN, 32 | ] 33 | 34 | # Action Patterns to be replaced with special tokens 35 | ACTION_PATTENS_XY = [ 36 | r"x=([0-9.]+), y=([0-9.]+)", 37 | r"from_coord=\[([0-9.]+), ([0-9.]+)\], to_coord=\[([0-9.]+), ([0-9.]+)\]", 38 | ] 39 | 40 | until = ["<|diff_marker|>"] 41 | -------------------------------------------------------------------------------- /src/gui_actor/inference.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | import re 4 | import os 5 | from qwen_vl_utils import process_vision_info 6 | from transformers import ( 7 | Qwen2VLForConditionalGeneration, 8 | LogitsProcessor, 9 | LogitsProcessorList, 10 | AutoModelForCausalLM, 11 | AutoTokenizer 12 | ) 13 | from gui_actor.constants import ( 14 | DEFAULT_POINTER_END_TOKEN, 15 | DEFAULT_POINTER_PAD_TOKEN, 16 | chat_template 17 | ) 18 | 19 | class ForceFollowTokensLogitsProcessor(LogitsProcessor): 20 | """ 21 | Forces tokens B (pointer_pad_token) and C (pointer_end_token) to follow token A (pointer_start_token). 22 | Whenever token_a_id is generated, enqueue the forced_sequence (e.g. [B, C]). 23 | As long as forced tokens remain in the queue, force them in the output. 24 | """ 25 | def __init__(self, token_a_id, forced_sequence=[DEFAULT_POINTER_PAD_TOKEN, DEFAULT_POINTER_END_TOKEN]): 26 | super().__init__() 27 | self.token_a_id = token_a_id 28 | self.forced_sequence = forced_sequence # list of token IDs, e.g. [B_id, C_id] 29 | self.force_queue = [] # holds the tokens we still need to force 30 | 31 | def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: 32 | """ 33 | Called at each decoding step to modify `scores`. 34 | 35 | Args: 36 | input_ids: shape (batch_size, seq_len). The already-decoded tokens. 37 | scores: shape (batch_size, vocab_size). Model logits for the next token. 38 | """ 39 | batch_size = input_ids.shape[0] 40 | if batch_size > 1: 41 | raise NotImplementedError("Batch size must be 1 for this logits processor.") 42 | 43 | # We assume batch_size=1 for simplicity; if you have multiple sequences, 44 | # you'll need to adapt the logic to handle each item in the batch. 45 | last_token_id = input_ids[0, -1].item() 46 | 47 | # If the last token was A, enqueue B and C 48 | if last_token_id == self.token_a_id: 49 | self.force_queue.extend(self.forced_sequence) 50 | 51 | # If we have forced tokens waiting in the queue, override the distribution 52 | if len(self.force_queue) > 0: 53 | forced_token = self.force_queue.pop(0) # next token to force 54 | # Create a mask of -inf for all tokens except the forced one 55 | new_scores = torch.full_like(scores, float('-inf')) 56 | new_scores[0, forced_token] = 0.0 # log prob = 0 => prob = 1 57 | return new_scores 58 | 59 | # Otherwise, return scores unmodified 60 | return scores 61 | 62 | 63 | def get_prediction_region_point(attn_scores, n_width, n_height, top_n=30, activation_threshold=0.3, return_all_regions=True, rect_center=False): 64 | """ 65 | 1. Select activated patches 66 | 2. Divide connected patches into different regions 67 | 3. Calculate the average activation value for each region 68 | 4. Select the region with the highest average activation value 69 | 5. Return the center point of that region as the final prediction point 70 | """ 71 | 72 | # Get patches with activation values greater than a certain proportion of the maximum activation value as activated patches 73 | # Get the highest activation value and threshold 74 | max_score = attn_scores[0].max().item() 75 | threshold = max_score * activation_threshold 76 | # Select all patches above the threshold 77 | mask = attn_scores[0] > threshold 78 | valid_indices = torch.nonzero(mask).squeeze(-1) 79 | topk_values = attn_scores[0][valid_indices] 80 | topk_indices = valid_indices 81 | 82 | # Convert indices to 2D coordinates 83 | topk_coords = [] 84 | for idx in topk_indices.tolist(): 85 | y = idx // n_width 86 | x = idx % n_width 87 | topk_coords.append((y, x, idx)) 88 | 89 | # Divide into connected regions 90 | regions = [] 91 | visited = set() 92 | for i, (y, x, idx) in enumerate(topk_coords): 93 | if idx in visited: 94 | continue 95 | 96 | # Start a new region 97 | region = [(y, x, idx, topk_values[i].item())] 98 | visited.add(idx) 99 | queue = [(y, x, idx, topk_values[i].item())] 100 | 101 | # BFS to find connected points 102 | while queue: 103 | cy, cx, c_idx, c_val = queue.pop(0) 104 | 105 | # Check 4 adjacent directions 106 | for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]: 107 | ny, nx = cy + dy, cx + dx 108 | n_idx = ny * n_width + nx 109 | 110 | # Check if this adjacent point is in the topk list 111 | for j, (ty, tx, t_idx) in enumerate(topk_coords): 112 | if ty == ny and tx == nx and t_idx not in visited: 113 | visited.add(t_idx) 114 | region.append((ny, nx, t_idx, topk_values[j].item())) 115 | queue.append((ny, nx, t_idx, topk_values[j].item())) 116 | 117 | regions.append(region) 118 | 119 | # Calculate the average activation value for each region 120 | region_scores = [] 121 | region_centers = [] 122 | region_points = [] 123 | 124 | for region in regions: 125 | # Calculate average score for the region 126 | avg_score = sum(item[3] for item in region) / len(region) 127 | region_scores.append(avg_score) 128 | 129 | # Calculate normalized center coordinates for each patch, then take the average 130 | normalized_centers = [] 131 | weights = [] 132 | y_coords = set() 133 | x_coords = set() 134 | 135 | for y, x, _, score in region: 136 | # Normalized coordinates of the center point for each patch 137 | center_y = (y + 0.5) / n_height 138 | center_x = (x + 0.5) / n_width 139 | normalized_centers.append((center_x, center_y)) 140 | weights.append(score) 141 | 142 | y_coords.add(center_y) 143 | x_coords.add(center_x) 144 | 145 | region_points.append(normalized_centers) 146 | 147 | # Calculate the average of normalized coordinates as the region center 148 | if not rect_center: 149 | # Weighted average 150 | total_weight = sum(weights) 151 | weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight 152 | weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight 153 | avg_center_x, avg_center_y = weighted_x, weighted_y 154 | # # Simple average 155 | # avg_center_x = sum(nc[0] for nc in normalized_centers) / len(normalized_centers) 156 | # avg_center_y = sum(nc[1] for nc in normalized_centers) / len(normalized_centers) 157 | else: 158 | avg_center_x = sum(x_coords) / len(x_coords) 159 | avg_center_y = sum(y_coords) / len(y_coords) 160 | region_centers.append((avg_center_x, avg_center_y)) 161 | 162 | # Select the region with the highest average activation value 163 | sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True) 164 | sorted_scores = [region_scores[i] for i in sorted_indices] 165 | sorted_centers = [region_centers[i] for i in sorted_indices] 166 | sorted_points = [region_points[i] for i in sorted_indices] 167 | best_point = sorted_centers[0] 168 | 169 | if return_all_regions: 170 | # Outputs: 171 | # 1. best_point: the center point of the region with the highest average activation value 172 | # 2. sorted_centers: the center points of all regions, sorted by the average activation value in descending order 173 | # 3. sorted_scores: the average activation values of all regions, sorted in descending order 174 | # 4. sorted_points: the normalized center coordinates of all patches, sorted by the average activation value in descending order 175 | return best_point, sorted_centers, sorted_scores, sorted_points 176 | else: 177 | return best_point 178 | 179 | 180 | def inference(conversation, model, tokenizer, data_processor, logits_processor=None, use_placeholder=False, topk=5): 181 | """ 182 | conversation = [ 183 | { 184 | "role": "system", 185 | "content": [ 186 | { 187 | "type": "text", 188 | "text": grounding_system_message, 189 | } 190 | ] 191 | }, 192 | { 193 | "role": "user", 194 | "content": [ 195 | { 196 | "type": "image", 197 | "image": example["image"], # PIL.Image.Image or str to path 198 | # "image_url": "https://xxxxx.png" or "https://xxxxx.jpg" or "file://xxxxx.png" or "data:image/png;base64,xxxxxxxx", will be split by "base64," 199 | }, 200 | { 201 | "type": "text", 202 | "text": example["instruction"] 203 | }, 204 | ], 205 | }, 206 | ] 207 | """ 208 | if logits_processor is None: 209 | logits_processor = ForceFollowTokensLogitsProcessor( 210 | token_a_id=tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0], 211 | forced_sequence=[ 212 | tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] 213 | ] 214 | ) 215 | 216 | assiatant_starter = "" if not use_placeholder else "<|im_start|>assistant<|recipient|>os\npyautogui.click(<|pointer_start|><|pointer_pad|><|pointer_end|>)" 217 | 218 | pred = { 219 | "output_text": None, # generated text 220 | "n_width": None, # number of patch_tokens in width dimension 221 | "n_height": None, # number of patch_tokens in height dimension 222 | "attn_scores": None, # attention scores over the image patches 223 | "topk_points": None, # topk points 224 | "topk_values": None, # topk values 225 | "topk_points_all": None, # all points 226 | } 227 | 228 | # prepare text 229 | text = data_processor.apply_chat_template(conversation, 230 | tokenize=False, 231 | add_generation_prompt=False, 232 | chat_template=chat_template 233 | ) 234 | text += assiatant_starter 235 | 236 | # prepare inputs 237 | image_inputs, video_inputs = process_vision_info(conversation) 238 | inputs = data_processor(text=[text], 239 | images=image_inputs, 240 | videos=video_inputs, 241 | padding=True, 242 | return_tensors="pt" 243 | ) 244 | inputs = inputs.to(model.device) 245 | 246 | # generate 247 | results = model.generate(**inputs, 248 | max_new_tokens=2048 if not use_placeholder else 1, 249 | logits_processor=LogitsProcessorList([logits_processor]), 250 | return_dict_in_generate=True, 251 | output_hidden_states=True 252 | ) 253 | 254 | 255 | # decode the generated ids 256 | input_ids = inputs["input_ids"][0] 257 | generated_ids = results.sequences[0][len(input_ids):] 258 | output_text = tokenizer.decode(generated_ids, skip_special_tokens=False, clean_up_tokenization_spaces=False) 259 | pred["output_text"] = output_text 260 | 261 | # check if there are is inside the input_ids or generated_ids 262 | if use_placeholder: 263 | pointer_pad_mask = (inputs["input_ids"][0] == model.config.pointer_pad_token_id) # n_all_input_tokens 264 | else: 265 | pointer_pad_mask = (generated_ids[:-1] == model.config.pointer_pad_token_id) # seq_len_generated_ids-1 266 | 267 | # if there are no in the input_ids or generated_ids, return the pred 268 | if len(pointer_pad_mask) == 0: 269 | return pred 270 | 271 | # otherwise, get the coordinate from the action head 272 | if use_placeholder: 273 | decoder_hidden_states = results.hidden_states[0][-1][0] # n_all_input_tokens, hidden_size 274 | else: 275 | decoder_hidden_states = [step_hidden_states[-1][0] for step_hidden_states in results.hidden_states[1:]] 276 | decoder_hidden_states = torch.cat(decoder_hidden_states, dim=0) # seq_len_generated_ids-1, hidden_size 277 | decoder_hidden_states = decoder_hidden_states[pointer_pad_mask] # n_pointer_pad_tokens, hidden_size 278 | 279 | # get the image embeddings as encoder vectors 280 | image_embeds = model.visual(inputs["pixel_values"], grid_thw=inputs["image_grid_thw"]) # n_image_tokens, hidden_size 281 | 282 | attn_scores, _ = model.multi_patch_pointer_head(image_embeds, decoder_hidden_states) 283 | pred["attn_scores"] = attn_scores.tolist() 284 | 285 | _, n_height, n_width = (inputs["image_grid_thw"][0] // model.visual.spatial_merge_size).tolist() 286 | pred["n_width"] = n_width 287 | pred["n_height"] = n_height 288 | 289 | # get the topk points according to the attention scores 290 | best_point, region_points, region_scores, region_points_all = get_prediction_region_point(attn_scores, n_width, n_height, return_all_regions=True, rect_center=False) 291 | topk_points = region_points[:topk] if len(region_points) > topk else region_points 292 | topk_values = region_scores[:topk] if len(region_scores) > topk else region_scores 293 | topk_points_all = region_points_all[:topk] if len(region_points_all) > topk else region_points_all 294 | pred["topk_points"] = topk_points 295 | pred["topk_values"] = topk_values 296 | pred["topk_points_all"] = topk_points_all 297 | 298 | return pred -------------------------------------------------------------------------------- /src/gui_actor/modeling.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.models.qwen2_vl.modeling_qwen2_vl import Qwen2VLCausalLMOutputWithPast, Qwen2VLForConditionalGeneration 6 | from gui_actor.constants import IGNORE_INDEX 7 | from typing import List, Tuple, Union, Optional 8 | from gui_actor.trainer import rank0_print 9 | 10 | class QwenVLwithVisionHeadOutputWithPast(Qwen2VLCausalLMOutputWithPast): 11 | """ 12 | Output class for Qwen2VL with pointer head, extending the base output class. 13 | 14 | Args: 15 | lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 16 | Language modeling loss. 17 | pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 18 | Vision pointer network loss. 19 | pointer_scores (`List[torch.FloatTensor]`, *optional*): 20 | Attention scores from the pointer network, one tensor per batch item. 21 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 22 | Combined loss (weighted sum of lm_loss and pointer_loss). 23 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 24 | Prediction scores from the language modeling head. 25 | past_key_values, hidden_states, attentions, rope_deltas: 26 | Same as parent class. 27 | """ 28 | def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | self.lm_loss = lm_loss 31 | self.pointer_loss = pointer_loss 32 | self.pointer_scores = pointer_scores 33 | 34 | 35 | class VisionHead_MultiPatch(nn.Module): 36 | def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1): 37 | super().__init__() 38 | self.d_model = d_model 39 | 40 | # Note: We omit additional normalization here because Qwen2VL 41 | # already normalizes hidden states using RMSNorm. 42 | self.projection_enc = nn.Sequential( 43 | nn.Linear(d_model, projection_dim), 44 | nn.GELU(), 45 | nn.Linear(projection_dim, d_model) 46 | ) 47 | self.projection_dec = nn.Sequential( 48 | nn.Linear(d_model, projection_dim), 49 | nn.GELU(), 50 | nn.Linear(projection_dim, d_model) 51 | ) 52 | 53 | # Add self-attention layer for visual features 54 | self.self_attention = nn.MultiheadAttention( 55 | embed_dim=d_model, 56 | num_heads=num_attention_heads, 57 | dropout=dropout_rate, 58 | batch_first=True 59 | ) 60 | 61 | # Layer normalization and residual connection 62 | self.layer_norm = nn.LayerNorm(d_model) 63 | self.dropout = nn.Dropout(dropout_rate) 64 | 65 | def forward(self, 66 | hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size 67 | hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample 68 | labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox 69 | do_single_patch: bool = False, 70 | ): 71 | 72 | enc_input = hidden_state_enc.unsqueeze(0) 73 | attn_output, _ = self.self_attention( 74 | query=enc_input, 75 | key=enc_input, 76 | value=enc_input, 77 | # attn_mask=attention_mask, 78 | need_weights=False 79 | ) 80 | # Residual connection and layer normalization 81 | hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)) 82 | # Remove batch dimension 83 | hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model] 84 | 85 | # Apply the projection networks. 86 | proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model] 87 | proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model] 88 | 89 | # Compute scaled dot-product attention scores. 90 | # Scaling by sqrt(d_model) is critical regardless of variable n_enc. 91 | scaling = self.d_model ** 0.5 92 | patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc] 93 | 94 | # Softmax normalization is applied along the encoder dimension. 95 | attn_weights = F.softmax(patch_logits, dim=-1) 96 | 97 | loss = None 98 | if (labels is not None) and (not do_single_patch): 99 | epsilon = 1e-8 100 | labels_float = labels.float() 101 | # Normalize each row to get target probability distribution 102 | target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon) 103 | 104 | # Apply log_softmax to logits 105 | pred_log_probs = F.log_softmax(patch_logits, dim=-1) 106 | # Use KL divergence as loss 107 | loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean') 108 | 109 | if do_single_patch and (labels is not None): 110 | loss = F.cross_entropy(attn_scores, labels) 111 | 112 | return attn_weights, loss 113 | 114 | 115 | class Qwen2VLForConditionalGenerationWithPointer(Qwen2VLForConditionalGeneration): 116 | def __init__(self, *args, **kwargs): 117 | super().__init__(*args, **kwargs) 118 | self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size) 119 | self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0) 120 | self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0) 121 | self.post_init() 122 | 123 | def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight): 124 | self.pointer_loss_weight = pointer_loss_weight 125 | self.lm_loss_weight = lm_loss_weight 126 | 127 | def forward(self, 128 | input_ids: torch.LongTensor = None, # (batch_size, seq_len) 129 | attention_mask: Optional[torch.Tensor] = None, 130 | position_ids: Optional[torch.LongTensor] = None, 131 | past_key_values: Optional[List[torch.FloatTensor]] = None, 132 | inputs_embeds: Optional[torch.FloatTensor] = None, 133 | labels: Optional[torch.LongTensor] = None, 134 | use_cache: Optional[bool] = None, 135 | output_attentions: Optional[bool] = None, 136 | output_hidden_states: Optional[bool] = None, 137 | return_dict: Optional[bool] = None, 138 | pixel_values: Optional[torch.Tensor] = None, 139 | pixel_values_videos: Optional[torch.FloatTensor] = None, 140 | image_grid_thw: Optional[torch.LongTensor] = None, 141 | video_grid_thw: Optional[torch.LongTensor] = None, 142 | rope_deltas: Optional[torch.LongTensor] = None, 143 | cache_position: Optional[torch.LongTensor] = None, 144 | # Grounding 145 | visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token 146 | multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox 147 | if_multi_patch: bool = True, 148 | coordinates: Optional[List[Tuple[float, float]]] = None, 149 | verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]: 150 | 151 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 152 | output_hidden_states = ( 153 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 154 | ) 155 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 156 | 157 | if verbose: 158 | rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...") 159 | rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...") 160 | rank0_print(f"pixel_values: {pixel_values.shape}") 161 | rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}") 162 | rank0_print(f"coordinates: {coordinates}") 163 | rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}") 164 | rank0_print(f"return_dict: {return_dict}") 165 | 166 | if inputs_embeds is None: 167 | inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model) 168 | if pixel_values is not None: 169 | pixel_values = pixel_values.type(self.visual.dtype) 170 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 171 | n_image_tokens = (input_ids == self.config.image_token_id).sum().item() 172 | n_image_features = image_embeds.shape[0] 173 | if n_image_tokens != n_image_features: 174 | raise ValueError( 175 | f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" 176 | ) 177 | image_mask = ( 178 | (input_ids == self.config.image_token_id) 179 | .unsqueeze(-1) 180 | .expand_as(inputs_embeds) 181 | .to(inputs_embeds.device) 182 | ) 183 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 184 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 185 | 186 | if pixel_values_videos is not None: 187 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype) 188 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 189 | n_video_tokens = (input_ids == self.config.video_token_id).sum().item() 190 | n_video_features = video_embeds.shape[0] 191 | if n_video_tokens != n_video_features: 192 | raise ValueError( 193 | f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" 194 | ) 195 | video_mask = ( 196 | (input_ids == self.config.video_token_id) 197 | .unsqueeze(-1) 198 | .expand_as(inputs_embeds) 199 | .to(inputs_embeds.device) 200 | ) 201 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 202 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 203 | 204 | if attention_mask is not None: 205 | attention_mask = attention_mask.to(inputs_embeds.device) 206 | 207 | # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme 208 | if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): 209 | # calculate RoPE index once per generation in the pre-fill stage only 210 | if ( 211 | (cache_position is not None and cache_position[0] == 0) 212 | or self.rope_deltas is None 213 | or (past_key_values is None or past_key_values.get_seq_length() == 0) 214 | ): 215 | position_ids, rope_deltas = self.get_rope_index( 216 | input_ids, image_grid_thw, video_grid_thw, attention_mask 217 | ) 218 | self.rope_deltas = rope_deltas 219 | # then use the prev pre-calculated rope-deltas to get the correct position ids 220 | else: 221 | batch_size, seq_length, _ = inputs_embeds.shape 222 | delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 223 | position_ids = torch.arange(seq_length, device=inputs_embeds.device) 224 | position_ids = position_ids.view(1, -1).expand(batch_size, -1) 225 | if cache_position is not None: # otherwise `deltas` is an int `0` 226 | delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) 227 | delta = delta.to(position_ids.device) 228 | position_ids = position_ids.add(delta) 229 | position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) 230 | 231 | outputs = self.model( 232 | input_ids=None, 233 | position_ids=position_ids, 234 | attention_mask=attention_mask, 235 | past_key_values=past_key_values, 236 | inputs_embeds=inputs_embeds, 237 | use_cache=use_cache, 238 | output_attentions=output_attentions, 239 | output_hidden_states=output_hidden_states, 240 | return_dict=return_dict, 241 | cache_position=cache_position, 242 | ) 243 | 244 | hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model) 245 | logits = self.lm_head(hidden_states) 246 | 247 | lm_loss = None 248 | if labels is not None and self.lm_loss_weight > 0: 249 | # Upcast to float if we need to compute the loss to avoid potential precision issues 250 | logits = logits.float() 251 | # Shift so that tokens < n predict n 252 | shift_logits = logits[..., :-1, :].contiguous() 253 | shift_labels = labels[..., 1:].contiguous() 254 | # Flatten the tokens 255 | loss_fct = nn.CrossEntropyLoss() 256 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 257 | shift_labels = shift_labels.view(-1) 258 | # Enable model parallelism 259 | shift_labels = shift_labels.to(shift_logits.device) 260 | lm_loss = loss_fct(shift_logits, shift_labels) 261 | 262 | 263 | # If vision supervision is requested, process the action head. 264 | pointer_loss = None 265 | pointer_scores = [] 266 | if visual_token_indices_of_coordinates is not None: 267 | batch_size = input_ids.shape[0] 268 | pointer_losses = [] 269 | 270 | # Process each sample individually because the number of visual and target tokens may vary. 271 | for i in range(batch_size): 272 | dummy_target = False 273 | 274 | # Get the token ids and corresponding hidden states for sample i. 275 | token_ids = input_ids[i] # shape: (seq_length,) 276 | hs = hidden_states[i] # shape: (seq_length, d_model) 277 | 278 | # Identify visual tokens indices. 279 | visual_mask = (token_ids == self.config.image_token_id) 280 | visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,) 281 | 282 | # Identify target tokens (the ones that should attend to visual features). 283 | target_mask = (token_ids == self.config.pointer_pad_token_id) 284 | target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) 285 | 286 | # If either visual or target tokens are missing, skip this sample. 287 | if visual_indices.numel() == 0: 288 | raise ValueError(f"No visual or target tokens found for sample {i}.") 289 | if target_indices.numel() == 0: 290 | target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token 291 | gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth 292 | if if_multi_patch: # task the first 4 visual tokens as the ground truth 293 | sample_labels = torch.zeros_like(visual_indices).unsqueeze(0) 294 | sample_labels[0][:4] = 1 295 | dummy_target = True 296 | else: 297 | # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,) 298 | # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token. 299 | gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,) 300 | if if_multi_patch: 301 | sample_labels = multi_patch_labels[i] 302 | 303 | # Gather the corresponding hidden state representations. 304 | # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model) 305 | visual_embeds = inputs_embeds[i][visual_indices] 306 | target_hidden = hs[target_indices] # shape: (n_target, d_model) 307 | 308 | # Calculate loss for multi-patch mode 309 | if if_multi_patch: 310 | # Ensure the number of targets matches between sample and labels 311 | if sample_labels.shape[0] != target_indices.shape[0]: 312 | raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens") 313 | 314 | # Process using VisionHead_MultiPatch 315 | attn_scores, loss_v = self.multi_patch_pointer_head( 316 | visual_embeds, 317 | target_hidden, 318 | labels=sample_labels 319 | ) 320 | 321 | else: 322 | # Deprecated branch - single patch mode is no longer used 323 | # Run the action head to compute the attention (from target tokens to visual tokens) and its loss. 324 | attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt) 325 | 326 | pointer_scores.append(attn_scores.detach().cpu()) 327 | 328 | pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v) 329 | 330 | pointer_loss = torch.stack(pointer_losses).mean() 331 | 332 | # Combine the LM loss and vision loss using the provided loss weights. 333 | 334 | if lm_loss is None: 335 | total_loss = pointer_loss 336 | elif pointer_loss is None: 337 | total_loss = lm_loss 338 | else: 339 | total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss 340 | 341 | if return_dict: 342 | return QwenVLwithVisionHeadOutputWithPast( 343 | lm_loss=lm_loss, 344 | pointer_loss=pointer_loss, 345 | pointer_scores=pointer_scores, 346 | loss=total_loss, 347 | logits=logits, 348 | past_key_values=outputs.past_key_values, 349 | hidden_states=outputs.hidden_states, 350 | attentions=outputs.attentions, 351 | rope_deltas=self.rope_deltas, 352 | ) 353 | else: 354 | # When labels are provided, parent's forward returns a tuple with loss as the first element. 355 | if labels is not None: 356 | # Replace the LM loss with the combined loss. 357 | output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:] 358 | print(f"returning: total_loss, logits, pointer_scores, ...") 359 | return (total_loss,) + output if total_loss is not None else output 360 | else: 361 | return outputs -------------------------------------------------------------------------------- /src/gui_actor/modeling_qwen25vl.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLCausalLMOutputWithPast, Qwen2_5_VLForConditionalGeneration 6 | from gui_actor.constants import IGNORE_INDEX 7 | from typing import List, Tuple, Union, Optional 8 | from gui_actor.trainer import rank0_print 9 | 10 | class QwenVLwithVisionHeadOutputWithPast(Qwen2_5_VLCausalLMOutputWithPast): 11 | """ 12 | Output class for Qwen2_5_VL with pointer head, extending the base output class. 13 | 14 | Args: 15 | lm_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 16 | Language modeling loss. 17 | pointer_loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 18 | Vision pointer network loss. 19 | pointer_scores (`List[torch.FloatTensor]`, *optional*): 20 | Attention scores from the pointer network, one tensor per batch item. 21 | loss (`torch.FloatTensor` of shape `(1,)`, *optional*): 22 | Combined loss (weighted sum of lm_loss and pointer_loss). 23 | logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): 24 | Prediction scores from the language modeling head. 25 | past_key_values, hidden_states, attentions, rope_deltas: 26 | Same as parent class. 27 | """ 28 | def __init__(self, lm_loss=None, pointer_loss=None, pointer_scores=None, *args, **kwargs): 29 | super().__init__(*args, **kwargs) 30 | self.lm_loss = lm_loss 31 | self.pointer_loss = pointer_loss 32 | self.pointer_scores = pointer_scores 33 | 34 | 35 | class VisionHead_MultiPatch(nn.Module): 36 | def __init__(self, d_model, projection_dim, num_attention_heads=8, dropout_rate=0.1): 37 | super().__init__() 38 | self.d_model = d_model 39 | 40 | # Note: We omit additional normalization here because Qwen2VL 41 | # already normalizes hidden states using RMSNorm. 42 | self.projection_enc = nn.Sequential( 43 | nn.Linear(d_model, projection_dim), 44 | nn.GELU(), 45 | nn.Linear(projection_dim, d_model) 46 | ) 47 | self.projection_dec = nn.Sequential( 48 | nn.Linear(d_model, projection_dim), 49 | nn.GELU(), 50 | nn.Linear(projection_dim, d_model) 51 | ) 52 | 53 | # Add self-attention layer for visual features 54 | self.self_attention = nn.MultiheadAttention( 55 | embed_dim=d_model, 56 | num_heads=num_attention_heads, 57 | dropout=dropout_rate, 58 | batch_first=True 59 | ) 60 | 61 | # Layer normalization and residual connection 62 | self.layer_norm = nn.LayerNorm(d_model) 63 | self.dropout = nn.Dropout(dropout_rate) 64 | 65 | def forward(self, 66 | hidden_state_enc, # shape: [n_enc, d_model] where n_enc can vary with image size 67 | hidden_state_dec, # shape: [n_dec, d_model] there can be multiple query in one sample 68 | labels: Optional[torch.Tensor] = None, # shape: [n_dec, n_enc], binary mask of patches in bbox 69 | do_single_patch: bool = False, 70 | ): 71 | 72 | enc_input = hidden_state_enc.unsqueeze(0) 73 | attn_output, _ = self.self_attention( 74 | query=enc_input, 75 | key=enc_input, 76 | value=enc_input, 77 | # attn_mask=attention_mask, 78 | need_weights=False 79 | ) 80 | # Residual connection and layer normalization 81 | hidden_state_enc_ctx = self.layer_norm(enc_input + self.dropout(attn_output)) 82 | # Remove batch dimension 83 | hidden_state_enc_ctx = hidden_state_enc_ctx.squeeze(0) # [n_enc, d_model] 84 | 85 | # Apply the projection networks. 86 | proj_enc = self.projection_enc(hidden_state_enc_ctx) # [n_enc, d_model] 87 | proj_dec = self.projection_dec(hidden_state_dec) # [n_dec, d_model] 88 | 89 | # Compute scaled dot-product attention scores. 90 | # Scaling by sqrt(d_model) is critical regardless of variable n_enc. 91 | scaling = self.d_model ** 0.5 92 | patch_logits = torch.matmul(proj_dec, proj_enc.transpose(0, 1)) / scaling # [n_dec, n_enc] 93 | 94 | # Softmax normalization is applied along the encoder dimension. 95 | attn_weights = F.softmax(patch_logits, dim=-1) 96 | 97 | loss = None 98 | if (labels is not None) and (not do_single_patch): 99 | epsilon = 1e-8 100 | labels_float = labels.float() 101 | # Normalize each row to get target probability distribution 102 | target_dist = labels_float / (labels_float.sum(dim=-1, keepdim=True) + epsilon) 103 | 104 | # Apply log_softmax to logits 105 | pred_log_probs = F.log_softmax(patch_logits, dim=-1) 106 | # Use KL divergence as loss 107 | loss = F.kl_div(pred_log_probs, target_dist, reduction='batchmean') 108 | 109 | if do_single_patch and (labels is not None): 110 | loss = F.cross_entropy(attn_scores, labels) 111 | 112 | return attn_weights, loss 113 | 114 | 115 | class Qwen2_5_VLForConditionalGenerationWithPointer(Qwen2_5_VLForConditionalGeneration): 116 | def __init__(self, *args, **kwargs): 117 | super().__init__(*args, **kwargs) 118 | self.multi_patch_pointer_head = VisionHead_MultiPatch(self.config.hidden_size, self.config.hidden_size) 119 | self.pointer_loss_weight = kwargs.get("pointer_loss_weight", 1.0) 120 | self.lm_loss_weight = kwargs.get("lm_loss_weight", 1.0) 121 | self.post_init() 122 | 123 | def reset_loss_weights(self, pointer_loss_weight, lm_loss_weight): 124 | self.pointer_loss_weight = pointer_loss_weight 125 | self.lm_loss_weight = lm_loss_weight 126 | 127 | def forward(self, 128 | input_ids: torch.LongTensor = None, # (batch_size, seq_len) 129 | attention_mask: Optional[torch.Tensor] = None, 130 | position_ids: Optional[torch.LongTensor] = None, 131 | past_key_values: Optional[List[torch.FloatTensor]] = None, 132 | inputs_embeds: Optional[torch.FloatTensor] = None, 133 | labels: Optional[torch.LongTensor] = None, 134 | use_cache: Optional[bool] = None, 135 | output_attentions: Optional[bool] = None, 136 | output_hidden_states: Optional[bool] = None, 137 | return_dict: Optional[bool] = None, 138 | pixel_values: Optional[torch.Tensor] = None, 139 | pixel_values_videos: Optional[torch.FloatTensor] = None, 140 | image_grid_thw: Optional[torch.LongTensor] = None, 141 | video_grid_thw: Optional[torch.LongTensor] = None, 142 | rope_deltas: Optional[torch.LongTensor] = None, 143 | cache_position: Optional[torch.LongTensor] = None, 144 | second_per_grid_ts: Optional[torch.Tensor] = None, 145 | # Grounding 146 | visual_token_indices_of_coordinates: Optional[torch.Tensor] = None, # shape: (batch_size, n_target); each element is the ground-truth index of the visual token that should be attended to for the corresponding target token 147 | multi_patch_labels: Optional[torch.Tensor] = None, # shape: list [(n_target, n_visual), ...]; binary mask of patches in bbox 148 | if_multi_patch: bool = True, 149 | coordinates: Optional[List[Tuple[float, float]]] = None, 150 | verbose: bool = False) -> Union[Tuple, QwenVLwithVisionHeadOutputWithPast]: 151 | 152 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 153 | output_hidden_states = ( 154 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 155 | ) 156 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 157 | 158 | if verbose: 159 | rank0_print(f"input_ids: {input_ids.shape}, {input_ids[0][:5]}...") 160 | rank0_print(f"labels: {labels.shape}, {labels[0][:5]}...") 161 | rank0_print(f"pixel_values: {pixel_values.shape}") 162 | rank0_print(f"image_grid_thw: {image_grid_thw.shape}, {image_grid_thw}") 163 | rank0_print(f"coordinates: {coordinates}") 164 | rank0_print(f"visual_token_indices_of_coordinates: {visual_token_indices_of_coordinates}") 165 | rank0_print(f"return_dict: {return_dict}") 166 | 167 | if inputs_embeds is None: 168 | inputs_embeds = self.model.embed_tokens(input_ids) # shape: (batch_size, seq_len, d_model) 169 | if pixel_values is not None: 170 | pixel_values = pixel_values.type(self.visual.dtype) 171 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 172 | n_image_tokens = (input_ids == self.config.image_token_id).sum().item() 173 | n_image_features = image_embeds.shape[0] 174 | if n_image_tokens != n_image_features: 175 | raise ValueError( 176 | f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" 177 | ) 178 | image_mask = ( 179 | (input_ids == self.config.image_token_id) 180 | .unsqueeze(-1) 181 | .expand_as(inputs_embeds) 182 | .to(inputs_embeds.device) 183 | ) 184 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 185 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 186 | 187 | if pixel_values_videos is not None: 188 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype) 189 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 190 | n_video_tokens = (input_ids == self.config.video_token_id).sum().item() 191 | n_video_features = video_embeds.shape[0] 192 | if n_video_tokens != n_video_features: 193 | raise ValueError( 194 | f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" 195 | ) 196 | video_mask = ( 197 | (input_ids == self.config.video_token_id) 198 | .unsqueeze(-1) 199 | .expand_as(inputs_embeds) 200 | .to(inputs_embeds.device) 201 | ) 202 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 203 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 204 | 205 | if attention_mask is not None: 206 | attention_mask = attention_mask.to(inputs_embeds.device) 207 | 208 | # if we get 4D attention mask we cannot calculate rope deltas anymore. TODO @raushan fixme 209 | if position_ids is None and (attention_mask is None or attention_mask.ndim == 2): 210 | # calculate RoPE index once per generation in the pre-fill stage only 211 | if ( 212 | (cache_position is not None and cache_position[0] == 0) 213 | or self.rope_deltas is None 214 | or (past_key_values is None or past_key_values.get_seq_length() == 0) 215 | ): 216 | position_ids, rope_deltas = self.get_rope_index( 217 | input_ids, image_grid_thw, video_grid_thw, attention_mask 218 | ) 219 | self.rope_deltas = rope_deltas 220 | # then use the prev pre-calculated rope-deltas to get the correct position ids 221 | else: 222 | batch_size, seq_length, _ = inputs_embeds.shape 223 | delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0 224 | position_ids = torch.arange(seq_length, device=inputs_embeds.device) 225 | position_ids = position_ids.view(1, -1).expand(batch_size, -1) 226 | if cache_position is not None: # otherwise `deltas` is an int `0` 227 | delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) 228 | delta = delta.to(position_ids.device) 229 | position_ids = position_ids.add(delta) 230 | position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) 231 | 232 | outputs = self.model( 233 | input_ids=None, 234 | position_ids=position_ids, 235 | attention_mask=attention_mask, 236 | past_key_values=past_key_values, 237 | inputs_embeds=inputs_embeds, 238 | use_cache=use_cache, 239 | output_attentions=output_attentions, 240 | output_hidden_states=output_hidden_states, 241 | return_dict=return_dict, 242 | cache_position=cache_position, 243 | ) 244 | 245 | hidden_states = outputs[0] # shape: (batch_size, seq_len, d_model) 246 | logits = self.lm_head(hidden_states) 247 | 248 | lm_loss = None 249 | if labels is not None and self.lm_loss_weight > 0: 250 | # Upcast to float if we need to compute the loss to avoid potential precision issues 251 | logits = logits.float() 252 | # Shift so that tokens < n predict n 253 | shift_logits = logits[..., :-1, :].contiguous() 254 | shift_labels = labels[..., 1:].contiguous() 255 | # Flatten the tokens 256 | loss_fct = nn.CrossEntropyLoss() 257 | shift_logits = shift_logits.view(-1, self.config.vocab_size) 258 | shift_labels = shift_labels.view(-1) 259 | # Enable model parallelism 260 | shift_labels = shift_labels.to(shift_logits.device) 261 | lm_loss = loss_fct(shift_logits, shift_labels) 262 | 263 | 264 | # If vision supervision is requested, process the action head. 265 | pointer_loss = None 266 | pointer_scores = [] 267 | if visual_token_indices_of_coordinates is not None: 268 | batch_size = input_ids.shape[0] 269 | pointer_losses = [] 270 | 271 | # Process each sample individually because the number of visual and target tokens may vary. 272 | for i in range(batch_size): 273 | dummy_target = False 274 | 275 | # Get the token ids and corresponding hidden states for sample i. 276 | token_ids = input_ids[i] # shape: (seq_length,) 277 | hs = hidden_states[i] # shape: (seq_length, d_model) 278 | 279 | # Identify visual tokens indices. 280 | visual_mask = (token_ids == self.config.image_token_id) 281 | visual_indices = torch.nonzero(visual_mask, as_tuple=False).squeeze(-1) # shape: (n_visual,) 282 | 283 | # Identify target tokens (the ones that should attend to visual features). 284 | target_mask = (token_ids == self.config.pointer_pad_token_id) 285 | target_indices = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) 286 | 287 | # If either visual or target tokens are missing, skip this sample. 288 | if visual_indices.numel() == 0: 289 | raise ValueError(f"No visual or target tokens found for sample {i}.") 290 | if target_indices.numel() == 0: 291 | target_indices = torch.tensor([hs.shape[0] - 1]) # take the last token as the dummy target token 292 | gt = torch.tensor([0]).to(hs.device) # take the first visual token as the dummy ground truth 293 | if if_multi_patch: # task the first 4 visual tokens as the ground truth 294 | sample_labels = torch.zeros_like(visual_indices).unsqueeze(0) 295 | sample_labels[0][:4] = 1 296 | # n_t = target_indices.size(0) # 目标 token 个数 297 | # n_v = visual_indices.size(0) 298 | # sample_labels = torch.zeros( 299 | # (n_t, n_v), device=hs.device, dtype=torch.float 300 | # ) 301 | # sample_labels[:, :min(4, n_v)] = 1 302 | dummy_target = True 303 | else: 304 | # For supervision, we assume that visual_token_indices_of_coordinates[i] is a tensor of shape (n_target,) 305 | # where each element is an integer in the range [0, n_visual-1] indicating the ground-truth visual token. 306 | gt = visual_token_indices_of_coordinates[i].to(hs.device) # shape: (n_target,) 307 | if if_multi_patch: 308 | sample_labels = multi_patch_labels[i] 309 | # if sample_labels is None: 310 | # n_t = target_indices.size(0) # 目标 token 个数 311 | # n_v = visual_indices.size(0) 312 | # sample_labels = torch.zeros( 313 | # (n_t, n_v), device=hs.device, dtype=torch.float 314 | # ) 315 | # sample_labels[:, :min(4, n_v)] = 1 316 | # dummy_target = True 317 | 318 | # Gather the corresponding hidden state representations. 319 | # visual_hidden = hs[visual_indices] # shape: (n_visual, d_model) 320 | visual_embeds = inputs_embeds[i][visual_indices] 321 | target_hidden = hs[target_indices] # shape: (n_target, d_model) 322 | 323 | # Calculate loss for multi-patch mode 324 | if if_multi_patch: 325 | # Ensure the number of targets matches between sample and labels 326 | if sample_labels.shape[0] != target_indices.shape[0]: 327 | raise ValueError(f"Sample {i} has mismatched target counts: {sample_labels.shape[0]} labels but found {target_indices.shape[0]} target tokens") 328 | 329 | # Process using VisionHead_MultiPatch 330 | attn_scores, loss_v = self.multi_patch_pointer_head( 331 | visual_embeds, 332 | target_hidden, 333 | labels=sample_labels 334 | ) 335 | 336 | else: 337 | # Deprecated branch - single patch mode is no longer used 338 | # Run the action head to compute the attention (from target tokens to visual tokens) and its loss. 339 | attn_scores, loss_v = self.pointer_head(visual_embeds, target_hidden, labels=gt) 340 | 341 | pointer_scores.append(attn_scores.detach().cpu()) 342 | 343 | pointer_losses.append(loss_v * 0.0 if dummy_target else loss_v) 344 | 345 | pointer_loss = torch.stack(pointer_losses).mean() 346 | 347 | # Combine the LM loss and vision loss using the provided loss weights. 348 | 349 | if lm_loss is None: 350 | total_loss = pointer_loss 351 | elif pointer_loss is None: 352 | total_loss = lm_loss 353 | else: 354 | total_loss = self.lm_loss_weight * lm_loss + self.pointer_loss_weight * pointer_loss 355 | 356 | if return_dict: 357 | return QwenVLwithVisionHeadOutputWithPast( 358 | lm_loss=lm_loss, 359 | pointer_loss=pointer_loss, 360 | pointer_scores=pointer_scores, 361 | loss=total_loss, 362 | logits=logits, 363 | past_key_values=outputs.past_key_values, 364 | hidden_states=outputs.hidden_states, 365 | attentions=outputs.attentions, 366 | rope_deltas=self.rope_deltas, 367 | ) 368 | else: 369 | # When labels are provided, parent's forward returns a tuple with loss as the first element. 370 | if labels is not None: 371 | # Replace the LM loss with the combined loss. 372 | output = (lm_loss, pointer_loss, logits, pointer_scores,) + outputs[1:] 373 | print(f"returning: total_loss, logits, pointer_scores, ...") 374 | return (total_loss,) + output if total_loss is not None else output 375 | else: 376 | return outputs -------------------------------------------------------------------------------- /src/gui_actor/trainer.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from functools import wraps 3 | from typing import Optional 4 | 5 | import torch 6 | import torch.distributed as dist 7 | import transformers 8 | from accelerate import Accelerator, DataLoaderConfiguration 9 | from accelerate.utils import GradientAccumulationPlugin, InitProcessGroupKwargs 10 | from torch.utils.data import DataLoader, RandomSampler 11 | from transformers import Trainer 12 | from transformers.trainer import ( 13 | ALL_LAYERNORM_LAYERS, 14 | get_parameter_names, 15 | has_length, 16 | is_accelerate_available, 17 | is_datasets_available, 18 | is_sagemaker_mp_enabled, 19 | ) 20 | from transformers.trainer_pt_utils import LengthGroupedSampler as HFLengthGroupedSampler 21 | from transformers.trainer_utils import seed_worker 22 | from transformers.utils import logging 23 | 24 | if is_datasets_available(): 25 | import datasets 26 | 27 | 28 | def rank0_print(*args): 29 | if dist.is_initialized(): 30 | if dist.get_rank() == 0: 31 | print(f"Rank {dist.get_rank()}: ", *args) 32 | else: 33 | print(*args) 34 | 35 | 36 | def maybe_zero_3(param, ignore_status=False, name=None): 37 | from deepspeed import zero 38 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 39 | 40 | if hasattr(param, "ds_id"): 41 | if param.ds_status == ZeroParamStatus.NOT_AVAILABLE and not ignore_status: 42 | logging.warning(f"{name}: param.ds_status != ZeroParamStatus.NOT_AVAILABLE: {param.ds_status}") 43 | with zero.GatheredParameters([param]): 44 | param = param.data.detach().cpu().clone() 45 | else: 46 | param = param.detach().cpu().clone() 47 | return param 48 | 49 | 50 | def get_mm_adapter_state_maybe_zero_3(named_params, keys_to_match): 51 | to_return = {k: t for k, t in named_params if any(key_match in k for key_match in keys_to_match)} 52 | to_return = {k: maybe_zero_3(v, ignore_status=True).cpu() for k, v in to_return.items()} 53 | return to_return 54 | 55 | 56 | def safe_save_model_for_hf_trainer(trainer: transformers.Trainer, output_dir: str): 57 | """Collects the state dict and dump to disk.""" 58 | trainer.accelerator.wait_for_everyone() 59 | torch.cuda.synchronize() 60 | 61 | if trainer.deepspeed: 62 | trainer.save_model(output_dir) 63 | return 64 | 65 | state_dict = trainer.model.state_dict() 66 | if trainer.args.should_save: 67 | cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()} 68 | del state_dict 69 | trainer._save(output_dir, state_dict=cpu_state_dict) 70 | 71 | 72 | class AGUVISTrainer(Trainer): 73 | 74 | def __init__(self, *args, **kwargs): 75 | super().__init__(*args, **kwargs) 76 | 77 | original_save = self._save 78 | original_save_model = self.save_model 79 | 80 | def modify_eos_token(func): 81 | @wraps(func) 82 | def wrapper(*args, **kwargs): 83 | tokenizer = self.processing_class.tokenizer 84 | old_config_id = self.model.config.eos_token_id 85 | old_eos_token = tokenizer.eos_token 86 | old_generation_config_eos_token_id = ( 87 | self.model.generation_config.eos_token_id if hasattr(self.model, "generation_config") else None 88 | ) 89 | 90 | try: 91 | new_eos_token_id = tokenizer.convert_tokens_to_ids("<|diff_marker|>") 92 | self.model.config.eos_token_id = [new_eos_token_id] 93 | tokenizer.eos_token = "<|diff_marker|>" 94 | if hasattr(self.model, "generation_config"): 95 | self.model.generation_config.eos_token_id = [new_eos_token_id] 96 | 97 | print("Set eos token id to", new_eos_token_id) 98 | print("Set eos token to", "<|diff_marker|>") 99 | print("Set generation config eos token id to", [new_eos_token_id]) 100 | 101 | result = func(*args, **kwargs) 102 | return result 103 | finally: 104 | self.model.config.eos_token_id = old_config_id 105 | tokenizer.eos_token = old_eos_token 106 | if hasattr(self.model, "generation_config") and old_generation_config_eos_token_id is not None: 107 | self.model.generation_config.eos_token_id = old_generation_config_eos_token_id 108 | 109 | print("Set eos token id back to", old_config_id) 110 | print("Set eos token back to", old_eos_token) 111 | if old_generation_config_eos_token_id is not None: 112 | print("Set generation config eos token id back to", old_generation_config_eos_token_id) 113 | 114 | return wrapper 115 | 116 | self._save = modify_eos_token(original_save) 117 | self.save_model = modify_eos_token(original_save_model) 118 | 119 | def create_accelerator_and_postprocess(self): 120 | grad_acc_kwargs = {"num_steps": self.args.gradient_accumulation_steps} 121 | grad_acc_kwargs["sync_with_dataloader"] = False 122 | gradient_accumulation_plugin = GradientAccumulationPlugin(**grad_acc_kwargs) 123 | 124 | accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52)) 125 | 126 | # create accelerator object 127 | dispatch_batches = getattr(self.args, "dispatch_batches", None) 128 | split_batches = getattr(self.args, "split_batches", None) 129 | self.dataloader_config = DataLoaderConfiguration( 130 | dispatch_batches=dispatch_batches, 131 | split_batches=split_batches, 132 | ) 133 | self.accelerator = Accelerator( 134 | dataloader_config=self.dataloader_config, 135 | deepspeed_plugin=self.args.deepspeed_plugin, 136 | gradient_accumulation_plugin=gradient_accumulation_plugin, 137 | kwargs_handlers=[accelerator_kwargs], 138 | ) 139 | # some Trainer classes need to use `gather` instead of `gather_for_metrics`, thus we store a flag 140 | self.gather_function = self.accelerator.gather_for_metrics 141 | 142 | # deepspeed and accelerate flags covering both trainer args and accelerate launcher 143 | self.is_deepspeed_enabled = getattr(self.accelerator.state, "deepspeed_plugin", None) is not None 144 | self.is_fsdp_enabled = getattr(self.accelerator.state, "fsdp_plugin", None) is not None 145 | 146 | # post accelerator creation setup 147 | if self.is_fsdp_enabled: 148 | fsdp_plugin = self.accelerator.state.fsdp_plugin 149 | fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( 150 | "limit_all_gathers", fsdp_plugin.limit_all_gathers 151 | ) 152 | if is_accelerate_available("0.23.0"): 153 | fsdp_plugin.activation_checkpointing = self.args.fsdp_config.get( 154 | "activation_checkpointing", fsdp_plugin.activation_checkpointing 155 | ) 156 | if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing: 157 | raise ValueError( 158 | "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg " 159 | "can't be set to True simultaneously. Please use FSDP's activation_checkpointing logic " 160 | "when using FSDP." 161 | ) 162 | 163 | if self.is_deepspeed_enabled and getattr(self.args, "hf_deepspeed_config", None) is None: 164 | self.propagate_args_to_deepspeed() 165 | 166 | def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: 167 | if self.train_dataset is None or not has_length(self.train_dataset): 168 | return None 169 | 170 | if self.args.group_by_length: 171 | lengths = self.train_dataset.lengths 172 | return HFLengthGroupedSampler( 173 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 174 | dataset=self.train_dataset, 175 | lengths=lengths, 176 | ) 177 | elif self.args.group_by_modality_length: 178 | lengths = self.train_dataset.modality_lengths 179 | return HFLengthGroupedSampler( 180 | self.args.train_batch_size * self.args.gradient_accumulation_steps, 181 | dataset=self.train_dataset, 182 | lengths=lengths, 183 | ) 184 | else: 185 | return RandomSampler(self.train_dataset) 186 | 187 | def get_train_dataloader(self) -> DataLoader: 188 | """ 189 | Returns the training [`~torch.utils.data.DataLoader`]. 190 | 191 | Will use no sampler if `train_dataset` does not implement `__len__`, a random sampler (adapted to distributed 192 | training if necessary) otherwise. 193 | 194 | Subclass and override this method if you want to inject some custom behavior. 195 | """ 196 | if self.train_dataset is None: 197 | raise ValueError("Trainer: training requires a train_dataset.") 198 | 199 | train_dataset = self.train_dataset 200 | data_collator = self.data_collator 201 | if is_datasets_available() and isinstance(train_dataset, datasets.Dataset): 202 | train_dataset = self._remove_unused_columns(train_dataset, description="training") 203 | else: 204 | data_collator = self._get_collator_with_removed_columns(data_collator, description="training") 205 | 206 | dataloader_params = { 207 | "batch_size": self._train_batch_size, 208 | "collate_fn": data_collator, 209 | "num_workers": self.args.dataloader_num_workers, 210 | "pin_memory": self.args.dataloader_pin_memory, 211 | "persistent_workers": self.args.dataloader_persistent_workers, 212 | } 213 | 214 | if not isinstance(train_dataset, torch.utils.data.IterableDataset): 215 | dataloader_params["sampler"] = self._get_train_sampler() 216 | dataloader_params["drop_last"] = self.args.dataloader_drop_last 217 | dataloader_params["worker_init_fn"] = seed_worker 218 | dataloader_params["prefetch_factor"] = ( 219 | self.args.dataloader_num_workers * 2 if self.args.dataloader_num_workers != 0 else None 220 | ) 221 | 222 | dataloader = self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params)) 223 | 224 | return dataloader 225 | 226 | def create_optimizer(self): 227 | """ 228 | Setup the optimizer. 229 | 230 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 231 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 232 | """ 233 | if is_sagemaker_mp_enabled(): 234 | return super().create_optimizer() 235 | 236 | opt_model = self.model 237 | 238 | if self.optimizer is None: 239 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 240 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 241 | optimizer_grouped_parameters = [ 242 | { 243 | "params": [ 244 | p for n, p in opt_model.named_parameters() if (n in decay_parameters and p.requires_grad) 245 | ], 246 | "weight_decay": self.args.weight_decay, 247 | }, 248 | { 249 | "params": [ 250 | p for n, p in opt_model.named_parameters() if (n not in decay_parameters and p.requires_grad) 251 | ], 252 | "weight_decay": 0.0, 253 | }, 254 | ] 255 | 256 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) 257 | 258 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 259 | 260 | return self.optimizer 261 | 262 | def create_optimizer_with_different_learning_rates(self): 263 | """ 264 | Setup the optimizer. 265 | 266 | We provide a reasonable default that works well. If you want to use something else, you can pass a tuple in the 267 | Trainer's init through `optimizers`, or subclass and override this method in a subclass. 268 | """ 269 | if is_sagemaker_mp_enabled(): 270 | raise NotImplementedError("Sagemaker MP is not supported for separate learning rate yet") 271 | return super().create_optimizer() 272 | 273 | opt_model = self.model 274 | 275 | if self.optimizer is None: 276 | decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS) 277 | decay_parameters = [name for name in decay_parameters if "bias" not in name] 278 | 279 | new_parameters = [] 280 | for name, param in opt_model.named_parameters(): 281 | if ("pointer_head" in name) or ("embed_tokens" in name): 282 | new_parameters.append(name) 283 | rank0_print(f"new_parameters: {len(new_parameters)}") 284 | 285 | optimizer_grouped_parameters = [ 286 | { 287 | "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n not in new_parameters) and p.requires_grad)], 288 | "weight_decay": self.args.weight_decay, 289 | "lr": self.args.learning_rate, 290 | }, 291 | { 292 | "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n not in new_parameters) and p.requires_grad)], 293 | "weight_decay": 0.0, 294 | "lr": self.args.learning_rate, 295 | }, 296 | { 297 | "params": [p for n, p in opt_model.named_parameters() if ((n in decay_parameters) and (n in new_parameters) and p.requires_grad)], 298 | "weight_decay": self.args.weight_decay, 299 | "lr": self.args.learning_rate_new_params, 300 | }, 301 | { 302 | "params": [p for n, p in opt_model.named_parameters() if ((n not in decay_parameters) and (n in new_parameters) and p.requires_grad)], 303 | "weight_decay": 0.0, 304 | "lr": self.args.learning_rate_new_params, 305 | }, 306 | ] 307 | 308 | optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(self.args) # {'lr': 0.0001, 'betas': (0.9, 0.999), 'eps': 1e-08} 309 | optimizer_kwargs.pop("lr") 310 | 311 | self.optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) 312 | 313 | return self.optimizer -------------------------------------------------------------------------------- /src/gui_actor/utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw, ImageColor 2 | import json 3 | import os 4 | 5 | def dump_args_to_json(model_config, data_processor, model_args, data_args, training_args, output_dir): 6 | def is_json_serializable(v): 7 | try: 8 | json.dumps(v) 9 | return True 10 | except: 11 | return False 12 | 13 | save_path = f"{output_dir}/args.json" 14 | if not os.path.exists(save_path): 15 | with open(save_path, "w") as f: 16 | json.dump({ 17 | "model_config": {k: v for k, v in model_config.__dict__.items() if is_json_serializable(v)}, 18 | "data_processor_config": {k: v for k, v in data_processor.__dict__.items() if is_json_serializable(v)}, 19 | "image_processor_config": {k: v for k, v in data_processor.image_processor.__dict__.items() if is_json_serializable(v)}, 20 | "model_args": {k: v for k, v in model_args.__dict__.items() if is_json_serializable(v)}, 21 | "data_args": {k: v for k, v in data_args.__dict__.items() if is_json_serializable(v)}, 22 | "training_args": {k: v for k, v in training_args.__dict__.items() if is_json_serializable(v)}, 23 | }, f, indent=4) 24 | 25 | def draw_point(image: Image.Image, point: list, color=None): 26 | if isinstance(color, str): 27 | try: 28 | color = ImageColor.getrgb(color) 29 | color = color + (128,) 30 | except ValueError: 31 | color = (255, 0, 0, 128) 32 | else: 33 | color = (255, 0, 0, 128) 34 | 35 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) 36 | overlay_draw = ImageDraw.Draw(overlay) 37 | radius = 14 38 | x, y = point 39 | 40 | overlay_draw.rectangle( 41 | [x - radius, y - radius, x + radius, y + radius], 42 | fill=color 43 | ) 44 | 45 | center_radius = radius * 0.1 46 | overlay_draw.ellipse( 47 | [(x - center_radius, y - center_radius), 48 | (x + center_radius, y + center_radius)], 49 | fill=(0, 255, 0, 255) 50 | ) 51 | 52 | image = image.convert('RGBA') 53 | combined = Image.alpha_composite(image, overlay) 54 | 55 | return combined.convert('RGB') 56 | 57 | def draw_bbox(image: Image.Image, bbox: list, color=None): 58 | """bbox is in the format of [x1, y1, x2, y2]""" 59 | if isinstance(color, str): 60 | try: 61 | color = ImageColor.getrgb(color) 62 | color = color + (128,) 63 | except ValueError: 64 | color = (255, 0, 0, 128) 65 | else: 66 | color = (255, 0, 0, 128) 67 | 68 | overlay = Image.new('RGBA', image.size, (255, 255, 255, 0)) 69 | overlay_draw = ImageDraw.Draw(overlay) 70 | overlay_draw.rectangle(bbox, fill=color) 71 | return Image.alpha_composite(image, overlay).convert('RGB') 72 | 73 | def do_boxes_overlap(box1, box2): 74 | """ 75 | Check if two boxes overlap. 76 | 77 | Each box is represented as a tuple: (x1, y1, x2, y2) 78 | Where (x1, y1) is the top-left and (x2, y2) is the bottom-right corner. 79 | """ 80 | # Unpack the coordinates 81 | x1_min, y1_min, x1_max, y1_max = box1 82 | x2_min, y2_min, x2_max, y2_max = box2 83 | 84 | # Check for no overlap 85 | if x1_max < x2_min or x2_max < x1_min: 86 | return False 87 | if y1_max < y2_min or y2_max < y1_min: 88 | return False 89 | 90 | return True -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import pathlib 2 | from dataclasses import dataclass, field 3 | from typing import Dict, Optional, Sequence 4 | 5 | import os 6 | import json 7 | import torch 8 | import transformers 9 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl 10 | from PIL import ImageFile 11 | from transformers import ( 12 | Qwen2VLForConditionalGeneration, 13 | AutoProcessor, 14 | ) 15 | 16 | from gui_actor.dataset import LazySupervisedDataset 17 | from gui_actor.trainer import AGUVISTrainer, rank0_print, safe_save_model_for_hf_trainer 18 | from gui_actor.utils import dump_args_to_json 19 | 20 | from gui_actor.constants import ( 21 | IGNORE_INDEX, 22 | ADDITIONAL_SPECIAL_TOKENS, 23 | DEFAULT_POINTER_START_TOKEN, 24 | DEFAULT_POINTER_END_TOKEN, 25 | DEFAULT_POINTER_PAD_TOKEN, 26 | ) 27 | 28 | from gui_actor.modeling import Qwen2VLForConditionalGenerationWithPointer 29 | from gui_actor.modeling_qwen25vl import Qwen2_5_VLForConditionalGenerationWithPointer 30 | 31 | apply_liger_kernel_to_qwen2_vl() 32 | 33 | torch.multiprocessing.set_sharing_strategy("file_system") 34 | 35 | ImageFile.LOAD_TRUNCATED_IMAGES = True 36 | local_rank = None 37 | 38 | 39 | @dataclass 40 | class ModelArguments: 41 | model_name_or_path: Optional[str] = field(default="facebook/opt-125m") 42 | flash_attn_2_enabled: bool = field(default=True) 43 | model_type: str = field(default="qwen2vl", metadata={"help": "model type: qwen2vl or qwen25vl"}) 44 | 45 | @dataclass 46 | class DataArguments: 47 | data_path: str = field(default=None) 48 | early_mix_text: bool = False 49 | image_folder: Optional[str] = field(default=None) 50 | min_pixels: Optional[int] = field(default=3136) # 2 * 2 * 28 * 28 = 56 * 56 51 | max_pixels: Optional[int] = field(default=5720064) # 5720064 = 114 * 64 * 28 * 28 = 3192 * 1792, 12845056 = 128 * 128 * 28 * 28 52 | max_conv_turns: Optional[int] = field(default=10) # 30 => 20 => 10 53 | 54 | 55 | @dataclass 56 | class TrainingArguments(transformers.TrainingArguments): 57 | cache_dir: Optional[str] = field(default=None) 58 | optim: str = field(default="adamw_torch") 59 | model_max_length: int = field( 60 | default=8192, 61 | metadata={"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."}, 62 | ) 63 | group_by_modality_length: bool = field(default=False) 64 | gradient_checkpointing: bool = field(default=True) 65 | verbose_logging: bool = field(default=False) 66 | 67 | unfreeze_all_parameters: bool = field(default=False) 68 | unfreeze_pointer_head: bool = field(default=True) 69 | unfreeze_lm_head: bool = field(default=False) 70 | unfreeze_base_model: bool = field(default=False) 71 | unfreeze_last_n_layers: int = field(default=-1) 72 | unfreeze_new_tokens: bool = field(default=True) 73 | unfreeze_visual: bool = field(default=False) 74 | pointer_loss_weight: float = field(default=0.1) 75 | lm_loss_weight: float = field(default=-1.0) 76 | 77 | # def mask_embedding_grad(grad): 78 | # n_new_tokens = len(ADDITIONAL_SPECIAL_TOKENS) 79 | # mask = torch.zeros_like(grad) 80 | # mask[-n_new_tokens:] = 1.0 81 | # return grad * mask 82 | 83 | def smart_tokenizer_and_embedding_resize( 84 | special_tokens_dict: Dict, 85 | tokenizer: transformers.PreTrainedTokenizer, 86 | model: transformers.PreTrainedModel, 87 | ): 88 | """Resize tokenizer and embedding. 89 | 90 | Note: This is the unoptimized version that may make your embedding size not be divisible by 64. 91 | """ 92 | num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) 93 | model.resize_token_embeddings(len(tokenizer)) 94 | 95 | new_vocab_size = len(tokenizer) 96 | # Update base model and current model config 97 | if hasattr(model.config, "text_config"): 98 | model.config.text_config.vocab_size = new_vocab_size 99 | else: 100 | model.config.vocab_size = new_vocab_size 101 | model.vocab_size = new_vocab_size 102 | 103 | if num_new_tokens > 0: 104 | input_embeddings = model.get_input_embeddings().weight.data 105 | output_embeddings = model.get_output_embeddings().weight.data 106 | 107 | input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 108 | output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True) 109 | 110 | input_embeddings[-num_new_tokens:] = input_embeddings_avg 111 | output_embeddings[-num_new_tokens:] = output_embeddings_avg 112 | 113 | 114 | def update_pointer_token_ids(model_config: transformers.PretrainedConfig, tokenizer: transformers.PreTrainedTokenizer): 115 | model_config.pointer_start_token_id = tokenizer.encode(DEFAULT_POINTER_START_TOKEN)[0] 116 | model_config.pointer_end_token_id = tokenizer.encode(DEFAULT_POINTER_END_TOKEN)[0] 117 | model_config.pointer_pad_token_id = tokenizer.encode(DEFAULT_POINTER_PAD_TOKEN)[0] 118 | rank0_print(f"Updated pointer token ids: {model_config.pointer_pad_token_id}, {model_config.pointer_start_token_id}, {model_config.pointer_end_token_id}") 119 | 120 | def setup_params_to_update(model: transformers.PreTrainedModel, training_args: TrainingArguments): 121 | if training_args.unfreeze_all_parameters: 122 | rank0_print(f"Unfreezing all model parameters...") 123 | for p in model.parameters(): 124 | p.requires_grad = True 125 | else: 126 | rank0_print(f"Freezing all model parameters...") 127 | for p in model.parameters(): 128 | p.requires_grad = False 129 | 130 | if training_args.unfreeze_pointer_head: 131 | rank0_print(f"Unfreezing pointer head parameters...") 132 | # for p in model.pointer_head.parameters(): 133 | # p.requires_grad = True 134 | for p in model.multi_patch_pointer_head.parameters(): 135 | p.requires_grad = True 136 | 137 | if training_args.unfreeze_lm_head: 138 | rank0_print(f"Unfreezing lm head parameters...") 139 | for p in model.lm_head.parameters(): 140 | p.requires_grad = True 141 | 142 | if training_args.unfreeze_base_model: # including text tokens 143 | rank0_print(f"Unfreezing base model parameters...") 144 | for p in model.model.parameters(): 145 | p.requires_grad = True 146 | 147 | if training_args.unfreeze_last_n_layers > 0: 148 | rank0_print(f"Unfreezing last {training_args.unfreeze_last_n_layers} layers of base model parameters...") 149 | for p in model.model.layers[-training_args.unfreeze_last_n_layers:].parameters(): 150 | p.requires_grad = True 151 | 152 | if training_args.unfreeze_new_tokens: 153 | rank0_print(f"Unfreezing new tokens parameters via embedding hook...") 154 | model.model.embed_tokens.weight.requires_grad = True 155 | # Registering hook before Trainer initialization is invalid, so it is disabled 156 | # model.model.embed_tokens.weight.register_hook(mask_embedding_grad) 157 | 158 | if training_args.unfreeze_visual: 159 | rank0_print(f"Unfreezing visual parameters...") 160 | for p in model.visual.parameters(): 161 | p.requires_grad = True 162 | 163 | @dataclass 164 | class DataCollatorForSupervisedDataset: 165 | """Collate examples for supervised fine-tuning.""" 166 | 167 | tokenizer: transformers.PreTrainedTokenizer 168 | 169 | def pad_sequence(self, input_ids, batch_first, padding_value): 170 | if self.tokenizer.padding_side == "left": 171 | input_ids = [torch.flip(_input_ids, [0]) for _input_ids in input_ids] 172 | input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=batch_first, padding_value=padding_value) 173 | if self.tokenizer.padding_side == "left": 174 | input_ids = torch.flip(input_ids, [1]) 175 | return input_ids 176 | 177 | def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 178 | input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) 179 | input_ids = [_input_ids[: self.tokenizer.model_max_length] for _input_ids in input_ids] 180 | labels = [_labels[: self.tokenizer.model_max_length] for _labels in labels] 181 | if self.tokenizer.pad_token_id is None: 182 | self.tokenizer.pad_token_id = 0 # This gets the best result. Don't know why. 183 | input_ids = self.pad_sequence(input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) 184 | labels = self.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) 185 | batch = { 186 | "input_ids": input_ids, 187 | "labels": labels.long() if labels.dtype == torch.int32 else labels, 188 | "attention_mask": input_ids.ne(self.tokenizer.pad_token_id), 189 | } 190 | 191 | if "pixel_values" in instances[0]: 192 | batch["pixel_values"] = torch.concat([instance["pixel_values"] for instance in instances], dim=0) 193 | batch["image_grid_thw"] = torch.concat([instance["image_grid_thw"] for instance in instances], dim=0) 194 | 195 | if "coordinates" in instances[0]: 196 | batch["coordinates"] = [instance["coordinates"] for instance in instances] 197 | batch["visual_token_indices_of_coordinates"] = [instance["visual_token_indices_of_coordinates"] for instance in instances] 198 | 199 | if "multi_patch_labels" in instances[0]: 200 | batch["multi_patch_labels"] = [instance["multi_patch_labels"] for instance in instances] 201 | 202 | return batch 203 | 204 | 205 | def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer, 206 | processor: transformers.ProcessorMixin, 207 | data_args: DataArguments, 208 | training_args: TrainingArguments) -> Dict: 209 | """Make dataset and collator for supervised fine-tuning.""" 210 | train_dataset = LazySupervisedDataset( 211 | tokenizer=tokenizer, processor=processor, data_path=data_args.data_path, data_args=data_args 212 | ) 213 | data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer) 214 | return {"train_dataset": train_dataset, "eval_dataset": None, "data_collator": data_collator} 215 | 216 | 217 | def train(): 218 | global local_rank 219 | 220 | parser = transformers.HfArgumentParser((ModelArguments, DataArguments, TrainingArguments)) 221 | model_args, data_args, training_args = parser.parse_args_into_dataclasses() 222 | local_rank = training_args.local_rank 223 | 224 | if training_args.verbose_logging: 225 | rank0_print("Inspecting experiment hyperparameters:\n") 226 | rank0_print(f"model_args = {vars(model_args)}\n\n") 227 | rank0_print(f"data_args = {vars(data_args)}\n\n") 228 | rank0_print(f"training_args = {vars(training_args)}\n\n") 229 | # rank0_print(f"evaluation_args = {vars(evaluation_args)}\n\n") 230 | 231 | # set up model 232 | if model_args.model_type == "qwen2vl": 233 | model = Qwen2VLForConditionalGenerationWithPointer.from_pretrained( 234 | model_args.model_name_or_path, 235 | cache_dir=training_args.cache_dir, 236 | attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None, 237 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 238 | low_cpu_mem_usage=False, 239 | ) 240 | elif model_args.model_type == "qwen25vl": 241 | model = Qwen2_5_VLForConditionalGenerationWithPointer.from_pretrained( 242 | model_args.model_name_or_path, 243 | cache_dir=training_args.cache_dir, 244 | attn_implementation="flash_attention_2" if model_args.flash_attn_2_enabled else None, 245 | torch_dtype=(torch.bfloat16 if training_args.bf16 else None), 246 | low_cpu_mem_usage=False, 247 | ) 248 | else: 249 | raise ValueError(f"Invalid model type: {model_args.model_type}") 250 | model.config.use_cache = False 251 | model.reset_loss_weights(pointer_loss_weight=training_args.pointer_loss_weight, lm_loss_weight=training_args.lm_loss_weight) 252 | 253 | if training_args.gradient_checkpointing: 254 | if hasattr(model, "enable_input_require_grads"): 255 | model.enable_input_require_grads() 256 | else: 257 | def make_inputs_require_grad(module, input, output): 258 | output.requires_grad_(True) 259 | model.get_input_embeddings().register_forward_hook(make_inputs_require_grad) 260 | 261 | setup_params_to_update(model, training_args) 262 | 263 | # set up tokenizer 264 | tokenizer = transformers.AutoTokenizer.from_pretrained( 265 | model_args.model_name_or_path, 266 | cache_dir=training_args.cache_dir, 267 | model_max_length=training_args.model_max_length, 268 | padding_side="right", 269 | ) 270 | 271 | smart_tokenizer_and_embedding_resize( 272 | special_tokens_dict={"additional_special_tokens": ADDITIONAL_SPECIAL_TOKENS}, 273 | tokenizer=tokenizer, 274 | model=model, 275 | ) 276 | update_pointer_token_ids(model.config, tokenizer) 277 | 278 | data_args.processor = AutoProcessor.from_pretrained( 279 | model_args.model_name_or_path, min_pixels=data_args.min_pixels, max_pixels=data_args.max_pixels 280 | ) 281 | data_args.processor.tokenizer = tokenizer 282 | 283 | if not os.path.exists(training_args.output_dir): 284 | os.makedirs(training_args.output_dir, exist_ok=True) 285 | 286 | if training_args.local_rank == 0 or training_args.local_rank == -1: 287 | dump_args_to_json(model.config, data_args.processor, model_args, data_args, training_args, training_args.output_dir) 288 | 289 | data_module = make_supervised_data_module(tokenizer=tokenizer, processor=data_args.processor, data_args=data_args, training_args=training_args) 290 | 291 | trainer = AGUVISTrainer( 292 | model=model, 293 | processing_class=data_args.processor, 294 | args=training_args, 295 | **data_module, 296 | ) 297 | 298 | # When LiteTrain, only update the gradient of the new tokens 299 | if training_args.unfreeze_new_tokens: 300 | emb_param = None 301 | for n, p in trainer.model.named_parameters(): 302 | if n.endswith("model.embed_tokens.weight"): 303 | emb_param = p; break 304 | if emb_param is None: 305 | raise ValueError("embed_tokens.weight not found") 306 | 307 | n_new_tokens = len(ADDITIONAL_SPECIAL_TOKENS) 308 | def mask_grad(grad): 309 | grad[:-n_new_tokens] = 0.0 310 | return grad 311 | emb_param.register_hook(mask_grad) 312 | 313 | if list(pathlib.Path(training_args.output_dir).glob("checkpoint-*")): 314 | trainer.train(resume_from_checkpoint=True) 315 | else: 316 | trainer.train() 317 | trainer.save_state() 318 | 319 | model.config.use_cache = True 320 | 321 | safe_save_model_for_hf_trainer(trainer=trainer, output_dir=training_args.output_dir) 322 | 323 | rank0_print(f"Model saved to {training_args.output_dir}") 324 | 325 | 326 | if __name__ == "__main__": 327 | train() 328 | -------------------------------------------------------------------------------- /verifier/README.md: -------------------------------------------------------------------------------- 1 | # Grounding Verifier for GUI-Actor 2 | 3 | We developed a grounding verifier to assess whether a selected action position aligns with a given language instruction. This model is particularly effective for GUI-Actor, as GUI-Actor's attention map produces diverse candidate positions from a single inference. With the verifier, we can efficiently evaluate actions **in hindsight**—after identifying the chosen position on the image—and make more informed decisions. 4 | 5 | image 6 | 7 | ## Training 8 | 9 | The verifier is trained to take a language instruction and an image (with a red circle marking the candidate position) as input, and predict whether the position is correct—outputting "True" or "False." 10 | 11 | We use the [OS-Atlas dataset](https://huggingface.co/datasets/OS-Copilot/OS-Atlas-data) and process it using `verifier_data_generation.py` to curate training data. The model is fine-tuned via supervised training (SFT) starting from the UITARS-SFT-2B checkpoint, providing strong performance with a relatively small model size. 12 | 13 | ### Data Preparation 14 | 15 | To prepare the dataset: 16 | 17 | 1. Download and unzip the [OS-Atlas dataset](https://huggingface.co/datasets/OS-Copilot/OS-Atlas-data) following the instructions on the Hugging Face page. 18 | 2. Organize the images into the following directory structure: 19 | 20 | ```python 21 | image_folder_dict = { 22 | 'windows_splited': f'{root_path}/desktop_domain/windows_images', 23 | 'linux_splited': f'{root_path}/desktop_domain/linux_images', 24 | 'macos_splited': f'{root_path}/desktop_domain/macos_images', 25 | 'widget_captioning': f'{root_path}/mobile_domain/combined', 26 | 'uibert_raw': f'{root_path}/mobile_domain/UIBert', 27 | 'ricosca': f'{root_path}/mobile_domain/combined', 28 | 'amex_raw': f'{root_path}/mobile_domain/amex_images', 29 | 'seeclick_web': f'{root_path}/web_domain/seeclick_web_imgs', 30 | 'fineweb_3m': f'{root_path}/web_domain/fineweb' 31 | } 32 | ``` 33 | 34 | Each training sample includes a positive and one or more negative examples: 35 | 36 | * **Positive samples**: taken directly from the original dataset with a red circle marking the correct target. 37 | * **Negative samples**: created by either (a) selecting another meaningful UI element or (b) randomly sampling a point, which may not correspond to any actionable item. 38 | 39 | To generate the dataset, run the following commands (since the dataset is very large, you can ): 40 | 41 | ```bash 42 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key desktop_domain --selected_size 30000 43 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key mobile_domain --selected_size 30000 44 | python verifier_data_generation.py --root_path ${path_to_OS-Atlas-data} --new_directory ${save_path} --file_dict_key web_domain --selected_size 30000 45 | ``` 46 | 47 | 48 | ### SFT 49 | 50 | We use the official code from [Aguvis](https://github.com/xlang-ai/aguvis) to perform SFT training. Make sure to set the file path correctly in the `stage1.yaml` configuration. For training, we use [**UITARS-2B-SFT**](https://huggingface.co/ByteDance-Seed/UI-TARS-2B-SFT) as the base model with a learning rate of $2 \times 10^{-5}$, running for one epoch. 51 | 52 | 53 | 54 | ## Evaluation 55 | 56 | We evaluate our method using the attention weights generated by GUI-Actor and the grounding verifier, saved in a JSON file (e.g., `screenspot_all_preds_Original.json`). Before running the evaluation scripts, please update the file paths in `run_ss_v1.sh`, `run_ss_v2.sh`, and `run_ss_pro.sh` accordingly. 57 | 58 | Make sure to download the ScreenSpot datasets and ensure their paths exactly match those specified in the shell scripts. Specifically, download **ScreenSpot** and **ScreenSpot-Pro** from [ss-v1](https://huggingface.co/datasets/rootsautomation/ScreenSpot) and [ss-pro](https://huggingface.co/datasets/likaixin/ScreenSpot-Pro), respectively. 59 | For **ScreenSpot-v2**, we provide a converted version (`ScreenSpot-v2-new`) that aligns with the format used by the other datasets. However, you still need to download the original images from [ss-v2](https://huggingface.co/datasets/OS-Copilot/ScreenSpot-v2). 60 | 61 | 62 | Once everything is set up, run the following commands: 63 | 64 | ```bash 65 | bash run_ss_v1.sh 66 | bash run_ss_v2.sh 67 | bash run_ss_pro.sh 68 | ``` 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /verifier/eval_ss_with_verifier.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import itertools 3 | import torch 4 | import json 5 | import re 6 | import argparse 7 | import os 8 | from PIL import Image, ImageDraw 9 | import logging 10 | from tqdm import tqdm 11 | 12 | 13 | 14 | def draw_annotations(img, point_in_pixel, bbox, output_path='test.png'): 15 | draw = ImageDraw.Draw(img) 16 | 17 | # Draw the ground truth bounding box in green 18 | if bbox: 19 | # Assuming bbox format is [x1, y1, x2, y2] 20 | draw.rectangle(bbox, outline="yellow", width=4) 21 | 22 | # Draw a small rectangle around the predicted point in red 23 | if point_in_pixel: 24 | # Create a small rectangle around the point (5 pixels in each direction) 25 | radius = 8 26 | circle_bbox = [ 27 | point_in_pixel[0] - radius, # x1 28 | point_in_pixel[1] - radius, # y1 29 | point_in_pixel[0] + radius, # x2 30 | point_in_pixel[1] + radius # y2 31 | ] 32 | draw.ellipse(circle_bbox, outline="red", width=4) 33 | 34 | img.save(output_path) 35 | print(f"Annotated image saved to {output_path}") 36 | return img 37 | 38 | 39 | 40 | 41 | logging.basicConfig(level=logging.INFO) 42 | torch.manual_seed(114514) 43 | 44 | 45 | GT_TYPES = ['positive', 'negative'] 46 | INSTRUCTION_STYLES = ['instruction', 'action', 'description'] 47 | LANGUAGES = ['en', 'cn'] 48 | 49 | 50 | def parse_args(): 51 | parser = argparse.ArgumentParser() 52 | parser.add_argument('--model_path', type=str, required=False) 53 | parser.add_argument('--screenspot_imgs', type=str, required=True) 54 | parser.add_argument('--screenspot_test', type=str, required=True) 55 | parser.add_argument('--task', type=str, required=True) 56 | parser.add_argument('--inst_style', type=str, required=True, choices=INSTRUCTION_STYLES + ['all'], help="Instruction style to use.") 57 | parser.add_argument('--language', type=str, required=True, choices=LANGUAGES + ['all'], default='en', help="Language to use.") 58 | parser.add_argument('--gt_type', type=str, required=True, choices=GT_TYPES + ['all'], help="Ground truth type: 'positive' or 'negative'.") 59 | parser.add_argument('--log_path', type=str, required=True) 60 | parser.add_argument('--json_prediction', type=str, required=False) 61 | parser.add_argument('--verifier_path', type=str, required=True) 62 | parser.add_argument('--verifier_method', type=str, required=True) 63 | 64 | args = parser.parse_args() 65 | return args 66 | 67 | 68 | def build_model(args): 69 | from verifier_model import GroundingVerifier 70 | model = GroundingVerifier(model_name_or_path=args.model_path, json_prediction=args.json_prediction, method=args.verifier_method) 71 | model.load_model(args.verifier_path) 72 | return model 73 | 74 | 75 | def collect_results_to_eval(results, platform=None, group=None, application=None, language=None, gt_type=None, instruction_style=None, ui_type=None): 76 | """ 77 | Filters the results based on provided values. None means include all (ignore filtering this attribute). 78 | 79 | 80 | Parameters: 81 | results (list): A list of dictionaries containing sample results. 82 | 83 | Returns: 84 | list: A filtered list of dictionaries based on the given criteria. 85 | """ 86 | filtered_results = [] 87 | 88 | 89 | for sample in results: 90 | # Check each filter condition; if None, consider it as passed 91 | if (platform is None or sample.get("platform") == platform) and \ 92 | (group is None or sample.get("group") == group) and \ 93 | (application is None or sample.get("application") == application) and \ 94 | (language is None or sample.get("language") == language) and \ 95 | (gt_type is None or sample.get("gt_type") == gt_type) and \ 96 | (instruction_style is None or sample.get("instruction_style") == instruction_style) and \ 97 | (ui_type is None or sample.get("ui_type") == ui_type): 98 | filtered_results.append(sample) 99 | 100 | 101 | return filtered_results 102 | 103 | 104 | 105 | 106 | def make_combinations(results, platform=False, group=None, application=False, language=False, gt_type=False, instruction_style=False, ui_type=False): 107 | """ 108 | Returns a list of combinations of values for attributes where the corresponding parameter is set to True. 109 | """ 110 | # Initialize a dictionary to store unique values for each attribute 111 | unique_values = { 112 | "platform": set(), 113 | "group": set(), 114 | "application": set(), 115 | "language": set(), 116 | "gt_type": set(), 117 | "instruction_style": set(), 118 | "ui_type": set(), 119 | } 120 | 121 | 122 | # Collect unique values from the results 123 | for sample in results: 124 | if platform: 125 | unique_values["platform"].add(sample.get("platform")) 126 | if group: 127 | unique_values["group"].add(sample.get("group")) 128 | if application: 129 | unique_values["application"].add(sample.get("application")) 130 | if language: 131 | unique_values["language"].add(sample.get("language")) 132 | if gt_type: 133 | unique_values["gt_type"].add(sample.get("gt_type")) 134 | if instruction_style: 135 | unique_values["instruction_style"].add(sample.get("instruction_style")) 136 | if ui_type: 137 | unique_values["ui_type"].add(sample.get("ui_type")) 138 | 139 | 140 | # Filter out the attributes that are set to False (no need for combinations) 141 | filtered_values = {key: list(value) for key, value in unique_values.items() if value} 142 | if not filtered_values: 143 | return [] 144 | 145 | 146 | # Generate all combinations of the selected attributes using itertools.product 147 | attribute_combinations = list(itertools.product(*filtered_values.values())) 148 | 149 | 150 | # Convert combinations into dictionaries with corresponding attribute names 151 | combinations = [] 152 | for combination in attribute_combinations: 153 | combinations.append(dict(zip(filtered_values.keys(), combination))) 154 | 155 | 156 | return combinations 157 | 158 | 159 | 160 | 161 | def calc_metric_for_result_list(results): 162 | """Calculates the metrics for a simple result list.""" 163 | num_total = len(results) 164 | correct_num = sum(1 for res in results if res["correctness"] == "correct") 165 | wrong_format_num = sum(1 for res in results if res["correctness"] == "wrong_format") 166 | 167 | 168 | # Calculate text and icon specific metrics using collect_results_to_eval 169 | text_results = collect_results_to_eval(results, ui_type="text") 170 | icon_results = collect_results_to_eval(results, ui_type="icon") 171 | 172 | 173 | text_correct = sum(1 for res in text_results if res["correctness"] == "correct") 174 | text_total = len(text_results) 175 | icon_correct = sum(1 for res in icon_results if res["correctness"] == "correct") 176 | icon_total = len(icon_results) 177 | metrics = { 178 | "num_correct_action": correct_num, 179 | "num_total": num_total, 180 | "wrong_format_num": wrong_format_num, 181 | "action_acc": correct_num / num_total if num_total > 0 else 0, 182 | "text_acc": text_correct / text_total if text_total > 0 else 0, 183 | "icon_acc": icon_correct / icon_total if icon_total > 0 else 0 184 | } 185 | return metrics 186 | 187 | 188 | 189 | 190 | def eval_sample_positive_gt(sample, response): 191 | bbox = sample["bbox"] 192 | bbox = [bbox[0], bbox[1], bbox[2], bbox[3]] # x1, y1, x2, y2 193 | # bbox = [bbox[0], bbox[1], bbox[0] + bbox[2], bbox[1] + bbox[3]] # x1, y1, w, h 194 | img_size = sample["img_size"] 195 | bbox = [bbox[0] / img_size[0], bbox[1] / img_size[1], bbox[2] / img_size[0], bbox[3] / img_size[1]] 196 | 197 | click_point = response["point"] # may be none 198 | print(click_point, bbox) 199 | # import pdb;pdb.set_trace() 200 | if click_point is None: 201 | return "wrong_format" 202 | # Check if the predicted point falls in the ground truth box 203 | if (bbox[0] <= click_point[0] <= bbox[2]) and (bbox[1] <= click_point[1] <= bbox[3]): 204 | return "correct" 205 | else: 206 | return "wrong" 207 | 208 | def eval_sample_negative_gt(sample, response): 209 | if response["result"] == "negative": 210 | return "correct" 211 | elif response["result"] == "positive": 212 | return "wrong" 213 | else: ## response["result"] == wrong_format 214 | return "wrong_format" 215 | 216 | 217 | def evaluate_fine_grained(results): 218 | # Generate all combinations of platform, instruction_style, and gt_type 219 | combinations = make_combinations( 220 | results, 221 | platform=True, 222 | application=True, 223 | instruction_style=True, 224 | gt_type=True 225 | ) 226 | 227 | 228 | evaluation_result = {} 229 | 230 | 231 | # Iterate through each combination 232 | for combo in combinations: 233 | platform = combo.get("platform") 234 | application = combo.get("application") 235 | inst_style = combo.get("instruction_style") 236 | gt_type = combo.get("gt_type") 237 | 238 | # Filter results for the current combination 239 | filtered_results = collect_results_to_eval( 240 | results=results, 241 | platform=platform, 242 | application=application, 243 | instruction_style=inst_style, 244 | gt_type=gt_type 245 | ) 246 | 247 | # Calculate metrics using the calc_metric_for_result_list function 248 | metrics = calc_metric_for_result_list(filtered_results) 249 | if metrics['num_total'] == 0: 250 | continue 251 | 252 | # Construct a unique key based on the combination 253 | key = f"plat:{platform} app:{application} inst_style:{inst_style} gt_type:{gt_type}" 254 | evaluation_result[key] = metrics 255 | 256 | 257 | return evaluation_result 258 | 259 | 260 | def evaluate_fine_grained_v2(results): 261 | # Generate all combinations of platform, instruction_style, and gt_type 262 | combinations = make_combinations( 263 | results, 264 | group=True, 265 | ) 266 | 267 | 268 | evaluation_result = {} 269 | 270 | 271 | # Iterate through each combination 272 | for combo in combinations: 273 | group = combo.get("group") 274 | 275 | 276 | # Filter results for the current combination 277 | filtered_results = collect_results_to_eval( 278 | results=results, 279 | group=group, 280 | ) 281 | 282 | # Calculate metrics using the calc_metric_for_result_list function 283 | metrics = calc_metric_for_result_list(filtered_results) 284 | if metrics['num_total'] == 0: 285 | continue 286 | 287 | # Construct a unique key based on the combination 288 | key = f"group:{group}" 289 | evaluation_result[key] = metrics 290 | 291 | 292 | return evaluation_result 293 | 294 | 295 | def evaluate_seeclick_paper_style(results): 296 | # Generate all combinations of platform, instruction_style, and gt_type 297 | combinations = make_combinations( 298 | results, 299 | platform=True, 300 | instruction_style=True, 301 | gt_type=True 302 | ) 303 | 304 | 305 | evaluation_result = {} 306 | 307 | 308 | # Iterate through each combination 309 | for combo in combinations: 310 | platform = combo.get("platform") 311 | inst_style = combo.get("instruction_style") 312 | gt_type = combo.get("gt_type") 313 | 314 | # Filter results for the current combination 315 | filtered_results = collect_results_to_eval( 316 | results=results, 317 | platform=platform, 318 | instruction_style=inst_style, 319 | gt_type=gt_type 320 | ) 321 | 322 | # Calculate metrics using the calc_metric_for_result_list function 323 | metrics = calc_metric_for_result_list(filtered_results) 324 | if metrics['num_total'] == 0: 325 | continue 326 | 327 | # Construct a unique key based on the combination 328 | key = f"plat:{platform} inst_style:{inst_style} gt_type:{gt_type}" 329 | evaluation_result[key] = metrics 330 | 331 | 332 | return evaluation_result 333 | 334 | 335 | def evaluate_leaderboard_detailed_style(results): 336 | # Generate all combinations of platform, instruction_style, and gt_type 337 | combinations = make_combinations( 338 | results, 339 | application=True, 340 | ) 341 | 342 | 343 | evaluation_result = {} 344 | 345 | 346 | # Iterate through each combination 347 | for combo in combinations: 348 | application = combo.get("application") 349 | 350 | # Filter results for the current combination 351 | filtered_results = collect_results_to_eval( 352 | results=results, 353 | application=application, 354 | ) 355 | 356 | # Calculate metrics using the calc_metric_for_result_list function 357 | metrics = calc_metric_for_result_list(filtered_results) 358 | if metrics['num_total'] == 0: 359 | continue 360 | 361 | # Construct a unique key based on the combination 362 | key = f"app:{application}" 363 | evaluation_result[key] = metrics 364 | 365 | 366 | return evaluation_result 367 | 368 | 369 | def evaluate_leaderboard_simple_style(results): 370 | # Generate all combinations of platform, instruction_style, and gt_type 371 | combinations = make_combinations( 372 | results, 373 | group=True, 374 | ) 375 | 376 | 377 | evaluation_result = {} 378 | 379 | 380 | # Iterate through each combination 381 | for combo in combinations: 382 | group = combo.get("group") 383 | 384 | # Filter results for the current combination 385 | filtered_results = collect_results_to_eval( 386 | results=results, 387 | group=group, 388 | ) 389 | 390 | # Calculate metrics using the calc_metric_for_result_list function 391 | metrics = calc_metric_for_result_list(filtered_results) 392 | if metrics['num_total'] == 0: 393 | continue 394 | 395 | # Construct a unique key based on the combination 396 | key = f"group:{group}" 397 | evaluation_result[key] = metrics 398 | 399 | 400 | return evaluation_result 401 | 402 | 403 | def evaluate_overall(results): 404 | """ 405 | Evaluates the overall metrics for all results without any filtering. 406 | 407 | Parameters: 408 | results (list): A list of dictionaries containing sample results. 409 | 410 | Returns: 411 | dict: A dictionary containing the overall metrics. 412 | """ 413 | # Calculate metrics for the entire result set 414 | metrics = calc_metric_for_result_list(results) 415 | 416 | return metrics 417 | 418 | 419 | 420 | 421 | def evaluate(results): 422 | """Collect results and calculate metrics. You can comment out function calls or add new ones based on your need. 423 | """ 424 | result_report = { 425 | "details": [], # Store detailed information for each sample 426 | "metrics": {} 427 | } 428 | 429 | 430 | # # TODO: comment out function calls based on your need 431 | result_report["metrics"]["fine_grained"] = evaluate_fine_grained_v2(results) 432 | # result_report["metrics"]["seeclick_style"] = evaluate_seeclick_paper_style(results) 433 | # result_report["metrics"]["leaderboard_simple_style"] = evaluate_leaderboard_simple_style(results) 434 | # result_report["metrics"]["leaderboard_detailed_style"] = evaluate_leaderboard_detailed_style(results) 435 | result_report["metrics"]["overall"] = evaluate_overall(results) 436 | 437 | 438 | # Save detailed results 439 | result_report["details"] = results 440 | 441 | 442 | return result_report 443 | 444 | 445 | def main(args): 446 | model = build_model(args) 447 | print("Load model success") 448 | 449 | 450 | if args.task == "all": 451 | task_filenames = [ 452 | os.path.splitext(f)[0] 453 | for f in os.listdir(args.screenspot_test) 454 | if f.endswith(".json") 455 | ] 456 | else: 457 | task_filenames = args.task.split(",") 458 | 459 | 460 | if args.inst_style == "all": 461 | inst_styles = INSTRUCTION_STYLES 462 | else: 463 | inst_styles = args.inst_style.split(",") 464 | 465 | 466 | if args.language == "all": 467 | languages = LANGUAGES 468 | else: 469 | languages = args.language.split(",") 470 | 471 | 472 | if args.gt_type == "all": 473 | gt_types = GT_TYPES 474 | else: 475 | gt_types = args.gt_type.split(",") 476 | 477 | 478 | tasks_to_run = [] 479 | for task_filename in task_filenames: 480 | dataset = task_filename + ".json" 481 | with open(os.path.join(args.screenspot_test, dataset), 'r') as f: 482 | task_data = json.load(f) 483 | 484 | 485 | # Create the list of tasks to run, one item as an instance. Tasks may be reused. 486 | for inst_style in inst_styles: # Expand tasks based on user configurations 487 | for gt_type in gt_types: 488 | for lang in languages: 489 | for task_instance in task_data: # [30:] 490 | task_instance = copy.deepcopy(task_instance) 491 | task_instance["task_filename"] = task_filename 492 | task_instance["gt_type"] = gt_type 493 | task_instance["instruction_style"] = inst_style 494 | task_instance["language"] = lang 495 | if lang == "cn": 496 | if inst_style!= 'instruction' or gt_type != 'positive': 497 | # TODO: Translate the data 498 | raise AttributeError("Only positive samples and 'instruction' style are supported for Chinese instructions.") 499 | task_instance["prompt_to_evaluate"] = task_instance["instruction_cn"] 500 | elif lang == "en": 501 | task_instance["prompt_to_evaluate"] = task_instance["instruction"] 502 | 503 | 504 | tasks_to_run.append(task_instance) 505 | print(f"Num of sample in {task_filename}: {len(task_data)} * {len(inst_styles)} * {len(gt_types)} * {len(languages)} = {len(task_data) * len(inst_styles) * len(gt_types) * len(languages)}") 506 | print(f"Total tasks: {len(tasks_to_run)}") 507 | 508 | 509 | results = [] 510 | for sample in tqdm(tasks_to_run[:]): 511 | filename = sample["img_filename"] 512 | img_path = os.path.join(args.screenspot_imgs, filename) 513 | 514 | if task_instance["gt_type"] == "positive": 515 | response = model.ground_only_positive(instruction=sample["prompt_to_evaluate"], image=img_path, target_point=sample['bbox']) 516 | 517 | 518 | elif task_instance["gt_type"] == "negative": 519 | response = model.ground_allow_negative(instruction=sample["prompt_to_evaluate"], image=img_path) 520 | # print(response) 521 | point = response["point"] 522 | img_size = sample["img_size"] 523 | point_in_pixel = [point[0] * img_size[0], point[1] * img_size[1]] if point else None 524 | 525 | sample_result = { 526 | "img_path": img_path, 527 | "group": sample["group"] if "group" in sample else None, 528 | "platform": sample["platform"], 529 | "application": sample["application"] if 'application' in sample else None, 530 | "lang": sample["language"], 531 | "instruction_style": sample["instruction_style"] if 'instruction_style' in sample else None, 532 | "prompt_to_evaluate": sample["prompt_to_evaluate"], 533 | "gt_type": sample["gt_type"], 534 | "ui_type": sample["ui_type"], 535 | "task_filename": sample["task_filename"], 536 | "pred": point_in_pixel, 537 | "raw_response": response["raw_response"] 538 | } 539 | 540 | if sample["gt_type"] == "positive": 541 | correctness = eval_sample_positive_gt(sample, response) 542 | sample_result.update({ 543 | "bbox": sample["bbox"], 544 | }) 545 | print(correctness) 546 | elif sample["gt_type"] == "negative": 547 | correctness = eval_sample_negative_gt(sample, response) 548 | else: 549 | raise ValueError("Wrong instruction type") 550 | 551 | 552 | 553 | sample_result.update({ 554 | "correctness": correctness, 555 | }) 556 | results.append(sample_result) 557 | 558 | result_report = evaluate(results) 559 | # Save to file 560 | os.makedirs(os.path.dirname(args.log_path), exist_ok=True) 561 | with open(args.log_path, 'w') as f: 562 | json.dump(result_report, f, indent=4) 563 | logging.info("Evaluation of ScreenSpot finished.") 564 | 565 | 566 | 567 | 568 | if __name__ == "__main__": 569 | main(parse_args()) 570 | 571 | 572 | 573 | -------------------------------------------------------------------------------- /verifier/run_ss_pro.sh: -------------------------------------------------------------------------------- 1 | # !/bin/bash 2 | set -e 3 | 4 | json_path='Screenspot_eval/json_data/7B_full_qwen2vl/final_eval/screenspot-Pro_all_preds_StandardResize.json' 5 | exp_name='Actor-7b-fixprompt-bon_score_verifier' 6 | verifier_path='microsoft/GUI-Actor-Verifier-2B' 7 | screenspot_dataset_path='/datadisk/data/ss-eval/ScreenSpot-Pro' 8 | logdir='results_pro' 9 | 10 | verifier_method='score' 11 | # verifier_method='best_one' 12 | export CUDA_VISIBLE_DEVICES=0 13 | 14 | 15 | python eval_ss_with_verifier.py \ 16 | --screenspot_imgs ${screenspot_dataset_path}'/images' \ 17 | --screenspot_test ${screenspot_dataset_path}'/annotations' \ 18 | --task "all" \ 19 | --language "en" \ 20 | --gt_type "positive" \ 21 | --log_path "${logdir}/${exp_name}_${checkpoint}_sspro.json" \ 22 | --inst_style "instruction" \ 23 | --verifier_method ${verifier_method} \ 24 | --verifier_path ${verifier_path} \ 25 | --json_prediction ${json_path} 26 | 27 | -------------------------------------------------------------------------------- /verifier/run_ss_v1.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | json_path='/home/t-yangrui/code/Screenspot_eval/json_data/new_prompt_7B/screenspot_all_preds_Original.json' 5 | exp_name='Actor-7b-warmup-fixprompt-bon_score_verifierultracpt6000_crop500' 6 | verifier_path='microsoft/GUI-Actor-Verifier-2B' 7 | screenspot_dataset_path="data/ss-eval/ScreenSpot" 8 | logdir='results_v1' 9 | 10 | verifier_method='score' 11 | # verifier_method='best_one' 12 | export CUDA_VISIBLE_DEVICES=0 13 | 14 | 15 | python eval_ss_with_verifier.py \ 16 | --screenspot_imgs ${screenspot_dataset_path}'/images' \ 17 | --screenspot_test ${screenspot_dataset_path} \ 18 | --task "all" \ 19 | --language "en" \ 20 | --gt_type "positive" \ 21 | --log_path "${logdir}/${exp_name}_${checkpoint}_ssv1.json" \ 22 | --inst_style "instruction" \ 23 | --verifier_method ${verifier_method} \ 24 | --verifier_path ${verifier_path} \ 25 | --json_prediction ${json_path} 26 | 27 | -------------------------------------------------------------------------------- /verifier/run_ss_v2.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | models=("aguvis-with-verifier") 5 | 6 | 7 | json_path='/home/t-yangrui/code/Screenspot_eval/json_data/new_prompt_3B/screenspot_v2_all_preds_Original.json' 8 | exp_name='Actor-7b-warmup-fixprompt-bon_score_verifierultracpt6000_crop500' 9 | verifier_path='microsoft/GUI-Actor-Verifier-2B' 10 | screenspot_dataset_path='ss-eval/ScreenSpot-v2' 11 | logdir='results_v2' 12 | 13 | 14 | verifier_method='score' 15 | # verifier_method='best_one' 16 | export CUDA_VISIBLE_DEVICES=0 17 | 18 | 19 | python eval_ss_with_verifier.py \ 20 | --screenspot_imgs "${screenspot_dataset_pat}/screenspotv2_image" \ 21 | --screenspot_test "ScreenSpot-v2-new" \ 22 | --task "all" \ 23 | --language "en" \ 24 | --gt_type "positive" \ 25 | --log_path "${logdir}/${exp_name}_${checkpoint}_ssv2.json" \ 26 | --inst_style "instruction" \ 27 | --verifier_method ${verifier_method} \ 28 | --verifier_path ${verifier_path} \ 29 | --json_prediction ${json_path} 30 | 31 | 32 | 33 | -------------------------------------------------------------------------------- /verifier/verifier_data_generation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import re 4 | import numpy as np 5 | import random 6 | from PIL import Image, ImageDraw 7 | import argparse 8 | 9 | 10 | dic = { 11 | "from": "gpt", 12 | "value": "True", 13 | "recipient": "os", 14 | "end_turn": True 15 | } 16 | neg_dic = { 17 | "from": "gpt", 18 | "value": "False", 19 | "recipient": "os", 20 | "end_turn": True 21 | } 22 | 23 | 24 | 25 | def sample_point(bbox): 26 | x0, y0, x1, y1 = bbox 27 | t = 0 28 | while t <= 50: 29 | xx, yy = np.random.random(2) 30 | t += 1 31 | if not ((x0 < xx < x1) and (y0 < yy < y1)): 32 | break 33 | if t > 50: 34 | return 35 | return xx, yy 36 | 37 | 38 | def load_json_file(file_path): 39 | """Load and parse JSON data from a file.""" 40 | try: 41 | with open(file_path, 'r') as file: 42 | data = json.load(file) 43 | return data 44 | except FileNotFoundError: 45 | print(f"Error: File '{file_path}' not found.") 46 | return None 47 | except json.JSONDecodeError: 48 | print(f"Error: '{file_path}' contains invalid JSON.") 49 | return None 50 | 51 | 52 | def draw_annotations(img, point_in_pixel, bbox, output_path='test.png', color='red', size=1): 53 | draw = ImageDraw.Draw(img) 54 | 55 | # Draw the ground truth bounding box in green 56 | if bbox: 57 | # Assuming bbox format is [x1, y1, x2, y2] 58 | draw.rectangle(bbox, outline="yellow", width=4) 59 | # Draw a small rectangle around the predicted point in red 60 | if point_in_pixel: 61 | # Create a small rectangle around the point (5 pixels in each direction) 62 | radius = np.ceil(8 * size).astype(int) 63 | circle_bbox = [ 64 | point_in_pixel[0] - radius, # x1 65 | point_in_pixel[1] - radius, # y1 66 | point_in_pixel[0] + radius, # x2 67 | point_in_pixel[1] + radius # y2 68 | ] 69 | draw.ellipse(circle_bbox, outline=color, width=np.ceil(4 * size).astype(int)) 70 | 71 | img.save(output_path) 72 | print(f"Annotated image saved to {output_path}") 73 | return img 74 | 75 | 76 | def transform_to_conversation_format(data, file, image_folder_dict, new_directory): 77 | """ 78 | Transform the input data to the specified conversation format. 79 | Args: 80 | data: List of dictionaries containing webpage elements data 81 | 82 | Returns: 83 | List of dictionaries in the conversation format 84 | """ 85 | image_folder = image_folder_dict[file] 86 | result = [] 87 | for i, item in enumerate(data): 88 | print(i / len(data)) 89 | img_filename = item['img_filename'] 90 | 91 | prompt = 'Please observe the screenshot and exame whether the hollow red circle accurately placed on the intended position in the image:' 92 | 93 | if 'elements' in item: 94 | # sample n//2 element 95 | n = len(item['elements']) 96 | ind_list = [] 97 | if n <= 1: 98 | ind_list = [0] 99 | else: 100 | ind_list = random.sample(range(n), min(n//2, 3)) 101 | 102 | for ind in ind_list: 103 | conversations = [] 104 | instruction = item['elements'][ind]['instruction'] 105 | bbox = item['elements'][ind]['bbox'] 106 | if (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) >= 0.8: 107 | continue 108 | 109 | conversations.append({ 110 | "from": "human", 111 | "value": f"\n{prompt} " + f"'{instruction}'. Answer True or False." 112 | }) 113 | 114 | # Calculate the center point of the bounding box 115 | x_center = (bbox[0] + bbox[2]) / 2 116 | y_center = (bbox[1] + bbox[3]) / 2 117 | 118 | if n >= 2: 119 | neg_ind = random.choice([k for k in range(n) if k != ind]) 120 | neg_bbox = item['elements'][neg_ind]['bbox'] 121 | x_neg, y_neg = (neg_bbox[0] + neg_bbox[2]) / 2, (neg_bbox[1] + neg_bbox[3]) / 2 122 | if (x_center - x_neg) ** 2 + (y_center - y_neg) ** 2 < 0.05: 123 | x_neg, y_neg = sample_point(bbox) 124 | else: 125 | x_neg, y_neg = sample_point(bbox) 126 | 127 | # draw image 128 | try: 129 | img = Image.open(os.path.join(image_folder, img_filename)) 130 | except: 131 | continue 132 | 133 | 134 | prefix, suffix = img_filename.split('.') 135 | width, height = img.size 136 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos{ind}.' + suffix) 137 | while os.path.exists(save_path): 138 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos{ind}_{random.randint(0, 1000)}.' + suffix) 139 | 140 | try: 141 | draw_annotations(img, [x_center * width, y_center* height], None, output_path=save_path, size=height/1000 * 1.2) 142 | except: 143 | continue 144 | img = Image.open(os.path.join(image_folder, img_filename)) 145 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg{ind}.' + suffix) 146 | while os.path.exists(neg_save_path): 147 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg{ind}_{random.randint(0, 1000)}.' + suffix) 148 | 149 | draw_annotations(img, [x_neg * width, y_neg* height], None, output_path=neg_save_path, size=height/1000 * 1.2) 150 | 151 | 152 | # Create the conversation item 153 | result.append({ 154 | "image":save_path.replace(new_directory, ''), 155 | "conversations": conversations + [dic] 156 | }) 157 | result.append({ 158 | "image":neg_save_path.replace(new_directory, ''), 159 | "conversations": conversations + [neg_dic] 160 | }) 161 | else: 162 | conversations = [] 163 | instruction = item['instruction'] 164 | bbox = item['bbox'] 165 | conversations.append({ 166 | "from": "human", 167 | "value": f"\n{prompt} " + f"'{instruction}'. Answer True or False." 168 | }) 169 | 170 | 171 | if (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) >= 0.8: 172 | continue 173 | 174 | x_center = (bbox[0] + bbox[2]) / 2 175 | y_center = (bbox[1] + bbox[3]) / 2 176 | x_neg, y_neg = sample_point(bbox) 177 | 178 | # draw image 179 | try: 180 | img = Image.open(os.path.join(image_folder, img_filename)) 181 | except: 182 | continue 183 | prefix, suffix = img_filename.split('.') 184 | width, height = img.size 185 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + '_pos.' + suffix) 186 | while os.path.exists(save_path): 187 | save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_pos_{random.randint(0, 1000)}.' + suffix) 188 | 189 | draw_annotations(img, [x_center * width, y_center* height], None, output_path=save_path, size=height/1000 * 1.2) 190 | 191 | img = Image.open(os.path.join(image_folder, img_filename)) 192 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + '_neg.' + suffix) 193 | while os.path.exists(neg_save_path): 194 | neg_save_path = os.path.join(new_directory, file+'_'+ prefix.replace('/', '') + f'_neg_{random.randint(0, 1000)}.' + suffix) 195 | 196 | draw_annotations(img, [x_neg * width, y_neg* height], None, output_path=neg_save_path, size=height/1000 * 1.2) 197 | 198 | # Create the conversation item 199 | result.append({ 200 | "image":save_path.replace(new_directory, ''), 201 | "conversations": conversations + [dic] 202 | }) 203 | result.append({ 204 | "image":neg_save_path.replace(new_directory, ''), 205 | "conversations": conversations + [neg_dic] 206 | }) 207 | 208 | return result 209 | 210 | 211 | 212 | 213 | if __name__ == "__main__": 214 | parser = argparse.ArgumentParser(description="Generate verifier data") 215 | parser.add_argument('--root_path', type=str, required=True, help='Root path to OS-Atlas-data') 216 | parser.add_argument('--new_directory', type=str, default='./verifier_data', help='Directory to save the new verifier data') 217 | parser.add_argument('--file_dict_key', type=str, default='', help='Key for the file dictionary to process') 218 | parser.add_argument('--save_suffix', type=str, default='verifier', help='Suffix for the saved files') 219 | parser.add_argument('--selected_size', type=int, default=10000, help='Number of samples to select from each file') 220 | args = parser.parse_args() 221 | 222 | 223 | root_path = args.root_path 224 | new_directory = args.new_directory 225 | save_suffix = args.save_suffix 226 | selected_size = args.selected_size 227 | 228 | if not os.path.exists(new_directory): 229 | os.makedirs(new_directory) 230 | 231 | 232 | image_folder_dict = { 233 | 'windows_splited': f'{root_path}/desktop_domain/windows_images', 234 | 'linux_splited': f'{root_path}/desktop_domain/linux_images', 235 | 'macos_splited': f'{root_path}/desktop_domain/macos_images', 236 | 'widget_captioning': f'{root_path}/mobile_domain/combined', 237 | 'uibert_raw': f'{root_path}/mobile_domain/UIBert', 238 | 'ricosca': f'{root_path}/mobile_domain/combined', 239 | 'amex_raw': f'{root_path}/mobile_domain/amex_images', 240 | 'seeclick_web': f'{root_path}/web_domain/seeclick_web_imgs', 241 | 'fineweb_3m': f'{root_path}/web_domain/fineweb' 242 | } 243 | 244 | 245 | file_dict = { 246 | 'desktop_domain': ['linux_splited', 'windows_splited', 'macos_splited'], 247 | 'mobile_domain': ['uibert_raw', 'ricosca', 'amex_raw', 'widget_captioning'], 248 | 'web_domain': ['fineweb_3m', 'seeclick_web'], 249 | } 250 | 251 | 252 | def process_files(directory): 253 | files = file_dict[directory] 254 | for file in files: 255 | file_path = os.path.join(root_path, directory, file + '.json') 256 | # Load the JSON data 257 | data = load_json_file(file_path) 258 | data = random.sample(data, selected_size) if len(data) >= selected_size else data 259 | print(directory, file, len(data)) 260 | 261 | # Extract coordinates 262 | new_data = transform_to_conversation_format(data, file, image_folder_dict, new_directory) 263 | 264 | 265 | print(directory, file, len(data)) 266 | with open(file_path.replace('.json', f'_{save_suffix}.json'), "w", encoding="utf-8") as f: 267 | json.dump(new_data, f) 268 | 269 | 270 | if len(args.file_dict_key) == 0: 271 | for directory in file_dict.keys(): 272 | process_files(directory) 273 | else: 274 | key = args.file_dict_key 275 | assert key in file_dict.keys(), f"Key {key} not found in file_dict" 276 | process_files(key) 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 285 | -------------------------------------------------------------------------------- /verifier/verifier_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor 3 | from transformers.generation import GenerationConfig 4 | import json 5 | import re 6 | import os 7 | import tempfile 8 | from PIL import Image, ImageDraw 9 | from qwen_vl_utils import process_vision_info 10 | from typing import List, Literal, Optional 11 | import numpy as np 12 | import random 13 | 14 | grounding_system_message = "You are a GUI agent. You are given a task and a screenshot of the screen. You need to perform a series of pyautogui actions to complete the task." 15 | 16 | 17 | def image_to_temp_filename(image): 18 | temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False) 19 | image.save(temp_file.name) 20 | print(f"Image saved to temporary file: {temp_file.name}") 21 | return temp_file.name 22 | 23 | 24 | def draw_point_list(img, points, color='red', size=1, crop=True, sample_crop=False, crop_size=500): 25 | draw = ImageDraw.Draw(img) 26 | radius = np.ceil(7 * size).astype(int) 27 | for point in points: 28 | circle_bbox = [ 29 | point[0] - radius, # x1 30 | point[1] - radius, # y1 31 | point[0] + radius, # x2 32 | point[1] + radius # y2 33 | ] 34 | draw.ellipse(circle_bbox, outline=color, width=np.ceil(3 * size).astype(int)) 35 | 36 | 37 | if crop: 38 | x, y = points[0] 39 | width, height = img.size 40 | crop_half_size = crop_size 41 | left = max(0, x - crop_half_size) 42 | right = min(width-1, x + crop_half_size) 43 | top = max(0, y - crop_half_size) 44 | bottom = min(height-1, y + crop_half_size) 45 | try: 46 | img = img.crop((left, top, right, bottom)) 47 | except Exception as e: 48 | print(f"Error cropping image: {e}") 49 | # If cropping fails, return the original image 50 | return img 51 | return img 52 | 53 | 54 | 55 | class GroundingVerifier(): 56 | def __init__(self, 57 | model_name_or_path="microsoft/GUI-Actor-Verifier-2B", 58 | json_prediction=None, 59 | method='score' # 'best_one', 'comparison', 'score' 60 | ): 61 | self.method = method 62 | self.model_name_or_path = model_name_or_path 63 | self.system_message = { 64 | "role": "system", 65 | "content": grounding_system_message, 66 | } 67 | self.json_prediction_path = json_prediction 68 | # load json prediction 69 | assert os.path.exists(json_prediction) and os.path.isfile(json_prediction), "Invalid json prediction path." 70 | with open(json_prediction, 'r') as f: 71 | self.json_prediction = json.load(f) 72 | 73 | self.verifier_crop_size = 500 # half of the true crop size 74 | # use 0.95 for ss-pro 75 | if '-pro' in self.json_prediction_path.lower(): 76 | self.threshold = 0.95 77 | else: # use 0.8 for ss and ss-v2 78 | self.threshold = 0.8 79 | 80 | self.json_index_dict = {} 81 | for i, item in enumerate(self.json_prediction): 82 | key = 'img_filename' if 'img_filename' in item else 'file_name' 83 | json_key = item[key] + item['instruction'] if 'instruction' in item else '' 84 | self.json_index_dict[json_key] = i 85 | 86 | 87 | 88 | def load_model(self, verifier_path): 89 | if self.method == 'best_one': 90 | return 91 | else: 92 | verifier_model_name_or_path = verifier_path 93 | 94 | self.verifier = Qwen2VLForConditionalGeneration.from_pretrained( 95 | verifier_model_name_or_path, 96 | device_map="cuda:0", 97 | trust_remote_code=True, 98 | torch_dtype=torch.bfloat16, 99 | attn_implementation="flash_attention_2" 100 | ).eval() 101 | self.verifier_tokenizer = AutoTokenizer.from_pretrained(verifier_model_name_or_path, trust_remote_code=True) 102 | self.verifier_processor = AutoProcessor.from_pretrained(verifier_model_name_or_path) 103 | self.verifier_processor.tokenizer.pad_token = self.verifier_processor.tokenizer.eos_token 104 | 105 | 106 | 107 | def set_generation_config(self, **kwargs): 108 | pass 109 | 110 | 111 | def verify(self, instruction, image): 112 | verifier_prompt = "Please observe the screenshot and exame whether the hollow red circle accurately placed on the intended position in the image: '{}'. Answer True or False." 113 | full_prompt = verifier_prompt.format(instruction) 114 | messages = [ 115 | { 116 | "role": "user", 117 | "content": [ 118 | { 119 | "type": "image", 120 | "image": image, 121 | }, 122 | {"type": "text", "text": full_prompt}, 123 | ], 124 | } 125 | ] 126 | text_input = self.verifier_processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) 127 | image_inputs, video_inputs = process_vision_info(messages) 128 | inputs = self.verifier_processor( 129 | text=[text_input], 130 | images=image_inputs, 131 | videos=video_inputs, 132 | padding=True, 133 | return_tensors="pt", 134 | ) 135 | inputs = inputs.to("cuda:0") 136 | 137 | 138 | # get the token probability of True and False using the verifier 139 | # Forward pass to get logits 140 | with torch.no_grad(): 141 | outputs = self.verifier(**inputs) 142 | logits = outputs.logits # shape: (batch_size, seq_len, vocab_size) 143 | 144 | # Get the last token's logits 145 | last_token_logits = logits[:, -1, :] # (batch_size, vocab_size) 146 | 147 | 148 | # Get vocab IDs for "True" and "False" 149 | true_id = self.verifier_processor.tokenizer.encode("True", add_special_tokens=False)[0] 150 | false_id = self.verifier_processor.tokenizer.encode("False", add_special_tokens=False)[0] 151 | 152 | # Get probabilities using softmax 153 | probs = torch.softmax(last_token_logits, dim=-1) 154 | true_prob = probs[0, true_id].item() 155 | false_prob = probs[0, false_id].item() 156 | score = true_prob / (true_prob + false_prob) 157 | return score 158 | 159 | 160 | 161 | def verifier_score(self, instruction, image, box): 162 | box = [box] 163 | img_copy = image.copy() 164 | img_new = draw_point_list(img_copy, box, crop_size=self.verifier_crop_size) 165 | score = self.verify(instruction, img_new) 166 | return score 167 | 168 | 169 | def get_prediction_region_point(self, attn_scores, n_width, n_height, top_n=20, return_all_regions=True, rect_center=False, no_groups=False): 170 | attn_scores = np.array(attn_scores) 171 | max_score = attn_scores.max() 172 | threshold = max_score * 0.2 173 | # select patches with activation scores above the threshold 174 | mask = attn_scores > threshold 175 | valid_indices = np.where(mask) 176 | # keep only top_n patches 177 | if len(valid_indices[1]) > top_n: 178 | valid_scores = attn_scores[valid_indices] 179 | sorted_idx = np.argsort(valid_scores)[::-1][:top_n] 180 | valid_indices = valid_indices[1][sorted_idx] 181 | topk_values = valid_scores[sorted_idx] 182 | topk_indices = valid_indices 183 | else: 184 | topk_values = attn_scores[valid_indices].tolist() 185 | topk_indices = valid_indices[1] 186 | 187 | # topk_values, topk_indices = attn_scores.topk(top_n, dim=-1) 188 | if n_width * n_height != attn_scores.shape[1]: 189 | n_width = n_width // 2 190 | n_height = n_height // 2 191 | 192 | 193 | # transform the topk_indices into coordinates 194 | topk_coords = [] 195 | for idx in topk_indices: 196 | x = idx % n_width 197 | y = idx // n_width 198 | topk_coords.append((int(y), int(x), int(idx))) 199 | 200 | # divide the topk_coords into regions based on connectivity 201 | regions = [] 202 | visited = set() 203 | 204 | for i, (y, x, idx) in enumerate(topk_coords): 205 | if idx in visited: 206 | continue 207 | 208 | region = [(y, x, idx, topk_values[i])] 209 | visited.add(idx) 210 | queue = [(y, x, idx, topk_values[i])] 211 | 212 | # BFS 213 | while queue: 214 | cy, cx, c_idx, c_val = queue.pop(0) 215 | 216 | # check four directions 217 | for dy, dx in [(-1, 0), (1, 0), (0, -1), (0, 1)]: 218 | ny, nx = cy + dy, cx + dx 219 | n_idx = ny * n_width + nx 220 | 221 | # check whether the new coordinates are within bounds 222 | for j, (ty, tx, t_idx) in enumerate(topk_coords): 223 | if ty == ny and tx == nx and t_idx not in visited: 224 | visited.add(t_idx) 225 | region.append((ny, nx, t_idx, topk_values[j])) 226 | queue.append((ny, nx, t_idx, topk_values[j])) 227 | 228 | regions.append(region) 229 | 230 | region_scores = [] 231 | region_centers = [] 232 | region_points = [] 233 | 234 | for region in regions: 235 | # calculate the average score of the region 236 | avg_score = sum(item[3] for item in region) / len(region) 237 | region_scores.append(avg_score) 238 | 239 | 240 | # calculate the normalized center of the region 241 | normalized_centers = [] 242 | weights = [] 243 | y_coords = set() 244 | x_coords = set() 245 | 246 | for y, x, _, score in region: 247 | center_y = (y + 0.5) / n_height 248 | center_x = (x + 0.5) / n_width 249 | normalized_centers.append((center_x, center_y)) 250 | weights.append(score) 251 | 252 | 253 | y_coords.add(center_y) 254 | x_coords.add(center_x) 255 | 256 | 257 | region_points.append(normalized_centers) 258 | 259 | 260 | # calculate the average center of the region 261 | if not rect_center: 262 | # weighted average 263 | total_weight = sum(weights) 264 | weighted_x = sum(nc[0] * w for nc, w in zip(normalized_centers, weights)) / total_weight 265 | weighted_y = sum(nc[1] * w for nc, w in zip(normalized_centers, weights)) / total_weight 266 | avg_center_x, avg_center_y = weighted_x, weighted_y 267 | else: 268 | avg_center_x = sum(x_coords) / len(x_coords) 269 | avg_center_y = sum(y_coords) / len(y_coords) 270 | region_centers.append((avg_center_x, avg_center_y)) 271 | 272 | # select top regions based on scores 273 | sorted_indices = sorted(range(len(region_scores)), key=lambda i: region_scores[i], reverse=True) 274 | sorted_scores = [region_scores[i] for i in sorted_indices] 275 | sorted_centers = [region_centers[i] for i in sorted_indices] 276 | sorted_points = [region_points[i] for i in sorted_indices] 277 | best_point = sorted_centers[0] 278 | 279 | 280 | if no_groups: 281 | if return_all_regions: 282 | return sorted_centers + [[(x[1] + 0.5) / n_width, (x[0] + 0.5) /n_height] for x in topk_coords] 283 | else: 284 | return sorted_centers + [(topk_coords[0][1]+ 0.5) / n_width, (topk_coords[0][0]+ 0.5) / n_height] 285 | 286 | if return_all_regions: 287 | return best_point, sorted_centers, sorted_scores, sorted_points 288 | else: 289 | return best_point 290 | 291 | 292 | 293 | 294 | def ground_only_positive(self, instruction, image, target_point): 295 | if isinstance(image, str): 296 | image_path = image 297 | assert os.path.exists(image_path) and os.path.isfile(image_path), "Invalid input image path." 298 | image = Image.open(image_path).convert('RGB') 299 | else: 300 | assert isinstance(image, Image.Image) 301 | image_path = image_to_temp_filename(image) 302 | 303 | width, height = image.size 304 | 305 | print(image_path) 306 | if 'v2' in image_path: 307 | key = image_path.split('/')[-1] 308 | elif 'Pro' in image_path: 309 | key = '/'.join(image_path.split('/')[-2:]) 310 | else: 311 | key = image_path.split('/')[-1] 312 | key += instruction 313 | index = self.json_index_dict[key] 314 | 315 | 316 | if self.method == 'best_one': 317 | predictions = self.json_prediction[index]['topk_points'] 318 | predictions = [predictions[0]] # only the first one 319 | else: 320 | attn_scores = self.json_prediction[index]['attn_scores'] 321 | if 'n_width' in self.json_prediction[index]: 322 | n_width, n_height = self.json_prediction[index]['n_width'], self.json_prediction[index]['n_height'] 323 | elif 'img_size_crop' in self.json_prediction[index]: 324 | n_width, n_height = self.json_prediction[index]['img_size_crop'] 325 | else: 326 | raise ValueError("Invalid json prediction format. 'n_width' or 'img_size_crop' not found.") 327 | predictions = self.get_prediction_region_point(attn_scores, n_width, n_height, top_n=20, return_all_regions=True, rect_center=False, no_groups=True) 328 | 329 | pred_points_list = [[pred[0] * image.size[0], pred[1] * image.size[1]] for pred in predictions] 330 | score_list = [] 331 | 332 | 333 | print(predictions, len(predictions)) 334 | if len(predictions) > 1: 335 | if self.method == 'score': 336 | for point in pred_points_list[len(score_list):]: 337 | score = self.verifier_score(instruction, image, point) 338 | score_list.append(score) 339 | if score >= self.threshold: 340 | break 341 | # get the max score 342 | print(score_list, len(score_list)) 343 | point = predictions[score_list.index(max(score_list))] 344 | else: 345 | point = predictions[0] 346 | 347 | 348 | result_dict = { 349 | "result": "positive", 350 | "format": "x1y1x2y2", 351 | "raw_response": pred_points_list, 352 | "bbox": None, 353 | "point": point, 354 | } 355 | return result_dict 356 | 357 | 358 | 359 | 360 | 361 | 362 | 363 | --------------------------------------------------------------------------------