├── .gitignore ├── LICENSE ├── README.md ├── assets ├── overview-2.jpg └── overview-2.pdf ├── installation.md ├── openrlhf ├── __init__.py ├── cli │ ├── __init__.py │ ├── eval_ray.py │ └── train_ppo_ray.py ├── datasets │ ├── __init__.py │ ├── prompts_dataset.py │ └── utils.py ├── models │ ├── __init__.py │ ├── actor.py │ ├── loss.py │ ├── model.py │ ├── ring_attn_utils.py │ └── utils.py ├── trainer │ ├── __init__.py │ ├── evaluator.py │ ├── ppo_trainer.py │ ├── ppo_utils │ │ ├── __init__.py │ │ ├── data_processor.py │ │ ├── experience_maker.py │ │ ├── kl_controller.py │ │ └── replay_buffer.py │ └── ray │ │ ├── __init__.py │ │ ├── evaluator2.py │ │ ├── launcher.py │ │ ├── ppo_actor.py │ │ ├── ppo_critic.py │ │ ├── utils.py │ │ ├── vllm_engine.py │ │ └── vllm_worker_wrap.py └── utils │ ├── __init__.py │ ├── deepspeed │ ├── __init__.py │ ├── deepspeed.py │ └── deepspeed_utils.py │ ├── distributed_sampler.py │ ├── distributed_util.py │ ├── logging_utils.py │ ├── processor.py │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── eval_7b.sh ├── eval_vlm_new.sh └── train_vlm_multi.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # UV 98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | #uv.lock 102 | 103 | # poetry 104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 105 | # This is especially recommended for binary packages to ensure reproducibility, and is more 106 | # commonly ignored for libraries. 107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 108 | #poetry.lock 109 | 110 | # pdm 111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 112 | #pdm.lock 113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 114 | # in version control. 115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control 116 | .pdm.toml 117 | .pdm-python 118 | .pdm-build/ 119 | 120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 121 | __pypackages__/ 122 | 123 | # Celery stuff 124 | celerybeat-schedule 125 | celerybeat.pid 126 | 127 | # SageMath parsed files 128 | *.sage.py 129 | 130 | # Environments 131 | .env 132 | .venv 133 | env/ 134 | venv/ 135 | ENV/ 136 | env.bak/ 137 | venv.bak/ 138 | 139 | # Spyder project settings 140 | .spyderproject 141 | .spyproject 142 | 143 | # Rope project settings 144 | .ropeproject 145 | 146 | # mkdocs documentation 147 | /site 148 | 149 | # mypy 150 | .mypy_cache/ 151 | .dmypy.json 152 | dmypy.json 153 | 154 | # Pyre type checker 155 | .pyre/ 156 | 157 | # pytype static type analyzer 158 | .pytype/ 159 | 160 | # Cython debug symbols 161 | cython_debug/ 162 | 163 | # PyCharm 164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 166 | # and can be added to the global gitignore or merged into this file. For a more nuclear 167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 168 | #.idea/ 169 | 170 | # Ruff stuff: 171 | .ruff_cache/ 172 | 173 | # PyPI configuration file 174 | .pypirc 175 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VL-Rethinker: Incentivizing Self-Reflection of Vision-Language Models with Reinforcement Learning 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | Authors: 20 | Haozhe Wang, 21 | Chao Qu, 22 | Zuming Huang, 23 | Wei Chu, 24 | Fangzhen Lin, 25 | Wenhu Chen  26 | 27 | ## 🔥News 28 | 29 | - [2025/4/22] We release the dataset [🤗 ViRL39K](https://huggingface.co/datasets/TIGER-Lab/ViRL39K). It covers **comprehensive collection** of 39K queries including **eight categories**, and provides fine-grained **model-capability annotations** for data selection. 30 | 31 | 32 | ## Overview 33 | ![overview](./assets/overview-2.jpg) 34 | 35 |
Abstract 36 | Recently, slow-thinking systems like GPT-o1 and DeepSeek-R1 have demonstrated great potential in solving challenging problems through explicit reflection. They significantly outperform the best fast-thinking models, such as GPT-4o, on various math and science benchmarks. However, their multimodal reasoning capabilities remain on par with fast-thinking models. For instance, GPT-o1's performance on benchmarks like MathVista, MathVerse, and MathVision is similar to fast-thinking models. In this paper, we aim to enhance the slow-thinking capabilities of vision-language models using reinforcement learning (without relying on distillation) to advance the state of the art. First, we adapt the GRPO algorithm with a novel technique called Selective Sample Replay (SSR) to address the vanishing advantages problem. While this approach yields strong performance, the resulting RL-trained models exhibit limited self-reflection or self-verification. 37 | To further encourage slow-thinking, we introduce Forced Rethinking, which appends a textual rethinking trigger to the end of initial rollouts in RL training, explicitly enforcing a self-reflection reasoning step. By combining these two techniques, our model, \model, advances state-of-the-art scores on MathVista, MathVerse, and MathVision to achieve significantly to achieve 80.3\%, 61.8\% and 43.9\% respectively. \model also achieves open-source SoTA on multi-disciplinary benchmarks such as MMMU-Pro, EMMA, and MEGA-Bench, narrowing the gap with GPT-o1. Our empirical results show the effectiveness of our approaches. 38 | 39 |
40 | 41 | ## Release Progress 42 | - [x] models. 43 | - [x] data. 44 | - [ ] inference and evaluation code. 45 | - [x] training code. 46 | 47 | ### Dataset 48 | **[ViRL39K](https://huggingface.co/datasets/TIGER-Lab/ViRL39K)** lays the foundation for our RL training. It has the following merits: 49 | - **high-quality** and **verifiable**: the QAs undergo rigorous filtering and quality control, removing problematic queries or ones that cannot be verified by rules. 50 | - covering **comprehensive** topics and categories: from grade school problems to broader STEM and Social topics; reasoning with charts, diagrams, tables, documents, spatial relationships, etc. 51 | - with fine-grained **model-capability annotations**: it tells you what queries to use when training models at different scales. 52 | 53 | 54 | ### RL-ed Models 55 | - [VL-Rethinker-7B](https://huggingface.co/TIGER-Lab/VL-Rethinker-7B): undergoes the proposed SSR and Forced Rethinking training from Qwen2.5-VL-7B-Instruct. 56 | - [VL-Rethinker-72B](https://huggingface.co/TIGER-Lab/VL-Rethinker-72B): undergoes the proposed SSR and Forced Rethinking training from Qwen2.5-VL-72B-Instruct. 57 | 58 | We are training 32B and further enhancing these models. Stay Tuned! 59 | 60 | 61 | ## Performance 62 | See our [website](https://tiger-ai-lab.github.io/VL-Rethinker/) or [paper](https://arxiv.org/abs/2504.08837) for detailed performance report. 63 | 64 | 65 | ## Selective Sample Replay (SSR) 66 | 67 | Training 72B models on publicly collected queries reveals "vanishing advantages," a phenomenon where rapid saturation in large models drastically reduces effective training samples. The concurrent work [DAPO](https://arxiv.org/abs/2503.14476) on LLMs, made a similar observation. 68 | 69 | DAPO combats this by filtering ineffective queries for gradient stability.Different from this gradient perspective, our method, Selective Sample Replay (SSR), takes an active learning perspective. Drawing a similar merit from Prioritized Experience Replay, SSR re-arranges training samples based on their informativeness -- examples with high advantages, which lie near the model's capability limits (i.e., correct responses to queries the model likely fails), are particularly informative. This active selection focuses training on samples most likely to contribute to model improvement, thereby pushing training efficiency. 70 | 71 | The implementation for SSR is also simple. In addition to code in `active_sampling() @openrlhf/trainer/ppo_utils/replay_buffer.py`. Here is a pseudocode for the key idea of SSR. 72 | ```python 73 | effective_qas = rule_out_zero(candidates) 74 | p = normalize_adv(effective_qas, alpha=1) 75 | selection = np.random.choice(np.arange(len(effective_qas)), size=size, p=p)) 76 | ``` 77 | 78 | Note: For different scenarios, e.g., on-policy or off-policy, the choice of `candidates`, `size` can be different. 79 | 80 | ## Inference 81 | Our models are established on top of the Qwen2.5-VL family. So we include a simple use case here, and refer the readers to [the standard inference procedure of Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL). 82 | 83 | 84 | ```python 85 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor 86 | from qwen_vl_utils import process_vision_info 87 | 88 | # default: Load the model on the available device(s) 89 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 90 | "TIGER-Lab/VL-Rethinker-7B", torch_dtype="auto", device_map="auto" 91 | ) 92 | 93 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios. 94 | # model = Qwen2_5_VLForConditionalGeneration.from_pretrained( 95 | # "Qwen/Qwen2.5-VL-7B-Instruct", 96 | # torch_dtype=torch.bfloat16, 97 | # attn_implementation="flash_attention_2", 98 | # device_map="auto", 99 | # ) 100 | 101 | # default processor 102 | # processor = AutoProcessor.from_pretrained("TIGER-Lab/VL-Rethinker-7B") 103 | 104 | 105 | min_pixels = 256*28*28 106 | max_pixels = 1280*28*28 107 | processor = AutoProcessor.from_pretrained("TIGER-Lab/VL-Rethinker-7B", min_pixels=min_pixels, max_pixels=max_pixels) 108 | 109 | messages = [ 110 | { 111 | "role": "user", 112 | "content": [ 113 | { 114 | "type": "image", 115 | "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg", 116 | }, 117 | {"type": "text", "text": "Describe this image."}, 118 | ], 119 | } 120 | ] 121 | 122 | # Preparation for inference 123 | text = processor.apply_chat_template( 124 | messages, tokenize=False, add_generation_prompt=True 125 | ) 126 | image_inputs, video_inputs = process_vision_info(messages) 127 | inputs = processor( 128 | text=[text], 129 | images=image_inputs, 130 | videos=video_inputs, 131 | padding=True, 132 | return_tensors="pt", 133 | ) 134 | inputs = inputs.to(model.device) 135 | 136 | # Inference: Generation of the output 137 | generated_ids = model.generate(**inputs, max_new_tokens=128) 138 | generated_ids_trimmed = [ 139 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) 140 | ] 141 | output_text = processor.batch_decode( 142 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False 143 | ) 144 | print(output_text) 145 | 146 | ``` 147 | 148 | **Important Notes**: 149 | 150 | Based on the training configurations of the VL-Rethinker family, it's recommended to: 151 | - *Prompt*: 152 | 153 | append `\n\nPlease reason step by step, and put your final answer within \\boxed{}` after the use queries. 154 | - *Resolutions*: 155 | ``` 156 | min_pixels = 256*28*28 157 | max_pixels = 1280*28*28 158 | ``` 159 | 160 | 161 | ## 🚀Quick Start 162 | The proposed algorithm is implemented with the [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) framework. 163 | 164 | ### Installations 165 | Please see [the installation instructions](installation.md). 166 | 167 | ### Evaluation 168 | Our models can be evaluated like Qwen2.5-VL using [lmms_eval](https://github.com/EvolvingLMMs-Lab/lmms-eval). 169 | 170 | Here we provide an alternative evaluation approach. It offers the following benefits: 171 | - Fast: Batch inference using vLLM for 1K queries on 8 A800 within 30 mins. 172 | - Convenient: Evaluation without time-consuming API calls. Judgement made by our rule-based functions align with LLM Judges. 173 | - Train-Test Aligned: the evaluation re-uses the correctness judgement of training to minimize the gap between training and test-time evaluation. 174 | 175 | The evaluation is integrated with the OpenRLHF framework. 176 | ```bash 177 | bash ./scripts/eval_7b.sh [benchmark] [modelname] [modelpath] 178 | ``` 179 | **Note: for MMMU-Val we cannot reproduce Qwen2.5-VL with neither lmms_eval, vlmevalkit or our native evaluation. We greatly appreciate it if you could provide any insights into the correct means of reproducing it.** 180 | 181 | 182 | ### Training 183 | Run the following. 184 | ```bash 185 | bash ./scripts/train_vlm_multi.sh 186 | ``` 187 | 188 | 189 | ## Acknowledgement 190 | This project adapts from [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) and [LMM-R1](https://github.com/TideDra/lmm-r1), released under the Apache License 2.0. Thanks for their open-source contributions! 191 | 192 | ## Citation 193 | If you find this work useful, please give us a free cite: 194 | ```bibtex 195 | @article{vl-rethinker, 196 | title={VL-Rethinker: Incentivizing Self-Reflection of Vision-Language Models with Reinforcement Learning}, 197 | author = {Wang, Haozhe and Qu, Chao and Huang, Zuming and Chu, Wei and Lin, Fangzhen and Chen, Wenhu}, 198 | journal={arXiv preprint arXiv:2504.08837}, 199 | year={2025} 200 | } 201 | ``` 202 | -------------------------------------------------------------------------------- /assets/overview-2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/assets/overview-2.jpg -------------------------------------------------------------------------------- /assets/overview-2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/assets/overview-2.pdf -------------------------------------------------------------------------------- /installation.md: -------------------------------------------------------------------------------- 1 | ### Installation 2 | 3 | ```bash 4 | cd VL-Rethinker 5 | conda create -n rethinker python=3.10 6 | pip install -e .[vllm] 7 | pip install flash_attn --no-build-isolation 8 | ``` 9 | 10 | Note: vLLM >=0.7.2 is recommended. 11 | 12 | Note: If you will use multi-node training, downgrade DeepSpeed to 0.15.0. 13 | reference: https://github.com/OpenRLHF/OpenRLHF/issues/776#issuecomment-2694472824 14 | 15 | ### Workarounds 16 | At the time of this project, some bugs still linger around using flash-attn and vLLM for Qwen2.5-VL. The following are solutions from the community: 17 | 1. to fix flash-attn issues 18 | ``` 19 | export LD_LIBRARY_PATH=/path/to/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 20 | ``` 21 | reference: https://github.com/pytorch/pytorch/issues/111469#issuecomment-1869208750 22 | 23 | 24 | 2. to fix qwen-vl preprocessor issues: modify preprocessor_config.json 25 | 26 | reference: 27 | - https://github.com/huggingface/transformers/issues/36193#issuecomment-2661278628 28 | - https://github.com/huggingface/transformers/issues/36246 29 | -------------------------------------------------------------------------------- /openrlhf/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/openrlhf/__init__.py -------------------------------------------------------------------------------- /openrlhf/cli/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/openrlhf/cli/__init__.py -------------------------------------------------------------------------------- /openrlhf/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # from .process_reward_dataset import ProcessRewardDataset 2 | from .prompts_dataset import PromptDataset 3 | # from .reward_dataset import RewardDataset 4 | # from .sft_dataset import SFTDataset 5 | # from .unpaired_preference_dataset import UnpairedPreferenceDataset 6 | 7 | __all__ = [ 8 | # "ProcessRewardDataset", 9 | "PromptDataset", 10 | # "RewardDataset", 11 | # "SFTDataset", 12 | # "UnpairedPreferenceDataset" 13 | ] 14 | -------------------------------------------------------------------------------- /openrlhf/datasets/prompts_dataset.py: -------------------------------------------------------------------------------- 1 | # /* 2 | # * Modified by Haozhe Wang in 2025 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # */ 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | import json 9 | 10 | def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str: 11 | if apply_chat_template: 12 | chat = data[input_key] 13 | if isinstance(chat, str): 14 | chat = [{"role": "user", "content": chat}] 15 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 16 | else: 17 | prompt = data[input_key] 18 | if input_template: 19 | prompt = input_template.format(prompt) 20 | 21 | return prompt 22 | 23 | 24 | templates = dict(longcot=""" 25 | You are a thoughtful and diligent student tasked with solving a problem. As you work through the problem, document your thought process in a reflective, first-person narrative. Think of yourself as talking to yourself through each step. Consider each step carefully, question your reasoning, and adjust as needed to arrive at a sound solution. Here's how you should proceed: 26 | 27 | 1. **Step-by-Step Analysis**: Start by thoroughly understanding the problem. Identify what is provided and what is being asked. Consider high-level strategies or approaches first, and then break them down into smaller, manageable steps. Ensure you address each component one at a time and do not skip over any details. 28 | 29 | 2. **Self-Questioning**: As you work through each step, ask yourself reflective questions like, "Is this correct?", "Does it make sense?", or "What might I be overlooking?" Be critical of your own reasoning, and adjust your approach as needed. Use notation to express your confidence and evaluate the progress about solving the problem. 30 | 31 | 3. **Reassessment**: If you notice a mistake or feel uncertain about your approach, reassess your work. Go back and revise your assumptions, logic, or calculations to correct any missteps, ensuring you're on the right track. 32 | 33 | 4. **Alternative Approaches**: If you find yourself stuck or unsure about the current method, consider alternative approaches. Look at the problem from different angles, and if one method feels insufficient, explore others. 34 | 35 | 5. **Clear Detailing**: For each step, explain your reasoning clearly and in simple language. Make sure anyone who follows your work can easily understand the logic behind your decisions and the steps you've taken. 36 | 37 | 6. **Final Solution**: Once you're confident in your solution, enclose it in \\boxed{} to highlight your final answer. 38 | 39 | **Your goal is to approach the problem in a reflective, iterative manner, ensuring that no steps are skipped and no assumptions go unchecked.** 40 | """, 41 | default="Please reason step by step, and put your final answer within \\boxed{}.", 42 | elaborate="First understand the problem: understand what information is given in the text and understand what the images describes. Then think about what the problem is asking for and what knowledge the problem aims to examine. Finally, think about how to solve the problem step by step. Explain your solution in simple words that are easy to follow, assuming the readers are junior students who DOT NOT master well the relevant knowledge. Remember to put your final answer within \\boxed{}.", 43 | elaborate_rethink="""Guidelines: 44 | - First understand the problem: understand what information is given in the text and understand what the images describes. Then think about what the problem is asking for and what knowledge the problem aims to examine. Finally, think about how to solve the problem step by step. Explain your solution in simple words that are easy to follow, assuming the readers are junior students who DOT NOT master well the relevant knowledge. 45 | - **Regularly perform self-questioning, self-verification, self-correction to check your ongoing reasoning**, using connectives such as "Wait a moment", "Wait, does it seem right?", etc. 46 | - Remember to put your final answer within \\boxed{}.""", 47 | explain="""Guidelines: 48 | Understand what the problem is asking for, and what knowledge the problem aims to examine. 49 | Explain the problem and your solution in simple words to a reader, assuming he has rare knowledge and poor mastery about the related concepts. 50 | """, 51 | rethink="""Guidelines: 52 | Please think step by step, and **regularly perform self-questioning, self-verification, self-correction to check your ongoing reasoning**, using connectives such as "Wait a moment", "Wait, does it seem right?", etc. Remember to put your final answer within \\boxed{}.""", 53 | ) 54 | templates['none'] = "" 55 | templates['autocode'] = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:" 56 | 57 | 58 | 59 | class PromptDataset(Dataset): 60 | """ 61 | Dataset for PPO model 62 | 63 | Args: 64 | dataset: dataset for PPO model 65 | tokenizer: tokenizer for PPO model 66 | max_length: max length of input 67 | """ 68 | 69 | def preprocess_data(self, data, input_template=None, input_key="input", apply_chat_template=None, system_prompt="longcot") -> str: 70 | has_vlm_processor = self.processor is not None 71 | # print('!!!! apply chat', apply_chat_template) 72 | # print('!!!! sys', system_prompt, input_key) 73 | # import pdb; pdb.set_trace() 74 | # if system_prompt=='dpsk': 75 | # # import json 76 | # if input_key=='response' and not self.is_eval: 77 | # chat = [{"role": "user", "content": data['question']}, 78 | # # {"role": "assistant", "content": data['response']} 79 | # ] 80 | 81 | # prompt = data['question'] # self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 82 | # else: 83 | # input_key = 'messages' 84 | # chat = data[input_key] 85 | # prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 86 | 87 | # elif system_prompt=='dsmath': 88 | # chat = data['messages'] 89 | # for entry in chat: 90 | # if entry['role']=='user': break 91 | # template = "User:{instruction}\n\nAssistant:" 92 | # # entry['content'] += f'\n{templates["default"]}' 93 | 94 | # prompt = template.format(instruction=entry['content']) 95 | # elif system_prompt=='autocode': 96 | # chat = data['messages'] 97 | # for entry in chat: 98 | # if entry['role']=='user': break 99 | # template = templates[system_prompt] 100 | # # template = "User:{instruction}\n\nAssistant:" 101 | # # entry['content'] += f'\n{templates["default"]}' 102 | 103 | # prompt = template.format(entry['content']) 104 | # elif input_key=='question': 105 | # prompt = data[input_key] 106 | # if system_prompt=='default': 107 | # trigger = templates[system_prompt] 108 | # chat = [{"role": "system", "content": trigger}, 109 | # {"role": "user", "content": prompt}] 110 | # prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 111 | # else: 112 | # input_template = templates[system_prompt] 113 | # prompt = input_template.format(prompt) 114 | if has_vlm_processor: 115 | if False: 116 | chat = data[input_key] 117 | if system_prompt in templates: 118 | chat.insert(0, dict(role='system', content=templates[system_prompt])) 119 | else: print(f'!!!! warning: {system_prompt} not in templates') 120 | if isinstance(chat[-1]['content'], str): 121 | text = chat[-1]['content'] 122 | content = [ 123 | # dict(type='image', image=None), 124 | dict(type='text', text=text) 125 | ] 126 | chat[-1]['content'] = content 127 | 128 | else: 129 | # sysp = None 130 | # if system_prompt in templates: 131 | # sysp = templates[system_prompt] 132 | # else: print(f'!!!! warning: {system_prompt} not in templates') 133 | # now we don't use system prompt 134 | if system_prompt == 'notrigger': 135 | trigger = "" 136 | elif system_prompt == 'elaborate': 137 | trigger = f"\n\n{templates['elaborate']}" 138 | elif system_prompt == 'elaborate_rethink': 139 | trigger = f"\n\n{templates['elaborate_rethink']}" 140 | elif system_prompt == 'rethink': 141 | trigger = f"\n\n{templates['rethink']}" 142 | else: 143 | trigger = f"\n\n{templates[system_prompt]}" 144 | q = data['question'] 145 | img = data.get('image', None) 146 | imglist = [] 147 | if img is None or img=="" : 148 | pass # keep it empty 149 | elif isinstance(img, list): 150 | imglist = [dict(type='image', image=imm) for imm in img if imm] 151 | else: imglist = [dict(type='image', image=img)] 152 | if len(imglist)>10: 153 | print('!!! [debug]', img) 154 | chat = [dict(role='user', 155 | content=imglist+[dict(type='text', text=q+trigger)] # if img else q 156 | )] 157 | 158 | if 'qid' in data: 159 | chat.append(dict(qid=data['qid'])) 160 | prompt = json.dumps(chat) 161 | elif input_key=='question': 162 | chat = [{"role": "system", "content": templates["default"]}, 163 | {"role": "user", "content": data['question']}] 164 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 165 | elif input_key=='messages': 166 | chat = data[input_key] 167 | if len(chat)>1: 168 | chat[0] = dict(role='system', content=templates[system_prompt]) # replace 169 | else: 170 | if system_prompt in templates: 171 | chat.insert(0, dict(role='system', content=templates[system_prompt])) 172 | else: print(f'!!!! warning: {system_prompt} not in templates') 173 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 174 | 175 | elif apply_chat_template: 176 | chat = data[input_key] 177 | if isinstance(chat, str): 178 | chat = [{"role": "user", "content": chat}] 179 | else: # messages 180 | # if system_prompt!="none": 181 | if len(chat)>1: 182 | chat[0] = dict(role='system', content=templates[system_prompt]) # replace 183 | else: chat.insert(0, dict(role='system', content=templates[system_prompt])) 184 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 185 | else: 186 | prompt = data[input_key] 187 | input_template = templates[system_prompt] 188 | if system_prompt in ['none']: 189 | print(f"template cannot be {system_prompt} when not using chat template") 190 | chat = [dict(role='system', content=templates[system_prompt]), 191 | dict(role='user', content=prompt) 192 | ] 193 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 194 | else: 195 | prompt = input_template.format(prompt) 196 | if prompt=="": print('!!!! warning, prompts incorrect') 197 | return prompt 198 | 199 | def __init__( 200 | self, 201 | dataset, 202 | tokenizer, 203 | strategy, 204 | input_template=None, 205 | is_eval=False, 206 | processor=None, 207 | ) -> None: 208 | super().__init__() 209 | self.strategy = strategy 210 | self.tokenizer = tokenizer 211 | self.processor = processor 212 | self.is_eval = is_eval 213 | 214 | # chat_template 215 | self.input_template = input_template 216 | input_key = getattr(self.strategy.args, "input_key", None) 217 | controlled_shuffle = getattr(self.strategy.args, "controlled_shuffle", 0) 218 | apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 219 | 220 | system_prompt = getattr(self.strategy.args, "system_prompt", "none") 221 | # print("sysprompt", system_prompt) 222 | do_vlm = getattr(self.strategy.args, "train_vlm", False) 223 | # import pdb; pdb.set_trace() 224 | if apply_chat_template: 225 | apply_chat_template = self.processor.apply_chat_template if do_vlm else self.tokenizer.apply_chat_template 226 | 227 | 228 | self.prompts = [] 229 | repeat = 1 if controlled_shuffle==0 else controlled_shuffle 230 | for _ in range(repeat): 231 | for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()): 232 | prompt = self.preprocess_data(data, input_template, input_key, apply_chat_template, system_prompt) 233 | self.prompts.append(prompt) 234 | # print("!!!! peek", self.prompts[0]) 235 | 236 | 237 | def __len__(self): 238 | length = len(self.prompts) 239 | return length 240 | 241 | def __getitem__(self, idx): 242 | return self.prompts[idx] 243 | 244 | -------------------------------------------------------------------------------- /openrlhf/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def zero_pad_sequences(sequences, side: str = "left", value=0): 6 | assert side in ("left", "right") 7 | max_len = max(seq.size(-1) for seq in sequences) 8 | padded_sequences = [] 9 | for seq in sequences: 10 | pad_len = max_len - seq.size(-1) 11 | padding = (pad_len, 0) if side == "left" else (0, pad_len) 12 | padded_sequences.append(F.pad(seq, padding, value=value)) 13 | return torch.stack(padded_sequences, dim=0) 14 | 15 | 16 | def exist_and_not_none(d, key): 17 | return key in d and not d[key] is None 18 | -------------------------------------------------------------------------------- /openrlhf/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import Actor 2 | from .loss import ( 3 | DPOLoss, 4 | GPTLMLoss, 5 | KDLoss, 6 | KTOLoss, 7 | LogExpLoss, 8 | PairWiseLoss, 9 | PolicyLoss, 10 | SFTLoss, 11 | PRMLoss, 12 | ValueLoss, 13 | VanillaKTOLoss, 14 | ) 15 | from .model import get_llm_for_sequence_regression 16 | 17 | __all__ = [ 18 | "Actor", 19 | "DPOLoss", 20 | "GPTLMLoss", 21 | "KDLoss", 22 | "KTOLoss", 23 | "LogExpLoss", 24 | "PairWiseLoss", 25 | "PolicyLoss", 26 | "SFTLoss", 27 | "PRMLoss", 28 | "ValueLoss", 29 | "VanillaKTOLoss", 30 | "get_llm_for_sequence_regression", 31 | ] 32 | -------------------------------------------------------------------------------- /openrlhf/models/model.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Union 2 | 3 | import deepspeed 4 | import torch 5 | import torch.nn as nn 6 | from flash_attn.utils.distributed import all_gather 7 | from peft import LoraConfig, get_peft_model 8 | from peft.tuners.lora import LoraLayer 9 | from transformers import AutoConfig, AutoModel, BitsAndBytesConfig 10 | from transformers.integrations.deepspeed import HfDeepSpeedConfig 11 | 12 | from openrlhf.utils.logging_utils import init_logger 13 | 14 | from .ring_attn_utils import convert_ring_attn_params 15 | from .utils import reset_position_ids 16 | from ..utils.utils import get_generation_cls 17 | 18 | logger = init_logger(__name__) 19 | 20 | 21 | # Construct transformer with a value head for sequence classification. 22 | # https://github.com/huggingface/transformers/blob/405b56269812056d9593869e22b7b264d806cb1e/src/transformers/models/llama/modeling_llama.py#L1254 23 | def get_llm_for_sequence_regression( 24 | model_name_or_path: str, 25 | model_type: str, 26 | *, 27 | bf16=True, 28 | load_in_4bit=False, 29 | lora_rank=0, 30 | lora_alpha=16, 31 | target_modules=None, 32 | lora_dropout=0, 33 | normalize_reward=False, 34 | use_flash_attention_2=False, 35 | ds_config: dict = None, 36 | init_value_head: bool = False, 37 | value_head_prefix="score", 38 | device_map=None, 39 | packing_samples=False, 40 | **kwargs, 41 | ) -> nn.Module: 42 | """Retrieve a transformer model with a sequence regression head on top. 43 | 44 | This function loads a pretrained transformer model and attaches a linear layer for sequence regression. 45 | 46 | Args: 47 | model_name_or_path (str): Path to the pretrained model. 48 | model_type (str): Type of the model, either "reward" or "critic". 49 | bf16 (bool, optional): Enable bfloat16 precision. Defaults to True. 50 | load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. 51 | lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. 52 | lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. 53 | target_modules (list, optional): List of target modules for LoRA. Defaults to None. 54 | lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. 55 | normalize_reward (bool, optional): Normalize reward values. Defaults to False. 56 | use_flash_attention_2 (bool, optional): Use Flash Attention 2.0. Defaults to False. 57 | ds_config (dict, optional): Deepspeed configuration for model partitioning across multiple GPUs when ZeRO-3 is enabled. Defaults to None. 58 | init_value_head (bool, optional): Initialize the value head. Defaults to False. 59 | value_head_prefix (str, optional): Prefix for the value head. Defaults to "score". 60 | device_map (dict, optional): Map of devices for model loading. Defaults to None. 61 | packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. 62 | 63 | Returns: 64 | nn.Module: A pretrained transformer model with a sequence regression head. 65 | """ 66 | assert ( 67 | model_type == "critic" or model_type == "reward" 68 | ), f"invalid model_type: {model_type}, should be critic or reward." 69 | 70 | config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) 71 | config.normalize_reward = normalize_reward 72 | config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" 73 | 74 | # Prioritize using the value_head_prefix in the model configuration. 75 | value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix) 76 | logger.info(f"set value_head_prefix to `{value_head_prefix}`") 77 | base_class = get_generation_cls(config) 78 | base_pretrained_class = base_class.__base__ 79 | if model_type == "reward": 80 | cls_class = _get_reward_model(base_class, value_head_prefix, packing_samples) 81 | else: 82 | cls_class = _get_critic_model(base_class, value_head_prefix, packing_samples) 83 | 84 | # Note: dschf is defined in function scope to avoid global effects 85 | # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration 86 | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: 87 | dschf = HfDeepSpeedConfig(ds_config) 88 | else: 89 | dschf = None 90 | 91 | if load_in_4bit: 92 | assert bf16, "we only support bnb_4bit_compute_dtype = bf16" 93 | nf4_config = BitsAndBytesConfig( 94 | load_in_4bit=True, 95 | bnb_4bit_quant_type="nf4", 96 | bnb_4bit_use_double_quant=True, 97 | bnb_4bit_compute_dtype=torch.bfloat16, 98 | ) 99 | else: 100 | nf4_config = None 101 | 102 | model = cls_class.from_pretrained( 103 | model_name_or_path, 104 | config=config, 105 | trust_remote_code=True, 106 | torch_dtype=torch.bfloat16 if bf16 else "auto", 107 | quantization_config=nf4_config, 108 | device_map=device_map, 109 | **kwargs, 110 | ) 111 | 112 | # LoRA 113 | if lora_rank > 0: 114 | model.enable_input_require_grads() 115 | lora_config = LoraConfig( 116 | r=lora_rank, 117 | lora_alpha=lora_alpha, 118 | target_modules=target_modules, 119 | lora_dropout=lora_dropout, 120 | bias="none", 121 | ) 122 | model = get_peft_model(model, lora_config) 123 | 124 | if load_in_4bit: 125 | for name, module in model.named_modules(): 126 | if isinstance(module, LoraLayer): 127 | module = module.to(torch.bfloat16) 128 | if "norm" in name: 129 | module = module.to(torch.float32) 130 | if value_head_prefix in name or "embed_tokens" in name: 131 | if hasattr(module, "weight"): 132 | module = module.to(torch.bfloat16) 133 | 134 | # MoE - balancing loss 135 | model_config = model.config.to_dict() 136 | if "output_router_logits" in model_config: 137 | print("[MoE] set output_router_logits as True") 138 | model.config.output_router_logits = True 139 | 140 | # https://github.com/huggingface/transformers/issues/26877 141 | model.config.use_cache = False 142 | 143 | # NOTE: For reward model training only, intialize value_head manually 144 | # because deepspeed.zero.Init() will not intialize them. 145 | # TODO: Find a better way to clarify reward model training. 146 | if init_value_head: 147 | value_head = getattr(model, value_head_prefix) 148 | if dschf is not None: 149 | logger.info("initialize value_head for ZeRO-3 reward model training.") 150 | with deepspeed.zero.GatheredParameters([value_head.weight], modifier_rank=0): 151 | if torch.distributed.get_rank() == 0: 152 | value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) 153 | else: 154 | value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) 155 | 156 | return model 157 | 158 | 159 | def _get_reward_model(base_llm_model, value_head_prefix="score", packing_samples=False): 160 | class RewardModel(base_llm_model): 161 | supports_gradient_checkpointing = True 162 | 163 | def __init__(self, config: AutoConfig): 164 | super().__init__(config) 165 | 166 | self.value_head_prefix = value_head_prefix 167 | setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) 168 | 169 | self.packing_samples = packing_samples 170 | 171 | # mean std 172 | self.normalize_reward = config.normalize_reward 173 | self.register_buffer("mean", torch.zeros(1), persistent=False) 174 | self.register_buffer("std", torch.ones(1), persistent=False) 175 | 176 | # load mean/std from config.json 177 | if hasattr(config, "mean"): 178 | self.mean[0] = config.mean 179 | self.std[0] = config.std 180 | 181 | def forward( 182 | self, 183 | input_ids: torch.LongTensor = None, 184 | attention_mask: Optional[torch.Tensor] = None, 185 | return_output=False, 186 | ring_attn_group=None, 187 | packed_seq_lens=None, 188 | visual_inputs=None, 189 | ) -> torch.Tensor: 190 | if visual_inputs is None: 191 | visual_inputs = {} 192 | if not self.packing_samples: 193 | # https://github.com/OpenRLHF/OpenRLHF/issues/217 194 | position_ids = attention_mask.long().cumsum(-1) - 1 195 | position_ids.masked_fill_(attention_mask == 0, 1) 196 | else: 197 | # convert attention_mask to position_ids 198 | if ring_attn_group is not None: 199 | input_ids, attention_mask, position_ids = convert_ring_attn_params( 200 | input_ids, attention_mask, packed_seq_lens, ring_attn_group 201 | ) 202 | else: 203 | position_ids = reset_position_ids(attention_mask) 204 | # explicitly ignore attention_mask for packing_samples 205 | attention_mask = None 206 | 207 | outputs = super().forward( 208 | input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs 209 | ) 210 | if "last_hidden_state" in outputs: 211 | last_hidden_states = outputs["last_hidden_state"] 212 | elif "hidden_states" in outputs: 213 | last_hidden_states = outputs["hidden_states"][-1] 214 | else: 215 | raise ValueError("outputs should contain either last_hidden_state or hidden_states") 216 | values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1) 217 | 218 | if self.packing_samples: 219 | if ring_attn_group is not None: 220 | reward = all_gather(values, ring_attn_group).reshape(1, -1) 221 | else: 222 | reward = values 223 | # TODO: convert packed_seq_lens into torch tensor in advance 224 | packed_seq_lens = torch.tensor(packed_seq_lens, device=values.device) 225 | eos_indices = packed_seq_lens.cumsum(dim=0) - 1 226 | reward = reward.squeeze(0).gather(dim=0, index=eos_indices) 227 | else: 228 | eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True) 229 | reward = values.gather(dim=1, index=eos_indices).squeeze(1) 230 | 231 | if not self.training and self.normalize_reward: 232 | reward = (reward - self.mean) / self.std 233 | 234 | return (reward, outputs) if return_output else reward 235 | 236 | return RewardModel 237 | 238 | 239 | def _get_critic_model(base_llm_model, value_head_prefix="score", packing_samples=False): 240 | class CriticModel(base_llm_model): 241 | supports_gradient_checkpointing = True 242 | 243 | def __init__(self, config: AutoConfig): 244 | super().__init__(config) 245 | 246 | self.value_head_prefix = value_head_prefix 247 | setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) 248 | 249 | self.packing_samples = packing_samples 250 | 251 | # mean std 252 | self.normalize_reward = config.normalize_reward 253 | self.register_buffer("mean", torch.zeros(1), persistent=False) 254 | self.register_buffer("std", torch.ones(1), persistent=False) 255 | 256 | # load mean/std from config.json 257 | if hasattr(config, "mean"): 258 | self.mean[0] = config.mean 259 | self.std[0] = config.std 260 | 261 | def forward( 262 | self, 263 | input_ids: torch.LongTensor = None, 264 | num_actions: Optional[Union[int, list[int]]] = None, 265 | attention_mask: Optional[torch.Tensor] = None, 266 | return_output=False, 267 | packed_seq_lens=None, 268 | visual_inputs={}, 269 | ) -> torch.Tensor: 270 | if not self.packing_samples: 271 | # https://github.com/OpenRLHF/OpenRLHF/issues/217 272 | position_ids = attention_mask.long().cumsum(-1) - 1 273 | position_ids.masked_fill_(attention_mask == 0, 1) 274 | else: 275 | # convert attention_mask to position_ids 276 | position_ids = reset_position_ids(attention_mask) 277 | # explicitly ignore attention_mask for packing_samples 278 | attention_mask = None 279 | 280 | outputs = super().forward( 281 | input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs 282 | ) 283 | if "last_hidden_state" in outputs: 284 | last_hidden_states = outputs["last_hidden_state"] 285 | elif "hidden_states" in outputs: 286 | last_hidden_states = outputs["hidden_states"][-1] 287 | else: 288 | raise ValueError("outputs should contain either last_hidden_state or hidden_states") 289 | values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)[:, :-1] 290 | 291 | # normalize reward 292 | if self.normalize_reward: 293 | values = (values - self.mean) / self.std 294 | 295 | if num_actions is None: 296 | assert return_output 297 | return outputs 298 | 299 | if not self.packing_samples: 300 | action_values = values[:, -num_actions:] 301 | else: 302 | assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens) 303 | action_values = [] 304 | offset = 0 305 | for num_action, seq_len in zip(num_actions, packed_seq_lens): 306 | start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1 307 | action_values.append(values[:, start:end]) 308 | offset += seq_len 309 | action_values = torch.cat(action_values, dim=1) 310 | 311 | if return_output: 312 | return (action_values, outputs) 313 | else: 314 | return action_values 315 | 316 | return CriticModel 317 | -------------------------------------------------------------------------------- /openrlhf/models/ring_attn_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.distributed as dist 3 | import torch.nn.functional as F 4 | 5 | 6 | RING_ATTN_GROUP = None 7 | 8 | 9 | def set_ring_attn_group(group): 10 | global RING_ATTN_GROUP 11 | RING_ATTN_GROUP = group 12 | 13 | 14 | def get_ring_attn_group(): 15 | return RING_ATTN_GROUP 16 | 17 | 18 | def reset_ring_attn_position_ids(start, end, packed_seq_lens): 19 | """ 20 | Calculate position ids for packed_seq_ids[start:end]. 21 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8, 22 | the position ids will be [2, 0, 1, 0, 1, 2]. 23 | 24 | Args: 25 | start: the start position 26 | end: the end position 27 | packed_seq_lens: the sequence lengths of packed sequences 28 | """ 29 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device()) 30 | offset = 0 31 | for seqlen in packed_seq_lens: 32 | seq_start = max(offset, start) 33 | seq_end = min(offset + seqlen, end) 34 | if seq_start < seq_end: 35 | position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset) 36 | 37 | offset += seqlen 38 | if offset >= end: 39 | break 40 | return position_ids 41 | 42 | 43 | def update_ring_attn_params(packed_seq_lens, total_seq_len): 44 | """ 45 | Calculate the cu_seqlens for the current forward pass and pass the value to 46 | the substituted ring_flash_attn. 47 | 48 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding. 49 | """ 50 | assert RING_ATTN_GROUP is not None 51 | cu_seqlens = torch.cumsum( 52 | torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32), 53 | dim=-1, 54 | dtype=torch.int32, 55 | ) 56 | cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len) 57 | 58 | from ring_flash_attn import update_ring_flash_attn_params 59 | 60 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP) 61 | 62 | 63 | def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_attn_group): 64 | # each rank within the ring group will process sequences[start:end] 65 | ring_attn_rank = dist.get_rank(group=ring_attn_group) 66 | ring_attn_size = dist.get_world_size(group=ring_attn_group) 67 | total_seq_len = sequences.numel() 68 | local_seq_len = total_seq_len // ring_attn_size 69 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len 70 | sequences = sequences[:, start:end] 71 | attention_mask = attention_mask[:, start:end] 72 | position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens) 73 | update_ring_attn_params(packed_seq_lens, total_seq_len) 74 | return sequences, attention_mask, position_ids 75 | -------------------------------------------------------------------------------- /openrlhf/models/utils.py: -------------------------------------------------------------------------------- 1 | # /* 2 | # * Modified by Haozhe Wang in 2025 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # */ 6 | 7 | from typing import Optional, Tuple, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | 12 | 13 | def compute_approx_kl( 14 | log_probs: torch.Tensor, 15 | log_probs_base: torch.Tensor, 16 | action_mask: Optional[torch.Tensor] = None, 17 | use_kl_estimator_k3: bool = False, 18 | ) -> torch.Tensor: 19 | """ 20 | Compute the approximate KL divergence between two distributions. 21 | Schulman blog: http://joschu.net/blog/kl-approx.html 22 | 23 | Args: 24 | log_probs: Log probabilities of the new distribution. 25 | log_probs_base: Log probabilities of the base distribution. 26 | action_mask: Mask for actions. 27 | """ 28 | 29 | log_ratio = log_probs.float() - log_probs_base.float() 30 | if action_mask is not None: 31 | log_ratio = log_ratio * action_mask 32 | 33 | # The k3 estimator is the non negative kl approximation in 34 | # http://joschu.net/blog/kl-approx.html 35 | # Besides non negative, it is also unbiased and have lower variance. 36 | if use_kl_estimator_k3: 37 | log_ratio = -log_ratio 38 | log_ratio = log_ratio.exp() - 1 - log_ratio 39 | 40 | return log_ratio 41 | 42 | 43 | def compute_reward( 44 | r: Union[torch.Tensor, float], 45 | kl_coef: float, 46 | kl: Union[torch.Tensor, list[torch.Tensor]], 47 | action_mask: Optional[torch.Tensor] = None, 48 | num_actions: Optional[Union[int, list[int]]] = None, 49 | reward_clip_range: Tuple[float, float] = None, 50 | ) -> Union[torch.Tensor, list[torch.Tensor]]: 51 | if kl_coef <= 0.0: 52 | kl_coef = 0.0 53 | 54 | if reward_clip_range: 55 | r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1]) 56 | 57 | if action_mask is not None: 58 | kl_reward = -kl_coef * kl 59 | # The following code is equivalent to: 60 | # 61 | # last_reward = torch.zeros_like(kl) 62 | # for i in range(last_reward.size(0)): 63 | # for t in reversed(range(last_reward.size(1))): 64 | # if action_mask[i][t] > 0.5: 65 | # last_reward[i][t] = r[i] 66 | # break 67 | # 68 | eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True) 69 | last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype)) 70 | 71 | reward = last_reward + kl_reward 72 | else: 73 | # TODO: write a more efficient version 74 | reward = [] 75 | for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)): 76 | kl_reward = -kl_coef * kl_seg 77 | kl_reward[action_len - 1] += r[i] 78 | reward.append(kl_reward) 79 | 80 | return reward 81 | 82 | 83 | def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 84 | # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881 85 | if logits.dtype in [torch.float32, torch.float64]: 86 | logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) 87 | logsumexp_values = torch.stack( 88 | [torch.logsumexp(l, dim=-1) for l in logits] # loop to reduce peak mem consumption 89 | ) 90 | log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x) 91 | else: 92 | log_probs_labels = [] 93 | for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption 94 | row_log_probs = F.log_softmax(row_logits, dim=-1) 95 | row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1) 96 | log_probs_labels.append(row_log_probs_labels) 97 | log_probs_labels = torch.stack(log_probs_labels) 98 | return log_probs_labels 99 | 100 | 101 | def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor: 102 | if mask is None: 103 | return tensor.mean(axis=dim) 104 | return (tensor * mask).sum(axis=dim) / (mask.sum(axis=dim)+1e-4) 105 | 106 | 107 | def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor: 108 | tensor = tensor * mask 109 | mean = masked_mean(tensor, mask, dim=dim) 110 | mean_centered = tensor - mean 111 | var = masked_mean(mean_centered**2, mask, dim=dim) 112 | return mean_centered * var.clamp(min=eps).rsqrt() 113 | 114 | 115 | # Reset positions for packed samples 116 | # For example 117 | # Input: attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2, 3, 3, 0]]) 118 | # Output: position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 0]]) 119 | def reset_position_ids(attention_mask): 120 | position_ids = torch.zeros_like(attention_mask, dtype=torch.long) 121 | for i in range(attention_mask.size(0)): 122 | mask = attention_mask[i] 123 | seq_num = mask.max().item() 124 | for index in range(1, seq_num + 1): 125 | sample_mask = mask == index 126 | sample_length = sample_mask.sum().item() 127 | position_ids[i, sample_mask] = torch.arange(sample_length, device=mask.device) 128 | return position_ids 129 | 130 | def packed_sequence_to_position_tensor(packed_seq_lens, device): 131 | """ 132 | Converts packed_seq_lens to a tensor of token positions. 133 | 134 | Args: 135 | packed_seq_lens: A list of integers representing token length for each sequence. 136 | 137 | Returns: 138 | A tensor of shape (1, ntoken) containing the sequences of positions. 139 | """ 140 | output_list = [] 141 | for seq_len in packed_seq_lens: 142 | output_list.extend(list(range(seq_len))) 143 | return torch.tensor(output_list, device=device).unsqueeze(0) 144 | 145 | 146 | def unpacking_samples(values: torch.Tensor, packed_seqlens: list[int]): 147 | values = values.squeeze(0) 148 | unpacked_values = [] 149 | offset = 0 150 | for seqlen in packed_seqlens: 151 | unpacked_values.append(values[offset : offset + seqlen]) 152 | offset += seqlen 153 | return unpacked_values 154 | -------------------------------------------------------------------------------- /openrlhf/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # from .dpo_trainer import DPOTrainer 2 | # from .kd_trainer import KDTrainer 3 | # from .kto_trainer import KTOTrainer 4 | from .ppo_trainer import PPOTrainer 5 | from .evaluator import Evaluator 6 | # from .prm_trainer import ProcessRewardModelTrainer 7 | # from .rm_trainer import RewardModelTrainer 8 | # from .sft_trainer import SFTTrainer 9 | 10 | __all__ = [ 11 | # "DPOTrainer", 12 | # "KDTrainer", 13 | # "KTOTrainer", 14 | "PPOTrainer", 15 | # "ProcessRewardModelTrainer", 16 | # "RewardModelTrainer", 17 | # "SFTTrainer", 18 | "Evaluator" 19 | ] 20 | -------------------------------------------------------------------------------- /openrlhf/trainer/evaluator.py: -------------------------------------------------------------------------------- 1 | # /* 2 | # * Original Copyright Haozhe Wang in 2025 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # */ 6 | 7 | import os 8 | import os.path 9 | from abc import ABC 10 | from typing import Any, Callable, Dict, List, Optional 11 | 12 | import torch 13 | import torch.distributed 14 | import torch.nn as nn 15 | from torch.optim import Optimizer 16 | from torch.utils.data import DataLoader 17 | from tqdm import tqdm 18 | 19 | from openrlhf.models import Actor, GPTLMLoss, PolicyLoss, SFTLoss, ValueLoss 20 | from openrlhf.models.utils import masked_mean 21 | from openrlhf.utils.distributed_sampler import DistributedSampler 22 | from openrlhf.models.utils import log_probs_from_logits 23 | 24 | from .ppo_utils import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer, DATA_PROCESSOR_MAP 25 | import random 26 | import copy 27 | import numpy as np 28 | from collections import defaultdict 29 | import json 30 | 31 | 32 | 33 | def read_jsonl(filepath): 34 | """ 35 | Reads a JSON Lines (jsonl) file and returns a list of dictionaries. 36 | 37 | Args: 38 | filepath (str): The path to the jsonl file. 39 | 40 | Returns: 41 | list: A list of dictionaries, where each dictionary represents a line 42 | from the jsonl file. Returns an empty list if the file is empty 43 | or if an error occurs. 44 | """ 45 | data = [] 46 | try: 47 | with open(filepath, 'r', encoding='utf-8') as f: 48 | for line in f: 49 | try: 50 | data.append(json.loads(line.strip())) 51 | except json.JSONDecodeError: 52 | print(f"Warning: Invalid JSON on line: {line.strip()}") 53 | # Optionally, you might want to log the error or handle it differently. 54 | 55 | except FileNotFoundError: 56 | print(f"Error: File not found at {filepath}") 57 | except Exception as e: 58 | print(f"An unexpected error occurred: {e}") 59 | 60 | return data 61 | 62 | class Evaluator(ABC): 63 | """ 64 | Trainer for Proximal Policy Optimization (PPO) algorithm. 65 | 66 | Args: 67 | strategy (Strategy): The training strategy to use. 68 | actor (Actor): The actor model in the PPO algorithm. 69 | critic (nn.Module): The critic model in the PPO algorithm. 70 | reward_model (nn.Module): The reward model for calculating rewards in the RLHF setup. 71 | initial_model (Actor): The initial model for reference logits to limit actor updates in RLHF. 72 | ema_model (Actor): The exponential moving average model for stable training. 73 | actor_optim (Optimizer): The optimizer for the actor model. 74 | critic_optim (Optimizer): The optimizer for the critic model. 75 | actor_scheduler (Scheduler): The learning rate scheduler for the actor. 76 | critic_scheduler (Scheduler): The learning rate scheduler for the critic. 77 | ema_beta (float, defaults to 0.992): EMA decay rate for model stability. 78 | init_kl_coef (float, defaults to 0.001): Initial coefficient for KL divergence. 79 | kl_target (float, optional): Target value for KL divergence. 80 | kl_horizon (int, defaults to 10000): Horizon for KL annealing. 81 | ptx_coef (float, defaults to 0): Coefficient for supervised loss from pre-trained data. 82 | micro_train_batch_size (int, defaults to 8): Micro-batch size for actor training. 83 | buffer_limit (int, defaults to 0): Maximum size of the replay buffer. 84 | buffer_cpu_offload (bool, defaults to True): If True, offloads replay buffer to CPU. 85 | eps_clip (float, defaults to 0.2): Clipping coefficient for policy loss. 86 | value_clip (float, defaults to 0.2): Clipping coefficient for value function loss. 87 | micro_rollout_batch_size (int, defaults to 8): Micro-batch size for generating rollouts. 88 | gradient_checkpointing (bool, defaults to False): If True, enables gradient checkpointing. 89 | max_epochs (int, defaults to 1): Number of epochs to train. 90 | max_norm (float, defaults to 1.0): Maximum gradient norm for gradient clipping. 91 | tokenizer (Callable, optional): Tokenizer for input data. 92 | prompt_max_len (int, defaults to 128): Maximum length for prompts. 93 | dataloader_pin_memory (bool, defaults to True): If True, pins memory in the data loader. 94 | remote_rm_url (str, optional): URL for remote reward model API. 95 | reward_fn (Callable, optional): Custom reward function for computing rewards. 96 | save_hf_ckpt (bool): Whether to save huggingface-format model weight. 97 | disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery) 98 | **generate_kwargs: Additional arguments for model generation. 99 | """ 100 | 101 | def __init__( 102 | self, 103 | strategy, 104 | ema_beta: float = 0.992, 105 | init_kl_coef: float = 0.001, 106 | kl_target: float = None, 107 | kl_horizon: int = 10000, 108 | ptx_coef: float = 0, 109 | micro_train_batch_size: int = 8, 110 | buffer_limit: int = 0, 111 | buffer_cpu_offload: bool = True, 112 | eps_clip: float = 0.2, 113 | value_clip: float = 0.2, 114 | micro_rollout_batch_size: int = 8, 115 | gradient_checkpointing: bool = False, 116 | max_epochs: int = 1, 117 | max_norm: float = 1.0, 118 | processor: Optional[Callable[[Any], Dict]] = None, 119 | tokenizer: Optional[Callable[[Any], Dict]] = None, 120 | prompt_max_len: int = 128, 121 | dataloader_pin_memory: bool = True, 122 | reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, 123 | save_hf_ckpt: bool = False, 124 | disable_ds_ckpt: bool = False, 125 | **generate_kwargs, 126 | ) -> None: 127 | # assert ( 128 | # not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None 129 | # ), "reward_fn must be specified if using multiple reward models" 130 | 131 | super().__init__() 132 | self.strategy = strategy 133 | 134 | strategy.setup_distributed() 135 | self.args = strategy.args 136 | self.rloo_sft = self.args.advantage_estimator.lower() in ['rloo_sft', 'group_sft'] 137 | self.save_hf_ckpt = save_hf_ckpt 138 | self.disable_ds_ckpt = disable_ds_ckpt 139 | self.micro_rollout_batch_size = micro_rollout_batch_size 140 | self.max_epochs = max_epochs 141 | self.tokenizer = tokenizer 142 | self.processor = processor 143 | self.data_processor = None 144 | # for vlm critic model, not provice processor. 145 | if self.args.train_vlm and processor is not None: 146 | self.data_processor = DATA_PROCESSOR_MAP[type(processor)](processor) 147 | self.tokenizer = self.data_processor.tokenizer 148 | 149 | self.generate_kwargs = generate_kwargs 150 | self.dataloader_pin_memory = dataloader_pin_memory 151 | self.max_norm = max_norm 152 | self.ptx_coef = ptx_coef 153 | self.micro_train_batch_size = micro_train_batch_size 154 | self.kl_target = kl_target 155 | self.prompt_max_len = prompt_max_len 156 | self.ema_beta = ema_beta 157 | self.gradient_checkpointing = gradient_checkpointing 158 | self.reward_fn = reward_fn 159 | 160 | args = self.args 161 | self.max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len 162 | 163 | packing_samples = getattr(self.args, "packing_samples", False) 164 | self.replay_buffer = NaiveReplayBuffer( 165 | micro_train_batch_size, self.data_processor, buffer_limit, buffer_cpu_offload, packing_samples, 166 | drop_maxlen=self.args.drop_maxlen, 167 | maxlen=self.args.generate_max_len + prompt_max_len, 168 | ) 169 | 170 | self.iter = 0 171 | self.eval_step = 0 172 | self.best = -1 173 | 174 | def eval_unit(self, args, ep, global_step, dataloader): 175 | keys = ['reward', 'response_length', 'validity','match','usefmt','round1_nwait'] 176 | infos = {k:[] for k in keys} 177 | print("!!!! eval loader size", len(dataloader), 'step', global_step) 178 | batchsize = dataloader.batch_sampler.batch_size 179 | for idx, rand_prompts in enumerate(dataloader): 180 | if batchsize>len(rand_prompts): 181 | current_len = len(rand_prompts) 182 | needed = batchsize - current_len 183 | repeat_indices = np.arange(needed) % current_len 184 | # repeat_indices = repeat_indices.to(rand_prompts.device) 185 | additional = [rand_prompts[ii] for ii in repeat_indices] 186 | rand_prompts = rand_prompts + additional 187 | else: needed = 0 188 | print(f"!!!! ========== eval progress {idx}/{len(dataloader)} ==========") 189 | 190 | exp_list = self.get_explist_from_prompts(args, ep, rand_prompts, is_eval=True, eval_step=global_step) 191 | 192 | for i, experience in enumerate(exp_list): 193 | self.replay_buffer.append_split(experience, is_eval=True) 194 | 195 | 196 | for item in self.replay_buffer.eval_items: 197 | info = item.info 198 | for k in keys: 199 | infos[k].append(info[k]) 200 | out_lens = infos['response_length'] 201 | 202 | for k,vlist in infos.items(): 203 | infos[k] = np.mean(vlist) 204 | infos['generation_exceed_rate'] = np.mean([x>args.generate_max_len-1 for x in out_lens]) 205 | 206 | torch.distributed.barrier() 207 | gather_info = self.strategy.all_reduce(infos) # mean 208 | 209 | return gather_info 210 | 211 | 212 | 213 | def get_eval_result_from_disk(self): 214 | args = self.strategy.args 215 | from glob import glob 216 | # os.makedirs(args.ckpt_path, exist_ok=True) 217 | # os.makedirs(f'{args.ckpt_path}/logs', exist_ok=True) 218 | tmp = f'{args.ckpt_path}/logs/sample.eval_iter{self.eval_step}*.jsonl' 219 | files = glob(tmp) 220 | print(f'!!!! [eval] reading from disk {len(files)} files', tmp, ) 221 | 222 | datalist = [read_jsonl(file) for file in files] 223 | results_each = defaultdict(list) 224 | q2results = defaultdict(list) 225 | for info in datalist: 226 | for x in info: 227 | qid = x['qids'] 228 | res = x.get('match') 229 | if res is None: 230 | r0_res = x['round0_correctness'] 231 | res = r0_res 232 | 233 | q2results[qid].append(res>0.5) 234 | # We compute query-wise mean acc, and then average them 235 | # this is a trick to handle the drop_last=False issue 236 | for qid, vlist in q2results.items(): 237 | bench = qid.split('-')[0] 238 | macc = np.mean(vlist) 239 | results_each[bench].append(macc) 240 | all_results = [] 241 | dump_info = [] 242 | modelpath = args.pretrain 243 | for k in results_each.keys(): 244 | nc = np.sum(results_each[k]) 245 | num = len(results_each[k]) 246 | dump_info.append(dict(benchname=k, pass1=nc/num, ncorrect=nc, ntotal=num, modelpath=modelpath)) 247 | print(f'!!!! [eval] from disk bench={k}, acc={np.mean(results_each[k])}={nc}/{num}') 248 | all_results.extend(results_each[k]) 249 | results_each[k] = np.mean(results_each[k]) 250 | 251 | json.dump(dump_info, open(f'{args.ckpt_path}/logs/metrics_iter{self.eval_step}.json', 'w')) 252 | acc = np.mean(all_results) 253 | return acc, results_each 254 | 255 | def fill_replay_buffer(self, buffer, num_expected): 256 | # Ensure every item in buffer appears at least once 257 | for item in buffer[:num_expected]: 258 | self.replay_buffer.append_split(item) 259 | 260 | # Fill the remaining slots with random choices from buffer 261 | remaining_slots = num_expected - len(buffer) 262 | if remaining_slots>0: 263 | for _ in range(remaining_slots): 264 | item = random.choice(buffer) 265 | self.replay_buffer.append_split(item) 266 | print(f'!!!! rbuffersize after filling: {len(self.replay_buffer)} should be {num_expected} x nsamples_per_query', ) 267 | # assert len(self.replay_buffer)==num_expected 268 | 269 | def get_explist_from_prompts(self, args, ep, all_prompts, append=False, is_eval=False, force_noprefix=False, eval_step=None): 270 | autocode = getattr(args, "prefix_generation", None) 271 | requires_group = getattr(args, "advantage_estimator", None) in [''] 272 | # print('!!!! requires group', requires_group) 273 | generate_kwargs = copy.copy(self.generate_kwargs) 274 | generate_kwargs['requires_group'] = requires_group 275 | if force_noprefix: 276 | pass 277 | elif autocode=='autocode': 278 | if ep==0: 279 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:2]] 280 | all_prompts = new_prompts 281 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:2]] 282 | else: 283 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[2:3]] 284 | all_prompts = new_prompts 285 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[2:3]] 286 | elif autocode=='autocode1': 287 | # if ep==0: 288 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:2]] 289 | all_prompts = new_prompts 290 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:2]] 291 | # else: 292 | # new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[2:3]] 293 | # all_prompts = new_prompts 294 | # generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[2:3]] 295 | elif autocode=='autocode2': 296 | # if ep==0: 297 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:3]] 298 | all_prompts = new_prompts 299 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:3]] 300 | elif autocode=='autocode_continue': 301 | # if ep==0: 302 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[3:5]] 303 | all_prompts = new_prompts 304 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[3:5]] 305 | elif append and autocode=="autocode_append": 306 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[5:6]] 307 | all_prompts = new_prompts 308 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[5:6]] 309 | 310 | return self.experience_maker.make_experience_list(all_prompts, is_eval=is_eval, eval_step=eval_step, **generate_kwargs) 311 | 312 | 313 | def evaluate( 314 | self, 315 | args, 316 | eval_data 317 | ) -> None: 318 | 319 | tmp = eval_data 320 | eval_bsz = args.micro_rollout_batch_size 321 | eval_dataloader = self.strategy.setup_dataloader( 322 | tmp, 323 | eval_bsz, # should larger than world size? 324 | True, 325 | True, 326 | drop_last=False 327 | ) 328 | print(f'!!!! eval dataloader size', len(eval_dataloader), 'eval_bsz', eval_bsz) 329 | self.eval_dataloader = eval_dataloader 330 | if len(eval_data)==0 or len(eval_dataloader)==0: print('!!!! no eval data, eval_data should be larger than num_vllm * micro_bsz', len(eval_data), len(eval_dataloader)) 331 | else: print(f'!!!! eval data {len(eval_data)} eval dataloader', len(eval_dataloader), args.micro_rollout_batch_size) 332 | info = self.eval_unit(args, 0, self.eval_step, eval_dataloader) 333 | eval_result = info['match'] 334 | torch.distributed.barrier() 335 | result2, bench_results = self.get_eval_result_from_disk() 336 | print(f'!!!! [eval] finish with step {self.eval_step} rank {self.strategy.get_rank()} gathered eval stats', info, 'from disk:', result2) 337 | 338 | self.eval_step += 1 339 | # info['match_overall'] = result2 340 | for k,v in bench_results.items(): 341 | info[f'match_{k}'] = v 342 | info['match_overall'] = result2 343 | eval_save = self.best<=result2 # and args.rollout_batch_size>16 344 | if eval_save: 345 | self.best = result2 346 | print(f"!!!! [eval] saving with average score {self.best}") 347 | 348 | del eval_dataloader 349 | self.replay_buffer.eval_items.clear() 350 | 351 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .experience_maker import Experience, NaiveExperienceMaker, RemoteExperienceMaker 2 | from .kl_controller import AdaptiveKLController, FixedKLController 3 | from .replay_buffer import NaiveReplayBuffer 4 | from .data_processor import BaseDataProcessor, DATA_PROCESSOR_MAP 5 | 6 | __all__ = [ 7 | "Experience", 8 | "NaiveExperienceMaker", 9 | "RemoteExperienceMaker", 10 | "AdaptiveKLController", 11 | "FixedKLController", 12 | "NaiveReplayBuffer", 13 | ] 14 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/data_processor.py: -------------------------------------------------------------------------------- 1 | # /* 2 | # * Modified by Haozhe Wang in 2025 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # */ 6 | import json 7 | import os 8 | from abc import ABC, abstractmethod 9 | from typing import List, Optional, Union, Dict 10 | 11 | import torch 12 | from qwen_vl_utils import process_vision_info 13 | from transformers import Qwen2VLProcessor 14 | from transformers.processing_utils import ProcessorMixin 15 | try: 16 | from transformers import Qwen2_5_VLProcessor 17 | except Exception as e: 18 | print("Qocal Qwen2_5_VLProcessor not found") 19 | 20 | # https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/qwen2_5_vl.md 21 | class BaseDataProcessor(ABC): 22 | def __init__(self, processor: ProcessorMixin): 23 | super().__init__() 24 | self.processor = processor 25 | 26 | @abstractmethod 27 | def __call__( 28 | self, 29 | messages: Union[Dict, List[str], str], 30 | max_length: int, 31 | padding: bool = True, 32 | device: Optional[Union[str, torch.device]] = None, 33 | return_tensors: Optional[str] = "pt", 34 | add_special_tokens: Optional[bool] = False, 35 | truncation: Optional[bool] = True, 36 | ) -> Dict: 37 | raise NotImplementedError 38 | 39 | @abstractmethod 40 | def make_input_batch(self, inputs: List[Dict]) -> Dict: 41 | raise NotImplementedError 42 | 43 | @abstractmethod 44 | def split_input_batch(self, batch: Dict) -> List[Dict]: 45 | raise NotImplementedError 46 | 47 | def _format_messages(self, messages: Union[Dict, List[str], str]) -> List[Dict]: 48 | if isinstance(messages, list) and isinstance(messages[0], str): 49 | return [json.loads(m) for m in messages] 50 | elif isinstance(messages, str): 51 | return [json.loads(messages)] 52 | elif isinstance(messages, dict): 53 | return [messages] 54 | else: 55 | raise ValueError("Invalid messages format, must be a list of strings or a string or a dict") 56 | 57 | def apply_chat_template( 58 | self, 59 | messages: Union[Dict, List[str], str], 60 | tokenize: bool = False, 61 | add_generation_prompt: bool = True, 62 | ) -> List[str]: 63 | messages = self._format_messages(messages) 64 | 65 | return self.processor.apply_chat_template( 66 | messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt 67 | ) 68 | 69 | def get_images_from_messages( 70 | self, messages: Union[Dict, List[str], str] 71 | ) -> List[Dict]: 72 | messages = self._format_messages(messages) 73 | return self._get_images_from_messages(messages) 74 | 75 | @abstractmethod 76 | def _get_images_from_messages(self, messages: List[Dict]) -> List[Dict]: 77 | raise NotImplementedError 78 | 79 | @property 80 | def pad_token_id(self) -> int: 81 | return self.processor.tokenizer.pad_token_id 82 | 83 | @property 84 | def eos_token_id(self) -> int: 85 | return self.processor.tokenizer.eos_token_id 86 | 87 | @property 88 | def tokenizer(self): 89 | return self.processor.tokenizer 90 | 91 | 92 | def add_pixel_bounds(messages): 93 | # 默认的像素范围 94 | DEFAULT_MIN_PIXELS = int(os.getenv("MIN_PIXELS", 256 * 28 * 28)) 95 | DEFAULT_MAX_PIXELS = int(os.getenv("MAX_PIXELS", 1280 * 28 * 28)) 96 | 97 | def process_content(content): 98 | if isinstance(content, list): 99 | for item in content: 100 | if isinstance(item, dict) and item.get("type") == "image": 101 | if "min_pixels" not in item: 102 | item["min_pixels"] = DEFAULT_MIN_PIXELS 103 | if "max_pixels" not in item: 104 | item["max_pixels"] = DEFAULT_MAX_PIXELS 105 | return content 106 | 107 | for message in messages: 108 | for msg in message: 109 | msg["content"] = process_content(msg["content"]) 110 | return messages 111 | 112 | def remove_except_last(text, tag): 113 | cnt = text.count(tag) 114 | if cnt>1: 115 | index = text.rfind(tag) 116 | return text[:index].replace(tag, "")+text[index:] 117 | else: return text 118 | 119 | def find_rank_occurrence(ids, target, rank): 120 | """ 121 | Finds the position (index) of the rank-th occurrence of the target in the list ids. 122 | 123 | Args: 124 | ids (list): List of integers to search through. 125 | target (int): Integer to find. 126 | rank (int): The occurrence number to locate (1-based). 127 | 128 | Returns: 129 | int: Index of the rank-th occurrence, or -1 if it doesn’t exist. 130 | """ 131 | count = 0 132 | for i, val in enumerate(ids): 133 | if val == target: 134 | count += 1 135 | if count == rank: 136 | return i 137 | return -1 138 | 139 | class Qwen2VLDataProcessor(BaseDataProcessor): 140 | def __call__( 141 | self, 142 | messages, 143 | max_length, 144 | padding=True, 145 | device=None, 146 | return_tensors="pt", 147 | add_special_tokens=False, 148 | truncation=True, 149 | ) -> Dict: 150 | 151 | # messages = newlist 152 | messages = self._format_messages(messages) # list of dicts 153 | processor = self.processor 154 | # for entry in messages: 155 | # if entry['role'] == 'user': 156 | # content = entry['content'][-1]['text'] 157 | # if "" in content: 158 | # content = content.replace("", "<|vision_start|><|image_pad|><|vision_end|>") 159 | # entry['content'][-1]['text'] = content 160 | 161 | texts = processor.apply_chat_template( 162 | messages, tokenize=False, add_generation_prompt=True 163 | ) 164 | texts = self.handle_placeholders(texts) 165 | messages = add_pixel_bounds(messages) 166 | image_inputs, video_inputs = process_vision_info(messages) 167 | # print(texts) 168 | max_length = 10240 # we need to make sure it does not trucate 169 | batch = processor( 170 | text=texts, 171 | images=image_inputs, 172 | videos=video_inputs, 173 | padding=padding, 174 | max_length=max_length, 175 | add_special_tokens=False, 176 | truncation=truncation, 177 | return_tensors=return_tensors, 178 | ) 179 | if device: 180 | return {k: v.to(device) for k, v in batch.items()} 181 | return {k: v for k, v in batch.items()} 182 | 183 | def handle_placeholders(self, texts): 184 | newlist = [] 185 | placeholder = "" 186 | # placeholder2 = "" 187 | replacewith = "<|vision_start|><|image_pad|><|vision_end|>" 188 | for m in texts: 189 | new = m 190 | for k in ["<|vision_start|>","<|image_pad|>","<|vision_end|>"]: 191 | new = new.replace(k,"") 192 | # now new has no replacewith 193 | if new.count(placeholder)>0: 194 | new = new.replace(placeholder, replacewith) 195 | else: 196 | new = replacewith + new 197 | newlist.append(new) 198 | return newlist 199 | 200 | def make_input_batch(self, inputs: List[Dict]) -> Dict: 201 | # each element has no batch dimension 202 | batch = {k: None for k in inputs[0].keys()} 203 | for k in batch.keys(): 204 | if k in ["input_ids", "attention_mask"]: 205 | batch[k] = torch.stack([inp[k] for inp in inputs], dim=0) 206 | elif k in ["pixel_values", "image_grid_thw"]: 207 | # qwen2vl concat all patches of all images in a batch in the first dimension 208 | batch[k] = torch.cat([inp[k] for inp in inputs], dim=0) 209 | else: 210 | raise ValueError(f"Unknown key {k} for Qwen2VLDataProcessor") 211 | return batch 212 | 213 | def split_input_batch(self, batch: Dict) -> List[Dict]: 214 | batch_size = len(batch["input_ids"]) 215 | batch_kwargs = [{} for _ in range(batch_size)] 216 | # first process None values 217 | keys = [] 218 | for k, v in batch.items(): 219 | if v is not None: 220 | keys.append(k) 221 | else: 222 | for i in range(batch_size): 223 | batch_kwargs[i][k] = None 224 | 225 | if "pixel_values" in keys and ( 226 | "input_ids" not in keys or "image_grid_thw" not in keys 227 | ): 228 | raise ValueError( 229 | "Cannot split batch with pixel_values without input_ids and image_grid_thw" 230 | ) 231 | if "image_grid_thw" in keys and ("input_ids" not in keys): 232 | raise ValueError("Cannot split batch with image_grid_thw without input_ids") 233 | for k in ["input_ids", "attention_mask"]: 234 | if k in keys: 235 | vals = batch[k] 236 | if isinstance(vals, torch.Tensor): 237 | vals = torch.unbind(vals) 238 | assert batch_size == len(vals) 239 | for i, v in enumerate(vals): 240 | batch_kwargs[i][k] = v 241 | if "pixel_values" in keys: 242 | thws = batch["image_grid_thw"] # (total_img_num, (t,h,w)) 243 | pixel_values = batch["pixel_values"] 244 | vision_start_id = self.processor.tokenizer("<|vision_start|>")["input_ids"][0] 245 | vision_end_id = self.processor.tokenizer("<|vision_end|>")["input_ids"][0] 246 | img_idx = 0 247 | patch_idx = 0 248 | for i in range(batch_size): 249 | input_ids_i = batch_kwargs[i]["input_ids"] 250 | if not isinstance(input_ids_i, torch.Tensor): 251 | input_ids_i = torch.tensor(input_ids_i) 252 | vision_start_num = (input_ids_i == vision_start_id).sum().item() 253 | vision_end_num = (input_ids_i == vision_end_id).sum().item() 254 | 255 | img_num = vision_end_num 256 | if img_num == 0: 257 | batch_kwargs[i]["pixel_values"] = None 258 | batch_kwargs[i]["image_grid_thw"] = None 259 | continue 260 | thws_i = thws[img_idx:img_num+img_idx] 261 | img_idx += img_num 262 | flag = False 263 | if len(thws_i) != img_num: 264 | thws_i = thws[-img_num:] 265 | print(f'[warning] the image_grid_thw does not match, this is polluted data, attempting: {len(thws_i)} vs {img_num}') 266 | flag = True 267 | # thws = thws[img_num:] 268 | if not isinstance(thws_i, torch.Tensor): 269 | thws_i = torch.stack(thws_i) 270 | batch_kwargs[i]["image_grid_thw"] = thws_i 271 | patchs_num = thws_i.prod(dim=1).sum().item() 272 | pixel_values_i = pixel_values[patch_idx:patchs_num+patch_idx] 273 | if len(pixel_values_i) != patchs_num: 274 | pixel_values_i = pixel_values[-patchs_num:] 275 | print(f'[warning] the pixel_values_i does not match, this is polluted data, attempting: {patchs_num} in {len(pixel_values)} resulting in {len(pixel_values_i)}') 276 | flag = True 277 | # assert len(pixel_values_i) == patchs_num 278 | # pixel_values = pixel_values[patch_idx:patchs_num+patch_idx] 279 | batch_kwargs[i]["pixel_values"] = pixel_values_i 280 | if flag: 281 | batch_kwargs[i] = None 282 | print('[truncation warning] appears a sample has mismatched vision_start and vision_end, likely due to garbage outputs, its current length is ', len(input_ids_i)) 283 | # print(input_ids_i.detach().cpu().numpy().tolist()) 284 | error_index = find_rank_occurrence(input_ids_i.detach().cpu().numpy().tolist(), vision_start_id, 1) 285 | input_ids_i[error_index:] = self.eos_token_id # how about directly before the vision start? 286 | continue 287 | # assert len(thws) == 0 288 | # assert len(pixel_values) == 0 289 | return batch_kwargs 290 | 291 | def _get_images_from_messages(self, messages: List[Dict]) -> List[Dict]: 292 | messages = add_pixel_bounds(messages) 293 | image_inputs, _ = process_vision_info(messages) 294 | return image_inputs 295 | 296 | 297 | DATA_PROCESSOR_MAP = { 298 | Qwen2VLProcessor: Qwen2VLDataProcessor, 299 | Qwen2_5_VLProcessor: Qwen2VLDataProcessor, 300 | } 301 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/kl_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveKLController: 5 | """ 6 | Adaptive KL controller described in the paper: 7 | https://arxiv.org/pdf/1909.08593.pdf 8 | """ 9 | 10 | def __init__(self, init_kl_coef, target, horizon): 11 | self.value = init_kl_coef 12 | self.target = target 13 | self.horizon = horizon 14 | 15 | def update(self, current, n_steps): 16 | target = self.target 17 | proportional_error = np.clip(current / target - 1, -0.2, 0.2) 18 | mult = 1 + proportional_error * n_steps / self.horizon 19 | self.value *= mult 20 | 21 | 22 | class FixedKLController: 23 | """Fixed KL controller.""" 24 | 25 | def __init__(self, kl_coef): 26 | self.value = kl_coef 27 | 28 | def update(self, current, n_steps): 29 | pass 30 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor 2 | from .ppo_actor import ActorModelRayActor 3 | from .ppo_critic import CriticModelRayActor 4 | from .vllm_engine import create_vllm_engines 5 | from .evaluator2 import Evaluator2 6 | 7 | __all__ = [ 8 | "DistributedTorchRayActor", 9 | "PPORayActorGroup", 10 | "ReferenceModelRayActor", 11 | "RewardModelRayActor", 12 | "ActorModelRayActor", 13 | "CriticModelRayActor", 14 | "create_vllm_engines", 15 | "Evaluator2" 16 | ] 17 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/launcher.py: -------------------------------------------------------------------------------- 1 | # /* 2 | # * Modified by Haozhe Wang in 2025 3 | # * 4 | # * Licensed under the Apache License, Version 2.0 (the "License"); 5 | # */ 6 | import logging 7 | import os 8 | import socket 9 | from typing import Callable, Dict, List, Optional, Type 10 | 11 | import ray 12 | import torch 13 | from ray.util.placement_group import PlacementGroup, placement_group 14 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 15 | 16 | from openrlhf.models import Actor, get_llm_for_sequence_regression 17 | from openrlhf.trainer.ray.utils import ray_noset_visible_devices 18 | from openrlhf.utils.deepspeed import DeepspeedStrategy 19 | 20 | 21 | class DistributedTorchRayActor: 22 | def __init__(self, world_size, rank, master_addr, master_port): 23 | logging.basicConfig( 24 | format="%(asctime)s %(levelname)-8s %(message)s", 25 | level=logging.INFO, 26 | datefmt="%Y-%m-%d %H:%M:%S", 27 | ) 28 | self._world_size = world_size 29 | self._rank = rank 30 | self._master_addr = master_addr if master_addr else self._get_current_node_ip() 31 | self._master_port = master_port if master_port else self._get_free_port() 32 | os.environ["MASTER_ADDR"] = self._master_addr 33 | os.environ["MASTER_PORT"] = str(self._master_port) 34 | os.environ["WORLD_SIZE"] = str(self._world_size) 35 | os.environ["RANK"] = str(self._rank) 36 | # NOTE: Ray will automatically set the *_VISIBLE_DEVICES 37 | # environment variable for each actor, unless 38 | # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, so 39 | # set local rank to 0 when the flag is not applicable. 40 | os.environ["LOCAL_RANK"] = str(ray.get_gpu_ids()[0]) if ray_noset_visible_devices() else "0" 41 | 42 | @staticmethod 43 | def _get_current_node_ip(): 44 | address = ray._private.services.get_node_ip_address() 45 | # strip ipv6 address 46 | return address.strip("[]") 47 | 48 | @staticmethod 49 | def _get_free_port(): 50 | with socket.socket() as sock: 51 | sock.bind(("", 0)) 52 | return sock.getsockname()[1] 53 | 54 | def get_master_addr_port(self): 55 | return self._master_addr, self._master_port 56 | 57 | 58 | class BasePPORole(DistributedTorchRayActor): 59 | def _setup_distributed(self, strategy: DeepspeedStrategy): 60 | # configure strategy 61 | self.strategy = strategy 62 | strategy.setup_distributed() 63 | 64 | def init_model_from_pretrained(self, *args, **kwargs): 65 | raise NotImplementedError() 66 | 67 | 68 | @ray.remote(num_gpus=1) 69 | class ReferenceModelRayActor(BasePPORole): 70 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): 71 | self._setup_distributed(strategy) 72 | model = Actor( 73 | pretrain, 74 | use_flash_attention_2=strategy.args.flash_attn, 75 | bf16=strategy.args.bf16, 76 | load_in_4bit=strategy.args.load_in_4bit, 77 | ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload), 78 | packing_samples=strategy.args.packing_samples, 79 | ) 80 | strategy.print(model) 81 | 82 | if strategy.args.ref_reward_offload: 83 | model._offload = True 84 | 85 | self.model = self.strategy.prepare(model, is_rlhf=True) 86 | self.model.eval() 87 | 88 | def forward( 89 | self, 90 | sequences: torch.LongTensor, 91 | num_actions: int = None, 92 | attention_mask: Optional[torch.Tensor] = None, 93 | return_output=False, 94 | packed_seq_lens: Optional[list[int]] = None, 95 | visual_inputs: Optional[dict] = None, 96 | ) -> torch.Tensor: 97 | if visual_inputs is None: 98 | visual_inputs = {} 99 | device = torch.cuda.current_device() 100 | with torch.no_grad(): 101 | visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()} 102 | log_probs = self.model( 103 | sequences.to(device), 104 | num_actions, 105 | attention_mask.to(device), 106 | return_output=return_output, 107 | packed_seq_lens=packed_seq_lens, 108 | visual_inputs=visual_inputs, 109 | ) 110 | return log_probs.to("cpu") 111 | 112 | def empty_cache(self) -> None: 113 | torch.cuda.empty_cache() 114 | 115 | 116 | @ray.remote(num_gpus=1) 117 | class RewardModelRayActor(BasePPORole): 118 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain): 119 | self._setup_distributed(strategy) 120 | model = get_llm_for_sequence_regression( 121 | pretrain, 122 | "reward", 123 | normalize_reward=strategy.args.normalize_reward, 124 | use_flash_attention_2=strategy.args.flash_attn, 125 | bf16=strategy.args.bf16, 126 | load_in_4bit=strategy.args.load_in_4bit, 127 | ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload), 128 | value_head_prefix=strategy.args.value_head_prefix, 129 | packing_samples=strategy.args.packing_samples, 130 | ) 131 | strategy.print(model) 132 | strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward)) 133 | strategy.print("mean: {}, std {}".format(model.mean, model.std)) 134 | 135 | if strategy.args.ref_reward_offload: 136 | model._offload = True 137 | 138 | self.model = self.strategy.prepare(model, is_rlhf=True) 139 | self.model.eval() 140 | 141 | def forward( 142 | self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, packed_seq_lens=None, visual_inputs: Optional[dict] = None, 143 | ) -> torch.Tensor: 144 | device = torch.cuda.current_device() 145 | if visual_inputs is None: 146 | visual_inputs = {} 147 | visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()} 148 | with torch.no_grad(): 149 | reward = self.model(sequences.to(device), attention_mask.to(device), packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs) 150 | return reward.to("cpu") 151 | 152 | def empty_cache(self) -> None: 153 | torch.cuda.empty_cache() 154 | 155 | 156 | class PPORayActorGroup: 157 | """ 158 | A group of ray actors 159 | Functions start with 'async' should return list of object refs 160 | 161 | Args: 162 | num_nodes (int): Number of nodes for this actor group. 163 | num_gpus_per_node (int): Number of gpus for this actor group. 164 | ray_actor_type (Type[BasePPORole]): PPO model type that this actor group serve on. 165 | pg (PlacementGroup, optional): Placement group to schedule actor on. 166 | If none, create new placement group automatically. Defaults to None. 167 | num_gpus_per_actor (float, optional): Number of gpus allocated for each actor. 168 | If < 1.0, multiple models can share same gpu. Defaults to 1. 169 | """ 170 | 171 | def __init__( 172 | self, 173 | num_nodes, 174 | num_gpus_per_node, 175 | ray_actor_type: Type[BasePPORole], 176 | pg: PlacementGroup = None, 177 | num_gpus_per_actor=1, 178 | resources: Dict[str, float] = None, 179 | num_resources_per_node: int = None, 180 | ) -> None: 181 | self._num_nodes = num_nodes 182 | self._num_gpus_per_node = num_gpus_per_node 183 | self.ray_actor_type = ray_actor_type 184 | 185 | # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html 186 | self._resources = resources 187 | self._num_resources_per_node = num_resources_per_node 188 | 189 | self._initiate_actors(pg, num_gpus_per_actor) 190 | 191 | def _initiate_actors(self, pg, num_gpus_per_actor): 192 | world_size = self._num_nodes * self._num_gpus_per_node 193 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}') 194 | # Use placement group to lock resources for models of same type 195 | if self._num_gpus_per_node > 1 and pg is None: 196 | bundles = [{"GPU": 1, "CPU": 1} for _ in range(self._num_nodes * self._num_gpus_per_node)] 197 | if self._resources: 198 | resources_name = list(self._resources.keys())[0] 199 | for i in range(len(bundles)): 200 | bundles[i][resources_name] = self._num_resources_per_node 201 | 202 | pg = placement_group(bundles, strategy="PACK") 203 | ray.get(pg.ready()) 204 | if pg: 205 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}, num_gpus_per_actor={num_gpus_per_actor}') 206 | master_actor = self.ray_actor_type.options( 207 | num_cpus=num_gpus_per_actor, 208 | num_gpus=num_gpus_per_actor, 209 | resources=self._resources, 210 | scheduling_strategy=PlacementGroupSchedulingStrategy( 211 | placement_group=pg, placement_group_bundle_index=0 212 | ), 213 | ).remote(world_size, 0, None, None) 214 | else: 215 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}, num_gpus_per_actor={num_gpus_per_actor}') 216 | master_actor = self.ray_actor_type.options( 217 | num_cpus=num_gpus_per_actor, 218 | num_gpus=num_gpus_per_actor, 219 | resources=self._resources, 220 | ).remote(world_size, 0, None, None) 221 | self._actor_handlers = [master_actor] 222 | 223 | # Create worker actors 224 | if world_size > 1: 225 | master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote()) 226 | for rank in range(1, world_size): 227 | if pg: 228 | worker_actor = self.ray_actor_type.options( 229 | num_cpus=num_gpus_per_actor, 230 | num_gpus=num_gpus_per_actor, 231 | resources=self._resources, 232 | scheduling_strategy=PlacementGroupSchedulingStrategy( 233 | placement_group=pg, 234 | placement_group_bundle_index=rank, 235 | ), 236 | ).remote(world_size, rank, master_addr, master_port) 237 | else: 238 | worker_actor = self.ray_actor_type.options( 239 | num_cpus=num_gpus_per_actor, 240 | num_gpus=num_gpus_per_actor, 241 | resources=self._resources, 242 | ).remote(world_size, rank, master_addr, master_port) 243 | self._actor_handlers.append(worker_actor) 244 | 245 | def async_init_model_from_pretrained( 246 | self, 247 | *args, 248 | **kwargs, 249 | ): 250 | """Init model from pretrained checkpoint. 251 | 252 | Returns: 253 | List: list of remote object refs. 254 | """ 255 | return [actor.init_model_from_pretrained.remote(*args, **kwargs) for actor in self._actor_handlers] 256 | 257 | def async_fit_actor_model( 258 | self, 259 | critic_model_group: "PPORayActorGroup", 260 | initial_model_group: "PPORayActorGroup", 261 | reward_model_groups: List["PPORayActorGroup"], 262 | remote_rm_urls: List[str] = None, 263 | reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None, 264 | vllm_engines: List = None, 265 | ): 266 | """Train actor model. 267 | 268 | Args: 269 | critic_model_group (PPORayActorGroup): critic model group. 270 | initial_model_group (PPORayActorGroup): reference model group. 271 | reward_model_groups (PPORayActorGroup): reward model groups. 272 | remote_rm_urls: remote RM APIs. 273 | reward_fn: reward calculate function, must be specified if using multiple reward models. 274 | vllm_engines: vllm engines for text generation, if not specified, generate text by actor model directly. 275 | 276 | Returns: 277 | List: list of remote object refs. 278 | """ 279 | assert ( 280 | (remote_rm_urls and len(remote_rm_urls) == 1) 281 | or (reward_model_groups and len(reward_model_groups) == 1) 282 | or reward_fn is not None 283 | ), "reward_fn must be specified if using multiple reward models" 284 | 285 | critic_actors = critic_model_group._actor_handlers if critic_model_group else None 286 | initial_actors = initial_model_group._actor_handlers if initial_model_group else None 287 | 288 | refs = [] 289 | # TODO(wuxibin): actor model choose critic/reward/initial model in a 290 | # round robin fashion, implement more efficient dispatching strategy. 291 | for i, actor in enumerate(self._actor_handlers): 292 | critic_actor = critic_actors[i % len(critic_actors)] if critic_actors else None 293 | initial_actor = initial_actors[i % len(initial_actors)] if initial_actors else None 294 | 295 | reward_actors = [] 296 | if reward_model_groups: 297 | for reward_model_group in reward_model_groups: 298 | actors = reward_model_group._actor_handlers 299 | reward_actors.append(actors[i % len(actors)]) 300 | 301 | refs.append( 302 | actor.fit.remote( 303 | critic_model=critic_actor, 304 | initial_model=initial_actor, 305 | reward_model=reward_actors, 306 | remote_rm_url=remote_rm_urls, 307 | reward_fn=reward_fn, 308 | vllm_engines=vllm_engines, 309 | # whether this actor should triger corresponding critic model training 310 | critic_train_remote=(i < len(critic_actors)) if critic_actor else None, 311 | ) 312 | ) 313 | 314 | return refs 315 | 316 | 317 | def async_save_model(self): 318 | """Save actor model on rank 0. 319 | 320 | Returns: 321 | List: list of remote object refs. 322 | """ 323 | return [actor.save_model.remote() for actor in self._actor_handlers] 324 | 325 | def async_run_method(self, method_name, *args, **kwargs): 326 | refs = [] 327 | for actor in self._actor_handlers: 328 | method = getattr(actor, method_name) 329 | refs.append(method.remote(*args, **kwargs)) 330 | return refs 331 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/ppo_critic.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | from typing import Dict, Optional, Union 4 | 5 | import ray 6 | import torch 7 | from torch.utils.data import DataLoader 8 | from tqdm import tqdm 9 | from transformers.trainer import get_scheduler 10 | 11 | from openrlhf.models import get_llm_for_sequence_regression 12 | from openrlhf.trainer import PPOTrainer 13 | from openrlhf.trainer.ppo_utils import Experience 14 | from openrlhf.utils import get_tokenizer, get_vl_processor 15 | from openrlhf.utils.deepspeed import DeepspeedStrategy 16 | 17 | from .launcher import BasePPORole 18 | 19 | 20 | class CriticPPOTrainer(PPOTrainer): 21 | def ppo_train(self): 22 | # replay buffer may be empty at first, we should rebuild at each training 23 | dataloader = DataLoader( 24 | self.replay_buffer, 25 | batch_size=self.replay_buffer.sample_batch_size, 26 | shuffle=True, 27 | drop_last=True, 28 | pin_memory=self.dataloader_pin_memory, 29 | collate_fn=self.replay_buffer.collate_fn, 30 | ) 31 | device = torch.cuda.current_device() 32 | 33 | status_list = [] 34 | status_mean = {} 35 | for epoch in range(self.max_epochs): 36 | pbar = tqdm( 37 | dataloader, 38 | desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]", 39 | disable=not self.strategy.is_rank_0(), 40 | ) 41 | for experience in pbar: 42 | experience.to_device(device) 43 | status = self.training_step(experience) 44 | 45 | # for DP 46 | status = self.strategy.all_reduce(status) 47 | 48 | status_list.append(status) 49 | pbar.set_postfix(status) 50 | 51 | if status_list: 52 | status_mean = status_list[0] 53 | for m in status_list[1:]: 54 | for k, v in m.items(): 55 | status_mean[k] += v 56 | for k in status_mean.keys(): 57 | status_mean[k] /= len(status_list) 58 | return status_mean 59 | 60 | def training_step(self, experience: Experience) -> Dict[str, float]: 61 | return self.training_step_critic(experience) 62 | 63 | 64 | @ray.remote(num_gpus=1) 65 | class CriticModelRayActor(BasePPORole): 66 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps): 67 | args = strategy.args 68 | 69 | self._setup_distributed(strategy) 70 | critic = get_llm_for_sequence_regression( 71 | pretrain, 72 | "critic", 73 | normalize_reward=strategy.args.normalize_reward, 74 | use_flash_attention_2=strategy.args.flash_attn, 75 | bf16=strategy.args.bf16, 76 | load_in_4bit=strategy.args.load_in_4bit, 77 | lora_rank=strategy.args.lora_rank, 78 | lora_alpha=strategy.args.lora_alpha, 79 | target_modules=strategy.args.target_modules, 80 | lora_dropout=strategy.args.lora_dropout, 81 | ds_config=strategy.get_ds_train_config(is_actor=False), 82 | value_head_prefix=strategy.args.value_head_prefix, 83 | init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain, 84 | packing_samples=strategy.args.packing_samples, 85 | ) 86 | strategy.print(critic) 87 | strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward)) 88 | strategy.print("mean: {}, std {}".format(critic.mean, critic.std)) 89 | 90 | # configure optimizer 91 | critic_optim = strategy.create_optimizer( 92 | critic, lr=args.critic_learning_rate, betas=args.adam_betas, weight_decay=args.l2 93 | ) 94 | 95 | # configure scheduler 96 | critic_scheduler = get_scheduler( 97 | "cosine_with_min_lr", 98 | critic_optim, 99 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio), 100 | num_training_steps=max_steps, 101 | scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1}, 102 | ) 103 | 104 | if args.gradient_checkpointing: 105 | critic.gradient_checkpointing_enable( 106 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant} 107 | ) 108 | 109 | # prepare models/optimizers... 110 | self.critic, self.critic_optim, self.critic_scheduler = strategy.prepare( 111 | (critic, critic_optim, critic_scheduler), 112 | is_rlhf=True, 113 | ) 114 | 115 | # load checkpoint 116 | if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")): 117 | ckpt_path = os.path.join(args.ckpt_path, "_critic") 118 | strategy.load_ckpt(self.critic, ckpt_path) 119 | strategy.print(f"Loaded the checkpoint: {ckpt_path}") 120 | 121 | # configure Trainer 122 | # only use wandb at actor model 123 | strategy.args.use_wandb = False 124 | # configure tokenizer 125 | args = strategy.args 126 | if args.train_vlm: 127 | self.processor = get_vl_processor( 128 | pretrain, self.critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer 129 | ) 130 | self.tokenizer = self.processor.tokenizer 131 | else: 132 | self.processor = None 133 | self.tokenizer = get_tokenizer( 134 | pretrain, self.critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer 135 | ) 136 | self.trainer = CriticPPOTrainer( 137 | strategy, 138 | actor=None, 139 | critic=self.critic, 140 | reward_model=None, 141 | initial_model=None, 142 | ema_model=None, 143 | actor_optim=None, 144 | critic_optim=self.critic_optim, 145 | actor_scheduler=None, 146 | critic_scheduler=self.critic_scheduler, 147 | max_epochs=args.max_epochs, 148 | micro_train_batch_size=args.micro_train_batch_size, 149 | micro_rollout_batch_size=args.micro_rollout_batch_size, 150 | gradient_checkpointing=args.gradient_checkpointing, 151 | prompt_max_len=args.prompt_max_len, 152 | value_clip=args.value_clip, 153 | eps_clip=args.eps_clip, 154 | processor=self.processor, 155 | tokenizer=self.tokenizer 156 | ) 157 | 158 | def forward( 159 | self, 160 | sequences: torch.LongTensor, 161 | num_actions: Optional[Union[int, list[int]]] = None, 162 | attention_mask: Optional[torch.Tensor] = None, 163 | packed_seq_lens=None, 164 | visual_inputs=None, 165 | ) -> torch.Tensor: 166 | """Generates critic values.""" 167 | device = torch.cuda.current_device() 168 | self.critic.eval() 169 | if visual_inputs is None: 170 | visual_inputs = {} 171 | with torch.no_grad(): 172 | visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()} 173 | value = self.critic( 174 | sequences.to(device), num_actions, attention_mask.to(device), packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs 175 | ) 176 | self.critic.train() # reset model state 177 | return value.to("cpu") 178 | 179 | def append(self, experience): 180 | """Append experience to replay buffer.""" 181 | self.trainer.replay_buffer.append(experience) 182 | 183 | def fit(self): 184 | """Train critic model with the replay buffer.""" 185 | torch.cuda.empty_cache() 186 | self.critic.train() 187 | status = self.trainer.ppo_train() 188 | self.trainer.replay_buffer.clear() 189 | torch.cuda.empty_cache() 190 | return status 191 | 192 | def empty_cache(self) -> None: 193 | torch.cuda.empty_cache() 194 | 195 | def save_model(self): 196 | args = self.strategy.args 197 | 198 | # save model checkpoint after fitting on only rank0 199 | if args.train_vlm: 200 | self.strategy.save_model( 201 | self.critic, 202 | self.processor, 203 | args.save_path + "_critic", 204 | ) 205 | else: 206 | self.strategy.save_model( 207 | self.critic, 208 | self.tokenizer, 209 | args.save_path + "_critic", 210 | ) 211 | 212 | def save_checkpoint(self, tag): 213 | args = self.strategy.args 214 | self.strategy.save_ckpt( 215 | self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem 216 | ) -------------------------------------------------------------------------------- /openrlhf/trainer/ray/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def ray_noset_visible_devices(env_vars=os.environ): 5 | # Refer to 6 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 7 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 8 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 9 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 13 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 14 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 15 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 16 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 17 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 18 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 19 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 20 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 21 | ] 22 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 23 | 24 | 25 | def get_physical_gpu_id(): 26 | import torch 27 | 28 | device = torch.cuda.current_device() 29 | props = torch.cuda.get_device_properties(device) 30 | return str(props.uuid) 31 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_engine.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import ray 5 | from ray.util.placement_group import placement_group 6 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy 7 | from vllm import LLM 8 | 9 | from openrlhf.utils.logging_utils import init_logger 10 | 11 | logger = init_logger(__name__) 12 | 13 | 14 | @ray.remote 15 | def get_all_env_variables(): 16 | import os 17 | 18 | return os.environ 19 | 20 | 21 | @ray.remote 22 | class LLMRayActor: 23 | 24 | def __init__(self, *args, bundle_indices: list = None, **kwargs): 25 | if kwargs.get("distributed_executor_backend") == "ray": 26 | # a hack to make the script work. 27 | # stop ray from manipulating CUDA_VISIBLE_DEVICES 28 | # at the top-level when the distributed_executor_backend is ray. 29 | os.environ.pop("CUDA_VISIBLE_DEVICES", None) 30 | # every worker will use 0.2 GPU, so that we can schedule 31 | # 2 instances on the same GPUs. 32 | if bundle_indices is not None: 33 | os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.2" 34 | os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices)) 35 | print(f"creating LLM with bundle_indices={bundle_indices}") 36 | 37 | # Number of actors that will send prompt to this engine 38 | self.num_actors = kwargs.pop("num_actors") 39 | self.actor_counter = 0 40 | self.requests = {} 41 | self.responses = {} 42 | 43 | self.llm = LLM(*args, **kwargs) 44 | 45 | def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray): 46 | return self.llm.collective_rpc( 47 | "init_process_group", 48 | args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray), 49 | ) 50 | 51 | def update_weight(self, name, dtype, shape, empty_cache=False): 52 | return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache)) 53 | 54 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False): 55 | return self.llm.collective_rpc("update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache)) 56 | 57 | def reset_prefix_cache(self): 58 | self.llm.llm_engine.reset_prefix_cache() 59 | 60 | def sleep(self, level=1): 61 | self.llm.sleep(level=level) 62 | 63 | def wake_up(self): 64 | self.llm.wake_up() 65 | 66 | def add_requests(self, actor_rank, *, sampling_params, prompt_token_ids): 67 | """ 68 | Save the requests from actors and generate responses when all actors have sent their requests 69 | """ 70 | self.requests[actor_rank] = prompt_token_ids 71 | self.actor_counter += 1 72 | if self.actor_counter == self.num_actors: 73 | assert len(self.requests) == self.num_actors 74 | num_requests = [] 75 | requests = [] 76 | for actor_rank, request in self.requests.items(): 77 | num_requests.append((actor_rank, len(request))) 78 | requests.extend(request) 79 | 80 | if len(requests) > 0: 81 | # For now we assume that all requests have the same sampling params 82 | responses = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=requests) 83 | else: 84 | responses = [] 85 | 86 | offset = 0 87 | self.responses = {} 88 | for actor_rank, num in num_requests: 89 | self.responses[actor_rank] = responses[offset : offset + num] 90 | offset += num 91 | 92 | self.actor_counter = 0 93 | self.requests = {} 94 | 95 | def add_requests_vlm(self, actor_rank, *, sampling_params, vllm_vision_input): 96 | """ 97 | Save the requests from actors and generate responses when all actors have sent their requests 98 | """ 99 | self.requests[actor_rank] = vllm_vision_input 100 | self.actor_counter += 1 101 | if self.actor_counter == self.num_actors: 102 | assert len(self.requests) == self.num_actors, f"{len(self.requests)} != {self.num_actors}" 103 | num_requests = [] 104 | requests = [] 105 | for actor_rank, request in self.requests.items(): 106 | num_requests.append((actor_rank, len(request))) 107 | requests.extend(request) 108 | 109 | if len(requests) > 0: 110 | # For now we assume that all requests have the same sampling params 111 | responses = self.llm.generate(requests, sampling_params=sampling_params) 112 | else: 113 | responses = [] 114 | 115 | offset = 0 116 | self.responses = {} 117 | for actor_rank, num in num_requests: 118 | self.responses[actor_rank] = responses[offset : offset + num] 119 | offset += num 120 | 121 | self.actor_counter = 0 122 | self.requests = {} 123 | 124 | def add_requests_vlm_mix(self, actor_rank, *, sampling_params, vllm_vision_input): 125 | """ 126 | Save the requests from actors and generate responses when all actors have sent their requests 127 | """ 128 | self.requests[actor_rank] = vllm_vision_input 129 | self.actor_counter += 1 130 | if self.actor_counter == self.num_actors: 131 | assert len(self.requests) == self.num_actors, f"{len(self.requests)} != {self.num_actors}" 132 | num_requests = [] 133 | requests = [] 134 | vrall, trall = [], [] 135 | vrsrc, trsrc = [], [] 136 | self.responses = {} 137 | for actor_rank, request in self.requests.items(): 138 | vreq, treq = request 139 | if vreq: 140 | vrall.extend(vreq) 141 | vrsrc.extend([actor_rank] * len(vreq)) 142 | # vresponses = self.llm.generate(vreq, sampling_params=sampling_params) 143 | # print('!!!! debug vr', type(vresponses)) 144 | # else: 145 | # vresponses = [] 146 | if treq: 147 | trall.extend(treq) 148 | trsrc.extend([actor_rank] * len(treq)) 149 | # tresponses = self.llm.generate(treq, sampling_params=sampling_params) 150 | # print('!!!! debug tr', type(tresponses)) 151 | 152 | vresponses = self.llm.generate(vrall, sampling_params=sampling_params) 153 | tresponses = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=trall) 154 | for actor_rank, request in self.requests.items(): 155 | self.responses[actor_rank] = [] 156 | for rank, rsp in zip(vrsrc, vresponses): 157 | self.responses[rank].append(rsp) 158 | for rank, rsp in zip(trsrc, tresponses): 159 | self.responses[rank].append(rsp) 160 | print('debug inside vllm engine') 161 | 162 | self.actor_counter = 0 163 | self.requests = {} 164 | 165 | def get_responses(self, actor_rank): 166 | """ 167 | Return the responses for the actor with the given rank 168 | """ 169 | return self.responses.pop(actor_rank) 170 | 171 | 172 | def create_vllm_engines( 173 | num_engines: int, 174 | tensor_parallel_size: int, 175 | pretrain: str, 176 | seed: int, 177 | enable_prefix_caching: bool, 178 | enforce_eager: bool, 179 | max_model_len: int, 180 | num_total_actors: int, 181 | shared_pg=None, 182 | gpu_memory_utilization=None, 183 | vllm_enable_sleep=False, 184 | ): 185 | import vllm 186 | 187 | assert vllm.__version__ >= "0.7.0", "OpenRLHF only supports vllm >= 0.7.0" 188 | 189 | vllm_engines = [] 190 | num_gpus = int(tensor_parallel_size == 1) 191 | distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray" 192 | for i in range(num_engines): 193 | bundle_indices = None 194 | scheduling_strategy = None 195 | 196 | # Hybrid engine 197 | if shared_pg is not None: 198 | assert vllm.__version__ >= "0.7.2", "Only vllm >= 0.7.2 supports hybrid engine" 199 | 200 | if tensor_parallel_size > 1: 201 | scheduling_strategy = PlacementGroupSchedulingStrategy( 202 | placement_group=shared_pg, 203 | placement_group_capture_child_tasks=True, 204 | placement_group_bundle_index=i * tensor_parallel_size 205 | ) 206 | bundle_indices = np.arange(i * tensor_parallel_size, (i + 1) * tensor_parallel_size).tolist() 207 | else: 208 | num_gpus = 0.2 209 | scheduling_strategy = PlacementGroupSchedulingStrategy( 210 | placement_group=shared_pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=i 211 | ) 212 | # Distributed RLHF 213 | elif tensor_parallel_size > 1: 214 | bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size 215 | pg = placement_group(bundles) 216 | ray.get(pg.ready()) 217 | 218 | scheduling_strategy = PlacementGroupSchedulingStrategy( 219 | placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0 220 | ) 221 | 222 | if num_engines >= num_total_actors: 223 | num_actors = 1 224 | else: 225 | num_actors = num_total_actors // num_engines + int(i < num_total_actors % num_engines) 226 | 227 | vllm_engines.append( 228 | LLMRayActor.options( 229 | num_cpus=0, 230 | num_gpus=num_gpus, 231 | scheduling_strategy=scheduling_strategy, 232 | ).remote( 233 | model=pretrain, 234 | enforce_eager=enforce_eager, 235 | worker_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap", 236 | tensor_parallel_size=tensor_parallel_size, 237 | seed=seed + i, 238 | distributed_executor_backend=distributed_executor_backend, 239 | max_model_len=max_model_len, 240 | enable_prefix_caching=enable_prefix_caching, 241 | dtype="bfloat16", 242 | trust_remote_code=True, 243 | num_actors=num_actors, 244 | gpu_memory_utilization=gpu_memory_utilization, 245 | bundle_indices=bundle_indices if shared_pg else None, 246 | enable_sleep_mode=vllm_enable_sleep, 247 | limit_mm_per_prompt={"image": 8} 248 | ) 249 | ) 250 | 251 | return vllm_engines 252 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_worker_wrap.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from vllm.worker.worker import Worker 3 | 4 | from openrlhf.utils.distributed_util import init_process_group 5 | from openrlhf.utils.logging_utils import init_logger 6 | from .utils import get_physical_gpu_id 7 | 8 | logger = init_logger(__name__) 9 | 10 | 11 | class WorkerWrap(Worker): 12 | def init_process_group( 13 | self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False 14 | ): 15 | """Init torch process group for model weights update""" 16 | assert torch.distributed.is_initialized(), f"default torch process group must be initialized" 17 | assert group_name != "", f"group name must not be empty" 18 | 19 | rank = torch.distributed.get_rank() + rank_offset 20 | if use_ray: 21 | import ray.util.collective as collective 22 | 23 | collective.init_collective_group(world_size=world_size, rank=rank, backend=backend, group_name=group_name) 24 | self._model_update_group = group_name 25 | else: 26 | self._model_update_group = init_process_group( 27 | backend=backend, 28 | init_method=f"tcp://{master_address}:{master_port}", 29 | world_size=world_size, 30 | rank=rank, 31 | group_name=group_name, 32 | ) 33 | self._model_update_with_ray = use_ray 34 | print( 35 | f"init_process_group: master_address={master_address}, master_port={master_port}, ", 36 | f"rank={rank}, world_size={world_size}, group_name={group_name}", 37 | ) 38 | 39 | def update_weight(self, name, dtype, shape, empty_cache=False): 40 | """Broadcast weight to all vllm workers from source rank 0 (actor model)""" 41 | if torch.distributed.get_rank() == 0: 42 | print(f"[vllm broadcast] update weight: {name}, dtype: {dtype}, shape: {shape}") 43 | 44 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 45 | weight = torch.empty(shape, dtype=dtype, device="cuda") 46 | if self._model_update_with_ray: 47 | import ray.util.collective as collective 48 | 49 | collective.broadcast(weight, 0, group_name=self._model_update_group) 50 | else: 51 | torch.distributed.broadcast(weight, 0, group=self._model_update_group) 52 | 53 | self.model_runner.model.load_weights(weights=[(name, weight)]) 54 | 55 | del weight 56 | # TODO: should we empty cache if all weights have updated? 57 | # if empty_cache: 58 | # torch.cuda.empty_cache() 59 | 60 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False): 61 | if torch.distributed.get_rank() == 0: 62 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 63 | 64 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 65 | 66 | handle = ipc_handles[get_physical_gpu_id()] 67 | device_id = self.device.index 68 | func, args = handle 69 | list_args = list(args) 70 | # the key is to change device id to the current device id 71 | # in case two processes have different CUDA_VISIBLE_DEVICES 72 | list_args[6] = device_id 73 | weight = func(*list_args) 74 | self.model_runner.model.load_weights(weights=[(name, weight)]) 75 | torch.cuda.synchronize() 76 | -------------------------------------------------------------------------------- /openrlhf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import get_processor, reward_normalization 2 | from .utils import blending_datasets, get_strategy, get_tokenizer, get_vl_processor 3 | 4 | __all__ = [ 5 | "get_processor", 6 | "reward_normalization", 7 | "blending_datasets", 8 | "get_strategy", 9 | "get_tokenizer", 10 | "get_vl_processor", 11 | ] 12 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed import DeepspeedStrategy 2 | 3 | __all__ = [ 4 | "DeepspeedStrategy", 5 | ] 6 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/deepspeed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import shutil 4 | from abc import ABC 5 | from collections import defaultdict 6 | from datetime import timedelta 7 | from typing import List, Tuple, Union 8 | 9 | import deepspeed 10 | import numpy as np 11 | import torch 12 | import torch.nn as nn 13 | import torch.optim as optim 14 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 15 | from peft import PeftModel, get_peft_model_state_dict 16 | from torch import distributed as dist 17 | from torch.optim import Optimizer 18 | from torch.utils.data import DataLoader 19 | 20 | from openrlhf.models import Actor 21 | from openrlhf.models.ring_attn_utils import get_ring_attn_group, set_ring_attn_group 22 | from openrlhf.utils.distributed_sampler import DistributedSampler 23 | 24 | from .deepspeed_utils import ( 25 | _z3_params_to_fetch, 26 | get_eval_ds_config, 27 | get_optimizer_grouped_parameters, 28 | get_train_ds_config, 29 | ) 30 | 31 | ModelOptimPair = Tuple[nn.Module, Optimizer] 32 | ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair] 33 | 34 | 35 | class DeepspeedStrategy(ABC): 36 | """ 37 | The strategy for training with Accelerator. 38 | """ 39 | 40 | def __init__( 41 | self, 42 | seed: int = 42, 43 | max_norm: float = 0.0, 44 | micro_train_batch_size=1, 45 | train_batch_size=1, 46 | zero_stage=2, 47 | bf16=True, 48 | args=None, 49 | ) -> None: 50 | super().__init__() 51 | 52 | self.args = args 53 | self.stage = zero_stage 54 | self.train_batch_size = train_batch_size 55 | self.micro_train_batch_size = micro_train_batch_size 56 | self.bf16 = bf16 57 | self.seed = seed 58 | self.max_norm = max_norm 59 | self.adam_offload = getattr(args, "adam_offload", False) 60 | self.param_offload = getattr(args, "param_offload", False) 61 | self.zpg = getattr(args, "zpg", 1) 62 | self.grad_accum_dtype = getattr(args, "grad_accum_dtype", None) 63 | # overlap_comm 64 | self.overlap_comm = getattr(args, "overlap_comm", False) 65 | 66 | self.is_rlhf = False 67 | self.time_steps = defaultdict(int) 68 | 69 | def set_seed(self, seed: int) -> None: 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed_all(seed) 74 | 75 | def setup_distributed(self, timeout=timedelta(minutes=60)) -> None: 76 | self.set_seed(self.seed) 77 | 78 | if self.args.local_rank == -1 and "LOCAL_RANK" in os.environ: # for slurm 79 | self.args.local_rank = int(os.environ["LOCAL_RANK"]) 80 | 81 | if self.args.local_rank != -1: 82 | torch.cuda.set_device(self.args.local_rank) 83 | print('!!!! setting up distributed') 84 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs 85 | deepspeed.init_distributed(timeout=timeout) 86 | self.setup_ring_attn() 87 | self.world_size = dist.get_world_size() 88 | self.accumulated_gradient = ( 89 | self.train_batch_size * self.ring_attn_size // self.micro_train_batch_size // self.world_size 90 | ) 91 | 92 | def setup_ring_attn(self): 93 | self.ring_attn_size = getattr(self.args, "ring_attn_size", 1) 94 | if self.ring_attn_size == 1: 95 | self.ring_attn_rank = 0 96 | return 97 | 98 | ring_head_stride = getattr(self.args, "ring_head_stride", 1) 99 | for i in range(dist.get_world_size() // self.ring_attn_size): 100 | ring_attn_ranks = list( 101 | range( 102 | i * self.ring_attn_size, 103 | (i + 1) * self.ring_attn_size, 104 | ) 105 | ) 106 | group = dist.new_group(ranks=ring_attn_ranks, backend="nccl") 107 | if dist.get_rank() in ring_attn_ranks: 108 | set_ring_attn_group(group) 109 | self.ring_attn_rank = dist.get_rank(group=group) 110 | 111 | from ring_flash_attn import substitute_hf_flash_attn 112 | 113 | substitute_hf_flash_attn(self.ring_attn_group, ring_head_stride) 114 | 115 | @property 116 | def ring_attn_group(self): 117 | return get_ring_attn_group() 118 | 119 | def create_optimizer(self, model, **kwargs) -> Optimizer: 120 | if isinstance(model, Actor): 121 | model = model.model 122 | # Optimizer 123 | AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam 124 | optim_params = get_optimizer_grouped_parameters(model, kwargs["weight_decay"]) 125 | optim = AdamOptimizer(optim_params, **kwargs) 126 | return optim 127 | 128 | def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None: 129 | if isinstance(model, Actor): 130 | model = model.model 131 | model.backward(loss) 132 | 133 | def optimizer_step( 134 | self, 135 | optimizer: optim.Optimizer, 136 | model: nn.Module, 137 | scheduler, 138 | name="model", 139 | **kwargs, 140 | ) -> None: 141 | if isinstance(model, Actor): 142 | model = model.model 143 | model.step() 144 | 145 | def setup_dataloader( 146 | self, 147 | replay_buffer, 148 | batch_size: int, 149 | pin_memory: bool = False, 150 | shuffle=True, 151 | collate_fn=None, 152 | drop_last=True, 153 | sampler=None, 154 | consumed_samples=0, 155 | ): 156 | # DDP only mode, replay buffers on each rank are different. 157 | if sampler is None: 158 | num_replicas = dist.get_world_size() // self.ring_attn_size 159 | rank = dist.get_rank() // self.ring_attn_size 160 | sampler = DistributedSampler( 161 | replay_buffer, 162 | num_replicas=num_replicas, 163 | rank=rank, 164 | shuffle=shuffle, 165 | seed=self.seed, 166 | drop_last=drop_last, 167 | consumed_samples=consumed_samples, 168 | ) 169 | 170 | return DataLoader( 171 | replay_buffer, 172 | batch_size=batch_size, 173 | sampler=sampler, 174 | drop_last=drop_last, 175 | collate_fn=collate_fn, 176 | pin_memory=pin_memory, 177 | ) 178 | 179 | def _unwrap_model(self, model) -> nn.Module: 180 | if isinstance(model, Actor): 181 | return self._unwrap_model(model.model) 182 | elif hasattr(model, "module"): 183 | return model.module 184 | else: 185 | return model 186 | 187 | def prepare( 188 | self, *models_or_model_optim_pairs: ModelOrModelOptimPair, is_rlhf=False 189 | ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: 190 | ret = [] 191 | self.is_rlhf = is_rlhf 192 | for arg in models_or_model_optim_pairs: 193 | if isinstance(arg, tuple): 194 | assert len(arg) == 3, f'Expect (model, optimizer, scheduler) pair, got a tuple with size "{len(arg)}"' 195 | if arg[0] is not None: 196 | ret.append(self._ds_init_train_model(*arg)) 197 | else: 198 | ret.append((None, None, None)) 199 | else: 200 | ret.append(self._ds_init_eval_model(arg)) 201 | 202 | return ret[0] if len(ret) == 1 else ret 203 | 204 | def _ds_init_train_model(self, model, optim, scheduler): 205 | is_actor = isinstance(model, Actor) 206 | ds_config = self.get_ds_train_config(is_actor) 207 | 208 | engine, optim, _, scheduler = deepspeed.initialize( 209 | model=model.model if is_actor else model, 210 | optimizer=optim, 211 | lr_scheduler=scheduler, 212 | config=ds_config, 213 | args={"local_rank": self.args.local_rank}, 214 | dist_init_required=True, 215 | ) 216 | if is_actor: 217 | model.model = engine 218 | else: 219 | model = engine 220 | 221 | return model, optim, scheduler 222 | 223 | def get_ds_train_config(self, is_actor): 224 | # DS Config 225 | ds_config = get_train_ds_config( 226 | offload=self.param_offload, 227 | adam_offload=self.adam_offload, 228 | stage=self.stage, 229 | bf16=self.bf16, 230 | max_norm=self.max_norm, 231 | zpg=self.zpg, 232 | grad_accum_dtype=self.grad_accum_dtype, 233 | overlap_comm=self.overlap_comm, 234 | ) 235 | print('!!!! ds config', ds_config) 236 | 237 | ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size 238 | train_batch_size = self.train_batch_size 239 | # corner case for ptx loss (backward twice) 240 | if self.is_rlhf and is_actor and self.args.pretrain_data is not None: 241 | train_batch_size *= 2 242 | ds_config["train_batch_size"] = train_batch_size * self.ring_attn_size 243 | 244 | return ds_config 245 | 246 | def _ds_init_eval_model(self, model): 247 | if not model: 248 | return model 249 | is_actor = isinstance(model, Actor) 250 | ds_config = self.get_ds_eval_config(offload=getattr(model, "_offload", False)) 251 | 252 | engine, *_ = deepspeed.initialize( 253 | model=model.model if is_actor else model, 254 | args={"local_rank": self.args.local_rank}, 255 | config=ds_config, 256 | dist_init_required=True, 257 | ) 258 | if is_actor: 259 | model.model = engine 260 | else: 261 | model = engine 262 | return model 263 | 264 | def get_ds_eval_config(self, offload=False): 265 | # DS Config 266 | ds_config = get_eval_ds_config(offload=offload, stage=self.stage if self.stage == 3 else 0, bf16=self.bf16) 267 | ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size 268 | ds_config["train_batch_size"] = self.train_batch_size * self.ring_attn_size 269 | 270 | return ds_config 271 | 272 | def moving_average(self, model, model_ema, beta=0.992, device="cpu"): 273 | self.time_steps["ema"] += 1 274 | if self.time_steps["ema"] % self.accumulated_gradient == 0: 275 | with torch.no_grad(): 276 | for param, param_ema in zip(model.parameters(), model_ema.parameters()): 277 | if param.requires_grad: 278 | if self.stage != 3: 279 | data = param.data.to(device) 280 | param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) 281 | else: 282 | # TODO: use prefiltering for efficiency 283 | params_to_fetch = _z3_params_to_fetch([param, param_ema]) 284 | with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): 285 | data = param.data.to(device) 286 | param_ema.data.copy_((1 - beta) * data + beta * param_ema.data) 287 | 288 | def load_model( 289 | self, 290 | model: nn.Module, 291 | path: str, 292 | map_location="cpu", 293 | strict: bool = False, 294 | key_replace_fn=None, 295 | ) -> None: 296 | unwrapped_model = self._unwrap_model(model) 297 | state_dict = torch.load(path, map_location=map_location) 298 | if key_replace_fn: 299 | state_dict = key_replace_fn(state_dict) 300 | unwrapped_model.load_state_dict(state_dict, strict=strict) 301 | 302 | def save_model(self, model: nn.Module, tokenizer, output_dir, **kwargs) -> None: 303 | if self.is_rank_0(): 304 | os.makedirs(output_dir, exist_ok=True) 305 | print('!!!! [saving] model', model) 306 | torch.distributed.barrier() 307 | # save model weights for ZeRO2/3 308 | model_to_save = self._unwrap_model(model) 309 | 310 | # gather parameters 311 | output_state_dict = {} 312 | dist.barrier() 313 | for k, v in model_to_save.named_parameters(): 314 | 315 | # only gather z3 params 316 | params_to_fetch = _z3_params_to_fetch([v]) 317 | with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0): 318 | vv = v.data.cpu() 319 | if self.is_rank_0(): 320 | output_state_dict[k] = vv 321 | print(f"!!!! [saving] named_parameters after gather, {k}:{v.shape}") 322 | 323 | if self.is_rank_0(): 324 | # print('!!!! after named_parameters', sorted(list(output_state_dict.keys()))) 325 | state_dict = model_to_save.state_dict() 326 | 327 | # copy named_buffers with `persistent=True` 328 | for k, v in model_to_save.named_buffers(): 329 | if k not in state_dict: 330 | continue 331 | # print(f"!!!! [saving] named_buffers, {k}:{v.shape}") 332 | vv = v.data.cpu() 333 | output_state_dict[k] = vv 334 | # print('!!!! after named_buffers', sorted(list(output_state_dict.keys()))) 335 | 336 | for k in output_state_dict: 337 | v = output_state_dict[k] 338 | # print(f'!!!! [saving] {k}:{v.shape}') 339 | if v.size(0) == 0: 340 | print(f"!!!! [saving] {k} is empty") 341 | # exit(-1) 342 | 343 | state_dict_keys = set(state_dict.keys()) 344 | output_state_dict_keys = set(output_state_dict.keys()) 345 | 346 | # corner case for tie_word_embeddings, such as Qwen2-0.5B 347 | if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys: 348 | state_dict_keys.remove("lm_head.weight") 349 | 350 | assert state_dict_keys.issubset( 351 | output_state_dict_keys 352 | ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}" 353 | 354 | # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295 355 | if isinstance(model_to_save, PeftModel): 356 | model_to_save.save_pretrained(output_dir, **kwargs) 357 | if self.stage == 3: 358 | torch.save( 359 | get_peft_model_state_dict(model_to_save, output_state_dict), 360 | os.path.join(output_dir, "adapter_model.bin"), 361 | ) 362 | filename = os.path.join(output_dir, "adapter_model.safetensors") 363 | if os.path.exists(filename): 364 | os.remove(filename) 365 | else: 366 | # save model 367 | model_to_save.save_pretrained(output_dir, state_dict=output_state_dict, **kwargs) 368 | 369 | # save config 370 | output_config_file = os.path.join(output_dir, "config.json") 371 | model_to_save.config.to_json_file(output_config_file) 372 | # save tokenizer 373 | tokenizer.save_pretrained(output_dir) 374 | 375 | # for models not in AutoModel, copy python module files 376 | train_from_model_path = model_to_save.config._name_or_path 377 | if os.path.exists(train_from_model_path): 378 | for filename in os.listdir(train_from_model_path): 379 | if filename.endswith(".py"): 380 | shutil.copy(os.path.join(train_from_model_path, filename), os.path.join(output_dir, filename)) 381 | 382 | def all_reduce(self, data, op="mean"): 383 | assert op in ("mean", "max", "sum") 384 | if isinstance(data, dict): 385 | ret = {} 386 | for k, v in data.items(): 387 | ret[k] = self.all_reduce(v, op) 388 | return ret 389 | else: 390 | is_tensor = True 391 | if not isinstance(data, torch.Tensor): 392 | data = torch.Tensor([data]) 393 | is_tensor = False 394 | is_cpu_tensor = data.device.type == "cpu" 395 | 396 | if is_cpu_tensor: 397 | data = data.to(torch.cuda.current_device()) 398 | if op == "mean": 399 | data /= self.world_size 400 | dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM) 401 | if is_cpu_tensor: 402 | data = data.cpu() 403 | return data.item() if not is_tensor else data 404 | 405 | def all_gather(self, data): 406 | if isinstance(data, dict): 407 | ret = {} 408 | for k, v in data.items(): 409 | ret[k] = self.all_gather(v) 410 | return ret 411 | else: 412 | if not isinstance(data, torch.Tensor): 413 | data = torch.Tensor([data]) 414 | is_cpu_tensor = data.device.type == "cpu" 415 | 416 | ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)] 417 | dist.all_gather(ret, data.to(torch.cuda.current_device())) 418 | return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret) 419 | 420 | def print(self, *msg): 421 | if self.is_rank_0(): 422 | print(*msg) 423 | 424 | def is_rank_0(self) -> bool: 425 | return dist.get_rank() == 0 426 | 427 | def get_rank(self) -> int: 428 | return dist.get_rank() 429 | 430 | def save_ckpt(self, model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True): 431 | assert isinstance(model, deepspeed.DeepSpeedEngine) 432 | if self.is_rank_0(): 433 | os.makedirs(save_dir, exist_ok=True) 434 | MAX_SIZE = max_mem * 1024**3 # Convert GB to bytes 435 | 436 | while True: 437 | subdirs = sorted( 438 | [ 439 | (os.path.join(save_dir, d), os.path.getmtime(os.path.join(save_dir, d))) 440 | for d in os.listdir(save_dir) 441 | if os.path.isdir(os.path.join(save_dir, d)) 442 | ], 443 | key=lambda x: x[1], 444 | ) 445 | total_size = sum( 446 | os.path.getsize(os.path.join(dirpath, f)) 447 | for subdir, _ in subdirs 448 | for dirpath, _, filenames in os.walk(subdir) 449 | for f in filenames 450 | ) 451 | 452 | if len(subdirs) >= max_num or total_size > MAX_SIZE: 453 | oldest_dir = subdirs[0][0] 454 | if os.path.exists(oldest_dir): 455 | shutil.rmtree(oldest_dir) 456 | self.print(f"Deleted oldest ckpt {oldest_dir}") 457 | else: 458 | break 459 | 460 | dist.barrier() 461 | model.save_checkpoint(save_dir, tag=tag, client_state=client_state, save_latest=save_latest) 462 | 463 | def load_ckpt( 464 | self, 465 | model, 466 | load_dir, 467 | tag=None, 468 | load_module_strict=True, 469 | load_optimizer_states=True, 470 | load_lr_scheduler_states=True, 471 | load_module_only=False, 472 | ): 473 | assert isinstance(model, deepspeed.DeepSpeedEngine) 474 | load_path, states = model.load_checkpoint( 475 | load_dir, 476 | tag, 477 | load_module_strict=load_module_strict, 478 | load_optimizer_states=load_optimizer_states, 479 | load_lr_scheduler_states=load_lr_scheduler_states, 480 | load_module_only=load_module_only, 481 | ) 482 | if load_path is None: 483 | raise Exception(f"[deepspeed] failed to resume from checkpoint {load_dir}") 484 | return load_path, states 485 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # SPDX-License-Identifier: Apache-2.0 3 | 4 | # DeepSpeed Team 5 | 6 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 7 | 8 | 9 | def get_train_ds_config( 10 | offload, 11 | adam_offload=True, 12 | stage=2, 13 | bf16=True, 14 | max_norm=1.0, 15 | zpg=8, 16 | grad_accum_dtype=None, 17 | overlap_comm=False, 18 | ): 19 | device = "cpu" if offload else "none" 20 | zero_opt_dict = { 21 | "stage": stage, 22 | "offload_param": {"device": device}, 23 | "offload_optimizer": { 24 | "device": "cpu" if adam_offload else "none", 25 | "pin_memory": True 26 | # "pin_memory": False, 27 | # "ratio": 0.9, 28 | }, 29 | "sub_group_size": "auto", 30 | "stage3_max_live_parameters": "auto", 31 | "stage3_max_reuse_distance": "auto", 32 | "stage3_param_persistence_threshold": "auto", 33 | "stage3_prefetch_bucket_size": "auto", 34 | "reduce_bucket_size": "auto", 35 | # ZeRO++ 36 | "zero_hpz_partition_size": zpg, 37 | "zero_quantized_weights": False, 38 | "zero_quantized_gradients": False, 39 | } 40 | if overlap_comm: 41 | zero_opt_dict["overlap_comm"] = True 42 | zero_opt_dict["contiguous_gradients"] = True 43 | 44 | return { 45 | "steps_per_print": 100, 46 | "zero_optimization": zero_opt_dict, 47 | "bf16": { 48 | "enabled": bf16, 49 | }, 50 | "gradient_clipping": max_norm, 51 | "prescale_gradients": False, 52 | "wall_clock_breakdown": False, 53 | "data_types": {"grad_accum_dtype": grad_accum_dtype}, 54 | } 55 | 56 | 57 | def get_eval_ds_config( 58 | offload, 59 | stage=0, 60 | bf16=True, 61 | ): 62 | zero_opt_dict = { 63 | "stage": stage, 64 | "stage3_param_persistence_threshold": "auto", 65 | "offload_param": { 66 | "device": "cpu" if offload else "none", 67 | "pin_memory": True, 68 | }, 69 | } 70 | return { 71 | "steps_per_print": 100, 72 | "zero_optimization": zero_opt_dict, 73 | "bf16": { 74 | "enabled": bf16, 75 | }, 76 | "gradient_clipping": 1.0, 77 | "prescale_gradients": False, 78 | "wall_clock_breakdown": False, 79 | } 80 | 81 | 82 | def get_optimizer_grouped_parameters( 83 | model, 84 | weight_decay, 85 | no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], 86 | ): 87 | optimizer_grouped_parameters = [ 88 | { 89 | "params": [ 90 | p 91 | for n, p in model.named_parameters() 92 | if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) 93 | ], 94 | "weight_decay": weight_decay, 95 | }, 96 | { 97 | "params": [ 98 | p 99 | for n, p in model.named_parameters() 100 | if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) 101 | ], 102 | "weight_decay": 0.0, 103 | }, 104 | ] 105 | return optimizer_grouped_parameters 106 | 107 | 108 | def _z3_params_to_fetch(param_list): 109 | return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] 110 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator, Optional, TypeVar 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | __all__ = ["DistributedSampler"] 11 | 12 | 13 | _T_co = TypeVar("_T_co", covariant=True) 14 | 15 | 16 | # Adapted from https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/torch/utils/data/distributed.py 17 | class DistributedSampler(Sampler[_T_co]): 18 | r"""Sampler that restricts data loading to a subset of the dataset. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 24 | original dataset that is exclusive to it. 25 | 26 | .. note:: 27 | Dataset is assumed to be of constant size and that any instance of it always 28 | returns the same elements in the same order. 29 | 30 | Args: 31 | dataset: Dataset used for sampling. 32 | num_replicas (int, optional): Number of processes participating in 33 | distributed training. By default, :attr:`world_size` is retrieved from the 34 | current distributed group. 35 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 36 | By default, :attr:`rank` is retrieved from the current distributed 37 | group. 38 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 39 | indices. 40 | seed (int, optional): random seed used to shuffle the sampler if 41 | :attr:`shuffle=True`. This number should be identical across all 42 | processes in the distributed group. Default: ``0``. 43 | drop_last (bool, optional): if ``True``, then the sampler will drop the 44 | tail of the data to make it evenly divisible across the number of 45 | replicas. If ``False``, the sampler will add extra indices to make 46 | the data evenly divisible across the replicas. Default: ``False``. 47 | 48 | .. warning:: 49 | In distributed mode, calling the :meth:`set_epoch` method at 50 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 51 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 52 | the same ordering will be always used. 53 | 54 | Example:: 55 | 56 | >>> # xdoctest: +SKIP 57 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 58 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 59 | ... sampler=sampler) 60 | >>> for epoch in range(start_epoch, n_epochs): 61 | ... if is_distributed: 62 | ... sampler.set_epoch(epoch) 63 | ... train(loader) 64 | """ 65 | 66 | def __init__( 67 | self, 68 | dataset: Dataset, 69 | num_replicas: Optional[int] = None, 70 | rank: Optional[int] = None, 71 | shuffle: bool = True, 72 | seed: int = 0, 73 | drop_last: bool = False, 74 | consumed_samples=0, 75 | ) -> None: 76 | if num_replicas is None: 77 | if not dist.is_available(): 78 | raise RuntimeError("Requires distributed package to be available") 79 | num_replicas = dist.get_world_size() 80 | if rank is None: 81 | if not dist.is_available(): 82 | raise RuntimeError("Requires distributed package to be available") 83 | rank = dist.get_rank() 84 | if rank >= num_replicas or rank < 0: 85 | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") 86 | self.dataset = dataset 87 | self.num_replicas = num_replicas 88 | self.rank = rank 89 | self.epoch = 0 90 | self.drop_last = drop_last 91 | # If the dataset length is evenly divisible by # of replicas, then there 92 | # is no need to drop any data, since the dataset will be split equally. 93 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 94 | # Split to nearest available length that is evenly divisible. 95 | # This is to ensure each rank receives the same amount of data when 96 | # using this Sampler. 97 | self.num_samples = math.ceil( 98 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 99 | ) 100 | else: 101 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 102 | self.total_size = self.num_samples * self.num_replicas 103 | self.shuffle = shuffle 104 | self.seed = seed 105 | self.consumed_indicies = consumed_samples // self.num_replicas 106 | 107 | def __iter__(self) -> Iterator[_T_co]: 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch and seed 110 | g = torch.Generator() 111 | g.manual_seed(self.seed + self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 113 | else: 114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 115 | 116 | if not self.drop_last: 117 | # add extra samples to make it evenly divisible 118 | padding_size = self.total_size - len(indices) 119 | if padding_size <= len(indices): 120 | indices += indices[:padding_size] 121 | else: 122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 123 | else: 124 | # remove tail of data to make it evenly divisible. 125 | indices = indices[: self.total_size] 126 | assert len(indices) == self.total_size 127 | 128 | # subsample 129 | indices = indices[self.rank : self.total_size : self.num_replicas] 130 | # skip consumed_samples 131 | indices = indices[self.consumed_indicies :] 132 | assert len(indices) == self.num_samples - self.consumed_indicies 133 | 134 | return iter(indices) 135 | 136 | def __len__(self) -> int: 137 | return self.num_samples - self.consumed_indicies 138 | 139 | def set_epoch(self, epoch: int, consumed_samples=0) -> None: 140 | r""" 141 | Set the epoch for this sampler. 142 | 143 | When :attr:`shuffle=True`, this ensures all replicas 144 | use a different random ordering for each epoch. Otherwise, the next iteration of this 145 | sampler will yield the same ordering. 146 | 147 | Args: 148 | epoch (int): Epoch number. 149 | """ 150 | self.epoch = epoch 151 | self.consumed_indicies = consumed_samples // self.num_replicas 152 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_util.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Optional, Union 3 | 4 | import torch 5 | import torch.distributed 6 | from torch.distributed.distributed_c10d import ( 7 | Backend, 8 | PrefixStore, 9 | Store, 10 | _new_process_group_helper, 11 | _world, 12 | default_pg_timeout, 13 | rendezvous, 14 | ) 15 | 16 | 17 | # Copy from pytorch to allow creating multiple main groups. 18 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py 19 | def init_process_group( 20 | backend: Union[str, Backend] = None, 21 | init_method: Optional[str] = None, 22 | timeout: Optional[timedelta] = None, 23 | world_size: int = -1, 24 | rank: int = -1, 25 | store: Optional[Store] = None, 26 | group_name: str = None, 27 | pg_options: Optional[Any] = None, 28 | ): 29 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." 30 | 31 | if store is not None: 32 | assert world_size > 0, "world_size must be positive if using store" 33 | assert rank >= 0, "rank must be non-negative if using store" 34 | elif init_method is None: 35 | init_method = "env://" 36 | 37 | if backend: 38 | backend = Backend(backend) 39 | else: 40 | backend = Backend("undefined") 41 | 42 | if timeout is None: 43 | timeout = default_pg_timeout 44 | 45 | # backward compatible API 46 | if store is None: 47 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) 48 | store, rank, world_size = next(rendezvous_iterator) 49 | store.set_timeout(timeout) 50 | 51 | # Use a PrefixStore to avoid accidental overrides of keys used by 52 | # different systems (e.g. RPC) in case the store is multi-tenant. 53 | store = PrefixStore(group_name, store) 54 | 55 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 56 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 57 | # We need to determine the appropriate parameter name based on PyTorch version 58 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" 59 | pg, _ = _new_process_group_helper( 60 | world_size, 61 | rank, 62 | [], 63 | backend, 64 | store, 65 | group_name=group_name, 66 | **{pg_options_param_name: pg_options}, 67 | timeout=timeout, 68 | ) 69 | 70 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} 71 | 72 | return pg 73 | -------------------------------------------------------------------------------- /openrlhf/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py 3 | """Logging configuration for vLLM.""" 4 | import logging 5 | import sys 6 | 7 | _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" 8 | _DATE_FORMAT = "%m-%d %H:%M:%S" 9 | 10 | 11 | class NewLineFormatter(logging.Formatter): 12 | """Adds logging prefix to newlines to align multi-line messages.""" 13 | 14 | def __init__(self, fmt, datefmt=None): 15 | logging.Formatter.__init__(self, fmt, datefmt) 16 | 17 | def format(self, record): 18 | msg = logging.Formatter.format(self, record) 19 | if record.message != "": 20 | parts = msg.split(record.message) 21 | msg = msg.replace("\n", "\r\n" + parts[0]) 22 | return msg 23 | 24 | 25 | _root_logger = logging.getLogger("openrlhf") 26 | _default_handler = None 27 | 28 | 29 | def _setup_logger(): 30 | _root_logger.setLevel(logging.DEBUG) 31 | global _default_handler 32 | if _default_handler is None: 33 | _default_handler = logging.StreamHandler(sys.stdout) 34 | _default_handler.flush = sys.stdout.flush # type: ignore 35 | _default_handler.setLevel(logging.INFO) 36 | _root_logger.addHandler(_default_handler) 37 | fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) 38 | _default_handler.setFormatter(fmt) 39 | # Setting this will avoid the message 40 | # being propagated to the parent logger. 41 | _root_logger.propagate = False 42 | 43 | 44 | # The logger is initialized when the module is imported. 45 | # This is thread-safe as the module is only imported once, 46 | # guaranteed by the Python GIL. 47 | _setup_logger() 48 | 49 | 50 | def init_logger(name: str): 51 | # Use the same settings as above for root logger 52 | logger = logging.getLogger(name) 53 | logger.setLevel(logging.DEBUG) 54 | logger.addHandler(_default_handler) 55 | logger.propagate = False 56 | return logger 57 | -------------------------------------------------------------------------------- /openrlhf/utils/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | def reward_normalization(objs): 6 | rewards = [float(obj["reward"]) for obj in objs] 7 | rewards = torch.tensor(rewards, dtype=torch.float64) 8 | rewards = (rewards - rewards.mean()) / rewards.std() 9 | for i, obj in enumerate(objs): 10 | obj["reward"] = rewards[i].item() 11 | 12 | 13 | # Conditional SFT 14 | # See https://arxiv.org/abs/2308.12050 15 | DEFAULT_REWARD_PROMPT = "{input} : {reward} " 16 | 17 | 18 | def conditional_sft_processor(args, objs): 19 | if "reward_template" not in args or args.reward_template is None: 20 | reward_template = DEFAULT_REWARD_PROMPT 21 | else: 22 | reward_template = args.reward_template 23 | assert "{input}" in reward_template 24 | assert "{reward}" in reward_template 25 | 26 | if args.normalize_reward: 27 | reward_normalization(objs) 28 | 29 | for obj in tqdm(objs, desc="Conditional SFT process..."): 30 | input = obj["input"] 31 | reward = "{:.2f}".format(float(obj["reward"])) 32 | input = reward_template.replace("{reward}", reward).replace("{input}", input) 33 | obj["input"] = input 34 | 35 | return objs 36 | 37 | 38 | # Rejection Sampling 39 | # See https://arxiv.org/abs/2307.09288 40 | def rejection_sampling_processor(args, objs): 41 | out = {} 42 | for obj in tqdm(objs, desc="Rejection Sampling process...."): 43 | input = obj["input"] 44 | output = obj["output"] 45 | reward = float(obj["reward"]) 46 | 47 | if input not in out: 48 | out[input] = {"output": output, "reward": reward} 49 | elif reward > out[input]["reward"]: 50 | out[input]["reward"] = reward 51 | out[input]["output"] = output 52 | 53 | return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()] 54 | 55 | 56 | # Iterative DPO 57 | # See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh 58 | def iterative_dpo_processor(args, objs): 59 | out = {} 60 | for obj in tqdm(objs, desc="Iterative DPO process...."): 61 | input = obj["input"] 62 | output = obj["output"] 63 | reward = float(obj["reward"]) 64 | 65 | if input not in out: 66 | out[input] = { 67 | "output": output, 68 | "chosen": output, 69 | "chosen_reward": reward, 70 | "rejected": output, 71 | "rejected_reward": reward, 72 | } 73 | elif reward > out[input]["chosen_reward"]: 74 | out[input]["chosen_reward"] = reward 75 | out[input]["chosen"] = output 76 | elif reward < out[input]["rejected_reward"]: 77 | out[input]["rejected_reward"] = reward 78 | out[input]["rejected"] = output 79 | 80 | return [ 81 | { 82 | "prompt": k, 83 | "chosen": v["chosen"], 84 | "chosen_reward": v["chosen_reward"], 85 | "rejected": v["rejected"], 86 | "rejected_reward": v["rejected_reward"], 87 | } 88 | for k, v in out.items() 89 | ] 90 | 91 | 92 | PROCESSORS = { 93 | "rs": rejection_sampling_processor, 94 | "csft": conditional_sft_processor, 95 | "iter_dpo": iterative_dpo_processor, 96 | } 97 | 98 | 99 | def get_processor(name): 100 | if name in PROCESSORS: 101 | return PROCESSORS[name] 102 | else: 103 | raise ValueError(f"Processor {name} does not exist.") 104 | -------------------------------------------------------------------------------- /openrlhf/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import interleave_datasets, load_dataset, load_from_disk 4 | from transformers import AutoTokenizer, AutoProcessor, AutoModel 5 | 6 | 7 | def get_vl_processor(pretrain, model, padding_side="left", strategy=None, use_fast=True): 8 | # TODO: Maybe better max_pixels set methods for other vl model 9 | # follow qwen-vl2.5 https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#image-resolution-for-performance-boost 10 | min_pixels = int(os.getenv("MIN_PIXELS", 256*28*28)) 11 | max_pixels = int(os.getenv("MAX_PIXELS", 1280*28*28)) 12 | processor = AutoProcessor.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels) 13 | tokenizer = processor.tokenizer 14 | tokenizer.padding_side = padding_side 15 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. 16 | # https://github.com/facebookresearch/llama-recipes/pull/196 17 | if tokenizer.pad_token is None: 18 | tokenizer.pad_token = tokenizer.eos_token 19 | tokenizer.pad_token_id = tokenizer.eos_token_id 20 | if model is not None: 21 | model.config.pad_token_id = tokenizer.pad_token_id 22 | return processor 23 | 24 | def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): 25 | tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) 26 | tokenizer.padding_side = padding_side 27 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. 28 | # https://github.com/facebookresearch/llama-recipes/pull/196 29 | if tokenizer.pad_token is None: 30 | tokenizer.pad_token = tokenizer.eos_token 31 | tokenizer.pad_token_id = tokenizer.eos_token_id 32 | model.config.pad_token_id = tokenizer.pad_token_id 33 | 34 | return tokenizer 35 | 36 | 37 | def get_strategy(args): 38 | from openrlhf.utils.deepspeed import DeepspeedStrategy 39 | 40 | strategy = DeepspeedStrategy( 41 | seed=getattr(args, "seed", 42), 42 | max_norm=getattr(args, "max_norm", 1.0), 43 | micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), 44 | train_batch_size=getattr(args, "train_batch_size", 128), 45 | zero_stage=args.zero_stage, 46 | bf16=getattr(args, "bf16", True), 47 | args=args, 48 | ) 49 | return strategy 50 | 51 | 52 | def blending_datasets( 53 | datasets, 54 | probabilities, 55 | strategy=None, 56 | seed=42, 57 | max_count=5000000, 58 | return_eval=True, 59 | stopping_strategy="first_exhausted", 60 | train_split="train", 61 | eval_split="test", 62 | ): 63 | datasets = datasets.split(",") 64 | probabilities = list(map(float, probabilities.split(","))) 65 | assert len(probabilities) == len(datasets) 66 | 67 | train_data_list = [] 68 | eval_data_list = [] 69 | for i, dataset in enumerate(datasets): 70 | dataset = dataset.strip() 71 | strategy.print(f"dataset: {dataset}") 72 | dp = dataset 73 | data_dir = dataset.split("@")[1].strip() if "@" in dataset else None 74 | dataset = dataset.split("@")[0].strip() 75 | dataset_basename = os.path.basename(dataset) 76 | 77 | ext = os.path.splitext(dataset)[-1] 78 | # local python script 79 | if ext == ".py" or ( 80 | os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py")) 81 | ): 82 | data = load_dataset(dataset, trust_remote_code=True) 83 | strategy.print(f"loaded {dataset} with python script") 84 | # local text file 85 | elif ext in [".json", ".jsonl", ".csv"]: 86 | ext = ext.lower().strip(".") 87 | if ext == "jsonl": 88 | ext = "json" 89 | data = load_dataset(ext, data_files=dataset) 90 | strategy.print(f"loaded {dataset} with data_files={dataset}") 91 | elif dp.endswith('parquet'): 92 | strategy.print(f"loaded parquet: {dp} from files") 93 | data = load_dataset("parquet", data_files=dp) 94 | 95 | # local dataset saved with `datasets.Dataset.save_to_disk` 96 | # elif os.path.isdir(dataset): 97 | # data = load_from_disk(dataset) 98 | # strategy.print(f"loaded {dataset} from disk") 99 | # # remote/local folder or common file 100 | else: 101 | data = load_dataset(dataset, data_dir=data_dir) 102 | strategy.print(f"loaded {dataset} from files") 103 | print(data) 104 | if train_split and train_split in data: 105 | train_data = data[train_split].select(range(min(max_count, len(data[train_split])))) 106 | else: 107 | train_data = data.select(range(min(max_count, len(data)))) 108 | train_data_list.append(train_data) 109 | 110 | if return_eval: 111 | if eval_split and eval_split in data: 112 | eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split])))) 113 | # train will contains eval? TODO 114 | else: 115 | eval_data = train_data.select(range(min(max_count, int(len(train_data) * 0.03)))) 116 | eval_data_list.append(eval_data) 117 | 118 | # merge datasets 119 | if strategy.is_rank_0(): 120 | print(train_data_list) 121 | 122 | train_dataset = interleave_datasets( 123 | train_data_list, 124 | probabilities=probabilities, 125 | seed=seed, 126 | stopping_strategy=stopping_strategy, 127 | ) 128 | if return_eval: 129 | eval_dataset = interleave_datasets( 130 | eval_data_list, 131 | probabilities=probabilities, 132 | seed=seed, 133 | stopping_strategy=stopping_strategy, 134 | ) 135 | return train_dataset, eval_dataset 136 | else: 137 | return train_dataset 138 | 139 | 140 | def convert_token_to_id(token, tokenizer): 141 | if isinstance(token, str): 142 | token = tokenizer.encode(token, add_special_tokens=False) 143 | assert len(token) == 1 144 | return token[0] 145 | else: 146 | raise ValueError("token should be int or str") 147 | 148 | def get_generation_cls(config): 149 | model_type = config.model_type 150 | model_arch = AutoModel._model_mapping[type(config)].__name__ 151 | if model_arch.endswith("ForCausalLM") or \ 152 | model_arch.endswith("ForConditionalGeneration"): 153 | return AutoModel._model_mapping[type(config)] 154 | elif model_arch.endswith("Model"): 155 | possible_arch = [model_arch.replace("Model", "ForCausalLM"), model_arch.replace("Model", "ForConditionalGeneration")] 156 | import importlib 157 | module = importlib.import_module(f".models.{model_type}.modeling_{model_type}",package="transformers") 158 | for arch in possible_arch: 159 | model_cls = getattr(module, arch, None) 160 | if model_cls is not None: 161 | return model_cls 162 | raise ValueError(f"Cannot find ForCausalLM or ForConditionalGeneration class for {model_arch}") 163 | else: 164 | raise ValueError(f"Unexpected model architecture {model_arch}") -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | 18 | 19 | [tool.black] 20 | line_length = 119 21 | 22 | [tool.ruff] 23 | line-length = 119 24 | 25 | [tool.pytest.ini_options] 26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 27 | # -vv will also display tests with durration = 0.00s 28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 30 | # directories to ignore when discovering tests 31 | norecursedirs = [ 32 | "external", 33 | "examples", 34 | "docs", 35 | "scripts", 36 | "tools", 37 | "tutorials", 38 | "*.egg", 39 | ".*", 40 | "_darcs", 41 | "build", 42 | "CVS", 43 | "dist", 44 | "venv", 45 | "{arch}", 46 | ] 47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests 48 | markers = [ 49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 55 | "pleasefixme: marks tests that are broken and need fixing", 56 | ] 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | deepspeed==0.15 5 | einops 6 | flask 7 | isort 8 | jsonlines 9 | loralib 10 | math-verify 11 | levenshtein 12 | optimum 13 | packaging 14 | peft 15 | pynvml>=12.0.0 16 | qwen_vl_utils 17 | ray[default]==2.42.0 18 | tensorboard 19 | torch 20 | torchmetrics 21 | tqdm 22 | transformers @ git+https://github.com/huggingface/transformers@main 23 | transformers_stream_generator 24 | wandb 25 | wheel 26 | -------------------------------------------------------------------------------- /scripts/eval_7b.sh: -------------------------------------------------------------------------------- 1 | benchmark=m3u 2 | if [[ "$benchmark" == "m3u" ]]; then 3 | export testdata="./data/MMMUPro_full.parquet" 4 | elif [[ "$benchmark" == "m3u_val" ]]; then 5 | export testdata="./data/m3u_val.parquet" 6 | elif [[ "$benchmark" == "emma" ]]; then 7 | export factor=4 8 | export testdata="./data/emma_full.parquet" 9 | elif [[ "$benchmark" == "mathverse" ]]; then 10 | export testdata="./data/MathVerse_testmini.parquet" 11 | elif [[ "$benchmark" == "mathvista" ]]; then 12 | export testdata=./data/MathVista_testmini.parquet 13 | elif [[ "$benchmark" == "mathvision" ]]; then 14 | export testdata="./data/MathVision_test3040.parquet" 15 | else 16 | export testdata="./data/${benchmark}.parquet" 17 | fi 18 | 19 | export num_vllm=8 20 | export num_gpus=8 21 | export tagname=eval_debug_${benchmark} 22 | export policy=/path/to/policy 23 | export nvj_path="" 24 | export working_dir=/path/to/dir 25 | bash ./scripts/eval_vlm_new.sh 26 | 27 | -------------------------------------------------------------------------------- /scripts/eval_vlm_new.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | RAY_MASTER_NODE_ADDRESS="0.0.0.0" 4 | RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-65535) 5 | WORLD_SIZE=1 6 | NODE_RANK=0 7 | GPUS_PER_NODE=8 8 | 9 | MASTER_HOST="$VC_WORKER_HOSTS" 10 | MASTER_ADDR="${VC_WORKER_HOSTS%%,*}" 11 | # export NCCL_SOCKET_IFNAME=ens2f5 12 | # export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME} 13 | export NCCL_NET_PLUGIN=none 14 | export NCCL_IB_TIMEOUT=22 15 | export NCCL_IB_RETRY_CNT=15 16 | export NCCL_DEBUG=INFO 17 | export CUDA_LAUNCH_BLOCKING=1 18 | export HOST_IP=0.0.0.0 19 | export VLLM_HOST_IP=0.0.0.0 20 | export WANDB_MODE="offline" 21 | export WANDB_API_KEY="null" 22 | working_dir=${working_dir:"/path/to/VL-Rethinker"} 23 | cd $working_dir 24 | export HF_ENDPOINT=https://hf-mirror.com 25 | nnode=$WORLD_SIZE 26 | testdata=${testdata:-"none"} 27 | num_vllm=${num_vllm:-"4"} 28 | num_gpus=${num_gpus:-"4"} 29 | tp=${tp:-"1"} 30 | actor_ngpus=${actor_ngpus:-"1"} 31 | nsamples=${nsamples:-"1"} 32 | temperature=${temperature:-"0.6"} 33 | factor=${factor:-"1"} 34 | export MIN_PIXELS=$(( 256 * 28 * 28)) 35 | export MAX_PIXELS=$(( 1280 * 28 * 28)) 36 | tag=${tagname} # -n${nsamples} 37 | rule_reward=${rule:-"none"} 38 | sys=${sys:-"default"} 39 | lr=${lr:-"10"} 40 | algo=${algo:-"group"} 41 | dataver=${dataver:-"none"} 42 | util=${util:-"0.7"} 43 | 44 | numref=0 45 | 46 | maxlen=${maxlen:-"8192"} 47 | policy=${policy:-"/path/to/policy"} 48 | save_name="${tag}" # rbsize 1024->256 49 | DATASET=${testdata} 50 | MODEL_CPK_NAME=${save_name} 51 | PRETRAIN_MODEL=${policy} 52 | savefolder=${savefolder:-"eval_results"} 53 | SAVE_PATH=$working_dir/${savefolder}/$save_name 54 | mkdir -p "${SAVE_PATH}" 55 | 56 | # python=/home/ma-user/anaconda3/envs/rethinker/bin/python 57 | # source /home/ma-user/anaconda3/bin/activate 58 | # conda activate rethinker 59 | 60 | 61 | 62 | post_args="" 63 | if [ $nnode -gt 1 ]; then 64 | if [ $nnode -gt 3 ]; then 65 | post_args=(--ref_num_nodes 0 66 | --ref_num_gpus_per_node 8 67 | --actor_num_nodes 16 68 | --actor_num_gpus_per_node 1 69 | --vllm_num_engines 16 70 | --vllm_tensor_parallel_size 1 71 | --micro_train_batch_size 4 72 | --train_batch_size 256 73 | --micro_rollout_batch_size 8 74 | --rollout_batch_size 1024 75 | ) 76 | else 77 | post_args=(--ref_num_nodes 0 78 | --ref_num_gpus_per_node 8 79 | --actor_num_nodes 8 80 | --actor_num_gpus_per_node 1 81 | --vllm_num_engines 8 82 | --vllm_tensor_parallel_size 1 83 | --micro_train_batch_size 4 84 | --train_batch_size 256 85 | --micro_rollout_batch_size 8 86 | --rollout_batch_size 1024 87 | ) 88 | fi 89 | else 90 | post_args=(--ref_num_nodes 0 91 | --ref_num_gpus_per_node 8 92 | --actor_num_nodes 0 93 | --actor_num_gpus_per_node ${actor_ngpus} 94 | --vllm_num_engines ${num_vllm} 95 | --vllm_tensor_parallel_size ${tp} 96 | --adam_offload 97 | --micro_train_batch_size 4 98 | --train_batch_size 256 99 | --micro_rollout_batch_size $(( 64 * ${num_vllm} / ${nsamples} / ${factor})) 100 | --rollout_batch_size 1024 101 | ) 102 | fi 103 | 104 | LD_LIBRARY_PATH_VALUE=$nvj_path:$LD_LIBRARY_PATH 105 | 106 | RUNTIME_ENV_JSON="{\"env_vars\": {\"RAY_DEBUG\": \"legacy\", \"LD_LIBRARY_PATH\": \"$LD_LIBRARY_PATH_VALUE\"}}" 107 | 108 | 109 | ray_output=$(ray start --head --num-gpus ${num_gpus}) 110 | 111 | 112 | ray status 113 | ray job submit --address="http://127.0.0.1:8265" \ 114 | --runtime-env-json="$RUNTIME_ENV_JSON" \ 115 | -- python3 -m openrlhf.cli.eval_ray \ 116 | --vllm_enable_sleep \ 117 | --vllm_gpu_memory_utilization ${util} \ 118 | --vllm_sync_backend gloo \ 119 | --enable_prefix_caching \ 120 | --pretrain $PRETRAIN_MODEL \ 121 | --save_path $SAVE_PATH \ 122 | --n_samples_per_prompt ${nsamples} \ 123 | --max_epochs 1 \ 124 | --num_episodes 3 \ 125 | --prompt_max_len 2048 \ 126 | --max_samples 100000 \ 127 | --generate_max_len ${maxlen} \ 128 | --advantage_estimator ${algo} \ 129 | --zero_stage 3 \ 130 | --bf16 \ 131 | --actor_learning_rate ${lr}e-7 \ 132 | --rule_reward ${rule_reward} \ 133 | --temperature 1.0 \ 134 | --top_p 0.95 \ 135 | --init_kl_coef 0.0 \ 136 | --aux_loss_coef 0.05 \ 137 | --entropy_loss_coef 0.0 \ 138 | --prompt_data $DATASET \ 139 | --input_key question \ 140 | --apply_chat_template \ 141 | --normalize_reward \ 142 | --data_version ${dataver} \ 143 | --flash_attn \ 144 | --gradient_checkpointing \ 145 | --ckpt_path $SAVE_PATH \ 146 | --save_steps 5 \ 147 | --max_ckpt_num 5 \ 148 | --save_hf_ckpt \ 149 | --disable_ds_ckpt \ 150 | --use_wandb $WANDB_API_KEY \ 151 | --wandb_run_name $save_name \ 152 | --system_prompt ${sys} \ 153 | --use_kl_estimator_k3 \ 154 | --wandb_project vlm-rl-eval \ 155 | --buffer_norm 0 \ 156 | --train_vlm \ 157 | --training_mode eval_only \ 158 | --eval_data ${testdata} \ 159 | ${post_args[@]} -------------------------------------------------------------------------------- /scripts/train_vlm_multi.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | 4 | find_interface() { 5 | local ip_output=$(ip addr show | head -n 10) # Limit to first 10 lines 6 | local selected_interface="" 7 | 8 | # Debug output (can be removed in final version) 9 | # echo "--- First 10 lines of ip addr show output: ---" 10 | # echo "$ip_output" 11 | # echo "--- End of ip addr show output ---" 12 | 13 | while IFS= read -r line; do 14 | # Debug output (can be removed in final version) 15 | # echo "Processing line: $line" 16 | 17 | if [[ "$line" =~ ^[0-9]+:\ ([^:]+):\ \<.*UP.*\> ]]; then 18 | local interface_name="${BASH_REMATCH[1]}" 19 | # Debug output (can be removed in final version) 20 | # echo " Interface found: $interface_name" 21 | local interface_up=true 22 | local is_loopback=false 23 | 24 | if [[ "$interface_name" == "lo" ]]; then 25 | is_loopback=true 26 | # Debug output (can be removed in final version) 27 | # echo " Interface '$interface_name' is loopback. Skipping." 28 | fi 29 | 30 | if $is_loopback; then 31 | continue # Skip loopback interface 32 | fi 33 | 34 | # Look for inet lines within this interface block 35 | while IFS= read -r subnet_line; do 36 | # Debug output (can be removed in final version) 37 | # echo " Processing subnet line: $subnet_line" 38 | if [[ "$subnet_line" =~ inet\ ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)/([0-9]+)\ .*scope\ ([^ ]+) ]]; then 39 | local ip_address="${BASH_REMATCH[1]}" 40 | local scope="${BASH_REMATCH[3]}" 41 | # Debug output (can be removed in final version) 42 | # echo " Found inet line: IP Address: $ip_address, Scope: $scope" 43 | 44 | # Exclude loopback IPs and docker0/bridge related IPs by IP range 45 | if [[ "$ip_address" =~ ^127\. ]]; then 46 | # Debug output (can be removed in final version) 47 | # echo " IP '$ip_address' is loopback. Skipping." 48 | continue # Skip 127.0.0.0/8 loopback IPs (although 'lo' should already be skipped) 49 | elif [[ "$ip_address" =~ ^169\.254\. ]]; then 50 | # Debug output (can be removed in final version) 51 | # echo " IP '$ip_address' is link-local (169.254.x.x). Skipping." 52 | continue # Skip 169.254.0.0/16 link-local IPs (like docker0 often has) 53 | fi 54 | 55 | local is_private_ip=false 56 | if [[ "$ip_address" =~ ^10\.([0-9]{1,3}\.){2}[0-9]{1,3}$ ]] || 57 | [[ "$ip_address" =~ ^172\.(1[6-9]|2[0-9]|3[0-1])\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]] || 58 | [[ "$ip_address" =~ ^192\.168\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]]; then 59 | is_private_ip=true 60 | # Debug output (can be removed in final version) 61 | # echo " IP '$ip_address' is a private IP." 62 | # else 63 | # Debug output (can be removed in final version) 64 | # echo " IP '$ip_address' is NOT a private IP." 65 | fi 66 | 67 | if $is_private_ip || [[ "$scope" == "global" ]]; then # Consider private or global scope interfaces 68 | selected_interface="$interface_name" 69 | # Debug output (can be removed in final version) 70 | # echo " Interface '$interface_name' with IP '$ip_address' and scope '$scope' is selected." 71 | # echo "export GLOO_SOCKET_IFNAME=$selected_interface" 72 | # exit 0 # Exit immediately after finding the first suitable interface for debugging (removed for function) 73 | break 2 # Found a suitable interface! Break out of both inner and outer loops 74 | # else 75 | # Debug output (can be removed in final version) 76 | # echo " Interface '$interface_name' with IP '$ip_address' and scope '$scope' is NOT suitable (not private or global)." 77 | fi 78 | fi 79 | done < <(echo "$ip_output" | sed -n "/$interface_name: /,/^[0-9]\+:/p" | sed '$d' ) # Extract lines belonging to current interface block 80 | if [[ -n "$selected_interface" ]]; then # Check if selected_interface is not empty, if so, interface found and loops broken. 81 | # Debug output (can be removed in final version) 82 | # echo " Selected interface '$selected_interface' already found. Breaking outer loop." 83 | break # Already found and assigned an interface, break outer loop as well. 84 | fi 85 | # else 86 | # Debug output (can be removed in final version) 87 | # echo " Line does not match interface pattern." 88 | fi 89 | done < <(echo "$ip_output") 90 | 91 | if [[ -n "$selected_interface" ]]; then 92 | echo "$selected_interface" 93 | else 94 | echo "" # Return empty string if no interface is found, so export GLOO_SOCKET_IFNAME= (empty) 95 | # echo "No suitable network interface could be automatically identified for GLOO_SOCKET_IFNAME." # No longer print error message to stderr in function context 96 | # return 1 # Optionally, you could return a non-zero exit code if you need to check for failure. 97 | fi 98 | } 99 | 100 | MULTINODE_FLAG=True 101 | if [ -v MULTINODE_FLAG ]; then 102 | # Define a string 103 | 104 | # Set the IFS (Internal Field Separator) to space 105 | IFS=',' 106 | 107 | WORLD_SIZE=${MA_NUM_HOSTS:-"1"} 108 | export RAY_MASTER_NODE_ADDRESS=${myvar[(($WORLD_SIZE-1))]} 109 | export RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-40000) 110 | 111 | NODE_RANK="" 112 | GPUS_PER_NODE="" 113 | 114 | else 115 | RAY_MASTER_NODE_ADDRESS="0.0.0.0" 116 | RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-65535) 117 | WORLD_SIZE=1 118 | NODE_RANK=0 119 | GPUS_PER_NODE=8 120 | fi 121 | MASTER_HOST="$VC_WORKER_HOSTS" 122 | MASTER_ADDR="${VC_WORKER_HOSTS%%,*}" 123 | # export NCCL_SOCKET_IFNAME=ens2f5 124 | # export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME} 125 | export NCCL_NET_PLUGIN=none 126 | export NCCL_IB_TIMEOUT=22 127 | export NCCL_IB_RETRY_CNT=15 128 | export NCCL_DEBUG=INFO 129 | export CUDA_LAUNCH_BLOCKING=1 130 | 131 | export HOST_IP=0.0.0.0 132 | export VLLM_HOST_IP=0.0.0.0 133 | 134 | working_dir=/path/to/workdir 135 | cd $working_dir 136 | export HF_ENDPOINT=https://hf-mirror.com 137 | export WANDB_API_KEY="" 138 | nnode=$WORLD_SIZE 139 | tagname=${tagname:-""} 140 | dataver=${dataver:-"none"} 141 | tag=qw-vl7b-${trainver}-${tagname} 142 | rule_reward=${rule:-"none"} 143 | sys=${sys:-"default"} 144 | lr=${lr:-"10"} 145 | algo=${algo:-"group_sft"} 146 | temperature=${temperature:-"1.0"} 147 | numref=0 148 | fmt=${fmt:-"none"} 149 | bsz=${bsz:-"512"} 150 | rbuffer=${bsz:-"1024"} 151 | nsamples=${nsamples:-"8"} 152 | mbsz=${mbsz:-"4"} 153 | maxlen=${maxlen:-"6144"} 154 | lossver=${lossver:-"none"} 155 | mode=${mode:-"none"} 156 | nactor=${nactor:-"16"} 157 | nvllm=${nvllm:-"8"} 158 | filter=${filter:-"None"} 159 | repeat=${repeat:-"0"} 160 | nepoch=${nepoch:-"3"} 161 | logp_bsz=${logp_bsz:-"8"} 162 | maxtoken=${maxtoken:-"2048"} 163 | tp=${tp:-"1"} 164 | aux=${aux:-"0.05"} 165 | evalsteps=${evalsteps:-"0"} 166 | save_name="${tag}-${bsz}-lossver${lossver}-samplever${dataver}-fmt${fmt}-${algo}-n${nsamples}-ml${maxlen}-lr${lr}-sys${sys}-${nnode}node" # rbsize 1024->256 167 | 168 | DATASET=/path/to/train.parquet 169 | MODEL_CPK_NAME=${save_name} 170 | PRETRAIN_MODEL=${policy} 171 | testdata="/path/to/test.parquet" 172 | SAVE_PATH=$working_dir/saves/$save_name 173 | mkdir -p "${SAVE_PATH}" 174 | # pip install -U deepspeed==0.15.0 # https://github.com/OpenRLHF/OpenRLHF/issues/776#issuecomment-2694472824 175 | # 176 | 177 | 178 | post_args="" 179 | if [ $nnode -gt 1 ]; then 180 | 181 | post_args=(--ref_num_nodes 0 182 | --ref_num_gpus_per_node 8 183 | --actor_num_nodes ${nactor} 184 | --actor_num_gpus_per_node 8 185 | --vllm_num_engines ${nvllm} 186 | --vllm_tensor_parallel_size ${tp} 187 | --micro_train_batch_size ${mbsz} 188 | --train_batch_size ${bsz} 189 | --micro_rollout_batch_size ${logp_bsz} 190 | --rollout_batch_size ${rbuffer} 191 | ) 192 | 193 | else 194 | post_args=(--ref_num_nodes 0 195 | --ref_num_gpus_per_node 8 196 | --actor_num_nodes 4 197 | --actor_num_gpus_per_node 1 198 | --vllm_num_engines 4 199 | --vllm_tensor_parallel_size 1 200 | --adam_offload 201 | --micro_train_batch_size 4 202 | --train_batch_size ${bsz} 203 | --micro_rollout_batch_size 4 204 | --rollout_batch_size ${rbuffer} 205 | ) 206 | fi 207 | # :/usr/local/cuda/targets/x86_64-linux/lib 208 | LD_LIBRARY_PATH_VALUE=/path/to/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH 209 | export BNB_CUDA_VERSION=122 210 | RUNTIME_ENV_JSON="{\"env_vars\": {\"LD_LIBRARY_PATH\": \"$LD_LIBRARY_PATH_VALUE\"}}" 211 | 212 | 213 | if [ "$NODE_RANK" = "0" ]; then 214 | # Start Ray head node and capture the output 215 | ray_output=$(ray start --head --num-gpus 8) 216 | 217 | # Extract the IP address using grep and sed 218 | ip_address=$(echo "$ray_output" | grep -oP "ray start --address='\K[^']+") 219 | 220 | # Write the extracted IP address to a file named "ip.txt" 221 | mkdir -p ip_tmp 222 | echo "$ip_address" > ip_tmp/ip_${tagname}.txt 223 | cat ip_tmp/ip_${tagname}.txt 224 | 225 | 226 | 227 | if [ $nnode -gt 1 ]; then 228 | # Example usage (to set the environment variable): 229 | export GLOO_SOCKET_IFNAME=$(find_interface) 230 | echo "$GLOO_SOCKET_IFNAME" > ip_tmp/gloo_${tagname}.txt 231 | sleep 60 232 | else 233 | unset GLOO_SOCKET_IFNAME 234 | unset NCLL_SOCKET_IFNAME 235 | fi 236 | ray status 237 | ray job submit --address="http://127.0.0.1:8265" \ 238 | --runtime-env-json="$RUNTIME_ENV_JSON" \ 239 | -- python3 -m openrlhf.cli.train_ppo_ray \ 240 | --vllm_enable_sleep \ 241 | --vllm_gpu_memory_utilization 0.85 \ 242 | --vllm_sync_backend gloo \ 243 | --pretrain $PRETRAIN_MODEL \ 244 | --save_path $SAVE_PATH \ 245 | --n_samples_per_prompt ${nsamples} \ 246 | --max_epochs 1 \ 247 | --num_episodes ${nepoch} \ 248 | --filter ${filter} \ 249 | --prompt_max_len 2048 \ 250 | --max_out_tokens ${maxtoken} \ 251 | --max_samples 100000 \ 252 | --generate_max_len ${maxlen} \ 253 | --advantage_estimator ${algo} \ 254 | --zero_stage 3 \ 255 | --controlled_shuffle ${repeat} \ 256 | --bf16 \ 257 | --actor_learning_rate ${lr}e-7 \ 258 | --rule_reward ${rule_reward} \ 259 | --temperature 1.0 \ 260 | --val_temperature 0.6 \ 261 | --top_p 0.95 \ 262 | --training_mode ${mode} \ 263 | --init_kl_coef 0.0 \ 264 | --aux_loss_coef ${aux} \ 265 | --entropy_loss_coef 0.0 \ 266 | --prompt_data $DATASET \ 267 | --input_key question \ 268 | --apply_chat_template \ 269 | --normalize_reward \ 270 | --flash_attn \ 271 | --gradient_checkpointing \ 272 | --ckpt_path $SAVE_PATH \ 273 | --save_steps 3 \ 274 | --eval_steps ${evalsteps} \ 275 | --max_ckpt_num 3 \ 276 | --save_hf_ckpt \ 277 | --disable_ds_ckpt \ 278 | --disable_fast_tokenizer \ 279 | --use_wandb $WANDB_API_KEY \ 280 | --wandb_run_name $save_name \ 281 | --system_prompt ${sys} \ 282 | --use_kl_estimator_k3 \ 283 | --wandb_project vlm-rl \ 284 | --buffer_norm 0 \ 285 | --train_vlm \ 286 | --filter ${filter} \ 287 | --eval_data ${testdata} \ 288 | --data_version ${dataver} \ 289 | --loss_version ${lossver} \ 290 | --format ${fmt} \ 291 | ${post_args[@]} 292 | # --train_vlm 293 | else 294 | sleep 15 295 | # Read the IP address from the file and assign it to the variable "head_ip" 296 | head_ip=$(cat ip_tmp/ip_${tagname}.txt) 297 | gloo=$(cat ip_tmp/gloo_${tagname}.txt) 298 | export GLOO_SOCKET_IFNAME=$gloo 299 | echo "gloo: $GLOO_SOCKET_IFNAME" 300 | # Print the value of head_ip for verification 301 | echo "Head IP Address: $head_ip" 302 | 303 | ray start --address ${head_ip} 304 | # echo $HOST_IP 305 | fi -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | 5 | from datetime import datetime 6 | from setuptools import find_packages, setup 7 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 8 | 9 | _build_mode = os.getenv("OPENRLHF_BUILD_MODE", "") 10 | 11 | 12 | def _is_nightly(): 13 | return _build_mode.lower() == "nightly" 14 | 15 | 16 | def _fetch_requirements(path): 17 | with open(path, "r") as fd: 18 | return [r.strip() for r in fd.readlines()] 19 | 20 | 21 | def _fetch_readme(): 22 | with open("README.md", encoding="utf-8") as f: 23 | return f.read() 24 | 25 | 26 | def _fetch_version(): 27 | with open("version.txt", "r") as f: 28 | version = f.read().strip() 29 | 30 | if _is_nightly(): 31 | now = datetime.now() 32 | date_str = now.strftime("%Y%m%d") 33 | version += f".dev{date_str}" 34 | 35 | return version 36 | 37 | 38 | def _fetch_package_name(): 39 | return "openrlhf-nightly" if _is_nightly() else "openrlhf" 40 | 41 | 42 | # Custom wheel class to modify the wheel name 43 | class bdist_wheel(_bdist_wheel): 44 | def finalize_options(self): 45 | _bdist_wheel.finalize_options(self) 46 | self.root_is_pure = False 47 | 48 | def get_tag(self): 49 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 50 | abi_tag = f"{python_version}" 51 | 52 | if platform.system() == "Linux": 53 | platform_tag = "manylinux1_x86_64" 54 | else: 55 | platform_tag = platform.system().lower() 56 | 57 | return python_version, abi_tag, platform_tag 58 | 59 | 60 | # Setup configuration 61 | setup( 62 | author="OpenRLHF Team", 63 | name=_fetch_package_name(), 64 | version=_fetch_version(), 65 | packages=find_packages( 66 | exclude=( 67 | "data", 68 | "docs", 69 | "examples", 70 | ) 71 | ), 72 | description="A Ray-based High-performance RLHF framework.", 73 | long_description=_fetch_readme(), 74 | long_description_content_type="text/markdown", 75 | install_requires=_fetch_requirements("requirements.txt"), 76 | extras_require={ 77 | "vllm": ["vllm==0.7.2"], 78 | "vllm_latest": ["vllm>0.7.2"], 79 | }, 80 | python_requires=">=3.10", 81 | classifiers=[ 82 | "Programming Language :: Python :: 3.10", 83 | "Programming Language :: Python :: 3.11", 84 | "Environment :: GPU :: NVIDIA CUDA", 85 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 86 | "Topic :: System :: Distributed Computing", 87 | ], 88 | cmdclass={"bdist_wheel": bdist_wheel}, 89 | ) 90 | --------------------------------------------------------------------------------