├── .gitignore
├── LICENSE
├── README.md
├── install.sh
├── local_scripts
├── fsdp.yaml
├── fsdp_config.json
├── zero2.json
├── zero2_offload.json
├── zero3++.json
└── zero3.json
├── open_r1
├── __init__.py
├── evaluate.py
├── generate.py
├── grpo.py
└── trainer
│ ├── __init__.py
│ ├── grpo_config.py
│ ├── grpo_trainer.py
│ └── utils
│ ├── __init__.py
│ ├── misc.py
│ ├── prompt_gallery.py
│ ├── vllm_client.py
│ └── vllm_client_v2.py
├── requirements.txt
├── scripts
├── .DS_Store
├── debug
│ ├── debug.sh
│ ├── debug_gh200.sh
│ ├── debug_gh200_local_vllm.sh
│ └── test_vllm.sh
├── grpo
│ ├── run_vllm.sh
│ ├── train_a100.sh
│ ├── train_a100_2B.sh
│ ├── train_a100_2B_SFT.sh
│ ├── train_a100_2B_close.sh
│ ├── train_a100_2B_close_SFT.sh
│ ├── train_a100_SFT.sh
│ ├── train_a100_close.sh
│ ├── train_fsdp.sh
│ ├── train_fused.sh
│ ├── train_gh200.sh
│ ├── train_gh200_2B.sh
│ ├── train_gh200_2B_SFT.sh
│ ├── train_gh200_2B_close.sh
│ ├── train_gh200_SFT.sh
│ ├── train_gh200_close.sh
│ ├── train_gh200_close_SFT.sh
│ └── train_zero3.sh
├── inference
│ └── run_sgg_inference.sh
├── sft
│ ├── 2B_sgg.sh
│ ├── 2B_sgg_predefined.sh
│ ├── 7B_sgg.sh
│ ├── 7B_sgg_lora.sh
│ └── 7B_sgg_predefined.sh
└── sft_local
│ ├── 2B_sgg.sh
│ ├── 2B_sgg_predefined.sh
│ ├── 7B_sgg.sh
│ └── 7B_sgg_predefined.sh
├── setup.py
├── src
├── __init__.py
├── mega_1m_category.py
├── psg_categories.json
├── sft_sgg.py
├── sgg_gather_preds.py
├── sgg_inference_vllm.py
├── utils
│ ├── __init__.py
│ ├── bbox_overlaps.py
│ ├── bounding_box.py
│ ├── cocoeval.py
│ ├── misc.py
│ ├── sgg_eval.py
│ ├── sgg_metrics.py
│ ├── wordnet.py
│ └── zeroshot_triplet.pytorch
├── vg150_eval.py
├── vg_synonyms.py
└── vllm_server_v2.py
└── tests
├── test_fsdp.py
├── test_rewards.py
├── test_sampler.py
├── test_vllm.py
└── test_vllm_local.py
/.gitignore:
--------------------------------------------------------------------------------
1 | *__pycache__*
2 | *pyc
3 | *.out
4 | *.log
5 | *.egg-info*
6 |
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # R1-SGG: Compile Scene Graphs with Reinforcement Learning
2 |
3 | ## **Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning**
4 | [](https://arxiv.org/abs/2504.13617) [](LICENSE) [](https://huggingface.co/spaces/JosephZ/R1-SGG)
5 | ---
6 |
7 | ## 🚀 Update
8 | - ✅ [R1-SGG-7B](https://huggingface.co/JosephZ/R1-SGG-7B), [R1-SGG-Zero-7B](https://huggingface.co/JosephZ/R1-SGG-Zero-7B)
9 | - ✅ Support [PSG](https://github.com/Jingkang50/OpenPSG) dataset (bbox format only, not Panoptic)
10 | - ✅ Updated loss implementation
11 | - ✅ Always use `custom_per_device_train_batch_size` instead of `per_device_train_batch_size` for faster sampling under gradient accumulation
12 | - ⚠️ Current loss implementation might still be affected by gradient accumulation: [trl issue #3021](https://github.com/huggingface/trl/issues/3021)
13 |
14 | ---
15 |
16 | ## 🛠️ Setup Environment
17 | ```bash
18 | bash install.sh
19 | ```
20 | Main dependencies:
21 | ```bash
22 | - torch == 2.5.0 or 2.5.1 (cu124, optional)
23 | - transformers (supports Qwen2VL, Qwen2.5VL)
24 | - trl
25 | - vLLM
26 | ```
27 |
28 | ---
29 |
30 | ## 📚 Dataset
31 | Load preprocessed datasets via:
32 | ```python
33 | from datasets import load_dataset
34 |
35 | db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"]
36 | db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"]
37 | ```
38 | or for PSG:
39 | ```python
40 | db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships
41 | db_val = load_dataset("JosephZ/psg_test_sg")["train"]
42 | ```
43 | We transformed VG150 into HuggingFace Datasets format with keys:
44 | - `image_id`
45 | - `image`
46 | - `prompt_open`
47 | - `prompt_close`
48 | - `objects`
49 | - `relationships`
50 |
51 | ---
52 |
53 | ## 🔥 Supported Models
54 | - [x] Qwen/Qwen2-VL-2B-Instruct
55 | - [x] Qwen/Qwen2-VL-7B-Instruct
56 | - [x] Qwen/Qwen2.5-VL-3B-Instruct
57 | - [x] Qwen/Qwen2.5-VL-7B-Instruct
58 |
59 | ---
60 |
61 | ## 🏋️♂️ Training
62 |
63 | ### Training with Supervised Fine-Tuning (SFT)
64 |
65 | For **SLURM users**:
66 | ```bash
67 | sbatch scripts/sft/7B_sgg.sh
68 | ```
69 |
70 | For **local machines**:
71 | ```bash
72 | bash scripts/sft_local/7B_sgg.sh
73 | ```
74 | ⏱️ Approximate training time:
75 | - 2B models: ~4 hours (4×A100 SXM4 GPUs)
76 | - 7B models: ~10 hours (4×A100 SXM4 GPUs)
77 |
78 | ---
79 |
80 | ### Training with Reinforcement Learning (GRPO)
81 | ** Update (11/05/2025): to use "Hard Recall"**:
82 | ```
83 | --reward_funcs format_reward edge_hard_reward
84 | ```
85 |
86 | For **A100 GPUs**:
87 | ```bash
88 | sbatch scripts/grpo/train_a100_2B.sh
89 | ```
90 | (12 hours on 16×A100 GPUs)
91 |
92 | For **GH200 GPUs**:
93 | ```bash
94 | sbatch scripts/grpo/train_gh200.sh
95 | ```
96 | (16 hours on 16×GH200 GPUs)
97 |
98 | For clusters with many RTX_3090/4090 GPUs:
99 | ```bash
100 | sbatch scripts/grpo/train_fused.sh
101 | ```
102 | - Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks.
103 | - (Fun fact: training with 120×RTX_4090 is crazy but severely limited by communication latency.)
104 |
105 | 💡 **Recommended learning rate**: `6e-7`.
106 |
107 | ---
108 |
109 | ## 🧪 Inference and Evaluation
110 |
111 | ### Inference with SFT-trained models:
112 | ```bash
113 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR
114 | ```
115 | For models trained **with predefined categories**, add `true`:
116 | ```bash
117 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true
118 | ```
119 |
120 | ### Inference with GRPO-trained models:
121 | ```bash
122 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true
123 | ```
124 |
125 | ### Evaluation:
126 | ```bash
127 | DATASET_TYPE=vg # or psg
128 | python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json
129 | python src/vg150_eval.py $DATASET sgg_pred_results.json
130 | ```
131 |
132 | ---
133 |
134 | ## 🤝 Acknowledgement
135 | The `GRPOTrainer` used in this project is based on [trl's GRPOTrainer](https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py), extended to support multimodal inputs.
136 |
137 | ---
138 |
139 | ## 📖 Citation
140 | If you find this work helpful, please cite:
141 | ```bibtex
142 | @article{chen2025compile,
143 | title={Compile Scene Graphs with Reinforcement Learning},
144 | author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen},
145 | journal={arXiv preprint arXiv:2504.13617},
146 | year={2025}
147 | }
148 | ```
149 |
150 | ---
151 |
152 | # ✨ Happy Compiling!
153 |
154 |
--------------------------------------------------------------------------------
/install.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124
5 |
6 | #pip install transformers@git+https://github.com/huggingface/transformers.git@2c2495cc7b0e3e2942a9310f61548f40a2bc8425
7 | pip install transformers==4.50.3
8 |
9 | pip install trl@git+https://github.com/huggingface/trl.git@ece6738686a8527345532e6fed8b3b1b75f16b16
10 |
11 | pip install --upgrade --no-build-isolation flash-attn==2.7.4.post1
12 |
13 | # for GH200,
14 | #MAX_JOBS=20 pip install --upgrade --no-build-isolation flash-attn==2.7.4.post1
15 |
16 | #git clone https://github.com/triton-lang/triton.git && git checkout 85267600 && cd triton && \
17 | #pip install -r python/requirements.txt # build-time dependencies && \
18 | #pip install -e python
19 |
20 | pip install -r requirements.txt
21 |
22 | # for GH200,
23 | #pip uninstall -y vllm &&git clone https://github.com/vllm-project/vllm.git&& cd vllm && git checkout ed6e9075d31e32c8548b480\
24 | #python use_existing_torch.py && pip install -r requirements/build.txt && pip install --no-build-isolation -e .
25 |
26 | pip install -e .
27 |
--------------------------------------------------------------------------------
/local_scripts/fsdp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | enable_cpu_affinity: false
6 | fsdp_config:
7 | fsdp_activation_checkpointing: true
8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9 | fsdp_backward_prefetch: BACKWARD_PRE
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_forward_prefetch: false
12 | fsdp_offload_params: false
13 | fsdp_sharding_strategy: HYBRID_SHARD
14 | fsdp_state_dict_type: SHARDED_STATE_DICT
15 | fsdp_sync_module_states: true
16 | fsdp_transformer_layer_cls_to_wrap: null
17 | fsdp_use_orig_params: true
18 |
19 | main_training_function: main
20 | mixed_precision: bf16
21 | rdzv_backend: static
22 | same_network: true
23 | tpu_env: []
24 | tpu_use_cluster: false
25 | tpu_use_sudo: false
26 | use_cpu: false
27 |
--------------------------------------------------------------------------------
/local_scripts/fsdp_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "backward_prefetch": "backward_pre",
3 | "forward_prefetch": "true",
4 | "activation_checkpointing": "true"
5 | }
6 |
--------------------------------------------------------------------------------
/local_scripts/zero2.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "none",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
42 |
--------------------------------------------------------------------------------
/local_scripts/zero2_offload.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "optimizer": {
14 | "type": "AdamW",
15 | "params": {
16 | "lr": "auto",
17 | "betas": "auto",
18 | "eps": "auto",
19 | "weight_decay": "auto"
20 | }
21 | },
22 | "zero_optimization": {
23 | "stage": 2,
24 | "offload_optimizer": {
25 | "device": "cpu",
26 | "pin_memory": true
27 | },
28 | "allgather_partitions": true,
29 | "allgather_bucket_size": 2e8,
30 | "overlap_comm": false,
31 | "reduce_scatter": true,
32 | "reduce_bucket_size": 2e8,
33 | "contiguous_gradients": true
34 | },
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
42 |
--------------------------------------------------------------------------------
/local_scripts/zero3++.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 | "zero_optimization": {
14 | "stage": 3,
15 | "offload_optimizer": {
16 | "device": "cpu",
17 | "pin_memory": true
18 | },
19 | "offload_param": {
20 | "device": "cpu",
21 | "pin_memory": true
22 | },
23 | "zero_hpz_partition_size": 8,
24 | "zero_quantized_weights": false,
25 | "zero_quantized_gradients": false,
26 | "overlap_comm": true,
27 | "contiguous_gradients": true,
28 | "sub_group_size": 1e9,
29 | "reduce_bucket_size": "auto",
30 | "stage3_prefetch_bucket_size": "auto",
31 | "stage3_param_persistence_threshold": "auto",
32 | "stage3_max_live_parameters": 1e9,
33 | "stage3_max_reuse_distance": 1e9,
34 | "stage3_gather_16bit_weights_on_model_save": true
35 | },
36 | "train_batch_size": "auto",
37 | "train_micro_batch_size_per_gpu": "auto",
38 | "gradient_accumulation_steps": "auto",
39 | "gradient_clipping": "auto",
40 | "zero_allow_untested_optimizer": true,
41 | }
42 |
--------------------------------------------------------------------------------
/local_scripts/zero3.json:
--------------------------------------------------------------------------------
1 | {
2 | "fp16": {
3 | "enabled": "auto",
4 | "loss_scale": 0,
5 | "loss_scale_window": 1000,
6 | "initial_scale_power": 16,
7 | "hysteresis": 2,
8 | "min_loss_scale": 1
9 | },
10 | "bf16": {
11 | "enabled": "auto"
12 | },
13 |
14 | "zero_optimization": {
15 | "stage": 3,
16 | "offload_optimizer": {
17 | "device": "none",
18 | "pin_memory": true
19 | },
20 | "offload_param": {
21 | "device": "none",
22 | "pin_memory": true
23 | },
24 | "overlap_comm": true,
25 | "contiguous_gradients": true,
26 | "sub_group_size": 1e9,
27 | "reduce_bucket_size": "auto",
28 | "stage3_prefetch_bucket_size": "auto",
29 | "stage3_param_persistence_threshold": "auto",
30 | "stage3_max_live_parameters": 1e9,
31 | "stage3_max_reuse_distance": 1e9,
32 | "stage3_gather_16bit_weights_on_model_save": true
33 | },
34 |
35 | "gradient_accumulation_steps": "auto",
36 | "gradient_clipping": "auto",
37 | "steps_per_print": 100,
38 | "train_batch_size": "auto",
39 | "train_micro_batch_size_per_gpu": "auto",
40 | "wall_clock_breakdown": false
41 | }
--------------------------------------------------------------------------------
/open_r1/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/open_r1/__init__.py
--------------------------------------------------------------------------------
/open_r1/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom evaluation tasks for LightEval."""
16 |
17 | from lighteval.metrics.dynamic_metrics import (
18 | ExprExtractionConfig,
19 | LatexExtractionConfig,
20 | multilingual_extractive_match_metric,
21 | )
22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
23 | from lighteval.tasks.requests import Doc
24 | from lighteval.utils.language import Language
25 |
26 |
27 | metric = multilingual_extractive_match_metric(
28 | language=Language.ENGLISH,
29 | fallback_mode="first_match",
30 | precision=5,
31 | gold_extraction_target=(LatexExtractionConfig(),),
32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()),
33 | aggregation_function=max,
34 | )
35 |
36 |
37 | def prompt_fn(line, task_name: str = None):
38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically"""
39 | return Doc(
40 | task_name=task_name,
41 | query=line["problem"],
42 | choices=[line["solution"]],
43 | gold_index=0,
44 | )
45 |
46 |
47 | # Define tasks
48 | aime24 = LightevalTaskConfig(
49 | name="aime24",
50 | suite=["custom"],
51 | prompt_function=prompt_fn,
52 | hf_repo="HuggingFaceH4/aime_2024",
53 | hf_subset="default",
54 | hf_avail_splits=["train"],
55 | evaluation_splits=["train"],
56 | few_shots_split=None,
57 | few_shots_select=None,
58 | generation_size=32768,
59 | metric=[metric],
60 | version=1,
61 | )
62 | math_500 = LightevalTaskConfig(
63 | name="math_500",
64 | suite=["custom"],
65 | prompt_function=prompt_fn,
66 | hf_repo="HuggingFaceH4/MATH-500",
67 | hf_subset="default",
68 | hf_avail_splits=["test"],
69 | evaluation_splits=["test"],
70 | few_shots_split=None,
71 | few_shots_select=None,
72 | generation_size=32768,
73 | metric=[metric],
74 | version=1,
75 | )
76 |
77 | # Add tasks to the table
78 | TASKS_TABLE = []
79 | TASKS_TABLE.append(aime24)
80 | TASKS_TABLE.append(math_500)
81 |
82 | # MODULE LOGIC
83 | if __name__ == "__main__":
84 | print([t["name"] for t in TASKS_TABLE])
85 | print(len(TASKS_TABLE))
86 |
--------------------------------------------------------------------------------
/open_r1/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | from distilabel.llms import OpenAILLM
18 | from distilabel.pipeline import Pipeline
19 | from distilabel.steps.tasks import TextGeneration
20 |
21 |
22 | def build_distilabel_pipeline(
23 | model: str,
24 | base_url: str = "http://localhost:8000/v1",
25 | prompt_column: Optional[str] = None,
26 | temperature: Optional[float] = None,
27 | top_p: Optional[float] = None,
28 | max_new_tokens: int = 8192,
29 | num_generations: int = 1,
30 | ) -> Pipeline:
31 | generation_kwargs = {"max_new_tokens": max_new_tokens}
32 |
33 | if temperature is not None:
34 | generation_kwargs["temperature"] = temperature
35 |
36 | if top_p is not None:
37 | generation_kwargs["top_p"] = top_p
38 |
39 | with Pipeline().ray() as pipeline:
40 | TextGeneration(
41 | llm=OpenAILLM(
42 | base_url=base_url,
43 | api_key="something",
44 | model=model,
45 | # thinking can take some time...
46 | timeout=10 * 60,
47 | generation_kwargs=generation_kwargs,
48 | ),
49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion
51 | num_generations=num_generations,
52 | )
53 |
54 | return pipeline
55 |
56 |
57 | if __name__ == "__main__":
58 | import argparse
59 |
60 | from datasets import load_dataset
61 |
62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
63 | parser.add_argument(
64 | "--hf-dataset",
65 | type=str,
66 | required=True,
67 | help="HuggingFace dataset to load",
68 | )
69 | parser.add_argument(
70 | "--hf-dataset-config",
71 | type=str,
72 | required=False,
73 | help="Dataset config to use",
74 | )
75 | parser.add_argument(
76 | "--hf-dataset-split",
77 | type=str,
78 | default="train",
79 | help="Dataset split to use",
80 | )
81 | parser.add_argument("--prompt-column", type=str, default="prompt")
82 | parser.add_argument(
83 | "--model",
84 | type=str,
85 | required=True,
86 | help="Model name to use for generation",
87 | )
88 | parser.add_argument(
89 | "--vllm-server-url",
90 | type=str,
91 | default="http://localhost:8000/v1",
92 | help="URL of the vLLM server",
93 | )
94 | parser.add_argument(
95 | "--temperature",
96 | type=float,
97 | help="Temperature for generation",
98 | )
99 | parser.add_argument(
100 | "--top-p",
101 | type=float,
102 | help="Top-p value for generation",
103 | )
104 | parser.add_argument(
105 | "--max-new-tokens",
106 | type=int,
107 | default=8192,
108 | help="Maximum number of new tokens to generate",
109 | )
110 | parser.add_argument(
111 | "--num-generations",
112 | type=int,
113 | default=1,
114 | help="Number of generations per problem",
115 | )
116 | parser.add_argument(
117 | "--hf-output-dataset",
118 | type=str,
119 | required=False,
120 | help="HuggingFace repo to push results to",
121 | )
122 | parser.add_argument(
123 | "--private",
124 | action="store_true",
125 | help="Whether to make the output dataset private when pushing to HF Hub",
126 | )
127 |
128 | args = parser.parse_args()
129 |
130 | print("\nRunning with arguments:")
131 | for arg, value in vars(args).items():
132 | print(f" {arg}: {value}")
133 | print()
134 |
135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split)
137 | print("Dataset loaded!")
138 |
139 | pipeline = build_distilabel_pipeline(
140 | model=args.model,
141 | base_url=args.vllm_server_url,
142 | prompt_column=args.prompt_column,
143 | temperature=args.temperature,
144 | top_p=args.top_p,
145 | max_new_tokens=args.max_new_tokens,
146 | num_generations=args.num_generations,
147 | )
148 |
149 | print("Running generation pipeline...")
150 | distiset = pipeline.run(dataset=dataset, use_cache=False)
151 | print("Generation pipeline finished!")
152 |
153 | if args.hf_output_dataset:
154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private)
156 | print("Dataset pushed!")
157 |
--------------------------------------------------------------------------------
/open_r1/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | from .grpo_config import GRPOConfig
2 | from .grpo_trainer import GRPOTrainerV2
3 |
4 |
5 | __all__ = ["GRPOTrainerV2", "GRPOConfig"]
6 |
--------------------------------------------------------------------------------
/open_r1/trainer/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/open_r1/trainer/utils/__init__.py
--------------------------------------------------------------------------------
/open_r1/trainer/utils/misc.py:
--------------------------------------------------------------------------------
1 | import base64
2 | from io import BytesIO
3 | from PIL import Image
4 |
5 | from transformers.utils.import_utils import _is_package_available
6 |
7 |
8 | _fastapi_available = _is_package_available("fastapi")
9 | _pydantic_available = _is_package_available("pydantic")
10 | _uvicorn_available = _is_package_available("uvicorn")
11 | _vllm_available = _is_package_available("vllm")
12 | _requests_available = _is_package_available("requests")
13 |
14 | def is_fastapi_available() -> bool:
15 | return _fastapi_available
16 |
17 |
18 | def is_pydantic_available() -> bool:
19 | return _pydantic_available
20 |
21 | def is_uvicorn_available() -> bool:
22 | return _uvicorn_available
23 |
24 |
25 | def is_vllm_available() -> bool:
26 | return _vllm_available
27 |
28 | def is_requests_available() -> bool:
29 | return _requests_available
30 |
31 |
32 | def is_pil_image(image) -> bool:
33 | return isinstance(image, Image.Image)
34 |
35 |
36 | def encode_image_to_base64(image: Image.Image, format: str = "PNG") -> str:
37 | """
38 | Encode a PIL Image to a base64 string.
39 |
40 | Args:
41 | image (PIL.Image): The image to encode.
42 | format (str): Image format to use (e.g., "PNG", "JPEG"). Default is "PNG".
43 |
44 | Returns:
45 | str: Base64-encoded string of the image.
46 | """
47 | buffer = BytesIO()
48 | image.save(buffer, format=format)
49 | buffer.seek(0)
50 | encoded_string = base64.b64encode(buffer.read()).decode("utf-8")
51 | return encoded_string
52 |
53 | def decode_base64_to_image(base64_str: str) -> Image.Image:
54 | """
55 | Decode a base64 string back to a PIL Image.
56 |
57 | Args:
58 | base64_str (str): Base64-encoded string of the image.
59 |
60 | Returns:
61 | PIL.Image: Decoded image.
62 | """
63 | image_data = base64.b64decode(base64_str)
64 | buffer = BytesIO(image_data)
65 | image = Image.open(buffer)
66 | return image
67 |
--------------------------------------------------------------------------------
/open_r1/trainer/utils/prompt_gallery.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | VG150_OBJ_CATEGORIES = ['__background__', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
4 |
5 |
6 | VG150_PREDICATES = ['__background__', "above", "across", "against", "along", "and", "at", "attached to", "behind", "belonging to", "between", "carrying", "covered in", "covering", "eating", "flying in", "for", "from", "growing on", "hanging from", "has", "holding", "in", "in front of", "laying on", "looking at", "lying on", "made of", "mounted on", "near", "of", "on", "on back of", "over", "painted on", "parked on", "part of", "playing", "riding", "says", "sitting on", "standing on", "to", "under", "using", "walking in", "walking on", "watching", "wearing", "wears", "with"]
7 |
8 |
9 | VG150_BASE_OBJ_CATEGORIES = set(['tile', 'drawer', 'men', 'railing', 'stand', 'towel', 'sneaker', 'vegetable', 'screen', 'vehicle', 'animal', 'kite', 'cabinet', 'sink', 'wire', 'fruit', 'curtain', 'lamp', 'flag', 'pot', 'sock', 'boot', 'guy', 'kid', 'finger', 'basket', 'wave', 'lady', 'orange', 'number', 'toilet', 'post', 'room', 'paper', 'mountain', 'paw', 'banana', 'rock', 'cup', 'hill', 'house', 'airplane', 'plant', 'skier', 'fork', 'box', 'seat', 'engine', 'mouth', 'letter', 'windshield', 'desk', 'board', 'counter', 'branch', 'coat', 'logo', 'book', 'roof', 'tie', 'tower', 'glove', 'sheep', 'neck', 'shelf', 'bottle', 'cap', 'vase', 'racket', 'ski', 'phone', 'handle', 'boat', 'tire', 'flower', 'child', 'bowl', 'pillow', 'player', 'trunk', 'bag', 'wing', 'light', 'laptop', 'pizza', 'cow', 'truck', 'jean', 'eye', 'arm', 'leaf', 'bird', 'surfboard', 'umbrella', 'food', 'people', 'nose', 'beach', 'sidewalk', 'helmet', 'face', 'skateboard', 'motorcycle', 'clock', 'bear'])
10 |
11 | VG150_BASE_PREDICATE = set(["between", "to", "made of", "looking at", "along", "laying on", "using", "carrying", "against", "mounted on", "sitting on", "flying in", "covering", "from", "over", "near", "hanging from", "across", "at", "above", "watching", "covered in", "wearing", "holding", "and", "standing on", "lying on", "growing on", "under", "on back of", "with", "has", "in front of", "behind", "parked on"])
12 |
13 |
14 | PROMPT_SG='Generate a structured scene graph for an image using the following format:\n\n```json\n{\n "objects": [\n {"id": "object_name.number", "bbox": [x1, y1, x2, y2]},\n ...\n ],\n "relationships": [\n {"subject": "object_name.number", "predicate": "relationship_type", "object": "object_name.number"},\n ...\n ]\n}\n```\n\n### **Guidelines:**\n- **Objects:**\n - Assign a unique ID for each object using the format `"object_name.number"` (e.g., `"person.1"`, `"bike.2"`).\n - Provide its bounding box `[x1, y1, x2, y2]` in integer pixel format.\n - Include all visible objects, even if they have no relationships.\n\n- **Relationships:**\n - Represent interactions accurately using `"subject"`, `"predicate"`, and `"object"`.\n - Omit relationships for orphan objects.\n\n### **Example Output:**\n```json\n{\n "objects": [\n {"id": "person.1", "bbox": [120, 200, 350, 700]},\n {"id": "bike.2", "bbox": [100, 600, 400, 800]},\n {"id": "helmet.3", "bbox": [150, 150, 280, 240]},\n {"id": "tree.4", "bbox": [500, 100, 750, 700]}\n ],\n "relationships": [\n {"subject": "person.1", "predicate": "riding", "object": "bike.2"},\n {"subject": "person.1", "predicate": "wearing", "object": "helmet.3"}\n ]\n}\n```\n\nNow, generate the complete scene graph for the provided image:\n'
15 |
16 | PROMPT_CLOSE_TEMPLATE='Generate a structured scene graph for an image using the specified object and relationship categories.\n\n### **Output Format:**\n```json\n{\n "objects": [\n {"id": "object_name.number", "bbox": [x1, y1, x2, y2]},\n ...\n ],\n "relationships": [\n {"subject": "object_name.number", "predicate": "relationship_type", "object": "object_name.number"},\n ...\n ]\n}\n```\n\n### **Guidelines:**\n- **Objects:**\n - Assign unique IDs in the format `"object_name.number"` (e.g., `"person.1"`). The **object_name** must belong to the predefined object set: `{OBJ_CLS}`.\n - Provide a bounding box `[x1, y1, x2, y2]` in integer pixel format.\n - Include all visible objects, even if they have no relationships.\n\n- **Relationships:**\n - Define relationships using `"subject"`, `"predicate"`, and `"object"`.\n - The **predicate** must belong to the predefined relationship set: `{REL_CLS}`.\n - Omit relationships for orphan objects.\n\n### **Example Output:**\n```json\n{\n "objects": [\n {"id": "person.1", "bbox": [120, 200, 350, 700]},\n {"id": "bike.2", "bbox": [100, 600, 400, 800]},\n {"id": "helmet.3", "bbox": [150, 150, 280, 240]},\n {"id": "tree.4", "bbox": [500, 100, 750, 700]}\n ],\n "relationships": [\n {"subject": "person.1", "predicate": "riding", "object": "bike.2"},\n {"subject": "person.1", "predicate": "wearing", "object": "helmet.3"}\n ]\n}\n```\n\nNow, generate the complete scene graph for the provided image:\n'
17 |
18 |
19 |
20 | psg_categories = {"thing_classes": ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"], "stuff_classes": ["banner", "blanket", "bridge", "cardboard", "counter", "curtain", "door-stuff", "floor-wood", "flower", "fruit", "gravel", "house", "light", "mirror-stuff", "net", "pillow", "platform", "playingfield", "railroad", "river", "road", "roof", "sand", "sea", "shelf", "snow", "stairs", "tent", "towel", "wall-brick", "wall-stone", "wall-tile", "wall-wood", "water-other", "window-blind", "window-other", "tree-merged", "fence-merged", "ceiling-merged", "sky-other-merged", "cabinet-merged", "table-merged", "floor-other-merged", "pavement-merged", "mountain-merged", "grass-merged", "dirt-merged", "paper-merged", "food-other-merged", "building-other-merged", "rock-merged", "wall-other-merged", "rug-merged"], "predicate_classes": ["over", "in front of", "beside", "on", "in", "attached to", "hanging from", "on back of", "falling off", "going down", "painted on", "walking on", "running on", "crossing", "standing on", "lying on", "sitting on", "flying over", "jumping over", "jumping from", "wearing", "holding", "carrying", "looking at", "guiding", "kissing", "eating", "drinking", "feeding", "biting", "catching", "picking", "playing with", "chasing", "climbing", "cleaning", "playing", "touching", "pushing", "pulling", "opening", "cooking", "talking to", "throwing", "slicing", "driving", "riding", "parked on", "driving on", "about to hit", "kicking", "swinging", "entering", "exiting", "enclosing", "leaning on"]}
21 |
22 | PSG_OBJ_CATEGORIES = psg_categories['thing_classes'] + psg_categories['stuff_classes']
23 | PSG_REL_CATEGORIES = psg_categories['predicate_classes']
24 |
25 |
26 | def format_prompt_close_sg(obj_cls, rel_cls):
27 | return PROMPT_CLOSE_TEMPLATE.replace("{OBJ_CLS}", json.dumps(obj_cls)).replace("{REL_CLS}", json.dumps(rel_cls))
28 |
29 |
30 | PROMPT_CLOSE_PSG = format_prompt_close_sg(PSG_OBJ_CATEGORIES, PSG_REL_CATEGORIES)
31 | PROMPT_CLOSE_VG150 = format_prompt_close_sg(VG150_OBJ_CATEGORIES[1:], VG150_PREDICATES[1:])
32 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | aiohttp
2 | accelerate==1.4.0
3 | datasets==3.3.2
4 | deepspeed==0.15.4
5 | numpy==1.26.4
6 | pillow==10.4.0
7 | qwen_vl_utils
8 | tqdm
9 | typing_extensions>=4.12.2
10 | pydantic>=2.10.6
11 | pydantic_core>=2.27.2
12 | thinc>=8.2.2
13 | packaging==24.2
14 | PyYAML==6.0.2
15 | spacy
16 | scipy
17 | safetensors==0.5.3
18 | fastapi
19 | openai==1.65.5
20 | vllm==0.7.3
21 | wandb==0.18.3
22 | pycocotools
23 | tabulate
24 | nltk
25 | matplotlib
26 |
--------------------------------------------------------------------------------
/scripts/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/scripts/.DS_Store
--------------------------------------------------------------------------------
/scripts/debug/debug.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt"
6 |
7 | export CUDA_VISIBLE_DEVICES=0
8 |
9 | accelerate launch --num_processes=1 open_r1/grpo.py \
10 | --output_dir models/qwen2vl-sgg-g8 \
11 | --model_name_or_path "Qwen/Qwen2-VL-2B-Instruct" \
12 | --dataset_name $DATA_PATH \
13 | --max_prompt_length 2048 \
14 | --max_completion_length 1024 \
15 | --per_device_train_batch_size 2 \
16 | --gradient_accumulation_steps 1 \
17 | --logging_steps 1 \
18 | --use_vllm true \
19 | --use_local_vllm true\
20 | --use_liger_loss true\
21 | --vllm_gpu_memory_utilization 0.25\
22 | --bf16 \
23 | --report_to wandb \
24 | --gradient_checkpointing true \
25 | --max_pixels 401408 \
26 | --temperature 0.7 \
27 | --top_p 0.01 \
28 | --top_k 1 \
29 | --num_train_epochs 2 \
30 | --run_name Qwen2-VL-2B-GRPO-SGG-debug \
31 | --save_steps 100 \
32 | --num_generations 2
33 |
--------------------------------------------------------------------------------
/scripts/debug/debug_gh200.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export HF_HOME=$SCRATCH/huggingface
4 | # ---------- Environment Setup ----------
5 | export NCCL_ASYNC_ERROR_HANDLING=1
6 | export DEBUG_MODE=True
7 | export WANDB_PROJECT=RL4SGG
8 |
9 |
10 | GROUP_SIZE=8
11 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
12 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
13 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-gh200"
14 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
15 | mkdir -p "$OUTPUT_DIR"
16 |
17 | TP_SIZE=1
18 | PORT_BASE=8000
19 | MAX_PIXELS=$((512 * 28 * 28))
20 |
21 |
22 | MIXED_NODES=1 # Set this dynamically if needed
23 |
24 |
25 | HEAD_NODE_IP=0.0.0.0
26 | MASTER_PORT=29500
27 |
28 | SERVER_IP=$(hostname -I | awk '{print $1}')
29 | SERVER_PORT='8000'
30 |
31 |
32 | # zero2:
33 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
34 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
35 | TRAIN_CMD="open_r1/grpo.py \
36 | --output_dir ${OUTPUT_DIR} \
37 | --model_name_or_path ${MODEL_PATH} \
38 | --dataset_name ${DATA_PATH} \
39 | --max_prompt_length 2048 \
40 | --max_completion_length 1024 \
41 | --per_device_train_batch_size 16 \
42 | --deepspeed ./local_scripts/zero2.json \
43 | --gradient_accumulation_steps 1 \
44 | --logging_steps 1 \
45 | --use_vllm true \
46 | --vllm_server_host ${SERVER_IP} \
47 | --vllm_server_port ${SERVER_PORT} \
48 | --vllm_server_timeout 600 \
49 | --vllm_locate_same_node true\
50 | --vllm_locate_same_remain_gpus 3\
51 | --bf16 true\
52 | --tf32 true\
53 | --report_to wandb \
54 | --gradient_checkpointing true \
55 | --max_pixels ${MAX_PIXELS} \
56 | --temperature 0.3 \
57 | --top_p 0.001 \
58 | --top_k 1 \
59 | --num_train_epochs 1 \
60 | --run_name ${RUN_NAME} \
61 | --save_steps 100 \
62 | --num_generations ${GROUP_SIZE} \
63 | --num_iterations 1 \
64 | --beta 0.0"
65 |
66 |
67 | log_file="vllm_node_0.log"
68 |
69 | # vLLM: GPUs 3
70 | CUDA_VISIBLE_DEVICES=3 python src/vllm_server_v2.py \
71 | --model ${MODEL_PATH} \
72 | --gpu_memory_utilization 0.9 \
73 | --enable-prefix-caching true \
74 | --dtype 'bfloat16' \
75 | --max_model_len 4096 \
76 | --tensor_parallel_size ${TP_SIZE} \
77 | --host '0.0.0.0' \
78 | --port ${PORT_BASE} > ${log_file} 2>&1 &
79 |
80 | echo "waiting for vLLM servers..."
81 | #sleep 60
82 | echo "start training..."
83 | # Training: GPUs 0-3
84 | CUDA_VISIBLE_DEVICES=0,1,2 torchrun --nnodes=1 --nproc_per_node=3 \
85 | --node_rank=0 \
86 | --master_addr=${HEAD_NODE_IP} \
87 | --master_port=${MASTER_PORT} \
88 | ${TRAIN_CMD} > debug-gh200.log 2>&1 &
89 |
90 |
91 |
--------------------------------------------------------------------------------
/scripts/debug/debug_gh200_local_vllm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export HF_HOME=$SCRATCH/huggingface
4 | # ---------- Environment Setup ----------
5 | export NCCL_ASYNC_ERROR_HANDLING=1
6 | export DEBUG_MODE=True
7 | export WANDB_PROJECT=RL4SGG
8 |
9 |
10 | GROUP_SIZE=8
11 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
12 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
13 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-gh200"
14 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
15 | mkdir -p "$OUTPUT_DIR"
16 |
17 | MAX_PIXELS=$((512 * 28 * 28))
18 |
19 |
20 |
21 | HEAD_NODE_IP=0.0.0.0
22 | MASTER_PORT=29500
23 |
24 |
25 |
26 | # GH200 has a very high bandwidth between CPU and GPU, we should use it!
27 | # zero2:
28 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
29 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
30 | # bsz_per_devie=16, ~40h with 4x GPUs
31 | TRAIN_CMD="open_r1/grpo.py \
32 | --output_dir ${OUTPUT_DIR} \
33 | --model_name_or_path ${MODEL_PATH} \
34 | --dataset_name ${DATA_PATH} \
35 | --max_prompt_length 2048 \
36 | --max_completion_length 1024 \
37 | --per_device_train_batch_size 16 \
38 | --deepspeed ./local_scripts/zero2.json \
39 | --gradient_accumulation_steps 1 \
40 | --logging_steps 1 \
41 | --use_vllm true \
42 | --use_local_vllm true\
43 | --bf16 true\
44 | --tf32 true\
45 | --report_to wandb \
46 | --gradient_checkpointing true \
47 | --max_pixels ${MAX_PIXELS} \
48 | --temperature 0.3 \
49 | --top_p 0.001 \
50 | --top_k 1 \
51 | --num_train_epochs 1 \
52 | --run_name ${RUN_NAME} \
53 | --save_steps 100 \
54 | --num_generations ${GROUP_SIZE} \
55 | --num_iterations 1 \
56 | --beta 0.0\
57 | --use_liger_loss false\
58 | --vllm_max_model_len 4096 \
59 | --vllm_gpu_memory_utilization 0.25"
60 |
61 |
62 | echo "start training..."
63 | # Training: GPUs 0-3, batch size: 16*4//8=8
64 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 \
65 | --node_rank=0 \
66 | --master_addr=${HEAD_NODE_IP} \
67 | --master_port=${MASTER_PORT} \
68 | ${TRAIN_CMD} > debug-gh200.log 2>&1 &
69 |
--------------------------------------------------------------------------------
/scripts/debug/test_vllm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | export HF_HOME=$SCRATCH/huggingface
4 | # ---------- Environment Setup ----------
5 | export NCCL_ASYNC_ERROR_HANDLING=1
6 | export DEBUG_MODE=True
7 | export WANDB_PROJECT=RL4SGG
8 |
9 |
10 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
11 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
12 |
13 | TP_SIZE=1
14 | PORT_BASE=8000
15 |
16 | MAX_PIXELS=$((512 * 28 * 28))
17 |
18 |
19 |
20 |
21 | HEAD_NODE_IP=0.0.0.0
22 | MASTER_PORT=29500
23 |
24 |
25 |
26 |
27 | server_ip=$(hostname -I | awk '{print $1}')
28 |
29 |
30 | # Launch vLLM servers
31 | for i in {0..1}; do
32 | log_file="vllm_server_${i}.log"
33 | port=$((PORT_BASE + i))
34 | CUDA_VISIBLE_DEVICES=${i} python src/vllm_server_v2.py \
35 | --model "${MODEL_PATH}" \
36 | --gpu_memory_utilization 0.9 \
37 | --enable_prefix_caching true \
38 | --dtype 'bfloat16' \
39 | --max_model_len 4096 \
40 | --tensor_parallel_size "${TP_SIZE}" \
41 | --host '0.0.0.0' \
42 | --port "${port}" > "${log_file}" 2>&1 &
43 | done
44 |
45 | echo "Waiting for vLLM servers to initialize..."
46 | #sleep 60
47 |
48 | # Run tests
49 | for i in {2..3}; do
50 | log_file="vllm_client_${i}.log"
51 | port=$((PORT_BASE + i - 2))
52 | group_port=$(( 51200 + i))
53 | CUDA_VISIBLE_DEVICES=${i} python tests/test_vllm.py \
54 | --hosts ${server_ip} \
55 | --server_port "${port}" \
56 | --group_port ${group_port}\
57 | --model_name_or_path "${MODEL_PATH}" > "${log_file}" 2>&1 &
58 | done
59 |
60 |
61 |
--------------------------------------------------------------------------------
/scripts/grpo/run_vllm.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | #SBATCH --job-name=VLLM
4 | #SBATCH --time=24:00:00
5 |
6 | #SBATCH --nodes=8
7 | #SBATCH --ntasks=16
8 | #SBATCH --ntasks-per-node=2
9 | #SBATCH --gpus-per-node=rtx_4090:8
10 | #SBATCH --cpus-per-task=4
11 | #SBATCH --mem-per-cpu=25000M
12 | #SBATCH --output=VLLM_%j_%N.out
13 |
14 | # ------------------ Config ------------------
15 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
16 | TP_SIZE=4 # Tensor parallelism (4 GPUs per process)
17 | PORT_BASE=8000 # Base port to offset by local rank
18 |
19 | # ------------------ Environment ------------------
20 | nodes=($(scontrol show hostnames "$SLURM_JOB_NODELIST"))
21 | DP_WORLD_SIZE=$(( ${#nodes[@]} * 2 )) # 2 processes per node
22 |
23 | head_node=${nodes[0]}
24 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
25 | MASTER_PORT=$(shuf -i 20000-40000 -n 1)
26 | MASTER_IP=${head_node_ip}
27 |
28 | echo "Head node IP: ${MASTER_IP}, Port: ${MASTER_PORT}"
29 | echo "DP_WORLD_SIZE : ${DP_WORLD_SIZE}"
30 | echo "Node list: ${nodes[@]}"
31 |
32 | # ------------------ Export IPs and Ports ------------------
33 | IP_FILE=ip_port_list.txt
34 |
35 | > ${IP_FILE} # Reset output list
36 |
37 | RANK=0
38 | for node in "${nodes[@]}"; do
39 | node_ip=$(srun --nodes=1 --ntasks=1 -w "$node" hostname --ip-address)
40 |
41 | for local_rank in 0 1; do
42 | PORT=$((PORT_BASE + local_rank))
43 | echo "${node_ip}:${PORT}" >> ${IP_FILE}
44 | done
45 | done
46 |
47 | # ------------------ Launch per-process ------------------
48 | RANK=0
49 | for node in "${nodes[@]}"; do
50 | echo "Launching 2 ranks on node $node"
51 |
52 | srun --nodes=1 --ntasks=1 --ntasks-per-node=1 -w "$node" \
53 | bash -c "
54 | for local_rank in 0 1; do
55 | (
56 | export RANK=\$(( ${RANK} + local_rank ))
57 | export DP_WORLD_SIZE=${DP_WORLD_SIZE}
58 | export TP_SIZE=${TP_SIZE}
59 | export MASTER_ADDR=${MASTER_IP}
60 | export MASTER_PORT=${MASTER_PORT}
61 |
62 | export CUDA_VISIBLE_DEVICES=\$(seq -s, \$(( local_rank * 4 )) \$(( local_rank * 4 + 3 )))
63 | PORT=\$(( ${PORT_BASE} + local_rank ))
64 |
65 | echo \"Starting rank \$RANK on $node with CUDA_VISIBLE_DEVICES=\$CUDA_VISIBLE_DEVICES, port \$PORT\"
66 |
67 | python src/vllm_server_v2.py \
68 | --model '${MODEL_PATH}' \
69 | --gpu_memory_utilization 0.85 \
70 | --dtype 'bfloat16' \
71 | --max_model_len 4096 \
72 | --tensor_parallel_size ${TP_SIZE} \
73 | --host '0.0.0.0' \
74 | --port \$PORT
75 | ) &
76 | done
77 | wait
78 | " &
79 |
80 | RANK=$((RANK + 2))
81 | done
82 |
83 | wait
84 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_A100_det_cls
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --nodes=4 # each has 4x A100
8 | #SBATCH --ntasks-per-node=1
9 | #SBATCH --gpus-per-node=4
10 | #SBATCH --cpus-per-task=128
11 |
12 | #SBATCH --partition=normal
13 | #SBATCH --output=RL_A100_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | set -x
18 | # ---------- Environment Setup ----------
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=8
26 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
28 | RUN_NAME="qwen2vl-7b-grpo-det-cls-g8-n1-bs32-A100-SXM4"
29 |
30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
31 | export LOG_PATH=${OUTPUT_DIR}/debug.log
32 |
33 | mkdir -p "$OUTPUT_DIR"
34 |
35 | MAX_PIXELS=$((512 * 28 * 28))
36 |
37 |
38 | MASTER_PORT=29500
39 |
40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
41 | NUM_TRAIN_NODES=${#NODELIST[@]}
42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
43 |
44 | # Choose the first training node as the rendezvous head node
45 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
46 |
47 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
48 |
49 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
50 | echo "MASTER_ADDR: $MASTER_ADDR"
51 |
52 |
53 |
54 | # batch size: PER_GPU(2)*GPUS(4)*NODES(8)*ACC(4) //8=32
55 | # local vLLM: 80G*0.25=20G
56 | #
57 | TRAIN_CMD="open_r1/grpo.py \
58 | --task_type det cls \
59 | --output_dir ${OUTPUT_DIR} \
60 | --model_name_or_path ${MODEL_PATH} \
61 | --dataset_name ${DATA_PATH} \
62 | --max_prompt_length 2048 \
63 | --max_completion_length 1024 \
64 | --custom_per_device_train_batch_size 4 \
65 | --deepspeed ./local_scripts/zero2.json \
66 | --gradient_accumulation_steps 4 \
67 | --learning_rate 3e-7 \
68 | --logging_steps 1 \
69 | --use_vllm true \
70 | --use_local_vllm true\
71 | --bf16 true\
72 | --tf32 true\
73 | --report_to wandb \
74 | --gradient_checkpointing true \
75 | --max_pixels ${MAX_PIXELS} \
76 | --temperature 1.0 \
77 | --top_p 0.9 \
78 | --top_k 50 \
79 | --num_train_epochs 1 \
80 | --run_name ${RUN_NAME} \
81 | --save_steps 100 \
82 | --num_generations ${GROUP_SIZE} \
83 | --num_iterations 1 \
84 | --beta 0.0\
85 | --vllm_max_model_len 4096 \
86 | --vllm_gpu_memory_utilization 0.25\
87 | --save_only_model true\
88 | --seed 42"
89 |
90 |
91 | echo "start training..."
92 |
93 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
94 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
95 | --node_rank ${SLURM_NODEID} \
96 | --rdzv_id $RANDOM \
97 | --rdzv_backend c10d \
98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
99 | ${TRAIN_CMD}
100 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_2B.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=A100_2B_1k_lr6e-7_psg_debug
5 | #SBATCH --time=00:30:00
6 |
7 | #SBATCH --exclude=nid002289,nid002325
8 | #SBATCH --nodes=2 # 4 nodes, each has 4x A100
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=128
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_A100_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | set -x
19 | # ---------- Environment Setup ----------
20 | export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
21 | export DEBUG_MODE=True
22 | export WANDB_PROJECT=RL4SGG
23 |
24 | export NCCL_DEBUG=INFO
25 |
26 |
27 | GPUS_PER_NODE=4
28 | GROUP_SIZE=8
29 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
30 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt"
31 | DATA_PATH="JosephZ/psg_train_sg"
32 |
33 | RUN_NAME="qwen2vl-2b-grpo-g8-n1-bs32-1k-lr6e-7-psg-debug-A100-SXM4"
34 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
35 | mkdir -p "$OUTPUT_DIR"
36 |
37 | export LOG_PATH=${OUTPUT_DIR}/debug.log
38 |
39 | export FORMAT_REWARD_WEIGHT=1.0
40 | export STRICT_FORMAT=True
41 |
42 | MAX_PIXELS=$((512 * 28 * 28))
43 |
44 |
45 | MASTER_PORT=29500
46 |
47 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
48 | NUM_TRAIN_NODES=${#NODELIST[@]}
49 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
50 |
51 | # Choose the first training node as the rendezvous head node
52 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
53 |
54 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
55 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
56 |
57 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
58 | echo "MASTER_ADDR: $MASTER_ADDR"
59 |
60 |
61 |
62 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32
63 | # local vLLM: 80G*0.2=16G
64 | #
65 | # ['format_reward', 'node_acc_reward', "node_box_reward", "edge_reward"]
66 | TRAIN_CMD="open_r1/grpo.py \
67 | --task_type sgg \
68 | --output_dir ${OUTPUT_DIR} \
69 | --model_name_or_path ${MODEL_PATH} \
70 | --dataset_name ${DATA_PATH} \
71 | --max_prompt_length 2048 \
72 | --max_completion_length 1024 \
73 | --custom_per_device_train_batch_size 4 \
74 | --deepspeed ./local_scripts/zero2.json \
75 | --gradient_accumulation_steps 4 \
76 | --learning_rate 6e-7 \
77 | --logging_steps 1 \
78 | --use_vllm true \
79 | --use_local_vllm true\
80 | --bf16 true\
81 | --tf32 true\
82 | --report_to wandb \
83 | --gradient_checkpointing true \
84 | --max_pixels ${MAX_PIXELS} \
85 | --temperature 1.0 \
86 | --top_p 0.9 \
87 | --top_k 50 \
88 | --num_train_epochs 1.0 \
89 | --run_name ${RUN_NAME} \
90 | --save_steps 100 \
91 | --num_generations ${GROUP_SIZE} \
92 | --num_iterations 1 \
93 | --beta 0.0 \
94 | --vllm_max_model_len 4096 \
95 | --vllm_gpu_memory_utilization 0.2 \
96 | --save_only_model true\
97 | --seed 42"
98 |
99 |
100 | echo "start training..."
101 |
102 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
103 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
104 | --node_rank ${SLURM_NODEID} \
105 | --rdzv_id $RANDOM \
106 | --rdzv_backend c10d \
107 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
108 | ${TRAIN_CMD}
109 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_2B_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=A100_2B_SFT_RL_2k_lr6e-7
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --exclude=nid002289,nid002325
8 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=128
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_A100_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | # ---------- Environment Setup ----------
19 | export TORCH_NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=8
26 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
27 | MODEL_PATH=$1
28 |
29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | RUN_NAME="qwen2vl-2b-sft-grpo-sgg-g8-n1-bs32-2k-lr6e-7-A100-SXM4"
31 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
32 | mkdir -p "$OUTPUT_DIR"
33 |
34 | export LOG_PATH=${OUTPUT_DIR}/debug.log
35 |
36 | MAX_PIXELS=$((512 * 28 * 28))
37 |
38 |
39 | MASTER_PORT=29500
40 |
41 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
42 | NUM_TRAIN_NODES=${#NODELIST[@]}
43 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
44 |
45 | # Choose the first training node as the rendezvous head node
46 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
47 |
48 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
49 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
50 |
51 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
52 | echo "MASTER_ADDR: $MASTER_ADDR"
53 |
54 |
55 |
56 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32
57 | # local vLLM: 80G*0.2=16G
58 | #
59 | TRAIN_CMD="open_r1/grpo.py \
60 | --task_type sgg \
61 | --output_dir ${OUTPUT_DIR} \
62 | --model_name_or_path ${MODEL_PATH} \
63 | --dataset_name ${DATA_PATH} \
64 | --max_prompt_length 2048 \
65 | --max_completion_length 2048 \
66 | --custom_per_device_train_batch_size 4 \
67 | --deepspeed ./local_scripts/zero2.json \
68 | --gradient_accumulation_steps 4 \
69 | --learning_rate 6e-7 \
70 | --logging_steps 1 \
71 | --use_vllm true \
72 | --use_local_vllm true\
73 | --bf16 true\
74 | --tf32 true\
75 | --report_to wandb \
76 | --gradient_checkpointing true \
77 | --max_pixels ${MAX_PIXELS} \
78 | --temperature 1.0 \
79 | --top_p 0.9 \
80 | --top_k 50 \
81 | --num_train_epochs 1 \
82 | --run_name ${RUN_NAME} \
83 | --save_steps 100 \
84 | --num_generations ${GROUP_SIZE} \
85 | --num_iterations 1 \
86 | --beta 0.0\
87 | --vllm_max_model_len 4096 \
88 | --vllm_gpu_memory_utilization 0.2 \
89 | --save_only_model true\
90 | --seed 42"
91 |
92 |
93 | echo "start training..."
94 |
95 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
96 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
97 | --node_rank ${SLURM_NODEID} \
98 | --rdzv_id $RANDOM \
99 | --rdzv_backend c10d \
100 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
101 | ${TRAIN_CMD}
102 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_2B_close.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=A100_2B
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100
8 | #SBATCH --ntasks-per-node=1
9 | #SBATCH --gpus-per-node=4
10 | #SBATCH --cpus-per-task=128
11 |
12 | #SBATCH --partition=normal
13 | #SBATCH --output=RL_A100_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | # ---------- Environment Setup ----------
18 | export NCCL_ASYNC_ERROR_HANDLING=1
19 | export DEBUG_MODE=True
20 | export WANDB_PROJECT=RL4SGG
21 |
22 |
23 | GPUS_PER_NODE=4
24 | GROUP_SIZE=8
25 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
26 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
27 | RUN_NAME="qwen2vl-2b-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4"
28 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
29 | mkdir -p "$OUTPUT_DIR"
30 |
31 | export LOG_PATH=${OUTPUT_DIR}/debug.log
32 |
33 | MAX_PIXELS=$((512 * 28 * 28))
34 |
35 |
36 | MASTER_PORT=29500
37 |
38 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
39 | NUM_TRAIN_NODES=${#NODELIST[@]}
40 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
41 |
42 | # Choose the first training node as the rendezvous head node
43 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
44 |
45 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
46 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
47 |
48 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
49 | echo "MASTER_ADDR: $MASTER_ADDR"
50 |
51 |
52 |
53 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32
54 | # local vLLM: 80G*0.2=16G
55 | #
56 | TRAIN_CMD="open_r1/grpo.py \
57 | --output_dir ${OUTPUT_DIR} \
58 | --model_name_or_path ${MODEL_PATH} \
59 | --dataset_name ${DATA_PATH} \
60 | --max_prompt_length 2048 \
61 | --max_completion_length 1024 \
62 | --custom_per_device_train_batch_size 4 \
63 | --deepspeed ./local_scripts/zero2.json \
64 | --gradient_accumulation_steps 4 \
65 | --learning_rate 3e-7 \
66 | --use_predefined_cats true \
67 | --logging_steps 1 \
68 | --use_vllm true \
69 | --use_local_vllm true\
70 | --bf16 true\
71 | --tf32 true\
72 | --report_to wandb \
73 | --gradient_checkpointing false \
74 | --max_pixels ${MAX_PIXELS} \
75 | --temperature 1.0 \
76 | --top_p 0.9 \
77 | --top_k 50 \
78 | --num_train_epochs 1 \
79 | --run_name ${RUN_NAME} \
80 | --save_steps 100 \
81 | --num_generations ${GROUP_SIZE} \
82 | --num_iterations 1 \
83 | --beta 0.0\
84 | --vllm_max_model_len 4096 \
85 | --vllm_gpu_memory_utilization 0.2 \
86 | --save_only_model true\
87 | --seed 42"
88 |
89 |
90 | echo "start training..."
91 |
92 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
93 | --node_rank ${SLURM_NODEID} \
94 | --rdzv_id $RANDOM \
95 | --rdzv_backend c10d \
96 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
97 | ${TRAIN_CMD}
98 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_2B_close_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=A100_2B_SFT_CLOSE_RL
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100
8 | #SBATCH --ntasks-per-node=1
9 | #SBATCH --gpus-per-node=4
10 | #SBATCH --cpus-per-task=128
11 |
12 | #SBATCH --partition=normal
13 | #SBATCH --output=RL_A100_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | # ---------- Environment Setup ----------
18 | export NCCL_ASYNC_ERROR_HANDLING=1
19 | export DEBUG_MODE=True
20 | export WANDB_PROJECT=RL4SGG
21 |
22 |
23 | GPUS_PER_NODE=4
24 | GROUP_SIZE=8
25 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
26 | MODEL_PATH=$1
27 |
28 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
29 | RUN_NAME="qwen2vl-2b-sft-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4"
30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
31 | mkdir -p "$OUTPUT_DIR"
32 |
33 | export LOG_PATH=${OUTPUT_DIR}/debug.log
34 |
35 | MAX_PIXELS=$((512 * 28 * 28))
36 |
37 |
38 | MASTER_PORT=29500
39 |
40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
41 | NUM_TRAIN_NODES=${#NODELIST[@]}
42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
43 |
44 | # Choose the first training node as the rendezvous head node
45 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
46 |
47 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
48 |
49 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
50 | echo "MASTER_ADDR: $MASTER_ADDR"
51 |
52 |
53 |
54 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32
55 | # local vLLM: 80G*0.2=16G
56 | #
57 | TRAIN_CMD="open_r1/grpo.py \
58 | --output_dir ${OUTPUT_DIR} \
59 | --model_name_or_path ${MODEL_PATH} \
60 | --dataset_name ${DATA_PATH} \
61 | --max_prompt_length 2048 \
62 | --max_completion_length 1024 \
63 | --custom_per_device_train_batch_size 4 \
64 | --deepspeed ./local_scripts/zero2.json \
65 | --gradient_accumulation_steps 4 \
66 | --learning_rate 3e-7 \
67 | --logging_steps 1 \
68 | --use_vllm true \
69 | --use_local_vllm true\
70 | --bf16 true\
71 | --tf32 true\
72 | --report_to wandb \
73 | --gradient_checkpointing true \
74 | --max_pixels ${MAX_PIXELS} \
75 | --temperature 1.0 \
76 | --top_p 0.9 \
77 | --top_k 50 \
78 | --num_train_epochs 1 \
79 | --run_name ${RUN_NAME} \
80 | --save_steps 100 \
81 | --num_generations ${GROUP_SIZE} \
82 | --num_iterations 1 \
83 | --beta 0.0\
84 | --vllm_max_model_len 4096 \
85 | --vllm_gpu_memory_utilization 0.25 \
86 | --save_only_model true\
87 | --use_predefined_cats true \
88 | --seed 42"
89 |
90 |
91 | echo "start training..."
92 |
93 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
94 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
95 | --node_rank ${SLURM_NODEID} \
96 | --rdzv_id $RANDOM \
97 | --rdzv_backend c10d \
98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
99 | ${TRAIN_CMD}
100 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_SFT_RL
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100
8 | #SBATCH --ntasks-per-node=1
9 | #SBATCH --gpus-per-node=4
10 | #SBATCH --cpus-per-task=128
11 |
12 | #SBATCH --partition=normal
13 | #SBATCH --output=RL_A100_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | set -x
18 | # ---------- Environment Setup ----------
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=8
26 | #MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
27 | MODEL_PATH=$1
28 |
29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | RUN_NAME="qwen2vl-7b-sft-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4"
31 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
32 | mkdir -p "$OUTPUT_DIR"
33 |
34 | MAX_PIXELS=$((512 * 28 * 28))
35 |
36 |
37 | MASTER_PORT=29500
38 |
39 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
40 | NUM_TRAIN_NODES=${#NODELIST[@]}
41 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
42 |
43 | # Choose the first training node as the rendezvous head node
44 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
45 |
46 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
47 |
48 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
49 | echo "MASTER_ADDR: $MASTER_ADDR"
50 |
51 |
52 |
53 | # batch size: PER_GPU(2)*GPUS(4)*NODES(4)*ACC(8) //8=32
54 | # local vLLM: 80G*0.25=20G
55 | #
56 | TRAIN_CMD="open_r1/grpo.py \
57 | --output_dir ${OUTPUT_DIR} \
58 | --model_name_or_path ${MODEL_PATH} \
59 | --dataset_name ${DATA_PATH} \
60 | --max_prompt_length 2048 \
61 | --max_completion_length 1024 \
62 | --custom_per_device_train_batch_size 2 \
63 | --deepspeed ./local_scripts/zero2.json \
64 | --gradient_accumulation_steps 8 \
65 | --learning_rate 3e-7 \
66 | --logging_steps 1 \
67 | --use_vllm true \
68 | --use_local_vllm true\
69 | --bf16 true\
70 | --tf32 true\
71 | --report_to wandb \
72 | --gradient_checkpointing true \
73 | --max_pixels ${MAX_PIXELS} \
74 | --temperature 1.0 \
75 | --top_p 0.9 \
76 | --top_k 50 \
77 | --num_train_epochs 1 \
78 | --run_name ${RUN_NAME} \
79 | --save_steps 100 \
80 | --num_generations ${GROUP_SIZE} \
81 | --num_iterations 1 \
82 | --beta 0.0\
83 | --vllm_max_model_len 4096 \
84 | --vllm_gpu_memory_utilization 0.25\
85 | --save_only_model true\
86 | --seed 42"
87 |
88 |
89 | echo "start training..."
90 |
91 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
92 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
93 | --node_rank ${SLURM_NODEID} \
94 | --rdzv_id $RANDOM \
95 | --rdzv_backend c10d \
96 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
97 | ${TRAIN_CMD}
98 |
--------------------------------------------------------------------------------
/scripts/grpo/train_a100_close.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_A100
5 | #SBATCH --time=24:00:00
6 |
7 | #SBATCH --nodes=8 # each has 4x A100
8 | #SBATCH --ntasks-per-node=1
9 | #SBATCH --gpus-per-node=4
10 | #SBATCH --cpus-per-task=128
11 |
12 | #SBATCH --partition=normal
13 | #SBATCH --output=RL_A100_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | set -x
18 | # ---------- Environment Setup ----------
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=8
26 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
28 | RUN_NAME="qwen2vl-7b-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4"
29 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
30 | mkdir -p "$OUTPUT_DIR"
31 |
32 | MAX_PIXELS=$((512 * 28 * 28))
33 |
34 |
35 | MASTER_PORT=29500
36 |
37 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
38 | NUM_TRAIN_NODES=${#NODELIST[@]}
39 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
40 |
41 | # Choose the first training node as the rendezvous head node
42 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
43 |
44 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
45 |
46 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g')
47 | echo "MASTER_ADDR: $MASTER_ADDR"
48 |
49 |
50 |
51 | # batch size: PER_GPU(2)*GPUS(4)*NODES(8)*ACC(4) //8=32
52 | # local vLLM: 80G*0.25=20G
53 | #
54 | TRAIN_CMD="open_r1/grpo.py \
55 | --output_dir ${OUTPUT_DIR} \
56 | --model_name_or_path ${MODEL_PATH} \
57 | --dataset_name ${DATA_PATH} \
58 | --max_prompt_length 2048 \
59 | --max_completion_length 1024 \
60 | --custom_per_device_train_batch_size 4 \
61 | --deepspeed ./local_scripts/zero2.json \
62 | --gradient_accumulation_steps 2 \
63 | --learning_rate 3e-7 \
64 | --use_predefined_cats true \
65 | --logging_steps 1 \
66 | --use_vllm true \
67 | --use_local_vllm true\
68 | --bf16 true\
69 | --tf32 true\
70 | --report_to wandb \
71 | --gradient_checkpointing true \
72 | --max_pixels ${MAX_PIXELS} \
73 | --temperature 1.0 \
74 | --top_p 0.9 \
75 | --top_k 50 \
76 | --num_train_epochs 1 \
77 | --run_name ${RUN_NAME} \
78 | --save_steps 100 \
79 | --num_generations ${GROUP_SIZE} \
80 | --num_iterations 1 \
81 | --beta 0.0\
82 | --vllm_max_model_len 4096 \
83 | --vllm_gpu_memory_utilization 0.25\
84 | --save_only_model true\
85 | --seed 42"
86 |
87 |
88 | echo "start training..."
89 |
90 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
91 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
92 | --node_rank ${SLURM_NODEID} \
93 | --rdzv_id $RANDOM \
94 | --rdzv_backend c10d \
95 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
96 | ${TRAIN_CMD}
97 |
--------------------------------------------------------------------------------
/scripts/grpo/train_fsdp.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=GRPO_train
6 | #SBATCH --time=24:00:00
7 | #SBATCH --nodes=16 # 4 training nodes + 1 vLLM node = 5 nodes
8 | #SBATCH --ntasks=16 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=rtx_4090:8
11 | #SBATCH --cpus-per-task=8
12 | #SBATCH --mem-per-cpu=25000M
13 | #SBATCH --output=RL_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | # force crashing on nccl issues like hanging broadcast
18 | export NCCL_ASYNC_ERROR_HANDLING=1
19 | #export NCCL_IB_DISABLE=1
20 |
21 | # export NCCL_DEBUG=INFO
22 | # export NCCL_DEBUG_SUBSYS=COLL
23 | # export NCCL_SOCKET_NTHREADS=1
24 | # export NCCL_NSOCKS_PERTHREAD=1
25 | # export CUDA_LAUNCH_BLOCKING=1
26 |
27 | # wait for vLLM servers
28 | #sleep 60
29 |
30 | # Read IPs from file and join them with commas
31 | #ip_str=$(paste -sd, ip_list.txt)
32 | #echo "vLLM servers: $ip_str"
33 |
34 | FILE="ip_port_list.txt"
35 |
36 | SERVER_IP=""
37 | SERVER_PORT=""
38 |
39 | while IFS=: read -r ip port; do
40 | SERVER_IP+="${ip},"
41 | SERVER_PORT+="${port},"
42 | done < "$FILE"
43 |
44 | # Remove trailing commas
45 | SERVER_IP="${SERVER_IP%,}"
46 | SERVER_PORT="${SERVER_PORT%,}"
47 |
48 | echo "SERVER_IP=$SERVER_IP"
49 | echo "SERVER_PORT=$SERVER_PORT"
50 |
51 |
52 | # Define node counts
53 | NUM_TRAIN_NODES=${SLURM_NNODES} # all nodes
54 | GPUS_PER_NODE=8
55 |
56 | # Get the list of allocated nodes
57 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
58 |
59 | # Assign training nodes (first NUM_TRAIN_NODES nodes)
60 | TRAIN_NODES=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
61 |
62 | # Choose the first training node as the rendezvous head node
63 | HEAD_NODE=${TRAIN_NODES[0]}
64 | MASTER_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
65 |
66 | MASTER_PORT=6000
67 | echo "Head Node IP: $MASTER_IP, port: ${MASTER_PORT}"
68 |
69 | # Create a comma-separated list of training nodes for srun
70 | TRAIN_NODES_LIST=$(IFS=, ; echo "${TRAIN_NODES[*]}")
71 |
72 | export NCCL_DEBUG=INFO
73 | echo "environment: $(env | grep NCCL)"
74 |
75 |
76 |
77 | export DEBUG_MODE=True
78 | export WANDB_PROJECT=RL4SGG
79 |
80 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt"
81 | export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
82 |
83 |
84 | # Training setup
85 | NNODES=$SLURM_NNODES
86 | NODE_RANK=$SLURM_PROCID
87 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))
88 |
89 | MAX_PIXELS=$((512 * 28 * 28))
90 |
91 | echo "Start training script..."
92 |
93 | LAUNCHER="accelerate launch \
94 | --multi_gpu \
95 | --num_machines $NNODES \
96 | --num_processes $WORLD_SIZE \
97 | --main_process_ip "$MASTER_IP" \
98 | --main_process_port $MASTER_PORT \
99 | --num_processes $WORLD_SIZE \
100 | --machine_rank \$SLURM_PROCID \
101 | --role $SLURMD_NODENAME: \
102 | --rdzv_conf rdzv_backend=c10d \
103 | --max_restarts 0 \
104 | --tee 3 \
105 | --config_file local_scripts/fsdp.yaml \
106 | "
107 |
108 | CMD=" \
109 | open_r1/grpo.py \
110 | --output_dir models/qwen2vl-fsdp-g8 \
111 | --model_name_or_path ${MODEL_PATH} \
112 | --dataset_name $DATA_PATH \
113 | --max_prompt_length 2048 \
114 | --max_completion_length 1024 \
115 | --per_device_train_batch_size 1 \
116 | --gradient_accumulation_steps 1 \
117 | --logging_steps 1 \
118 | --use_vllm true \
119 | --vllm_server_host ${SERVER_IP} \
120 | --vllm_server_port ${SERVER_PORT} \
121 | --vllm_server_timeout 600 \
122 | --bf16 \
123 | --report_to wandb \
124 | --gradient_checkpointing true \
125 | --max_pixels ${MAX_PIXELS} \
126 | --temperature 0.3 \
127 | --top_p 0.001 \
128 | --top_k 1 \
129 | --num_train_epochs 1 \
130 | --run_name Qwen2VL-7B-GRPO-fsdp-G8 \
131 | --save_steps 100 \
132 | --num_generations 8
133 | "
134 |
135 |
136 | srun --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD"
137 |
138 | echo "END TIME: $(date)"
139 |
--------------------------------------------------------------------------------
/scripts/grpo/train_fused.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | #SBATCH --job-name=GRPO_train_vllm
3 | #SBATCH --time=24:00:00
4 | #SBATCH --nodes=15
5 | #SBATCH --ntasks=15
6 | #SBATCH --ntasks-per-node=1
7 | #SBATCH --gpus-per-node=rtx_4090:8
8 | #SBATCH --cpus-per-task=15
9 | #SBATCH --mem-per-cpu=16000M
10 | #SBATCH --output=TrainVLLM_%j_%N.out
11 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
12 |
13 | set -euo pipefail
14 |
15 | # ---------- Environment Setup ----------
16 | export NCCL_ASYNC_ERROR_HANDLING=1
17 | export DEBUG_MODE=True
18 | export WANDB_PROJECT=RL4SGG
19 |
20 | GROUP_SIZE=8
21 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
22 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
23 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-4090"
24 | OUTPUT_DIR="models/${RUN_NAME}"
25 | mkdir -p "$OUTPUT_DIR"
26 |
27 | TP_SIZE=4
28 | PORT_BASE=8000
29 | MAX_PIXELS=$((512 * 28 * 28))
30 |
31 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
32 | NUM_NODES=${#NODELIST[@]}
33 |
34 | if (( NUM_NODES % 3 != 0 )); then
35 | echo "Error: number of nodes ($NUM_NODES) must be divisible by 3."
36 | exit 1
37 | fi
38 |
39 | MIXED_NODES=10 # Set this dynamically if needed
40 | # vLLM: 10*4=40 GPUs, 10 servers
41 | # training, x=(8-4)*10//8= 5
42 | # training GPUs: 10*4+ 5*8=80, batch size=80//8*2=20
43 |
44 | HEAD_NODE=${NODELIST[0]}
45 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
46 | RDZV_PORT=29500
47 | IP_FILE="${OUTPUT_DIR}/ip_port_list.txt"
48 | > "$IP_FILE"
49 |
50 | for i in $(seq 0 $((MIXED_NODES - 1))); do
51 | node=${NODELIST[$i]}
52 | ip=$(srun --nodes=1 --ntasks=1 -w "$node" hostname --ip-address)
53 | echo "${ip}:$((PORT_BASE))" >> "$IP_FILE"
54 | done
55 |
56 | SERVER_IP=$(cut -d: -f1 $IP_FILE | paste -sd,)
57 | SERVER_PORT=$(cut -d: -f2 $IP_FILE | paste -sd,)
58 |
59 | TRAIN_CMD="open_r1/grpo.py \
60 | --output_dir ${OUTPUT_DIR} \
61 | --model_name_or_path ${MODEL_PATH} \
62 | --dataset_name ${DATA_PATH} \
63 | --deepspeed ./local_scripts/zero3.json \
64 | --max_prompt_length 2048 \
65 | --max_completion_length 1024 \
66 | --custom_per_device_train_batch_size 1 \
67 | --gradient_accumulation_steps 2 \
68 | --learning_rate 3e-7 \
69 | --logging_steps 1 \
70 | --use_vllm true \
71 | --vllm_server_host ${SERVER_IP} \
72 | --vllm_server_port ${SERVER_PORT} \
73 | --vllm_server_timeout 600 \
74 | --vllm_locate_same_node true\
75 | --vllm_locate_same_remain_gpus 4\
76 | --bf16 \
77 | --report_to wandb \
78 | --gradient_checkpointing true \
79 | --max_pixels ${MAX_PIXELS} \
80 | --temperature 1 \
81 | --top_p 0.9 \
82 | --top_k 50 \
83 | --num_train_epochs 1 \
84 | --run_name ${RUN_NAME} \
85 | --save_steps 100 \
86 | --num_generations 8 \
87 | --num_iterations 1 \
88 | --beta 0.0 \
89 | --save_only_model true \
90 | --seed 42"
91 |
92 | # ---------- Functions ----------
93 | launch_mixed_node() {
94 | local i=$1
95 | local node=$2
96 | local log_file="${OUTPUT_DIR}/vllm_node_${i}_${node}.log"
97 | srun --nodes=1 --ntasks=1 -w "$node" bash -c "
98 | export RANK=$i
99 | export NODE_RANK=$i
100 | # vLLM: GPUs 4-7
101 | CUDA_VISIBLE_DEVICES=4,5,6,7 python src/vllm_server_v2.py \
102 | --model '${MODEL_PATH}' \
103 | --gpu_memory_utilization 0.85 \
104 | --enable-prefix-caching true \
105 | --dtype 'bfloat16' \
106 | --max_model_len 4096 \
107 | --tensor_parallel_size ${TP_SIZE} \
108 | --host '0.0.0.0' \
109 | --port ${PORT_BASE} > ${log_file} 2>&1 &
110 |
111 | # Training: GPUs 0-3
112 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes ${NUM_NODES} --nproc_per_node 4 \
113 | --node_rank \$NODE_RANK \
114 | --rdzv_id grpo_run \
115 | --rdzv_backend c10d \
116 | --rdzv_endpoint ${HEAD_NODE_IP}:${RDZV_PORT} \
117 | ${TRAIN_CMD} &
118 | wait
119 | " &
120 | }
121 |
122 | launch_training_node() {
123 | local i=$1
124 | local node=$2
125 | srun --nodes=1 --ntasks=1 -w "$node" bash -c "
126 | export NODE_RANK=$i
127 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes ${NUM_NODES} --nproc_per_node 8 \
128 | --node_rank \$NODE_RANK \
129 | --rdzv_id grpo_run \
130 | --rdzv_backend c10d \
131 | --rdzv_endpoint ${HEAD_NODE_IP}:${RDZV_PORT} \
132 | ${TRAIN_CMD}
133 | " &
134 | }
135 |
136 | # ---------- Main Launcher ----------
137 | for i in "${!NODELIST[@]}"; do
138 | if [ $i -lt ${MIXED_NODES} ]; then
139 | launch_mixed_node $i "${NODELIST[$i]}"
140 | else
141 | launch_training_node $i "${NODELIST[$i]}"
142 | fi
143 | done
144 |
145 | wait
146 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_GH200_zero_lr2x_fp8
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --exclude=nid006792,nid007085
8 |
9 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200
10 | #SBATCH --ntasks=8 # Total tasks equals total nodes
11 | #SBATCH --ntasks-per-node=1
12 | #SBATCH --gpus-per-node=4
13 | #SBATCH --cpus-per-task=288 # fixed for GH200
14 |
15 |
16 | #SBATCH --partition=normal
17 | #SBATCH --output=RL_gh200_%j_%N.out
18 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
19 |
20 |
21 | set -x
22 | # ---------- Environment Setup ----------
23 | export NCCL_ASYNC_ERROR_HANDLING=1
24 | export DEBUG_MODE=True
25 | export WANDB_PROJECT=RL4SGG
26 |
27 |
28 | GPUS_PER_NODE=4
29 | GROUP_SIZE=8
30 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
31 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
32 | RUN_NAME="qwen2vl-7b-grpo-2k-lr6e-7-g8-n1-bs32-fp8-gh200"
33 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
34 | mkdir -p "$OUTPUT_DIR"
35 |
36 | export LOG_PATH=${OUTPUT_DIR}/debug.log
37 |
38 | export STRICT_FORMAT=True
39 |
40 | MAX_PIXELS=$((512 * 28 * 28))
41 |
42 |
43 | MASTER_PORT=29500
44 |
45 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
46 | NUM_TRAIN_NODES=${#NODELIST[@]}
47 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
48 |
49 | # Choose the first training node as the rendezvous head node
50 | ##HEAD_NODE=${TRAIN_NODES_LIST[0]}
51 | ##HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
52 | ##echo "Head Node IP: $HEAD_NODE_IP"
53 |
54 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
55 | echo "MASTER_ADDR: $MASTER_ADDR"
56 |
57 |
58 |
59 | # GH200 has a very high bandwidth between CPU and GPU, we should use it!
60 | # zero2:
61 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
62 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
63 | # bsz_per_devie=16, ~40h with 4x GPUs
64 | #
65 | # batch size: 16*1*4*4 //8=32
66 | TRAIN_CMD="open_r1/grpo.py \
67 | --task_type sgg \
68 | --use_fp8 true \
69 | --output_dir ${OUTPUT_DIR} \
70 | --model_name_or_path ${MODEL_PATH} \
71 | --dataset_name ${DATA_PATH} \
72 | --max_prompt_length 2048 \
73 | --max_completion_length 1024 \
74 | --custom_per_device_train_batch_size 8 \
75 | --deepspeed ./local_scripts/zero2_offload.json \
76 | --gradient_accumulation_steps 1 \
77 | --learning_rate 6e-7 \
78 | --logging_steps 1 \
79 | --use_vllm true \
80 | --use_local_vllm true\
81 | --bf16 true\
82 | --tf32 true\
83 | --report_to wandb \
84 | --gradient_checkpointing true \
85 | --max_pixels ${MAX_PIXELS} \
86 | --temperature 1 \
87 | --top_p 0.9 \
88 | --top_k 50 \
89 | --num_train_epochs 1 \
90 | --run_name ${RUN_NAME} \
91 | --save_steps 100 \
92 | --num_generations ${GROUP_SIZE} \
93 | --num_iterations 1 \
94 | --beta 0.0 \
95 | --vllm_max_model_len 4096 \
96 | --vllm_gpu_memory_utilization 0.2 \
97 | --ddp_timeout 3600 \
98 | --save_only_model false"
99 |
100 |
101 | echo "start training with CMD=${TRAIN_CMD} ..."
102 |
103 | WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_TRAIN_NODES))
104 |
105 |
106 | LAUNCHER="accelerate launch \
107 | --multi_gpu \
108 | --num_machines $NUM_TRAIN_NODES \
109 | --num_processes $WORLD_SIZE \
110 | --main_process_ip "$MASTER_ADDR" \
111 | --main_process_port $MASTER_PORT \
112 | --num_processes $WORLD_SIZE \
113 | --machine_rank $SLURM_PROCID \
114 | --role $SLURMD_NODENAME: \
115 | --rdzv_conf rdzv_backend=c10d \
116 | --rdzv_timeout 3600 \
117 | --max_restarts 0 \
118 | --tee 3 \
119 | --mixed_precision fp8 \
120 | "
121 |
122 |
123 | srun --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $TRAIN_CMD"
124 |
125 | #srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
126 | # --node_rank ${SLURM_NODEID} \
127 | # --rdzv_id $RANDOM \
128 | # --rdzv_backend c10d \
129 | # --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
130 | # --rdzv_timeout 3600 \
131 | # ${TRAIN_CMD}
132 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_2B.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=2B_bs32_gh200_lr6e-7_debug
5 | #SBATCH --time=00:20:00
6 |
7 | #SBATCH --nodes=1 # 2 nodes, each has 4x GH200
8 | #SBATCH --ntasks=1 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=288 # fixed for GH200
12 |
13 | #SBATCH --partition=debug
14 | #SBATCH --output=RL_gh200_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | # ---------- Environment Setup ----------
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=16
26 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
28 | RUN_NAME="qwen2vl-2b-grpo-debug-n1-sgg-bs16-lr6e-7-gh200"
29 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
30 | mkdir -p "$OUTPUT_DIR"
31 |
32 | MAX_PIXELS=$((512 * 28 * 28))
33 |
34 | REF_MODEL_NAME=$1
35 | export LOG_PATH=${OUTPUT_DIR}/debug.log
36 |
37 | export STRICT_FORMAT=True
38 |
39 | MASTER_PORT=29500
40 |
41 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
42 | NUM_TRAIN_NODES=${#NODELIST[@]}
43 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
44 |
45 | # Choose the first training node as the rendezvous head node
46 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
47 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
48 | echo "Head Node IP: $HEAD_NODE_IP"
49 |
50 |
51 |
52 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32
53 | TRAIN_CMD="open_r1/grpo.py \
54 | --output_dir ${OUTPUT_DIR} \
55 | --task_type sgg \
56 | --model_name_or_path ${MODEL_PATH} \
57 | --dataset_name ${DATA_PATH} \
58 | --max_prompt_length 2048 \
59 | --max_completion_length 1024 \
60 | --custom_per_device_train_batch_size 16 \
61 | --deepspeed ./local_scripts/zero2_offload.json \
62 | --gradient_accumulation_steps 1 \
63 | --learning_rate 6e-7 \
64 | --logging_steps 1 \
65 | --use_vllm true \
66 | --use_local_vllm true\
67 | --bf16 true\
68 | --tf32 true\
69 | --report_to wandb \
70 | --gradient_checkpointing true \
71 | --max_pixels ${MAX_PIXELS} \
72 | --temperature 1 \
73 | --top_p 0.9 \
74 | --top_k 50 \
75 | --num_train_epochs 1 \
76 | --run_name ${RUN_NAME} \
77 | --save_steps 100 \
78 | --num_generations ${GROUP_SIZE} \
79 | --num_iterations 1 \
80 | --beta 0.00 \
81 | --vllm_max_model_len 4096 \
82 | --vllm_gpu_memory_utilization 0.2 \
83 | --save_only_model true \
84 | --seed 42"
85 |
86 |
87 | echo "start training..."
88 |
89 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
90 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
91 | --node_rank ${SLURM_NODEID} \
92 | --rdzv_id $RANDOM \
93 | --rdzv_backend c10d \
94 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \
95 | ${TRAIN_CMD}
96 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_2B_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=2B_bs32_gh200_SFT_lr6e-7_psg
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --nodes=4 # 2 nodes, each has 4x GH200
8 | #SBATCH --ntasks=4 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=288 # fixed for GH200
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_gh200_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | # ---------- Environment Setup ----------
19 | export NCCL_ASYNC_ERROR_HANDLING=1
20 | export DEBUG_MODE=True
21 | export WANDB_PROJECT=RL4SGG
22 |
23 |
24 | GPUS_PER_NODE=4
25 | GROUP_SIZE=8
26 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct"
27 | MODEL_PATH=$1
28 |
29 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | DATA_PATH="JosephZ/psg_train_sg"
31 | RUN_NAME="qwen2vl-2b-sft-grpo-g8-n1-bs32-lr6e-7-psg-merged-gh200"
32 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
33 | mkdir -p "$OUTPUT_DIR"
34 |
35 | export LOG_PATH=${OUTPUT_DIR}/debug.log
36 | export STRICT_FORMAT=True
37 |
38 | MAX_PIXELS=$((512 * 28 * 28))
39 |
40 |
41 | MASTER_PORT=29500
42 |
43 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
44 | NUM_TRAIN_NODES=${#NODELIST[@]}
45 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
46 |
47 | # Choose the first training node as the rendezvous head node
48 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
49 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
50 | echo "Head Node IP: $HEAD_NODE_IP"
51 |
52 |
53 |
54 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32
55 | TRAIN_CMD="open_r1/grpo.py \
56 | --output_dir ${OUTPUT_DIR} \
57 | --model_name_or_path ${MODEL_PATH} \
58 | --dataset_name ${DATA_PATH} \
59 | --max_prompt_length 2048 \
60 | --max_completion_length 1024 \
61 | --custom_per_device_train_batch_size 16 \
62 | --deepspeed ./local_scripts/zero2_offload.json \
63 | --gradient_accumulation_steps 1 \
64 | --learning_rate 6e-7 \
65 | --logging_steps 1 \
66 | --use_vllm true \
67 | --use_local_vllm true\
68 | --bf16 true\
69 | --tf32 true\
70 | --report_to wandb \
71 | --gradient_checkpointing true \
72 | --max_pixels ${MAX_PIXELS} \
73 | --temperature 1 \
74 | --top_p 0.9 \
75 | --top_k 50 \
76 | --max_steps 2000 \
77 | --run_name ${RUN_NAME} \
78 | --save_steps 100 \
79 | --num_generations ${GROUP_SIZE} \
80 | --num_iterations 1 \
81 | --beta 0.0\
82 | --vllm_max_model_len 4096 \
83 | --vllm_gpu_memory_utilization 0.2 \
84 | --save_only_model true \
85 | --seed 42"
86 |
87 |
88 | echo "start training..."
89 |
90 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
91 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
92 | --node_rank ${SLURM_NODEID} \
93 | --rdzv_id $RANDOM \
94 | --rdzv_backend c10d \
95 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \
96 | ${TRAIN_CMD}
97 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_2B_close.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=2B_bs32_gh200_sft_close_lr2x
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --nodes=4 # 2 nodes, each has 4x GH200
8 | #SBATCH --ntasks=4 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=288 # fixed for GH200
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_gh200_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | set -x
19 | # ---------- Environment Setup ----------
20 | export NCCL_ASYNC_ERROR_HANDLING=1
21 | export DEBUG_MODE=True
22 | export WANDB_PROJECT=RL4SGG
23 |
24 |
25 | GPUS_PER_NODE=4
26 | GROUP_SIZE=8
27 | MODEL_PATH=$1 #"Qwen/Qwen2-VL-2B-Instruct"
28 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
29 | RUN_NAME="qwen2vl-2b-sft-close-grpo-g8-n1-bs32-lr6e-7-gh200"
30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
31 | mkdir -p "$OUTPUT_DIR"
32 |
33 | MAX_PIXELS=$((512 * 28 * 28))
34 |
35 | export LOG_PATH=${OUTPUT_DIR}/debug.log
36 | export STRICT_FORMAT=True
37 |
38 | MASTER_PORT=29500
39 |
40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
41 | NUM_TRAIN_NODES=${#NODELIST[@]}
42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
43 |
44 | # Choose the first training node as the rendezvous head node
45 | HEAD_NODE=${TRAIN_NODES_LIST[0]}
46 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
47 | echo "Head Node IP: $HEAD_NODE_IP"
48 |
49 |
50 |
51 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32
52 | TRAIN_CMD="open_r1/grpo.py \
53 | --output_dir ${OUTPUT_DIR} \
54 | --model_name_or_path ${MODEL_PATH} \
55 | --dataset_name ${DATA_PATH} \
56 | --max_prompt_length 2048 \
57 | --max_completion_length 1024 \
58 | --custom_per_device_train_batch_size 16 \
59 | --deepspeed ./local_scripts/zero2_offload.json \
60 | --gradient_accumulation_steps 1 \
61 | --learning_rate 6e-7 \
62 | --use_predefined_cats true \
63 | --logging_steps 1 \
64 | --use_vllm true \
65 | --use_local_vllm true\
66 | --bf16 true\
67 | --tf32 true\
68 | --report_to wandb \
69 | --gradient_checkpointing true \
70 | --max_pixels ${MAX_PIXELS} \
71 | --temperature 1 \
72 | --top_p 0.9 \
73 | --top_k 50 \
74 | --num_train_epochs 1 \
75 | --run_name ${RUN_NAME} \
76 | --save_steps 100 \
77 | --num_generations ${GROUP_SIZE} \
78 | --num_iterations 1 \
79 | --beta 0.0\
80 | --vllm_max_model_len 4096 \
81 | --vllm_gpu_memory_utilization 0.2 \
82 | --save_only_model true \
83 | --seed 42"
84 |
85 |
86 | echo "start training..."
87 |
88 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
89 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
90 | --node_rank ${SLURM_NODEID} \
91 | --rdzv_id $RANDOM \
92 | --rdzv_backend c10d \
93 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \
94 | ${TRAIN_CMD}
95 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_GH200_SFT_GRPO_1k_lr2x_psg
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200
8 | #SBATCH --ntasks=8 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=288 # fixed for GH200
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_gh200_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | set -x
19 | # ---------- Environment Setup ----------
20 | export NCCL_ASYNC_ERROR_HANDLING=1
21 | export DEBUG_MODE=True
22 | export WANDB_PROJECT=RL4SGG
23 |
24 |
25 | GPUS_PER_NODE=4
26 | GROUP_SIZE=8
27 | MODEL_PATH=$1
28 |
29 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | DATA_PATH="JosephZ/psg_train_sg"
31 |
32 | RUN_NAME="qwen2vl-7b-sft-grpo-psg-merged-g8-n1-bs32-1k-lr6e-7-gh200"
33 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
34 | mkdir -p "$OUTPUT_DIR"
35 |
36 | export LOG_PATH=${OUTPUT_DIR}/debug.log
37 | export STRICT_FORMAT=True
38 |
39 | MAX_PIXELS=$((512 * 28 * 28))
40 |
41 |
42 | MASTER_PORT=29500
43 |
44 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
45 | NUM_TRAIN_NODES=${#NODELIST[@]}
46 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
47 |
48 | #HEAD_NODE=${TRAIN_NODES_LIST[0]}
49 | #HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
50 | #echo "Head Node IP: $HEAD_NODE_IP"
51 |
52 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
53 | echo "MASTER_ADDR: $MASTER_ADDR"
54 |
55 |
56 |
57 | # GH200 has a very high bandwidth between CPU and GPU, we should use it!
58 | # zero2:
59 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
60 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
61 | # bsz_per_devie=16, ~40h with 4x GPUs
62 | #
63 | # batch size: PER DEVICE(16) * ACC(1) * GPU(4) * NODE(4) // GROUP_SIZE(8) = 32
64 | TRAIN_CMD="open_r1/grpo.py \
65 | --output_dir ${OUTPUT_DIR} \
66 | --model_name_or_path ${MODEL_PATH} \
67 | --dataset_name ${DATA_PATH} \
68 | --max_prompt_length 2048 \
69 | --max_completion_length 1024 \
70 | --custom_per_device_train_batch_size 8 \
71 | --deepspeed ./local_scripts/zero2_offload.json \
72 | --gradient_accumulation_steps 1 \
73 | --learning_rate 6e-7 \
74 | --logging_steps 1 \
75 | --use_vllm true \
76 | --use_local_vllm true\
77 | --bf16 true\
78 | --tf32 true\
79 | --report_to wandb \
80 | --gradient_checkpointing true \
81 | --max_pixels ${MAX_PIXELS} \
82 | --temperature 1 \
83 | --top_p 0.9 \
84 | --top_k 50 \
85 | --num_train_epochs 1 \
86 | --run_name ${RUN_NAME} \
87 | --save_steps 100 \
88 | --num_generations ${GROUP_SIZE} \
89 | --num_iterations 1 \
90 | --beta 0.0\
91 | --vllm_max_model_len 4096 \
92 | --vllm_gpu_memory_utilization 0.2 \
93 | --save_only_model false"
94 |
95 |
96 | echo "start training with TRAIN_CMD=${TRAIN_CMD} ..."
97 |
98 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
99 | --node_rank ${SLURM_NODEID} \
100 | --rdzv_id $RANDOM \
101 | --rdzv_backend c10d \
102 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
103 | ${TRAIN_CMD}
104 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_close.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_GH200_close_lr2x
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --exclude=nid006792
8 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200
9 | #SBATCH --ntasks=8 # Total tasks equals total nodes
10 | #SBATCH --ntasks-per-node=1
11 | #SBATCH --gpus-per-node=4
12 | #SBATCH --cpus-per-task=288 # fixed for GH200
13 |
14 | #SBATCH --partition=normal
15 | #SBATCH --output=RL_gh200_%j_%N.out
16 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
17 |
18 |
19 | set -x
20 | # ---------- Environment Setup ----------
21 | export NCCL_ASYNC_ERROR_HANDLING=1
22 | export DEBUG_MODE=True
23 | export WANDB_PROJECT=RL4SGG
24 |
25 |
26 | GPUS_PER_NODE=4
27 | GROUP_SIZE=8
28 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | RUN_NAME="qwen2vl-7b-close-grpo-g8-n1-bs32-lr6e-7-gh200"
31 | export OUTPUT_DIR="${SCRATCH}/models/7B/${RUN_NAME}"
32 | mkdir -p "$OUTPUT_DIR"
33 |
34 | export STRICT_FORMAT=True
35 | export LOG_PATH=${OUTPUT_DIR}/debug.log
36 |
37 | MAX_PIXELS=$((512 * 28 * 28))
38 |
39 |
40 | MASTER_PORT=29500
41 |
42 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
43 | NUM_TRAIN_NODES=${#NODELIST[@]}
44 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
45 |
46 |
47 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
48 | echo "MASTER_ADDR: $MASTER_ADDR"
49 |
50 |
51 | # GH200 has a very high bandwidth between CPU and GPU, we should use it!
52 | # zero2:
53 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
54 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
55 | # bsz_per_devie=16, ~40h with 4x GPUs
56 | #
57 | # batch size: 16*1*4*4 //8=32
58 | TRAIN_CMD="open_r1/grpo.py \
59 | --output_dir ${OUTPUT_DIR} \
60 | --model_name_or_path ${MODEL_PATH} \
61 | --dataset_name ${DATA_PATH} \
62 | --max_prompt_length 2048 \
63 | --max_completion_length 1024 \
64 | --custom_per_device_train_batch_size 8 \
65 | --deepspeed ./local_scripts/zero2_offload.json \
66 | --gradient_accumulation_steps 1 \
67 | --learning_rate 6e-7 \
68 | --use_predefined_cats true \
69 | --logging_steps 1 \
70 | --use_vllm true \
71 | --use_local_vllm true\
72 | --bf16 true\
73 | --tf32 true\
74 | --report_to wandb \
75 | --gradient_checkpointing true \
76 | --max_pixels ${MAX_PIXELS} \
77 | --temperature 1 \
78 | --top_p 0.9 \
79 | --top_k 50 \
80 | --num_train_epochs 1 \
81 | --run_name ${RUN_NAME} \
82 | --save_steps 100 \
83 | --num_generations ${GROUP_SIZE} \
84 | --num_iterations 1 \
85 | --beta 0.0 \
86 | --vllm_max_model_len 4096 \
87 | --vllm_gpu_memory_utilization 0.2 \
88 | --ddp_timeout 3600 \
89 | --save_only_model false"
90 |
91 |
92 | echo "start training with CMD=${TRAIN_CMD} ..."
93 |
94 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
95 | --node_rank ${SLURM_NODEID} \
96 | --rdzv_id $RANDOM \
97 | --rdzv_backend c10d \
98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
99 | ${TRAIN_CMD}
100 |
--------------------------------------------------------------------------------
/scripts/grpo/train_gh200_close_SFT.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 | #SBATCH --job-name=7B_GH200_SFT_CLOSE_GRPO_lr2x
5 | #SBATCH --time=12:00:00
6 |
7 | #SBATCH --nodes=8 # 8 nodes, each has 4x GH200
8 | #SBATCH --ntasks=8 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=4
11 | #SBATCH --cpus-per-task=288 # fixed for GH200
12 |
13 | #SBATCH --partition=normal
14 | #SBATCH --output=RL_gh200_%j_%N.out
15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
16 |
17 |
18 | set -x
19 | # ---------- Environment Setup ----------
20 | export NCCL_ASYNC_ERROR_HANDLING=1
21 | export DEBUG_MODE=True
22 | export WANDB_PROJECT=RL4SGG
23 |
24 |
25 | GPUS_PER_NODE=4
26 | GROUP_SIZE=8
27 | MODEL_PATH=$1
28 |
29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt"
30 | RUN_NAME="qwen2vl-7b-sft-grpo-g8-n1-bs32-close-lr6e-7-gh200"
31 | export OUTPUT_DIR="${SCRATCH}/models/7B/${RUN_NAME}"
32 | mkdir -p "$OUTPUT_DIR"
33 |
34 | export LOG_PATH=${OUTPUT_DIR}/debug.log
35 | export STRICT_FORMAT=True
36 |
37 | MAX_PIXELS=$((512 * 28 * 28))
38 |
39 |
40 | MASTER_PORT=29500
41 |
42 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
43 | NUM_TRAIN_NODES=${#NODELIST[@]}
44 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
45 |
46 |
47 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
48 | echo "MASTER_ADDR: $MASTER_ADDR"
49 |
50 |
51 |
52 | # GH200 has a very high bandwidth between CPU and GPU, we should use it!
53 | # zero2:
54 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs
55 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs
56 | # bsz_per_devie=16, ~40h with 4x GPUs
57 | #
58 | # batch size: PER DEVICE(16) * ACC(1) * GPU(4) * NODE(4) // GROUP_SIZE(8) = 32
59 | TRAIN_CMD="open_r1/grpo.py \
60 | --output_dir ${OUTPUT_DIR} \
61 | --model_name_or_path ${MODEL_PATH} \
62 | --dataset_name ${DATA_PATH} \
63 | --max_prompt_length 2048 \
64 | --max_completion_length 1024 \
65 | --custom_per_device_train_batch_size 8 \
66 | --deepspeed ./local_scripts/zero2_offload.json \
67 | --gradient_accumulation_steps 1 \
68 | --learning_rate 6e-7 \
69 | --logging_steps 1 \
70 | --use_vllm true \
71 | --use_local_vllm true\
72 | --bf16 true\
73 | --tf32 true\
74 | --report_to wandb \
75 | --gradient_checkpointing true \
76 | --max_pixels ${MAX_PIXELS} \
77 | --temperature 1 \
78 | --top_p 0.9 \
79 | --top_k 50 \
80 | --num_train_epochs 1 \
81 | --run_name ${RUN_NAME} \
82 | --save_steps 100 \
83 | --num_generations ${GROUP_SIZE} \
84 | --num_iterations 1 \
85 | --beta 0.0\
86 | --vllm_max_model_len 4096 \
87 | --vllm_gpu_memory_utilization 0.2 \
88 | --use_predefined_cats true \
89 | --ddp_timeout 3600 \
90 | --save_only_model false"
91 |
92 |
93 | echo "start training with TRAIN_CMD=${TRAIN_CMD} ..."
94 |
95 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
96 | --node_rank ${SLURM_NODEID} \
97 | --rdzv_id $RANDOM \
98 | --rdzv_backend c10d \
99 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \
100 | ${TRAIN_CMD}
101 |
--------------------------------------------------------------------------------
/scripts/grpo/train_zero3.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=GRPO_train
6 | #SBATCH --time=24:00:00
7 | #SBATCH --nodes=16 # each node has 8x GPUs, 4x for training, 4x for vLLM inference
8 | #SBATCH --ntasks=16 # Total tasks equals total nodes
9 | #SBATCH --ntasks-per-node=1
10 | #SBATCH --gpus-per-node=rtx_4090:8
11 | #SBATCH --cpus-per-task=16
12 | #SBATCH --mem-per-cpu=16000M
13 | #SBATCH --output=RL_%j_%N.out
14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL
15 |
16 |
17 | # force crashing on nccl issues like hanging broadcast
18 | export NCCL_ASYNC_ERROR_HANDLING=1
19 | # export NCCL_DEBUG=INFO
20 | # export NCCL_DEBUG_SUBSYS=COLL
21 | # export NCCL_SOCKET_NTHREADS=1
22 | # export NCCL_NSOCKS_PERTHREAD=1
23 | # export CUDA_LAUNCH_BLOCKING=1
24 |
25 | # wait for vLLM servers
26 | #sleep 60
27 |
28 | # Read IPs from file and join them with commas
29 | #ip_str=$(paste -sd, ip_list.txt)
30 | #echo "vLLM servers: $ip_str"
31 |
32 | FILE="ip_port_list.txt"
33 |
34 | SERVER_IP=""
35 | SERVER_PORT=""
36 |
37 | while IFS=: read -r ip port; do
38 | SERVER_IP+="${ip},"
39 | SERVER_PORT+="${port},"
40 | done < "$FILE"
41 |
42 | # Remove trailing commas
43 | SERVER_IP="${SERVER_IP%,}"
44 | SERVER_PORT="${SERVER_PORT%,}"
45 |
46 | echo "SERVER_IP=$SERVER_IP"
47 | echo "SERVER_PORT=$SERVER_PORT"
48 |
49 |
50 | # Define node counts
51 | GPUS_PER_NODE=8
52 |
53 | # Get the list of allocated nodes
54 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST))
55 | NUM_TRAIN_NODES=${#NODELIST[@]}
56 |
57 | # Assign training nodes (first NUM_TRAIN_NODES nodes)
58 | TRAIN_NODES=("${NODELIST[@]:0:$NUM_TRAIN_NODES}")
59 |
60 | # Choose the first training node as the rendezvous head node
61 | HEAD_NODE=${TRAIN_NODES[0]}
62 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address)
63 | echo "Head Node IP: $HEAD_NODE_IP"
64 |
65 | echo "environment: $(env | grep NCCL)"
66 |
67 |
68 | # Create a comma-separated list of training nodes for srun
69 | TRAIN_NODES_LIST=$(IFS=, ; echo "${TRAIN_NODES[*]}")
70 |
71 | # Define HOST and PORT for the vLLM server
72 | PORT_A=8888
73 |
74 |
75 | export DEBUG_MODE=True
76 | export WANDB_PROJECT=RL4SGG
77 |
78 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt"
79 | export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct"
80 |
81 | export NODE_RANK=${SLURM_NODEID} # Provided by SLURM
82 |
83 | MAX_PIXELS=$((512 * 28 * 28))
84 |
85 | # Launch distributed training on the training nodes using 8 GPUs per node
86 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \
87 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \
88 | --node_rank $NODE_RANK \
89 | --rdzv_id $RANDOM \
90 | --rdzv_backend c10d \
91 | --rdzv_endpoint ${HEAD_NODE_IP}:29500 \
92 | open_r1/grpo.py \
93 | --output_dir models/qwen2vl-nokl-n1-g8 \
94 | --model_name_or_path ${MODEL_PATH} \
95 | --dataset_name $DATA_PATH \
96 | --deepspeed ./local_scripts/zero3.json \
97 | --max_prompt_length 2048 \
98 | --max_completion_length 1024 \
99 | --per_device_train_batch_size 1 \
100 | --gradient_accumulation_steps 1 \
101 | --logging_steps 1 \
102 | --use_vllm true \
103 | --vllm_server_host ${SERVER_IP} \
104 | --vllm_server_port ${SERVER_PORT} \
105 | --vllm_server_timeout 600 \
106 | --bf16 \
107 | --report_to wandb \
108 | --gradient_checkpointing true \
109 | --max_pixels ${MAX_PIXELS} \
110 | --temperature 0.3 \
111 | --top_p 0.001 \
112 | --top_k 1 \
113 | --num_train_epochs 1 \
114 | --run_name Qwen2VL-7B-GRPO-nokl-n1-G8 \
115 | --save_steps 100 \
116 | --num_generations 8 \
117 | --num_iterations 1 \
118 | --beta 0.0
119 |
--------------------------------------------------------------------------------
/scripts/inference/run_sgg_inference.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | export GPUS_PER_NODE=4
6 |
7 |
8 | DATASET=$1
9 | MODEL_NAME=$2
10 | OUTPUT_DIR=$3
11 | USE_CATS=$4 # true/false
12 | PROMPT_TYPE=$5 # true/false
13 |
14 | BATCH_SIZE=${6:-8}
15 |
16 | echo "MODEL_NAME: $MODEL_NAME, OUTPUT_DIR: $OUTPUT_DIR"
17 | echo "USE_CATS: $USE_CATS, PROMPT_TYPE: $PROMPT_TYPE"
18 |
19 | ARGS="--dataset $DATASET --model $MODEL_NAME --output_dir $OUTPUT_DIR --max_model_len 4096 --batch_size $BATCH_SIZE"
20 |
21 |
22 | if [ "$PROMPT_TYPE" == "true" ]; then
23 | ARGS="$ARGS --use_think_system_prompt"
24 | fi
25 |
26 | if [ "$USE_CATS" == "true" ]; then
27 | ARGS="$ARGS --use_predefined_cats"
28 | fi
29 |
30 | echo "ARGS:$ARGS"
31 |
32 | torchrun --nnodes 1 \
33 | --nproc_per_node $GPUS_PER_NODE \
34 | --node_rank 0 \
35 | src/sgg_inference_vllm.py -- $ARGS
36 |
--------------------------------------------------------------------------------
/scripts/sft/2B_sgg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=SFT_2B_psg
6 | #SBATCH --time=12:00:00
7 |
8 | # 4x A100
9 |
10 | #SBATCH --nodes=8
11 | #SBATCH --ntasks-per-node=1
12 | #SBATCH --gpus-per-node=4
13 | #SBATCH --cpus-per-task=288
14 |
15 | #SBATCH --mail-user="zychen.uestc@gmail.com"
16 | #SBATCH --mail-type=ALL
17 | #SBATCH --output=SFT-2B_%j_%N.out
18 |
19 | # Get node list and determine head node
20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
21 | head_node=${nodes[0]}
22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
23 |
24 |
25 | #DATASET=JosephZ/vg150_train_sgg_prompt
26 | DATASET=JosephZ/psg_train_sg
27 |
28 |
29 | echo "Head Node IP: $head_node_ip"
30 |
31 | # Set NODE_RANK from SLURM environment variable
32 | export NODE_RANK=${SLURM_NODEID}
33 |
34 | export GPUS_PER_NODE=4
35 |
36 | export WANDB_PROJECT=RL4SGG
37 |
38 | RUN_NAME="qwen2vl-2b-sft-open-psg-bs128-e6-merged-gh200"
39 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
40 | mkdir -p "$OUTPUT_DIR"
41 |
42 |
43 | # batch size=4 * 2 * 16 = 128
44 | srun torchrun --nnodes ${SLURM_NNODES} \
45 | --nproc_per_node $GPUS_PER_NODE \
46 | --node_rank $NODE_RANK \
47 | --rdzv_id $RANDOM \
48 | --rdzv_backend c10d \
49 | --rdzv_endpoint ${head_node_ip}:29500 \
50 | src/sft_sgg.py \
51 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
52 | --dataset_name $DATASET \
53 | --learning_rate 1e-5 \
54 | --per_device_train_batch_size 4\
55 | --gradient_accumulation_steps 1\
56 | --warmup_ratio 0.05 \
57 | --max_grad_norm 0.3 \
58 | --logging_steps 1 \
59 | --bf16 true\
60 | --tf32 true\
61 | --report_to wandb \
62 | --attn_implementation flash_attention_2 \
63 | --num_train_epochs 6 \
64 | --run_name $RUN_NAME \
65 | --save_steps 500 \
66 | --save_only_model true \
67 | --torch_dtype bfloat16 \
68 | --fsdp "full_shard auto_wrap" \
69 | --fsdp_config local_scripts/fsdp_config.json \
70 | --output_dir $OUTPUT_DIR \
71 | --seed 42
72 |
--------------------------------------------------------------------------------
/scripts/sft/2B_sgg_predefined.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=SFT_2B_close
6 | #SBATCH --time=24:00:00
7 |
8 | # 4x A100
9 |
10 | #SBATCH --nodes=1
11 | #SBATCH --ntasks-per-node=1
12 | #SBATCH --gpus-per-node=4
13 | #SBATCH --cpus-per-task=128
14 |
15 | #SBATCH --mail-user="zychen.uestc@gmail.com"
16 | #SBATCH --mail-type=ALL
17 | #SBATCH --output=SFT-2B_%j_%N.out
18 |
19 | # Get node list and determine head node
20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
21 | head_node=${nodes[0]}
22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
23 |
24 | echo "Head Node IP: $head_node_ip"
25 |
26 | # Set NODE_RANK from SLURM environment variable
27 | export NODE_RANK=${SLURM_NODEID}
28 |
29 | export GPUS_PER_NODE=4
30 |
31 | export WANDB_PROJECT=RL4SGG
32 |
33 |
34 | # batch size=4 * 2 * 16 = 128
35 | srun torchrun --nnodes ${SLURM_NNODES} \
36 | --nproc_per_node $GPUS_PER_NODE \
37 | --node_rank $NODE_RANK \
38 | --rdzv_id $RANDOM \
39 | --rdzv_backend c10d \
40 | --rdzv_endpoint ${head_node_ip}:29500 \
41 | src/sft_sgg.py \
42 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
43 | --dataset_name JosephZ/vg150_train_sgg_prompt \
44 | --learning_rate 1e-5 \
45 | --per_device_train_batch_size 4\
46 | --gradient_accumulation_steps 8\
47 | --warmup_ratio 0.05 \
48 | --max_grad_norm 0.3 \
49 | --logging_steps 1 \
50 | --bf16 true\
51 | --tf32 true\
52 | --report_to wandb \
53 | --attn_implementation flash_attention_2 \
54 | --num_train_epochs 3 \
55 | --run_name Qwen2-VL-2B_vg150_sgg_b128_predefined_e3 \
56 | --save_steps 100 \
57 | --save_only_model true \
58 | --torch_dtype bfloat16 \
59 | --fsdp "full_shard auto_wrap" \
60 | --fsdp_config local_scripts/fsdp_config.json \
61 | --use_predefined_cats true \
62 | --output_dir models/qwen2vl-2b-sft-vg150-b128-predefined-e3 \
63 | --seed 42
64 |
--------------------------------------------------------------------------------
/scripts/sft/7B_sgg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=SFT_7B_psg
6 | #SBATCH --time=12:00:00
7 |
8 | # 4x A100
9 |
10 | #SBATCH --nodes=8
11 | #SBATCH --ntasks=8
12 | #SBATCH --ntasks-per-node=1
13 | #SBATCH --gpus-per-node=4
14 | #SBATCH --cpus-per-task=288
15 |
16 | #SBATCH --mail-user="zychen.uestc@gmail.com"
17 | #SBATCH --mail-type=ALL
18 | #SBATCH --output=SFT-7B_%j_%N.out
19 |
20 |
21 | set -x
22 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
23 | echo "MASTER_ADDR:$MASTER_ADDR"
24 |
25 |
26 | # Set NODE_RANK from SLURM environment variable
27 | export NODE_RANK=${SLURM_NODEID}
28 |
29 | export GPUS_PER_NODE=4
30 |
31 | export WANDB_PROJECT=RL4SGG
32 |
33 | DATASET=JosephZ/psg_train_sg
34 |
35 | RUN_NAME="qwen2vl-7b-sft-open-psg-bs128-e3-merged-gh200"
36 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
37 | mkdir -p "$OUTPUT_DIR"
38 |
39 | # batch size=4 * 2 * 16 = 128
40 | srun torchrun --nnodes ${SLURM_NNODES} \
41 | --nproc_per_node $GPUS_PER_NODE \
42 | --node_rank $NODE_RANK \
43 | --rdzv_id $RANDOM \
44 | --rdzv_backend c10d \
45 | --rdzv_endpoint ${MASTER_ADDR}:29500 \
46 | src/sft_sgg.py \
47 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
48 | --dataset_name $DATASET \
49 | --learning_rate 1e-5 \
50 | --per_device_train_batch_size 4 \
51 | --gradient_accumulation_steps 1 \
52 | --warmup_ratio 0.05 \
53 | --max_grad_norm 0.3 \
54 | --logging_steps 1 \
55 | --bf16 true\
56 | --tf32 true\
57 | --report_to wandb \
58 | --attn_implementation flash_attention_2 \
59 | --num_train_epochs 3 \
60 | --run_name $RUN_NAME \
61 | --save_steps 100 \
62 | --save_only_model true \
63 | --torch_dtype bfloat16 \
64 | --fsdp "full_shard auto_wrap" \
65 | --fsdp_config local_scripts/fsdp_config.json \
66 | --output_dir $OUTPUT_DIR \
67 | --seed 42
68 |
69 |
--------------------------------------------------------------------------------
/scripts/sft/7B_sgg_lora.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=SFT_7B_lora
6 | #SBATCH --time=12:00:00
7 |
8 | # 4x A100
9 |
10 | #SBATCH --nodes=4
11 | #SBATCH --ntasks=4
12 | #SBATCH --ntasks-per-node=1
13 | #SBATCH --gpus-per-node=4
14 | #SBATCH --cpus-per-task=288
15 |
16 | #SBATCH --mail-user="zychen.uestc@gmail.com"
17 | #SBATCH --mail-type=ALL
18 | #SBATCH --output=SFT-7B_%j_%N.out
19 |
20 |
21 | set -x
22 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1)
23 | echo "MASTER_ADDR:$MASTER_ADDR"
24 |
25 |
26 | # Set NODE_RANK from SLURM environment variable
27 | export NODE_RANK=${SLURM_NODEID}
28 |
29 | export GPUS_PER_NODE=4
30 |
31 | export WANDB_PROJECT=RL4SGG
32 |
33 | DATASET=JosephZ/vg150_train_sgg_prompt
34 |
35 | RUN_NAME="qwen2vl-7b-sft-open-vg150-lora-bs128-gh200"
36 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}"
37 | mkdir -p "$OUTPUT_DIR"
38 |
39 | # batch size=4 * 2 * 16 = 128
40 | srun torchrun --nnodes ${SLURM_NNODES} \
41 | --nproc_per_node $GPUS_PER_NODE \
42 | --node_rank $NODE_RANK \
43 | --rdzv_id $RANDOM \
44 | --rdzv_backend c10d \
45 | --rdzv_endpoint ${MASTER_ADDR}:29500 \
46 | src/sft_sgg.py \
47 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
48 | --dataset_name $DATASET \
49 | --learning_rate 1e-5 \
50 | --per_device_train_batch_size 8 \
51 | --gradient_accumulation_steps 1 \
52 | --warmup_ratio 0.05 \
53 | --max_grad_norm 0.3 \
54 | --logging_steps 1 \
55 | --bf16 true\
56 | --tf32 true\
57 | --report_to wandb \
58 | --attn_implementation flash_attention_2 \
59 | --num_train_epochs 3 \
60 | --run_name $RUN_NAME \
61 | --save_steps 100 \
62 | --save_only_model true \
63 | --torch_dtype bfloat16 \
64 | --lora_r 16 \
65 | --lora_alpha 32 \
66 | --lora_dropout 0.05 \
67 | --fsdp "full_shard auto_wrap" \
68 | --fsdp_config local_scripts/fsdp_config.json \
69 | --output_dir $OUTPUT_DIR \
70 | --seed 42
71 |
72 |
--------------------------------------------------------------------------------
/scripts/sft/7B_sgg_predefined.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 | #SBATCH --job-name=SFT_7B_close
6 | #SBATCH --time=24:00:00
7 |
8 | # 4x A100
9 |
10 | #SBATCH --nodes=1
11 | #SBATCH --ntasks-per-node=1
12 | #SBATCH --gpus-per-node=4
13 | #SBATCH --cpus-per-task=128
14 |
15 | #SBATCH --mail-user="zychen.uestc@gmail.com"
16 | #SBATCH --mail-type=ALL
17 | #SBATCH --output=SFT-7B-close_%j_%N.out
18 |
19 | # Get node list and determine head node
20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) )
21 | head_node=${nodes[0]}
22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
23 |
24 | echo "Head Node IP: $head_node_ip"
25 |
26 | # Set NODE_RANK from SLURM environment variable
27 | export NODE_RANK=${SLURM_NODEID}
28 |
29 | export GPUS_PER_NODE=4
30 |
31 | export WANDB_PROJECT=RL4SGG
32 |
33 |
34 | # batch size=4 * 2 * 16 = 128
35 | srun torchrun --nnodes ${SLURM_NNODES} \
36 | --nproc_per_node $GPUS_PER_NODE \
37 | --node_rank $NODE_RANK \
38 | --rdzv_id $RANDOM \
39 | --rdzv_backend c10d \
40 | --rdzv_endpoint ${head_node_ip}:29500 \
41 | src/sft_sgg.py \
42 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
43 | --dataset_name JosephZ/vg150_train_sgg_prompt \
44 | --learning_rate 1e-5 \
45 | --per_device_train_batch_size 2 \
46 | --gradient_accumulation_steps 16 \
47 | --warmup_ratio 0.05 \
48 | --max_grad_norm 0.3 \
49 | --logging_steps 1 \
50 | --bf16 true\
51 | --tf32 true\
52 | --report_to wandb \
53 | --attn_implementation flash_attention_2 \
54 | --num_train_epochs 3 \
55 | --run_name Qwen2-VL-7B_vg150_sgg_b128_predefined_e3 \
56 | --save_steps 100 \
57 | --save_only_model true \
58 | --torch_dtype bfloat16 \
59 | --fsdp "full_shard auto_wrap" \
60 | --fsdp_config local_scripts/fsdp_config.json \
61 | --use_predefined_cats true \
62 | --output_dir models/qwen2vl-7b-sft-vg150-b128-predefined-e3 \
63 | --seed 42
64 |
65 |
--------------------------------------------------------------------------------
/scripts/sft_local/2B_sgg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 | export TORCH_DISTRIBUTED_DEBUG=INFO
7 | export NCCL_DEBUG=INFO
8 |
9 |
10 | export GPUS_PER_NODE=4
11 | export WANDB_PROJECT=RL4SGG
12 |
13 |
14 | # batch size=4 * 2 * 16 = 128
15 | torchrun --nnodes 1 \
16 | --nproc_per_node $GPUS_PER_NODE \
17 | --node_rank 0 \
18 | src/sft_sgg.py \
19 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
20 | --dataset_name JosephZ/vg150_train_sgg_prompt \
21 | --learning_rate 1e-5 \
22 | --per_device_train_batch_size 4\
23 | --gradient_accumulation_steps 8\
24 | --warmup_ratio 0.05 \
25 | --max_grad_norm 0.3 \
26 | --logging_steps 1 \
27 | --bf16 true\
28 | --tf32 true\
29 | --report_to wandb \
30 | --attn_implementation flash_attention_2 \
31 | --num_train_epochs 3 \
32 | --run_name Qwen2-VL-2B_vg150_sgg_b128_open_e3 \
33 | --save_steps 100 \
34 | --save_only_model true \
35 | --torch_dtype bfloat16 \
36 | --fsdp "full_shard auto_wrap" \
37 | --fsdp_config local_scripts/fsdp_config.json \
38 | --output_dir models/qwen2vl-2b-sft-vg150-b128-open-e3 \
39 | --seed 42
40 |
--------------------------------------------------------------------------------
/scripts/sft_local/2B_sgg_predefined.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 | export GPUS_PER_NODE=4
7 |
8 | export WANDB_PROJECT=RL4SGG
9 |
10 |
11 | # batch size=4 * 2 * 16 = 128
12 | torchrun --nnodes 1 \
13 | --nproc_per_node $GPUS_PER_NODE \
14 | --node_rank 0 \
15 | src/sft_sgg.py \
16 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \
17 | --dataset_name JosephZ/vg150_train_sgg_prompt \
18 | --learning_rate 1e-5 \
19 | --per_device_train_batch_size 4\
20 | --gradient_accumulation_steps 8\
21 | --warmup_ratio 0.05 \
22 | --max_grad_norm 0.3 \
23 | --logging_steps 1 \
24 | --bf16 true\
25 | --tf32 true\
26 | --report_to wandb \
27 | --attn_implementation flash_attention_2 \
28 | --num_train_epochs 3 \
29 | --run_name Qwen2-VL-2B_vg150_sgg_b128_predefined_e3 \
30 | --save_steps 100 \
31 | --save_only_model true \
32 | --torch_dtype bfloat16 \
33 | --fsdp "full_shard auto_wrap" \
34 | --fsdp_config local_scripts/fsdp_config.json \
35 | --use_predefined_cats true \
36 | --output_dir models/qwen2vl-2b-sft-vg150-b128-predefined-e3 \
37 | --seed 42
38 |
--------------------------------------------------------------------------------
/scripts/sft_local/7B_sgg.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 |
7 | export GPUS_PER_NODE=4
8 |
9 | export WANDB_PROJECT=RL4SGG
10 |
11 |
12 | # batch size=4 * 2 * 16 = 128
13 | torchrun --nnodes 1 \
14 | --nproc_per_node $GPUS_PER_NODE \
15 | --node_rank 0 \
16 | src/sft_sgg.py \
17 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
18 | --dataset_name JosephZ/vg150_train_sgg_prompt \
19 | --learning_rate 1e-5 \
20 | --per_device_train_batch_size 2 \
21 | --gradient_accumulation_steps 16 \
22 | --warmup_ratio 0.05 \
23 | --max_grad_norm 0.3 \
24 | --logging_steps 1 \
25 | --bf16 true\
26 | --tf32 true\
27 | --report_to wandb \
28 | --attn_implementation flash_attention_2 \
29 | --num_train_epochs 3 \
30 | --run_name Qwen2-VL-7B_vg150_sgg_b128_open_e3 \
31 | --save_steps 100 \
32 | --save_only_model true \
33 | --torch_dtype bfloat16 \
34 | --fsdp "full_shard auto_wrap" \
35 | --fsdp_config local_scripts/fsdp_config.json \
36 | --output_dir models/qwen2vl-7b-sft-vg150-b128-open-e3 \
37 | --seed 42
38 |
39 |
--------------------------------------------------------------------------------
/scripts/sft_local/7B_sgg_predefined.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 |
4 |
5 |
6 | export GPUS_PER_NODE=4
7 |
8 | export WANDB_PROJECT=RL4SGG
9 |
10 |
11 | # batch size=4 * 2 * 16 = 128
12 | srun torchrun --nnodes 1 \
13 | --nproc_per_node $GPUS_PER_NODE \
14 | --node_rank 0 \
15 | src/sft_sgg.py \
16 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \
17 | --dataset_name JosephZ/vg150_train_sgg_prompt \
18 | --learning_rate 1e-5 \
19 | --per_device_train_batch_size 2 \
20 | --gradient_accumulation_steps 16 \
21 | --warmup_ratio 0.05 \
22 | --max_grad_norm 0.3 \
23 | --logging_steps 1 \
24 | --bf16 true\
25 | --tf32 true\
26 | --report_to wandb \
27 | --attn_implementation flash_attention_2 \
28 | --num_train_epochs 3 \
29 | --run_name Qwen2-VL-7B_vg150_sgg_b128_predefined_e3 \
30 | --save_steps 100 \
31 | --save_only_model true \
32 | --torch_dtype bfloat16 \
33 | --fsdp "full_shard auto_wrap" \
34 | --fsdp_config local_scripts/fsdp_config.json \
35 | --use_predefined_cats true \
36 | --output_dir models/qwen2vl-7b-sft-vg150-b128-predefined-e3 \
37 | --seed 42
38 |
39 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name="open_r1",
5 | version="0.1",
6 | packages=find_packages(),
7 | )
8 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/__init__.py
--------------------------------------------------------------------------------
/src/psg_categories.json:
--------------------------------------------------------------------------------
1 | {"thing_classes": ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"], "stuff_classes": ["banner", "blanket", "bridge", "cardboard", "counter", "curtain", "door-stuff", "floor-wood", "flower", "fruit", "gravel", "house", "light", "mirror-stuff", "net", "pillow", "platform", "playingfield", "railroad", "river", "road", "roof", "sand", "sea", "shelf", "snow", "stairs", "tent", "towel", "wall-brick", "wall-stone", "wall-tile", "wall-wood", "water-other", "window-blind", "window-other", "tree-merged", "fence-merged", "ceiling-merged", "sky-other-merged", "cabinet-merged", "table-merged", "floor-other-merged", "pavement-merged", "mountain-merged", "grass-merged", "dirt-merged", "paper-merged", "food-other-merged", "building-other-merged", "rock-merged", "wall-other-merged", "rug-merged"], "predicate_classes": ["over", "in front of", "beside", "on", "in", "attached to", "hanging from", "on back of", "falling off", "going down", "painted on", "walking on", "running on", "crossing", "standing on", "lying on", "sitting on", "flying over", "jumping over", "jumping from", "wearing", "holding", "carrying", "looking at", "guiding", "kissing", "eating", "drinking", "feeding", "biting", "catching", "picking", "playing with", "chasing", "climbing", "cleaning", "playing", "touching", "pushing", "pulling", "opening", "cooking", "talking to", "throwing", "slicing", "driving", "riding", "parked on", "driving on", "about to hit", "kicking", "swinging", "entering", "exiting", "enclosing", "leaning on"]}
--------------------------------------------------------------------------------
/src/sgg_inference_vllm.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import re
4 | import torch
5 | import glob
6 | import argparse
7 | from datasets import load_dataset
8 | from transformers import AutoProcessor
9 | from accelerate import Accelerator
10 | from torch.utils.data import DataLoader
11 | from tqdm import tqdm
12 | from transformers import Qwen2VLForConditionalGeneration, GenerationConfig
13 | from qwen_vl_utils import process_vision_info
14 |
15 | import numpy as np
16 | import random
17 | from PIL import Image, ImageDraw
18 |
19 | from transformers import Qwen2_5_VLForConditionalGeneration
20 |
21 | from vllm import LLM, SamplingParams
22 | from huggingface_hub import snapshot_download
23 |
24 | os.environ["NCCL_SOCKET_TIMEOUT"] = "3600000" # 1 hours
25 | os.environ["NCCL_BLOCKING_WAIT"] = "1"
26 |
27 |
28 | from src.vg_synonyms import VG150_OBJ_CATEGORIES, VG150_PREDICATES
29 |
30 | from transformers import AutoProcessor, LlavaForConditionalGeneration
31 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration
32 |
33 |
34 |
35 | from open_r1.trainer.utils.misc import encode_image_to_base64, is_pil_image
36 |
37 | SYSTEM_PROMPT = (
38 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
39 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
40 | "process and answer are enclosed within and tags, respectively, i.e., "
41 | " reasoning process here answer here "
42 | )
43 |
44 | from open_r1.trainer.utils.prompt_gallery import (
45 | PROMPT_SG,
46 | PROMPT_CLOSE_PSG,
47 | PROMPT_CLOSE_VG150,
48 | VG150_BASE_OBJ_CATEGORIES,
49 | VG150_BASE_PREDICATE,
50 | format_prompt_close_sg,
51 | VG150_OBJ_CATEGORIES,
52 | VG150_PREDICATES
53 | )
54 |
55 |
56 | def get_model(name, device_map="auto", max_model_len=4096):
57 | is_qwen2vl = 'qwen2vl' in name.lower() or 'qwen2-vl' in name.lower()
58 | is_qwen25vl = 'qwen2.5-vl' in name.lower() or 'qwen25-vl' in name.lower() or 'qwen2.5vl' in name.lower()
59 | is_llava = 'llava' in name.lower()
60 | base_model_name = None
61 | if is_qwen2vl or is_qwen25vl:
62 | print("Using model:", name)
63 | min_pixels = 4*28*28
64 | max_pixels = 1024*28*28
65 | if is_qwen2vl:
66 | if '7b' in name.lower():
67 | base_model_name = "Qwen/Qwen2-VL-7B-Instruct"
68 | elif '2b' in name.lower():
69 | base_model_name = "Qwen/Qwen2-VL-2B-Instruct"
70 | if is_qwen25vl:
71 | if '7b' in name.lower():
72 | base_model_name = "Qwen/Qwen2.5-VL-7B-Instruct"
73 | elif '3b' in name.lower():
74 | base_model_name = "Qwen/Qwen2.5-VL-3B-Instruct"
75 |
76 | assert base_model_name is not None, "TODO: check the model -- {}".format(name)
77 | processor = AutoProcessor.from_pretrained(base_model_name,
78 | min_pixels=min_pixels, max_pixels=max_pixels)
79 |
80 | try:
81 | local_model_path = snapshot_download(name)
82 | print(f"set model:{name} to local path:", local_model_path)
83 | name = local_model_path
84 | except:
85 | pass
86 |
87 | model = LLM(
88 | model=name,
89 | limit_mm_per_prompt={"image": 1},
90 | dtype='bfloat16',
91 | device=device_map,
92 | max_model_len=max_model_len,
93 | mm_processor_kwargs= { "max_pixels": max_pixels, "min_pixels": min_pixels},
94 | )
95 | elif is_llava:
96 | model_cls = LlavaForConditionalGeneration if '1.5' in name else LlavaNextForConditionalGeneration
97 | model = model_cls.from_pretrained(
98 | name,
99 | torch_dtype=torch.bfloat16,
100 | ).to(device_map)
101 | processor = AutoProcessor.from_pretrained(name)
102 | else:
103 | raise Exception(f"Unknown model_id: {name}")
104 |
105 | return is_qwen2vl, is_qwen25vl, is_llava, model, processor
106 |
107 |
108 | def replace_answer_format(item: str) -> str:
109 | return item.replace("", "```json").replace("", "```")
110 |
111 | def format_data(dataset_name, sample, use_predefined_cats=False, use_think_system_prompt=False, remove_image_size_in_prompt=True):
112 | image = sample['image'].convert('RGB')
113 | iw, ih = image.size
114 | if use_predefined_cats:
115 | prompt = PROMPT_CLOSE_PSG if 'psg' in dataset_name else PROMPT_CLOSE_VG150
116 | else:
117 | prompt = PROMPT_SG
118 |
119 | if remove_image_size_in_prompt:
120 | prompt = prompt.replace(f"of size ({iw} x {ih}) ", "")
121 |
122 | prompt = replace_answer_format(prompt)
123 |
124 | system_prompt = SYSTEM_PROMPT if use_think_system_prompt else "You are a helpful and multimodal AI assistant."
125 |
126 | base64_image = encode_image_to_base64(image)
127 | messages = [
128 | {
129 | "role": "system",
130 | "content": system_prompt
131 | },
132 | {
133 | "role": "user",
134 | "content": [
135 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}},
136 | {"type": "text", "text": prompt},
137 | ],
138 | },
139 | ]
140 | return image, messages
141 |
142 | def parse_args():
143 | parser = argparse.ArgumentParser(description="Run model inference on a dataset.")
144 | parser.add_argument("--dataset", required=True, help="Hugging Face dataset identifier")
145 | parser.add_argument("--model", required=True, help="Model name to load")
146 | parser.add_argument("--output_dir", required=True, help="Directory to save the outputs")
147 | parser.add_argument("--use_think_system_prompt", action="store_true", help="Use system prompt with ...")
148 | parser.add_argument("--use_predefined_cats", action="store_true", help="Use predefined categories in the prompt")
149 | parser.add_argument("--max_model_len", type=int, default=4096, help="max_model_len for vLLM")
150 | parser.add_argument("--batch_size", type=int, default=1, help="batch size")
151 |
152 | return parser.parse_args()
153 |
154 | def main():
155 | # Parse command line arguments.
156 | args = parse_args()
157 | print("args:", args)
158 |
159 | # Initialize Accelerator for distributed training/inference.
160 | accelerator = Accelerator()
161 | local_rank = accelerator.local_process_index
162 | device = f"cuda:{local_rank}" # each process occupies a GPU
163 |
164 | # Get rank and world size for manual splitting
165 | rank = torch.distributed.get_rank() # GPU ID or node rank
166 | world_size = torch.distributed.get_world_size() # Total number of GPUs/nodes
167 |
168 |
169 | # Load the model and processor.
170 | is_qwen2vl, is_qwen25vl, is_llava, model, processor = get_model(args.model, device_map=device, max_model_len=args.max_model_len)
171 | sampling_params = SamplingParams(
172 | temperature=0.01,
173 | top_k=1,
174 | top_p=0.001,
175 | repetition_penalty=1.0,
176 | max_tokens=2048,
177 | )
178 |
179 | print(f"model_id: {args.model}", " generation_config:", sampling_params)
180 |
181 | class Collator(object):
182 | def __init__(self, data_name,
183 | processor,
184 | use_predefined_cats, use_think_system_prompt,
185 | is_llava=False):
186 | self.data_name = data_name
187 | self.processor = processor
188 | self.use_predefined_cats = use_predefined_cats
189 | self.use_think_system_prompt = use_think_system_prompt
190 | self.is_llava = is_llava
191 |
192 | def __call__(self, examples):
193 | ids = [e['image_id'] for e in examples]
194 | gt_objs = [e['objects'] for e in examples]
195 | gt_rels = [e['relationships'] for e in examples]
196 |
197 | llm_inputs = []
198 | images = []
199 | for example in examples:
200 | image, prompt = format_data(self.data_name, example,
201 | use_predefined_cats=self.use_predefined_cats,
202 | use_think_system_prompt=self.use_think_system_prompt)
203 |
204 | if self.is_llava:
205 | conversation = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt[-1]['content'][-1]['text']}, {"type": "image"},]}]
206 | prompt_item = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
207 | llm_inputs.append(prompt_item)
208 | else:
209 | llm_inputs.append(prompt)
210 | images.append(image)
211 |
212 | if self.is_llava:
213 | llm_inputs = self.processor(text=llm_inputs, images=images, padding=True, return_tensors="pt")
214 | input_height = input_width = [336]*len(images)
215 | else:
216 | texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
217 | for msg in llm_inputs]
218 | inputs = processor(
219 | text=texts,
220 | images=images,
221 | padding=True,
222 | return_tensors="pt",
223 | )
224 | input_height = [inputs['image_grid_thw'][idx][1].item()*14 for idx in range(len(images))]
225 | input_width = [inputs['image_grid_thw'][idx][2].item()*14 for idx in range(len(images))]
226 |
227 | return ids, gt_objs, gt_rels, input_width, input_height, llm_inputs
228 |
229 |
230 |
231 | # Load dataset from Hugging Face hub.
232 | dataset = load_dataset(args.dataset)['train']
233 |
234 | names = glob.glob(args.output_dir + "/*json")
235 | names = set([e.split('/')[-1].replace('.json', '') for e in tqdm(names)])
236 | ids = []
237 | for idx, item in enumerate(tqdm(dataset)):
238 | if item['image_id'] in names:
239 | continue
240 | ids.append(idx)
241 | dataset = dataset.select(ids)
242 | print("*"*100, " old:", len(names), " unhandled:", len(dataset))
243 |
244 |
245 | # Split dataset manually
246 | total_size = len(dataset)
247 | per_gpu_size = total_size // world_size
248 | start_idx = rank * per_gpu_size
249 | end_idx = total_size if rank == world_size - 1 else (rank + 1) * per_gpu_size
250 |
251 | subset = dataset.select(range(start_idx, end_idx)) # Select subset for this GPU
252 | print("*"*100, "\n rank:", rank, " world size:", world_size,
253 | "subset from", start_idx, " to ", end_idx, "\n",
254 | "\n data[0]:", format_data(args.dataset, dataset[0], use_predefined_cats=args.use_predefined_cats, use_think_system_prompt=args.use_think_system_prompt),
255 | "*"*100)
256 |
257 | data_loader = DataLoader(
258 | subset,
259 | batch_size=args.batch_size,
260 | shuffle=False,
261 | collate_fn=Collator(args.dataset, processor,
262 | use_predefined_cats=args.use_predefined_cats,
263 | use_think_system_prompt=args.use_think_system_prompt,
264 | is_llava=is_llava),
265 | pin_memory=True
266 | )
267 | #data_loader = accelerator.prepare(data_loader)
268 | print(f"Local ID: {local_rank} | len(dataset): {len(data_loader)}")
269 |
270 | # Create output directory if it doesn't exist.
271 | os.makedirs(args.output_dir, exist_ok=True)
272 | print(f"Save to {args.output_dir}")
273 |
274 | # Iterate over the data loader.
275 | _iter = 0
276 | for im_ids, gt_objs, gt_rels, input_width, input_height, batch in tqdm(data_loader, desc=f"Progress at rank {local_rank}"):
277 | with torch.no_grad():
278 | if is_llava:
279 | batch = batch.to(model.device)
280 | outputs = model.generate(**batch, max_new_tokens=2048)
281 | output_texts = processor.batch_decode(outputs, skip_special_tokens=True)
282 | output_texts = [text.split("ASSISTANT:")[-1] for text in output_texts]
283 | else:
284 | outputs = model.chat(batch, sampling_params=sampling_params)
285 | output_texts = [output.outputs[0].text for output in outputs]
286 |
287 |
288 | if local_rank == 0 and _iter % 100 == 0:
289 | print("*" * 100)
290 | print("nvidia-smi:")
291 | os.system("nvidia-smi")
292 | print("*" * 100)
293 | print("*"*100, "\n", "image_id:", im_ids[0], "\n",
294 | "Response:", output_texts[0], "\n",
295 | "GT objs:", gt_objs[0], " GT rels.: ", gt_rels[0],
296 | "*"*100)
297 |
298 | _iter += 1
299 | for im_id, gt_obj, gt_rel, output_text, input_iw, input_ih in zip(im_ids, gt_objs, gt_rels, output_texts, input_width, input_height):
300 | if is_qwen2vl:
301 | box_scale = [1000.0, 1000.0]
302 | else:
303 | box_scale = [input_iw, input_ih]
304 |
305 | out = {"image_id": im_id, "response": output_text,
306 | "gt_objects": gt_obj, "gt_relationships": gt_rel,
307 | "box_scale": box_scale
308 | }
309 | dst_file = os.path.join(args.output_dir, f"{im_id}.json")
310 | with open(dst_file, 'w') as fout:
311 | json.dump(out, fout)
312 |
313 | print("Rank:", rank, " finished!")
314 | torch.cuda.empty_cache()
315 | accelerator.wait_for_everyone()
316 | print("All jobs finished!")
317 |
318 | if __name__ == "__main__":
319 | main()
320 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/bbox_overlaps.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) OpenMMLab. All rights reserved.
2 | import numpy as np
3 |
4 |
5 | def bbox_overlaps(bboxes1,
6 | bboxes2,
7 | mode='iou',
8 | eps=1e-6,
9 | use_legacy_coordinate=False):
10 | """Calculate the ious between each bbox of bboxes1 and bboxes2.
11 |
12 | Args:
13 | bboxes1 (ndarray): Shape (n, 4)
14 | bboxes2 (ndarray): Shape (k, 4)
15 | mode (str): IOU (intersection over union) or IOF (intersection
16 | over foreground)
17 | use_legacy_coordinate (bool): Whether to use coordinate system in
18 | mmdet v1.x. which means width, height should be
19 | calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively.
20 | Note when function is used in `VOCDataset`, it should be
21 | True to align with the official implementation
22 | `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar`
23 | Default: False.
24 |
25 | Returns:
26 | ious (ndarray): Shape (n, k)
27 | """
28 |
29 | assert mode in ['iou', 'iof']
30 | if not use_legacy_coordinate:
31 | extra_length = 0.
32 | else:
33 | extra_length = 1.
34 | bboxes1 = bboxes1.astype(np.float32)
35 | bboxes2 = bboxes2.astype(np.float32)
36 | rows = bboxes1.shape[0]
37 | cols = bboxes2.shape[0]
38 | ious = np.zeros((rows, cols), dtype=np.float32)
39 | if rows * cols == 0:
40 | return ious
41 | exchange = False
42 | if bboxes1.shape[0] > bboxes2.shape[0]:
43 | bboxes1, bboxes2 = bboxes2, bboxes1
44 | ious = np.zeros((cols, rows), dtype=np.float32)
45 | exchange = True
46 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + extra_length) * (
47 | bboxes1[:, 3] - bboxes1[:, 1] + extra_length)
48 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + extra_length) * (
49 | bboxes2[:, 3] - bboxes2[:, 1] + extra_length)
50 | for i in range(bboxes1.shape[0]):
51 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
52 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
53 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
54 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
55 | overlap = np.maximum(x_end - x_start + extra_length, 0) * np.maximum(
56 | y_end - y_start + extra_length, 0)
57 | if mode == 'iou':
58 | union = area1[i] + area2 - overlap
59 | else:
60 | union = area1[i] if not exchange else area2
61 | union = np.maximum(union, eps)
62 | ious[i, :] = overlap / union
63 | if exchange:
64 | ious = ious.T
65 | return ious
66 |
--------------------------------------------------------------------------------
/src/utils/cocoeval.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from pycocotools.cocoeval import COCOeval
4 | from .bbox_overlaps import bbox_overlaps
5 |
6 | from .misc import get_rank
7 |
8 |
9 | class COCOEval(COCOeval):
10 | def __init__(self, *args, **kwargs):
11 | super().__init__(*args, **kwargs)
12 | self.gt_dt_valid = {}
13 |
14 | def evaluateImg(self, imgId, catId, aRng, maxDet):
15 | '''
16 | perform evaluation for single category and image
17 | :return: dict (single image results)
18 | '''
19 | p = self.params
20 | if p.useCats:
21 | gt = self._gts[imgId,catId]
22 | dt = self._dts[imgId,catId]
23 | else:
24 | gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
25 | dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
26 | if len(gt) == 0 and len(dt) ==0:
27 | return None
28 |
29 | for g in gt:
30 | if g['ignore'] or (g['area']aRng[1]):
31 | g['_ignore'] = 1
32 | else:
33 | g['_ignore'] = 0
34 |
35 | # sort dt highest score first, sort gt ignore last
36 | gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
37 | gt = [gt[i] for i in gtind]
38 | dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
39 | dt = [dt[i] for i in dtind[0:maxDet]]
40 | iscrowd = [int(o['iscrowd']) for o in gt]
41 |
42 | # load computed ious
43 | ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
44 |
45 | T = len(p.iouThrs)
46 | G = len(gt)
47 | D = len(dt)
48 | gtm = np.zeros((T,G))
49 | dtm = np.zeros((T,D))
50 | gtIg = np.array([g['_ignore'] for g in gt])
51 | dtIg = np.zeros((T,D))
52 |
53 | has_group_of = False
54 | if len(gt) > 0 and len(dt) > 0 and 'is_group_of' in gt[0]:
55 | has_group_of = True
56 |
57 | if has_group_of:
58 | dt_boxes = np.array([d['bbox'] for d in dt]).reshape(-1, 4)
59 | gt_boxes = np.array([g['bbox'] for g in gt]).reshape(-1, 4)
60 | dt_boxes[:, 2:] = dt_boxes[:, 2:] + dt_boxes[:, :2]
61 | gt_boxes[:, 2:] = gt_boxes[:, 2:] + gt_boxes[:, :2]
62 |
63 |
64 | is_group_of = np.array([g['is_group_of'] for g in gt], dtype=bool)
65 | is_group_idx = np.where(is_group_of)[0]
66 | non_group_idx = np.where(~is_group_of)[0]
67 |
68 | iofs = bbox_overlaps(dt_boxes, gt_boxes, mode='iof')
69 |
70 | non_group_gt = [gt[e] for e in non_group_idx]
71 | group_gt = [gt[e] for e in is_group_idx]
72 | # step 1: for non-group-of gts.
73 | if len(ious) > 0:
74 | for tind, t in enumerate(p.iouThrs):
75 | for dind, d in enumerate(dt):
76 | # information about best match so far (m=-1 -> unmatched)
77 | iou = min([t,1-1e-10])
78 | m = -1
79 | for gind, g in zip(non_group_idx, non_group_gt):
80 | # if this gt already matched, and not a crowd, continue
81 | if gtm[tind,gind]>0 and not iscrowd[gind]:
82 | continue
83 | # if dt matched to reg gt, and on ignore gt, stop
84 | if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
85 | break
86 |
87 | # continue to next gt unless better match made
88 | if ious[dind,gind] < iou:
89 | continue
90 | # if match successful and best so far, store appropriately
91 | iou=ious[dind,gind]
92 | m=gind
93 |
94 | if m ==-1:
95 | continue
96 | # if match made store id of match for both dt and gt
97 | dtIg[tind,dind] = gtIg[m]
98 | dtm[tind,dind] = gt[m]['id']
99 | gtm[tind,m] = d['id']
100 |
101 | # step 2: for group-of gts
102 | if len(is_group_idx) > 0 and len(iofs) > 0:
103 | for tind, t in enumerate(p.iouThrs):
104 | iof_thresh = min([t,1-1e-10])
105 | maxIoF = [-1 for _ in range(len(gt))] # store maximum IoF for each gt
106 | maxDtIndex = [-1 for _ in range(len(gt))] # store dt index with maximum IoF for each gt
107 |
108 | for dind, d in enumerate(dt):
109 | # dt already matched
110 | if dtm[tind, dind] > 0 or dtIg[tind, dind] > 0:
111 | continue
112 | #
113 | m = -1
114 | for gind, g in zip(is_group_idx, group_gt):
115 | iof = iofs[dind, gind]
116 | if iof > iof_thresh:
117 | if iof > maxIoF[gind]: # if current IoF is larger than stored maxIoF
118 | if maxDtIndex[gind] != -1: # if there was a previously stored dt for this gt
119 | dtIg[tind,maxDtIndex[gind]] = 1 # ignore the previous dt
120 | tmp = int(maxDtIndex[gind])
121 |
122 | maxIoF[gind] = iof
123 | maxDtIndex[gind] = dind
124 | dtIg[tind,dind] = 0 # not ignored
125 | m = gind
126 | else:
127 | dtIg[tind,dind] = 1 # ignore other dts inside gt
128 |
129 | if m == -1:
130 | continue
131 |
132 | # if match made store id of match for both dt and gt
133 | dtm[tind,dind] = gt[m]['id']
134 | gtm[tind,m] = d['id']
135 |
136 | else: # normal
137 | if not len(ious)==0:
138 | for tind, t in enumerate(p.iouThrs):
139 | for dind, d in enumerate(dt):
140 | # information about best match so far (m=-1 -> unmatched)
141 | iou = min([t,1-1e-10])
142 | m = -1
143 | for gind, g in enumerate(gt):
144 | # if this gt already matched, and not a crowd, continue
145 | if gtm[tind,gind]>0 and not iscrowd[gind]:
146 | continue
147 | # if dt matched to reg gt, and on ignore gt, stop
148 | if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
149 | break
150 |
151 | # continue to next gt unless better match made
152 | if ious[dind,gind] < iou:
153 | continue
154 | # if match successful and best so far, store appropriately
155 | iou=ious[dind,gind]
156 | m=gind
157 |
158 | if m ==-1:
159 | continue
160 |
161 | # if match made store id of match for both dt and gt
162 | dtIg[tind,dind] = gtIg[m]
163 | dtm[tind,dind] = gt[m]['id']
164 | gtm[tind,m] = d['id']
165 |
166 |
167 |
168 | # set unmatched detections outside of area range to ignore
169 | a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt)))
170 | dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
171 |
172 | for gind, g in enumerate(gt):
173 | if gtIg[gind] != 1:
174 | gid = g['category_id']
175 | if gid not in self.gt_dt_valid:
176 | self.gt_dt_valid[gid] = {'gts': 0, 'dts': [0]*T}
177 | self.gt_dt_valid[gid]['gts'] += 1
178 |
179 | for dind, d in enumerate(dt):
180 | for tind in range(len(p.iouThrs)):
181 | if dtIg[tind, dind] != 1:
182 | did = d['category_id']
183 | if did not in self.gt_dt_valid:
184 | self.gt_dt_valid[did] = {'gts': 0, 'dts': [0]*T}
185 | self.gt_dt_valid[did]['dts'][tind] += 1
186 |
187 | # store results for given image and category
188 | return {
189 | 'image_id': imgId,
190 | 'category_id': catId,
191 | 'aRng': aRng,
192 | 'maxDet': maxDet,
193 | 'dtIds': [d['id'] for d in dt],
194 | 'gtIds': [g['id'] for g in gt],
195 | 'dtMatches': dtm,
196 | 'gtMatches': gtm,
197 | 'dtScores': [d['score'] for d in dt],
198 | 'gtIgnore': gtIg,
199 | 'dtIgnore': dtIg,
200 | }
201 |
202 | def summarize(self):
203 | '''
204 | Compute and display summary metrics for evaluation results.
205 | Note this functin can *only* be applied on the default parameter setting
206 | '''
207 | def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
208 | p = self.params
209 | iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
210 | titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
211 | typeStr = '(AP)' if ap==1 else '(AR)'
212 | iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
213 | if iouThr is None else '{:0.2f}'.format(iouThr)
214 |
215 | aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
216 | mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
217 | if ap == 1:
218 | # dimension of precision: [TxRxKxAxM]
219 | s = self.eval['precision']
220 | # IoU
221 | if iouThr is not None:
222 | t = np.where(iouThr == p.iouThrs)[0]
223 | s = s[t]
224 | s = s[:,:,:,aind,mind]
225 | else:
226 | # dimension of recall: [TxKxAxM]
227 | s = self.eval['recall']
228 | if iouThr is not None:
229 | t = np.where(iouThr == p.iouThrs)[0]
230 | s = s[t]
231 | s = s[:,:,aind,mind]
232 | if len(s[s>-1])==0:
233 | mean_s = -1
234 | else:
235 | mean_s = np.mean(s[s>-1])
236 |
237 | if get_rank() == 0:
238 | print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
239 |
240 | return mean_s
241 |
242 | def _summarizeDets():
243 | stats = np.zeros((13,))
244 | stats[0] = _summarize(1, maxDets=self.params.maxDets[2])
245 | stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
246 | stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
247 | stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
248 | stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
249 | stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
250 | stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
251 | stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
252 | stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
253 | stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
254 | stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
255 | stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
256 | # add
257 | stats[12] = _summarize(0, iouThr=.5, maxDets=self.params.maxDets[2])
258 | return stats
259 |
260 | def _summarizeKps():
261 | stats = np.zeros((10,))
262 | stats[0] = _summarize(1, maxDets=20)
263 | stats[1] = _summarize(1, maxDets=20, iouThr=.5)
264 | stats[2] = _summarize(1, maxDets=20, iouThr=.75)
265 | stats[3] = _summarize(1, maxDets=20, areaRng='medium')
266 | stats[4] = _summarize(1, maxDets=20, areaRng='large')
267 | stats[5] = _summarize(0, maxDets=20)
268 | stats[6] = _summarize(0, maxDets=20, iouThr=.5)
269 | stats[7] = _summarize(0, maxDets=20, iouThr=.75)
270 | stats[8] = _summarize(0, maxDets=20, areaRng='medium')
271 | stats[9] = _summarize(0, maxDets=20, areaRng='large')
272 | return stats
273 | if not self.eval:
274 | raise Exception('Please run accumulate() first')
275 | iouType = self.params.iouType
276 | if iouType == 'segm' or iouType == 'bbox':
277 | summarize = _summarizeDets
278 | elif iouType == 'keypoints':
279 | summarize = _summarizeKps
280 | self.stats = summarize()
281 |
--------------------------------------------------------------------------------
/src/utils/wordnet.py:
--------------------------------------------------------------------------------
1 | import nltk
2 | from nltk.corpus import wordnet
3 | import json
4 |
5 | # Ensure nltk resources are downloaded
6 | #nltk.download('wordnet')
7 |
8 | def find_synonym_map(set_a, set_b_list):
9 | """
10 | Builds a synonym map for items in set A if a synonym exists in set B.
11 |
12 | Args:
13 | set_a: A list of strings (set A).
14 | set_b_list: A list of strings (set B).
15 |
16 | Returns:
17 | A dictionary representing the synonym map.
18 | Keys are items from set A that have synonyms in set B.
19 | Values are the corresponding synonym items from set B.
20 | """
21 | synonym_map = {}
22 | set_b = set(set_b_list) # Convert set B list to set for faster lookup
23 |
24 | for item_a in set_a:
25 | found_synonym = False
26 | # First, check for direct match (item_a itself in set_b)
27 | if item_a in set_b:
28 | synonym_map[item_a] = item_a
29 | found_synonym = True
30 | else:
31 | # If no direct match, try to find synonyms using WordNet
32 | for syn in wordnet.synsets(item_a):
33 | for lemma in syn.lemmas():
34 | lemma_name = lemma.name()
35 | if lemma_name in set_b:
36 | synonym_map[item_a] = lemma_name
37 | found_synonym = True
38 | break # Found a synonym, move to next item in set_a
39 | if found_synonym:
40 | break # Found a synonym, move to next item in set_a
41 | if not found_synonym:
42 | # Check for synonyms by splitting words if item_a is a phrase
43 | words_in_a = item_a.split()
44 | if len(words_in_a) > 1:
45 | for word_a in words_in_a:
46 | if word_a in set_b:
47 | synonym_map[item_a] = word_a
48 | found_synonym = True
49 | break
50 | else:
51 | for syn in wordnet.synsets(word_a):
52 | for lemma in syn.lemmas():
53 | lemma_name = lemma.name()
54 | if lemma_name in set_b:
55 | synonym_map[item_a] = lemma_name
56 | found_synonym = True
57 | break
58 | if found_synonym:
59 | break
60 | if found_synonym:
61 | break
62 |
63 |
64 | return synonym_map
65 |
66 |
67 |
68 | if __name__ == "__main__":
69 | set_A = ['quilt', 'sensor', 'fish', 'wall socket', 'goat', 'mouse', 'snow globe', 'santa', 'tie', 'mushroom', 'storefront', 'leaf', 'jar', 'feeder', 'child', 'crib', 'propeller', 'headboard', 'outhouse', 'lemon', 'weather vane', 'snowboarder', 'bike route', 'statue', 'cork', 'ice rink', 'log cabin', 'donkey', 'catering truck', 'flowerbed', 'boarding bridge', 'post it', 'lid', 'saucer', 'water tower', 'blue pants', 'hand', 'planter', 'brickwall', 'red bandana', 'bowl', 'banana', 'wheat thins', 'keyboard', 'wine bottle', 'crosswalk', 'game machine', 'sandals', 'mongoose', 'apple', 'pond', 'clothes', 'marker', 'white chair', 'earphones', 'medals', 'blocks', 'panda', 'saw', 'sun', 'taxi', 'blue mug', 'rugby ball', 'first aid kit', 'candle', 'kite', 'zebra', 'pumpkin', 'yellow truck', 'hotdog', 'bookstore', 'charger', 'grape vine', 'file', 'speaker', 'wetsuit', 'portrait', 'gondola', 'coca cola', 'leaves', 'purple flowers', 'pancake', 'racket', 'birdhouse', 'tablecloth', 'chicken', 'sticker', 'ceiling fan', 'scaffolding', 'sunglasses', 'cheese', 'aircraft carrier', 'sleeping bag', 'rock', 'telescope', 'boat', 'piano', 'flagpole', 'brick wall', 'coleslaw', 'tablet', 'shoes', 'thermos', 'corkboard', 'cowboy hat', 'cheerios', 'restaurant sign', 'floral pattern', 'sand dune', 'mug', 'buildings', 'oar', 'towel holder', 'toaster', 'backpack', 'zebra bag', 'crane', 'vase', 'server rack', 'wood', 'beer glass', 'tv stand', 'can', 'moon', 'fireplug', 'pen holder', 'antelope', 'keys', 'whipped cream', 'shelves', 'external drive', 'driver', 'calendar', 'bathroom', 'church', 'white truck', 'plate', 'vegetables', 'book', 'lights', 'hair dryer', 'knife block', 'hill', 'glue', 'star', 'stuffed animal', 'countertop', 'rainbow', 'tracks', 'green beans', 'noodles', 'chips', 'coral', 'wine', 'pepsi', 'jacket', 'police', 'bridge', 'school bus', 'slide', 'ball', 'helicopter', 'cymbal', 'pot', 'porta potty', 'wave', 'refrigerator', 'arch', 'ribs', 'basketball', 'dog', 'clouds', 'projector', 'sandbag wall', 'glasses', 'pillars', 'knee pads', 'folder', 'palm', 'bread', 'tent', 'napkin', 'gazebo', 'printer', 'peas', 'sauce', 'police car', 'lamp shade', 'picnic table', 'cigarette', 'staircase', 'driveway', 'yellow mug', 'candy', 'mccafe drink', 'soccer ball', 'coffee pot', 'mousepad', 'snow sign', 'fence', 'crates', 'well', 'asparagus', 'go sign', 'pail', 'stapler', 'bookcase', 'field', 'surfer', 'vending machine', 'store', 'buoy', 'postbox', 'tank', 'floor', 'internet', 'wardrobe', 'lake', 'water', 'hammer', 'sheep', 'pancakes', 'sky', 'paint', 'wreath', 'duck', 'wooden box', 'paintbrush', 'bouquet', 'jars', 'tray', 'pipe', 'tissue box', 'net', 'engine', 'closet', 'salad', 'toothpaste', 'vent', 'feathers', 'river', 'tennis bag', 'greenhouse', 'sculpture', 'bed', 'green fabric', 'name card', 'rockwall', 'scorpions', 'path', 'record player', 'parking sign', 'giraffe', 'watch', 'bath', 'coca cola bottle', 'metal grates', 'air conditioner', 'pitcher', 'deer head', 'menu board', 'bun', 'ski poles', 'grass', 'wrench', 'butterfly', 'purse', 'iceberg', 'sidewalk', 'flowers', 'gray car', 'atm', 'ice', 'socks', 'tractor', 'egg', 'dock', 'wind turbine', 'gloves', 'clock', 'kiteboard', 'barn', 'ornament', 'shell', 'toilet tissue', 'cow', 'truck', 'blinds', 'table', 'podium', 'ship', 'bottle', 'rocks', 'cucumber', 'seat', 'billboard', 'file cabinet', 'ruler', 'chain', 'dough', 'meat', 'scooter', 'green fridge', 'wind chime', 'microphone', 'tomato', 'ski lift', 'chocolate sauce', 'shelf', 'fries', 'oven', 'reflection', 'power pole', 'cross', 'cutlery holder', 'directory', 'berries', 'water glass', 'couscous', 'ottoman', 'box', 'plant', 'water bottle', 'snowboard', 'door', 'pizza', 'waterfall', 'stroller', 'stop sign', 'bus', 'paper bag', 'snack bar', 'headphones', 'balcony', 'bowling lane', 'apples', 'bedroom', 'tennis shorts', 'pineapple', 'stained glass', 'canvas', 'milk box', 'cutting board', 'shed', 'cereal', 'van', 'basket', 'toothbrush', 'tricycle', 'basketball court', 'fire extinguisher', 'pants', 'grill', 'boxcar', 'rug', 'greenwall', 'exit sign', 'conveyor', 'house', 'flower hanging', 'fire hydrant', 'bowling lanes', 'barrel', 'blanket', 'tile', 'cake', 'tv', 'streetlight', 'laptop', 'stone wall', 'water heater', 'sand', 'yellow strap', 'hamburger', 'game case', 'pole', 'tower', 'minnie', 'vehicle', 'firetruck', 'spectators', 'mirror', 'fridge', 'books', 'sea lion', 'tree', 'scoreboard', 'bulletin board', 'file drawer', 'board', 'hat', 'wheelchair', 'tennis player', 'wall', 'fountain', 'candles', 'ceiling', 'ambulance', 'palm tree', 'beach', 'window', 'building', 'police officer', 'drill', 'bar', 'remote', 'coffee table', 'crosswalk sign', 'lightswitch', 'jeans', 'basketball jersey', 'sandbox', 'gun', 'kiwi', 'microwave', 'power line', 'tail', 'no biking sign', 'lamp base', 'restaurant', 'stairs', 'sofa', 'pool', 'bike lane sign', 'wii', 'home plate', 'antenna', 'parrot', 'gravel', 'elephant', 'iguana', 'tag', 'bacon', 'toilet', 'squirrel', 'hot tub', 't shirt', 'side table', 'street', 'bat', 'doll', 'whiteboard', 'potatoes', 'label', 'pool table', 'sink', 'pear', 'coffee', 'skis', 'gas pump', 'placemat', 'whirlpool', 'sugar pack', 'stool', 'paper', 'frisbee', 'monitor', 'car', 'handbag', 'courtroom', 'shower', 'snowflake', 'castle', 'wine glass', 'track', 'number', 'wagon', 'ocean', 'glass roof', 'cauliflower', 'gate', 'parking by permit only', 'concrete path', 'glass', 'shirt', 'bike rack', 'plastic container', 'stove', 'paper towel', 'cup', 'snow', 'container', 'toys', 'fan', 'base', 'bowling ball', 'stuffed toy', 'lime', 'parking permit only sign', 'ski', 'pen', 'person', 'rocky slope', 'carrousel', 'curtain', 'laptop case', 'towel', 'sandwich bar', 'baby', 'water fountain', 'shrub', 'cones', 'chair', 'lottery sign', 'baptism tub', 'railing', 'goggles', 'fork', 'green car', 'forklift', 'road', 'cushion', 'cabinet', 'computer', 'juice', 'camera', 'lock', 'power lines', 'high heel shoe', 'red car', 'bowling pins', 'blue seat', 'bed sheet', 'doily', 'vacuum', 'ice cream', 'coffee cup', 'phone booth', 'bucket', 'ping pong table', 'phone', 'violin', 'lampshade', 'rice', 'cloud', 'jet', 'structure', 'sailboat', 'green plants', 'houses', 'invitation', 'pine tree', 'speedometer', 'skier', 'lamp', 'belt', 'text', 'concrete wall', 'green shirt', 'white gate', 'tarp', 'foosball table', 'desk', 'carrots', 'wafer', 'busstop', 'surfboard', 'couch', 'blue court', 'grapes', 'trolley', 'pouch', 'television', 'lighthouse', 'leg', 'soda can', 'candy bar', 'barbed wire', 'sweater', 'fish tank', 'bell', 'bookshelf', 'sack', 'seats', 'bison', 'chocolate', 'guardrail', 'pitcher mound', 'cactus', 'flipflops', 'cross traffic sign', 'scissors', 'hoop', 'airplane', 'carrot', 'underwear', 'food truck', 'bus sign', 'raft', 'bulldozer', 'dishwasher', 'fruit', 'map', 'mat', 'sail', 'drum', 'lanyard', 'toy', 'painting', 'cups', 'ground', 'pump', 'branch', 'spectator', 'inflatable', 'vegetation', 'stage', 'yellow flower', 'concrete', 'green box', 'button', 'food', 'parking meter', 'log', 'cable', 'scrambled eggs', 'clock tower', 'blue wall', 'motorcycle', 'washing machine', 'anchor', 'cockpit', 'screwdriver', 'shutters', 'tennis court', 'baby monkey', 'strawberry', 'skeleton', 'wing', 'train track', 'lollipop', 'balloon', 'tennis racket', 'candle holder', 'horse', 'goal', 'bath mat', 'shoe', 'dryer', 'baseball field', 'flag', 'beer bottle', 'chocolate dessert', 'lunchbox', 'milking machine', 'trashcan', 'ski pole', 'garage', 'skull', 'cd', 'shutter', 'knife', 'cell phone', 'wheel', 'manhole', 'roof', 'counter', 'battery', 'people', 'soil', 'broccoli', 'white cabinet', 'teddy bear', 'tissue paper', 'drain', 'banana leaf', 'papers', 'coffee maker', 'pufferfish', 'z crossing', 'fire', 'sign', 'guitar', 'chimney', 'cheetah', 'kettle', 'corn', 'bathtub', 'bag', 'sausage', 'train', 'fireplace', 'cloth', 'paella', 'platform', 'ramp', 'fire escape', 'pepper', 'mountain', '4 way sign', 'hedge', 'toast', 'calculator', 'trash can', 'boy', 'menu', 'sandwich', 'bank', 'salt', 'lifeguard', 'mountains', 'mobile home', 'frying egg', 'street sign', 'spaghetti', 'magazine', 'nest', 'water meter', 'goalpost', 'yellow pole', 'donut', 'ladder', 'drawer', 'shorts', 'escalator', 'newspaper', 'toilet paper', 'store sign', 'plastic bag', 'tennis net', 'street light', 'monkey bar', 'parking lot', 'telephone pole', 'case', 'spatula', 'soap', 'pluto', 'scarf', 'brick column', 'screen', 'bush', 'motor', 'elevator', 'microphone stand', 'blue building', 'tennis ball', 'spoon', 'flower arrangement', 'computer monitor', 'bench', 'dirt', 'pillow', 'woman', 'lettuce', 'dresser', 'ketchup', 'tea', 'metal object', 'tow truck', 'light', 'mouse pad', 'potted plant', 'bus stop', 'no parking sign', 'fruits', 'moss', 'plug', 'thermostat', 'traffic light', 'machine', 'goose', 'tissue', 'polling station sign', 'earrings', 'mirror ball', 'umbrella', 'fishbowl', 'nightstand', 'baseball', 'mailbox', 'towel bar', 'trees', 'scorpion', 'cd rack', 'glove', 'office', 'picture frame', 'white fridge', 'notebook', 'flower', 'knee pad', 'poster', 'handle', 'faucet', 'minibar', 'plane', 'bicycle', 'spiderman', 'column', 'soap dish', 'kitkat', 'carriage', 'hose', 'frame', 'potato', 'dome', 'cereal box', 'spray bottle', 'trash', 'trailer', 'sidecar', 'chandelier', 'collaborate sign', 'tripod', 'record', 'garbage can', 'monkey', 'watering can', 'street lamp', 'cat', 'mannequin', 'helmet', 'baseball glove', 'socket', 'valve', 'cowboy', 'arm', 'straw', 'toilet seat', 'wallet', 'collar', 'bird', 'package', 'grape', 'bear', 'cone', 'penne', 'phonebooth', 'no entry', 'radiator', 'lily pad', 'one way sign', 'bagel', 'lantern', 'armchair', 'air hockey table', 'orange juice', 'stripes', 'suv', 'suitcase', 'bike', 'alarm', 'steak', 'grate', 'toilet brush', 'luggage', 'onions', 'pier', 'skateboard', 'cash register', 'spire', 'pan', 'red counter', 'picture', 'orange', 'teapot', 'tank top', 'cart', 'rope', 'man', 'windmill', 'mashed potatoes', 'eraser', 'yogurt', 'tiger']
70 | set_B_list = ['airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
71 |
72 |
73 | print(find_synonym_map(["bicycle"], set_B_list))
74 | import pdb; pdb.set_trace()
75 |
76 | synonym_mapping = find_synonym_map(set_A, set_B_list)
77 | print(synonym_mapping)
78 | with open("synonym_mapping.json", 'w') as fout:
79 | json.dump(synonym_mapping, fout)
80 |
81 |
--------------------------------------------------------------------------------
/src/utils/zeroshot_triplet.pytorch:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/utils/zeroshot_triplet.pytorch
--------------------------------------------------------------------------------
/src/vg150_eval.py:
--------------------------------------------------------------------------------
1 | from tqdm import tqdm
2 | import numpy as np
3 | import copy
4 |
5 | import torch
6 | from pycocotools.coco import COCO
7 |
8 | from src.utils.sgg_eval import SggEvaluator
9 |
10 |
11 | VG150_OBJ_CATEGORIES = ['__background__', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra']
12 |
13 | VG150_PREDICATES = ['__background__', "above", "across", "against", "along", "and", "at", "attached to", "behind", "belonging to", "between", "carrying", "covered in", "covering", "eating", "flying in", "for", "from", "growing on", "hanging from", "has", "holding", "in", "in front of", "laying on", "looking at", "lying on", "made of", "mounted on", "near", "of", "on", "on back of", "over", "painted on", "parked on", "part of", "playing", "riding", "says", "sitting on", "standing on", "to", "under", "using", "walking in", "walking on", "watching", "wearing", "wears", "with"]
14 |
15 |
16 | def compute_iou(boxA, boxB):
17 | # box format: [x1, y1, x2, y2]
18 | xA = max(boxA[0], boxB[0])
19 | yA = max(boxA[1], boxB[1])
20 | xB = min(boxA[2], boxB[2])
21 | yB = min(boxA[3], boxB[3])
22 | interWidth = max(0, xB - xA)
23 | interHeight = max(0, yB - yA)
24 | interArea = interWidth * interHeight
25 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1])
26 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1])
27 | unionArea = boxAArea + boxBArea - interArea
28 | return 0.0 if unionArea == 0 else interArea / unionArea
29 |
30 |
31 | class MyDataset(object):
32 | def __init__(self, db, db_type='vg150'):
33 | self._coco = None
34 | self.db_type = db_type
35 | assert self.db_type in ['vg150', 'psg']
36 |
37 | if self.db_type == 'vg150':
38 | self.ind_to_classes = VG150_OBJ_CATEGORIES
39 | self.ind_to_predicates = VG150_PREDICATES
40 | self.name2classes = {name: cls for cls, name in enumerate(self.ind_to_classes) if name != "__background__"}
41 | self.categories = [{'supercategory': 'none', # not used?
42 | 'id': idx,
43 | 'name': self.ind_to_classes[idx]}
44 | for idx in range(len(self.ind_to_classes)) if self.ind_to_classes[idx] != '__background__'
45 | ]
46 | elif self.db_type == 'psg':
47 | psg_categories = json.load(open("src/psg_categories.json"))
48 | PSG_OBJ_CATEGORIES = psg_categories['thing_classes'] + psg_categories['stuff_classes']
49 | PSG_PREDICATES = psg_categories['predicate_classes']
50 | self.ind_to_classes = PSG_OBJ_CATEGORIES
51 | self.ind_to_predicates = ['__background__'] + PSG_PREDICATES
52 | self.name2classes = {name: cls for cls, name in enumerate(self.ind_to_classes) if name != "__background__"}
53 | self.categories = [{'supercategory': 'none', # not used?
54 | 'id': idx,
55 | 'name': self.ind_to_classes[idx]}
56 | for idx in range(len(self.ind_to_classes)) if self.ind_to_classes[idx] != '__background__'
57 | ]
58 |
59 |
60 | self.images = []
61 | self.annotations = []
62 | self.ids = []
63 | for item in tqdm(db):
64 | im_id = item['image_id']
65 |
66 | self.images.append({'id': im_id})
67 | self.ids.append(im_id)
68 | objs = json.loads(item['objects'])
69 |
70 | ann = {'image_id': im_id, 'labels': [], 'boxes': []}
71 | names = []
72 | for obj in objs:
73 | name, box = obj['id'].split('.')[0], obj['bbox']
74 | names.append(obj['id'])
75 | cls = self.name2classes[name]
76 | ann['labels'].append(cls)
77 | ann['boxes'].append(box)
78 |
79 | rels = json.loads(item['relationships'])
80 | edges = []
81 | for rel in rels:
82 | sub = rel['subject']
83 | obj = rel['object']
84 | pred = rel['predicate']
85 | sid = names.index(sub)
86 | oid = names.index(obj)
87 | tmp = [sid, oid, self.ind_to_predicates.index(pred)]
88 | edges.append(tmp)
89 |
90 | ann['edges'] = edges
91 | self.annotations.append(ann)
92 |
93 | print("total images", len(self.images), self.images[0])
94 |
95 | def get_groundtruth(self, index):
96 | ann = self.annotations[index]
97 |
98 | return torch.as_tensor(ann['boxes']), \
99 | torch.as_tensor(ann['labels']), \
100 | torch.as_tensor(ann['edges'])
101 |
102 |
103 |
104 | @property
105 | def coco(self):
106 | if self._coco is None:
107 | _coco = COCO()
108 | coco_dicts = dict(
109 | images=self.images,
110 | annotations=[],
111 | categories=self.categories)
112 |
113 | for ann in tqdm(self.annotations):
114 | for cls, box in zip(ann['labels'], ann['boxes']):
115 | assert len(box) == 4
116 | item = {
117 | 'area': (box[3] - box[1]) * (box[2] - box[0]),
118 | 'bbox': [box[0], box[1], box[2] - box[0], box[3] - box[1]], # xywh
119 | 'category_id': cls,
120 | 'image_id': ann['image_id'],
121 | 'id': len(coco_dicts['annotations']),
122 | 'iscrowd': 0,
123 | }
124 | coco_dicts['annotations'].append(item)
125 |
126 | _coco.dataset = coco_dicts
127 | _coco.createIndex()
128 | self._coco = _coco
129 |
130 | return self._coco
131 |
132 | def refine_node_edge(obj):
133 | """ remove speical chars in the name. """
134 | obj = obj.replace("_", " ").replace("-", " ")
135 | return obj.strip().lower()
136 |
137 | if __name__ == "__main__":
138 | import os
139 | import sys
140 | import json
141 | from datasets import load_dataset
142 | import torch
143 | from collections import defaultdict
144 |
145 | preds = json.load(open(sys.argv[2]))
146 | db = load_dataset(sys.argv[1])['train']
147 | db_type = 'vg150' if 'psg' not in sys.argv[1] else 'psg'
148 | dataset = MyDataset(db, db_type)
149 |
150 |
151 | ngR = []
152 | mR = defaultdict(list)
153 | ngR_per_image = []
154 | mR_per_image = defaultdict(list)
155 | num_gt_rels = 0
156 | for gt in tqdm(db):
157 | im_id = gt['image_id']
158 | if im_id in preds: # to prevent wrong generated image_id
159 | pred = preds[im_id]
160 | else:
161 | pred = None
162 | gt_rels = json.loads(gt['relationships'])
163 | gt_objects = json.loads(gt['objects'])
164 | gt_boxes = {refine_node_edge(obj['id']): obj['bbox'] for obj in gt_objects}
165 | recall = []
166 | recall_per_cat = defaultdict(list)
167 |
168 | for gt_rel in gt_rels:
169 | num_gt_rels += 1
170 | match = False
171 | gt_pred = refine_node_edge(gt_rel['predicate'])
172 | gt_sub_name = refine_node_edge(gt_rel['subject'])
173 | gt_obj_name = refine_node_edge(gt_rel['object'])
174 |
175 | if pred is not None:
176 | for pred_rel in pred['relation_tuples']:
177 | if refine_node_edge(gt_pred) != refine_node_edge(pred_rel[-1]):
178 | continue
179 |
180 | if gt_sub_name.split('.')[0].strip() != refine_node_edge(pred_rel[0]).split('.')[0].strip() or \
181 | gt_obj_name.split('.')[0].strip() != refine_node_edge(pred_rel[2]).split('.')[0].strip():
182 | continue
183 |
184 | sub_iou = compute_iou(gt_boxes[gt_sub_name], pred_rel[1])
185 | obj_iou = compute_iou(gt_boxes[gt_obj_name], pred_rel[3])
186 | if sub_iou >= 0.5 and obj_iou >= 0.5:
187 | match = True
188 | break
189 |
190 | recall.append(match)
191 | ngR.append(match)
192 | mR[gt_pred].append(match)
193 | recall_per_cat[gt_pred].append(match)
194 |
195 | if len(recall) > 0:
196 | ngR_per_image.append(sum(recall) / len(recall) )
197 |
198 | for k in recall_per_cat.keys():
199 | mR_per_image[k].append(sum(recall_per_cat[k]) / len(recall_per_cat[k]) )
200 |
201 | mR_list = []
202 | for k in mR.keys():
203 | tmp = round(np.mean(mR[k]), 4)
204 | mR_list.append((k, tmp))
205 |
206 | ngR_per_image = np.mean(ngR_per_image)
207 | mR_per_image = [(cat, round(np.mean(mR_per_image[cat]), 4) ) for cat in mR_per_image.keys()]
208 |
209 |
210 | sgg_evaluator = SggEvaluator(dataset, iou_types=("bbox","relation"),
211 | num_workers=4,
212 | num_rel_category=len(dataset.ind_to_predicates))
213 |
214 | def to_torch(item):
215 | for k in item.keys():
216 | try:
217 | item[k] = torch.as_tensor(item[k])
218 | except:
219 | pass
220 |
221 |
222 | k0 = None
223 | for k in tqdm(preds.keys()):
224 | k0 = k
225 | to_torch(preds[k])
226 | if 'graph' in preds[k]:
227 | graph = preds[k]['graph']
228 | to_torch(graph)
229 | preds[k]['graph'] = graph
230 |
231 | print('id:', k0, ' v:', preds[k0])
232 |
233 | sgg_evaluator.update(preds)
234 | sgg_evaluator.synchronize_between_processes()
235 |
236 | sgg_res = sgg_evaluator.accumulate()
237 | sgg_evaluator.summarize()
238 | sgg_evaluator.reset()
239 |
240 |
241 | print("whole ng recall list:", mR_list)
242 | print(f'whole ngR: {np.mean(ngR) * 100:.2f}')
243 | print(f'whole mean of ngR: {sum([e[1] for e in mR_list]) / len(mR_list) * 100:.2f}')
244 | print(f'ngR per image:{ngR_per_image * 100:.2f}')
245 | print(f'mean ngR per image: {sum([e[1] for e in mR_per_image]) / len(mR_per_image) * 100:.2f}')
246 | print(f'mean ngR list:{mR_per_image}')
247 |
--------------------------------------------------------------------------------
/tests/test_fsdp.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.distributed as dist
4 | import os
5 | import time
6 | from io import BytesIO
7 | import base64
8 | import json
9 | from contextlib import nullcontext
10 |
11 | import deepspeed
12 | from accelerate.utils import DistributedType
13 | from accelerate import Accelerator
14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
15 | from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration
16 |
17 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient
18 | from datasets import load_dataset
19 |
20 |
21 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
22 | buffer = BytesIO()
23 | image.save(buffer, format=format)
24 | return base64.b64encode(buffer.getvalue()).decode("utf-8")
25 |
26 |
27 | def prepare_messages(image):
28 | encoded_image_text = encode_image_to_base64(image)
29 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}"
30 |
31 | messages_vllm = [
32 | {"role": "system", "content": "You are a helpful assistant."},
33 | {
34 | "role": "user",
35 | "content": [
36 | {"type": "image_url", "image_url": {"url": base64_qwen}},
37 | {"type": "text", "text": "Describe this image."},
38 | ],
39 | },
40 | ]
41 |
42 | return messages_vllm
43 |
44 |
45 |
46 | def main():
47 | model_name = "Qwen/Qwen2-VL-7B-Instruct"
48 |
49 | accelerator = Accelerator()
50 | device = accelerator.device
51 |
52 | processor = Qwen2VLProcessor.from_pretrained(model_name, max_pixels=512*28*28)
53 | model = Qwen2VLForConditionalGeneration.from_pretrained(
54 | model_name,
55 | torch_dtype=torch.bfloat16,
56 | attn_implementation='flash_attention_2'
57 | )
58 | model = accelerator.prepare_model(model)
59 |
60 |
61 | """ test clients """
62 | def get_gateway_client_id(world_size, rank, gpus_per_node, num_clients):
63 | num_nodes = world_size // gpus_per_node
64 | client_ranks = [
65 | (i % num_nodes) * gpus_per_node + (i // num_nodes)
66 | for i in range(num_clients)
67 | ]
68 | if rank in client_ranks:
69 | return client_ranks.index(rank)
70 | return None
71 |
72 | rank = accelerator.process_index
73 | world_size = accelerator.num_processes
74 | gpus_per_node = torch.cuda.device_count()
75 |
76 |
77 | hosts, ports = [], []
78 | for line in open("ip_port_test.txt"):
79 | host, port =line.strip().split(':')
80 | hosts.append(host.strip())
81 | ports.append(port.strip())
82 |
83 | num_clients = len(hosts)
84 | client_id = get_gateway_client_id(world_size, rank, gpus_per_node, num_clients)
85 |
86 | # create N=len(hosts) clients
87 | if client_id is not None:
88 | vllm_client = VLLMClient(
89 | hosts, ports,
90 | connection_timeout=360,
91 | client_rank = client_id
92 | )
93 | print("*"*100, "\n Create VLLMClient at rank:", rank, " cliend_rank:", client_id)
94 | else:
95 | vllm_client = None
96 |
97 |
98 | """ test chat """
99 |
100 | if accelerator.is_main_process:
101 |
102 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train']
103 | prompts = []
104 | for kk, item in enumerate(tqdm(db)):
105 | if len(prompts) >=128: break
106 | prompt = prepare_messages(item['image'])
107 | prompts.append(prompt)
108 |
109 | print("[INFO] Running vLLM inference...")
110 | t0 = time.time()
111 | prompts = [json.dumps(e) for e in prompts]
112 | print(len(prompts))
113 |
114 | generated_ids = vllm_client.loop.run_until_complete(vllm_client.chat(prompts, n=1, max_tokens=50,
115 | top_p=0.001, top_k=1, temperature=0.01))
116 |
117 | t1 = time.time() - t0
118 | #generated_ids = [torch.as_tensor(e) for e in generated_ids]
119 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True),
120 | print(len(outputs))
121 | print("****** vLLM generated text:",
122 | outputs,
123 | " cost:", t1)
124 |
125 | """ test weight synchronization """
126 | max_chunk_size = 100 * 1024 * 1024 # 100 MB
127 | param_chunk = []
128 | current_chunk_size = 0
129 | debug_file = "tests/debug_%s.log" % accelerator.process_index
130 | with open(debug_file, 'w') as fout:
131 | pass
132 |
133 | is_fsdp_used = accelerator.distributed_type == DistributedType.FSDP
134 | deepspeed_plugin = accelerator.state.deepspeed_plugin
135 | zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3
136 | gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext
137 |
138 | if is_fsdp_used:
139 | print("*"*100, "\n Test FSDP ...\n", "*"*100)
140 | with FSDP.summon_full_params(model, recurse=True, writeback=False):
141 | for name, param in model.named_parameters():
142 | if param.data is None:
143 | continue
144 | if vllm_client is not None:
145 | # Calculate the size of this parameter in bytes
146 | param_size = param.numel() * param.element_size()
147 |
148 | param_chunk.append((name, param.data))
149 | current_chunk_size += param_size
150 |
151 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk.
152 | if current_chunk_size >= max_chunk_size:
153 | if os.path.exists(debug_file):
154 | with open(debug_file, 'a') as fout:
155 | names = [(p[0], p[1].shape) for p in param_chunk]
156 | cmd = f"FSDP --- rank={accelerator.process_index}, send params={names}\n"
157 | fout.write(cmd)
158 | vllm_client.update_model_in_chunks_from_named_list(param_chunk)
159 | # Reset for the next chunk
160 | param_chunk = []
161 | current_chunk_size = 0
162 | else:
163 | print("*"*100, "\n Test non-FSDP ...\n", "*"*100)
164 | for name, param in self.model.named_parameters():
165 | with gather_if_zero3([param]): # gather if zero3 used
166 | if vllm_client is not None:
167 | # Calculate the size of this parameter in bytes
168 | param_size = param.numel() * param.element_size()
169 |
170 | param_chunk.append((name, param.data))
171 | current_chunk_size += param_size
172 |
173 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk.
174 | if current_chunk_size >= max_chunk_size:
175 | if os.path.exists(debug_file):
176 | with open(debug_file, 'a') as fout:
177 | names = [(p[0], p[1].shape) for p in param_chunk]
178 | cmd = f"rank={accelerator.process_index}, send params={names}\n"
179 | fout.write(cmd)
180 | vllm_client.update_model_in_chunks_from_named_list(param_chunk)
181 | # Reset for the next chunk
182 | param_chunk = []
183 | current_chunk_size = 0
184 |
185 | # If any parameters remain that didn't reach the 100MB threshold, update them as well.
186 | if param_chunk and vllm_client is not None:
187 | if os.path.exists(debug_file):
188 | with open(debug_file, 'a') as fout:
189 | names = [(p[0], p[1].shape) for p in param_chunk]
190 | cmd = f"rank={accelerator.process_index}, send params={names}\n"
191 | fout.write(cmd)
192 | vllm_client.update_model_in_chunks_from_named_list(param_chunk)
193 |
194 |
195 |
196 |
197 |
198 | if __name__ == "__main__":
199 | main()
200 |
--------------------------------------------------------------------------------
/tests/test_sampler.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from typing import Optional
3 |
4 | from transformers import Trainer, TrainingArguments
5 |
6 | from torch.utils.data import Sampler
7 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed
8 |
9 | BATCH_PER_DEVICE=2
10 | GRAD_ACC=3
11 | # 2*4*3 // 8 = 3
12 |
13 | class RepeatRandomSampler(Sampler):
14 | """
15 | Sampler that repeats the indices of a dataset in a structured manner.
16 |
17 | Args:
18 | data_source (`Sized`):
19 | Dataset to sample from.
20 | mini_repeat_count (`int`):
21 | Number of times to repeat each index per batch.
22 | batch_size (`int`, *optional*, defaults to `1`):
23 | Number of unique indices per batch.
24 | repeat_count (`int`, *optional*, defaults to `1`):
25 | Number of times to repeat the full sampling process.
26 | seed (`int` or `None`, *optional*, defaults to `None`):
27 | Random seed for reproducibility (only affects this sampler).
28 |
29 | Example:
30 | ```python
31 | >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4)
32 | >>> list(sampler)
33 | [4, 4, 3, 3, 0, 0,
34 | 4, 4, 3, 3, 0, 0,
35 | 4, 4, 3, 3, 0, 0,
36 | 4, 4, 3, 3, 0, 0,
37 |
38 | 1, 1, 2, 2, 6, 6,
39 | 1, 1, 2, 2, 6, 6,
40 | 1, 1, 2, 2, 6, 6,
41 | 1, 1, 2, 2, 6, 6]
42 | ```
43 |
44 | ```txt
45 | mini_repeat_count = 3
46 | - - -
47 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
48 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
49 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, |
50 | repeat_count = 2
51 | 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, |
52 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, |
53 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] |
54 | --------- --------- --------- ---------
55 | --------- --------- --------- ---------
56 | --------- --------- --------- ---------
57 | batch_size = 12
58 | ```
59 | """
60 |
61 | def __init__(
62 | self,
63 | data_source ,
64 | mini_repeat_count: int,
65 | batch_size: int = 1,
66 | repeat_count: int = 1,
67 | seed: Optional[int] = None,
68 | ):
69 | self.data_source = data_source
70 | self.mini_repeat_count = mini_repeat_count
71 | self.batch_size = batch_size
72 | self.repeat_count = repeat_count
73 | self.num_samples = len(data_source)
74 | self.seed = seed
75 | self.generator = torch.Generator() # Create a local random generator
76 | if seed is not None:
77 | self.generator.manual_seed(seed)
78 |
79 | def __iter__(self):
80 | # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7)
81 | indexes = torch.randperm(self.num_samples, generator=self.generator).tolist()
82 |
83 | # [2, 4, 3, 1, 0, 6, 5]
84 | # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3)
85 | indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)]
86 |
87 | # [[2, 4, 3], [1, 0, 6], [5]]
88 | # -> [[2, 4, 3], [1, 0, 6]]
89 | indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size]
90 |
91 | for chunk in indexes:
92 | for _ in range(self.repeat_count):
93 | for index in chunk:
94 | for _ in range(self.mini_repeat_count):
95 | yield index
96 |
97 | def __len__(self) -> int:
98 | return self.num_samples * self.mini_repeat_count * self.repeat_count
99 |
100 |
101 |
102 |
103 |
104 | # Dummy dataset
105 | class DummyDataset(torch.utils.data.Dataset):
106 | def __init__(self, size=100):
107 | self.data = list(range(size))
108 |
109 | def __len__(self):
110 | return len(self.data)
111 |
112 | def __getitem__(self, idx):
113 | return {
114 | "input_ids": torch.tensor([idx]),
115 | "labels": torch.tensor([idx]),
116 | }
117 |
118 | # Dummy model
119 | class DummyModel(torch.nn.Module):
120 | def __init__(self):
121 | super().__init__()
122 | self.linear = torch.nn.Linear(1, 1)
123 |
124 | def forward(self, input_ids=None, labels=None):
125 | outputs = self.linear(input_ids.float())
126 | return {"logits": outputs}
127 |
128 |
129 | class MyTrainer(Trainer):
130 | def _get_train_sampler(self):
131 | effective_batch_size = BATCH_PER_DEVICE * self.accelerator.num_processes * GRAD_ACC
132 | print("effective_batch_size:", effective_batch_size)
133 | return RepeatRandomSampler(
134 | data_source=self.train_dataset,
135 | mini_repeat_count=8,
136 | batch_size=effective_batch_size//8,
137 | repeat_count=1*GRAD_ACC,
138 | seed=self.args.seed,
139 | )
140 |
141 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
142 | input_ids = inputs['input_ids']
143 | labels = inputs['labels']
144 | all_labels = self.accelerator.gather(labels)
145 |
146 | rank = self.accelerator.process_index
147 | if rank == 0:
148 | print("rank:", self.accelerator.process_index, "global_step:", self.state.global_step, "labels:", all_labels.tolist(), "\n")
149 |
150 | # Forward pass
151 | outputs = model(input_ids)
152 |
153 | logits = outputs['logits']
154 | loss = torch.nn.functional.mse_loss(logits, labels.float())
155 |
156 | if return_outputs:
157 | return loss, outputs
158 | return loss
159 |
160 | # Run dummy training
161 | if __name__ == "__main__":
162 | dataset = DummyDataset(size=300)
163 | model = DummyModel()
164 |
165 | training_args = TrainingArguments(
166 | per_device_train_batch_size=BATCH_PER_DEVICE*GRAD_ACC,
167 | gradient_accumulation_steps=GRAD_ACC,
168 | num_train_epochs=1,
169 | logging_steps=1,
170 | save_steps=10,
171 | logging_dir="./logs",
172 | disable_tqdm=False,
173 | seed=42,
174 | report_to=None
175 | )
176 |
177 | trainer = MyTrainer(
178 | model=model,
179 | args=training_args,
180 | train_dataset=dataset,
181 | )
182 |
183 | trainer.train()
184 |
--------------------------------------------------------------------------------
/tests/test_vllm.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | from io import BytesIO
4 | import json
5 | import time
6 | from typing import List
7 | from transformers import AutoProcessor
8 |
9 | import torch
10 | from PIL import Image
11 |
12 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient
13 | from datasets import load_dataset
14 | from tqdm import tqdm
15 |
16 | from transformers import Qwen2VLForConditionalGeneration
17 |
18 | SYSTEM_PROMPT = (
19 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
20 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
21 | "process and answer are enclosed within and tags, respectively, i.e., "
22 | " reasoning process here answer here "
23 | )
24 |
25 |
26 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
27 | buffer = BytesIO()
28 | image.save(buffer, format=format)
29 | return base64.b64encode(buffer.getvalue()).decode("utf-8")
30 |
31 |
32 | def prepare_messages(item):
33 | def replace_answer_format(item: str) -> str:
34 | return item.replace("", "```json").replace("", "```")
35 |
36 | image = item['image']
37 | org_iw, org_ih = image.size
38 |
39 | prompt = item['prompt_open']
40 | prompt = prompt.replace(f"of size ({org_iw} x {org_ih}) ", "")
41 | prompt = replace_answer_format(prompt)
42 |
43 | encoded_image_text = encode_image_to_base64(image)
44 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}"
45 |
46 | messages_vllm = [
47 | {"role": "system",
48 | "content": SYSTEM_PROMPT
49 | },
50 | {
51 | "role": "user",
52 | "content": [
53 | {"type": "image_url", "image_url": {"url": base64_qwen}},
54 | {"type": "text", "text": prompt},
55 | ],
56 | },
57 | ]
58 |
59 | return messages_vllm
60 |
61 |
62 |
63 |
64 | def main(args):
65 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train']
66 |
67 | print(f"[INFO] Connecting to vLLM server at {args.hosts}:{args.server_port}")
68 | processor = AutoProcessor.from_pretrained(args.model_name_or_path)
69 | prompts = []
70 | for kk, item in enumerate(tqdm(db)):
71 | if kk > 10: break
72 | prompt = prepare_messages(item)
73 | prompts.append(prompt)
74 |
75 |
76 | client = VLLMClient(
77 | hosts=args.hosts, #.split(','),
78 | server_ports=args.server_port,
79 | group_port=args.group_port,
80 | connection_timeout=60,
81 | )
82 | model = Qwen2VLForConditionalGeneration.from_pretrained(
83 | "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto"
84 | )
85 | for name, param in model.named_parameters():
86 | client.update_named_param(name, param)
87 |
88 | print("[INFO] Running vLLM inference...")
89 | t0 = time.time()
90 | prompts = [json.dumps(e) for e in prompts]
91 | print(len(prompts))
92 |
93 | generated_ids = client.loop.run_until_complete(client.chat(prompts, n=8, max_tokens=1024,
94 | top_p=0.001, top_k=1, temperature=0.01))
95 |
96 | t1 = time.time() - t0
97 | #generated_ids = [torch.as_tensor(e) for e in generated_ids]
98 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True),
99 | print(len(outputs))
100 | print("****** vLLM generated text:",
101 | outputs[0][0],
102 | " cost:", t1)
103 |
104 |
105 |
106 | def cal_cost(client, model, lens):
107 | cost = []
108 | for i in range(3):
109 | t0 = time.time()
110 | #client.update_model_in_chunks(model, lens)
111 |
112 | named_params = list(model.named_parameters())
113 | chunk_size = lens # or tune based on memory
114 |
115 | for i in range(0, len(named_params), chunk_size):
116 | chunk = named_params[i:i+chunk_size]
117 | client.update_model_in_chunks_from_named_list(chunk)
118 |
119 | t1 = time.time()
120 | cost.append(t1-t0)
121 | return sum(cost)/len(cost)
122 |
123 | def cal_cost_by_size(client, model, max_bytes):
124 | cost = []
125 | for i in range(3):
126 | t0 = time.time()
127 | chunks = [] # List to accumulate (name, param) tuples
128 | current_chunk_bytes = 0 # Accumulated memory size in bytes
129 |
130 | for name, param in model.named_parameters():
131 | param_bytes = param.numel() * param.element_size()
132 |
133 | # If adding this parameter would exceed the max_bytes limit
134 | if current_chunk_bytes + param_bytes > max_bytes:
135 | # Process the current chunk if not empty
136 | if chunks:
137 | client.update_model_in_chunks_from_named_list(chunks)
138 | chunks = []
139 | current_chunk_bytes = 0
140 |
141 | # If the parameter itself exceeds max_bytes, process it individually
142 | if param_bytes > max_bytes:
143 | client.update_model_in_chunks_from_named_list([(name, param)])
144 | else:
145 | # Otherwise, add the parameter to the current chunk
146 | chunks.append((name, param))
147 | current_chunk_bytes += param_bytes
148 |
149 | # Process any remaining parameters
150 | if chunks:
151 | client.update_model_in_chunks_from_named_list(chunks)
152 |
153 | t1 = time.time()
154 | cost.append(t1 - t0)
155 | return sum(cost) / len(cost)
156 |
157 |
158 |
159 | for k in range(1, 10):
160 | try:
161 | GB = (1<<30) * 0.1 * k
162 | print(f"update cost with chunk size={k} GB:", cal_cost_by_size(client, model, GB))
163 | except:
164 | print("Timeout at", k)
165 | break
166 |
167 |
168 | if __name__ == "__main__":
169 | parser = argparse.ArgumentParser()
170 | parser.add_argument("--hosts", type=str, default="[127.0.0.1]", help="Host address of the vLLM server.")
171 | parser.add_argument("--server_port", type=str, default='8000', help="Port for vLLM API requests.")
172 | parser.add_argument("--group_port", type=int, default=51216, help="Port for NCCL communication.")
173 | parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2-VL-7B-Instruct", help="Model ID or path.")
174 | args = parser.parse_args()
175 | main(args)
176 |
--------------------------------------------------------------------------------
/tests/test_vllm_local.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import base64
3 | from io import BytesIO
4 | import json
5 | import time
6 | import random
7 | from typing import List
8 | from transformers import AutoProcessor
9 |
10 | import torch
11 | from PIL import Image
12 |
13 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient
14 | from datasets import load_dataset
15 | from tqdm import tqdm
16 |
17 | from transformers import Qwen2VLForConditionalGeneration
18 | from transformers import FineGrainedFP8Config
19 |
20 | SYSTEM_PROMPT = (
21 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
22 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
23 | "process and answer are enclosed within and tags, respectively, i.e., "
24 | " reasoning process here answer here "
25 | )
26 |
27 |
28 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str:
29 | buffer = BytesIO()
30 | image.save(buffer, format=format)
31 | return base64.b64encode(buffer.getvalue()).decode("utf-8")
32 |
33 |
34 | def prepare_messages(item):
35 | def replace_answer_format(item: str) -> str:
36 | return item.replace("", "```json").replace("", "```")
37 |
38 | image = item['image']
39 | org_iw, org_ih = image.size
40 |
41 | prompt = item['prompt_open']
42 | prompt = prompt.replace(f"of size ({org_iw} x {org_ih}) ", "")
43 | prompt = replace_answer_format(prompt)
44 |
45 | encoded_image_text = encode_image_to_base64(image)
46 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}"
47 |
48 | messages_vllm = [
49 | {"role": "system",
50 | "content": SYSTEM_PROMPT
51 | },
52 | {
53 | "role": "user",
54 | "content": [
55 | {"type": "image_url", "image_url": {"url": base64_qwen}},
56 | {"type": "text", "text": prompt},
57 | ],
58 | },
59 | ]
60 | return messages_vllm
61 |
62 |
63 | def main(args):
64 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train']
65 |
66 | processor = AutoProcessor.from_pretrained(args.model_name_or_path, max_pixels=512*28*28)
67 | prompts = []
68 | for kk, item in enumerate(tqdm(db)):
69 | if kk > 200: break
70 | prompt = prepare_messages(item)
71 | prompts.append(prompt)
72 |
73 | use_fp8 = True
74 |
75 | quantization_config = FineGrainedFP8Config()
76 | kwargs = {"device_map": "cuda:0", 'attn_implementation': 'flash_attention_2'}
77 | if use_fp8:
78 | kwargs['quantization_config'] = quantization_config
79 |
80 |
81 | client = VLLMClient(
82 | local_vllm=True,
83 | model_name=args.model_name_or_path,
84 | max_pixels=512*28*28,
85 | use_fp8=use_fp8,
86 | device='cuda:0',
87 | gpu_memory_utilization=0.9
88 | )
89 |
90 |
91 |
92 |
93 | print("[INFO] Running vLLM inference...")
94 | t0 = time.time()
95 | print("len(prompts):", len(prompts))
96 |
97 | generated_ids = client.run_chat([prompts[0]], n=8, max_tokens=1024,
98 | top_p=0.95, top_k=50, temperature=1.0)
99 |
100 | t1 = time.time() - t0
101 | #generated_ids = [torch.as_tensor(e) for e in generated_ids]
102 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True),
103 | print(len(outputs))
104 | print("****** vLLM generated text:")
105 | for i in range(8):
106 | print(outputs[0][i])
107 |
108 | print(" cost:", t1)
109 | # benchmark speed
110 | cost = []
111 | t0 = time.time()
112 | BS = 8
113 | N = 5
114 | for i in tqdm(range(N)):
115 | s0 = time.time()
116 | generated_ids = client.run_chat(prompts[2+i*BS: 2+(i+1)*BS], n=8, max_tokens=1024,
117 | top_p=0.95, top_k=50, temperature=1.0)
118 | s1 = time.time()
119 | cost.append(s1-s0)
120 |
121 | t1 = time.time()
122 | print("Benchmark speed :", (t1-t0)/ (BS*N), "second / item", " Batch cost: ", sum(cost)/len(cost) )
123 | quit()
124 | model = Qwen2VLForConditionalGeneration.from_pretrained(
125 | "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16,
126 | **kwargs
127 | )
128 |
129 |
130 | # check weight sync.
131 | llmp = client.llm.llm_engine.model_executor.driver_worker.model_runner.model
132 | llmp_dicts = dict(llmp.named_parameters())
133 |
134 | miss1 = []
135 | with torch.no_grad():
136 | for name, param in model.named_parameters():
137 | # Fetch corresponding param from llmp
138 | llmp_param = llmp_dicts.get(name, None)
139 | if llmp_param is None:
140 | #print(f"[WARN] Parameter {name} not found in vLLM model.")
141 | miss1.append( (name, param.data.min()) )
142 | param.data += 100
143 | continue
144 |
145 | # Compare tensors
146 | if not torch.allclose(param.data, llmp_param.data, atol=1e-5):
147 | print(f"[FAIL] Mismatch in param '{name}'")
148 | else:
149 | print(f"[PASS] Param '{name}' is synchronized.")
150 | param.data += 100
151 |
152 |
153 | print("\n", "*"*100, "start weight synchronization ...\n", "*"*100, "\n")
154 | max_chunk_size = 100 * 1024 * 1024 # 100 MB
155 | param_chunk = []
156 | current_chunk_size = 0
157 | del llmp_dicts
158 |
159 | t0 = time.time()
160 | updated_params = set()
161 | for name, param in model.named_parameters():
162 | # Calculate the size of this parameter in bytes
163 | param_size = param.numel() * param.element_size()
164 |
165 | param_chunk.append((name, param.data))
166 | current_chunk_size += param_size
167 |
168 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk.
169 | if current_chunk_size >= max_chunk_size:
170 | old = client.update_model_in_chunks_from_named_list(param_chunk)
171 | updated_params.update(old)
172 | # Reset for the next chunk
173 | param_chunk = []
174 | current_chunk_size = 0
175 |
176 | if param_chunk and client is not None:
177 | client.update_model_in_chunks_from_named_list(param_chunk)
178 | t1 = time.time()
179 | print("weight synchronization cost:", t1-t0)
180 | # check again
181 | llmp_dicts = dict(llmp.named_parameters())
182 |
183 | miss2 = []
184 | for name, param in model.named_parameters():
185 | # Fetch corresponding param from llmp
186 | llmp_param = llmp_dicts.get(name, None)
187 | if llmp_param is None:
188 | miss2.append((name, param.data.min()))
189 | #print(f"[WARN] Parameter {name} not found in vLLM model.")
190 | continue
191 | # Compare tensors
192 | if not torch.allclose(param.data, llmp_param.data, atol=1e-5):
193 | print(f"[FAIL] Mismatch in param '{name}'")
194 | else:
195 | print(f"[PASS] Param '{name}' is synchronized.")
196 |
197 | #import pdb; pdb.set_trace()
198 |
199 | def cal_cost(client, model, lens):
200 | cost = []
201 | for i in range(3):
202 | t0 = time.time()
203 | #client.update_model_in_chunks(model, lens)
204 |
205 | named_params = list(model.named_parameters())
206 | chunk_size = lens # or tune based on memory
207 |
208 | for i in range(0, len(named_params), chunk_size):
209 | chunk = named_params[i:i+chunk_size]
210 | client.update_model_in_chunks_from_named_list(chunk)
211 |
212 | t1 = time.time()
213 | cost.append(t1-t0)
214 | return sum(cost)/len(cost)
215 |
216 | def cal_cost_by_size(client, model, max_bytes):
217 | cost = []
218 | for i in range(3):
219 | t0 = time.time()
220 | chunks = [] # List to accumulate (name, param) tuples
221 | current_chunk_bytes = 0 # Accumulated memory size in bytes
222 |
223 | for name, param in model.named_parameters():
224 | param_bytes = param.numel() * param.element_size()
225 |
226 | # If adding this parameter would exceed the max_bytes limit
227 | if current_chunk_bytes + param_bytes > max_bytes:
228 | # Process the current chunk if not empty
229 | if chunks:
230 | client.update_model_in_chunks_from_named_list(chunks)
231 | chunks = []
232 | current_chunk_bytes = 0
233 |
234 | # If the parameter itself exceeds max_bytes, process it individually
235 | if param_bytes > max_bytes:
236 | client.update_model_in_chunks_from_named_list([(name, param)])
237 | else:
238 | # Otherwise, add the parameter to the current chunk
239 | chunks.append((name, param))
240 | current_chunk_bytes += param_bytes
241 |
242 | # Process any remaining parameters
243 | if chunks:
244 | client.update_model_in_chunks_from_named_list(chunks)
245 |
246 | t1 = time.time()
247 | cost.append(t1 - t0)
248 | return sum(cost) / len(cost)
249 |
250 |
251 |
252 | for k in range(1, 10):
253 | try:
254 | GB = (1<<30) * 0.1 * k
255 | print(f"update cost with chunk size={k} GB:", cal_cost_by_size(client, model, GB))
256 | except:
257 | print("Timeout at", k)
258 | break
259 |
260 |
261 | if __name__ == "__main__":
262 | parser = argparse.ArgumentParser()
263 | parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2-VL-7B-Instruct", help="Model ID or path.")
264 | args = parser.parse_args()
265 | main(args)
266 |
267 |
--------------------------------------------------------------------------------