├── .DS_Store ├── .gitignore ├── LICENSE ├── README.md ├── configs ├── grpo-qwen-2.5-v.yaml └── zero3.json ├── data └── textvqa_cot_train_1_bbox_0.json ├── grpo_qwen2vl.py ├── run_grpo_vlm.sh └── run_r1_grpo_vlm.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FusionBrainLab/Vision_GRPO/65ec90d090f93dae1f303ea7eeed8c9d0c06f64b/.DS_Store -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # PyPI configuration file 171 | .pypirc 172 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Vision GRPO 2 | 3 | # Training Vision Language Models with GRPO for Visual Grounding 4 | 5 | Based on the recent advances in RL for reasoning enhance, we'll explore how to fine-tune Vision Language Models (VLMs) using ***Group Relative Policy Optimization*** (***GRPO***). We'll walk through a complete training pipeline, from dataset preparation to evaluating results. 6 | 7 | ## 1. Modified GRPO for Vision Language Models 8 | 9 | ### Adapting GRPO for Vision Language Models 10 | 11 | Based on the great tutorial [mini-R1](https://www.philschmid.de/mini-deepseek-r1) tutorial, we provided the modified version of the approach for training vision language models using the same reasoning approach. To adapt it for Vision Language Models, we need to: 12 | 13 | 1. **Handle Multimodal Inputs**: Process both images and text in the same framework 14 | 2. **Custom Reward Functions**: Create vision-specific rewards that evaluate how well the model identifies regions in images 15 | 3. **Specialized Architecture**: Use a vision-language model architecture (like Qwen2.5-VL) that can process both modalities 16 | 17 | Due to the fact that each Vision-Language model follows its own architecture, we are not able to use unified abstraction such as AutoModelforCausalLM for language models, so, our tutorial covers two common multimodal architectures for Qwen-VL (2 and 2.5). 18 | 19 | The modified `Qwen2VLGRPOTrainer` enables: 20 | 21 | - Processing of image-text pairs 22 | - Evaluation of visual grounding capabilities 23 | - Optimization of both text generation and region identification 24 | 25 | ```python 26 | # Example of the modified GRPO trainer integration 27 | trainer = Qwen2VLGRPOTrainer( 28 | model=model, 29 | reward_funcs=[grounding_reward, format_reward], 30 | args=training_args, 31 | train_dataset=train_dataset, 32 | eval_dataset=test_dataset, 33 | peft_config=get_peft_config(model_args) 34 | ) 35 | 36 | ``` 37 | 38 | ## 2. The Visual Grounding Task for CoT Reasoning 39 | 40 | ### What is Visual Grounding? 41 | 42 | Basically, vision grounding task is defined as a task to provide bounding boxes for the object defined in the input request. We look into the visual grounding as a suplementary task to help the model to provide correct answer on some complex question. This approach was investigated in [Visual CoT](https://arxiv.org/abs/2403.16999), where authors proposed to use visual grounding task to zoom into specific part of the image, where the answer is kept. In our tutorial, we use subsample of the textVQA dataset to show, whether we can teach the model to zoom in to the relevant parts of the image via RL. 43 | 44 | ### Task Formulation 45 | 46 | The task is structured as follows: 47 | 48 | 1. The model receives an image and a text query about a specific visual element 49 | 2. The model must: 50 | - Reason through the visual content (in `...` tags) 51 | - Output precise bounding box coordinates for the relevant region (in `[x1, y1, x2, y2]` format) 52 | 53 | ### Example Query 54 | 55 | ``` 56 | Image: [Image of a living room] 57 | Query: Where is the red vase in this image? Show your reasoning in thinking process tags. Return bounding box in [x1, y1, x2, y2] tags. 58 | ``` 59 | 60 | ### Expected Output 61 | 62 | ``` 63 | Let me analyze this image. 64 | 65 | I can see a living room with various furniture. Looking for a red vase... 66 | I can see a red vase on the coffee table in the center of the image. 67 | It appears to be located approximately at the coordinates [220, 150, 260, 210]. 68 | 69 | {"bbox": [220, 150, 260, 210]} 70 | 71 | ``` 72 | 73 | ## 3. Dataset Preparation 74 | 75 | ### Dataset Structure 76 | 77 | For this tutorial, we use a vision chain-of-thought dataset specifically designed for visual grounding tasks: 78 | 79 | ```python 80 | import json 81 | import math 82 | from PIL import Image 83 | import os 84 | 85 | def process_jsonl_data(jsonl_file, train_path, output_file=None, max_size=512, maintain_aspect_ratio=True): 86 | """ 87 | Process a JSONL file containing image metadata, resize images, and rescale bounding boxes. 88 | 89 | Parameters: 90 | ----------- 91 | jsonl_file: str 92 | Path to the JSONL file 93 | train_path: str 94 | Path to the directory containing training images 95 | output_file: str, optional 96 | Path to save the processed dataset (if None, just returns the data) 97 | max_size: int, default=512 98 | Maximum dimension for resized images 99 | maintain_aspect_ratio: bool, default=True 100 | Whether to maintain aspect ratio when resizing 101 | 102 | Returns: 103 | -------- 104 | list: Processed dataset 105 | """ 106 | dataset = [] 107 | 108 | # Count for statistics 109 | total_entries = 0 110 | skipped_entries = 0 111 | processed_entries = 0 112 | 113 | with open(jsonl_file, "r", encoding="utf-8") as f: 114 | for line in f: 115 | if not line.strip(): 116 | # Skip any empty lines if present 117 | continue 118 | 119 | total_entries += 1 120 | 121 | try: 122 | data = json.loads(line) 123 | 124 | # Skip entries with multiple bounding boxes 125 | if len(data['bboxs']) > 1: 126 | skipped_entries += 1 127 | continue 128 | 129 | # Ensure image path is complete 130 | if not data['image'].startswith(train_path): 131 | data['image'] = os.path.join(train_path, data['image']) 132 | 133 | # Check if image exists 134 | if not os.path.exists(data['image']): 135 | print(f"Warning: Image not found at {data['image']}") 136 | skipped_entries += 1 137 | continue 138 | 139 | # Open and get dimensions of the image 140 | try: 141 | image = Image.open(data['image']) 142 | original_width, original_height = image.size 143 | except Exception as e: 144 | print(f"Error opening image {data['image']}: {e}") 145 | skipped_entries += 1 146 | continue 147 | 148 | # Determine new dimensions 149 | if maintain_aspect_ratio: 150 | if original_width > max_size or original_height > max_size: 151 | # Calculate new dimensions maintaining aspect ratio 152 | if original_width > original_height: 153 | new_width = max_size 154 | new_height = int(original_height * (max_size / original_width)) 155 | else: 156 | new_height = max_size 157 | new_width = int(original_width * (max_size / original_height)) 158 | else: 159 | # Image is within acceptable dimensions, no resize needed 160 | new_width, new_height = original_width, original_height 161 | else: 162 | # Fixed size without maintaining aspect ratio 163 | new_width, new_height = max_size, max_size 164 | 165 | # Only rescale bounding boxes if dimensions changed 166 | if new_width != original_width or new_height != original_height: 167 | # Calculate the scaling factors 168 | scale_x = new_width / original_width 169 | scale_y = new_height / original_height 170 | 171 | # Rescale all bounding boxes 172 | new_bboxs = [] 173 | for original_bbox in data['bboxs']: 174 | # Adjust the bounding box coordinates 175 | new_bbox = [ 176 | math.ceil(original_bbox[0] * scale_x), 177 | math.ceil(original_bbox[1] * scale_y), 178 | math.ceil(original_bbox[2] * scale_x), 179 | math.ceil(original_bbox[3] * scale_y) 180 | ] 181 | new_bboxs.append(new_bbox) 182 | 183 | # Update bounding boxes in the data 184 | data['bboxs'] = new_bboxs.copy() 185 | 186 | # Store the new dimensions in the data 187 | data['width'] = new_width 188 | data['height'] = new_height 189 | 190 | # Append processed data to the dataset 191 | dataset.append(data) 192 | processed_entries += 1 193 | 194 | # Print progress every 1000 entries 195 | if processed_entries % 1000 == 0: 196 | print(f"Processed {processed_entries} entries...") 197 | 198 | except Exception as e: 199 | print(f"Error processing line: {e}") 200 | skipped_entries += 1 201 | 202 | # Print statistics 203 | print(f"Total entries: {total_entries}") 204 | print(f"Processed entries: {processed_entries}") 205 | print(f"Skipped entries: {skipped_entries}") 206 | 207 | # Save processed dataset if output file is specified 208 | if output_file: 209 | with open(output_file, 'w', encoding='utf-8') as f: 210 | for item in dataset: 211 | f.write(json.dumps(item) + '\n') 212 | print(f"Saved processed dataset to {output_file}") 213 | 214 | return dataset 215 | 216 | # Example usage: 217 | if __name__ == "__main__": 218 | TRAIN_PATH = "./train_images/" 219 | JSONL_FILE = "./metadata/textvqa_cot_train.jsonl" 220 | OUTPUT_FILE = "processed_textvqa_train.jsonl" 221 | 222 | # Process the JSONL file 223 | processed_data = process_jsonl_data( 224 | jsonl_file=JSONL_FILE, 225 | train_path=TRAIN_PATH, 226 | output_file=OUTPUT_FILE, 227 | max_size=512, 228 | maintain_aspect_ratio=True 229 | ) 230 | 231 | print(f"Processed dataset contains {len(processed_data)} entries") 232 | 233 | # Show a sample entry if available 234 | if processed_data: 235 | sample = processed_data[0] 236 | print("\nSample entry:") 237 | print(f"Question: {sample['question']}") 238 | print(f"Answer: {sample['answer']}") 239 | print(f"Image: {sample['image']}") 240 | print(f"Dimensions: {sample['width']}x{sample['height']}") 241 | print(f"Bounding boxes: {sample['bboxs']}") 242 | 243 | ``` 244 | 245 | ### Generating Prompts for Training 246 | 247 | We format each example into a chat template for Qwen2.5-VL, using a system message that specifies the visual grounding task: 248 | 249 | ```python 250 | system_message = "You are a Vision Language Model specialized in visual grounding. Provide bounding box in [x1, y1, x2, y2] ." 251 | 252 | def generate_r1_prompt(sample): 253 | prefix = [ 254 | {"role": "system", "content": [{"type": "text", "text": system_message}]}, 255 | { 256 | "role": "user", 257 | "content": [ 258 | {"type": "image", "image": sample["image"]}, 259 | { 260 | "type": "text", 261 | "text": ( 262 | sample["question"] + " Show your reasoning in thinking process tags. " 263 | "Return bounding box in [x1, y1, x2, y2] tags." 264 | ), 265 | }, 266 | ], 267 | }, 268 | { 269 | "role": "assistant", 270 | "content": [{"type": "text", "text": "Let me analyze this image.\n"}], 271 | }, 272 | ] 273 | encoded_prompt = processor.apply_chat_template(prefix, tokenize=False, continue_final_message=True) 274 | return {"prompt": encoded_prompt, "target": sample["bboxs"]} 275 | 276 | # Apply prompt generation to dataset 277 | dataset = dataset.map(generate_r1_prompt) 278 | 279 | # Create train/test split 280 | train_test_split = dataset.train_test_split(test_size=0.1) 281 | train_dataset = train_test_split["train"] 282 | test_dataset = train_test_split["test"] 283 | 284 | ``` 285 | 286 | ## 4. Launching Training 287 | 288 | ### Setting Up Reward Functions 289 | 290 | A key component of GRPO is the definition of reward functions. For visual grounding, we define multiple reward functions to evaluate different aspects of the model's output: 291 | 292 | ```python 293 | def grounding_reward(completions, target, **kwargs): 294 | """Reward function that checks bounding boxes.""" 295 | rewards = [] 296 | for completion, gt_bbox in zip(completions, target): 297 | try: 298 | bbox_match = re.search(r"\[(.*?)\]", completion) 299 | if bbox_match: 300 | pred_bbox = [float(x.strip()) for x in bbox_match.group(1).split(",")] 301 | gt_bbox = [float(x) for x in gt_bbox[0].strip("[]").split(",")] 302 | 303 | # Check IoU between predicted and ground truth bounding boxes 304 | reward = 1.0 if relaxed_bbox_iou(pred_bbox, gt_bbox) else 0.0 305 | else: 306 | reward = 0.0 307 | except Exception: 308 | reward = 0.0 309 | rewards.append(reward) 310 | return rewards 311 | 312 | def format_reward(completions, **kwargs): 313 | """Check that completions follow the required format.""" 314 | completions = ["" + c for c in completions] 315 | pattern = r".*?\s*.*?" 316 | matches = [re.fullmatch(pattern, c, re.DOTALL) for c in completions] 317 | return [1.0 if m else 0.0 for m in matches] 318 | 319 | # Select reward functions for training 320 | chosen_reward_funcs = [grounding_reward, format_reward] 321 | 322 | ``` 323 | 324 | ### Training Configuration 325 | 326 | We configure the training process with appropriate hyperparameters: 327 | 328 | ```python 329 | # Training arguments example 330 | training_args = GRPOConfig( 331 | output_dir="./qwen_vl_grpo_output", 332 | num_train_epochs=3, 333 | per_device_train_batch_size=1, 334 | per_device_eval_batch_size=1, 335 | gradient_accumulation_steps=2, 336 | learning_rate=1e-5, 337 | warmup_steps=100, 338 | logging_steps=10, 339 | evaluation_strategy="steps", 340 | eval_steps=50, 341 | save_strategy="steps", 342 | save_steps=50, 343 | save_total_limit=3, 344 | bf16=True, 345 | report_to="wandb", 346 | logging_first_step=True 347 | ) 348 | 349 | ``` 350 | 351 | ### Initializing Model and Trainer 352 | 353 | We load the Qwen2.5-VL model and set up the GRPO trainer: 354 | 355 | ```python 356 | from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration 357 | 358 | # Load model and processor 359 | processor = Qwen2_5_VLProcessor.from_pretrained( 360 | model_args.model_name_or_path, 361 | trust_remote_code=True 362 | ) 363 | 364 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 365 | model_args.model_name_or_path, 366 | trust_remote_code=True 367 | ) 368 | 369 | # Initialize GRPO trainer 370 | trainer = Qwen2VLGRPOTrainer( 371 | model=model, 372 | reward_funcs=chosen_reward_funcs, 373 | args=training_args, 374 | train_dataset=train_dataset, 375 | eval_dataset=test_dataset, 376 | peft_config=get_peft_config(model_args) 377 | ) 378 | 379 | # Start training 380 | train_result = trainer.train() 381 | 382 | ``` 383 | 384 | ### Saving and Logging 385 | 386 | After training completes, we save the model and metrics: 387 | 388 | ```python 389 | # Save metrics 390 | metrics = train_result.metrics 391 | trainer.log_metrics("train", metrics) 392 | trainer.save_metrics("train", metrics) 393 | trainer.save_state() 394 | 395 | # Save model 396 | trainer.save_model(training_args.output_dir) 397 | processor.save_pretrained(training_args.output_dir) 398 | 399 | # Optional: Push to Hugging Face Hub 400 | if training_args.push_to_hub: 401 | trainer.push_to_hub() 402 | 403 | ``` 404 | 405 | ## 5. Training Metrics 406 | 407 | It is interesting to notice that the grounding reward didn't change much, due to the fact that the Qwen-VL model is able to provide zero-shot object grounding. Also, we noticed that it was necessary to provide the correct format of the answer, closer to the one the model was adjusted to, otherwise, the grounding reward was a constant 0. 408 | 409 | image 410 | 411 | image 412 | 413 | ## 6. Example Results 414 | 415 | Let's look at some examples of the model's performance after training: 416 | 417 | ### Example 1: Successful Grounding 418 | 419 | image 420 | 421 | **Query:** 422 | 423 | ``` 424 | What is the comment? Show your reasoning in thinking process tags. Return bounding box in [x1, y1, x2, y2] tags. 425 | 426 | ``` 427 | 428 | **Model Output:** 429 | 430 | 431 | ``` 432 | Let me analyze this image. 433 | 434 | The comment on the book is located near the bottom of the image, just above the comment input field. 435 | 436 | {"bbox": [508, 467, 593, 487]} 437 | 438 | ``` 439 | 440 | Qwen2.5 VL initially performs well on grounding tasks; however, the results vary across different examples. 441 | 442 | DUNE 443 | 444 | ## Conclusion 445 | 446 | In this tutorial, we've walked through the complete process of training a Vision Language Model for visual grounding using GRPO: 447 | 448 | 1. We adapted GRPO for vision-language tasks by implementing custom reward functions for bounding box evaluation 449 | 2. Prepared a specialized dataset for visual grounding with formatted prompts 450 | 3. Configured and launched training with the modified `Qwen2VLGRPOTrainer` 451 | 4. Examined examples showing the model's ability to perform visual grounding tasks 452 | 453 | This approach demonstrates how reinforcement learning techniques can be applied to multimodal models, helping them learn to connect textual and visual information more effectively. While the example is not for real-life applications, and smaller models can benefit more from SFT-reasoning, this is a good starting point. 454 | 455 | While the GRPO for VLM can provide interesting findings, there is important to notice the following: 456 | 457 | 1. As noticed by researchers from DeepSeek, small vision-language models do not perform well on GRPO tasks; SFT could provide better results. 458 | 3. Processing long context remains a challenge; we will further adjust the code for these cases. 459 | 4. It is important to construct the reward function and the answer format suitable for the model, otherwise the model can get stuck in a local minimum. 460 | 461 | ### Next Steps 462 | 463 | - Experiment with different reward functions to further improve performance 464 | - Explore more complex visual grounding tasks (e.g., multiple object identification) 465 | - Combine with other vision-language tasks like visual question answering or image captioning 466 | 467 | ### Resources 468 | 469 | - [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300) 470 | - [Qwen2.5-VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl) 471 | - [Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning](https://arxiv.org/abs/2403.16999) 472 | - [Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial](https://www.philschmid.de/mini-deepseek-r1) 473 | - [VLM-R1](https://github.com/om-ai-lab/VLM-R1/tree/main) 474 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal) 475 | -------------------------------------------------------------------------------- /configs/grpo-qwen-2.5-v.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | attn_implementation: flash_attention_2 6 | bf16: true 7 | tf32: true 8 | 9 | # Data arguments 10 | json_data_path: "./textvqa_cot_train_1_bbox_0.json" 11 | 12 | # Training arguments 13 | output_dir: ./checkpoints/qwen2_5-3b-grpo-updated 14 | per_device_train_batch_size: 1 15 | gradient_accumulation_steps: 2 16 | gradient_checkpointing: true 17 | gradient_checkpointing_kwargs: 18 | use_reentrant: false 19 | learning_rate: 1.0e-6 20 | lr_scheduler_type: linear 21 | warmup_ratio: 0.0 22 | beta: 0.04 23 | max_prompt_length: 1280 24 | max_completion_length: 256 25 | num_generations: 8 26 | use_vllm: false 27 | 28 | max_pixels: 12845056 # Maximum number of pixels for image processing 29 | min_pixels: 3136 # Minimum number of pixels for image processing 30 | 31 | logging_strategy: steps 32 | logging_steps: 1 33 | save_strategy: steps 34 | save_steps: 1000 35 | seed: 42 36 | 37 | push_to_hub: false 38 | -------------------------------------------------------------------------------- /configs/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /grpo_qwen2vl.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import functools 17 | import os 18 | import textwrap 19 | import warnings 20 | from collections import defaultdict 21 | from typing import Any, Callable, Optional, Sized, Union 22 | from unittest.mock import patch 23 | 24 | import torch 25 | import torch.utils.data 26 | import transformers 27 | from accelerate import PartialState 28 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed 29 | from accelerate.utils.other import is_compiled_module 30 | from datasets import Dataset, IterableDataset 31 | from packaging import version 32 | from torch import nn 33 | from torch.utils.data import Sampler 34 | from transformers import ( 35 | AutoModelForCausalLM, 36 | AutoModelForSequenceClassification, 37 | AutoTokenizer, 38 | GenerationConfig, 39 | PreTrainedModel, 40 | PreTrainedTokenizerBase, 41 | Trainer, 42 | TrainerCallback, 43 | is_wandb_available, 44 | ) 45 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled 46 | from transformers.utils import is_peft_available 47 | 48 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template 49 | from trl.extras.profiling import profiling_context, profiling_decorator 50 | from trl.import_utils import is_rich_available, is_vllm_available 51 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation 52 | from trl.trainer.callbacks import SyncRefModelCallback 53 | from trl.trainer.grpo_config import GRPOConfig 54 | from trl.trainer.utils import ( 55 | generate_model_card, 56 | get_comet_experiment_url, 57 | pad, 58 | print_prompt_completions_sample, 59 | selective_log_softmax, 60 | ) 61 | 62 | import torch 63 | from transformers import ( 64 | Qwen2_5_VLForConditionalGeneration, 65 | Qwen2VLForConditionalGeneration, 66 | AutoTokenizer, 67 | AutoProcessor, 68 | ) 69 | 70 | if is_peft_available(): 71 | from peft import PeftConfig, get_peft_model 72 | 73 | if is_vllm_available(): 74 | from vllm import LLM, SamplingParams 75 | from vllm.sampling_params import GuidedDecodingParams 76 | 77 | if is_wandb_available(): 78 | import wandb 79 | 80 | import PIL 81 | 82 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of 83 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model. 84 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]] 85 | 86 | 87 | class RepeatRandomSampler(Sampler): 88 | """ 89 | Sampler that repeats the indices of a dataset in a structured manner. 90 | 91 | Args: 92 | data_source (`Sized`): 93 | Dataset to sample from. 94 | mini_repeat_count (`int`): 95 | Number of times to repeat each index per batch. 96 | batch_size (`int`, *optional*, defaults to `1`): 97 | Number of unique indices per batch. 98 | repeat_count (`int`, *optional*, defaults to `1`): 99 | Number of times to repeat the full sampling process. 100 | seed (`int` or `None`, *optional*, defaults to `None`): 101 | Random seed for reproducibility (only affects this sampler). 102 | 103 | Example: 104 | ```python 105 | >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) 106 | >>> list(sampler) 107 | [4, 4, 3, 3, 0, 0, 108 | 4, 4, 3, 3, 0, 0, 109 | 4, 4, 3, 3, 0, 0, 110 | 4, 4, 3, 3, 0, 0, 111 | 112 | 1, 1, 2, 2, 6, 6, 113 | 1, 1, 2, 2, 6, 6, 114 | 1, 1, 2, 2, 6, 6, 115 | 1, 1, 2, 2, 6, 6] 116 | ``` 117 | 118 | ```txt 119 | mini_repeat_count = 3 120 | - - - 121 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 122 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 123 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | 124 | repeat_count = 2 125 | 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 126 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 127 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | 128 | --------- --------- --------- --------- 129 | --------- --------- --------- --------- 130 | --------- --------- --------- --------- 131 | batch_size = 12 132 | ``` 133 | """ 134 | 135 | def __init__( 136 | self, 137 | data_source: Sized, 138 | mini_repeat_count: int, 139 | batch_size: int = 1, 140 | repeat_count: int = 1, 141 | seed: Optional[int] = None, 142 | ): 143 | self.data_source = data_source 144 | self.mini_repeat_count = mini_repeat_count 145 | self.batch_size = batch_size 146 | self.repeat_count = repeat_count 147 | self.num_samples = len(data_source) 148 | self.seed = seed 149 | self.generator = torch.Generator() # Create a local random generator 150 | if seed is not None: 151 | self.generator.manual_seed(seed) 152 | 153 | def __iter__(self): 154 | # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) 155 | indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() 156 | 157 | # [2, 4, 3, 1, 0, 6, 5] 158 | # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) 159 | indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] 160 | 161 | # [[2, 4, 3], [1, 0, 6], [5]] 162 | # -> [[2, 4, 3], [1, 0, 6]] 163 | indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] 164 | 165 | for chunk in indexes: 166 | for _ in range(self.repeat_count): 167 | for index in chunk: 168 | for _ in range(self.mini_repeat_count): 169 | yield index 170 | 171 | def __len__(self) -> int: 172 | return self.num_samples * self.mini_repeat_count * self.repeat_count 173 | 174 | 175 | class Qwen2VLGRPOTrainer(Trainer): 176 | """ 177 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the 178 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300). 179 | 180 | Example: 181 | 182 | ```python 183 | from datasets import load_dataset 184 | from trl import GRPOTrainer 185 | 186 | dataset = load_dataset("trl-lib/tldr", split="train") 187 | 188 | def reward_func(completions, **kwargs): 189 | # Dummy reward function that rewards completions with more unique letters. 190 | return [float(len(set(completion))) for completion in completions] 191 | 192 | trainer = GRPOTrainer( 193 | model="Qwen/Qwen2-0.5B-Instruct", 194 | reward_funcs=reward_func, 195 | train_dataset=dataset, 196 | ) 197 | 198 | trainer.train() 199 | ``` 200 | 201 | Args: 202 | model (`Union[str, PreTrainedModel]`): 203 | Model to be trained. Can be either: 204 | 205 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or 206 | a path to a *directory* containing model weights saved using 207 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is 208 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments 209 | in `args.model_init_kwargs`. 210 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported. 211 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`): 212 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward 213 | functions with the prompts and completions and sum the rewards. Can be either: 214 | 215 | - A single reward function, such as: 216 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a 217 | path to a *directory* containing model weights saved using 218 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded 219 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the 220 | keyword arguments in `args.model_init_kwargs`. 221 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported. 222 | - A custom reward function: The function is provided with the prompts and the generated completions, 223 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see 224 | [Using a custom reward function](#using-a-custom-reward-function). 225 | - A list of reward functions, where each item can independently be any of the above types. Mixing different 226 | types within the list (e.g., a string model ID and a custom reward function) is allowed. 227 | args ([`GRPOConfig`], *optional*, defaults to `None`): 228 | Configuration for this trainer. If `None`, a default configuration is used. 229 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]): 230 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is 231 | ignored. The format of the samples can be either: 232 | 233 | - [Standard](dataset_formats#standard): Each sample contains plain text. 234 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role 235 | and content). 236 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`): 237 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`. 238 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`): 239 | Processing class used to process the data. The padding side must be set to "left". If `None`, the 240 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`]. 241 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`): 242 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either: 243 | 244 | - A single processing class: Used when `reward_funcs` contains only one reward function. 245 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`. 246 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is 247 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`]. 248 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]), 249 | the corresponding entries in `reward_processing_classes` are ignored. 250 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`): 251 | List of callbacks to customize the training loop. Will add those to the list of default callbacks 252 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback). 253 | 254 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`] 255 | method. 256 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`): 257 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your 258 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`. 259 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`): 260 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped. 261 | """ 262 | 263 | _tag_names = ["trl", "grpo"] 264 | 265 | def __init__( 266 | self, 267 | model: Union[str, PreTrainedModel], 268 | reward_funcs: Union[RewardFunc, list[RewardFunc]], 269 | args: Optional[GRPOConfig] = None, 270 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None, 271 | eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None, 272 | processing_class: Optional[PreTrainedTokenizerBase] = None, 273 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None, 274 | callbacks: Optional[list[TrainerCallback]] = None, 275 | optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None), 276 | peft_config: Optional["PeftConfig"] = None, 277 | max_pixels: Optional[int] = 12845056, 278 | min_pixels: Optional[int] = 3136, 279 | attn_implementation: str = "flash_attention_2", 280 | torch_dtype: Optional[torch.dtype] = None, 281 | ): 282 | # Args 283 | if args is None: 284 | model_name = model if isinstance(model, str) else model.config._name_or_path 285 | model_name = model_name.split("/")[-1] 286 | args = GRPOConfig(f"{model_name}-GRPO") 287 | 288 | # Models 289 | # Trained model 290 | model_init_kwargs = args.model_init_kwargs or {} 291 | if isinstance(model, str): 292 | model_id = model 293 | torch_dtype = model_init_kwargs.get("torch_dtype") 294 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None: 295 | pass # torch_dtype is already a torch.dtype or "auto" or None 296 | elif isinstance(torch_dtype, str): # it's a str, but not "auto" 297 | torch_dtype = getattr(torch, torch_dtype) 298 | model_init_kwargs["torch_dtype"] = torch_dtype 299 | else: 300 | raise ValueError( 301 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing " 302 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}." 303 | ) 304 | # Disable caching if gradient checkpointing is enabled (not supported) 305 | model_init_kwargs["use_cache"] = ( 306 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache") 307 | ) 308 | model_init_kwargs["attn_implementation"] = attn_implementation 309 | if "Qwen2-VL" in model_id: 310 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 311 | elif "Qwen2.5-VL" in model_id: 312 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs) 313 | else: 314 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs) 315 | else: 316 | model_id = model.config._name_or_path 317 | if args.model_init_kwargs is not None: 318 | raise ValueError( 319 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. " 320 | "This argument can only be used when the `model` argument is a string." 321 | ) 322 | 323 | if peft_config is not None: 324 | if not is_peft_available(): 325 | raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.") 326 | model = get_peft_model(model, peft_config) 327 | 328 | # Enable gradient checkpointing if requested 329 | if args.gradient_checkpointing: 330 | model = self._enable_gradient_checkpointing(model, args) 331 | 332 | # Reference model 333 | self.beta = args.beta 334 | if self.beta == 0.0: 335 | # If beta is 0.0, the reference model is not needed 336 | self.ref_model = None 337 | if is_deepspeed_zero3_enabled(): 338 | if "Qwen2-VL" in model_id: 339 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 340 | elif "Qwen2.5-VL" in model_id: 341 | self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs) 342 | else: 343 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs) 344 | elif is_peft_model(model): 345 | # If PEFT is used, the reference model is not needed since the adapter can be disabled 346 | # to revert to the initial model. 347 | self.ref_model = None 348 | else: 349 | # If PEFT configuration is not provided, create a reference model based on the initial model. 350 | self.ref_model = create_reference_model(model) 351 | 352 | # Processing class 353 | if processing_class is None: 354 | if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id: 355 | processing_class = AutoProcessor.from_pretrained(model_id) 356 | pad_token_id = processing_class.tokenizer.pad_token_id 357 | processing_class.pad_token_id = pad_token_id 358 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id 359 | if "Qwen" in model_id or "Qwen2.5-VL" in model_id: 360 | processing_class.image_processor.max_pixels = max_pixels 361 | processing_class.image_processor.min_pixels = min_pixels 362 | else: 363 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left") 364 | pad_token_id = processing_class.pad_token_id 365 | 366 | # Reward functions 367 | if not isinstance(reward_funcs, list): 368 | reward_funcs = [reward_funcs] 369 | for i, reward_func in enumerate(reward_funcs): 370 | if isinstance(reward_func, str): 371 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained( 372 | reward_func, num_labels=1, **model_init_kwargs 373 | ) 374 | self.reward_funcs = reward_funcs 375 | 376 | # Reward weights 377 | if args.reward_weights is not None: 378 | if len(args.reward_weights) != len(reward_funcs): 379 | raise ValueError( 380 | f"Number of reward weights ({len(args.reward_weights)}) must match number of reward " 381 | f"functions ({len(reward_funcs)})" 382 | ) 383 | self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32) 384 | else: 385 | self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32) 386 | 387 | # Reward processing class 388 | if reward_processing_classes is None: 389 | reward_processing_classes = [None] * len(reward_funcs) 390 | elif not isinstance(reward_processing_classes, list): 391 | reward_processing_classes = [reward_processing_classes] 392 | else: 393 | if len(reward_processing_classes) != len(reward_funcs): 394 | raise ValueError("The number of reward processing classes must match the number of reward functions.") 395 | 396 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)): 397 | if isinstance(reward_func, PreTrainedModel): 398 | if reward_processing_class is None: 399 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path) 400 | if reward_processing_class.pad_token_id is None: 401 | reward_processing_class.pad_token = reward_processing_class.eos_token 402 | # The reward model computes the reward for the latest non-padded token in the input sequence. 403 | # So it's important to set the pad token ID to the padding token ID of the processing class. 404 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id 405 | reward_processing_classes[i] = reward_processing_class 406 | self.reward_processing_classes = reward_processing_classes 407 | 408 | # Data collator 409 | def data_collator(features): # No data collation is needed in GRPO 410 | return features 411 | 412 | # Training arguments 413 | self.max_prompt_length = args.max_prompt_length 414 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper 415 | self.num_generations = args.num_generations # = G in the GRPO paper 416 | self.use_vllm = args.use_vllm 417 | 418 | # Multi-step 419 | self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper 420 | self.epsilon = args.epsilon 421 | # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle. 422 | self._step = 0 423 | # Buffer the batch to reuse generated outputs across multiple updates. For more details, see 424 | # `_get_train_sampler` and `_prepare_inputs`. 425 | self._buffered_inputs = [None] * args.gradient_accumulation_steps 426 | 427 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the 428 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the 429 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning: 430 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To 431 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True. 432 | # This acts as a flag to indicate that the warning has already been issued. 433 | model.warnings_issued["estimate_tokens"] = True 434 | 435 | # Initialize the metrics 436 | self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)} 437 | self.log_completions = args.log_completions 438 | 439 | super().__init__( 440 | model=model, 441 | args=args, 442 | data_collator=data_collator, 443 | train_dataset=train_dataset, 444 | eval_dataset=eval_dataset, 445 | processing_class=processing_class, 446 | callbacks=callbacks, 447 | optimizers=optimizers, 448 | ) 449 | 450 | # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations 451 | num_processes = self.accelerator.num_processes 452 | global_batch_size = args.per_device_train_batch_size * num_processes 453 | possible_values = [n_gen for n_gen in range(1, global_batch_size + 1) if (global_batch_size) % n_gen == 0] 454 | 455 | if self.num_generations not in possible_values: 456 | raise ValueError( 457 | f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly " 458 | f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train " 459 | f"batch size, the valid values for the number of generations are: {possible_values}." 460 | ) 461 | if self.args.eval_strategy != "no": 462 | global_batch_size = args.per_device_eval_batch_size * num_processes 463 | possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0] 464 | if self.num_generations not in possible_values: 465 | raise ValueError( 466 | f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly " 467 | f"divisible by the number of generations per prompt ({self.num_generations}). Given the current " 468 | f"eval batch size, the valid values for the number of generations are: {possible_values}." 469 | ) 470 | 471 | # Ensure each process receives a unique seed to prevent duplicate completions when generating with 472 | # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but 473 | # it's safer to set it in all cases. 474 | set_seed(args.seed, device_specific=True) 475 | 476 | self.generation_config = GenerationConfig( 477 | max_new_tokens=self.max_completion_length, 478 | do_sample=True, 479 | temperature=args.temperature, 480 | num_return_sequences=self.num_generations, 481 | pad_token_id=pad_token_id, 482 | ) 483 | 484 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the 485 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set 486 | # self.model_accepts_loss_kwargs to False to enable scaling. 487 | self.model_accepts_loss_kwargs = False 488 | 489 | # Add tags to the model 490 | self.model.add_model_tags(self._tag_names) 491 | 492 | if self.ref_model is not None: 493 | if self.is_deepspeed_enabled: 494 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator) 495 | else: 496 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True) 497 | 498 | if args.sync_ref_model: 499 | self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator)) 500 | 501 | for i, reward_func in enumerate(self.reward_funcs): 502 | if isinstance(reward_func, PreTrainedModel): 503 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True) 504 | 505 | def _set_signature_columns_if_needed(self): 506 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed. 507 | # By default, this method sets `self._signature_columns` to the model's expected inputs. 508 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work. 509 | # Instead, we set them to the columns expected by the `training_step` method, hence the override. 510 | if self._signature_columns is None: 511 | self._signature_columns = ["prompt"] 512 | 513 | def _get_train_sampler(self) -> Sampler: 514 | # Returns a sampler that 515 | # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are 516 | # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt 517 | # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies 518 | # in group formation. 519 | # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to 520 | # _prepare_inputs to see how the generations are stored and reused. 521 | 522 | # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the 523 | # second row shows the second sampled batch, and so on. 524 | # 525 | # | GPU 0 | GPU 1 | GPU 2 | 526 | # 527 | # global_step step <───────> num_generations=3 528 | # <───────────> per_device_train_batch_size=4 529 | # ▲ 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │ 530 | # grad_accum=3 │ 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ Generate completions for each prompt 531 | # ▼ 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │ 532 | # 533 | # 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ The sampled prompts are the same as in the first iteration 534 | # 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ Reuse the completions (here, once, because num_iterations=2) 535 | # 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │ 536 | # 537 | # 2 6 12 12 12 13 13 13 14 14 14 15 15 15 538 | # 2 7 16 16 16 17 17 17 18 18 18 19 19 19 539 | # 2 8 20 20 20 21 21 21 22 22 22 23 23 23 540 | # ... 541 | effective_batch_size = ( 542 | self.args.per_device_train_batch_size 543 | * self.accelerator.num_processes 544 | * self.args.gradient_accumulation_steps 545 | ) 546 | return RepeatRandomSampler( 547 | data_source=self.train_dataset, 548 | mini_repeat_count=self.num_generations, 549 | batch_size=effective_batch_size // self.num_generations, 550 | repeat_count=self.num_iterations, 551 | seed=self.args.seed, 552 | ) 553 | 554 | def _get_eval_sampler(self, eval_dataset) -> Sampler: 555 | # See _get_train_sampler for an explanation of the sampler. 556 | return RepeatRandomSampler( 557 | data_source=eval_dataset, 558 | mini_repeat_count=self.num_generations, 559 | seed=self.args.seed, 560 | ) 561 | 562 | def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel: 563 | """Enables gradient checkpointing for the model.""" 564 | # Ensure use_cache is disabled 565 | model.config.use_cache = False 566 | 567 | # Enable gradient checkpointing on the base model for PEFT 568 | if is_peft_model(model): 569 | model.base_model.gradient_checkpointing_enable() 570 | # Enable gradient checkpointing for non-PEFT models 571 | else: 572 | model.gradient_checkpointing_enable() 573 | 574 | gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {} 575 | use_reentrant = ( 576 | "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"] 577 | ) 578 | 579 | if use_reentrant: 580 | model.enable_input_require_grads() 581 | 582 | return model 583 | 584 | # Get the per-token log probabilities for the completions for the model and the reference model 585 | @profiling_decorator 586 | def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep): 587 | # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded 588 | logits = model( 589 | input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw 590 | ).logits 591 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred 592 | 593 | input_ids = input_ids[:, -logits_to_keep:] 594 | # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. 595 | # See https://github.com/huggingface/trl/issues/2770 596 | logits = logits[:, -logits_to_keep:] 597 | return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens 598 | 599 | @profiling_decorator 600 | def _move_model_to_vllm(self): 601 | with unwrap_model_for_generation( 602 | self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation 603 | ) as unwrapped_model: 604 | if is_compiled_module(unwrapped_model): 605 | unwrapped_model = unwrapped_model._orig_mod 606 | if is_peft_model(unwrapped_model): 607 | unwrapped_model.merge_adapter() 608 | state_dict = unwrapped_model.state_dict() 609 | # Remove base_model and base_layer prefixes 610 | state_dict = { 611 | k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items() 612 | } 613 | # Remove values with adapter prefix (example: "_lora") 614 | state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k} 615 | # When module to save, remove its prefix and discard the original module 616 | state_dict = { 617 | k.replace("modules_to_save.default.", ""): v 618 | for k, v in state_dict.items() 619 | if "original_module" not in k 620 | } 621 | else: 622 | state_dict = unwrapped_model.state_dict() 623 | if self.accelerator.is_main_process: 624 | llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model 625 | llm_model.load_weights(state_dict.items()) 626 | # Unmerge the adapter to restore the model to its original state. 627 | # This must be done after loading weights to ensure they correspond to the merged state. 628 | if is_peft_model(unwrapped_model): 629 | unwrapped_model.unmerge_adapter() 630 | 631 | @profiling_decorator 632 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]: 633 | mode = "eval" if self.control.should_evaluate else "train" 634 | if mode == "train": 635 | if self.state.global_step % self.num_iterations == 0: 636 | inputs = self._generate_and_score_completions(inputs) 637 | self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs 638 | else: 639 | inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] 640 | self._step += 1 641 | else: 642 | # In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs. 643 | inputs = self._generate_and_score_completions(inputs) 644 | return inputs 645 | 646 | def _generate_and_score_completions( 647 | self, inputs: dict[str, Union[torch.Tensor, Any]] 648 | ) -> dict[str, Union[torch.Tensor, Any]]: 649 | device = self.accelerator.device 650 | prompts = [x["prompt"] for x in inputs] 651 | prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs] 652 | 653 | # Handle both pre-loaded images and image paths 654 | images = [] 655 | for x in inputs: 656 | if "image" in x: 657 | img = x["image"] 658 | else: 659 | img = PIL.Image.open(x["image_path"]) 660 | 661 | # Ensure minimum dimensions of 28 pixels 662 | w, h = img.size 663 | if w < 28 or h < 28: 664 | # Calculate new dimensions maintaining aspect ratio 665 | if w < h: 666 | new_w = 28 667 | new_h = int(h * (28 / w)) 668 | else: 669 | new_h = 28 670 | new_w = int(w * (28 / h)) 671 | img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS) 672 | elif w > 512 or h > 512: 673 | # Calculate new dimensions maintaining aspect ratio for large images 674 | if w > h: 675 | new_w = 512 676 | new_h = int(h * (512 / w)) 677 | else: 678 | new_h = 512 679 | new_w = int(w * (512 / h)) 680 | else: 681 | # Image is within acceptable dimensions, no resize needed 682 | new_w, new_h = w, h 683 | 684 | # Only resize if dimensions changed 685 | if new_w != w or new_h != h: 686 | img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS) 687 | 688 | images.append(img) 689 | 690 | prompt_inputs = self.processing_class( 691 | text=prompts_text, 692 | images=images, 693 | return_tensors="pt", 694 | padding=True, 695 | padding_side="left", 696 | add_special_tokens=False, 697 | ) 698 | prompt_inputs = super()._prepare_inputs(prompt_inputs) 699 | 700 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"] 701 | pixel_values = prompt_inputs["pixel_values"] 702 | image_grid_thw = prompt_inputs["image_grid_thw"] 703 | 704 | if self.max_prompt_length is not None: 705 | prompt_ids = prompt_ids[:, -self.max_prompt_length :] 706 | prompt_mask = prompt_mask[:, -self.max_prompt_length :] 707 | 708 | # Generate completions using either vLLM or regular generation 709 | if self.args.use_vllm: 710 | # First, have main process load weights if needed 711 | if self.state.global_step != self._last_loaded_step: 712 | self._move_model_to_vllm() 713 | self._last_loaded_step = self.state.global_step 714 | 715 | # Generate completions using vLLM: gather all prompts and use them in a single call in the main process 716 | all_prompts_text = gather_object(prompts_text) 717 | if self.accelerator.is_main_process: 718 | # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate 719 | # num_generations outputs for each one. This is faster than generating outputs for each duplicate 720 | # prompt individually. 721 | ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text)) 722 | with profiling_context(self, "vLLM.generate"): 723 | all_outputs = self.llm.generate( 724 | ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False 725 | ) 726 | completion_ids = [] 727 | for outputs in all_outputs: 728 | for output in outputs.outputs: 729 | completion_ids.append(output.token_ids) 730 | else: 731 | completion_ids = [None] * len(all_prompts_text) 732 | # Broadcast the completions from the main process to all processes, ensuring each process receives its 733 | # corresponding slice. 734 | completion_ids = broadcast_object_list(completion_ids, from_process=0) 735 | process_slice = slice( 736 | self.accelerator.process_index * len(prompts), 737 | (self.accelerator.process_index + 1) * len(prompts), 738 | ) 739 | completion_ids = completion_ids[process_slice] 740 | 741 | # Pad the completions, and concatenate them with the prompts 742 | completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids] 743 | completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id) 744 | prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1) 745 | else: 746 | # Regular generation path 747 | with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model: 748 | prompt_completion_ids = unwrapped_model.generate( 749 | **prompt_inputs, generation_config=self.generation_config 750 | ) 751 | 752 | # Compute prompt length and extract completion ids 753 | prompt_length = prompt_ids.size(1) 754 | prompt_ids = prompt_completion_ids[:, :prompt_length] 755 | completion_ids = prompt_completion_ids[:, prompt_length:] 756 | prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0) 757 | 758 | # Mask everything after the first EOS token 759 | is_eos = completion_ids == self.processing_class.eos_token_id 760 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device) 761 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)] 762 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1) 763 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int() 764 | 765 | # Concatenate prompt_mask with completion_mask for logit computation 766 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C) 767 | # Repeat image inputs to match batch size after generation 768 | if pixel_values is not None: 769 | pixel_values = pixel_values.repeat_interleave(self.num_generations, dim=0) 770 | if image_grid_thw is not None: 771 | image_grid_thw = image_grid_thw.repeat_interleave(self.num_generations, dim=0) 772 | 773 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens 774 | 775 | with torch.inference_mode(): 776 | # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's 777 | # computation here, and use per_token_logps.detach() instead. 778 | if self.num_iterations > 1: 779 | old_per_token_logps = self._get_per_token_logps( 780 | self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep 781 | ) 782 | else: 783 | old_per_token_logps = None 784 | 785 | if self.beta == 0.0: 786 | ref_per_token_logps = None 787 | elif self.ref_model is not None: 788 | ref_per_token_logps = self._get_per_token_logps( 789 | self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep 790 | ) 791 | else: 792 | with self.accelerator.unwrap_model(self.model).disable_adapter(): 793 | ref_per_token_logps = self._get_per_token_logps( 794 | self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep 795 | ) 796 | 797 | # Decode the generated completions 798 | completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True) 799 | if is_conversational(inputs[0]): 800 | completions = [] 801 | for prompt, completion in zip(prompts, completions_text): 802 | bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else "" 803 | completions.append([{"role": "assistant", "content": bootstrap + completion}]) 804 | else: 805 | completions = completions_text 806 | 807 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device) 808 | for i, (reward_func, reward_processing_class) in enumerate( 809 | zip(self.reward_funcs, self.reward_processing_classes) 810 | ): 811 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models 812 | reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}" 813 | else: 814 | reward_func_name = reward_func.__name__ 815 | with profiling_context(self, reward_func_name): 816 | if isinstance( 817 | reward_func, nn.Module 818 | ): # Module instead of PretrainedModel for compat with compiled models 819 | if is_conversational(inputs[0]): 820 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)] 821 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages] 822 | else: 823 | texts = [p + c for p, c in zip(prompts, completions)] 824 | reward_inputs = reward_processing_class( 825 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False 826 | ) 827 | reward_inputs = super()._prepare_inputs(reward_inputs) 828 | with torch.inference_mode(): 829 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,) 830 | else: 831 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations 832 | keys = [key for key in inputs[0] if key not in ["prompt", "completion"]] 833 | reward_kwargs = {key: [example[key] for example in inputs] for key in keys} 834 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs) 835 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device) 836 | 837 | # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the 838 | # completions may be distributed across processes 839 | rewards_per_func = gather(rewards_per_func) 840 | 841 | # Apply weights to each reward function's output and sum 842 | rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1) 843 | 844 | # Compute grouped-wise rewards 845 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1) 846 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1) 847 | 848 | # Normalize the rewards to compute the advantages 849 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 850 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) 851 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) 852 | 853 | # Slice to keep only the local part of the data 854 | process_slice = slice( 855 | self.accelerator.process_index * len(prompts), 856 | (self.accelerator.process_index + 1) * len(prompts), 857 | ) 858 | advantages = advantages[process_slice] 859 | 860 | # Log the metrics 861 | mode = "eval" if self.control.should_evaluate else "train" 862 | 863 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item() 864 | self._metrics[mode]["completion_length"].append(completion_length) 865 | 866 | reward_per_func = rewards_per_func.mean(0) 867 | for i, reward_func in enumerate(self.reward_funcs): 868 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models 869 | reward_func_name = reward_func.config._name_or_path.split("/")[-1] 870 | else: 871 | reward_func_name = reward_func.__name__ 872 | self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item()) 873 | 874 | self._metrics[mode]["reward"].append(rewards.mean().item()) 875 | self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item()) 876 | 877 | if self.log_completions and self.state.global_step % self.args.logging_steps == 0: 878 | prompts_to_log = gather_object(prompts_text) 879 | completions_to_log = gather_object(completions_text) 880 | rewards_to_log = rewards.tolist() 881 | 882 | if self.accelerator.is_main_process: 883 | if is_rich_available(): 884 | print_prompt_completions_sample( 885 | prompts_to_log, 886 | completions_to_log, 887 | rewards_to_log, 888 | self.state.global_step, 889 | ) 890 | if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None: 891 | import pandas as pd 892 | 893 | # For logging 894 | table = { 895 | "step": [str(self.state.global_step)] * len(rewards), 896 | "prompt": prompts_to_log, 897 | "completion": completions_to_log, 898 | "reward": rewards.tolist(), 899 | } 900 | df = pd.DataFrame(table) 901 | wandb.log({"completions": wandb.Table(dataframe=df)}) 902 | 903 | return { 904 | "prompt_ids": prompt_ids, 905 | "prompt_mask": prompt_mask, 906 | "pixel_values": pixel_values, 907 | "image_grid_thw": image_grid_thw, 908 | "completion_ids": completion_ids, 909 | "completion_mask": completion_mask, 910 | "old_per_token_logps": old_per_token_logps, 911 | "ref_per_token_logps": ref_per_token_logps, 912 | "advantages": advantages, 913 | } 914 | 915 | @profiling_decorator 916 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 917 | if return_outputs: 918 | raise ValueError("The GRPOTrainer does not support returning outputs") 919 | # Compute the per-token log probabilities for the model 920 | 921 | prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] 922 | completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"] 923 | pixel_values, image_grid_thw = inputs["pixel_values"], inputs["image_grid_thw"] 924 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1) 925 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) 926 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens 927 | 928 | per_token_logps = self._get_per_token_logps( 929 | model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep 930 | ) 931 | 932 | # Compute the KL divergence between the model and the reference model 933 | if self.beta != 0.0: 934 | ref_per_token_logps = inputs["ref_per_token_logps"] 935 | per_token_kl = ( 936 | torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1 937 | ) 938 | 939 | # Compute the loss 940 | advantages = inputs["advantages"] 941 | # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see 942 | # _generate_and_score_completions) and use per_token_logps.detach() instead. 943 | old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach() 944 | coef_1 = torch.exp(per_token_logps - old_per_token_logps) 945 | coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon) 946 | per_token_loss1 = coef_1 * advantages.unsqueeze(1) 947 | per_token_loss2 = coef_2 * advantages.unsqueeze(1) 948 | per_token_loss = -torch.min(per_token_loss1, per_token_loss2) 949 | if self.beta != 0.0: 950 | per_token_loss = per_token_loss + self.beta * per_token_kl 951 | loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() 952 | 953 | # Log the metrics 954 | mode = "eval" if self.control.should_evaluate else "train" 955 | 956 | if self.beta != 0.0: 957 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean() 958 | self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item()) 959 | 960 | is_clipped = (per_token_loss1 < per_token_loss2).float() 961 | clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() 962 | self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item()) 963 | return loss 964 | 965 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None): 966 | inputs = self._prepare_inputs(inputs) 967 | with torch.no_grad(): 968 | with self.compute_loss_context_manager(): 969 | loss = self.compute_loss(model, inputs) 970 | loss = loss.mean().detach() 971 | return loss, None, None 972 | 973 | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None: 974 | mode = "eval" if self.control.should_evaluate else "train" 975 | metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics 976 | 977 | # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs` 978 | # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format. 979 | if mode == "eval": 980 | metrics = {f"eval_{key}": val for key, val in metrics.items()} 981 | 982 | logs = {**logs, **metrics} 983 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"): 984 | super().log(logs, start_time) 985 | else: # transformers<=4.46 986 | super().log(logs) 987 | self._metrics[mode].clear() 988 | 989 | def create_model_card( 990 | self, 991 | model_name: Optional[str] = None, 992 | dataset_name: Optional[str] = None, 993 | tags: Union[str, list[str], None] = None, 994 | ): 995 | """ 996 | Creates a draft of a model card using the information available to the `Trainer`. 997 | 998 | Args: 999 | model_name (`str` or `None`, *optional*, defaults to `None`): 1000 | Name of the model. 1001 | dataset_name (`str` or `None`, *optional*, defaults to `None`): 1002 | Name of the dataset used for training. 1003 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`): 1004 | Tags to be associated with the model card. 1005 | """ 1006 | if not self.is_world_process_zero(): 1007 | return 1008 | 1009 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path): 1010 | base_model = self.model.config._name_or_path 1011 | else: 1012 | base_model = None 1013 | 1014 | tags = tags or [] 1015 | if isinstance(tags, str): 1016 | tags = [tags] 1017 | 1018 | if hasattr(self.model.config, "unsloth_version"): 1019 | tags.append("unsloth") 1020 | 1021 | citation = textwrap.dedent( 1022 | """\ 1023 | @article{zhihong2024deepseekmath, 1024 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}}, 1025 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo}, 1026 | year = 2024, 1027 | eprint = {arXiv:2402.03300}, 1028 | } 1029 | """ 1030 | ) 1031 | 1032 | model_card = generate_model_card( 1033 | base_model=base_model, 1034 | model_name=model_name, 1035 | hub_model_id=self.hub_model_id, 1036 | dataset_name=dataset_name, 1037 | tags=tags, 1038 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None, 1039 | comet_url=get_comet_experiment_url(), 1040 | trainer_name="GRPO", 1041 | trainer_citation=citation, 1042 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models", 1043 | paper_id="2402.03300", 1044 | ) 1045 | 1046 | model_card.save(os.path.join(self.args.output_dir, "README.md")) 1047 | -------------------------------------------------------------------------------- /run_grpo_vlm.sh: -------------------------------------------------------------------------------- 1 | export WANDB_PROJECT="grpo-qwen-2-2B" 2 | export WANDB_ENTITY="" 3 | export WANDB_API_KEY="" 4 | export WANDB_NAME="grpo-qwen-2-2B-v-8gpus-one_reward-synthetic-data-check" 5 | 6 | 7 | RUN_NAME="Qwen2-VL-2B-GRPO" 8 | export LOG_PATH="./debug_log_$RUN_NAME.txt" 9 | 10 | torchrun --nproc_per_node="8" \ 11 | --nnodes="1" \ 12 | --node_rank="0" \ 13 | --master_addr="127.0.0.1" \ 14 | --master_port="12346" \ 15 | run_r1_grpo_vlm.py \ 16 | --deepspeed configs/zero3.json \ 17 | --config configs/grpo-qwen-2.5-v.yaml \ 18 | --json_data_path "./data/textvqa_cot_train_1_bbox_0.json" \ 19 | --report_to wandb \ 20 | --gradient_checkpointing false \ 21 | --attn_implementation flash_attention_2 \ 22 | --num_train_epochs 1 \ 23 | --save_only_model true 24 | -------------------------------------------------------------------------------- /run_r1_grpo_vlm.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from datetime import datetime 5 | import random 6 | import re 7 | import torch 8 | import yaml 9 | 10 | from transformers.trainer_utils import get_last_checkpoint 11 | 12 | from grpo_config import GRPOConfig 13 | 14 | import torch 15 | from transformers import ( 16 | Qwen2_5_VLForConditionalGeneration, 17 | Qwen2VLForConditionalGeneration, 18 | AutoTokenizer, 19 | AutoProcessor, 20 | ) 21 | 22 | import datasets 23 | from datasets import load_dataset 24 | from torch.utils.data import Dataset 25 | 26 | import sys 27 | import math 28 | from typing import Optional, Tuple 29 | 30 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( 31 | Qwen2_5_VLVisionFlashAttention2, 32 | apply_rotary_pos_emb_flashatt, 33 | flash_attn_varlen_func, 34 | ) 35 | import torch 36 | from typing import Tuple 37 | 38 | 39 | # FlashAttention fix from https://github.com/om-ai-lab/VLM-R1/blob/main/src/open-r1-multimodal/src/open_r1/grpo_rec.py 40 | def custom_forward( 41 | self, 42 | hidden_states: torch.Tensor, 43 | cu_seqlens: torch.Tensor, 44 | rotary_pos_emb: Optional[torch.Tensor] = None, 45 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, 46 | ) -> torch.Tensor: 47 | seq_length = hidden_states.shape[0] 48 | q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0) 49 | if position_embeddings is None: 50 | logger.warning_once( 51 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally " 52 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed " 53 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be " 54 | "removed and `position_embeddings` will be mandatory." 55 | ) 56 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1) 57 | cos = emb.cos().float() 58 | sin = emb.sin().float() 59 | else: 60 | cos, sin = position_embeddings 61 | # Add this 62 | cos = cos.to(torch.float) 63 | sin = sin.to(torch.float) 64 | q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin) 65 | q = q.squeeze(0) 66 | k = k.squeeze(0) 67 | 68 | max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item() 69 | attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape( 70 | seq_length, -1 71 | ) 72 | attn_output = self.proj(attn_output) 73 | return attn_output 74 | 75 | 76 | Qwen2_5_VLVisionFlashAttention2.forward = custom_forward 77 | 78 | 79 | from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config 80 | 81 | from grpo_qwen2vl import Qwen2VLGRPOTrainer 82 | from PIL import Image 83 | 84 | try: 85 | from math_verify import parse, verify 86 | except ImportError: 87 | 88 | def parse(x): 89 | return x 90 | 91 | def verify(x, y): 92 | return float(x == y) 93 | 94 | 95 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" 96 | 97 | # Optional: wandb environment 98 | os.environ["WANDB_PROJECT"] = "grpo-qwen-2-2B" 99 | os.environ["WANDB_ENTITY"] = "" 100 | os.environ["WANDB_API_KEY"] = "3d726fd76bb1ed0c15a7004731707d54572acef0" 101 | os.environ["WANDB_NAME"] = "grpo-qwen-2-2B-v-8gpus-one_reward-synthetic-data-check" 102 | 103 | 104 | @dataclass 105 | class ScriptArguments: 106 | dataset_id_or_path: str = "Jiayi-Pan/Countdown-Tasks-3to4" 107 | dataset_splits: str = "train" 108 | tokenizer_name_or_path: str = None 109 | max_pixels: int = 12845056 110 | min_pixels: int = 3136 111 | image_root: str = "" 112 | config_path: str = "config.yaml" 113 | json_data_path: str = "" 114 | 115 | 116 | logging.basicConfig(level=logging.INFO) 117 | logger = logging.getLogger(__name__) 118 | logger.setLevel(logging.INFO) 119 | handler = logging.StreamHandler() 120 | handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")) 121 | logger.addHandler(handler) 122 | 123 | 124 | def relaxed_bbox_iou(pred_bbox, gt_bbox, threshold=0.5): 125 | x1 = max(pred_bbox[0], gt_bbox[0]) 126 | y1 = max(pred_bbox[1], gt_bbox[1]) 127 | x2 = min(pred_bbox[2], gt_bbox[2]) 128 | y2 = min(pred_bbox[3], gt_bbox[3]) 129 | 130 | intersection = max(0, x2 - x1) * max(0, y2 - y1) 131 | pred_area = (pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1]) 132 | gt_area = (gt_bbox[2] - gt_bbox[0]) * (gt_bbox[3] - gt_bbox[1]) 133 | union = pred_area + gt_area - intersection 134 | iou = intersection / union if union > 0 else 0 135 | 136 | return iou >= threshold 137 | 138 | 139 | def relaxed_correctness(prediction: str, target: str, max_relative_change: float = 0.05) -> bool: 140 | def extract_first_number(text: str): 141 | match = re.search(r"[-]?\d[\d,]*(\.\d+)?%?", text) 142 | return match.group(0) if match else text 143 | 144 | def to_float(text: str): 145 | text = text.strip().lower() 146 | text = text.replace(",", "") 147 | if text.endswith("%"): 148 | try: 149 | val = float(text.rstrip("%")) 150 | return val / 100.0 151 | except ValueError: 152 | return None 153 | else: 154 | try: 155 | return float(text) 156 | except ValueError: 157 | return None 158 | 159 | pred_num_str = extract_first_number(prediction) 160 | tgt_num_str = extract_first_number(target) 161 | pred_float = to_float(pred_num_str) 162 | tgt_float = to_float(tgt_num_str) 163 | if pred_float is not None and tgt_float is not None: 164 | if abs(tgt_float) < 1e-12: 165 | return abs(pred_float - tgt_float) < 1e-12 166 | relative_change = abs(pred_float - tgt_float) / abs(tgt_float) 167 | return relative_change <= max_relative_change 168 | return prediction.strip().lower() == target.strip().lower() 169 | 170 | 171 | def grounding_reward(completions, target, **kwargs): 172 | """ 173 | Reward function that checks bounding boxes. We keep your original logic. 174 | """ 175 | completions = [c[0]["content"] for c in completions] 176 | rewards = [] 177 | for completion, gt_bbox_list in zip(completions, target): 178 | try: 179 | bbox_match = re.search(r".*?\[(.*?)\].*?", completion, re.DOTALL) 180 | if bbox_match: 181 | pred_bbox = [float(x.strip()) for x in bbox_match.group(1).split(",")] 182 | gt_bbox = [float(x) for x in gt_bbox_list[0]] # given your data format 183 | reward = 1.0 if relaxed_bbox_iou(pred_bbox, gt_bbox) else 0.0 184 | else: 185 | reward = 0.0 186 | except Exception: 187 | reward = 0.0 188 | 189 | rewards.append(reward) 190 | # Log a small fraction 191 | if random.random() < 0.1: 192 | os.makedirs("completion_samples", exist_ok=True) 193 | log_file = os.path.join("completion_samples", "grounding_samples.txt") 194 | with open(log_file, "a") as f: 195 | f.write(f"------------- Grounding reward: {reward} -------------\n") 196 | f.write(f"Content: {completion}\n") 197 | f.write(f"GT bbox: {gt_bbox_list}\n") 198 | 199 | return rewards 200 | 201 | 202 | def format_reward(completions, **kwargs): 203 | """A simpler check for ... and ... (not checking correctness).""" 204 | completions = [c[0]["content"] for c in completions] 205 | completions = ["" + c if c.startswith("") else c for c in completions] 206 | pattern = r".*?\s*.*?" 207 | matches = [re.fullmatch(pattern, c, re.DOTALL) for c in completions] 208 | return [1.0 if m else 0.0 for m in matches] 209 | 210 | 211 | def get_checkpoint(training_args): 212 | if os.path.isdir(training_args.output_dir): 213 | return get_last_checkpoint(training_args.output_dir) 214 | return None 215 | 216 | 217 | class CustomDataset: 218 | def __init__(self, list_data_dict=None, script_args=None, processor=None, json_data_path=None): 219 | super(CustomDataset, self).__init__() 220 | self.script_args = script_args 221 | self.processor = processor 222 | self.SYSTEM_PROMPT = ( 223 | "You are a Vision Language Model specialized in visual grounding in [x1, y1, x2, y2]." 224 | ) 225 | self.QUESTION_TEMPLATE = "Provide bounding box for the region of the image relevant to the asked question: {Question}. First output the thinking process in tags and then output the final answer in [x1, y1, x2, y2] tags." 226 | 227 | # Load data from JSON file if path is provided 228 | if json_data_path and os.path.exists(json_data_path): 229 | import json 230 | 231 | with open(json_data_path, "r") as f: 232 | data = json.load(f) 233 | self.list_data_dict = data 234 | else: 235 | self.list_data_dict = list_data_dict or [] 236 | 237 | def __len__(self): 238 | return len(self.list_data_dict) 239 | 240 | def __getitem__(self, i): 241 | # Format into conversation 242 | def make_conversation(example): 243 | return { 244 | "prompt": [ 245 | {"role": "system", "content": self.SYSTEM_PROMPT}, 246 | {"role": "user", "content": example["question"]}, 247 | ], 248 | } 249 | 250 | def make_conversation_image(example): 251 | return { 252 | "prompt": [ 253 | { 254 | "role": "user", 255 | "content": [ 256 | {"type": "image"}, 257 | { 258 | "type": "text", 259 | "text": self.QUESTION_TEMPLATE.format(Question=example["question"]) 260 | + ' {"bbox": [x1, y1, x2, y2]} tags, ' 261 | + "where x1, y1, x2, y2 are integers representing the coordinates.", 262 | }, 263 | ], 264 | }, 265 | ], 266 | } 267 | 268 | example = self.list_data_dict[i] 269 | image_root = self.script_args.image_root 270 | if "image" in example: 271 | image_path = os.path.join(image_root, example["image"]) 272 | # In case the image is not found 273 | while not os.path.exists(image_path): 274 | print(f"Warning: Image {image_path} not found, randomly selecting another image") 275 | new_index = random.randint(0, len(self.list_data_dict) - 1) 276 | example = self.list_data_dict[new_index] 277 | image_path = os.path.join(image_root, example["image"]) 278 | image = Image.open(image_path).convert("RGB") 279 | 280 | # Resize image if needed to meet min/max size requirements 281 | w, h = image.size 282 | 283 | if w < 28 or h < 28: 284 | # Calculate new dimensions maintaining aspect ratio for small images 285 | if w < h: 286 | new_w = 28 287 | new_h = int(h * (28 / w)) 288 | else: 289 | new_h = 28 290 | new_w = int(w * (28 / h)) 291 | image = image.resize((new_w, new_h), Image.LANCZOS) 292 | elif w > 512 or h > 512: 293 | # Calculate new dimensions maintaining aspect ratio for large images 294 | if w > h: 295 | new_w = 512 296 | new_h = int(h * (512 / w)) 297 | else: 298 | new_h = 512 299 | new_w = int(w * (512 / h)) 300 | image = image.resize((new_w, new_h), Image.LANCZOS) 301 | else: 302 | # Image is within acceptable dimensions, no resize needed 303 | new_w, new_h = w, h 304 | else: 305 | image = None 306 | # print("Image size", image.size) 307 | return { 308 | "image": image, 309 | "question": example["question"], 310 | "target": example["bboxs"], 311 | "prompt": ( 312 | make_conversation_image(example)["prompt"] 313 | if "image" in example 314 | else make_conversation(example)["prompt"] 315 | ), 316 | } 317 | 318 | 319 | def grpo_function(model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig): 320 | logger.info(f"Model parameters {model_args}") 321 | logger.info(f"Training/evaluation parameters {training_args}") 322 | 323 | model_name = model_args.model_name_or_path 324 | tokenizer_path = script_args.tokenizer_name_or_path if script_args.tokenizer_name_or_path else model_name 325 | 326 | # Load the appropriate model and processor based on model name 327 | if "Qwen2.5-VL" in model_name: 328 | processor = AutoProcessor.from_pretrained( 329 | tokenizer_path, 330 | trust_remote_code=model_args.trust_remote_code, 331 | revision=model_args.model_revision, 332 | ) 333 | else: # Default to Qwen2-VL 334 | processor = AutoProcessor.from_pretrained( 335 | tokenizer_path, 336 | trust_remote_code=model_args.trust_remote_code, 337 | revision=model_args.model_revision, 338 | ) 339 | 340 | # Create CustomDataset instances 341 | train_dataset = CustomDataset( 342 | script_args=script_args, 343 | processor=processor, 344 | json_data_path=script_args.json_data_path, 345 | ) 346 | print(f"Created datasets with {len(train_dataset)} training examples") 347 | print(f"Sample example: {train_dataset[0]}") 348 | 349 | # Choose your reward functions 350 | chosen_reward_funcs = [grounding_reward, format_reward] 351 | 352 | trainer = Qwen2VLGRPOTrainer( 353 | model=model_args.model_name_or_path, 354 | reward_funcs=chosen_reward_funcs, 355 | args=training_args, 356 | train_dataset=train_dataset, 357 | peft_config=get_peft_config(model_args), 358 | attn_implementation=model_args.attn_implementation, 359 | max_pixels=script_args.max_pixels, 360 | min_pixels=script_args.min_pixels, 361 | torch_dtype=model_args.torch_dtype, 362 | ) 363 | 364 | last_checkpoint = get_checkpoint(training_args) 365 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 366 | logger.info(f"Resuming from checkpoint at {last_checkpoint}.") 367 | 368 | logger.info("*** Starting training ***") 369 | train_result = trainer.train(resume_from_checkpoint=last_checkpoint) 370 | metrics = train_result.metrics 371 | metrics["train_samples"] = len(train_dataset) 372 | trainer.log_metrics("train", metrics) 373 | trainer.save_metrics("train", metrics) 374 | trainer.save_state() 375 | 376 | logger.info("*** Training complete ***") 377 | logger.info("*** Save model ***") 378 | trainer.model.config.use_cache = True 379 | trainer.save_model(training_args.output_dir) 380 | logger.info(f"Model saved to {training_args.output_dir}") 381 | 382 | training_args.distributed_state.wait_for_everyone() 383 | processor.save_pretrained(training_args.output_dir) 384 | logger.info(f"Processor saved to {training_args.output_dir}") 385 | 386 | if trainer.accelerator.is_main_process: 387 | trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]}) 388 | 389 | if training_args.push_to_hub: 390 | logger.info("Pushing to hub...") 391 | trainer.push_to_hub() 392 | 393 | logger.info("*** Done ***") 394 | 395 | 396 | def main(): 397 | from trl import TrlParser 398 | 399 | parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig)) 400 | model_args, script_args, training_args = parser.parse_args_and_config() 401 | 402 | # Load config from config file if not set directly via command line 403 | if not script_args.json_data_path and script_args.config_path and os.path.exists(script_args.config_path): 404 | with open(script_args.config_path, "r") as f: 405 | config = yaml.safe_load(f) 406 | if config and "json_data_path" in config: 407 | script_args.json_data_path = config.get("json_data_path", "") 408 | 409 | grpo_function(model_args, script_args, training_args) 410 | 411 | 412 | if __name__ == "__main__": 413 | main() 414 | --------------------------------------------------------------------------------