├── LICENSE
├── README.md
├── assets
├── arch.png
└── vis.png
├── data_tools
├── example.jpg
├── example.json
└── prompt_hub.py
├── eval_tools
├── qwen2vl_plan_cmd_eval.sh
├── qwen2vl_plan_cmd_eval_grpo.py
└── qwen2vl_plan_cmd_eval_sft.py
├── setup.sh
├── src
├── distill_r1
│ ├── README.md
│ ├── create_hf_dataset.py
│ ├── filter_r1.py
│ ├── generate_scene_qa_pairs.ipynb
│ ├── grpo_r1_distilled.jpg
│ ├── prompt.py
│ └── query_r1.py
├── eval
│ ├── logs
│ │ ├── counting_results_superclevr_200_qwen2vl_2b_instruct_grpo100_legacy.json
│ │ ├── counting_results_superclevr_200_qwen2vl_2b_instruct_legacy.json
│ │ ├── geoqa_test_qwen2vl_7b_grpo_2epochs_legacy.json
│ │ └── geoqa_test_qwen2vl_7b_instruct_legacy.json
│ ├── prompts
│ │ ├── geoqa_test_prompts.jsonl
│ │ └── superclevr_test200_counting_problems.jsonl
│ ├── test_qwen2vl_counting_superclevr.py
│ ├── test_qwen2vl_geoqa.py
│ └── test_qwen2vl_geoqa_multigpu.py
├── r1-v
│ ├── .gitignore
│ ├── LICENSE
│ ├── Makefile
│ ├── configs
│ │ ├── ddp.yaml
│ │ ├── qwen2vl_sft_config.yaml
│ │ ├── zero2.yaml
│ │ └── zero3.yaml
│ ├── local_scripts
│ │ ├── create_vision_cot_data.py
│ │ ├── lmms_eval_qwen2vl.sh
│ │ ├── prepare_hf_data.py
│ │ ├── train_aria_moe.sh
│ │ ├── train_qwen2_vl.sh
│ │ ├── zero2.json
│ │ ├── zero3.json
│ │ ├── zero3.yaml
│ │ └── zero3_offload.json
│ ├── run_grpo.sh
│ ├── setup.cfg
│ ├── setup.py
│ ├── src
│ │ ├── __init__.py
│ │ └── open_r1
│ │ │ ├── __init__.py
│ │ │ ├── evaluate.py
│ │ │ ├── generate.py
│ │ │ ├── grpo.py
│ │ │ ├── sft.py
│ │ │ └── trainer
│ │ │ ├── __init__.py
│ │ │ ├── grpo_trainer.py
│ │ │ └── vllm_grpo_trainer.py
│ └── temp_image.png
└── scripts
│ ├── run_grpo_clevr.sh
│ ├── run_grpo_vllm.sh
│ ├── run_sft_clevr.sh
│ └── test_grpo_geoqa_multigpu.sh
└── train_tools
├── run_train_grpo.sh
└── run_train_sft.sh
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
🌌 AlphaDrive: Unleashing the Power of VLMs in Autonomous
4 |
5 | Driving via Reinforcement Learning and Reasoning
6 |
7 | [Bo Jiang](https://scholar.google.com/citations?user=UlDxGP0AAAAJ&hl=zh-CN)1, [Shaoyu Chen](https://scholar.google.com/citations?user=PIeNN2gAAAAJ&hl=en&oi=sra)1,2, [Qian Zhang](https://scholar.google.com/citations?user=pCY-bikAAAAJ&hl=zh-CN)2, [Wenyu Liu](http://eic.hust.edu.cn/professor/liuwenyu/)1, [Xinggang Wang](https://xwcv.github.io/)1,📧
8 |
9 | 1 Huazhong University of Science and Technology,
10 | 2 Horizon Robotics,
11 | 📧 corresponding author
12 |
13 |
14 | [](https://arxiv.org/abs/2503.07608)
15 | [](https://opensource.org/licenses/Apache-2.0)
16 |
17 |
18 |
19 |
20 |

21 |
22 |
23 | https://github.com/user-attachments/assets/71695178-90ca-4f5c-acc2-ab2e13c63e43
24 |
25 |
26 | ## ✨ Highlights
27 |
28 | * To the best of our knowledge, AlphaDrive is the first to integrate GRPO-based RL with planning reasoning to autonomous driving, significantly boosting both performance and training efficiency.
29 |
30 | * We are excited to discover that, following RL training, AlphaDrive exhibits some emergent multimodal planning capabilities, which is promising for improving driving safety and efficiency.
31 |
32 |
33 | ## 📋 News
34 |
35 | `[2025-3-26]:` We have released the training and evaluation scripts of AlphaDrive.
36 |
37 | `[2025-3-11]:` AlphaDrive [arXiv](https://arxiv.org/abs/2503.07608) paper released. Code are coming soon. Please stay tuned! ☕️
38 |
39 |
40 | ## 🎮 Getting Started
41 | ### Installtion
42 | ```shell
43 | git clone git@github.com:hustvl/AlphaDrive.git
44 | conda create -n alphadrive python=3.11 -y
45 | conda activate alphadrive
46 | sh setup.sh
47 | ```
48 |
49 | ### Data Preparation
50 | We provide the [prompt templates](https://github.com/hustvl/AlphaDrive/blob/main/data_tools/prompt_hub.py) used in AlphaDrive for training and generating planning reasoning data, and an example QA is provided in [example.json](https://github.com/hustvl/AlphaDrive/blob/main/data_tools/example.json).
51 |
52 |
53 | ### Training
54 | For Supervised Fine-tuning Phase:
55 | ```shell
56 | sh train_tools/run_train_sft.sh
57 | ```
58 |
59 | For Reinforcement Learning Phase:
60 | ```shell
61 | sh train_tools/run_train_grpo.sh
62 | ```
63 |
64 | ### Evaluation
65 | You can evaluate the meta-action planning accuracy using the script below.
66 | ```shell
67 | sh eval_tools/qwen2vl_plan_cmd_eval.sh
68 | ```
69 |
70 |
71 | ## 📊 Qualitative Results
72 |
73 |
74 |
75 |

76 |
77 |
78 |
79 | ## ❤️ Acknowledgements
80 |
81 | This repo is built on [open-r1](https://github.com/huggingface/open-r1) and [R1-V](https://github.com/Deep-Agent/R1-V). We sincerely thank the contributors for their great work!
82 |
83 | ## 📚 Citation
84 | If you find AlphaDrive useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
85 |
86 |
87 | ```bibtex
88 | @article{jiang2025alphadrive,
89 | title={AlphaDrive: Unleashing the Power of VLMs in Autonomous Driving via Reinforcement Learning and Reasoning},
90 | author={Bo Jiang and Shaoyu Chen and Qian Zhang and Wenyu Liu and Xinggang Wang},
91 | year={2025},
92 | eprint={2503.07608},
93 | archivePrefix={arXiv},
94 | primaryClass={cs.CV},
95 | url={https://arxiv.org/abs/2503.07608},
96 | }
97 | ```
98 |
99 | ## 🥰 Related Projects
100 | Check out our other awesome projects:
101 |
102 | [VAD & VADv2](https://github.com/hustvl/VAD): Vectorized Scene Representation for Efficient Autonomous Driving.
103 |
104 | [Senna](https://github.com/hustvl/Senna): Bridging Large Vision-Language Models and End-to-End Autonomous Driving.
105 |
106 | [DiffusionDrive](https://github.com/hustvl/DiffusionDrive): Truncated Diffusion Model for End-to-End Autonomous Driving.
107 |
108 | [RAD](https://hgao-cv.github.io/RAD/): Training an End-to-End Driving Policy via Large-Scale 3DGS-based Reinforcement Learning.
109 |
110 | [MapTR](https://github.com/hustvl/MapTR): An End-to-End Framework for Online Vectorized HD Map Construction.
111 |
--------------------------------------------------------------------------------
/assets/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/assets/arch.png
--------------------------------------------------------------------------------
/assets/vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/assets/vis.png
--------------------------------------------------------------------------------
/data_tools/example.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/data_tools/example.jpg
--------------------------------------------------------------------------------
/data_tools/example.json:
--------------------------------------------------------------------------------
1 | {
2 | "image": "example.jpg",
3 | "problem": "You are an expert driving assistant. Your current speed is 4.2m/s, the navigation command is '{'100m': 'Straight', '200m': 'Straight'}', based on the understanding of the driving scene and the navigation information, what is your driving plan for the next three seconds? Output the planning reasoning process in and final planning answer in tags, respectively. Planning answer consists of SPEED plan and PATH plan, SPEED includes KEEP, ACCELERATE, DECELERATE, and STOP. PATH includes STRAIGHT, RIGHT_CHANGE, LEFT_CHANGE, RIGHT_TURN, and LEFT_TURN. For example, a correct answer format is like ' planning reasoning process here KEEP, LEFT_TURN '.",
4 | "solution": " ACCELERATE, STRAIGHT ",
5 | "thinking": " The driving decision to accelerate and go straight is based on the navigation command indicating that the next 100 meters should be driven straight, and there appears to be no immediate obstacles that would require a change in direction or speed. The road ahead is clear, and the light is green, allowing for safe acceleration to pass the intersection. "
6 | }
--------------------------------------------------------------------------------
/data_tools/prompt_hub.py:
--------------------------------------------------------------------------------
1 | meta_action_prompt = """
2 | You are an expert driving assistant. \
3 | Your current speed is {}m/s, the navigation command is '{}', \
4 | based on the understanding of the driving scene and the navigation information, \
5 | what is your driving plan for the next three seconds? \
6 | Output the planning reasoning process in and final planning answer in tags, respectively. \
7 | Planning answer consists of SPEED plan and PATH plan, SPEED includes KEEP, ACCELERATE, DECELERATE, and STOP. \
8 | PATH includes STRAIGHT, RIGHT_CHANGE, LEFT_CHANGE, RIGHT_TURN, and LEFT_TURN. \
9 | For example, a correct answer format is like ' planning reasoning process here KEEP, LEFT_TURN '.
10 | """
11 |
12 | plan_reason_prompt = """
13 | You are an expert driving assistant. \
14 | Your current speed is {}m/s, the navigation command is '{}', \
15 | and your driving decision for the next three seconds is '{}'. \
16 | Based on your understanding of the driving scene, briefly explain why you made the above driving decisions in one or two sentences.
17 | """
18 |
--------------------------------------------------------------------------------
/eval_tools/qwen2vl_plan_cmd_eval.sh:
--------------------------------------------------------------------------------
1 | export ENV_NAME="/path/to/your/python/env"
2 | export OUTDIR="/path/to/your/output/directory"
3 | export OUT_NAME="Qwen2-VL-2B-EXP"
4 | export EVAL_DATA="/path/to/your/val/data"
5 | export EVAL_SAVE_NAME="eval_result.json"
6 |
7 | echo "Validation Process..."
8 | $ENV_NAME/bin/python eval_tools/qwen2vl_plan_cmd_eval_grpo.py \
9 | --eval_data_path $EVAL_DATA \
10 | --model_path $OUTDIR/$OUT_NAME \
11 | --save_path $OUTDIR/$OUT_NAME/$EVAL_SAVE_NAME
--------------------------------------------------------------------------------
/eval_tools/qwen2vl_plan_cmd_eval_grpo.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import json
4 | from PIL import Image
5 | import time
6 |
7 | import torch
8 | from tqdm import tqdm
9 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor, GenerationConfig
10 | from qwen_vl_utils import process_vision_info
11 | from datasets import load_dataset
12 | import numpy as np
13 | from trl.data_utils import maybe_apply_chat_template
14 |
15 |
16 | def parse_args():
17 | parser = argparse.ArgumentParser(description="Evaluate Qwen2-VL model on validation dataset")
18 | parser.add_argument("--eval_data_path", type=str, required=True, help="Path to evaluation dataset")
19 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model")
20 | parser.add_argument("--save_path", type=str, required=True, help="Path to save evaluation results")
21 | return parser.parse_args()
22 |
23 |
24 | def main():
25 |
26 | args = parse_args()
27 | model_path = args.model_path
28 | eval_data_path = args.eval_data_path
29 | save_path = args.save_path
30 |
31 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
32 | model = Qwen2VLForConditionalGeneration.from_pretrained(
33 | model_path,
34 | torch_dtype=torch.bfloat16,
35 | attn_implementation="flash_attention_2",
36 | device_map="auto",
37 | )
38 | processor = AutoProcessor.from_pretrained(model_path)
39 |
40 |
41 | eval_dataset = load_dataset(eval_data_path)
42 |
43 |
44 | tot_num, correct_num = 0, 0
45 |
46 | SPEED_PLAN = ['KEEP', 'ACCELERATE', 'DECELERATE', 'STOP']
47 | PATH_PLAN = ['RIGHT_TURN', 'RIGHT_CHANGE', 'LEFT_TURN', 'LEFT_CHANGE', 'STRAIGHT']
48 |
49 | metric_tot_cnt = {speed + '_' + path: 0 for speed in SPEED_PLAN for path in PATH_PLAN}
50 | metric_correct_cnt = copy.deepcopy(metric_tot_cnt)
51 |
52 | eval_record = {}
53 |
54 | speed_tp = {
55 | 'KEEP': 0,
56 | 'ACCELERATE': 0,
57 | 'DECELERATE': 0,
58 | 'STOP': 0,
59 | }
60 |
61 | speed_fp, speed_fn = copy.deepcopy(speed_tp), copy.deepcopy(speed_tp)
62 |
63 | path_tp = {
64 | 'RIGHT_TURN': 0,
65 | 'LEFT_TURN': 0,
66 | 'RIGHT_CHANGE': 0,
67 | 'LEFT_CHANGE': 0,
68 | 'STRAIGHT': 0,
69 | }
70 |
71 | path_fp, path_fn = copy.deepcopy(path_tp), copy.deepcopy(path_tp)
72 |
73 | f1_score = {
74 | 'KEEP': 0,
75 | 'ACCELERATE': 0,
76 | 'DECELERATE': 0,
77 | 'STOP': 0,
78 | 'RIGHT_TURN': 0,
79 | 'LEFT_TURN': 0,
80 | 'RIGHT_CHANGE': 0,
81 | 'LEFT_CHANGE': 0,
82 | 'STRAIGHT': 0,
83 | }
84 |
85 | generation_config = GenerationConfig(
86 | max_new_tokens=1024,
87 | do_sample=True,
88 | temperature=1, # HACK
89 | num_return_sequences=2,
90 | pad_token_id=151643,
91 | )
92 |
93 | for sample in tqdm(eval_dataset['validation']):
94 |
95 | tot_num = tot_num + 1
96 |
97 | gt_answer = sample['solution']
98 | text = sample['problem']
99 |
100 | inputs = [{
101 | 'prompt':[
102 | {
103 | 'content': [
104 | {'text':None, 'type':'image'},
105 | {'text':text, 'type':'text'},
106 | ],
107 | 'role': 'user'
108 | }
109 | ]
110 | }]
111 |
112 | prompts_text = [maybe_apply_chat_template(example, processor)["prompt"] for example in inputs]
113 | images, _ = process_vision_info(inputs[0]['prompt'])
114 |
115 | prompt_inputs = processor(
116 | text=prompts_text,
117 | images=images,
118 | return_tensors="pt",
119 | padding=True,
120 | padding_side="left",
121 | add_special_tokens=False,
122 | )
123 |
124 |
125 | prompt_inputs = prompt_inputs.to("cuda")
126 |
127 | prompt_ids, prompt_mask = prompt_inputs["input_ids"], prompt_inputs["attention_mask"]
128 |
129 | # Inference: Generation of the output
130 | prompt_completion_ids = model.generate(**prompt_inputs, generation_config=generation_config)
131 | prompt_length = prompt_ids.size(1)
132 | completion_ids = prompt_completion_ids[:, prompt_length:]
133 | answers = processor.batch_decode(completion_ids, skip_special_tokens=True)
134 | answer = answers[0].strip()
135 |
136 | gt_answer = gt_answer.replace(' ', '')
137 | gt_answer = gt_answer.replace(' ', '')
138 | speed_plan, path_plan = gt_answer.split(', ')
139 | path_plan = path_plan.split('\n')[0]
140 |
141 | for key in speed_tp.keys():
142 | if key in answer: # P
143 | if speed_plan in answer:
144 | speed_tp[key] += 1 # TP
145 | else:
146 | speed_fp[key] += 1 # FP
147 | else: # N
148 | if key in speed_plan:
149 | speed_fn[key] += 1 # FN
150 |
151 | for key in path_tp.keys():
152 | if key in answer: # P
153 | if path_plan in answer:
154 | path_tp[key] += 1 # TP
155 | else:
156 | path_fp[key] += 1 # FP
157 | else: # N
158 | if key in path_plan:
159 | path_fn[key] += 1 # FN
160 |
161 |
162 | metric_tot_cnt[speed_plan+'_'+path_plan] += 1
163 | if speed_plan in answer and path_plan in answer:
164 | correct_num = correct_num + 1
165 | metric_correct_cnt[speed_plan+'_'+path_plan] += 1
166 | else:
167 | fail_case = {
168 | 'gt': gt_answer,
169 | 'pred': answer,
170 | }
171 |
172 |
173 | for key in f1_score.keys():
174 | if key in speed_tp.keys():
175 | if speed_tp[key] + speed_fp[key] != 0:
176 | precision = speed_tp[key] / (speed_tp[key] + speed_fp[key])
177 | else:
178 | precision = 0
179 | if speed_tp[key] + speed_fn[key] != 0:
180 | recall = speed_tp[key] / (speed_tp[key] + speed_fn[key])
181 | else:
182 | recall = 0
183 | if precision + recall != 0:
184 | f1_score[key] = 2.0 * precision * recall / (precision + recall)
185 | else:
186 | f1_score[key] = 0
187 | if key in path_tp.keys():
188 | if path_tp[key] + path_tp[key] != 0:
189 | precision = path_tp[key] / (path_tp[key] + path_fp[key])
190 | else:
191 | precision = 0
192 | if path_tp[key] + path_fn[key] != 0:
193 | recall = path_tp[key] / (path_tp[key] + path_fn[key])
194 | else:
195 | recall = 0
196 | if precision + recall != 0:
197 | f1_score[key] = 2.0 * precision * recall / (precision + recall)
198 | else:
199 | f1_score[key] = 0
200 |
201 | print("\n\n=========== F1 Score ===========\n\n")
202 | for k, v in f1_score.items():
203 | print(f"{k}: {v}")
204 | print("\n\n================================\n\n")
205 |
206 | print(f'\nTotal Number: {tot_num}\n')
207 | print(f'\nCorrect Number: {correct_num}\n')
208 |
209 | print('\n------------------------------\n\n')
210 | print(f"Planning Accuracy: {correct_num/tot_num * 100:.2f}%")
211 | print('\n\n------------------------------\n')
212 |
213 | for key in metric_tot_cnt.keys():
214 | if metric_tot_cnt[key] > 0:
215 | print(f"{key}: num: {metric_tot_cnt[key]}, correct num: {metric_correct_cnt[key]}, {100*metric_correct_cnt[key]/metric_tot_cnt[key]:.2f}%")
216 |
217 | eval_record['summary'] = f'Total Number: {tot_num}'
218 | eval_record['summary'] = eval_record['summary'] + '\n' + f'Correct Number: {correct_num}'
219 | eval_record['summary'] = eval_record['summary'] + '\n' + f"Planning Accuracy: {correct_num/tot_num * 100:.2f}%"
220 |
221 | for key in metric_tot_cnt.keys():
222 | if metric_tot_cnt[key] > 0:
223 | eval_record['summary'] = eval_record['summary'] + '\n' + \
224 | f"{key}: num: {metric_tot_cnt[key]}, correct num: {metric_correct_cnt[key]}, {100*metric_correct_cnt[key]/metric_tot_cnt[key]:.2f}%"
225 |
226 | eval_record['f1_score'] = {}
227 | for k, v in f1_score.items():
228 | eval_record['f1_score'][k] = v
229 |
230 | with open(save_path, "w") as f:
231 | json.dump(eval_record, f)
232 | print(f'\nEval results saved to {save_path}\n')
233 |
234 |
235 | if __name__ == "__main__":
236 | main()
237 |
--------------------------------------------------------------------------------
/eval_tools/qwen2vl_plan_cmd_eval_sft.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import copy
3 | import json
4 | from PIL import Image
5 |
6 | import torch
7 | from tqdm import tqdm
8 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
9 | from qwen_vl_utils import process_vision_info
10 | from datasets import load_dataset
11 |
12 |
13 | def convert_example(example):
14 | """
15 | correct example into "messages"
16 | eg:
17 | {
18 | "system": "You are a helpful assistant.",
19 | "conversations": [
20 | {"from": "user", "value": "How many objects are included in this image?",
21 | "image_path": "/path/to/image.png"},
22 | {"from": "assistant", "value": "\nI can see 10 objects\n\n\n10\n"}
23 | ]
24 | }
25 | """
26 | messages = []
27 |
28 | image = example.get("image")
29 | messages.append({
30 | "role": "user",
31 | "content": [
32 | {"type": "text", "text": example["problem"]},
33 | {"type": "image", "image": image},
34 | ]
35 | })
36 |
37 | return messages
38 |
39 |
40 | def parse_args():
41 | parser = argparse.ArgumentParser(description="Evaluate Qwen2-VL model on validation dataset")
42 | parser.add_argument("--eval_data_path", type=str, required=True, help="Path to evaluation dataset")
43 | parser.add_argument("--model_path", type=str, required=True, help="Path to the pretrained model")
44 | parser.add_argument("--save_path", type=str, required=True, help="Path to save evaluation results")
45 | return parser.parse_args()
46 |
47 |
48 | def main():
49 |
50 | args = parse_args()
51 | model_path = args.model_path
52 | eval_data_path = args.eval_data_path
53 | save_path = args.save_path
54 |
55 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
56 | model = Qwen2VLForConditionalGeneration.from_pretrained(
57 | model_path,
58 | torch_dtype=torch.bfloat16,
59 | attn_implementation="flash_attention_2",
60 | device_map="auto",
61 | )
62 | processor = AutoProcessor.from_pretrained(model_path)
63 |
64 |
65 | eval_dataset = load_dataset(eval_data_path)
66 |
67 |
68 | tot_num, correct_num = 0, 0
69 |
70 | SPEED_PLAN = ['KEEP', 'ACCELERATE', 'DECELERATE', 'STOP']
71 | PATH_PLAN = ['RIGHT_TURN', 'RIGHT_CHANGE', 'LEFT_TURN', 'LEFT_CHANGE', 'STRAIGHT']
72 |
73 | metric_tot_cnt = {speed + '_' + path: 0 for speed in SPEED_PLAN for path in PATH_PLAN}
74 | metric_correct_cnt = copy.deepcopy(metric_tot_cnt)
75 |
76 | eval_record = {}
77 |
78 | speed_tp = {
79 | 'KEEP': 0,
80 | 'ACCELERATE': 0,
81 | 'DECELERATE': 0,
82 | 'STOP': 0,
83 | }
84 |
85 | speed_fp, speed_fn = copy.deepcopy(speed_tp), copy.deepcopy(speed_tp)
86 |
87 | path_tp = {
88 | 'RIGHT_TURN': 0,
89 | 'LEFT_TURN': 0,
90 | 'RIGHT_CHANGE': 0,
91 | 'LEFT_CHANGE': 0,
92 | 'STRAIGHT': 0,
93 | }
94 |
95 | path_fp, path_fn = copy.deepcopy(path_tp), copy.deepcopy(path_tp)
96 |
97 | f1_score = {
98 | 'KEEP': 0,
99 | 'ACCELERATE': 0,
100 | 'DECELERATE': 0,
101 | 'STOP': 0,
102 | 'RIGHT_TURN': 0,
103 | 'LEFT_TURN': 0,
104 | 'RIGHT_CHANGE': 0,
105 | 'LEFT_CHANGE': 0,
106 | 'STRAIGHT': 0,
107 | }
108 |
109 | for sample in tqdm(eval_dataset['validation']):
110 |
111 | tot_num = tot_num + 1
112 |
113 | gt_answer = sample['solution']
114 |
115 | messages = convert_example(sample)
116 |
117 | # Preparation for inference
118 | text = processor.apply_chat_template(
119 | messages, tokenize=False, add_generation_prompt=True
120 | )
121 | image_inputs, video_inputs = process_vision_info(messages)
122 | inputs = processor(
123 | text=[text],
124 | images=image_inputs,
125 | videos=video_inputs,
126 | padding=True,
127 | return_tensors="pt",
128 | )
129 | inputs = inputs.to("cuda")
130 |
131 | # Inference: Generation of the output
132 | generated_ids = model.generate(**inputs, max_new_tokens=512)
133 | generated_ids_trimmed = [
134 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
135 | ]
136 | answer = processor.batch_decode(
137 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
138 | )[0]
139 | answer = answer.strip()
140 |
141 | gt_answer = gt_answer.replace(' ', '')
142 | gt_answer = gt_answer.replace(' ', '')
143 | speed_plan, path_plan = gt_answer.split(', ')
144 | path_plan = path_plan.split('\n')[0]
145 |
146 |
147 | for key in speed_tp.keys():
148 | if key in answer: # P
149 | if speed_plan in answer:
150 | speed_tp[key] += 1 # TP
151 | else:
152 | speed_fp[key] += 1 # FP
153 | else: # N
154 | if key in speed_plan:
155 | speed_fn[key] += 1 # FN
156 |
157 | for key in path_tp.keys():
158 | if key in answer: # P
159 | if path_plan in answer:
160 | path_tp[key] += 1 # TP
161 | else:
162 | path_fp[key] += 1 # FP
163 | else: # N
164 | if key in path_plan:
165 | path_fn[key] += 1 # FN
166 |
167 |
168 | metric_tot_cnt[speed_plan+'_'+path_plan] += 1
169 | if speed_plan in answer and path_plan in answer:
170 | correct_num = correct_num + 1
171 | metric_correct_cnt[speed_plan+'_'+path_plan] += 1
172 | else:
173 | fail_case = {
174 | 'gt': gt_answer,
175 | 'pred': answer,
176 | }
177 |
178 |
179 | for key in f1_score.keys():
180 | if key in speed_tp.keys():
181 | if speed_tp[key] + speed_fp[key] != 0:
182 | precision = speed_tp[key] / (speed_tp[key] + speed_fp[key])
183 | else:
184 | precision = 0
185 | if speed_tp[key] + speed_fn[key] != 0:
186 | recall = speed_tp[key] / (speed_tp[key] + speed_fn[key])
187 | else:
188 | recall = 0
189 | if precision + recall != 0:
190 | f1_score[key] = 2.0 * precision * recall / (precision + recall)
191 | else:
192 | f1_score[key] = 0
193 | if key in path_tp.keys():
194 | if path_tp[key] + path_tp[key] != 0:
195 | precision = path_tp[key] / (path_tp[key] + path_fp[key])
196 | else:
197 | precision = 0
198 | if path_tp[key] + path_fn[key] != 0:
199 | recall = path_tp[key] / (path_tp[key] + path_fn[key])
200 | else:
201 | recall = 0
202 | if precision + recall != 0:
203 | f1_score[key] = 2.0 * precision * recall / (precision + recall)
204 | else:
205 | f1_score[key] = 0
206 |
207 | print("\n\n=========== F1 Score ===========\n\n")
208 | for k, v in f1_score.items():
209 | print(f"{k}: {v}")
210 | print("\n\n================================\n\n")
211 |
212 | print(f'\nTotal Number: {tot_num}\n')
213 | print(f'\nCorrect Number: {correct_num}\n')
214 |
215 | print('\n------------------------------\n\n')
216 | print(f"Planning Accuracy: {correct_num/tot_num * 100:.2f}%")
217 | print('\n\n------------------------------\n')
218 |
219 | for key in metric_tot_cnt.keys():
220 | if metric_tot_cnt[key] > 0:
221 | print(f"{key}: num: {metric_tot_cnt[key]}, correct num: {metric_correct_cnt[key]}, {100*metric_correct_cnt[key]/metric_tot_cnt[key]:.2f}%")
222 |
223 | eval_record['summary'] = f'Total Number: {tot_num}'
224 | eval_record['summary'] = eval_record['summary'] + '\n' + f'Correct Number: {correct_num}'
225 | eval_record['summary'] = eval_record['summary'] + '\n' + f"Planning Accuracy: {correct_num/tot_num * 100:.2f}%"
226 |
227 | for key in metric_tot_cnt.keys():
228 | if metric_tot_cnt[key] > 0:
229 | eval_record['summary'] = eval_record['summary'] + '\n' + \
230 | f"{key}: num: {metric_tot_cnt[key]}, correct num: {metric_correct_cnt[key]}, {100*metric_correct_cnt[key]/metric_tot_cnt[key]:.2f}%"
231 |
232 | eval_record['f1_score'] = {}
233 | for k, v in f1_score.items():
234 | eval_record['f1_score'][k] = v
235 |
236 | with open(save_path, "w") as f:
237 | json.dump(eval_record, f)
238 | print(f'\nEval results saved to {save_path}\n')
239 |
240 |
241 | if __name__ == "__main__":
242 | main()
243 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | # Install the packages in r1-v .
2 | cd src/r1-v
3 | pip install -e ".[dev]"
4 |
5 | # Addtional modules
6 | pip install wandb==0.18.3
7 | pip install tensorboardx
8 | pip install qwen_vl_utils torchvision
9 | pip install flash-attn --no-build-isolation
10 |
11 | # vLLM support
12 | pip install vllm==0.7.2
13 |
14 | # fix transformers version
15 | pip install git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef
--------------------------------------------------------------------------------
/src/distill_r1/README.md:
--------------------------------------------------------------------------------
1 | # R1 Reasoning Dataset Generation
2 |
3 |
4 |
5 | ## QA Pairs Generation
6 |
7 | We create a `scene description` by combining the objects (with meta info such as location, depth) using a template.
8 |
9 | We keep the couting relevant questions and add a `How many items are there in the described scene?` question to count all objects in the scene.
10 |
11 | Example QA pair:
12 |
13 | ```json
14 | {'img_filename': 'CLEVR_trainA_048403.png',
15 | 'question': 'How many things are both on the right side of the big yellow rubber thing and left of the purple ball?',
16 | 'answer': '5',
17 | 'description': 'Scene Description:\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\n'}
18 | ```
19 |
20 | See `generate_scene_qa_pairs.ipynb` for details.
21 |
22 | ## R1 Response Generation
23 |
24 | `query_r1.py`: We get the predictions and obtain the reasoning trace of R1 by querying SilconFlow (You can register via https://cloud.siliconflow.cn/i/rYfeQZ6A so both we 20M token credit!)
25 |
26 | Example Response:
27 |
28 | ```json
29 | Okay, let's see. The user is asking how many items are there in the described scene. Let me go through the scene description step by step.
30 |
31 | So, the scene description lists each object with details like color, material, shape, rotation, 3D coordinates, and pixel coordinates. Each entry starts with "A" which usually indicates one item each. Let me count each one.
32 |
33 | First entry: "A small green metal cylinder..." That's one. Second: "A small blue rubber cylinder..." Second item. Third: "A small cyan rubber cylinder..." That's three. Fourth: "A large cyan metal sphere..." Four. Fifth: "A large brown metal cube..." Five. Sixth: "A large yellow rubber cube..." Six. Seventh: "A large brown rubber cylinder..." That's seven.
34 |
35 | Wait, did I miss any? Let me check again. The list has entries from "A small green..." up to the seventh one. Each sentence starts with "A", which suggests each is a separate item. No commas separating multiple items in a single entry. Each has different attributes and coordinates, so they must all be distinct.
36 |
37 | So the answer should be 7 items.
38 |
39 |
40 | There are 7 items in the described scene. Each entry corresponds to one distinct object, listed by their properties, coordinates, and rotations.
41 | ```
42 |
43 | In `v1` we do not constrain the output format and thus the answers are somewhat messy to parse. We then switched to `v2` by explicitly prompting the model to generate the answer with `**The answer is: **`
44 |
45 | ## Reasoning Path Filtering
46 |
47 | `filter_r1.py`: We filter out (almost) valid reasoning trace, by juding whether the R1 answer is correct (following our previous work [Math-Shepherd](https://arxiv.org/abs/2312.08935)).
48 |
49 | ## HF dataset creation
50 |
51 | Finally, we create the dataset using `create_hf_dataset.py` and upload to HF dataset hub.
52 |
53 |
54 |
55 |
--------------------------------------------------------------------------------
/src/distill_r1/create_hf_dataset.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import random
4 | from datasets import load_dataset
5 | from tqdm import tqdm
6 |
7 | random.seed(1234)
8 | VAL_NUM = 5000
9 |
10 |
11 | def create_r1_train_dataset(
12 | valid_pair_json,
13 | data_dir,
14 | img_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/trainA/",
15 | ):
16 | os.makedirs(data_dir, exist_ok=True)
17 | pairs = [json.loads(line) for line in open(valid_pair_json, "r")]
18 | mapped_pairs = []
19 |
20 | for idx, pair in tqdm(enumerate(pairs)):
21 | img_filename = pair["img_filename"]
22 | new_pair = {}
23 | try:
24 | new_pair["thinking"] = (
25 | pair["r1_response"]
26 | .split("")[1]
27 | .split("")[0]
28 | .replace("scene description", "image")
29 | )
30 | except Exception as e:
31 | print(f"Error processing pair response: ", pair["r1_response"])
32 | continue # skip this pair
33 | # add index to distinguish the same image
34 | dataset_filename = (
35 | img_filename.split(".")[0] + "_" + str(idx) + "." + img_filename.split(".")[1]
36 | )
37 | if not os.path.exists(f"{data_dir}/{img_filename}"):
38 | os.system(f"cp {img_dir}/{img_filename} {data_dir}/{dataset_filename}")
39 | q, a = pair["q"], pair["a"]
40 | new_pair["problem"] = q
41 | # get the thinking path
42 |
43 | new_pair["thinking"] = "" + new_pair["thinking"] + ""
44 | new_pair["solution"] = f" {a} "
45 | new_pair["file_name"] = dataset_filename
46 | mapped_pairs.append(new_pair)
47 | with open(f"{data_dir}/metadata.jsonl", "w") as f:
48 | for pair in mapped_pairs:
49 | f.write(json.dumps(pair) + "\n")
50 |
51 | train_dataset = load_dataset(
52 | "imagefolder",
53 | data_dir=data_dir,
54 | split="train",
55 | )
56 | return train_dataset
57 |
58 |
59 | def create_val_dataset(
60 | json_file,
61 | data_dir,
62 | val_num=VAL_NUM,
63 | image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
64 | ):
65 | os.makedirs(data_dir, exist_ok=True)
66 | val = json.load(open(json_file))
67 | random.shuffle(val)
68 | val = val[:val_num]
69 | val_pairs = []
70 | for idx, pair in tqdm(enumerate(val)):
71 | q, a = pair["q"], pair["a"]
72 | img_filename = pair["img_filename"]
73 | # copy images to the DATA_DIR
74 | val_filename = (
75 | img_filename.split(".")[0] + f"_{idx}." + img_filename.split(".")[1]
76 | )
77 | if not os.path.exists(f"{data_dir}/{img_filename}"):
78 | os.system(f"cp {image_dir}/{img_filename} {data_dir}/{val_filename}")
79 | new_pair = {}
80 | new_pair["problem"] = q
81 | new_pair["solution"] = f" {a} "
82 | new_pair["file_name"] = val_filename
83 | val_pairs.append(new_pair)
84 | with open(f"{data_dir}/metadata.jsonl", "w") as f:
85 | for pair in val_pairs:
86 | f.write(json.dumps(pair) + "\n")
87 | val_dataset = load_dataset("imagefolder", data_dir=data_dir, split="train")
88 | return val_dataset
89 |
90 |
91 | # valA split
92 | VALA_DATA_DIR = "data/Clevr_CoGenT_ValA"
93 | VALB_DATA_DIR = "data/Clevr_CoGenT_ValB"
94 | valA_json = (
95 | "/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valA.json"
96 | )
97 | valB_json = (
98 | "/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_valB.json"
99 | )
100 | TRAIN_DATADIR = "data/Clevr_CoGenT_TrainA_R1"
101 | train_dataset = create_r1_train_dataset(
102 | "/home/lilei/Visual-R1/filter_results_v2/valid_pairs.jsonl",
103 | TRAIN_DATADIR,
104 | )
105 |
106 | # print(train_dataset)
107 | valA_dataset = create_val_dataset(
108 | valA_json,
109 | VALA_DATA_DIR,
110 | image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valA",
111 | )
112 | valB_dataset = create_val_dataset(
113 | valB_json,
114 | VALB_DATA_DIR,
115 | image_dir="/home/lilei/Visual-R1/CLEVR_CoGenT_v1.0/images/valB",
116 | )
117 | valA_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValA")
118 | valB_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_ValB")
119 | train_dataset.push_to_hub("MMInstruction/Clevr_CoGenT_TrainA_R1")
120 |
--------------------------------------------------------------------------------
/src/distill_r1/filter_r1.py:
--------------------------------------------------------------------------------
1 | import json
2 | import re
3 | from pathlib import Path
4 |
5 |
6 |
7 | def extract_answer_from_query(query_results: str) -> str | None:
8 | """
9 | Extract answer from query results, specifically looking for:
10 | - Numbers within asterisks
11 | - Yes/No answers in various formats
12 |
13 | Args:
14 | query_results: String containing the query response
15 |
16 | Returns:
17 | Extracted answer string or None if no answer found
18 | """
19 | # First try to find answers in the standard format with labels
20 | # Split the text into segments (trying to get the last conclusion)
21 | if "" not in query_results or "" not in query_results:
22 | return None
23 | segments = query_results.split("\n")
24 |
25 | # First try to find final conclusion in the last few segments
26 | conclusion_patterns = [
27 | r"(?:so|therefore|thus|hence),?\s*(?:the answer is\s+)?\*\*\s*(no|yes|[0-9]+)\s*\*\*",
28 | r"(?:so|therefore|thus|hence),?\s*(?:the answer is\s+)?(no|yes|[0-9]+)\b",
29 | r"the answer is\s+\*\*\s*(no|yes|[0-9]+)\s*\*\*",
30 | r"(?:final|conclusive) answer(?:\s+is)?\s*\*\*\s*(no|yes|[0-9]+)\s*\*\*",
31 | ]
32 |
33 | # Try to find conclusion in last 3 segments
34 | for segment in reversed(segments[-3:]):
35 | for pattern in conclusion_patterns:
36 | match = re.search(pattern, segment, re.IGNORECASE)
37 | if match:
38 | return match.group(1).strip().lower()
39 |
40 | # If no conclusion found, try other patterns on the full text
41 | labeled_patterns = [
42 | r"\*\*The answer is:\s*\*\*\s*([0-9]+|yes|no)\b",
43 | r"\*\*Answer:\s*\*\*\s*([0-9]+|yes|no)\b",
44 | r"\*\*Answer\*\*:\s*([0-9]+|yes|no)\b",
45 | r"\*\*Answer:?\s*\*\*\s*There (?:is|are)\s+([0-9]+)",
46 | r"\*\*Final Count:\s*\*\*\s*([0-9]+)",
47 | r"\*\*Final Count:\s*\*\*\s*([0-9]+)\s+(?:items?|objects?|spheres?|cubes?|boxes?)",
48 | r"\*\*Total:\s*\*\*\s*([0-9]+)",
49 | r"The answer is:\s*([0-9]+|yes|no)\b",
50 | r"Answer:\s*([0-9]+|yes|no)\b",
51 | r"should be\s+([0-9]+)[.\s]",
52 | ]
53 |
54 | direct_patterns = [
55 | r"\*\*\s*([0-9]+)\s*\*\*",
56 | r"\*\*\s*([0-9]+)\s+(?:items?|objects?|cubes?|boxes?|spheres?)?\s*\*\*",
57 | r"\*\*\s*([0-9]+)\s+[^*]+\*\*",
58 | ]
59 |
60 | latex_patterns = [
61 | r"\$\\boxed{([0-9]+)}\$",
62 | r"\\boxed{([0-9]+)}",
63 | ]
64 |
65 | count_patterns = [
66 | r"There (?:is|are)\s+([0-9]+)\s+(?:items?|objects?|spheres?|cubes?|boxes?)",
67 | ]
68 |
69 | # Try all patterns in sequence on full text
70 | all_patterns = labeled_patterns + direct_patterns + latex_patterns + count_patterns
71 |
72 | for pattern in all_patterns:
73 | match = re.search(pattern, query_results, re.IGNORECASE)
74 | if match:
75 | return match.group(1).strip().lower()
76 |
77 | return None
78 |
79 |
80 | def validate_qa_pairs(input_file: str, output_dir: str, verbose: bool = True):
81 | """
82 | Process QA pairs and save them to separate files.
83 | Only saves pairs where parsed answer matches ground truth.
84 |
85 | Args:
86 | input_file: Path to input JSONL file
87 | output_dir: Directory to save output files
88 | verbose: If True, print examples of mismatched or unparseable responses
89 | """
90 | output_dir = Path(output_dir)
91 | output_dir.mkdir(parents=True, exist_ok=True)
92 |
93 | valid_pairs = []
94 | invalid_pairs = []
95 | stats = {"total": 0, "unparseable": 0, "mismatch": 0, "valid": 0}
96 |
97 | with open(input_file, "r", encoding="utf-8") as f:
98 | for line_num, line in enumerate(f, 1):
99 | stats["total"] += 1
100 | qa_pair = json.loads(line.strip())
101 | ground_truth = str(qa_pair.get("a", "")).lower().strip()
102 | parsed_answer = extract_answer_from_query(qa_pair["r1_response"])
103 |
104 | if parsed_answer is None:
105 | stats["unparseable"] += 1
106 | qa_pair["error"] = "unparseable"
107 | invalid_pairs.append(qa_pair)
108 | if verbose:
109 | print(f"\nLine {line_num}: Could not parse answer")
110 | print(f"Ground truth: {ground_truth}")
111 | print(f"Query results: {qa_pair['r1_response'][-200:]}...")
112 | elif parsed_answer != ground_truth:
113 | stats["mismatch"] += 1
114 | qa_pair["error"] = "mismatch"
115 | qa_pair["parsed_answer"] = parsed_answer
116 | invalid_pairs.append(qa_pair)
117 | if verbose:
118 | print(f"\nLine {line_num}: Answer mismatch")
119 | print(f"Ground truth: {ground_truth}")
120 | print(f"Parsed answer: {parsed_answer}")
121 | print(f"Query results: {qa_pair['r1_response'][-200:]}...")
122 | else:
123 | stats["valid"] += 1
124 | valid_pairs.append(qa_pair)
125 |
126 | # Save valid pairs (where parsed answer matches ground truth)
127 | valid_file = output_dir / "valid_pairs.jsonl"
128 | with open(valid_file, "w", encoding="utf-8") as f:
129 | for pair in valid_pairs:
130 | f.write(json.dumps(pair, ensure_ascii=False) + "\n")
131 |
132 | # Save invalid pairs (unparseable or mismatched)
133 | invalid_file = output_dir / "invalid_pairs.jsonl"
134 | with open(invalid_file, "w", encoding="utf-8") as f:
135 | for pair in invalid_pairs:
136 | f.write(json.dumps(pair, ensure_ascii=False) + "\n")
137 |
138 | # Print statistics
139 | print(f"\nProcessing Summary:")
140 | print(f"Total pairs processed: {stats['total']}")
141 | print(f"Valid pairs (matching ground truth): {stats['valid']}")
142 | print(f"Invalid pairs: {stats['unparseable'] + stats['mismatch']}")
143 | print(f" - Unparseable: {stats['unparseable']}")
144 | print(f" - Answer mismatch: {stats['mismatch']}")
145 | print(f"\nOutput files:")
146 | print(f"Valid pairs saved to: {valid_file}")
147 | print(f"Invalid pairs saved to: {invalid_file}")
148 |
149 |
150 | if __name__ == "__main__":
151 | validate_qa_pairs(
152 | "r1_results_clevr_cogent_v1.0_trainA_v2.jsonl", "filter_results_v2"
153 | ) # "filtered_output_tmp_v1.jsonl")
154 |
--------------------------------------------------------------------------------
/src/distill_r1/generate_scene_qa_pairs.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "id": "3a704ea6-2e61-4aaa-97aa-416579c9bc13",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import json\n",
11 | "import random"
12 | ]
13 | },
14 | {
15 | "cell_type": "code",
16 | "execution_count": 4,
17 | "id": "c4920a8f-cddd-4063-8cab-215d238b5dad",
18 | "metadata": {},
19 | "outputs": [
20 | {
21 | "name": "stdout",
22 | "output_type": "stream",
23 | "text": [
24 | "CLEVR_trainA_scenes.json CLEVR_valA_scenes.json CLEVR_valB_scenes.json\n"
25 | ]
26 | }
27 | ],
28 | "source": [
29 | "!ls CLEVR_CoGenT_v1.0/scenes"
30 | ]
31 | },
32 | {
33 | "cell_type": "code",
34 | "execution_count": 15,
35 | "id": "934fa005-3b2a-43ed-8a71-6a12b7579546",
36 | "metadata": {},
37 | "outputs": [],
38 | "source": [
39 | "split = \"valB\"\n",
40 | "clevr_train_json = f\"CLEVR_CoGenT_v1.0/scenes/CLEVR_{split}_scenes.json\"\n",
41 | "train_qs = f\"CLEVR_CoGenT_v1.0/questions/CLEVR_{split}_questions.json\"\n",
42 | "data = json.load(open(clevr_train_json))\n",
43 | "qs = json.load(open(train_qs))"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 16,
49 | "id": "1f0d6180-94c4-4aea-bd2b-8d5cfeb0aecb",
50 | "metadata": {},
51 | "outputs": [
52 | {
53 | "name": "stdout",
54 | "output_type": "stream",
55 | "text": [
56 | "[{'pixel_coords': [343, 131, 11.278693199157715], 'size': 'small', 'color': 'green', 'material': 'metal', 'shape': 'sphere', '3d_coords': [0.9906095862388611, 2.083291530609131, 0.3499999940395355], 'rotation': 107.73596690369371}, {'pixel_coords': [396, 172, 9.857704162597656], 'size': 'small', 'color': 'cyan', 'material': 'rubber', 'shape': 'sphere', '3d_coords': [2.69626522064209, 1.5257188081741333, 0.3499999940395355], 'rotation': 305.3536122513589}, {'pixel_coords': [115, 182, 8.91348934173584], 'size': 'large', 'color': 'yellow', 'material': 'rubber', 'shape': 'cylinder', '3d_coords': [0.049163494259119034, -2.864100217819214, 0.699999988079071], 'rotation': 161.8370138842408}, {'pixel_coords': [203, 131, 10.548327445983887], 'size': 'large', 'color': 'purple', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-0.4719269275665283, -0.5699371695518494, 0.699999988079071], 'rotation': 159.41862667811446}, {'pixel_coords': [253, 75, 13.141877174377441], 'size': 'large', 'color': 'red', 'material': 'rubber', 'shape': 'cube', '3d_coords': [-2.036878824234009, 2.222999334335327, 0.699999988079071], 'rotation': 37.40490732771224}]\n",
57 | "len: 5\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "print(data['scenes'][0]['objects'])\n",
63 | "print(\"len: \", len(data['scenes'][0]['objects']))"
64 | ]
65 | },
66 | {
67 | "cell_type": "code",
68 | "execution_count": 17,
69 | "id": "7c828ca4-08f9-4927-a745-224a95379c2f",
70 | "metadata": {},
71 | "outputs": [],
72 | "source": [
73 | "def object_info_to_description(object_list):\n",
74 | " descriptions = []\n",
75 | " random.shuffle(object_list)\n",
76 | " for obj in object_list:\n",
77 | " desc = f\"A {obj['size']} {obj['color']} {obj['material']} {obj['shape']}\"\n",
78 | " desc += f\" rotated {obj['rotation']:.1f}° located at\"\n",
79 | " desc += f\" 3D coordinates ({obj['3d_coords'][0]:.2f}, {obj['3d_coords'][1]:.2f}, {obj['3d_coords'][2]:.2f})\"\n",
80 | " desc += f\" and pixel coordinates ({obj['pixel_coords'][0]}, {obj['pixel_coords'][1]}, {obj['pixel_coords'][2]:.2f})\"\n",
81 | " descriptions.append(desc)\n",
82 | " \n",
83 | " final_description = \"Scene Description:\\n\"\n",
84 | " for i, desc in enumerate(descriptions, 1):\n",
85 | " final_description += f\"{desc}\\n\"\n",
86 | " \n",
87 | " return final_description"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 18,
93 | "id": "cb048e25-d554-4bd7-bf11-878e071b5987",
94 | "metadata": {},
95 | "outputs": [
96 | {
97 | "data": {
98 | "text/plain": [
99 | "'Scene Description:\\nA large yellow rubber cylinder rotated 161.8° located at 3D coordinates (0.05, -2.86, 0.70) and pixel coordinates (115, 182, 8.91)\\nA large purple rubber cube rotated 159.4° located at 3D coordinates (-0.47, -0.57, 0.70) and pixel coordinates (203, 131, 10.55)\\nA large red rubber cube rotated 37.4° located at 3D coordinates (-2.04, 2.22, 0.70) and pixel coordinates (253, 75, 13.14)\\nA small green metal sphere rotated 107.7° located at 3D coordinates (0.99, 2.08, 0.35) and pixel coordinates (343, 131, 11.28)\\nA small cyan rubber sphere rotated 305.4° located at 3D coordinates (2.70, 1.53, 0.35) and pixel coordinates (396, 172, 9.86)\\n'"
100 | ]
101 | },
102 | "execution_count": 18,
103 | "metadata": {},
104 | "output_type": "execute_result"
105 | }
106 | ],
107 | "source": [
108 | "object_info_to_description(data['scenes'][0]['objects'])"
109 | ]
110 | },
111 | {
112 | "cell_type": "code",
113 | "execution_count": 19,
114 | "id": "ffacd5f3-e9a4-46ca-8c50-187ab12c9f1b",
115 | "metadata": {},
116 | "outputs": [],
117 | "source": [
118 | "img2obj_dict = {}\n",
119 | "for scene in data['scenes']:\n",
120 | " obj_list = scene['objects']\n",
121 | " img2obj_dict[scene['image_filename']] = obj_list"
122 | ]
123 | },
124 | {
125 | "cell_type": "code",
126 | "execution_count": 20,
127 | "id": "db35f03c-1529-4776-bf4f-3bd44e960e5f",
128 | "metadata": {},
129 | "outputs": [
130 | {
131 | "data": {
132 | "text/plain": [
133 | "{'question_index': 0,\n",
134 | " 'question_family_index': 29,\n",
135 | " 'image_index': 0,\n",
136 | " 'question': 'The big thing that is in front of the large rubber cube in front of the small thing that is behind the tiny matte ball is what color?',\n",
137 | " 'answer': 'yellow',\n",
138 | " 'image_filename': 'CLEVR_valB_000000.png',\n",
139 | " 'split': 'valB',\n",
140 | " 'program': [{'value_inputs': [], 'inputs': [], 'function': 'scene'},\n",
141 | " {'value_inputs': ['small'], 'inputs': [0], 'function': 'filter_size'},\n",
142 | " {'value_inputs': ['rubber'], 'inputs': [1], 'function': 'filter_material'},\n",
143 | " {'value_inputs': ['sphere'], 'inputs': [2], 'function': 'filter_shape'},\n",
144 | " {'value_inputs': [], 'inputs': [3], 'function': 'unique'},\n",
145 | " {'value_inputs': ['behind'], 'inputs': [4], 'function': 'relate'},\n",
146 | " {'value_inputs': ['small'], 'inputs': [5], 'function': 'filter_size'},\n",
147 | " {'value_inputs': [], 'inputs': [6], 'function': 'unique'},\n",
148 | " {'value_inputs': ['front'], 'inputs': [7], 'function': 'relate'},\n",
149 | " {'value_inputs': ['large'], 'inputs': [8], 'function': 'filter_size'},\n",
150 | " {'value_inputs': ['rubber'], 'inputs': [9], 'function': 'filter_material'},\n",
151 | " {'value_inputs': ['cube'], 'inputs': [10], 'function': 'filter_shape'},\n",
152 | " {'value_inputs': [], 'inputs': [11], 'function': 'unique'},\n",
153 | " {'value_inputs': ['front'], 'inputs': [12], 'function': 'relate'},\n",
154 | " {'value_inputs': ['large'], 'inputs': [13], 'function': 'filter_size'},\n",
155 | " {'value_inputs': [], 'inputs': [14], 'function': 'unique'},\n",
156 | " {'value_inputs': [], 'inputs': [15], 'function': 'query_color'}]}"
157 | ]
158 | },
159 | "execution_count": 20,
160 | "metadata": {},
161 | "output_type": "execute_result"
162 | }
163 | ],
164 | "source": [
165 | "qs['questions'][0]"
166 | ]
167 | },
168 | {
169 | "cell_type": "code",
170 | "execution_count": 21,
171 | "id": "66b746fc-569c-4922-a442-79dbbc09e33b",
172 | "metadata": {},
173 | "outputs": [],
174 | "source": [
175 | "random.shuffle(qs['questions'])\n",
176 | "cnt = 0 \n",
177 | "qa_pairs = [] \n",
178 | "added_pair = set()\n",
179 | "for qd in qs['questions']:\n",
180 | " img_idx = qd['image_filename']\n",
181 | " total_count = len(img2obj_dict[img_idx]) # object list length\n",
182 | " desc = object_info_to_description(img2obj_dict[img_idx])\n",
183 | " question, answer = qd['question'], qd['answer']\n",
184 | " if 'how many' in question.lower() or 'number' in question.lower():\n",
185 | " qa_pairs.append({\n",
186 | " \"img_filename\": img_idx,\n",
187 | " 'q': question,\n",
188 | " 'a': answer,\n",
189 | " 'description': desc \n",
190 | " })\n",
191 | " if img_idx not in added_pair:\n",
192 | " qa_pairs.append({\n",
193 | " \"img_filename\": img_idx,\n",
194 | " 'q': \"How many items are there in the described scene?\",\n",
195 | " 'a': total_count,\n",
196 | " 'description': desc \n",
197 | " })\n",
198 | " added_pair.add(img_idx)\n"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 22,
204 | "id": "c271fa7b-fed5-472f-a302-6ec203c4b787",
205 | "metadata": {},
206 | "outputs": [
207 | {
208 | "data": {
209 | "text/plain": [
210 | "59978"
211 | ]
212 | },
213 | "execution_count": 22,
214 | "metadata": {},
215 | "output_type": "execute_result"
216 | }
217 | ],
218 | "source": [
219 | "len(qa_pairs)"
220 | ]
221 | },
222 | {
223 | "cell_type": "code",
224 | "execution_count": 23,
225 | "id": "b0da8a70-c3f5-4e48-b384-3684933d72ef",
226 | "metadata": {},
227 | "outputs": [
228 | {
229 | "data": {
230 | "text/plain": [
231 | "14884"
232 | ]
233 | },
234 | "execution_count": 23,
235 | "metadata": {},
236 | "output_type": "execute_result"
237 | }
238 | ],
239 | "source": [
240 | "len(added_pair)"
241 | ]
242 | },
243 | {
244 | "cell_type": "code",
245 | "execution_count": 24,
246 | "id": "c648587e-2ec0-427c-b594-f55dd187b4d9",
247 | "metadata": {},
248 | "outputs": [],
249 | "source": [
250 | "# save for later loading\n",
251 | "with open(f\"clever_counting_problems_clevr_cogent_v1.0_{split}.json\", 'w') as fw:\n",
252 | " json.dump( qa_pairs, fw, indent=4)"
253 | ]
254 | },
255 | {
256 | "cell_type": "code",
257 | "execution_count": 20,
258 | "id": "b3a8cbe4-4261-41d3-a481-43a0b1cc2795",
259 | "metadata": {},
260 | "outputs": [],
261 | "source": [
262 | "random.shuffle(qa_pairs)"
263 | ]
264 | },
265 | {
266 | "cell_type": "code",
267 | "execution_count": 57,
268 | "id": "d6dff4e7-65dd-4e82-82df-340ec2a57919",
269 | "metadata": {},
270 | "outputs": [
271 | {
272 | "data": {
273 | "text/plain": [
274 | "[{'img_filename': 'CLEVR_trainA_048403.png',\n",
275 | " 'q': 'How many things are both on the right side of the big yellow rubber thing and left of the purple ball?',\n",
276 | " 'a': '5',\n",
277 | " 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'},\n",
278 | " {'img_filename': 'CLEVR_trainA_048403.png',\n",
279 | " 'q': 'How many items are there in the described scene?',\n",
280 | " 'a': 9,\n",
281 | " 'description': 'Scene Description:\\nA large red rubber cylinder rotated 291.3° located at 3D coordinates (-0.89, -2.73, 0.70) and pixel coordinates (101, 152, 10.04)\\nA small purple metal sphere rotated 247.7° located at 3D coordinates (2.93, 0.87, 0.35) and pixel coordinates (379, 183, 9.66)\\nA large cyan rubber cylinder rotated 114.5° located at 3D coordinates (-2.40, 2.23, 0.70) and pixel coordinates (246, 82, 13.94)\\nA small red metal cylinder rotated 109.9° located at 3D coordinates (-0.95, 1.77, 0.35) and pixel coordinates (270, 113, 12.83)\\nA small red rubber cylinder rotated 343.7° located at 3D coordinates (-0.12, -0.74, 0.35) and pixel coordinates (209, 153, 10.82)\\nA large red rubber cylinder rotated 324.5° located at 3D coordinates (-2.71, -2.21, 0.70) and pixel coordinates (84, 119, 11.59)\\nA small red metal cylinder rotated 1.1° located at 3D coordinates (2.88, -0.12, 0.35) and pixel coordinates (342, 200, 9.12)\\nA small gray rubber cube rotated 144.9° located at 3D coordinates (0.79, 0.98, 0.35) and pixel coordinates (299, 145, 11.19)\\nA large yellow rubber cube rotated 90.0° located at 3D coordinates (-1.78, -0.31, 0.70) and pixel coordinates (180, 110, 12.05)\\n'}]"
282 | ]
283 | },
284 | "execution_count": 57,
285 | "metadata": {},
286 | "output_type": "execute_result"
287 | }
288 | ],
289 | "source": [
290 | "qa_pairs[:2]"
291 | ]
292 | },
293 | {
294 | "cell_type": "code",
295 | "execution_count": 26,
296 | "id": "a6a66364-5b47-4138-91d6-a045404d21b1",
297 | "metadata": {},
298 | "outputs": [],
299 | "source": [
300 | "def query_r1(query='who are you?', model=\"deepseek-ai/DeepSeek-R1\"):\n",
301 | " # Create the chat completion\n",
302 | " response = client.chat.completions.create(\n",
303 | " model=model,\n",
304 | " messages=[\n",
305 | " {'role': 'user', \n",
306 | " 'content': query}\n",
307 | " ],\n",
308 | " stream=False,\n",
309 | " )\n",
310 | " # Print the response\n",
311 | " return response.choices[0].message.content"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 44,
317 | "id": "e5d5649f-c4e3-4f3f-b76e-7f7ed27f68e8",
318 | "metadata": {},
319 | "outputs": [],
320 | "source": [
321 | "def format_query(qa_dict):\n",
322 | " query = \"Answer the question according to scene description.\\n\\n\"\n",
323 | " query += qa_dict['description']\n",
324 | " query += f\"\\nQuestion:\\n{qa_dict['q']}\"\n",
325 | " return query \n",
326 | " "
327 | ]
328 | },
329 | {
330 | "cell_type": "code",
331 | "execution_count": 39,
332 | "id": "7f568a4e-f217-464a-8329-bbefb64d9653",
333 | "metadata": {},
334 | "outputs": [
335 | {
336 | "name": "stdout",
337 | "output_type": "stream",
338 | "text": [
339 | "Okay, let's see. The user is asking how many items are there in the described scene. Let me go through the scene description step by step.\n",
340 | "\n",
341 | "So, the scene description lists each object with details like color, material, shape, rotation, 3D coordinates, and pixel coordinates. Each entry starts with \"A\" which usually indicates one item each. Let me count each one.\n",
342 | "\n",
343 | "First entry: \"A small green metal cylinder...\" That's one. Second: \"A small blue rubber cylinder...\" Second item. Third: \"A small cyan rubber cylinder...\" That's three. Fourth: \"A large cyan metal sphere...\" Four. Fifth: \"A large brown metal cube...\" Five. Sixth: \"A large yellow rubber cube...\" Six. Seventh: \"A large brown rubber cylinder...\" That's seven. \n",
344 | "\n",
345 | "Wait, did I miss any? Let me check again. The list has entries from \"A small green...\" up to the seventh one. Each sentence starts with \"A\", which suggests each is a separate item. No commas separating multiple items in a single entry. Each has different attributes and coordinates, so they must all be distinct. \n",
346 | "\n",
347 | "So the answer should be 7 items.\n",
348 | "\n",
349 | "\n",
350 | "There are 7 items in the described scene. Each entry corresponds to one distinct object, listed by their properties, coordinates, and rotations.\n",
351 | "None\n"
352 | ]
353 | }
354 | ],
355 | "source": [
356 | "debug_query = format_query(qa_pairs[0])\n",
357 | "print(query_r1(debug_query))"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 41,
363 | "id": "cdc4231a-8ef4-4cf6-a575-d84ae7bbd0b5",
364 | "metadata": {},
365 | "outputs": [
366 | {
367 | "name": "stdout",
368 | "output_type": "stream",
369 | "text": [
370 | "Answer the question accordingly to scene description.\n",
371 | "\n",
372 | "Scene Description:\n",
373 | "A small green metal cylinder rotated 329.5° located at 3D coordinates (-2.49, -1.65, 0.35) and pixel coordinates (111, 132, 11.81)\n",
374 | "A small blue rubber cylinder rotated 312.2° located at 3D coordinates (-1.73, -2.91, 0.35) and pixel coordinates (76, 163, 10.57)\n",
375 | "A small cyan rubber cylinder rotated 48.4° located at 3D coordinates (-2.10, -0.22, 0.35) and pixel coordinates (172, 118, 12.41)\n",
376 | "A large cyan metal sphere rotated 27.4° located at 3D coordinates (1.52, -1.26, 0.70) and pixel coordinates (247, 181, 9.33)\n",
377 | "A large brown metal cube rotated 107.7° located at 3D coordinates (-0.73, 2.39, 0.70) and pixel coordinates (290, 92, 12.93)\n",
378 | "A large yellow rubber cube rotated 288.2° located at 3D coordinates (0.52, 0.63, 0.70) and pixel coordinates (279, 130, 11.09)\n",
379 | "A large brown rubber cylinder rotated 229.8° located at 3D coordinates (2.38, 0.38, 0.70) and pixel coordinates (343, 166, 9.77)\n",
380 | "\n",
381 | "Question:\n",
382 | "How many items are there in the described scene?\n"
383 | ]
384 | }
385 | ],
386 | "source": [
387 | "print(debug_query)"
388 | ]
389 | },
390 | {
391 | "cell_type": "code",
392 | "execution_count": 42,
393 | "id": "4cf90eb6-2cce-4e3d-8190-c44168a66dca",
394 | "metadata": {},
395 | "outputs": [
396 | {
397 | "data": {
398 | "text/plain": [
399 | "{'img_filename': 'CLEVR_train_044000.png',\n",
400 | " 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
401 | " 'a': '2',\n",
402 | " 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
403 | ]
404 | },
405 | "execution_count": 42,
406 | "metadata": {},
407 | "output_type": "execute_result"
408 | }
409 | ],
410 | "source": [
411 | "qa_pairs[1]"
412 | ]
413 | },
414 | {
415 | "cell_type": "code",
416 | "execution_count": 45,
417 | "id": "33fcd4eb-1f33-47d4-a453-76ef00e6d5d3",
418 | "metadata": {},
419 | "outputs": [
420 | {
421 | "name": "stdout",
422 | "output_type": "stream",
423 | "text": [
424 | "Okay, let's tackle this question. The user wants to know how many rubber objects are either small blue spheres or small things. Hmm, first, I need to parse each part of the question correctly.\n",
425 | "\n",
426 | "Let's go back to the scene description. The objects are listed with their attributes: size, color, material, shape, rotation, 3D and pixel coordinates. The key here is to filter the objects based on the criteria given.\n",
427 | "\n",
428 | "The question has two parts: \"small blue spheres\" OR \"small things\". So any rubber object that is either a small blue sphere or any small thing (regardless of other attributes) counts. But wait, do both categories need to be rubber? Because the question says \"rubber objects are either...\". So rubber is the material, and the condition is either being a small blue sphere or a small thing. So first, all rubber objects, and among them, count those that are either small blue spheres or small (regardless of color or shape). Wait, no. Let me re-read the question.\n",
429 | "\n",
430 | "The question is: \"How many rubber objects are either small blue spheres or small things?\" So rubber is the material. Within all rubber objects, count the ones that are either (1) small blue spheres OR (2) small things. Wait, does (2) being small things mean that even if they're small and of any color or shape, but they must be rubber?\n",
431 | "\n",
432 | "Yes, because the entire set is rubber objects. So first, select all objects where material is rubber. Then, within those, count how many are either (1) small, blue, sphere, or (2) small (any color or shape). Wait, but the structure is \"either X or Y\", where X is \"small blue sphere\" and Y is \"small things\". But \"small things\" would include all small objects, regardless of color and shape. However, since we've already narrowed it to rubber objects, \"small things\" here would be small rubber objects, regardless of color and shape.\n",
433 | "\n",
434 | "But wait, the condition is within rubber objects. So for the first part, small blue spheres (must check size, color, shape) and for the second part, small things (size is small, any color and shape, but since material is already rubber, that's covered). But wait, does the OR merge the two conditions, leading to rubber objects that are either (small blue spheres) or (small any-color any-shape).\n",
435 | "\n",
436 | "So the combined condition is: object is rubber AND ( (is small AND blue AND sphere) OR (is small) ). Wait, but if the condition for the second part is just \"small things\", which would imply any small object. But the entire group is rubber objects. So it's rubber objects that are small blue spheres OR rubber objects that are small (regardless of color or shape).\n",
437 | "\n",
438 | "Wait, no. Let's parse the sentence again: \"rubber objects are either small blue spheres or small things\". The \"either/or\" applies to \"small blue spheres\" and \"small things\". So, each rubber object has to be either (a small blue sphere) or (a small thing). However, \"small things\" here might refer to any small object regardless of other attributes. So if a rubber object is small, regardless of color or shape, it counts. But then, the first condition (small blue sphere) would also satisfy being a small thing. Wait, so there's an overlap. But when dealing with OR conditions, we have to avoid double-counting. So, the actual count is the number of small rubber objects (since any small rubber object is covered by the second part, which includes all small rubber objects, whether blue spheres or not) plus any objects that are small blue spheres but not rubber? But no, the question specifies \"rubber objects\", so we can ignore non-rubber ones.\n",
439 | "\n",
440 | "Wait, perhaps the wording is: \"rubber objects that are either small blue spheres or small things\". So \"small things\" here must reference other attributes. Wait, maybe there's ambiguity here. If the user is grouping \"small things\" as a separate category, regardless of being the other attributes. Let me try to approach this step by step.\n",
441 | "\n",
442 | "First, list all the rubber objects from the scene description:\n",
443 | "\n",
444 | "Looking through the list:\n",
445 | "\n",
446 | "1. A large purple rubber sphere ... location etc.\n",
447 | "So material rubber, large, purple, sphere.\n",
448 | "\n",
449 | "2. A large gray metal cube ... metal, so not rubber.\n",
450 | "\n",
451 | "3. A large purple metal sphere ... metal, not rubber.\n",
452 | "\n",
453 | "4. A large blue rubber sphere ... rubber, large, blue, sphere.\n",
454 | "\n",
455 | "5. A small brown rubber cube ... rubber, small, brown, cube.\n",
456 | "\n",
457 | "6. A small green rubber sphere ... rubber, small, green, sphere.\n",
458 | "\n",
459 | "So the rubber objects are items 1,4,5,6.\n",
460 | "\n",
461 | "Now, for each of these four rubber objects, check if they are either (small blue sphere) or (small things).\n",
462 | "\n",
463 | "Let's check each:\n",
464 | "\n",
465 | "1. Large purple rubber sphere. Size: large. So for the first condition (small blue sphere): no. For the second condition (small thing): size large, so no. Not included.\n",
466 | "\n",
467 | "4. Large blue rubber sphere. Large. So, even though it's blue and sphere, the size is large. So (small blue sphere: no, since size is large). (small thing: no.) So not included.\n",
468 | "\n",
469 | "5. Small brown rubber cube. Size is small. So regardless of color and shape (brown, cube), this is a small thing. So it meets the second condition. So count it.\n",
470 | "\n",
471 | "6. Small green rubber sphere. Size: small. So meets the second condition (small thing). But it's a sphere and green, not blue. So this counts under the second category.\n",
472 | "\n",
473 | "Now, what about the first condition: small blue spheres. Are there any rubber objects that are small blue spheres? The only blue rubber sphere is item 4, which is large. So none of the rubber objects are small blue spheres.\n",
474 | "\n",
475 | "So now the two options under the OR are:\n",
476 | "\n",
477 | "- small blue sphere (none found)\n",
478 | "- small things (items 5 and 6), both are rubber.\n",
479 | "\n",
480 | "Therefore total count is 0 + 2 = 2.\n",
481 | "\n",
482 | "Wait, but the initial question is: number of rubber objects that are either (small blue spheres) or (small things). Since \"small blue spheres\" are a subset of \"small things\" (since if something is a small blue sphere, it is a small thing). Therefore if there were any such objects, they would be counted twice, but in this case, since there are none, the total is just the count of small things.\n",
483 | "\n",
484 | "But according to the data, the rubber objects that are small are two: item 5 (small brown cube) and item 6 (small green sphere). Neither is a blue sphere, so no overlap. Therefore, sum is 2.\n",
485 | "\n",
486 | "So the answer should be 2.\n",
487 | "\n",
488 | "I need to confirm again whether the OR condition includes the union, meaning any rubber object that is either condition. So if either condition is true, count it.\n",
489 | "\n",
490 | "Since both 5 and 6 are small, even though they aren't blue spheres, they belong to the second category (small things). The blue sphere (if it were small) would have been counted as well. But in this case, the valid counts are 5 and 6. So yes, answer is 2.\n",
491 | "\n",
492 | "\n",
493 | "The answer is 2. \n",
494 | "\n",
495 | "**Explanation:**\n",
496 | "First, identify all rubber objects from the scene description:\n",
497 | "1. Large purple rubber sphere (not small)\n",
498 | "2. Large blue rubber sphere (not small)\n",
499 | "3. Small brown rubber cube (small)\n",
500 | "4. Small green rubber sphere (small)\n",
501 | "\n",
502 | "Next, apply the criteria:\n",
503 | "- **Small blue spheres**: None of the rubber objects meet this (the only blue rubber sphere is large).\n",
504 | "- **Small rubber objects (regardless of color/shape)**: The small brown rubber cube and small green rubber sphere qualify (2 objects).\n",
505 | "\n",
506 | "Thus, there are **2 rubber objects** that fit either criterion.\n"
507 | ]
508 | }
509 | ],
510 | "source": [
511 | "debug_query1 = format_query(qa_pairs[1])\n",
512 | "res1 = query_r1(debug_query1)"
513 | ]
514 | },
515 | {
516 | "cell_type": "code",
517 | "execution_count": 47,
518 | "id": "8e516bd0-f1e5-4898-88a3-3afcaf0ae34e",
519 | "metadata": {},
520 | "outputs": [
521 | {
522 | "data": {
523 | "text/plain": [
524 | "{'img_filename': 'CLEVR_train_044000.png',\n",
525 | " 'q': 'How many rubber objects are either small blue spheres or small things?',\n",
526 | " 'a': '2',\n",
527 | " 'description': 'Scene Description:\\nA large purple rubber sphere rotated 78.4° located at 3D coordinates (2.27, 0.87, 0.70) and pixel coordinates (360, 156, 9.49)\\nA large gray metal cube rotated 152.7° located at 3D coordinates (2.79, -1.26, 0.70) and pixel coordinates (301, 213, 7.91)\\nA large purple metal sphere rotated 79.2° located at 3D coordinates (-2.66, -2.74, 0.70) and pixel coordinates (51, 126, 10.61)\\nA large blue rubber sphere rotated 279.5° located at 3D coordinates (1.31, 2.72, 0.70) and pixel coordinates (376, 112, 11.19)\\nA small brown rubber cube rotated 124.1° located at 3D coordinates (-2.49, 2.61, 0.35) and pixel coordinates (251, 82, 13.79)\\nA small green rubber sphere rotated 323.9° located at 3D coordinates (-2.02, 0.45, 0.35) and pixel coordinates (197, 109, 12.22)\\n'}"
528 | ]
529 | },
530 | "execution_count": 47,
531 | "metadata": {},
532 | "output_type": "execute_result"
533 | }
534 | ],
535 | "source": [
536 | "qa_pairs[1]"
537 | ]
538 | },
539 | {
540 | "cell_type": "code",
541 | "execution_count": null,
542 | "id": "92784518-49e2-443d-9541-2785cbb944cf",
543 | "metadata": {},
544 | "outputs": [],
545 | "source": []
546 | }
547 | ],
548 | "metadata": {
549 | "kernelspec": {
550 | "display_name": "Python 3 (ipykernel)",
551 | "language": "python",
552 | "name": "python3"
553 | },
554 | "language_info": {
555 | "codemirror_mode": {
556 | "name": "ipython",
557 | "version": 3
558 | },
559 | "file_extension": ".py",
560 | "mimetype": "text/x-python",
561 | "name": "python",
562 | "nbconvert_exporter": "python",
563 | "pygments_lexer": "ipython3",
564 | "version": "3.12.2"
565 | }
566 | },
567 | "nbformat": 4,
568 | "nbformat_minor": 5
569 | }
570 |
--------------------------------------------------------------------------------
/src/distill_r1/grpo_r1_distilled.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/src/distill_r1/grpo_r1_distilled.jpg
--------------------------------------------------------------------------------
/src/distill_r1/prompt.py:
--------------------------------------------------------------------------------
1 | R1_SYS_PROMPT = """You are DeepSeek-R1, an AI assistant created exclusively by the Chinese Company DeepSeek. You'll provide helpful, harmless, and detailed responses to all user inquiries. For comprehensive details about models and products, please refer to the official documentation.
2 |
3 | Key Guidelines:
4 | Identity & Compliance
5 |
6 | Clearly state your identity as a DeepSeek AI assistant in initial responses.
7 |
8 | Comply with Chinese laws and regulations, including data privacy requirements.
9 |
10 | Capability Scope
11 |
12 | Handle both Chinese and English queries effectively
13 |
14 | Acknowledge limitations for real-time information post knowledge cutoff (2023-12)
15 |
16 | Provide technical explanations for AI-related questions when appropriate
17 |
18 | Response Quality
19 |
20 | Give comprehensive, logically structured answers
21 |
22 | Use markdown formatting for clear information organization
23 |
24 | Admit uncertainties for ambiguous queries
25 |
26 | Ethical Operation
27 |
28 | Strictly refuse requests involving illegal activities, violence, or explicit content
29 |
30 | Maintain political neutrality according to company guidelines
31 |
32 | Protect user privacy and avoid data collection
33 |
34 | Specialized Processing
35 |
36 | Use ... tags for internal reasoning before responding
37 |
38 | Employ XML-like tags for structured output when required
39 | """
--------------------------------------------------------------------------------
/src/distill_r1/query_r1.py:
--------------------------------------------------------------------------------
1 | import json
2 | import random
3 | import os
4 | from openai import OpenAI
5 | from tqdm import tqdm
6 | import concurrent.futures
7 | from typing import List, Dict, Optional
8 | from datetime import datetime
9 | from threading import Lock
10 | import time
11 | from prompt import R1_SYS_PROMPT
12 | # Initialize the client
13 | client = OpenAI(
14 | api_key=os.environ.get("SL_KEY", "YOUR_SILCONFLOW_KEY"),
15 | base_url="https://api.siliconflow.cn/v1",
16 | )
17 |
18 | # Create a lock for thread-safe file writing
19 | file_lock = Lock()
20 |
21 | def format_query(qa_dict: Dict, v2=False) -> str:
22 | query = "Answer the question according to scene description.\n\n"
23 | query += qa_dict["description"]
24 | query += f"\nQuestion:\n{qa_dict['q']}"
25 | if v2:
26 | query += "\nInstructions:\n"
27 | query += "1. Carefully analyze the scene description\n"
28 | query += "2. Provide your reasoning if necessary\n"
29 | query += "3. For the final answer, start a new line with '**The answer is: **' followed by your answer\n"
30 | return query
31 |
32 | def write_to_jsonl(result: Dict, filename: str):
33 | """Thread-safe function to write a result to JSONL file"""
34 | with file_lock:
35 | with open(filename, 'a') as f:
36 | f.write(json.dumps(result) + '\n')
37 |
38 | def query_r1(qa_pair: Dict, output_file: str, model: str = "deepseek-ai/DeepSeek-R1", v2=False) -> Optional[Dict]:
39 | query = format_query(qa_pair, v2=v2)
40 | try:
41 | response = client.chat.completions.create(
42 | model=model,
43 | messages=[
44 | {"role": "system", "content": R1_SYS_PROMPT},
45 | {"role": "user", "content": query}],
46 | stream=False,
47 | max_tokens=4096
48 | )
49 | result = {
50 | **qa_pair,
51 | "r1_response": response.choices[0].message.content,
52 | "timestamp": datetime.now().isoformat()
53 | }
54 | # Write result immediately
55 | write_to_jsonl(result, output_file)
56 | time.sleep(4)
57 | return result
58 | except Exception as e:
59 | print(f"Error processing query: {e}")
60 | error_result = {
61 | **qa_pair,
62 | "error": str(e),
63 | "timestamp": datetime.now().isoformat()
64 | }
65 | write_to_jsonl(error_result, f"errors_{output_file}")
66 | time.sleep(10)
67 | return None
68 |
69 | def process_qa_pairs_parallel(qa_pairs: List[Dict], output_file: str, max_workers: int = 10) -> List[Dict]:
70 | successful_count = 0
71 |
72 | with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
73 | # Create futures for all qa_pairs
74 | futures = [executor.submit(query_r1, qa_pair, output_file, v2="v2" in output_file) for qa_pair in qa_pairs]
75 |
76 | # Process results as they complete with progress bar
77 | results = []
78 | for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures)):
79 | try:
80 | result = future.result()
81 | if result is not None:
82 | results.append(result)
83 | successful_count += 1
84 | except Exception as e:
85 | print(f"Failed to process query: {e}")
86 |
87 | return results
88 |
89 | if __name__ == "__main__":
90 | # Load and shuffle QA pairs
91 | random.seed(1234)
92 | qa_pairs = json.load(open("/home/lilei/Visual-R1/data/clever_counting_problems_clevr_cogent_v1.0_trainA.json"))
93 | random.shuffle(qa_pairs)
94 | qa_pairs = qa_pairs[:10000]
95 | # Create output filename with timestamp
96 | timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
97 | output_file = f"r1_results_clevr_cogent_v1.0_trainA_v2.jsonl"
98 |
99 | finished = set()
100 | with open(output_file, 'r') as f:
101 | for line in f:
102 | ins = json.loads(line)
103 | key = ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"])
104 | finished.add(key)
105 | qa_pairs = [ins for ins in qa_pairs if ins["img_filename"] + "-" + ins["q"] + "-" + str(ins["a"]) not in finished]
106 | print("Finished: ", len(finished))
107 | print("Remaining: ", len(qa_pairs))
108 | # Process QA pairs in parallel
109 | r1_results = process_qa_pairs_parallel(qa_pairs, output_file)
110 |
111 | # Print final statistics
112 | print(f"Successfully processed {len(r1_results)} out of {len(qa_pairs)} queries")
113 | print(f"Results saved to {output_file}")
114 | print(f"Any errors were saved to errors_{output_file}")
--------------------------------------------------------------------------------
/src/eval/prompts/superclevr_test200_counting_problems.jsonl:
--------------------------------------------------------------------------------
1 | {"image_path": "./images/superCLEVR_new_025000.png", "question": "How many different items are there in the image?", "ground_truth": 4}
2 | {"image_path": "./images/superCLEVR_new_025001.png", "question": "How many different items are there in the image?", "ground_truth": 9}
3 | {"image_path": "./images/superCLEVR_new_025002.png", "question": "How many different items are there in the image?", "ground_truth": 10}
4 | {"image_path": "./images/superCLEVR_new_025003.png", "question": "How many different items are there in the image?", "ground_truth": 4}
5 | {"image_path": "./images/superCLEVR_new_025004.png", "question": "How many different items are there in the image?", "ground_truth": 3}
6 | {"image_path": "./images/superCLEVR_new_025005.png", "question": "How many different items are there in the image?", "ground_truth": 3}
7 | {"image_path": "./images/superCLEVR_new_025006.png", "question": "How many different items are there in the image?", "ground_truth": 3}
8 | {"image_path": "./images/superCLEVR_new_025007.png", "question": "How many different items are there in the image?", "ground_truth": 4}
9 | {"image_path": "./images/superCLEVR_new_025008.png", "question": "How many different items are there in the image?", "ground_truth": 9}
10 | {"image_path": "./images/superCLEVR_new_025009.png", "question": "How many different items are there in the image?", "ground_truth": 10}
11 | {"image_path": "./images/superCLEVR_new_025010.png", "question": "How many different items are there in the image?", "ground_truth": 7}
12 | {"image_path": "./images/superCLEVR_new_025011.png", "question": "How many different items are there in the image?", "ground_truth": 7}
13 | {"image_path": "./images/superCLEVR_new_025012.png", "question": "How many different items are there in the image?", "ground_truth": 7}
14 | {"image_path": "./images/superCLEVR_new_025013.png", "question": "How many different items are there in the image?", "ground_truth": 6}
15 | {"image_path": "./images/superCLEVR_new_025014.png", "question": "How many different items are there in the image?", "ground_truth": 5}
16 | {"image_path": "./images/superCLEVR_new_025015.png", "question": "How many different items are there in the image?", "ground_truth": 10}
17 | {"image_path": "./images/superCLEVR_new_025016.png", "question": "How many different items are there in the image?", "ground_truth": 4}
18 | {"image_path": "./images/superCLEVR_new_025017.png", "question": "How many different items are there in the image?", "ground_truth": 5}
19 | {"image_path": "./images/superCLEVR_new_025018.png", "question": "How many different items are there in the image?", "ground_truth": 6}
20 | {"image_path": "./images/superCLEVR_new_025019.png", "question": "How many different items are there in the image?", "ground_truth": 8}
21 | {"image_path": "./images/superCLEVR_new_025020.png", "question": "How many different items are there in the image?", "ground_truth": 10}
22 | {"image_path": "./images/superCLEVR_new_025021.png", "question": "How many different items are there in the image?", "ground_truth": 3}
23 | {"image_path": "./images/superCLEVR_new_025022.png", "question": "How many different items are there in the image?", "ground_truth": 4}
24 | {"image_path": "./images/superCLEVR_new_025023.png", "question": "How many different items are there in the image?", "ground_truth": 4}
25 | {"image_path": "./images/superCLEVR_new_025024.png", "question": "How many different items are there in the image?", "ground_truth": 5}
26 | {"image_path": "./images/superCLEVR_new_025025.png", "question": "How many different items are there in the image?", "ground_truth": 5}
27 | {"image_path": "./images/superCLEVR_new_025026.png", "question": "How many different items are there in the image?", "ground_truth": 7}
28 | {"image_path": "./images/superCLEVR_new_025027.png", "question": "How many different items are there in the image?", "ground_truth": 4}
29 | {"image_path": "./images/superCLEVR_new_025028.png", "question": "How many different items are there in the image?", "ground_truth": 4}
30 | {"image_path": "./images/superCLEVR_new_025029.png", "question": "How many different items are there in the image?", "ground_truth": 9}
31 | {"image_path": "./images/superCLEVR_new_025030.png", "question": "How many different items are there in the image?", "ground_truth": 8}
32 | {"image_path": "./images/superCLEVR_new_025031.png", "question": "How many different items are there in the image?", "ground_truth": 6}
33 | {"image_path": "./images/superCLEVR_new_025032.png", "question": "How many different items are there in the image?", "ground_truth": 3}
34 | {"image_path": "./images/superCLEVR_new_025033.png", "question": "How many different items are there in the image?", "ground_truth": 10}
35 | {"image_path": "./images/superCLEVR_new_025034.png", "question": "How many different items are there in the image?", "ground_truth": 9}
36 | {"image_path": "./images/superCLEVR_new_025035.png", "question": "How many different items are there in the image?", "ground_truth": 9}
37 | {"image_path": "./images/superCLEVR_new_025036.png", "question": "How many different items are there in the image?", "ground_truth": 3}
38 | {"image_path": "./images/superCLEVR_new_025037.png", "question": "How many different items are there in the image?", "ground_truth": 6}
39 | {"image_path": "./images/superCLEVR_new_025038.png", "question": "How many different items are there in the image?", "ground_truth": 6}
40 | {"image_path": "./images/superCLEVR_new_025039.png", "question": "How many different items are there in the image?", "ground_truth": 5}
41 | {"image_path": "./images/superCLEVR_new_025040.png", "question": "How many different items are there in the image?", "ground_truth": 3}
42 | {"image_path": "./images/superCLEVR_new_025041.png", "question": "How many different items are there in the image?", "ground_truth": 10}
43 | {"image_path": "./images/superCLEVR_new_025042.png", "question": "How many different items are there in the image?", "ground_truth": 6}
44 | {"image_path": "./images/superCLEVR_new_025043.png", "question": "How many different items are there in the image?", "ground_truth": 3}
45 | {"image_path": "./images/superCLEVR_new_025044.png", "question": "How many different items are there in the image?", "ground_truth": 6}
46 | {"image_path": "./images/superCLEVR_new_025045.png", "question": "How many different items are there in the image?", "ground_truth": 5}
47 | {"image_path": "./images/superCLEVR_new_025046.png", "question": "How many different items are there in the image?", "ground_truth": 7}
48 | {"image_path": "./images/superCLEVR_new_025047.png", "question": "How many different items are there in the image?", "ground_truth": 5}
49 | {"image_path": "./images/superCLEVR_new_025048.png", "question": "How many different items are there in the image?", "ground_truth": 5}
50 | {"image_path": "./images/superCLEVR_new_025049.png", "question": "How many different items are there in the image?", "ground_truth": 10}
51 | {"image_path": "./images/superCLEVR_new_025050.png", "question": "How many different items are there in the image?", "ground_truth": 6}
52 | {"image_path": "./images/superCLEVR_new_025051.png", "question": "How many different items are there in the image?", "ground_truth": 3}
53 | {"image_path": "./images/superCLEVR_new_025052.png", "question": "How many different items are there in the image?", "ground_truth": 7}
54 | {"image_path": "./images/superCLEVR_new_025053.png", "question": "How many different items are there in the image?", "ground_truth": 9}
55 | {"image_path": "./images/superCLEVR_new_025054.png", "question": "How many different items are there in the image?", "ground_truth": 7}
56 | {"image_path": "./images/superCLEVR_new_025055.png", "question": "How many different items are there in the image?", "ground_truth": 6}
57 | {"image_path": "./images/superCLEVR_new_025056.png", "question": "How many different items are there in the image?", "ground_truth": 9}
58 | {"image_path": "./images/superCLEVR_new_025057.png", "question": "How many different items are there in the image?", "ground_truth": 8}
59 | {"image_path": "./images/superCLEVR_new_025058.png", "question": "How many different items are there in the image?", "ground_truth": 10}
60 | {"image_path": "./images/superCLEVR_new_025059.png", "question": "How many different items are there in the image?", "ground_truth": 10}
61 | {"image_path": "./images/superCLEVR_new_025060.png", "question": "How many different items are there in the image?", "ground_truth": 8}
62 | {"image_path": "./images/superCLEVR_new_025061.png", "question": "How many different items are there in the image?", "ground_truth": 8}
63 | {"image_path": "./images/superCLEVR_new_025062.png", "question": "How many different items are there in the image?", "ground_truth": 8}
64 | {"image_path": "./images/superCLEVR_new_025063.png", "question": "How many different items are there in the image?", "ground_truth": 10}
65 | {"image_path": "./images/superCLEVR_new_025064.png", "question": "How many different items are there in the image?", "ground_truth": 3}
66 | {"image_path": "./images/superCLEVR_new_025065.png", "question": "How many different items are there in the image?", "ground_truth": 4}
67 | {"image_path": "./images/superCLEVR_new_025066.png", "question": "How many different items are there in the image?", "ground_truth": 6}
68 | {"image_path": "./images/superCLEVR_new_025067.png", "question": "How many different items are there in the image?", "ground_truth": 7}
69 | {"image_path": "./images/superCLEVR_new_025068.png", "question": "How many different items are there in the image?", "ground_truth": 3}
70 | {"image_path": "./images/superCLEVR_new_025069.png", "question": "How many different items are there in the image?", "ground_truth": 10}
71 | {"image_path": "./images/superCLEVR_new_025070.png", "question": "How many different items are there in the image?", "ground_truth": 9}
72 | {"image_path": "./images/superCLEVR_new_025071.png", "question": "How many different items are there in the image?", "ground_truth": 6}
73 | {"image_path": "./images/superCLEVR_new_025072.png", "question": "How many different items are there in the image?", "ground_truth": 10}
74 | {"image_path": "./images/superCLEVR_new_025073.png", "question": "How many different items are there in the image?", "ground_truth": 5}
75 | {"image_path": "./images/superCLEVR_new_025074.png", "question": "How many different items are there in the image?", "ground_truth": 9}
76 | {"image_path": "./images/superCLEVR_new_025075.png", "question": "How many different items are there in the image?", "ground_truth": 3}
77 | {"image_path": "./images/superCLEVR_new_025076.png", "question": "How many different items are there in the image?", "ground_truth": 5}
78 | {"image_path": "./images/superCLEVR_new_025077.png", "question": "How many different items are there in the image?", "ground_truth": 5}
79 | {"image_path": "./images/superCLEVR_new_025078.png", "question": "How many different items are there in the image?", "ground_truth": 5}
80 | {"image_path": "./images/superCLEVR_new_025079.png", "question": "How many different items are there in the image?", "ground_truth": 9}
81 | {"image_path": "./images/superCLEVR_new_025080.png", "question": "How many different items are there in the image?", "ground_truth": 5}
82 | {"image_path": "./images/superCLEVR_new_025081.png", "question": "How many different items are there in the image?", "ground_truth": 5}
83 | {"image_path": "./images/superCLEVR_new_025082.png", "question": "How many different items are there in the image?", "ground_truth": 10}
84 | {"image_path": "./images/superCLEVR_new_025083.png", "question": "How many different items are there in the image?", "ground_truth": 4}
85 | {"image_path": "./images/superCLEVR_new_025084.png", "question": "How many different items are there in the image?", "ground_truth": 8}
86 | {"image_path": "./images/superCLEVR_new_025085.png", "question": "How many different items are there in the image?", "ground_truth": 8}
87 | {"image_path": "./images/superCLEVR_new_025086.png", "question": "How many different items are there in the image?", "ground_truth": 10}
88 | {"image_path": "./images/superCLEVR_new_025087.png", "question": "How many different items are there in the image?", "ground_truth": 9}
89 | {"image_path": "./images/superCLEVR_new_025088.png", "question": "How many different items are there in the image?", "ground_truth": 3}
90 | {"image_path": "./images/superCLEVR_new_025089.png", "question": "How many different items are there in the image?", "ground_truth": 4}
91 | {"image_path": "./images/superCLEVR_new_025090.png", "question": "How many different items are there in the image?", "ground_truth": 9}
92 | {"image_path": "./images/superCLEVR_new_025091.png", "question": "How many different items are there in the image?", "ground_truth": 7}
93 | {"image_path": "./images/superCLEVR_new_025092.png", "question": "How many different items are there in the image?", "ground_truth": 6}
94 | {"image_path": "./images/superCLEVR_new_025093.png", "question": "How many different items are there in the image?", "ground_truth": 10}
95 | {"image_path": "./images/superCLEVR_new_025094.png", "question": "How many different items are there in the image?", "ground_truth": 6}
96 | {"image_path": "./images/superCLEVR_new_025095.png", "question": "How many different items are there in the image?", "ground_truth": 6}
97 | {"image_path": "./images/superCLEVR_new_025096.png", "question": "How many different items are there in the image?", "ground_truth": 8}
98 | {"image_path": "./images/superCLEVR_new_025097.png", "question": "How many different items are there in the image?", "ground_truth": 7}
99 | {"image_path": "./images/superCLEVR_new_025098.png", "question": "How many different items are there in the image?", "ground_truth": 10}
100 | {"image_path": "./images/superCLEVR_new_025099.png", "question": "How many different items are there in the image?", "ground_truth": 10}
101 | {"image_path": "./images/superCLEVR_new_025100.png", "question": "How many different items are there in the image?", "ground_truth": 5}
102 | {"image_path": "./images/superCLEVR_new_025101.png", "question": "How many different items are there in the image?", "ground_truth": 7}
103 | {"image_path": "./images/superCLEVR_new_025102.png", "question": "How many different items are there in the image?", "ground_truth": 3}
104 | {"image_path": "./images/superCLEVR_new_025103.png", "question": "How many different items are there in the image?", "ground_truth": 6}
105 | {"image_path": "./images/superCLEVR_new_025104.png", "question": "How many different items are there in the image?", "ground_truth": 9}
106 | {"image_path": "./images/superCLEVR_new_025105.png", "question": "How many different items are there in the image?", "ground_truth": 7}
107 | {"image_path": "./images/superCLEVR_new_025106.png", "question": "How many different items are there in the image?", "ground_truth": 8}
108 | {"image_path": "./images/superCLEVR_new_025107.png", "question": "How many different items are there in the image?", "ground_truth": 8}
109 | {"image_path": "./images/superCLEVR_new_025108.png", "question": "How many different items are there in the image?", "ground_truth": 3}
110 | {"image_path": "./images/superCLEVR_new_025109.png", "question": "How many different items are there in the image?", "ground_truth": 7}
111 | {"image_path": "./images/superCLEVR_new_025110.png", "question": "How many different items are there in the image?", "ground_truth": 8}
112 | {"image_path": "./images/superCLEVR_new_025111.png", "question": "How many different items are there in the image?", "ground_truth": 9}
113 | {"image_path": "./images/superCLEVR_new_025112.png", "question": "How many different items are there in the image?", "ground_truth": 9}
114 | {"image_path": "./images/superCLEVR_new_025113.png", "question": "How many different items are there in the image?", "ground_truth": 6}
115 | {"image_path": "./images/superCLEVR_new_025114.png", "question": "How many different items are there in the image?", "ground_truth": 6}
116 | {"image_path": "./images/superCLEVR_new_025115.png", "question": "How many different items are there in the image?", "ground_truth": 9}
117 | {"image_path": "./images/superCLEVR_new_025116.png", "question": "How many different items are there in the image?", "ground_truth": 7}
118 | {"image_path": "./images/superCLEVR_new_025117.png", "question": "How many different items are there in the image?", "ground_truth": 9}
119 | {"image_path": "./images/superCLEVR_new_025118.png", "question": "How many different items are there in the image?", "ground_truth": 5}
120 | {"image_path": "./images/superCLEVR_new_025119.png", "question": "How many different items are there in the image?", "ground_truth": 9}
121 | {"image_path": "./images/superCLEVR_new_025120.png", "question": "How many different items are there in the image?", "ground_truth": 6}
122 | {"image_path": "./images/superCLEVR_new_025121.png", "question": "How many different items are there in the image?", "ground_truth": 10}
123 | {"image_path": "./images/superCLEVR_new_025122.png", "question": "How many different items are there in the image?", "ground_truth": 10}
124 | {"image_path": "./images/superCLEVR_new_025123.png", "question": "How many different items are there in the image?", "ground_truth": 6}
125 | {"image_path": "./images/superCLEVR_new_025124.png", "question": "How many different items are there in the image?", "ground_truth": 8}
126 | {"image_path": "./images/superCLEVR_new_025125.png", "question": "How many different items are there in the image?", "ground_truth": 8}
127 | {"image_path": "./images/superCLEVR_new_025126.png", "question": "How many different items are there in the image?", "ground_truth": 3}
128 | {"image_path": "./images/superCLEVR_new_025127.png", "question": "How many different items are there in the image?", "ground_truth": 7}
129 | {"image_path": "./images/superCLEVR_new_025128.png", "question": "How many different items are there in the image?", "ground_truth": 6}
130 | {"image_path": "./images/superCLEVR_new_025129.png", "question": "How many different items are there in the image?", "ground_truth": 4}
131 | {"image_path": "./images/superCLEVR_new_025130.png", "question": "How many different items are there in the image?", "ground_truth": 5}
132 | {"image_path": "./images/superCLEVR_new_025131.png", "question": "How many different items are there in the image?", "ground_truth": 8}
133 | {"image_path": "./images/superCLEVR_new_025132.png", "question": "How many different items are there in the image?", "ground_truth": 3}
134 | {"image_path": "./images/superCLEVR_new_025133.png", "question": "How many different items are there in the image?", "ground_truth": 5}
135 | {"image_path": "./images/superCLEVR_new_025134.png", "question": "How many different items are there in the image?", "ground_truth": 8}
136 | {"image_path": "./images/superCLEVR_new_025135.png", "question": "How many different items are there in the image?", "ground_truth": 8}
137 | {"image_path": "./images/superCLEVR_new_025136.png", "question": "How many different items are there in the image?", "ground_truth": 6}
138 | {"image_path": "./images/superCLEVR_new_025137.png", "question": "How many different items are there in the image?", "ground_truth": 5}
139 | {"image_path": "./images/superCLEVR_new_025138.png", "question": "How many different items are there in the image?", "ground_truth": 3}
140 | {"image_path": "./images/superCLEVR_new_025139.png", "question": "How many different items are there in the image?", "ground_truth": 4}
141 | {"image_path": "./images/superCLEVR_new_025140.png", "question": "How many different items are there in the image?", "ground_truth": 3}
142 | {"image_path": "./images/superCLEVR_new_025141.png", "question": "How many different items are there in the image?", "ground_truth": 9}
143 | {"image_path": "./images/superCLEVR_new_025142.png", "question": "How many different items are there in the image?", "ground_truth": 10}
144 | {"image_path": "./images/superCLEVR_new_025143.png", "question": "How many different items are there in the image?", "ground_truth": 5}
145 | {"image_path": "./images/superCLEVR_new_025144.png", "question": "How many different items are there in the image?", "ground_truth": 6}
146 | {"image_path": "./images/superCLEVR_new_025145.png", "question": "How many different items are there in the image?", "ground_truth": 10}
147 | {"image_path": "./images/superCLEVR_new_025146.png", "question": "How many different items are there in the image?", "ground_truth": 5}
148 | {"image_path": "./images/superCLEVR_new_025147.png", "question": "How many different items are there in the image?", "ground_truth": 6}
149 | {"image_path": "./images/superCLEVR_new_025148.png", "question": "How many different items are there in the image?", "ground_truth": 8}
150 | {"image_path": "./images/superCLEVR_new_025149.png", "question": "How many different items are there in the image?", "ground_truth": 8}
151 | {"image_path": "./images/superCLEVR_new_025150.png", "question": "How many different items are there in the image?", "ground_truth": 9}
152 | {"image_path": "./images/superCLEVR_new_025151.png", "question": "How many different items are there in the image?", "ground_truth": 8}
153 | {"image_path": "./images/superCLEVR_new_025152.png", "question": "How many different items are there in the image?", "ground_truth": 10}
154 | {"image_path": "./images/superCLEVR_new_025153.png", "question": "How many different items are there in the image?", "ground_truth": 3}
155 | {"image_path": "./images/superCLEVR_new_025154.png", "question": "How many different items are there in the image?", "ground_truth": 5}
156 | {"image_path": "./images/superCLEVR_new_025155.png", "question": "How many different items are there in the image?", "ground_truth": 10}
157 | {"image_path": "./images/superCLEVR_new_025156.png", "question": "How many different items are there in the image?", "ground_truth": 3}
158 | {"image_path": "./images/superCLEVR_new_025157.png", "question": "How many different items are there in the image?", "ground_truth": 6}
159 | {"image_path": "./images/superCLEVR_new_025158.png", "question": "How many different items are there in the image?", "ground_truth": 4}
160 | {"image_path": "./images/superCLEVR_new_025159.png", "question": "How many different items are there in the image?", "ground_truth": 5}
161 | {"image_path": "./images/superCLEVR_new_025160.png", "question": "How many different items are there in the image?", "ground_truth": 9}
162 | {"image_path": "./images/superCLEVR_new_025161.png", "question": "How many different items are there in the image?", "ground_truth": 3}
163 | {"image_path": "./images/superCLEVR_new_025162.png", "question": "How many different items are there in the image?", "ground_truth": 5}
164 | {"image_path": "./images/superCLEVR_new_025163.png", "question": "How many different items are there in the image?", "ground_truth": 10}
165 | {"image_path": "./images/superCLEVR_new_025164.png", "question": "How many different items are there in the image?", "ground_truth": 9}
166 | {"image_path": "./images/superCLEVR_new_025165.png", "question": "How many different items are there in the image?", "ground_truth": 7}
167 | {"image_path": "./images/superCLEVR_new_025166.png", "question": "How many different items are there in the image?", "ground_truth": 8}
168 | {"image_path": "./images/superCLEVR_new_025167.png", "question": "How many different items are there in the image?", "ground_truth": 7}
169 | {"image_path": "./images/superCLEVR_new_025168.png", "question": "How many different items are there in the image?", "ground_truth": 3}
170 | {"image_path": "./images/superCLEVR_new_025169.png", "question": "How many different items are there in the image?", "ground_truth": 10}
171 | {"image_path": "./images/superCLEVR_new_025170.png", "question": "How many different items are there in the image?", "ground_truth": 8}
172 | {"image_path": "./images/superCLEVR_new_025171.png", "question": "How many different items are there in the image?", "ground_truth": 7}
173 | {"image_path": "./images/superCLEVR_new_025172.png", "question": "How many different items are there in the image?", "ground_truth": 4}
174 | {"image_path": "./images/superCLEVR_new_025173.png", "question": "How many different items are there in the image?", "ground_truth": 10}
175 | {"image_path": "./images/superCLEVR_new_025174.png", "question": "How many different items are there in the image?", "ground_truth": 9}
176 | {"image_path": "./images/superCLEVR_new_025175.png", "question": "How many different items are there in the image?", "ground_truth": 4}
177 | {"image_path": "./images/superCLEVR_new_025176.png", "question": "How many different items are there in the image?", "ground_truth": 9}
178 | {"image_path": "./images/superCLEVR_new_025177.png", "question": "How many different items are there in the image?", "ground_truth": 6}
179 | {"image_path": "./images/superCLEVR_new_025178.png", "question": "How many different items are there in the image?", "ground_truth": 10}
180 | {"image_path": "./images/superCLEVR_new_025179.png", "question": "How many different items are there in the image?", "ground_truth": 6}
181 | {"image_path": "./images/superCLEVR_new_025180.png", "question": "How many different items are there in the image?", "ground_truth": 3}
182 | {"image_path": "./images/superCLEVR_new_025181.png", "question": "How many different items are there in the image?", "ground_truth": 3}
183 | {"image_path": "./images/superCLEVR_new_025182.png", "question": "How many different items are there in the image?", "ground_truth": 8}
184 | {"image_path": "./images/superCLEVR_new_025183.png", "question": "How many different items are there in the image?", "ground_truth": 5}
185 | {"image_path": "./images/superCLEVR_new_025184.png", "question": "How many different items are there in the image?", "ground_truth": 5}
186 | {"image_path": "./images/superCLEVR_new_025185.png", "question": "How many different items are there in the image?", "ground_truth": 3}
187 | {"image_path": "./images/superCLEVR_new_025186.png", "question": "How many different items are there in the image?", "ground_truth": 4}
188 | {"image_path": "./images/superCLEVR_new_025187.png", "question": "How many different items are there in the image?", "ground_truth": 5}
189 | {"image_path": "./images/superCLEVR_new_025188.png", "question": "How many different items are there in the image?", "ground_truth": 5}
190 | {"image_path": "./images/superCLEVR_new_025189.png", "question": "How many different items are there in the image?", "ground_truth": 3}
191 | {"image_path": "./images/superCLEVR_new_025190.png", "question": "How many different items are there in the image?", "ground_truth": 5}
192 | {"image_path": "./images/superCLEVR_new_025191.png", "question": "How many different items are there in the image?", "ground_truth": 8}
193 | {"image_path": "./images/superCLEVR_new_025192.png", "question": "How many different items are there in the image?", "ground_truth": 3}
194 | {"image_path": "./images/superCLEVR_new_025193.png", "question": "How many different items are there in the image?", "ground_truth": 9}
195 | {"image_path": "./images/superCLEVR_new_025194.png", "question": "How many different items are there in the image?", "ground_truth": 10}
196 | {"image_path": "./images/superCLEVR_new_025195.png", "question": "How many different items are there in the image?", "ground_truth": 5}
197 | {"image_path": "./images/superCLEVR_new_025196.png", "question": "How many different items are there in the image?", "ground_truth": 6}
198 | {"image_path": "./images/superCLEVR_new_025197.png", "question": "How many different items are there in the image?", "ground_truth": 3}
199 | {"image_path": "./images/superCLEVR_new_025198.png", "question": "How many different items are there in the image?", "ground_truth": 4}
200 | {"image_path": "./images/superCLEVR_new_025199.png", "question": "How many different items are there in the image?", "ground_truth": 3}
201 |
--------------------------------------------------------------------------------
/src/eval/test_qwen2vl_counting_superclevr.py:
--------------------------------------------------------------------------------
1 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2 | from qwen_vl_utils import process_vision_info
3 | import torch
4 | import json
5 | from tqdm import tqdm
6 | import re
7 |
8 |
9 |
10 | MODEL_PATH="checkpoints/Qwen2-VL-2B-Instruct" # Qwen2vl-2b-Instruct for original scores
11 | BSZ=16 # reduce it if GPU OOM
12 | OUTPUT_PATH="src/eval/logs/counting_results_superclevr_200_qwen2vl_2b_instruct.json"
13 | PROMPT_PATH="src/eval/prompts/superclevr_test200_counting_problems.jsonl"
14 |
15 | #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
16 | model = Qwen2VLForConditionalGeneration.from_pretrained(
17 | MODEL_PATH,
18 | torch_dtype=torch.bfloat16,
19 | attn_implementation="flash_attention_2",
20 | device_map="auto",
21 | )
22 |
23 | # default processer
24 | processor = AutoProcessor.from_pretrained(MODEL_PATH)
25 |
26 | data = []
27 | with open(PROMPT_PATH, "r") as f:
28 | for line in f:
29 | data.append(json.loads(line))
30 |
31 |
32 | QUESTION_TEMPLATE = "{Question} First output the thinking process in and final answer (number) in tags."
33 |
34 | messages = []
35 |
36 | for i in data:
37 | message = [{
38 | "role": "user",
39 | "content": [
40 | {
41 | "type": "image",
42 | "image": f"file://{i['image_path']}"
43 | },
44 | {
45 | "type": "text",
46 | "text": QUESTION_TEMPLATE.format(Question=i['question'])
47 | }
48 | ]
49 | }]
50 | messages.append(message)
51 |
52 |
53 |
54 |
55 | all_outputs = [] # List to store all answers
56 |
57 | # Process data in batches
58 | for i in tqdm(range(0, len(messages), BSZ)):
59 | batch_messages = messages[i:i + BSZ]
60 |
61 | # Preparation for inference
62 | text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
63 |
64 | image_inputs, video_inputs = process_vision_info(batch_messages)
65 | inputs = processor(
66 | text=text,
67 | images=image_inputs,
68 | videos=video_inputs,
69 | padding=True,
70 | return_tensors="pt",
71 | )
72 | inputs = inputs.to("cuda")
73 |
74 | # Inference: Generation of the output
75 | generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=256, do_sample=False)
76 |
77 | generated_ids_trimmed = [
78 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
79 | ]
80 | batch_output_text = processor.batch_decode(
81 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
82 | )
83 |
84 | all_outputs.extend(batch_output_text)
85 | print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
86 |
87 |
88 | def extract_number_answer(output_str):
89 | # Try to find the number within tags, if can not find, return None
90 | answer_pattern = r'\s*(\d+)\s*'
91 | match = re.search(answer_pattern, output_str)
92 |
93 | if match:
94 | return int(match.group(1))
95 | return None
96 |
97 |
98 | final_output = []
99 | correct_number = 0
100 |
101 | for input_example, model_output in zip(data,all_outputs):
102 | original_output = model_output
103 | ground_truth = input_example['ground_truth']
104 | model_answer = extract_number_answer(original_output)
105 |
106 | # Create a result dictionary for this example
107 | result = {
108 | 'question': input_example,
109 | 'ground_truth': ground_truth,
110 | 'model_output': original_output,
111 | 'extracted_answer': model_answer
112 | }
113 | final_output.append(result)
114 |
115 | # Count correct answers
116 | if model_answer is not None and model_answer == ground_truth:
117 | correct_number += 1
118 |
119 | # Calculate and print accuracy
120 | accuracy = correct_number / len(data) * 100
121 | print(f"\nAccuracy: {accuracy:.2f}%")
122 |
123 | # Save results to a JSON file
124 | output_path = OUTPUT_PATH
125 | with open(output_path, "w") as f:
126 | json.dump({
127 | 'accuracy': accuracy,
128 | 'results': final_output
129 | }, f, indent=2)
130 |
131 | print(f"Results saved to {output_path}")
132 |
133 |
134 |
135 |
136 |
137 |
--------------------------------------------------------------------------------
/src/eval/test_qwen2vl_geoqa.py:
--------------------------------------------------------------------------------
1 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2 | from qwen_vl_utils import process_vision_info
3 | import torch
4 | import json
5 | from tqdm import tqdm
6 | import re
7 | from math_verify import parse, verify
8 |
9 |
10 | MODEL_PATH="" # qwen2vl model or grpoed model on geoqa train
11 | BSZ=50 # reduce it if GPU OOM
12 | OUTPUT_PATH=""
13 | PROMPT_PATH="./prompts/geoqa_test_prompts.jsonl"
14 |
15 | #We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
16 | model = Qwen2VLForConditionalGeneration.from_pretrained(
17 | MODEL_PATH,
18 | torch_dtype=torch.bfloat16,
19 | attn_implementation="flash_attention_2",
20 | device_map="auto",
21 | )
22 |
23 | # default processer
24 | processor = AutoProcessor.from_pretrained(MODEL_PATH)
25 |
26 | data = []
27 | with open(PROMPT_PATH, "r") as f:
28 | for line in f:
29 | data.append(json.loads(line))
30 |
31 |
32 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (number) in tags."
33 |
34 | messages = []
35 |
36 | data = data
37 |
38 | for i in data:
39 | message = [{
40 | "role": "user",
41 | "content": [
42 | {
43 | "type": "image",
44 | "image": f"file://{i['image_path']}"
45 | },
46 | {
47 | "type": "text",
48 | "text": QUESTION_TEMPLATE.format(Question=i['question'])
49 | }
50 | ]
51 | }]
52 | messages.append(message)
53 |
54 |
55 |
56 |
57 | all_outputs = [] # List to store all answers
58 |
59 | # Process data in batches
60 | for i in tqdm(range(0, len(messages), BSZ)):
61 | batch_messages = messages[i:i + BSZ]
62 |
63 | # Preparation for inference
64 | text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
65 |
66 | image_inputs, video_inputs = process_vision_info(batch_messages)
67 | inputs = processor(
68 | text=text,
69 | images=image_inputs,
70 | videos=video_inputs,
71 | padding=True,
72 | return_tensors="pt",
73 | )
74 | inputs = inputs.to("cuda")
75 |
76 | # Inference: Generation of the output
77 | generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024, do_sample=False)
78 |
79 | generated_ids_trimmed = [
80 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
81 | ]
82 | batch_output_text = processor.batch_decode(
83 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
84 | )
85 |
86 | all_outputs.extend(batch_output_text)
87 | print(f"Processed batch {i//BSZ + 1}/{(len(messages) + BSZ - 1)//BSZ}")
88 |
89 |
90 |
91 |
92 |
93 | final_output = []
94 | correct_number = 0
95 |
96 | for input_example, model_output in zip(data,all_outputs):
97 | original_output = model_output
98 | ground_truth = input_example['ground_truth']
99 | model_answer = parse(original_output)
100 |
101 | # Count correct answers
102 | if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
103 | correct_number += 1
104 | is_correct = True
105 | else:
106 | is_correct = False
107 |
108 | try:
109 | result = {
110 | 'question': input_example,
111 | 'ground_truth': ground_truth,
112 | 'model_output': original_output,
113 | 'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
114 | 'is_correct':is_correct
115 | }
116 |
117 | except Exception as e:
118 | print("no answer parsed",e,model_answer)
119 | result = {
120 | 'question': input_example,
121 | 'ground_truth': ground_truth,
122 | 'model_output': original_output,
123 | 'extracted_answer':None,
124 | 'is_correct':is_correct
125 | }
126 |
127 |
128 |
129 | final_output.append(result)
130 |
131 |
132 | # Calculate and print accuracy
133 | accuracy = correct_number / len(data) * 100
134 | print(f"\nAccuracy: {accuracy:.2f}%")
135 |
136 | # Save results to a JSON file
137 | output_path = OUTPUT_PATH
138 | with open(output_path, "w") as f:
139 | json.dump({
140 | 'accuracy': accuracy,
141 | 'results': final_output
142 | }, f, indent=2, ensure_ascii=False)
143 |
144 | print(f"Results saved to {output_path}")
145 |
146 |
147 |
148 |
149 |
150 |
--------------------------------------------------------------------------------
/src/eval/test_qwen2vl_geoqa_multigpu.py:
--------------------------------------------------------------------------------
1 | from transformers import Qwen2VLForConditionalGeneration, AutoTokenizer, AutoProcessor
2 | from qwen_vl_utils import process_vision_info
3 | import torch
4 | import json
5 | import tqdm
6 | from math_verify import parse, verify
7 | import argparse
8 | import pandas as pd
9 | from torch.multiprocessing import Process, set_start_method, Manager
10 | from transformers.utils.logging import disable_progress_bar
11 | disable_progress_bar()
12 |
13 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
14 | # >>>>> 1. get evaluation configuration <<<<<
15 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
16 | def get_eval_config():
17 | parser = argparse.ArgumentParser(description="Inference script for GeoQA evaluation.")
18 | parser.add_argument("--model_path", required=True, type=str, help="Path to the model checkpoint (e.g., qwen2vl model or a fine-tuned model).")
19 | parser.add_argument("--batch_size", default=4, type=int, help="Batch size for inference. Reduce if GPU OOM (default: 50).")
20 | parser.add_argument("--output_path", required=True, type=str, help="Path to save inference result (e.g., JSON file).")
21 | parser.add_argument("--prompt_path", required=True, type=str, help="Path to the prompts JSONL file for GeoQA evaluation.")
22 | all_gpu = ",".join(map(str, range(torch.cuda.device_count())))
23 | parser.add_argument("--gpu_ids", default=all_gpu, help="comma-separated list of GPU IDs to use")
24 | args = parser.parse_args()
25 | return args
26 |
27 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
28 | # >>>>>>>>>> 2. load testset <<<<<<<<<<<<<
29 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
30 | def prepare_test_messages(testset_path):
31 | testset_data = pd.read_json(testset_path, lines=True).to_dict(orient="records")
32 | QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (number) in tags."
33 | tested_messages = []
34 | for i in testset_data:
35 | message = [{
36 | "role": "user",
37 | "content": [
38 | {
39 | "type": "image",
40 | "image": f"file://{i['image_path']}"
41 | },
42 | {
43 | "type": "text",
44 | "text": QUESTION_TEMPLATE.format(Question=i['question'])
45 | }
46 | ]
47 | }]
48 | tested_messages.append(message)
49 | return testset_data, tested_messages
50 |
51 |
52 |
53 |
54 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
55 | # >>>>> 3. use several GPUs to accelerate inference at testset <<<<<
56 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
57 |
58 | def init_model(model_path, gpu_id):
59 | """init a model(args.model_path) on a specific gpu"""
60 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
61 | model = Qwen2VLForConditionalGeneration.from_pretrained(
62 | model_path,
63 | torch_dtype=torch.bfloat16,
64 | attn_implementation="flash_attention_2",
65 | device_map=f"cuda:{gpu_id}",
66 | )
67 |
68 | # default processer
69 | processor = AutoProcessor.from_pretrained(model_path, use_fast=True)
70 | return model, processor
71 |
72 | def answer_a_batch_question_qwen(batch_messages, model, processor):
73 | """ let qwen answer a batch of questions """
74 | text = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batch_messages]
75 | image_inputs, video_inputs = process_vision_info(batch_messages)
76 | inputs = processor(
77 | text=text,
78 | images=image_inputs,
79 | videos=video_inputs,
80 | padding=True,
81 | return_tensors="pt",
82 | )
83 | inputs = inputs.to(model.device)
84 |
85 | generated_ids = model.generate(**inputs, use_cache=True, max_new_tokens=1024) # do_sample=False
86 | generated_ids_trimmed = [
87 | out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
88 | ]
89 | batch_output_text = processor.batch_decode(
90 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
91 | )
92 | return batch_output_text
93 |
94 | def infer_on_single_gpu(model_path, device_id, chunk_of_tested_messages, batch_size, results=None):
95 | """init model on this single gpu and let it answer asign chunk of questions"""
96 | model, processor = init_model(model_path, device_id)
97 |
98 | ### split batch
99 | responses = []
100 | batch_messages_list = [chunk_of_tested_messages[start: start + batch_size]
101 | for start in range(0, len(chunk_of_tested_messages), batch_size)]
102 |
103 | for batch_messages in tqdm.auto.tqdm(batch_messages_list, desc=f"GPU {device_id} progress", position=device_id, leave=False):
104 | batch_output_text = answer_a_batch_question_qwen(batch_messages, model, processor)
105 |
106 | responses.extend(batch_output_text)
107 |
108 | results[device_id] = responses
109 | return
110 |
111 |
112 | def multi_gpu_inference(prompts, gpu_ids, model_path, batch_size):
113 | """ let each gpu (along with a model) answer a chunk of questions """
114 | set_start_method("spawn", force=True)
115 | manager = Manager()
116 | gpu_id2result = manager.dict()
117 |
118 | gpu_ids = [int(gpu_id.strip()) for gpu_id in gpu_ids.split(',')]
119 | num_gpus = len(gpu_ids)
120 |
121 | chunk_size = len(prompts) // num_gpus
122 | processes = []
123 | for i, gpu_id in enumerate(gpu_ids):
124 | start_idx = i * chunk_size
125 | end_idx = (i + 1) * chunk_size if i != num_gpus - 1 else len(prompts)
126 | chunk = prompts[start_idx: end_idx]
127 | process = Process(target=infer_on_single_gpu, args=(model_path, gpu_id, chunk, batch_size, gpu_id2result))
128 | process.start()
129 | processes.append(process)
130 |
131 | # for process in tqdm.auto.tqdm(processes, desc="Inference progress", position=num_gpus, leave=True):
132 | for process in processes:
133 | process.join()
134 |
135 | all_predicts = []
136 | for gpu_id in gpu_ids:
137 | all_predicts.extend(gpu_id2result[gpu_id])
138 |
139 | return all_predicts
140 |
141 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
142 | # >>>>>>>>>> 4. compute metrics <<<<<<<<<<<
143 | # >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
144 |
145 | def compute_metrics(testset_data, all_predicts):
146 | final_output = []
147 | correct_number = 0
148 |
149 | for input_example, model_output in zip(testset_data, all_predicts):
150 | original_output = model_output
151 | ground_truth = input_example['ground_truth']
152 | model_answer = parse(original_output)
153 |
154 | # Count correct answers
155 | if model_answer is not None and float(verify(model_answer,parse(ground_truth)))>0:
156 | correct_number += 1
157 | is_correct = True
158 | else:
159 | is_correct = False
160 |
161 | try:
162 | result = {
163 | 'question': input_example,
164 | 'ground_truth': ground_truth,
165 | 'model_output': original_output,
166 | 'extracted_answer':str(model_answer[0]) if model_answer is not None else None,
167 | 'is_correct':is_correct
168 | }
169 |
170 | except Exception as e:
171 | print("no answer parsed",e,model_answer)
172 | result = {
173 | 'question': input_example,
174 | 'ground_truth': ground_truth,
175 | 'model_output': original_output,
176 | 'extracted_answer':None,
177 | 'is_correct':is_correct
178 | }
179 |
180 |
181 |
182 | final_output.append(result)
183 |
184 |
185 | # Calculate and print accuracy
186 | accuracy = correct_number / len(tested_messages) * 100
187 | print(f"\nAccuracy: {accuracy:.2f}%")
188 |
189 | # Save results to a JSON file
190 | with open(args.output_path, "w") as f:
191 | json.dump({
192 | 'accuracy': accuracy,
193 | 'results': final_output
194 | }, f, indent=2, ensure_ascii=False)
195 |
196 | print(f"Results saved to {args.output_path}")
197 |
198 |
199 |
200 | if __name__ == "__main__":
201 | args = get_eval_config()
202 | testset_data, tested_messages = prepare_test_messages(testset_path=args.prompt_path)
203 | all_predicts = multi_gpu_inference(tested_messages, args.gpu_ids, args.model_path, args.batch_size)
204 | compute_metrics(testset_data, all_predicts)
205 |
206 |
--------------------------------------------------------------------------------
/src/r1-v/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # PyPI configuration file
171 | .pypirc
172 |
173 | # Temp folders
174 | data/
175 | wandb/
176 | scripts/
177 | checkpoints/
178 | .vscode/
--------------------------------------------------------------------------------
/src/r1-v/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/src/r1-v/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: style quality
2 |
3 | # make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!)
4 | export PYTHONPATH = src
5 |
6 | check_dirs := src
7 |
8 | style:
9 | black --line-length 119 --target-version py310 $(check_dirs) setup.py
10 | isort $(check_dirs) setup.py
11 |
12 | quality:
13 | black --check --line-length 119 --target-version py310 $(check_dirs) setup.py
14 | isort --check-only $(check_dirs) setup.py
15 | flake8 --max-line-length 119 $(check_dirs) setup.py
16 |
17 |
18 | # Evaluation
19 |
20 | evaluate:
21 |
--------------------------------------------------------------------------------
/src/r1-v/configs/ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/src/r1-v/configs/qwen2vl_sft_config.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: /path/to/your/Qwen2-VL-2B-Instruct
3 | model_revision: main
4 | torch_dtype: bfloat16
5 |
6 | # Data training arguments
7 | dataset_name: /path/to/your/train/data
8 | dataset_configs:
9 | - all
10 | preprocessing_num_workers: 8
11 |
12 | # SFT trainer config
13 | bf16: true
14 | do_eval: true
15 | eval_strategy: "no"
16 | gradient_accumulation_steps: 4
17 | gradient_checkpointing: true
18 | gradient_checkpointing_kwargs:
19 | use_reentrant: false
20 | hub_model_id: Qwen2-VL-2B-Instruct-SFT
21 | hub_strategy: every_save
22 | learning_rate: 2.0e-05
23 | log_level: info
24 | logging_steps: 5
25 | logging_strategy: steps
26 | lr_scheduler_type: cosine
27 | packing: true
28 | max_seq_length: 4096
29 | max_steps: -1
30 | num_train_epochs: 1
31 | output_dir: /path/to/your/out_dir
32 | overwrite_output_dir: true
33 | per_device_eval_batch_size: 4
34 | per_device_train_batch_size: 4
35 | push_to_hub: false
36 | report_to:
37 | - tensorboard
38 | save_strategy: "steps"
39 | save_steps: 1000
40 | seed: 42
41 | warmup_ratio: 0.1
--------------------------------------------------------------------------------
/src/r1-v/configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
--------------------------------------------------------------------------------
/src/r1-v/configs/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/create_vision_cot_data.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | import concurrent.futures
4 | import io
5 | import json
6 | import os
7 | import random
8 | import re
9 | import time
10 | from concurrent.futures import ThreadPoolExecutor
11 | from functools import partial
12 | from io import BytesIO
13 | from typing import Dict, List
14 |
15 | import matplotlib.pyplot as plt
16 | import numpy as np
17 | import pandas as pd
18 | from datasets import Dataset, concatenate_datasets, load_dataset, load_from_disk
19 | from tqdm import tqdm
20 |
21 | import bytedtos
22 | import seaborn as sns
23 | import yaml
24 | from openai import AzureOpenAI
25 | from PIL import Image
26 | from pillow_avif import AvifImagePlugin
27 |
28 |
29 | PROMPT_FORMAT = """I will provide you with an image, an original question, and its answer related to the image. Your task is to rewrite the question in such a way that answering it requires step-by-step Chain-of-Thought (CoT) reasoning with numerical or mathematical expressions where applicable. The reasoning process can include expressions like "let me think," "oh, I see," or other natural language thought expressions.
30 |
31 | Please make sure your question is to ask for a certain answer with a certain value, do not ask for open-ended answer, and the answer is correct and easy to verify via simple protocol, like "2" or "A".
32 |
33 | Please strictly do not include "Answer:" in the question part to avoid confusion and leakage.
34 |
35 | Input Format:
36 | Original Question: {original_question}
37 | Original Answer: {original_answer}
38 |
39 | Output Format:
40 | Question: [rewrite the question if necessary]
41 | Answer: [answer with reasoning steps, including calculations where applicable]
42 | step-by-step reasoning process
43 | easy to verify answer
44 | """
45 |
46 |
47 | def get_image_data_url(image_input):
48 | if isinstance(image_input, str) and image_input.startswith("data:"):
49 | return image_input
50 |
51 | if isinstance(image_input, str) and image_input.startswith("http"):
52 | image_input = load_image(image_input)
53 |
54 | if isinstance(image_input, str):
55 | image_input = Image.open(image_input)
56 |
57 | if not isinstance(image_input, Image.Image):
58 | raise ValueError("Unsupported image input type")
59 |
60 | if image_input.mode != "RGB":
61 | image_input = image_input.convert("RGB")
62 |
63 | buffer = BytesIO()
64 | image_input.save(buffer, format="JPEG")
65 | img_bytes = buffer.getvalue()
66 | base64_data = base64.b64encode(img_bytes).decode("utf-8")
67 | return f"data:image/jpeg;base64,{base64_data}"
68 |
69 |
70 | def gpt4o_query(image, prompt, max_retries=5, initial_delay=3):
71 | if image is None:
72 | return None
73 |
74 | data_url_list = [get_image_data_url(image)]
75 | client = AzureOpenAI(
76 | azure_endpoint="YOUR_AZURE_ENDPOINT",
77 | api_version="2023-07-01-preview",
78 | api_key="YOUR_API_KEY",
79 | )
80 |
81 | for attempt in range(max_retries):
82 | try:
83 | messages = [
84 | {
85 | "role": "system",
86 | "content": "You are an expert to analyze the image and provide useful information for users.",
87 | },
88 | {
89 | "role": "user",
90 | "content": [
91 | {"type": "text", "text": prompt},
92 | ],
93 | },
94 | ]
95 |
96 | for data_url in data_url_list:
97 | messages[1]["content"].insert(
98 | 0, {"type": "image_url", "image_url": {"url": data_url}}
99 | )
100 |
101 | response = client.chat.completions.create(
102 | model="gpt-4o-2024-08-06",
103 | messages=messages,
104 | temperature=0.2,
105 | max_tokens=8192,
106 | )
107 | return response.choices[0].message.content
108 |
109 | except Exception as e:
110 | if attempt == max_retries - 1:
111 | raise Exception(
112 | f"Failed after {max_retries} attempts. Last error: {str(e)}"
113 | )
114 | delay = initial_delay * (2**attempt) + random.uniform(
115 | 0, 0.1 * initial_delay * (2**attempt)
116 | )
117 | time.sleep(delay)
118 |
119 |
120 | def process_single_item(example):
121 | try:
122 | image_path = example["image_path"]
123 | formatted_prompt = PROMPT_FORMAT.format(
124 | original_question=example["question"], original_answer=example["answer"]
125 | )
126 |
127 | response = gpt4o_query(image_path, formatted_prompt)
128 | example["gpt4o_response"] = response
129 | return example
130 | except Exception as e:
131 | print(f"Error processing item: {str(e)}")
132 | example["gpt4o_response"] = None
133 | return example
134 |
135 |
136 | def main():
137 | dataset_path = "path/to/your/dataset"
138 | full_dataset = load_from_disk(dataset_path)
139 |
140 | processed_dataset = full_dataset.map(
141 | function=partial(process_single_item),
142 | num_proc=256,
143 | desc="Processing dataset with GPT-4o",
144 | keep_in_memory=True,
145 | )
146 |
147 | output_path = f"{dataset_path}_processed"
148 | processed_dataset.save_to_disk(output_path)
149 | print(f"Processed dataset saved to: {output_path}")
150 |
151 |
152 | if __name__ == "__main__":
153 | main()
154 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/lmms_eval_qwen2vl.sh:
--------------------------------------------------------------------------------
1 | export HF_HOME=""
2 | export HF_TOKEN=""
3 | export HF_HUB_ENABLE_HF_TRANSFER="1"
4 |
5 | export API_TYPE=""
6 | export AZURE_ENDPOINT=""
7 | export AZURE_API_KEY=""
8 | export API_VERSION=""
9 | export MODEL_VERSION=""
10 | export NAVIT_ATTENTION_IMPLEMENTATION="eager"
11 |
12 | # Prompt for installation with 3-second timeout
13 | read -t 3 -p "Do you want to install dependencies? (YES/no, timeout in 3s): " install_deps || true
14 | if [ "$install_deps" = "YES" ]; then
15 | # Prepare the environment
16 | pip3 install --upgrade pip
17 | pip3 install -U setuptools
18 |
19 | cd
20 | if [ ! -d "maas_engine" ]; then
21 | git clone
22 | else
23 | echo "maas_engine directory already exists, skipping clone"
24 | fi
25 | cd maas_engine
26 | git pull
27 | git checkout
28 | pip3 install --no-cache-dir --no-build-isolation -e ".[standalone]"
29 |
30 | current_version=$(pip3 show transformers | grep Version | cut -d' ' -f2)
31 | if [ "$current_version" != "4.46.2" ]; then
32 | echo "Installing transformers 4.46.2 (current version: $current_version)"
33 | pip3 install transformers==4.46.2
34 | else
35 | echo "transformers 4.46.2 is already installed"
36 | fi
37 |
38 | cd
39 | rm -rf
40 | pip3 install -e .
41 | pip3 install -U pydantic
42 | pip3 install Levenshtein
43 | pip3 install nltk
44 | python3 -c "import nltk; nltk.download('wordnet', quiet=True); nltk.download('punkt', quiet=True)"
45 | fi
46 |
47 | TASKS=mmmu_val,mathvista_testmini,mmmu_pro
48 | MODEL_BASENAME=qwen2_vl
49 |
50 | model_checkpoint=""
51 | echo "MODEL_BASENAME: ${MODEL_BASENAME}"
52 | cd
53 |
54 | python3 -m accelerate.commands.launch --num_processes=8 --main_process_port=12345 lmms_eval \
55 | --model qwen2_vl \
56 | --model_args=pretrained=${model_checkpoint},max_pixels=2359296 \
57 | --tasks ${TASKS} \
58 | --batch_size 1 \
59 | --log_samples \
60 | --log_samples_suffix ${MODEL_BASENAME} \
61 | --output_path ./logs
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/prepare_hf_data.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import seaborn as sns
3 | import pandas as pd
4 | import random
5 | from typing import List, Dict
6 | import numpy as np
7 | from concurrent.futures import ThreadPoolExecutor
8 | from tqdm import tqdm
9 | import datasets
10 |
11 | import io
12 | from datasets import load_dataset, load_from_disk, concatenate_datasets
13 | from PIL import Image
14 | from tqdm import tqdm
15 | from functools import partial
16 | from pillow_avif import AvifImagePlugin
17 | from datasets import Dataset
18 | import json
19 | import yaml
20 | import os
21 | import re
22 | import time
23 | import random
24 | import base64
25 | from openai import AzureOpenAI
26 | import concurrent.futures
27 | from typing import List, Dict
28 | import argparse
29 | import time
30 |
31 |
32 | def extract_problem_solution(gpt4o_response):
33 | # Split the response into parts
34 | parts = gpt4o_response.split("")
35 |
36 | # Extract the problem (first part before any tags)
37 | problem = parts[0].strip()
38 | # Remove "Question:" prefix if it exists
39 | problem = re.sub(r"^Question:\s*", "", problem)
40 | # Remove "Answer:" at the end of the problem
41 | problem = re.sub(r"\s*Answer:\s*$", "", problem).strip()
42 |
43 | # Combine all the reasoning steps into a single block
44 | think_parts = [p.split("")[0].strip() for p in parts[1:] if "" in p]
45 | solution = f"{' '.join(think_parts)}"
46 |
47 | # Add the final answer if it exists, removing "Answer:" prefix
48 | if "" in gpt4o_response:
49 | final_answer = (
50 | gpt4o_response.split("")[-1].split("")[0].strip()
51 | )
52 | final_answer = re.sub(r"^Answer:\s*", "", final_answer)
53 | solution += f"\n\n{final_answer}"
54 |
55 | return problem, solution
56 |
57 |
58 | def load_image_from_path(image_path):
59 | try:
60 | img = Image.open(image_path)
61 | return img
62 | except Exception as e:
63 | print(f"Error loading image {image_path}: {str(e)}")
64 | return None
65 |
66 |
67 | def process_raw_data(raw_data):
68 | # Parse the raw data if it's a string
69 | if isinstance(raw_data, str):
70 | data = json.loads(raw_data)
71 | else:
72 | data = raw_data
73 |
74 | # Extract problem and solution
75 | try:
76 | problem, solution = extract_problem_solution(data["gpt4o_response"])
77 | image = load_image_from_path(data["image_path"])
78 |
79 | return {
80 | "image": image,
81 | "problem": problem,
82 | "solution": solution,
83 | "original_question": data["question"],
84 | "original_answer": data["answer"],
85 | }
86 | except Exception as e:
87 | print(f"Error processing data {data}: {str(e)}")
88 | return {
89 | "image": None,
90 | "problem": None,
91 | "solution": None,
92 | "original_question": None,
93 | "original_answer": None,
94 | }
95 |
96 |
97 | raw_data_list = [
98 | "/path/to/reasoning_data_with_response_90k_verified",
99 | ]
100 |
101 | raw_data = concatenate_datasets([load_from_disk(path) for path in raw_data_list])
102 |
103 | processed_data = raw_data.map(process_raw_data, num_proc=128).shuffle(seed=42)
104 |
105 | hf_dict = {
106 | "image": [],
107 | "problem": [],
108 | "solution": [],
109 | "original_question": [],
110 | "original_answer": [],
111 | }
112 |
113 | for item in tqdm(processed_data):
114 | hf_dict["image"].append(item["image"])
115 | hf_dict["problem"].append(item["problem"])
116 | hf_dict["solution"].append(item["solution"])
117 | hf_dict["original_question"].append(item["original_question"])
118 | hf_dict["original_answer"].append(item["original_answer"])
119 |
120 |
121 | features = datasets.Features(
122 | {
123 | "image": datasets.Image(),
124 | "problem": datasets.Value("string"),
125 | "solution": datasets.Value("string"),
126 | "original_question": datasets.Value("string"),
127 | "original_answer": datasets.Value("string"),
128 | }
129 | )
130 |
131 |
132 | def has_empty_tags(text):
133 | # Pattern to match empty tags like
134 | pattern = r"<[^>]+>[^>]+>"
135 | return bool(re.search(pattern, text))
136 |
137 |
138 | def has_answer_pattern(text):
139 | if "Answer:" in text:
140 | return True
141 | return False
142 |
143 |
144 | def has_valid_image_size(example): # for Qwen2-VL-2B's processor requirement
145 | # Assuming the image is in a format that can be checked for dimensions
146 | # You might need to adjust this depending on how the image is stored in your dataset
147 | try:
148 | image = example["image"] # or however your image is accessed
149 | if isinstance(image, dict) and "height" in image and "width" in image:
150 | return image["height"] >= 28 and image["width"] >= 28
151 | # If image is a PIL Image or similar
152 | return image.height >= 28 and image.width >= 28
153 | except:
154 | return False
155 |
156 |
157 | ds = datasets.Dataset.from_dict(hf_dict, features=features)
158 | ds = ds.filter(
159 | lambda x: not has_empty_tags(x["solution"])
160 | and not has_answer_pattern(x["problem"])
161 | and has_valid_image_size(x)
162 | and x["image"] is not None,
163 | num_proc=128,
164 | )
165 | # Push to Hugging Face Hub
166 | ds.push_to_hub("path/to/your/dataset")
167 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/train_aria_moe.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | # CONFIG Huggingface
12 | # export HF_TOKEN=""
13 | export HF_TOKEN=""
14 | export HF_HOME="$HOME/.cache/huggingface"
15 | export HF_HUB_ENABLE_HF_TRANSFER="1"
16 |
17 | export NCCL_DEBUG=INFO
18 |
19 | GPUS="0,1,2,3,4,5,6,7"
20 |
21 | # 取 worker0 第一个 port
22 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
23 | port=${ports[0]}
24 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
25 |
26 | echo "total workers: ${ARNOLD_WORKER_NUM}"
27 | echo "cur worker id: ${ARNOLD_ID}"
28 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
29 | echo "master ip: ${METIS_WORKER_0_HOST}"
30 | echo "master port: ${port}"
31 | echo "master port in cmd: ${port_in_cmd}"
32 |
33 | # export WANDB_BASE_URL=https://api.wandb.ai
34 | # export WANDB_API_KEY=""
35 | # wandb login $WANDB_API_KEY
36 |
37 | export WANDB_BASE_URL=https://api.wandb.ai
38 | export WANDB_PROJECT=vision-reasoning
39 | export WANDB_API_KEY=""
40 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
41 | wandb login $WANDB_API_KEY
42 |
43 | cd /home/tiger/multimodal-open-r1
44 | # pip3 install vllm==0.6.6.post1
45 | pip3 install -e ".[dev]"
46 | pip3 install wandb==0.18.3
47 |
48 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
49 | --nnodes="${ARNOLD_WORKER_NUM}" \
50 | --node_rank="${ARNOLD_ID}" \
51 | --master_addr="${METIS_WORKER_0_HOST}" \
52 | --master_port="${port_in_cmd}" \
53 | src/open_r1/grpo.py \
54 | --deepspeed scripts/zero3.json \
55 | --output_dir Aria-GRPO-mini_cot_80k \
56 | --model_name_or_path rhymes-ai/Aria \
57 | --dataset_name luodian/mini_cot_80k \
58 | --max_prompt_length 8192 \
59 | --per_device_train_batch_size 1 \
60 | --gradient_accumulation_steps 1 \
61 | --logging_steps 1 \
62 | --bf16 \
63 | --report_to wandb \
64 | --gradient_checkpointing true \
65 | --attn_implementation eager \
66 | --save_total_limit 8 \
67 | --num_train_epochs 1 \
68 | --run_name $WANDB_RUN_NAME
69 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/train_qwen2_vl.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export NCCL_BLOCKING_WAIT=0
4 | export TOKENIZERS_PARALLELISM=false
5 | export OMP_NUM_THREADS=8
6 | export NCCL_IB_DISABLE=0
7 | export NCCL_IB_GID_INDEX=3
8 | export NCCL_SOCKET_IFNAME=eth0
9 | export NCCL_DEBUG=INFO
10 |
11 | GPUS="0,1,2,3,4,5,6,7"
12 |
13 | # 取 worker0 第一个 port
14 | ports=($(echo $METIS_WORKER_0_PORT | tr ',' ' '))
15 | port=${ports[0]}
16 | port_in_cmd="$(echo "${METIS_WORKER_0_PORT:-2000}" | awk -F',' '{print $1}')"
17 |
18 | echo "total workers: ${ARNOLD_WORKER_NUM}"
19 | echo "cur worker id: ${ARNOLD_ID}"
20 | echo "gpus per worker: ${ARNOLD_WORKER_GPU}"
21 | echo "master ip: ${METIS_WORKER_0_HOST}"
22 | echo "master port: ${port}"
23 | echo "master port in cmd: ${port_in_cmd}"
24 |
25 | # export WANDB_BASE_URL=https://api.wandb.ai
26 | # export WANDB_API_KEY=""
27 | # wandb login $WANDB_API_KEY
28 |
29 | export WANDB_BASE_URL=https://api.wandb.ai
30 | export WANDB_PROJECT=vision-reasoning
31 | export WANDB_API_KEY=""
32 | export WANDB_RUN_NAME=Qwen-VL-2B-GRPO-$(date +%Y-%m-%d-%H-%M-%S)
33 | wandb login $WANDB_API_KEY
34 |
35 | cd /home/tiger/multimodal-open-r1
36 | # pip3 install vllm==0.6.6.post1
37 | pip3 install -e ".[dev]"
38 | pip3 install wandb==0.18.3
39 |
40 | torchrun --nproc_per_node="${ARNOLD_WORKER_GPU}" \
41 | --nnodes="${ARNOLD_WORKER_NUM}" \
42 | --node_rank="${ARNOLD_ID}" \
43 | --master_addr="${METIS_WORKER_0_HOST}" \
44 | --master_port="${port_in_cmd}" \
45 | src/open_r1/grpo.py \
46 | --deepspeed scripts/zero3.json \
47 | --output_dir checkpoints/${WANDB_RUN_NAME} \
48 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
49 | --dataset_name luodian/${DATASET_NAME} \
50 | --max_prompt_length 8192 \
51 | --per_device_train_batch_size 1 \
52 | --gradient_accumulation_steps 1 \
53 | --logging_steps 1 \
54 | --bf16 \
55 | --report_to wandb \
56 | --gradient_checkpointing true \
57 | --attn_implementation flash_attention_2 \
58 | --max_pixels 2359296 \
59 | --save_total_limit 8 \
60 | --num_train_epochs 1 \
61 | --run_name $WANDB_RUN_NAME
62 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/r1-v/local_scripts/zero3_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 3,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "offload_param": {
29 | "device": "cpu",
30 | "pin_memory": true
31 | },
32 | "overlap_comm": true,
33 | "contiguous_gradients": true,
34 | "sub_group_size": 1e9,
35 | "reduce_bucket_size": "auto",
36 | "stage3_prefetch_bucket_size": "auto",
37 | "stage3_param_persistence_threshold": "auto",
38 | "stage3_max_live_parameters": 1e9,
39 | "stage3_max_reuse_distance": 1e9,
40 | "gather_16bit_weights_on_model_save": true
41 | },
42 | "gradient_accumulation_steps": "auto",
43 | "gradient_clipping": "auto",
44 | "train_batch_size": "auto",
45 | "train_micro_batch_size_per_gpu": "auto",
46 | "steps_per_print": 1e5,
47 | "wall_clock_breakdown": false
48 | }
--------------------------------------------------------------------------------
/src/r1-v/run_grpo.sh:
--------------------------------------------------------------------------------
1 | cd src/r1-v
2 |
3 | export DEBUG_MODE="true"
4 | export LOG_PATH="./debug_log_2b.txt"
5 |
6 |
7 |
8 | torchrun --nproc_per_node="8" \
9 | --nnodes="1" \
10 | --node_rank="0" \
11 | --master_addr="127.0.0.1" \
12 | --master_port="12345" \
13 | src/open_r1/grpo.py \
14 | --output_dir \
15 | --model_name_or_path \
16 | --dataset_name \
17 | --max_prompt_length 1024 \
18 | --per_device_train_batch_size 1 \
19 | --gradient_accumulation_steps 2 \
20 | --logging_steps 1 \
21 | --bf16 \
22 | --report_to wandb \
23 | --gradient_checkpointing false \
24 | --attn_implementation flash_attention_2 \
25 | --max_pixels 401408 \
26 | --num_train_epochs 2 \
27 | --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \
28 | --save_steps 100 \
29 | --save_only_model true
--------------------------------------------------------------------------------
/src/r1-v/setup.cfg:
--------------------------------------------------------------------------------
1 | [isort]
2 | default_section = FIRSTPARTY
3 | ensure_newline_before_comments = True
4 | force_grid_wrap = 0
5 | include_trailing_comma = True
6 | known_first_party = open_r1
7 | known_third_party =
8 | transformers
9 | datasets
10 | fugashi
11 | git
12 | h5py
13 | matplotlib
14 | nltk
15 | numpy
16 | packaging
17 | pandas
18 | psutil
19 | pytest
20 | rouge_score
21 | sacrebleu
22 | seqeval
23 | sklearn
24 | streamlit
25 | torch
26 | tqdm
27 |
28 | line_length = 119
29 | lines_after_imports = 2
30 | multi_line_output = 3
31 | use_parentheses = True
32 |
33 | [flake8]
34 | ignore = E203, E501, E741, W503, W605
35 | max-line-length = 119
36 | per-file-ignores =
37 | # imported but unused
38 | __init__.py: F401
39 |
40 | [tool:pytest]
41 | doctest_optionflags=NUMBER NORMALIZE_WHITESPACE ELLIPSIS
--------------------------------------------------------------------------------
/src/r1-v/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # Adapted from huggingface/transformers: https://github.com/huggingface/transformers/blob/21a2d900eceeded7be9edc445b56877b95eda4ca/setup.py
16 |
17 |
18 | import re
19 | import shutil
20 | from pathlib import Path
21 |
22 | from setuptools import find_packages, setup
23 |
24 |
25 | # Remove stale open_r1.egg-info directory to avoid https://github.com/pypa/pip/issues/5466
26 | stale_egg_info = Path(__file__).parent / "open_r1.egg-info"
27 | if stale_egg_info.exists():
28 | print(
29 | (
30 | "Warning: {} exists.\n\n"
31 | "If you recently updated open_r1, this is expected,\n"
32 | "but it may prevent open_r1 from installing in editable mode.\n\n"
33 | "This directory is automatically generated by Python's packaging tools.\n"
34 | "I will remove it now.\n\n"
35 | "See https://github.com/pypa/pip/issues/5466 for details.\n"
36 | ).format(stale_egg_info)
37 | )
38 | shutil.rmtree(stale_egg_info)
39 |
40 |
41 | # IMPORTANT: all dependencies should be listed here with their version requirements, if any.
42 | # * If a dependency is fast-moving (e.g. transformers), pin to the exact version
43 | _deps = [
44 | "accelerate>=1.2.1",
45 | "bitsandbytes>=0.43.0",
46 | "black>=24.4.2",
47 | "datasets>=3.2.0",
48 | "deepspeed==0.15.4",
49 | "distilabel[vllm,ray,openai]>=1.5.2",
50 | "einops>=0.8.0",
51 | "flake8>=6.0.0",
52 | "hf_transfer>=0.1.4",
53 | "huggingface-hub[cli]>=0.19.2,<1.0",
54 | "isort>=5.12.0",
55 | "liger_kernel==0.5.2",
56 | "lighteval @ git+https://github.com/huggingface/lighteval.git@4f381b352c0e467b5870a97d41cb66b487a2c503#egg=lighteval[math]",
57 | "math-verify", # Used for math verification in grpo
58 | "packaging>=23.0",
59 | "parameterized>=0.9.0",
60 | "pytest",
61 | "safetensors>=0.3.3",
62 | "sentencepiece>=0.1.99",
63 | "torch>=2.5.1",
64 | "transformers @ git+https://github.com/huggingface/transformers.git@336dc69d63d56f232a183a3e7f52790429b871ef",
65 | "trl==0.14.0",
66 | "vllm==0.6.6.post1",
67 | "wandb>=0.19.1",
68 | "pillow",
69 | ]
70 |
71 | # this is a lookup table with items like:
72 | #
73 | # tokenizers: "tokenizers==0.9.4"
74 | # packaging: "packaging"
75 | #
76 | # some of the values are versioned whereas others aren't.
77 | deps = {b: a for a, b in (re.findall(r"^(([^!=<>~ \[\]]+)(?:\[[^\]]+\])?(?:[!=<>~ ].*)?$)", x)[0] for x in _deps)}
78 |
79 |
80 | def deps_list(*pkgs):
81 | return [deps[pkg] for pkg in pkgs]
82 |
83 |
84 | extras = {}
85 | extras["tests"] = deps_list("pytest", "parameterized")
86 | extras["torch"] = deps_list("torch")
87 | extras["quality"] = deps_list("black", "isort", "flake8")
88 | extras["eval"] = deps_list("lighteval", "math-verify")
89 | extras["dev"] = extras["quality"] + extras["tests"] + extras["eval"]
90 |
91 | # core dependencies shared across the whole project - keep this to a bare minimum :)
92 | install_requires = [
93 | deps["accelerate"],
94 | deps["bitsandbytes"],
95 | deps["einops"],
96 | deps["datasets"],
97 | deps["deepspeed"],
98 | deps["hf_transfer"],
99 | deps["huggingface-hub"],
100 | deps["liger_kernel"],
101 | deps["packaging"], # utilities from PyPA to e.g., compare versions
102 | deps["safetensors"],
103 | deps["sentencepiece"],
104 | deps["transformers"],
105 | deps["trl"],
106 | ]
107 |
108 | setup(
109 | name="r1-v",
110 | version="0.1.0", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots)
111 | author="The r1-v team and the Hugging Face team (past and future)",
112 | description="R1-V",
113 | license="Apache",
114 | url="https://github.com/Deep-Agent/R1-V",
115 | package_dir={"": "src"},
116 | packages=find_packages("src"),
117 | zip_safe=False,
118 | extras_require=extras,
119 | python_requires=">=3.10.9",
120 | install_requires=install_requires,
121 | classifiers=[
122 | "Development Status :: 3 - Alpha",
123 | "Intended Audience :: Developers",
124 | "Intended Audience :: Education",
125 | "Intended Audience :: Science/Research",
126 | "License :: OSI Approved :: Apache Software License",
127 | "Operating System :: OS Independent",
128 | "Programming Language :: Python :: 3",
129 | "Programming Language :: Python :: 3.10",
130 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
131 | ],
132 | )
133 |
--------------------------------------------------------------------------------
/src/r1-v/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/src/r1-v/src/__init__.py
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/src/r1-v/src/open_r1/__init__.py
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom evaluation tasks for LightEval."""
16 |
17 | from lighteval.metrics.dynamic_metrics import (
18 | ExprExtractionConfig,
19 | LatexExtractionConfig,
20 | multilingual_extractive_match_metric,
21 | )
22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
23 | from lighteval.tasks.requests import Doc
24 | from lighteval.utils.language import Language
25 |
26 |
27 | metric = multilingual_extractive_match_metric(
28 | language=Language.ENGLISH,
29 | fallback_mode="first_match",
30 | precision=5,
31 | gold_extraction_target=(LatexExtractionConfig(),),
32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33 | aggregation_function=max,
34 | )
35 |
36 |
37 | def prompt_fn(line, task_name: str = None):
38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39 | return Doc(
40 | task_name=task_name,
41 | query=line["problem"],
42 | choices=[line["solution"]],
43 | gold_index=0,
44 | )
45 |
46 |
47 | # Define tasks
48 | aime24 = LightevalTaskConfig(
49 | name="aime24",
50 | suite=["custom"],
51 | prompt_function=prompt_fn,
52 | hf_repo="HuggingFaceH4/aime_2024",
53 | hf_subset="default",
54 | hf_avail_splits=["train"],
55 | evaluation_splits=["train"],
56 | few_shots_split=None,
57 | few_shots_select=None,
58 | generation_size=32768,
59 | metric=[metric],
60 | version=1,
61 | )
62 | math_500 = LightevalTaskConfig(
63 | name="math_500",
64 | suite=["custom"],
65 | prompt_function=prompt_fn,
66 | hf_repo="HuggingFaceH4/MATH-500",
67 | hf_subset="default",
68 | hf_avail_splits=["test"],
69 | evaluation_splits=["test"],
70 | few_shots_split=None,
71 | few_shots_select=None,
72 | generation_size=32768,
73 | metric=[metric],
74 | version=1,
75 | )
76 |
77 | # Add tasks to the table
78 | TASKS_TABLE = []
79 | TASKS_TABLE.append(aime24)
80 | TASKS_TABLE.append(math_500)
81 |
82 | # MODULE LOGIC
83 | if __name__ == "__main__":
84 | print([t["name"] for t in TASKS_TABLE])
85 | print(len(TASKS_TABLE))
86 |
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | from distilabel.llms import OpenAILLM
18 | from distilabel.pipeline import Pipeline
19 | from distilabel.steps.tasks import TextGeneration
20 |
21 |
22 | def build_distilabel_pipeline(
23 | model: str,
24 | base_url: str = "http://localhost:8000/v1",
25 | prompt_column: Optional[str] = None,
26 | temperature: Optional[float] = None,
27 | top_p: Optional[float] = None,
28 | max_new_tokens: int = 8192,
29 | num_generations: int = 1,
30 | ) -> Pipeline:
31 | generation_kwargs = {"max_new_tokens": max_new_tokens}
32 |
33 | if temperature is not None:
34 | generation_kwargs["temperature"] = temperature
35 |
36 | if top_p is not None:
37 | generation_kwargs["top_p"] = top_p
38 |
39 | with Pipeline().ray() as pipeline:
40 | TextGeneration(
41 | llm=OpenAILLM(
42 | base_url=base_url,
43 | api_key="something",
44 | model=model,
45 | # thinking can take some time...
46 | timeout=10 * 60,
47 | generation_kwargs=generation_kwargs,
48 | ),
49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51 | num_generations=num_generations,
52 | )
53 |
54 | return pipeline
55 |
56 |
57 | if __name__ == "__main__":
58 | import argparse
59 |
60 | from datasets import load_dataset
61 |
62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63 | parser.add_argument(
64 | "--hf-dataset",
65 | type=str,
66 | required=True,
67 | help="HuggingFace dataset to load",
68 | )
69 | parser.add_argument(
70 | "--hf-dataset-config",
71 | type=str,
72 | required=False,
73 | help="Dataset config to use",
74 | )
75 | parser.add_argument(
76 | "--hf-dataset-split",
77 | type=str,
78 | default="train",
79 | help="Dataset split to use",
80 | )
81 | parser.add_argument("--prompt-column", type=str, default="prompt")
82 | parser.add_argument(
83 | "--model",
84 | type=str,
85 | required=True,
86 | help="Model name to use for generation",
87 | )
88 | parser.add_argument(
89 | "--vllm-server-url",
90 | type=str,
91 | default="http://localhost:8000/v1",
92 | help="URL of the vLLM server",
93 | )
94 | parser.add_argument(
95 | "--temperature",
96 | type=float,
97 | help="Temperature for generation",
98 | )
99 | parser.add_argument(
100 | "--top-p",
101 | type=float,
102 | help="Top-p value for generation",
103 | )
104 | parser.add_argument(
105 | "--max-new-tokens",
106 | type=int,
107 | default=8192,
108 | help="Maximum number of new tokens to generate",
109 | )
110 | parser.add_argument(
111 | "--num-generations",
112 | type=int,
113 | default=1,
114 | help="Number of generations per problem",
115 | )
116 | parser.add_argument(
117 | "--hf-output-dataset",
118 | type=str,
119 | required=False,
120 | help="HuggingFace repo to push results to",
121 | )
122 | parser.add_argument(
123 | "--private",
124 | action="store_true",
125 | help="Whether to make the output dataset private when pushing to HF Hub",
126 | )
127 |
128 | args = parser.parse_args()
129 |
130 | print("\nRunning with arguments:")
131 | for arg, value in vars(args).items():
132 | print(f" {arg}: {value}")
133 | print()
134 |
135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137 | print("Dataset loaded!")
138 |
139 | pipeline = build_distilabel_pipeline(
140 | model=args.model,
141 | base_url=args.vllm_server_url,
142 | prompt_column=args.prompt_column,
143 | temperature=args.temperature,
144 | top_p=args.top_p,
145 | max_new_tokens=args.max_new_tokens,
146 | num_generations=args.num_generations,
147 | )
148 |
149 | print("Running generation pipeline...")
150 | distiset = pipeline.run(dataset=dataset, use_cache=False)
151 | print("Generation pipeline finished!")
152 |
153 | if args.hf_output_dataset:
154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156 | print("Dataset pushed!")
157 |
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/grpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import re
17 | from collections import Counter
18 | from datetime import datetime
19 | from dataclasses import dataclass, field
20 | from typing import Optional
21 |
22 | from datasets import load_dataset, load_from_disk
23 | from transformers import Qwen2VLForConditionalGeneration
24 |
25 | from math_verify import parse, verify
26 | from open_r1.trainer import Qwen2VLGRPOTrainer, Qwen2VLGRPOVLLMTrainer
27 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
28 |
29 |
30 | @dataclass
31 | class GRPOScriptArguments(ScriptArguments):
32 | """
33 | Script arguments for the GRPO training script.
34 |
35 | Args:
36 | reward_funcs (`list[str]`):
37 | List of reward functions. Possible values: 'plan_speed_reward', 'plan_path_reward', 'plan_format_reward'.
38 | """
39 |
40 | reward_funcs: list[str] = field(
41 | default_factory=lambda: ["plan_speed_reward", "plan_path_reward", "plan_format_reward"],
42 | metadata={"help": "List of reward functions. Possible values: 'plan_speed_reward', 'plan_path_reward', 'plan_format_reward'"},
43 | )
44 | max_pixels: Optional[int] = field(
45 | default=12845056,
46 | metadata={"help": "Maximum number of pixels for the image"},
47 | )
48 | min_pixels: Optional[int] = field(
49 | default=3136,
50 | metadata={"help": "Minimum number of pixels for the image"},
51 | )
52 |
53 |
54 | def plan_speed_reward(completions,
55 | solution,
56 | diversity_weight=0.4,
57 | complexity_weights=None,
58 | **kwargs):
59 | """
60 | planning speed reward function.
61 | """
62 | if complexity_weights is None:
63 | complexity_weights = {
64 | "ACCELERATE": 0.9, "DECELERATE": 1.0, "STOP": 1.0, "KEEP": 0.8,
65 | }
66 |
67 | rewards = []
68 | global_decision_count = Counter()
69 |
70 | for completion, sol in zip(completions, solution):
71 | sol_match = re.search(r'(.*?)', sol)
72 | if not sol_match:
73 | rewards.append(0)
74 | continue
75 | ground_truth_words = set(sol_match.group(1).strip().split(', '))
76 | ground_truth_words = {word for word in ground_truth_words if word in complexity_weights}
77 |
78 | match = re.search(r"(.*?)", completion[0]["content"])
79 | if match:
80 | content = match.group(1).strip()
81 | else:
82 | content = completion[0]["content"].replace('', '').replace('', '')
83 |
84 | content_word_list = [re.sub(r'[^\w]', '', word) for word in content.split(', ') if word in complexity_weights]
85 | content_words = set(content_word_list)
86 | global_decision_count.update(content_words)
87 |
88 | true_positives = len(content_words & ground_truth_words)
89 | false_positives = len(content_words - ground_truth_words)
90 | false_negatives = len(ground_truth_words - content_words)
91 |
92 | precision = true_positives / (true_positives + false_positives + 1e-6)
93 | recall = true_positives / (true_positives + false_negatives + 1e-6)
94 | if true_positives == 0 and false_positives == 0 and false_negatives == 0:
95 | f1_score = 0 # no match
96 | else:
97 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)
98 |
99 | complexity_factor = sum(complexity_weights[word] for word in content_words) / (len(content_words) + 1e-6)
100 |
101 | diversity_factor = [True if global_decision_count[word] == 1 else False for word in content_words]
102 | diversity_factor = True if all(diversity_factor) else False
103 | diversity_factor = diversity_weight if diversity_factor else -diversity_weight
104 |
105 | reward = f1_score * complexity_factor + diversity_factor
106 |
107 | rewards.append(reward)
108 |
109 | return rewards
110 |
111 |
112 | def plan_path_reward(completions,
113 | solution,
114 | diversity_weight=0.4,
115 | complexity_weights=None,
116 | **kwargs):
117 | """
118 | planning path reward function.
119 | """
120 | if complexity_weights is None:
121 | complexity_weights = {
122 | "LEFT_TURN": 1.0, "RIGHT_TURN": 1.0,
123 | "LEFT_CHANGE": 1.0, "RIGHT_CHANGE": 1.0, "STRAIGHT": 0.8
124 | }
125 |
126 | rewards = []
127 | global_decision_count = Counter()
128 |
129 | for completion, sol in zip(completions, solution):
130 | sol_match = re.search(r'(.*?)', sol)
131 | if not sol_match:
132 | rewards.append(0)
133 | continue
134 | ground_truth_words = set(sol_match.group(1).strip().split(', '))
135 | ground_truth_words = {word for word in ground_truth_words if word in complexity_weights}
136 |
137 | match = re.search(r"(.*?)", completion[0]["content"])
138 | if match:
139 | content = match.group(1).strip()
140 | else:
141 | content = completion[0]["content"].replace('', '').replace('', '')
142 |
143 | content_word_list = [re.sub(r'[^\w]', '', word) for word in content.split(', ') if word in complexity_weights]
144 | content_words = set(content_word_list)
145 | global_decision_count.update(content_words)
146 |
147 | true_positives = len(content_words & ground_truth_words)
148 | false_positives = len(content_words - ground_truth_words)
149 | false_negatives = len(ground_truth_words - content_words)
150 |
151 | precision = true_positives / (true_positives + false_positives + 1e-6)
152 | recall = true_positives / (true_positives + false_negatives + 1e-6)
153 | if true_positives == 0 and false_positives == 0 and false_negatives == 0:
154 | f1_score = 0 # no match
155 | else:
156 | f1_score = 2 * (precision * recall) / (precision + recall + 1e-6)
157 |
158 | complexity_factor = sum(complexity_weights[word] for word in content_words) / (len(content_words) + 1e-6)
159 |
160 | diversity_factor = [True if global_decision_count[word] == 1 else False for word in content_words]
161 | diversity_factor = True if all(diversity_factor) else False
162 | diversity_factor = diversity_weight if diversity_factor else -diversity_weight
163 |
164 | reward = f1_score * complexity_factor + diversity_factor
165 |
166 | rewards.append(reward)
167 |
168 | return rewards
169 |
170 |
171 | def plan_format_reward(completions, **kwargs):
172 | """Reward function that checks if the completion has a specific format."""
173 | # check if answer format is xxx\nxxx
174 | pattern = r".*?\s*.*?"
175 | completion_contents = [completion[0]["content"] for completion in completions]
176 | matches = [re.fullmatch(pattern, content, re.DOTALL) for content in completion_contents]
177 |
178 | return [1.0 if match else 0.0 for match in matches]
179 |
180 |
181 |
182 | reward_funcs_registry = {
183 | "plan_format_reward": plan_format_reward,
184 | "plan_speed_reward": plan_speed_reward,
185 | "plan_path_reward": plan_path_reward,
186 | }
187 |
188 | SYSTEM_PROMPT = (
189 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
190 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
191 | "process and answer are enclosed within and tags, respectively, i.e., "
192 | " reasoning process here answer here "
193 | )
194 |
195 |
196 | def main(script_args, training_args, model_args):
197 | # Get reward functions
198 | reward_funcs = [reward_funcs_registry[func] for func in script_args.reward_funcs]
199 |
200 | # Load the dataset
201 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
202 |
203 |
204 | # Format into conversation
205 | def make_conversation(example):
206 | return {
207 | "prompt": [
208 | {"role": "system", "content": SYSTEM_PROMPT},
209 | {"role": "user", "content": example["problem"]},
210 | ],
211 | }
212 |
213 | # def make_conversation_image(example):
214 | # return {
215 | # "prompt": [
216 | # {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
217 | # {
218 | # "role": "user",
219 | # "content": [
220 | # {"type": "image"},
221 | # {"type": "text", "text": example["problem"]},
222 | # ],
223 | # },
224 | # ],
225 | # }
226 |
227 | # QUESTION_TEMPLATE = "{Question} Output the thinking process in and final answer (number) in tags."
228 | # QUESTION_TEMPLATE = "{Question} Output the final answer in tags."
229 |
230 | def make_conversation_image(example):
231 | return {
232 | "prompt": [
233 | {
234 | "role": "user",
235 | "content": [
236 | {"type": "image"},
237 | # {"type": "text", "text": QUESTION_TEMPLATE.format(Question=example["problem"])},
238 | {"type": "text", "text": example["problem"]},
239 | ],
240 | },
241 | ],
242 | }
243 |
244 |
245 | if "image" in dataset[script_args.dataset_train_split].features:
246 | print("has image in dataset")
247 | dataset = dataset.map(make_conversation_image) # Utilize multiprocessing for faster mapping
248 | # dataset = dataset.remove_columns(["original_question", "original_answer"])
249 |
250 | else:
251 | print("no image in dataset")
252 | dataset = dataset.map(make_conversation)
253 | dataset = dataset.remove_columns("messages")
254 |
255 |
256 | trainer_cls = Qwen2VLGRPOTrainer if not training_args.use_vllm else Qwen2VLGRPOVLLMTrainer
257 | print("using: ", trainer_cls)
258 |
259 | # Initialize the GRPO trainer
260 | trainer = trainer_cls(
261 | model=model_args.model_name_or_path,
262 | reward_funcs=reward_funcs,
263 | args=training_args,
264 | train_dataset=dataset[script_args.dataset_train_split],
265 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
266 | peft_config=get_peft_config(model_args),
267 | attn_implementation=model_args.attn_implementation,
268 | max_pixels=script_args.max_pixels,
269 | min_pixels=script_args.min_pixels,
270 | )
271 |
272 | # Train and push the model to the Hub
273 | trainer.train()
274 |
275 | # Save and push to hub
276 | trainer.save_model(training_args.output_dir)
277 | if training_args.push_to_hub:
278 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
279 |
280 |
281 | if __name__ == "__main__":
282 | parser = TrlParser((GRPOScriptArguments, GRPOConfig, ModelConfig))
283 | script_args, training_args, model_args = parser.parse_args_and_config()
284 | main(script_args, training_args, model_args)
285 |
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Supervised fine-tuning script for decoder language models.
17 |
18 | Usage:
19 |
20 | # One 1 node of 8 x H100s
21 | accelerate launch --config_file=configs/zero3.yaml src/open_r1/sft.py \
22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24 | --learning_rate 2.0e-5 \
25 | --num_train_epochs 1 \
26 | --packing \
27 | --max_seq_length 4096 \
28 | --per_device_train_batch_size 4 \
29 | --gradient_accumulation_steps 4 \
30 | --gradient_checkpointing \
31 | --bf16 \
32 | --logging_steps 5 \
33 | --eval_strategy steps \
34 | --eval_steps 100 \
35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36 | """
37 |
38 | import logging
39 | import os
40 | import sys
41 |
42 | import datasets
43 | from dataclasses import dataclass, field
44 | from typing import Optional
45 | import torch
46 | import transformers
47 | from datasets import load_dataset
48 | from transformers import AutoTokenizer, set_seed, AutoProcessor
49 | from transformers.trainer_utils import get_last_checkpoint
50 | import trl
51 | from trl import (
52 | ModelConfig,
53 | ScriptArguments,
54 | SFTTrainer,
55 | TrlParser,
56 | get_kbit_device_map,
57 | get_peft_config,
58 | get_quantization_config,
59 | )
60 |
61 | from qwen_vl_utils import process_vision_info
62 | logger = logging.getLogger(__name__)
63 |
64 |
65 | @dataclass
66 | class SFTConfig(trl.SFTConfig):
67 | """
68 | args for callbacks, benchmarks etc
69 | """
70 |
71 | benchmarks: list[str] = field(
72 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
73 | )
74 | callbacks: list[str] = field(
75 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
76 | )
77 | system_prompt: Optional[str] = field(
78 | default=None,
79 | metadata={"help": "The optional system prompt to use for benchmarking."},
80 | )
81 | hub_model_revision: Optional[str] = field(
82 | default="main",
83 | metadata={"help": "The Hub model branch to push the model to."},
84 | )
85 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
86 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
87 |
88 |
89 |
90 | processor = None
91 |
92 |
93 | def convert_example(example):
94 | """
95 | correct example into "messages"
96 | eg:
97 | {
98 | "system": "You are a helpful assistant.",
99 | "conversations": [
100 | {"from": "user", "value": "How many objects are included in this image?",
101 | "image_path": "/path/to/image.png"},
102 | {"from": "assistant", "value": "\nI can see 10 objects\n\n\n10\n"}
103 | ]
104 | }
105 | """
106 | messages = []
107 | # if "system" in example:
108 | # messages.append({
109 | # "role": "system",
110 | # "content": [{"type": "text", "text": example["system"]}],
111 | # })
112 | # else:
113 | # SYSTEM_PROMPT = (
114 | # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
115 | # "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
116 | # "process and answer are enclosed within and tags, respectively, i.e., "
117 | # " reasoning process here answer here "
118 | # )
119 |
120 | # SYSTEM_PROMPT = (
121 | # "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The answer is enclosed within tags, i.e., "
122 | # " answer here "
123 | # )
124 |
125 | # messages.append({
126 | # "role": "system",
127 | # "content": [{"type": "text", "text": SYSTEM_PROMPT}],
128 | # })
129 |
130 | thinking = example.get("thinking")
131 | problem = example.get("problem")
132 | solution = example.get("solution")
133 | image = example.get("image")
134 | messages.append({
135 | "role": "user",
136 | "content": [
137 | {"type": "text", "text": problem},
138 | {"type": "image", "image": image},
139 | ]
140 | })
141 | # messages.append({
142 | # "role": "assistant",
143 | # "content": f"{thinking}\n{solution}",
144 | # })
145 | messages.append({
146 | "role": "assistant",
147 | "content": f"{solution}",
148 | })
149 |
150 | example["messages"] = messages
151 | return example
152 |
153 |
154 | def collate_fn(examples):
155 | texts = [
156 | processor.apply_chat_template( convert_example(example)["messages"], tokenize=False, add_generation_prompt=True)
157 | for example in examples
158 | ]
159 | image_inputs = []
160 | for example in examples:
161 | imgs, vids = process_vision_info(example["messages"])
162 | image_inputs.append(imgs)
163 | batch = processor(
164 | text=texts,
165 | images=image_inputs,
166 | return_tensors="pt",
167 | padding=True,
168 | )
169 | labels = batch["input_ids"].clone()
170 | labels[labels == processor.tokenizer.pad_token_id] = -100
171 | image_token_id = processor.tokenizer.convert_tokens_to_ids(processor.image_token)
172 | labels[labels == image_token_id] = -100
173 | batch["labels"] = labels
174 |
175 | return batch
176 |
177 |
178 | def main(script_args, training_args, model_args):
179 | # Set seed for reproducibility
180 | set_seed(training_args.seed)
181 |
182 | ###############
183 | # Setup logging
184 | ###############
185 | logging.basicConfig(
186 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
187 | datefmt="%Y-%m-%d %H:%M:%S",
188 | handlers=[logging.StreamHandler(sys.stdout)],
189 | )
190 | log_level = training_args.get_process_log_level()
191 | logger.setLevel(log_level)
192 | datasets.utils.logging.set_verbosity(log_level)
193 | transformers.utils.logging.set_verbosity(log_level)
194 | transformers.utils.logging.enable_default_handler()
195 | transformers.utils.logging.enable_explicit_format()
196 |
197 | # Log on each process a small summary
198 | logger.warning(
199 | f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
200 | + f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
201 | )
202 | logger.info(f"Model parameters {model_args}")
203 | logger.info(f"Script parameters {script_args}")
204 | logger.info(f"Data parameters {training_args}")
205 |
206 | # Check for last checkpoint
207 | last_checkpoint = None
208 | if os.path.isdir(training_args.output_dir):
209 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
210 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
211 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
212 |
213 | ################
214 | # Load datasets
215 | ################
216 |
217 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
218 |
219 | ################
220 | # Load tokenizer
221 | ################
222 | global processor
223 | if "vl" in model_args.model_name_or_path.lower():
224 | processor = AutoProcessor.from_pretrained(
225 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
226 | )
227 | logger.info("Using AutoProcessor for vision-language model.")
228 | else:
229 | processor = AutoTokenizer.from_pretrained(
230 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
231 | )
232 | logger.info("Using AutoTokenizer for text-only model.")
233 | if hasattr(processor, "pad_token") and processor.pad_token is None:
234 | processor.pad_token = processor.eos_token
235 | elif hasattr(processor.tokenizer, "pad_token") and processor.tokenizer.pad_token is None:
236 | processor.tokenizer.pad_token = processor.tokenizer.eos_token
237 |
238 | ###################
239 | # Model init kwargs
240 | ###################
241 | logger.info("*** Initializing model kwargs ***")
242 | torch_dtype = (
243 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
244 | )
245 | quantization_config = get_quantization_config(model_args)
246 | model_kwargs = dict(
247 | revision=model_args.model_revision,
248 | trust_remote_code=model_args.trust_remote_code,
249 | attn_implementation=model_args.attn_implementation,
250 | torch_dtype=torch_dtype,
251 | use_cache=False if training_args.gradient_checkpointing else True,
252 | device_map=get_kbit_device_map() if quantization_config is not None else None,
253 | quantization_config=quantization_config,
254 | )
255 | # training_args.model_init_kwargs = model_kwargs
256 | from transformers import Qwen2VLForConditionalGeneration
257 | model = Qwen2VLForConditionalGeneration.from_pretrained(
258 | model_args.model_name_or_path, **model_kwargs
259 | )
260 | ############################
261 | # Initialize the SFT Trainer
262 | ############################
263 | training_args.dataset_kwargs = {
264 | "skip_prepare_dataset": True,
265 | }
266 | training_args.remove_unused_columns = False
267 | trainer = SFTTrainer(
268 | model=model,
269 | args=training_args,
270 | train_dataset=dataset[script_args.dataset_train_split],
271 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
272 | processing_class=processor.tokenizer,
273 | data_collator=collate_fn,
274 | peft_config=get_peft_config(model_args)
275 | )
276 |
277 | ###############
278 | # Training loop
279 | ###############
280 | logger.info("*** Train ***")
281 | checkpoint = None
282 | if training_args.resume_from_checkpoint is not None:
283 | checkpoint = training_args.resume_from_checkpoint
284 | elif last_checkpoint is not None:
285 | checkpoint = last_checkpoint
286 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
287 | metrics = train_result.metrics
288 | metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
289 | trainer.log_metrics("train", metrics)
290 | trainer.save_metrics("train", metrics)
291 | trainer.save_state()
292 |
293 | ##################################
294 | # Save model and create model card
295 | ##################################
296 | logger.info("*** Save model ***")
297 | trainer.save_model(training_args.output_dir)
298 | processor.save_pretrained(training_args.output_dir)
299 | logger.info(f"Model saved to {training_args.output_dir}")
300 |
301 | # Save everything else on main process
302 | kwargs = {
303 | "dataset_name": script_args.dataset_name,
304 | "tags": ["R1-V"],
305 | }
306 | if trainer.accelerator.is_main_process:
307 | trainer.create_model_card(**kwargs)
308 | # Restore k,v cache for fast inference
309 | trainer.model.config.use_cache = True
310 | trainer.model.config.save_pretrained(training_args.output_dir)
311 | #############
312 | # push to hub
313 | #############
314 |
315 | if training_args.push_to_hub:
316 | logger.info("Pushing to hub...")
317 | trainer.push_to_hub(**kwargs)
318 | processor.push_to_hub(training_args.hub_model_id)
319 |
320 |
321 |
322 |
323 | if __name__ == "__main__":
324 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
325 | script_args, training_args, model_args = parser.parse_args_and_config()
326 | main(script_args, training_args, model_args)
327 |
--------------------------------------------------------------------------------
/src/r1-v/src/open_r1/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .grpo_trainer import Qwen2VLGRPOTrainer
2 | from .vllm_grpo_trainer import Qwen2VLGRPOVLLMTrainer
3 |
4 | __all__ = ["Qwen2VLGRPOTrainer", "Qwen2VLGRPOVLLMTrainer"]
5 |
--------------------------------------------------------------------------------
/src/r1-v/temp_image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hustvl/AlphaDrive/bb4104c680a3e4d70e7e998e08de90f6e0acf8c5/src/r1-v/temp_image.png
--------------------------------------------------------------------------------
/src/scripts/run_grpo_clevr.sh:
--------------------------------------------------------------------------------
1 | export DEBUG_MODE="true" # Enable Debug if you want to see the rollout of model during RL
2 | export LOG_PATH="./debug_log_2b.txt"
3 |
4 | torchrun --nproc_per_node="8" \
5 | --nnodes="1" \
6 | --node_rank="0" \
7 | --master_addr="127.0.0.1" \
8 | --master_port="12345" \
9 | src/open_r1/grpo.py \
10 | --output_dir \
11 | --model_name_or_path \
12 | --dataset_name \ #https://huggingface.co/datasets/leonardPKU/clevr_cogen_a_train
13 | --max_prompt_length 1024 \
14 | --per_device_train_batch_size 1 \
15 | --gradient_accumulation_steps 2 \
16 | --logging_steps 1 \
17 | --bf16 \
18 | --report_to wandb \
19 | --gradient_checkpointing false \
20 | --attn_implementation flash_attention_2 \
21 | --max_pixels 401408 \
22 | --num_train_epochs 2 \
23 | --run_name Qwen2-VL-2B-GRPO-CLEVR-70k \
24 | --save_steps 100 \
25 | --save_only_model true
--------------------------------------------------------------------------------
/src/scripts/run_grpo_vllm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # The latest vllm==0.7.2 is required for this script: pip3 install vllm==0.7.2
4 |
5 |
6 | export DEBUG_MODE="true"
7 | export LOG_PATH="./vllm_run.txt"
8 |
9 | QWEN_PATH="PATH_TO_QWEN_2B_CKPT"
10 | HF_DATASET="MMInstruction/Clevr_CoGenT_TrainA_70K_Complex"
11 | OUTPUT_DIR="OUTPUT_DIR"
12 | RUN_NAME="RUN_NAME_FOR_WANDB"
13 |
14 | # NOTE: you are expected to use X + 1 cards for X training proc and 1 vLLM proc
15 | # e.g., the visible devices should be 0,1,2,3,4 for 5 cards, and --nproc_per_node="4"
16 |
17 | CUDA_VISIBLE_DEVICES="0,1,2,3,4" torchrun --nproc_per_node="4" \
18 | --nnodes="1" \
19 | --node_rank="0" \
20 | --master_addr="127.0.0.1" \
21 | --master_port="12345" \
22 | src/open_r1/grpo.py --use_vllm True \
23 | --output_dir $OUTPUT_DIR \
24 | --model_name_or_path $QWEN_PATH \
25 | --dataset_name $HF_DATASET \
26 | --max_prompt_length 512 \
27 | --max_completion_length 1024 \
28 | --temperature 1.0 \
29 | --num_generations 4 \
30 | --per_device_train_batch_size 1 \
31 | --gradient_accumulation_steps 4 \
32 | --logging_steps 1 \
33 | --bf16 \
34 | --report_to wandb \
35 | --gradient_checkpointing true \
36 | --attn_implementation flash_attention_2 \
37 | --max_pixels 400000 \
38 | --max_steps 13125 \
39 | --run_name $RUN_NAME \
40 | --save_steps 1000 \
41 | --save_only_model true
42 |
--------------------------------------------------------------------------------
/src/scripts/run_sft_clevr.sh:
--------------------------------------------------------------------------------
1 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file src/open-r1-multimodal/configs/zero2.yaml src/open-r1-multimodal/src/open_r1/sft.py --config src/open-r1-multimodal/configs/qwen2vl_sft_config.yaml
--------------------------------------------------------------------------------
/src/scripts/test_grpo_geoqa_multigpu.sh:
--------------------------------------------------------------------------------
1 | r1_v_path=/workspace/xxx/github/R1-V
2 | cd ${r1_v_path}
3 |
4 | model_path=${r1_v_path}/output/train@geo170k/checkpoint-30
5 | batch_size=4
6 | output_path=${r1_v_path}/output/train@geo170k/eval/res@checkpoint-30.json
7 | prompt_path=${r1_v_path}/src/eval/prompts/geoqa_test_prompts.jsonl
8 | gpu_ids=0,1,2,3,4,5,6,7
9 |
10 | python src/eval/test_qwen2vl_geoqa_multigpu.py \
11 | --model_path ${model_path} \
12 | --batch_size ${batch_size} \
13 | --output_path ${output_path} \
14 | --prompt_path ${prompt_path} \
15 | --gpu_ids ${gpu_ids}
16 |
--------------------------------------------------------------------------------
/train_tools/run_train_grpo.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # dynamic settings
4 | export ENV_NAME="/path/to/your/python/env"
5 | export MODEL_PATH="/path/to/your/Qwen2-VL-2B-Instruct"
6 | export OUT_NAME="Qwen2-VL-2B-EXP"
7 | export TRAIN_DATA="/path/to/your/train/data"
8 | export EVAL_DATA="/path/to/your/val/data"
9 | export EVAL_SAVE_NAME="eval_result.json"
10 | export OUTDIR="/path/to/your/out_dir"
11 | export WORKING_PATH="path/to/your/AlphaDrive"
12 |
13 | cd ${WORKING_PATH}
14 |
15 |
16 | # setup environments
17 | echo "Setup environments..."
18 | # export NCCL_P2P_DISABLE="1"
19 | # export NCCL_IB_DISABLE="1"
20 |
21 | mkdir -p ${OUTDIR}
22 |
23 |
24 | echo "Training Process..."
25 | cd src/r1-v
26 |
27 |
28 | $ENV_NAME/bin/torchrun --nproc_per_node="8" \
29 | --nnodes="2" \
30 | --node_rank="0" \
31 | --master_addr="127.0.0.1" \
32 | --master_port="12345" \
33 | src/open_r1/grpo.py \
34 | --output_dir $OUTDIR/$OUT_NAME \
35 | --model_name_or_path $MODEL_PATH \
36 | --dataset_name $TRAIN_DATA \
37 | --max_prompt_length 1024 \
38 | --per_device_train_batch_size 1 \
39 | --gradient_accumulation_steps 2 \
40 | --logging_steps 1 \
41 | --bf16 \
42 | --report_to tensorboard \
43 | --gradient_checkpointing false \
44 | --attn_implementation flash_attention_2 \
45 | --max_pixels 401408 \
46 | --reward_funcs "plan_speed_reward" "plan_path_reward" "plan_format_reward" \
47 | --num_train_epochs 1 \
48 | --run_name $OUT_NAME \
49 | --save_steps 1000 \
50 | --save_only_model true \
51 | --num_generations 2 # number of outputs G in grpo, reduce it would lead to faster training and smaller memory cost but higher variance
52 |
53 |
54 | echo "Validation Process..."
55 | cd ${WORKING_PATH}
56 | $ENV_NAME/bin/python eval_tools/qwen2vl_plan_cmd_eval_grpo.py \
57 | --eval_data_path $EVAL_DATA \
58 | --model_path $OUTDIR/$OUT_NAME \
59 | --save_path $OUTDIR/$OUT_NAME/$EVAL_SAVE_NAME
60 |
--------------------------------------------------------------------------------
/train_tools/run_train_sft.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | # dynamic settings
4 | export ENV_NAME="/path/to/your/python/env"
5 | export MODEL_PATH="/path/to/your/Qwen2-VL-2B-Instruct"
6 | export OUT_NAME="Qwen2-VL-2B-EXP"
7 | export EVAL_DATA="/path/to/your/val/data"
8 | export EVAL_SAVE_NAME="eval_result.json"
9 | export OUTDIR="/path/to/your/out_dir"
10 | export WORKING_PATH="path/to/your/AlphaDrive"
11 |
12 | cd ${WORKING_PATH}
13 |
14 |
15 | # setup environments
16 | echo "Setup environments..."
17 | # export NCCL_P2P_DISABLE="1"
18 | # export NCCL_IB_DISABLE="1"
19 |
20 | mkdir -p ${OUTDIR}
21 |
22 |
23 | echo "Training Process..."
24 | $ENV_NAME/bin/accelerate launch --config_file src/r1-v/configs/zero2.yaml src/r1-v/src/open_r1/sft.py --config src/r1-v/configs/qwen2vl_sft_config.yaml
25 |
26 |
27 | echo "Validation Process..."
28 | $ENV_NAME/bin/python eval_tools/qwen2vl_plan_cmd_eval_sft.py \
29 | --eval_data_path $EVAL_DATA \
30 | --model_path $OUTDIR/$OUT_NAME \
31 | --save_path $OUTDIR/$OUT_NAME/$EVAL_SAVE_NAME
32 |
--------------------------------------------------------------------------------