├── .github └── workflows │ └── stale.yml ├── .gitignore ├── LICENSE ├── README.md ├── README_zh.md ├── assets ├── logo.png └── test.jpeg ├── eval ├── README.md ├── eval_data │ ├── process_ac.py │ ├── process_aitz.py │ ├── process_odyssey.py │ ├── readme.md │ └── requirements.txt ├── grounding_eval │ ├── README.md │ ├── code │ │ ├── Aguvis │ │ │ ├── bbox2text_eval_aguvis.py │ │ │ ├── fun2bbox_eval_aguvis.py │ │ │ └── text2bbox_eval_aguvis.py │ │ ├── GPT-4o │ │ │ ├── bbox2text_eval.py │ │ │ ├── fun2bbox_eval_click.py │ │ │ ├── fun2bbox_eval_gpt-4o_with_grounding.py │ │ │ ├── process_image.py │ │ │ ├── text2bbox_eval_click.py │ │ │ └── text2bbox_eval_gpt-4o_with_grounding.py │ │ ├── Intern2.5-VL │ │ │ ├── evaluate_grounding_bbox2text.py │ │ │ ├── evaluate_grounding_fun2bbox.py │ │ │ └── evaluate_grounding_text2bbox.py │ │ ├── OS-Altas │ │ │ ├── bbox2text_eval_osatlas.py │ │ │ ├── fun2bbox_eval_osatlas.py │ │ │ └── text2bbox_eval_osatlas.py │ │ ├── OS-genesis │ │ │ ├── bbox2text_eval_osgenesis.py │ │ │ ├── fun2bbox_eval_osgenesis.py │ │ │ └── text2bbox_eval_osgenesis.py │ │ ├── Qwen2.5-VL │ │ │ ├── bbox2text_eval_qwen.py │ │ │ ├── fun2bbox_eval_qwen.py │ │ │ └── text2bbox_eval_qwen.py │ │ ├── UI-TARS │ │ │ ├── bbox2text_eval_uitars.py │ │ │ ├── fun2bbox_eval_uitars.py │ │ │ └── text2bbox_eval_uitars.py │ │ └── minicpm │ │ │ ├── bbox2text_eval_minicpm.py │ │ │ ├── fun2bbox_eval_minicpm.py │ │ │ └── text2bbox_eval_minicpm.py │ └── dataset │ │ └── code │ │ ├── cap.jsonl │ │ └── ocr.jsonl ├── run_eval_agent.py ├── run_predict_aguvis.py ├── run_predict_minicpm.py ├── run_predict_odyssey.py ├── run_predict_os_atlas.py ├── run_predict_os_gensis.py ├── run_predict_qwen2_5VL.py ├── run_predict_ui_tars.py └── utils │ ├── SimHei.ttf │ ├── action_type.py │ ├── action_utils.py │ ├── convert_output.py │ ├── evaluator.py │ ├── qwen_mobile_tool.py │ ├── schema │ ├── schema.json │ ├── schema_for_extraction.json │ └── test_schema.py │ ├── utils.py │ ├── utils_odyssey │ ├── config.json │ ├── configuration_qwen.py │ ├── generation_config.json │ ├── his_index.json │ ├── model.safetensors.index.json │ ├── modeling_qwen.py │ ├── pytorch_model.bin.index.json │ ├── qwen.tiktoken │ ├── qwen_generation_utils.py │ ├── special_tokens_map.json │ ├── tokenization_qwen.py │ ├── tokenizer_config.json │ └── visual.py │ └── utils_qwen │ └── agent_function_call.py ├── model └── README.md ├── requirements.txt ├── rft ├── config_files │ ├── ds.yml │ ├── ds_dst.yml │ ├── fsdp.yml │ ├── fsdp2_dst.yml │ ├── fsdp_dst.yml │ ├── hostfile │ ├── zero.json │ ├── zero2.json │ └── zero3.json ├── configs.py ├── fsdp.sh ├── grpo.py ├── readme.md ├── requirements.txt └── trainer │ ├── __init__.py │ ├── arl.py │ ├── utils │ ├── __init__.py │ ├── dataloader.py │ ├── dataset.py │ ├── gui_eval.py │ └── process.py │ └── zmq.py └── sft ├── __init__.py ├── dataset.py ├── ds_config_zero2.json ├── ds_config_zero3.json ├── ds_config_zero3_offload.json ├── finetune.py ├── finetune_ds.sh ├── finetune_lora.sh ├── readme.md └── trainer.py /.github/workflows/stale.yml: -------------------------------------------------------------------------------- 1 | name: Close inactive issues 2 | on: 3 | schedule: 4 | - cron: "30 1 * * *" 5 | 6 | jobs: 7 | close-issues: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | pull-requests: write 12 | steps: 13 | - uses: actions/stale@v9 14 | with: 15 | days-before-issue-stale: 30 16 | days-before-issue-close: 14 17 | stale-issue-label: "stale" 18 | stale-issue-message: "This issue is stale because it has been open for 30 days with no activity." 19 | close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." 20 | days-before-pr-stale: -1 21 | days-before-pr-close: -1 22 | repo-token: ${{ secrets.GITHUB_TOKEN }} 23 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | share/python-wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | MANIFEST 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | db.sqlite3-journal 65 | 66 | # Flask stuff: 67 | instance/ 68 | .webassets-cache 69 | 70 | # Scrapy stuff: 71 | .scrapy 72 | 73 | # Sphinx documentation 74 | docs/_build/ 75 | 76 | # PyBuilder 77 | .pybuilder/ 78 | target/ 79 | 80 | # Jupyter Notebook 81 | .ipynb_checkpoints 82 | 83 | # IPython 84 | profile_default/ 85 | ipython_config.py 86 | 87 | # pyenv 88 | # For a library or package, you might want to ignore these files since the code is 89 | # intended to run in multiple environments; otherwise, check them in: 90 | # .python-version 91 | 92 | # pipenv 93 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 94 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 95 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 96 | # install all needed dependencies. 97 | #Pipfile.lock 98 | 99 | # UV 100 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 101 | # This is especially recommended for binary packages to ensure reproducibility, and is more 102 | # commonly ignored for libraries. 103 | #uv.lock 104 | 105 | # poetry 106 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 107 | # This is especially recommended for binary packages to ensure reproducibility, and is more 108 | # commonly ignored for libraries. 109 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 110 | #poetry.lock 111 | 112 | # pdm 113 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 114 | #pdm.lock 115 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 116 | # in version control. 117 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 118 | .pdm.toml 119 | .pdm-python 120 | .pdm-build/ 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | # Ruff stuff: 173 | .ruff_cache/ 174 | 175 | # PyPI configuration file 176 | .pypirc 177 | -------------------------------------------------------------------------------- /assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/AgentCPM-GUI/b3dbe5c68643858351fbf10e0ce2a8922e83bf8e/assets/logo.png -------------------------------------------------------------------------------- /assets/test.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/AgentCPM-GUI/b3dbe5c68643858351fbf10e0ce2a8922e83bf8e/assets/test.jpeg -------------------------------------------------------------------------------- /eval/eval_data/process_aitz.py: -------------------------------------------------------------------------------- 1 | import json, ast, sys, os 2 | from pathlib import Path 3 | from PIL import Image 4 | 5 | current_file_path = os.path.abspath(__file__) 6 | current_dir = os.path.dirname(current_file_path) 7 | 8 | root = Path(os.path.join(current_dir, "aitz_test/test")) 9 | img_key = 'image_path' # field that points to the image 10 | 11 | for jp in root.rglob('*.json'): 12 | with open(jp, 'r', encoding='utf-8') as f: 13 | obj = json.load(f) 14 | 15 | recs = obj if isinstance(obj, list) else [obj] 16 | 17 | for r in recs: 18 | if img_key not in r: 19 | continue 20 | 21 | img = root / Path(r[img_key]) 22 | 23 | w, h = Image.open(img).size 24 | c = len(Image.open(img).getbands()) 25 | r['image_height'] = h 26 | r['image_width'] = w 27 | r['image_channels'] = c 28 | 29 | if 'ui_positions' in r: 30 | try: 31 | pos = ast.literal_eval(r['ui_positions']) 32 | norm = [[y/h, x/w, hh/h, ww/w] for y, x, hh, ww in pos] 33 | r['ui_positions'] = json.dumps(norm, ensure_ascii=False) 34 | except Exception: 35 | pass 36 | 37 | with open(jp, 'w', encoding='utf-8') as f: 38 | json.dump(obj, f, ensure_ascii=False, indent=2) 39 | 40 | print('Done') -------------------------------------------------------------------------------- /eval/eval_data/process_odyssey.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import json 4 | from copy import deepcopy 5 | from tqdm import tqdm 6 | from pathlib import Path 7 | 8 | 9 | current_file_path = os.path.abspath(__file__) 10 | current_dir = os.path.dirname(current_file_path) 11 | split_data_path = os.path.join(current_dir, 'tmp/GUI-Odyssey/test_anno/random_split.json') 12 | 13 | out_dir = Path(os.path.join(current_dir, 'odyssey/test/odyssey')) 14 | out_dir.mkdir(parents=True, exist_ok=True) 15 | 16 | def transform_action_data(action_data:dict) -> dict: 17 | 18 | # directly retrieve relevant information from the data. 19 | episode_id:str = action_data["image"].split('/')[-1].split("_")[0] 20 | step_id:str = action_data["image"].split('/')[-1].split("_")[1][:-4] 21 | episode_length:str = action_data["step_length"] 22 | image_path:str = action_data["image"] 23 | instruction:str = action_data["question"] 24 | answer:str = action_data["answer"] 25 | 26 | # Get the picture information (w/h) of the data. They are restored under annotation/*.json. 27 | with open(os.path.join(current_dir, "tmp/GUI-Odyssey/annotations", f"{episode_id}.json"), "r", encoding="utf-8") as f: 28 | data = json.load(f)["device_info"] 29 | image_width:int = data["w"] 30 | image_height:int = data["h"] 31 | 32 | # first intialize the variables. 33 | result_action_text = "" 34 | result_touch_yx = [-1.0,-1.0] 35 | result_lift_yx = [-1.0,-1.0] 36 | duration = 0 37 | result_action_type:int = 2 38 | 39 | # case: Click. 40 | if answer.startswith("CLICK"): 41 | result_action_type = 4 42 | x,y = list(map(float, answer[8:-1].split(","))) 43 | result_touch_yx = [y / 1000, x / 1000] # get the ratio. 44 | result_lift_yx = deepcopy(result_touch_yx) # same point. 45 | 46 | if answer.startswith("SCROLL"): 47 | result_action_type = 4 48 | 49 | result_touch_yx = [0.5, 0.5] 50 | result_lift_yx = deepcopy(result_touch_yx) 51 | 52 | # Manually create the end point as odyssey didn't give us an exact coordinate. 53 | if answer.endswith("UP"): 54 | result_lift_yx[0] -= 0.1 55 | 56 | elif answer.endswith("DOWN"): 57 | result_lift_yx[0] += 0.1 58 | 59 | elif answer.endswith("LEFT"): 60 | result_lift_yx[1] -= 0.1 61 | 62 | else: 63 | result_lift_yx[1] += 0.1 64 | 65 | if answer.startswith("LONG_PRESS"): 66 | result_action_type = 0 67 | 68 | x,y = list(map(float, answer[13:-1].split(","))) 69 | result_touch_yx = [y / 1000, x / 1000] 70 | 71 | result_lift_yx = result_lift_yx = deepcopy(result_touch_yx) 72 | 73 | if answer.startswith("TYPE"): 74 | result_action_type = 3 75 | result_action_text:str = answer[5:].strip() 76 | 77 | if answer.startswith("PRESS_HOME"): 78 | result_action_type = 6 79 | 80 | if answer.startswith("PRESS_BACK"): 81 | result_action_type = 5 82 | 83 | if answer.startswith("PRESS_RECENT"): 84 | result_action_type = 6 # mapping as HOME. 85 | 86 | if answer.startswith("COMPLETE"): 87 | result_action_type = 10 88 | 89 | if answer.startswith("IMPOSSIBLE"): 90 | result_action_type = 11 91 | 92 | data = { 93 | "episode_id": episode_id, 94 | "step_id": step_id, 95 | "episode_length": episode_length, 96 | "image_width": image_width, 97 | "image_height": image_height, 98 | "image_path": image_path, 99 | "instruction": instruction, 100 | "result_action_type": result_action_type, 101 | "result_touch_yx": str(result_touch_yx), 102 | "result_lift_yx": str(result_lift_yx), 103 | "duration": duration, # ignore the duration. 104 | "result_action_text": result_action_text, 105 | "ui_positions": "", 106 | "low_instruction": "" 107 | } 108 | 109 | return data 110 | 111 | # Construct the data. 112 | with open(split_data_path, "r", encoding="utf-8") as f: 113 | eval_data_raw:dict = json.load(f) 114 | 115 | data_eval = [transform_action_data(data) for data in tqdm(eval_data_raw)] 116 | data_eval = [d for d in data_eval if d is not None] 117 | 118 | 119 | # save data 120 | def dump_traj(traj, out_root: Path, idx: int): 121 | if not traj: 122 | return 123 | 124 | subfolder_name = f"traj_{idx:05d}" 125 | subfolder_path = out_root / subfolder_name 126 | subfolder_path.mkdir(parents=True, exist_ok=True) 127 | 128 | out_filename = f"{subfolder_name}.json" 129 | out_file = subfolder_path / out_filename 130 | 131 | with out_file.open("w", encoding="utf-8") as fw: 132 | json.dump(traj, fw, ensure_ascii=False, indent=2) 133 | 134 | print(f"Save {subfolder_name}/{out_filename} (steps={len(traj)})") 135 | 136 | 137 | records = data_eval 138 | traj = [] 139 | traj_idx = 1 140 | prev_step_id, curr_instr = None, None 141 | 142 | for rec in records: 143 | step_id = int(rec["step_id"]) 144 | instr = rec["instruction"] 145 | rec['subset'] = 'odyssey' 146 | rec['step_id'] = int(rec['step_id']) 147 | rec['ui_positions'] = "[]" 148 | 149 | new_traj = ( 150 | curr_instr is None or 151 | instr != curr_instr or 152 | prev_step_id is None or 153 | step_id != prev_step_id + 1 154 | ) 155 | 156 | if new_traj and traj: 157 | dump_traj(traj, out_dir, traj_idx) 158 | traj_idx += 1 159 | traj = [] 160 | 161 | traj.append(rec) 162 | prev_step_id = step_id 163 | curr_instr = instr 164 | 165 | dump_traj(traj, out_dir, traj_idx) 166 | print("all done.") -------------------------------------------------------------------------------- /eval/eval_data/readme.md: -------------------------------------------------------------------------------- 1 | # Data Processing Scripts 2 | 3 | ``` 4 | # Setup environment 5 | 6 | cd AgentCPM-GUI/eval/eval_data 7 | conda create -n process_data python=3.11 8 | conda activate process_data 9 | pip install -r requirements.txt 10 | 11 | mkdir tmp && cd tmp 12 | git clone https://github.com/deepmind/android_env/ 13 | cd android_env; pip install . 14 | ``` 15 | 16 | ## Android Control 17 | 18 | Download [Android Control](https://github.com/google-research/google-research/tree/master/android_control) and save at ``AgentCPM-GUI/eval/eval_data/tmp/android_control`` 19 | 20 | ``` 21 | cd AgentCPM-GUI/eval/eval_data 22 | python process_ac.py 23 | ln -s android_control_test android_control_high_test 24 | ln -s android_control_test android_control_low_test 25 | ``` 26 | 27 | ## CAGUI 28 | 29 | ``` 30 | cd AgentCPM-GUI/eval/eval_data 31 | mkdir chinese_app_test && cd chinese_app_test 32 | huggingface-cli download openbmb/CAGUI --repo-type dataset --include "CAGUI_agent/**" --local-dir ./ --local-dir-use-symlinks False --resume-download 33 | mv CAGUI_agent test 34 | ``` 35 | 36 | ## aitz 37 | 38 | Download [aitz](https://github.com/IMNearth/CoAT) and save at ``AgentCPM-GUI/eval/eval_data/tmp/android_in_the_zoo`` 39 | 40 | ``` 41 | cd AgentCPM-GUI/eval/eval_data 42 | mv tmp/android_in_the_zoo ./aitz_test 43 | python process_aitz.py 44 | ``` 45 | 46 | ## gui-odyssey 47 | 48 | Download [GUI-Odyssey](https://github.com/OpenGVLab/GUI-Odyssey?tab=readme-ov-file) and save at ``AgentCPM-GUI/eval/eval_data/tmp/GUI-Odyssey``. Copy [preprocessing.py](https://github.com/OpenGVLab/GUI-Odyssey/blob/master/data/preprocessing.py) and [format_converter.py](https://github.com/OpenGVLab/GUI-Odyssey/blob/master/data/format_converter.py) from the GUI-Odyssey repo to ``AgentCPM-GUI/eval/eval_data/tmp/GUI-Odyssey`` 49 | 50 | ``` 51 | cd AgentCPM-GUI/eval/eval_data/tmp/GUI-Odyssey 52 | python preprocessing.py 53 | python format_converter.py 54 | python ../../process_odyssey.py 55 | ``` 56 | -------------------------------------------------------------------------------- /eval/eval_data/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.2.2 2 | astunparse==1.6.3 3 | attrs==25.3.0 4 | certifi==2025.4.26 5 | charset-normalizer==3.4.2 6 | dm-env==1.6 7 | dm-tree==0.1.9 8 | filelock==3.18.0 9 | flatbuffers==25.2.10 10 | fsspec==2025.3.2 11 | gast==0.6.0 12 | google-pasta==0.2.0 13 | grpcio==1.71.0 14 | h5py==3.13.0 15 | hf-xet==1.1.0 16 | huggingface-hub==0.31.1 17 | idna==3.10 18 | keras==3.9.2 19 | libclang==18.1.1 20 | Markdown==3.8 21 | markdown-it-py==3.0.0 22 | MarkupSafe==3.0.2 23 | mdurl==0.1.2 24 | ml_dtypes==0.5.1 25 | namex==0.0.9 26 | numpy==2.1.3 27 | opt_einsum==3.4.0 28 | optree==0.15.0 29 | packaging==25.0 30 | pillow==11.2.1 31 | portpicker==1.6.0 32 | protobuf==5.29.4 33 | psutil==7.0.0 34 | pygame==2.6.1 35 | Pygments==2.19.1 36 | PyYAML==6.0.2 37 | requests==2.32.3 38 | rich==14.0.0 39 | six==1.17.0 40 | tensorboard==2.19.0 41 | tensorboard-data-server==0.7.2 42 | tensorflow==2.19.0 43 | tensorflow-io-gcs-filesystem==0.37.1 44 | termcolor==3.1.0 45 | tqdm==4.67.1 46 | typing_extensions==4.13.2 47 | urllib3==2.4.0 48 | Werkzeug==3.1.3 49 | wrapt==1.17.2 50 | -------------------------------------------------------------------------------- /eval/grounding_eval/README.md: -------------------------------------------------------------------------------- 1 | # README for Evaluation 2 | 3 | Here, we have listed all the evaluation codes. Since each model has different image processing and action spaces, we have organized the evaluation codes by model. 4 | 5 | ## General Notification 6 | 7 | ### step 1 8 | Please first download the corresponding images and replace the empty folder `eval/grounding_eval/dataset/images`. 9 | 10 | ### step 2 11 | We recommend using vLLM as the engine for inference of most open-source models (except InternVL) to ensure inference speed. The specific command is as follows: 12 | ```code 13 | python -m vllm.entrypoints.openai.api_server --model /path/to/your/model --served-model-name name_of_your_model --tensor-parallel-size 4 14 | ``` 15 | 16 | ### step 3 17 | Modify the `json_data_path` variable in the script to the path of your dataset JSONL file, and then run the script: 18 | ``` 19 | python your_evaluation_script_name.py 20 | ``` 21 | 22 | ## Notification for Special Models 23 | 24 | ### InternVL series 25 | For the InternVL series of models, we rely on the model loading code provided in their open-source repository for inference. Before running the evaluation code, you need to clone the InternVL open-source repository and install the dependencies. Then, place the inference code under the path `InternVL/internvl_chat/eval`. 26 | After that, run the command: 27 | ``` 28 | torchrun --nproc_per_node=8 path/to/your/evaluate_script.py --checkpoint ${CHECKPOINT} --dynamic 29 | ``` 30 | 31 | ### GPT-4o with grounding 32 | For GPT-4o, in addition to testing its direct grounding capabilities, we also use the Omni-parser to draw bounding boxes on the components in the images, allowing GPT-4o to select the most similar bounding box and then calculate the IoU. 33 | 34 | To reproduce this result, you need to first use Omni-parser to process the images, put the script `grounding_eval/GPT-4o/process_image.py` in following position: 35 | 36 | ``` 37 | Omniparser 38 | ├──docs 39 | ... 40 | ├──weights 41 | ├──utils 42 | └──process_image.py 43 | ``` 44 | run the code and save the annotated image and bounding box in `your/path/to/annotated/image` 45 | 46 | Finally, modify the `image_data_path` variable in the script to `your/path/to/annotated/image`, and run the code. 47 | -------------------------------------------------------------------------------- /eval/grounding_eval/code/Aguvis/bbox2text_eval_aguvis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 46 | 输入:屏幕截图,边界框的相对坐标,的格式表示 47 | 输出:组件中的文本,注意是文字而非坐标! 48 | 示例输出一:可口可乐。 49 | 示例输出二:关注''' 50 | image_path = item["image"] 51 | base64_image, w, h = encode_image(image_path) 52 | content = [] 53 | # 动态添加base64_image部分到 content 列表 54 | content.append({ 55 | "type": "image_url", 56 | "image_url": { 57 | "url": f"data:image/jpeg;base64,{base64_image}", 58 | }, 59 | }) 60 | content.append({ 61 | "type": "text", 62 | "text": "屏幕上某一组件的边界框:{}".format(item["rel_position"]) 63 | }) 64 | 65 | res = await client.chat.completions.create( 66 | messages=[ 67 | { 68 | "role": "system", 69 | "content": sys_prompt, 70 | }, 71 | { 72 | "role":"user", 73 | "content": content, 74 | } 75 | ], 76 | model=model_name, 77 | max_tokens = 256 78 | ) 79 | 80 | response = res.choices[0].message.content 81 | return response.strip("\n").replace(" ","") 82 | 83 | async def verify(response, ground_truth): 84 | """ 85 | 接受模型的字符串输入,判断是否正确 86 | """ 87 | if response == ground_truth: 88 | return 1 89 | else: 90 | return 0 91 | 92 | async def process_item_async(item, client, model_name, semaphore): 93 | async with semaphore: 94 | response = await call_Qwenvl(item, client, model_name) 95 | correct = await verify(response,item["text"]) 96 | return correct 97 | 98 | async def main(): 99 | model_name = "aguivs" 100 | total = 0 101 | correct = 0 102 | json_data_path = "your/path/to/the/dataset" 103 | data = read_jsonl(json_data_path) 104 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 105 | semaphore = asyncio.Semaphore(16) 106 | tasks = [] 107 | for item in data: 108 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 109 | tasks.append(task) 110 | 111 | results = await asyncio.gather(*tasks) 112 | for result in results: 113 | correct += result 114 | total += 1 115 | print(correct, total, correct / total) 116 | 117 | return 0 118 | 119 | if __name__=="__main__": 120 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/Aguvis/fun2bbox_eval_aguvis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 42 | 43 | async def call_Qwenvl(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 48 | 输入:屏幕截图,文本描述 49 | 输出:文本的相对坐标的中心点,以pyautogui.click(x=0.4754, y=0.2062)为格式,使用[]定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 50 | ''' 51 | image_path = item["image"] 52 | base64_image, w, h = encode_image(image_path) 53 | content = [] 54 | # 动态添加base64_image部分到 content 列表 55 | content.append({ 56 | "type": "image_url", 57 | "image_url": { 58 | "url": f"data:image/jpeg;base64,{base64_image}", 59 | }, 60 | }) 61 | content.append({ 62 | "type": "text", 63 | "text": "屏幕上某一组件的功能描述:{}".format(item["text"]) 64 | }) 65 | try: 66 | res = await client.chat.completions.create( 67 | messages=[ 68 | { 69 | "role": "system", 70 | "content": sys_prompt, 71 | }, 72 | { 73 | "role":"user", 74 | "content": content, 75 | } 76 | ], 77 | model=model_name, 78 | ) 79 | 80 | response = res.choices[0].message.content 81 | except: 82 | response = 'None' 83 | return response.strip("\n").replace(" ","") 84 | 85 | async def verify(response, ground_truth): 86 | """ 87 | 接受模型的字符串输入,判断是否正确 88 | """ 89 | pattern = r'pyautogui\.click\(x=(.*?),y=(.*?)\)' 90 | matches = re.findall(pattern, response) 91 | # 将输入字符串转换为整数列表 92 | if matches: 93 | match = matches[0] 94 | bbox = list(map(float, match)) 95 | gt_bbox = list(map(float, ground_truth.strip('<>').split(','))) 96 | 97 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 98 | gt_x_min = gt_bbox[0] 99 | gt_x_max = gt_bbox[2] 100 | gt_y_min = gt_bbox[1] 101 | gt_y_max = gt_bbox[3] 102 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 103 | return 1 104 | else: 105 | print("wrong response: {}".format(response)) 106 | return 0 107 | 108 | async def process_item_async(item, client, model_name, semaphore): 109 | async with semaphore: 110 | bbox = await call_Qwenvl(item, client, model_name) 111 | correct = await verify(bbox,item["rel_position"]) 112 | return correct 113 | 114 | async def main(): 115 | model_name = "aguivs" 116 | total = 0 117 | correct = 0 118 | json_data_path = "your/path/to/the/dataset" 119 | data = read_jsonl(json_data_path) 120 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 121 | semaphore = asyncio.Semaphore(16) 122 | tasks = [] 123 | for item in data: 124 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 125 | tasks.append(task) 126 | 127 | results = await asyncio.gather(*tasks) 128 | for result in results: 129 | correct += result 130 | total += 1 131 | print(correct, total, correct / total) 132 | 133 | return 0 134 | 135 | if __name__=="__main__": 136 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/Aguvis/text2bbox_eval_aguvis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = ''' 46 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 47 | 输入:屏幕截图,文本描述 48 | 输出:边界框的坐标,使用click 49 | 示例输出:pyautogui.click(x=0.4754, y=0.2062) 50 | ''' 51 | image_path = item["image"].replace("/home/test/test03","/home/test/test12") 52 | base64_image, w, h = encode_image(image_path) 53 | content = [] 54 | # 动态添加base64_image部分到 content 列表 55 | content.append({ 56 | "type": "image_url", 57 | "image_url": { 58 | "url": f"data:image/jpeg;base64,{base64_image}", 59 | }, 60 | }) 61 | content.append({ 62 | "type": "text", 63 | "text": "屏幕上的文本:{}".format(item["text"]) 64 | }) 65 | 66 | res = await client.chat.completions.create( 67 | messages=[ 68 | { 69 | "role": "system", 70 | "content": sys_prompt, 71 | }, 72 | { 73 | "role":"user", 74 | "content": content, 75 | } 76 | ], 77 | model=model_name, 78 | temperature=0 79 | ) 80 | 81 | response = res.choices[0].message.content 82 | return response.strip("\n").replace(" ","") 83 | 84 | async def verify(response, ground_truth): 85 | """ 86 | 接受模型的字符串输入,判断是否正确 87 | """ 88 | pattern = r'pyautogui\.click\(x=(.*?),y=(.*?)\)' 89 | matches = re.findall(pattern, response) 90 | # 将输入字符串转换为整数列表 91 | if matches: 92 | match = matches[0] 93 | try: 94 | bbox = list(map(float, match)) 95 | gt_bbox = list(map(float, ground_truth.strip('<>').split(','))) 96 | 97 | gt_x_min = gt_bbox[0] 98 | gt_x_max = gt_bbox[2] 99 | gt_y_min = gt_bbox[1] 100 | gt_y_max = gt_bbox[3] 101 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 102 | return 1 103 | except: 104 | return 0 105 | else: 106 | print("wrong response: {}".format(response)) 107 | return 0 108 | 109 | async def process_item_async(item, client, model_name, semaphore): 110 | async with semaphore: 111 | bbox = await call_Qwenvl(item, client, model_name) 112 | correct = await verify(bbox,item["rel_position"]) 113 | return correct 114 | 115 | async def main(): 116 | model_name = "aguivs" 117 | total = 0 118 | correct = 0 119 | json_data_path = "your/path/to/the/dataset" 120 | data = read_jsonl(json_data_path) 121 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 122 | semaphore = asyncio.Semaphore(16) 123 | tasks = [] 124 | for item in data: 125 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 126 | tasks.append(task) 127 | 128 | results = await asyncio.gather(*tasks) 129 | for result in results: 130 | correct += result 131 | total += 1 132 | print(correct, total, correct / total) 133 | 134 | 135 | if __name__=="__main__": 136 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/bbox2text_eval.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 46 | 输入:屏幕截图,边界框的绝对坐标的格式表示,前两者为边界框的绝对坐标,后两者为边界框的大小 47 | 输出:组件中的文本,注意是文字而非坐标! 48 | 示例输出一:可口可乐。 49 | 示例输出二:关注''' 50 | image_path = item["image"] 51 | base64_image, w, h = encode_image(image_path) 52 | content = [] 53 | # 动态添加base64_image部分到 content 列表 54 | content.append({ 55 | "type": "image_url", 56 | "image_url": { 57 | "url": f"data:image/jpeg;base64,{base64_image}", 58 | }, 59 | }) 60 | content.append({ 61 | "type": "text", 62 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的边界框:{}".format(w, h, item["abs_position"]) 63 | }) 64 | 65 | res = await client.chat.completions.create( 66 | messages=[ 67 | { 68 | "role": "system", 69 | "content": sys_prompt, 70 | }, 71 | { 72 | "role":"user", 73 | "content": content, 74 | } 75 | ], 76 | model=model_name, 77 | ) 78 | 79 | response = res.choices[0].message.content 80 | return response.strip("\n").replace(" ","") 81 | 82 | async def verify(response, ground_truth): 83 | """ 84 | 接受模型的字符串输入,判断是否正确 85 | """ 86 | print(response, ground_truth) 87 | if response == ground_truth: 88 | return 1 89 | else: 90 | return 0 91 | 92 | async def process_item_async(item, client, model_name, semaphore): 93 | async with semaphore: 94 | response = await call_Qwenvl(item, client, model_name) 95 | correct = await verify(response,item["text"]) 96 | return correct 97 | 98 | async def main(): 99 | model_name = "gpt-4o" 100 | total = 0 101 | correct = 0 102 | json_data_path = "your/path/to/the/dataset" 103 | data = read_jsonl(json_data_path) 104 | client = AsyncClient(api_key="sk-123") 105 | semaphore = asyncio.Semaphore(16) 106 | tasks = [] 107 | for item in data: 108 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 109 | tasks.append(task) 110 | 111 | results = await asyncio.gather(*tasks) 112 | for result in results: 113 | correct += result 114 | total += 1 115 | print(correct, total, correct / total) 116 | 117 | return 0 118 | 119 | if __name__=="__main__": 120 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/fun2bbox_eval_click.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 42 | 43 | async def call_Qwenvl(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出组件的坐标。 48 | 输入:屏幕截图,功能描述 49 | 输出:边界框的绝对坐标,以为格式,使用<>定位,其中不能存在任何非坐标字符 50 | 示例输出:我认为该UI元素在<600,1000>附近 51 | ''' 52 | image_path = item["image"] 53 | base64_image, w, h = encode_image(image_path) 54 | content = [] 55 | # 动态添加base64_image部分到 content 列表 56 | content.append({ 57 | "type": "image_url", 58 | "image_url": { 59 | "url": f"data:image/jpeg;base64,{base64_image}", 60 | }, 61 | }) 62 | content.append({ 63 | "type": "text", 64 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的功能描述:{}".format(w, h, item["text"]) 65 | }) 66 | try: 67 | res = await client.chat.completions.create( 68 | messages=[ 69 | { 70 | "role": "system", 71 | "content": sys_prompt, 72 | }, 73 | { 74 | "role":"user", 75 | "content": content, 76 | } 77 | ], 78 | model=model_name, 79 | ) 80 | 81 | response = res.choices[0].message.content 82 | except: 83 | response = 'None' 84 | return response.strip("\n").replace(" ","") 85 | 86 | async def verify(response, ground_truth): 87 | """ 88 | 接受模型的字符串输入,判断是否正确 89 | """ 90 | pattern = r'<\d+,\d+>' 91 | matches = re.findall(pattern, response) 92 | # 将输入字符串转换为整数列表 93 | if matches: 94 | match = matches[0] 95 | bbox = list(map(int, match.strip('<>').split(','))) 96 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 97 | 98 | gt_x_min = gt_bbox[0] 99 | gt_x_max = gt_bbox[2] 100 | gt_y_min = gt_bbox[1] 101 | gt_y_max = gt_bbox[3] 102 | print(bbox, gt_bbox) 103 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 104 | return 1 105 | else: 106 | print("wrong response: {}".format(response)) 107 | return 0 108 | 109 | async def process_item_async(item, client, model_name, semaphore): 110 | async with semaphore: 111 | bbox = await call_Qwenvl(item, client, model_name) 112 | correct = await verify(bbox,item["abs_position"]) 113 | return correct 114 | 115 | 116 | async def main(): 117 | model_name = "gpt-4o" 118 | total = 0 119 | correct = 0 120 | json_data_path = "your/path/to/the/dataset" 121 | data = read_jsonl(json_data_path) 122 | client = AsyncClient(api_key="sk-123") 123 | semaphore = asyncio.Semaphore(16) 124 | tasks = [] 125 | for item in data: 126 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 127 | tasks.append(task) 128 | 129 | results = await asyncio.gather(*tasks) 130 | for result in results: 131 | correct += result 132 | total += 1 133 | print(correct, total, correct / total) 134 | 135 | return 0 136 | 137 | if __name__=="__main__": 138 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/fun2bbox_eval_gpt-4o_with_grounding.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8") 42 | 43 | async def call_GPT(item, client, model_name, new_image_dir): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | system_prompt = ''' 48 | 你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的位置。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出与组件最接近的框的编号 49 | 输入:屏幕截图,文本描述 50 | 输出:屏幕截图中框的编号,以为格式 51 | 示例输出一:0 52 | 示例输出二:14 53 | ''' 54 | 55 | old_image_path = item["image"] 56 | old_image_dir = '/'.join(old_image_path.split('/')[:-1]) 57 | new_image_path = item["image"].replace(old_image_dir, new_image_dir) 58 | base64_image = encode_image(new_image_path) 59 | content = [] 60 | # 动态添加base64_image部分到 content 列表 61 | content.append({ 62 | "type": "image_url", 63 | "image_url": { 64 | "url": f"data:image/jpeg;base64,{base64_image}", 65 | }, 66 | }) 67 | content.append({ 68 | "type": "text", 69 | "text": "屏幕上组件的功能描述为:{}".format(item["text"]) 70 | }) 71 | 72 | res = await client.chat.completions.create( 73 | messages=[ 74 | { 75 | "role":"system", 76 | "content": system_prompt 77 | }, 78 | { 79 | "role":"user", 80 | "content": content, 81 | } 82 | ], 83 | model=model_name, 84 | ) 85 | 86 | response = res.choices[0].message.content 87 | return response.strip("\n").replace(" ","") 88 | 89 | async def verify(response, json_list, ground_truth,w, h): 90 | """ 91 | 接受模型的字符串输入,判断是否正确 92 | """ 93 | pattern = r'\d+' 94 | matches = re.findall(pattern, response) 95 | # 将输入字符串转换为整数列表 96 | if matches: 97 | match = matches[0] 98 | bbox_idx = int(match.replace('','').replace('','')) 99 | pre_bbox = json_list[bbox_idx]["bbox"] 100 | 101 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 102 | gt_x_min = gt_bbox[0] 103 | gt_x_max = gt_bbox[2] 104 | gt_y_min = gt_bbox[1] 105 | gt_y_max = gt_bbox[3] 106 | pre_x_min = pre_bbox[0]*w 107 | pre_x_max = pre_bbox[2]*w 108 | pre_y_min = pre_bbox[1]*h 109 | pre_y_max = pre_bbox[3]*h 110 | inter_x_min = max(gt_x_min, pre_x_min) 111 | inter_x_max = min(gt_x_max, pre_x_max) 112 | inter_y_min = max(gt_y_min, pre_y_min) 113 | inter_y_max = min(gt_y_max, pre_y_max) 114 | 115 | # 如果两个 bounding boxes 没有交集,交集面积为 0 116 | if inter_x_min > inter_x_max or inter_y_min > inter_y_max: 117 | inter_area = 0 118 | else: 119 | inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min) 120 | 121 | # 计算两个 bounding boxes 的并集 122 | gt_area = (gt_x_max - gt_x_min) * (gt_y_max - gt_y_min) 123 | pre_area = (pre_x_max - pre_x_min) * (pre_y_max - pre_y_min) 124 | union_area = gt_area + pre_area - inter_area 125 | 126 | # 计算 IoU 127 | iou = inter_area / union_area 128 | 129 | # 判断 IoU 是否大于 50% 130 | if iou > 0.5: 131 | return 1 132 | else: 133 | return 0 134 | else: 135 | print("wrong response: {}".format(response)) 136 | return 0 137 | 138 | async def process_item_async(item, json_item, client, model_name, semaphore, new_image_dir): 139 | async with semaphore: 140 | old_image_path = item["image"] 141 | old_image_dir = '/'.join(old_image_path.split('/')[:-1]) 142 | new_image_path = item["image"].replace(old_image_dir, new_image_dir) 143 | image = Image.open(new_image_path) 144 | w = image.width 145 | h = image.height 146 | bbox = await call_GPT(item, client, model_name, new_image_dir) 147 | correct = await verify(bbox,json_item ,item["abs_position"],w,h) 148 | return correct 149 | 150 | async def main(): 151 | model_name = "gpt-4o" 152 | total = 0 153 | correct = 0 154 | json_data_path = "your/path/to/the/dataset" 155 | image_data_path = "your/path/to/annotated/image" 156 | data = read_jsonl(json_data_path) 157 | client = AsyncClient(api_key="sk-123") 158 | semaphore = asyncio.Semaphore(8) 159 | tasks = [] 160 | for item in data: 161 | try: 162 | old_image_path = item["image"] 163 | old_image_dir = '/'.join(old_image_path.split('/')[:-1]) 164 | new_image_path = item["image"].replace(old_image_dir, image_data_path) 165 | json_path = new_image_path.replace(".jpeg", ".json") 166 | json_item = json.load(open(json_path)) 167 | task = asyncio.create_task(process_item_async(item, json_item,client, model_name, image_data_path, semaphore)) 168 | tasks.append(task) 169 | except: 170 | pass 171 | 172 | results = await asyncio.gather(*tasks) 173 | for result in results: 174 | correct += result 175 | total += 1 176 | print(correct, total, correct / total) 177 | 178 | return 0 179 | 180 | if __name__=="__main__": 181 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/process_image.py: -------------------------------------------------------------------------------- 1 | from util.utils import get_som_labeled_img, check_ocr_box, get_caption_model_processor, get_yolo_model 2 | import torch 3 | from ultralytics import YOLO 4 | from PIL import Image 5 | import os 6 | import base64 7 | import matplotlib.pyplot as plt 8 | import io 9 | import json 10 | 11 | device = 'cuda' 12 | model_path='weights/icon_detect/model.pt' 13 | 14 | som_model = get_yolo_model(model_path) 15 | 16 | som_model.to(device) 17 | print('model to {}'.format(device)) 18 | 19 | caption_model_processor = get_caption_model_processor(model_name="florence2", model_name_or_path="weights/icon_caption_florence", device=device) 20 | source_dir = "your/path/to/original/image" 21 | target_dir = "your/path/to/annotated/image" 22 | 23 | 24 | for image_dir in os.listdir(source_dir): 25 | if image_dir.endswith(".jpeg"): 26 | image_path = os.path.join(source_dir, image_dir) 27 | output_dir = image_path.replace(source_dir, target_dir) 28 | if os.path.exists(output_dir): 29 | continue 30 | else: 31 | image = Image.open(image_path) 32 | image_rgb = image.convert('RGB') 33 | print('image size:', image.size) 34 | 35 | box_overlay_ratio = max(image.size) / 3200 36 | draw_bbox_config = { 37 | 'text_scale': 0.8 * box_overlay_ratio, 38 | 'text_thickness': max(int(2 * box_overlay_ratio), 1), 39 | 'text_padding': max(int(3 * box_overlay_ratio), 1), 40 | 'thickness': max(int(3 * box_overlay_ratio), 1), 41 | } 42 | BOX_TRESHOLD = 0.05 43 | 44 | try: 45 | ocr_bbox_rslt, is_goal_filtered = check_ocr_box(image_path, display_img = False, output_bb_format='xyxy', goal_filtering=None, easyocr_args={'paragraph': False, 'text_threshold':0.9}, use_paddleocr=True) 46 | text, ocr_bbox = ocr_bbox_rslt 47 | dino_labled_img, label_coordinates, parsed_content_list = get_som_labeled_img(image_path, som_model, BOX_TRESHOLD = BOX_TRESHOLD, output_coord_in_ratio=True, ocr_bbox=ocr_bbox,draw_bbox_config=draw_bbox_config, caption_model_processor=caption_model_processor, ocr_text=text,use_local_semantics=True, iou_threshold=0.7, scale_img=False, batch_size=128) 48 | 49 | image = Image.open(io.BytesIO(base64.b64decode(dino_labled_img))) 50 | if not os.path.exists(output_dir): 51 | mkdir = '/'.join(output_dir.split('/')[:-1]) 52 | os.makedirs(mkdir, exist_ok=True) 53 | image.save(output_dir) 54 | content_output_dir = output_dir.replace('.jpeg', '.json') 55 | if not os.path.exists(content_output_dir): 56 | mkdir = '/'.join(content_output_dir.split('/')[:-1]) 57 | os.makedirs(mkdir, exist_ok=True) 58 | with open(content_output_dir, "w", encoding="utf-8") as f: 59 | json.dump(parsed_content_list,f,indent=2) 60 | except: 61 | print(f"未成功处理:{image_path}") -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/text2bbox_eval_click.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | # sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 46 | # 输入:屏幕截图,文本描述 47 | # 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 48 | # 示例输出一:我认为该文本在<600,1000>附近 49 | # 示例输出二:该文本的位置是<1238,430>''' 50 | sys_prompt = ''' 51 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 52 | 输入:屏幕截图,文本描述 53 | 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 54 | 示例输出一:我认为该文本在<600,1000>附近 55 | ''' 56 | image_path = item["image"] 57 | base64_image, w, h = encode_image(image_path) 58 | content = [] 59 | # 动态添加base64_image部分到 content 列表 60 | content.append({ 61 | "type": "image_url", 62 | "image_url": { 63 | "url": f"data:image/jpeg;base64,{base64_image}", 64 | }, 65 | }) 66 | content.append({ 67 | "type": "text", 68 | "text": "当前屏幕的尺寸为{}*{},屏幕上的文本:{}".format(w, h, item["text"]) 69 | }) 70 | try: 71 | res = await client.chat.completions.create( 72 | messages=[ 73 | { 74 | "role": "system", 75 | "content": sys_prompt, 76 | }, 77 | { 78 | "role":"user", 79 | "content": content, 80 | } 81 | ], 82 | model=model_name, 83 | temperature=0 84 | ) 85 | 86 | response = res.choices[0].message.content 87 | except: 88 | response = 'None' 89 | return response.strip("\n").replace(" ","") 90 | 91 | async def verify(response, ground_truth): 92 | """ 93 | 接受模型的字符串输入,判断是否正确 94 | """ 95 | pattern = r'<\d+,\d+>' 96 | matches = re.findall(pattern, response) 97 | # 将输入字符串转换为整数列表 98 | if matches: 99 | match = matches[0] 100 | bbox = list(map(int, match.strip('<>').split(','))) 101 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 102 | 103 | gt_x_min = gt_bbox[0] 104 | gt_x_max = gt_bbox[2] 105 | gt_y_min = gt_bbox[1] 106 | gt_y_max = gt_bbox[3] 107 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 108 | return 1 109 | else: 110 | print("wrong response: {}".format(response)) 111 | return 0 112 | 113 | async def process_item_async(item, client, model_name, semaphore): 114 | async with semaphore: 115 | bbox = await call_Qwenvl(item, client, model_name) 116 | correct = await verify(bbox,item["abs_position"]) 117 | return correct 118 | 119 | async def main(): 120 | model_name = "gpt-4o" 121 | total = 0 122 | correct = 0 123 | json_data_path = "your/path/to/the/dataset" 124 | data = read_jsonl(json_data_path) 125 | client = AsyncClient(api_key="sk-123") 126 | semaphore = asyncio.Semaphore(8) 127 | tasks = [] 128 | for item in data: 129 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 130 | tasks.append(task) 131 | 132 | results = await asyncio.gather(*tasks) 133 | for result in results: 134 | correct += result 135 | total += 1 136 | print(correct, total, correct / total) 137 | return 0 138 | 139 | if __name__=="__main__": 140 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/GPT-4o/text2bbox_eval_gpt-4o_with_grounding.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8") 42 | 43 | async def call_GPT(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | system_prompt = ''' 48 | 你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的位置。你的任务是根据给定的GUI截图和图中某个文本输出与文本最接近的框的编号 49 | 输入:屏幕截图,文本描述 50 | 输出:屏幕截图中框的编号,以为格式 51 | 示例输出一:0 52 | 示例输出二:14 53 | ''' 54 | 55 | image_path = item["image"] 56 | base64_image = encode_image(image_path) 57 | content = [] 58 | # 动态添加base64_image部分到 content 列表 59 | content.append({ 60 | "type": "image_url", 61 | "image_url": { 62 | "url": f"data:image/jpeg;base64,{base64_image}", 63 | }, 64 | }) 65 | content.append({ 66 | "type": "text", 67 | "text": "屏幕上的文本:{}".format(item["text"]) 68 | }) 69 | 70 | res = await client.chat.completions.create( 71 | messages=[ 72 | { 73 | "role":"system", 74 | "content": system_prompt 75 | }, 76 | { 77 | "role":"user", 78 | "content": content, 79 | } 80 | ], 81 | model=model_name, 82 | ) 83 | 84 | response = res.choices[0].message.content 85 | return response.strip("\n").replace(" ","") 86 | 87 | async def verify(response, json_list, ground_truth): 88 | """ 89 | 接受模型的字符串输入,判断是否正确 90 | """ 91 | pattern = r'\d+' 92 | matches = re.findall(pattern, response) 93 | # 将输入字符串转换为整数列表 94 | if matches: 95 | try: 96 | match = matches[0] 97 | bbox_idx = int(match.replace('','').replace('','')) 98 | pre_bbox = json_list[bbox_idx]["bbox"] 99 | 100 | gt_bbox = list(map(float, ground_truth.strip('<>').split(','))) 101 | gt_x_min = gt_bbox[0] 102 | gt_x_max = gt_bbox[2] 103 | gt_y_min = gt_bbox[1] 104 | gt_y_max = gt_bbox[3] 105 | print(pre_bbox, [gt_x_min,gt_y_min,gt_x_max,gt_y_max]) 106 | pre_x_min = pre_bbox[0] 107 | pre_x_max = pre_bbox[2] 108 | pre_y_min = pre_bbox[1] 109 | pre_y_max = pre_bbox[3] 110 | inter_x_min = max(gt_x_min, pre_x_min) 111 | inter_x_max = min(gt_x_max, pre_x_max) 112 | inter_y_min = max(gt_y_min, pre_y_min) 113 | inter_y_max = min(gt_y_max, pre_y_max) 114 | 115 | # 如果两个 bounding boxes 没有交集,交集面积为 0 116 | if inter_x_min > inter_x_max or inter_y_min > inter_y_max: 117 | inter_area = 0 118 | else: 119 | inter_area = (inter_x_max - inter_x_min) * (inter_y_max - inter_y_min) 120 | 121 | # 计算两个 bounding boxes 的并集 122 | gt_area = (gt_x_max - gt_x_min) * (gt_y_max - gt_y_min) 123 | pre_area = (pre_x_max - pre_x_min) * (pre_y_max - pre_y_min) 124 | union_area = gt_area + pre_area - inter_area 125 | 126 | # 计算 IoU 127 | iou = inter_area / union_area 128 | 129 | # 判断 IoU 是否大于 50% 130 | if iou > 0.5: 131 | return 1 132 | else: 133 | return 0 134 | except: 135 | return 0 136 | else: 137 | print("wrong response: {}".format(response)) 138 | return 0 139 | 140 | async def process_item_async(item, json_item, client, model_name, semaphore): 141 | async with semaphore: 142 | bbox = await call_GPT(item, client, model_name) 143 | correct = await verify(bbox,json_item ,item["rel_position"]) 144 | return correct 145 | 146 | async def main(): 147 | model_name = "gpt-4o" 148 | total = 0 149 | correct = 0 150 | json_data_path = "your/path/to/the/dataset" 151 | image_data_path = "your/path/to/annotated/image" 152 | data = read_jsonl(json_data_path) 153 | client = AsyncClient(api_key="sk-123") 154 | semaphore = asyncio.Semaphore(8) 155 | tasks = [] 156 | for item in data: 157 | try: 158 | old_image_path = item["image"] 159 | old_image_dir = '/'.join(old_image_path.split('/')[:-1]) 160 | new_image_path = item["image"].replace(old_image_dir, image_data_path) 161 | json_path = new_image_path.replace(".jpeg", ".json") 162 | json_item = json.load(open(json_path)) 163 | task = asyncio.create_task(process_item_async(item, json_item,client, model_name, image_data_path, semaphore)) 164 | tasks.append(task) 165 | except: 166 | pass 167 | 168 | results = await asyncio.gather(*tasks) 169 | for result in results: 170 | correct += result 171 | total += 1 172 | print(correct, total, correct / total) 173 | 174 | return 0 175 | 176 | if __name__=="__main__": 177 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-Altas/bbox2text_eval_osatlas.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | def process_position(item, w, h): 42 | pattern = r'<\d+, \d+, \d+, \d+>' 43 | matches = re.findall(pattern, item["abs_position"]) 44 | if matches: 45 | match = matches[0] 46 | bbox = list(map(int, match.strip('<>').split(','))) 47 | rel_position = [int(bbox[0]/w*1000), int(bbox[1]/h*1000),int(bbox[2]/w*1000),int(bbox[3]/h*1000)] 48 | return rel_position[0],rel_position[1],rel_position[2],rel_position[3] 49 | 50 | async def call_Qwenvl(item, client, model_name): 51 | """ 52 | 调用Qwen输出function的描述,输出bbox 53 | """ 54 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 55 | 输入:屏幕截图,边界框的相对坐标,缩放至0~1000, 的格式表示 56 | 输出:组件中的文本,注意是文字而非坐标! 57 | 示例输出一:可口可乐。 58 | 示例输出二:关注''' 59 | image_path = item["image"] 60 | base64_image, w, h = encode_image(image_path) 61 | bbox = process_position(item, w, h) 62 | content = [] 63 | # 动态添加base64_image部分到 content 列表 64 | content.append({ 65 | "type": "image_url", 66 | "image_url": { 67 | "url": f"data:image/jpeg;base64,{base64_image}", 68 | }, 69 | }) 70 | content.append({ 71 | "type": "text", 72 | "text": "屏幕上某一组件的边界框:<{},{},{},{}>".format(bbox[0],bbox[1],bbox[2],bbox[3]) 73 | }) 74 | 75 | res = await client.chat.completions.create( 76 | messages=[ 77 | { 78 | "role": "system", 79 | "content": sys_prompt, 80 | }, 81 | { 82 | "role":"user", 83 | "content": content, 84 | } 85 | ], 86 | model=model_name, 87 | max_tokens = 256 88 | ) 89 | 90 | response = res.choices[0].message.content 91 | return response.strip("\n").replace(" ","") 92 | 93 | async def verify(response, ground_truth): 94 | """ 95 | 接受模型的字符串输入,判断是否正确 96 | """ 97 | if response == ground_truth: 98 | return 1 99 | else: 100 | return 0 101 | 102 | async def process_item_async(item, client, model_name, semaphore): 103 | async with semaphore: 104 | response = await call_Qwenvl(item, client, model_name) 105 | correct = await verify(response,item["text"]) 106 | return correct 107 | 108 | async def main(): 109 | model_name = "OS-Atlas" 110 | total = 0 111 | correct = 0 112 | json_data_path = "your/path/to/the/dataset" 113 | data = read_jsonl(json_data_path) 114 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 115 | semaphore = asyncio.Semaphore(16) 116 | tasks = [] 117 | for item in data: 118 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 119 | tasks.append(task) 120 | 121 | results = await asyncio.gather(*tasks) 122 | for result in results: 123 | correct += result 124 | total += 1 125 | print(correct, total, correct / total) 126 | 127 | return 0 128 | 129 | if __name__=="__main__": 130 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-Altas/fun2bbox_eval_osatlas.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 读取 JSONL 文件并解析为 Python 字典列表 13 | :param file_path: JSONL 文件路径 14 | :return: 包含所有 JSON 对象的列表 15 | """ 16 | data = [] 17 | with open(file_path, "r", encoding="utf-8") as f: 18 | for line in f: 19 | data.append(json.loads(line.strip())) 20 | return data 21 | 22 | def load_image_from_path(image_path): 23 | """ 24 | 从指定路径加载图片 25 | :param image_path: 图片文件路径 26 | :return: PIL.Image 对象 27 | """ 28 | try: 29 | image = Image.open(image_path) 30 | return image 31 | except Exception as e: 32 | print(f"Error loading image from {image_path}: {e}") 33 | return None 34 | 35 | # Function to encode the image 36 | def encode_image(image_path): 37 | image = Image.open(image_path) 38 | w, h = image.width, image.height 39 | with open(image_path, "rb") as image_file: 40 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 41 | 42 | async def call_Qwenvl(item, client, model_name): 43 | """ 44 | 调用Qwen输出function的描述,输出bbox 45 | """ 46 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出组件的坐标。 47 | 输入:屏幕截图,功能描述 48 | 输出:边界框的相对坐标,缩放至0~1000,以CLICK [[x-axis, y-axis]]为格式,使用<>定位,其中不能存在任何非坐标字符 49 | 示例输出:CLICK [[600, 1000]] 50 | ''' 51 | image_path = item["image"] 52 | base64_image, w, h = encode_image(image_path) 53 | content = [] 54 | # 动态添加base64_image部分到 content 列表 55 | content.append({ 56 | "type": "image_url", 57 | "image_url": { 58 | "url": f"data:image/jpeg;base64,{base64_image}", 59 | }, 60 | }) 61 | content.append({ 62 | "type": "text", 63 | "text": "屏幕上某一组件的功能描述:{}".format(item["text"]) 64 | }) 65 | try: 66 | res = await client.chat.completions.create( 67 | messages=[ 68 | { 69 | "role": "system", 70 | "content": sys_prompt, 71 | }, 72 | { 73 | "role":"user", 74 | "content": content, 75 | } 76 | ], 77 | model=model_name, 78 | ) 79 | 80 | response = res.choices[0].message.content 81 | except: 82 | response = 'None' 83 | return response.strip("\n").replace(" ","") 84 | 85 | async def verify(response, ground_truth, w, h): 86 | """ 87 | 接受模型的字符串输入,判断是否正确 88 | """ 89 | pattern = r'\[\[\d+,\d+\]\]' 90 | matches = re.findall(pattern, response, re.DOTALL) 91 | # 将输入字符串转换为整数列表 92 | if matches: 93 | match = matches[0] 94 | bbox = list(map(int, match.strip('[[]]').split(','))) 95 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 96 | abs_bbox = [bbox[0]/1000*w, bbox[1]/1000*h] 97 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 98 | gt_x_min = gt_bbox[0] 99 | gt_x_max = gt_bbox[2] 100 | gt_y_min = gt_bbox[1] 101 | gt_y_max = gt_bbox[3] 102 | print(bbox, gt_bbox) 103 | if gt_x_min<=abs_bbox[0]<=gt_x_max and gt_y_min<=abs_bbox[1]<=gt_y_max: 104 | return 1 105 | else: 106 | print("wrong response: {}".format(response)) 107 | return 0 108 | 109 | async def process_item_async(item, client, model_name, semaphore): 110 | async with semaphore: 111 | bbox = await call_Qwenvl(item, client, model_name) 112 | image = Image.open(item["image"]) 113 | w = image.width 114 | h = image.height 115 | correct = await verify(bbox,item["abs_position"], w, h) 116 | return correct 117 | 118 | async def main(): 119 | model_name = "OS-Atlas" 120 | total = 0 121 | correct = 0 122 | json_data_path = "your/path/to/the/dataset" 123 | data = read_jsonl(json_data_path) 124 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 125 | semaphore = asyncio.Semaphore(16) 126 | tasks = [] 127 | for item in data: 128 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 129 | tasks.append(task) 130 | 131 | results = await asyncio.gather(*tasks) 132 | for result in results: 133 | correct += result 134 | total += 1 135 | print(correct, total, correct / total) 136 | 137 | return 0 138 | 139 | if __name__=="__main__": 140 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-Altas/text2bbox_eval_osatlas.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = ''' 46 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 47 | 输入:屏幕截图,文本描述 48 | 输出:边界框的坐标,使用click 49 | 示例输出:CLICK [[600, 1000]] 50 | ''' 51 | image_path = item["image"].replace("/home/test/test03","/home/test/test12") 52 | base64_image, w, h = encode_image(image_path) 53 | content = [] 54 | # 动态添加base64_image部分到 content 列表 55 | content.append({ 56 | "type": "image_url", 57 | "image_url": { 58 | "url": f"data:image/jpeg;base64,{base64_image}", 59 | }, 60 | }) 61 | content.append({ 62 | "type": "text", 63 | "text": "屏幕上的文本:{}".format(item['text']) 64 | }) 65 | try: 66 | res = await client.chat.completions.create( 67 | messages=[ 68 | { 69 | "role": "system", 70 | "content": sys_prompt, 71 | }, 72 | { 73 | "role":"user", 74 | "content": content, 75 | } 76 | ], 77 | model=model_name, 78 | temperature=0 79 | ) 80 | response = res.choices[0].message.content 81 | except: 82 | response = 'None' 83 | return response.strip("\n").replace(" ","") 84 | 85 | async def verify(response, ground_truth, w ,h): 86 | """ 87 | 接受模型的字符串输入,判断是否正确,当前判断方法:点是否落在bbox内 88 | """ 89 | pattern = r'\[\[\d+,\d+\]\]' 90 | matches = re.findall(pattern, response, re.DOTALL) 91 | # 将输入字符串转换为整数列表 92 | if matches: 93 | match = matches[0] 94 | bbox = list(map(int, match.strip('[[]]').split(','))) 95 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 96 | abs_bbox = [bbox[0]/1000*w, bbox[1]/1000*h] 97 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 98 | gt_x_min = gt_bbox[0] 99 | gt_x_max = gt_bbox[2] 100 | gt_y_min = gt_bbox[1] 101 | gt_y_max = gt_bbox[3] 102 | print(abs_bbox, gt_bbox) 103 | if gt_x_min<=abs_bbox[0]<=gt_x_max and gt_y_min<=abs_bbox[1]<=gt_y_max: 104 | return 1 105 | else: 106 | print("wrong response: {}".format(response)) 107 | return 0 108 | 109 | async def process_item_async(item, client, model_name, semaphore): 110 | async with semaphore: 111 | image = Image.open(item["image"]) 112 | w = image.width 113 | h = image.height 114 | bbox = await call_Qwenvl(item, client, model_name) 115 | correct = await verify(bbox,item["abs_position"],w,h) 116 | return correct 117 | 118 | async def main(): 119 | model_name = "OS-Atlas" 120 | total = 0 121 | correct = 0 122 | json_data_path = "your/path/to/the/dataset" 123 | data = read_jsonl(json_data_path) 124 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 125 | semaphore = asyncio.Semaphore(16) 126 | tasks = [] 127 | for item in data: 128 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 129 | tasks.append(task) 130 | 131 | results = await asyncio.gather(*tasks) 132 | for result in results: 133 | correct += result 134 | total += 1 135 | print(correct, total, correct / total) 136 | return 0 137 | 138 | if __name__=="__main__": 139 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-genesis/bbox2text_eval_osgenesis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 46 | 输入:屏幕截图,边界框的绝对坐标的格式表示 47 | 输出:组件中的文本,注意是文字而非坐标! 48 | 示例输出一:可口可乐。 49 | 示例输出二:关注''' 50 | image_path = item["image"] 51 | base64_image, w, h = encode_image(image_path) 52 | content = [] 53 | # 动态添加base64_image部分到 content 列表 54 | content.append({ 55 | "type": "image_url", 56 | "image_url": { 57 | "url": f"data:image/jpeg;base64,{base64_image}", 58 | }, 59 | }) 60 | content.append({ 61 | "type": "text", 62 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的边界框:{}".format(w, h, item["abs_position"]) 63 | }) 64 | 65 | res = await client.chat.completions.create( 66 | messages=[ 67 | { 68 | "role": "system", 69 | "content": sys_prompt, 70 | }, 71 | { 72 | "role":"user", 73 | "content": content, 74 | } 75 | ], 76 | model=model_name, 77 | ) 78 | 79 | response = res.choices[0].message.content 80 | return response.strip("\n").replace(" ","") 81 | 82 | async def verify(response, ground_truth): 83 | """ 84 | 接受模型的字符串输入,判断是否正确 85 | """ 86 | print(response, ground_truth) 87 | if response == ground_truth: 88 | return 1 89 | else: 90 | return 0 91 | 92 | async def process_item_async(item, client, model_name, semaphore): 93 | async with semaphore: 94 | response = await call_Qwenvl(item, client, model_name) 95 | correct = await verify(response,item["text"]) 96 | return correct 97 | 98 | async def main(): 99 | model_name = "osgenesis" 100 | total = 0 101 | correct = 0 102 | json_data_path = "your/path/to/the/dataset" 103 | data = read_jsonl(json_data_path) 104 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1/') 105 | semaphore = asyncio.Semaphore(16) 106 | tasks = [] 107 | for item in data: 108 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 109 | tasks.append(task) 110 | 111 | results = await asyncio.gather(*tasks) 112 | for result in results: 113 | correct += result 114 | total += 1 115 | print(correct, total, correct / total) 116 | 117 | return 0 118 | 119 | if __name__=="__main__": 120 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-genesis/fun2bbox_eval_osgenesis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 42 | 43 | async def call_Qwenvl(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出组件的坐标。 48 | 输入:屏幕截图,功能描述 49 | 输出:边界框的坐标,使用click 50 | 示例输出:actions:\n{\"action_type\":\"click\",\"x\":843,\"y\":231} 51 | ''' 52 | image_path = item["image"] 53 | base64_image, w, h = encode_image(image_path) 54 | content = [] 55 | # 动态添加base64_image部分到 content 列表 56 | content.append({ 57 | "type": "image_url", 58 | "image_url": { 59 | "url": f"data:image/jpeg;base64,{base64_image}", 60 | }, 61 | }) 62 | content.append({ 63 | "type": "text", 64 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的功能描述:{}".format(w, h, item["text"]) 65 | }) 66 | 67 | res = await client.chat.completions.create( 68 | messages=[ 69 | { 70 | "role": "system", 71 | "content": sys_prompt, 72 | }, 73 | { 74 | "role":"user", 75 | "content": content, 76 | } 77 | ], 78 | model=model_name, 79 | ) 80 | 81 | response = res.choices[0].message.content 82 | return response.strip("\n").replace(" ","") 83 | 84 | async def verify(response, ground_truth): 85 | """ 86 | 接受模型的字符串输入,判断是否正确 87 | """ 88 | pattern = r'actions:\n{"action_type":"click","x":.*?,"y":.*?}' 89 | matches = re.findall(pattern, response) 90 | # 将输入字符串转换为整数列表 91 | if matches: 92 | match = matches[0] 93 | try: 94 | bbox_json = json.loads(match.split('\n')[-1]) 95 | bbox = [int(bbox_json["x"]),int(bbox_json["y"])] 96 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 97 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 98 | gt_x_min = gt_bbox[0] 99 | gt_x_max = gt_bbox[2] 100 | gt_y_min = gt_bbox[1] 101 | gt_y_max = gt_bbox[3] 102 | print(bbox, [gt_x_min,gt_y_min,gt_x_max,gt_y_max]) 103 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 104 | return 1 105 | else: 106 | return 0 107 | except: 108 | return 0 109 | 110 | else: 111 | print("wrong response: {}".format(response)) 112 | return 0 113 | 114 | async def process_item_async(item, client, model_name, semaphore): 115 | async with semaphore: 116 | bbox = await call_Qwenvl(item, client, model_name) 117 | correct = await verify(bbox,item["abs_position"]) 118 | return correct 119 | 120 | async def main(): 121 | model_name = "osgenesis" 122 | total = 0 123 | correct = 0 124 | json_data_path = "your/path/to/the/dataset" 125 | data = read_jsonl(json_data_path) 126 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1/') 127 | semaphore = asyncio.Semaphore(16) 128 | tasks = [] 129 | for item in data: 130 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 131 | tasks.append(task) 132 | 133 | results = await asyncio.gather(*tasks) 134 | for result in results: 135 | correct += result 136 | total += 1 137 | print(correct, total, correct/total) 138 | return 0 139 | 140 | if __name__=="__main__": 141 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/OS-genesis/text2bbox_eval_osgenesis.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | # sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 46 | # 输入:屏幕截图,文本描述 47 | # 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 48 | # 示例输出一:我认为该文本在<600,1000>附近 49 | # 示例输出二:该文本的位置是<1238,430>''' 50 | sys_prompt = ''' 51 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 52 | 输入:屏幕截图,文本描述 53 | 输出:边界框的坐标,使用click 54 | 示例输出:actions:\n{\"action_type\":\"click\",\"x\":843,\"y\":231} 55 | ''' 56 | image_path = item["image"] 57 | base64_image, w, h = encode_image(image_path) 58 | content = [] 59 | # 动态添加base64_image部分到 content 列表 60 | content.append({ 61 | "type": "image_url", 62 | "image_url": { 63 | "url": f"data:image/jpeg;base64,{base64_image}", 64 | }, 65 | }) 66 | content.append({ 67 | "type": "text", 68 | "text": "当前屏幕的尺寸为{}*{},屏幕上的文本:{}".format(w, h, item['text']) 69 | }) 70 | 71 | res = await client.chat.completions.create( 72 | messages=[ 73 | { 74 | "role": "system", 75 | "content": sys_prompt, 76 | }, 77 | { 78 | "role":"user", 79 | "content": content, 80 | } 81 | ], 82 | model=model_name, 83 | temperature=0 84 | ) 85 | 86 | response = res.choices[0].message.content 87 | return response.strip("\n").replace(" ","") 88 | 89 | async def verify(response, ground_truth): 90 | """ 91 | 接受模型的字符串输入,判断是否正确,当前判断方法 92 | """ 93 | pattern = r'actions:\n{"action_type":"click","x":.*?,"y":.*?}' 94 | matches = re.findall(pattern, response) 95 | # 将输入字符串转换为整数列表 96 | if matches: 97 | match = matches[0] 98 | try: 99 | bbox_json = json.loads(match.split('\n')[-1]) 100 | bbox = [int(bbox_json["x"]),int(bbox_json["y"])] 101 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 102 | 103 | gt_x_min = gt_bbox[0] 104 | gt_x_max = gt_bbox[2] 105 | gt_y_min = gt_bbox[1] 106 | gt_y_max = gt_bbox[3] 107 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 108 | return 1 109 | except: 110 | return 0 111 | 112 | else: 113 | print("wrong response: {}".format(response)) 114 | return 0 115 | 116 | async def process_item_async(item, client, model_name, semaphore): 117 | async with semaphore: 118 | bbox = await call_Qwenvl(item, client, model_name) 119 | correct = await verify(bbox,item["abs_position"]) 120 | return correct 121 | 122 | async def main(): 123 | model_name = "osgenesis" 124 | total = 0 125 | correct = 0 126 | json_data_path = "your/path/to/the/dataset" 127 | data = read_jsonl(json_data_path) 128 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 129 | semaphore = asyncio.Semaphore(16) 130 | tasks = [] 131 | for item in data: 132 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 133 | tasks.append(task) 134 | 135 | results = await asyncio.gather(*tasks) 136 | for result in results: 137 | correct += result 138 | total += 1 139 | print(correct, total, correct / total) 140 | return 0 141 | 142 | if __name__=="__main__": 143 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/Qwen2.5-VL/bbox2text_eval_qwen.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 46 | 输入:屏幕截图,边界框的绝对坐标的格式表示 47 | 输出:组件中的文本,注意是文字而非坐标! 48 | 示例输出一:可口可乐。 49 | 示例输出二:关注''' 50 | image_path = item["image"] 51 | base64_image, w, h = encode_image(image_path) 52 | content = [] 53 | # 动态添加base64_image部分到 content 列表 54 | content.append({ 55 | "type": "image_url", 56 | "image_url": { 57 | "url": f"data:image/jpeg;base64,{base64_image}", 58 | }, 59 | }) 60 | content.append({ 61 | "type": "text", 62 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的边界框:{}".format(w, h, item["abs_position"]) 63 | }) 64 | 65 | res = await client.chat.completions.create( 66 | messages=[ 67 | { 68 | "role": "system", 69 | "content": sys_prompt, 70 | }, 71 | { 72 | "role":"user", 73 | "content": content, 74 | } 75 | ], 76 | model=model_name, 77 | ) 78 | 79 | response = res.choices[0].message.content 80 | return response.strip("\n").replace(" ","") 81 | 82 | async def verify(response, ground_truth): 83 | """ 84 | 接受模型的字符串输入,判断是否正确 85 | """ 86 | print(response, ground_truth) 87 | if response == ground_truth: 88 | return 1 89 | else: 90 | return 0 91 | 92 | async def process_item_async(item, client, model_name, semaphore): 93 | async with semaphore: 94 | response = await call_Qwenvl(item, client, model_name) 95 | correct = await verify(response,item["text"]) 96 | return correct 97 | 98 | async def main(): 99 | model_name = "Qwen-VL" 100 | total = 0 101 | correct = 0 102 | json_data_path = "your/path/to/the/dataset" 103 | data = read_jsonl(json_data_path) 104 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 105 | semaphore = asyncio.Semaphore(16) 106 | tasks = [] 107 | for item in data: 108 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 109 | tasks.append(task) 110 | 111 | results = await asyncio.gather(*tasks) 112 | for result in results: 113 | correct += result 114 | total += 1 115 | print(correct, total, correct / total) 116 | 117 | return 0 118 | 119 | if __name__=="__main__": 120 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/Qwen2.5-VL/fun2bbox_eval_qwen.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 42 | 43 | async def call_Qwenvl(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出组件的坐标。 48 | 输入:屏幕截图,功能描述 49 | 输出:边界框的绝对坐标,以为格式,使用<>定位,其中不能存在任何非坐标字符 50 | 示例输出:我认为该UI元素在<600,1000>附近 51 | ''' 52 | image_path = item["image"].replace("/home/test/test03", "/home/test/test12") 53 | base64_image, w, h = encode_image(image_path) 54 | content = [] 55 | # 动态添加base64_image部分到 content 列表 56 | content.append({ 57 | "type": "image_url", 58 | "image_url": { 59 | "url": f"data:image/jpeg;base64,{base64_image}", 60 | }, 61 | }) 62 | content.append({ 63 | "type": "text", 64 | "text": "当前屏幕的尺寸为{}*{},屏幕上某一组件的功能描述:{}".format(w, h, item["text"]) 65 | }) 66 | try: 67 | res = await client.chat.completions.create( 68 | messages=[ 69 | { 70 | "role": "system", 71 | "content": sys_prompt, 72 | }, 73 | { 74 | "role":"user", 75 | "content": content, 76 | } 77 | ], 78 | model=model_name, 79 | ) 80 | 81 | response = res.choices[0].message.content 82 | except: 83 | response = 'None' 84 | return response.strip("\n").replace(" ","") 85 | 86 | async def verify(response, ground_truth): 87 | """ 88 | 接受模型的字符串输入,判断是否正确 89 | """ 90 | pattern = r'<\d+,\d+>' 91 | matches = re.findall(pattern, response) 92 | # 将输入字符串转换为整数列表 93 | if matches: 94 | match = matches[0] 95 | bbox = list(map(int, match.strip('<>').split(','))) 96 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 97 | 98 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 99 | gt_x_min = gt_bbox[0] 100 | gt_x_max = gt_bbox[2] 101 | gt_y_min = gt_bbox[1] 102 | gt_y_max = gt_bbox[3] 103 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 104 | return 1 105 | else: 106 | print("wrong response: {}".format(response)) 107 | return 0 108 | 109 | async def process_item_async(item, client, model_name, semaphore): 110 | async with semaphore: 111 | bbox = await call_Qwenvl(item, client, model_name) 112 | correct = await verify(bbox,item["abs_position"]) 113 | return correct 114 | 115 | async def main(): 116 | model_name = "Qwen-VL" 117 | total = 0 118 | correct = 0 119 | json_data_path = "your/path/to/the/dataset" 120 | data = read_jsonl(json_data_path) 121 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8001/v1') 122 | semaphore = asyncio.Semaphore(16) 123 | tasks = [] 124 | for item in data: 125 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 126 | tasks.append(task) 127 | 128 | results = await asyncio.gather(*tasks) 129 | for result in results: 130 | correct += result 131 | total += 1 132 | print(correct, total, correct / total) 133 | 134 | return 0 135 | 136 | if __name__=="__main__": 137 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/Qwen2.5-VL/text2bbox_eval_qwen.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | # sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 46 | # 输入:屏幕截图,文本描述 47 | # 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 48 | # 示例输出一:我认为该文本在<600,1000>附近 49 | # 示例输出二:该文本的位置是<1238,430>''' 50 | sys_prompt = ''' 51 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 52 | 输入:屏幕截图,文本描述 53 | 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 54 | 示例输出一:我认为该文本在<600,1000>附近 55 | ''' 56 | image_path = item["image"].replace("/home/test/test03","/home/test/test12") 57 | base64_image, w, h = encode_image(image_path) 58 | content = [] 59 | # 动态添加base64_image部分到 content 列表 60 | content.append({ 61 | "type": "image_url", 62 | "image_url": { 63 | "url": f"data:image/jpeg;base64,{base64_image}", 64 | }, 65 | }) 66 | content.append({ 67 | "type": "text", 68 | "text": "当前屏幕的尺寸为{}*{},屏幕上的文本:{}".format(w, h, item["text"]) 69 | }) 70 | try: 71 | res = await client.chat.completions.create( 72 | messages=[ 73 | { 74 | "role": "system", 75 | "content": sys_prompt, 76 | }, 77 | { 78 | "role":"user", 79 | "content": content, 80 | } 81 | ], 82 | model=model_name, 83 | temperature=0 84 | ) 85 | 86 | response = res.choices[0].message.content 87 | except: 88 | response = 'None' 89 | return response.strip("\n").replace(" ","") 90 | 91 | async def verify(response, ground_truth): 92 | """ 93 | 接受模型的字符串输入,判断是否正确 94 | """ 95 | pattern = r'<\d+,\d+>' 96 | matches = re.findall(pattern, response) 97 | # 将输入字符串转换为整数列表 98 | if matches: 99 | match = matches[0] 100 | bbox = list(map(int, match.strip('<>').split(','))) 101 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 102 | 103 | gt_x_min = gt_bbox[0] 104 | gt_x_max = gt_bbox[2] 105 | gt_y_min = gt_bbox[1] 106 | gt_y_max = gt_bbox[3] 107 | print(bbox, gt_bbox) 108 | if gt_x_min<=bbox[0]<=gt_x_max and gt_y_min<=bbox[1]<=gt_y_max: 109 | return 1 110 | else: 111 | print("wrong response: {}".format(response)) 112 | return 0 113 | 114 | async def process_item_async(item, client, model_name, semaphore): 115 | async with semaphore: 116 | bbox = await call_Qwenvl(item, client, model_name) 117 | correct = await verify(bbox,item["abs_position"]) 118 | return correct 119 | 120 | async def main(): 121 | model_name = "Qwen-VL" 122 | total = 0 123 | correct = 0 124 | json_data_path = "your/path/to/the/dataset" 125 | data = read_jsonl(json_data_path) 126 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8001/v1') 127 | semaphore = asyncio.Semaphore(16) 128 | tasks = [] 129 | for item in data: 130 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 131 | tasks.append(task) 132 | 133 | results = await asyncio.gather(*tasks) 134 | for result in results: 135 | correct += result 136 | total += 1 137 | print(correct, total, correct / total) 138 | return 0 139 | 140 | if __name__=="__main__": 141 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/UI-TARS/bbox2text_eval_uitars.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | def process_position(item, w, h): 42 | pattern = r'<\d+, \d+, \d+, \d+>' 43 | matches = re.findall(pattern, item["abs_position"]) 44 | if matches: 45 | match = matches[0] 46 | bbox = list(map(int, match.strip('<>').split(','))) 47 | rel_position = [int(bbox[0]/w*1000), int(bbox[1]/h*1000),int(bbox[2]/w*1000),int(bbox[3]/h*1000)] 48 | return rel_position[0],rel_position[1],rel_position[2],rel_position[3] 49 | 50 | async def call_Qwenvl(item, client, model_name): 51 | """ 52 | 调用Qwen输出function的描述,输出bbox 53 | """ 54 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。 55 | 输入:屏幕截图,边界框的相对坐标,缩放至0~1000, Action:click(start_box='(x_min,y_min,x_max,y_max)')的格式表示 56 | 输出:组件中的文本,注意是文字而非坐标!''' 57 | image_path = item["image"] 58 | base64_image, w, h = encode_image(image_path) 59 | bbox = process_position(item, w, h) 60 | content = [] 61 | # 动态添加base64_image部分到 content 列表 62 | content.append({ 63 | "type": "image_url", 64 | "image_url": { 65 | "url": f"data:image/jpeg;base64,{base64_image}", 66 | }, 67 | }) 68 | content.append({ 69 | "type": "text", 70 | "text": "屏幕上某一组件的边界框:Action:click(start_box='({},{},{},{})')".format(bbox[0],bbox[1],bbox[2],bbox[3]) 71 | }) 72 | 73 | res = await client.chat.completions.create( 74 | messages=[ 75 | { 76 | "role": "system", 77 | "content": sys_prompt, 78 | }, 79 | { 80 | "role":"user", 81 | "content": content, 82 | } 83 | ], 84 | model=model_name, 85 | max_tokens = 256 86 | ) 87 | 88 | response = res.choices[0].message.content 89 | return response.strip("\n").replace(" ","") 90 | 91 | async def verify(response, ground_truth): 92 | """ 93 | 接受模型的字符串输入,判断是否正确 94 | """ 95 | if response == ground_truth: 96 | return 1 97 | else: 98 | return 0 99 | 100 | async def process_item_async(item, client, model_name, semaphore): 101 | async with semaphore: 102 | response = await call_Qwenvl(item, client, model_name) 103 | correct = await verify(response,item["text"]) 104 | return correct 105 | 106 | async def main(): 107 | model_name = "uitars" 108 | total = 0 109 | correct = 0 110 | json_data_path = "your/path/to/the/dataset" 111 | data = read_jsonl(json_data_path) 112 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 113 | semaphore = asyncio.Semaphore(16) 114 | tasks = [] 115 | for item in data: 116 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 117 | tasks.append(task) 118 | 119 | results = await asyncio.gather(*tasks) 120 | for result in results: 121 | correct += result 122 | total += 1 123 | print(correct, total, correct / total) 124 | 125 | return 0 126 | 127 | if __name__=="__main__": 128 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/UI-TARS/fun2bbox_eval_uitars.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 13 | 读取 JSONL 文件并解析为 Python 字典列表 14 | :param file_path: JSONL 文件路径 15 | :return: 包含所有 JSON 对象的列表 16 | """ 17 | data = [] 18 | with open(file_path, "r", encoding="utf-8") as f: 19 | for line in f: 20 | data.append(json.loads(line.strip())) 21 | return data 22 | 23 | def load_image_from_path(image_path): 24 | """ 25 | 从指定路径加载图片 26 | :param image_path: 图片文件路径 27 | :return: PIL.Image 对象 28 | """ 29 | try: 30 | image = Image.open(image_path) 31 | return image 32 | except Exception as e: 33 | print(f"Error loading image from {image_path}: {e}") 34 | return None 35 | 36 | # Function to encode the image 37 | def encode_image(image_path): 38 | image = Image.open(image_path) 39 | w, h = image.width, image.height 40 | with open(image_path, "rb") as image_file: 41 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 42 | 43 | async def call_Qwenvl(item, client, model_name): 44 | """ 45 | 调用Qwen输出function的描述,输出bbox 46 | """ 47 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的任务是根据给定的GUI截图和图中某个组件的功能描述输出组件的坐标。 48 | 输入:屏幕截图,功能描述 49 | 输出:边界框的坐标,使用click 50 | 示例输出:Action:click(start_box='(66,263)') 51 | ''' 52 | image_path = item["image"].replace("/home/test/test03", "/home/test/test12") 53 | base64_image, w, h = encode_image(image_path) 54 | content = [] 55 | # 动态添加base64_image部分到 content 列表 56 | content.append({ 57 | "type": "image_url", 58 | "image_url": { 59 | "url": f"data:image/jpeg;base64,{base64_image}", 60 | }, 61 | }) 62 | content.append({ 63 | "type": "text", 64 | "text": "屏幕上某一组件的功能描述:{}".format(item["text"]) 65 | }) 66 | 67 | res = await client.chat.completions.create( 68 | messages=[ 69 | { 70 | "role": "system", 71 | "content": sys_prompt, 72 | }, 73 | { 74 | "role":"user", 75 | "content": content, 76 | } 77 | ], 78 | model=model_name, 79 | ) 80 | 81 | response = res.choices[0].message.content 82 | return response.strip("\n").replace(" ","") 83 | 84 | async def verify(response, ground_truth, w, h): 85 | """ 86 | 接受模型的字符串输入,判断是否正确 87 | """ 88 | pattern = r'(\d+,\d+)' 89 | matches = re.findall(pattern, response) 90 | # 将输入字符串转换为整数列表 91 | if matches: 92 | match = matches[0] 93 | try: 94 | x, y = map(int, match.split(',')) 95 | x = x/1000*w 96 | y = y/1000*h 97 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 98 | # 遍历每个值,检查是否在ground truth对应值的±5范围内 99 | gt_x_min = gt_bbox[0] 100 | gt_x_max = gt_bbox[2] 101 | gt_y_min = gt_bbox[1] 102 | gt_y_max = gt_bbox[3] 103 | print([x,y], [gt_x_min,gt_y_min,gt_x_max,gt_y_max]) 104 | if gt_x_min<=x<=gt_x_max and gt_y_min<=y<=gt_y_max: 105 | return 1 106 | else: 107 | return 0 108 | except: 109 | return 0 110 | 111 | else: 112 | print("wrong response: {}".format(response)) 113 | return 0 114 | 115 | async def process_item_async(item, client, model_name, semaphore): 116 | async with semaphore: 117 | bbox = await call_Qwenvl(item, client, model_name) 118 | image = Image.open(item["image"]) 119 | w, h = image.width,image.height 120 | correct = await verify(bbox,item["abs_position"], w, h) 121 | return correct 122 | 123 | async def main(): 124 | model_name = "uitars" 125 | total = 0 126 | correct = 0 127 | json_data_path = "your/path/to/the/dataset" 128 | data = read_jsonl(json_data_path) 129 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1/') 130 | semaphore = asyncio.Semaphore(16) 131 | tasks = [] 132 | for item in data: 133 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 134 | tasks.append(task) 135 | 136 | results = await asyncio.gather(*tasks) 137 | for result in results: 138 | correct += result 139 | total += 1 140 | print(correct, total, correct/total) 141 | return 0 142 | 143 | if __name__=="__main__": 144 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/UI-TARS/text2bbox_eval_uitars.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | 9 | def read_jsonl(file_path): 10 | """ 11 | 读取 JSONL 文件并解析为 Python 字典列表 12 | :param file_path: JSONL 文件路径 13 | :return: 包含所有 JSON 对象的列表 14 | """ 15 | data = [] 16 | with open(file_path, "r", encoding="utf-8") as f: 17 | for line in f: 18 | data.append(json.loads(line.strip())) 19 | return data 20 | 21 | def load_image_from_path(image_path): 22 | """ 23 | 从指定路径加载图片 24 | :param image_path: 图片文件路径 25 | :return: PIL.Image 对象 26 | """ 27 | try: 28 | image = Image.open(image_path) 29 | return image 30 | except Exception as e: 31 | print(f"Error loading image from {image_path}: {e}") 32 | return None 33 | 34 | # Function to encode the image 35 | def encode_image(image_path): 36 | image = Image.open(image_path) 37 | w, h = image.width, image.height 38 | with open(image_path, "rb") as image_file: 39 | return base64.b64encode(image_file.read()).decode("utf-8"),w, h 40 | 41 | async def call_Qwenvl(item, client, model_name): 42 | """ 43 | 调用Qwen输出function的描述,输出bbox 44 | """ 45 | # sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 46 | # 输入:屏幕截图,文本描述 47 | # 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 48 | # 示例输出一:我认为该文本在<600,1000>附近 49 | # 示例输出二:该文本的位置是<1238,430>''' 50 | sys_prompt = ''' 51 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 52 | 输入:屏幕截图,文本描述 53 | 输出:边界框的坐标,使用click 54 | 示例输出:Action:click(start_box='(66,263)') 55 | ''' 56 | image_path = item["image"].replace("/home/test/test03","/home/test/test12") 57 | base64_image, w, h = encode_image(image_path) 58 | content = [] 59 | # 动态添加base64_image部分到 content 列表 60 | content.append({ 61 | "type": "image_url", 62 | "image_url": { 63 | "url": f"data:image/jpeg;base64,{base64_image}", 64 | }, 65 | }) 66 | content.append({ 67 | "type": "text", 68 | "text": "屏幕上的文本:{}".format(item["text"]) 69 | }) 70 | 71 | res = await client.chat.completions.create( 72 | messages=[ 73 | { 74 | "role": "system", 75 | "content": sys_prompt, 76 | }, 77 | { 78 | "role":"user", 79 | "content": content, 80 | } 81 | ], 82 | model=model_name, 83 | temperature=0 84 | ) 85 | 86 | response = res.choices[0].message.content 87 | return response.strip("\n").replace(" ","") 88 | 89 | async def verify(response, ground_truth, w, h): 90 | """ 91 | 接受模型的字符串输入,判断是否正确 92 | """ 93 | pattern = r'(\d+,\d+)' 94 | matches = re.findall(pattern, response) 95 | # 将输入字符串转换为整数列表 96 | if matches: 97 | match = matches[0] 98 | try: 99 | x, y = map(int, match.split(',')) 100 | x = x/1000*w 101 | y = y/1000*h 102 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 103 | 104 | gt_x_min = gt_bbox[0] 105 | gt_x_max = gt_bbox[2] 106 | gt_y_min = gt_bbox[1] 107 | gt_y_max = gt_bbox[3] 108 | if gt_x_min<=x<=gt_x_max and gt_y_min<=y<=gt_y_max: 109 | return 1 110 | except: 111 | return 0 112 | else: 113 | print("wrong response: {}".format(response)) 114 | return 0 115 | 116 | async def process_item_async(item, client, model_name, semaphore): 117 | async with semaphore: 118 | bbox = await call_Qwenvl(item, client, model_name) 119 | image = Image.open(item["image"]) 120 | w, h = image.width,image.height 121 | correct = await verify(bbox,item["abs_position"], w, h) 122 | return correct 123 | 124 | async def main(): 125 | model_name = "uitars" 126 | total = 0 127 | correct = 0 128 | json_data_path = "your/path/to/the/dataset" 129 | data = read_jsonl(json_data_path) 130 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 131 | semaphore = asyncio.Semaphore(16) 132 | tasks = [] 133 | for item in data: 134 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 135 | tasks.append(task) 136 | 137 | results = await asyncio.gather(*tasks) 138 | for result in results: 139 | correct += result 140 | total += 1 141 | print(correct, total, correct / total) 142 | return 0 143 | 144 | if __name__=="__main__": 145 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/minicpm/bbox2text_eval_minicpm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import io 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 读取 JSONL 文件并解析为 Python 字典列表 13 | :param file_path: JSONL 文件路径 14 | :return: 包含所有 JSON 对象的列表 15 | """ 16 | data = [] 17 | with open(file_path, "r", encoding="utf-8") as f: 18 | for line in f: 19 | data.append(json.loads(line.strip())) 20 | return data 21 | 22 | def load_image_from_path(image_path): 23 | """ 24 | 从指定路径加载图片 25 | :param image_path: 图片文件路径 26 | :return: PIL.Image 对象 27 | """ 28 | try: 29 | image = Image.open(image_path) 30 | return image 31 | except Exception as e: 32 | print(f"Error loading image from {image_path}: {e}") 33 | return None 34 | 35 | # Function to encode the image 36 | def encode_image(image_path): 37 | image = Image.open(image_path) 38 | w, h = image.width, image.height 39 | image = resize(image) 40 | buffered = io.BytesIO() 41 | image.save(buffered, format="JPEG") # 保存为 JPEG 格式 42 | image_bytes = buffered.getvalue() 43 | 44 | # 对字节流进行 Base64 编码 45 | base64_encoded = base64.b64encode(image_bytes).decode("utf-8") 46 | return base64_encoded, w, h 47 | 48 | def resize(origin_img): 49 | resolution = origin_img.size 50 | w,h = resolution 51 | max_line_res = 1120 52 | if max_line_res is not None: 53 | max_line = max_line_res 54 | if h > max_line: 55 | w = int(w * max_line / h) 56 | h = max_line 57 | if w > max_line: 58 | h = int(h * max_line / w) 59 | w = max_line 60 | img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS) 61 | return img 62 | 63 | def process_position(item, w, h): 64 | pattern = r'<\d+, \d+, \d+, \d+>' 65 | matches = re.findall(pattern, item["abs_position"]) 66 | if matches: 67 | match = matches[0] 68 | bbox = list(map(int, match.strip('<>').split(','))) 69 | rel_position = [int(bbox[0]/w*1000), int(bbox[1]/h*1000),int(bbox[2]/w*1000),int(bbox[3]/h*1000)] 70 | print(rel_position) 71 | return rel_position[0],rel_position[1],rel_position[2],rel_position[3] 72 | 73 | async def call_Qwenvl(item, client, model_name): 74 | """ 75 | 调用Qwen输出function的描述,输出bbox 76 | """ 77 | sys_prompt = '''你是一个GUI组件文字识别的专家,擅长根据组件的边界框(bounding box)描述输出对应的文字。你的任务是根据给定的GUI截图和图中某个组件的边界框输出组件的中的文字。\n 输入:屏幕截图,边界框的坐标\n 输出:组件中的文本''' 78 | image_path = item["image"] 79 | base64_image, w, h = encode_image(image_path) 80 | x_min, y_min, x_max, y_max = process_position(item, w, h) 81 | content = [] 82 | # 动态添加base64_image部分到 content 列表 83 | content.append({ 84 | "type": "image_url", 85 | "image_url": { 86 | "url": f"data:image/jpeg;base64,{base64_image}", 87 | }, 88 | }) 89 | content.append({ 90 | "type": "text", 91 | "text": "屏幕上某一组件的边界框:{{\"bbox\":[[{},{}],[{},{}]]\}}".format(x_min,y_min,x_max,y_max) 92 | }) 93 | 94 | res = await client.chat.completions.create( 95 | messages=[ 96 | { 97 | "role": "system", 98 | "content": sys_prompt, 99 | }, 100 | { 101 | "role":"user", 102 | "content": content, 103 | } 104 | ], 105 | model=model_name, 106 | ) 107 | 108 | response = res.choices[0].message.content 109 | return response.strip("\n").replace(" ","") 110 | 111 | async def verify(response, ground_truth): 112 | """ 113 | 接受模型的字符串输入,判断是否正确 114 | """ 115 | print(response, ground_truth) 116 | if response == ground_truth: 117 | return 1 118 | 119 | return 0 120 | 121 | async def process_item_async(item, client, model_name, semaphore): 122 | async with semaphore: 123 | response = await call_Qwenvl(item, client, model_name) 124 | correct = await verify(response,item["text"]) 125 | return correct 126 | 127 | async def main(): 128 | model_name = "minicpm" 129 | total = 0 130 | correct = 0 131 | json_data_path = "your/path/to/the/dataset" 132 | data = read_jsonl(json_data_path) 133 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 134 | semaphore = asyncio.Semaphore(16) 135 | tasks = [] 136 | for item in data: 137 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 138 | tasks.append(task) 139 | 140 | results = await asyncio.gather(*tasks) 141 | for result in results: 142 | correct += result 143 | total += 1 144 | print(correct, total, correct / total) 145 | 146 | return 0 147 | 148 | if __name__=="__main__": 149 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/minicpm/fun2bbox_eval_minicpm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import traceback 9 | import io 10 | 11 | def read_jsonl(file_path): 12 | """ 13 | 14 | 读取 JSONL 文件并解析为 Python 字典列表 15 | :param file_path: JSONL 文件路径 16 | :return: 包含所有 JSON 对象的列表 17 | """ 18 | data = [] 19 | with open(file_path, "r", encoding="utf-8") as f: 20 | for line in f: 21 | data.append(json.loads(line.strip())) 22 | return data 23 | 24 | def load_image_from_path(image_path): 25 | """ 26 | 从指定路径加载图片 27 | :param image_path: 图片文件路径 28 | :return: PIL.Image 对象 29 | """ 30 | try: 31 | image = Image.open(image_path) 32 | return image 33 | except Exception as e: 34 | print(f"Error loading image from {image_path}: {e}") 35 | return None 36 | 37 | # Function to encode the image 38 | def encode_image(image_path): 39 | image = Image.open(image_path) 40 | w, h = image.width, image.height 41 | image = resize(image) 42 | buffered = io.BytesIO() 43 | image.save(buffered, format="JPEG") # 保存为 JPEG 格式 44 | image_bytes = buffered.getvalue() 45 | 46 | # 对字节流进行 Base64 编码 47 | base64_encoded = base64.b64encode(image_bytes).decode("utf-8") 48 | return base64_encoded, w, h 49 | 50 | def resize(origin_img): 51 | resolution = origin_img.size 52 | w,h = resolution 53 | max_line_res = 1120 54 | if max_line_res is not None: 55 | max_line = max_line_res 56 | if h > max_line: 57 | w = int(w * max_line / h) 58 | h = max_line 59 | if w > max_line: 60 | h = int(h * max_line / w) 61 | w = max_line 62 | img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS) 63 | return img 64 | 65 | async def call_Qwenvl(item, client, model_name): 66 | """ 67 | 调用Qwen输出function的描述,输出bbox 68 | """ 69 | sys_prompt = '''你是一个GUI组件定位的专家,擅长根据组件的功能描述输出对应的坐标。你的下一步操作是根据给定的GUI截图和图中某个组件的功能描述点击组件的中心位置。坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000\n 输入:屏幕截图,功能描述\n 输出:点击操作,以{\"POINT\":[...,...]}为格式,其中不能存在任何非坐标字符''' 70 | image_path = item["image"] 71 | base64_image, w, h = encode_image(image_path) 72 | content = [] 73 | # 动态添加base64_image部分到 content 列表 74 | content.append({ 75 | "type": "image_url", 76 | "image_url": { 77 | "url": f"data:image/jpeg;base64,{base64_image}", 78 | }, 79 | }) 80 | content.append({ 81 | "type": "text", 82 | "text": "屏幕上某一组件的功能描述:{}".format(item["text"]) 83 | }) 84 | 85 | res = await client.chat.completions.create( 86 | messages=[ 87 | { 88 | "role": "system", 89 | "content": sys_prompt, 90 | }, 91 | { 92 | "role":"user", 93 | "content": content, 94 | } 95 | ], 96 | model=model_name, 97 | ) 98 | 99 | response = res.choices[0].message.content 100 | return response.strip("\n").replace(" ","") 101 | 102 | async def verify(response, ground_truth,w ,h): 103 | """ 104 | 接受模型的字符串输入,判断是否正确,当前判断方法:点是否落在bbox内 105 | """ 106 | try: 107 | json_action = json.loads(response) 108 | bbox = json_action["POINT"] 109 | abs_bbox = [bbox[0]/1000*w, bbox[1]/1000*h] 110 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 111 | gt_x_min = gt_bbox[0] 112 | gt_x_max = gt_bbox[2] 113 | gt_y_min = gt_bbox[1] 114 | gt_y_max = gt_bbox[3] 115 | if gt_x_min<=abs_bbox[0]<=gt_x_max and gt_y_min<=abs_bbox[1]<=gt_y_max: 116 | return 1 117 | except Exception as e: 118 | print("wrong response: {}".format(response)) 119 | return 0 120 | 121 | async def process_item_async(item, client, model_name, semaphore): 122 | async with semaphore: 123 | bbox = await call_Qwenvl(item, client, model_name) 124 | image = Image.open(item["image"]) 125 | w, h = image.width, image.height 126 | correct = await verify(bbox,item["abs_position"],w ,h) 127 | return correct 128 | 129 | async def main(): 130 | model_name = "minicpm" 131 | total = 0 132 | correct = 0 133 | json_data_path = "your/path/to/the/dataset" 134 | data = read_jsonl(json_data_path) 135 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1/') 136 | semaphore = asyncio.Semaphore(16) 137 | tasks = [] 138 | for item in data: 139 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 140 | tasks.append(task) 141 | 142 | results = await asyncio.gather(*tasks) 143 | for result in results: 144 | correct += result 145 | total += 1 146 | print(correct, total, correct / total) 147 | 148 | return 0 149 | 150 | if __name__=="__main__": 151 | asyncio.run(main()) -------------------------------------------------------------------------------- /eval/grounding_eval/code/minicpm/text2bbox_eval_minicpm.py: -------------------------------------------------------------------------------- 1 | import json 2 | from PIL import Image 3 | import base64 4 | import re 5 | from openai import AsyncClient 6 | import os 7 | import asyncio 8 | import io 9 | 10 | def read_jsonl(file_path): 11 | """ 12 | 读取 JSONL 文件并解析为 Python 字典列表 13 | :param file_path: JSONL 文件路径 14 | :return: 包含所有 JSON 对象的列表 15 | """ 16 | data = [] 17 | with open(file_path, "r", encoding="utf-8") as f: 18 | for line in f: 19 | data.append(json.loads(line.strip())) 20 | return data 21 | 22 | def load_image_from_path(image_path): 23 | """ 24 | 从指定路径加载图片 25 | :param image_path: 图片文件路径 26 | :return: PIL.Image 对象 27 | """ 28 | try: 29 | image = Image.open(image_path) 30 | return image 31 | except Exception as e: 32 | print(f"Error loading image from {image_path}: {e}") 33 | return None 34 | 35 | # Function to encode the image 36 | def encode_image(image_path): 37 | image = Image.open(image_path) 38 | image = resize(image) 39 | w, h = image.width, image.height 40 | buffered = io.BytesIO() 41 | image.save(buffered, format="JPEG") # 保存为 JPEG 格式 42 | image_bytes = buffered.getvalue() 43 | 44 | # 对字节流进行 Base64 编码 45 | base64_encoded = base64.b64encode(image_bytes).decode("utf-8") 46 | return base64_encoded, w, h 47 | 48 | def resize(origin_img): 49 | resolution = origin_img.size 50 | w,h = resolution 51 | max_line_res = 1120 52 | if max_line_res is not None: 53 | max_line = max_line_res 54 | if h > max_line: 55 | w = int(w * max_line / h) 56 | h = max_line 57 | if w > max_line: 58 | h = int(h * max_line / w) 59 | w = max_line 60 | img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS) 61 | return img 62 | 63 | async def call_Qwenvl(item, client, model_name): 64 | """ 65 | 调用Qwen输出function的描述,输出bbox 66 | """ 67 | # sys_prompt = '''你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。 68 | # 输入:屏幕截图,文本描述 69 | # 输出:文本的绝对坐标的中心点,以为格式,使用<>定位,其中不能存在任何非坐标字符,注意中心点应当是两个坐标而不是四个。 70 | # 示例输出一:我认为该文本在<600,1000>附近 71 | # 示例输出二:该文本的位置是<1238,430>''' 72 | sys_prompt = ''' 73 | 你是一个GUI组件定位的专家,擅长输出图片上文本对应的坐标。你的任务是根据给定的GUI截图和图中某个文本输出该文本的坐标。\n 输入:屏幕截图,文本描述\n 输出:文本的相对坐标的中心点,{\"POINT\":[...,...]}为格式 74 | ''' 75 | image_path = item["image"] 76 | base64_image, w, h = encode_image(image_path) 77 | content = [] 78 | # 动态添加base64_image部分到 content 列表 79 | content.append({ 80 | "type": "image_url", 81 | "image_url": { 82 | "url": f"data:image/jpeg;base64,{base64_image}", 83 | }, 84 | }) 85 | content.append({ 86 | "type": "text", 87 | "text": "屏幕上的文本:## Text\n{}\n".format(item['text']) 88 | }) 89 | 90 | res = await client.chat.completions.create( 91 | messages=[ 92 | { 93 | "role": "system", 94 | "content": sys_prompt, 95 | }, 96 | { 97 | "role":"user", 98 | "content": content, 99 | } 100 | ], 101 | model=model_name, 102 | temperature=0 103 | ) 104 | 105 | response = res.choices[0].message.content 106 | return response.strip("\n").replace(" ","") 107 | 108 | async def verify(response, ground_truth, w ,h): 109 | """ 110 | 接受模型的字符串输入,判断是否正确,当前判断方法:点是否落在bbox内 111 | """ 112 | try: 113 | json_action = json.loads(response) 114 | bbox = json_action["POINT"] 115 | abs_bbox = [bbox[0]/1000*w, bbox[1]/1000*h] 116 | gt_bbox = list(map(int, ground_truth.strip('<>').split(','))) 117 | gt_x_min = gt_bbox[0] 118 | gt_x_max = gt_bbox[2] 119 | gt_y_min = gt_bbox[1] 120 | gt_y_max = gt_bbox[3] 121 | if gt_x_min<=abs_bbox[0]<=gt_x_max and gt_y_min<=abs_bbox[1]<=gt_y_max: 122 | return 1 123 | except Exception as e: 124 | print("wrong response: {}".format(response)) 125 | return 0 126 | 127 | async def process_item_async(item, client, model_name, semaphore): 128 | async with semaphore: 129 | image = Image.open(item["image"]) 130 | w = image.width 131 | h = image.height 132 | bbox = await call_Qwenvl(item, client, model_name) 133 | correct = await verify(bbox,item["abs_position"],w,h) 134 | return correct 135 | 136 | async def main(): 137 | model_name = "minicpm" 138 | total = 0 139 | correct = 0 140 | json_data_path = "your/path/to/the/dataset" 141 | data = read_jsonl(json_data_path) 142 | client = AsyncClient(api_key="sk-123", base_url='http://localhost:8000/v1') 143 | semaphore = asyncio.Semaphore(8) 144 | tasks = [] 145 | for item in data: 146 | task = asyncio.create_task(process_item_async(item, client, model_name,semaphore)) 147 | tasks.append(task) 148 | 149 | results = await asyncio.gather(*tasks) 150 | for result in results: 151 | correct += result 152 | total += 1 153 | print(correct, total, correct / total) 154 | return 0 155 | 156 | if __name__=="__main__": 157 | asyncio.run(main()) 158 | -------------------------------------------------------------------------------- /eval/run_predict_minicpm.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import multiprocessing 3 | import os 4 | os.chdir(os.path.dirname(os.path.abspath(__file__))) 5 | import json 6 | import torch 7 | import random 8 | import jsonschema 9 | from tqdm import tqdm 10 | from transformers import AutoTokenizer,AutoModelForCausalLM 11 | from concurrent.futures import ProcessPoolExecutor,as_completed,ThreadPoolExecutor 12 | from PIL import Image 13 | from utils.utils import get_dataset_dir 14 | import argparse 15 | import logging 16 | import time 17 | 18 | DEVICES = [ 19 | "cuda:0", "cuda:1", "cuda:2", "cuda:3", 20 | "cuda:4","cuda:5", "cuda:6", "cuda:7", 21 | ] 22 | 23 | current_file_path = os.path.abspath(__file__) 24 | current_dir = os.path.dirname(current_file_path) 25 | 26 | if current_dir not in sys.path: 27 | sys.path.append(current_dir) 28 | 29 | def compact_json_dumps(obj): 30 | return json.dumps(obj, indent=None, separators=(",", ":"), ensure_ascii=False) 31 | 32 | ACTION_SCHEMA = json.load(open(os.path.join(current_dir, 'utils/schema', 'schema.json'), encoding="utf-8")) 33 | items = list(ACTION_SCHEMA.items()) 34 | insert_index = 3 35 | items.insert(insert_index, ("required", ["thought"])) # enable/disable thought by setting it to "required"/"optional" 36 | ACTION_SCHEMA = dict(items) 37 | SYSTEM_PROMPT = f'''# Role 38 | 你是一名熟悉安卓系统触屏GUI操作的智能体,将根据用户的问题,分析当前界面的GUI元素和布局,生成相应的操作。 39 | 40 | # Task 41 | 针对用户问题,根据输入的当前屏幕截图,输出下一步的操作。 42 | 43 | # Rule 44 | - 以紧凑JSON格式输出 45 | - 输出操作必须遵循Schema约束 46 | 47 | # Schema 48 | {json.dumps(ACTION_SCHEMA, indent=None, ensure_ascii=False, separators=(',', ':'))}''' 49 | 50 | EXTRACT_SCHEMA = json.load(open(os.path.join(current_dir, 'utils/schema', 'schema_for_extraction.json'), encoding="utf-8")) 51 | 52 | 53 | _llm = None 54 | _tokenizer = None 55 | 56 | def _init_llm(model_name): 57 | global _llm,_tokenizer 58 | if _llm is None: 59 | _llm = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True,torch_dtype=torch.bfloat16) 60 | if _tokenizer is None: 61 | _tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 62 | 63 | def move_to(device): 64 | global _llm,_tokenizer 65 | if _llm is None: 66 | raise ValueError("Error, LLM is not initialized.") 67 | _llm = _llm.to(device) 68 | if _tokenizer is None: 69 | raise ValueError("Error, Tokenizer is not initialized.") 70 | return f"Moved to {device}" 71 | 72 | 73 | def run_episode(episode, msg,): 74 | global _llm,_tokenizer 75 | outputs = _llm.chat(image=None, msgs=msg, system_prompt=SYSTEM_PROMPT, tokenizer=_tokenizer, temperature=0.1,top_p=0.3,n=1,) 76 | episode["pred"] = extract_and_validate_json(outputs) 77 | return episode 78 | 79 | 80 | def extract_and_validate_json(input_string): 81 | try: 82 | json_obj = json.loads(input_string) 83 | jsonschema.validate(json_obj, EXTRACT_SCHEMA) 84 | return json_obj 85 | except json.JSONDecodeError as e: 86 | print("Error, JSON is NOT valid.") 87 | return input_string 88 | except Exception as e: 89 | print(f"Error, JSON is NOT valid according to the schema.{input_string}", e) 90 | return input_string 91 | 92 | def load_image(episode, image_path, data_name): 93 | # resize the image proportionally so that the longer side is at most 1120 94 | def __resize__(origin_img): 95 | resolution = origin_img.size 96 | w,h = resolution 97 | max_line_res = 1120 98 | if max_line_res is not None: 99 | max_line = max_line_res 100 | if h > max_line: 101 | w = int(w * max_line / h) 102 | h = max_line 103 | if w > max_line: 104 | h = int(h * max_line / w) 105 | w = max_line 106 | img = origin_img.resize((w,h),resample=Image.Resampling.LANCZOS) 107 | return img 108 | 109 | image = Image.open(image_path).convert("RGB") 110 | image = __resize__(image) 111 | 112 | if data_name == 'android_control_low_test': 113 | query = episode['low_instruction'] 114 | else: 115 | query = episode['instruction'] 116 | 117 | messages = [] 118 | messages.append( 119 | { 120 | "role": "user", 121 | "content": [ 122 | f"{query}\n当前屏幕截图:", 123 | image 124 | ] 125 | } 126 | ) 127 | return (episode,messages) 128 | 129 | 130 | def predict(args): 131 | args.data_dir, args.split, data_subset = get_dataset_dir(args.data_name) 132 | print(f"Predicting on: {args.data_dir}/{args.split}") 133 | print(f"Data subset: {data_subset}") 134 | 135 | if multiprocessing.get_start_method(allow_none=True) != "spawn": 136 | multiprocessing.set_start_method("spawn", force=True) 137 | 138 | with ProcessPoolExecutor(max_workers=len(DEVICES),initializer=_init_llm,initargs=(args.model_path,)) as poolexec: 139 | tasks = [] 140 | print("Moving model to devices") 141 | futures = [poolexec.submit(move_to, dev) for dev in DEVICES] 142 | for fut in futures: 143 | print(fut.result()) 144 | 145 | for dataset in data_subset: 146 | save_dir = os.path.join(args.output_dir, dataset) 147 | if not os.path.exists(save_dir): 148 | os.makedirs(save_dir) 149 | 150 | episode_dir = os.path.join(args.data_dir, args.split, dataset) 151 | output_file = os.path.join(save_dir, "predict.jsonl") 152 | 153 | # Get the list of all episodes files 154 | if os.path.exists(episode_dir): 155 | episodes_files = os.listdir(episode_dir) 156 | else: 157 | continue 158 | 159 | future = [] 160 | all_tasks = [] 161 | print("Loading episodes") 162 | with ThreadPoolExecutor(max_workers=16) as executor: 163 | for episodes_file in episodes_files: 164 | 165 | episodes_path = os.path.join(episode_dir, episodes_file, f"{episodes_file}.json") 166 | try: 167 | with open(episodes_path, 'r', encoding='utf-8') as f: 168 | episodes = json.load(f) 169 | except Exception as e: 170 | print(f"Failed to load {episodes_path}: {e}") 171 | continue 172 | # Skip this file on error 173 | 174 | for episode in episodes: 175 | episode["category"] = dataset 176 | image_path = os.path.join(episode_dir, episodes_file, f"{episodes_file}_{episode['step_id']}.jpeg") 177 | if not os.path.exists(image_path): 178 | image_path = image_path.replace(".jpeg", ".png") 179 | if not os.path.exists(image_path): 180 | image_path = episode['image_path'] 181 | future.append(executor.submit(load_image, episode, image_path, args.data_name)) 182 | 183 | for f in as_completed(future): 184 | all_tasks.append(f.result()) 185 | 186 | with open(output_file, "w", encoding="utf-8") as f_out: 187 | print("Predicting") 188 | tasks = [] 189 | for task_value in all_tasks: 190 | tasks.append(poolexec.submit(run_episode, *task_value)) 191 | 192 | for task in tqdm(as_completed(tasks), total=len(tasks), dynamic_ncols=True): 193 | try: 194 | episode = task.result() 195 | episode_json = json.dumps(episode, ensure_ascii=False) 196 | f_out.write(episode_json + "\n") 197 | f_out.flush() 198 | except Exception as e: 199 | print(f"Error: {e}") 200 | continue 201 | 202 | print(f"Prediction saved at: {output_file}.") 203 | os.system(f"cat {args.output_dir}/*/predict.jsonl > {args.output_dir}/all.jsonl") 204 | print(f"Merged prediction saved at: {args.output_dir}/all.jsonl.") 205 | 206 | 207 | if __name__ == "__main__": 208 | 209 | parser = argparse.ArgumentParser(description="GUI Agent Inference") 210 | parser.add_argument("--seed", type=int, default=2020, help="Random seed") 211 | parser.add_argument("--model_path", type=str, required=True, help="Model path") 212 | parser.add_argument("--output_dir", type=str, required=True, help="Directory to save results") 213 | parser.add_argument("--data_name", type=str, required=True, choices=['gui_odyssey_test', 'chinese_app_test', 'aitz_test', 'android_control_high_test', 'android_control_low_test'], help="Eval dataset name") 214 | args = parser.parse_args() 215 | random.seed(args.seed) 216 | 217 | print(f'Loading model at : {args.model_path}') 218 | print(f'Saving results at: {args.output_dir}') 219 | 220 | predict(args) 221 | -------------------------------------------------------------------------------- /eval/utils/SimHei.ttf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/AgentCPM-GUI/b3dbe5c68643858351fbf10e0ce2a8922e83bf8e/eval/utils/SimHei.ttf -------------------------------------------------------------------------------- /eval/utils/action_type.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2024 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | """AndroidInTheWild action types.""" 17 | # https://github.com/google-research/google-research/blob/master/android_in_the_wild/action_type.py 18 | 19 | import enum 20 | 21 | 22 | class ActionType(enum.IntEnum): 23 | """Integer values for each supported action type in AndroidInTheWild.""" 24 | 25 | # Placeholders for unused enum values 26 | # UNUSED_0 = 0 # used for long point 27 | # UNUSED_1 = 1 # used for no action 28 | UNUSED_2 = 2 29 | UNUSED_8 = 8 30 | UNUSED_9 = 9 31 | 32 | ########### Agent actions ########### 33 | 34 | LONG_POINT = 0 # long ponint 35 | NO_ACTION = 1 # no action 36 | 37 | # A type action that sends text to the emulator. Note that this simply sends 38 | # text and does not perform any clicks for element focus or enter presses for 39 | # submitting text. 40 | TYPE = 3 41 | 42 | # The dual point action used to represent all gestures. 43 | DUAL_POINT = 4 44 | 45 | # These actions differentiate pressing the home and back button from touches. 46 | # They represent explicit presses of back and home performed using ADB. 47 | PRESS_BACK = 5 48 | PRESS_HOME = 6 49 | 50 | # An action representing that ADB command for hitting enter was performed. 51 | PRESS_ENTER = 7 52 | 53 | ########### Episode status actions ########### 54 | 55 | # An action used to indicate the desired task has been completed and resets 56 | # the environment. This action should also be used in the case that the task 57 | # has already been completed and there is nothing to do. 58 | # e.g. The task is to turn on the Wi-Fi when it is already on 59 | STATUS_TASK_COMPLETE = 10 60 | 61 | # An action used to indicate that desired task is impossible to complete and 62 | # resets the environment. This can be a result of many different things 63 | # including UI changes, Android version differences, etc. 64 | STATUS_TASK_IMPOSSIBLE = 11 65 | -------------------------------------------------------------------------------- /eval/utils/convert_output.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import jsonschema 4 | from concurrent.futures import ProcessPoolExecutor, as_completed 5 | from tqdm import tqdm 6 | 7 | 8 | # Get the absolute path of the current file 9 | current_file_path = os.path.abspath(__file__) 10 | schema_dir = os.path.dirname(os.path.dirname(current_file_path)) 11 | EXTRACT_SCHEMA = json.load(open(os.path.join(schema_dir, 'utils/schema', 'schema_for_extraction.json'), encoding="utf-8")) 12 | 13 | 14 | def load_json_data(file_path): 15 | data = [] 16 | # Determine file type, support both JSON and JSONL 17 | if file_path.endswith('.json'): 18 | # Handle JSON file 19 | with open(file_path, 'r') as file: 20 | data = json.load(file) 21 | elif file_path.endswith('.jsonl'): 22 | # Handle JSONL file 23 | with open(file_path, 'r') as file: 24 | first_line = file.readline().strip() 25 | try: 26 | json.loads(first_line) 27 | data.append(json.loads(first_line)) 28 | except json.JSONDecodeError: 29 | pass 30 | for line in file: 31 | line = line.strip() 32 | if line: 33 | data.append(json.loads(line)) 34 | return data 35 | 36 | 37 | def parse_action(data): 38 | try: 39 | jsonschema.validate(data, EXTRACT_SCHEMA) 40 | 41 | actions = {} 42 | parameters = {} 43 | status = data.get("STATUS", "continue") # Default value 44 | 45 | # Define actions 46 | action_keys = ["POINT", "to", "PRESS", "TYPE"] 47 | 48 | # Extract actions 49 | for key in action_keys: 50 | if key in data: 51 | actions[key] = data[key] 52 | 53 | # Extract global parameters 54 | parameters["duration"] = data.get("duration", EXTRACT_SCHEMA["properties"]["duration"]["default"]) 55 | 56 | # Handle "to" parameter, if present 57 | if "to" in data: 58 | parameters["to"] = data["to"] 59 | 60 | return actions, parameters, status 61 | 62 | except Exception as e: 63 | print('Error, JSON is NOT valid according to the schema.') 64 | return None, None, None 65 | 66 | 67 | # Use multiprocessing to speed up processing 68 | def process_step(args): 69 | task, episode_id, step_id, pred, base_path = args 70 | try: 71 | actions, parameters, status = parse_action(pred) 72 | 73 | transformed_entry = { 74 | "action_predict": { 75 | "COA": { 76 | "txt": { 77 | "ACTION": actions, 78 | "ARGS": parameters, 79 | "STATUS": status 80 | }, 81 | } 82 | } 83 | } 84 | 85 | folder = f"{task}-{episode_id}" 86 | file_name = f"{folder}_{step_id}.json" 87 | output_file_path = os.path.join(base_path, folder, file_name) 88 | 89 | with open(output_file_path, 'w', encoding='utf-8') as output_file: 90 | json.dump(transformed_entry, output_file, indent=4, ensure_ascii=False) 91 | 92 | return f"Saved transformed entry to: {output_file_path}" 93 | except Exception as e: 94 | return f"Error processing step {step_id} in episode {episode_id}: {e}" 95 | 96 | 97 | # # Multi-threaded version 98 | def convert2aitz(input_path, output_path, max_workers=None): 99 | data = load_json_data(input_path) 100 | base_path = os.path.join(output_path) 101 | folders = set() 102 | tasks = [] 103 | for item in data: 104 | task = item.get("category", item.get("subset", "unknown")) 105 | episode_id = item.get("episode_id", "unknown") 106 | steps = item.get("steps", [item]) 107 | 108 | for index, each_step in enumerate(steps): 109 | step_id = index if "steps" in item else each_step.get("step_id", index) 110 | folder = f"{task}-{episode_id}" 111 | folders.add(folder) 112 | pred = each_step.get("pred", {}) 113 | tasks.append((task, episode_id, step_id, pred, base_path)) 114 | 115 | for folder in folders: 116 | folder_path = os.path.join(base_path, folder) 117 | os.makedirs(folder_path, exist_ok=True) 118 | 119 | with ProcessPoolExecutor(max_workers=max_workers) as executor: 120 | futures = [executor.submit(process_step, task_args) for task_args in tasks] 121 | for future in tqdm(as_completed(futures), total=len(futures), desc="Processing steps"): 122 | result = future.result() 123 | print(result) 124 | 125 | 126 | # # Single-threaded version 127 | def convert2aitz_single_thread(input_path, output_path): 128 | data = load_json_data(input_path) 129 | base_path = os.path.join(output_path) 130 | 131 | for item in data: 132 | task = item.get("category", "unknown") 133 | episode_id = item.get("episode_id", "unknown") 134 | steps = item.get("steps", [item]) 135 | 136 | for index, each_step in enumerate(steps): 137 | step_id = index if "steps" in item else each_step.get("step_id", index) 138 | 139 | actions, parameters, status = parse_action(each_step["pred"]) 140 | 141 | transformed_entry = { 142 | "action_predict": { 143 | "COA": { 144 | "txt": { 145 | "ACTION": actions, 146 | "ARGS": parameters, 147 | "STATUS": status 148 | }, 149 | } 150 | } 151 | } 152 | folder = f"{task}-{episode_id}" 153 | file_name = f"{folder}_{step_id}.json" 154 | folder_path = os.path.join(base_path, folder) 155 | output_path = os.path.join(folder_path, file_name) 156 | 157 | os.makedirs(folder_path, exist_ok=True) 158 | 159 | with open(output_path, 'w', encoding='utf-8') as output_file: 160 | json.dump(transformed_entry, output_file, indent=4, ensure_ascii=False) 161 | 162 | print(f"Saved transformed entry to: {output_path}") 163 | -------------------------------------------------------------------------------- /eval/utils/schema/schema.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "description": "执行操作并决定当前任务状态", 4 | "additionalProperties": false, 5 | "properties": { 6 | "thought": { 7 | "type": "string", 8 | "description": "智能体的思维过程" 9 | }, 10 | "POINT": { 11 | "$ref": "#/$defs/Location", 12 | "description": "点击屏幕上的指定位置" 13 | }, 14 | "to": { 15 | "description": "移动,组合手势参数", 16 | "oneOf": [ 17 | { 18 | "enum": [ 19 | "up", 20 | "down", 21 | "left", 22 | "right" 23 | ], 24 | "description": "从当前点(POINT)出发,执行滑动手势操作,方向包括向上、向下、向左、向右" 25 | }, 26 | { 27 | "$ref": "#/$defs/Location", 28 | "description": "移动到某个位置" 29 | } 30 | ] 31 | }, 32 | "duration": { 33 | "type": "integer", 34 | "description": "动作执行的时间或等待时间,毫秒", 35 | "minimum": 0, 36 | "default": 200 37 | }, 38 | "PRESS": { 39 | "type": "string", 40 | "description": "触发特殊按键,HOME为回到主页按钮,BACK为返回按钮,ENTER为回车按钮", 41 | "enum": [ 42 | "HOME", 43 | "BACK", 44 | "ENTER" 45 | ] 46 | }, 47 | "TYPE": { 48 | "type": "string", 49 | "description": "输入文本" 50 | }, 51 | "STATUS": { 52 | "type": "string", 53 | "description": "当前任务的状态。特殊情况:satisfied,无需操作;impossible,任务无法完成;interrupt,任务中断;need_feedback,需要用户反馈;", 54 | "enum": [ 55 | "continue", 56 | "finish", 57 | "satisfied", 58 | "impossible", 59 | "interrupt", 60 | "need_feedback" 61 | ], 62 | "default": "continue" 63 | } 64 | }, 65 | "$defs": { 66 | "Location": { 67 | "type": "array", 68 | "description": "坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000,数组第一个元素为横坐标x,第二个元素为纵坐标y", 69 | "items": { 70 | "type": "integer", 71 | "minimum": 0, 72 | "maximum": 1000 73 | }, 74 | "minItems": 2, 75 | "maxItems": 2 76 | } 77 | } 78 | } -------------------------------------------------------------------------------- /eval/utils/schema/schema_for_extraction.json: -------------------------------------------------------------------------------- 1 | { 2 | "type": "object", 3 | "description": "执行操作并决定当前任务状态", 4 | "additionalProperties": false, 5 | "properties": { 6 | "thought": { 7 | "type": "string" 8 | }, 9 | "POINT": { 10 | "description": "点击屏幕上的指定位置", 11 | "$ref": "#/$defs/Location" 12 | }, 13 | "to": { 14 | "description": "移动,组合手势参数", 15 | "oneOf": [ 16 | { 17 | "enum": [ 18 | "up", 19 | "down", 20 | "left", 21 | "right" 22 | ], 23 | "description": "结合 POINT 操作,实现向上下左右滑动" 24 | }, 25 | { 26 | "$ref": "#/$defs/Location", 27 | "description": "移动到某个位置" 28 | } 29 | ] 30 | }, 31 | "duration": { 32 | "type": "integer", 33 | "description": "动作执行的时间或等待时间,毫秒", 34 | "minimum": 0, 35 | "default": 200 36 | }, 37 | "PRESS": { 38 | "type": "string", 39 | "description": "触发特殊按键,HOME为回到主页按钮,BACK为返回按钮,ENTER为回车按钮,APPSELECT为查看已打开APP列表按钮", 40 | "enum": [ 41 | "HOME", 42 | "BACK", 43 | "ENTER", 44 | "APPSELECT" 45 | ] 46 | }, 47 | "TYPE": { 48 | "type": "string", 49 | "description": "输入文本" 50 | }, 51 | "DEEP_LINK": { 52 | "type": "null", 53 | "description": "跳转到最近打开的 APP" 54 | }, 55 | "CLEAR": { 56 | "type": "null", 57 | "description": "清空输入框的内容" 58 | }, 59 | "STATUS": { 60 | "type": "string", 61 | "description": "当前任务的状态。特殊情况:satisfied,无需操作;impossible,任务无法完成;interrupt,任务中断;need_feedback,需要用户反馈;", 62 | "enum": [ 63 | "continue", 64 | "start", 65 | "finish", 66 | "satisfied", 67 | "impossible", 68 | "interrupt", 69 | "need_feedback" 70 | ], 71 | "default": "continue" 72 | } 73 | }, 74 | "$defs": { 75 | "Location": { 76 | "type": "array", 77 | "description": "坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到 0~1000,数组第一个元素为横坐标 x,第二个元素为纵坐标 y", 78 | "items": { 79 | "type": "integer", 80 | "minimum": 0, 81 | "maximum": 1000 82 | }, 83 | "minItems": 2, 84 | "maxItems": 2 85 | } 86 | }, 87 | "allOf": [ 88 | { 89 | "if": { 90 | "required": ["to"] 91 | }, 92 | "then": { 93 | "required": ["POINT"] 94 | } 95 | }, 96 | { 97 | "if": { 98 | "anyOf": [ 99 | { "not": { "required": ["STATUS"] } }, 100 | { "properties": { "STATUS": { "enum": ["continue", "start"] } } } 101 | ] 102 | }, 103 | "then": { 104 | "anyOf": [ 105 | { "required": ["POINT"] }, 106 | { "required": ["PRESS"] }, 107 | { "required": ["TYPE"] }, 108 | { "required": ["DEEP_LINK"] }, 109 | { "required": ["CLEAR"] }, 110 | { "required": ["duration"] } 111 | ] 112 | } 113 | }, 114 | { 115 | "oneOf": [ 116 | { 117 | "required": ["POINT"], 118 | "not": { 119 | "anyOf": [ 120 | { "required": ["PRESS"] }, 121 | { "required": ["TYPE"] }, 122 | { "required": ["DEEP_LINK"] }, 123 | { "required": ["CLEAR"] } 124 | ] 125 | } 126 | }, 127 | { 128 | "required": ["PRESS"], 129 | "not": { 130 | "anyOf": [ 131 | { "required": ["POINT"] }, 132 | { "required": ["TYPE"] }, 133 | { "required": ["DEEP_LINK"] }, 134 | { "required": ["CLEAR"] } 135 | ] 136 | } 137 | }, 138 | { 139 | "required": ["TYPE"], 140 | "not": { 141 | "anyOf": [ 142 | { "required": ["POINT"] }, 143 | { "required": ["PRESS"] }, 144 | { "required": ["DEEP_LINK"] }, 145 | { "required": ["CLEAR"] } 146 | ] 147 | } 148 | }, 149 | { 150 | "required": ["DEEP_LINK"], 151 | "not": { 152 | "anyOf": [ 153 | { "required": ["POINT"] }, 154 | { "required": ["PRESS"] }, 155 | { "required": ["TYPE"] }, 156 | { "required": ["CLEAR"] } 157 | ] 158 | } 159 | }, 160 | { 161 | "required": ["CLEAR"], 162 | "not": { 163 | "anyOf": [ 164 | { "required": ["POINT"] }, 165 | { "required": ["PRESS"] }, 166 | { "required": ["TYPE"] }, 167 | { "required": ["DEEP_LINK"] } 168 | ] 169 | } 170 | }, 171 | { 172 | "not": { 173 | "anyOf": [ 174 | { "required": ["POINT"] }, 175 | { "required": ["PRESS"] }, 176 | { "required": ["TYPE"] }, 177 | { "required": ["DEEP_LINK"] }, 178 | { "required": ["CLEAR"] } 179 | ] 180 | } 181 | } 182 | ] 183 | } 184 | ] 185 | } 186 | -------------------------------------------------------------------------------- /eval/utils/utils.py: -------------------------------------------------------------------------------- 1 | from colorama import init, Fore, Style 2 | import os 3 | from utils.action_type import ActionType 4 | 5 | 6 | def annotate_and_save_image(img_path, output_folder, gt_action_type, gt_action_detail, pd_action_type, pd_action_detail, type_match, exact_match, subset, episode_id, step_id, task_desc): 7 | """Save an annotated image with action details to the specified folder.""" 8 | # Load the image and get its dimensions 9 | image = Image.open(img_path) 10 | draw = ImageDraw.Draw(image) 11 | 12 | # Dynamically compute font size based on image height 13 | base_height = 1080 # Reference height, e.g., 1080p 14 | font_size = max(12, int(image.height / base_height * 20)) # Ensure minimum font size is 12 15 | 16 | current_file_path = os.path.abspath(__file__) 17 | current_dir = os.path.dirname(current_file_path) 18 | try: 19 | font = ImageFont.truetype(os.path.join(current_dir, './SimHei.ttf'), font_size) 20 | except IOError: 21 | # If the specified font file does not exist, use the default font 22 | font = ImageFont.load_default() 23 | 24 | w, h = image.width, image.height 25 | 26 | # Create annotation text 27 | annotation_text = ( 28 | f"taskDesc: {task_desc}\n" 29 | f"taskID: {subset}{episode_id}_{step_id}\n" 30 | f"GT action: {gt_action_type}\n" 31 | f"GT detail: {gt_action_detail}\n" 32 | f"PD action: {pd_action_type}\n" 33 | f"PD detail: {pd_action_detail}\n" 34 | f"type_match: {'Yes' if type_match else 'No'}\n" 35 | f"exac_match: {'Yes' if exact_match else 'No'}" 36 | ) 37 | 38 | # Calculate text size and wrap lines if necessary 39 | max_width = w - 20 # Max width for the text 40 | lines = [] 41 | for line in annotation_text.split('\n'): 42 | # Split line by words to check width 43 | words = line.split() 44 | current_line = "" 45 | for word in words: 46 | # Check width after adding a word 47 | test_line = current_line + " " + word if current_line else word 48 | if draw.textlength(test_line, font=font) > max_width: 49 | # If line is too long, start a new line 50 | lines.append(current_line) 51 | current_line = word 52 | else: 53 | current_line = test_line 54 | lines.append(current_line) # Add the final line 55 | 56 | # Draw each line on the image 57 | y_text = 10 58 | line_spacing = int(font_size * 1.2) # Line spacing is 1.2 times the font size 59 | for line in lines: 60 | draw.text((10, y_text), line, font=font, fill='red') 61 | y_text += line_spacing # Move to next line position 62 | 63 | # Draw rectangle and point based on conditions 64 | if pd_action_type == 'click' and type_match: 65 | if isinstance(gt_action_detail, (list, tuple)) and len(gt_action_detail) == 4: 66 | ymin, xmin, height, width = gt_action_detail # Parse GT action details 67 | pd_x = pd_action_detail.get("x", 0) * w 68 | pd_y = pd_action_detail.get("y", 0) * h 69 | gt_box = [xmin * w, ymin * h, (xmin + width) * w, (ymin + height) * h] 70 | draw.rectangle(gt_box, outline="red", width=max(1, int(font_size / 5))) # Adjust line width dynamically 71 | point_radius = max(5, int(font_size / 2)) # Adjust point radius dynamically 72 | draw.ellipse( 73 | (pd_x - point_radius, pd_y - point_radius, pd_x + point_radius, pd_y + point_radius), 74 | fill="red", 75 | outline="blue", 76 | width=max(1, int(font_size / 10)) 77 | ) 78 | 79 | # Save the annotated image to the output folder 80 | if not os.path.exists(output_folder): 81 | os.makedirs(output_folder, exist_ok=True) 82 | output_file_name = os.path.basename(img_path).replace('.png', '_annotated.png') 83 | output_path = os.path.join(output_folder, output_file_name) 84 | image.save(output_path) 85 | 86 | return output_path 87 | 88 | 89 | def get_dataset_dir(data_name): 90 | data_list = ['aitz_test', 'chinese_app_test', 'gui_odyssey_test', 'android_control_high_test', 'android_control_low_test'] 91 | assert data_name in data_list, "Error, unkonw eval dataset." 92 | data_split = None 93 | data_dir = None 94 | data_subset = None 95 | 96 | current_file_path = os.path.abspath(__file__) 97 | data_dir = os.path.dirname(os.path.dirname(current_file_path)) 98 | 99 | match data_name: 100 | case 'aitz_test': 101 | data_dir = os.path.join(data_dir, "eval_data", "aitz_test") 102 | data_split = "test" 103 | data_subset = ["general", "install", "web_shopping", "google_apps"] 104 | case 'chinese_app_test': 105 | data_dir = os.path.join(data_dir, "eval_data", "chinese_app_test") 106 | data_split = "test" 107 | data_subset = ["domestic"] 108 | case 'gui_odyssey_test': 109 | data_dir = os.path.join(data_dir, "eval_data", "odyssey") 110 | data_split = "test" 111 | data_subset = ["odyssey"] 112 | case 'android_control_high_test': 113 | data_dir = os.path.join(data_dir, "eval_data", "android_control_high_test") 114 | data_split = "test" 115 | data_subset = ["android_control"] 116 | case 'android_control_low_test': 117 | data_dir = os.path.join(data_dir, "eval_data", "android_control_low_test") 118 | data_split = "test" 119 | data_subset = ["android_control"] 120 | 121 | return data_dir, data_split, data_subset 122 | -------------------------------------------------------------------------------- /eval/utils/utils_odyssey/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "architectures": [ 3 | "QWenLMHeadModel" 4 | ], 5 | "attn_dropout_prob": 0.0, 6 | "auto_map": { 7 | "AutoConfig": "configuration_qwen.QWenConfig", 8 | "AutoModelForCausalLM": "modeling_qwen.QWenLMHeadModel" 9 | }, 10 | "bf16": true, 11 | "emb_dropout_prob": 0.0, 12 | "fp16": false, 13 | "fp32": false, 14 | "hidden_size": 4096, 15 | "his_len": 4, 16 | "initializer_range": 0.02, 17 | "intermediate_size": 22016, 18 | "kv_channels": 128, 19 | "layer_norm_epsilon": 1e-06, 20 | "max_position_embeddings": 8192, 21 | "model_type": "qwen", 22 | "no_bias": true, 23 | "num_attention_heads": 32, 24 | "num_hidden_layers": 32, 25 | "onnx_safe": null, 26 | "rotary_emb_base": 10000, 27 | "rotary_pct": 1.0, 28 | "scale_attn_weights": true, 29 | "seq_length": 2048, 30 | "tie_word_embeddings": false, 31 | "tokenizer_type": "QWenTokenizer", 32 | "torch_dtype": "bfloat16", 33 | "transformers_version": "4.50.0", 34 | "use_cache": true, 35 | "use_dynamic_ntk": true, 36 | "use_flash_attn": false, 37 | "use_logn_attn": true, 38 | "visual": { 39 | "heads": 16, 40 | "image_size": 448, 41 | "image_start_id": 151857, 42 | "layers": 48, 43 | "mlp_ratio": 4.9231, 44 | "output_dim": 4096, 45 | "patch_size": 14, 46 | "width": 1664 47 | }, 48 | "vocab_size": 151936 49 | } 50 | -------------------------------------------------------------------------------- /eval/utils/utils_odyssey/configuration_qwen.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Alibaba Cloud. 2 | # 3 | # This source code is licensed under the license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from transformers import PretrainedConfig 7 | 8 | 9 | class QWenConfig(PretrainedConfig): 10 | model_type = "qwen" 11 | keys_to_ignore_at_inference = ["past_key_values"] 12 | 13 | def __init__( 14 | self, 15 | vocab_size=151936, 16 | hidden_size=4096, 17 | num_hidden_layers=32, 18 | num_attention_heads=32, 19 | emb_dropout_prob=0.0, 20 | attn_dropout_prob=0.0, 21 | layer_norm_epsilon=1e-6, 22 | initializer_range=0.02, 23 | max_position_embeddings=8192, 24 | scale_attn_weights=True, 25 | use_cache=True, 26 | bf16=False, 27 | fp16=False, 28 | fp32=False, 29 | kv_channels=128, 30 | rotary_pct=1.0, 31 | rotary_emb_base=10000, 32 | use_dynamic_ntk=True, 33 | use_logn_attn=True, 34 | use_flash_attn="auto", 35 | intermediate_size=22016, 36 | no_bias=True, 37 | tie_word_embeddings=False, 38 | **kwargs, 39 | ): 40 | self.vocab_size = vocab_size 41 | self.hidden_size = hidden_size 42 | self.intermediate_size = intermediate_size 43 | self.num_hidden_layers = num_hidden_layers 44 | self.num_attention_heads = num_attention_heads 45 | self.emb_dropout_prob = emb_dropout_prob 46 | self.attn_dropout_prob = attn_dropout_prob 47 | self.layer_norm_epsilon = layer_norm_epsilon 48 | self.initializer_range = initializer_range 49 | self.scale_attn_weights = scale_attn_weights 50 | self.use_cache = use_cache 51 | self.max_position_embeddings = max_position_embeddings 52 | self.bf16 = bf16 53 | self.fp16 = fp16 54 | self.fp32 = fp32 55 | self.kv_channels = kv_channels 56 | self.rotary_pct = rotary_pct 57 | self.rotary_emb_base = rotary_emb_base 58 | self.use_dynamic_ntk = use_dynamic_ntk 59 | self.use_logn_attn = use_logn_attn 60 | self.use_flash_attn = use_flash_attn 61 | self.no_bias = no_bias 62 | super().__init__( 63 | tie_word_embeddings=tie_word_embeddings, 64 | **kwargs 65 | ) 66 | -------------------------------------------------------------------------------- /eval/utils/utils_odyssey/generation_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "_from_model_config": true, 3 | "transformers_version": "4.50.0" 4 | } 5 | -------------------------------------------------------------------------------- /eval/utils/utils_odyssey/special_tokens_map.json: -------------------------------------------------------------------------------- 1 | {} 2 | -------------------------------------------------------------------------------- /eval/utils/utils_odyssey/tokenizer_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "added_tokens_decoder": {}, 3 | "auto_map": { 4 | "AutoTokenizer": [ 5 | "Qwen/Qwen-VL-Chat--tokenization_qwen.QWenTokenizer", 6 | null 7 | ] 8 | }, 9 | "clean_up_tokenization_spaces": false, 10 | "extra_special_tokens": {}, 11 | "model_max_length": 8192, 12 | "tokenizer_class": "QWenTokenizer" 13 | } 14 | -------------------------------------------------------------------------------- /model/README.md: -------------------------------------------------------------------------------- 1 | # AgentCPM-GUI 2 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.5.1 2 | absl_py==2.1.0 3 | accelerate==0.30.1 4 | acme.hello==0.1 5 | colorama==0.4.6 6 | datasets==3.1.0 7 | deepspeed 8 | demjson3==3.0.6 9 | dm_env==1.6 10 | einops==0.8.1 11 | #grpc_tools==1.0.0 12 | grpcio-tools==1.71.0 13 | gym==0.26.2 14 | icecream==2.1.4 15 | ipython==8.12.3 16 | jax==0.6.0 17 | json5==0.10.0 18 | jsonschema==4.23.0 19 | matplotlib==3.7.4 20 | numpy==1.26.4 21 | openai==1.77.0 22 | packaging==25.0 23 | peft==0.12.0 24 | Pillow==11.2.1 25 | portpicker==1.6.0 26 | pygame==2.6.1 27 | python_Levenshtein==0.27.1 28 | PyYAML==6.0.2 29 | qwen_agent==0.0.16 30 | qwen_vl_utils==0.0.11 31 | Requests==2.32.3 32 | setuptools==75.1.0 33 | tensorflow==2.19.0 34 | tiktoken==0.7.0 35 | torchvision==0.20.1 36 | tqdm==4.66.3 37 | transformers==4.50.0 38 | transformers_stream_generator==0.0.4 39 | trl==0.9.6 40 | ultralytics==8.3.129 41 | vllm==0.7.1 42 | yacs==0.1.8 43 | tf_keras==2.19.0 44 | flash_attn==2.7.4.post1 45 | -------------------------------------------------------------------------------- /rft/config_files/ds.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: config_files/zero.json 5 | zero3_init_flag: true 6 | distributed_type: DEEPSPEED 7 | downcast_bf16: 'no' 8 | dynamo_config: 9 | dynamo_backend: INDUCTOR 10 | enable_cpu_affinity: false 11 | machine_rank: 0 12 | main_training_function: main 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /rft/config_files/ds_dst.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_config_file: /data3/workhome/luyaxi/ARL/config_files/zero3.json 5 | deepspeed_hostfile: /data3/workhome/luyaxi/ARL/config_files/hostfile 6 | deepspeed_multinode_launcher: pdsh 7 | zero3_init_flag: true 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | dynamo_config: 11 | dynamo_backend: INDUCTOR 12 | enable_cpu_affinity: false 13 | machine_rank: 1 14 | main_process_ip: 10.0.1.11 15 | main_process_port: 12346 16 | main_training_function: main 17 | num_machines: 2 18 | num_processes: 16 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_env: [] 22 | tpu_use_cluster: false 23 | tpu_use_sudo: false 24 | use_cpu: false 25 | -------------------------------------------------------------------------------- /rft/config_files/fsdp.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | mixed_precision: 'bf16' 7 | fsdp_config: 8 | fsdp_activation_checkpointing: true 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_backward_prefetch: BACKWARD_PRE 11 | fsdp_cpu_ram_efficient_loading: true 12 | fsdp_forward_prefetch: true 13 | fsdp_offload_params: false 14 | fsdp_sharding_strategy: FULL_SHARD 15 | fsdp_state_dict_type: SHARDED_STATE_DICT 16 | fsdp_sync_module_states: true 17 | fsdp_transformer_layer_cls_to_wrap: Qwen2DecoderLayer 18 | fsdp_use_orig_params: true 19 | machine_rank: 0 20 | main_training_function: main 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /rft/config_files/fsdp2_dst.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: FSDP 3 | downcast_bf16: 'no' 4 | fsdp_config: 5 | fsdp_activation_checkpointing: true 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_cpu_ram_efficient_loading: true 8 | fsdp_forward_prefetch: true 9 | fsdp_offload_params: false 10 | fsdp_reshard_after_forward: true 11 | fsdp_state_dict_type: SHARDED_STATE_DICT 12 | fsdp_version: 2 13 | main_training_function: main 14 | mixed_precision: bf16 15 | -------------------------------------------------------------------------------- /rft/config_files/fsdp_dst.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | distributed_type: FSDP 3 | downcast_bf16: 'no' 4 | fsdp_config: 5 | fsdp_activation_checkpointing: true 6 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 7 | fsdp_backward_prefetch: BACKWARD_PRE 8 | fsdp_cpu_ram_efficient_loading: true 9 | fsdp_forward_prefetch: false 10 | fsdp_offload_params: true 11 | fsdp_sharding_strategy: FULL_SHARD 12 | fsdp_state_dict_type: SHARDED_STATE_DICT 13 | fsdp_sync_module_states: true 14 | fsdp_use_orig_params: true 15 | main_training_function: main 16 | mixed_precision: bf16 17 | -------------------------------------------------------------------------------- /rft/config_files/hostfile: -------------------------------------------------------------------------------- 1 | g1 slots=8 2 | g2 slots=8 -------------------------------------------------------------------------------- /rft/config_files/zero.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 1, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "allgather_partitions": true, 20 | "allgather_bucket_size": 2e8, 21 | "overlap_comm": true, 22 | "reduce_scatter": true, 23 | "reduce_bucket_size": 2e8, 24 | "contiguous_gradients": true 25 | }, 26 | "gradient_accumulation_steps": "auto", 27 | "gradient_clipping": "auto", 28 | "steps_per_print": 100, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "wall_clock_breakdown": false 32 | } -------------------------------------------------------------------------------- /rft/config_files/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 2, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "allgather_partitions": true, 20 | "allgather_bucket_size": 5e8, 21 | "overlap_comm": true, 22 | "reduce_scatter": true, 23 | "reduce_bucket_size": 5e8, 24 | "contiguous_gradients": true 25 | }, 26 | "gradient_accumulation_steps": "auto", 27 | "gradient_clipping": "auto", 28 | "steps_per_print": 100, 29 | "train_batch_size": "auto", 30 | "train_micro_batch_size_per_gpu": "auto", 31 | "wall_clock_breakdown": false 32 | } -------------------------------------------------------------------------------- /rft/config_files/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /rft/fsdp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -euo pipefail 3 | 4 | NODES=("g1" "g2") 5 | NODE_NUMS=${#NODES[@]} 6 | NUM_PROCESSES=16 7 | MASTER_ADDR=${NODES[0]} 8 | MASTER_PORT=29500 9 | RUN_NAME="test" 10 | FILE_DIR=$(cd "$(dirname "$0")" && pwd) 11 | # --------------------- 12 | 13 | PDSH_PIDS=() 14 | 15 | cleanup() { 16 | echo -e "\n>>> Cleanup: killing local pdsh and remote training processes..." 17 | if [ ${#PDSH_PIDS[@]} -gt 0 ]; then 18 | kill "${PDSH_PIDS[@]}" 2>/dev/null || true 19 | fi 20 | pdsh -R ssh -w "${NODES[*]}" "pkill -f 'accelerate launch' || true" 21 | } 22 | trap cleanup EXIT SIGINT SIGTERM 23 | 24 | echo "Training directory: $FILE_DIR" 25 | echo "Launching on nodes: ${NODES[*]}" 26 | 27 | for i in "${!NODES[@]}"; do 28 | NODE=${NODES[$i]} 29 | NODE_RANK=$i 30 | 31 | echo "-> Launching on $NODE (rank $NODE_RANK)..." 32 | 33 | pdsh -R ssh -w "$NODE" bash -lc " 34 | source ~/miniconda3/bin/activate arl 35 | cd '$FILE_DIR' 36 | export TOKENIZERS_PARALLELISM=false 37 | export WANDB_PROJECT=VLM-RFT 38 | export MASTER_ADDR=$MASTER_ADDR 39 | export MASTER_PORT=$MASTER_PORT 40 | export NODE_RANK=$NODE_RANK 41 | export TORCHDYNAMO_CAPTURE_SCALAR_OUTPUTS=1 42 | export MONITOR_INTERVAL=30 43 | 44 | accelerate launch \ 45 | --config_file config_files/fsdp2_dst.yml \ 46 | --num_machines=$NODE_NUMS \ 47 | --num_processes=$NUM_PROCESSES \ 48 | --machine_rank=$i \ 49 | --main_process_ip=$MASTER_ADDR \ 50 | --main_process_port=$MASTER_PORT \ 51 | --same_network \ 52 | grpo.py \ 53 | --output_dir output/$RUN_NAME \ 54 | --clear_device true \ 55 | --model_name_or_path output/resume-sft \ 56 | --dataset_name /share_data/data1/GUIData/train_ody_aitwc_mb_ac_locate.jsonl \ 57 | --eval_dataset_name /share_data/data1/GUIData/valid_f_ody_aitwc_mb_ac.jsonl \ 58 | --max_prompt_length 2048 \ 59 | --max_completion_length 128 \ 60 | --max_line_res 1120 \ 61 | --hist_length 1 \ 62 | --num_generations 32 \ 63 | --num_iterations 1 \ 64 | --per_device_train_batch_size 1 \ 65 | --gradient_accumulation_steps 32 \ 66 | --dataloader_prefetch_factor 4 \ 67 | --dataloader_num_workers 4 \ 68 | --dataloader_drop_last true \ 69 | --max_grad_norm 1.0 \ 70 | --logging_steps 1 \ 71 | --learning_rate 1e-6 \ 72 | --warmup_steps 10 \ 73 | --weight_decay 0.1 \ 74 | --eval_strategy steps \ 75 | --per_device_eval_batch_size 4 \ 76 | --eval_steps 25 \ 77 | --adam_beta2 0.99 \ 78 | --lr_scheduler_type 'constant' \ 79 | --tune_vision true \ 80 | --bf16 \ 81 | --beta 0.0 \ 82 | --data_seed 41 \ 83 | --report_to wandb \ 84 | --num_train_epochs 3 \ 85 | --run_name $RUN_NAME \ 86 | --save_steps 50 \ 87 | --save_only_model true \ 88 | --attn_implementation flash_attention_2 \ 89 | --reward_funcs 'type' 'args' 90 | " & 91 | 92 | PDSH_PIDS+=($!) 93 | done 94 | 95 | wait 96 | -------------------------------------------------------------------------------- /rft/grpo.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from transformers import AutoModelForCausalLM, AutoProcessor 4 | from trainer.arl import AsyncRLGRPOTrainer 5 | from configs import GRPOTrainingConfig,GRPOScriptArguments 6 | from trl import ModelConfig, TrlParser 7 | from trainer.utils import action_schema_check, action_args_check, GUIRFTDataset,action_type_check,GUIMTRFTDataset,react_check,fsdp2_prepare_model 8 | import torch.distributed as dist 9 | 10 | reward_funcs_registry = { 11 | # "accuracy": iou_reward, 12 | # "format": format_reward, 13 | "react": react_check, 14 | "schema": action_schema_check, 15 | "type":action_type_check, 16 | "args":action_args_check, 17 | } 18 | 19 | # ----------------------- Main Script ----------------------- 20 | def main(script_args, training_args, model_args): 21 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs] 22 | print("reward_funcs:", reward_funcs) 23 | 24 | 25 | model_init_kwargs = {} 26 | model_init_kwargs["attn_implementation"] = model_args.attn_implementation 27 | model_init_kwargs["torch_dtype"] = torch.bfloat16 28 | model_init_kwargs["trust_remote_code"] = True 29 | model_id = model_args.model_name_or_path 30 | if "minicpm" in model_id.lower(): 31 | if "minicpm-o" in model_id.lower(): 32 | model_init_kwargs["init_tts"] = False 33 | model_init_kwargs["init_audio"] = False 34 | 35 | model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) 36 | processing_class = AutoProcessor.from_pretrained(model_id,trust_remote_code=True) 37 | # processing_class.pad_token_id = processing_class.tokenizer.pad_token_id 38 | # if processing_class.pad_token_id is None: 39 | # processing_class.tokenizer.pad_token_id = 2 40 | # processing_class.pad_token_id = 2 41 | # processing_class.padding_side = "left" 42 | # processing_class.tokenizer.padding_side = "left" 43 | device_mesh = None 44 | if training_args.tensor_parallel_size is not None: 45 | tp_size = int(training_args.tensor_parallel_size) 46 | world_size = dist.get_world_size() 47 | if world_size % tp_size != 0: 48 | raise ValueError( 49 | f"world_size {world_size} must be divisible by tensor_parallel_size {tp_size}" 50 | ) 51 | dp_size = world_size // tp_size 52 | 53 | device_mesh = dist.device_mesh.DeviceMesh( 54 | "cuda", 55 | mesh=torch.arange(world_size).reshape((dp_size, tp_size)), 56 | mesh_dim_names=("dp", "tp"), 57 | ) 58 | 59 | tp_mesh = device_mesh["tp"] 60 | dp_mesh = device_mesh["dp"] 61 | 62 | from torch.distributed.tensor.parallel import ColwiseParallel,RowwiseParallel,parallelize_module,SequenceParallel,PrepareModuleInput,PrepareModuleOutput 63 | from torch.distributed.tensor import Replicate,Shard 64 | 65 | layer_tp_plan = { 66 | "llm.model.layers.*.mlp": PrepareModuleInput( 67 | input_layouts=(Shard(1),), 68 | desired_input_layouts=(Replicate(),) 69 | ), 70 | "llm.model.layers.*.mlp.up_proj": ColwiseParallel(), 71 | "llm.model.layers.*.mlp.gate_proj": ColwiseParallel(), 72 | "llm.model.layers.*.mlp.down_proj": RowwiseParallel(output_layouts=Shard(1)), 73 | "llm.model.layers.*.self_attn": PrepareModuleInput( 74 | input_kwarg_layouts={ 75 | "hidden_states": Shard(1), 76 | "attention_mask": None, 77 | "position_ids": None, 78 | "past_key_value": None, 79 | "output_attentions": None, 80 | "use_cache": None, 81 | "cache_position": None, 82 | "position_embeddings": None, 83 | }, 84 | desired_input_kwarg_layouts={ 85 | "hidden_states": Replicate(), 86 | "attention_mask": None, 87 | "position_ids": None, 88 | "past_key_value": None, 89 | "output_attentions": None, 90 | "use_cache": None, 91 | "cache_position": None, 92 | "position_embeddings": None, 93 | }, 94 | ), 95 | "llm.model.layers.*.self_attn.k_proj": ColwiseParallel(), 96 | "llm.model.layers.*.self_attn.q_proj": ColwiseParallel(), 97 | "llm.model.layers.*.self_attn.v_proj": ColwiseParallel(), 98 | "llm.model.layers.*.self_attn.o_proj": RowwiseParallel(output_layouts=Shard(1)), 99 | "llm.model.layers.*.input_layernorm": SequenceParallel(), 100 | "llm.model.layers.*.post_attention_layernorm": SequenceParallel(), 101 | "llm.model.layers.*.input_layernorm": SequenceParallel(), 102 | "llm.model.norm": SequenceParallel(), 103 | "llm.model.layers.0": PrepareModuleInput( 104 | input_layouts=(Replicate(),), 105 | desired_input_layouts=(Shard(1),) 106 | ) 107 | # "llm.model.embed_tokens": 108 | # RowwiseParallel( 109 | # input_layouts=Replicate(), 110 | # output_layouts=Shard(1), 111 | # ), 112 | # "llm.lm_head": ColwiseParallel( 113 | # input_layouts=Shard(1), 114 | # output_layouts=Replicate(), 115 | # # use_local_output=False, # TODO: Is this ture for grpo ? 116 | # ) 117 | } 118 | 119 | layer_tp_plan = { 120 | "llm.model.layers.*.mlp.up_proj": ColwiseParallel(), 121 | "llm.model.layers.*.mlp.gate_proj": ColwiseParallel(), 122 | "llm.model.layers.*.mlp.down_proj": RowwiseParallel(), 123 | "llm.model.layers.*.self_attn.k_proj": ColwiseParallel(), 124 | "llm.model.layers.*.self_attn.q_proj": ColwiseParallel(), 125 | "llm.model.layers.*.self_attn.v_proj": ColwiseParallel(), 126 | "llm.model.layers.*.self_attn.o_proj": RowwiseParallel(), 127 | } 128 | 129 | model = parallelize_module( 130 | model, 131 | tp_mesh, 132 | layer_tp_plan 133 | ) 134 | 135 | # for layer_id, transformer_block in enumerate(model.vpm.encoder.layers): 136 | # layer_tp_plan = { 137 | # "mlp.fc1": ColwiseParallel(), 138 | # "mlp.fc2": RowwiseParallel(), 139 | # "self_attn.k_proj": ColwiseParallel(), 140 | # "self_attn.q_proj": ColwiseParallel(), 141 | # "self_attn.v_proj": ColwiseParallel(), 142 | # "self_attn.out_proj": RowwiseParallel() 143 | # } 144 | # # adjuest attention module to use the local number of heads 145 | # attn_layer = transformer_block.self_attn 146 | # attn_layer.num_heads = attn_layer.num_heads // tp_mesh.size() 147 | # attn_layer.embed_dim = attn_layer.embed_dim // tp_mesh.size() 148 | 149 | # model.vpm.encoder.layers[layer_id] = parallelize_module( 150 | # module=transformer_block, 151 | # device_mesh=tp_mesh, 152 | # parallelize_plan=layer_tp_plan 153 | # ) 154 | 155 | model = fsdp2_prepare_model( 156 | model, 157 | mesh=dp_mesh 158 | ) 159 | 160 | global_task_dispatch_addr = f"tcp://{os.environ.get('MASTER_ADDR')}:{training_args.global_data_dispatch_port}" 161 | 162 | dataset_cls = GUIMTRFTDataset if training_args.hist_length > 1 else GUIRFTDataset 163 | 164 | datasets = dataset_cls( 165 | global_task_dispatch_addr=global_task_dispatch_addr, 166 | hist_length=training_args.hist_length, 167 | jsonl_file_path=script_args.dataset_name, 168 | max_line_res=script_args.max_line_res,) 169 | 170 | if script_args.eval_dataset_name is not None: 171 | eval_set = dataset_cls( 172 | global_task_dispatch_addr=global_task_dispatch_addr, 173 | hist_length=training_args.hist_length, 174 | jsonl_file_path=script_args.eval_dataset_name, 175 | max_line_res=script_args.max_line_res,) 176 | else: 177 | eval_set = None 178 | 179 | # Initialize the GRPO trainer 180 | trainer = AsyncRLGRPOTrainer( 181 | model=model, 182 | model_init_kwargs=model_init_kwargs, 183 | processing_class=processing_class, 184 | reward_funcs=reward_funcs, 185 | args=training_args, 186 | train_dataset=datasets, 187 | eval_dataset=eval_set, 188 | device_mesh=device_mesh 189 | ) 190 | 191 | # Train and push the model to the Hub 192 | trainer.train() 193 | 194 | # Save and push to hub 195 | trainer.save_model(training_args.output_dir) 196 | if training_args.push_to_hub: 197 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 198 | 199 | 200 | if __name__ == "__main__": 201 | parser = TrlParser((GRPOScriptArguments, GRPOTrainingConfig, ModelConfig)) 202 | script_args, training_args, model_args = parser.parse_args_and_config() 203 | main(script_args, training_args, model_args) 204 | -------------------------------------------------------------------------------- /rft/readme.md: -------------------------------------------------------------------------------- 1 | # ARL 2 | 3 | ARL (Another asynchronous Reinforcement Learning framework) allow us to train a vision-language model with minimal modification of the hugginface transformers Trainer. 4 | 5 | Current Support Features: 6 | 7 | - Load balance for task feed, (multiturn) completions gathering between different nodes. 8 | - Aynchronous rollout before model updating. 9 | - FSDPv2 Support. 10 | 11 | ## Installation 12 | 13 | Run following commands: 14 | 15 | ```bash 16 | conda create -n arl python=3.11 17 | conda activate arl 18 | pip install -e requirements.txt 19 | ``` 20 | 21 | Note: You need to install pytorch>=2.6 and the latest transformers to run with FSDPv2. 22 | 23 | 24 | ## How To Use 25 | 26 | ### 1. Modify the training scripts 27 | 28 | The example script is `fsdp.sh`, you should change following args before runing the code: 29 | 30 | - Set `RUN_NAME` to any name you like 31 | - `source ~/miniconda3/bin/activate arl` this should be modify according to your installation. 32 | - `model_name_or_path` should be the path to the model you want to train. 33 | - `dataset_name` and `eval_dataset_name` should be the path of processed datasets. 34 | - `NODES` and `NUM_PROCESSES` should be set according to your cluster status. 35 | 36 | Make sure you have install `pdsh` to start training. 37 | 38 | 39 | ### (Optional) 2. Modify the Loading and Forwarding Behavior 40 | 41 | You could modify the loading behavior of your model in the `grpo.py`. 42 | 43 | Some models takes different keys when forwarding, you may should modify the `_get_per_token_logps` method for `AsyncRLGRPOTrainer` in the `trainer.arl` to support your models. 44 | 45 | ### 3. Run the script 46 | 47 | ```bash 48 | bash fsdp.sh 49 | ``` 50 | You can view your wandb for details running, the checkpoint will be saved under `output` folder. -------------------------------------------------------------------------------- /rft/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.6.0 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | aiohappyeyeballs==2.4.4 5 | aiohttp==3.11.10 6 | aiosignal==1.3.2 7 | airportsdata==20250224 8 | annotated-types==0.7.0 9 | antlr4-python3-runtime==4.13.2 10 | anyio==4.7.0 11 | astor==0.8.1 12 | attrs==24.3.0 13 | audioread==3.0.1 14 | autocommand==2.2.2 15 | av==14.0.1 16 | backports.tarfile==1.2.0 17 | bitsandbytes==0.45.3 18 | black==25.1.0 19 | blake3==1.0.4 20 | cachetools==5.5.2 21 | certifi==2024.12.14 22 | cffi==1.17.1 23 | charset-normalizer==3.4.0 24 | click==8.1.7 25 | cloudpickle==3.1.0 26 | cmake==3.31.2 27 | colorama==0.4.6 28 | compressed-tensors==0.9.2 29 | contourpy==1.3.1 30 | cupy-cuda12x==13.4.1 31 | cycler==0.12.1 32 | Cython==3.0.11 33 | datasets==3.3.2 34 | decorator==5.2.1 35 | decord==0.6.0 36 | deepspeed==0.15.3 37 | depyf==0.18.0 38 | dill==0.3.8 39 | diskcache==5.6.3 40 | distro==1.9.0 41 | dnspython==2.7.0 42 | docker-pycreds==0.4.0 43 | docstring_parser==0.16 44 | editdistance==0.6.2 45 | einops==0.8.1 46 | einx==0.3.0 47 | email_validator==2.2.0 48 | encodec==0.1.1 49 | et_xmlfile==2.0.0 50 | fairscale==0.4.0 51 | fastapi==0.115.6 52 | fastapi-cli==0.0.7 53 | fastrlock==0.8.3 54 | ffmpy==0.4.0 55 | filelock==3.16.1 56 | fire==0.7.0 57 | flake8==7.1.2 58 | flash_attn==2.7.4.post1 59 | fonttools==4.55.3 60 | frozendict==2.4.6 61 | frozenlist==1.5.0 62 | fsspec==2024.9.0 63 | gguf==0.10.0 64 | gitdb==4.0.12 65 | GitPython==3.1.44 66 | gradio==4.41.0 67 | gradio_client==1.3.0 68 | h11==0.14.0 69 | hf_transfer==0.1.9 70 | hjson==3.1.0 71 | httpcore==1.0.7 72 | httptools==0.6.4 73 | httpx==0.28.1 74 | huggingface-hub==0.27.0 75 | idna==3.10 76 | imagesize==1.4.1 77 | importlib_metadata==8.0.0 78 | importlib_resources==6.4.5 79 | inflect==7.3.1 80 | iniconfig==2.0.0 81 | inquirerpy==0.3.4 82 | interegular==0.3.3 83 | isort==6.0.1 84 | jaraco.collections==5.1.0 85 | jaraco.context==5.3.0 86 | jaraco.functools==4.0.1 87 | jaraco.text==3.12.1 88 | jax==0.4.37 89 | jaxlib==0.4.36 90 | jieba==0.42.1 91 | Jinja2==3.1.4 92 | jiter==0.8.2 93 | joblib==1.4.2 94 | json5==0.10.0 95 | jsonlines==4.0.0 96 | jsonschema==4.23.0 97 | jsonschema-specifications==2024.10.1 98 | kiwisolver==1.4.7 99 | lark==1.2.2 100 | latex2sympy2_extended==1.10.1 101 | lazy_loader==0.4 102 | Levenshtein==0.26.1 103 | librosa==0.10.2.post1 104 | liger_kernel==0.5.2 105 | llguidance==0.7.11 106 | llvmlite==0.43.0 107 | lm-format-enforcer==0.10.11 108 | lxml==5.3.0 109 | markdown-it-py==3.0.0 110 | markdown2==2.4.10 111 | MarkupSafe==2.1.5 112 | math-verify==0.7.0 113 | matplotlib==3.7.4 114 | mccabe==0.7.0 115 | mdurl==0.1.2 116 | mistral_common==1.5.4 117 | ml_dtypes==0.5.0 118 | modelbest-sdk==0.3.1 119 | more-itertools==10.1.0 120 | mpmath==1.3.0 121 | msgpack==1.1.0 122 | msgspec==0.18.6 123 | multidict==6.1.0 124 | multiprocess==0.70.16 125 | mypy-extensions==1.0.0 126 | nest-asyncio==1.6.0 127 | networkx==3.4.2 128 | ninja==1.11.1.3 129 | nltk==3.8.1 130 | numba==0.60.0 131 | numpy==1.26.4 132 | nvidia-cublas-cu12==12.4.5.8 133 | nvidia-cuda-cupti-cu12==12.4.127 134 | nvidia-cuda-nvrtc-cu12==12.4.127 135 | nvidia-cuda-runtime-cu12==12.4.127 136 | nvidia-cudnn-cu12==9.1.0.70 137 | nvidia-cufft-cu12==11.2.1.3 138 | nvidia-curand-cu12==10.3.5.147 139 | nvidia-cusolver-cu12==11.6.1.9 140 | nvidia-cusparse-cu12==12.3.1.170 141 | nvidia-cusparselt-cu12==0.6.2 142 | nvidia-ml-py==12.560.30 143 | nvidia-nccl-cu12==2.21.5 144 | nvidia-nvjitlink-cu12==12.4.127 145 | nvidia-nvtx-cu12==12.4.127 146 | openai==1.57.4 147 | opencv-python==4.10.0.84 148 | opencv-python-headless==4.5.5.64 149 | openpyxl==3.1.2 150 | opt_einsum==3.4.0 151 | orjson==3.10.12 152 | outlines==0.1.11 153 | outlines_core==0.1.26 154 | packaging==24.2 155 | pandas==2.2.3 156 | parameterized==0.9.0 157 | partial-json-parser==0.2.1.1.post4 158 | pathspec==0.12.1 159 | peft==0.12.0 160 | pfzy==0.3.4 161 | pillow==10.4.0 162 | platformdirs==4.3.6 163 | pluggy==1.5.0 164 | ply==3.11 165 | pooch==1.8.2 166 | portalocker==3.0.0 167 | prometheus-fastapi-instrumentator==7.0.0 168 | prometheus_client==0.21.1 169 | prompt_toolkit==3.0.50 170 | propcache==0.2.1 171 | protobuf==4.25.0 172 | psutil==7.0.0 173 | py-cpuinfo==9.0.0 174 | pyairports==2.1.1 175 | pyarrow==18.1.0 176 | pycodestyle==2.12.1 177 | pycountry==24.6.1 178 | pycparser==2.22 179 | pydantic==2.9.2 180 | pydantic_core==2.23.4 181 | pydub==0.25.1 182 | pyflakes==3.2.0 183 | Pygments==2.19.1 184 | pynvml==11.5.0 185 | pyparsing==3.2.0 186 | pytest==8.3.5 187 | python-dateutil==2.9.0.post0 188 | python-dotenv==1.0.1 189 | python-json-logger==3.3.0 190 | python-multipart==0.0.19 191 | pytz==2024.2 192 | PyYAML==6.0.2 193 | pyzmq==26.3.0 194 | qwen-vl-utils==0.0.10 195 | RapidFuzz==3.10.1 196 | ray==2.44.1 197 | referencing==0.35.1 198 | regex==2024.11.6 199 | requests==2.32.3 200 | rich==13.9.4 201 | rich-toolkit==0.14.1 202 | rouge-chinese==1.0.3 203 | rpds-py==0.22.3 204 | ruff==0.8.3 205 | sacrebleu==2.3.2 206 | safetensors==0.4.5 207 | scikit-learn==1.6.0 208 | scipy==1.14.1 209 | seaborn==0.13.0 210 | semantic-version==2.10.0 211 | sentencepiece==0.2.0 212 | sentry-sdk==2.22.0 213 | setproctitle==1.3.5 214 | shellingham==1.5.4 215 | shortuuid==1.0.11 216 | shtab==1.7.1 217 | six==1.17.0 218 | smmap==5.0.2 219 | sniffio==1.3.1 220 | socksio==1.0.0 221 | soundfile==0.12.1 222 | soxr==0.5.0.post1 223 | sse-starlette==2.1.3 224 | starlette==0.41.3 225 | sympy==1.13.1 226 | tabulate==0.9.0 227 | tenacity==9.0.0 228 | tensorboardX==2.6.2.2 229 | termcolor==2.5.0 230 | threadpoolctl==3.5.0 231 | thriftpy2==0.5.2 232 | tiktoken==0.7.0 233 | timm==0.9.10 234 | tokenizers==0.21.0 235 | tomli==2.0.1 236 | tomlkit==0.12.0 237 | torch==2.6.0 238 | torchaudio==2.6.0 239 | torchvision==0.21.0 240 | tqdm==4.67.1 241 | transformers 242 | triton==3.2.0 243 | trl 244 | typeguard==4.4.2 245 | typer==0.15.1 246 | typing_extensions==4.12.2 247 | tyro==0.8.14 248 | tzdata==2024.2 249 | urllib3==2.2.3 250 | uvicorn==0.24.0.post1 251 | uvloop==0.21.0 252 | vector-quantize-pytorch==1.21.9 253 | vllm==0.8.2 254 | vocos==0.1.0 255 | wandb==0.18.3 256 | watchfiles==1.0.3 257 | wcwidth==0.2.13 258 | websockets==12.0 259 | xformers==0.0.29.post2 260 | xgrammar==0.1.16 261 | xxhash==3.5.0 262 | yacs==0.1.8 263 | yarl==1.18.3 264 | zipp==3.21.0 265 | -------------------------------------------------------------------------------- /rft/trainer/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/AgentCPM-GUI/b3dbe5c68643858351fbf10e0ce2a8922e83bf8e/rft/trainer/__init__.py -------------------------------------------------------------------------------- /rft/trainer/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .gui_eval import action_schema_check, action_args_check, action_type_check,react_check 2 | from .process import _prepare_messages,_process_inputs,_create_inputs 3 | from .dataloader import GlobalDistributed0MQDataLoader 4 | from .dataset import GUIRFTDataset,GUIMTRFTDataset 5 | from .dataloader import GlobalDistributed0MQDataLoader 6 | 7 | __all__ = [ 8 | "GUIRFTDataset","GUIMTRFTDataset", 9 | "action_schema_check","action_args_check","action_type_check","react_check", 10 | "_prepare_messages","_process_inputs","_create_inputs", 11 | "GlobalDistributed0MQDataLoader", 12 | "no_sync","Timer","logger" 13 | ] 14 | 15 | import os 16 | import time 17 | import torch 18 | import logging 19 | from contextlib import contextmanager,nullcontext 20 | from accelerate import Accelerator 21 | 22 | import warnings 23 | import functools 24 | from accelerate.utils.fsdp_utils import is_compiled_module, get_module_children_bottom_up,fsdp2_prepare_auto_wrap_policy 25 | from accelerate.utils import FullyShardedDataParallelPlugin 26 | import torch.distributed as dist 27 | 28 | @contextmanager 29 | def Timer(name: str): 30 | start = time.perf_counter() 31 | yield 32 | end = time.perf_counter() 33 | # torch.cuda.synchronize() 34 | logger.info(f"[TIMER] {name}: {(end - start)*1000:.2f} ms") 35 | 36 | logger = logging.getLogger("ARL") 37 | logger.setLevel(logging.INFO) 38 | stream_handler = logging.StreamHandler() 39 | stream_handler.setLevel(logging.INFO) 40 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 41 | stream_handler.setFormatter(formatter) 42 | logger.addHandler(stream_handler) 43 | 44 | 45 | @contextmanager 46 | def no_sync(self:Accelerator, model): 47 | '''For FSPD2, disable gradient synchronization for all model parameters.''' 48 | context = nullcontext 49 | if self.use_distributed: 50 | context = getattr(model, "no_sync", context) 51 | if self.is_fsdp2 and os.environ.get("ENABLE_FSDP2_NOSYNC","False") == "True": 52 | model.set_requires_gradient_sync(False) 53 | yield 54 | model.set_requires_gradient_sync(True) 55 | return 56 | 57 | with context(): 58 | yield 59 | 60 | 61 | 62 | 63 | def fsdp2_prepare_model(model: torch.nn.Module,mesh:dist.device_mesh.DeviceMesh) -> torch.nn.Module: 64 | """Prepares the model for FSDP2 in-place. Also returns the model to avoid misuse of the original model. 65 | 66 | Args: 67 | accelerator (`Accelerator`): The accelerator instance 68 | model (`torch.nn.Module`): The model to prepare 69 | 70 | Returns: 71 | `torch.nn.Module`: Prepared model 72 | """ 73 | from torch.distributed.fsdp import FSDPModule, MixedPrecisionPolicy, fully_shard 74 | 75 | is_type_fsdp = isinstance(model, FSDPModule) or ( 76 | is_compiled_module(model) and isinstance(model._orig_mod, FSDPModule) 77 | ) 78 | if is_type_fsdp: 79 | return model 80 | 81 | fsdp2_plugin = FullyShardedDataParallelPlugin() 82 | 83 | original_sd = model.state_dict() 84 | 85 | from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy 86 | 87 | # We need the `auto_wrap_policy` original type to create a custom poilicy function for sharding 88 | # This is because `fully_shard` doesn't support old auto wrap policies, rather we have to imitate the behaviour 89 | auto_wrap_policy_type = None 90 | if fsdp2_plugin.auto_wrap_policy is transformer_auto_wrap_policy: 91 | auto_wrap_policy_type = "transformer" 92 | elif fsdp2_plugin.auto_wrap_policy is size_based_auto_wrap_policy: 93 | auto_wrap_policy_type = "size" 94 | 95 | # We set `auto_wrap_policy` to `functools.partial` to avoid creating it again 96 | # This is because of `apply_activation_checkpointing` which will can reuse this function 97 | fsdp2_plugin.set_auto_wrap_policy(model) 98 | 99 | if fsdp2_plugin.activation_checkpointing: 100 | from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( 101 | CheckpointImpl, 102 | apply_activation_checkpointing, 103 | checkpoint_wrapper, 104 | ) 105 | 106 | # Apply activation checkpointing before applying `fully_shard` 107 | apply_activation_checkpointing( 108 | model, 109 | checkpoint_wrapper_fn=functools.partial( 110 | checkpoint_wrapper, 111 | checkpoint_impl=CheckpointImpl.NO_REENTRANT, 112 | ), 113 | auto_wrap_policy=fsdp2_plugin.auto_wrap_policy, 114 | ) 115 | fsdp2_kwargs = { 116 | "reshard_after_forward": fsdp2_plugin.reshard_after_forward, 117 | "offload_policy": fsdp2_plugin.cpu_offload, 118 | "mesh": mesh, 119 | # `fully_shard` doesn't accept `None` in case of `MixedPrecisionPolicy` 120 | "mp_policy": fsdp2_plugin.mixed_precision_policy or MixedPrecisionPolicy(), 121 | } 122 | 123 | auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, auto_wrap_policy_type, model) 124 | if auto_wrap_policy is not None: 125 | # We skip the model itself, as that one is always wrapped 126 | for module in get_module_children_bottom_up(model)[:-1]: 127 | if auto_wrap_policy(module): 128 | fully_shard(module, **fsdp2_kwargs) 129 | 130 | fully_shard(model, **fsdp2_kwargs) 131 | 132 | if fsdp2_plugin.cpu_ram_efficient_loading: 133 | # If `cpu_ram_efficient_loading` is enabled, only rank 0 loads the weights 134 | # Other ranks have an empty model on `meta` device, so we need to distribute the weights properly 135 | # fsdp2_load_full_state_dict(model, original_sd) 136 | assert False, "Currently not support `cpu_ram_efficient_loading` with Tensor Parallel." 137 | 138 | if model.dtype != torch.float32: 139 | # We upcast the model according to `deepspeed`'s implementation 140 | # More info about this can be found in `accelerator.py:prepare_model`s FSDP1 section 141 | model = model.to(torch.float32) 142 | # if accelerator.is_main_process: 143 | # # TODO(siro1): Add a warning for each parameter that was upcasted 144 | # warnings.warn( 145 | # "FSDP upcast of low precision parameters to fp32 (since mixed_precision != 'no') may affect the precision of model checkpoints." 146 | # ) 147 | return model -------------------------------------------------------------------------------- /rft/trainer/utils/dataloader.py: -------------------------------------------------------------------------------- 1 | import zmq 2 | import threading 3 | import multiprocessing 4 | import queue 5 | from typing import Iterator, Any, Callable, Optional, List 6 | import pickle 7 | from torch.utils.data import Sampler 8 | from collections import defaultdict 9 | import torch.distributed as dist 10 | 11 | class GlobalDistributed0MQDataLoader: 12 | def __init__( 13 | self, 14 | dataset: Any, 15 | global_sync_address: str, 16 | batch_size: int, 17 | collate_fn: Callable, 18 | num_workers: int, 19 | sampler: Sampler, 20 | worker_init_fn: Callable, 21 | prefetch_factor: int, 22 | world_size:int, 23 | **kwargs: Any 24 | ): 25 | self.dataset = dataset 26 | self.global_sync_address = global_sync_address 27 | self.batch_size = batch_size 28 | self.collate_fn = collate_fn 29 | self.num_workers = num_workers 30 | self.sampler = sampler 31 | self.worker_init_fn = worker_init_fn 32 | self.prefetch_factor = prefetch_factor 33 | self.kwargs = kwargs 34 | self._init_kwargs = kwargs 35 | self.world_size = world_size 36 | self.rank = dist.get_rank() 37 | 38 | self.index_queue = multiprocessing.Queue(self.num_workers) 39 | self.result_queue = multiprocessing.Queue(self.prefetch_factor) 40 | 41 | self.workers = [ multiprocessing.Process( 42 | target=GlobalDistributed0MQDataLoader._load_data, 43 | args=( 44 | self.dataset, 45 | self.index_queue, 46 | self.result_queue, 47 | self.collate_fn, 48 | self.worker_init_fn 49 | ), 50 | daemon=True 51 | ) for _ in range(self.num_workers) ] 52 | 53 | for p in self.workers: 54 | p.start() 55 | 56 | if self.rank == 0: 57 | self.master_proc = multiprocessing.Process( 58 | target=GlobalDistributed0MQDataLoader._master_loop, 59 | args=(self.global_sync_address, self.batch_size,self.sampler), 60 | daemon=True 61 | ) 62 | self.master_proc.start() 63 | 64 | def __len__(self): 65 | return len(self.sampler) // self.world_size 66 | 67 | @staticmethod 68 | def _master_loop( 69 | global_sync_address: str, 70 | batch_size: int, 71 | sampler: Sampler, 72 | ): 73 | '''Master loop to dispatch tasks to workers''' 74 | zctx = zmq.Context() 75 | task_dispatcher = zctx.socket(zmq.REP) 76 | task_dispatcher.bind(global_sync_address) 77 | 78 | it = None 79 | 80 | multiturn_cache = queue.PriorityQueue() 81 | cached_completions = defaultdict(list) 82 | 83 | while True: 84 | req: str | list[dict] | dict = task_dispatcher.recv_pyobj() 85 | if isinstance(req,str): 86 | if req == "REQ_TASK": 87 | tasks = [] 88 | for _ in range(min(multiturn_cache.qsize(),batch_size)): 89 | tasks.append(multiturn_cache.get()[1]) 90 | 91 | while len(tasks) < batch_size: 92 | try: 93 | tasks.append(next(it)) 94 | except StopIteration: 95 | it = iter(sampler) 96 | print("Restart Sampler During Epoch") 97 | 98 | task_dispatcher.send(pickle.dumps(tasks)) 99 | 100 | elif req == "RESTART": 101 | it = iter(sampler) 102 | task_dispatcher.send_string("RESTARTED") 103 | 104 | else: 105 | raise NotImplementedError(f"Receive Unknown Request {req}") 106 | elif isinstance(req,dict): 107 | if "get" in req: 108 | index = req["get"] 109 | if len(cached_completions[index]) > 0: 110 | d = cached_completions[index].pop(0) 111 | if not req["pop"]: 112 | cached_completions[index].append(d) 113 | # print(f"Cache Completions: {d}") 114 | else: 115 | d = "" 116 | print(f"Error: Illegal access of cache completions at index {index}.") 117 | task_dispatcher.send_string(d) 118 | else: 119 | raise NotImplementedError(f"Receive Unknown Request Type {type(req)}") 120 | 121 | elif isinstance(req,list): 122 | for d in req: 123 | multiturn_cache.put((d["gid"],d["next_id"])) 124 | cached_completions[d["id"]].append(d["completion"]) 125 | task_dispatcher.send_string("Received") 126 | 127 | # counts = {} 128 | # for k,v in cached_completions.items(): 129 | # if len(v) > 0: 130 | # counts[k] = len(v) 131 | # print(f"Cache Completions: {counts}") 132 | 133 | else: 134 | raise NotImplementedError(f"Receive Unknown Request Type {type(req)}") 135 | 136 | @staticmethod 137 | def _load_data( 138 | dataset: Any, 139 | index_queue: multiprocessing.Queue, 140 | result_queue: multiprocessing.Queue, 141 | collate_fn: Callable, 142 | worker_init_fn: Callable, 143 | ): 144 | worker_init_fn(None) 145 | 146 | while True: 147 | indices = index_queue.get() 148 | 149 | data = [dataset[index] for index in indices] 150 | 151 | result_queue.put(collate_fn(data)) 152 | 153 | def __iter__(self): 154 | '''Get tasks from the master and load the data into the queue''' 155 | zctx = zmq.Context() 156 | task_receiver = zctx.socket(zmq.REQ) 157 | task_receiver.connect(self.global_sync_address) 158 | 159 | 160 | if self.rank == 0: 161 | task_receiver.send_pyobj("RESTART") 162 | task_receiver.recv() 163 | dist.barrier() 164 | 165 | def get_task(): 166 | while True: 167 | task_receiver.send_pyobj("REQ_TASK") 168 | tasks = pickle.loads(task_receiver.recv(copy=False)) 169 | if tasks is None: 170 | self.result_queue.put(None) 171 | break 172 | self.index_queue.put(tasks) 173 | task_get_thd = threading.Thread(target=get_task,daemon=True) 174 | task_get_thd.start() 175 | 176 | 177 | while True: 178 | data = self.result_queue.get() 179 | if data is None: 180 | break 181 | yield data 182 | -------------------------------------------------------------------------------- /rft/trainer/utils/process.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import copy 3 | from PIL import Image 4 | 5 | def _prepare_messages( 6 | prompts, 7 | processing_class, 8 | max_prompt_length 9 | ): 10 | prompts_lists = [] 11 | input_images_lists = [] 12 | 13 | for msgs in prompts: 14 | copy_msgs = copy.deepcopy(msgs) 15 | 16 | images = [] 17 | for i, msg in enumerate(copy_msgs): 18 | role, content = msg["role"], msg["content"] 19 | 20 | if isinstance(content,str): 21 | content = [content] 22 | cur_msgs = [] 23 | for c in content: 24 | if isinstance(c, Image.Image): 25 | images.append(c) 26 | cur_msgs.append("(./)") 27 | elif isinstance(c, str): 28 | cur_msgs.append(c) 29 | msg['content'] = "\n".join(cur_msgs) 30 | 31 | prompts_lists.append( 32 | processing_class.tokenizer.apply_chat_template(copy_msgs, tokenize=False, add_generation_prompt=True) 33 | ) 34 | input_images_lists.append(images) 35 | 36 | ret = processing_class( 37 | prompts_lists, 38 | input_images_lists, 39 | return_tensors="pt", 40 | max_length=max_prompt_length 41 | ) 42 | 43 | 44 | return { 45 | **ret 46 | } 47 | 48 | def _create_inputs( 49 | processing_class, 50 | prompt_inputs, 51 | completions, 52 | ): 53 | # now handle completion_ids and completion_mask 54 | pad_token_id = getattr(processing_class,"pad_token_id", getattr(processing_class.tokenizer,"pad_token_id",None)) 55 | if pad_token_id is None: 56 | pad_token_id = 0 57 | completion_ids = torch.full((len(prompt_inputs["input_ids"]),max(map(len,completions))), pad_token_id , dtype=prompt_inputs["input_ids"].dtype,device=prompt_inputs["input_ids"].device) 58 | for idx,completion in enumerate(completions): 59 | completion_ids[idx,:len(completion)] = completion 60 | 61 | # Mask everything after the first EOS token 62 | im_eos = completion_ids == processing_class.tokenizer.convert_tokens_to_ids('<|im_end|>') 63 | s_eos = completion_ids == processing_class.tokenizer.convert_tokens_to_ids('') 64 | is_eos = im_eos | s_eos 65 | 66 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long,device=completion_ids.device) 67 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 68 | sequence_indices = torch.arange(is_eos.size(1)).expand(is_eos.size(0), -1).to(device=eos_idx.device) 69 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 70 | 71 | 72 | prompt_inputs["input_ids"] = torch.cat([prompt_inputs["input_ids"],completion_ids],dim=-1).to(dtype=torch.int64) 73 | prompt_inputs["attention_mask"] = torch.cat([prompt_inputs["attention_mask"], completion_mask], dim=1) # (B, P+C) 74 | 75 | return prompt_inputs,completion_mask 76 | 77 | def _process_inputs( 78 | inputs, 79 | processing_class, 80 | max_prompt_length 81 | ): 82 | prompts = [] 83 | completions = [] 84 | advantages = [] 85 | rewards = [] 86 | ids = [] 87 | step_ids = [] 88 | for inp in inputs: 89 | ids.append(inp["id"]) 90 | prompts.append(inp["prompt"]) 91 | completions.append(inp["completion_ids"]) 92 | advantages.append(inp["advantage"]) 93 | rewards.append(inp["reward"]) 94 | step_ids.append(inp.get("step_id",0)) 95 | 96 | ids = torch.tensor(ids) 97 | advantages = torch.tensor(advantages) 98 | step_ids = torch.tensor(step_ids) 99 | 100 | prompt_inputs = _prepare_messages(prompts,processing_class,max_prompt_length) 101 | prompt_len = prompt_inputs["input_ids"].size(1) 102 | prompt_inputs["rewards"] = torch.tensor(rewards) 103 | 104 | prompt_inputs,completion_mask = _create_inputs(processing_class,prompt_inputs,completions) 105 | return { 106 | "prompt_inputs": prompt_inputs, 107 | "completion_mask": completion_mask, 108 | "advantages": advantages, 109 | "prompt_len": prompt_len, 110 | "step_ids": step_ids 111 | } -------------------------------------------------------------------------------- /sft/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/OpenBMB/AgentCPM-GUI/b3dbe5c68643858351fbf10e0ce2a8922e83bf8e/sft/__init__.py -------------------------------------------------------------------------------- /sft/ds_config_zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | 11 | "bf16": { 12 | "enabled": "auto" 13 | }, 14 | 15 | "optimizer": { 16 | "type": "AdamW", 17 | "params": { 18 | "lr": "auto", 19 | "betas": "auto", 20 | "eps": "auto", 21 | "weight_decay": "auto" 22 | } 23 | }, 24 | 25 | "scheduler": { 26 | "type": "WarmupCosineLR", 27 | "params": { 28 | "total_num_steps": "auto", 29 | "cos_min_ratio": 0.01, 30 | "warmup_num_steps": "auto" 31 | } 32 | }, 33 | 34 | "zero_optimization": { 35 | "stage": 2, 36 | "offload_optimizer": { 37 | "device": "none", 38 | "pin_memory": true 39 | }, 40 | "allgather_partitions": true, 41 | "allgather_bucket_size": 1e9, 42 | "overlap_comm": true, 43 | "reduce_scatter": true, 44 | "reduce_bucket_size": 1e9, 45 | "contiguous_gradients": true 46 | }, 47 | 48 | "gradient_accumulation_steps": "auto", 49 | "gradient_clipping": "auto", 50 | "steps_per_print": 100, 51 | "train_batch_size": "auto", 52 | "train_micro_batch_size_per_gpu": "auto", 53 | "wall_clock_breakdown": false 54 | } 55 | -------------------------------------------------------------------------------- /sft/ds_config_zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "bf16": { 3 | "enabled": "auto" 4 | }, 5 | "optimizer": { 6 | "type": "AdamW", 7 | "params": { 8 | "lr": "auto", 9 | "betas": "auto", 10 | "eps": "auto", 11 | "weight_decay": "auto" 12 | } 13 | }, 14 | 15 | "scheduler": { 16 | "type": "WarmupCosineLR", 17 | "params": { 18 | "total_num_steps": "auto", 19 | "cos_min_ratio": 0.01, 20 | "warmup_num_steps": "auto" 21 | } 22 | }, 23 | 24 | "zero_optimization": { 25 | "stage": 3, 26 | "offload_optimizer": { 27 | "device": "none", 28 | "pin_memory": true 29 | }, 30 | "offload_param": { 31 | "device": "none", 32 | "pin_memory": true 33 | }, 34 | "overlap_comm": true, 35 | "contiguous_gradients": true, 36 | "sub_group_size": 1e9, 37 | "reduce_bucket_size": "auto", 38 | "stage3_prefetch_bucket_size": "auto", 39 | "stage3_param_persistence_threshold": "auto", 40 | "stage3_max_live_parameters": 1e9, 41 | "stage3_max_reuse_distance": 1e9, 42 | "stage3_gather_16bit_weights_on_model_save": true 43 | }, 44 | 45 | "gradient_accumulation_steps": "auto", 46 | "gradient_clipping": "auto", 47 | "steps_per_print": 100, 48 | "train_batch_size": "auto", 49 | "train_micro_batch_size_per_gpu": "auto", 50 | "wall_clock_breakdown": false 51 | } 52 | 53 | -------------------------------------------------------------------------------- /sft/ds_config_zero3_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "train_micro_batch_size_per_gpu": "auto", 4 | "gradient_accumulation_steps": "auto", 5 | "gradient_clipping": "auto", 6 | "zero_allow_untested_optimizer": true, 7 | "fp16": { 8 | "enabled": "auto", 9 | "loss_scale": 0, 10 | "loss_scale_window": 1000, 11 | "initial_scale_power": 16, 12 | "hysteresis": 2, 13 | "min_loss_scale": 1 14 | }, 15 | "bf16": { 16 | "enabled": "auto" 17 | }, 18 | "scheduler": { 19 | "type": "WarmupCosineLR", 20 | "params": { 21 | "total_num_steps": "auto", 22 | "cos_min_ratio": 0.01, 23 | "warmup_num_steps": "auto" 24 | } 25 | }, 26 | "zero_optimization": { 27 | "stage": 3, 28 | "offload_optimizer": { 29 | "device": "cpu", 30 | "pin_memory": true 31 | }, 32 | "offload_param": { 33 | "device": "cpu", 34 | "pin_memory": true 35 | }, 36 | "overlap_comm": true, 37 | "contiguous_gradients": true, 38 | "sub_group_size": 1e9, 39 | "reduce_bucket_size": "auto", 40 | "stage3_prefetch_bucket_size": "auto", 41 | "stage3_param_persistence_threshold": "auto", 42 | "stage3_max_live_parameters": 1e9, 43 | "stage3_max_reuse_distance": 1e9, 44 | "stage3_gather_16bit_weights_on_model_save": true 45 | } 46 | } 47 | 48 | -------------------------------------------------------------------------------- /sft/finetune_ds.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUS_PER_NODE=8 4 | NNODES=1 5 | NODE_RANK=0 6 | MASTER_ADDR=localhost 7 | MASTER_PORT=6001 8 | 9 | MODEL="openbmb/MiniCPM-V-2_6" 10 | # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6 11 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 12 | # See the section for finetuning in README for more information. 13 | DATA="path/to/trainging_data" 14 | EVAL_DATA="path/to/test_data" 15 | 16 | # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3", 17 | # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen 18 | LLM_TYPE="qwen" 19 | MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 20 | 21 | 22 | DISTRIBUTED_ARGS=" 23 | --nproc_per_node $GPUS_PER_NODE \ 24 | --nnodes $NNODES \ 25 | --node_rank $NODE_RANK \ 26 | --master_addr $MASTER_ADDR \ 27 | --master_port $MASTER_PORT 28 | " 29 | torchrun $DISTRIBUTED_ARGS finetune.py \ 30 | --model_name_or_path $MODEL \ 31 | --llm_type $LLM_TYPE \ 32 | --data_path $DATA \ 33 | --eval_data_path $EVAL_DATA \ 34 | --remove_unused_columns false \ 35 | --label_names "labels" \ 36 | --prediction_loss_only false \ 37 | --bf16 true \ 38 | --bf16_full_eval true \ 39 | --fp16 false \ 40 | --fp16_full_eval false \ 41 | --do_train \ 42 | --do_eval \ 43 | --tune_vision true \ 44 | --tune_llm true \ 45 | --model_max_length $MODEL_MAX_Length \ 46 | --max_slice_nums 9 \ 47 | --num_train_epochs 3 \ 48 | --output_dir output/model \ 49 | --logging_dir output/model \ 50 | --max_line_res 1120 \ 51 | --logging_strategy "steps" \ 52 | --per_device_train_batch_size 1 \ 53 | --gradient_accumulation_steps 1 \ 54 | --save_strategy "steps" \ 55 | --save_steps 1000 \ 56 | --save_total_limit 100 \ 57 | --learning_rate 1e-5 \ 58 | --weight_decay 0.1 \ 59 | --adam_beta1 0.9 \ 60 | --adam_beta2 0.999 \ 61 | --warmup_ratio 0.05 \ 62 | --lr_scheduler_type "cosine" \ 63 | --logging_steps 1 \ 64 | --gradient_checkpointing false \ 65 | --deepspeed ds_config_zero2.json \ 66 | --report_to "tensorboard" \ 67 | --dataloader_num_workers 8 \ 68 | --dataloader_prefetch_factor 16 \ 69 | -------------------------------------------------------------------------------- /sft/finetune_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | GPUS_PER_NODE=8 4 | NNODES=1 5 | NODE_RANK=0 6 | MASTER_ADDR=localhost 7 | MASTER_PORT=6001 8 | 9 | MODEL="openbmb/MiniCPM-o-2_6" 10 | # or openbmb/MiniCPM-V-2, openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2_6 11 | # ATTENTION: specify the path to your training data, which should be a json file consisting of a list of conversations. 12 | # See the section for finetuning in README for more information. 13 | DATA="path/to/trainging_data" 14 | EVAL_DATA="path/to/test_data" 15 | # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3", 16 | # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen 17 | LLM_TYPE="qwen" 18 | MODEL_MAX_Length=2048 # if conduct multi-images sft, please set MODEL_MAX_Length=4096 19 | 20 | DISTRIBUTED_ARGS=" 21 | --nproc_per_node $GPUS_PER_NODE \ 22 | --nnodes $NNODES \ 23 | --node_rank $NODE_RANK \ 24 | --master_addr $MASTER_ADDR \ 25 | --master_port $MASTER_PORT 26 | " 27 | 28 | torchrun $DISTRIBUTED_ARGS finetune.py \ 29 | --model_name_or_path $MODEL \ 30 | --llm_type $LLM_TYPE \ 31 | --data_path $DATA \ 32 | --eval_data_path $EVAL_DATA \ 33 | --remove_unused_columns false \ 34 | --label_names "labels" \ 35 | --prediction_loss_only false \ 36 | --bf16 false \ 37 | --bf16_full_eval false \ 38 | --fp16 true \ 39 | --fp16_full_eval true \ 40 | --do_train \ 41 | --do_eval \ 42 | --tune_vision true \ 43 | --tune_llm false \ 44 | --use_lora true \ 45 | --lora_target_modules "llm\..*layers\.\d+\.self_attn\.(q_proj|k_proj|v_proj|o_proj)" \ 46 | --model_max_length $MODEL_MAX_Length \ 47 | --max_slice_nums 9 \ 48 | --max_steps 10000 \ 49 | --eval_steps 1000 \ 50 | --output_dir output/output__lora \ 51 | --logging_dir output/output_lora \ 52 | --logging_strategy "steps" \ 53 | --per_device_train_batch_size 1 \ 54 | --per_device_eval_batch_size 1 \ 55 | --gradient_accumulation_steps 1 \ 56 | --evaluation_strategy "steps" \ 57 | --save_strategy "steps" \ 58 | --save_steps 1000 \ 59 | --save_total_limit 10 \ 60 | --learning_rate 1e-6 \ 61 | --weight_decay 0.1 \ 62 | --adam_beta2 0.95 \ 63 | --warmup_ratio 0.01 \ 64 | --lr_scheduler_type "cosine" \ 65 | --logging_steps 1 \ 66 | --gradient_checkpointing true \ 67 | --deepspeed ds_config_zero2.json \ 68 | --report_to "tensorboard" # wandb 69 | -------------------------------------------------------------------------------- /sft/readme.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning Scripts 2 | 3 | This repository provides fine-tuning scripts adapted from MiniCPM-o, supporting fine-tuning the model on GUI agent tasks. For detailed parameter configurations, please refer to the [MiniCPM-o](https://github.com/OpenBMB/MiniCPM-o/tree/main/finetune) repository. 4 | 5 | --- 6 | 7 | ## 📂 Data Preparation 8 | 9 | Each training sample should be a dictionary with: 10 | 11 | - `id`: a unique identifier, 12 | - `image`: either a single image path (string) or a dictionary of image paths with placeholders (``, ``, ...), 13 | - `conversations`: a list of dialogue turns between user and assistant. 14 | 15 | #### Single Image Example 16 | 17 | ```json 18 | [ 19 | { 20 | "id": "0", 21 | "image": "path/to/image.jpg", 22 | "conversations": [ 23 | {"role": "system", "content": "system prompt"}, 24 | {"role": "user", "content": "\nWhat is in the image?"}, 25 | {"role": "assistant", "content": "The image contains..."} 26 | ] 27 | } 28 | ] 29 | ``` 30 | 31 | #### Multi-Image Example 32 | 33 | ```json 34 | [ 35 | { 36 | "id": "0", 37 | "image": { 38 | "": "path/to/image0.jpg", 39 | "": "path/to/image1.jpg" 40 | }, 41 | "conversations": [ 42 | {"role": "system", "content": "system prompt"}, 43 | {"role": "user", "content": "Compare the objects.\n\n"}, 44 | {"role": "assistant", "content": "The first image shows..."} 45 | ] 46 | } 47 | ] 48 | ``` 49 | 50 | If no image placeholder is present in the text, the image embedding will be prepended by default. 51 | 52 | --- 53 | 54 | ## 🚀 Training Setup 55 | 56 | #### Full-Parameter Finetuning 57 | 58 | Edit and run "finetune_ds.sh": 59 | 60 | ```bash 61 | MODEL="openbmb/MiniCPM-V-2_6" # or "openbmb/MiniCPM-o-2_6", openbmb/MiniCPM-Llama3-V-2_5, openbmb/MiniCPM-V-2 62 | DATA="path/to/trainging_data" # json file 63 | EVAL_DATA="path/to/test_data" # json file 64 | LLM_TYPE="qwen" # if use openbmb/MiniCPM-V-2, please set LLM_TYPE=minicpm, if use openbmb/MiniCPM-Llama3-V-2_5, please set LLM_TYPE="llama3", 65 | # if use openbmb/MiniCPM-o-2_6 or openbmb/MiniCPM-V-2_6, please set LLM_TYPE=qwen 66 | ``` 67 | 68 | --- 69 | 70 | ## 📃 Training Data Sample 71 | 72 | ```json 73 | [{"id":"0","image":{"":"/image_path/screenshot.jpeg"},"conversations":[{"role":"system","content":"# Role\n你是一名熟悉安卓系统触屏GUI操作的智能体,将根据用户的问题,分析当前界面的GUI元素和布局,生成相应的操作。\n\n# Task\n针对用户问题,根据输入的当前屏幕截图,输出下一步的操作。\n\n# Rule\n- 以紧凑JSON格式输出\n- 输出操作必须遵循Schema约束\n\n# Schema\n{\"type\":\"object\",\"description\":\"执行操作并决定当前任务状态\",\"additionalProperties\":false,\"optional\":[\"thought\"],\"properties\":{\"thought\":{\"type\":\"string\",\"description\":\"智能体的思维过程\"},\"POINT\":{\"$ref\":\"#/$defs/Location\",\"description\":\"点击屏幕上的指定位置\"},\"to\":{\"description\":\"移动,组合手势参数\",\"oneOf\":[{\"enum\":[\"up\",\"down\",\"left\",\"right\"],\"description\":\"从当前点(POINT)出发,执行滑动手势操作,方向包括向上、向下、向左、向右\"},{\"$ref\":\"#/$defs/Location\",\"description\":\"移动到某个位置\"}]},\"duration\":{\"type\":\"integer\",\"description\":\"动作执行的时间或等待时间,毫秒\",\"minimum\":0,\"default\":200},\"PRESS\":{\"type\":\"string\",\"description\":\"触发特殊按键,HOME为回到主页按钮,BACK为返回按钮,ENTER为回车按钮\",\"enum\":[\"HOME\",\"BACK\",\"ENTER\"]},\"TYPE\":{\"type\":\"string\",\"description\":\"输入文本\"},\"STATUS\":{\"type\":\"string\",\"description\":\"当前任务的状态。特殊情况:satisfied,无需操作;impossible,任务无法完成;interrupt,任务中断;need_feedback,需要用户反馈;\",\"enum\":[\"continue\",\"finish\",\"satisfied\",\"impossible\",\"interrupt\",\"need_feedback\"],\"default\":\"continue\"}},\"$defs\":{\"Location\":{\"type\":\"array\",\"description\":\"坐标为相对于屏幕左上角位原点的相对位置,并且按照宽高比例缩放到0~1000,数组第一个元素为横坐标x,第二个元素为纵坐标y\",\"items\":{\"type\":\"integer\",\"minimum\":0,\"maximum\":1000},\"minItems\":2,\"maxItems\":2}}}"},{"role":"user","content":"打开美团外卖\n当前屏幕截图:"},{"role":"assistant","content":"{\"POINT\":[197,634],\"to\":\"right\"}"}]}] 74 | ``` 75 | --------------------------------------------------------------------------------