├── .gitignore ├── LICENSE ├── README.md ├── install.sh ├── local_scripts ├── fsdp.yaml ├── fsdp_config.json ├── zero2.json ├── zero2_offload.json ├── zero3++.json └── zero3.json ├── open_r1 ├── __init__.py ├── evaluate.py ├── generate.py ├── grpo.py └── trainer │ ├── __init__.py │ ├── grpo_config.py │ ├── grpo_trainer.py │ └── utils │ ├── __init__.py │ ├── misc.py │ ├── prompt_gallery.py │ ├── vllm_client.py │ └── vllm_client_v2.py ├── requirements.txt ├── scripts ├── .DS_Store ├── debug │ ├── debug.sh │ ├── debug_gh200.sh │ ├── debug_gh200_local_vllm.sh │ └── test_vllm.sh ├── grpo │ ├── run_vllm.sh │ ├── train_a100.sh │ ├── train_a100_2B.sh │ ├── train_a100_2B_SFT.sh │ ├── train_a100_2B_close.sh │ ├── train_a100_2B_close_SFT.sh │ ├── train_a100_SFT.sh │ ├── train_a100_close.sh │ ├── train_fsdp.sh │ ├── train_fused.sh │ ├── train_gh200.sh │ ├── train_gh200_2B.sh │ ├── train_gh200_2B_SFT.sh │ ├── train_gh200_2B_close.sh │ ├── train_gh200_SFT.sh │ ├── train_gh200_close.sh │ ├── train_gh200_close_SFT.sh │ └── train_zero3.sh ├── inference │ └── run_sgg_inference.sh ├── sft │ ├── 2B_sgg.sh │ ├── 2B_sgg_predefined.sh │ ├── 7B_sgg.sh │ ├── 7B_sgg_lora.sh │ └── 7B_sgg_predefined.sh └── sft_local │ ├── 2B_sgg.sh │ ├── 2B_sgg_predefined.sh │ ├── 7B_sgg.sh │ └── 7B_sgg_predefined.sh ├── setup.py ├── src ├── __init__.py ├── mega_1m_category.py ├── psg_categories.json ├── sft_sgg.py ├── sgg_gather_preds.py ├── sgg_inference_vllm.py ├── utils │ ├── __init__.py │ ├── bbox_overlaps.py │ ├── bounding_box.py │ ├── cocoeval.py │ ├── misc.py │ ├── sgg_eval.py │ ├── sgg_metrics.py │ ├── wordnet.py │ └── zeroshot_triplet.pytorch ├── vg150_eval.py ├── vg_synonyms.py └── vllm_server_v2.py └── tests ├── test_fsdp.py ├── test_rewards.py ├── test_sampler.py ├── test_vllm.py └── test_vllm_local.py /.gitignore: -------------------------------------------------------------------------------- 1 | *__pycache__* 2 | *pyc 3 | *.out 4 | *.log 5 | *.egg-info* 6 | 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R1-SGG: Compile Scene Graphs with Reinforcement Learning 2 | 3 | ## **Structured Visual Reasoning with Multimodal LLMs and Reinforcement Learning** 4 | [![Paper](https://img.shields.io/badge/arXiv-2504.13617-b31b1b.svg)](https://arxiv.org/abs/2504.13617) [![License](https://img.shields.io/badge/license-Apache--2.0-blue.svg)](LICENSE) [![Hugging Face](https://img.shields.io/badge/HuggingFace-Demo-orange?logo=huggingface)](https://huggingface.co/spaces/JosephZ/R1-SGG) 5 | --- 6 | 7 | ## 🚀 Update 8 | - ✅ ![Hugging Face](https://img.shields.io/badge/HuggingFace-Model-orange?logo=huggingface)[R1-SGG-7B](https://huggingface.co/JosephZ/R1-SGG-7B), [R1-SGG-Zero-7B](https://huggingface.co/JosephZ/R1-SGG-Zero-7B) 9 | - ✅ Support [PSG](https://github.com/Jingkang50/OpenPSG) dataset (bbox format only, not Panoptic) 10 | - ✅ Updated loss implementation 11 | - ✅ Always use `custom_per_device_train_batch_size` instead of `per_device_train_batch_size` for faster sampling under gradient accumulation 12 | - ⚠️ Current loss implementation might still be affected by gradient accumulation: [trl issue #3021](https://github.com/huggingface/trl/issues/3021) 13 | 14 | --- 15 | 16 | ## 🛠️ Setup Environment 17 | ```bash 18 | bash install.sh 19 | ``` 20 | Main dependencies: 21 | ```bash 22 | - torch == 2.5.0 or 2.5.1 (cu124, optional) 23 | - transformers (supports Qwen2VL, Qwen2.5VL) 24 | - trl 25 | - vLLM 26 | ``` 27 | 28 | --- 29 | 30 | ## 📚 Dataset 31 | Load preprocessed datasets via: 32 | ```python 33 | from datasets import load_dataset 34 | 35 | db_train = load_dataset("JosephZ/vg150_train_sgg_prompt")["train"] 36 | db_val = load_dataset("JosephZ/vg150_val_sgg_prompt")["train"] 37 | ``` 38 | or for PSG: 39 | ```python 40 | db_train = load_dataset("JosephZ/psg_train_sg")["train"] # keys: image_id, image, objects, relationships 41 | db_val = load_dataset("JosephZ/psg_test_sg")["train"] 42 | ``` 43 | We transformed VG150 into HuggingFace Datasets format with keys: 44 | - `image_id` 45 | - `image` 46 | - `prompt_open` 47 | - `prompt_close` 48 | - `objects` 49 | - `relationships` 50 | 51 | --- 52 | 53 | ## 🔥 Supported Models 54 | - [x] Qwen/Qwen2-VL-2B-Instruct 55 | - [x] Qwen/Qwen2-VL-7B-Instruct 56 | - [x] Qwen/Qwen2.5-VL-3B-Instruct 57 | - [x] Qwen/Qwen2.5-VL-7B-Instruct 58 | 59 | --- 60 | 61 | ## 🏋️‍♂️ Training 62 | 63 | ### Training with Supervised Fine-Tuning (SFT) 64 | 65 | For **SLURM users**: 66 | ```bash 67 | sbatch scripts/sft/7B_sgg.sh 68 | ``` 69 | 70 | For **local machines**: 71 | ```bash 72 | bash scripts/sft_local/7B_sgg.sh 73 | ``` 74 | ⏱️ Approximate training time: 75 | - 2B models: ~4 hours (4×A100 SXM4 GPUs) 76 | - 7B models: ~10 hours (4×A100 SXM4 GPUs) 77 | 78 | --- 79 | 80 | ### Training with Reinforcement Learning (GRPO) 81 | ** Update (11/05/2025): to use "Hard Recall"**: 82 | ``` 83 | --reward_funcs format_reward edge_hard_reward 84 | ``` 85 | 86 | For **A100 GPUs**: 87 | ```bash 88 | sbatch scripts/grpo/train_a100_2B.sh 89 | ``` 90 | (12 hours on 16×A100 GPUs) 91 | 92 | For **GH200 GPUs**: 93 | ```bash 94 | sbatch scripts/grpo/train_gh200.sh 95 | ``` 96 | (16 hours on 16×GH200 GPUs) 97 | 98 | For clusters with many RTX_3090/4090 GPUs: 99 | ```bash 100 | sbatch scripts/grpo/train_fused.sh 101 | ``` 102 | - Training 7B models on 24GB cards is possible with Zero3, but slow due to communication bottlenecks. 103 | - (Fun fact: training with 120×RTX_4090 is crazy but severely limited by communication latency.) 104 | 105 | 💡 **Recommended learning rate**: `6e-7`. 106 | 107 | --- 108 | 109 | ## 🧪 Inference and Evaluation 110 | 111 | ### Inference with SFT-trained models: 112 | ```bash 113 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR 114 | ``` 115 | For models trained **with predefined categories**, add `true`: 116 | ```bash 117 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR true 118 | ``` 119 | 120 | ### Inference with GRPO-trained models: 121 | ```bash 122 | bash scripts/inference/run_sgg_inference.sh $DATASET $MODEL_NAME $OUTPUT_DIR false/true true 123 | ``` 124 | 125 | ### Evaluation: 126 | ```bash 127 | DATASET_TYPE=vg # or psg 128 | python src/sgg_gather_preds.py $DATASET_TYPE $OUTPUT_DIR sgg_pred_results.json 129 | python src/vg150_eval.py $DATASET sgg_pred_results.json 130 | ``` 131 | 132 | --- 133 | 134 | ## 🤝 Acknowledgement 135 | The `GRPOTrainer` used in this project is based on [trl's GRPOTrainer](https://github.com/huggingface/trl/blob/main/trl/trainer/grpo_trainer.py), extended to support multimodal inputs. 136 | 137 | --- 138 | 139 | ## 📖 Citation 140 | If you find this work helpful, please cite: 141 | ```bibtex 142 | @article{chen2025compile, 143 | title={Compile Scene Graphs with Reinforcement Learning}, 144 | author={Chen, Zuyao and Wu, Jinlin and Lei, Zhen and Pollefeys, Marc and Chen, Chang Wen}, 145 | journal={arXiv preprint arXiv:2504.13617}, 146 | year={2025} 147 | } 148 | ``` 149 | 150 | --- 151 | 152 | # ✨ Happy Compiling! 153 | 154 | -------------------------------------------------------------------------------- /install.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | pip install torch==2.5.1 torchvision==0.20.1 torchaudio==2.5.1 --index-url https://download.pytorch.org/whl/cu124 5 | 6 | #pip install transformers@git+https://github.com/huggingface/transformers.git@2c2495cc7b0e3e2942a9310f61548f40a2bc8425 7 | pip install transformers==4.50.3 8 | 9 | pip install trl@git+https://github.com/huggingface/trl.git@ece6738686a8527345532e6fed8b3b1b75f16b16 10 | 11 | pip install --upgrade --no-build-isolation flash-attn==2.7.4.post1 12 | 13 | # for GH200, 14 | #MAX_JOBS=20 pip install --upgrade --no-build-isolation flash-attn==2.7.4.post1 15 | 16 | #git clone https://github.com/triton-lang/triton.git && git checkout 85267600 && cd triton && \ 17 | #pip install -r python/requirements.txt # build-time dependencies && \ 18 | #pip install -e python 19 | 20 | pip install -r requirements.txt 21 | 22 | # for GH200, 23 | #pip uninstall -y vllm &&git clone https://github.com/vllm-project/vllm.git&& cd vllm && git checkout ed6e9075d31e32c8548b480\ 24 | #python use_existing_torch.py && pip install -r requirements/build.txt && pip install --no-build-isolation -e . 25 | 26 | pip install -e . 27 | -------------------------------------------------------------------------------- /local_scripts/fsdp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: true 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: false 12 | fsdp_offload_params: false 13 | fsdp_sharding_strategy: HYBRID_SHARD 14 | fsdp_state_dict_type: SHARDED_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_transformer_layer_cls_to_wrap: null 17 | fsdp_use_orig_params: true 18 | 19 | main_training_function: main 20 | mixed_precision: bf16 21 | rdzv_backend: static 22 | same_network: true 23 | tpu_env: [] 24 | tpu_use_cluster: false 25 | tpu_use_sudo: false 26 | use_cpu: false 27 | -------------------------------------------------------------------------------- /local_scripts/fsdp_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "backward_prefetch": "backward_pre", 3 | "forward_prefetch": "true", 4 | "activation_checkpointing": "true" 5 | } 6 | -------------------------------------------------------------------------------- /local_scripts/zero2.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "none", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } 42 | -------------------------------------------------------------------------------- /local_scripts/zero2_offload.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "optimizer": { 14 | "type": "AdamW", 15 | "params": { 16 | "lr": "auto", 17 | "betas": "auto", 18 | "eps": "auto", 19 | "weight_decay": "auto" 20 | } 21 | }, 22 | "zero_optimization": { 23 | "stage": 2, 24 | "offload_optimizer": { 25 | "device": "cpu", 26 | "pin_memory": true 27 | }, 28 | "allgather_partitions": true, 29 | "allgather_bucket_size": 2e8, 30 | "overlap_comm": false, 31 | "reduce_scatter": true, 32 | "reduce_bucket_size": 2e8, 33 | "contiguous_gradients": true 34 | }, 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } 42 | -------------------------------------------------------------------------------- /local_scripts/zero3++.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | "zero_optimization": { 14 | "stage": 3, 15 | "offload_optimizer": { 16 | "device": "cpu", 17 | "pin_memory": true 18 | }, 19 | "offload_param": { 20 | "device": "cpu", 21 | "pin_memory": true 22 | }, 23 | "zero_hpz_partition_size": 8, 24 | "zero_quantized_weights": false, 25 | "zero_quantized_gradients": false, 26 | "overlap_comm": true, 27 | "contiguous_gradients": true, 28 | "sub_group_size": 1e9, 29 | "reduce_bucket_size": "auto", 30 | "stage3_prefetch_bucket_size": "auto", 31 | "stage3_param_persistence_threshold": "auto", 32 | "stage3_max_live_parameters": 1e9, 33 | "stage3_max_reuse_distance": 1e9, 34 | "stage3_gather_16bit_weights_on_model_save": true 35 | }, 36 | "train_batch_size": "auto", 37 | "train_micro_batch_size_per_gpu": "auto", 38 | "gradient_accumulation_steps": "auto", 39 | "gradient_clipping": "auto", 40 | "zero_allow_untested_optimizer": true, 41 | } 42 | -------------------------------------------------------------------------------- /local_scripts/zero3.json: -------------------------------------------------------------------------------- 1 | { 2 | "fp16": { 3 | "enabled": "auto", 4 | "loss_scale": 0, 5 | "loss_scale_window": 1000, 6 | "initial_scale_power": 16, 7 | "hysteresis": 2, 8 | "min_loss_scale": 1 9 | }, 10 | "bf16": { 11 | "enabled": "auto" 12 | }, 13 | 14 | "zero_optimization": { 15 | "stage": 3, 16 | "offload_optimizer": { 17 | "device": "none", 18 | "pin_memory": true 19 | }, 20 | "offload_param": { 21 | "device": "none", 22 | "pin_memory": true 23 | }, 24 | "overlap_comm": true, 25 | "contiguous_gradients": true, 26 | "sub_group_size": 1e9, 27 | "reduce_bucket_size": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "stage3_param_persistence_threshold": "auto", 30 | "stage3_max_live_parameters": 1e9, 31 | "stage3_max_reuse_distance": 1e9, 32 | "stage3_gather_16bit_weights_on_model_save": true 33 | }, 34 | 35 | "gradient_accumulation_steps": "auto", 36 | "gradient_clipping": "auto", 37 | "steps_per_print": 100, 38 | "train_batch_size": "auto", 39 | "train_micro_batch_size_per_gpu": "auto", 40 | "wall_clock_breakdown": false 41 | } -------------------------------------------------------------------------------- /open_r1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/open_r1/__init__.py -------------------------------------------------------------------------------- /open_r1/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom evaluation tasks for LightEval.""" 16 | 17 | from lighteval.metrics.dynamic_metrics import ( 18 | ExprExtractionConfig, 19 | LatexExtractionConfig, 20 | multilingual_extractive_match_metric, 21 | ) 22 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 23 | from lighteval.tasks.requests import Doc 24 | from lighteval.utils.language import Language 25 | 26 | 27 | metric = multilingual_extractive_match_metric( 28 | language=Language.ENGLISH, 29 | fallback_mode="first_match", 30 | precision=5, 31 | gold_extraction_target=(LatexExtractionConfig(),), 32 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig()), 33 | aggregation_function=max, 34 | ) 35 | 36 | 37 | def prompt_fn(line, task_name: str = None): 38 | """Assumes the model is either prompted to emit \\boxed{answer} or does so automatically""" 39 | return Doc( 40 | task_name=task_name, 41 | query=line["problem"], 42 | choices=[line["solution"]], 43 | gold_index=0, 44 | ) 45 | 46 | 47 | # Define tasks 48 | aime24 = LightevalTaskConfig( 49 | name="aime24", 50 | suite=["custom"], 51 | prompt_function=prompt_fn, 52 | hf_repo="HuggingFaceH4/aime_2024", 53 | hf_subset="default", 54 | hf_avail_splits=["train"], 55 | evaluation_splits=["train"], 56 | few_shots_split=None, 57 | few_shots_select=None, 58 | generation_size=32768, 59 | metric=[metric], 60 | version=1, 61 | ) 62 | math_500 = LightevalTaskConfig( 63 | name="math_500", 64 | suite=["custom"], 65 | prompt_function=prompt_fn, 66 | hf_repo="HuggingFaceH4/MATH-500", 67 | hf_subset="default", 68 | hf_avail_splits=["test"], 69 | evaluation_splits=["test"], 70 | few_shots_split=None, 71 | few_shots_select=None, 72 | generation_size=32768, 73 | metric=[metric], 74 | version=1, 75 | ) 76 | 77 | # Add tasks to the table 78 | TASKS_TABLE = [] 79 | TASKS_TABLE.append(aime24) 80 | TASKS_TABLE.append(math_500) 81 | 82 | # MODULE LOGIC 83 | if __name__ == "__main__": 84 | print([t["name"] for t in TASKS_TABLE]) 85 | print(len(TASKS_TABLE)) 86 | -------------------------------------------------------------------------------- /open_r1/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from distilabel.llms import OpenAILLM 18 | from distilabel.pipeline import Pipeline 19 | from distilabel.steps.tasks import TextGeneration 20 | 21 | 22 | def build_distilabel_pipeline( 23 | model: str, 24 | base_url: str = "http://localhost:8000/v1", 25 | prompt_column: Optional[str] = None, 26 | temperature: Optional[float] = None, 27 | top_p: Optional[float] = None, 28 | max_new_tokens: int = 8192, 29 | num_generations: int = 1, 30 | ) -> Pipeline: 31 | generation_kwargs = {"max_new_tokens": max_new_tokens} 32 | 33 | if temperature is not None: 34 | generation_kwargs["temperature"] = temperature 35 | 36 | if top_p is not None: 37 | generation_kwargs["top_p"] = top_p 38 | 39 | with Pipeline().ray() as pipeline: 40 | TextGeneration( 41 | llm=OpenAILLM( 42 | base_url=base_url, 43 | api_key="something", 44 | model=model, 45 | # thinking can take some time... 46 | timeout=10 * 60, 47 | generation_kwargs=generation_kwargs, 48 | ), 49 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, 50 | input_batch_size=64, # on 4 nodes bs ~60+ leads to preemption due to KV cache exhaustion 51 | num_generations=num_generations, 52 | ) 53 | 54 | return pipeline 55 | 56 | 57 | if __name__ == "__main__": 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | 62 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") 63 | parser.add_argument( 64 | "--hf-dataset", 65 | type=str, 66 | required=True, 67 | help="HuggingFace dataset to load", 68 | ) 69 | parser.add_argument( 70 | "--hf-dataset-config", 71 | type=str, 72 | required=False, 73 | help="Dataset config to use", 74 | ) 75 | parser.add_argument( 76 | "--hf-dataset-split", 77 | type=str, 78 | default="train", 79 | help="Dataset split to use", 80 | ) 81 | parser.add_argument("--prompt-column", type=str, default="prompt") 82 | parser.add_argument( 83 | "--model", 84 | type=str, 85 | required=True, 86 | help="Model name to use for generation", 87 | ) 88 | parser.add_argument( 89 | "--vllm-server-url", 90 | type=str, 91 | default="http://localhost:8000/v1", 92 | help="URL of the vLLM server", 93 | ) 94 | parser.add_argument( 95 | "--temperature", 96 | type=float, 97 | help="Temperature for generation", 98 | ) 99 | parser.add_argument( 100 | "--top-p", 101 | type=float, 102 | help="Top-p value for generation", 103 | ) 104 | parser.add_argument( 105 | "--max-new-tokens", 106 | type=int, 107 | default=8192, 108 | help="Maximum number of new tokens to generate", 109 | ) 110 | parser.add_argument( 111 | "--num-generations", 112 | type=int, 113 | default=1, 114 | help="Number of generations per problem", 115 | ) 116 | parser.add_argument( 117 | "--hf-output-dataset", 118 | type=str, 119 | required=False, 120 | help="HuggingFace repo to push results to", 121 | ) 122 | parser.add_argument( 123 | "--private", 124 | action="store_true", 125 | help="Whether to make the output dataset private when pushing to HF Hub", 126 | ) 127 | 128 | args = parser.parse_args() 129 | 130 | print("\nRunning with arguments:") 131 | for arg, value in vars(args).items(): 132 | print(f" {arg}: {value}") 133 | print() 134 | 135 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") 136 | dataset = load_dataset(args.hf_dataset, split=args.hf_dataset_split) 137 | print("Dataset loaded!") 138 | 139 | pipeline = build_distilabel_pipeline( 140 | model=args.model, 141 | base_url=args.vllm_server_url, 142 | prompt_column=args.prompt_column, 143 | temperature=args.temperature, 144 | top_p=args.top_p, 145 | max_new_tokens=args.max_new_tokens, 146 | num_generations=args.num_generations, 147 | ) 148 | 149 | print("Running generation pipeline...") 150 | distiset = pipeline.run(dataset=dataset, use_cache=False) 151 | print("Generation pipeline finished!") 152 | 153 | if args.hf_output_dataset: 154 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") 155 | distiset.push_to_hub(args.hf_output_dataset, private=args.private) 156 | print("Dataset pushed!") 157 | -------------------------------------------------------------------------------- /open_r1/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .grpo_config import GRPOConfig 2 | from .grpo_trainer import GRPOTrainerV2 3 | 4 | 5 | __all__ = ["GRPOTrainerV2", "GRPOConfig"] 6 | -------------------------------------------------------------------------------- /open_r1/trainer/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/open_r1/trainer/utils/__init__.py -------------------------------------------------------------------------------- /open_r1/trainer/utils/misc.py: -------------------------------------------------------------------------------- 1 | import base64 2 | from io import BytesIO 3 | from PIL import Image 4 | 5 | from transformers.utils.import_utils import _is_package_available 6 | 7 | 8 | _fastapi_available = _is_package_available("fastapi") 9 | _pydantic_available = _is_package_available("pydantic") 10 | _uvicorn_available = _is_package_available("uvicorn") 11 | _vllm_available = _is_package_available("vllm") 12 | _requests_available = _is_package_available("requests") 13 | 14 | def is_fastapi_available() -> bool: 15 | return _fastapi_available 16 | 17 | 18 | def is_pydantic_available() -> bool: 19 | return _pydantic_available 20 | 21 | def is_uvicorn_available() -> bool: 22 | return _uvicorn_available 23 | 24 | 25 | def is_vllm_available() -> bool: 26 | return _vllm_available 27 | 28 | def is_requests_available() -> bool: 29 | return _requests_available 30 | 31 | 32 | def is_pil_image(image) -> bool: 33 | return isinstance(image, Image.Image) 34 | 35 | 36 | def encode_image_to_base64(image: Image.Image, format: str = "PNG") -> str: 37 | """ 38 | Encode a PIL Image to a base64 string. 39 | 40 | Args: 41 | image (PIL.Image): The image to encode. 42 | format (str): Image format to use (e.g., "PNG", "JPEG"). Default is "PNG". 43 | 44 | Returns: 45 | str: Base64-encoded string of the image. 46 | """ 47 | buffer = BytesIO() 48 | image.save(buffer, format=format) 49 | buffer.seek(0) 50 | encoded_string = base64.b64encode(buffer.read()).decode("utf-8") 51 | return encoded_string 52 | 53 | def decode_base64_to_image(base64_str: str) -> Image.Image: 54 | """ 55 | Decode a base64 string back to a PIL Image. 56 | 57 | Args: 58 | base64_str (str): Base64-encoded string of the image. 59 | 60 | Returns: 61 | PIL.Image: Decoded image. 62 | """ 63 | image_data = base64.b64decode(base64_str) 64 | buffer = BytesIO(image_data) 65 | image = Image.open(buffer) 66 | return image 67 | -------------------------------------------------------------------------------- /open_r1/trainer/utils/prompt_gallery.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | VG150_OBJ_CATEGORIES = ['__background__', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra'] 4 | 5 | 6 | VG150_PREDICATES = ['__background__', "above", "across", "against", "along", "and", "at", "attached to", "behind", "belonging to", "between", "carrying", "covered in", "covering", "eating", "flying in", "for", "from", "growing on", "hanging from", "has", "holding", "in", "in front of", "laying on", "looking at", "lying on", "made of", "mounted on", "near", "of", "on", "on back of", "over", "painted on", "parked on", "part of", "playing", "riding", "says", "sitting on", "standing on", "to", "under", "using", "walking in", "walking on", "watching", "wearing", "wears", "with"] 7 | 8 | 9 | VG150_BASE_OBJ_CATEGORIES = set(['tile', 'drawer', 'men', 'railing', 'stand', 'towel', 'sneaker', 'vegetable', 'screen', 'vehicle', 'animal', 'kite', 'cabinet', 'sink', 'wire', 'fruit', 'curtain', 'lamp', 'flag', 'pot', 'sock', 'boot', 'guy', 'kid', 'finger', 'basket', 'wave', 'lady', 'orange', 'number', 'toilet', 'post', 'room', 'paper', 'mountain', 'paw', 'banana', 'rock', 'cup', 'hill', 'house', 'airplane', 'plant', 'skier', 'fork', 'box', 'seat', 'engine', 'mouth', 'letter', 'windshield', 'desk', 'board', 'counter', 'branch', 'coat', 'logo', 'book', 'roof', 'tie', 'tower', 'glove', 'sheep', 'neck', 'shelf', 'bottle', 'cap', 'vase', 'racket', 'ski', 'phone', 'handle', 'boat', 'tire', 'flower', 'child', 'bowl', 'pillow', 'player', 'trunk', 'bag', 'wing', 'light', 'laptop', 'pizza', 'cow', 'truck', 'jean', 'eye', 'arm', 'leaf', 'bird', 'surfboard', 'umbrella', 'food', 'people', 'nose', 'beach', 'sidewalk', 'helmet', 'face', 'skateboard', 'motorcycle', 'clock', 'bear']) 10 | 11 | VG150_BASE_PREDICATE = set(["between", "to", "made of", "looking at", "along", "laying on", "using", "carrying", "against", "mounted on", "sitting on", "flying in", "covering", "from", "over", "near", "hanging from", "across", "at", "above", "watching", "covered in", "wearing", "holding", "and", "standing on", "lying on", "growing on", "under", "on back of", "with", "has", "in front of", "behind", "parked on"]) 12 | 13 | 14 | PROMPT_SG='Generate a structured scene graph for an image using the following format:\n\n```json\n{\n "objects": [\n {"id": "object_name.number", "bbox": [x1, y1, x2, y2]},\n ...\n ],\n "relationships": [\n {"subject": "object_name.number", "predicate": "relationship_type", "object": "object_name.number"},\n ...\n ]\n}\n```\n\n### **Guidelines:**\n- **Objects:**\n - Assign a unique ID for each object using the format `"object_name.number"` (e.g., `"person.1"`, `"bike.2"`).\n - Provide its bounding box `[x1, y1, x2, y2]` in integer pixel format.\n - Include all visible objects, even if they have no relationships.\n\n- **Relationships:**\n - Represent interactions accurately using `"subject"`, `"predicate"`, and `"object"`.\n - Omit relationships for orphan objects.\n\n### **Example Output:**\n```json\n{\n "objects": [\n {"id": "person.1", "bbox": [120, 200, 350, 700]},\n {"id": "bike.2", "bbox": [100, 600, 400, 800]},\n {"id": "helmet.3", "bbox": [150, 150, 280, 240]},\n {"id": "tree.4", "bbox": [500, 100, 750, 700]}\n ],\n "relationships": [\n {"subject": "person.1", "predicate": "riding", "object": "bike.2"},\n {"subject": "person.1", "predicate": "wearing", "object": "helmet.3"}\n ]\n}\n```\n\nNow, generate the complete scene graph for the provided image:\n' 15 | 16 | PROMPT_CLOSE_TEMPLATE='Generate a structured scene graph for an image using the specified object and relationship categories.\n\n### **Output Format:**\n```json\n{\n "objects": [\n {"id": "object_name.number", "bbox": [x1, y1, x2, y2]},\n ...\n ],\n "relationships": [\n {"subject": "object_name.number", "predicate": "relationship_type", "object": "object_name.number"},\n ...\n ]\n}\n```\n\n### **Guidelines:**\n- **Objects:**\n - Assign unique IDs in the format `"object_name.number"` (e.g., `"person.1"`). The **object_name** must belong to the predefined object set: `{OBJ_CLS}`.\n - Provide a bounding box `[x1, y1, x2, y2]` in integer pixel format.\n - Include all visible objects, even if they have no relationships.\n\n- **Relationships:**\n - Define relationships using `"subject"`, `"predicate"`, and `"object"`.\n - The **predicate** must belong to the predefined relationship set: `{REL_CLS}`.\n - Omit relationships for orphan objects.\n\n### **Example Output:**\n```json\n{\n "objects": [\n {"id": "person.1", "bbox": [120, 200, 350, 700]},\n {"id": "bike.2", "bbox": [100, 600, 400, 800]},\n {"id": "helmet.3", "bbox": [150, 150, 280, 240]},\n {"id": "tree.4", "bbox": [500, 100, 750, 700]}\n ],\n "relationships": [\n {"subject": "person.1", "predicate": "riding", "object": "bike.2"},\n {"subject": "person.1", "predicate": "wearing", "object": "helmet.3"}\n ]\n}\n```\n\nNow, generate the complete scene graph for the provided image:\n' 17 | 18 | 19 | 20 | psg_categories = {"thing_classes": ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"], "stuff_classes": ["banner", "blanket", "bridge", "cardboard", "counter", "curtain", "door-stuff", "floor-wood", "flower", "fruit", "gravel", "house", "light", "mirror-stuff", "net", "pillow", "platform", "playingfield", "railroad", "river", "road", "roof", "sand", "sea", "shelf", "snow", "stairs", "tent", "towel", "wall-brick", "wall-stone", "wall-tile", "wall-wood", "water-other", "window-blind", "window-other", "tree-merged", "fence-merged", "ceiling-merged", "sky-other-merged", "cabinet-merged", "table-merged", "floor-other-merged", "pavement-merged", "mountain-merged", "grass-merged", "dirt-merged", "paper-merged", "food-other-merged", "building-other-merged", "rock-merged", "wall-other-merged", "rug-merged"], "predicate_classes": ["over", "in front of", "beside", "on", "in", "attached to", "hanging from", "on back of", "falling off", "going down", "painted on", "walking on", "running on", "crossing", "standing on", "lying on", "sitting on", "flying over", "jumping over", "jumping from", "wearing", "holding", "carrying", "looking at", "guiding", "kissing", "eating", "drinking", "feeding", "biting", "catching", "picking", "playing with", "chasing", "climbing", "cleaning", "playing", "touching", "pushing", "pulling", "opening", "cooking", "talking to", "throwing", "slicing", "driving", "riding", "parked on", "driving on", "about to hit", "kicking", "swinging", "entering", "exiting", "enclosing", "leaning on"]} 21 | 22 | PSG_OBJ_CATEGORIES = psg_categories['thing_classes'] + psg_categories['stuff_classes'] 23 | PSG_REL_CATEGORIES = psg_categories['predicate_classes'] 24 | 25 | 26 | def format_prompt_close_sg(obj_cls, rel_cls): 27 | return PROMPT_CLOSE_TEMPLATE.replace("{OBJ_CLS}", json.dumps(obj_cls)).replace("{REL_CLS}", json.dumps(rel_cls)) 28 | 29 | 30 | PROMPT_CLOSE_PSG = format_prompt_close_sg(PSG_OBJ_CATEGORIES, PSG_REL_CATEGORIES) 31 | PROMPT_CLOSE_VG150 = format_prompt_close_sg(VG150_OBJ_CATEGORIES[1:], VG150_PREDICATES[1:]) 32 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | aiohttp 2 | accelerate==1.4.0 3 | datasets==3.3.2 4 | deepspeed==0.15.4 5 | numpy==1.26.4 6 | pillow==10.4.0 7 | qwen_vl_utils 8 | tqdm 9 | typing_extensions>=4.12.2 10 | pydantic>=2.10.6 11 | pydantic_core>=2.27.2 12 | thinc>=8.2.2 13 | packaging==24.2 14 | PyYAML==6.0.2 15 | spacy 16 | scipy 17 | safetensors==0.5.3 18 | fastapi 19 | openai==1.65.5 20 | vllm==0.7.3 21 | wandb==0.18.3 22 | pycocotools 23 | tabulate 24 | nltk 25 | matplotlib 26 | -------------------------------------------------------------------------------- /scripts/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/scripts/.DS_Store -------------------------------------------------------------------------------- /scripts/debug/debug.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt" 6 | 7 | export CUDA_VISIBLE_DEVICES=0 8 | 9 | accelerate launch --num_processes=1 open_r1/grpo.py \ 10 | --output_dir models/qwen2vl-sgg-g8 \ 11 | --model_name_or_path "Qwen/Qwen2-VL-2B-Instruct" \ 12 | --dataset_name $DATA_PATH \ 13 | --max_prompt_length 2048 \ 14 | --max_completion_length 1024 \ 15 | --per_device_train_batch_size 2 \ 16 | --gradient_accumulation_steps 1 \ 17 | --logging_steps 1 \ 18 | --use_vllm true \ 19 | --use_local_vllm true\ 20 | --use_liger_loss true\ 21 | --vllm_gpu_memory_utilization 0.25\ 22 | --bf16 \ 23 | --report_to wandb \ 24 | --gradient_checkpointing true \ 25 | --max_pixels 401408 \ 26 | --temperature 0.7 \ 27 | --top_p 0.01 \ 28 | --top_k 1 \ 29 | --num_train_epochs 2 \ 30 | --run_name Qwen2-VL-2B-GRPO-SGG-debug \ 31 | --save_steps 100 \ 32 | --num_generations 2 33 | -------------------------------------------------------------------------------- /scripts/debug/debug_gh200.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME=$SCRATCH/huggingface 4 | # ---------- Environment Setup ---------- 5 | export NCCL_ASYNC_ERROR_HANDLING=1 6 | export DEBUG_MODE=True 7 | export WANDB_PROJECT=RL4SGG 8 | 9 | 10 | GROUP_SIZE=8 11 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 12 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 13 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-gh200" 14 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 15 | mkdir -p "$OUTPUT_DIR" 16 | 17 | TP_SIZE=1 18 | PORT_BASE=8000 19 | MAX_PIXELS=$((512 * 28 * 28)) 20 | 21 | 22 | MIXED_NODES=1 # Set this dynamically if needed 23 | 24 | 25 | HEAD_NODE_IP=0.0.0.0 26 | MASTER_PORT=29500 27 | 28 | SERVER_IP=$(hostname -I | awk '{print $1}') 29 | SERVER_PORT='8000' 30 | 31 | 32 | # zero2: 33 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 34 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 35 | TRAIN_CMD="open_r1/grpo.py \ 36 | --output_dir ${OUTPUT_DIR} \ 37 | --model_name_or_path ${MODEL_PATH} \ 38 | --dataset_name ${DATA_PATH} \ 39 | --max_prompt_length 2048 \ 40 | --max_completion_length 1024 \ 41 | --per_device_train_batch_size 16 \ 42 | --deepspeed ./local_scripts/zero2.json \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 1 \ 45 | --use_vllm true \ 46 | --vllm_server_host ${SERVER_IP} \ 47 | --vllm_server_port ${SERVER_PORT} \ 48 | --vllm_server_timeout 600 \ 49 | --vllm_locate_same_node true\ 50 | --vllm_locate_same_remain_gpus 3\ 51 | --bf16 true\ 52 | --tf32 true\ 53 | --report_to wandb \ 54 | --gradient_checkpointing true \ 55 | --max_pixels ${MAX_PIXELS} \ 56 | --temperature 0.3 \ 57 | --top_p 0.001 \ 58 | --top_k 1 \ 59 | --num_train_epochs 1 \ 60 | --run_name ${RUN_NAME} \ 61 | --save_steps 100 \ 62 | --num_generations ${GROUP_SIZE} \ 63 | --num_iterations 1 \ 64 | --beta 0.0" 65 | 66 | 67 | log_file="vllm_node_0.log" 68 | 69 | # vLLM: GPUs 3 70 | CUDA_VISIBLE_DEVICES=3 python src/vllm_server_v2.py \ 71 | --model ${MODEL_PATH} \ 72 | --gpu_memory_utilization 0.9 \ 73 | --enable-prefix-caching true \ 74 | --dtype 'bfloat16' \ 75 | --max_model_len 4096 \ 76 | --tensor_parallel_size ${TP_SIZE} \ 77 | --host '0.0.0.0' \ 78 | --port ${PORT_BASE} > ${log_file} 2>&1 & 79 | 80 | echo "waiting for vLLM servers..." 81 | #sleep 60 82 | echo "start training..." 83 | # Training: GPUs 0-3 84 | CUDA_VISIBLE_DEVICES=0,1,2 torchrun --nnodes=1 --nproc_per_node=3 \ 85 | --node_rank=0 \ 86 | --master_addr=${HEAD_NODE_IP} \ 87 | --master_port=${MASTER_PORT} \ 88 | ${TRAIN_CMD} > debug-gh200.log 2>&1 & 89 | 90 | 91 | -------------------------------------------------------------------------------- /scripts/debug/debug_gh200_local_vllm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME=$SCRATCH/huggingface 4 | # ---------- Environment Setup ---------- 5 | export NCCL_ASYNC_ERROR_HANDLING=1 6 | export DEBUG_MODE=True 7 | export WANDB_PROJECT=RL4SGG 8 | 9 | 10 | GROUP_SIZE=8 11 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 12 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 13 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-gh200" 14 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 15 | mkdir -p "$OUTPUT_DIR" 16 | 17 | MAX_PIXELS=$((512 * 28 * 28)) 18 | 19 | 20 | 21 | HEAD_NODE_IP=0.0.0.0 22 | MASTER_PORT=29500 23 | 24 | 25 | 26 | # GH200 has a very high bandwidth between CPU and GPU, we should use it! 27 | # zero2: 28 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 29 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 30 | # bsz_per_devie=16, ~40h with 4x GPUs 31 | TRAIN_CMD="open_r1/grpo.py \ 32 | --output_dir ${OUTPUT_DIR} \ 33 | --model_name_or_path ${MODEL_PATH} \ 34 | --dataset_name ${DATA_PATH} \ 35 | --max_prompt_length 2048 \ 36 | --max_completion_length 1024 \ 37 | --per_device_train_batch_size 16 \ 38 | --deepspeed ./local_scripts/zero2.json \ 39 | --gradient_accumulation_steps 1 \ 40 | --logging_steps 1 \ 41 | --use_vllm true \ 42 | --use_local_vllm true\ 43 | --bf16 true\ 44 | --tf32 true\ 45 | --report_to wandb \ 46 | --gradient_checkpointing true \ 47 | --max_pixels ${MAX_PIXELS} \ 48 | --temperature 0.3 \ 49 | --top_p 0.001 \ 50 | --top_k 1 \ 51 | --num_train_epochs 1 \ 52 | --run_name ${RUN_NAME} \ 53 | --save_steps 100 \ 54 | --num_generations ${GROUP_SIZE} \ 55 | --num_iterations 1 \ 56 | --beta 0.0\ 57 | --use_liger_loss false\ 58 | --vllm_max_model_len 4096 \ 59 | --vllm_gpu_memory_utilization 0.25" 60 | 61 | 62 | echo "start training..." 63 | # Training: GPUs 0-3, batch size: 16*4//8=8 64 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes=1 --nproc_per_node=4 \ 65 | --node_rank=0 \ 66 | --master_addr=${HEAD_NODE_IP} \ 67 | --master_port=${MASTER_PORT} \ 68 | ${TRAIN_CMD} > debug-gh200.log 2>&1 & 69 | -------------------------------------------------------------------------------- /scripts/debug/test_vllm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | export HF_HOME=$SCRATCH/huggingface 4 | # ---------- Environment Setup ---------- 5 | export NCCL_ASYNC_ERROR_HANDLING=1 6 | export DEBUG_MODE=True 7 | export WANDB_PROJECT=RL4SGG 8 | 9 | 10 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 11 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 12 | 13 | TP_SIZE=1 14 | PORT_BASE=8000 15 | 16 | MAX_PIXELS=$((512 * 28 * 28)) 17 | 18 | 19 | 20 | 21 | HEAD_NODE_IP=0.0.0.0 22 | MASTER_PORT=29500 23 | 24 | 25 | 26 | 27 | server_ip=$(hostname -I | awk '{print $1}') 28 | 29 | 30 | # Launch vLLM servers 31 | for i in {0..1}; do 32 | log_file="vllm_server_${i}.log" 33 | port=$((PORT_BASE + i)) 34 | CUDA_VISIBLE_DEVICES=${i} python src/vllm_server_v2.py \ 35 | --model "${MODEL_PATH}" \ 36 | --gpu_memory_utilization 0.9 \ 37 | --enable_prefix_caching true \ 38 | --dtype 'bfloat16' \ 39 | --max_model_len 4096 \ 40 | --tensor_parallel_size "${TP_SIZE}" \ 41 | --host '0.0.0.0' \ 42 | --port "${port}" > "${log_file}" 2>&1 & 43 | done 44 | 45 | echo "Waiting for vLLM servers to initialize..." 46 | #sleep 60 47 | 48 | # Run tests 49 | for i in {2..3}; do 50 | log_file="vllm_client_${i}.log" 51 | port=$((PORT_BASE + i - 2)) 52 | group_port=$(( 51200 + i)) 53 | CUDA_VISIBLE_DEVICES=${i} python tests/test_vllm.py \ 54 | --hosts ${server_ip} \ 55 | --server_port "${port}" \ 56 | --group_port ${group_port}\ 57 | --model_name_or_path "${MODEL_PATH}" > "${log_file}" 2>&1 & 58 | done 59 | 60 | 61 | -------------------------------------------------------------------------------- /scripts/grpo/run_vllm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH --job-name=VLLM 4 | #SBATCH --time=24:00:00 5 | 6 | #SBATCH --nodes=8 7 | #SBATCH --ntasks=16 8 | #SBATCH --ntasks-per-node=2 9 | #SBATCH --gpus-per-node=rtx_4090:8 10 | #SBATCH --cpus-per-task=4 11 | #SBATCH --mem-per-cpu=25000M 12 | #SBATCH --output=VLLM_%j_%N.out 13 | 14 | # ------------------ Config ------------------ 15 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 16 | TP_SIZE=4 # Tensor parallelism (4 GPUs per process) 17 | PORT_BASE=8000 # Base port to offset by local rank 18 | 19 | # ------------------ Environment ------------------ 20 | nodes=($(scontrol show hostnames "$SLURM_JOB_NODELIST")) 21 | DP_WORLD_SIZE=$(( ${#nodes[@]} * 2 )) # 2 processes per node 22 | 23 | head_node=${nodes[0]} 24 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 25 | MASTER_PORT=$(shuf -i 20000-40000 -n 1) 26 | MASTER_IP=${head_node_ip} 27 | 28 | echo "Head node IP: ${MASTER_IP}, Port: ${MASTER_PORT}" 29 | echo "DP_WORLD_SIZE : ${DP_WORLD_SIZE}" 30 | echo "Node list: ${nodes[@]}" 31 | 32 | # ------------------ Export IPs and Ports ------------------ 33 | IP_FILE=ip_port_list.txt 34 | 35 | > ${IP_FILE} # Reset output list 36 | 37 | RANK=0 38 | for node in "${nodes[@]}"; do 39 | node_ip=$(srun --nodes=1 --ntasks=1 -w "$node" hostname --ip-address) 40 | 41 | for local_rank in 0 1; do 42 | PORT=$((PORT_BASE + local_rank)) 43 | echo "${node_ip}:${PORT}" >> ${IP_FILE} 44 | done 45 | done 46 | 47 | # ------------------ Launch per-process ------------------ 48 | RANK=0 49 | for node in "${nodes[@]}"; do 50 | echo "Launching 2 ranks on node $node" 51 | 52 | srun --nodes=1 --ntasks=1 --ntasks-per-node=1 -w "$node" \ 53 | bash -c " 54 | for local_rank in 0 1; do 55 | ( 56 | export RANK=\$(( ${RANK} + local_rank )) 57 | export DP_WORLD_SIZE=${DP_WORLD_SIZE} 58 | export TP_SIZE=${TP_SIZE} 59 | export MASTER_ADDR=${MASTER_IP} 60 | export MASTER_PORT=${MASTER_PORT} 61 | 62 | export CUDA_VISIBLE_DEVICES=\$(seq -s, \$(( local_rank * 4 )) \$(( local_rank * 4 + 3 ))) 63 | PORT=\$(( ${PORT_BASE} + local_rank )) 64 | 65 | echo \"Starting rank \$RANK on $node with CUDA_VISIBLE_DEVICES=\$CUDA_VISIBLE_DEVICES, port \$PORT\" 66 | 67 | python src/vllm_server_v2.py \ 68 | --model '${MODEL_PATH}' \ 69 | --gpu_memory_utilization 0.85 \ 70 | --dtype 'bfloat16' \ 71 | --max_model_len 4096 \ 72 | --tensor_parallel_size ${TP_SIZE} \ 73 | --host '0.0.0.0' \ 74 | --port \$PORT 75 | ) & 76 | done 77 | wait 78 | " & 79 | 80 | RANK=$((RANK + 2)) 81 | done 82 | 83 | wait 84 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_A100_det_cls 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --nodes=4 # each has 4x A100 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=128 11 | 12 | #SBATCH --partition=normal 13 | #SBATCH --output=RL_A100_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | set -x 18 | # ---------- Environment Setup ---------- 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=8 26 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 28 | RUN_NAME="qwen2vl-7b-grpo-det-cls-g8-n1-bs32-A100-SXM4" 29 | 30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 31 | export LOG_PATH=${OUTPUT_DIR}/debug.log 32 | 33 | mkdir -p "$OUTPUT_DIR" 34 | 35 | MAX_PIXELS=$((512 * 28 * 28)) 36 | 37 | 38 | MASTER_PORT=29500 39 | 40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 41 | NUM_TRAIN_NODES=${#NODELIST[@]} 42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 43 | 44 | # Choose the first training node as the rendezvous head node 45 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 46 | 47 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 48 | 49 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 50 | echo "MASTER_ADDR: $MASTER_ADDR" 51 | 52 | 53 | 54 | # batch size: PER_GPU(2)*GPUS(4)*NODES(8)*ACC(4) //8=32 55 | # local vLLM: 80G*0.25=20G 56 | # 57 | TRAIN_CMD="open_r1/grpo.py \ 58 | --task_type det cls \ 59 | --output_dir ${OUTPUT_DIR} \ 60 | --model_name_or_path ${MODEL_PATH} \ 61 | --dataset_name ${DATA_PATH} \ 62 | --max_prompt_length 2048 \ 63 | --max_completion_length 1024 \ 64 | --custom_per_device_train_batch_size 4 \ 65 | --deepspeed ./local_scripts/zero2.json \ 66 | --gradient_accumulation_steps 4 \ 67 | --learning_rate 3e-7 \ 68 | --logging_steps 1 \ 69 | --use_vllm true \ 70 | --use_local_vllm true\ 71 | --bf16 true\ 72 | --tf32 true\ 73 | --report_to wandb \ 74 | --gradient_checkpointing true \ 75 | --max_pixels ${MAX_PIXELS} \ 76 | --temperature 1.0 \ 77 | --top_p 0.9 \ 78 | --top_k 50 \ 79 | --num_train_epochs 1 \ 80 | --run_name ${RUN_NAME} \ 81 | --save_steps 100 \ 82 | --num_generations ${GROUP_SIZE} \ 83 | --num_iterations 1 \ 84 | --beta 0.0\ 85 | --vllm_max_model_len 4096 \ 86 | --vllm_gpu_memory_utilization 0.25\ 87 | --save_only_model true\ 88 | --seed 42" 89 | 90 | 91 | echo "start training..." 92 | 93 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 94 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 95 | --node_rank ${SLURM_NODEID} \ 96 | --rdzv_id $RANDOM \ 97 | --rdzv_backend c10d \ 98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 99 | ${TRAIN_CMD} 100 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_2B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=A100_2B_1k_lr6e-7_psg_debug 5 | #SBATCH --time=00:30:00 6 | 7 | #SBATCH --exclude=nid002289,nid002325 8 | #SBATCH --nodes=2 # 4 nodes, each has 4x A100 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=128 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_A100_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | set -x 19 | # ---------- Environment Setup ---------- 20 | export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 21 | export DEBUG_MODE=True 22 | export WANDB_PROJECT=RL4SGG 23 | 24 | export NCCL_DEBUG=INFO 25 | 26 | 27 | GPUS_PER_NODE=4 28 | GROUP_SIZE=8 29 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 30 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt" 31 | DATA_PATH="JosephZ/psg_train_sg" 32 | 33 | RUN_NAME="qwen2vl-2b-grpo-g8-n1-bs32-1k-lr6e-7-psg-debug-A100-SXM4" 34 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 35 | mkdir -p "$OUTPUT_DIR" 36 | 37 | export LOG_PATH=${OUTPUT_DIR}/debug.log 38 | 39 | export FORMAT_REWARD_WEIGHT=1.0 40 | export STRICT_FORMAT=True 41 | 42 | MAX_PIXELS=$((512 * 28 * 28)) 43 | 44 | 45 | MASTER_PORT=29500 46 | 47 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 48 | NUM_TRAIN_NODES=${#NODELIST[@]} 49 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 50 | 51 | # Choose the first training node as the rendezvous head node 52 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 53 | 54 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 55 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 56 | 57 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 58 | echo "MASTER_ADDR: $MASTER_ADDR" 59 | 60 | 61 | 62 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32 63 | # local vLLM: 80G*0.2=16G 64 | # 65 | # ['format_reward', 'node_acc_reward', "node_box_reward", "edge_reward"] 66 | TRAIN_CMD="open_r1/grpo.py \ 67 | --task_type sgg \ 68 | --output_dir ${OUTPUT_DIR} \ 69 | --model_name_or_path ${MODEL_PATH} \ 70 | --dataset_name ${DATA_PATH} \ 71 | --max_prompt_length 2048 \ 72 | --max_completion_length 1024 \ 73 | --custom_per_device_train_batch_size 4 \ 74 | --deepspeed ./local_scripts/zero2.json \ 75 | --gradient_accumulation_steps 4 \ 76 | --learning_rate 6e-7 \ 77 | --logging_steps 1 \ 78 | --use_vllm true \ 79 | --use_local_vllm true\ 80 | --bf16 true\ 81 | --tf32 true\ 82 | --report_to wandb \ 83 | --gradient_checkpointing true \ 84 | --max_pixels ${MAX_PIXELS} \ 85 | --temperature 1.0 \ 86 | --top_p 0.9 \ 87 | --top_k 50 \ 88 | --num_train_epochs 1.0 \ 89 | --run_name ${RUN_NAME} \ 90 | --save_steps 100 \ 91 | --num_generations ${GROUP_SIZE} \ 92 | --num_iterations 1 \ 93 | --beta 0.0 \ 94 | --vllm_max_model_len 4096 \ 95 | --vllm_gpu_memory_utilization 0.2 \ 96 | --save_only_model true\ 97 | --seed 42" 98 | 99 | 100 | echo "start training..." 101 | 102 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 103 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 104 | --node_rank ${SLURM_NODEID} \ 105 | --rdzv_id $RANDOM \ 106 | --rdzv_backend c10d \ 107 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 108 | ${TRAIN_CMD} 109 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_2B_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=A100_2B_SFT_RL_2k_lr6e-7 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --exclude=nid002289,nid002325 8 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=128 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_A100_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | # ---------- Environment Setup ---------- 19 | export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=8 26 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 27 | MODEL_PATH=$1 28 | 29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | RUN_NAME="qwen2vl-2b-sft-grpo-sgg-g8-n1-bs32-2k-lr6e-7-A100-SXM4" 31 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 32 | mkdir -p "$OUTPUT_DIR" 33 | 34 | export LOG_PATH=${OUTPUT_DIR}/debug.log 35 | 36 | MAX_PIXELS=$((512 * 28 * 28)) 37 | 38 | 39 | MASTER_PORT=29500 40 | 41 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 42 | NUM_TRAIN_NODES=${#NODELIST[@]} 43 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 44 | 45 | # Choose the first training node as the rendezvous head node 46 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 47 | 48 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 49 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 50 | 51 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 52 | echo "MASTER_ADDR: $MASTER_ADDR" 53 | 54 | 55 | 56 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32 57 | # local vLLM: 80G*0.2=16G 58 | # 59 | TRAIN_CMD="open_r1/grpo.py \ 60 | --task_type sgg \ 61 | --output_dir ${OUTPUT_DIR} \ 62 | --model_name_or_path ${MODEL_PATH} \ 63 | --dataset_name ${DATA_PATH} \ 64 | --max_prompt_length 2048 \ 65 | --max_completion_length 2048 \ 66 | --custom_per_device_train_batch_size 4 \ 67 | --deepspeed ./local_scripts/zero2.json \ 68 | --gradient_accumulation_steps 4 \ 69 | --learning_rate 6e-7 \ 70 | --logging_steps 1 \ 71 | --use_vllm true \ 72 | --use_local_vllm true\ 73 | --bf16 true\ 74 | --tf32 true\ 75 | --report_to wandb \ 76 | --gradient_checkpointing true \ 77 | --max_pixels ${MAX_PIXELS} \ 78 | --temperature 1.0 \ 79 | --top_p 0.9 \ 80 | --top_k 50 \ 81 | --num_train_epochs 1 \ 82 | --run_name ${RUN_NAME} \ 83 | --save_steps 100 \ 84 | --num_generations ${GROUP_SIZE} \ 85 | --num_iterations 1 \ 86 | --beta 0.0\ 87 | --vllm_max_model_len 4096 \ 88 | --vllm_gpu_memory_utilization 0.2 \ 89 | --save_only_model true\ 90 | --seed 42" 91 | 92 | 93 | echo "start training..." 94 | 95 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 96 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 97 | --node_rank ${SLURM_NODEID} \ 98 | --rdzv_id $RANDOM \ 99 | --rdzv_backend c10d \ 100 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 101 | ${TRAIN_CMD} 102 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_2B_close.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=A100_2B 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=128 11 | 12 | #SBATCH --partition=normal 13 | #SBATCH --output=RL_A100_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | # ---------- Environment Setup ---------- 18 | export NCCL_ASYNC_ERROR_HANDLING=1 19 | export DEBUG_MODE=True 20 | export WANDB_PROJECT=RL4SGG 21 | 22 | 23 | GPUS_PER_NODE=4 24 | GROUP_SIZE=8 25 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 26 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 27 | RUN_NAME="qwen2vl-2b-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4" 28 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 29 | mkdir -p "$OUTPUT_DIR" 30 | 31 | export LOG_PATH=${OUTPUT_DIR}/debug.log 32 | 33 | MAX_PIXELS=$((512 * 28 * 28)) 34 | 35 | 36 | MASTER_PORT=29500 37 | 38 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 39 | NUM_TRAIN_NODES=${#NODELIST[@]} 40 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 41 | 42 | # Choose the first training node as the rendezvous head node 43 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 44 | 45 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 46 | #MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 47 | 48 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 49 | echo "MASTER_ADDR: $MASTER_ADDR" 50 | 51 | 52 | 53 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32 54 | # local vLLM: 80G*0.2=16G 55 | # 56 | TRAIN_CMD="open_r1/grpo.py \ 57 | --output_dir ${OUTPUT_DIR} \ 58 | --model_name_or_path ${MODEL_PATH} \ 59 | --dataset_name ${DATA_PATH} \ 60 | --max_prompt_length 2048 \ 61 | --max_completion_length 1024 \ 62 | --custom_per_device_train_batch_size 4 \ 63 | --deepspeed ./local_scripts/zero2.json \ 64 | --gradient_accumulation_steps 4 \ 65 | --learning_rate 3e-7 \ 66 | --use_predefined_cats true \ 67 | --logging_steps 1 \ 68 | --use_vllm true \ 69 | --use_local_vllm true\ 70 | --bf16 true\ 71 | --tf32 true\ 72 | --report_to wandb \ 73 | --gradient_checkpointing false \ 74 | --max_pixels ${MAX_PIXELS} \ 75 | --temperature 1.0 \ 76 | --top_p 0.9 \ 77 | --top_k 50 \ 78 | --num_train_epochs 1 \ 79 | --run_name ${RUN_NAME} \ 80 | --save_steps 100 \ 81 | --num_generations ${GROUP_SIZE} \ 82 | --num_iterations 1 \ 83 | --beta 0.0\ 84 | --vllm_max_model_len 4096 \ 85 | --vllm_gpu_memory_utilization 0.2 \ 86 | --save_only_model true\ 87 | --seed 42" 88 | 89 | 90 | echo "start training..." 91 | 92 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 93 | --node_rank ${SLURM_NODEID} \ 94 | --rdzv_id $RANDOM \ 95 | --rdzv_backend c10d \ 96 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 97 | ${TRAIN_CMD} 98 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_2B_close_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=A100_2B_SFT_CLOSE_RL 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=128 11 | 12 | #SBATCH --partition=normal 13 | #SBATCH --output=RL_A100_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | # ---------- Environment Setup ---------- 18 | export NCCL_ASYNC_ERROR_HANDLING=1 19 | export DEBUG_MODE=True 20 | export WANDB_PROJECT=RL4SGG 21 | 22 | 23 | GPUS_PER_NODE=4 24 | GROUP_SIZE=8 25 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 26 | MODEL_PATH=$1 27 | 28 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 29 | RUN_NAME="qwen2vl-2b-sft-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4" 30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 31 | mkdir -p "$OUTPUT_DIR" 32 | 33 | export LOG_PATH=${OUTPUT_DIR}/debug.log 34 | 35 | MAX_PIXELS=$((512 * 28 * 28)) 36 | 37 | 38 | MASTER_PORT=29500 39 | 40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 41 | NUM_TRAIN_NODES=${#NODELIST[@]} 42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 43 | 44 | # Choose the first training node as the rendezvous head node 45 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 46 | 47 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 48 | 49 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 50 | echo "MASTER_ADDR: $MASTER_ADDR" 51 | 52 | 53 | 54 | # batch size: PER_GPU(4)*GPUS(4)*NODES(4)*ACC(4) // GROUP_SIZE(8) = 32 55 | # local vLLM: 80G*0.2=16G 56 | # 57 | TRAIN_CMD="open_r1/grpo.py \ 58 | --output_dir ${OUTPUT_DIR} \ 59 | --model_name_or_path ${MODEL_PATH} \ 60 | --dataset_name ${DATA_PATH} \ 61 | --max_prompt_length 2048 \ 62 | --max_completion_length 1024 \ 63 | --custom_per_device_train_batch_size 4 \ 64 | --deepspeed ./local_scripts/zero2.json \ 65 | --gradient_accumulation_steps 4 \ 66 | --learning_rate 3e-7 \ 67 | --logging_steps 1 \ 68 | --use_vllm true \ 69 | --use_local_vllm true\ 70 | --bf16 true\ 71 | --tf32 true\ 72 | --report_to wandb \ 73 | --gradient_checkpointing true \ 74 | --max_pixels ${MAX_PIXELS} \ 75 | --temperature 1.0 \ 76 | --top_p 0.9 \ 77 | --top_k 50 \ 78 | --num_train_epochs 1 \ 79 | --run_name ${RUN_NAME} \ 80 | --save_steps 100 \ 81 | --num_generations ${GROUP_SIZE} \ 82 | --num_iterations 1 \ 83 | --beta 0.0\ 84 | --vllm_max_model_len 4096 \ 85 | --vllm_gpu_memory_utilization 0.25 \ 86 | --save_only_model true\ 87 | --use_predefined_cats true \ 88 | --seed 42" 89 | 90 | 91 | echo "start training..." 92 | 93 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 94 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 95 | --node_rank ${SLURM_NODEID} \ 96 | --rdzv_id $RANDOM \ 97 | --rdzv_backend c10d \ 98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 99 | ${TRAIN_CMD} 100 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_SFT_RL 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --nodes=4 # 4 nodes, each has 4x A100 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=128 11 | 12 | #SBATCH --partition=normal 13 | #SBATCH --output=RL_A100_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | set -x 18 | # ---------- Environment Setup ---------- 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=8 26 | #MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 27 | MODEL_PATH=$1 28 | 29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | RUN_NAME="qwen2vl-7b-sft-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4" 31 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 32 | mkdir -p "$OUTPUT_DIR" 33 | 34 | MAX_PIXELS=$((512 * 28 * 28)) 35 | 36 | 37 | MASTER_PORT=29500 38 | 39 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 40 | NUM_TRAIN_NODES=${#NODELIST[@]} 41 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 42 | 43 | # Choose the first training node as the rendezvous head node 44 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 45 | 46 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 47 | 48 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 49 | echo "MASTER_ADDR: $MASTER_ADDR" 50 | 51 | 52 | 53 | # batch size: PER_GPU(2)*GPUS(4)*NODES(4)*ACC(8) //8=32 54 | # local vLLM: 80G*0.25=20G 55 | # 56 | TRAIN_CMD="open_r1/grpo.py \ 57 | --output_dir ${OUTPUT_DIR} \ 58 | --model_name_or_path ${MODEL_PATH} \ 59 | --dataset_name ${DATA_PATH} \ 60 | --max_prompt_length 2048 \ 61 | --max_completion_length 1024 \ 62 | --custom_per_device_train_batch_size 2 \ 63 | --deepspeed ./local_scripts/zero2.json \ 64 | --gradient_accumulation_steps 8 \ 65 | --learning_rate 3e-7 \ 66 | --logging_steps 1 \ 67 | --use_vllm true \ 68 | --use_local_vllm true\ 69 | --bf16 true\ 70 | --tf32 true\ 71 | --report_to wandb \ 72 | --gradient_checkpointing true \ 73 | --max_pixels ${MAX_PIXELS} \ 74 | --temperature 1.0 \ 75 | --top_p 0.9 \ 76 | --top_k 50 \ 77 | --num_train_epochs 1 \ 78 | --run_name ${RUN_NAME} \ 79 | --save_steps 100 \ 80 | --num_generations ${GROUP_SIZE} \ 81 | --num_iterations 1 \ 82 | --beta 0.0\ 83 | --vllm_max_model_len 4096 \ 84 | --vllm_gpu_memory_utilization 0.25\ 85 | --save_only_model true\ 86 | --seed 42" 87 | 88 | 89 | echo "start training..." 90 | 91 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 92 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 93 | --node_rank ${SLURM_NODEID} \ 94 | --rdzv_id $RANDOM \ 95 | --rdzv_backend c10d \ 96 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 97 | ${TRAIN_CMD} 98 | -------------------------------------------------------------------------------- /scripts/grpo/train_a100_close.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_A100 5 | #SBATCH --time=24:00:00 6 | 7 | #SBATCH --nodes=8 # each has 4x A100 8 | #SBATCH --ntasks-per-node=1 9 | #SBATCH --gpus-per-node=4 10 | #SBATCH --cpus-per-task=128 11 | 12 | #SBATCH --partition=normal 13 | #SBATCH --output=RL_A100_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | set -x 18 | # ---------- Environment Setup ---------- 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=8 26 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 28 | RUN_NAME="qwen2vl-7b-close-grpo-g${GROUP_SIZE}-n1-bs32-A100-SXM4" 29 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 30 | mkdir -p "$OUTPUT_DIR" 31 | 32 | MAX_PIXELS=$((512 * 28 * 28)) 33 | 34 | 35 | MASTER_PORT=29500 36 | 37 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 38 | NUM_TRAIN_NODES=${#NODELIST[@]} 39 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 40 | 41 | # Choose the first training node as the rendezvous head node 42 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 43 | 44 | #MASTER_ADDR=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 45 | 46 | MASTER_ADDR=$(echo "${SLURM_NODELIST}" | sed 's/[],].*//g; s/\[//g') 47 | echo "MASTER_ADDR: $MASTER_ADDR" 48 | 49 | 50 | 51 | # batch size: PER_GPU(2)*GPUS(4)*NODES(8)*ACC(4) //8=32 52 | # local vLLM: 80G*0.25=20G 53 | # 54 | TRAIN_CMD="open_r1/grpo.py \ 55 | --output_dir ${OUTPUT_DIR} \ 56 | --model_name_or_path ${MODEL_PATH} \ 57 | --dataset_name ${DATA_PATH} \ 58 | --max_prompt_length 2048 \ 59 | --max_completion_length 1024 \ 60 | --custom_per_device_train_batch_size 4 \ 61 | --deepspeed ./local_scripts/zero2.json \ 62 | --gradient_accumulation_steps 2 \ 63 | --learning_rate 3e-7 \ 64 | --use_predefined_cats true \ 65 | --logging_steps 1 \ 66 | --use_vllm true \ 67 | --use_local_vllm true\ 68 | --bf16 true\ 69 | --tf32 true\ 70 | --report_to wandb \ 71 | --gradient_checkpointing true \ 72 | --max_pixels ${MAX_PIXELS} \ 73 | --temperature 1.0 \ 74 | --top_p 0.9 \ 75 | --top_k 50 \ 76 | --num_train_epochs 1 \ 77 | --run_name ${RUN_NAME} \ 78 | --save_steps 100 \ 79 | --num_generations ${GROUP_SIZE} \ 80 | --num_iterations 1 \ 81 | --beta 0.0\ 82 | --vllm_max_model_len 4096 \ 83 | --vllm_gpu_memory_utilization 0.25\ 84 | --save_only_model true\ 85 | --seed 42" 86 | 87 | 88 | echo "start training..." 89 | 90 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 91 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 92 | --node_rank ${SLURM_NODEID} \ 93 | --rdzv_id $RANDOM \ 94 | --rdzv_backend c10d \ 95 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 96 | ${TRAIN_CMD} 97 | -------------------------------------------------------------------------------- /scripts/grpo/train_fsdp.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=GRPO_train 6 | #SBATCH --time=24:00:00 7 | #SBATCH --nodes=16 # 4 training nodes + 1 vLLM node = 5 nodes 8 | #SBATCH --ntasks=16 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=rtx_4090:8 11 | #SBATCH --cpus-per-task=8 12 | #SBATCH --mem-per-cpu=25000M 13 | #SBATCH --output=RL_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | # force crashing on nccl issues like hanging broadcast 18 | export NCCL_ASYNC_ERROR_HANDLING=1 19 | #export NCCL_IB_DISABLE=1 20 | 21 | # export NCCL_DEBUG=INFO 22 | # export NCCL_DEBUG_SUBSYS=COLL 23 | # export NCCL_SOCKET_NTHREADS=1 24 | # export NCCL_NSOCKS_PERTHREAD=1 25 | # export CUDA_LAUNCH_BLOCKING=1 26 | 27 | # wait for vLLM servers 28 | #sleep 60 29 | 30 | # Read IPs from file and join them with commas 31 | #ip_str=$(paste -sd, ip_list.txt) 32 | #echo "vLLM servers: $ip_str" 33 | 34 | FILE="ip_port_list.txt" 35 | 36 | SERVER_IP="" 37 | SERVER_PORT="" 38 | 39 | while IFS=: read -r ip port; do 40 | SERVER_IP+="${ip}," 41 | SERVER_PORT+="${port}," 42 | done < "$FILE" 43 | 44 | # Remove trailing commas 45 | SERVER_IP="${SERVER_IP%,}" 46 | SERVER_PORT="${SERVER_PORT%,}" 47 | 48 | echo "SERVER_IP=$SERVER_IP" 49 | echo "SERVER_PORT=$SERVER_PORT" 50 | 51 | 52 | # Define node counts 53 | NUM_TRAIN_NODES=${SLURM_NNODES} # all nodes 54 | GPUS_PER_NODE=8 55 | 56 | # Get the list of allocated nodes 57 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 58 | 59 | # Assign training nodes (first NUM_TRAIN_NODES nodes) 60 | TRAIN_NODES=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 61 | 62 | # Choose the first training node as the rendezvous head node 63 | HEAD_NODE=${TRAIN_NODES[0]} 64 | MASTER_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 65 | 66 | MASTER_PORT=6000 67 | echo "Head Node IP: $MASTER_IP, port: ${MASTER_PORT}" 68 | 69 | # Create a comma-separated list of training nodes for srun 70 | TRAIN_NODES_LIST=$(IFS=, ; echo "${TRAIN_NODES[*]}") 71 | 72 | export NCCL_DEBUG=INFO 73 | echo "environment: $(env | grep NCCL)" 74 | 75 | 76 | 77 | export DEBUG_MODE=True 78 | export WANDB_PROJECT=RL4SGG 79 | 80 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt" 81 | export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 82 | 83 | 84 | # Training setup 85 | NNODES=$SLURM_NNODES 86 | NODE_RANK=$SLURM_PROCID 87 | WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES)) 88 | 89 | MAX_PIXELS=$((512 * 28 * 28)) 90 | 91 | echo "Start training script..." 92 | 93 | LAUNCHER="accelerate launch \ 94 | --multi_gpu \ 95 | --num_machines $NNODES \ 96 | --num_processes $WORLD_SIZE \ 97 | --main_process_ip "$MASTER_IP" \ 98 | --main_process_port $MASTER_PORT \ 99 | --num_processes $WORLD_SIZE \ 100 | --machine_rank \$SLURM_PROCID \ 101 | --role $SLURMD_NODENAME: \ 102 | --rdzv_conf rdzv_backend=c10d \ 103 | --max_restarts 0 \ 104 | --tee 3 \ 105 | --config_file local_scripts/fsdp.yaml \ 106 | " 107 | 108 | CMD=" \ 109 | open_r1/grpo.py \ 110 | --output_dir models/qwen2vl-fsdp-g8 \ 111 | --model_name_or_path ${MODEL_PATH} \ 112 | --dataset_name $DATA_PATH \ 113 | --max_prompt_length 2048 \ 114 | --max_completion_length 1024 \ 115 | --per_device_train_batch_size 1 \ 116 | --gradient_accumulation_steps 1 \ 117 | --logging_steps 1 \ 118 | --use_vllm true \ 119 | --vllm_server_host ${SERVER_IP} \ 120 | --vllm_server_port ${SERVER_PORT} \ 121 | --vllm_server_timeout 600 \ 122 | --bf16 \ 123 | --report_to wandb \ 124 | --gradient_checkpointing true \ 125 | --max_pixels ${MAX_PIXELS} \ 126 | --temperature 0.3 \ 127 | --top_p 0.001 \ 128 | --top_k 1 \ 129 | --num_train_epochs 1 \ 130 | --run_name Qwen2VL-7B-GRPO-fsdp-G8 \ 131 | --save_steps 100 \ 132 | --num_generations 8 133 | " 134 | 135 | 136 | srun --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $CMD" 137 | 138 | echo "END TIME: $(date)" 139 | -------------------------------------------------------------------------------- /scripts/grpo/train_fused.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=GRPO_train_vllm 3 | #SBATCH --time=24:00:00 4 | #SBATCH --nodes=15 5 | #SBATCH --ntasks=15 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --gpus-per-node=rtx_4090:8 8 | #SBATCH --cpus-per-task=15 9 | #SBATCH --mem-per-cpu=16000M 10 | #SBATCH --output=TrainVLLM_%j_%N.out 11 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 12 | 13 | set -euo pipefail 14 | 15 | # ---------- Environment Setup ---------- 16 | export NCCL_ASYNC_ERROR_HANDLING=1 17 | export DEBUG_MODE=True 18 | export WANDB_PROJECT=RL4SGG 19 | 20 | GROUP_SIZE=8 21 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 22 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 23 | RUN_NAME="qwen2vl-7b-grpo-g${GROUP_SIZE}-n1-4090" 24 | OUTPUT_DIR="models/${RUN_NAME}" 25 | mkdir -p "$OUTPUT_DIR" 26 | 27 | TP_SIZE=4 28 | PORT_BASE=8000 29 | MAX_PIXELS=$((512 * 28 * 28)) 30 | 31 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 32 | NUM_NODES=${#NODELIST[@]} 33 | 34 | if (( NUM_NODES % 3 != 0 )); then 35 | echo "Error: number of nodes ($NUM_NODES) must be divisible by 3." 36 | exit 1 37 | fi 38 | 39 | MIXED_NODES=10 # Set this dynamically if needed 40 | # vLLM: 10*4=40 GPUs, 10 servers 41 | # training, x=(8-4)*10//8= 5 42 | # training GPUs: 10*4+ 5*8=80, batch size=80//8*2=20 43 | 44 | HEAD_NODE=${NODELIST[0]} 45 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 46 | RDZV_PORT=29500 47 | IP_FILE="${OUTPUT_DIR}/ip_port_list.txt" 48 | > "$IP_FILE" 49 | 50 | for i in $(seq 0 $((MIXED_NODES - 1))); do 51 | node=${NODELIST[$i]} 52 | ip=$(srun --nodes=1 --ntasks=1 -w "$node" hostname --ip-address) 53 | echo "${ip}:$((PORT_BASE))" >> "$IP_FILE" 54 | done 55 | 56 | SERVER_IP=$(cut -d: -f1 $IP_FILE | paste -sd,) 57 | SERVER_PORT=$(cut -d: -f2 $IP_FILE | paste -sd,) 58 | 59 | TRAIN_CMD="open_r1/grpo.py \ 60 | --output_dir ${OUTPUT_DIR} \ 61 | --model_name_or_path ${MODEL_PATH} \ 62 | --dataset_name ${DATA_PATH} \ 63 | --deepspeed ./local_scripts/zero3.json \ 64 | --max_prompt_length 2048 \ 65 | --max_completion_length 1024 \ 66 | --custom_per_device_train_batch_size 1 \ 67 | --gradient_accumulation_steps 2 \ 68 | --learning_rate 3e-7 \ 69 | --logging_steps 1 \ 70 | --use_vllm true \ 71 | --vllm_server_host ${SERVER_IP} \ 72 | --vllm_server_port ${SERVER_PORT} \ 73 | --vllm_server_timeout 600 \ 74 | --vllm_locate_same_node true\ 75 | --vllm_locate_same_remain_gpus 4\ 76 | --bf16 \ 77 | --report_to wandb \ 78 | --gradient_checkpointing true \ 79 | --max_pixels ${MAX_PIXELS} \ 80 | --temperature 1 \ 81 | --top_p 0.9 \ 82 | --top_k 50 \ 83 | --num_train_epochs 1 \ 84 | --run_name ${RUN_NAME} \ 85 | --save_steps 100 \ 86 | --num_generations 8 \ 87 | --num_iterations 1 \ 88 | --beta 0.0 \ 89 | --save_only_model true \ 90 | --seed 42" 91 | 92 | # ---------- Functions ---------- 93 | launch_mixed_node() { 94 | local i=$1 95 | local node=$2 96 | local log_file="${OUTPUT_DIR}/vllm_node_${i}_${node}.log" 97 | srun --nodes=1 --ntasks=1 -w "$node" bash -c " 98 | export RANK=$i 99 | export NODE_RANK=$i 100 | # vLLM: GPUs 4-7 101 | CUDA_VISIBLE_DEVICES=4,5,6,7 python src/vllm_server_v2.py \ 102 | --model '${MODEL_PATH}' \ 103 | --gpu_memory_utilization 0.85 \ 104 | --enable-prefix-caching true \ 105 | --dtype 'bfloat16' \ 106 | --max_model_len 4096 \ 107 | --tensor_parallel_size ${TP_SIZE} \ 108 | --host '0.0.0.0' \ 109 | --port ${PORT_BASE} > ${log_file} 2>&1 & 110 | 111 | # Training: GPUs 0-3 112 | CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nnodes ${NUM_NODES} --nproc_per_node 4 \ 113 | --node_rank \$NODE_RANK \ 114 | --rdzv_id grpo_run \ 115 | --rdzv_backend c10d \ 116 | --rdzv_endpoint ${HEAD_NODE_IP}:${RDZV_PORT} \ 117 | ${TRAIN_CMD} & 118 | wait 119 | " & 120 | } 121 | 122 | launch_training_node() { 123 | local i=$1 124 | local node=$2 125 | srun --nodes=1 --ntasks=1 -w "$node" bash -c " 126 | export NODE_RANK=$i 127 | CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 torchrun --nnodes ${NUM_NODES} --nproc_per_node 8 \ 128 | --node_rank \$NODE_RANK \ 129 | --rdzv_id grpo_run \ 130 | --rdzv_backend c10d \ 131 | --rdzv_endpoint ${HEAD_NODE_IP}:${RDZV_PORT} \ 132 | ${TRAIN_CMD} 133 | " & 134 | } 135 | 136 | # ---------- Main Launcher ---------- 137 | for i in "${!NODELIST[@]}"; do 138 | if [ $i -lt ${MIXED_NODES} ]; then 139 | launch_mixed_node $i "${NODELIST[$i]}" 140 | else 141 | launch_training_node $i "${NODELIST[$i]}" 142 | fi 143 | done 144 | 145 | wait 146 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_GH200_zero_lr2x_fp8 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --exclude=nid006792,nid007085 8 | 9 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200 10 | #SBATCH --ntasks=8 # Total tasks equals total nodes 11 | #SBATCH --ntasks-per-node=1 12 | #SBATCH --gpus-per-node=4 13 | #SBATCH --cpus-per-task=288 # fixed for GH200 14 | 15 | 16 | #SBATCH --partition=normal 17 | #SBATCH --output=RL_gh200_%j_%N.out 18 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 19 | 20 | 21 | set -x 22 | # ---------- Environment Setup ---------- 23 | export NCCL_ASYNC_ERROR_HANDLING=1 24 | export DEBUG_MODE=True 25 | export WANDB_PROJECT=RL4SGG 26 | 27 | 28 | GPUS_PER_NODE=4 29 | GROUP_SIZE=8 30 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 31 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 32 | RUN_NAME="qwen2vl-7b-grpo-2k-lr6e-7-g8-n1-bs32-fp8-gh200" 33 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 34 | mkdir -p "$OUTPUT_DIR" 35 | 36 | export LOG_PATH=${OUTPUT_DIR}/debug.log 37 | 38 | export STRICT_FORMAT=True 39 | 40 | MAX_PIXELS=$((512 * 28 * 28)) 41 | 42 | 43 | MASTER_PORT=29500 44 | 45 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 46 | NUM_TRAIN_NODES=${#NODELIST[@]} 47 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 48 | 49 | # Choose the first training node as the rendezvous head node 50 | ##HEAD_NODE=${TRAIN_NODES_LIST[0]} 51 | ##HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 52 | ##echo "Head Node IP: $HEAD_NODE_IP" 53 | 54 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 55 | echo "MASTER_ADDR: $MASTER_ADDR" 56 | 57 | 58 | 59 | # GH200 has a very high bandwidth between CPU and GPU, we should use it! 60 | # zero2: 61 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 62 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 63 | # bsz_per_devie=16, ~40h with 4x GPUs 64 | # 65 | # batch size: 16*1*4*4 //8=32 66 | TRAIN_CMD="open_r1/grpo.py \ 67 | --task_type sgg \ 68 | --use_fp8 true \ 69 | --output_dir ${OUTPUT_DIR} \ 70 | --model_name_or_path ${MODEL_PATH} \ 71 | --dataset_name ${DATA_PATH} \ 72 | --max_prompt_length 2048 \ 73 | --max_completion_length 1024 \ 74 | --custom_per_device_train_batch_size 8 \ 75 | --deepspeed ./local_scripts/zero2_offload.json \ 76 | --gradient_accumulation_steps 1 \ 77 | --learning_rate 6e-7 \ 78 | --logging_steps 1 \ 79 | --use_vllm true \ 80 | --use_local_vllm true\ 81 | --bf16 true\ 82 | --tf32 true\ 83 | --report_to wandb \ 84 | --gradient_checkpointing true \ 85 | --max_pixels ${MAX_PIXELS} \ 86 | --temperature 1 \ 87 | --top_p 0.9 \ 88 | --top_k 50 \ 89 | --num_train_epochs 1 \ 90 | --run_name ${RUN_NAME} \ 91 | --save_steps 100 \ 92 | --num_generations ${GROUP_SIZE} \ 93 | --num_iterations 1 \ 94 | --beta 0.0 \ 95 | --vllm_max_model_len 4096 \ 96 | --vllm_gpu_memory_utilization 0.2 \ 97 | --ddp_timeout 3600 \ 98 | --save_only_model false" 99 | 100 | 101 | echo "start training with CMD=${TRAIN_CMD} ..." 102 | 103 | WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_TRAIN_NODES)) 104 | 105 | 106 | LAUNCHER="accelerate launch \ 107 | --multi_gpu \ 108 | --num_machines $NUM_TRAIN_NODES \ 109 | --num_processes $WORLD_SIZE \ 110 | --main_process_ip "$MASTER_ADDR" \ 111 | --main_process_port $MASTER_PORT \ 112 | --num_processes $WORLD_SIZE \ 113 | --machine_rank $SLURM_PROCID \ 114 | --role $SLURMD_NODENAME: \ 115 | --rdzv_conf rdzv_backend=c10d \ 116 | --rdzv_timeout 3600 \ 117 | --max_restarts 0 \ 118 | --tee 3 \ 119 | --mixed_precision fp8 \ 120 | " 121 | 122 | 123 | srun --jobid $SLURM_JOB_ID bash -c "$LAUNCHER $TRAIN_CMD" 124 | 125 | #srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 126 | # --node_rank ${SLURM_NODEID} \ 127 | # --rdzv_id $RANDOM \ 128 | # --rdzv_backend c10d \ 129 | # --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 130 | # --rdzv_timeout 3600 \ 131 | # ${TRAIN_CMD} 132 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_2B.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=2B_bs32_gh200_lr6e-7_debug 5 | #SBATCH --time=00:20:00 6 | 7 | #SBATCH --nodes=1 # 2 nodes, each has 4x GH200 8 | #SBATCH --ntasks=1 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=288 # fixed for GH200 12 | 13 | #SBATCH --partition=debug 14 | #SBATCH --output=RL_gh200_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | # ---------- Environment Setup ---------- 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=16 26 | MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 27 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 28 | RUN_NAME="qwen2vl-2b-grpo-debug-n1-sgg-bs16-lr6e-7-gh200" 29 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 30 | mkdir -p "$OUTPUT_DIR" 31 | 32 | MAX_PIXELS=$((512 * 28 * 28)) 33 | 34 | REF_MODEL_NAME=$1 35 | export LOG_PATH=${OUTPUT_DIR}/debug.log 36 | 37 | export STRICT_FORMAT=True 38 | 39 | MASTER_PORT=29500 40 | 41 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 42 | NUM_TRAIN_NODES=${#NODELIST[@]} 43 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 44 | 45 | # Choose the first training node as the rendezvous head node 46 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 47 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 48 | echo "Head Node IP: $HEAD_NODE_IP" 49 | 50 | 51 | 52 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32 53 | TRAIN_CMD="open_r1/grpo.py \ 54 | --output_dir ${OUTPUT_DIR} \ 55 | --task_type sgg \ 56 | --model_name_or_path ${MODEL_PATH} \ 57 | --dataset_name ${DATA_PATH} \ 58 | --max_prompt_length 2048 \ 59 | --max_completion_length 1024 \ 60 | --custom_per_device_train_batch_size 16 \ 61 | --deepspeed ./local_scripts/zero2_offload.json \ 62 | --gradient_accumulation_steps 1 \ 63 | --learning_rate 6e-7 \ 64 | --logging_steps 1 \ 65 | --use_vllm true \ 66 | --use_local_vllm true\ 67 | --bf16 true\ 68 | --tf32 true\ 69 | --report_to wandb \ 70 | --gradient_checkpointing true \ 71 | --max_pixels ${MAX_PIXELS} \ 72 | --temperature 1 \ 73 | --top_p 0.9 \ 74 | --top_k 50 \ 75 | --num_train_epochs 1 \ 76 | --run_name ${RUN_NAME} \ 77 | --save_steps 100 \ 78 | --num_generations ${GROUP_SIZE} \ 79 | --num_iterations 1 \ 80 | --beta 0.00 \ 81 | --vllm_max_model_len 4096 \ 82 | --vllm_gpu_memory_utilization 0.2 \ 83 | --save_only_model true \ 84 | --seed 42" 85 | 86 | 87 | echo "start training..." 88 | 89 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 90 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 91 | --node_rank ${SLURM_NODEID} \ 92 | --rdzv_id $RANDOM \ 93 | --rdzv_backend c10d \ 94 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \ 95 | ${TRAIN_CMD} 96 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_2B_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=2B_bs32_gh200_SFT_lr6e-7_psg 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --nodes=4 # 2 nodes, each has 4x GH200 8 | #SBATCH --ntasks=4 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=288 # fixed for GH200 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_gh200_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | # ---------- Environment Setup ---------- 19 | export NCCL_ASYNC_ERROR_HANDLING=1 20 | export DEBUG_MODE=True 21 | export WANDB_PROJECT=RL4SGG 22 | 23 | 24 | GPUS_PER_NODE=4 25 | GROUP_SIZE=8 26 | #MODEL_PATH="Qwen/Qwen2-VL-2B-Instruct" 27 | MODEL_PATH=$1 28 | 29 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | DATA_PATH="JosephZ/psg_train_sg" 31 | RUN_NAME="qwen2vl-2b-sft-grpo-g8-n1-bs32-lr6e-7-psg-merged-gh200" 32 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 33 | mkdir -p "$OUTPUT_DIR" 34 | 35 | export LOG_PATH=${OUTPUT_DIR}/debug.log 36 | export STRICT_FORMAT=True 37 | 38 | MAX_PIXELS=$((512 * 28 * 28)) 39 | 40 | 41 | MASTER_PORT=29500 42 | 43 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 44 | NUM_TRAIN_NODES=${#NODELIST[@]} 45 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 46 | 47 | # Choose the first training node as the rendezvous head node 48 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 49 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 50 | echo "Head Node IP: $HEAD_NODE_IP" 51 | 52 | 53 | 54 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32 55 | TRAIN_CMD="open_r1/grpo.py \ 56 | --output_dir ${OUTPUT_DIR} \ 57 | --model_name_or_path ${MODEL_PATH} \ 58 | --dataset_name ${DATA_PATH} \ 59 | --max_prompt_length 2048 \ 60 | --max_completion_length 1024 \ 61 | --custom_per_device_train_batch_size 16 \ 62 | --deepspeed ./local_scripts/zero2_offload.json \ 63 | --gradient_accumulation_steps 1 \ 64 | --learning_rate 6e-7 \ 65 | --logging_steps 1 \ 66 | --use_vllm true \ 67 | --use_local_vllm true\ 68 | --bf16 true\ 69 | --tf32 true\ 70 | --report_to wandb \ 71 | --gradient_checkpointing true \ 72 | --max_pixels ${MAX_PIXELS} \ 73 | --temperature 1 \ 74 | --top_p 0.9 \ 75 | --top_k 50 \ 76 | --max_steps 2000 \ 77 | --run_name ${RUN_NAME} \ 78 | --save_steps 100 \ 79 | --num_generations ${GROUP_SIZE} \ 80 | --num_iterations 1 \ 81 | --beta 0.0\ 82 | --vllm_max_model_len 4096 \ 83 | --vllm_gpu_memory_utilization 0.2 \ 84 | --save_only_model true \ 85 | --seed 42" 86 | 87 | 88 | echo "start training..." 89 | 90 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 91 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 92 | --node_rank ${SLURM_NODEID} \ 93 | --rdzv_id $RANDOM \ 94 | --rdzv_backend c10d \ 95 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \ 96 | ${TRAIN_CMD} 97 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_2B_close.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=2B_bs32_gh200_sft_close_lr2x 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --nodes=4 # 2 nodes, each has 4x GH200 8 | #SBATCH --ntasks=4 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=288 # fixed for GH200 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_gh200_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | set -x 19 | # ---------- Environment Setup ---------- 20 | export NCCL_ASYNC_ERROR_HANDLING=1 21 | export DEBUG_MODE=True 22 | export WANDB_PROJECT=RL4SGG 23 | 24 | 25 | GPUS_PER_NODE=4 26 | GROUP_SIZE=8 27 | MODEL_PATH=$1 #"Qwen/Qwen2-VL-2B-Instruct" 28 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 29 | RUN_NAME="qwen2vl-2b-sft-close-grpo-g8-n1-bs32-lr6e-7-gh200" 30 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 31 | mkdir -p "$OUTPUT_DIR" 32 | 33 | MAX_PIXELS=$((512 * 28 * 28)) 34 | 35 | export LOG_PATH=${OUTPUT_DIR}/debug.log 36 | export STRICT_FORMAT=True 37 | 38 | MASTER_PORT=29500 39 | 40 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 41 | NUM_TRAIN_NODES=${#NODELIST[@]} 42 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 43 | 44 | # Choose the first training node as the rendezvous head node 45 | HEAD_NODE=${TRAIN_NODES_LIST[0]} 46 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 47 | echo "Head Node IP: $HEAD_NODE_IP" 48 | 49 | 50 | 51 | # batch size: PER_DEVICE(16) * ACC(2) * GPU (4) * NODE(2) // GROUP_SIZE(8) = 32 52 | TRAIN_CMD="open_r1/grpo.py \ 53 | --output_dir ${OUTPUT_DIR} \ 54 | --model_name_or_path ${MODEL_PATH} \ 55 | --dataset_name ${DATA_PATH} \ 56 | --max_prompt_length 2048 \ 57 | --max_completion_length 1024 \ 58 | --custom_per_device_train_batch_size 16 \ 59 | --deepspeed ./local_scripts/zero2_offload.json \ 60 | --gradient_accumulation_steps 1 \ 61 | --learning_rate 6e-7 \ 62 | --use_predefined_cats true \ 63 | --logging_steps 1 \ 64 | --use_vllm true \ 65 | --use_local_vllm true\ 66 | --bf16 true\ 67 | --tf32 true\ 68 | --report_to wandb \ 69 | --gradient_checkpointing true \ 70 | --max_pixels ${MAX_PIXELS} \ 71 | --temperature 1 \ 72 | --top_p 0.9 \ 73 | --top_k 50 \ 74 | --num_train_epochs 1 \ 75 | --run_name ${RUN_NAME} \ 76 | --save_steps 100 \ 77 | --num_generations ${GROUP_SIZE} \ 78 | --num_iterations 1 \ 79 | --beta 0.0\ 80 | --vllm_max_model_len 4096 \ 81 | --vllm_gpu_memory_utilization 0.2 \ 82 | --save_only_model true \ 83 | --seed 42" 84 | 85 | 86 | echo "start training..." 87 | 88 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 89 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 90 | --node_rank ${SLURM_NODEID} \ 91 | --rdzv_id $RANDOM \ 92 | --rdzv_backend c10d \ 93 | --rdzv_endpoint ${HEAD_NODE_IP}:${MASTER_PORT} \ 94 | ${TRAIN_CMD} 95 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_GH200_SFT_GRPO_1k_lr2x_psg 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200 8 | #SBATCH --ntasks=8 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=288 # fixed for GH200 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_gh200_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | set -x 19 | # ---------- Environment Setup ---------- 20 | export NCCL_ASYNC_ERROR_HANDLING=1 21 | export DEBUG_MODE=True 22 | export WANDB_PROJECT=RL4SGG 23 | 24 | 25 | GPUS_PER_NODE=4 26 | GROUP_SIZE=8 27 | MODEL_PATH=$1 28 | 29 | #DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | DATA_PATH="JosephZ/psg_train_sg" 31 | 32 | RUN_NAME="qwen2vl-7b-sft-grpo-psg-merged-g8-n1-bs32-1k-lr6e-7-gh200" 33 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 34 | mkdir -p "$OUTPUT_DIR" 35 | 36 | export LOG_PATH=${OUTPUT_DIR}/debug.log 37 | export STRICT_FORMAT=True 38 | 39 | MAX_PIXELS=$((512 * 28 * 28)) 40 | 41 | 42 | MASTER_PORT=29500 43 | 44 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 45 | NUM_TRAIN_NODES=${#NODELIST[@]} 46 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 47 | 48 | #HEAD_NODE=${TRAIN_NODES_LIST[0]} 49 | #HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 50 | #echo "Head Node IP: $HEAD_NODE_IP" 51 | 52 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 53 | echo "MASTER_ADDR: $MASTER_ADDR" 54 | 55 | 56 | 57 | # GH200 has a very high bandwidth between CPU and GPU, we should use it! 58 | # zero2: 59 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 60 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 61 | # bsz_per_devie=16, ~40h with 4x GPUs 62 | # 63 | # batch size: PER DEVICE(16) * ACC(1) * GPU(4) * NODE(4) // GROUP_SIZE(8) = 32 64 | TRAIN_CMD="open_r1/grpo.py \ 65 | --output_dir ${OUTPUT_DIR} \ 66 | --model_name_or_path ${MODEL_PATH} \ 67 | --dataset_name ${DATA_PATH} \ 68 | --max_prompt_length 2048 \ 69 | --max_completion_length 1024 \ 70 | --custom_per_device_train_batch_size 8 \ 71 | --deepspeed ./local_scripts/zero2_offload.json \ 72 | --gradient_accumulation_steps 1 \ 73 | --learning_rate 6e-7 \ 74 | --logging_steps 1 \ 75 | --use_vllm true \ 76 | --use_local_vllm true\ 77 | --bf16 true\ 78 | --tf32 true\ 79 | --report_to wandb \ 80 | --gradient_checkpointing true \ 81 | --max_pixels ${MAX_PIXELS} \ 82 | --temperature 1 \ 83 | --top_p 0.9 \ 84 | --top_k 50 \ 85 | --num_train_epochs 1 \ 86 | --run_name ${RUN_NAME} \ 87 | --save_steps 100 \ 88 | --num_generations ${GROUP_SIZE} \ 89 | --num_iterations 1 \ 90 | --beta 0.0\ 91 | --vllm_max_model_len 4096 \ 92 | --vllm_gpu_memory_utilization 0.2 \ 93 | --save_only_model false" 94 | 95 | 96 | echo "start training with TRAIN_CMD=${TRAIN_CMD} ..." 97 | 98 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 99 | --node_rank ${SLURM_NODEID} \ 100 | --rdzv_id $RANDOM \ 101 | --rdzv_backend c10d \ 102 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 103 | ${TRAIN_CMD} 104 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_close.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_GH200_close_lr2x 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --exclude=nid006792 8 | #SBATCH --nodes=8 # 4 nodes, each has 4x GH200 9 | #SBATCH --ntasks=8 # Total tasks equals total nodes 10 | #SBATCH --ntasks-per-node=1 11 | #SBATCH --gpus-per-node=4 12 | #SBATCH --cpus-per-task=288 # fixed for GH200 13 | 14 | #SBATCH --partition=normal 15 | #SBATCH --output=RL_gh200_%j_%N.out 16 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 17 | 18 | 19 | set -x 20 | # ---------- Environment Setup ---------- 21 | export NCCL_ASYNC_ERROR_HANDLING=1 22 | export DEBUG_MODE=True 23 | export WANDB_PROJECT=RL4SGG 24 | 25 | 26 | GPUS_PER_NODE=4 27 | GROUP_SIZE=8 28 | MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | RUN_NAME="qwen2vl-7b-close-grpo-g8-n1-bs32-lr6e-7-gh200" 31 | export OUTPUT_DIR="${SCRATCH}/models/7B/${RUN_NAME}" 32 | mkdir -p "$OUTPUT_DIR" 33 | 34 | export STRICT_FORMAT=True 35 | export LOG_PATH=${OUTPUT_DIR}/debug.log 36 | 37 | MAX_PIXELS=$((512 * 28 * 28)) 38 | 39 | 40 | MASTER_PORT=29500 41 | 42 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 43 | NUM_TRAIN_NODES=${#NODELIST[@]} 44 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 45 | 46 | 47 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 48 | echo "MASTER_ADDR: $MASTER_ADDR" 49 | 50 | 51 | # GH200 has a very high bandwidth between CPU and GPU, we should use it! 52 | # zero2: 53 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 54 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 55 | # bsz_per_devie=16, ~40h with 4x GPUs 56 | # 57 | # batch size: 16*1*4*4 //8=32 58 | TRAIN_CMD="open_r1/grpo.py \ 59 | --output_dir ${OUTPUT_DIR} \ 60 | --model_name_or_path ${MODEL_PATH} \ 61 | --dataset_name ${DATA_PATH} \ 62 | --max_prompt_length 2048 \ 63 | --max_completion_length 1024 \ 64 | --custom_per_device_train_batch_size 8 \ 65 | --deepspeed ./local_scripts/zero2_offload.json \ 66 | --gradient_accumulation_steps 1 \ 67 | --learning_rate 6e-7 \ 68 | --use_predefined_cats true \ 69 | --logging_steps 1 \ 70 | --use_vllm true \ 71 | --use_local_vllm true\ 72 | --bf16 true\ 73 | --tf32 true\ 74 | --report_to wandb \ 75 | --gradient_checkpointing true \ 76 | --max_pixels ${MAX_PIXELS} \ 77 | --temperature 1 \ 78 | --top_p 0.9 \ 79 | --top_k 50 \ 80 | --num_train_epochs 1 \ 81 | --run_name ${RUN_NAME} \ 82 | --save_steps 100 \ 83 | --num_generations ${GROUP_SIZE} \ 84 | --num_iterations 1 \ 85 | --beta 0.0 \ 86 | --vllm_max_model_len 4096 \ 87 | --vllm_gpu_memory_utilization 0.2 \ 88 | --ddp_timeout 3600 \ 89 | --save_only_model false" 90 | 91 | 92 | echo "start training with CMD=${TRAIN_CMD} ..." 93 | 94 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 95 | --node_rank ${SLURM_NODEID} \ 96 | --rdzv_id $RANDOM \ 97 | --rdzv_backend c10d \ 98 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 99 | ${TRAIN_CMD} 100 | -------------------------------------------------------------------------------- /scripts/grpo/train_gh200_close_SFT.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | #SBATCH --job-name=7B_GH200_SFT_CLOSE_GRPO_lr2x 5 | #SBATCH --time=12:00:00 6 | 7 | #SBATCH --nodes=8 # 8 nodes, each has 4x GH200 8 | #SBATCH --ntasks=8 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=4 11 | #SBATCH --cpus-per-task=288 # fixed for GH200 12 | 13 | #SBATCH --partition=normal 14 | #SBATCH --output=RL_gh200_%j_%N.out 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 16 | 17 | 18 | set -x 19 | # ---------- Environment Setup ---------- 20 | export NCCL_ASYNC_ERROR_HANDLING=1 21 | export DEBUG_MODE=True 22 | export WANDB_PROJECT=RL4SGG 23 | 24 | 25 | GPUS_PER_NODE=4 26 | GROUP_SIZE=8 27 | MODEL_PATH=$1 28 | 29 | DATA_PATH="JosephZ/vg150_train_sgg_prompt" 30 | RUN_NAME="qwen2vl-7b-sft-grpo-g8-n1-bs32-close-lr6e-7-gh200" 31 | export OUTPUT_DIR="${SCRATCH}/models/7B/${RUN_NAME}" 32 | mkdir -p "$OUTPUT_DIR" 33 | 34 | export LOG_PATH=${OUTPUT_DIR}/debug.log 35 | export STRICT_FORMAT=True 36 | 37 | MAX_PIXELS=$((512 * 28 * 28)) 38 | 39 | 40 | MASTER_PORT=29500 41 | 42 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 43 | NUM_TRAIN_NODES=${#NODELIST[@]} 44 | TRAIN_NODES_LIST=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 45 | 46 | 47 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 48 | echo "MASTER_ADDR: $MASTER_ADDR" 49 | 50 | 51 | 52 | # GH200 has a very high bandwidth between CPU and GPU, we should use it! 53 | # zero2: 54 | # bsz_per_devie=16, OOM; Ok, with CPU offload for optimizer, ~60h with 3x GPUs 55 | # bsz_per_devie=8, 386s for 30 steps, ~60h with 3x GPUs 56 | # bsz_per_devie=16, ~40h with 4x GPUs 57 | # 58 | # batch size: PER DEVICE(16) * ACC(1) * GPU(4) * NODE(4) // GROUP_SIZE(8) = 32 59 | TRAIN_CMD="open_r1/grpo.py \ 60 | --output_dir ${OUTPUT_DIR} \ 61 | --model_name_or_path ${MODEL_PATH} \ 62 | --dataset_name ${DATA_PATH} \ 63 | --max_prompt_length 2048 \ 64 | --max_completion_length 1024 \ 65 | --custom_per_device_train_batch_size 8 \ 66 | --deepspeed ./local_scripts/zero2_offload.json \ 67 | --gradient_accumulation_steps 1 \ 68 | --learning_rate 6e-7 \ 69 | --logging_steps 1 \ 70 | --use_vllm true \ 71 | --use_local_vllm true\ 72 | --bf16 true\ 73 | --tf32 true\ 74 | --report_to wandb \ 75 | --gradient_checkpointing true \ 76 | --max_pixels ${MAX_PIXELS} \ 77 | --temperature 1 \ 78 | --top_p 0.9 \ 79 | --top_k 50 \ 80 | --num_train_epochs 1 \ 81 | --run_name ${RUN_NAME} \ 82 | --save_steps 100 \ 83 | --num_generations ${GROUP_SIZE} \ 84 | --num_iterations 1 \ 85 | --beta 0.0\ 86 | --vllm_max_model_len 4096 \ 87 | --vllm_gpu_memory_utilization 0.2 \ 88 | --use_predefined_cats true \ 89 | --ddp_timeout 3600 \ 90 | --save_only_model false" 91 | 92 | 93 | echo "start training with TRAIN_CMD=${TRAIN_CMD} ..." 94 | 95 | srun torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 96 | --node_rank ${SLURM_NODEID} \ 97 | --rdzv_id $RANDOM \ 98 | --rdzv_backend c10d \ 99 | --rdzv_endpoint ${MASTER_ADDR}:${MASTER_PORT} \ 100 | ${TRAIN_CMD} 101 | -------------------------------------------------------------------------------- /scripts/grpo/train_zero3.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=GRPO_train 6 | #SBATCH --time=24:00:00 7 | #SBATCH --nodes=16 # each node has 8x GPUs, 4x for training, 4x for vLLM inference 8 | #SBATCH --ntasks=16 # Total tasks equals total nodes 9 | #SBATCH --ntasks-per-node=1 10 | #SBATCH --gpus-per-node=rtx_4090:8 11 | #SBATCH --cpus-per-task=16 12 | #SBATCH --mem-per-cpu=16000M 13 | #SBATCH --output=RL_%j_%N.out 14 | #SBATCH --mail-user="zychen.uestc@gmail.com" --mail-type=ALL 15 | 16 | 17 | # force crashing on nccl issues like hanging broadcast 18 | export NCCL_ASYNC_ERROR_HANDLING=1 19 | # export NCCL_DEBUG=INFO 20 | # export NCCL_DEBUG_SUBSYS=COLL 21 | # export NCCL_SOCKET_NTHREADS=1 22 | # export NCCL_NSOCKS_PERTHREAD=1 23 | # export CUDA_LAUNCH_BLOCKING=1 24 | 25 | # wait for vLLM servers 26 | #sleep 60 27 | 28 | # Read IPs from file and join them with commas 29 | #ip_str=$(paste -sd, ip_list.txt) 30 | #echo "vLLM servers: $ip_str" 31 | 32 | FILE="ip_port_list.txt" 33 | 34 | SERVER_IP="" 35 | SERVER_PORT="" 36 | 37 | while IFS=: read -r ip port; do 38 | SERVER_IP+="${ip}," 39 | SERVER_PORT+="${port}," 40 | done < "$FILE" 41 | 42 | # Remove trailing commas 43 | SERVER_IP="${SERVER_IP%,}" 44 | SERVER_PORT="${SERVER_PORT%,}" 45 | 46 | echo "SERVER_IP=$SERVER_IP" 47 | echo "SERVER_PORT=$SERVER_PORT" 48 | 49 | 50 | # Define node counts 51 | GPUS_PER_NODE=8 52 | 53 | # Get the list of allocated nodes 54 | NODELIST=($(scontrol show hostnames $SLURM_JOB_NODELIST)) 55 | NUM_TRAIN_NODES=${#NODELIST[@]} 56 | 57 | # Assign training nodes (first NUM_TRAIN_NODES nodes) 58 | TRAIN_NODES=("${NODELIST[@]:0:$NUM_TRAIN_NODES}") 59 | 60 | # Choose the first training node as the rendezvous head node 61 | HEAD_NODE=${TRAIN_NODES[0]} 62 | HEAD_NODE_IP=$(srun --nodes=1 --ntasks=1 -w "$HEAD_NODE" hostname --ip-address) 63 | echo "Head Node IP: $HEAD_NODE_IP" 64 | 65 | echo "environment: $(env | grep NCCL)" 66 | 67 | 68 | # Create a comma-separated list of training nodes for srun 69 | TRAIN_NODES_LIST=$(IFS=, ; echo "${TRAIN_NODES[*]}") 70 | 71 | # Define HOST and PORT for the vLLM server 72 | PORT_A=8888 73 | 74 | 75 | export DEBUG_MODE=True 76 | export WANDB_PROJECT=RL4SGG 77 | 78 | export DATA_PATH="JosephZ/vg150_train_sgg_prompt" 79 | export MODEL_PATH="Qwen/Qwen2-VL-7B-Instruct" 80 | 81 | export NODE_RANK=${SLURM_NODEID} # Provided by SLURM 82 | 83 | MAX_PIXELS=$((512 * 28 * 28)) 84 | 85 | # Launch distributed training on the training nodes using 8 GPUs per node 86 | srun --nodes=${NUM_TRAIN_NODES} --nodelist="${TRAIN_NODES_LIST}" \ 87 | torchrun --nnodes ${NUM_TRAIN_NODES} --nproc_per_node ${GPUS_PER_NODE} \ 88 | --node_rank $NODE_RANK \ 89 | --rdzv_id $RANDOM \ 90 | --rdzv_backend c10d \ 91 | --rdzv_endpoint ${HEAD_NODE_IP}:29500 \ 92 | open_r1/grpo.py \ 93 | --output_dir models/qwen2vl-nokl-n1-g8 \ 94 | --model_name_or_path ${MODEL_PATH} \ 95 | --dataset_name $DATA_PATH \ 96 | --deepspeed ./local_scripts/zero3.json \ 97 | --max_prompt_length 2048 \ 98 | --max_completion_length 1024 \ 99 | --per_device_train_batch_size 1 \ 100 | --gradient_accumulation_steps 1 \ 101 | --logging_steps 1 \ 102 | --use_vllm true \ 103 | --vllm_server_host ${SERVER_IP} \ 104 | --vllm_server_port ${SERVER_PORT} \ 105 | --vllm_server_timeout 600 \ 106 | --bf16 \ 107 | --report_to wandb \ 108 | --gradient_checkpointing true \ 109 | --max_pixels ${MAX_PIXELS} \ 110 | --temperature 0.3 \ 111 | --top_p 0.001 \ 112 | --top_k 1 \ 113 | --num_train_epochs 1 \ 114 | --run_name Qwen2VL-7B-GRPO-nokl-n1-G8 \ 115 | --save_steps 100 \ 116 | --num_generations 8 \ 117 | --num_iterations 1 \ 118 | --beta 0.0 119 | -------------------------------------------------------------------------------- /scripts/inference/run_sgg_inference.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | export GPUS_PER_NODE=4 6 | 7 | 8 | DATASET=$1 9 | MODEL_NAME=$2 10 | OUTPUT_DIR=$3 11 | USE_CATS=$4 # true/false 12 | PROMPT_TYPE=$5 # true/false 13 | 14 | BATCH_SIZE=${6:-8} 15 | 16 | echo "MODEL_NAME: $MODEL_NAME, OUTPUT_DIR: $OUTPUT_DIR" 17 | echo "USE_CATS: $USE_CATS, PROMPT_TYPE: $PROMPT_TYPE" 18 | 19 | ARGS="--dataset $DATASET --model $MODEL_NAME --output_dir $OUTPUT_DIR --max_model_len 4096 --batch_size $BATCH_SIZE" 20 | 21 | 22 | if [ "$PROMPT_TYPE" == "true" ]; then 23 | ARGS="$ARGS --use_think_system_prompt" 24 | fi 25 | 26 | if [ "$USE_CATS" == "true" ]; then 27 | ARGS="$ARGS --use_predefined_cats" 28 | fi 29 | 30 | echo "ARGS:$ARGS" 31 | 32 | torchrun --nnodes 1 \ 33 | --nproc_per_node $GPUS_PER_NODE \ 34 | --node_rank 0 \ 35 | src/sgg_inference_vllm.py -- $ARGS 36 | -------------------------------------------------------------------------------- /scripts/sft/2B_sgg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=SFT_2B_psg 6 | #SBATCH --time=12:00:00 7 | 8 | # 4x A100 9 | 10 | #SBATCH --nodes=8 11 | #SBATCH --ntasks-per-node=1 12 | #SBATCH --gpus-per-node=4 13 | #SBATCH --cpus-per-task=288 14 | 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" 16 | #SBATCH --mail-type=ALL 17 | #SBATCH --output=SFT-2B_%j_%N.out 18 | 19 | # Get node list and determine head node 20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 21 | head_node=${nodes[0]} 22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 23 | 24 | 25 | #DATASET=JosephZ/vg150_train_sgg_prompt 26 | DATASET=JosephZ/psg_train_sg 27 | 28 | 29 | echo "Head Node IP: $head_node_ip" 30 | 31 | # Set NODE_RANK from SLURM environment variable 32 | export NODE_RANK=${SLURM_NODEID} 33 | 34 | export GPUS_PER_NODE=4 35 | 36 | export WANDB_PROJECT=RL4SGG 37 | 38 | RUN_NAME="qwen2vl-2b-sft-open-psg-bs128-e6-merged-gh200" 39 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 40 | mkdir -p "$OUTPUT_DIR" 41 | 42 | 43 | # batch size=4 * 2 * 16 = 128 44 | srun torchrun --nnodes ${SLURM_NNODES} \ 45 | --nproc_per_node $GPUS_PER_NODE \ 46 | --node_rank $NODE_RANK \ 47 | --rdzv_id $RANDOM \ 48 | --rdzv_backend c10d \ 49 | --rdzv_endpoint ${head_node_ip}:29500 \ 50 | src/sft_sgg.py \ 51 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 52 | --dataset_name $DATASET \ 53 | --learning_rate 1e-5 \ 54 | --per_device_train_batch_size 4\ 55 | --gradient_accumulation_steps 1\ 56 | --warmup_ratio 0.05 \ 57 | --max_grad_norm 0.3 \ 58 | --logging_steps 1 \ 59 | --bf16 true\ 60 | --tf32 true\ 61 | --report_to wandb \ 62 | --attn_implementation flash_attention_2 \ 63 | --num_train_epochs 6 \ 64 | --run_name $RUN_NAME \ 65 | --save_steps 500 \ 66 | --save_only_model true \ 67 | --torch_dtype bfloat16 \ 68 | --fsdp "full_shard auto_wrap" \ 69 | --fsdp_config local_scripts/fsdp_config.json \ 70 | --output_dir $OUTPUT_DIR \ 71 | --seed 42 72 | -------------------------------------------------------------------------------- /scripts/sft/2B_sgg_predefined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=SFT_2B_close 6 | #SBATCH --time=24:00:00 7 | 8 | # 4x A100 9 | 10 | #SBATCH --nodes=1 11 | #SBATCH --ntasks-per-node=1 12 | #SBATCH --gpus-per-node=4 13 | #SBATCH --cpus-per-task=128 14 | 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" 16 | #SBATCH --mail-type=ALL 17 | #SBATCH --output=SFT-2B_%j_%N.out 18 | 19 | # Get node list and determine head node 20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 21 | head_node=${nodes[0]} 22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 23 | 24 | echo "Head Node IP: $head_node_ip" 25 | 26 | # Set NODE_RANK from SLURM environment variable 27 | export NODE_RANK=${SLURM_NODEID} 28 | 29 | export GPUS_PER_NODE=4 30 | 31 | export WANDB_PROJECT=RL4SGG 32 | 33 | 34 | # batch size=4 * 2 * 16 = 128 35 | srun torchrun --nnodes ${SLURM_NNODES} \ 36 | --nproc_per_node $GPUS_PER_NODE \ 37 | --node_rank $NODE_RANK \ 38 | --rdzv_id $RANDOM \ 39 | --rdzv_backend c10d \ 40 | --rdzv_endpoint ${head_node_ip}:29500 \ 41 | src/sft_sgg.py \ 42 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 43 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 44 | --learning_rate 1e-5 \ 45 | --per_device_train_batch_size 4\ 46 | --gradient_accumulation_steps 8\ 47 | --warmup_ratio 0.05 \ 48 | --max_grad_norm 0.3 \ 49 | --logging_steps 1 \ 50 | --bf16 true\ 51 | --tf32 true\ 52 | --report_to wandb \ 53 | --attn_implementation flash_attention_2 \ 54 | --num_train_epochs 3 \ 55 | --run_name Qwen2-VL-2B_vg150_sgg_b128_predefined_e3 \ 56 | --save_steps 100 \ 57 | --save_only_model true \ 58 | --torch_dtype bfloat16 \ 59 | --fsdp "full_shard auto_wrap" \ 60 | --fsdp_config local_scripts/fsdp_config.json \ 61 | --use_predefined_cats true \ 62 | --output_dir models/qwen2vl-2b-sft-vg150-b128-predefined-e3 \ 63 | --seed 42 64 | -------------------------------------------------------------------------------- /scripts/sft/7B_sgg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=SFT_7B_psg 6 | #SBATCH --time=12:00:00 7 | 8 | # 4x A100 9 | 10 | #SBATCH --nodes=8 11 | #SBATCH --ntasks=8 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gpus-per-node=4 14 | #SBATCH --cpus-per-task=288 15 | 16 | #SBATCH --mail-user="zychen.uestc@gmail.com" 17 | #SBATCH --mail-type=ALL 18 | #SBATCH --output=SFT-7B_%j_%N.out 19 | 20 | 21 | set -x 22 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 23 | echo "MASTER_ADDR:$MASTER_ADDR" 24 | 25 | 26 | # Set NODE_RANK from SLURM environment variable 27 | export NODE_RANK=${SLURM_NODEID} 28 | 29 | export GPUS_PER_NODE=4 30 | 31 | export WANDB_PROJECT=RL4SGG 32 | 33 | DATASET=JosephZ/psg_train_sg 34 | 35 | RUN_NAME="qwen2vl-7b-sft-open-psg-bs128-e3-merged-gh200" 36 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 37 | mkdir -p "$OUTPUT_DIR" 38 | 39 | # batch size=4 * 2 * 16 = 128 40 | srun torchrun --nnodes ${SLURM_NNODES} \ 41 | --nproc_per_node $GPUS_PER_NODE \ 42 | --node_rank $NODE_RANK \ 43 | --rdzv_id $RANDOM \ 44 | --rdzv_backend c10d \ 45 | --rdzv_endpoint ${MASTER_ADDR}:29500 \ 46 | src/sft_sgg.py \ 47 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ 48 | --dataset_name $DATASET \ 49 | --learning_rate 1e-5 \ 50 | --per_device_train_batch_size 4 \ 51 | --gradient_accumulation_steps 1 \ 52 | --warmup_ratio 0.05 \ 53 | --max_grad_norm 0.3 \ 54 | --logging_steps 1 \ 55 | --bf16 true\ 56 | --tf32 true\ 57 | --report_to wandb \ 58 | --attn_implementation flash_attention_2 \ 59 | --num_train_epochs 3 \ 60 | --run_name $RUN_NAME \ 61 | --save_steps 100 \ 62 | --save_only_model true \ 63 | --torch_dtype bfloat16 \ 64 | --fsdp "full_shard auto_wrap" \ 65 | --fsdp_config local_scripts/fsdp_config.json \ 66 | --output_dir $OUTPUT_DIR \ 67 | --seed 42 68 | 69 | -------------------------------------------------------------------------------- /scripts/sft/7B_sgg_lora.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=SFT_7B_lora 6 | #SBATCH --time=12:00:00 7 | 8 | # 4x A100 9 | 10 | #SBATCH --nodes=4 11 | #SBATCH --ntasks=4 12 | #SBATCH --ntasks-per-node=1 13 | #SBATCH --gpus-per-node=4 14 | #SBATCH --cpus-per-task=288 15 | 16 | #SBATCH --mail-user="zychen.uestc@gmail.com" 17 | #SBATCH --mail-type=ALL 18 | #SBATCH --output=SFT-7B_%j_%N.out 19 | 20 | 21 | set -x 22 | MASTER_ADDR=$(scontrol show hostnames $SLURM_NODELIST | head -n 1) 23 | echo "MASTER_ADDR:$MASTER_ADDR" 24 | 25 | 26 | # Set NODE_RANK from SLURM environment variable 27 | export NODE_RANK=${SLURM_NODEID} 28 | 29 | export GPUS_PER_NODE=4 30 | 31 | export WANDB_PROJECT=RL4SGG 32 | 33 | DATASET=JosephZ/vg150_train_sgg_prompt 34 | 35 | RUN_NAME="qwen2vl-7b-sft-open-vg150-lora-bs128-gh200" 36 | export OUTPUT_DIR="${SCRATCH}/models/${RUN_NAME}" 37 | mkdir -p "$OUTPUT_DIR" 38 | 39 | # batch size=4 * 2 * 16 = 128 40 | srun torchrun --nnodes ${SLURM_NNODES} \ 41 | --nproc_per_node $GPUS_PER_NODE \ 42 | --node_rank $NODE_RANK \ 43 | --rdzv_id $RANDOM \ 44 | --rdzv_backend c10d \ 45 | --rdzv_endpoint ${MASTER_ADDR}:29500 \ 46 | src/sft_sgg.py \ 47 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ 48 | --dataset_name $DATASET \ 49 | --learning_rate 1e-5 \ 50 | --per_device_train_batch_size 8 \ 51 | --gradient_accumulation_steps 1 \ 52 | --warmup_ratio 0.05 \ 53 | --max_grad_norm 0.3 \ 54 | --logging_steps 1 \ 55 | --bf16 true\ 56 | --tf32 true\ 57 | --report_to wandb \ 58 | --attn_implementation flash_attention_2 \ 59 | --num_train_epochs 3 \ 60 | --run_name $RUN_NAME \ 61 | --save_steps 100 \ 62 | --save_only_model true \ 63 | --torch_dtype bfloat16 \ 64 | --lora_r 16 \ 65 | --lora_alpha 32 \ 66 | --lora_dropout 0.05 \ 67 | --fsdp "full_shard auto_wrap" \ 68 | --fsdp_config local_scripts/fsdp_config.json \ 69 | --output_dir $OUTPUT_DIR \ 70 | --seed 42 71 | 72 | -------------------------------------------------------------------------------- /scripts/sft/7B_sgg_predefined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | #SBATCH --job-name=SFT_7B_close 6 | #SBATCH --time=24:00:00 7 | 8 | # 4x A100 9 | 10 | #SBATCH --nodes=1 11 | #SBATCH --ntasks-per-node=1 12 | #SBATCH --gpus-per-node=4 13 | #SBATCH --cpus-per-task=128 14 | 15 | #SBATCH --mail-user="zychen.uestc@gmail.com" 16 | #SBATCH --mail-type=ALL 17 | #SBATCH --output=SFT-7B-close_%j_%N.out 18 | 19 | # Get node list and determine head node 20 | nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) 21 | head_node=${nodes[0]} 22 | head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) 23 | 24 | echo "Head Node IP: $head_node_ip" 25 | 26 | # Set NODE_RANK from SLURM environment variable 27 | export NODE_RANK=${SLURM_NODEID} 28 | 29 | export GPUS_PER_NODE=4 30 | 31 | export WANDB_PROJECT=RL4SGG 32 | 33 | 34 | # batch size=4 * 2 * 16 = 128 35 | srun torchrun --nnodes ${SLURM_NNODES} \ 36 | --nproc_per_node $GPUS_PER_NODE \ 37 | --node_rank $NODE_RANK \ 38 | --rdzv_id $RANDOM \ 39 | --rdzv_backend c10d \ 40 | --rdzv_endpoint ${head_node_ip}:29500 \ 41 | src/sft_sgg.py \ 42 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ 43 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 44 | --learning_rate 1e-5 \ 45 | --per_device_train_batch_size 2 \ 46 | --gradient_accumulation_steps 16 \ 47 | --warmup_ratio 0.05 \ 48 | --max_grad_norm 0.3 \ 49 | --logging_steps 1 \ 50 | --bf16 true\ 51 | --tf32 true\ 52 | --report_to wandb \ 53 | --attn_implementation flash_attention_2 \ 54 | --num_train_epochs 3 \ 55 | --run_name Qwen2-VL-7B_vg150_sgg_b128_predefined_e3 \ 56 | --save_steps 100 \ 57 | --save_only_model true \ 58 | --torch_dtype bfloat16 \ 59 | --fsdp "full_shard auto_wrap" \ 60 | --fsdp_config local_scripts/fsdp_config.json \ 61 | --use_predefined_cats true \ 62 | --output_dir models/qwen2vl-7b-sft-vg150-b128-predefined-e3 \ 63 | --seed 42 64 | 65 | -------------------------------------------------------------------------------- /scripts/sft_local/2B_sgg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | export TORCH_DISTRIBUTED_DEBUG=INFO 7 | export NCCL_DEBUG=INFO 8 | 9 | 10 | export GPUS_PER_NODE=4 11 | export WANDB_PROJECT=RL4SGG 12 | 13 | 14 | # batch size=4 * 2 * 16 = 128 15 | torchrun --nnodes 1 \ 16 | --nproc_per_node $GPUS_PER_NODE \ 17 | --node_rank 0 \ 18 | src/sft_sgg.py \ 19 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 20 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 21 | --learning_rate 1e-5 \ 22 | --per_device_train_batch_size 4\ 23 | --gradient_accumulation_steps 8\ 24 | --warmup_ratio 0.05 \ 25 | --max_grad_norm 0.3 \ 26 | --logging_steps 1 \ 27 | --bf16 true\ 28 | --tf32 true\ 29 | --report_to wandb \ 30 | --attn_implementation flash_attention_2 \ 31 | --num_train_epochs 3 \ 32 | --run_name Qwen2-VL-2B_vg150_sgg_b128_open_e3 \ 33 | --save_steps 100 \ 34 | --save_only_model true \ 35 | --torch_dtype bfloat16 \ 36 | --fsdp "full_shard auto_wrap" \ 37 | --fsdp_config local_scripts/fsdp_config.json \ 38 | --output_dir models/qwen2vl-2b-sft-vg150-b128-open-e3 \ 39 | --seed 42 40 | -------------------------------------------------------------------------------- /scripts/sft_local/2B_sgg_predefined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | export GPUS_PER_NODE=4 7 | 8 | export WANDB_PROJECT=RL4SGG 9 | 10 | 11 | # batch size=4 * 2 * 16 = 128 12 | torchrun --nnodes 1 \ 13 | --nproc_per_node $GPUS_PER_NODE \ 14 | --node_rank 0 \ 15 | src/sft_sgg.py \ 16 | --model_name_or_path Qwen/Qwen2-VL-2B-Instruct \ 17 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 18 | --learning_rate 1e-5 \ 19 | --per_device_train_batch_size 4\ 20 | --gradient_accumulation_steps 8\ 21 | --warmup_ratio 0.05 \ 22 | --max_grad_norm 0.3 \ 23 | --logging_steps 1 \ 24 | --bf16 true\ 25 | --tf32 true\ 26 | --report_to wandb \ 27 | --attn_implementation flash_attention_2 \ 28 | --num_train_epochs 3 \ 29 | --run_name Qwen2-VL-2B_vg150_sgg_b128_predefined_e3 \ 30 | --save_steps 100 \ 31 | --save_only_model true \ 32 | --torch_dtype bfloat16 \ 33 | --fsdp "full_shard auto_wrap" \ 34 | --fsdp_config local_scripts/fsdp_config.json \ 35 | --use_predefined_cats true \ 36 | --output_dir models/qwen2vl-2b-sft-vg150-b128-predefined-e3 \ 37 | --seed 42 38 | -------------------------------------------------------------------------------- /scripts/sft_local/7B_sgg.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | 7 | export GPUS_PER_NODE=4 8 | 9 | export WANDB_PROJECT=RL4SGG 10 | 11 | 12 | # batch size=4 * 2 * 16 = 128 13 | torchrun --nnodes 1 \ 14 | --nproc_per_node $GPUS_PER_NODE \ 15 | --node_rank 0 \ 16 | src/sft_sgg.py \ 17 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ 18 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 19 | --learning_rate 1e-5 \ 20 | --per_device_train_batch_size 2 \ 21 | --gradient_accumulation_steps 16 \ 22 | --warmup_ratio 0.05 \ 23 | --max_grad_norm 0.3 \ 24 | --logging_steps 1 \ 25 | --bf16 true\ 26 | --tf32 true\ 27 | --report_to wandb \ 28 | --attn_implementation flash_attention_2 \ 29 | --num_train_epochs 3 \ 30 | --run_name Qwen2-VL-7B_vg150_sgg_b128_open_e3 \ 31 | --save_steps 100 \ 32 | --save_only_model true \ 33 | --torch_dtype bfloat16 \ 34 | --fsdp "full_shard auto_wrap" \ 35 | --fsdp_config local_scripts/fsdp_config.json \ 36 | --output_dir models/qwen2vl-7b-sft-vg150-b128-open-e3 \ 37 | --seed 42 38 | 39 | -------------------------------------------------------------------------------- /scripts/sft_local/7B_sgg_predefined.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | 4 | 5 | 6 | export GPUS_PER_NODE=4 7 | 8 | export WANDB_PROJECT=RL4SGG 9 | 10 | 11 | # batch size=4 * 2 * 16 = 128 12 | srun torchrun --nnodes 1 \ 13 | --nproc_per_node $GPUS_PER_NODE \ 14 | --node_rank 0 \ 15 | src/sft_sgg.py \ 16 | --model_name_or_path Qwen/Qwen2-VL-7B-Instruct \ 17 | --dataset_name JosephZ/vg150_train_sgg_prompt \ 18 | --learning_rate 1e-5 \ 19 | --per_device_train_batch_size 2 \ 20 | --gradient_accumulation_steps 16 \ 21 | --warmup_ratio 0.05 \ 22 | --max_grad_norm 0.3 \ 23 | --logging_steps 1 \ 24 | --bf16 true\ 25 | --tf32 true\ 26 | --report_to wandb \ 27 | --attn_implementation flash_attention_2 \ 28 | --num_train_epochs 3 \ 29 | --run_name Qwen2-VL-7B_vg150_sgg_b128_predefined_e3 \ 30 | --save_steps 100 \ 31 | --save_only_model true \ 32 | --torch_dtype bfloat16 \ 33 | --fsdp "full_shard auto_wrap" \ 34 | --fsdp_config local_scripts/fsdp_config.json \ 35 | --use_predefined_cats true \ 36 | --output_dir models/qwen2vl-7b-sft-vg150-b128-predefined-e3 \ 37 | --seed 42 38 | 39 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name="open_r1", 5 | version="0.1", 6 | packages=find_packages(), 7 | ) 8 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/__init__.py -------------------------------------------------------------------------------- /src/psg_categories.json: -------------------------------------------------------------------------------- 1 | {"thing_classes": ["person", "bicycle", "car", "motorcycle", "airplane", "bus", "train", "truck", "boat", "traffic light", "fire hydrant", "stop sign", "parking meter", "bench", "bird", "cat", "dog", "horse", "sheep", "cow", "elephant", "bear", "zebra", "giraffe", "backpack", "umbrella", "handbag", "tie", "suitcase", "frisbee", "skis", "snowboard", "sports ball", "kite", "baseball bat", "baseball glove", "skateboard", "surfboard", "tennis racket", "bottle", "wine glass", "cup", "fork", "knife", "spoon", "bowl", "banana", "apple", "sandwich", "orange", "broccoli", "carrot", "hot dog", "pizza", "donut", "cake", "chair", "couch", "potted plant", "bed", "dining table", "toilet", "tv", "laptop", "mouse", "remote", "keyboard", "cell phone", "microwave", "oven", "toaster", "sink", "refrigerator", "book", "clock", "vase", "scissors", "teddy bear", "hair drier", "toothbrush"], "stuff_classes": ["banner", "blanket", "bridge", "cardboard", "counter", "curtain", "door-stuff", "floor-wood", "flower", "fruit", "gravel", "house", "light", "mirror-stuff", "net", "pillow", "platform", "playingfield", "railroad", "river", "road", "roof", "sand", "sea", "shelf", "snow", "stairs", "tent", "towel", "wall-brick", "wall-stone", "wall-tile", "wall-wood", "water-other", "window-blind", "window-other", "tree-merged", "fence-merged", "ceiling-merged", "sky-other-merged", "cabinet-merged", "table-merged", "floor-other-merged", "pavement-merged", "mountain-merged", "grass-merged", "dirt-merged", "paper-merged", "food-other-merged", "building-other-merged", "rock-merged", "wall-other-merged", "rug-merged"], "predicate_classes": ["over", "in front of", "beside", "on", "in", "attached to", "hanging from", "on back of", "falling off", "going down", "painted on", "walking on", "running on", "crossing", "standing on", "lying on", "sitting on", "flying over", "jumping over", "jumping from", "wearing", "holding", "carrying", "looking at", "guiding", "kissing", "eating", "drinking", "feeding", "biting", "catching", "picking", "playing with", "chasing", "climbing", "cleaning", "playing", "touching", "pushing", "pulling", "opening", "cooking", "talking to", "throwing", "slicing", "driving", "riding", "parked on", "driving on", "about to hit", "kicking", "swinging", "entering", "exiting", "enclosing", "leaning on"]} -------------------------------------------------------------------------------- /src/sgg_inference_vllm.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import re 4 | import torch 5 | import glob 6 | import argparse 7 | from datasets import load_dataset 8 | from transformers import AutoProcessor 9 | from accelerate import Accelerator 10 | from torch.utils.data import DataLoader 11 | from tqdm import tqdm 12 | from transformers import Qwen2VLForConditionalGeneration, GenerationConfig 13 | from qwen_vl_utils import process_vision_info 14 | 15 | import numpy as np 16 | import random 17 | from PIL import Image, ImageDraw 18 | 19 | from transformers import Qwen2_5_VLForConditionalGeneration 20 | 21 | from vllm import LLM, SamplingParams 22 | from huggingface_hub import snapshot_download 23 | 24 | os.environ["NCCL_SOCKET_TIMEOUT"] = "3600000" # 1 hours 25 | os.environ["NCCL_BLOCKING_WAIT"] = "1" 26 | 27 | 28 | from src.vg_synonyms import VG150_OBJ_CATEGORIES, VG150_PREDICATES 29 | 30 | from transformers import AutoProcessor, LlavaForConditionalGeneration 31 | from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration 32 | 33 | 34 | 35 | from open_r1.trainer.utils.misc import encode_image_to_base64, is_pil_image 36 | 37 | SYSTEM_PROMPT = ( 38 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 39 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 40 | "process and answer are enclosed within and tags, respectively, i.e., " 41 | " reasoning process here answer here " 42 | ) 43 | 44 | from open_r1.trainer.utils.prompt_gallery import ( 45 | PROMPT_SG, 46 | PROMPT_CLOSE_PSG, 47 | PROMPT_CLOSE_VG150, 48 | VG150_BASE_OBJ_CATEGORIES, 49 | VG150_BASE_PREDICATE, 50 | format_prompt_close_sg, 51 | VG150_OBJ_CATEGORIES, 52 | VG150_PREDICATES 53 | ) 54 | 55 | 56 | def get_model(name, device_map="auto", max_model_len=4096): 57 | is_qwen2vl = 'qwen2vl' in name.lower() or 'qwen2-vl' in name.lower() 58 | is_qwen25vl = 'qwen2.5-vl' in name.lower() or 'qwen25-vl' in name.lower() or 'qwen2.5vl' in name.lower() 59 | is_llava = 'llava' in name.lower() 60 | base_model_name = None 61 | if is_qwen2vl or is_qwen25vl: 62 | print("Using model:", name) 63 | min_pixels = 4*28*28 64 | max_pixels = 1024*28*28 65 | if is_qwen2vl: 66 | if '7b' in name.lower(): 67 | base_model_name = "Qwen/Qwen2-VL-7B-Instruct" 68 | elif '2b' in name.lower(): 69 | base_model_name = "Qwen/Qwen2-VL-2B-Instruct" 70 | if is_qwen25vl: 71 | if '7b' in name.lower(): 72 | base_model_name = "Qwen/Qwen2.5-VL-7B-Instruct" 73 | elif '3b' in name.lower(): 74 | base_model_name = "Qwen/Qwen2.5-VL-3B-Instruct" 75 | 76 | assert base_model_name is not None, "TODO: check the model -- {}".format(name) 77 | processor = AutoProcessor.from_pretrained(base_model_name, 78 | min_pixels=min_pixels, max_pixels=max_pixels) 79 | 80 | try: 81 | local_model_path = snapshot_download(name) 82 | print(f"set model:{name} to local path:", local_model_path) 83 | name = local_model_path 84 | except: 85 | pass 86 | 87 | model = LLM( 88 | model=name, 89 | limit_mm_per_prompt={"image": 1}, 90 | dtype='bfloat16', 91 | device=device_map, 92 | max_model_len=max_model_len, 93 | mm_processor_kwargs= { "max_pixels": max_pixels, "min_pixels": min_pixels}, 94 | ) 95 | elif is_llava: 96 | model_cls = LlavaForConditionalGeneration if '1.5' in name else LlavaNextForConditionalGeneration 97 | model = model_cls.from_pretrained( 98 | name, 99 | torch_dtype=torch.bfloat16, 100 | ).to(device_map) 101 | processor = AutoProcessor.from_pretrained(name) 102 | else: 103 | raise Exception(f"Unknown model_id: {name}") 104 | 105 | return is_qwen2vl, is_qwen25vl, is_llava, model, processor 106 | 107 | 108 | def replace_answer_format(item: str) -> str: 109 | return item.replace("", "```json").replace("", "```") 110 | 111 | def format_data(dataset_name, sample, use_predefined_cats=False, use_think_system_prompt=False, remove_image_size_in_prompt=True): 112 | image = sample['image'].convert('RGB') 113 | iw, ih = image.size 114 | if use_predefined_cats: 115 | prompt = PROMPT_CLOSE_PSG if 'psg' in dataset_name else PROMPT_CLOSE_VG150 116 | else: 117 | prompt = PROMPT_SG 118 | 119 | if remove_image_size_in_prompt: 120 | prompt = prompt.replace(f"of size ({iw} x {ih}) ", "") 121 | 122 | prompt = replace_answer_format(prompt) 123 | 124 | system_prompt = SYSTEM_PROMPT if use_think_system_prompt else "You are a helpful and multimodal AI assistant." 125 | 126 | base64_image = encode_image_to_base64(image) 127 | messages = [ 128 | { 129 | "role": "system", 130 | "content": system_prompt 131 | }, 132 | { 133 | "role": "user", 134 | "content": [ 135 | {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}, 136 | {"type": "text", "text": prompt}, 137 | ], 138 | }, 139 | ] 140 | return image, messages 141 | 142 | def parse_args(): 143 | parser = argparse.ArgumentParser(description="Run model inference on a dataset.") 144 | parser.add_argument("--dataset", required=True, help="Hugging Face dataset identifier") 145 | parser.add_argument("--model", required=True, help="Model name to load") 146 | parser.add_argument("--output_dir", required=True, help="Directory to save the outputs") 147 | parser.add_argument("--use_think_system_prompt", action="store_true", help="Use system prompt with ...") 148 | parser.add_argument("--use_predefined_cats", action="store_true", help="Use predefined categories in the prompt") 149 | parser.add_argument("--max_model_len", type=int, default=4096, help="max_model_len for vLLM") 150 | parser.add_argument("--batch_size", type=int, default=1, help="batch size") 151 | 152 | return parser.parse_args() 153 | 154 | def main(): 155 | # Parse command line arguments. 156 | args = parse_args() 157 | print("args:", args) 158 | 159 | # Initialize Accelerator for distributed training/inference. 160 | accelerator = Accelerator() 161 | local_rank = accelerator.local_process_index 162 | device = f"cuda:{local_rank}" # each process occupies a GPU 163 | 164 | # Get rank and world size for manual splitting 165 | rank = torch.distributed.get_rank() # GPU ID or node rank 166 | world_size = torch.distributed.get_world_size() # Total number of GPUs/nodes 167 | 168 | 169 | # Load the model and processor. 170 | is_qwen2vl, is_qwen25vl, is_llava, model, processor = get_model(args.model, device_map=device, max_model_len=args.max_model_len) 171 | sampling_params = SamplingParams( 172 | temperature=0.01, 173 | top_k=1, 174 | top_p=0.001, 175 | repetition_penalty=1.0, 176 | max_tokens=2048, 177 | ) 178 | 179 | print(f"model_id: {args.model}", " generation_config:", sampling_params) 180 | 181 | class Collator(object): 182 | def __init__(self, data_name, 183 | processor, 184 | use_predefined_cats, use_think_system_prompt, 185 | is_llava=False): 186 | self.data_name = data_name 187 | self.processor = processor 188 | self.use_predefined_cats = use_predefined_cats 189 | self.use_think_system_prompt = use_think_system_prompt 190 | self.is_llava = is_llava 191 | 192 | def __call__(self, examples): 193 | ids = [e['image_id'] for e in examples] 194 | gt_objs = [e['objects'] for e in examples] 195 | gt_rels = [e['relationships'] for e in examples] 196 | 197 | llm_inputs = [] 198 | images = [] 199 | for example in examples: 200 | image, prompt = format_data(self.data_name, example, 201 | use_predefined_cats=self.use_predefined_cats, 202 | use_think_system_prompt=self.use_think_system_prompt) 203 | 204 | if self.is_llava: 205 | conversation = [{'role': 'user', 'content': [{'type': 'text', 'text': prompt[-1]['content'][-1]['text']}, {"type": "image"},]}] 206 | prompt_item = self.processor.apply_chat_template(conversation, add_generation_prompt=True) 207 | llm_inputs.append(prompt_item) 208 | else: 209 | llm_inputs.append(prompt) 210 | images.append(image) 211 | 212 | if self.is_llava: 213 | llm_inputs = self.processor(text=llm_inputs, images=images, padding=True, return_tensors="pt") 214 | input_height = input_width = [336]*len(images) 215 | else: 216 | texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) 217 | for msg in llm_inputs] 218 | inputs = processor( 219 | text=texts, 220 | images=images, 221 | padding=True, 222 | return_tensors="pt", 223 | ) 224 | input_height = [inputs['image_grid_thw'][idx][1].item()*14 for idx in range(len(images))] 225 | input_width = [inputs['image_grid_thw'][idx][2].item()*14 for idx in range(len(images))] 226 | 227 | return ids, gt_objs, gt_rels, input_width, input_height, llm_inputs 228 | 229 | 230 | 231 | # Load dataset from Hugging Face hub. 232 | dataset = load_dataset(args.dataset)['train'] 233 | 234 | names = glob.glob(args.output_dir + "/*json") 235 | names = set([e.split('/')[-1].replace('.json', '') for e in tqdm(names)]) 236 | ids = [] 237 | for idx, item in enumerate(tqdm(dataset)): 238 | if item['image_id'] in names: 239 | continue 240 | ids.append(idx) 241 | dataset = dataset.select(ids) 242 | print("*"*100, " old:", len(names), " unhandled:", len(dataset)) 243 | 244 | 245 | # Split dataset manually 246 | total_size = len(dataset) 247 | per_gpu_size = total_size // world_size 248 | start_idx = rank * per_gpu_size 249 | end_idx = total_size if rank == world_size - 1 else (rank + 1) * per_gpu_size 250 | 251 | subset = dataset.select(range(start_idx, end_idx)) # Select subset for this GPU 252 | print("*"*100, "\n rank:", rank, " world size:", world_size, 253 | "subset from", start_idx, " to ", end_idx, "\n", 254 | "\n data[0]:", format_data(args.dataset, dataset[0], use_predefined_cats=args.use_predefined_cats, use_think_system_prompt=args.use_think_system_prompt), 255 | "*"*100) 256 | 257 | data_loader = DataLoader( 258 | subset, 259 | batch_size=args.batch_size, 260 | shuffle=False, 261 | collate_fn=Collator(args.dataset, processor, 262 | use_predefined_cats=args.use_predefined_cats, 263 | use_think_system_prompt=args.use_think_system_prompt, 264 | is_llava=is_llava), 265 | pin_memory=True 266 | ) 267 | #data_loader = accelerator.prepare(data_loader) 268 | print(f"Local ID: {local_rank} | len(dataset): {len(data_loader)}") 269 | 270 | # Create output directory if it doesn't exist. 271 | os.makedirs(args.output_dir, exist_ok=True) 272 | print(f"Save to {args.output_dir}") 273 | 274 | # Iterate over the data loader. 275 | _iter = 0 276 | for im_ids, gt_objs, gt_rels, input_width, input_height, batch in tqdm(data_loader, desc=f"Progress at rank {local_rank}"): 277 | with torch.no_grad(): 278 | if is_llava: 279 | batch = batch.to(model.device) 280 | outputs = model.generate(**batch, max_new_tokens=2048) 281 | output_texts = processor.batch_decode(outputs, skip_special_tokens=True) 282 | output_texts = [text.split("ASSISTANT:")[-1] for text in output_texts] 283 | else: 284 | outputs = model.chat(batch, sampling_params=sampling_params) 285 | output_texts = [output.outputs[0].text for output in outputs] 286 | 287 | 288 | if local_rank == 0 and _iter % 100 == 0: 289 | print("*" * 100) 290 | print("nvidia-smi:") 291 | os.system("nvidia-smi") 292 | print("*" * 100) 293 | print("*"*100, "\n", "image_id:", im_ids[0], "\n", 294 | "Response:", output_texts[0], "\n", 295 | "GT objs:", gt_objs[0], " GT rels.: ", gt_rels[0], 296 | "*"*100) 297 | 298 | _iter += 1 299 | for im_id, gt_obj, gt_rel, output_text, input_iw, input_ih in zip(im_ids, gt_objs, gt_rels, output_texts, input_width, input_height): 300 | if is_qwen2vl: 301 | box_scale = [1000.0, 1000.0] 302 | else: 303 | box_scale = [input_iw, input_ih] 304 | 305 | out = {"image_id": im_id, "response": output_text, 306 | "gt_objects": gt_obj, "gt_relationships": gt_rel, 307 | "box_scale": box_scale 308 | } 309 | dst_file = os.path.join(args.output_dir, f"{im_id}.json") 310 | with open(dst_file, 'w') as fout: 311 | json.dump(out, fout) 312 | 313 | print("Rank:", rank, " finished!") 314 | torch.cuda.empty_cache() 315 | accelerator.wait_for_everyone() 316 | print("All jobs finished!") 317 | 318 | if __name__ == "__main__": 319 | main() 320 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/bbox_overlaps.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) OpenMMLab. All rights reserved. 2 | import numpy as np 3 | 4 | 5 | def bbox_overlaps(bboxes1, 6 | bboxes2, 7 | mode='iou', 8 | eps=1e-6, 9 | use_legacy_coordinate=False): 10 | """Calculate the ious between each bbox of bboxes1 and bboxes2. 11 | 12 | Args: 13 | bboxes1 (ndarray): Shape (n, 4) 14 | bboxes2 (ndarray): Shape (k, 4) 15 | mode (str): IOU (intersection over union) or IOF (intersection 16 | over foreground) 17 | use_legacy_coordinate (bool): Whether to use coordinate system in 18 | mmdet v1.x. which means width, height should be 19 | calculated as 'x2 - x1 + 1` and 'y2 - y1 + 1' respectively. 20 | Note when function is used in `VOCDataset`, it should be 21 | True to align with the official implementation 22 | `http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCdevkit_18-May-2011.tar` 23 | Default: False. 24 | 25 | Returns: 26 | ious (ndarray): Shape (n, k) 27 | """ 28 | 29 | assert mode in ['iou', 'iof'] 30 | if not use_legacy_coordinate: 31 | extra_length = 0. 32 | else: 33 | extra_length = 1. 34 | bboxes1 = bboxes1.astype(np.float32) 35 | bboxes2 = bboxes2.astype(np.float32) 36 | rows = bboxes1.shape[0] 37 | cols = bboxes2.shape[0] 38 | ious = np.zeros((rows, cols), dtype=np.float32) 39 | if rows * cols == 0: 40 | return ious 41 | exchange = False 42 | if bboxes1.shape[0] > bboxes2.shape[0]: 43 | bboxes1, bboxes2 = bboxes2, bboxes1 44 | ious = np.zeros((cols, rows), dtype=np.float32) 45 | exchange = True 46 | area1 = (bboxes1[:, 2] - bboxes1[:, 0] + extra_length) * ( 47 | bboxes1[:, 3] - bboxes1[:, 1] + extra_length) 48 | area2 = (bboxes2[:, 2] - bboxes2[:, 0] + extra_length) * ( 49 | bboxes2[:, 3] - bboxes2[:, 1] + extra_length) 50 | for i in range(bboxes1.shape[0]): 51 | x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) 52 | y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) 53 | x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) 54 | y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) 55 | overlap = np.maximum(x_end - x_start + extra_length, 0) * np.maximum( 56 | y_end - y_start + extra_length, 0) 57 | if mode == 'iou': 58 | union = area1[i] + area2 - overlap 59 | else: 60 | union = area1[i] if not exchange else area2 61 | union = np.maximum(union, eps) 62 | ious[i, :] = overlap / union 63 | if exchange: 64 | ious = ious.T 65 | return ious 66 | -------------------------------------------------------------------------------- /src/utils/cocoeval.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pycocotools.cocoeval import COCOeval 4 | from .bbox_overlaps import bbox_overlaps 5 | 6 | from .misc import get_rank 7 | 8 | 9 | class COCOEval(COCOeval): 10 | def __init__(self, *args, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | self.gt_dt_valid = {} 13 | 14 | def evaluateImg(self, imgId, catId, aRng, maxDet): 15 | ''' 16 | perform evaluation for single category and image 17 | :return: dict (single image results) 18 | ''' 19 | p = self.params 20 | if p.useCats: 21 | gt = self._gts[imgId,catId] 22 | dt = self._dts[imgId,catId] 23 | else: 24 | gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]] 25 | dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]] 26 | if len(gt) == 0 and len(dt) ==0: 27 | return None 28 | 29 | for g in gt: 30 | if g['ignore'] or (g['area']aRng[1]): 31 | g['_ignore'] = 1 32 | else: 33 | g['_ignore'] = 0 34 | 35 | # sort dt highest score first, sort gt ignore last 36 | gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort') 37 | gt = [gt[i] for i in gtind] 38 | dtind = np.argsort([-d['score'] for d in dt], kind='mergesort') 39 | dt = [dt[i] for i in dtind[0:maxDet]] 40 | iscrowd = [int(o['iscrowd']) for o in gt] 41 | 42 | # load computed ious 43 | ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId] 44 | 45 | T = len(p.iouThrs) 46 | G = len(gt) 47 | D = len(dt) 48 | gtm = np.zeros((T,G)) 49 | dtm = np.zeros((T,D)) 50 | gtIg = np.array([g['_ignore'] for g in gt]) 51 | dtIg = np.zeros((T,D)) 52 | 53 | has_group_of = False 54 | if len(gt) > 0 and len(dt) > 0 and 'is_group_of' in gt[0]: 55 | has_group_of = True 56 | 57 | if has_group_of: 58 | dt_boxes = np.array([d['bbox'] for d in dt]).reshape(-1, 4) 59 | gt_boxes = np.array([g['bbox'] for g in gt]).reshape(-1, 4) 60 | dt_boxes[:, 2:] = dt_boxes[:, 2:] + dt_boxes[:, :2] 61 | gt_boxes[:, 2:] = gt_boxes[:, 2:] + gt_boxes[:, :2] 62 | 63 | 64 | is_group_of = np.array([g['is_group_of'] for g in gt], dtype=bool) 65 | is_group_idx = np.where(is_group_of)[0] 66 | non_group_idx = np.where(~is_group_of)[0] 67 | 68 | iofs = bbox_overlaps(dt_boxes, gt_boxes, mode='iof') 69 | 70 | non_group_gt = [gt[e] for e in non_group_idx] 71 | group_gt = [gt[e] for e in is_group_idx] 72 | # step 1: for non-group-of gts. 73 | if len(ious) > 0: 74 | for tind, t in enumerate(p.iouThrs): 75 | for dind, d in enumerate(dt): 76 | # information about best match so far (m=-1 -> unmatched) 77 | iou = min([t,1-1e-10]) 78 | m = -1 79 | for gind, g in zip(non_group_idx, non_group_gt): 80 | # if this gt already matched, and not a crowd, continue 81 | if gtm[tind,gind]>0 and not iscrowd[gind]: 82 | continue 83 | # if dt matched to reg gt, and on ignore gt, stop 84 | if m>-1 and gtIg[m]==0 and gtIg[gind]==1: 85 | break 86 | 87 | # continue to next gt unless better match made 88 | if ious[dind,gind] < iou: 89 | continue 90 | # if match successful and best so far, store appropriately 91 | iou=ious[dind,gind] 92 | m=gind 93 | 94 | if m ==-1: 95 | continue 96 | # if match made store id of match for both dt and gt 97 | dtIg[tind,dind] = gtIg[m] 98 | dtm[tind,dind] = gt[m]['id'] 99 | gtm[tind,m] = d['id'] 100 | 101 | # step 2: for group-of gts 102 | if len(is_group_idx) > 0 and len(iofs) > 0: 103 | for tind, t in enumerate(p.iouThrs): 104 | iof_thresh = min([t,1-1e-10]) 105 | maxIoF = [-1 for _ in range(len(gt))] # store maximum IoF for each gt 106 | maxDtIndex = [-1 for _ in range(len(gt))] # store dt index with maximum IoF for each gt 107 | 108 | for dind, d in enumerate(dt): 109 | # dt already matched 110 | if dtm[tind, dind] > 0 or dtIg[tind, dind] > 0: 111 | continue 112 | # 113 | m = -1 114 | for gind, g in zip(is_group_idx, group_gt): 115 | iof = iofs[dind, gind] 116 | if iof > iof_thresh: 117 | if iof > maxIoF[gind]: # if current IoF is larger than stored maxIoF 118 | if maxDtIndex[gind] != -1: # if there was a previously stored dt for this gt 119 | dtIg[tind,maxDtIndex[gind]] = 1 # ignore the previous dt 120 | tmp = int(maxDtIndex[gind]) 121 | 122 | maxIoF[gind] = iof 123 | maxDtIndex[gind] = dind 124 | dtIg[tind,dind] = 0 # not ignored 125 | m = gind 126 | else: 127 | dtIg[tind,dind] = 1 # ignore other dts inside gt 128 | 129 | if m == -1: 130 | continue 131 | 132 | # if match made store id of match for both dt and gt 133 | dtm[tind,dind] = gt[m]['id'] 134 | gtm[tind,m] = d['id'] 135 | 136 | else: # normal 137 | if not len(ious)==0: 138 | for tind, t in enumerate(p.iouThrs): 139 | for dind, d in enumerate(dt): 140 | # information about best match so far (m=-1 -> unmatched) 141 | iou = min([t,1-1e-10]) 142 | m = -1 143 | for gind, g in enumerate(gt): 144 | # if this gt already matched, and not a crowd, continue 145 | if gtm[tind,gind]>0 and not iscrowd[gind]: 146 | continue 147 | # if dt matched to reg gt, and on ignore gt, stop 148 | if m>-1 and gtIg[m]==0 and gtIg[gind]==1: 149 | break 150 | 151 | # continue to next gt unless better match made 152 | if ious[dind,gind] < iou: 153 | continue 154 | # if match successful and best so far, store appropriately 155 | iou=ious[dind,gind] 156 | m=gind 157 | 158 | if m ==-1: 159 | continue 160 | 161 | # if match made store id of match for both dt and gt 162 | dtIg[tind,dind] = gtIg[m] 163 | dtm[tind,dind] = gt[m]['id'] 164 | gtm[tind,m] = d['id'] 165 | 166 | 167 | 168 | # set unmatched detections outside of area range to ignore 169 | a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt))) 170 | dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0))) 171 | 172 | for gind, g in enumerate(gt): 173 | if gtIg[gind] != 1: 174 | gid = g['category_id'] 175 | if gid not in self.gt_dt_valid: 176 | self.gt_dt_valid[gid] = {'gts': 0, 'dts': [0]*T} 177 | self.gt_dt_valid[gid]['gts'] += 1 178 | 179 | for dind, d in enumerate(dt): 180 | for tind in range(len(p.iouThrs)): 181 | if dtIg[tind, dind] != 1: 182 | did = d['category_id'] 183 | if did not in self.gt_dt_valid: 184 | self.gt_dt_valid[did] = {'gts': 0, 'dts': [0]*T} 185 | self.gt_dt_valid[did]['dts'][tind] += 1 186 | 187 | # store results for given image and category 188 | return { 189 | 'image_id': imgId, 190 | 'category_id': catId, 191 | 'aRng': aRng, 192 | 'maxDet': maxDet, 193 | 'dtIds': [d['id'] for d in dt], 194 | 'gtIds': [g['id'] for g in gt], 195 | 'dtMatches': dtm, 196 | 'gtMatches': gtm, 197 | 'dtScores': [d['score'] for d in dt], 198 | 'gtIgnore': gtIg, 199 | 'dtIgnore': dtIg, 200 | } 201 | 202 | def summarize(self): 203 | ''' 204 | Compute and display summary metrics for evaluation results. 205 | Note this functin can *only* be applied on the default parameter setting 206 | ''' 207 | def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ): 208 | p = self.params 209 | iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}' 210 | titleStr = 'Average Precision' if ap == 1 else 'Average Recall' 211 | typeStr = '(AP)' if ap==1 else '(AR)' 212 | iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \ 213 | if iouThr is None else '{:0.2f}'.format(iouThr) 214 | 215 | aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng] 216 | mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets] 217 | if ap == 1: 218 | # dimension of precision: [TxRxKxAxM] 219 | s = self.eval['precision'] 220 | # IoU 221 | if iouThr is not None: 222 | t = np.where(iouThr == p.iouThrs)[0] 223 | s = s[t] 224 | s = s[:,:,:,aind,mind] 225 | else: 226 | # dimension of recall: [TxKxAxM] 227 | s = self.eval['recall'] 228 | if iouThr is not None: 229 | t = np.where(iouThr == p.iouThrs)[0] 230 | s = s[t] 231 | s = s[:,:,aind,mind] 232 | if len(s[s>-1])==0: 233 | mean_s = -1 234 | else: 235 | mean_s = np.mean(s[s>-1]) 236 | 237 | if get_rank() == 0: 238 | print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s)) 239 | 240 | return mean_s 241 | 242 | def _summarizeDets(): 243 | stats = np.zeros((13,)) 244 | stats[0] = _summarize(1, maxDets=self.params.maxDets[2]) 245 | stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2]) 246 | stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2]) 247 | stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2]) 248 | stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2]) 249 | stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2]) 250 | stats[6] = _summarize(0, maxDets=self.params.maxDets[0]) 251 | stats[7] = _summarize(0, maxDets=self.params.maxDets[1]) 252 | stats[8] = _summarize(0, maxDets=self.params.maxDets[2]) 253 | stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2]) 254 | stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2]) 255 | stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2]) 256 | # add 257 | stats[12] = _summarize(0, iouThr=.5, maxDets=self.params.maxDets[2]) 258 | return stats 259 | 260 | def _summarizeKps(): 261 | stats = np.zeros((10,)) 262 | stats[0] = _summarize(1, maxDets=20) 263 | stats[1] = _summarize(1, maxDets=20, iouThr=.5) 264 | stats[2] = _summarize(1, maxDets=20, iouThr=.75) 265 | stats[3] = _summarize(1, maxDets=20, areaRng='medium') 266 | stats[4] = _summarize(1, maxDets=20, areaRng='large') 267 | stats[5] = _summarize(0, maxDets=20) 268 | stats[6] = _summarize(0, maxDets=20, iouThr=.5) 269 | stats[7] = _summarize(0, maxDets=20, iouThr=.75) 270 | stats[8] = _summarize(0, maxDets=20, areaRng='medium') 271 | stats[9] = _summarize(0, maxDets=20, areaRng='large') 272 | return stats 273 | if not self.eval: 274 | raise Exception('Please run accumulate() first') 275 | iouType = self.params.iouType 276 | if iouType == 'segm' or iouType == 'bbox': 277 | summarize = _summarizeDets 278 | elif iouType == 'keypoints': 279 | summarize = _summarizeKps 280 | self.stats = summarize() 281 | -------------------------------------------------------------------------------- /src/utils/wordnet.py: -------------------------------------------------------------------------------- 1 | import nltk 2 | from nltk.corpus import wordnet 3 | import json 4 | 5 | # Ensure nltk resources are downloaded 6 | #nltk.download('wordnet') 7 | 8 | def find_synonym_map(set_a, set_b_list): 9 | """ 10 | Builds a synonym map for items in set A if a synonym exists in set B. 11 | 12 | Args: 13 | set_a: A list of strings (set A). 14 | set_b_list: A list of strings (set B). 15 | 16 | Returns: 17 | A dictionary representing the synonym map. 18 | Keys are items from set A that have synonyms in set B. 19 | Values are the corresponding synonym items from set B. 20 | """ 21 | synonym_map = {} 22 | set_b = set(set_b_list) # Convert set B list to set for faster lookup 23 | 24 | for item_a in set_a: 25 | found_synonym = False 26 | # First, check for direct match (item_a itself in set_b) 27 | if item_a in set_b: 28 | synonym_map[item_a] = item_a 29 | found_synonym = True 30 | else: 31 | # If no direct match, try to find synonyms using WordNet 32 | for syn in wordnet.synsets(item_a): 33 | for lemma in syn.lemmas(): 34 | lemma_name = lemma.name() 35 | if lemma_name in set_b: 36 | synonym_map[item_a] = lemma_name 37 | found_synonym = True 38 | break # Found a synonym, move to next item in set_a 39 | if found_synonym: 40 | break # Found a synonym, move to next item in set_a 41 | if not found_synonym: 42 | # Check for synonyms by splitting words if item_a is a phrase 43 | words_in_a = item_a.split() 44 | if len(words_in_a) > 1: 45 | for word_a in words_in_a: 46 | if word_a in set_b: 47 | synonym_map[item_a] = word_a 48 | found_synonym = True 49 | break 50 | else: 51 | for syn in wordnet.synsets(word_a): 52 | for lemma in syn.lemmas(): 53 | lemma_name = lemma.name() 54 | if lemma_name in set_b: 55 | synonym_map[item_a] = lemma_name 56 | found_synonym = True 57 | break 58 | if found_synonym: 59 | break 60 | if found_synonym: 61 | break 62 | 63 | 64 | return synonym_map 65 | 66 | 67 | 68 | if __name__ == "__main__": 69 | set_A = ['quilt', 'sensor', 'fish', 'wall socket', 'goat', 'mouse', 'snow globe', 'santa', 'tie', 'mushroom', 'storefront', 'leaf', 'jar', 'feeder', 'child', 'crib', 'propeller', 'headboard', 'outhouse', 'lemon', 'weather vane', 'snowboarder', 'bike route', 'statue', 'cork', 'ice rink', 'log cabin', 'donkey', 'catering truck', 'flowerbed', 'boarding bridge', 'post it', 'lid', 'saucer', 'water tower', 'blue pants', 'hand', 'planter', 'brickwall', 'red bandana', 'bowl', 'banana', 'wheat thins', 'keyboard', 'wine bottle', 'crosswalk', 'game machine', 'sandals', 'mongoose', 'apple', 'pond', 'clothes', 'marker', 'white chair', 'earphones', 'medals', 'blocks', 'panda', 'saw', 'sun', 'taxi', 'blue mug', 'rugby ball', 'first aid kit', 'candle', 'kite', 'zebra', 'pumpkin', 'yellow truck', 'hotdog', 'bookstore', 'charger', 'grape vine', 'file', 'speaker', 'wetsuit', 'portrait', 'gondola', 'coca cola', 'leaves', 'purple flowers', 'pancake', 'racket', 'birdhouse', 'tablecloth', 'chicken', 'sticker', 'ceiling fan', 'scaffolding', 'sunglasses', 'cheese', 'aircraft carrier', 'sleeping bag', 'rock', 'telescope', 'boat', 'piano', 'flagpole', 'brick wall', 'coleslaw', 'tablet', 'shoes', 'thermos', 'corkboard', 'cowboy hat', 'cheerios', 'restaurant sign', 'floral pattern', 'sand dune', 'mug', 'buildings', 'oar', 'towel holder', 'toaster', 'backpack', 'zebra bag', 'crane', 'vase', 'server rack', 'wood', 'beer glass', 'tv stand', 'can', 'moon', 'fireplug', 'pen holder', 'antelope', 'keys', 'whipped cream', 'shelves', 'external drive', 'driver', 'calendar', 'bathroom', 'church', 'white truck', 'plate', 'vegetables', 'book', 'lights', 'hair dryer', 'knife block', 'hill', 'glue', 'star', 'stuffed animal', 'countertop', 'rainbow', 'tracks', 'green beans', 'noodles', 'chips', 'coral', 'wine', 'pepsi', 'jacket', 'police', 'bridge', 'school bus', 'slide', 'ball', 'helicopter', 'cymbal', 'pot', 'porta potty', 'wave', 'refrigerator', 'arch', 'ribs', 'basketball', 'dog', 'clouds', 'projector', 'sandbag wall', 'glasses', 'pillars', 'knee pads', 'folder', 'palm', 'bread', 'tent', 'napkin', 'gazebo', 'printer', 'peas', 'sauce', 'police car', 'lamp shade', 'picnic table', 'cigarette', 'staircase', 'driveway', 'yellow mug', 'candy', 'mccafe drink', 'soccer ball', 'coffee pot', 'mousepad', 'snow sign', 'fence', 'crates', 'well', 'asparagus', 'go sign', 'pail', 'stapler', 'bookcase', 'field', 'surfer', 'vending machine', 'store', 'buoy', 'postbox', 'tank', 'floor', 'internet', 'wardrobe', 'lake', 'water', 'hammer', 'sheep', 'pancakes', 'sky', 'paint', 'wreath', 'duck', 'wooden box', 'paintbrush', 'bouquet', 'jars', 'tray', 'pipe', 'tissue box', 'net', 'engine', 'closet', 'salad', 'toothpaste', 'vent', 'feathers', 'river', 'tennis bag', 'greenhouse', 'sculpture', 'bed', 'green fabric', 'name card', 'rockwall', 'scorpions', 'path', 'record player', 'parking sign', 'giraffe', 'watch', 'bath', 'coca cola bottle', 'metal grates', 'air conditioner', 'pitcher', 'deer head', 'menu board', 'bun', 'ski poles', 'grass', 'wrench', 'butterfly', 'purse', 'iceberg', 'sidewalk', 'flowers', 'gray car', 'atm', 'ice', 'socks', 'tractor', 'egg', 'dock', 'wind turbine', 'gloves', 'clock', 'kiteboard', 'barn', 'ornament', 'shell', 'toilet tissue', 'cow', 'truck', 'blinds', 'table', 'podium', 'ship', 'bottle', 'rocks', 'cucumber', 'seat', 'billboard', 'file cabinet', 'ruler', 'chain', 'dough', 'meat', 'scooter', 'green fridge', 'wind chime', 'microphone', 'tomato', 'ski lift', 'chocolate sauce', 'shelf', 'fries', 'oven', 'reflection', 'power pole', 'cross', 'cutlery holder', 'directory', 'berries', 'water glass', 'couscous', 'ottoman', 'box', 'plant', 'water bottle', 'snowboard', 'door', 'pizza', 'waterfall', 'stroller', 'stop sign', 'bus', 'paper bag', 'snack bar', 'headphones', 'balcony', 'bowling lane', 'apples', 'bedroom', 'tennis shorts', 'pineapple', 'stained glass', 'canvas', 'milk box', 'cutting board', 'shed', 'cereal', 'van', 'basket', 'toothbrush', 'tricycle', 'basketball court', 'fire extinguisher', 'pants', 'grill', 'boxcar', 'rug', 'greenwall', 'exit sign', 'conveyor', 'house', 'flower hanging', 'fire hydrant', 'bowling lanes', 'barrel', 'blanket', 'tile', 'cake', 'tv', 'streetlight', 'laptop', 'stone wall', 'water heater', 'sand', 'yellow strap', 'hamburger', 'game case', 'pole', 'tower', 'minnie', 'vehicle', 'firetruck', 'spectators', 'mirror', 'fridge', 'books', 'sea lion', 'tree', 'scoreboard', 'bulletin board', 'file drawer', 'board', 'hat', 'wheelchair', 'tennis player', 'wall', 'fountain', 'candles', 'ceiling', 'ambulance', 'palm tree', 'beach', 'window', 'building', 'police officer', 'drill', 'bar', 'remote', 'coffee table', 'crosswalk sign', 'lightswitch', 'jeans', 'basketball jersey', 'sandbox', 'gun', 'kiwi', 'microwave', 'power line', 'tail', 'no biking sign', 'lamp base', 'restaurant', 'stairs', 'sofa', 'pool', 'bike lane sign', 'wii', 'home plate', 'antenna', 'parrot', 'gravel', 'elephant', 'iguana', 'tag', 'bacon', 'toilet', 'squirrel', 'hot tub', 't shirt', 'side table', 'street', 'bat', 'doll', 'whiteboard', 'potatoes', 'label', 'pool table', 'sink', 'pear', 'coffee', 'skis', 'gas pump', 'placemat', 'whirlpool', 'sugar pack', 'stool', 'paper', 'frisbee', 'monitor', 'car', 'handbag', 'courtroom', 'shower', 'snowflake', 'castle', 'wine glass', 'track', 'number', 'wagon', 'ocean', 'glass roof', 'cauliflower', 'gate', 'parking by permit only', 'concrete path', 'glass', 'shirt', 'bike rack', 'plastic container', 'stove', 'paper towel', 'cup', 'snow', 'container', 'toys', 'fan', 'base', 'bowling ball', 'stuffed toy', 'lime', 'parking permit only sign', 'ski', 'pen', 'person', 'rocky slope', 'carrousel', 'curtain', 'laptop case', 'towel', 'sandwich bar', 'baby', 'water fountain', 'shrub', 'cones', 'chair', 'lottery sign', 'baptism tub', 'railing', 'goggles', 'fork', 'green car', 'forklift', 'road', 'cushion', 'cabinet', 'computer', 'juice', 'camera', 'lock', 'power lines', 'high heel shoe', 'red car', 'bowling pins', 'blue seat', 'bed sheet', 'doily', 'vacuum', 'ice cream', 'coffee cup', 'phone booth', 'bucket', 'ping pong table', 'phone', 'violin', 'lampshade', 'rice', 'cloud', 'jet', 'structure', 'sailboat', 'green plants', 'houses', 'invitation', 'pine tree', 'speedometer', 'skier', 'lamp', 'belt', 'text', 'concrete wall', 'green shirt', 'white gate', 'tarp', 'foosball table', 'desk', 'carrots', 'wafer', 'busstop', 'surfboard', 'couch', 'blue court', 'grapes', 'trolley', 'pouch', 'television', 'lighthouse', 'leg', 'soda can', 'candy bar', 'barbed wire', 'sweater', 'fish tank', 'bell', 'bookshelf', 'sack', 'seats', 'bison', 'chocolate', 'guardrail', 'pitcher mound', 'cactus', 'flipflops', 'cross traffic sign', 'scissors', 'hoop', 'airplane', 'carrot', 'underwear', 'food truck', 'bus sign', 'raft', 'bulldozer', 'dishwasher', 'fruit', 'map', 'mat', 'sail', 'drum', 'lanyard', 'toy', 'painting', 'cups', 'ground', 'pump', 'branch', 'spectator', 'inflatable', 'vegetation', 'stage', 'yellow flower', 'concrete', 'green box', 'button', 'food', 'parking meter', 'log', 'cable', 'scrambled eggs', 'clock tower', 'blue wall', 'motorcycle', 'washing machine', 'anchor', 'cockpit', 'screwdriver', 'shutters', 'tennis court', 'baby monkey', 'strawberry', 'skeleton', 'wing', 'train track', 'lollipop', 'balloon', 'tennis racket', 'candle holder', 'horse', 'goal', 'bath mat', 'shoe', 'dryer', 'baseball field', 'flag', 'beer bottle', 'chocolate dessert', 'lunchbox', 'milking machine', 'trashcan', 'ski pole', 'garage', 'skull', 'cd', 'shutter', 'knife', 'cell phone', 'wheel', 'manhole', 'roof', 'counter', 'battery', 'people', 'soil', 'broccoli', 'white cabinet', 'teddy bear', 'tissue paper', 'drain', 'banana leaf', 'papers', 'coffee maker', 'pufferfish', 'z crossing', 'fire', 'sign', 'guitar', 'chimney', 'cheetah', 'kettle', 'corn', 'bathtub', 'bag', 'sausage', 'train', 'fireplace', 'cloth', 'paella', 'platform', 'ramp', 'fire escape', 'pepper', 'mountain', '4 way sign', 'hedge', 'toast', 'calculator', 'trash can', 'boy', 'menu', 'sandwich', 'bank', 'salt', 'lifeguard', 'mountains', 'mobile home', 'frying egg', 'street sign', 'spaghetti', 'magazine', 'nest', 'water meter', 'goalpost', 'yellow pole', 'donut', 'ladder', 'drawer', 'shorts', 'escalator', 'newspaper', 'toilet paper', 'store sign', 'plastic bag', 'tennis net', 'street light', 'monkey bar', 'parking lot', 'telephone pole', 'case', 'spatula', 'soap', 'pluto', 'scarf', 'brick column', 'screen', 'bush', 'motor', 'elevator', 'microphone stand', 'blue building', 'tennis ball', 'spoon', 'flower arrangement', 'computer monitor', 'bench', 'dirt', 'pillow', 'woman', 'lettuce', 'dresser', 'ketchup', 'tea', 'metal object', 'tow truck', 'light', 'mouse pad', 'potted plant', 'bus stop', 'no parking sign', 'fruits', 'moss', 'plug', 'thermostat', 'traffic light', 'machine', 'goose', 'tissue', 'polling station sign', 'earrings', 'mirror ball', 'umbrella', 'fishbowl', 'nightstand', 'baseball', 'mailbox', 'towel bar', 'trees', 'scorpion', 'cd rack', 'glove', 'office', 'picture frame', 'white fridge', 'notebook', 'flower', 'knee pad', 'poster', 'handle', 'faucet', 'minibar', 'plane', 'bicycle', 'spiderman', 'column', 'soap dish', 'kitkat', 'carriage', 'hose', 'frame', 'potato', 'dome', 'cereal box', 'spray bottle', 'trash', 'trailer', 'sidecar', 'chandelier', 'collaborate sign', 'tripod', 'record', 'garbage can', 'monkey', 'watering can', 'street lamp', 'cat', 'mannequin', 'helmet', 'baseball glove', 'socket', 'valve', 'cowboy', 'arm', 'straw', 'toilet seat', 'wallet', 'collar', 'bird', 'package', 'grape', 'bear', 'cone', 'penne', 'phonebooth', 'no entry', 'radiator', 'lily pad', 'one way sign', 'bagel', 'lantern', 'armchair', 'air hockey table', 'orange juice', 'stripes', 'suv', 'suitcase', 'bike', 'alarm', 'steak', 'grate', 'toilet brush', 'luggage', 'onions', 'pier', 'skateboard', 'cash register', 'spire', 'pan', 'red counter', 'picture', 'orange', 'teapot', 'tank top', 'cart', 'rope', 'man', 'windmill', 'mashed potatoes', 'eraser', 'yogurt', 'tiger'] 70 | set_B_list = ['airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra'] 71 | 72 | 73 | print(find_synonym_map(["bicycle"], set_B_list)) 74 | import pdb; pdb.set_trace() 75 | 76 | synonym_mapping = find_synonym_map(set_A, set_B_list) 77 | print(synonym_mapping) 78 | with open("synonym_mapping.json", 'w') as fout: 79 | json.dump(synonym_mapping, fout) 80 | 81 | -------------------------------------------------------------------------------- /src/utils/zeroshot_triplet.pytorch: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gpt4vision/R1-SGG/e4de64d4c4c97edec648021d012198b21a9b1864/src/utils/zeroshot_triplet.pytorch -------------------------------------------------------------------------------- /src/vg150_eval.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | import numpy as np 3 | import copy 4 | 5 | import torch 6 | from pycocotools.coco import COCO 7 | 8 | from src.utils.sgg_eval import SggEvaluator 9 | 10 | 11 | VG150_OBJ_CATEGORIES = ['__background__', 'airplane', 'animal', 'arm', 'bag', 'banana', 'basket', 'beach', 'bear', 'bed', 'bench', 'bike', 'bird', 'board', 'boat', 'book', 'boot', 'bottle', 'bowl', 'box', 'boy', 'branch', 'building', 'bus', 'cabinet', 'cap', 'car', 'cat', 'chair', 'child', 'clock', 'coat', 'counter', 'cow', 'cup', 'curtain', 'desk', 'dog', 'door', 'drawer', 'ear', 'elephant', 'engine', 'eye', 'face', 'fence', 'finger', 'flag', 'flower', 'food', 'fork', 'fruit', 'giraffe', 'girl', 'glass', 'glove', 'guy', 'hair', 'hand', 'handle', 'hat', 'head', 'helmet', 'hill', 'horse', 'house', 'jacket', 'jean', 'kid', 'kite', 'lady', 'lamp', 'laptop', 'leaf', 'leg', 'letter', 'light', 'logo', 'man', 'men', 'motorcycle', 'mountain', 'mouth', 'neck', 'nose', 'number', 'orange', 'pant', 'paper', 'paw', 'people', 'person', 'phone', 'pillow', 'pizza', 'plane', 'plant', 'plate', 'player', 'pole', 'post', 'pot', 'racket', 'railing', 'rock', 'roof', 'room', 'screen', 'seat', 'sheep', 'shelf', 'shirt', 'shoe', 'short', 'sidewalk', 'sign', 'sink', 'skateboard', 'ski', 'skier', 'sneaker', 'snow', 'sock', 'stand', 'street', 'surfboard', 'table', 'tail', 'tie', 'tile', 'tire', 'toilet', 'towel', 'tower', 'track', 'train', 'tree', 'truck', 'trunk', 'umbrella', 'vase', 'vegetable', 'vehicle', 'wave', 'wheel', 'window', 'windshield', 'wing', 'wire', 'woman', 'zebra'] 12 | 13 | VG150_PREDICATES = ['__background__', "above", "across", "against", "along", "and", "at", "attached to", "behind", "belonging to", "between", "carrying", "covered in", "covering", "eating", "flying in", "for", "from", "growing on", "hanging from", "has", "holding", "in", "in front of", "laying on", "looking at", "lying on", "made of", "mounted on", "near", "of", "on", "on back of", "over", "painted on", "parked on", "part of", "playing", "riding", "says", "sitting on", "standing on", "to", "under", "using", "walking in", "walking on", "watching", "wearing", "wears", "with"] 14 | 15 | 16 | def compute_iou(boxA, boxB): 17 | # box format: [x1, y1, x2, y2] 18 | xA = max(boxA[0], boxB[0]) 19 | yA = max(boxA[1], boxB[1]) 20 | xB = min(boxA[2], boxB[2]) 21 | yB = min(boxA[3], boxB[3]) 22 | interWidth = max(0, xB - xA) 23 | interHeight = max(0, yB - yA) 24 | interArea = interWidth * interHeight 25 | boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) 26 | boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) 27 | unionArea = boxAArea + boxBArea - interArea 28 | return 0.0 if unionArea == 0 else interArea / unionArea 29 | 30 | 31 | class MyDataset(object): 32 | def __init__(self, db, db_type='vg150'): 33 | self._coco = None 34 | self.db_type = db_type 35 | assert self.db_type in ['vg150', 'psg'] 36 | 37 | if self.db_type == 'vg150': 38 | self.ind_to_classes = VG150_OBJ_CATEGORIES 39 | self.ind_to_predicates = VG150_PREDICATES 40 | self.name2classes = {name: cls for cls, name in enumerate(self.ind_to_classes) if name != "__background__"} 41 | self.categories = [{'supercategory': 'none', # not used? 42 | 'id': idx, 43 | 'name': self.ind_to_classes[idx]} 44 | for idx in range(len(self.ind_to_classes)) if self.ind_to_classes[idx] != '__background__' 45 | ] 46 | elif self.db_type == 'psg': 47 | psg_categories = json.load(open("src/psg_categories.json")) 48 | PSG_OBJ_CATEGORIES = psg_categories['thing_classes'] + psg_categories['stuff_classes'] 49 | PSG_PREDICATES = psg_categories['predicate_classes'] 50 | self.ind_to_classes = PSG_OBJ_CATEGORIES 51 | self.ind_to_predicates = ['__background__'] + PSG_PREDICATES 52 | self.name2classes = {name: cls for cls, name in enumerate(self.ind_to_classes) if name != "__background__"} 53 | self.categories = [{'supercategory': 'none', # not used? 54 | 'id': idx, 55 | 'name': self.ind_to_classes[idx]} 56 | for idx in range(len(self.ind_to_classes)) if self.ind_to_classes[idx] != '__background__' 57 | ] 58 | 59 | 60 | self.images = [] 61 | self.annotations = [] 62 | self.ids = [] 63 | for item in tqdm(db): 64 | im_id = item['image_id'] 65 | 66 | self.images.append({'id': im_id}) 67 | self.ids.append(im_id) 68 | objs = json.loads(item['objects']) 69 | 70 | ann = {'image_id': im_id, 'labels': [], 'boxes': []} 71 | names = [] 72 | for obj in objs: 73 | name, box = obj['id'].split('.')[0], obj['bbox'] 74 | names.append(obj['id']) 75 | cls = self.name2classes[name] 76 | ann['labels'].append(cls) 77 | ann['boxes'].append(box) 78 | 79 | rels = json.loads(item['relationships']) 80 | edges = [] 81 | for rel in rels: 82 | sub = rel['subject'] 83 | obj = rel['object'] 84 | pred = rel['predicate'] 85 | sid = names.index(sub) 86 | oid = names.index(obj) 87 | tmp = [sid, oid, self.ind_to_predicates.index(pred)] 88 | edges.append(tmp) 89 | 90 | ann['edges'] = edges 91 | self.annotations.append(ann) 92 | 93 | print("total images", len(self.images), self.images[0]) 94 | 95 | def get_groundtruth(self, index): 96 | ann = self.annotations[index] 97 | 98 | return torch.as_tensor(ann['boxes']), \ 99 | torch.as_tensor(ann['labels']), \ 100 | torch.as_tensor(ann['edges']) 101 | 102 | 103 | 104 | @property 105 | def coco(self): 106 | if self._coco is None: 107 | _coco = COCO() 108 | coco_dicts = dict( 109 | images=self.images, 110 | annotations=[], 111 | categories=self.categories) 112 | 113 | for ann in tqdm(self.annotations): 114 | for cls, box in zip(ann['labels'], ann['boxes']): 115 | assert len(box) == 4 116 | item = { 117 | 'area': (box[3] - box[1]) * (box[2] - box[0]), 118 | 'bbox': [box[0], box[1], box[2] - box[0], box[3] - box[1]], # xywh 119 | 'category_id': cls, 120 | 'image_id': ann['image_id'], 121 | 'id': len(coco_dicts['annotations']), 122 | 'iscrowd': 0, 123 | } 124 | coco_dicts['annotations'].append(item) 125 | 126 | _coco.dataset = coco_dicts 127 | _coco.createIndex() 128 | self._coco = _coco 129 | 130 | return self._coco 131 | 132 | def refine_node_edge(obj): 133 | """ remove speical chars in the name. """ 134 | obj = obj.replace("_", " ").replace("-", " ") 135 | return obj.strip().lower() 136 | 137 | if __name__ == "__main__": 138 | import os 139 | import sys 140 | import json 141 | from datasets import load_dataset 142 | import torch 143 | from collections import defaultdict 144 | 145 | preds = json.load(open(sys.argv[2])) 146 | db = load_dataset(sys.argv[1])['train'] 147 | db_type = 'vg150' if 'psg' not in sys.argv[1] else 'psg' 148 | dataset = MyDataset(db, db_type) 149 | 150 | 151 | ngR = [] 152 | mR = defaultdict(list) 153 | ngR_per_image = [] 154 | mR_per_image = defaultdict(list) 155 | num_gt_rels = 0 156 | for gt in tqdm(db): 157 | im_id = gt['image_id'] 158 | if im_id in preds: # to prevent wrong generated image_id 159 | pred = preds[im_id] 160 | else: 161 | pred = None 162 | gt_rels = json.loads(gt['relationships']) 163 | gt_objects = json.loads(gt['objects']) 164 | gt_boxes = {refine_node_edge(obj['id']): obj['bbox'] for obj in gt_objects} 165 | recall = [] 166 | recall_per_cat = defaultdict(list) 167 | 168 | for gt_rel in gt_rels: 169 | num_gt_rels += 1 170 | match = False 171 | gt_pred = refine_node_edge(gt_rel['predicate']) 172 | gt_sub_name = refine_node_edge(gt_rel['subject']) 173 | gt_obj_name = refine_node_edge(gt_rel['object']) 174 | 175 | if pred is not None: 176 | for pred_rel in pred['relation_tuples']: 177 | if refine_node_edge(gt_pred) != refine_node_edge(pred_rel[-1]): 178 | continue 179 | 180 | if gt_sub_name.split('.')[0].strip() != refine_node_edge(pred_rel[0]).split('.')[0].strip() or \ 181 | gt_obj_name.split('.')[0].strip() != refine_node_edge(pred_rel[2]).split('.')[0].strip(): 182 | continue 183 | 184 | sub_iou = compute_iou(gt_boxes[gt_sub_name], pred_rel[1]) 185 | obj_iou = compute_iou(gt_boxes[gt_obj_name], pred_rel[3]) 186 | if sub_iou >= 0.5 and obj_iou >= 0.5: 187 | match = True 188 | break 189 | 190 | recall.append(match) 191 | ngR.append(match) 192 | mR[gt_pred].append(match) 193 | recall_per_cat[gt_pred].append(match) 194 | 195 | if len(recall) > 0: 196 | ngR_per_image.append(sum(recall) / len(recall) ) 197 | 198 | for k in recall_per_cat.keys(): 199 | mR_per_image[k].append(sum(recall_per_cat[k]) / len(recall_per_cat[k]) ) 200 | 201 | mR_list = [] 202 | for k in mR.keys(): 203 | tmp = round(np.mean(mR[k]), 4) 204 | mR_list.append((k, tmp)) 205 | 206 | ngR_per_image = np.mean(ngR_per_image) 207 | mR_per_image = [(cat, round(np.mean(mR_per_image[cat]), 4) ) for cat in mR_per_image.keys()] 208 | 209 | 210 | sgg_evaluator = SggEvaluator(dataset, iou_types=("bbox","relation"), 211 | num_workers=4, 212 | num_rel_category=len(dataset.ind_to_predicates)) 213 | 214 | def to_torch(item): 215 | for k in item.keys(): 216 | try: 217 | item[k] = torch.as_tensor(item[k]) 218 | except: 219 | pass 220 | 221 | 222 | k0 = None 223 | for k in tqdm(preds.keys()): 224 | k0 = k 225 | to_torch(preds[k]) 226 | if 'graph' in preds[k]: 227 | graph = preds[k]['graph'] 228 | to_torch(graph) 229 | preds[k]['graph'] = graph 230 | 231 | print('id:', k0, ' v:', preds[k0]) 232 | 233 | sgg_evaluator.update(preds) 234 | sgg_evaluator.synchronize_between_processes() 235 | 236 | sgg_res = sgg_evaluator.accumulate() 237 | sgg_evaluator.summarize() 238 | sgg_evaluator.reset() 239 | 240 | 241 | print("whole ng recall list:", mR_list) 242 | print(f'whole ngR: {np.mean(ngR) * 100:.2f}') 243 | print(f'whole mean of ngR: {sum([e[1] for e in mR_list]) / len(mR_list) * 100:.2f}') 244 | print(f'ngR per image:{ngR_per_image * 100:.2f}') 245 | print(f'mean ngR per image: {sum([e[1] for e in mR_per_image]) / len(mR_per_image) * 100:.2f}') 246 | print(f'mean ngR list:{mR_per_image}') 247 | -------------------------------------------------------------------------------- /tests/test_fsdp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | import os 5 | import time 6 | from io import BytesIO 7 | import base64 8 | import json 9 | from contextlib import nullcontext 10 | 11 | import deepspeed 12 | from accelerate.utils import DistributedType 13 | from accelerate import Accelerator 14 | from torch.distributed.fsdp import FullyShardedDataParallel as FSDP 15 | from transformers import Qwen2VLProcessor, Qwen2VLForConditionalGeneration 16 | 17 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient 18 | from datasets import load_dataset 19 | 20 | 21 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: 22 | buffer = BytesIO() 23 | image.save(buffer, format=format) 24 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 25 | 26 | 27 | def prepare_messages(image): 28 | encoded_image_text = encode_image_to_base64(image) 29 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}" 30 | 31 | messages_vllm = [ 32 | {"role": "system", "content": "You are a helpful assistant."}, 33 | { 34 | "role": "user", 35 | "content": [ 36 | {"type": "image_url", "image_url": {"url": base64_qwen}}, 37 | {"type": "text", "text": "Describe this image."}, 38 | ], 39 | }, 40 | ] 41 | 42 | return messages_vllm 43 | 44 | 45 | 46 | def main(): 47 | model_name = "Qwen/Qwen2-VL-7B-Instruct" 48 | 49 | accelerator = Accelerator() 50 | device = accelerator.device 51 | 52 | processor = Qwen2VLProcessor.from_pretrained(model_name, max_pixels=512*28*28) 53 | model = Qwen2VLForConditionalGeneration.from_pretrained( 54 | model_name, 55 | torch_dtype=torch.bfloat16, 56 | attn_implementation='flash_attention_2' 57 | ) 58 | model = accelerator.prepare_model(model) 59 | 60 | 61 | """ test clients """ 62 | def get_gateway_client_id(world_size, rank, gpus_per_node, num_clients): 63 | num_nodes = world_size // gpus_per_node 64 | client_ranks = [ 65 | (i % num_nodes) * gpus_per_node + (i // num_nodes) 66 | for i in range(num_clients) 67 | ] 68 | if rank in client_ranks: 69 | return client_ranks.index(rank) 70 | return None 71 | 72 | rank = accelerator.process_index 73 | world_size = accelerator.num_processes 74 | gpus_per_node = torch.cuda.device_count() 75 | 76 | 77 | hosts, ports = [], [] 78 | for line in open("ip_port_test.txt"): 79 | host, port =line.strip().split(':') 80 | hosts.append(host.strip()) 81 | ports.append(port.strip()) 82 | 83 | num_clients = len(hosts) 84 | client_id = get_gateway_client_id(world_size, rank, gpus_per_node, num_clients) 85 | 86 | # create N=len(hosts) clients 87 | if client_id is not None: 88 | vllm_client = VLLMClient( 89 | hosts, ports, 90 | connection_timeout=360, 91 | client_rank = client_id 92 | ) 93 | print("*"*100, "\n Create VLLMClient at rank:", rank, " cliend_rank:", client_id) 94 | else: 95 | vllm_client = None 96 | 97 | 98 | """ test chat """ 99 | 100 | if accelerator.is_main_process: 101 | 102 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train'] 103 | prompts = [] 104 | for kk, item in enumerate(tqdm(db)): 105 | if len(prompts) >=128: break 106 | prompt = prepare_messages(item['image']) 107 | prompts.append(prompt) 108 | 109 | print("[INFO] Running vLLM inference...") 110 | t0 = time.time() 111 | prompts = [json.dumps(e) for e in prompts] 112 | print(len(prompts)) 113 | 114 | generated_ids = vllm_client.loop.run_until_complete(vllm_client.chat(prompts, n=1, max_tokens=50, 115 | top_p=0.001, top_k=1, temperature=0.01)) 116 | 117 | t1 = time.time() - t0 118 | #generated_ids = [torch.as_tensor(e) for e in generated_ids] 119 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True), 120 | print(len(outputs)) 121 | print("****** vLLM generated text:", 122 | outputs, 123 | " cost:", t1) 124 | 125 | """ test weight synchronization """ 126 | max_chunk_size = 100 * 1024 * 1024 # 100 MB 127 | param_chunk = [] 128 | current_chunk_size = 0 129 | debug_file = "tests/debug_%s.log" % accelerator.process_index 130 | with open(debug_file, 'w') as fout: 131 | pass 132 | 133 | is_fsdp_used = accelerator.distributed_type == DistributedType.FSDP 134 | deepspeed_plugin = accelerator.state.deepspeed_plugin 135 | zero_stage_3 = deepspeed_plugin is not None and deepspeed_plugin.zero_stage == 3 136 | gather_if_zero3 = deepspeed.zero.GatheredParameters if zero_stage_3 else nullcontext 137 | 138 | if is_fsdp_used: 139 | print("*"*100, "\n Test FSDP ...\n", "*"*100) 140 | with FSDP.summon_full_params(model, recurse=True, writeback=False): 141 | for name, param in model.named_parameters(): 142 | if param.data is None: 143 | continue 144 | if vllm_client is not None: 145 | # Calculate the size of this parameter in bytes 146 | param_size = param.numel() * param.element_size() 147 | 148 | param_chunk.append((name, param.data)) 149 | current_chunk_size += param_size 150 | 151 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk. 152 | if current_chunk_size >= max_chunk_size: 153 | if os.path.exists(debug_file): 154 | with open(debug_file, 'a') as fout: 155 | names = [(p[0], p[1].shape) for p in param_chunk] 156 | cmd = f"FSDP --- rank={accelerator.process_index}, send params={names}\n" 157 | fout.write(cmd) 158 | vllm_client.update_model_in_chunks_from_named_list(param_chunk) 159 | # Reset for the next chunk 160 | param_chunk = [] 161 | current_chunk_size = 0 162 | else: 163 | print("*"*100, "\n Test non-FSDP ...\n", "*"*100) 164 | for name, param in self.model.named_parameters(): 165 | with gather_if_zero3([param]): # gather if zero3 used 166 | if vllm_client is not None: 167 | # Calculate the size of this parameter in bytes 168 | param_size = param.numel() * param.element_size() 169 | 170 | param_chunk.append((name, param.data)) 171 | current_chunk_size += param_size 172 | 173 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk. 174 | if current_chunk_size >= max_chunk_size: 175 | if os.path.exists(debug_file): 176 | with open(debug_file, 'a') as fout: 177 | names = [(p[0], p[1].shape) for p in param_chunk] 178 | cmd = f"rank={accelerator.process_index}, send params={names}\n" 179 | fout.write(cmd) 180 | vllm_client.update_model_in_chunks_from_named_list(param_chunk) 181 | # Reset for the next chunk 182 | param_chunk = [] 183 | current_chunk_size = 0 184 | 185 | # If any parameters remain that didn't reach the 100MB threshold, update them as well. 186 | if param_chunk and vllm_client is not None: 187 | if os.path.exists(debug_file): 188 | with open(debug_file, 'a') as fout: 189 | names = [(p[0], p[1].shape) for p in param_chunk] 190 | cmd = f"rank={accelerator.process_index}, send params={names}\n" 191 | fout.write(cmd) 192 | vllm_client.update_model_in_chunks_from_named_list(param_chunk) 193 | 194 | 195 | 196 | 197 | 198 | if __name__ == "__main__": 199 | main() 200 | -------------------------------------------------------------------------------- /tests/test_sampler.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from typing import Optional 3 | 4 | from transformers import Trainer, TrainingArguments 5 | 6 | from torch.utils.data import Sampler 7 | from accelerate.utils import broadcast_object_list, gather, gather_object, is_peft_model, set_seed 8 | 9 | BATCH_PER_DEVICE=2 10 | GRAD_ACC=3 11 | # 2*4*3 // 8 = 3 12 | 13 | class RepeatRandomSampler(Sampler): 14 | """ 15 | Sampler that repeats the indices of a dataset in a structured manner. 16 | 17 | Args: 18 | data_source (`Sized`): 19 | Dataset to sample from. 20 | mini_repeat_count (`int`): 21 | Number of times to repeat each index per batch. 22 | batch_size (`int`, *optional*, defaults to `1`): 23 | Number of unique indices per batch. 24 | repeat_count (`int`, *optional*, defaults to `1`): 25 | Number of times to repeat the full sampling process. 26 | seed (`int` or `None`, *optional*, defaults to `None`): 27 | Random seed for reproducibility (only affects this sampler). 28 | 29 | Example: 30 | ```python 31 | >>> sampler = RepeatRandomSampler(["a", "b", "c", "d", "e", "f", "g"], mini_repeat_count=2, batch_size=3, repeat_count=4) 32 | >>> list(sampler) 33 | [4, 4, 3, 3, 0, 0, 34 | 4, 4, 3, 3, 0, 0, 35 | 4, 4, 3, 3, 0, 0, 36 | 4, 4, 3, 3, 0, 0, 37 | 38 | 1, 1, 2, 2, 6, 6, 39 | 1, 1, 2, 2, 6, 6, 40 | 1, 1, 2, 2, 6, 6, 41 | 1, 1, 2, 2, 6, 6] 42 | ``` 43 | 44 | ```txt 45 | mini_repeat_count = 3 46 | - - - 47 | [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 48 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 49 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, | 50 | repeat_count = 2 51 | 0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, | 52 | 4, 4, 4, 5, 5, 5, 6, 6, 6, 7, 7, 7, | 53 | 8, 8, 8, 9, 9, 9, 10, 10, 10, 11, 11, 11, ...] | 54 | --------- --------- --------- --------- 55 | --------- --------- --------- --------- 56 | --------- --------- --------- --------- 57 | batch_size = 12 58 | ``` 59 | """ 60 | 61 | def __init__( 62 | self, 63 | data_source , 64 | mini_repeat_count: int, 65 | batch_size: int = 1, 66 | repeat_count: int = 1, 67 | seed: Optional[int] = None, 68 | ): 69 | self.data_source = data_source 70 | self.mini_repeat_count = mini_repeat_count 71 | self.batch_size = batch_size 72 | self.repeat_count = repeat_count 73 | self.num_samples = len(data_source) 74 | self.seed = seed 75 | self.generator = torch.Generator() # Create a local random generator 76 | if seed is not None: 77 | self.generator.manual_seed(seed) 78 | 79 | def __iter__(self): 80 | # E.g., [2, 4, 3, 1, 0, 6, 5] (num_samples = 7) 81 | indexes = torch.randperm(self.num_samples, generator=self.generator).tolist() 82 | 83 | # [2, 4, 3, 1, 0, 6, 5] 84 | # -> [[2, 4, 3], [1, 0, 6], [5]] (batch_size = 3) 85 | indexes = [indexes[i : i + self.batch_size] for i in range(0, len(indexes), self.batch_size)] 86 | 87 | # [[2, 4, 3], [1, 0, 6], [5]] 88 | # -> [[2, 4, 3], [1, 0, 6]] 89 | indexes = [chunk for chunk in indexes if len(chunk) == self.batch_size] 90 | 91 | for chunk in indexes: 92 | for _ in range(self.repeat_count): 93 | for index in chunk: 94 | for _ in range(self.mini_repeat_count): 95 | yield index 96 | 97 | def __len__(self) -> int: 98 | return self.num_samples * self.mini_repeat_count * self.repeat_count 99 | 100 | 101 | 102 | 103 | 104 | # Dummy dataset 105 | class DummyDataset(torch.utils.data.Dataset): 106 | def __init__(self, size=100): 107 | self.data = list(range(size)) 108 | 109 | def __len__(self): 110 | return len(self.data) 111 | 112 | def __getitem__(self, idx): 113 | return { 114 | "input_ids": torch.tensor([idx]), 115 | "labels": torch.tensor([idx]), 116 | } 117 | 118 | # Dummy model 119 | class DummyModel(torch.nn.Module): 120 | def __init__(self): 121 | super().__init__() 122 | self.linear = torch.nn.Linear(1, 1) 123 | 124 | def forward(self, input_ids=None, labels=None): 125 | outputs = self.linear(input_ids.float()) 126 | return {"logits": outputs} 127 | 128 | 129 | class MyTrainer(Trainer): 130 | def _get_train_sampler(self): 131 | effective_batch_size = BATCH_PER_DEVICE * self.accelerator.num_processes * GRAD_ACC 132 | print("effective_batch_size:", effective_batch_size) 133 | return RepeatRandomSampler( 134 | data_source=self.train_dataset, 135 | mini_repeat_count=8, 136 | batch_size=effective_batch_size//8, 137 | repeat_count=1*GRAD_ACC, 138 | seed=self.args.seed, 139 | ) 140 | 141 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 142 | input_ids = inputs['input_ids'] 143 | labels = inputs['labels'] 144 | all_labels = self.accelerator.gather(labels) 145 | 146 | rank = self.accelerator.process_index 147 | if rank == 0: 148 | print("rank:", self.accelerator.process_index, "global_step:", self.state.global_step, "labels:", all_labels.tolist(), "\n") 149 | 150 | # Forward pass 151 | outputs = model(input_ids) 152 | 153 | logits = outputs['logits'] 154 | loss = torch.nn.functional.mse_loss(logits, labels.float()) 155 | 156 | if return_outputs: 157 | return loss, outputs 158 | return loss 159 | 160 | # Run dummy training 161 | if __name__ == "__main__": 162 | dataset = DummyDataset(size=300) 163 | model = DummyModel() 164 | 165 | training_args = TrainingArguments( 166 | per_device_train_batch_size=BATCH_PER_DEVICE*GRAD_ACC, 167 | gradient_accumulation_steps=GRAD_ACC, 168 | num_train_epochs=1, 169 | logging_steps=1, 170 | save_steps=10, 171 | logging_dir="./logs", 172 | disable_tqdm=False, 173 | seed=42, 174 | report_to=None 175 | ) 176 | 177 | trainer = MyTrainer( 178 | model=model, 179 | args=training_args, 180 | train_dataset=dataset, 181 | ) 182 | 183 | trainer.train() 184 | -------------------------------------------------------------------------------- /tests/test_vllm.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | from io import BytesIO 4 | import json 5 | import time 6 | from typing import List 7 | from transformers import AutoProcessor 8 | 9 | import torch 10 | from PIL import Image 11 | 12 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient 13 | from datasets import load_dataset 14 | from tqdm import tqdm 15 | 16 | from transformers import Qwen2VLForConditionalGeneration 17 | 18 | SYSTEM_PROMPT = ( 19 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 20 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 21 | "process and answer are enclosed within and tags, respectively, i.e., " 22 | " reasoning process here answer here " 23 | ) 24 | 25 | 26 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: 27 | buffer = BytesIO() 28 | image.save(buffer, format=format) 29 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 30 | 31 | 32 | def prepare_messages(item): 33 | def replace_answer_format(item: str) -> str: 34 | return item.replace("", "```json").replace("", "```") 35 | 36 | image = item['image'] 37 | org_iw, org_ih = image.size 38 | 39 | prompt = item['prompt_open'] 40 | prompt = prompt.replace(f"of size ({org_iw} x {org_ih}) ", "") 41 | prompt = replace_answer_format(prompt) 42 | 43 | encoded_image_text = encode_image_to_base64(image) 44 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}" 45 | 46 | messages_vllm = [ 47 | {"role": "system", 48 | "content": SYSTEM_PROMPT 49 | }, 50 | { 51 | "role": "user", 52 | "content": [ 53 | {"type": "image_url", "image_url": {"url": base64_qwen}}, 54 | {"type": "text", "text": prompt}, 55 | ], 56 | }, 57 | ] 58 | 59 | return messages_vllm 60 | 61 | 62 | 63 | 64 | def main(args): 65 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train'] 66 | 67 | print(f"[INFO] Connecting to vLLM server at {args.hosts}:{args.server_port}") 68 | processor = AutoProcessor.from_pretrained(args.model_name_or_path) 69 | prompts = [] 70 | for kk, item in enumerate(tqdm(db)): 71 | if kk > 10: break 72 | prompt = prepare_messages(item) 73 | prompts.append(prompt) 74 | 75 | 76 | client = VLLMClient( 77 | hosts=args.hosts, #.split(','), 78 | server_ports=args.server_port, 79 | group_port=args.group_port, 80 | connection_timeout=60, 81 | ) 82 | model = Qwen2VLForConditionalGeneration.from_pretrained( 83 | "Qwen/Qwen2-VL-7B-Instruct", torch_dtype="auto", device_map="auto" 84 | ) 85 | for name, param in model.named_parameters(): 86 | client.update_named_param(name, param) 87 | 88 | print("[INFO] Running vLLM inference...") 89 | t0 = time.time() 90 | prompts = [json.dumps(e) for e in prompts] 91 | print(len(prompts)) 92 | 93 | generated_ids = client.loop.run_until_complete(client.chat(prompts, n=8, max_tokens=1024, 94 | top_p=0.001, top_k=1, temperature=0.01)) 95 | 96 | t1 = time.time() - t0 97 | #generated_ids = [torch.as_tensor(e) for e in generated_ids] 98 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True), 99 | print(len(outputs)) 100 | print("****** vLLM generated text:", 101 | outputs[0][0], 102 | " cost:", t1) 103 | 104 | 105 | 106 | def cal_cost(client, model, lens): 107 | cost = [] 108 | for i in range(3): 109 | t0 = time.time() 110 | #client.update_model_in_chunks(model, lens) 111 | 112 | named_params = list(model.named_parameters()) 113 | chunk_size = lens # or tune based on memory 114 | 115 | for i in range(0, len(named_params), chunk_size): 116 | chunk = named_params[i:i+chunk_size] 117 | client.update_model_in_chunks_from_named_list(chunk) 118 | 119 | t1 = time.time() 120 | cost.append(t1-t0) 121 | return sum(cost)/len(cost) 122 | 123 | def cal_cost_by_size(client, model, max_bytes): 124 | cost = [] 125 | for i in range(3): 126 | t0 = time.time() 127 | chunks = [] # List to accumulate (name, param) tuples 128 | current_chunk_bytes = 0 # Accumulated memory size in bytes 129 | 130 | for name, param in model.named_parameters(): 131 | param_bytes = param.numel() * param.element_size() 132 | 133 | # If adding this parameter would exceed the max_bytes limit 134 | if current_chunk_bytes + param_bytes > max_bytes: 135 | # Process the current chunk if not empty 136 | if chunks: 137 | client.update_model_in_chunks_from_named_list(chunks) 138 | chunks = [] 139 | current_chunk_bytes = 0 140 | 141 | # If the parameter itself exceeds max_bytes, process it individually 142 | if param_bytes > max_bytes: 143 | client.update_model_in_chunks_from_named_list([(name, param)]) 144 | else: 145 | # Otherwise, add the parameter to the current chunk 146 | chunks.append((name, param)) 147 | current_chunk_bytes += param_bytes 148 | 149 | # Process any remaining parameters 150 | if chunks: 151 | client.update_model_in_chunks_from_named_list(chunks) 152 | 153 | t1 = time.time() 154 | cost.append(t1 - t0) 155 | return sum(cost) / len(cost) 156 | 157 | 158 | 159 | for k in range(1, 10): 160 | try: 161 | GB = (1<<30) * 0.1 * k 162 | print(f"update cost with chunk size={k} GB:", cal_cost_by_size(client, model, GB)) 163 | except: 164 | print("Timeout at", k) 165 | break 166 | 167 | 168 | if __name__ == "__main__": 169 | parser = argparse.ArgumentParser() 170 | parser.add_argument("--hosts", type=str, default="[127.0.0.1]", help="Host address of the vLLM server.") 171 | parser.add_argument("--server_port", type=str, default='8000', help="Port for vLLM API requests.") 172 | parser.add_argument("--group_port", type=int, default=51216, help="Port for NCCL communication.") 173 | parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2-VL-7B-Instruct", help="Model ID or path.") 174 | args = parser.parse_args() 175 | main(args) 176 | -------------------------------------------------------------------------------- /tests/test_vllm_local.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import base64 3 | from io import BytesIO 4 | import json 5 | import time 6 | import random 7 | from typing import List 8 | from transformers import AutoProcessor 9 | 10 | import torch 11 | from PIL import Image 12 | 13 | from open_r1.trainer.utils.vllm_client_v2 import VLLMClient 14 | from datasets import load_dataset 15 | from tqdm import tqdm 16 | 17 | from transformers import Qwen2VLForConditionalGeneration 18 | from transformers import FineGrainedFP8Config 19 | 20 | SYSTEM_PROMPT = ( 21 | "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant " 22 | "first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning " 23 | "process and answer are enclosed within and tags, respectively, i.e., " 24 | " reasoning process here answer here " 25 | ) 26 | 27 | 28 | def encode_image_to_base64(image: Image.Image, format: str = "JPEG") -> str: 29 | buffer = BytesIO() 30 | image.save(buffer, format=format) 31 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 32 | 33 | 34 | def prepare_messages(item): 35 | def replace_answer_format(item: str) -> str: 36 | return item.replace("", "```json").replace("", "```") 37 | 38 | image = item['image'] 39 | org_iw, org_ih = image.size 40 | 41 | prompt = item['prompt_open'] 42 | prompt = prompt.replace(f"of size ({org_iw} x {org_ih}) ", "") 43 | prompt = replace_answer_format(prompt) 44 | 45 | encoded_image_text = encode_image_to_base64(image) 46 | base64_qwen = f"data:image/jpeg;base64,{encoded_image_text}" 47 | 48 | messages_vllm = [ 49 | {"role": "system", 50 | "content": SYSTEM_PROMPT 51 | }, 52 | { 53 | "role": "user", 54 | "content": [ 55 | {"type": "image_url", "image_url": {"url": base64_qwen}}, 56 | {"type": "text", "text": prompt}, 57 | ], 58 | }, 59 | ] 60 | return messages_vllm 61 | 62 | 63 | def main(args): 64 | db = load_dataset("JosephZ/vg150_val_sgg_prompt")['train'] 65 | 66 | processor = AutoProcessor.from_pretrained(args.model_name_or_path, max_pixels=512*28*28) 67 | prompts = [] 68 | for kk, item in enumerate(tqdm(db)): 69 | if kk > 200: break 70 | prompt = prepare_messages(item) 71 | prompts.append(prompt) 72 | 73 | use_fp8 = True 74 | 75 | quantization_config = FineGrainedFP8Config() 76 | kwargs = {"device_map": "cuda:0", 'attn_implementation': 'flash_attention_2'} 77 | if use_fp8: 78 | kwargs['quantization_config'] = quantization_config 79 | 80 | 81 | client = VLLMClient( 82 | local_vllm=True, 83 | model_name=args.model_name_or_path, 84 | max_pixels=512*28*28, 85 | use_fp8=use_fp8, 86 | device='cuda:0', 87 | gpu_memory_utilization=0.9 88 | ) 89 | 90 | 91 | 92 | 93 | print("[INFO] Running vLLM inference...") 94 | t0 = time.time() 95 | print("len(prompts):", len(prompts)) 96 | 97 | generated_ids = client.run_chat([prompts[0]], n=8, max_tokens=1024, 98 | top_p=0.95, top_k=50, temperature=1.0) 99 | 100 | t1 = time.time() - t0 101 | #generated_ids = [torch.as_tensor(e) for e in generated_ids] 102 | outputs = processor.batch_decode(generated_ids, skip_special_tokens=True), 103 | print(len(outputs)) 104 | print("****** vLLM generated text:") 105 | for i in range(8): 106 | print(outputs[0][i]) 107 | 108 | print(" cost:", t1) 109 | # benchmark speed 110 | cost = [] 111 | t0 = time.time() 112 | BS = 8 113 | N = 5 114 | for i in tqdm(range(N)): 115 | s0 = time.time() 116 | generated_ids = client.run_chat(prompts[2+i*BS: 2+(i+1)*BS], n=8, max_tokens=1024, 117 | top_p=0.95, top_k=50, temperature=1.0) 118 | s1 = time.time() 119 | cost.append(s1-s0) 120 | 121 | t1 = time.time() 122 | print("Benchmark speed :", (t1-t0)/ (BS*N), "second / item", " Batch cost: ", sum(cost)/len(cost) ) 123 | quit() 124 | model = Qwen2VLForConditionalGeneration.from_pretrained( 125 | "Qwen/Qwen2-VL-7B-Instruct", torch_dtype=torch.bfloat16, 126 | **kwargs 127 | ) 128 | 129 | 130 | # check weight sync. 131 | llmp = client.llm.llm_engine.model_executor.driver_worker.model_runner.model 132 | llmp_dicts = dict(llmp.named_parameters()) 133 | 134 | miss1 = [] 135 | with torch.no_grad(): 136 | for name, param in model.named_parameters(): 137 | # Fetch corresponding param from llmp 138 | llmp_param = llmp_dicts.get(name, None) 139 | if llmp_param is None: 140 | #print(f"[WARN] Parameter {name} not found in vLLM model.") 141 | miss1.append( (name, param.data.min()) ) 142 | param.data += 100 143 | continue 144 | 145 | # Compare tensors 146 | if not torch.allclose(param.data, llmp_param.data, atol=1e-5): 147 | print(f"[FAIL] Mismatch in param '{name}'") 148 | else: 149 | print(f"[PASS] Param '{name}' is synchronized.") 150 | param.data += 100 151 | 152 | 153 | print("\n", "*"*100, "start weight synchronization ...\n", "*"*100, "\n") 154 | max_chunk_size = 100 * 1024 * 1024 # 100 MB 155 | param_chunk = [] 156 | current_chunk_size = 0 157 | del llmp_dicts 158 | 159 | t0 = time.time() 160 | updated_params = set() 161 | for name, param in model.named_parameters(): 162 | # Calculate the size of this parameter in bytes 163 | param_size = param.numel() * param.element_size() 164 | 165 | param_chunk.append((name, param.data)) 166 | current_chunk_size += param_size 167 | 168 | # When the accumulated chunk reaches or exceeds 100MB, update the model parameters in one chunk. 169 | if current_chunk_size >= max_chunk_size: 170 | old = client.update_model_in_chunks_from_named_list(param_chunk) 171 | updated_params.update(old) 172 | # Reset for the next chunk 173 | param_chunk = [] 174 | current_chunk_size = 0 175 | 176 | if param_chunk and client is not None: 177 | client.update_model_in_chunks_from_named_list(param_chunk) 178 | t1 = time.time() 179 | print("weight synchronization cost:", t1-t0) 180 | # check again 181 | llmp_dicts = dict(llmp.named_parameters()) 182 | 183 | miss2 = [] 184 | for name, param in model.named_parameters(): 185 | # Fetch corresponding param from llmp 186 | llmp_param = llmp_dicts.get(name, None) 187 | if llmp_param is None: 188 | miss2.append((name, param.data.min())) 189 | #print(f"[WARN] Parameter {name} not found in vLLM model.") 190 | continue 191 | # Compare tensors 192 | if not torch.allclose(param.data, llmp_param.data, atol=1e-5): 193 | print(f"[FAIL] Mismatch in param '{name}'") 194 | else: 195 | print(f"[PASS] Param '{name}' is synchronized.") 196 | 197 | #import pdb; pdb.set_trace() 198 | 199 | def cal_cost(client, model, lens): 200 | cost = [] 201 | for i in range(3): 202 | t0 = time.time() 203 | #client.update_model_in_chunks(model, lens) 204 | 205 | named_params = list(model.named_parameters()) 206 | chunk_size = lens # or tune based on memory 207 | 208 | for i in range(0, len(named_params), chunk_size): 209 | chunk = named_params[i:i+chunk_size] 210 | client.update_model_in_chunks_from_named_list(chunk) 211 | 212 | t1 = time.time() 213 | cost.append(t1-t0) 214 | return sum(cost)/len(cost) 215 | 216 | def cal_cost_by_size(client, model, max_bytes): 217 | cost = [] 218 | for i in range(3): 219 | t0 = time.time() 220 | chunks = [] # List to accumulate (name, param) tuples 221 | current_chunk_bytes = 0 # Accumulated memory size in bytes 222 | 223 | for name, param in model.named_parameters(): 224 | param_bytes = param.numel() * param.element_size() 225 | 226 | # If adding this parameter would exceed the max_bytes limit 227 | if current_chunk_bytes + param_bytes > max_bytes: 228 | # Process the current chunk if not empty 229 | if chunks: 230 | client.update_model_in_chunks_from_named_list(chunks) 231 | chunks = [] 232 | current_chunk_bytes = 0 233 | 234 | # If the parameter itself exceeds max_bytes, process it individually 235 | if param_bytes > max_bytes: 236 | client.update_model_in_chunks_from_named_list([(name, param)]) 237 | else: 238 | # Otherwise, add the parameter to the current chunk 239 | chunks.append((name, param)) 240 | current_chunk_bytes += param_bytes 241 | 242 | # Process any remaining parameters 243 | if chunks: 244 | client.update_model_in_chunks_from_named_list(chunks) 245 | 246 | t1 = time.time() 247 | cost.append(t1 - t0) 248 | return sum(cost) / len(cost) 249 | 250 | 251 | 252 | for k in range(1, 10): 253 | try: 254 | GB = (1<<30) * 0.1 * k 255 | print(f"update cost with chunk size={k} GB:", cal_cost_by_size(client, model, GB)) 256 | except: 257 | print("Timeout at", k) 258 | break 259 | 260 | 261 | if __name__ == "__main__": 262 | parser = argparse.ArgumentParser() 263 | parser.add_argument("--model_name_or_path", type=str, default="Qwen/Qwen2-VL-7B-Instruct", help="Model ID or path.") 264 | args = parser.parse_args() 265 | main(args) 266 | 267 | --------------------------------------------------------------------------------