├── .DS_Store
├── .gitignore
├── LICENSE
├── README.md
├── configs
├── grpo-qwen-2.5-v.yaml
└── zero3.json
├── data
└── textvqa_cot_train_1_bbox_0.json
├── grpo_qwen2vl.py
├── run_grpo_vlm.sh
└── run_r1_grpo_vlm.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/FusionBrainLab/Vision_GRPO/65ec90d090f93dae1f303ea7eeed8c9d0c06f64b/.DS_Store
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Vision GRPO
2 |
3 | # Training Vision Language Models with GRPO for Visual Grounding
4 |
5 | Based on the recent advances in RL for reasoning enhance, we'll explore how to fine-tune Vision Language Models (VLMs) using ***Group Relative Policy Optimization*** (***GRPO***). We'll walk through a complete training pipeline, from dataset preparation to evaluating results.
6 |
7 | ## 1. Modified GRPO for Vision Language Models
8 |
9 | ### Adapting GRPO for Vision Language Models
10 |
11 | Based on the great tutorial [mini-R1](https://www.philschmid.de/mini-deepseek-r1) tutorial, we provided the modified version of the approach for training vision language models using the same reasoning approach. To adapt it for Vision Language Models, we need to:
12 |
13 | 1. **Handle Multimodal Inputs**: Process both images and text in the same framework
14 | 2. **Custom Reward Functions**: Create vision-specific rewards that evaluate how well the model identifies regions in images
15 | 3. **Specialized Architecture**: Use a vision-language model architecture (like Qwen2.5-VL) that can process both modalities
16 |
17 | Due to the fact that each Vision-Language model follows its own architecture, we are not able to use unified abstraction such as AutoModelforCausalLM for language models, so, our tutorial covers two common multimodal architectures for Qwen-VL (2 and 2.5).
18 |
19 | The modified `Qwen2VLGRPOTrainer` enables:
20 |
21 | - Processing of image-text pairs
22 | - Evaluation of visual grounding capabilities
23 | - Optimization of both text generation and region identification
24 |
25 | ```python
26 | # Example of the modified GRPO trainer integration
27 | trainer = Qwen2VLGRPOTrainer(
28 | model=model,
29 | reward_funcs=[grounding_reward, format_reward],
30 | args=training_args,
31 | train_dataset=train_dataset,
32 | eval_dataset=test_dataset,
33 | peft_config=get_peft_config(model_args)
34 | )
35 |
36 | ```
37 |
38 | ## 2. The Visual Grounding Task for CoT Reasoning
39 |
40 | ### What is Visual Grounding?
41 |
42 | Basically, vision grounding task is defined as a task to provide bounding boxes for the object defined in the input request. We look into the visual grounding as a suplementary task to help the model to provide correct answer on some complex question. This approach was investigated in [Visual CoT](https://arxiv.org/abs/2403.16999), where authors proposed to use visual grounding task to zoom into specific part of the image, where the answer is kept. In our tutorial, we use subsample of the textVQA dataset to show, whether we can teach the model to zoom in to the relevant parts of the image via RL.
43 |
44 | ### Task Formulation
45 |
46 | The task is structured as follows:
47 |
48 | 1. The model receives an image and a text query about a specific visual element
49 | 2. The model must:
50 | - Reason through the visual content (in `...` tags)
51 | - Output precise bounding box coordinates for the relevant region (in `[x1, y1, x2, y2]` format)
52 |
53 | ### Example Query
54 |
55 | ```
56 | Image: [Image of a living room]
57 | Query: Where is the red vase in this image? Show your reasoning in thinking process tags. Return bounding box in [x1, y1, x2, y2] tags.
58 | ```
59 |
60 | ### Expected Output
61 |
62 | ```
63 | Let me analyze this image.
64 |
65 | I can see a living room with various furniture. Looking for a red vase...
66 | I can see a red vase on the coffee table in the center of the image.
67 | It appears to be located approximately at the coordinates [220, 150, 260, 210].
68 |
69 | {"bbox": [220, 150, 260, 210]}
70 |
71 | ```
72 |
73 | ## 3. Dataset Preparation
74 |
75 | ### Dataset Structure
76 |
77 | For this tutorial, we use a vision chain-of-thought dataset specifically designed for visual grounding tasks:
78 |
79 | ```python
80 | import json
81 | import math
82 | from PIL import Image
83 | import os
84 |
85 | def process_jsonl_data(jsonl_file, train_path, output_file=None, max_size=512, maintain_aspect_ratio=True):
86 | """
87 | Process a JSONL file containing image metadata, resize images, and rescale bounding boxes.
88 |
89 | Parameters:
90 | -----------
91 | jsonl_file: str
92 | Path to the JSONL file
93 | train_path: str
94 | Path to the directory containing training images
95 | output_file: str, optional
96 | Path to save the processed dataset (if None, just returns the data)
97 | max_size: int, default=512
98 | Maximum dimension for resized images
99 | maintain_aspect_ratio: bool, default=True
100 | Whether to maintain aspect ratio when resizing
101 |
102 | Returns:
103 | --------
104 | list: Processed dataset
105 | """
106 | dataset = []
107 |
108 | # Count for statistics
109 | total_entries = 0
110 | skipped_entries = 0
111 | processed_entries = 0
112 |
113 | with open(jsonl_file, "r", encoding="utf-8") as f:
114 | for line in f:
115 | if not line.strip():
116 | # Skip any empty lines if present
117 | continue
118 |
119 | total_entries += 1
120 |
121 | try:
122 | data = json.loads(line)
123 |
124 | # Skip entries with multiple bounding boxes
125 | if len(data['bboxs']) > 1:
126 | skipped_entries += 1
127 | continue
128 |
129 | # Ensure image path is complete
130 | if not data['image'].startswith(train_path):
131 | data['image'] = os.path.join(train_path, data['image'])
132 |
133 | # Check if image exists
134 | if not os.path.exists(data['image']):
135 | print(f"Warning: Image not found at {data['image']}")
136 | skipped_entries += 1
137 | continue
138 |
139 | # Open and get dimensions of the image
140 | try:
141 | image = Image.open(data['image'])
142 | original_width, original_height = image.size
143 | except Exception as e:
144 | print(f"Error opening image {data['image']}: {e}")
145 | skipped_entries += 1
146 | continue
147 |
148 | # Determine new dimensions
149 | if maintain_aspect_ratio:
150 | if original_width > max_size or original_height > max_size:
151 | # Calculate new dimensions maintaining aspect ratio
152 | if original_width > original_height:
153 | new_width = max_size
154 | new_height = int(original_height * (max_size / original_width))
155 | else:
156 | new_height = max_size
157 | new_width = int(original_width * (max_size / original_height))
158 | else:
159 | # Image is within acceptable dimensions, no resize needed
160 | new_width, new_height = original_width, original_height
161 | else:
162 | # Fixed size without maintaining aspect ratio
163 | new_width, new_height = max_size, max_size
164 |
165 | # Only rescale bounding boxes if dimensions changed
166 | if new_width != original_width or new_height != original_height:
167 | # Calculate the scaling factors
168 | scale_x = new_width / original_width
169 | scale_y = new_height / original_height
170 |
171 | # Rescale all bounding boxes
172 | new_bboxs = []
173 | for original_bbox in data['bboxs']:
174 | # Adjust the bounding box coordinates
175 | new_bbox = [
176 | math.ceil(original_bbox[0] * scale_x),
177 | math.ceil(original_bbox[1] * scale_y),
178 | math.ceil(original_bbox[2] * scale_x),
179 | math.ceil(original_bbox[3] * scale_y)
180 | ]
181 | new_bboxs.append(new_bbox)
182 |
183 | # Update bounding boxes in the data
184 | data['bboxs'] = new_bboxs.copy()
185 |
186 | # Store the new dimensions in the data
187 | data['width'] = new_width
188 | data['height'] = new_height
189 |
190 | # Append processed data to the dataset
191 | dataset.append(data)
192 | processed_entries += 1
193 |
194 | # Print progress every 1000 entries
195 | if processed_entries % 1000 == 0:
196 | print(f"Processed {processed_entries} entries...")
197 |
198 | except Exception as e:
199 | print(f"Error processing line: {e}")
200 | skipped_entries += 1
201 |
202 | # Print statistics
203 | print(f"Total entries: {total_entries}")
204 | print(f"Processed entries: {processed_entries}")
205 | print(f"Skipped entries: {skipped_entries}")
206 |
207 | # Save processed dataset if output file is specified
208 | if output_file:
209 | with open(output_file, 'w', encoding='utf-8') as f:
210 | for item in dataset:
211 | f.write(json.dumps(item) + '\n')
212 | print(f"Saved processed dataset to {output_file}")
213 |
214 | return dataset
215 |
216 | # Example usage:
217 | if __name__ == "__main__":
218 | TRAIN_PATH = "./train_images/"
219 | JSONL_FILE = "./metadata/textvqa_cot_train.jsonl"
220 | OUTPUT_FILE = "processed_textvqa_train.jsonl"
221 |
222 | # Process the JSONL file
223 | processed_data = process_jsonl_data(
224 | jsonl_file=JSONL_FILE,
225 | train_path=TRAIN_PATH,
226 | output_file=OUTPUT_FILE,
227 | max_size=512,
228 | maintain_aspect_ratio=True
229 | )
230 |
231 | print(f"Processed dataset contains {len(processed_data)} entries")
232 |
233 | # Show a sample entry if available
234 | if processed_data:
235 | sample = processed_data[0]
236 | print("\nSample entry:")
237 | print(f"Question: {sample['question']}")
238 | print(f"Answer: {sample['answer']}")
239 | print(f"Image: {sample['image']}")
240 | print(f"Dimensions: {sample['width']}x{sample['height']}")
241 | print(f"Bounding boxes: {sample['bboxs']}")
242 |
243 | ```
244 |
245 | ### Generating Prompts for Training
246 |
247 | We format each example into a chat template for Qwen2.5-VL, using a system message that specifies the visual grounding task:
248 |
249 | ```python
250 | system_message = "You are a Vision Language Model specialized in visual grounding. Provide bounding box in [x1, y1, x2, y2] ."
251 |
252 | def generate_r1_prompt(sample):
253 | prefix = [
254 | {"role": "system", "content": [{"type": "text", "text": system_message}]},
255 | {
256 | "role": "user",
257 | "content": [
258 | {"type": "image", "image": sample["image"]},
259 | {
260 | "type": "text",
261 | "text": (
262 | sample["question"] + " Show your reasoning in thinking process tags. "
263 | "Return bounding box in [x1, y1, x2, y2] tags."
264 | ),
265 | },
266 | ],
267 | },
268 | {
269 | "role": "assistant",
270 | "content": [{"type": "text", "text": "Let me analyze this image.\n"}],
271 | },
272 | ]
273 | encoded_prompt = processor.apply_chat_template(prefix, tokenize=False, continue_final_message=True)
274 | return {"prompt": encoded_prompt, "target": sample["bboxs"]}
275 |
276 | # Apply prompt generation to dataset
277 | dataset = dataset.map(generate_r1_prompt)
278 |
279 | # Create train/test split
280 | train_test_split = dataset.train_test_split(test_size=0.1)
281 | train_dataset = train_test_split["train"]
282 | test_dataset = train_test_split["test"]
283 |
284 | ```
285 |
286 | ## 4. Launching Training
287 |
288 | ### Setting Up Reward Functions
289 |
290 | A key component of GRPO is the definition of reward functions. For visual grounding, we define multiple reward functions to evaluate different aspects of the model's output:
291 |
292 | ```python
293 | def grounding_reward(completions, target, **kwargs):
294 | """Reward function that checks bounding boxes."""
295 | rewards = []
296 | for completion, gt_bbox in zip(completions, target):
297 | try:
298 | bbox_match = re.search(r"\[(.*?)\]", completion)
299 | if bbox_match:
300 | pred_bbox = [float(x.strip()) for x in bbox_match.group(1).split(",")]
301 | gt_bbox = [float(x) for x in gt_bbox[0].strip("[]").split(",")]
302 |
303 | # Check IoU between predicted and ground truth bounding boxes
304 | reward = 1.0 if relaxed_bbox_iou(pred_bbox, gt_bbox) else 0.0
305 | else:
306 | reward = 0.0
307 | except Exception:
308 | reward = 0.0
309 | rewards.append(reward)
310 | return rewards
311 |
312 | def format_reward(completions, **kwargs):
313 | """Check that completions follow the required format."""
314 | completions = ["" + c for c in completions]
315 | pattern = r".*?\s*.*?"
316 | matches = [re.fullmatch(pattern, c, re.DOTALL) for c in completions]
317 | return [1.0 if m else 0.0 for m in matches]
318 |
319 | # Select reward functions for training
320 | chosen_reward_funcs = [grounding_reward, format_reward]
321 |
322 | ```
323 |
324 | ### Training Configuration
325 |
326 | We configure the training process with appropriate hyperparameters:
327 |
328 | ```python
329 | # Training arguments example
330 | training_args = GRPOConfig(
331 | output_dir="./qwen_vl_grpo_output",
332 | num_train_epochs=3,
333 | per_device_train_batch_size=1,
334 | per_device_eval_batch_size=1,
335 | gradient_accumulation_steps=2,
336 | learning_rate=1e-5,
337 | warmup_steps=100,
338 | logging_steps=10,
339 | evaluation_strategy="steps",
340 | eval_steps=50,
341 | save_strategy="steps",
342 | save_steps=50,
343 | save_total_limit=3,
344 | bf16=True,
345 | report_to="wandb",
346 | logging_first_step=True
347 | )
348 |
349 | ```
350 |
351 | ### Initializing Model and Trainer
352 |
353 | We load the Qwen2.5-VL model and set up the GRPO trainer:
354 |
355 | ```python
356 | from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
357 |
358 | # Load model and processor
359 | processor = Qwen2_5_VLProcessor.from_pretrained(
360 | model_args.model_name_or_path,
361 | trust_remote_code=True
362 | )
363 |
364 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
365 | model_args.model_name_or_path,
366 | trust_remote_code=True
367 | )
368 |
369 | # Initialize GRPO trainer
370 | trainer = Qwen2VLGRPOTrainer(
371 | model=model,
372 | reward_funcs=chosen_reward_funcs,
373 | args=training_args,
374 | train_dataset=train_dataset,
375 | eval_dataset=test_dataset,
376 | peft_config=get_peft_config(model_args)
377 | )
378 |
379 | # Start training
380 | train_result = trainer.train()
381 |
382 | ```
383 |
384 | ### Saving and Logging
385 |
386 | After training completes, we save the model and metrics:
387 |
388 | ```python
389 | # Save metrics
390 | metrics = train_result.metrics
391 | trainer.log_metrics("train", metrics)
392 | trainer.save_metrics("train", metrics)
393 | trainer.save_state()
394 |
395 | # Save model
396 | trainer.save_model(training_args.output_dir)
397 | processor.save_pretrained(training_args.output_dir)
398 |
399 | # Optional: Push to Hugging Face Hub
400 | if training_args.push_to_hub:
401 | trainer.push_to_hub()
402 |
403 | ```
404 |
405 | ## 5. Training Metrics
406 |
407 | It is interesting to notice that the grounding reward didn't change much, due to the fact that the Qwen-VL model is able to provide zero-shot object grounding. Also, we noticed that it was necessary to provide the correct format of the answer, closer to the one the model was adjusted to, otherwise, the grounding reward was a constant 0.
408 |
409 |
410 |
411 |
412 |
413 | ## 6. Example Results
414 |
415 | Let's look at some examples of the model's performance after training:
416 |
417 | ### Example 1: Successful Grounding
418 |
419 |
420 |
421 | **Query:**
422 |
423 | ```
424 | What is the comment? Show your reasoning in thinking process tags. Return bounding box in [x1, y1, x2, y2] tags.
425 |
426 | ```
427 |
428 | **Model Output:**
429 |
430 |
431 | ```
432 | Let me analyze this image.
433 |
434 | The comment on the book is located near the bottom of the image, just above the comment input field.
435 |
436 | {"bbox": [508, 467, 593, 487]}
437 |
438 | ```
439 |
440 | Qwen2.5 VL initially performs well on grounding tasks; however, the results vary across different examples.
441 |
442 |
443 |
444 | ## Conclusion
445 |
446 | In this tutorial, we've walked through the complete process of training a Vision Language Model for visual grounding using GRPO:
447 |
448 | 1. We adapted GRPO for vision-language tasks by implementing custom reward functions for bounding box evaluation
449 | 2. Prepared a specialized dataset for visual grounding with formatted prompts
450 | 3. Configured and launched training with the modified `Qwen2VLGRPOTrainer`
451 | 4. Examined examples showing the model's ability to perform visual grounding tasks
452 |
453 | This approach demonstrates how reinforcement learning techniques can be applied to multimodal models, helping them learn to connect textual and visual information more effectively. While the example is not for real-life applications, and smaller models can benefit more from SFT-reasoning, this is a good starting point.
454 |
455 | While the GRPO for VLM can provide interesting findings, there is important to notice the following:
456 |
457 | 1. As noticed by researchers from DeepSeek, small vision-language models do not perform well on GRPO tasks; SFT could provide better results.
458 | 3. Processing long context remains a challenge; we will further adjust the code for these cases.
459 | 4. It is important to construct the reward function and the answer format suitable for the model, otherwise the model can get stuck in a local minimum.
460 |
461 | ### Next Steps
462 |
463 | - Experiment with different reward functions to further improve performance
464 | - Explore more complex visual grounding tasks (e.g., multiple object identification)
465 | - Combine with other vision-language tasks like visual question answering or image captioning
466 |
467 | ### Resources
468 |
469 | - [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300)
470 | - [Qwen2.5-VL](https://huggingface.co/docs/transformers/model_doc/qwen2_5_vl)
471 | - [Visual CoT: Advancing Multi-Modal Language Models with a Comprehensive Dataset and Benchmark for Chain-of-Thought Reasoning](https://arxiv.org/abs/2403.16999)
472 | - [Mini-R1: Reproduce Deepseek R1 „aha moment“ a RL tutorial](https://www.philschmid.de/mini-deepseek-r1)
473 | - [VLM-R1](https://github.com/om-ai-lab/VLM-R1/tree/main)
474 | - [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal)
475 |
--------------------------------------------------------------------------------
/configs/grpo-qwen-2.5-v.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: Qwen/Qwen2.5-VL-3B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | attn_implementation: flash_attention_2
6 | bf16: true
7 | tf32: true
8 |
9 | # Data arguments
10 | json_data_path: "./textvqa_cot_train_1_bbox_0.json"
11 |
12 | # Training arguments
13 | output_dir: ./checkpoints/qwen2_5-3b-grpo-updated
14 | per_device_train_batch_size: 1
15 | gradient_accumulation_steps: 2
16 | gradient_checkpointing: true
17 | gradient_checkpointing_kwargs:
18 | use_reentrant: false
19 | learning_rate: 1.0e-6
20 | lr_scheduler_type: linear
21 | warmup_ratio: 0.0
22 | beta: 0.04
23 | max_prompt_length: 1280
24 | max_completion_length: 256
25 | num_generations: 8
26 | use_vllm: false
27 |
28 | max_pixels: 12845056 # Maximum number of pixels for image processing
29 | min_pixels: 3136 # Minimum number of pixels for image processing
30 |
31 | logging_strategy: steps
32 | logging_steps: 1
33 | save_strategy: steps
34 | save_steps: 1000
35 | seed: 42
36 |
37 | push_to_hub: false
38 |
--------------------------------------------------------------------------------
/configs/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/grpo_qwen2vl.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import contextlib
16 | import functools
17 | import os
18 | import textwrap
19 | import warnings
20 | from collections import defaultdict
21 | from typing import Any, Callable, Optional, Sized, Union
22 | from unittest.mock import patch
23 |
24 | import torch
25 | import torch.utils.data
26 | import transformers
27 | from accelerate import PartialState
28 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
29 | from accelerate.utils.other import is_compiled_module
30 | from datasets import Dataset, IterableDataset
31 | from packaging import version
32 | from torch import nn
33 | from torch.utils.data import Sampler
34 | from transformers import (
35 | AutoModelForCausalLM,
36 | AutoModelForSequenceClassification,
37 | AutoTokenizer,
38 | GenerationConfig,
39 | PreTrainedModel,
40 | PreTrainedTokenizerBase,
41 | Trainer,
42 | TrainerCallback,
43 | is_wandb_available,
44 | )
45 | from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
46 | from transformers.utils import is_peft_available
47 |
48 | from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
49 | from trl.extras.profiling import profiling_context, profiling_decorator
50 | from trl.import_utils import is_rich_available, is_vllm_available
51 | from trl.models import create_reference_model, prepare_deepspeed, unwrap_model_for_generation
52 | from trl.trainer.callbacks import SyncRefModelCallback
53 | from trl.trainer.grpo_config import GRPOConfig
54 | from trl.trainer.utils import (
55 | generate_model_card,
56 | get_comet_experiment_url,
57 | pad,
58 | print_prompt_completions_sample,
59 | selective_log_softmax,
60 | )
61 |
62 | import torch
63 | from transformers import (
64 | Qwen2_5_VLForConditionalGeneration,
65 | Qwen2VLForConditionalGeneration,
66 | AutoTokenizer,
67 | AutoProcessor,
68 | )
69 |
70 | if is_peft_available():
71 | from peft import PeftConfig, get_peft_model
72 |
73 | if is_vllm_available():
74 | from vllm import LLM, SamplingParams
75 | from vllm.sampling_params import GuidedDecodingParams
76 |
77 | if is_wandb_available():
78 | import wandb
79 |
80 | import PIL
81 |
82 | # What we call a reward function is a callable that takes a list of prompts and completions and returns a list of
83 | # rewards. When it's a string, it's a model ID, so it's loaded as a pretrained model.
84 | RewardFunc = Union[str, PreTrainedModel, Callable[[list, list], list[float]]]
85 |
86 |
87 | class RepeatRandomSampler(Sampler):
88 | """
89 | Sampler that repeats the indices of a dataset in a structured manner.
90 |
91 | Args:
92 | data_source (`Sized`):
93 | Dataset to sample from.
94 | mini_repeat_count (`int`):
95 | Number of times to repeat each index per batch.
96 | batch_size (`int`, *optional*, defaults to `1`):
97 | Number of unique indices per batch.
98 | repeat_count (`int`, *optional*, defaults to `1`):
99 | Number of times to repeat the full sampling process.
100 | seed (`int` or `None`, *optional*, defaults to `None`):
101 | Random seed for reproducibility (only affects this sampler).
102 |
103 | Example:
104 | ```python
105 | >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
106 | >>> list(sampler)
107 | [4, 4, 3, 3, 0, 0,
108 | 4, 4, 3, 3, 0, 0,
109 | 4, 4, 3, 3, 0, 0,
110 | 4, 4, 3, 3, 0, 0,
111 |
112 | 1, 1, 2, 2, 6, 6,
113 | 1, 1, 2, 2, 6, 6,
114 | 1, 1, 2, 2, 6, 6,
115 | 1, 1, 2, 2, 6, 6]
116 | ```
117 |
118 | ```txt
119 | mini_repeat_count = 3
120 | - - -
121 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
122 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
123 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
124 | repeat_count = 2
125 | 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
126 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
127 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
128 | --------- --------- --------- ---------
129 | --------- --------- --------- ---------
130 | --------- --------- --------- ---------
131 | batch_size = 12
132 | ```
133 | """
134 |
135 | def __init__(
136 | self,
137 | data_source: Sized,
138 | mini_repeat_count: int,
139 | batch_size: int = 1,
140 | repeat_count: int = 1,
141 | seed: Optional[int] = None,
142 | ):
143 | self.data_source = data_source
144 | self.mini_repeat_count = mini_repeat_count
145 | self.batch_size = batch_size
146 | self.repeat_count = repeat_count
147 | self.num_samples = len(data_source)
148 | self.seed = seed
149 | self.generator = torch.Generator() # Create a local random generator
150 | if seed is not None:
151 | self.generator.manual_seed(seed)
152 |
153 | def __iter__(self):
154 | # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
155 | indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
156 |
157 | # [2, 4, 3, 1, 0, 6, 5]
158 | # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
159 | indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
160 |
161 | # [[2, 4, 3], [1, 0, 6], [5]]
162 | # -> [[2, 4, 3], [1, 0, 6]]
163 | indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
164 |
165 | for chunk in indexes:
166 | for _ in range(self.repeat_count):
167 | for index in chunk:
168 | for _ in range(self.mini_repeat_count):
169 | yield index
170 |
171 | def __len__(self) -> int:
172 | return self.num_samples * self.mini_repeat_count * self.repeat_count
173 |
174 |
175 | class Qwen2VLGRPOTrainer(Trainer):
176 | """
177 | Trainer for the Group Relative Policy Optimization (GRPO) method. This algorithm was initially proposed in the
178 | paper [DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://huggingface.co/papers/2402.03300).
179 |
180 | Example:
181 |
182 | ```python
183 | from datasets import load_dataset
184 | from trl import GRPOTrainer
185 |
186 | dataset = load_dataset("trl-lib/tldr", split="train")
187 |
188 | def reward_func(completions, **kwargs):
189 | # Dummy reward function that rewards completions with more unique letters.
190 | return [float(len(set(completion))) for completion in completions]
191 |
192 | trainer = GRPOTrainer(
193 | model="Qwen/Qwen2-0.5B-Instruct",
194 | reward_funcs=reward_func,
195 | train_dataset=dataset,
196 | )
197 |
198 | trainer.train()
199 | ```
200 |
201 | Args:
202 | model (`Union[str, PreTrainedModel]`):
203 | Model to be trained. Can be either:
204 |
205 | - A string, being the *model id* of a pretrained model hosted inside a model repo on huggingface.co, or
206 | a path to a *directory* containing model weights saved using
207 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is
208 | loaded using [`~transformers.AutoModelForCausalLM.from_pretrained`] with the keywork arguments
209 | in `args.model_init_kwargs`.
210 | - A [`~transformers.PreTrainedModel`] object. Only causal language models are supported.
211 | reward_funcs (`Union[RewardFunc, list[RewardFunc]]`):
212 | Reward functions to be used for computing the rewards. To compute the rewards, we call all the reward
213 | functions with the prompts and completions and sum the rewards. Can be either:
214 |
215 | - A single reward function, such as:
216 | - A string: The *model ID* of a pretrained model hosted inside a model repo on huggingface.co, or a
217 | path to a *directory* containing model weights saved using
218 | [`~transformers.PreTrainedModel.save_pretrained`], e.g., `'./my_model_directory/'`. The model is loaded
219 | using [`~transformers.AutoModelForSequenceClassification.from_pretrained`] with `num_labels=1` and the
220 | keyword arguments in `args.model_init_kwargs`.
221 | - A [`~transformers.PreTrainedModel`] object: Only sequence classification models are supported.
222 | - A custom reward function: The function is provided with the prompts and the generated completions,
223 | plus any additional columns in the dataset. It should return a list of rewards. For more details, see
224 | [Using a custom reward function](#using-a-custom-reward-function).
225 | - A list of reward functions, where each item can independently be any of the above types. Mixing different
226 | types within the list (e.g., a string model ID and a custom reward function) is allowed.
227 | args ([`GRPOConfig`], *optional*, defaults to `None`):
228 | Configuration for this trainer. If `None`, a default configuration is used.
229 | train_dataset ([`~datasets.Dataset`] or [`~datasets.IterableDataset`]):
230 | Dataset to use for training. It must include a column `"prompt"`. Any additional columns in the dataset is
231 | ignored. The format of the samples can be either:
232 |
233 | - [Standard](dataset_formats#standard): Each sample contains plain text.
234 | - [Conversational](dataset_formats#conversational): Each sample contains structured messages (e.g., role
235 | and content).
236 | eval_dataset ([`~datasets.Dataset`], [`~datasets.IterableDataset`] or `dict[str, Union[Dataset, IterableDataset]]`):
237 | Dataset to use for evaluation. It must meet the same requirements as `train_dataset`.
238 | processing_class ([`~transformers.PreTrainedTokenizerBase`], *optional*, defaults to `None`):
239 | Processing class used to process the data. The padding side must be set to "left". If `None`, the
240 | processing class is loaded from the model's name with [`~transformers.AutoTokenizer.from_pretrained`].
241 | reward_processing_classes (`Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]`, *optional*, defaults to `None`):
242 | Processing classes corresponding to the reward functions specified in `reward_funcs`. Can be either:
243 |
244 | - A single processing class: Used when `reward_funcs` contains only one reward function.
245 | - A list of processing classes: Must match the order and length of the reward functions in `reward_funcs`.
246 | If set to `None`, or if an element of the list corresponding to a [`~transformers.PreTrainedModel`] is
247 | `None`, the tokenizer for the model is automatically loaded using [`~transformers.AutoTokenizer.from_pretrained`].
248 | For elements in `reward_funcs` that are custom reward functions (not [`~transformers.PreTrainedModel`]),
249 | the corresponding entries in `reward_processing_classes` are ignored.
250 | callbacks (list of [`~transformers.TrainerCallback`], *optional*, defaults to `None`):
251 | List of callbacks to customize the training loop. Will add those to the list of default callbacks
252 | detailed in [here](https://huggingface.co/docs/transformers/main_classes/callback).
253 |
254 | If you want to remove one of the default callbacks used, use the [`~transformers.Trainer.remove_callback`]
255 | method.
256 | optimizers (`tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]`, *optional*, defaults to `(None, None)`):
257 | A tuple containing the optimizer and the scheduler to use. Will default to an instance of [`AdamW`] on your
258 | model and a scheduler given by [`get_linear_schedule_with_warmup`] controlled by `args`.
259 | peft_config ([`~peft.PeftConfig`], *optional*, defaults to `None`):
260 | PEFT configuration used to wrap the model. If `None`, the model is not wrapped.
261 | """
262 |
263 | _tag_names = ["trl", "grpo"]
264 |
265 | def __init__(
266 | self,
267 | model: Union[str, PreTrainedModel],
268 | reward_funcs: Union[RewardFunc, list[RewardFunc]],
269 | args: Optional[GRPOConfig] = None,
270 | train_dataset: Optional[Union[Dataset, IterableDataset]] = None,
271 | eval_dataset: Optional[Union[Dataset, IterableDataset, dict[str, Union[Dataset, IterableDataset]]]] = None,
272 | processing_class: Optional[PreTrainedTokenizerBase] = None,
273 | reward_processing_classes: Optional[Union[PreTrainedTokenizerBase, list[PreTrainedTokenizerBase]]] = None,
274 | callbacks: Optional[list[TrainerCallback]] = None,
275 | optimizers: tuple[Optional[torch.optim.Optimizer], Optional[torch.optim.lr_scheduler.LambdaLR]] = (None, None),
276 | peft_config: Optional["PeftConfig"] = None,
277 | max_pixels: Optional[int] = 12845056,
278 | min_pixels: Optional[int] = 3136,
279 | attn_implementation: str = "flash_attention_2",
280 | torch_dtype: Optional[torch.dtype] = None,
281 | ):
282 | # Args
283 | if args is None:
284 | model_name = model if isinstance(model, str) else model.config._name_or_path
285 | model_name = model_name.split("/")[-1]
286 | args = GRPOConfig(f"{model_name}-GRPO")
287 |
288 | # Models
289 | # Trained model
290 | model_init_kwargs = args.model_init_kwargs or {}
291 | if isinstance(model, str):
292 | model_id = model
293 | torch_dtype = model_init_kwargs.get("torch_dtype")
294 | if isinstance(torch_dtype, torch.dtype) or torch_dtype == "auto" or torch_dtype is None:
295 | pass # torch_dtype is already a torch.dtype or "auto" or None
296 | elif isinstance(torch_dtype, str): # it's a str, but not "auto"
297 | torch_dtype = getattr(torch, torch_dtype)
298 | model_init_kwargs["torch_dtype"] = torch_dtype
299 | else:
300 | raise ValueError(
301 | "Invalid `torch_dtype` passed to `GRPOConfig`. Expected either 'auto' or a string representing "
302 | f"a `torch.dtype` (e.g., 'float32'), but got {torch_dtype}."
303 | )
304 | # Disable caching if gradient checkpointing is enabled (not supported)
305 | model_init_kwargs["use_cache"] = (
306 | False if args.gradient_checkpointing else model_init_kwargs.get("use_cache")
307 | )
308 | model_init_kwargs["attn_implementation"] = attn_implementation
309 | if "Qwen2-VL" in model_id:
310 | model = Qwen2VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
311 | elif "Qwen2.5-VL" in model_id:
312 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model, **model_init_kwargs)
313 | else:
314 | model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
315 | else:
316 | model_id = model.config._name_or_path
317 | if args.model_init_kwargs is not None:
318 | raise ValueError(
319 | "You passed `model_init_kwargs` to the `GRPOConfig`, but your model is already instantiated. "
320 | "This argument can only be used when the `model` argument is a string."
321 | )
322 |
323 | if peft_config is not None:
324 | if not is_peft_available():
325 | raise ImportError("PEFT is required to use `peft_config`. Run `pip install peft`.")
326 | model = get_peft_model(model, peft_config)
327 |
328 | # Enable gradient checkpointing if requested
329 | if args.gradient_checkpointing:
330 | model = self._enable_gradient_checkpointing(model, args)
331 |
332 | # Reference model
333 | self.beta = args.beta
334 | if self.beta == 0.0:
335 | # If beta is 0.0, the reference model is not needed
336 | self.ref_model = None
337 | if is_deepspeed_zero3_enabled():
338 | if "Qwen2-VL" in model_id:
339 | self.ref_model = Qwen2VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
340 | elif "Qwen2.5-VL" in model_id:
341 | self.ref_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(model_id, **model_init_kwargs)
342 | else:
343 | self.ref_model = AutoModelForCausalLM.from_pretrained(model_id, **model_init_kwargs)
344 | elif is_peft_model(model):
345 | # If PEFT is used, the reference model is not needed since the adapter can be disabled
346 | # to revert to the initial model.
347 | self.ref_model = None
348 | else:
349 | # If PEFT configuration is not provided, create a reference model based on the initial model.
350 | self.ref_model = create_reference_model(model)
351 |
352 | # Processing class
353 | if processing_class is None:
354 | if "Qwen2-VL" in model_id or "Qwen2.5-VL" in model_id:
355 | processing_class = AutoProcessor.from_pretrained(model_id)
356 | pad_token_id = processing_class.tokenizer.pad_token_id
357 | processing_class.pad_token_id = pad_token_id
358 | processing_class.eos_token_id = processing_class.tokenizer.eos_token_id
359 | if "Qwen" in model_id or "Qwen2.5-VL" in model_id:
360 | processing_class.image_processor.max_pixels = max_pixels
361 | processing_class.image_processor.min_pixels = min_pixels
362 | else:
363 | processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path, padding_side="left")
364 | pad_token_id = processing_class.pad_token_id
365 |
366 | # Reward functions
367 | if not isinstance(reward_funcs, list):
368 | reward_funcs = [reward_funcs]
369 | for i, reward_func in enumerate(reward_funcs):
370 | if isinstance(reward_func, str):
371 | reward_funcs[i] = AutoModelForSequenceClassification.from_pretrained(
372 | reward_func, num_labels=1, **model_init_kwargs
373 | )
374 | self.reward_funcs = reward_funcs
375 |
376 | # Reward weights
377 | if args.reward_weights is not None:
378 | if len(args.reward_weights) != len(reward_funcs):
379 | raise ValueError(
380 | f"Number of reward weights ({len(args.reward_weights)}) must match number of reward "
381 | f"functions ({len(reward_funcs)})"
382 | )
383 | self.reward_weights = torch.tensor(args.reward_weights, dtype=torch.float32)
384 | else:
385 | self.reward_weights = torch.ones(len(reward_funcs), dtype=torch.float32)
386 |
387 | # Reward processing class
388 | if reward_processing_classes is None:
389 | reward_processing_classes = [None] * len(reward_funcs)
390 | elif not isinstance(reward_processing_classes, list):
391 | reward_processing_classes = [reward_processing_classes]
392 | else:
393 | if len(reward_processing_classes) != len(reward_funcs):
394 | raise ValueError("The number of reward processing classes must match the number of reward functions.")
395 |
396 | for i, (reward_processing_class, reward_func) in enumerate(zip(reward_processing_classes, reward_funcs)):
397 | if isinstance(reward_func, PreTrainedModel):
398 | if reward_processing_class is None:
399 | reward_processing_class = AutoTokenizer.from_pretrained(reward_func.config._name_or_path)
400 | if reward_processing_class.pad_token_id is None:
401 | reward_processing_class.pad_token = reward_processing_class.eos_token
402 | # The reward model computes the reward for the latest non-padded token in the input sequence.
403 | # So it's important to set the pad token ID to the padding token ID of the processing class.
404 | reward_func.config.pad_token_id = reward_processing_class.pad_token_id
405 | reward_processing_classes[i] = reward_processing_class
406 | self.reward_processing_classes = reward_processing_classes
407 |
408 | # Data collator
409 | def data_collator(features): # No data collation is needed in GRPO
410 | return features
411 |
412 | # Training arguments
413 | self.max_prompt_length = args.max_prompt_length
414 | self.max_completion_length = args.max_completion_length # = |o_i| in the GRPO paper
415 | self.num_generations = args.num_generations # = G in the GRPO paper
416 | self.use_vllm = args.use_vllm
417 |
418 | # Multi-step
419 | self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
420 | self.epsilon = args.epsilon
421 | # Tracks the number of iterations (forward + backward passes), including those within a gradient accumulation cycle.
422 | self._step = 0
423 | # Buffer the batch to reuse generated outputs across multiple updates. For more details, see
424 | # `_get_train_sampler` and `_prepare_inputs`.
425 | self._buffered_inputs = [None] * args.gradient_accumulation_steps
426 |
427 | # The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
428 | # input tensor associated with the key "input_ids". However, in GRPO, the sampled data does not include the
429 | # "input_ids" key. Instead, the available keys is "prompt". As a result, the trainer issues the warning:
430 | # "Could not estimate the number of tokens of the input, floating-point operations will not be computed." To
431 | # suppress this warning, we set the "estimate_tokens" key in the model's "warnings_issued" dictionary to True.
432 | # This acts as a flag to indicate that the warning has already been issued.
433 | model.warnings_issued["estimate_tokens"] = True
434 |
435 | # Initialize the metrics
436 | self._metrics = {"train": defaultdict(list), "eval": defaultdict(list)}
437 | self.log_completions = args.log_completions
438 |
439 | super().__init__(
440 | model=model,
441 | args=args,
442 | data_collator=data_collator,
443 | train_dataset=train_dataset,
444 | eval_dataset=eval_dataset,
445 | processing_class=processing_class,
446 | callbacks=callbacks,
447 | optimizers=optimizers,
448 | )
449 |
450 | # Check if the per_device_train/eval_batch_size * num processes can be divided by the number of generations
451 | num_processes = self.accelerator.num_processes
452 | global_batch_size = args.per_device_train_batch_size * num_processes
453 | possible_values = [n_gen for n_gen in range(1, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
454 |
455 | if self.num_generations not in possible_values:
456 | raise ValueError(
457 | f"The global train batch size ({num_processes} x {args.per_device_train_batch_size}) must be evenly "
458 | f"divisible by the number of generations per prompt ({self.num_generations}). Given the current train "
459 | f"batch size, the valid values for the number of generations are: {possible_values}."
460 | )
461 | if self.args.eval_strategy != "no":
462 | global_batch_size = args.per_device_eval_batch_size * num_processes
463 | possible_values = [n_gen for n_gen in range(2, global_batch_size + 1) if (global_batch_size) % n_gen == 0]
464 | if self.num_generations not in possible_values:
465 | raise ValueError(
466 | f"The global eval batch size ({num_processes} x {args.per_device_eval_batch_size}) must be evenly "
467 | f"divisible by the number of generations per prompt ({self.num_generations}). Given the current "
468 | f"eval batch size, the valid values for the number of generations are: {possible_values}."
469 | )
470 |
471 | # Ensure each process receives a unique seed to prevent duplicate completions when generating with
472 | # transformers if num_generations exceeds per_device_train_batch_size. We could skip it if we use vLLM, but
473 | # it's safer to set it in all cases.
474 | set_seed(args.seed, device_specific=True)
475 |
476 | self.generation_config = GenerationConfig(
477 | max_new_tokens=self.max_completion_length,
478 | do_sample=True,
479 | temperature=args.temperature,
480 | num_return_sequences=self.num_generations,
481 | pad_token_id=pad_token_id,
482 | )
483 |
484 | # Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
485 | # model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
486 | # self.model_accepts_loss_kwargs to False to enable scaling.
487 | self.model_accepts_loss_kwargs = False
488 |
489 | # Add tags to the model
490 | self.model.add_model_tags(self._tag_names)
491 |
492 | if self.ref_model is not None:
493 | if self.is_deepspeed_enabled:
494 | self.ref_model = prepare_deepspeed(self.ref_model, self.accelerator)
495 | else:
496 | self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
497 |
498 | if args.sync_ref_model:
499 | self.add_callback(SyncRefModelCallback(ref_model=self.ref_model, accelerator=self.accelerator))
500 |
501 | for i, reward_func in enumerate(self.reward_funcs):
502 | if isinstance(reward_func, PreTrainedModel):
503 | self.reward_funcs[i] = self.accelerator.prepare_model(reward_func, evaluation_mode=True)
504 |
505 | def _set_signature_columns_if_needed(self):
506 | # If `self.args.remove_unused_columns` is True, non-signature columns are removed.
507 | # By default, this method sets `self._signature_columns` to the model's expected inputs.
508 | # In GRPOTrainer, we preprocess data, so using the model's signature columns doesn't work.
509 | # Instead, we set them to the columns expected by the `training_step` method, hence the override.
510 | if self._signature_columns is None:
511 | self._signature_columns = ["prompt"]
512 |
513 | def _get_train_sampler(self) -> Sampler:
514 | # Returns a sampler that
515 | # 1. ensures each prompt is repeated across multiple processes. This guarantees that identical prompts are
516 | # distributed to different GPUs, allowing rewards to be computed and normalized correctly within each prompt
517 | # group. Using the same seed across processes ensures consistent prompt assignment, preventing discrepancies
518 | # in group formation.
519 | # 2. repeats the batch multiple times to allow reusing generations across multiple updates. Refer to
520 | # _prepare_inputs to see how the generations are stored and reused.
521 |
522 | # In the following figure, the values are the prompt indices. The first row shows the first sampled batch, the
523 | # second row shows the second sampled batch, and so on.
524 | #
525 | # | GPU 0 | GPU 1 | GPU 2 |
526 | #
527 | # global_step step <───────> num_generations=3
528 | # <───────────> per_device_train_batch_size=4
529 | # ▲ 0 0 0 0 0 1 1 1 2 2 2 3 3 3 │
530 | # grad_accum=3 │ 0 1 4 4 4 5 5 5 6 6 6 7 7 7 │ Generate completions for each prompt
531 | # ▼ 0 2 8 8 8 9 9 9 10 10 10 11 11 11 │
532 | #
533 | # 1 3 0 0 0 1 1 1 2 2 2 3 3 3 │ The sampled prompts are the same as in the first iteration
534 | # 1 4 4 4 4 5 5 5 6 6 6 7 7 7 │ Reuse the completions (here, once, because num_iterations=2)
535 | # 1 5 8 8 8 9 9 9 10 10 10 11 11 11 │
536 | #
537 | # 2 6 12 12 12 13 13 13 14 14 14 15 15 15
538 | # 2 7 16 16 16 17 17 17 18 18 18 19 19 19
539 | # 2 8 20 20 20 21 21 21 22 22 22 23 23 23
540 | # ...
541 | effective_batch_size = (
542 | self.args.per_device_train_batch_size
543 | * self.accelerator.num_processes
544 | * self.args.gradient_accumulation_steps
545 | )
546 | return RepeatRandomSampler(
547 | data_source=self.train_dataset,
548 | mini_repeat_count=self.num_generations,
549 | batch_size=effective_batch_size // self.num_generations,
550 | repeat_count=self.num_iterations,
551 | seed=self.args.seed,
552 | )
553 |
554 | def _get_eval_sampler(self, eval_dataset) -> Sampler:
555 | # See _get_train_sampler for an explanation of the sampler.
556 | return RepeatRandomSampler(
557 | data_source=eval_dataset,
558 | mini_repeat_count=self.num_generations,
559 | seed=self.args.seed,
560 | )
561 |
562 | def _enable_gradient_checkpointing(self, model: PreTrainedModel, args: GRPOConfig) -> PreTrainedModel:
563 | """Enables gradient checkpointing for the model."""
564 | # Ensure use_cache is disabled
565 | model.config.use_cache = False
566 |
567 | # Enable gradient checkpointing on the base model for PEFT
568 | if is_peft_model(model):
569 | model.base_model.gradient_checkpointing_enable()
570 | # Enable gradient checkpointing for non-PEFT models
571 | else:
572 | model.gradient_checkpointing_enable()
573 |
574 | gradient_checkpointing_kwargs = args.gradient_checkpointing_kwargs or {}
575 | use_reentrant = (
576 | "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]
577 | )
578 |
579 | if use_reentrant:
580 | model.enable_input_require_grads()
581 |
582 | return model
583 |
584 | # Get the per-token log probabilities for the completions for the model and the reference model
585 | @profiling_decorator
586 | def _get_per_token_logps(self, model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep):
587 | # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
588 | logits = model(
589 | input_ids=input_ids, attention_mask=attention_mask, pixel_values=pixel_values, image_grid_thw=image_grid_thw
590 | ).logits
591 | logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
592 |
593 | input_ids = input_ids[:, -logits_to_keep:]
594 | # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
595 | # See https://github.com/huggingface/trl/issues/2770
596 | logits = logits[:, -logits_to_keep:]
597 | return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
598 |
599 | @profiling_decorator
600 | def _move_model_to_vllm(self):
601 | with unwrap_model_for_generation(
602 | self.model, self.accelerator, gather_deepspeed3_params=self.args.ds3_gather_for_generation
603 | ) as unwrapped_model:
604 | if is_compiled_module(unwrapped_model):
605 | unwrapped_model = unwrapped_model._orig_mod
606 | if is_peft_model(unwrapped_model):
607 | unwrapped_model.merge_adapter()
608 | state_dict = unwrapped_model.state_dict()
609 | # Remove base_model and base_layer prefixes
610 | state_dict = {
611 | k.removeprefix("base_model.model.").replace(".base_layer", ""): v for k, v in state_dict.items()
612 | }
613 | # Remove values with adapter prefix (example: "_lora")
614 | state_dict = {k: v for k, v in state_dict.items() if unwrapped_model.prefix not in k}
615 | # When module to save, remove its prefix and discard the original module
616 | state_dict = {
617 | k.replace("modules_to_save.default.", ""): v
618 | for k, v in state_dict.items()
619 | if "original_module" not in k
620 | }
621 | else:
622 | state_dict = unwrapped_model.state_dict()
623 | if self.accelerator.is_main_process:
624 | llm_model = self.llm.llm_engine.model_executor.driver_worker.model_runner.model
625 | llm_model.load_weights(state_dict.items())
626 | # Unmerge the adapter to restore the model to its original state.
627 | # This must be done after loading weights to ensure they correspond to the merged state.
628 | if is_peft_model(unwrapped_model):
629 | unwrapped_model.unmerge_adapter()
630 |
631 | @profiling_decorator
632 | def _prepare_inputs(self, inputs: dict[str, Union[torch.Tensor, Any]]) -> dict[str, Union[torch.Tensor, Any]]:
633 | mode = "eval" if self.control.should_evaluate else "train"
634 | if mode == "train":
635 | if self.state.global_step % self.num_iterations == 0:
636 | inputs = self._generate_and_score_completions(inputs)
637 | self._buffered_inputs[self._step % self.args.gradient_accumulation_steps] = inputs
638 | else:
639 | inputs = self._buffered_inputs[self._step % self.args.gradient_accumulation_steps]
640 | self._step += 1
641 | else:
642 | # In evaluation, we don't reuse completions across multiple updates, so we don't need to buffer inputs.
643 | inputs = self._generate_and_score_completions(inputs)
644 | return inputs
645 |
646 | def _generate_and_score_completions(
647 | self, inputs: dict[str, Union[torch.Tensor, Any]]
648 | ) -> dict[str, Union[torch.Tensor, Any]]:
649 | device = self.accelerator.device
650 | prompts = [x["prompt"] for x in inputs]
651 | prompts_text = [maybe_apply_chat_template(example, self.processing_class)["prompt"] for example in inputs]
652 |
653 | # Handle both pre-loaded images and image paths
654 | images = []
655 | for x in inputs:
656 | if "image" in x:
657 | img = x["image"]
658 | else:
659 | img = PIL.Image.open(x["image_path"])
660 |
661 | # Ensure minimum dimensions of 28 pixels
662 | w, h = img.size
663 | if w < 28 or h < 28:
664 | # Calculate new dimensions maintaining aspect ratio
665 | if w < h:
666 | new_w = 28
667 | new_h = int(h * (28 / w))
668 | else:
669 | new_h = 28
670 | new_w = int(w * (28 / h))
671 | img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
672 | elif w > 512 or h > 512:
673 | # Calculate new dimensions maintaining aspect ratio for large images
674 | if w > h:
675 | new_w = 512
676 | new_h = int(h * (512 / w))
677 | else:
678 | new_h = 512
679 | new_w = int(w * (512 / h))
680 | else:
681 | # Image is within acceptable dimensions, no resize needed
682 | new_w, new_h = w, h
683 |
684 | # Only resize if dimensions changed
685 | if new_w != w or new_h != h:
686 | img = img.resize((new_w, new_h), PIL.Image.Resampling.LANCZOS)
687 |
688 | images.append(img)
689 |
690 | prompt_inputs = self.processing_class(
691 | text=prompts_text,
692 | images=images,
693 | return_tensors="pt",
694 | padding=True,
695 | padding_side="left",
696 | add_special_tokens=False,
697 | )
698 | prompt_inputs = super()._prepare_inputs(prompt_inputs)
699 |
700 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
701 | pixel_values = prompt_inputs["pixel_values"]
702 | image_grid_thw = prompt_inputs["image_grid_thw"]
703 |
704 | if self.max_prompt_length is not None:
705 | prompt_ids = prompt_ids[:, -self.max_prompt_length :]
706 | prompt_mask = prompt_mask[:, -self.max_prompt_length :]
707 |
708 | # Generate completions using either vLLM or regular generation
709 | if self.args.use_vllm:
710 | # First, have main process load weights if needed
711 | if self.state.global_step != self._last_loaded_step:
712 | self._move_model_to_vllm()
713 | self._last_loaded_step = self.state.global_step
714 |
715 | # Generate completions using vLLM: gather all prompts and use them in a single call in the main process
716 | all_prompts_text = gather_object(prompts_text)
717 | if self.accelerator.is_main_process:
718 | # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
719 | # num_generations outputs for each one. This is faster than generating outputs for each duplicate
720 | # prompt individually.
721 | ordered_set_of_prompts = list(dict.fromkeys(all_prompts_text))
722 | with profiling_context(self, "vLLM.generate"):
723 | all_outputs = self.llm.generate(
724 | ordered_set_of_prompts, sampling_params=self.sampling_params, use_tqdm=False
725 | )
726 | completion_ids = []
727 | for outputs in all_outputs:
728 | for output in outputs.outputs:
729 | completion_ids.append(output.token_ids)
730 | else:
731 | completion_ids = [None] * len(all_prompts_text)
732 | # Broadcast the completions from the main process to all processes, ensuring each process receives its
733 | # corresponding slice.
734 | completion_ids = broadcast_object_list(completion_ids, from_process=0)
735 | process_slice = slice(
736 | self.accelerator.process_index * len(prompts),
737 | (self.accelerator.process_index + 1) * len(prompts),
738 | )
739 | completion_ids = completion_ids[process_slice]
740 |
741 | # Pad the completions, and concatenate them with the prompts
742 | completion_ids = [torch.tensor(ids, device=device) for ids in completion_ids]
743 | completion_ids = pad(completion_ids, padding_value=self.processing_class.pad_token_id)
744 | prompt_completion_ids = torch.cat([prompt_ids, completion_ids], dim=1)
745 | else:
746 | # Regular generation path
747 | with unwrap_model_for_generation(self.model, self.accelerator) as unwrapped_model:
748 | prompt_completion_ids = unwrapped_model.generate(
749 | **prompt_inputs, generation_config=self.generation_config
750 | )
751 |
752 | # Compute prompt length and extract completion ids
753 | prompt_length = prompt_ids.size(1)
754 | prompt_ids = prompt_completion_ids[:, :prompt_length]
755 | completion_ids = prompt_completion_ids[:, prompt_length:]
756 | prompt_mask = prompt_mask.repeat_interleave(self.num_generations, dim=0)
757 |
758 | # Mask everything after the first EOS token
759 | is_eos = completion_ids == self.processing_class.eos_token_id
760 | eos_idx = torch.full((is_eos.size(0),), is_eos.size(1), dtype=torch.long, device=device)
761 | eos_idx[is_eos.any(dim=1)] = is_eos.int().argmax(dim=1)[is_eos.any(dim=1)]
762 | sequence_indices = torch.arange(is_eos.size(1), device=device).expand(is_eos.size(0), -1)
763 | completion_mask = (sequence_indices <= eos_idx.unsqueeze(1)).int()
764 |
765 | # Concatenate prompt_mask with completion_mask for logit computation
766 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) # (B, P+C)
767 | # Repeat image inputs to match batch size after generation
768 | if pixel_values is not None:
769 | pixel_values = pixel_values.repeat_interleave(self.num_generations, dim=0)
770 | if image_grid_thw is not None:
771 | image_grid_thw = image_grid_thw.repeat_interleave(self.num_generations, dim=0)
772 |
773 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
774 |
775 | with torch.inference_mode():
776 | # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
777 | # computation here, and use per_token_logps.detach() instead.
778 | if self.num_iterations > 1:
779 | old_per_token_logps = self._get_per_token_logps(
780 | self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep
781 | )
782 | else:
783 | old_per_token_logps = None
784 |
785 | if self.beta == 0.0:
786 | ref_per_token_logps = None
787 | elif self.ref_model is not None:
788 | ref_per_token_logps = self._get_per_token_logps(
789 | self.ref_model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep
790 | )
791 | else:
792 | with self.accelerator.unwrap_model(self.model).disable_adapter():
793 | ref_per_token_logps = self._get_per_token_logps(
794 | self.model, prompt_completion_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep
795 | )
796 |
797 | # Decode the generated completions
798 | completions_text = self.processing_class.batch_decode(completion_ids, skip_special_tokens=True)
799 | if is_conversational(inputs[0]):
800 | completions = []
801 | for prompt, completion in zip(prompts, completions_text):
802 | bootstrap = prompt.pop()["content"] if prompt[-1]["role"] == "assistant" else ""
803 | completions.append([{"role": "assistant", "content": bootstrap + completion}])
804 | else:
805 | completions = completions_text
806 |
807 | rewards_per_func = torch.zeros(len(prompts), len(self.reward_funcs), device=device)
808 | for i, (reward_func, reward_processing_class) in enumerate(
809 | zip(self.reward_funcs, self.reward_processing_classes)
810 | ):
811 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
812 | reward_func_name = f"reward {reward_func.config._name_or_path.split('/')[-1]}"
813 | else:
814 | reward_func_name = reward_func.__name__
815 | with profiling_context(self, reward_func_name):
816 | if isinstance(
817 | reward_func, nn.Module
818 | ): # Module instead of PretrainedModel for compat with compiled models
819 | if is_conversational(inputs[0]):
820 | messages = [{"messages": p + c} for p, c in zip(prompts, completions)]
821 | texts = [apply_chat_template(x, reward_processing_class)["text"] for x in messages]
822 | else:
823 | texts = [p + c for p, c in zip(prompts, completions)]
824 | reward_inputs = reward_processing_class(
825 | texts, return_tensors="pt", padding=True, padding_side="right", add_special_tokens=False
826 | )
827 | reward_inputs = super()._prepare_inputs(reward_inputs)
828 | with torch.inference_mode():
829 | rewards_per_func[:, i] = reward_func(**reward_inputs).logits[:, 0] # Shape (B*G,)
830 | else:
831 | # Repeat all input columns (but "prompt" and "completion") to match the number of generations
832 | keys = [key for key in inputs[0] if key not in ["prompt", "completion"]]
833 | reward_kwargs = {key: [example[key] for example in inputs] for key in keys}
834 | output_reward_func = reward_func(prompts=prompts, completions=completions, **reward_kwargs)
835 | rewards_per_func[:, i] = torch.tensor(output_reward_func, dtype=torch.float32, device=device)
836 |
837 | # Gather the reward per function: this part is crucial, because the rewards are normalized per group and the
838 | # completions may be distributed across processes
839 | rewards_per_func = gather(rewards_per_func)
840 |
841 | # Apply weights to each reward function's output and sum
842 | rewards = (rewards_per_func * self.reward_weights.to(device).unsqueeze(0)).sum(dim=1)
843 |
844 | # Compute grouped-wise rewards
845 | mean_grouped_rewards = rewards.view(-1, self.num_generations).mean(dim=1)
846 | std_grouped_rewards = rewards.view(-1, self.num_generations).std(dim=1)
847 |
848 | # Normalize the rewards to compute the advantages
849 | mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
850 | std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
851 | advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
852 |
853 | # Slice to keep only the local part of the data
854 | process_slice = slice(
855 | self.accelerator.process_index * len(prompts),
856 | (self.accelerator.process_index + 1) * len(prompts),
857 | )
858 | advantages = advantages[process_slice]
859 |
860 | # Log the metrics
861 | mode = "eval" if self.control.should_evaluate else "train"
862 |
863 | completion_length = self.accelerator.gather_for_metrics(completion_mask.sum(1)).float().mean().item()
864 | self._metrics[mode]["completion_length"].append(completion_length)
865 |
866 | reward_per_func = rewards_per_func.mean(0)
867 | for i, reward_func in enumerate(self.reward_funcs):
868 | if isinstance(reward_func, nn.Module): # Module instead of PretrainedModel for compat with compiled models
869 | reward_func_name = reward_func.config._name_or_path.split("/")[-1]
870 | else:
871 | reward_func_name = reward_func.__name__
872 | self._metrics[mode][f"rewards/{reward_func_name}"].append(reward_per_func[i].item())
873 |
874 | self._metrics[mode]["reward"].append(rewards.mean().item())
875 | self._metrics[mode]["reward_std"].append(std_grouped_rewards.mean().item())
876 |
877 | if self.log_completions and self.state.global_step % self.args.logging_steps == 0:
878 | prompts_to_log = gather_object(prompts_text)
879 | completions_to_log = gather_object(completions_text)
880 | rewards_to_log = rewards.tolist()
881 |
882 | if self.accelerator.is_main_process:
883 | if is_rich_available():
884 | print_prompt_completions_sample(
885 | prompts_to_log,
886 | completions_to_log,
887 | rewards_to_log,
888 | self.state.global_step,
889 | )
890 | if self.args.report_to and "wandb" in self.args.report_to and wandb.run is not None:
891 | import pandas as pd
892 |
893 | # For logging
894 | table = {
895 | "step": [str(self.state.global_step)] * len(rewards),
896 | "prompt": prompts_to_log,
897 | "completion": completions_to_log,
898 | "reward": rewards.tolist(),
899 | }
900 | df = pd.DataFrame(table)
901 | wandb.log({"completions": wandb.Table(dataframe=df)})
902 |
903 | return {
904 | "prompt_ids": prompt_ids,
905 | "prompt_mask": prompt_mask,
906 | "pixel_values": pixel_values,
907 | "image_grid_thw": image_grid_thw,
908 | "completion_ids": completion_ids,
909 | "completion_mask": completion_mask,
910 | "old_per_token_logps": old_per_token_logps,
911 | "ref_per_token_logps": ref_per_token_logps,
912 | "advantages": advantages,
913 | }
914 |
915 | @profiling_decorator
916 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
917 | if return_outputs:
918 | raise ValueError("The GRPOTrainer does not support returning outputs")
919 | # Compute the per-token log probabilities for the model
920 |
921 | prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
922 | completion_ids, completion_mask = inputs["completion_ids"], inputs["completion_mask"]
923 | pixel_values, image_grid_thw = inputs["pixel_values"], inputs["image_grid_thw"]
924 | input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
925 | attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
926 | logits_to_keep = completion_ids.size(1) # we only need to compute the logits for the completion tokens
927 |
928 | per_token_logps = self._get_per_token_logps(
929 | model, input_ids, attention_mask, pixel_values, image_grid_thw, logits_to_keep
930 | )
931 |
932 | # Compute the KL divergence between the model and the reference model
933 | if self.beta != 0.0:
934 | ref_per_token_logps = inputs["ref_per_token_logps"]
935 | per_token_kl = (
936 | torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
937 | )
938 |
939 | # Compute the loss
940 | advantages = inputs["advantages"]
941 | # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's computation (see
942 | # _generate_and_score_completions) and use per_token_logps.detach() instead.
943 | old_per_token_logps = inputs["old_per_token_logps"] if self.num_iterations > 1 else per_token_logps.detach()
944 | coef_1 = torch.exp(per_token_logps - old_per_token_logps)
945 | coef_2 = torch.clamp(coef_1, 1 - self.epsilon, 1 + self.epsilon)
946 | per_token_loss1 = coef_1 * advantages.unsqueeze(1)
947 | per_token_loss2 = coef_2 * advantages.unsqueeze(1)
948 | per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
949 | if self.beta != 0.0:
950 | per_token_loss = per_token_loss + self.beta * per_token_kl
951 | loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
952 |
953 | # Log the metrics
954 | mode = "eval" if self.control.should_evaluate else "train"
955 |
956 | if self.beta != 0.0:
957 | mean_kl = ((per_token_kl * completion_mask).sum(dim=1) / completion_mask.sum(dim=1)).mean()
958 | self._metrics[mode]["kl"].append(self.accelerator.gather_for_metrics(mean_kl).mean().item())
959 |
960 | is_clipped = (per_token_loss1 < per_token_loss2).float()
961 | clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
962 | self._metrics[mode]["clip_ratio"].append(self.accelerator.gather_for_metrics(clip_ratio).mean().item())
963 | return loss
964 |
965 | def prediction_step(self, model, inputs, prediction_loss_only, ignore_keys: Optional[list[str]] = None):
966 | inputs = self._prepare_inputs(inputs)
967 | with torch.no_grad():
968 | with self.compute_loss_context_manager():
969 | loss = self.compute_loss(model, inputs)
970 | loss = loss.mean().detach()
971 | return loss, None, None
972 |
973 | def log(self, logs: dict[str, float], start_time: Optional[float] = None) -> None:
974 | mode = "eval" if self.control.should_evaluate else "train"
975 | metrics = {key: sum(val) / len(val) for key, val in self._metrics[mode].items()} # average the metrics
976 |
977 | # This method can be called both in training and evaluation. When called in evaluation, the keys in `logs`
978 | # start with "eval_". We need to add the prefix "eval_" to the keys in `metrics` to match the format.
979 | if mode == "eval":
980 | metrics = {f"eval_{key}": val for key, val in metrics.items()}
981 |
982 | logs = {**logs, **metrics}
983 | if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
984 | super().log(logs, start_time)
985 | else: # transformers<=4.46
986 | super().log(logs)
987 | self._metrics[mode].clear()
988 |
989 | def create_model_card(
990 | self,
991 | model_name: Optional[str] = None,
992 | dataset_name: Optional[str] = None,
993 | tags: Union[str, list[str], None] = None,
994 | ):
995 | """
996 | Creates a draft of a model card using the information available to the `Trainer`.
997 |
998 | Args:
999 | model_name (`str` or `None`, *optional*, defaults to `None`):
1000 | Name of the model.
1001 | dataset_name (`str` or `None`, *optional*, defaults to `None`):
1002 | Name of the dataset used for training.
1003 | tags (`str`, `list[str]` or `None`, *optional*, defaults to `None`):
1004 | Tags to be associated with the model card.
1005 | """
1006 | if not self.is_world_process_zero():
1007 | return
1008 |
1009 | if hasattr(self.model.config, "_name_or_path") and not os.path.isdir(self.model.config._name_or_path):
1010 | base_model = self.model.config._name_or_path
1011 | else:
1012 | base_model = None
1013 |
1014 | tags = tags or []
1015 | if isinstance(tags, str):
1016 | tags = [tags]
1017 |
1018 | if hasattr(self.model.config, "unsloth_version"):
1019 | tags.append("unsloth")
1020 |
1021 | citation = textwrap.dedent(
1022 | """\
1023 | @article{zhihong2024deepseekmath,
1024 | title = {{DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}},
1025 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Junxiao Song and Mingchuan Zhang and Y. K. Li and Y. Wu and Daya Guo},
1026 | year = 2024,
1027 | eprint = {arXiv:2402.03300},
1028 | }
1029 | """
1030 | )
1031 |
1032 | model_card = generate_model_card(
1033 | base_model=base_model,
1034 | model_name=model_name,
1035 | hub_model_id=self.hub_model_id,
1036 | dataset_name=dataset_name,
1037 | tags=tags,
1038 | wandb_url=wandb.run.get_url() if is_wandb_available() and wandb.run is not None else None,
1039 | comet_url=get_comet_experiment_url(),
1040 | trainer_name="GRPO",
1041 | trainer_citation=citation,
1042 | paper_title="DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models",
1043 | paper_id="2402.03300",
1044 | )
1045 |
1046 | model_card.save(os.path.join(self.args.output_dir, "README.md"))
1047 |
--------------------------------------------------------------------------------
/run_grpo_vlm.sh:
--------------------------------------------------------------------------------
1 | export WANDB_PROJECT="grpo-qwen-2-2B"
2 | export WANDB_ENTITY=""
3 | export WANDB_API_KEY=""
4 | export WANDB_NAME="grpo-qwen-2-2B-v-8gpus-one_reward-synthetic-data-check"
5 |
6 |
7 | RUN_NAME="Qwen2-VL-2B-GRPO"
8 | export LOG_PATH="./debug_log_$RUN_NAME.txt"
9 |
10 | torchrun --nproc_per_node="8" \
11 | --nnodes="1" \
12 | --node_rank="0" \
13 | --master_addr="127.0.0.1" \
14 | --master_port="12346" \
15 | run_r1_grpo_vlm.py \
16 | --deepspeed configs/zero3.json \
17 | --config configs/grpo-qwen-2.5-v.yaml \
18 | --json_data_path "./data/textvqa_cot_train_1_bbox_0.json" \
19 | --report_to wandb \
20 | --gradient_checkpointing false \
21 | --attn_implementation flash_attention_2 \
22 | --num_train_epochs 1 \
23 | --save_only_model true
24 |
--------------------------------------------------------------------------------
/run_r1_grpo_vlm.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import os
3 | from dataclasses import dataclass
4 | from datetime import datetime
5 | import random
6 | import re
7 | import torch
8 | import yaml
9 |
10 | from transformers.trainer_utils import get_last_checkpoint
11 |
12 | from grpo_config import GRPOConfig
13 |
14 | import torch
15 | from transformers import (
16 | Qwen2_5_VLForConditionalGeneration,
17 | Qwen2VLForConditionalGeneration,
18 | AutoTokenizer,
19 | AutoProcessor,
20 | )
21 |
22 | import datasets
23 | from datasets import load_dataset
24 | from torch.utils.data import Dataset
25 |
26 | import sys
27 | import math
28 | from typing import Optional, Tuple
29 |
30 | from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
31 | Qwen2_5_VLVisionFlashAttention2,
32 | apply_rotary_pos_emb_flashatt,
33 | flash_attn_varlen_func,
34 | )
35 | import torch
36 | from typing import Tuple
37 |
38 |
39 | # FlashAttention fix from https://github.com/om-ai-lab/VLM-R1/blob/main/src/open-r1-multimodal/src/open_r1/grpo_rec.py
40 | def custom_forward(
41 | self,
42 | hidden_states: torch.Tensor,
43 | cu_seqlens: torch.Tensor,
44 | rotary_pos_emb: Optional[torch.Tensor] = None,
45 | position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
46 | ) -> torch.Tensor:
47 | seq_length = hidden_states.shape[0]
48 | q, k, v = self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
49 | if position_embeddings is None:
50 | logger.warning_once(
51 | "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
52 | "through `rotary_pos_emb` (2D tensor of RoPE theta values), to using externally computed "
53 | "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.54 `rotary_pos_emb` will be "
54 | "removed and `position_embeddings` will be mandatory."
55 | )
56 | emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
57 | cos = emb.cos().float()
58 | sin = emb.sin().float()
59 | else:
60 | cos, sin = position_embeddings
61 | # Add this
62 | cos = cos.to(torch.float)
63 | sin = sin.to(torch.float)
64 | q, k = apply_rotary_pos_emb_flashatt(q.unsqueeze(0), k.unsqueeze(0), cos, sin)
65 | q = q.squeeze(0)
66 | k = k.squeeze(0)
67 |
68 | max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
69 | attn_output = flash_attn_varlen_func(q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen).reshape(
70 | seq_length, -1
71 | )
72 | attn_output = self.proj(attn_output)
73 | return attn_output
74 |
75 |
76 | Qwen2_5_VLVisionFlashAttention2.forward = custom_forward
77 |
78 |
79 | from trl import ModelConfig, ScriptArguments, TrlParser, get_peft_config
80 |
81 | from grpo_qwen2vl import Qwen2VLGRPOTrainer
82 | from PIL import Image
83 |
84 | try:
85 | from math_verify import parse, verify
86 | except ImportError:
87 |
88 | def parse(x):
89 | return x
90 |
91 | def verify(x, y):
92 | return float(x == y)
93 |
94 |
95 | os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
96 |
97 | # Optional: wandb environment
98 | os.environ["WANDB_PROJECT"] = "grpo-qwen-2-2B"
99 | os.environ["WANDB_ENTITY"] = ""
100 | os.environ["WANDB_API_KEY"] = "3d726fd76bb1ed0c15a7004731707d54572acef0"
101 | os.environ["WANDB_NAME"] = "grpo-qwen-2-2B-v-8gpus-one_reward-synthetic-data-check"
102 |
103 |
104 | @dataclass
105 | class ScriptArguments:
106 | dataset_id_or_path: str = "Jiayi-Pan/Countdown-Tasks-3to4"
107 | dataset_splits: str = "train"
108 | tokenizer_name_or_path: str = None
109 | max_pixels: int = 12845056
110 | min_pixels: int = 3136
111 | image_root: str = ""
112 | config_path: str = "config.yaml"
113 | json_data_path: str = ""
114 |
115 |
116 | logging.basicConfig(level=logging.INFO)
117 | logger = logging.getLogger(__name__)
118 | logger.setLevel(logging.INFO)
119 | handler = logging.StreamHandler()
120 | handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
121 | logger.addHandler(handler)
122 |
123 |
124 | def relaxed_bbox_iou(pred_bbox, gt_bbox, threshold=0.5):
125 | x1 = max(pred_bbox[0], gt_bbox[0])
126 | y1 = max(pred_bbox[1], gt_bbox[1])
127 | x2 = min(pred_bbox[2], gt_bbox[2])
128 | y2 = min(pred_bbox[3], gt_bbox[3])
129 |
130 | intersection = max(0, x2 - x1) * max(0, y2 - y1)
131 | pred_area = (pred_bbox[2] - pred_bbox[0]) * (pred_bbox[3] - pred_bbox[1])
132 | gt_area = (gt_bbox[2] - gt_bbox[0]) * (gt_bbox[3] - gt_bbox[1])
133 | union = pred_area + gt_area - intersection
134 | iou = intersection / union if union > 0 else 0
135 |
136 | return iou >= threshold
137 |
138 |
139 | def relaxed_correctness(prediction: str, target: str, max_relative_change: float = 0.05) -> bool:
140 | def extract_first_number(text: str):
141 | match = re.search(r"[-]?\d[\d,]*(\.\d+)?%?", text)
142 | return match.group(0) if match else text
143 |
144 | def to_float(text: str):
145 | text = text.strip().lower()
146 | text = text.replace(",", "")
147 | if text.endswith("%"):
148 | try:
149 | val = float(text.rstrip("%"))
150 | return val / 100.0
151 | except ValueError:
152 | return None
153 | else:
154 | try:
155 | return float(text)
156 | except ValueError:
157 | return None
158 |
159 | pred_num_str = extract_first_number(prediction)
160 | tgt_num_str = extract_first_number(target)
161 | pred_float = to_float(pred_num_str)
162 | tgt_float = to_float(tgt_num_str)
163 | if pred_float is not None and tgt_float is not None:
164 | if abs(tgt_float) < 1e-12:
165 | return abs(pred_float - tgt_float) < 1e-12
166 | relative_change = abs(pred_float - tgt_float) / abs(tgt_float)
167 | return relative_change <= max_relative_change
168 | return prediction.strip().lower() == target.strip().lower()
169 |
170 |
171 | def grounding_reward(completions, target, **kwargs):
172 | """
173 | Reward function that checks bounding boxes. We keep your original logic.
174 | """
175 | completions = [c[0]["content"] for c in completions]
176 | rewards = []
177 | for completion, gt_bbox_list in zip(completions, target):
178 | try:
179 | bbox_match = re.search(r".*?\[(.*?)\].*?", completion, re.DOTALL)
180 | if bbox_match:
181 | pred_bbox = [float(x.strip()) for x in bbox_match.group(1).split(",")]
182 | gt_bbox = [float(x) for x in gt_bbox_list[0]] # given your data format
183 | reward = 1.0 if relaxed_bbox_iou(pred_bbox, gt_bbox) else 0.0
184 | else:
185 | reward = 0.0
186 | except Exception:
187 | reward = 0.0
188 |
189 | rewards.append(reward)
190 | # Log a small fraction
191 | if random.random() < 0.1:
192 | os.makedirs("completion_samples", exist_ok=True)
193 | log_file = os.path.join("completion_samples", "grounding_samples.txt")
194 | with open(log_file, "a") as f:
195 | f.write(f"------------- Grounding reward: {reward} -------------\n")
196 | f.write(f"Content: {completion}\n")
197 | f.write(f"GT bbox: {gt_bbox_list}\n")
198 |
199 | return rewards
200 |
201 |
202 | def format_reward(completions, **kwargs):
203 | """A simpler check for ... and ... (not checking correctness)."""
204 | completions = [c[0]["content"] for c in completions]
205 | completions = ["" + c if c.startswith("") else c for c in completions]
206 | pattern = r".*?\s*.*?"
207 | matches = [re.fullmatch(pattern, c, re.DOTALL) for c in completions]
208 | return [1.0 if m else 0.0 for m in matches]
209 |
210 |
211 | def get_checkpoint(training_args):
212 | if os.path.isdir(training_args.output_dir):
213 | return get_last_checkpoint(training_args.output_dir)
214 | return None
215 |
216 |
217 | class CustomDataset:
218 | def __init__(self, list_data_dict=None, script_args=None, processor=None, json_data_path=None):
219 | super(CustomDataset, self).__init__()
220 | self.script_args = script_args
221 | self.processor = processor
222 | self.SYSTEM_PROMPT = (
223 | "You are a Vision Language Model specialized in visual grounding in [x1, y1, x2, y2]."
224 | )
225 | self.QUESTION_TEMPLATE = "Provide bounding box for the region of the image relevant to the asked question: {Question}. First output the thinking process in tags and then output the final answer in [x1, y1, x2, y2] tags."
226 |
227 | # Load data from JSON file if path is provided
228 | if json_data_path and os.path.exists(json_data_path):
229 | import json
230 |
231 | with open(json_data_path, "r") as f:
232 | data = json.load(f)
233 | self.list_data_dict = data
234 | else:
235 | self.list_data_dict = list_data_dict or []
236 |
237 | def __len__(self):
238 | return len(self.list_data_dict)
239 |
240 | def __getitem__(self, i):
241 | # Format into conversation
242 | def make_conversation(example):
243 | return {
244 | "prompt": [
245 | {"role": "system", "content": self.SYSTEM_PROMPT},
246 | {"role": "user", "content": example["question"]},
247 | ],
248 | }
249 |
250 | def make_conversation_image(example):
251 | return {
252 | "prompt": [
253 | {
254 | "role": "user",
255 | "content": [
256 | {"type": "image"},
257 | {
258 | "type": "text",
259 | "text": self.QUESTION_TEMPLATE.format(Question=example["question"])
260 | + ' {"bbox": [x1, y1, x2, y2]} tags, '
261 | + "where x1, y1, x2, y2 are integers representing the coordinates.",
262 | },
263 | ],
264 | },
265 | ],
266 | }
267 |
268 | example = self.list_data_dict[i]
269 | image_root = self.script_args.image_root
270 | if "image" in example:
271 | image_path = os.path.join(image_root, example["image"])
272 | # In case the image is not found
273 | while not os.path.exists(image_path):
274 | print(f"Warning: Image {image_path} not found, randomly selecting another image")
275 | new_index = random.randint(0, len(self.list_data_dict) - 1)
276 | example = self.list_data_dict[new_index]
277 | image_path = os.path.join(image_root, example["image"])
278 | image = Image.open(image_path).convert("RGB")
279 |
280 | # Resize image if needed to meet min/max size requirements
281 | w, h = image.size
282 |
283 | if w < 28 or h < 28:
284 | # Calculate new dimensions maintaining aspect ratio for small images
285 | if w < h:
286 | new_w = 28
287 | new_h = int(h * (28 / w))
288 | else:
289 | new_h = 28
290 | new_w = int(w * (28 / h))
291 | image = image.resize((new_w, new_h), Image.LANCZOS)
292 | elif w > 512 or h > 512:
293 | # Calculate new dimensions maintaining aspect ratio for large images
294 | if w > h:
295 | new_w = 512
296 | new_h = int(h * (512 / w))
297 | else:
298 | new_h = 512
299 | new_w = int(w * (512 / h))
300 | image = image.resize((new_w, new_h), Image.LANCZOS)
301 | else:
302 | # Image is within acceptable dimensions, no resize needed
303 | new_w, new_h = w, h
304 | else:
305 | image = None
306 | # print("Image size", image.size)
307 | return {
308 | "image": image,
309 | "question": example["question"],
310 | "target": example["bboxs"],
311 | "prompt": (
312 | make_conversation_image(example)["prompt"]
313 | if "image" in example
314 | else make_conversation(example)["prompt"]
315 | ),
316 | }
317 |
318 |
319 | def grpo_function(model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig):
320 | logger.info(f"Model parameters {model_args}")
321 | logger.info(f"Training/evaluation parameters {training_args}")
322 |
323 | model_name = model_args.model_name_or_path
324 | tokenizer_path = script_args.tokenizer_name_or_path if script_args.tokenizer_name_or_path else model_name
325 |
326 | # Load the appropriate model and processor based on model name
327 | if "Qwen2.5-VL" in model_name:
328 | processor = AutoProcessor.from_pretrained(
329 | tokenizer_path,
330 | trust_remote_code=model_args.trust_remote_code,
331 | revision=model_args.model_revision,
332 | )
333 | else: # Default to Qwen2-VL
334 | processor = AutoProcessor.from_pretrained(
335 | tokenizer_path,
336 | trust_remote_code=model_args.trust_remote_code,
337 | revision=model_args.model_revision,
338 | )
339 |
340 | # Create CustomDataset instances
341 | train_dataset = CustomDataset(
342 | script_args=script_args,
343 | processor=processor,
344 | json_data_path=script_args.json_data_path,
345 | )
346 | print(f"Created datasets with {len(train_dataset)} training examples")
347 | print(f"Sample example: {train_dataset[0]}")
348 |
349 | # Choose your reward functions
350 | chosen_reward_funcs = [grounding_reward, format_reward]
351 |
352 | trainer = Qwen2VLGRPOTrainer(
353 | model=model_args.model_name_or_path,
354 | reward_funcs=chosen_reward_funcs,
355 | args=training_args,
356 | train_dataset=train_dataset,
357 | peft_config=get_peft_config(model_args),
358 | attn_implementation=model_args.attn_implementation,
359 | max_pixels=script_args.max_pixels,
360 | min_pixels=script_args.min_pixels,
361 | torch_dtype=model_args.torch_dtype,
362 | )
363 |
364 | last_checkpoint = get_checkpoint(training_args)
365 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
366 | logger.info(f"Resuming from checkpoint at {last_checkpoint}.")
367 |
368 | logger.info("*** Starting training ***")
369 | train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
370 | metrics = train_result.metrics
371 | metrics["train_samples"] = len(train_dataset)
372 | trainer.log_metrics("train", metrics)
373 | trainer.save_metrics("train", metrics)
374 | trainer.save_state()
375 |
376 | logger.info("*** Training complete ***")
377 | logger.info("*** Save model ***")
378 | trainer.model.config.use_cache = True
379 | trainer.save_model(training_args.output_dir)
380 | logger.info(f"Model saved to {training_args.output_dir}")
381 |
382 | training_args.distributed_state.wait_for_everyone()
383 | processor.save_pretrained(training_args.output_dir)
384 | logger.info(f"Processor saved to {training_args.output_dir}")
385 |
386 | if trainer.accelerator.is_main_process:
387 | trainer.create_model_card({"tags": ["rl", "grpo", "tutorial", "philschmid"]})
388 |
389 | if training_args.push_to_hub:
390 | logger.info("Pushing to hub...")
391 | trainer.push_to_hub()
392 |
393 | logger.info("*** Done ***")
394 |
395 |
396 | def main():
397 | from trl import TrlParser
398 |
399 | parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
400 | model_args, script_args, training_args = parser.parse_args_and_config()
401 |
402 | # Load config from config file if not set directly via command line
403 | if not script_args.json_data_path and script_args.config_path and os.path.exists(script_args.config_path):
404 | with open(script_args.config_path, "r") as f:
405 | config = yaml.safe_load(f)
406 | if config and "json_data_path" in config:
407 | script_args.json_data_path = config.get("json_data_path", "")
408 |
409 | grpo_function(model_args, script_args, training_args)
410 |
411 |
412 | if __name__ == "__main__":
413 | main()
414 |
--------------------------------------------------------------------------------