├── .gitignore
├── LICENSE
├── README.md
├── assets
├── overview-2.jpg
└── overview-2.pdf
├── installation.md
├── openrlhf
├── __init__.py
├── cli
│ ├── __init__.py
│ ├── eval_ray.py
│ └── train_ppo_ray.py
├── datasets
│ ├── __init__.py
│ ├── prompts_dataset.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── actor.py
│ ├── loss.py
│ ├── model.py
│ ├── ring_attn_utils.py
│ └── utils.py
├── trainer
│ ├── __init__.py
│ ├── evaluator.py
│ ├── ppo_trainer.py
│ ├── ppo_utils
│ │ ├── __init__.py
│ │ ├── data_processor.py
│ │ ├── experience_maker.py
│ │ ├── kl_controller.py
│ │ └── replay_buffer.py
│ └── ray
│ │ ├── __init__.py
│ │ ├── evaluator2.py
│ │ ├── launcher.py
│ │ ├── ppo_actor.py
│ │ ├── ppo_critic.py
│ │ ├── utils.py
│ │ ├── vllm_engine.py
│ │ └── vllm_worker_wrap.py
└── utils
│ ├── __init__.py
│ ├── deepspeed
│ ├── __init__.py
│ ├── deepspeed.py
│ └── deepspeed_utils.py
│ ├── distributed_sampler.py
│ ├── distributed_util.py
│ ├── logging_utils.py
│ ├── processor.py
│ └── utils.py
├── pyproject.toml
├── requirements.txt
├── scripts
├── eval_7b.sh
├── eval_vlm_new.sh
└── train_vlm_multi.sh
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | share/python-wheels/
24 | *.egg-info/
25 | .installed.cfg
26 | *.egg
27 | MANIFEST
28 |
29 | # PyInstaller
30 | # Usually these files are written by a python script from a template
31 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
32 | *.manifest
33 | *.spec
34 |
35 | # Installer logs
36 | pip-log.txt
37 | pip-delete-this-directory.txt
38 |
39 | # Unit test / coverage reports
40 | htmlcov/
41 | .tox/
42 | .nox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | *.py,cover
50 | .hypothesis/
51 | .pytest_cache/
52 | cover/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | .pybuilder/
76 | target/
77 |
78 | # Jupyter Notebook
79 | .ipynb_checkpoints
80 |
81 | # IPython
82 | profile_default/
83 | ipython_config.py
84 |
85 | # pyenv
86 | # For a library or package, you might want to ignore these files since the code is
87 | # intended to run in multiple environments; otherwise, check them in:
88 | # .python-version
89 |
90 | # pipenv
91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
94 | # install all needed dependencies.
95 | #Pipfile.lock
96 |
97 | # UV
98 | # Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
99 | # This is especially recommended for binary packages to ensure reproducibility, and is more
100 | # commonly ignored for libraries.
101 | #uv.lock
102 |
103 | # poetry
104 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
105 | # This is especially recommended for binary packages to ensure reproducibility, and is more
106 | # commonly ignored for libraries.
107 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
108 | #poetry.lock
109 |
110 | # pdm
111 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
112 | #pdm.lock
113 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
114 | # in version control.
115 | # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
116 | .pdm.toml
117 | .pdm-python
118 | .pdm-build/
119 |
120 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
121 | __pypackages__/
122 |
123 | # Celery stuff
124 | celerybeat-schedule
125 | celerybeat.pid
126 |
127 | # SageMath parsed files
128 | *.sage.py
129 |
130 | # Environments
131 | .env
132 | .venv
133 | env/
134 | venv/
135 | ENV/
136 | env.bak/
137 | venv.bak/
138 |
139 | # Spyder project settings
140 | .spyderproject
141 | .spyproject
142 |
143 | # Rope project settings
144 | .ropeproject
145 |
146 | # mkdocs documentation
147 | /site
148 |
149 | # mypy
150 | .mypy_cache/
151 | .dmypy.json
152 | dmypy.json
153 |
154 | # Pyre type checker
155 | .pyre/
156 |
157 | # pytype static type analyzer
158 | .pytype/
159 |
160 | # Cython debug symbols
161 | cython_debug/
162 |
163 | # PyCharm
164 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
165 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
166 | # and can be added to the global gitignore or merged into this file. For a more nuclear
167 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
168 | #.idea/
169 |
170 | # Ruff stuff:
171 | .ruff_cache/
172 |
173 | # PyPI configuration file
174 | .pypirc
175 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # VL-Rethinker: Incentivizing Self-Reflection of Vision-Language Models with Reinforcement Learning
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | Authors:
20 | Haozhe Wang,
21 | Chao Qu,
22 | Zuming Huang,
23 | Wei Chu,
24 | Fangzhen Lin,
25 | Wenhu Chen
26 |
27 | ## 🔥News
28 |
29 | - [2025/4/22] We release the dataset [🤗 ViRL39K](https://huggingface.co/datasets/TIGER-Lab/ViRL39K). It covers **comprehensive collection** of 39K queries including **eight categories**, and provides fine-grained **model-capability annotations** for data selection.
30 |
31 |
32 | ## Overview
33 | 
34 |
35 | Abstract
36 | Recently, slow-thinking systems like GPT-o1 and DeepSeek-R1 have demonstrated great potential in solving challenging problems through explicit reflection. They significantly outperform the best fast-thinking models, such as GPT-4o, on various math and science benchmarks. However, their multimodal reasoning capabilities remain on par with fast-thinking models. For instance, GPT-o1's performance on benchmarks like MathVista, MathVerse, and MathVision is similar to fast-thinking models. In this paper, we aim to enhance the slow-thinking capabilities of vision-language models using reinforcement learning (without relying on distillation) to advance the state of the art. First, we adapt the GRPO algorithm with a novel technique called Selective Sample Replay (SSR) to address the vanishing advantages problem. While this approach yields strong performance, the resulting RL-trained models exhibit limited self-reflection or self-verification.
37 | To further encourage slow-thinking, we introduce Forced Rethinking, which appends a textual rethinking trigger to the end of initial rollouts in RL training, explicitly enforcing a self-reflection reasoning step. By combining these two techniques, our model, \model, advances state-of-the-art scores on MathVista, MathVerse, and MathVision to achieve significantly to achieve 80.3\%, 61.8\% and 43.9\% respectively. \model also achieves open-source SoTA on multi-disciplinary benchmarks such as MMMU-Pro, EMMA, and MEGA-Bench, narrowing the gap with GPT-o1. Our empirical results show the effectiveness of our approaches.
38 |
39 |
40 |
41 | ## Release Progress
42 | - [x] models.
43 | - [x] data.
44 | - [ ] inference and evaluation code.
45 | - [x] training code.
46 |
47 | ### Dataset
48 | **[ViRL39K](https://huggingface.co/datasets/TIGER-Lab/ViRL39K)** lays the foundation for our RL training. It has the following merits:
49 | - **high-quality** and **verifiable**: the QAs undergo rigorous filtering and quality control, removing problematic queries or ones that cannot be verified by rules.
50 | - covering **comprehensive** topics and categories: from grade school problems to broader STEM and Social topics; reasoning with charts, diagrams, tables, documents, spatial relationships, etc.
51 | - with fine-grained **model-capability annotations**: it tells you what queries to use when training models at different scales.
52 |
53 |
54 | ### RL-ed Models
55 | - [VL-Rethinker-7B](https://huggingface.co/TIGER-Lab/VL-Rethinker-7B): undergoes the proposed SSR and Forced Rethinking training from Qwen2.5-VL-7B-Instruct.
56 | - [VL-Rethinker-72B](https://huggingface.co/TIGER-Lab/VL-Rethinker-72B): undergoes the proposed SSR and Forced Rethinking training from Qwen2.5-VL-72B-Instruct.
57 |
58 | We are training 32B and further enhancing these models. Stay Tuned!
59 |
60 |
61 | ## Performance
62 | See our [website](https://tiger-ai-lab.github.io/VL-Rethinker/) or [paper](https://arxiv.org/abs/2504.08837) for detailed performance report.
63 |
64 |
65 | ## Selective Sample Replay (SSR)
66 |
67 | Training 72B models on publicly collected queries reveals "vanishing advantages," a phenomenon where rapid saturation in large models drastically reduces effective training samples. The concurrent work [DAPO](https://arxiv.org/abs/2503.14476) on LLMs, made a similar observation.
68 |
69 | DAPO combats this by filtering ineffective queries for gradient stability.Different from this gradient perspective, our method, Selective Sample Replay (SSR), takes an active learning perspective. Drawing a similar merit from Prioritized Experience Replay, SSR re-arranges training samples based on their informativeness -- examples with high advantages, which lie near the model's capability limits (i.e., correct responses to queries the model likely fails), are particularly informative. This active selection focuses training on samples most likely to contribute to model improvement, thereby pushing training efficiency.
70 |
71 | The implementation for SSR is also simple. In addition to code in `active_sampling() @openrlhf/trainer/ppo_utils/replay_buffer.py`. Here is a pseudocode for the key idea of SSR.
72 | ```python
73 | effective_qas = rule_out_zero(candidates)
74 | p = normalize_adv(effective_qas, alpha=1)
75 | selection = np.random.choice(np.arange(len(effective_qas)), size=size, p=p))
76 | ```
77 |
78 | Note: For different scenarios, e.g., on-policy or off-policy, the choice of `candidates`, `size` can be different.
79 |
80 | ## Inference
81 | Our models are established on top of the Qwen2.5-VL family. So we include a simple use case here, and refer the readers to [the standard inference procedure of Qwen2.5-VL](https://github.com/QwenLM/Qwen2.5-VL).
82 |
83 |
84 | ```python
85 | from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
86 | from qwen_vl_utils import process_vision_info
87 |
88 | # default: Load the model on the available device(s)
89 | model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
90 | "TIGER-Lab/VL-Rethinker-7B", torch_dtype="auto", device_map="auto"
91 | )
92 |
93 | # We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
94 | # model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
95 | # "Qwen/Qwen2.5-VL-7B-Instruct",
96 | # torch_dtype=torch.bfloat16,
97 | # attn_implementation="flash_attention_2",
98 | # device_map="auto",
99 | # )
100 |
101 | # default processor
102 | # processor = AutoProcessor.from_pretrained("TIGER-Lab/VL-Rethinker-7B")
103 |
104 |
105 | min_pixels = 256*28*28
106 | max_pixels = 1280*28*28
107 | processor = AutoProcessor.from_pretrained("TIGER-Lab/VL-Rethinker-7B", min_pixels=min_pixels, max_pixels=max_pixels)
108 |
109 | messages = [
110 | {
111 | "role": "user",
112 | "content": [
113 | {
114 | "type": "image",
115 | "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
116 | },
117 | {"type": "text", "text": "Describe this image."},
118 | ],
119 | }
120 | ]
121 |
122 | # Preparation for inference
123 | text = processor.apply_chat_template(
124 | messages, tokenize=False, add_generation_prompt=True
125 | )
126 | image_inputs, video_inputs = process_vision_info(messages)
127 | inputs = processor(
128 | text=[text],
129 | images=image_inputs,
130 | videos=video_inputs,
131 | padding=True,
132 | return_tensors="pt",
133 | )
134 | inputs = inputs.to(model.device)
135 |
136 | # Inference: Generation of the output
137 | generated_ids = model.generate(**inputs, max_new_tokens=128)
138 | generated_ids_trimmed = [
139 | out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
140 | ]
141 | output_text = processor.batch_decode(
142 | generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
143 | )
144 | print(output_text)
145 |
146 | ```
147 |
148 | **Important Notes**:
149 |
150 | Based on the training configurations of the VL-Rethinker family, it's recommended to:
151 | - *Prompt*:
152 |
153 | append `\n\nPlease reason step by step, and put your final answer within \\boxed{}` after the use queries.
154 | - *Resolutions*:
155 | ```
156 | min_pixels = 256*28*28
157 | max_pixels = 1280*28*28
158 | ```
159 |
160 |
161 | ## 🚀Quick Start
162 | The proposed algorithm is implemented with the [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) framework.
163 |
164 | ### Installations
165 | Please see [the installation instructions](installation.md).
166 |
167 | ### Evaluation
168 | Our models can be evaluated like Qwen2.5-VL using [lmms_eval](https://github.com/EvolvingLMMs-Lab/lmms-eval).
169 |
170 | Here we provide an alternative evaluation approach. It offers the following benefits:
171 | - Fast: Batch inference using vLLM for 1K queries on 8 A800 within 30 mins.
172 | - Convenient: Evaluation without time-consuming API calls. Judgement made by our rule-based functions align with LLM Judges.
173 | - Train-Test Aligned: the evaluation re-uses the correctness judgement of training to minimize the gap between training and test-time evaluation.
174 |
175 | The evaluation is integrated with the OpenRLHF framework.
176 | ```bash
177 | bash ./scripts/eval_7b.sh [benchmark] [modelname] [modelpath]
178 | ```
179 | **Note: for MMMU-Val we cannot reproduce Qwen2.5-VL with neither lmms_eval, vlmevalkit or our native evaluation. We greatly appreciate it if you could provide any insights into the correct means of reproducing it.**
180 |
181 |
182 | ### Training
183 | Run the following.
184 | ```bash
185 | bash ./scripts/train_vlm_multi.sh
186 | ```
187 |
188 |
189 | ## Acknowledgement
190 | This project adapts from [OpenRLHF](https://github.com/OpenRLHF/OpenRLHF) and [LMM-R1](https://github.com/TideDra/lmm-r1), released under the Apache License 2.0. Thanks for their open-source contributions!
191 |
192 | ## Citation
193 | If you find this work useful, please give us a free cite:
194 | ```bibtex
195 | @article{vl-rethinker,
196 | title={VL-Rethinker: Incentivizing Self-Reflection of Vision-Language Models with Reinforcement Learning},
197 | author = {Wang, Haozhe and Qu, Chao and Huang, Zuming and Chu, Wei and Lin, Fangzhen and Chen, Wenhu},
198 | journal={arXiv preprint arXiv:2504.08837},
199 | year={2025}
200 | }
201 | ```
202 |
--------------------------------------------------------------------------------
/assets/overview-2.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/assets/overview-2.jpg
--------------------------------------------------------------------------------
/assets/overview-2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/assets/overview-2.pdf
--------------------------------------------------------------------------------
/installation.md:
--------------------------------------------------------------------------------
1 | ### Installation
2 |
3 | ```bash
4 | cd VL-Rethinker
5 | conda create -n rethinker python=3.10
6 | pip install -e .[vllm]
7 | pip install flash_attn --no-build-isolation
8 | ```
9 |
10 | Note: vLLM >=0.7.2 is recommended.
11 |
12 | Note: If you will use multi-node training, downgrade DeepSpeed to 0.15.0.
13 | reference: https://github.com/OpenRLHF/OpenRLHF/issues/776#issuecomment-2694472824
14 |
15 | ### Workarounds
16 | At the time of this project, some bugs still linger around using flash-attn and vLLM for Qwen2.5-VL. The following are solutions from the community:
17 | 1. to fix flash-attn issues
18 | ```
19 | export LD_LIBRARY_PATH=/path/to/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
20 | ```
21 | reference: https://github.com/pytorch/pytorch/issues/111469#issuecomment-1869208750
22 |
23 |
24 | 2. to fix qwen-vl preprocessor issues: modify preprocessor_config.json
25 |
26 | reference:
27 | - https://github.com/huggingface/transformers/issues/36193#issuecomment-2661278628
28 | - https://github.com/huggingface/transformers/issues/36246
29 |
--------------------------------------------------------------------------------
/openrlhf/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/openrlhf/__init__.py
--------------------------------------------------------------------------------
/openrlhf/cli/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/TIGER-AI-Lab/VL-Rethinker/dd2c17d149a5939314690c59a804b817b7d422df/openrlhf/cli/__init__.py
--------------------------------------------------------------------------------
/openrlhf/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # from .process_reward_dataset import ProcessRewardDataset
2 | from .prompts_dataset import PromptDataset
3 | # from .reward_dataset import RewardDataset
4 | # from .sft_dataset import SFTDataset
5 | # from .unpaired_preference_dataset import UnpairedPreferenceDataset
6 |
7 | __all__ = [
8 | # "ProcessRewardDataset",
9 | "PromptDataset",
10 | # "RewardDataset",
11 | # "SFTDataset",
12 | # "UnpairedPreferenceDataset"
13 | ]
14 |
--------------------------------------------------------------------------------
/openrlhf/datasets/prompts_dataset.py:
--------------------------------------------------------------------------------
1 | # /*
2 | # * Modified by Haozhe Wang in 2025
3 | # *
4 | # * Licensed under the Apache License, Version 2.0 (the "License");
5 | # */
6 | from torch.utils.data import Dataset
7 | from tqdm import tqdm
8 | import json
9 |
10 | def preprocess_data(data, input_template=None, input_key="input", apply_chat_template=None) -> str:
11 | if apply_chat_template:
12 | chat = data[input_key]
13 | if isinstance(chat, str):
14 | chat = [{"role": "user", "content": chat}]
15 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
16 | else:
17 | prompt = data[input_key]
18 | if input_template:
19 | prompt = input_template.format(prompt)
20 |
21 | return prompt
22 |
23 |
24 | templates = dict(longcot="""
25 | You are a thoughtful and diligent student tasked with solving a problem. As you work through the problem, document your thought process in a reflective, first-person narrative. Think of yourself as talking to yourself through each step. Consider each step carefully, question your reasoning, and adjust as needed to arrive at a sound solution. Here's how you should proceed:
26 |
27 | 1. **Step-by-Step Analysis**: Start by thoroughly understanding the problem. Identify what is provided and what is being asked. Consider high-level strategies or approaches first, and then break them down into smaller, manageable steps. Ensure you address each component one at a time and do not skip over any details.
28 |
29 | 2. **Self-Questioning**: As you work through each step, ask yourself reflective questions like, "Is this correct?", "Does it make sense?", or "What might I be overlooking?" Be critical of your own reasoning, and adjust your approach as needed. Use notation to express your confidence and evaluate the progress about solving the problem.
30 |
31 | 3. **Reassessment**: If you notice a mistake or feel uncertain about your approach, reassess your work. Go back and revise your assumptions, logic, or calculations to correct any missteps, ensuring you're on the right track.
32 |
33 | 4. **Alternative Approaches**: If you find yourself stuck or unsure about the current method, consider alternative approaches. Look at the problem from different angles, and if one method feels insufficient, explore others.
34 |
35 | 5. **Clear Detailing**: For each step, explain your reasoning clearly and in simple language. Make sure anyone who follows your work can easily understand the logic behind your decisions and the steps you've taken.
36 |
37 | 6. **Final Solution**: Once you're confident in your solution, enclose it in \\boxed{} to highlight your final answer.
38 |
39 | **Your goal is to approach the problem in a reflective, iterative manner, ensuring that no steps are skipped and no assumptions go unchecked.**
40 | """,
41 | default="Please reason step by step, and put your final answer within \\boxed{}.",
42 | elaborate="First understand the problem: understand what information is given in the text and understand what the images describes. Then think about what the problem is asking for and what knowledge the problem aims to examine. Finally, think about how to solve the problem step by step. Explain your solution in simple words that are easy to follow, assuming the readers are junior students who DOT NOT master well the relevant knowledge. Remember to put your final answer within \\boxed{}.",
43 | elaborate_rethink="""Guidelines:
44 | - First understand the problem: understand what information is given in the text and understand what the images describes. Then think about what the problem is asking for and what knowledge the problem aims to examine. Finally, think about how to solve the problem step by step. Explain your solution in simple words that are easy to follow, assuming the readers are junior students who DOT NOT master well the relevant knowledge.
45 | - **Regularly perform self-questioning, self-verification, self-correction to check your ongoing reasoning**, using connectives such as "Wait a moment", "Wait, does it seem right?", etc.
46 | - Remember to put your final answer within \\boxed{}.""",
47 | explain="""Guidelines:
48 | Understand what the problem is asking for, and what knowledge the problem aims to examine.
49 | Explain the problem and your solution in simple words to a reader, assuming he has rare knowledge and poor mastery about the related concepts.
50 | """,
51 | rethink="""Guidelines:
52 | Please think step by step, and **regularly perform self-questioning, self-verification, self-correction to check your ongoing reasoning**, using connectives such as "Wait a moment", "Wait, does it seem right?", etc. Remember to put your final answer within \\boxed{}.""",
53 | )
54 | templates['none'] = ""
55 | templates['autocode'] = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{}\n\n### Response:"
56 |
57 |
58 |
59 | class PromptDataset(Dataset):
60 | """
61 | Dataset for PPO model
62 |
63 | Args:
64 | dataset: dataset for PPO model
65 | tokenizer: tokenizer for PPO model
66 | max_length: max length of input
67 | """
68 |
69 | def preprocess_data(self, data, input_template=None, input_key="input", apply_chat_template=None, system_prompt="longcot") -> str:
70 | has_vlm_processor = self.processor is not None
71 | # print('!!!! apply chat', apply_chat_template)
72 | # print('!!!! sys', system_prompt, input_key)
73 | # import pdb; pdb.set_trace()
74 | # if system_prompt=='dpsk':
75 | # # import json
76 | # if input_key=='response' and not self.is_eval:
77 | # chat = [{"role": "user", "content": data['question']},
78 | # # {"role": "assistant", "content": data['response']}
79 | # ]
80 |
81 | # prompt = data['question'] # self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
82 | # else:
83 | # input_key = 'messages'
84 | # chat = data[input_key]
85 | # prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
86 |
87 | # elif system_prompt=='dsmath':
88 | # chat = data['messages']
89 | # for entry in chat:
90 | # if entry['role']=='user': break
91 | # template = "User:{instruction}\n\nAssistant:"
92 | # # entry['content'] += f'\n{templates["default"]}'
93 |
94 | # prompt = template.format(instruction=entry['content'])
95 | # elif system_prompt=='autocode':
96 | # chat = data['messages']
97 | # for entry in chat:
98 | # if entry['role']=='user': break
99 | # template = templates[system_prompt]
100 | # # template = "User:{instruction}\n\nAssistant:"
101 | # # entry['content'] += f'\n{templates["default"]}'
102 |
103 | # prompt = template.format(entry['content'])
104 | # elif input_key=='question':
105 | # prompt = data[input_key]
106 | # if system_prompt=='default':
107 | # trigger = templates[system_prompt]
108 | # chat = [{"role": "system", "content": trigger},
109 | # {"role": "user", "content": prompt}]
110 | # prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
111 | # else:
112 | # input_template = templates[system_prompt]
113 | # prompt = input_template.format(prompt)
114 | if has_vlm_processor:
115 | if False:
116 | chat = data[input_key]
117 | if system_prompt in templates:
118 | chat.insert(0, dict(role='system', content=templates[system_prompt]))
119 | else: print(f'!!!! warning: {system_prompt} not in templates')
120 | if isinstance(chat[-1]['content'], str):
121 | text = chat[-1]['content']
122 | content = [
123 | # dict(type='image', image=None),
124 | dict(type='text', text=text)
125 | ]
126 | chat[-1]['content'] = content
127 |
128 | else:
129 | # sysp = None
130 | # if system_prompt in templates:
131 | # sysp = templates[system_prompt]
132 | # else: print(f'!!!! warning: {system_prompt} not in templates')
133 | # now we don't use system prompt
134 | if system_prompt == 'notrigger':
135 | trigger = ""
136 | elif system_prompt == 'elaborate':
137 | trigger = f"\n\n{templates['elaborate']}"
138 | elif system_prompt == 'elaborate_rethink':
139 | trigger = f"\n\n{templates['elaborate_rethink']}"
140 | elif system_prompt == 'rethink':
141 | trigger = f"\n\n{templates['rethink']}"
142 | else:
143 | trigger = f"\n\n{templates[system_prompt]}"
144 | q = data['question']
145 | img = data.get('image', None)
146 | imglist = []
147 | if img is None or img=="" :
148 | pass # keep it empty
149 | elif isinstance(img, list):
150 | imglist = [dict(type='image', image=imm) for imm in img if imm]
151 | else: imglist = [dict(type='image', image=img)]
152 | if len(imglist)>10:
153 | print('!!! [debug]', img)
154 | chat = [dict(role='user',
155 | content=imglist+[dict(type='text', text=q+trigger)] # if img else q
156 | )]
157 |
158 | if 'qid' in data:
159 | chat.append(dict(qid=data['qid']))
160 | prompt = json.dumps(chat)
161 | elif input_key=='question':
162 | chat = [{"role": "system", "content": templates["default"]},
163 | {"role": "user", "content": data['question']}]
164 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
165 | elif input_key=='messages':
166 | chat = data[input_key]
167 | if len(chat)>1:
168 | chat[0] = dict(role='system', content=templates[system_prompt]) # replace
169 | else:
170 | if system_prompt in templates:
171 | chat.insert(0, dict(role='system', content=templates[system_prompt]))
172 | else: print(f'!!!! warning: {system_prompt} not in templates')
173 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
174 |
175 | elif apply_chat_template:
176 | chat = data[input_key]
177 | if isinstance(chat, str):
178 | chat = [{"role": "user", "content": chat}]
179 | else: # messages
180 | # if system_prompt!="none":
181 | if len(chat)>1:
182 | chat[0] = dict(role='system', content=templates[system_prompt]) # replace
183 | else: chat.insert(0, dict(role='system', content=templates[system_prompt]))
184 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
185 | else:
186 | prompt = data[input_key]
187 | input_template = templates[system_prompt]
188 | if system_prompt in ['none']:
189 | print(f"template cannot be {system_prompt} when not using chat template")
190 | chat = [dict(role='system', content=templates[system_prompt]),
191 | dict(role='user', content=prompt)
192 | ]
193 | prompt = self.tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
194 | else:
195 | prompt = input_template.format(prompt)
196 | if prompt=="": print('!!!! warning, prompts incorrect')
197 | return prompt
198 |
199 | def __init__(
200 | self,
201 | dataset,
202 | tokenizer,
203 | strategy,
204 | input_template=None,
205 | is_eval=False,
206 | processor=None,
207 | ) -> None:
208 | super().__init__()
209 | self.strategy = strategy
210 | self.tokenizer = tokenizer
211 | self.processor = processor
212 | self.is_eval = is_eval
213 |
214 | # chat_template
215 | self.input_template = input_template
216 | input_key = getattr(self.strategy.args, "input_key", None)
217 | controlled_shuffle = getattr(self.strategy.args, "controlled_shuffle", 0)
218 | apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False)
219 |
220 | system_prompt = getattr(self.strategy.args, "system_prompt", "none")
221 | # print("sysprompt", system_prompt)
222 | do_vlm = getattr(self.strategy.args, "train_vlm", False)
223 | # import pdb; pdb.set_trace()
224 | if apply_chat_template:
225 | apply_chat_template = self.processor.apply_chat_template if do_vlm else self.tokenizer.apply_chat_template
226 |
227 |
228 | self.prompts = []
229 | repeat = 1 if controlled_shuffle==0 else controlled_shuffle
230 | for _ in range(repeat):
231 | for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()):
232 | prompt = self.preprocess_data(data, input_template, input_key, apply_chat_template, system_prompt)
233 | self.prompts.append(prompt)
234 | # print("!!!! peek", self.prompts[0])
235 |
236 |
237 | def __len__(self):
238 | length = len(self.prompts)
239 | return length
240 |
241 | def __getitem__(self, idx):
242 | return self.prompts[idx]
243 |
244 |
--------------------------------------------------------------------------------
/openrlhf/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 |
5 | def zero_pad_sequences(sequences, side: str = "left", value=0):
6 | assert side in ("left", "right")
7 | max_len = max(seq.size(-1) for seq in sequences)
8 | padded_sequences = []
9 | for seq in sequences:
10 | pad_len = max_len - seq.size(-1)
11 | padding = (pad_len, 0) if side == "left" else (0, pad_len)
12 | padded_sequences.append(F.pad(seq, padding, value=value))
13 | return torch.stack(padded_sequences, dim=0)
14 |
15 |
16 | def exist_and_not_none(d, key):
17 | return key in d and not d[key] is None
18 |
--------------------------------------------------------------------------------
/openrlhf/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .actor import Actor
2 | from .loss import (
3 | DPOLoss,
4 | GPTLMLoss,
5 | KDLoss,
6 | KTOLoss,
7 | LogExpLoss,
8 | PairWiseLoss,
9 | PolicyLoss,
10 | SFTLoss,
11 | PRMLoss,
12 | ValueLoss,
13 | VanillaKTOLoss,
14 | )
15 | from .model import get_llm_for_sequence_regression
16 |
17 | __all__ = [
18 | "Actor",
19 | "DPOLoss",
20 | "GPTLMLoss",
21 | "KDLoss",
22 | "KTOLoss",
23 | "LogExpLoss",
24 | "PairWiseLoss",
25 | "PolicyLoss",
26 | "SFTLoss",
27 | "PRMLoss",
28 | "ValueLoss",
29 | "VanillaKTOLoss",
30 | "get_llm_for_sequence_regression",
31 | ]
32 |
--------------------------------------------------------------------------------
/openrlhf/models/model.py:
--------------------------------------------------------------------------------
1 | from typing import Optional, Union
2 |
3 | import deepspeed
4 | import torch
5 | import torch.nn as nn
6 | from flash_attn.utils.distributed import all_gather
7 | from peft import LoraConfig, get_peft_model
8 | from peft.tuners.lora import LoraLayer
9 | from transformers import AutoConfig, AutoModel, BitsAndBytesConfig
10 | from transformers.integrations.deepspeed import HfDeepSpeedConfig
11 |
12 | from openrlhf.utils.logging_utils import init_logger
13 |
14 | from .ring_attn_utils import convert_ring_attn_params
15 | from .utils import reset_position_ids
16 | from ..utils.utils import get_generation_cls
17 |
18 | logger = init_logger(__name__)
19 |
20 |
21 | # Construct transformer with a value head for sequence classification.
22 | # https://github.com/huggingface/transformers/blob/405b56269812056d9593869e22b7b264d806cb1e/src/transformers/models/llama/modeling_llama.py#L1254
23 | def get_llm_for_sequence_regression(
24 | model_name_or_path: str,
25 | model_type: str,
26 | *,
27 | bf16=True,
28 | load_in_4bit=False,
29 | lora_rank=0,
30 | lora_alpha=16,
31 | target_modules=None,
32 | lora_dropout=0,
33 | normalize_reward=False,
34 | use_flash_attention_2=False,
35 | ds_config: dict = None,
36 | init_value_head: bool = False,
37 | value_head_prefix="score",
38 | device_map=None,
39 | packing_samples=False,
40 | **kwargs,
41 | ) -> nn.Module:
42 | """Retrieve a transformer model with a sequence regression head on top.
43 |
44 | This function loads a pretrained transformer model and attaches a linear layer for sequence regression.
45 |
46 | Args:
47 | model_name_or_path (str): Path to the pretrained model.
48 | model_type (str): Type of the model, either "reward" or "critic".
49 | bf16 (bool, optional): Enable bfloat16 precision. Defaults to True.
50 | load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False.
51 | lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0.
52 | lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16.
53 | target_modules (list, optional): List of target modules for LoRA. Defaults to None.
54 | lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0.
55 | normalize_reward (bool, optional): Normalize reward values. Defaults to False.
56 | use_flash_attention_2 (bool, optional): Use Flash Attention 2.0. Defaults to False.
57 | ds_config (dict, optional): Deepspeed configuration for model partitioning across multiple GPUs when ZeRO-3 is enabled. Defaults to None.
58 | init_value_head (bool, optional): Initialize the value head. Defaults to False.
59 | value_head_prefix (str, optional): Prefix for the value head. Defaults to "score".
60 | device_map (dict, optional): Map of devices for model loading. Defaults to None.
61 | packing_samples (bool, optional): Whether to pack samples during training. Defaults to False.
62 |
63 | Returns:
64 | nn.Module: A pretrained transformer model with a sequence regression head.
65 | """
66 | assert (
67 | model_type == "critic" or model_type == "reward"
68 | ), f"invalid model_type: {model_type}, should be critic or reward."
69 |
70 | config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
71 | config.normalize_reward = normalize_reward
72 | config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager"
73 |
74 | # Prioritize using the value_head_prefix in the model configuration.
75 | value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix)
76 | logger.info(f"set value_head_prefix to `{value_head_prefix}`")
77 | base_class = get_generation_cls(config)
78 | base_pretrained_class = base_class.__base__
79 | if model_type == "reward":
80 | cls_class = _get_reward_model(base_class, value_head_prefix, packing_samples)
81 | else:
82 | cls_class = _get_critic_model(base_class, value_head_prefix, packing_samples)
83 |
84 | # Note: dschf is defined in function scope to avoid global effects
85 | # https://huggingface.co/docs/transformers/main_classes/deepspeed#nontrainer-deepspeed-integration
86 | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3:
87 | dschf = HfDeepSpeedConfig(ds_config)
88 | else:
89 | dschf = None
90 |
91 | if load_in_4bit:
92 | assert bf16, "we only support bnb_4bit_compute_dtype = bf16"
93 | nf4_config = BitsAndBytesConfig(
94 | load_in_4bit=True,
95 | bnb_4bit_quant_type="nf4",
96 | bnb_4bit_use_double_quant=True,
97 | bnb_4bit_compute_dtype=torch.bfloat16,
98 | )
99 | else:
100 | nf4_config = None
101 |
102 | model = cls_class.from_pretrained(
103 | model_name_or_path,
104 | config=config,
105 | trust_remote_code=True,
106 | torch_dtype=torch.bfloat16 if bf16 else "auto",
107 | quantization_config=nf4_config,
108 | device_map=device_map,
109 | **kwargs,
110 | )
111 |
112 | # LoRA
113 | if lora_rank > 0:
114 | model.enable_input_require_grads()
115 | lora_config = LoraConfig(
116 | r=lora_rank,
117 | lora_alpha=lora_alpha,
118 | target_modules=target_modules,
119 | lora_dropout=lora_dropout,
120 | bias="none",
121 | )
122 | model = get_peft_model(model, lora_config)
123 |
124 | if load_in_4bit:
125 | for name, module in model.named_modules():
126 | if isinstance(module, LoraLayer):
127 | module = module.to(torch.bfloat16)
128 | if "norm" in name:
129 | module = module.to(torch.float32)
130 | if value_head_prefix in name or "embed_tokens" in name:
131 | if hasattr(module, "weight"):
132 | module = module.to(torch.bfloat16)
133 |
134 | # MoE - balancing loss
135 | model_config = model.config.to_dict()
136 | if "output_router_logits" in model_config:
137 | print("[MoE] set output_router_logits as True")
138 | model.config.output_router_logits = True
139 |
140 | # https://github.com/huggingface/transformers/issues/26877
141 | model.config.use_cache = False
142 |
143 | # NOTE: For reward model training only, intialize value_head manually
144 | # because deepspeed.zero.Init() will not intialize them.
145 | # TODO: Find a better way to clarify reward model training.
146 | if init_value_head:
147 | value_head = getattr(model, value_head_prefix)
148 | if dschf is not None:
149 | logger.info("initialize value_head for ZeRO-3 reward model training.")
150 | with deepspeed.zero.GatheredParameters([value_head.weight], modifier_rank=0):
151 | if torch.distributed.get_rank() == 0:
152 | value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1))
153 | else:
154 | value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1))
155 |
156 | return model
157 |
158 |
159 | def _get_reward_model(base_llm_model, value_head_prefix="score", packing_samples=False):
160 | class RewardModel(base_llm_model):
161 | supports_gradient_checkpointing = True
162 |
163 | def __init__(self, config: AutoConfig):
164 | super().__init__(config)
165 |
166 | self.value_head_prefix = value_head_prefix
167 | setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False))
168 |
169 | self.packing_samples = packing_samples
170 |
171 | # mean std
172 | self.normalize_reward = config.normalize_reward
173 | self.register_buffer("mean", torch.zeros(1), persistent=False)
174 | self.register_buffer("std", torch.ones(1), persistent=False)
175 |
176 | # load mean/std from config.json
177 | if hasattr(config, "mean"):
178 | self.mean[0] = config.mean
179 | self.std[0] = config.std
180 |
181 | def forward(
182 | self,
183 | input_ids: torch.LongTensor = None,
184 | attention_mask: Optional[torch.Tensor] = None,
185 | return_output=False,
186 | ring_attn_group=None,
187 | packed_seq_lens=None,
188 | visual_inputs=None,
189 | ) -> torch.Tensor:
190 | if visual_inputs is None:
191 | visual_inputs = {}
192 | if not self.packing_samples:
193 | # https://github.com/OpenRLHF/OpenRLHF/issues/217
194 | position_ids = attention_mask.long().cumsum(-1) - 1
195 | position_ids.masked_fill_(attention_mask == 0, 1)
196 | else:
197 | # convert attention_mask to position_ids
198 | if ring_attn_group is not None:
199 | input_ids, attention_mask, position_ids = convert_ring_attn_params(
200 | input_ids, attention_mask, packed_seq_lens, ring_attn_group
201 | )
202 | else:
203 | position_ids = reset_position_ids(attention_mask)
204 | # explicitly ignore attention_mask for packing_samples
205 | attention_mask = None
206 |
207 | outputs = super().forward(
208 | input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs
209 | )
210 | if "last_hidden_state" in outputs:
211 | last_hidden_states = outputs["last_hidden_state"]
212 | elif "hidden_states" in outputs:
213 | last_hidden_states = outputs["hidden_states"][-1]
214 | else:
215 | raise ValueError("outputs should contain either last_hidden_state or hidden_states")
216 | values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)
217 |
218 | if self.packing_samples:
219 | if ring_attn_group is not None:
220 | reward = all_gather(values, ring_attn_group).reshape(1, -1)
221 | else:
222 | reward = values
223 | # TODO: convert packed_seq_lens into torch tensor in advance
224 | packed_seq_lens = torch.tensor(packed_seq_lens, device=values.device)
225 | eos_indices = packed_seq_lens.cumsum(dim=0) - 1
226 | reward = reward.squeeze(0).gather(dim=0, index=eos_indices)
227 | else:
228 | eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True)
229 | reward = values.gather(dim=1, index=eos_indices).squeeze(1)
230 |
231 | if not self.training and self.normalize_reward:
232 | reward = (reward - self.mean) / self.std
233 |
234 | return (reward, outputs) if return_output else reward
235 |
236 | return RewardModel
237 |
238 |
239 | def _get_critic_model(base_llm_model, value_head_prefix="score", packing_samples=False):
240 | class CriticModel(base_llm_model):
241 | supports_gradient_checkpointing = True
242 |
243 | def __init__(self, config: AutoConfig):
244 | super().__init__(config)
245 |
246 | self.value_head_prefix = value_head_prefix
247 | setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False))
248 |
249 | self.packing_samples = packing_samples
250 |
251 | # mean std
252 | self.normalize_reward = config.normalize_reward
253 | self.register_buffer("mean", torch.zeros(1), persistent=False)
254 | self.register_buffer("std", torch.ones(1), persistent=False)
255 |
256 | # load mean/std from config.json
257 | if hasattr(config, "mean"):
258 | self.mean[0] = config.mean
259 | self.std[0] = config.std
260 |
261 | def forward(
262 | self,
263 | input_ids: torch.LongTensor = None,
264 | num_actions: Optional[Union[int, list[int]]] = None,
265 | attention_mask: Optional[torch.Tensor] = None,
266 | return_output=False,
267 | packed_seq_lens=None,
268 | visual_inputs={},
269 | ) -> torch.Tensor:
270 | if not self.packing_samples:
271 | # https://github.com/OpenRLHF/OpenRLHF/issues/217
272 | position_ids = attention_mask.long().cumsum(-1) - 1
273 | position_ids.masked_fill_(attention_mask == 0, 1)
274 | else:
275 | # convert attention_mask to position_ids
276 | position_ids = reset_position_ids(attention_mask)
277 | # explicitly ignore attention_mask for packing_samples
278 | attention_mask = None
279 |
280 | outputs = super().forward(
281 | input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids,output_hidden_states=True, **visual_inputs
282 | )
283 | if "last_hidden_state" in outputs:
284 | last_hidden_states = outputs["last_hidden_state"]
285 | elif "hidden_states" in outputs:
286 | last_hidden_states = outputs["hidden_states"][-1]
287 | else:
288 | raise ValueError("outputs should contain either last_hidden_state or hidden_states")
289 | values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)[:, :-1]
290 |
291 | # normalize reward
292 | if self.normalize_reward:
293 | values = (values - self.mean) / self.std
294 |
295 | if num_actions is None:
296 | assert return_output
297 | return outputs
298 |
299 | if not self.packing_samples:
300 | action_values = values[:, -num_actions:]
301 | else:
302 | assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens)
303 | action_values = []
304 | offset = 0
305 | for num_action, seq_len in zip(num_actions, packed_seq_lens):
306 | start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1
307 | action_values.append(values[:, start:end])
308 | offset += seq_len
309 | action_values = torch.cat(action_values, dim=1)
310 |
311 | if return_output:
312 | return (action_values, outputs)
313 | else:
314 | return action_values
315 |
316 | return CriticModel
317 |
--------------------------------------------------------------------------------
/openrlhf/models/ring_attn_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.distributed as dist
3 | import torch.nn.functional as F
4 |
5 |
6 | RING_ATTN_GROUP = None
7 |
8 |
9 | def set_ring_attn_group(group):
10 | global RING_ATTN_GROUP
11 | RING_ATTN_GROUP = group
12 |
13 |
14 | def get_ring_attn_group():
15 | return RING_ATTN_GROUP
16 |
17 |
18 | def reset_ring_attn_position_ids(start, end, packed_seq_lens):
19 | """
20 | Calculate position ids for packed_seq_ids[start:end].
21 | For example, if the packed_seq_lens is [3, 2, 4, 1], start=2, end=8,
22 | the position ids will be [2, 0, 1, 0, 1, 2].
23 |
24 | Args:
25 | start: the start position
26 | end: the end position
27 | packed_seq_lens: the sequence lengths of packed sequences
28 | """
29 | position_ids = torch.zeros((1, end - start), dtype=torch.long, device=torch.cuda.current_device())
30 | offset = 0
31 | for seqlen in packed_seq_lens:
32 | seq_start = max(offset, start)
33 | seq_end = min(offset + seqlen, end)
34 | if seq_start < seq_end:
35 | position_ids[0, seq_start - start : seq_end - start] = torch.arange(seq_start - offset, seq_end - offset)
36 |
37 | offset += seqlen
38 | if offset >= end:
39 | break
40 | return position_ids
41 |
42 |
43 | def update_ring_attn_params(packed_seq_lens, total_seq_len):
44 | """
45 | Calculate the cu_seqlens for the current forward pass and pass the value to
46 | the substituted ring_flash_attn.
47 |
48 | Note that total_seq_len may be larger than the sum of packed_seq_lens because of padding.
49 | """
50 | assert RING_ATTN_GROUP is not None
51 | cu_seqlens = torch.cumsum(
52 | torch.tensor(packed_seq_lens, device=torch.cuda.current_device(), dtype=torch.int32),
53 | dim=-1,
54 | dtype=torch.int32,
55 | )
56 | cu_seqlens = F.pad(F.pad(cu_seqlens, (1, 0), value=0), (0, 1), value=total_seq_len)
57 |
58 | from ring_flash_attn import update_ring_flash_attn_params
59 |
60 | update_ring_flash_attn_params(cu_seqlens, RING_ATTN_GROUP)
61 |
62 |
63 | def convert_ring_attn_params(sequences, attention_mask, packed_seq_lens, ring_attn_group):
64 | # each rank within the ring group will process sequences[start:end]
65 | ring_attn_rank = dist.get_rank(group=ring_attn_group)
66 | ring_attn_size = dist.get_world_size(group=ring_attn_group)
67 | total_seq_len = sequences.numel()
68 | local_seq_len = total_seq_len // ring_attn_size
69 | start, end = ring_attn_rank * local_seq_len, (ring_attn_rank + 1) * local_seq_len
70 | sequences = sequences[:, start:end]
71 | attention_mask = attention_mask[:, start:end]
72 | position_ids = reset_ring_attn_position_ids(start, end, packed_seq_lens)
73 | update_ring_attn_params(packed_seq_lens, total_seq_len)
74 | return sequences, attention_mask, position_ids
75 |
--------------------------------------------------------------------------------
/openrlhf/models/utils.py:
--------------------------------------------------------------------------------
1 | # /*
2 | # * Modified by Haozhe Wang in 2025
3 | # *
4 | # * Licensed under the Apache License, Version 2.0 (the "License");
5 | # */
6 |
7 | from typing import Optional, Tuple, Union
8 |
9 | import torch
10 | import torch.nn.functional as F
11 |
12 |
13 | def compute_approx_kl(
14 | log_probs: torch.Tensor,
15 | log_probs_base: torch.Tensor,
16 | action_mask: Optional[torch.Tensor] = None,
17 | use_kl_estimator_k3: bool = False,
18 | ) -> torch.Tensor:
19 | """
20 | Compute the approximate KL divergence between two distributions.
21 | Schulman blog: http://joschu.net/blog/kl-approx.html
22 |
23 | Args:
24 | log_probs: Log probabilities of the new distribution.
25 | log_probs_base: Log probabilities of the base distribution.
26 | action_mask: Mask for actions.
27 | """
28 |
29 | log_ratio = log_probs.float() - log_probs_base.float()
30 | if action_mask is not None:
31 | log_ratio = log_ratio * action_mask
32 |
33 | # The k3 estimator is the non negative kl approximation in
34 | # http://joschu.net/blog/kl-approx.html
35 | # Besides non negative, it is also unbiased and have lower variance.
36 | if use_kl_estimator_k3:
37 | log_ratio = -log_ratio
38 | log_ratio = log_ratio.exp() - 1 - log_ratio
39 |
40 | return log_ratio
41 |
42 |
43 | def compute_reward(
44 | r: Union[torch.Tensor, float],
45 | kl_coef: float,
46 | kl: Union[torch.Tensor, list[torch.Tensor]],
47 | action_mask: Optional[torch.Tensor] = None,
48 | num_actions: Optional[Union[int, list[int]]] = None,
49 | reward_clip_range: Tuple[float, float] = None,
50 | ) -> Union[torch.Tensor, list[torch.Tensor]]:
51 | if kl_coef <= 0.0:
52 | kl_coef = 0.0
53 |
54 | if reward_clip_range:
55 | r = r.clamp(min=reward_clip_range[0], max=reward_clip_range[1])
56 |
57 | if action_mask is not None:
58 | kl_reward = -kl_coef * kl
59 | # The following code is equivalent to:
60 | #
61 | # last_reward = torch.zeros_like(kl)
62 | # for i in range(last_reward.size(0)):
63 | # for t in reversed(range(last_reward.size(1))):
64 | # if action_mask[i][t] > 0.5:
65 | # last_reward[i][t] = r[i]
66 | # break
67 | #
68 | eos_indices = action_mask.size(1) - 1 - action_mask.long().fliplr().argmax(dim=1, keepdim=True)
69 | last_reward = torch.zeros_like(kl).scatter_(dim=1, index=eos_indices, src=r.unsqueeze(1).to(kl.dtype))
70 |
71 | reward = last_reward + kl_reward
72 | else:
73 | # TODO: write a more efficient version
74 | reward = []
75 | for i, (kl_seg, action_len) in enumerate(zip(kl, num_actions)):
76 | kl_reward = -kl_coef * kl_seg
77 | kl_reward[action_len - 1] += r[i]
78 | reward.append(kl_reward)
79 |
80 | return reward
81 |
82 |
83 | def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
84 | # https://github.com/OpenRLHF/OpenRLHF/pull/718#issuecomment-2641081881
85 | if logits.dtype in [torch.float32, torch.float64]:
86 | logits_labels = torch.gather(logits, dim=-1, index=labels.unsqueeze(-1)).squeeze(-1)
87 | logsumexp_values = torch.stack(
88 | [torch.logsumexp(l, dim=-1) for l in logits] # loop to reduce peak mem consumption
89 | )
90 | log_probs_labels = logits_labels - logsumexp_values # log_softmax(x_i) = x_i - logsumexp(x)
91 | else:
92 | log_probs_labels = []
93 | for row_logits, row_labels in zip(logits, labels): # loop to reduce peak mem consumption
94 | row_log_probs = F.log_softmax(row_logits, dim=-1)
95 | row_log_probs_labels = row_log_probs.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
96 | log_probs_labels.append(row_log_probs_labels)
97 | log_probs_labels = torch.stack(log_probs_labels)
98 | return log_probs_labels
99 |
100 |
101 | def masked_mean(tensor: torch.Tensor, mask: Optional[torch.Tensor], dim: int = None) -> torch.Tensor:
102 | if mask is None:
103 | return tensor.mean(axis=dim)
104 | return (tensor * mask).sum(axis=dim) / (mask.sum(axis=dim)+1e-4)
105 |
106 |
107 | def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
108 | tensor = tensor * mask
109 | mean = masked_mean(tensor, mask, dim=dim)
110 | mean_centered = tensor - mean
111 | var = masked_mean(mean_centered**2, mask, dim=dim)
112 | return mean_centered * var.clamp(min=eps).rsqrt()
113 |
114 |
115 | # Reset positions for packed samples
116 | # For example
117 | # Input: attention_mask = torch.tensor([[1, 1, 1, 2, 2, 2, 3, 3, 0]])
118 | # Output: position_ids = torch.tensor([[0, 1, 2, 0, 1, 2, 0, 1, 0]])
119 | def reset_position_ids(attention_mask):
120 | position_ids = torch.zeros_like(attention_mask, dtype=torch.long)
121 | for i in range(attention_mask.size(0)):
122 | mask = attention_mask[i]
123 | seq_num = mask.max().item()
124 | for index in range(1, seq_num + 1):
125 | sample_mask = mask == index
126 | sample_length = sample_mask.sum().item()
127 | position_ids[i, sample_mask] = torch.arange(sample_length, device=mask.device)
128 | return position_ids
129 |
130 | def packed_sequence_to_position_tensor(packed_seq_lens, device):
131 | """
132 | Converts packed_seq_lens to a tensor of token positions.
133 |
134 | Args:
135 | packed_seq_lens: A list of integers representing token length for each sequence.
136 |
137 | Returns:
138 | A tensor of shape (1, ntoken) containing the sequences of positions.
139 | """
140 | output_list = []
141 | for seq_len in packed_seq_lens:
142 | output_list.extend(list(range(seq_len)))
143 | return torch.tensor(output_list, device=device).unsqueeze(0)
144 |
145 |
146 | def unpacking_samples(values: torch.Tensor, packed_seqlens: list[int]):
147 | values = values.squeeze(0)
148 | unpacked_values = []
149 | offset = 0
150 | for seqlen in packed_seqlens:
151 | unpacked_values.append(values[offset : offset + seqlen])
152 | offset += seqlen
153 | return unpacked_values
154 |
--------------------------------------------------------------------------------
/openrlhf/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # from .dpo_trainer import DPOTrainer
2 | # from .kd_trainer import KDTrainer
3 | # from .kto_trainer import KTOTrainer
4 | from .ppo_trainer import PPOTrainer
5 | from .evaluator import Evaluator
6 | # from .prm_trainer import ProcessRewardModelTrainer
7 | # from .rm_trainer import RewardModelTrainer
8 | # from .sft_trainer import SFTTrainer
9 |
10 | __all__ = [
11 | # "DPOTrainer",
12 | # "KDTrainer",
13 | # "KTOTrainer",
14 | "PPOTrainer",
15 | # "ProcessRewardModelTrainer",
16 | # "RewardModelTrainer",
17 | # "SFTTrainer",
18 | "Evaluator"
19 | ]
20 |
--------------------------------------------------------------------------------
/openrlhf/trainer/evaluator.py:
--------------------------------------------------------------------------------
1 | # /*
2 | # * Original Copyright Haozhe Wang in 2025
3 | # *
4 | # * Licensed under the Apache License, Version 2.0 (the "License");
5 | # */
6 |
7 | import os
8 | import os.path
9 | from abc import ABC
10 | from typing import Any, Callable, Dict, List, Optional
11 |
12 | import torch
13 | import torch.distributed
14 | import torch.nn as nn
15 | from torch.optim import Optimizer
16 | from torch.utils.data import DataLoader
17 | from tqdm import tqdm
18 |
19 | from openrlhf.models import Actor, GPTLMLoss, PolicyLoss, SFTLoss, ValueLoss
20 | from openrlhf.models.utils import masked_mean
21 | from openrlhf.utils.distributed_sampler import DistributedSampler
22 | from openrlhf.models.utils import log_probs_from_logits
23 |
24 | from .ppo_utils import AdaptiveKLController, Experience, FixedKLController, NaiveExperienceMaker, NaiveReplayBuffer, DATA_PROCESSOR_MAP
25 | import random
26 | import copy
27 | import numpy as np
28 | from collections import defaultdict
29 | import json
30 |
31 |
32 |
33 | def read_jsonl(filepath):
34 | """
35 | Reads a JSON Lines (jsonl) file and returns a list of dictionaries.
36 |
37 | Args:
38 | filepath (str): The path to the jsonl file.
39 |
40 | Returns:
41 | list: A list of dictionaries, where each dictionary represents a line
42 | from the jsonl file. Returns an empty list if the file is empty
43 | or if an error occurs.
44 | """
45 | data = []
46 | try:
47 | with open(filepath, 'r', encoding='utf-8') as f:
48 | for line in f:
49 | try:
50 | data.append(json.loads(line.strip()))
51 | except json.JSONDecodeError:
52 | print(f"Warning: Invalid JSON on line: {line.strip()}")
53 | # Optionally, you might want to log the error or handle it differently.
54 |
55 | except FileNotFoundError:
56 | print(f"Error: File not found at {filepath}")
57 | except Exception as e:
58 | print(f"An unexpected error occurred: {e}")
59 |
60 | return data
61 |
62 | class Evaluator(ABC):
63 | """
64 | Trainer for Proximal Policy Optimization (PPO) algorithm.
65 |
66 | Args:
67 | strategy (Strategy): The training strategy to use.
68 | actor (Actor): The actor model in the PPO algorithm.
69 | critic (nn.Module): The critic model in the PPO algorithm.
70 | reward_model (nn.Module): The reward model for calculating rewards in the RLHF setup.
71 | initial_model (Actor): The initial model for reference logits to limit actor updates in RLHF.
72 | ema_model (Actor): The exponential moving average model for stable training.
73 | actor_optim (Optimizer): The optimizer for the actor model.
74 | critic_optim (Optimizer): The optimizer for the critic model.
75 | actor_scheduler (Scheduler): The learning rate scheduler for the actor.
76 | critic_scheduler (Scheduler): The learning rate scheduler for the critic.
77 | ema_beta (float, defaults to 0.992): EMA decay rate for model stability.
78 | init_kl_coef (float, defaults to 0.001): Initial coefficient for KL divergence.
79 | kl_target (float, optional): Target value for KL divergence.
80 | kl_horizon (int, defaults to 10000): Horizon for KL annealing.
81 | ptx_coef (float, defaults to 0): Coefficient for supervised loss from pre-trained data.
82 | micro_train_batch_size (int, defaults to 8): Micro-batch size for actor training.
83 | buffer_limit (int, defaults to 0): Maximum size of the replay buffer.
84 | buffer_cpu_offload (bool, defaults to True): If True, offloads replay buffer to CPU.
85 | eps_clip (float, defaults to 0.2): Clipping coefficient for policy loss.
86 | value_clip (float, defaults to 0.2): Clipping coefficient for value function loss.
87 | micro_rollout_batch_size (int, defaults to 8): Micro-batch size for generating rollouts.
88 | gradient_checkpointing (bool, defaults to False): If True, enables gradient checkpointing.
89 | max_epochs (int, defaults to 1): Number of epochs to train.
90 | max_norm (float, defaults to 1.0): Maximum gradient norm for gradient clipping.
91 | tokenizer (Callable, optional): Tokenizer for input data.
92 | prompt_max_len (int, defaults to 128): Maximum length for prompts.
93 | dataloader_pin_memory (bool, defaults to True): If True, pins memory in the data loader.
94 | remote_rm_url (str, optional): URL for remote reward model API.
95 | reward_fn (Callable, optional): Custom reward function for computing rewards.
96 | save_hf_ckpt (bool): Whether to save huggingface-format model weight.
97 | disable_ds_ckpt (bool): Whether not to save deepspeed-format model weight. (Deepspeed model weight is used for training recovery)
98 | **generate_kwargs: Additional arguments for model generation.
99 | """
100 |
101 | def __init__(
102 | self,
103 | strategy,
104 | ema_beta: float = 0.992,
105 | init_kl_coef: float = 0.001,
106 | kl_target: float = None,
107 | kl_horizon: int = 10000,
108 | ptx_coef: float = 0,
109 | micro_train_batch_size: int = 8,
110 | buffer_limit: int = 0,
111 | buffer_cpu_offload: bool = True,
112 | eps_clip: float = 0.2,
113 | value_clip: float = 0.2,
114 | micro_rollout_batch_size: int = 8,
115 | gradient_checkpointing: bool = False,
116 | max_epochs: int = 1,
117 | max_norm: float = 1.0,
118 | processor: Optional[Callable[[Any], Dict]] = None,
119 | tokenizer: Optional[Callable[[Any], Dict]] = None,
120 | prompt_max_len: int = 128,
121 | dataloader_pin_memory: bool = True,
122 | reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
123 | save_hf_ckpt: bool = False,
124 | disable_ds_ckpt: bool = False,
125 | **generate_kwargs,
126 | ) -> None:
127 | # assert (
128 | # not isinstance(reward_model, List) or len(reward_model) == 1 or reward_fn is not None
129 | # ), "reward_fn must be specified if using multiple reward models"
130 |
131 | super().__init__()
132 | self.strategy = strategy
133 |
134 | strategy.setup_distributed()
135 | self.args = strategy.args
136 | self.rloo_sft = self.args.advantage_estimator.lower() in ['rloo_sft', 'group_sft']
137 | self.save_hf_ckpt = save_hf_ckpt
138 | self.disable_ds_ckpt = disable_ds_ckpt
139 | self.micro_rollout_batch_size = micro_rollout_batch_size
140 | self.max_epochs = max_epochs
141 | self.tokenizer = tokenizer
142 | self.processor = processor
143 | self.data_processor = None
144 | # for vlm critic model, not provice processor.
145 | if self.args.train_vlm and processor is not None:
146 | self.data_processor = DATA_PROCESSOR_MAP[type(processor)](processor)
147 | self.tokenizer = self.data_processor.tokenizer
148 |
149 | self.generate_kwargs = generate_kwargs
150 | self.dataloader_pin_memory = dataloader_pin_memory
151 | self.max_norm = max_norm
152 | self.ptx_coef = ptx_coef
153 | self.micro_train_batch_size = micro_train_batch_size
154 | self.kl_target = kl_target
155 | self.prompt_max_len = prompt_max_len
156 | self.ema_beta = ema_beta
157 | self.gradient_checkpointing = gradient_checkpointing
158 | self.reward_fn = reward_fn
159 |
160 | args = self.args
161 | self.max_len = args.max_len if args.max_len else args.prompt_max_len + args.generate_max_len
162 |
163 | packing_samples = getattr(self.args, "packing_samples", False)
164 | self.replay_buffer = NaiveReplayBuffer(
165 | micro_train_batch_size, self.data_processor, buffer_limit, buffer_cpu_offload, packing_samples,
166 | drop_maxlen=self.args.drop_maxlen,
167 | maxlen=self.args.generate_max_len + prompt_max_len,
168 | )
169 |
170 | self.iter = 0
171 | self.eval_step = 0
172 | self.best = -1
173 |
174 | def eval_unit(self, args, ep, global_step, dataloader):
175 | keys = ['reward', 'response_length', 'validity','match','usefmt','round1_nwait']
176 | infos = {k:[] for k in keys}
177 | print("!!!! eval loader size", len(dataloader), 'step', global_step)
178 | batchsize = dataloader.batch_sampler.batch_size
179 | for idx, rand_prompts in enumerate(dataloader):
180 | if batchsize>len(rand_prompts):
181 | current_len = len(rand_prompts)
182 | needed = batchsize - current_len
183 | repeat_indices = np.arange(needed) % current_len
184 | # repeat_indices = repeat_indices.to(rand_prompts.device)
185 | additional = [rand_prompts[ii] for ii in repeat_indices]
186 | rand_prompts = rand_prompts + additional
187 | else: needed = 0
188 | print(f"!!!! ========== eval progress {idx}/{len(dataloader)} ==========")
189 |
190 | exp_list = self.get_explist_from_prompts(args, ep, rand_prompts, is_eval=True, eval_step=global_step)
191 |
192 | for i, experience in enumerate(exp_list):
193 | self.replay_buffer.append_split(experience, is_eval=True)
194 |
195 |
196 | for item in self.replay_buffer.eval_items:
197 | info = item.info
198 | for k in keys:
199 | infos[k].append(info[k])
200 | out_lens = infos['response_length']
201 |
202 | for k,vlist in infos.items():
203 | infos[k] = np.mean(vlist)
204 | infos['generation_exceed_rate'] = np.mean([x>args.generate_max_len-1 for x in out_lens])
205 |
206 | torch.distributed.barrier()
207 | gather_info = self.strategy.all_reduce(infos) # mean
208 |
209 | return gather_info
210 |
211 |
212 |
213 | def get_eval_result_from_disk(self):
214 | args = self.strategy.args
215 | from glob import glob
216 | # os.makedirs(args.ckpt_path, exist_ok=True)
217 | # os.makedirs(f'{args.ckpt_path}/logs', exist_ok=True)
218 | tmp = f'{args.ckpt_path}/logs/sample.eval_iter{self.eval_step}*.jsonl'
219 | files = glob(tmp)
220 | print(f'!!!! [eval] reading from disk {len(files)} files', tmp, )
221 |
222 | datalist = [read_jsonl(file) for file in files]
223 | results_each = defaultdict(list)
224 | q2results = defaultdict(list)
225 | for info in datalist:
226 | for x in info:
227 | qid = x['qids']
228 | res = x.get('match')
229 | if res is None:
230 | r0_res = x['round0_correctness']
231 | res = r0_res
232 |
233 | q2results[qid].append(res>0.5)
234 | # We compute query-wise mean acc, and then average them
235 | # this is a trick to handle the drop_last=False issue
236 | for qid, vlist in q2results.items():
237 | bench = qid.split('-')[0]
238 | macc = np.mean(vlist)
239 | results_each[bench].append(macc)
240 | all_results = []
241 | dump_info = []
242 | modelpath = args.pretrain
243 | for k in results_each.keys():
244 | nc = np.sum(results_each[k])
245 | num = len(results_each[k])
246 | dump_info.append(dict(benchname=k, pass1=nc/num, ncorrect=nc, ntotal=num, modelpath=modelpath))
247 | print(f'!!!! [eval] from disk bench={k}, acc={np.mean(results_each[k])}={nc}/{num}')
248 | all_results.extend(results_each[k])
249 | results_each[k] = np.mean(results_each[k])
250 |
251 | json.dump(dump_info, open(f'{args.ckpt_path}/logs/metrics_iter{self.eval_step}.json', 'w'))
252 | acc = np.mean(all_results)
253 | return acc, results_each
254 |
255 | def fill_replay_buffer(self, buffer, num_expected):
256 | # Ensure every item in buffer appears at least once
257 | for item in buffer[:num_expected]:
258 | self.replay_buffer.append_split(item)
259 |
260 | # Fill the remaining slots with random choices from buffer
261 | remaining_slots = num_expected - len(buffer)
262 | if remaining_slots>0:
263 | for _ in range(remaining_slots):
264 | item = random.choice(buffer)
265 | self.replay_buffer.append_split(item)
266 | print(f'!!!! rbuffersize after filling: {len(self.replay_buffer)} should be {num_expected} x nsamples_per_query', )
267 | # assert len(self.replay_buffer)==num_expected
268 |
269 | def get_explist_from_prompts(self, args, ep, all_prompts, append=False, is_eval=False, force_noprefix=False, eval_step=None):
270 | autocode = getattr(args, "prefix_generation", None)
271 | requires_group = getattr(args, "advantage_estimator", None) in ['']
272 | # print('!!!! requires group', requires_group)
273 | generate_kwargs = copy.copy(self.generate_kwargs)
274 | generate_kwargs['requires_group'] = requires_group
275 | if force_noprefix:
276 | pass
277 | elif autocode=='autocode':
278 | if ep==0:
279 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:2]]
280 | all_prompts = new_prompts
281 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:2]]
282 | else:
283 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[2:3]]
284 | all_prompts = new_prompts
285 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[2:3]]
286 | elif autocode=='autocode1':
287 | # if ep==0:
288 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:2]]
289 | all_prompts = new_prompts
290 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:2]]
291 | # else:
292 | # new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[2:3]]
293 | # all_prompts = new_prompts
294 | # generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[2:3]]
295 | elif autocode=='autocode2':
296 | # if ep==0:
297 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[:3]]
298 | all_prompts = new_prompts
299 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[:3]]
300 | elif autocode=='autocode_continue':
301 | # if ep==0:
302 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[3:5]]
303 | all_prompts = new_prompts
304 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[3:5]]
305 | elif append and autocode=="autocode_append":
306 | new_prompts = [x+prefix for x in all_prompts for prefix in self.prefixes[5:6]]
307 | all_prompts = new_prompts
308 | generate_kwargs['prefix_lengths'] = [plen for x in all_prompts for plen in self.prefix_lengths[5:6]]
309 |
310 | return self.experience_maker.make_experience_list(all_prompts, is_eval=is_eval, eval_step=eval_step, **generate_kwargs)
311 |
312 |
313 | def evaluate(
314 | self,
315 | args,
316 | eval_data
317 | ) -> None:
318 |
319 | tmp = eval_data
320 | eval_bsz = args.micro_rollout_batch_size
321 | eval_dataloader = self.strategy.setup_dataloader(
322 | tmp,
323 | eval_bsz, # should larger than world size?
324 | True,
325 | True,
326 | drop_last=False
327 | )
328 | print(f'!!!! eval dataloader size', len(eval_dataloader), 'eval_bsz', eval_bsz)
329 | self.eval_dataloader = eval_dataloader
330 | if len(eval_data)==0 or len(eval_dataloader)==0: print('!!!! no eval data, eval_data should be larger than num_vllm * micro_bsz', len(eval_data), len(eval_dataloader))
331 | else: print(f'!!!! eval data {len(eval_data)} eval dataloader', len(eval_dataloader), args.micro_rollout_batch_size)
332 | info = self.eval_unit(args, 0, self.eval_step, eval_dataloader)
333 | eval_result = info['match']
334 | torch.distributed.barrier()
335 | result2, bench_results = self.get_eval_result_from_disk()
336 | print(f'!!!! [eval] finish with step {self.eval_step} rank {self.strategy.get_rank()} gathered eval stats', info, 'from disk:', result2)
337 |
338 | self.eval_step += 1
339 | # info['match_overall'] = result2
340 | for k,v in bench_results.items():
341 | info[f'match_{k}'] = v
342 | info['match_overall'] = result2
343 | eval_save = self.best<=result2 # and args.rollout_batch_size>16
344 | if eval_save:
345 | self.best = result2
346 | print(f"!!!! [eval] saving with average score {self.best}")
347 |
348 | del eval_dataloader
349 | self.replay_buffer.eval_items.clear()
350 |
351 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ppo_utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .experience_maker import Experience, NaiveExperienceMaker, RemoteExperienceMaker
2 | from .kl_controller import AdaptiveKLController, FixedKLController
3 | from .replay_buffer import NaiveReplayBuffer
4 | from .data_processor import BaseDataProcessor, DATA_PROCESSOR_MAP
5 |
6 | __all__ = [
7 | "Experience",
8 | "NaiveExperienceMaker",
9 | "RemoteExperienceMaker",
10 | "AdaptiveKLController",
11 | "FixedKLController",
12 | "NaiveReplayBuffer",
13 | ]
14 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ppo_utils/data_processor.py:
--------------------------------------------------------------------------------
1 | # /*
2 | # * Modified by Haozhe Wang in 2025
3 | # *
4 | # * Licensed under the Apache License, Version 2.0 (the "License");
5 | # */
6 | import json
7 | import os
8 | from abc import ABC, abstractmethod
9 | from typing import List, Optional, Union, Dict
10 |
11 | import torch
12 | from qwen_vl_utils import process_vision_info
13 | from transformers import Qwen2VLProcessor
14 | from transformers.processing_utils import ProcessorMixin
15 | try:
16 | from transformers import Qwen2_5_VLProcessor
17 | except Exception as e:
18 | print("Qocal Qwen2_5_VLProcessor not found")
19 |
20 | # https://github.com/huggingface/transformers/blob/main/docs/source/en/model_doc/qwen2_5_vl.md
21 | class BaseDataProcessor(ABC):
22 | def __init__(self, processor: ProcessorMixin):
23 | super().__init__()
24 | self.processor = processor
25 |
26 | @abstractmethod
27 | def __call__(
28 | self,
29 | messages: Union[Dict, List[str], str],
30 | max_length: int,
31 | padding: bool = True,
32 | device: Optional[Union[str, torch.device]] = None,
33 | return_tensors: Optional[str] = "pt",
34 | add_special_tokens: Optional[bool] = False,
35 | truncation: Optional[bool] = True,
36 | ) -> Dict:
37 | raise NotImplementedError
38 |
39 | @abstractmethod
40 | def make_input_batch(self, inputs: List[Dict]) -> Dict:
41 | raise NotImplementedError
42 |
43 | @abstractmethod
44 | def split_input_batch(self, batch: Dict) -> List[Dict]:
45 | raise NotImplementedError
46 |
47 | def _format_messages(self, messages: Union[Dict, List[str], str]) -> List[Dict]:
48 | if isinstance(messages, list) and isinstance(messages[0], str):
49 | return [json.loads(m) for m in messages]
50 | elif isinstance(messages, str):
51 | return [json.loads(messages)]
52 | elif isinstance(messages, dict):
53 | return [messages]
54 | else:
55 | raise ValueError("Invalid messages format, must be a list of strings or a string or a dict")
56 |
57 | def apply_chat_template(
58 | self,
59 | messages: Union[Dict, List[str], str],
60 | tokenize: bool = False,
61 | add_generation_prompt: bool = True,
62 | ) -> List[str]:
63 | messages = self._format_messages(messages)
64 |
65 | return self.processor.apply_chat_template(
66 | messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt
67 | )
68 |
69 | def get_images_from_messages(
70 | self, messages: Union[Dict, List[str], str]
71 | ) -> List[Dict]:
72 | messages = self._format_messages(messages)
73 | return self._get_images_from_messages(messages)
74 |
75 | @abstractmethod
76 | def _get_images_from_messages(self, messages: List[Dict]) -> List[Dict]:
77 | raise NotImplementedError
78 |
79 | @property
80 | def pad_token_id(self) -> int:
81 | return self.processor.tokenizer.pad_token_id
82 |
83 | @property
84 | def eos_token_id(self) -> int:
85 | return self.processor.tokenizer.eos_token_id
86 |
87 | @property
88 | def tokenizer(self):
89 | return self.processor.tokenizer
90 |
91 |
92 | def add_pixel_bounds(messages):
93 | # 默认的像素范围
94 | DEFAULT_MIN_PIXELS = int(os.getenv("MIN_PIXELS", 256 * 28 * 28))
95 | DEFAULT_MAX_PIXELS = int(os.getenv("MAX_PIXELS", 1280 * 28 * 28))
96 |
97 | def process_content(content):
98 | if isinstance(content, list):
99 | for item in content:
100 | if isinstance(item, dict) and item.get("type") == "image":
101 | if "min_pixels" not in item:
102 | item["min_pixels"] = DEFAULT_MIN_PIXELS
103 | if "max_pixels" not in item:
104 | item["max_pixels"] = DEFAULT_MAX_PIXELS
105 | return content
106 |
107 | for message in messages:
108 | for msg in message:
109 | msg["content"] = process_content(msg["content"])
110 | return messages
111 |
112 | def remove_except_last(text, tag):
113 | cnt = text.count(tag)
114 | if cnt>1:
115 | index = text.rfind(tag)
116 | return text[:index].replace(tag, "")+text[index:]
117 | else: return text
118 |
119 | def find_rank_occurrence(ids, target, rank):
120 | """
121 | Finds the position (index) of the rank-th occurrence of the target in the list ids.
122 |
123 | Args:
124 | ids (list): List of integers to search through.
125 | target (int): Integer to find.
126 | rank (int): The occurrence number to locate (1-based).
127 |
128 | Returns:
129 | int: Index of the rank-th occurrence, or -1 if it doesn’t exist.
130 | """
131 | count = 0
132 | for i, val in enumerate(ids):
133 | if val == target:
134 | count += 1
135 | if count == rank:
136 | return i
137 | return -1
138 |
139 | class Qwen2VLDataProcessor(BaseDataProcessor):
140 | def __call__(
141 | self,
142 | messages,
143 | max_length,
144 | padding=True,
145 | device=None,
146 | return_tensors="pt",
147 | add_special_tokens=False,
148 | truncation=True,
149 | ) -> Dict:
150 |
151 | # messages = newlist
152 | messages = self._format_messages(messages) # list of dicts
153 | processor = self.processor
154 | # for entry in messages:
155 | # if entry['role'] == 'user':
156 | # content = entry['content'][-1]['text']
157 | # if "" in content:
158 | # content = content.replace("", "<|vision_start|><|image_pad|><|vision_end|>")
159 | # entry['content'][-1]['text'] = content
160 |
161 | texts = processor.apply_chat_template(
162 | messages, tokenize=False, add_generation_prompt=True
163 | )
164 | texts = self.handle_placeholders(texts)
165 | messages = add_pixel_bounds(messages)
166 | image_inputs, video_inputs = process_vision_info(messages)
167 | # print(texts)
168 | max_length = 10240 # we need to make sure it does not trucate
169 | batch = processor(
170 | text=texts,
171 | images=image_inputs,
172 | videos=video_inputs,
173 | padding=padding,
174 | max_length=max_length,
175 | add_special_tokens=False,
176 | truncation=truncation,
177 | return_tensors=return_tensors,
178 | )
179 | if device:
180 | return {k: v.to(device) for k, v in batch.items()}
181 | return {k: v for k, v in batch.items()}
182 |
183 | def handle_placeholders(self, texts):
184 | newlist = []
185 | placeholder = ""
186 | # placeholder2 = ""
187 | replacewith = "<|vision_start|><|image_pad|><|vision_end|>"
188 | for m in texts:
189 | new = m
190 | for k in ["<|vision_start|>","<|image_pad|>","<|vision_end|>"]:
191 | new = new.replace(k,"")
192 | # now new has no replacewith
193 | if new.count(placeholder)>0:
194 | new = new.replace(placeholder, replacewith)
195 | else:
196 | new = replacewith + new
197 | newlist.append(new)
198 | return newlist
199 |
200 | def make_input_batch(self, inputs: List[Dict]) -> Dict:
201 | # each element has no batch dimension
202 | batch = {k: None for k in inputs[0].keys()}
203 | for k in batch.keys():
204 | if k in ["input_ids", "attention_mask"]:
205 | batch[k] = torch.stack([inp[k] for inp in inputs], dim=0)
206 | elif k in ["pixel_values", "image_grid_thw"]:
207 | # qwen2vl concat all patches of all images in a batch in the first dimension
208 | batch[k] = torch.cat([inp[k] for inp in inputs], dim=0)
209 | else:
210 | raise ValueError(f"Unknown key {k} for Qwen2VLDataProcessor")
211 | return batch
212 |
213 | def split_input_batch(self, batch: Dict) -> List[Dict]:
214 | batch_size = len(batch["input_ids"])
215 | batch_kwargs = [{} for _ in range(batch_size)]
216 | # first process None values
217 | keys = []
218 | for k, v in batch.items():
219 | if v is not None:
220 | keys.append(k)
221 | else:
222 | for i in range(batch_size):
223 | batch_kwargs[i][k] = None
224 |
225 | if "pixel_values" in keys and (
226 | "input_ids" not in keys or "image_grid_thw" not in keys
227 | ):
228 | raise ValueError(
229 | "Cannot split batch with pixel_values without input_ids and image_grid_thw"
230 | )
231 | if "image_grid_thw" in keys and ("input_ids" not in keys):
232 | raise ValueError("Cannot split batch with image_grid_thw without input_ids")
233 | for k in ["input_ids", "attention_mask"]:
234 | if k in keys:
235 | vals = batch[k]
236 | if isinstance(vals, torch.Tensor):
237 | vals = torch.unbind(vals)
238 | assert batch_size == len(vals)
239 | for i, v in enumerate(vals):
240 | batch_kwargs[i][k] = v
241 | if "pixel_values" in keys:
242 | thws = batch["image_grid_thw"] # (total_img_num, (t,h,w))
243 | pixel_values = batch["pixel_values"]
244 | vision_start_id = self.processor.tokenizer("<|vision_start|>")["input_ids"][0]
245 | vision_end_id = self.processor.tokenizer("<|vision_end|>")["input_ids"][0]
246 | img_idx = 0
247 | patch_idx = 0
248 | for i in range(batch_size):
249 | input_ids_i = batch_kwargs[i]["input_ids"]
250 | if not isinstance(input_ids_i, torch.Tensor):
251 | input_ids_i = torch.tensor(input_ids_i)
252 | vision_start_num = (input_ids_i == vision_start_id).sum().item()
253 | vision_end_num = (input_ids_i == vision_end_id).sum().item()
254 |
255 | img_num = vision_end_num
256 | if img_num == 0:
257 | batch_kwargs[i]["pixel_values"] = None
258 | batch_kwargs[i]["image_grid_thw"] = None
259 | continue
260 | thws_i = thws[img_idx:img_num+img_idx]
261 | img_idx += img_num
262 | flag = False
263 | if len(thws_i) != img_num:
264 | thws_i = thws[-img_num:]
265 | print(f'[warning] the image_grid_thw does not match, this is polluted data, attempting: {len(thws_i)} vs {img_num}')
266 | flag = True
267 | # thws = thws[img_num:]
268 | if not isinstance(thws_i, torch.Tensor):
269 | thws_i = torch.stack(thws_i)
270 | batch_kwargs[i]["image_grid_thw"] = thws_i
271 | patchs_num = thws_i.prod(dim=1).sum().item()
272 | pixel_values_i = pixel_values[patch_idx:patchs_num+patch_idx]
273 | if len(pixel_values_i) != patchs_num:
274 | pixel_values_i = pixel_values[-patchs_num:]
275 | print(f'[warning] the pixel_values_i does not match, this is polluted data, attempting: {patchs_num} in {len(pixel_values)} resulting in {len(pixel_values_i)}')
276 | flag = True
277 | # assert len(pixel_values_i) == patchs_num
278 | # pixel_values = pixel_values[patch_idx:patchs_num+patch_idx]
279 | batch_kwargs[i]["pixel_values"] = pixel_values_i
280 | if flag:
281 | batch_kwargs[i] = None
282 | print('[truncation warning] appears a sample has mismatched vision_start and vision_end, likely due to garbage outputs, its current length is ', len(input_ids_i))
283 | # print(input_ids_i.detach().cpu().numpy().tolist())
284 | error_index = find_rank_occurrence(input_ids_i.detach().cpu().numpy().tolist(), vision_start_id, 1)
285 | input_ids_i[error_index:] = self.eos_token_id # how about directly before the vision start?
286 | continue
287 | # assert len(thws) == 0
288 | # assert len(pixel_values) == 0
289 | return batch_kwargs
290 |
291 | def _get_images_from_messages(self, messages: List[Dict]) -> List[Dict]:
292 | messages = add_pixel_bounds(messages)
293 | image_inputs, _ = process_vision_info(messages)
294 | return image_inputs
295 |
296 |
297 | DATA_PROCESSOR_MAP = {
298 | Qwen2VLProcessor: Qwen2VLDataProcessor,
299 | Qwen2_5_VLProcessor: Qwen2VLDataProcessor,
300 | }
301 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ppo_utils/kl_controller.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class AdaptiveKLController:
5 | """
6 | Adaptive KL controller described in the paper:
7 | https://arxiv.org/pdf/1909.08593.pdf
8 | """
9 |
10 | def __init__(self, init_kl_coef, target, horizon):
11 | self.value = init_kl_coef
12 | self.target = target
13 | self.horizon = horizon
14 |
15 | def update(self, current, n_steps):
16 | target = self.target
17 | proportional_error = np.clip(current / target - 1, -0.2, 0.2)
18 | mult = 1 + proportional_error * n_steps / self.horizon
19 | self.value *= mult
20 |
21 |
22 | class FixedKLController:
23 | """Fixed KL controller."""
24 |
25 | def __init__(self, kl_coef):
26 | self.value = kl_coef
27 |
28 | def update(self, current, n_steps):
29 | pass
30 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/__init__.py:
--------------------------------------------------------------------------------
1 | from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor
2 | from .ppo_actor import ActorModelRayActor
3 | from .ppo_critic import CriticModelRayActor
4 | from .vllm_engine import create_vllm_engines
5 | from .evaluator2 import Evaluator2
6 |
7 | __all__ = [
8 | "DistributedTorchRayActor",
9 | "PPORayActorGroup",
10 | "ReferenceModelRayActor",
11 | "RewardModelRayActor",
12 | "ActorModelRayActor",
13 | "CriticModelRayActor",
14 | "create_vllm_engines",
15 | "Evaluator2"
16 | ]
17 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/launcher.py:
--------------------------------------------------------------------------------
1 | # /*
2 | # * Modified by Haozhe Wang in 2025
3 | # *
4 | # * Licensed under the Apache License, Version 2.0 (the "License");
5 | # */
6 | import logging
7 | import os
8 | import socket
9 | from typing import Callable, Dict, List, Optional, Type
10 |
11 | import ray
12 | import torch
13 | from ray.util.placement_group import PlacementGroup, placement_group
14 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
15 |
16 | from openrlhf.models import Actor, get_llm_for_sequence_regression
17 | from openrlhf.trainer.ray.utils import ray_noset_visible_devices
18 | from openrlhf.utils.deepspeed import DeepspeedStrategy
19 |
20 |
21 | class DistributedTorchRayActor:
22 | def __init__(self, world_size, rank, master_addr, master_port):
23 | logging.basicConfig(
24 | format="%(asctime)s %(levelname)-8s %(message)s",
25 | level=logging.INFO,
26 | datefmt="%Y-%m-%d %H:%M:%S",
27 | )
28 | self._world_size = world_size
29 | self._rank = rank
30 | self._master_addr = master_addr if master_addr else self._get_current_node_ip()
31 | self._master_port = master_port if master_port else self._get_free_port()
32 | os.environ["MASTER_ADDR"] = self._master_addr
33 | os.environ["MASTER_PORT"] = str(self._master_port)
34 | os.environ["WORLD_SIZE"] = str(self._world_size)
35 | os.environ["RANK"] = str(self._rank)
36 | # NOTE: Ray will automatically set the *_VISIBLE_DEVICES
37 | # environment variable for each actor, unless
38 | # RAY_EXPERIMENTAL_NOSET_*_VISIBLE_DEVICES is set, so
39 | # set local rank to 0 when the flag is not applicable.
40 | os.environ["LOCAL_RANK"] = str(ray.get_gpu_ids()[0]) if ray_noset_visible_devices() else "0"
41 |
42 | @staticmethod
43 | def _get_current_node_ip():
44 | address = ray._private.services.get_node_ip_address()
45 | # strip ipv6 address
46 | return address.strip("[]")
47 |
48 | @staticmethod
49 | def _get_free_port():
50 | with socket.socket() as sock:
51 | sock.bind(("", 0))
52 | return sock.getsockname()[1]
53 |
54 | def get_master_addr_port(self):
55 | return self._master_addr, self._master_port
56 |
57 |
58 | class BasePPORole(DistributedTorchRayActor):
59 | def _setup_distributed(self, strategy: DeepspeedStrategy):
60 | # configure strategy
61 | self.strategy = strategy
62 | strategy.setup_distributed()
63 |
64 | def init_model_from_pretrained(self, *args, **kwargs):
65 | raise NotImplementedError()
66 |
67 |
68 | @ray.remote(num_gpus=1)
69 | class ReferenceModelRayActor(BasePPORole):
70 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain):
71 | self._setup_distributed(strategy)
72 | model = Actor(
73 | pretrain,
74 | use_flash_attention_2=strategy.args.flash_attn,
75 | bf16=strategy.args.bf16,
76 | load_in_4bit=strategy.args.load_in_4bit,
77 | ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload),
78 | packing_samples=strategy.args.packing_samples,
79 | )
80 | strategy.print(model)
81 |
82 | if strategy.args.ref_reward_offload:
83 | model._offload = True
84 |
85 | self.model = self.strategy.prepare(model, is_rlhf=True)
86 | self.model.eval()
87 |
88 | def forward(
89 | self,
90 | sequences: torch.LongTensor,
91 | num_actions: int = None,
92 | attention_mask: Optional[torch.Tensor] = None,
93 | return_output=False,
94 | packed_seq_lens: Optional[list[int]] = None,
95 | visual_inputs: Optional[dict] = None,
96 | ) -> torch.Tensor:
97 | if visual_inputs is None:
98 | visual_inputs = {}
99 | device = torch.cuda.current_device()
100 | with torch.no_grad():
101 | visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()}
102 | log_probs = self.model(
103 | sequences.to(device),
104 | num_actions,
105 | attention_mask.to(device),
106 | return_output=return_output,
107 | packed_seq_lens=packed_seq_lens,
108 | visual_inputs=visual_inputs,
109 | )
110 | return log_probs.to("cpu")
111 |
112 | def empty_cache(self) -> None:
113 | torch.cuda.empty_cache()
114 |
115 |
116 | @ray.remote(num_gpus=1)
117 | class RewardModelRayActor(BasePPORole):
118 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain):
119 | self._setup_distributed(strategy)
120 | model = get_llm_for_sequence_regression(
121 | pretrain,
122 | "reward",
123 | normalize_reward=strategy.args.normalize_reward,
124 | use_flash_attention_2=strategy.args.flash_attn,
125 | bf16=strategy.args.bf16,
126 | load_in_4bit=strategy.args.load_in_4bit,
127 | ds_config=strategy.get_ds_eval_config(offload=strategy.args.ref_reward_offload),
128 | value_head_prefix=strategy.args.value_head_prefix,
129 | packing_samples=strategy.args.packing_samples,
130 | )
131 | strategy.print(model)
132 | strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward))
133 | strategy.print("mean: {}, std {}".format(model.mean, model.std))
134 |
135 | if strategy.args.ref_reward_offload:
136 | model._offload = True
137 |
138 | self.model = self.strategy.prepare(model, is_rlhf=True)
139 | self.model.eval()
140 |
141 | def forward(
142 | self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None, packed_seq_lens=None, visual_inputs: Optional[dict] = None,
143 | ) -> torch.Tensor:
144 | device = torch.cuda.current_device()
145 | if visual_inputs is None:
146 | visual_inputs = {}
147 | visual_inputs = {k:v.to(device) for k,v in visual_inputs.items()}
148 | with torch.no_grad():
149 | reward = self.model(sequences.to(device), attention_mask.to(device), packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs)
150 | return reward.to("cpu")
151 |
152 | def empty_cache(self) -> None:
153 | torch.cuda.empty_cache()
154 |
155 |
156 | class PPORayActorGroup:
157 | """
158 | A group of ray actors
159 | Functions start with 'async' should return list of object refs
160 |
161 | Args:
162 | num_nodes (int): Number of nodes for this actor group.
163 | num_gpus_per_node (int): Number of gpus for this actor group.
164 | ray_actor_type (Type[BasePPORole]): PPO model type that this actor group serve on.
165 | pg (PlacementGroup, optional): Placement group to schedule actor on.
166 | If none, create new placement group automatically. Defaults to None.
167 | num_gpus_per_actor (float, optional): Number of gpus allocated for each actor.
168 | If < 1.0, multiple models can share same gpu. Defaults to 1.
169 | """
170 |
171 | def __init__(
172 | self,
173 | num_nodes,
174 | num_gpus_per_node,
175 | ray_actor_type: Type[BasePPORole],
176 | pg: PlacementGroup = None,
177 | num_gpus_per_actor=1,
178 | resources: Dict[str, float] = None,
179 | num_resources_per_node: int = None,
180 | ) -> None:
181 | self._num_nodes = num_nodes
182 | self._num_gpus_per_node = num_gpus_per_node
183 | self.ray_actor_type = ray_actor_type
184 |
185 | # custom resources, see https://docs.ray.io/en/latest/ray-core/scheduling/resources.html
186 | self._resources = resources
187 | self._num_resources_per_node = num_resources_per_node
188 |
189 | self._initiate_actors(pg, num_gpus_per_actor)
190 |
191 | def _initiate_actors(self, pg, num_gpus_per_actor):
192 | world_size = self._num_nodes * self._num_gpus_per_node
193 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}')
194 | # Use placement group to lock resources for models of same type
195 | if self._num_gpus_per_node > 1 and pg is None:
196 | bundles = [{"GPU": 1, "CPU": 1} for _ in range(self._num_nodes * self._num_gpus_per_node)]
197 | if self._resources:
198 | resources_name = list(self._resources.keys())[0]
199 | for i in range(len(bundles)):
200 | bundles[i][resources_name] = self._num_resources_per_node
201 |
202 | pg = placement_group(bundles, strategy="PACK")
203 | ray.get(pg.ready())
204 | if pg:
205 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}, num_gpus_per_actor={num_gpus_per_actor}')
206 | master_actor = self.ray_actor_type.options(
207 | num_cpus=num_gpus_per_actor,
208 | num_gpus=num_gpus_per_actor,
209 | resources=self._resources,
210 | scheduling_strategy=PlacementGroupSchedulingStrategy(
211 | placement_group=pg, placement_group_bundle_index=0
212 | ),
213 | ).remote(world_size, 0, None, None)
214 | else:
215 | print(f'!!!! [config] worldsize={world_size}, num_nodes={self._num_nodes}, num_gpus_per_node={self._num_gpus_per_node}, placementgroup={pg}, num_gpus_per_actor={num_gpus_per_actor}')
216 | master_actor = self.ray_actor_type.options(
217 | num_cpus=num_gpus_per_actor,
218 | num_gpus=num_gpus_per_actor,
219 | resources=self._resources,
220 | ).remote(world_size, 0, None, None)
221 | self._actor_handlers = [master_actor]
222 |
223 | # Create worker actors
224 | if world_size > 1:
225 | master_addr, master_port = ray.get(master_actor.get_master_addr_port.remote())
226 | for rank in range(1, world_size):
227 | if pg:
228 | worker_actor = self.ray_actor_type.options(
229 | num_cpus=num_gpus_per_actor,
230 | num_gpus=num_gpus_per_actor,
231 | resources=self._resources,
232 | scheduling_strategy=PlacementGroupSchedulingStrategy(
233 | placement_group=pg,
234 | placement_group_bundle_index=rank,
235 | ),
236 | ).remote(world_size, rank, master_addr, master_port)
237 | else:
238 | worker_actor = self.ray_actor_type.options(
239 | num_cpus=num_gpus_per_actor,
240 | num_gpus=num_gpus_per_actor,
241 | resources=self._resources,
242 | ).remote(world_size, rank, master_addr, master_port)
243 | self._actor_handlers.append(worker_actor)
244 |
245 | def async_init_model_from_pretrained(
246 | self,
247 | *args,
248 | **kwargs,
249 | ):
250 | """Init model from pretrained checkpoint.
251 |
252 | Returns:
253 | List: list of remote object refs.
254 | """
255 | return [actor.init_model_from_pretrained.remote(*args, **kwargs) for actor in self._actor_handlers]
256 |
257 | def async_fit_actor_model(
258 | self,
259 | critic_model_group: "PPORayActorGroup",
260 | initial_model_group: "PPORayActorGroup",
261 | reward_model_groups: List["PPORayActorGroup"],
262 | remote_rm_urls: List[str] = None,
263 | reward_fn: Callable[[List[torch.Tensor]], torch.Tensor] = None,
264 | vllm_engines: List = None,
265 | ):
266 | """Train actor model.
267 |
268 | Args:
269 | critic_model_group (PPORayActorGroup): critic model group.
270 | initial_model_group (PPORayActorGroup): reference model group.
271 | reward_model_groups (PPORayActorGroup): reward model groups.
272 | remote_rm_urls: remote RM APIs.
273 | reward_fn: reward calculate function, must be specified if using multiple reward models.
274 | vllm_engines: vllm engines for text generation, if not specified, generate text by actor model directly.
275 |
276 | Returns:
277 | List: list of remote object refs.
278 | """
279 | assert (
280 | (remote_rm_urls and len(remote_rm_urls) == 1)
281 | or (reward_model_groups and len(reward_model_groups) == 1)
282 | or reward_fn is not None
283 | ), "reward_fn must be specified if using multiple reward models"
284 |
285 | critic_actors = critic_model_group._actor_handlers if critic_model_group else None
286 | initial_actors = initial_model_group._actor_handlers if initial_model_group else None
287 |
288 | refs = []
289 | # TODO(wuxibin): actor model choose critic/reward/initial model in a
290 | # round robin fashion, implement more efficient dispatching strategy.
291 | for i, actor in enumerate(self._actor_handlers):
292 | critic_actor = critic_actors[i % len(critic_actors)] if critic_actors else None
293 | initial_actor = initial_actors[i % len(initial_actors)] if initial_actors else None
294 |
295 | reward_actors = []
296 | if reward_model_groups:
297 | for reward_model_group in reward_model_groups:
298 | actors = reward_model_group._actor_handlers
299 | reward_actors.append(actors[i % len(actors)])
300 |
301 | refs.append(
302 | actor.fit.remote(
303 | critic_model=critic_actor,
304 | initial_model=initial_actor,
305 | reward_model=reward_actors,
306 | remote_rm_url=remote_rm_urls,
307 | reward_fn=reward_fn,
308 | vllm_engines=vllm_engines,
309 | # whether this actor should triger corresponding critic model training
310 | critic_train_remote=(i < len(critic_actors)) if critic_actor else None,
311 | )
312 | )
313 |
314 | return refs
315 |
316 |
317 | def async_save_model(self):
318 | """Save actor model on rank 0.
319 |
320 | Returns:
321 | List: list of remote object refs.
322 | """
323 | return [actor.save_model.remote() for actor in self._actor_handlers]
324 |
325 | def async_run_method(self, method_name, *args, **kwargs):
326 | refs = []
327 | for actor in self._actor_handlers:
328 | method = getattr(actor, method_name)
329 | refs.append(method.remote(*args, **kwargs))
330 | return refs
331 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/ppo_critic.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os
3 | from typing import Dict, Optional, Union
4 |
5 | import ray
6 | import torch
7 | from torch.utils.data import DataLoader
8 | from tqdm import tqdm
9 | from transformers.trainer import get_scheduler
10 |
11 | from openrlhf.models import get_llm_for_sequence_regression
12 | from openrlhf.trainer import PPOTrainer
13 | from openrlhf.trainer.ppo_utils import Experience
14 | from openrlhf.utils import get_tokenizer, get_vl_processor
15 | from openrlhf.utils.deepspeed import DeepspeedStrategy
16 |
17 | from .launcher import BasePPORole
18 |
19 |
20 | class CriticPPOTrainer(PPOTrainer):
21 | def ppo_train(self):
22 | # replay buffer may be empty at first, we should rebuild at each training
23 | dataloader = DataLoader(
24 | self.replay_buffer,
25 | batch_size=self.replay_buffer.sample_batch_size,
26 | shuffle=True,
27 | drop_last=True,
28 | pin_memory=self.dataloader_pin_memory,
29 | collate_fn=self.replay_buffer.collate_fn,
30 | )
31 | device = torch.cuda.current_device()
32 |
33 | status_list = []
34 | status_mean = {}
35 | for epoch in range(self.max_epochs):
36 | pbar = tqdm(
37 | dataloader,
38 | desc=f"Train epoch [{epoch + 1}/{self.max_epochs}]",
39 | disable=not self.strategy.is_rank_0(),
40 | )
41 | for experience in pbar:
42 | experience.to_device(device)
43 | status = self.training_step(experience)
44 |
45 | # for DP
46 | status = self.strategy.all_reduce(status)
47 |
48 | status_list.append(status)
49 | pbar.set_postfix(status)
50 |
51 | if status_list:
52 | status_mean = status_list[0]
53 | for m in status_list[1:]:
54 | for k, v in m.items():
55 | status_mean[k] += v
56 | for k in status_mean.keys():
57 | status_mean[k] /= len(status_list)
58 | return status_mean
59 |
60 | def training_step(self, experience: Experience) -> Dict[str, float]:
61 | return self.training_step_critic(experience)
62 |
63 |
64 | @ray.remote(num_gpus=1)
65 | class CriticModelRayActor(BasePPORole):
66 | def init_model_from_pretrained(self, strategy: DeepspeedStrategy, pretrain, max_steps):
67 | args = strategy.args
68 |
69 | self._setup_distributed(strategy)
70 | critic = get_llm_for_sequence_regression(
71 | pretrain,
72 | "critic",
73 | normalize_reward=strategy.args.normalize_reward,
74 | use_flash_attention_2=strategy.args.flash_attn,
75 | bf16=strategy.args.bf16,
76 | load_in_4bit=strategy.args.load_in_4bit,
77 | lora_rank=strategy.args.lora_rank,
78 | lora_alpha=strategy.args.lora_alpha,
79 | target_modules=strategy.args.target_modules,
80 | lora_dropout=strategy.args.lora_dropout,
81 | ds_config=strategy.get_ds_train_config(is_actor=False),
82 | value_head_prefix=strategy.args.value_head_prefix,
83 | init_value_head=strategy.args.pretrain == strategy.args.critic_pretrain,
84 | packing_samples=strategy.args.packing_samples,
85 | )
86 | strategy.print(critic)
87 | strategy.print("reward normalization status: {}".format(strategy.args.normalize_reward))
88 | strategy.print("mean: {}, std {}".format(critic.mean, critic.std))
89 |
90 | # configure optimizer
91 | critic_optim = strategy.create_optimizer(
92 | critic, lr=args.critic_learning_rate, betas=args.adam_betas, weight_decay=args.l2
93 | )
94 |
95 | # configure scheduler
96 | critic_scheduler = get_scheduler(
97 | "cosine_with_min_lr",
98 | critic_optim,
99 | num_warmup_steps=math.ceil(max_steps * args.lr_warmup_ratio),
100 | num_training_steps=max_steps,
101 | scheduler_specific_kwargs={"min_lr": args.critic_learning_rate * 0.1},
102 | )
103 |
104 | if args.gradient_checkpointing:
105 | critic.gradient_checkpointing_enable(
106 | gradient_checkpointing_kwargs={"use_reentrant": args.gradient_checkpointing_use_reentrant}
107 | )
108 |
109 | # prepare models/optimizers...
110 | self.critic, self.critic_optim, self.critic_scheduler = strategy.prepare(
111 | (critic, critic_optim, critic_scheduler),
112 | is_rlhf=True,
113 | )
114 |
115 | # load checkpoint
116 | if args.load_checkpoint and os.path.exists(os.path.join(args.ckpt_path, "_actor")):
117 | ckpt_path = os.path.join(args.ckpt_path, "_critic")
118 | strategy.load_ckpt(self.critic, ckpt_path)
119 | strategy.print(f"Loaded the checkpoint: {ckpt_path}")
120 |
121 | # configure Trainer
122 | # only use wandb at actor model
123 | strategy.args.use_wandb = False
124 | # configure tokenizer
125 | args = strategy.args
126 | if args.train_vlm:
127 | self.processor = get_vl_processor(
128 | pretrain, self.critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer
129 | )
130 | self.tokenizer = self.processor.tokenizer
131 | else:
132 | self.processor = None
133 | self.tokenizer = get_tokenizer(
134 | pretrain, self.critic, "left", strategy, use_fast=not strategy.args.disable_fast_tokenizer
135 | )
136 | self.trainer = CriticPPOTrainer(
137 | strategy,
138 | actor=None,
139 | critic=self.critic,
140 | reward_model=None,
141 | initial_model=None,
142 | ema_model=None,
143 | actor_optim=None,
144 | critic_optim=self.critic_optim,
145 | actor_scheduler=None,
146 | critic_scheduler=self.critic_scheduler,
147 | max_epochs=args.max_epochs,
148 | micro_train_batch_size=args.micro_train_batch_size,
149 | micro_rollout_batch_size=args.micro_rollout_batch_size,
150 | gradient_checkpointing=args.gradient_checkpointing,
151 | prompt_max_len=args.prompt_max_len,
152 | value_clip=args.value_clip,
153 | eps_clip=args.eps_clip,
154 | processor=self.processor,
155 | tokenizer=self.tokenizer
156 | )
157 |
158 | def forward(
159 | self,
160 | sequences: torch.LongTensor,
161 | num_actions: Optional[Union[int, list[int]]] = None,
162 | attention_mask: Optional[torch.Tensor] = None,
163 | packed_seq_lens=None,
164 | visual_inputs=None,
165 | ) -> torch.Tensor:
166 | """Generates critic values."""
167 | device = torch.cuda.current_device()
168 | self.critic.eval()
169 | if visual_inputs is None:
170 | visual_inputs = {}
171 | with torch.no_grad():
172 | visual_inputs = {k: v.to(device) for k, v in visual_inputs.items()}
173 | value = self.critic(
174 | sequences.to(device), num_actions, attention_mask.to(device), packed_seq_lens=packed_seq_lens, visual_inputs=visual_inputs
175 | )
176 | self.critic.train() # reset model state
177 | return value.to("cpu")
178 |
179 | def append(self, experience):
180 | """Append experience to replay buffer."""
181 | self.trainer.replay_buffer.append(experience)
182 |
183 | def fit(self):
184 | """Train critic model with the replay buffer."""
185 | torch.cuda.empty_cache()
186 | self.critic.train()
187 | status = self.trainer.ppo_train()
188 | self.trainer.replay_buffer.clear()
189 | torch.cuda.empty_cache()
190 | return status
191 |
192 | def empty_cache(self) -> None:
193 | torch.cuda.empty_cache()
194 |
195 | def save_model(self):
196 | args = self.strategy.args
197 |
198 | # save model checkpoint after fitting on only rank0
199 | if args.train_vlm:
200 | self.strategy.save_model(
201 | self.critic,
202 | self.processor,
203 | args.save_path + "_critic",
204 | )
205 | else:
206 | self.strategy.save_model(
207 | self.critic,
208 | self.tokenizer,
209 | args.save_path + "_critic",
210 | )
211 |
212 | def save_checkpoint(self, tag):
213 | args = self.strategy.args
214 | self.strategy.save_ckpt(
215 | self.critic, os.path.join(args.ckpt_path, "_critic"), tag, args.max_ckpt_num, args.max_ckpt_mem
216 | )
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def ray_noset_visible_devices(env_vars=os.environ):
5 | # Refer to
6 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96
7 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103
8 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95
9 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117
10 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109
11 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172
12 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98
13 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [
14 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES",
15 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES",
16 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES",
17 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES",
18 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES",
19 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS",
20 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR",
21 | ]
22 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST)
23 |
24 |
25 | def get_physical_gpu_id():
26 | import torch
27 |
28 | device = torch.cuda.current_device()
29 | props = torch.cuda.get_device_properties(device)
30 | return str(props.uuid)
31 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/vllm_engine.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import ray
5 | from ray.util.placement_group import placement_group
6 | from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
7 | from vllm import LLM
8 |
9 | from openrlhf.utils.logging_utils import init_logger
10 |
11 | logger = init_logger(__name__)
12 |
13 |
14 | @ray.remote
15 | def get_all_env_variables():
16 | import os
17 |
18 | return os.environ
19 |
20 |
21 | @ray.remote
22 | class LLMRayActor:
23 |
24 | def __init__(self, *args, bundle_indices: list = None, **kwargs):
25 | if kwargs.get("distributed_executor_backend") == "ray":
26 | # a hack to make the script work.
27 | # stop ray from manipulating CUDA_VISIBLE_DEVICES
28 | # at the top-level when the distributed_executor_backend is ray.
29 | os.environ.pop("CUDA_VISIBLE_DEVICES", None)
30 | # every worker will use 0.2 GPU, so that we can schedule
31 | # 2 instances on the same GPUs.
32 | if bundle_indices is not None:
33 | os.environ["VLLM_RAY_PER_WORKER_GPUS"] = "0.2"
34 | os.environ["VLLM_RAY_BUNDLE_INDICES"] = ",".join(map(str, bundle_indices))
35 | print(f"creating LLM with bundle_indices={bundle_indices}")
36 |
37 | # Number of actors that will send prompt to this engine
38 | self.num_actors = kwargs.pop("num_actors")
39 | self.actor_counter = 0
40 | self.requests = {}
41 | self.responses = {}
42 |
43 | self.llm = LLM(*args, **kwargs)
44 |
45 | def init_process_group(self, master_address, master_port, rank_offset, world_size, group_name, backend, use_ray):
46 | return self.llm.collective_rpc(
47 | "init_process_group",
48 | args=(master_address, master_port, rank_offset, world_size, group_name, backend, use_ray),
49 | )
50 |
51 | def update_weight(self, name, dtype, shape, empty_cache=False):
52 | return self.llm.collective_rpc("update_weight", args=(name, dtype, shape, empty_cache))
53 |
54 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles, empty_cache=False):
55 | return self.llm.collective_rpc("update_weight_cuda_ipc", args=(name, dtype, shape, ipc_handles, empty_cache))
56 |
57 | def reset_prefix_cache(self):
58 | self.llm.llm_engine.reset_prefix_cache()
59 |
60 | def sleep(self, level=1):
61 | self.llm.sleep(level=level)
62 |
63 | def wake_up(self):
64 | self.llm.wake_up()
65 |
66 | def add_requests(self, actor_rank, *, sampling_params, prompt_token_ids):
67 | """
68 | Save the requests from actors and generate responses when all actors have sent their requests
69 | """
70 | self.requests[actor_rank] = prompt_token_ids
71 | self.actor_counter += 1
72 | if self.actor_counter == self.num_actors:
73 | assert len(self.requests) == self.num_actors
74 | num_requests = []
75 | requests = []
76 | for actor_rank, request in self.requests.items():
77 | num_requests.append((actor_rank, len(request)))
78 | requests.extend(request)
79 |
80 | if len(requests) > 0:
81 | # For now we assume that all requests have the same sampling params
82 | responses = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=requests)
83 | else:
84 | responses = []
85 |
86 | offset = 0
87 | self.responses = {}
88 | for actor_rank, num in num_requests:
89 | self.responses[actor_rank] = responses[offset : offset + num]
90 | offset += num
91 |
92 | self.actor_counter = 0
93 | self.requests = {}
94 |
95 | def add_requests_vlm(self, actor_rank, *, sampling_params, vllm_vision_input):
96 | """
97 | Save the requests from actors and generate responses when all actors have sent their requests
98 | """
99 | self.requests[actor_rank] = vllm_vision_input
100 | self.actor_counter += 1
101 | if self.actor_counter == self.num_actors:
102 | assert len(self.requests) == self.num_actors, f"{len(self.requests)} != {self.num_actors}"
103 | num_requests = []
104 | requests = []
105 | for actor_rank, request in self.requests.items():
106 | num_requests.append((actor_rank, len(request)))
107 | requests.extend(request)
108 |
109 | if len(requests) > 0:
110 | # For now we assume that all requests have the same sampling params
111 | responses = self.llm.generate(requests, sampling_params=sampling_params)
112 | else:
113 | responses = []
114 |
115 | offset = 0
116 | self.responses = {}
117 | for actor_rank, num in num_requests:
118 | self.responses[actor_rank] = responses[offset : offset + num]
119 | offset += num
120 |
121 | self.actor_counter = 0
122 | self.requests = {}
123 |
124 | def add_requests_vlm_mix(self, actor_rank, *, sampling_params, vllm_vision_input):
125 | """
126 | Save the requests from actors and generate responses when all actors have sent their requests
127 | """
128 | self.requests[actor_rank] = vllm_vision_input
129 | self.actor_counter += 1
130 | if self.actor_counter == self.num_actors:
131 | assert len(self.requests) == self.num_actors, f"{len(self.requests)} != {self.num_actors}"
132 | num_requests = []
133 | requests = []
134 | vrall, trall = [], []
135 | vrsrc, trsrc = [], []
136 | self.responses = {}
137 | for actor_rank, request in self.requests.items():
138 | vreq, treq = request
139 | if vreq:
140 | vrall.extend(vreq)
141 | vrsrc.extend([actor_rank] * len(vreq))
142 | # vresponses = self.llm.generate(vreq, sampling_params=sampling_params)
143 | # print('!!!! debug vr', type(vresponses))
144 | # else:
145 | # vresponses = []
146 | if treq:
147 | trall.extend(treq)
148 | trsrc.extend([actor_rank] * len(treq))
149 | # tresponses = self.llm.generate(treq, sampling_params=sampling_params)
150 | # print('!!!! debug tr', type(tresponses))
151 |
152 | vresponses = self.llm.generate(vrall, sampling_params=sampling_params)
153 | tresponses = self.llm.generate(sampling_params=sampling_params, prompt_token_ids=trall)
154 | for actor_rank, request in self.requests.items():
155 | self.responses[actor_rank] = []
156 | for rank, rsp in zip(vrsrc, vresponses):
157 | self.responses[rank].append(rsp)
158 | for rank, rsp in zip(trsrc, tresponses):
159 | self.responses[rank].append(rsp)
160 | print('debug inside vllm engine')
161 |
162 | self.actor_counter = 0
163 | self.requests = {}
164 |
165 | def get_responses(self, actor_rank):
166 | """
167 | Return the responses for the actor with the given rank
168 | """
169 | return self.responses.pop(actor_rank)
170 |
171 |
172 | def create_vllm_engines(
173 | num_engines: int,
174 | tensor_parallel_size: int,
175 | pretrain: str,
176 | seed: int,
177 | enable_prefix_caching: bool,
178 | enforce_eager: bool,
179 | max_model_len: int,
180 | num_total_actors: int,
181 | shared_pg=None,
182 | gpu_memory_utilization=None,
183 | vllm_enable_sleep=False,
184 | ):
185 | import vllm
186 |
187 | assert vllm.__version__ >= "0.7.0", "OpenRLHF only supports vllm >= 0.7.0"
188 |
189 | vllm_engines = []
190 | num_gpus = int(tensor_parallel_size == 1)
191 | distributed_executor_backend = "uni" if tensor_parallel_size == 1 else "ray"
192 | for i in range(num_engines):
193 | bundle_indices = None
194 | scheduling_strategy = None
195 |
196 | # Hybrid engine
197 | if shared_pg is not None:
198 | assert vllm.__version__ >= "0.7.2", "Only vllm >= 0.7.2 supports hybrid engine"
199 |
200 | if tensor_parallel_size > 1:
201 | scheduling_strategy = PlacementGroupSchedulingStrategy(
202 | placement_group=shared_pg,
203 | placement_group_capture_child_tasks=True,
204 | placement_group_bundle_index=i * tensor_parallel_size
205 | )
206 | bundle_indices = np.arange(i * tensor_parallel_size, (i + 1) * tensor_parallel_size).tolist()
207 | else:
208 | num_gpus = 0.2
209 | scheduling_strategy = PlacementGroupSchedulingStrategy(
210 | placement_group=shared_pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=i
211 | )
212 | # Distributed RLHF
213 | elif tensor_parallel_size > 1:
214 | bundles = [{"GPU": 1, "CPU": 1}] * tensor_parallel_size
215 | pg = placement_group(bundles)
216 | ray.get(pg.ready())
217 |
218 | scheduling_strategy = PlacementGroupSchedulingStrategy(
219 | placement_group=pg, placement_group_capture_child_tasks=True, placement_group_bundle_index=0
220 | )
221 |
222 | if num_engines >= num_total_actors:
223 | num_actors = 1
224 | else:
225 | num_actors = num_total_actors // num_engines + int(i < num_total_actors % num_engines)
226 |
227 | vllm_engines.append(
228 | LLMRayActor.options(
229 | num_cpus=0,
230 | num_gpus=num_gpus,
231 | scheduling_strategy=scheduling_strategy,
232 | ).remote(
233 | model=pretrain,
234 | enforce_eager=enforce_eager,
235 | worker_cls="openrlhf.trainer.ray.vllm_worker_wrap.WorkerWrap",
236 | tensor_parallel_size=tensor_parallel_size,
237 | seed=seed + i,
238 | distributed_executor_backend=distributed_executor_backend,
239 | max_model_len=max_model_len,
240 | enable_prefix_caching=enable_prefix_caching,
241 | dtype="bfloat16",
242 | trust_remote_code=True,
243 | num_actors=num_actors,
244 | gpu_memory_utilization=gpu_memory_utilization,
245 | bundle_indices=bundle_indices if shared_pg else None,
246 | enable_sleep_mode=vllm_enable_sleep,
247 | limit_mm_per_prompt={"image": 8}
248 | )
249 | )
250 |
251 | return vllm_engines
252 |
--------------------------------------------------------------------------------
/openrlhf/trainer/ray/vllm_worker_wrap.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from vllm.worker.worker import Worker
3 |
4 | from openrlhf.utils.distributed_util import init_process_group
5 | from openrlhf.utils.logging_utils import init_logger
6 | from .utils import get_physical_gpu_id
7 |
8 | logger = init_logger(__name__)
9 |
10 |
11 | class WorkerWrap(Worker):
12 | def init_process_group(
13 | self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False
14 | ):
15 | """Init torch process group for model weights update"""
16 | assert torch.distributed.is_initialized(), f"default torch process group must be initialized"
17 | assert group_name != "", f"group name must not be empty"
18 |
19 | rank = torch.distributed.get_rank() + rank_offset
20 | if use_ray:
21 | import ray.util.collective as collective
22 |
23 | collective.init_collective_group(world_size=world_size, rank=rank, backend=backend, group_name=group_name)
24 | self._model_update_group = group_name
25 | else:
26 | self._model_update_group = init_process_group(
27 | backend=backend,
28 | init_method=f"tcp://{master_address}:{master_port}",
29 | world_size=world_size,
30 | rank=rank,
31 | group_name=group_name,
32 | )
33 | self._model_update_with_ray = use_ray
34 | print(
35 | f"init_process_group: master_address={master_address}, master_port={master_port}, ",
36 | f"rank={rank}, world_size={world_size}, group_name={group_name}",
37 | )
38 |
39 | def update_weight(self, name, dtype, shape, empty_cache=False):
40 | """Broadcast weight to all vllm workers from source rank 0 (actor model)"""
41 | if torch.distributed.get_rank() == 0:
42 | print(f"[vllm broadcast] update weight: {name}, dtype: {dtype}, shape: {shape}")
43 |
44 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
45 | weight = torch.empty(shape, dtype=dtype, device="cuda")
46 | if self._model_update_with_ray:
47 | import ray.util.collective as collective
48 |
49 | collective.broadcast(weight, 0, group_name=self._model_update_group)
50 | else:
51 | torch.distributed.broadcast(weight, 0, group=self._model_update_group)
52 |
53 | self.model_runner.model.load_weights(weights=[(name, weight)])
54 |
55 | del weight
56 | # TODO: should we empty cache if all weights have updated?
57 | # if empty_cache:
58 | # torch.cuda.empty_cache()
59 |
60 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False):
61 | if torch.distributed.get_rank() == 0:
62 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}")
63 |
64 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
65 |
66 | handle = ipc_handles[get_physical_gpu_id()]
67 | device_id = self.device.index
68 | func, args = handle
69 | list_args = list(args)
70 | # the key is to change device id to the current device id
71 | # in case two processes have different CUDA_VISIBLE_DEVICES
72 | list_args[6] = device_id
73 | weight = func(*list_args)
74 | self.model_runner.model.load_weights(weights=[(name, weight)])
75 | torch.cuda.synchronize()
76 |
--------------------------------------------------------------------------------
/openrlhf/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .processor import get_processor, reward_normalization
2 | from .utils import blending_datasets, get_strategy, get_tokenizer, get_vl_processor
3 |
4 | __all__ = [
5 | "get_processor",
6 | "reward_normalization",
7 | "blending_datasets",
8 | "get_strategy",
9 | "get_tokenizer",
10 | "get_vl_processor",
11 | ]
12 |
--------------------------------------------------------------------------------
/openrlhf/utils/deepspeed/__init__.py:
--------------------------------------------------------------------------------
1 | from .deepspeed import DeepspeedStrategy
2 |
3 | __all__ = [
4 | "DeepspeedStrategy",
5 | ]
6 |
--------------------------------------------------------------------------------
/openrlhf/utils/deepspeed/deepspeed.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import shutil
4 | from abc import ABC
5 | from collections import defaultdict
6 | from datetime import timedelta
7 | from typing import List, Tuple, Union
8 |
9 | import deepspeed
10 | import numpy as np
11 | import torch
12 | import torch.nn as nn
13 | import torch.optim as optim
14 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam
15 | from peft import PeftModel, get_peft_model_state_dict
16 | from torch import distributed as dist
17 | from torch.optim import Optimizer
18 | from torch.utils.data import DataLoader
19 |
20 | from openrlhf.models import Actor
21 | from openrlhf.models.ring_attn_utils import get_ring_attn_group, set_ring_attn_group
22 | from openrlhf.utils.distributed_sampler import DistributedSampler
23 |
24 | from .deepspeed_utils import (
25 | _z3_params_to_fetch,
26 | get_eval_ds_config,
27 | get_optimizer_grouped_parameters,
28 | get_train_ds_config,
29 | )
30 |
31 | ModelOptimPair = Tuple[nn.Module, Optimizer]
32 | ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
33 |
34 |
35 | class DeepspeedStrategy(ABC):
36 | """
37 | The strategy for training with Accelerator.
38 | """
39 |
40 | def __init__(
41 | self,
42 | seed: int = 42,
43 | max_norm: float = 0.0,
44 | micro_train_batch_size=1,
45 | train_batch_size=1,
46 | zero_stage=2,
47 | bf16=True,
48 | args=None,
49 | ) -> None:
50 | super().__init__()
51 |
52 | self.args = args
53 | self.stage = zero_stage
54 | self.train_batch_size = train_batch_size
55 | self.micro_train_batch_size = micro_train_batch_size
56 | self.bf16 = bf16
57 | self.seed = seed
58 | self.max_norm = max_norm
59 | self.adam_offload = getattr(args, "adam_offload", False)
60 | self.param_offload = getattr(args, "param_offload", False)
61 | self.zpg = getattr(args, "zpg", 1)
62 | self.grad_accum_dtype = getattr(args, "grad_accum_dtype", None)
63 | # overlap_comm
64 | self.overlap_comm = getattr(args, "overlap_comm", False)
65 |
66 | self.is_rlhf = False
67 | self.time_steps = defaultdict(int)
68 |
69 | def set_seed(self, seed: int) -> None:
70 | random.seed(seed)
71 | np.random.seed(seed)
72 | torch.manual_seed(seed)
73 | torch.cuda.manual_seed_all(seed)
74 |
75 | def setup_distributed(self, timeout=timedelta(minutes=60)) -> None:
76 | self.set_seed(self.seed)
77 |
78 | if self.args.local_rank == -1 and "LOCAL_RANK" in os.environ: # for slurm
79 | self.args.local_rank = int(os.environ["LOCAL_RANK"])
80 |
81 | if self.args.local_rank != -1:
82 | torch.cuda.set_device(self.args.local_rank)
83 | print('!!!! setting up distributed')
84 | # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
85 | deepspeed.init_distributed(timeout=timeout)
86 | self.setup_ring_attn()
87 | self.world_size = dist.get_world_size()
88 | self.accumulated_gradient = (
89 | self.train_batch_size * self.ring_attn_size // self.micro_train_batch_size // self.world_size
90 | )
91 |
92 | def setup_ring_attn(self):
93 | self.ring_attn_size = getattr(self.args, "ring_attn_size", 1)
94 | if self.ring_attn_size == 1:
95 | self.ring_attn_rank = 0
96 | return
97 |
98 | ring_head_stride = getattr(self.args, "ring_head_stride", 1)
99 | for i in range(dist.get_world_size() // self.ring_attn_size):
100 | ring_attn_ranks = list(
101 | range(
102 | i * self.ring_attn_size,
103 | (i + 1) * self.ring_attn_size,
104 | )
105 | )
106 | group = dist.new_group(ranks=ring_attn_ranks, backend="nccl")
107 | if dist.get_rank() in ring_attn_ranks:
108 | set_ring_attn_group(group)
109 | self.ring_attn_rank = dist.get_rank(group=group)
110 |
111 | from ring_flash_attn import substitute_hf_flash_attn
112 |
113 | substitute_hf_flash_attn(self.ring_attn_group, ring_head_stride)
114 |
115 | @property
116 | def ring_attn_group(self):
117 | return get_ring_attn_group()
118 |
119 | def create_optimizer(self, model, **kwargs) -> Optimizer:
120 | if isinstance(model, Actor):
121 | model = model.model
122 | # Optimizer
123 | AdamOptimizer = DeepSpeedCPUAdam if self.adam_offload else FusedAdam
124 | optim_params = get_optimizer_grouped_parameters(model, kwargs["weight_decay"])
125 | optim = AdamOptimizer(optim_params, **kwargs)
126 | return optim
127 |
128 | def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
129 | if isinstance(model, Actor):
130 | model = model.model
131 | model.backward(loss)
132 |
133 | def optimizer_step(
134 | self,
135 | optimizer: optim.Optimizer,
136 | model: nn.Module,
137 | scheduler,
138 | name="model",
139 | **kwargs,
140 | ) -> None:
141 | if isinstance(model, Actor):
142 | model = model.model
143 | model.step()
144 |
145 | def setup_dataloader(
146 | self,
147 | replay_buffer,
148 | batch_size: int,
149 | pin_memory: bool = False,
150 | shuffle=True,
151 | collate_fn=None,
152 | drop_last=True,
153 | sampler=None,
154 | consumed_samples=0,
155 | ):
156 | # DDP only mode, replay buffers on each rank are different.
157 | if sampler is None:
158 | num_replicas = dist.get_world_size() // self.ring_attn_size
159 | rank = dist.get_rank() // self.ring_attn_size
160 | sampler = DistributedSampler(
161 | replay_buffer,
162 | num_replicas=num_replicas,
163 | rank=rank,
164 | shuffle=shuffle,
165 | seed=self.seed,
166 | drop_last=drop_last,
167 | consumed_samples=consumed_samples,
168 | )
169 |
170 | return DataLoader(
171 | replay_buffer,
172 | batch_size=batch_size,
173 | sampler=sampler,
174 | drop_last=drop_last,
175 | collate_fn=collate_fn,
176 | pin_memory=pin_memory,
177 | )
178 |
179 | def _unwrap_model(self, model) -> nn.Module:
180 | if isinstance(model, Actor):
181 | return self._unwrap_model(model.model)
182 | elif hasattr(model, "module"):
183 | return model.module
184 | else:
185 | return model
186 |
187 | def prepare(
188 | self, *models_or_model_optim_pairs: ModelOrModelOptimPair, is_rlhf=False
189 | ) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
190 | ret = []
191 | self.is_rlhf = is_rlhf
192 | for arg in models_or_model_optim_pairs:
193 | if isinstance(arg, tuple):
194 | assert len(arg) == 3, f'Expect (model, optimizer, scheduler) pair, got a tuple with size "{len(arg)}"'
195 | if arg[0] is not None:
196 | ret.append(self._ds_init_train_model(*arg))
197 | else:
198 | ret.append((None, None, None))
199 | else:
200 | ret.append(self._ds_init_eval_model(arg))
201 |
202 | return ret[0] if len(ret) == 1 else ret
203 |
204 | def _ds_init_train_model(self, model, optim, scheduler):
205 | is_actor = isinstance(model, Actor)
206 | ds_config = self.get_ds_train_config(is_actor)
207 |
208 | engine, optim, _, scheduler = deepspeed.initialize(
209 | model=model.model if is_actor else model,
210 | optimizer=optim,
211 | lr_scheduler=scheduler,
212 | config=ds_config,
213 | args={"local_rank": self.args.local_rank},
214 | dist_init_required=True,
215 | )
216 | if is_actor:
217 | model.model = engine
218 | else:
219 | model = engine
220 |
221 | return model, optim, scheduler
222 |
223 | def get_ds_train_config(self, is_actor):
224 | # DS Config
225 | ds_config = get_train_ds_config(
226 | offload=self.param_offload,
227 | adam_offload=self.adam_offload,
228 | stage=self.stage,
229 | bf16=self.bf16,
230 | max_norm=self.max_norm,
231 | zpg=self.zpg,
232 | grad_accum_dtype=self.grad_accum_dtype,
233 | overlap_comm=self.overlap_comm,
234 | )
235 | print('!!!! ds config', ds_config)
236 |
237 | ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size
238 | train_batch_size = self.train_batch_size
239 | # corner case for ptx loss (backward twice)
240 | if self.is_rlhf and is_actor and self.args.pretrain_data is not None:
241 | train_batch_size *= 2
242 | ds_config["train_batch_size"] = train_batch_size * self.ring_attn_size
243 |
244 | return ds_config
245 |
246 | def _ds_init_eval_model(self, model):
247 | if not model:
248 | return model
249 | is_actor = isinstance(model, Actor)
250 | ds_config = self.get_ds_eval_config(offload=getattr(model, "_offload", False))
251 |
252 | engine, *_ = deepspeed.initialize(
253 | model=model.model if is_actor else model,
254 | args={"local_rank": self.args.local_rank},
255 | config=ds_config,
256 | dist_init_required=True,
257 | )
258 | if is_actor:
259 | model.model = engine
260 | else:
261 | model = engine
262 | return model
263 |
264 | def get_ds_eval_config(self, offload=False):
265 | # DS Config
266 | ds_config = get_eval_ds_config(offload=offload, stage=self.stage if self.stage == 3 else 0, bf16=self.bf16)
267 | ds_config["train_micro_batch_size_per_gpu"] = self.micro_train_batch_size
268 | ds_config["train_batch_size"] = self.train_batch_size * self.ring_attn_size
269 |
270 | return ds_config
271 |
272 | def moving_average(self, model, model_ema, beta=0.992, device="cpu"):
273 | self.time_steps["ema"] += 1
274 | if self.time_steps["ema"] % self.accumulated_gradient == 0:
275 | with torch.no_grad():
276 | for param, param_ema in zip(model.parameters(), model_ema.parameters()):
277 | if param.requires_grad:
278 | if self.stage != 3:
279 | data = param.data.to(device)
280 | param_ema.data.copy_((1 - beta) * data + beta * param_ema.data)
281 | else:
282 | # TODO: use prefiltering for efficiency
283 | params_to_fetch = _z3_params_to_fetch([param, param_ema])
284 | with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0):
285 | data = param.data.to(device)
286 | param_ema.data.copy_((1 - beta) * data + beta * param_ema.data)
287 |
288 | def load_model(
289 | self,
290 | model: nn.Module,
291 | path: str,
292 | map_location="cpu",
293 | strict: bool = False,
294 | key_replace_fn=None,
295 | ) -> None:
296 | unwrapped_model = self._unwrap_model(model)
297 | state_dict = torch.load(path, map_location=map_location)
298 | if key_replace_fn:
299 | state_dict = key_replace_fn(state_dict)
300 | unwrapped_model.load_state_dict(state_dict, strict=strict)
301 |
302 | def save_model(self, model: nn.Module, tokenizer, output_dir, **kwargs) -> None:
303 | if self.is_rank_0():
304 | os.makedirs(output_dir, exist_ok=True)
305 | print('!!!! [saving] model', model)
306 | torch.distributed.barrier()
307 | # save model weights for ZeRO2/3
308 | model_to_save = self._unwrap_model(model)
309 |
310 | # gather parameters
311 | output_state_dict = {}
312 | dist.barrier()
313 | for k, v in model_to_save.named_parameters():
314 |
315 | # only gather z3 params
316 | params_to_fetch = _z3_params_to_fetch([v])
317 | with deepspeed.zero.GatheredParameters(params_to_fetch, enabled=len(params_to_fetch) > 0):
318 | vv = v.data.cpu()
319 | if self.is_rank_0():
320 | output_state_dict[k] = vv
321 | print(f"!!!! [saving] named_parameters after gather, {k}:{v.shape}")
322 |
323 | if self.is_rank_0():
324 | # print('!!!! after named_parameters', sorted(list(output_state_dict.keys())))
325 | state_dict = model_to_save.state_dict()
326 |
327 | # copy named_buffers with `persistent=True`
328 | for k, v in model_to_save.named_buffers():
329 | if k not in state_dict:
330 | continue
331 | # print(f"!!!! [saving] named_buffers, {k}:{v.shape}")
332 | vv = v.data.cpu()
333 | output_state_dict[k] = vv
334 | # print('!!!! after named_buffers', sorted(list(output_state_dict.keys())))
335 |
336 | for k in output_state_dict:
337 | v = output_state_dict[k]
338 | # print(f'!!!! [saving] {k}:{v.shape}')
339 | if v.size(0) == 0:
340 | print(f"!!!! [saving] {k} is empty")
341 | # exit(-1)
342 |
343 | state_dict_keys = set(state_dict.keys())
344 | output_state_dict_keys = set(output_state_dict.keys())
345 |
346 | # corner case for tie_word_embeddings, such as Qwen2-0.5B
347 | if getattr(model_to_save.config, "tie_word_embeddings", False) and "lm_head.weight" in state_dict_keys:
348 | state_dict_keys.remove("lm_head.weight")
349 |
350 | assert state_dict_keys.issubset(
351 | output_state_dict_keys
352 | ), f"mismatch keys {output_state_dict_keys.symmetric_difference(state_dict_keys)}"
353 |
354 | # only save peft weights https://github.com/microsoft/DeepSpeed/issues/4295
355 | if isinstance(model_to_save, PeftModel):
356 | model_to_save.save_pretrained(output_dir, **kwargs)
357 | if self.stage == 3:
358 | torch.save(
359 | get_peft_model_state_dict(model_to_save, output_state_dict),
360 | os.path.join(output_dir, "adapter_model.bin"),
361 | )
362 | filename = os.path.join(output_dir, "adapter_model.safetensors")
363 | if os.path.exists(filename):
364 | os.remove(filename)
365 | else:
366 | # save model
367 | model_to_save.save_pretrained(output_dir, state_dict=output_state_dict, **kwargs)
368 |
369 | # save config
370 | output_config_file = os.path.join(output_dir, "config.json")
371 | model_to_save.config.to_json_file(output_config_file)
372 | # save tokenizer
373 | tokenizer.save_pretrained(output_dir)
374 |
375 | # for models not in AutoModel, copy python module files
376 | train_from_model_path = model_to_save.config._name_or_path
377 | if os.path.exists(train_from_model_path):
378 | for filename in os.listdir(train_from_model_path):
379 | if filename.endswith(".py"):
380 | shutil.copy(os.path.join(train_from_model_path, filename), os.path.join(output_dir, filename))
381 |
382 | def all_reduce(self, data, op="mean"):
383 | assert op in ("mean", "max", "sum")
384 | if isinstance(data, dict):
385 | ret = {}
386 | for k, v in data.items():
387 | ret[k] = self.all_reduce(v, op)
388 | return ret
389 | else:
390 | is_tensor = True
391 | if not isinstance(data, torch.Tensor):
392 | data = torch.Tensor([data])
393 | is_tensor = False
394 | is_cpu_tensor = data.device.type == "cpu"
395 |
396 | if is_cpu_tensor:
397 | data = data.to(torch.cuda.current_device())
398 | if op == "mean":
399 | data /= self.world_size
400 | dist.all_reduce(data, op=dist.ReduceOp.MAX if op == "max" else dist.ReduceOp.SUM)
401 | if is_cpu_tensor:
402 | data = data.cpu()
403 | return data.item() if not is_tensor else data
404 |
405 | def all_gather(self, data):
406 | if isinstance(data, dict):
407 | ret = {}
408 | for k, v in data.items():
409 | ret[k] = self.all_gather(v)
410 | return ret
411 | else:
412 | if not isinstance(data, torch.Tensor):
413 | data = torch.Tensor([data])
414 | is_cpu_tensor = data.device.type == "cpu"
415 |
416 | ret = [torch.zeros_like(data).to(torch.cuda.current_device()) for _ in range(self.world_size)]
417 | dist.all_gather(ret, data.to(torch.cuda.current_device()))
418 | return torch.cat(ret).cpu() if is_cpu_tensor else torch.cat(ret)
419 |
420 | def print(self, *msg):
421 | if self.is_rank_0():
422 | print(*msg)
423 |
424 | def is_rank_0(self) -> bool:
425 | return dist.get_rank() == 0
426 |
427 | def get_rank(self) -> int:
428 | return dist.get_rank()
429 |
430 | def save_ckpt(self, model, save_dir, tag=None, max_num=3, max_mem=1000, client_state={}, save_latest=True):
431 | assert isinstance(model, deepspeed.DeepSpeedEngine)
432 | if self.is_rank_0():
433 | os.makedirs(save_dir, exist_ok=True)
434 | MAX_SIZE = max_mem * 1024**3 # Convert GB to bytes
435 |
436 | while True:
437 | subdirs = sorted(
438 | [
439 | (os.path.join(save_dir, d), os.path.getmtime(os.path.join(save_dir, d)))
440 | for d in os.listdir(save_dir)
441 | if os.path.isdir(os.path.join(save_dir, d))
442 | ],
443 | key=lambda x: x[1],
444 | )
445 | total_size = sum(
446 | os.path.getsize(os.path.join(dirpath, f))
447 | for subdir, _ in subdirs
448 | for dirpath, _, filenames in os.walk(subdir)
449 | for f in filenames
450 | )
451 |
452 | if len(subdirs) >= max_num or total_size > MAX_SIZE:
453 | oldest_dir = subdirs[0][0]
454 | if os.path.exists(oldest_dir):
455 | shutil.rmtree(oldest_dir)
456 | self.print(f"Deleted oldest ckpt {oldest_dir}")
457 | else:
458 | break
459 |
460 | dist.barrier()
461 | model.save_checkpoint(save_dir, tag=tag, client_state=client_state, save_latest=save_latest)
462 |
463 | def load_ckpt(
464 | self,
465 | model,
466 | load_dir,
467 | tag=None,
468 | load_module_strict=True,
469 | load_optimizer_states=True,
470 | load_lr_scheduler_states=True,
471 | load_module_only=False,
472 | ):
473 | assert isinstance(model, deepspeed.DeepSpeedEngine)
474 | load_path, states = model.load_checkpoint(
475 | load_dir,
476 | tag,
477 | load_module_strict=load_module_strict,
478 | load_optimizer_states=load_optimizer_states,
479 | load_lr_scheduler_states=load_lr_scheduler_states,
480 | load_module_only=load_module_only,
481 | )
482 | if load_path is None:
483 | raise Exception(f"[deepspeed] failed to resume from checkpoint {load_dir}")
484 | return load_path, states
485 |
--------------------------------------------------------------------------------
/openrlhf/utils/deepspeed/deepspeed_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) Microsoft Corporation.
2 | # SPDX-License-Identifier: Apache-2.0
3 |
4 | # DeepSpeed Team
5 |
6 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus
7 |
8 |
9 | def get_train_ds_config(
10 | offload,
11 | adam_offload=True,
12 | stage=2,
13 | bf16=True,
14 | max_norm=1.0,
15 | zpg=8,
16 | grad_accum_dtype=None,
17 | overlap_comm=False,
18 | ):
19 | device = "cpu" if offload else "none"
20 | zero_opt_dict = {
21 | "stage": stage,
22 | "offload_param": {"device": device},
23 | "offload_optimizer": {
24 | "device": "cpu" if adam_offload else "none",
25 | "pin_memory": True
26 | # "pin_memory": False,
27 | # "ratio": 0.9,
28 | },
29 | "sub_group_size": "auto",
30 | "stage3_max_live_parameters": "auto",
31 | "stage3_max_reuse_distance": "auto",
32 | "stage3_param_persistence_threshold": "auto",
33 | "stage3_prefetch_bucket_size": "auto",
34 | "reduce_bucket_size": "auto",
35 | # ZeRO++
36 | "zero_hpz_partition_size": zpg,
37 | "zero_quantized_weights": False,
38 | "zero_quantized_gradients": False,
39 | }
40 | if overlap_comm:
41 | zero_opt_dict["overlap_comm"] = True
42 | zero_opt_dict["contiguous_gradients"] = True
43 |
44 | return {
45 | "steps_per_print": 100,
46 | "zero_optimization": zero_opt_dict,
47 | "bf16": {
48 | "enabled": bf16,
49 | },
50 | "gradient_clipping": max_norm,
51 | "prescale_gradients": False,
52 | "wall_clock_breakdown": False,
53 | "data_types": {"grad_accum_dtype": grad_accum_dtype},
54 | }
55 |
56 |
57 | def get_eval_ds_config(
58 | offload,
59 | stage=0,
60 | bf16=True,
61 | ):
62 | zero_opt_dict = {
63 | "stage": stage,
64 | "stage3_param_persistence_threshold": "auto",
65 | "offload_param": {
66 | "device": "cpu" if offload else "none",
67 | "pin_memory": True,
68 | },
69 | }
70 | return {
71 | "steps_per_print": 100,
72 | "zero_optimization": zero_opt_dict,
73 | "bf16": {
74 | "enabled": bf16,
75 | },
76 | "gradient_clipping": 1.0,
77 | "prescale_gradients": False,
78 | "wall_clock_breakdown": False,
79 | }
80 |
81 |
82 | def get_optimizer_grouped_parameters(
83 | model,
84 | weight_decay,
85 | no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"],
86 | ):
87 | optimizer_grouped_parameters = [
88 | {
89 | "params": [
90 | p
91 | for n, p in model.named_parameters()
92 | if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad)
93 | ],
94 | "weight_decay": weight_decay,
95 | },
96 | {
97 | "params": [
98 | p
99 | for n, p in model.named_parameters()
100 | if (any(nd in n for nd in no_decay_name_list) and p.requires_grad)
101 | ],
102 | "weight_decay": 0.0,
103 | },
104 | ]
105 | return optimizer_grouped_parameters
106 |
107 |
108 | def _z3_params_to_fetch(param_list):
109 | return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE]
110 |
--------------------------------------------------------------------------------
/openrlhf/utils/distributed_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from typing import Iterator, Optional, TypeVar
3 |
4 | import torch
5 | import torch.distributed as dist
6 | from torch.utils.data.dataset import Dataset
7 | from torch.utils.data.sampler import Sampler
8 |
9 |
10 | __all__ = ["DistributedSampler"]
11 |
12 |
13 | _T_co = TypeVar("_T_co", covariant=True)
14 |
15 |
16 | # Adapted from https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/torch/utils/data/distributed.py
17 | class DistributedSampler(Sampler[_T_co]):
18 | r"""Sampler that restricts data loading to a subset of the dataset.
19 |
20 | It is especially useful in conjunction with
21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each
22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a
23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the
24 | original dataset that is exclusive to it.
25 |
26 | .. note::
27 | Dataset is assumed to be of constant size and that any instance of it always
28 | returns the same elements in the same order.
29 |
30 | Args:
31 | dataset: Dataset used for sampling.
32 | num_replicas (int, optional): Number of processes participating in
33 | distributed training. By default, :attr:`world_size` is retrieved from the
34 | current distributed group.
35 | rank (int, optional): Rank of the current process within :attr:`num_replicas`.
36 | By default, :attr:`rank` is retrieved from the current distributed
37 | group.
38 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the
39 | indices.
40 | seed (int, optional): random seed used to shuffle the sampler if
41 | :attr:`shuffle=True`. This number should be identical across all
42 | processes in the distributed group. Default: ``0``.
43 | drop_last (bool, optional): if ``True``, then the sampler will drop the
44 | tail of the data to make it evenly divisible across the number of
45 | replicas. If ``False``, the sampler will add extra indices to make
46 | the data evenly divisible across the replicas. Default: ``False``.
47 |
48 | .. warning::
49 | In distributed mode, calling the :meth:`set_epoch` method at
50 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator
51 | is necessary to make shuffling work properly across multiple epochs. Otherwise,
52 | the same ordering will be always used.
53 |
54 | Example::
55 |
56 | >>> # xdoctest: +SKIP
57 | >>> sampler = DistributedSampler(dataset) if is_distributed else None
58 | >>> loader = DataLoader(dataset, shuffle=(sampler is None),
59 | ... sampler=sampler)
60 | >>> for epoch in range(start_epoch, n_epochs):
61 | ... if is_distributed:
62 | ... sampler.set_epoch(epoch)
63 | ... train(loader)
64 | """
65 |
66 | def __init__(
67 | self,
68 | dataset: Dataset,
69 | num_replicas: Optional[int] = None,
70 | rank: Optional[int] = None,
71 | shuffle: bool = True,
72 | seed: int = 0,
73 | drop_last: bool = False,
74 | consumed_samples=0,
75 | ) -> None:
76 | if num_replicas is None:
77 | if not dist.is_available():
78 | raise RuntimeError("Requires distributed package to be available")
79 | num_replicas = dist.get_world_size()
80 | if rank is None:
81 | if not dist.is_available():
82 | raise RuntimeError("Requires distributed package to be available")
83 | rank = dist.get_rank()
84 | if rank >= num_replicas or rank < 0:
85 | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]")
86 | self.dataset = dataset
87 | self.num_replicas = num_replicas
88 | self.rank = rank
89 | self.epoch = 0
90 | self.drop_last = drop_last
91 | # If the dataset length is evenly divisible by # of replicas, then there
92 | # is no need to drop any data, since the dataset will be split equally.
93 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
94 | # Split to nearest available length that is evenly divisible.
95 | # This is to ensure each rank receives the same amount of data when
96 | # using this Sampler.
97 | self.num_samples = math.ceil(
98 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
99 | )
100 | else:
101 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type]
102 | self.total_size = self.num_samples * self.num_replicas
103 | self.shuffle = shuffle
104 | self.seed = seed
105 | self.consumed_indicies = consumed_samples // self.num_replicas
106 |
107 | def __iter__(self) -> Iterator[_T_co]:
108 | if self.shuffle:
109 | # deterministically shuffle based on epoch and seed
110 | g = torch.Generator()
111 | g.manual_seed(self.seed + self.epoch)
112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type]
113 | else:
114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type]
115 |
116 | if not self.drop_last:
117 | # add extra samples to make it evenly divisible
118 | padding_size = self.total_size - len(indices)
119 | if padding_size <= len(indices):
120 | indices += indices[:padding_size]
121 | else:
122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
123 | else:
124 | # remove tail of data to make it evenly divisible.
125 | indices = indices[: self.total_size]
126 | assert len(indices) == self.total_size
127 |
128 | # subsample
129 | indices = indices[self.rank : self.total_size : self.num_replicas]
130 | # skip consumed_samples
131 | indices = indices[self.consumed_indicies :]
132 | assert len(indices) == self.num_samples - self.consumed_indicies
133 |
134 | return iter(indices)
135 |
136 | def __len__(self) -> int:
137 | return self.num_samples - self.consumed_indicies
138 |
139 | def set_epoch(self, epoch: int, consumed_samples=0) -> None:
140 | r"""
141 | Set the epoch for this sampler.
142 |
143 | When :attr:`shuffle=True`, this ensures all replicas
144 | use a different random ordering for each epoch. Otherwise, the next iteration of this
145 | sampler will yield the same ordering.
146 |
147 | Args:
148 | epoch (int): Epoch number.
149 | """
150 | self.epoch = epoch
151 | self.consumed_indicies = consumed_samples // self.num_replicas
152 |
--------------------------------------------------------------------------------
/openrlhf/utils/distributed_util.py:
--------------------------------------------------------------------------------
1 | from datetime import timedelta
2 | from typing import Any, Optional, Union
3 |
4 | import torch
5 | import torch.distributed
6 | from torch.distributed.distributed_c10d import (
7 | Backend,
8 | PrefixStore,
9 | Store,
10 | _new_process_group_helper,
11 | _world,
12 | default_pg_timeout,
13 | rendezvous,
14 | )
15 |
16 |
17 | # Copy from pytorch to allow creating multiple main groups.
18 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py
19 | def init_process_group(
20 | backend: Union[str, Backend] = None,
21 | init_method: Optional[str] = None,
22 | timeout: Optional[timedelta] = None,
23 | world_size: int = -1,
24 | rank: int = -1,
25 | store: Optional[Store] = None,
26 | group_name: str = None,
27 | pg_options: Optional[Any] = None,
28 | ):
29 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store."
30 |
31 | if store is not None:
32 | assert world_size > 0, "world_size must be positive if using store"
33 | assert rank >= 0, "rank must be non-negative if using store"
34 | elif init_method is None:
35 | init_method = "env://"
36 |
37 | if backend:
38 | backend = Backend(backend)
39 | else:
40 | backend = Backend("undefined")
41 |
42 | if timeout is None:
43 | timeout = default_pg_timeout
44 |
45 | # backward compatible API
46 | if store is None:
47 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout)
48 | store, rank, world_size = next(rendezvous_iterator)
49 | store.set_timeout(timeout)
50 |
51 | # Use a PrefixStore to avoid accidental overrides of keys used by
52 | # different systems (e.g. RPC) in case the store is multi-tenant.
53 | store = PrefixStore(group_name, store)
54 |
55 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0
56 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844
57 | # We need to determine the appropriate parameter name based on PyTorch version
58 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options"
59 | pg, _ = _new_process_group_helper(
60 | world_size,
61 | rank,
62 | [],
63 | backend,
64 | store,
65 | group_name=group_name,
66 | **{pg_options_param_name: pg_options},
67 | timeout=timeout,
68 | )
69 |
70 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)}
71 |
72 | return pg
73 |
--------------------------------------------------------------------------------
/openrlhf/utils/logging_utils.py:
--------------------------------------------------------------------------------
1 | # Adapted from
2 | # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py
3 | """Logging configuration for vLLM."""
4 | import logging
5 | import sys
6 |
7 | _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
8 | _DATE_FORMAT = "%m-%d %H:%M:%S"
9 |
10 |
11 | class NewLineFormatter(logging.Formatter):
12 | """Adds logging prefix to newlines to align multi-line messages."""
13 |
14 | def __init__(self, fmt, datefmt=None):
15 | logging.Formatter.__init__(self, fmt, datefmt)
16 |
17 | def format(self, record):
18 | msg = logging.Formatter.format(self, record)
19 | if record.message != "":
20 | parts = msg.split(record.message)
21 | msg = msg.replace("\n", "\r\n" + parts[0])
22 | return msg
23 |
24 |
25 | _root_logger = logging.getLogger("openrlhf")
26 | _default_handler = None
27 |
28 |
29 | def _setup_logger():
30 | _root_logger.setLevel(logging.DEBUG)
31 | global _default_handler
32 | if _default_handler is None:
33 | _default_handler = logging.StreamHandler(sys.stdout)
34 | _default_handler.flush = sys.stdout.flush # type: ignore
35 | _default_handler.setLevel(logging.INFO)
36 | _root_logger.addHandler(_default_handler)
37 | fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT)
38 | _default_handler.setFormatter(fmt)
39 | # Setting this will avoid the message
40 | # being propagated to the parent logger.
41 | _root_logger.propagate = False
42 |
43 |
44 | # The logger is initialized when the module is imported.
45 | # This is thread-safe as the module is only imported once,
46 | # guaranteed by the Python GIL.
47 | _setup_logger()
48 |
49 |
50 | def init_logger(name: str):
51 | # Use the same settings as above for root logger
52 | logger = logging.getLogger(name)
53 | logger.setLevel(logging.DEBUG)
54 | logger.addHandler(_default_handler)
55 | logger.propagate = False
56 | return logger
57 |
--------------------------------------------------------------------------------
/openrlhf/utils/processor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from tqdm import tqdm
3 |
4 |
5 | def reward_normalization(objs):
6 | rewards = [float(obj["reward"]) for obj in objs]
7 | rewards = torch.tensor(rewards, dtype=torch.float64)
8 | rewards = (rewards - rewards.mean()) / rewards.std()
9 | for i, obj in enumerate(objs):
10 | obj["reward"] = rewards[i].item()
11 |
12 |
13 | # Conditional SFT
14 | # See https://arxiv.org/abs/2308.12050
15 | DEFAULT_REWARD_PROMPT = "{input} : {reward} "
16 |
17 |
18 | def conditional_sft_processor(args, objs):
19 | if "reward_template" not in args or args.reward_template is None:
20 | reward_template = DEFAULT_REWARD_PROMPT
21 | else:
22 | reward_template = args.reward_template
23 | assert "{input}" in reward_template
24 | assert "{reward}" in reward_template
25 |
26 | if args.normalize_reward:
27 | reward_normalization(objs)
28 |
29 | for obj in tqdm(objs, desc="Conditional SFT process..."):
30 | input = obj["input"]
31 | reward = "{:.2f}".format(float(obj["reward"]))
32 | input = reward_template.replace("{reward}", reward).replace("{input}", input)
33 | obj["input"] = input
34 |
35 | return objs
36 |
37 |
38 | # Rejection Sampling
39 | # See https://arxiv.org/abs/2307.09288
40 | def rejection_sampling_processor(args, objs):
41 | out = {}
42 | for obj in tqdm(objs, desc="Rejection Sampling process...."):
43 | input = obj["input"]
44 | output = obj["output"]
45 | reward = float(obj["reward"])
46 |
47 | if input not in out:
48 | out[input] = {"output": output, "reward": reward}
49 | elif reward > out[input]["reward"]:
50 | out[input]["reward"] = reward
51 | out[input]["output"] = output
52 |
53 | return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()]
54 |
55 |
56 | # Iterative DPO
57 | # See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh
58 | def iterative_dpo_processor(args, objs):
59 | out = {}
60 | for obj in tqdm(objs, desc="Iterative DPO process...."):
61 | input = obj["input"]
62 | output = obj["output"]
63 | reward = float(obj["reward"])
64 |
65 | if input not in out:
66 | out[input] = {
67 | "output": output,
68 | "chosen": output,
69 | "chosen_reward": reward,
70 | "rejected": output,
71 | "rejected_reward": reward,
72 | }
73 | elif reward > out[input]["chosen_reward"]:
74 | out[input]["chosen_reward"] = reward
75 | out[input]["chosen"] = output
76 | elif reward < out[input]["rejected_reward"]:
77 | out[input]["rejected_reward"] = reward
78 | out[input]["rejected"] = output
79 |
80 | return [
81 | {
82 | "prompt": k,
83 | "chosen": v["chosen"],
84 | "chosen_reward": v["chosen_reward"],
85 | "rejected": v["rejected"],
86 | "rejected_reward": v["rejected_reward"],
87 | }
88 | for k, v in out.items()
89 | ]
90 |
91 |
92 | PROCESSORS = {
93 | "rs": rejection_sampling_processor,
94 | "csft": conditional_sft_processor,
95 | "iter_dpo": iterative_dpo_processor,
96 | }
97 |
98 |
99 | def get_processor(name):
100 | if name in PROCESSORS:
101 | return PROCESSORS[name]
102 | else:
103 | raise ValueError(f"Processor {name} does not exist.")
104 |
--------------------------------------------------------------------------------
/openrlhf/utils/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from datasets import interleave_datasets, load_dataset, load_from_disk
4 | from transformers import AutoTokenizer, AutoProcessor, AutoModel
5 |
6 |
7 | def get_vl_processor(pretrain, model, padding_side="left", strategy=None, use_fast=True):
8 | # TODO: Maybe better max_pixels set methods for other vl model
9 | # follow qwen-vl2.5 https://github.com/QwenLM/Qwen2.5-VL?tab=readme-ov-file#image-resolution-for-performance-boost
10 | min_pixels = int(os.getenv("MIN_PIXELS", 256*28*28))
11 | max_pixels = int(os.getenv("MAX_PIXELS", 1280*28*28))
12 | processor = AutoProcessor.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast, min_pixels=min_pixels, max_pixels=max_pixels)
13 | tokenizer = processor.tokenizer
14 | tokenizer.padding_side = padding_side
15 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
16 | # https://github.com/facebookresearch/llama-recipes/pull/196
17 | if tokenizer.pad_token is None:
18 | tokenizer.pad_token = tokenizer.eos_token
19 | tokenizer.pad_token_id = tokenizer.eos_token_id
20 | if model is not None:
21 | model.config.pad_token_id = tokenizer.pad_token_id
22 | return processor
23 |
24 | def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True):
25 | tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast)
26 | tokenizer.padding_side = padding_side
27 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM.
28 | # https://github.com/facebookresearch/llama-recipes/pull/196
29 | if tokenizer.pad_token is None:
30 | tokenizer.pad_token = tokenizer.eos_token
31 | tokenizer.pad_token_id = tokenizer.eos_token_id
32 | model.config.pad_token_id = tokenizer.pad_token_id
33 |
34 | return tokenizer
35 |
36 |
37 | def get_strategy(args):
38 | from openrlhf.utils.deepspeed import DeepspeedStrategy
39 |
40 | strategy = DeepspeedStrategy(
41 | seed=getattr(args, "seed", 42),
42 | max_norm=getattr(args, "max_norm", 1.0),
43 | micro_train_batch_size=getattr(args, "micro_train_batch_size", 1),
44 | train_batch_size=getattr(args, "train_batch_size", 128),
45 | zero_stage=args.zero_stage,
46 | bf16=getattr(args, "bf16", True),
47 | args=args,
48 | )
49 | return strategy
50 |
51 |
52 | def blending_datasets(
53 | datasets,
54 | probabilities,
55 | strategy=None,
56 | seed=42,
57 | max_count=5000000,
58 | return_eval=True,
59 | stopping_strategy="first_exhausted",
60 | train_split="train",
61 | eval_split="test",
62 | ):
63 | datasets = datasets.split(",")
64 | probabilities = list(map(float, probabilities.split(",")))
65 | assert len(probabilities) == len(datasets)
66 |
67 | train_data_list = []
68 | eval_data_list = []
69 | for i, dataset in enumerate(datasets):
70 | dataset = dataset.strip()
71 | strategy.print(f"dataset: {dataset}")
72 | dp = dataset
73 | data_dir = dataset.split("@")[1].strip() if "@" in dataset else None
74 | dataset = dataset.split("@")[0].strip()
75 | dataset_basename = os.path.basename(dataset)
76 |
77 | ext = os.path.splitext(dataset)[-1]
78 | # local python script
79 | if ext == ".py" or (
80 | os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py"))
81 | ):
82 | data = load_dataset(dataset, trust_remote_code=True)
83 | strategy.print(f"loaded {dataset} with python script")
84 | # local text file
85 | elif ext in [".json", ".jsonl", ".csv"]:
86 | ext = ext.lower().strip(".")
87 | if ext == "jsonl":
88 | ext = "json"
89 | data = load_dataset(ext, data_files=dataset)
90 | strategy.print(f"loaded {dataset} with data_files={dataset}")
91 | elif dp.endswith('parquet'):
92 | strategy.print(f"loaded parquet: {dp} from files")
93 | data = load_dataset("parquet", data_files=dp)
94 |
95 | # local dataset saved with `datasets.Dataset.save_to_disk`
96 | # elif os.path.isdir(dataset):
97 | # data = load_from_disk(dataset)
98 | # strategy.print(f"loaded {dataset} from disk")
99 | # # remote/local folder or common file
100 | else:
101 | data = load_dataset(dataset, data_dir=data_dir)
102 | strategy.print(f"loaded {dataset} from files")
103 | print(data)
104 | if train_split and train_split in data:
105 | train_data = data[train_split].select(range(min(max_count, len(data[train_split]))))
106 | else:
107 | train_data = data.select(range(min(max_count, len(data))))
108 | train_data_list.append(train_data)
109 |
110 | if return_eval:
111 | if eval_split and eval_split in data:
112 | eval_data = data[eval_split].select(range(min(max_count, len(data[eval_split]))))
113 | # train will contains eval? TODO
114 | else:
115 | eval_data = train_data.select(range(min(max_count, int(len(train_data) * 0.03))))
116 | eval_data_list.append(eval_data)
117 |
118 | # merge datasets
119 | if strategy.is_rank_0():
120 | print(train_data_list)
121 |
122 | train_dataset = interleave_datasets(
123 | train_data_list,
124 | probabilities=probabilities,
125 | seed=seed,
126 | stopping_strategy=stopping_strategy,
127 | )
128 | if return_eval:
129 | eval_dataset = interleave_datasets(
130 | eval_data_list,
131 | probabilities=probabilities,
132 | seed=seed,
133 | stopping_strategy=stopping_strategy,
134 | )
135 | return train_dataset, eval_dataset
136 | else:
137 | return train_dataset
138 |
139 |
140 | def convert_token_to_id(token, tokenizer):
141 | if isinstance(token, str):
142 | token = tokenizer.encode(token, add_special_tokens=False)
143 | assert len(token) == 1
144 | return token[0]
145 | else:
146 | raise ValueError("token should be int or str")
147 |
148 | def get_generation_cls(config):
149 | model_type = config.model_type
150 | model_arch = AutoModel._model_mapping[type(config)].__name__
151 | if model_arch.endswith("ForCausalLM") or \
152 | model_arch.endswith("ForConditionalGeneration"):
153 | return AutoModel._model_mapping[type(config)]
154 | elif model_arch.endswith("Model"):
155 | possible_arch = [model_arch.replace("Model", "ForCausalLM"), model_arch.replace("Model", "ForConditionalGeneration")]
156 | import importlib
157 | module = importlib.import_module(f".models.{model_type}.modeling_{model_type}",package="transformers")
158 | for arch in possible_arch:
159 | model_cls = getattr(module, arch, None)
160 | if model_cls is not None:
161 | return model_cls
162 | raise ValueError(f"Cannot find ForCausalLM or ForConditionalGeneration class for {model_arch}")
163 | else:
164 | raise ValueError(f"Unexpected model architecture {model_arch}")
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [build-system]
2 | requires = [
3 | "packaging",
4 | "setuptools >= 49.4.0",
5 | "wheel",
6 | ]
7 | build-backend = "setuptools.build_meta"
8 |
9 | [tool.isort]
10 | profile = "black" # black-compatible
11 | line_length = 119 # should match black parameters
12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style
13 | py_version = 310 # python 3.10 as a target version
14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
15 | default_section = "THIRDPARTY"
16 | extend_skip = ["setup.py", "docs/source/conf.py"]
17 |
18 |
19 | [tool.black]
20 | line_length = 119
21 |
22 | [tool.ruff]
23 | line-length = 119
24 |
25 | [tool.pytest.ini_options]
26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one.
27 | # -vv will also display tests with durration = 0.00s
28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest
29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module
30 | # directories to ignore when discovering tests
31 | norecursedirs = [
32 | "external",
33 | "examples",
34 | "docs",
35 | "scripts",
36 | "tools",
37 | "tutorials",
38 | "*.egg",
39 | ".*",
40 | "_darcs",
41 | "build",
42 | "CVS",
43 | "dist",
44 | "venv",
45 | "{arch}",
46 | ]
47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests
48 | markers = [
49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')",
50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')",
51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')",
52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')",
53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')",
54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups",
55 | "pleasefixme: marks tests that are broken and need fixing",
56 | ]
57 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate
2 | bitsandbytes
3 | datasets
4 | deepspeed==0.15
5 | einops
6 | flask
7 | isort
8 | jsonlines
9 | loralib
10 | math-verify
11 | levenshtein
12 | optimum
13 | packaging
14 | peft
15 | pynvml>=12.0.0
16 | qwen_vl_utils
17 | ray[default]==2.42.0
18 | tensorboard
19 | torch
20 | torchmetrics
21 | tqdm
22 | transformers @ git+https://github.com/huggingface/transformers@main
23 | transformers_stream_generator
24 | wandb
25 | wheel
26 |
--------------------------------------------------------------------------------
/scripts/eval_7b.sh:
--------------------------------------------------------------------------------
1 | benchmark=m3u
2 | if [[ "$benchmark" == "m3u" ]]; then
3 | export testdata="./data/MMMUPro_full.parquet"
4 | elif [[ "$benchmark" == "m3u_val" ]]; then
5 | export testdata="./data/m3u_val.parquet"
6 | elif [[ "$benchmark" == "emma" ]]; then
7 | export factor=4
8 | export testdata="./data/emma_full.parquet"
9 | elif [[ "$benchmark" == "mathverse" ]]; then
10 | export testdata="./data/MathVerse_testmini.parquet"
11 | elif [[ "$benchmark" == "mathvista" ]]; then
12 | export testdata=./data/MathVista_testmini.parquet
13 | elif [[ "$benchmark" == "mathvision" ]]; then
14 | export testdata="./data/MathVision_test3040.parquet"
15 | else
16 | export testdata="./data/${benchmark}.parquet"
17 | fi
18 |
19 | export num_vllm=8
20 | export num_gpus=8
21 | export tagname=eval_debug_${benchmark}
22 | export policy=/path/to/policy
23 | export nvj_path=""
24 | export working_dir=/path/to/dir
25 | bash ./scripts/eval_vlm_new.sh
26 |
27 |
--------------------------------------------------------------------------------
/scripts/eval_vlm_new.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 | RAY_MASTER_NODE_ADDRESS="0.0.0.0"
4 | RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-65535)
5 | WORLD_SIZE=1
6 | NODE_RANK=0
7 | GPUS_PER_NODE=8
8 |
9 | MASTER_HOST="$VC_WORKER_HOSTS"
10 | MASTER_ADDR="${VC_WORKER_HOSTS%%,*}"
11 | # export NCCL_SOCKET_IFNAME=ens2f5
12 | # export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME}
13 | export NCCL_NET_PLUGIN=none
14 | export NCCL_IB_TIMEOUT=22
15 | export NCCL_IB_RETRY_CNT=15
16 | export NCCL_DEBUG=INFO
17 | export CUDA_LAUNCH_BLOCKING=1
18 | export HOST_IP=0.0.0.0
19 | export VLLM_HOST_IP=0.0.0.0
20 | export WANDB_MODE="offline"
21 | export WANDB_API_KEY="null"
22 | working_dir=${working_dir:"/path/to/VL-Rethinker"}
23 | cd $working_dir
24 | export HF_ENDPOINT=https://hf-mirror.com
25 | nnode=$WORLD_SIZE
26 | testdata=${testdata:-"none"}
27 | num_vllm=${num_vllm:-"4"}
28 | num_gpus=${num_gpus:-"4"}
29 | tp=${tp:-"1"}
30 | actor_ngpus=${actor_ngpus:-"1"}
31 | nsamples=${nsamples:-"1"}
32 | temperature=${temperature:-"0.6"}
33 | factor=${factor:-"1"}
34 | export MIN_PIXELS=$(( 256 * 28 * 28))
35 | export MAX_PIXELS=$(( 1280 * 28 * 28))
36 | tag=${tagname} # -n${nsamples}
37 | rule_reward=${rule:-"none"}
38 | sys=${sys:-"default"}
39 | lr=${lr:-"10"}
40 | algo=${algo:-"group"}
41 | dataver=${dataver:-"none"}
42 | util=${util:-"0.7"}
43 |
44 | numref=0
45 |
46 | maxlen=${maxlen:-"8192"}
47 | policy=${policy:-"/path/to/policy"}
48 | save_name="${tag}" # rbsize 1024->256
49 | DATASET=${testdata}
50 | MODEL_CPK_NAME=${save_name}
51 | PRETRAIN_MODEL=${policy}
52 | savefolder=${savefolder:-"eval_results"}
53 | SAVE_PATH=$working_dir/${savefolder}/$save_name
54 | mkdir -p "${SAVE_PATH}"
55 |
56 | # python=/home/ma-user/anaconda3/envs/rethinker/bin/python
57 | # source /home/ma-user/anaconda3/bin/activate
58 | # conda activate rethinker
59 |
60 |
61 |
62 | post_args=""
63 | if [ $nnode -gt 1 ]; then
64 | if [ $nnode -gt 3 ]; then
65 | post_args=(--ref_num_nodes 0
66 | --ref_num_gpus_per_node 8
67 | --actor_num_nodes 16
68 | --actor_num_gpus_per_node 1
69 | --vllm_num_engines 16
70 | --vllm_tensor_parallel_size 1
71 | --micro_train_batch_size 4
72 | --train_batch_size 256
73 | --micro_rollout_batch_size 8
74 | --rollout_batch_size 1024
75 | )
76 | else
77 | post_args=(--ref_num_nodes 0
78 | --ref_num_gpus_per_node 8
79 | --actor_num_nodes 8
80 | --actor_num_gpus_per_node 1
81 | --vllm_num_engines 8
82 | --vllm_tensor_parallel_size 1
83 | --micro_train_batch_size 4
84 | --train_batch_size 256
85 | --micro_rollout_batch_size 8
86 | --rollout_batch_size 1024
87 | )
88 | fi
89 | else
90 | post_args=(--ref_num_nodes 0
91 | --ref_num_gpus_per_node 8
92 | --actor_num_nodes 0
93 | --actor_num_gpus_per_node ${actor_ngpus}
94 | --vllm_num_engines ${num_vllm}
95 | --vllm_tensor_parallel_size ${tp}
96 | --adam_offload
97 | --micro_train_batch_size 4
98 | --train_batch_size 256
99 | --micro_rollout_batch_size $(( 64 * ${num_vllm} / ${nsamples} / ${factor}))
100 | --rollout_batch_size 1024
101 | )
102 | fi
103 |
104 | LD_LIBRARY_PATH_VALUE=$nvj_path:$LD_LIBRARY_PATH
105 |
106 | RUNTIME_ENV_JSON="{\"env_vars\": {\"RAY_DEBUG\": \"legacy\", \"LD_LIBRARY_PATH\": \"$LD_LIBRARY_PATH_VALUE\"}}"
107 |
108 |
109 | ray_output=$(ray start --head --num-gpus ${num_gpus})
110 |
111 |
112 | ray status
113 | ray job submit --address="http://127.0.0.1:8265" \
114 | --runtime-env-json="$RUNTIME_ENV_JSON" \
115 | -- python3 -m openrlhf.cli.eval_ray \
116 | --vllm_enable_sleep \
117 | --vllm_gpu_memory_utilization ${util} \
118 | --vllm_sync_backend gloo \
119 | --enable_prefix_caching \
120 | --pretrain $PRETRAIN_MODEL \
121 | --save_path $SAVE_PATH \
122 | --n_samples_per_prompt ${nsamples} \
123 | --max_epochs 1 \
124 | --num_episodes 3 \
125 | --prompt_max_len 2048 \
126 | --max_samples 100000 \
127 | --generate_max_len ${maxlen} \
128 | --advantage_estimator ${algo} \
129 | --zero_stage 3 \
130 | --bf16 \
131 | --actor_learning_rate ${lr}e-7 \
132 | --rule_reward ${rule_reward} \
133 | --temperature 1.0 \
134 | --top_p 0.95 \
135 | --init_kl_coef 0.0 \
136 | --aux_loss_coef 0.05 \
137 | --entropy_loss_coef 0.0 \
138 | --prompt_data $DATASET \
139 | --input_key question \
140 | --apply_chat_template \
141 | --normalize_reward \
142 | --data_version ${dataver} \
143 | --flash_attn \
144 | --gradient_checkpointing \
145 | --ckpt_path $SAVE_PATH \
146 | --save_steps 5 \
147 | --max_ckpt_num 5 \
148 | --save_hf_ckpt \
149 | --disable_ds_ckpt \
150 | --use_wandb $WANDB_API_KEY \
151 | --wandb_run_name $save_name \
152 | --system_prompt ${sys} \
153 | --use_kl_estimator_k3 \
154 | --wandb_project vlm-rl-eval \
155 | --buffer_norm 0 \
156 | --train_vlm \
157 | --training_mode eval_only \
158 | --eval_data ${testdata} \
159 | ${post_args[@]}
--------------------------------------------------------------------------------
/scripts/train_vlm_multi.sh:
--------------------------------------------------------------------------------
1 | set -x
2 |
3 |
4 | find_interface() {
5 | local ip_output=$(ip addr show | head -n 10) # Limit to first 10 lines
6 | local selected_interface=""
7 |
8 | # Debug output (can be removed in final version)
9 | # echo "--- First 10 lines of ip addr show output: ---"
10 | # echo "$ip_output"
11 | # echo "--- End of ip addr show output ---"
12 |
13 | while IFS= read -r line; do
14 | # Debug output (can be removed in final version)
15 | # echo "Processing line: $line"
16 |
17 | if [[ "$line" =~ ^[0-9]+:\ ([^:]+):\ \<.*UP.*\> ]]; then
18 | local interface_name="${BASH_REMATCH[1]}"
19 | # Debug output (can be removed in final version)
20 | # echo " Interface found: $interface_name"
21 | local interface_up=true
22 | local is_loopback=false
23 |
24 | if [[ "$interface_name" == "lo" ]]; then
25 | is_loopback=true
26 | # Debug output (can be removed in final version)
27 | # echo " Interface '$interface_name' is loopback. Skipping."
28 | fi
29 |
30 | if $is_loopback; then
31 | continue # Skip loopback interface
32 | fi
33 |
34 | # Look for inet lines within this interface block
35 | while IFS= read -r subnet_line; do
36 | # Debug output (can be removed in final version)
37 | # echo " Processing subnet line: $subnet_line"
38 | if [[ "$subnet_line" =~ inet\ ([0-9]+\.[0-9]+\.[0-9]+\.[0-9]+)/([0-9]+)\ .*scope\ ([^ ]+) ]]; then
39 | local ip_address="${BASH_REMATCH[1]}"
40 | local scope="${BASH_REMATCH[3]}"
41 | # Debug output (can be removed in final version)
42 | # echo " Found inet line: IP Address: $ip_address, Scope: $scope"
43 |
44 | # Exclude loopback IPs and docker0/bridge related IPs by IP range
45 | if [[ "$ip_address" =~ ^127\. ]]; then
46 | # Debug output (can be removed in final version)
47 | # echo " IP '$ip_address' is loopback. Skipping."
48 | continue # Skip 127.0.0.0/8 loopback IPs (although 'lo' should already be skipped)
49 | elif [[ "$ip_address" =~ ^169\.254\. ]]; then
50 | # Debug output (can be removed in final version)
51 | # echo " IP '$ip_address' is link-local (169.254.x.x). Skipping."
52 | continue # Skip 169.254.0.0/16 link-local IPs (like docker0 often has)
53 | fi
54 |
55 | local is_private_ip=false
56 | if [[ "$ip_address" =~ ^10\.([0-9]{1,3}\.){2}[0-9]{1,3}$ ]] ||
57 | [[ "$ip_address" =~ ^172\.(1[6-9]|2[0-9]|3[0-1])\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]] ||
58 | [[ "$ip_address" =~ ^192\.168\.([0-9]{1,3}\.){1}[0-9]{1,3}$ ]]; then
59 | is_private_ip=true
60 | # Debug output (can be removed in final version)
61 | # echo " IP '$ip_address' is a private IP."
62 | # else
63 | # Debug output (can be removed in final version)
64 | # echo " IP '$ip_address' is NOT a private IP."
65 | fi
66 |
67 | if $is_private_ip || [[ "$scope" == "global" ]]; then # Consider private or global scope interfaces
68 | selected_interface="$interface_name"
69 | # Debug output (can be removed in final version)
70 | # echo " Interface '$interface_name' with IP '$ip_address' and scope '$scope' is selected."
71 | # echo "export GLOO_SOCKET_IFNAME=$selected_interface"
72 | # exit 0 # Exit immediately after finding the first suitable interface for debugging (removed for function)
73 | break 2 # Found a suitable interface! Break out of both inner and outer loops
74 | # else
75 | # Debug output (can be removed in final version)
76 | # echo " Interface '$interface_name' with IP '$ip_address' and scope '$scope' is NOT suitable (not private or global)."
77 | fi
78 | fi
79 | done < <(echo "$ip_output" | sed -n "/$interface_name: /,/^[0-9]\+:/p" | sed '$d' ) # Extract lines belonging to current interface block
80 | if [[ -n "$selected_interface" ]]; then # Check if selected_interface is not empty, if so, interface found and loops broken.
81 | # Debug output (can be removed in final version)
82 | # echo " Selected interface '$selected_interface' already found. Breaking outer loop."
83 | break # Already found and assigned an interface, break outer loop as well.
84 | fi
85 | # else
86 | # Debug output (can be removed in final version)
87 | # echo " Line does not match interface pattern."
88 | fi
89 | done < <(echo "$ip_output")
90 |
91 | if [[ -n "$selected_interface" ]]; then
92 | echo "$selected_interface"
93 | else
94 | echo "" # Return empty string if no interface is found, so export GLOO_SOCKET_IFNAME= (empty)
95 | # echo "No suitable network interface could be automatically identified for GLOO_SOCKET_IFNAME." # No longer print error message to stderr in function context
96 | # return 1 # Optionally, you could return a non-zero exit code if you need to check for failure.
97 | fi
98 | }
99 |
100 | MULTINODE_FLAG=True
101 | if [ -v MULTINODE_FLAG ]; then
102 | # Define a string
103 |
104 | # Set the IFS (Internal Field Separator) to space
105 | IFS=','
106 |
107 | WORLD_SIZE=${MA_NUM_HOSTS:-"1"}
108 | export RAY_MASTER_NODE_ADDRESS=${myvar[(($WORLD_SIZE-1))]}
109 | export RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-40000)
110 |
111 | NODE_RANK=""
112 | GPUS_PER_NODE=""
113 |
114 | else
115 | RAY_MASTER_NODE_ADDRESS="0.0.0.0"
116 | RAY_MASTER_NODE_PORT=$(shuf -n 1 -i 30000-65535)
117 | WORLD_SIZE=1
118 | NODE_RANK=0
119 | GPUS_PER_NODE=8
120 | fi
121 | MASTER_HOST="$VC_WORKER_HOSTS"
122 | MASTER_ADDR="${VC_WORKER_HOSTS%%,*}"
123 | # export NCCL_SOCKET_IFNAME=ens2f5
124 | # export GLOO_SOCKET_IFNAME=${NCCL_SOCKET_IFNAME}
125 | export NCCL_NET_PLUGIN=none
126 | export NCCL_IB_TIMEOUT=22
127 | export NCCL_IB_RETRY_CNT=15
128 | export NCCL_DEBUG=INFO
129 | export CUDA_LAUNCH_BLOCKING=1
130 |
131 | export HOST_IP=0.0.0.0
132 | export VLLM_HOST_IP=0.0.0.0
133 |
134 | working_dir=/path/to/workdir
135 | cd $working_dir
136 | export HF_ENDPOINT=https://hf-mirror.com
137 | export WANDB_API_KEY=""
138 | nnode=$WORLD_SIZE
139 | tagname=${tagname:-""}
140 | dataver=${dataver:-"none"}
141 | tag=qw-vl7b-${trainver}-${tagname}
142 | rule_reward=${rule:-"none"}
143 | sys=${sys:-"default"}
144 | lr=${lr:-"10"}
145 | algo=${algo:-"group_sft"}
146 | temperature=${temperature:-"1.0"}
147 | numref=0
148 | fmt=${fmt:-"none"}
149 | bsz=${bsz:-"512"}
150 | rbuffer=${bsz:-"1024"}
151 | nsamples=${nsamples:-"8"}
152 | mbsz=${mbsz:-"4"}
153 | maxlen=${maxlen:-"6144"}
154 | lossver=${lossver:-"none"}
155 | mode=${mode:-"none"}
156 | nactor=${nactor:-"16"}
157 | nvllm=${nvllm:-"8"}
158 | filter=${filter:-"None"}
159 | repeat=${repeat:-"0"}
160 | nepoch=${nepoch:-"3"}
161 | logp_bsz=${logp_bsz:-"8"}
162 | maxtoken=${maxtoken:-"2048"}
163 | tp=${tp:-"1"}
164 | aux=${aux:-"0.05"}
165 | evalsteps=${evalsteps:-"0"}
166 | save_name="${tag}-${bsz}-lossver${lossver}-samplever${dataver}-fmt${fmt}-${algo}-n${nsamples}-ml${maxlen}-lr${lr}-sys${sys}-${nnode}node" # rbsize 1024->256
167 |
168 | DATASET=/path/to/train.parquet
169 | MODEL_CPK_NAME=${save_name}
170 | PRETRAIN_MODEL=${policy}
171 | testdata="/path/to/test.parquet"
172 | SAVE_PATH=$working_dir/saves/$save_name
173 | mkdir -p "${SAVE_PATH}"
174 | # pip install -U deepspeed==0.15.0 # https://github.com/OpenRLHF/OpenRLHF/issues/776#issuecomment-2694472824
175 | #
176 |
177 |
178 | post_args=""
179 | if [ $nnode -gt 1 ]; then
180 |
181 | post_args=(--ref_num_nodes 0
182 | --ref_num_gpus_per_node 8
183 | --actor_num_nodes ${nactor}
184 | --actor_num_gpus_per_node 8
185 | --vllm_num_engines ${nvllm}
186 | --vllm_tensor_parallel_size ${tp}
187 | --micro_train_batch_size ${mbsz}
188 | --train_batch_size ${bsz}
189 | --micro_rollout_batch_size ${logp_bsz}
190 | --rollout_batch_size ${rbuffer}
191 | )
192 |
193 | else
194 | post_args=(--ref_num_nodes 0
195 | --ref_num_gpus_per_node 8
196 | --actor_num_nodes 4
197 | --actor_num_gpus_per_node 1
198 | --vllm_num_engines 4
199 | --vllm_tensor_parallel_size 1
200 | --adam_offload
201 | --micro_train_batch_size 4
202 | --train_batch_size ${bsz}
203 | --micro_rollout_batch_size 4
204 | --rollout_batch_size ${rbuffer}
205 | )
206 | fi
207 | # :/usr/local/cuda/targets/x86_64-linux/lib
208 | LD_LIBRARY_PATH_VALUE=/path/to/nvidia/nvjitlink/lib:$LD_LIBRARY_PATH
209 | export BNB_CUDA_VERSION=122
210 | RUNTIME_ENV_JSON="{\"env_vars\": {\"LD_LIBRARY_PATH\": \"$LD_LIBRARY_PATH_VALUE\"}}"
211 |
212 |
213 | if [ "$NODE_RANK" = "0" ]; then
214 | # Start Ray head node and capture the output
215 | ray_output=$(ray start --head --num-gpus 8)
216 |
217 | # Extract the IP address using grep and sed
218 | ip_address=$(echo "$ray_output" | grep -oP "ray start --address='\K[^']+")
219 |
220 | # Write the extracted IP address to a file named "ip.txt"
221 | mkdir -p ip_tmp
222 | echo "$ip_address" > ip_tmp/ip_${tagname}.txt
223 | cat ip_tmp/ip_${tagname}.txt
224 |
225 |
226 |
227 | if [ $nnode -gt 1 ]; then
228 | # Example usage (to set the environment variable):
229 | export GLOO_SOCKET_IFNAME=$(find_interface)
230 | echo "$GLOO_SOCKET_IFNAME" > ip_tmp/gloo_${tagname}.txt
231 | sleep 60
232 | else
233 | unset GLOO_SOCKET_IFNAME
234 | unset NCLL_SOCKET_IFNAME
235 | fi
236 | ray status
237 | ray job submit --address="http://127.0.0.1:8265" \
238 | --runtime-env-json="$RUNTIME_ENV_JSON" \
239 | -- python3 -m openrlhf.cli.train_ppo_ray \
240 | --vllm_enable_sleep \
241 | --vllm_gpu_memory_utilization 0.85 \
242 | --vllm_sync_backend gloo \
243 | --pretrain $PRETRAIN_MODEL \
244 | --save_path $SAVE_PATH \
245 | --n_samples_per_prompt ${nsamples} \
246 | --max_epochs 1 \
247 | --num_episodes ${nepoch} \
248 | --filter ${filter} \
249 | --prompt_max_len 2048 \
250 | --max_out_tokens ${maxtoken} \
251 | --max_samples 100000 \
252 | --generate_max_len ${maxlen} \
253 | --advantage_estimator ${algo} \
254 | --zero_stage 3 \
255 | --controlled_shuffle ${repeat} \
256 | --bf16 \
257 | --actor_learning_rate ${lr}e-7 \
258 | --rule_reward ${rule_reward} \
259 | --temperature 1.0 \
260 | --val_temperature 0.6 \
261 | --top_p 0.95 \
262 | --training_mode ${mode} \
263 | --init_kl_coef 0.0 \
264 | --aux_loss_coef ${aux} \
265 | --entropy_loss_coef 0.0 \
266 | --prompt_data $DATASET \
267 | --input_key question \
268 | --apply_chat_template \
269 | --normalize_reward \
270 | --flash_attn \
271 | --gradient_checkpointing \
272 | --ckpt_path $SAVE_PATH \
273 | --save_steps 3 \
274 | --eval_steps ${evalsteps} \
275 | --max_ckpt_num 3 \
276 | --save_hf_ckpt \
277 | --disable_ds_ckpt \
278 | --disable_fast_tokenizer \
279 | --use_wandb $WANDB_API_KEY \
280 | --wandb_run_name $save_name \
281 | --system_prompt ${sys} \
282 | --use_kl_estimator_k3 \
283 | --wandb_project vlm-rl \
284 | --buffer_norm 0 \
285 | --train_vlm \
286 | --filter ${filter} \
287 | --eval_data ${testdata} \
288 | --data_version ${dataver} \
289 | --loss_version ${lossver} \
290 | --format ${fmt} \
291 | ${post_args[@]}
292 | # --train_vlm
293 | else
294 | sleep 15
295 | # Read the IP address from the file and assign it to the variable "head_ip"
296 | head_ip=$(cat ip_tmp/ip_${tagname}.txt)
297 | gloo=$(cat ip_tmp/gloo_${tagname}.txt)
298 | export GLOO_SOCKET_IFNAME=$gloo
299 | echo "gloo: $GLOO_SOCKET_IFNAME"
300 | # Print the value of head_ip for verification
301 | echo "Head IP Address: $head_ip"
302 |
303 | ray start --address ${head_ip}
304 | # echo $HOST_IP
305 | fi
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import platform
4 |
5 | from datetime import datetime
6 | from setuptools import find_packages, setup
7 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
8 |
9 | _build_mode = os.getenv("OPENRLHF_BUILD_MODE", "")
10 |
11 |
12 | def _is_nightly():
13 | return _build_mode.lower() == "nightly"
14 |
15 |
16 | def _fetch_requirements(path):
17 | with open(path, "r") as fd:
18 | return [r.strip() for r in fd.readlines()]
19 |
20 |
21 | def _fetch_readme():
22 | with open("README.md", encoding="utf-8") as f:
23 | return f.read()
24 |
25 |
26 | def _fetch_version():
27 | with open("version.txt", "r") as f:
28 | version = f.read().strip()
29 |
30 | if _is_nightly():
31 | now = datetime.now()
32 | date_str = now.strftime("%Y%m%d")
33 | version += f".dev{date_str}"
34 |
35 | return version
36 |
37 |
38 | def _fetch_package_name():
39 | return "openrlhf-nightly" if _is_nightly() else "openrlhf"
40 |
41 |
42 | # Custom wheel class to modify the wheel name
43 | class bdist_wheel(_bdist_wheel):
44 | def finalize_options(self):
45 | _bdist_wheel.finalize_options(self)
46 | self.root_is_pure = False
47 |
48 | def get_tag(self):
49 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
50 | abi_tag = f"{python_version}"
51 |
52 | if platform.system() == "Linux":
53 | platform_tag = "manylinux1_x86_64"
54 | else:
55 | platform_tag = platform.system().lower()
56 |
57 | return python_version, abi_tag, platform_tag
58 |
59 |
60 | # Setup configuration
61 | setup(
62 | author="OpenRLHF Team",
63 | name=_fetch_package_name(),
64 | version=_fetch_version(),
65 | packages=find_packages(
66 | exclude=(
67 | "data",
68 | "docs",
69 | "examples",
70 | )
71 | ),
72 | description="A Ray-based High-performance RLHF framework.",
73 | long_description=_fetch_readme(),
74 | long_description_content_type="text/markdown",
75 | install_requires=_fetch_requirements("requirements.txt"),
76 | extras_require={
77 | "vllm": ["vllm==0.7.2"],
78 | "vllm_latest": ["vllm>0.7.2"],
79 | },
80 | python_requires=">=3.10",
81 | classifiers=[
82 | "Programming Language :: Python :: 3.10",
83 | "Programming Language :: Python :: 3.11",
84 | "Environment :: GPU :: NVIDIA CUDA",
85 | "Topic :: Scientific/Engineering :: Artificial Intelligence",
86 | "Topic :: System :: Distributed Computing",
87 | ],
88 | cmdclass={"bdist_wheel": bdist_wheel},
89 | )
90 |
--------------------------------------------------------------------------------