├── .github └── workflows │ └── python-package.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── README_zh.md ├── dockerfile ├── Dockerfile └── docker-entrypoint.sh ├── docs ├── lmm-r1-logo-panda.png ├── lmm-r1-logo.png ├── logo.png ├── model.jpg ├── motivation.png ├── ppo_examples.md ├── ray_architecture.png ├── sokoban_demo.gif ├── time_compare.jpg └── wandb_log_1.png ├── examples ├── data │ ├── convert_text_to_img.py │ ├── gen_sokoban_tasks.py │ ├── mathlv345_8k_chatml.json │ └── test_message.jsonl └── scripts │ ├── ckpt_ds_zero_to_universal.sh │ ├── docker_run.sh │ ├── experience_filter.py │ ├── lmm_r1 │ ├── train_direct_rl_geo.sh │ ├── train_fre_multi.sh │ ├── train_fre_text.sh │ ├── train_mgt_geo.sh │ ├── train_mgt_percereas.sh │ └── train_sokoban.sh │ ├── nvidia_docker_install.sh │ ├── reward_func.py │ ├── serve_remote_rm.sh │ ├── train_conditional_llama.sh │ ├── train_continue_pretrain_llama.sh │ ├── train_dpo_llama.sh │ ├── train_dpo_llama_34b.sh │ ├── train_dpo_ring_llama.sh │ ├── train_grpo_llama_ray.sh │ ├── train_grpo_ray_hybrid_engine.sh │ ├── train_iterative_dpo_llama.sh │ ├── train_knowledge_distillation.sh │ ├── train_kto_llama.sh │ ├── train_llama_slurm.sh │ ├── train_ppo_llama_ray.sh │ ├── train_ppo_llama_ray_70b.sh │ ├── train_ppo_llama_ray_hybrid_engine.sh │ ├── train_ppo_llama_ray_ring.sh │ ├── train_ppo_llama_ray_slurm.sh │ ├── train_ppo_llama_with_dynamic_sampling.sh │ ├── train_ppo_llama_with_remote_rm.sh │ ├── train_ppo_llama_with_reward_fn.sh │ ├── train_prm_mistral.sh │ ├── train_reinforce_baseline_llama_ray_hybrid_engine.sh │ ├── train_reinforce_llama_ray.sh │ ├── train_reinforce_llama_ray_hybrid_engine.sh │ ├── train_rejection_sampling_llama.sh │ ├── train_rm_llama.sh │ ├── train_sft_llama.sh │ └── train_sft_mixtral_lora.sh ├── openrlhf ├── __init__.py ├── cli │ ├── __init__.py │ ├── batch_inference.py │ ├── interactive_chat.py │ ├── lora_combiner.py │ ├── serve_rm.py │ ├── train_dpo.py │ ├── train_kd.py │ ├── train_kto.py │ ├── train_ppo_ray.py │ ├── train_prm.py │ ├── train_rm.py │ └── train_sft.py ├── datasets │ ├── __init__.py │ ├── process_reward_dataset.py │ ├── prompts_dataset.py │ ├── reward_dataset.py │ ├── sft_dataset.py │ ├── unpaired_preference_dataset.py │ └── utils.py ├── models │ ├── __init__.py │ ├── actor.py │ ├── lmm_kits │ │ ├── base │ │ │ ├── data_processor.py │ │ │ └── patch.py │ │ ├── gemma3 │ │ │ ├── data_processor.py │ │ │ └── patch.py │ │ ├── llm │ │ │ ├── data_processor.py │ │ │ └── patch.py │ │ ├── phi3_v │ │ │ ├── data_processor.py │ │ │ ├── patch.py │ │ │ └── src │ │ │ │ ├── configuration_phi3_v.py │ │ │ │ ├── modeling_phi3_v.py │ │ │ │ └── processing_phi3_v.py │ │ ├── phi4mm │ │ │ ├── data_processor.py │ │ │ ├── patch.py │ │ │ └── src │ │ │ │ ├── configuration_phi4mm.py │ │ │ │ ├── modeling_phi4mm.py │ │ │ │ ├── processing_phi4mm.py │ │ │ │ ├── speech_conformer_encoder.py │ │ │ │ └── vision_siglip_navit.py │ │ ├── qwen2_5_vl │ │ │ ├── data_processor.py │ │ │ └── patch.py │ │ └── utils.py │ ├── loss.py │ ├── model.py │ ├── remote_rm │ │ ├── math_verifier.py │ │ └── sokoban_verifier.py │ ├── ring_attn_utils.py │ └── utils.py ├── trainer │ ├── __init__.py │ ├── dpo_trainer.py │ ├── kd_trainer.py │ ├── kto_trainer.py │ ├── ppo_trainer.py │ ├── ppo_utils │ │ ├── __init__.py │ │ ├── experience_maker.py │ │ ├── kl_controller.py │ │ └── replay_buffer.py │ ├── prm_trainer.py │ ├── ray │ │ ├── __init__.py │ │ ├── launcher.py │ │ ├── ppo_actor.py │ │ ├── ppo_critic.py │ │ ├── utils.py │ │ ├── vllm_engine.py │ │ └── vllm_worker_wrap.py │ ├── rm_trainer.py │ └── sft_trainer.py └── utils │ ├── __init__.py │ ├── deepspeed │ ├── __init__.py │ ├── deepspeed.py │ └── deepspeed_utils.py │ ├── distributed_sampler.py │ ├── distributed_util.py │ ├── logging_utils.py │ ├── processor.py │ ├── remote_rm_utils.py │ └── utils.py ├── pyproject.toml ├── requirements.txt ├── setup.py └── version.txt /.github/workflows/python-package.yml: -------------------------------------------------------------------------------- 1 | name: Python package 2 | 3 | on: 4 | release: 5 | types: [published] 6 | workflow_dispatch: 7 | 8 | jobs: 9 | build-and-publish: 10 | # do not run in forks 11 | if: ${{ github.repository_owner == 'OpenRLHF' && (github.event_name == 'release' || github.event_name == 'workflow_dispatch') }} 12 | name: build wheel and upload 13 | runs-on: ubuntu-22.04 14 | 15 | strategy: 16 | matrix: 17 | python-version: [3.10.14, 3.11.0] 18 | cuda-version: [12.1] 19 | 20 | steps: 21 | - uses: actions/checkout@v3 22 | 23 | - name: Set up Python ${{ matrix.python-version }} 24 | uses: actions/setup-python@v4 25 | with: 26 | python-version: ${{ matrix.python-version }} 27 | 28 | - name: Install CUDA ${{ matrix.cuda-version }} 29 | run: | 30 | wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-ubuntu2204.pin 31 | sudo mv cuda-ubuntu2204.pin /etc/apt/preferences.d/cuda-repository-pin-600 32 | sudo apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/3bf863cc.pub 33 | sudo add-apt-repository "deb https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/ /" 34 | sudo apt-get update 35 | sudo apt-get -y install cuda-${{ matrix.cuda-version }} 36 | 37 | - name: Set up CUDA environment variables 38 | run: | 39 | echo "/usr/local/cuda-${{ matrix.cuda-version }}/lib64" | sudo tee -a /etc/ld.so.conf.d/cuda.conf 40 | echo "export PATH=/usr/local/cuda-${{ matrix.cuda-version }}/bin:\$PATH" | sudo tee -a /etc/environment 41 | sudo ldconfig 42 | shell: bash 43 | 44 | - name: Install dependencies 45 | run: | 46 | python -m pip install --upgrade pip 47 | pip install setuptools wheel twine packaging 48 | 49 | - name: Build package 50 | run: | 51 | python setup.py bdist_wheel --dist-dir=dist 52 | 53 | - name: Publish package 54 | env: 55 | TWINE_USERNAME: __token__ 56 | TWINE_PASSWORD: ${{ secrets.PYPI_API_TOKEN }} 57 | run: | 58 | python -m twine upload dist/* -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/.build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mkdocs documentation 122 | /site 123 | 124 | # mypy 125 | .mypy_cache/ 126 | .dmypy.json 127 | dmypy.json 128 | 129 | # Pyre type checker 130 | .pyre/ 131 | 132 | # IDE 133 | .idea/ 134 | .vscode/ 135 | 136 | # macos 137 | *.DS_Store 138 | #data/ 139 | 140 | docs/.build 141 | 142 | # pytorch checkpoint 143 | *.pt 144 | 145 | core 146 | */ckpt/* 147 | .vscode 148 | .nfs* 149 | *jianh* 150 | *test_scripts* 151 | */checkpoint/* 152 | ckpts 153 | data 154 | !examples/data -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | default_language_version: 2 | python: python3 3 | 4 | ci: 5 | autofix_prs: true 6 | autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions' 7 | autoupdate_schedule: quarterly 8 | 9 | repos: 10 | - repo: https://github.com/pre-commit/pre-commit-hooks 11 | rev: v5.0.0 12 | hooks: 13 | - id: check-yaml 14 | - id: check-case-conflict 15 | - id: detect-private-key 16 | - id: check-added-large-files 17 | args: ['--maxkb=1000'] 18 | - id: requirements-txt-fixer 19 | 20 | - repo: https://github.com/PyCQA/autoflake 21 | rev: v2.3.1 22 | hooks: 23 | - id: autoflake 24 | args: [--remove-all-unused-imports, --in-place] 25 | 26 | - repo: https://github.com/PyCQA/isort 27 | rev: 6.0.1 28 | hooks: 29 | - id: isort 30 | name: Format imports 31 | exclude: docs/ 32 | 33 | - repo: https://github.com/psf/black 34 | rev: 25.1.0 35 | hooks: 36 | - id: black 37 | name: Format code 38 | additional_dependencies: ['click==8.0.2'] 39 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to OpenRLHF 2 | 3 | After cloning the repository, please install pre-commit hooks with: 4 | ``` 5 | pip install pre-commit 6 | pre-commit install 7 | ``` -------------------------------------------------------------------------------- /dockerfile/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:24.07-py3 2 | 3 | WORKDIR /app 4 | 5 | RUN set -eux && \ 6 | apt-get update && \ 7 | apt-get install -y gosu && \ 8 | rm -rf /var/lib/apt/lists/* && \ 9 | gosu nobody true 10 | 11 | RUN apt-get update && apt-get -y install sudo 12 | RUN sudo su - 13 | 14 | RUN DEBIAN_FRONTEND=noninteractive apt install -y tzdata 15 | 16 | RUN apt-get -y install build-essential git python3-dev python3-pip libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev gdb 17 | RUN pip uninstall xgboost transformer_engine flash_attn pynvml opencv-python-headless -y 18 | RUN pip install vllm==0.8.3 19 | 20 | COPY docker-entrypoint.sh . 21 | RUN chmod a+x docker-entrypoint.sh 22 | 23 | ENTRYPOINT ["/app/docker-entrypoint.sh"] 24 | -------------------------------------------------------------------------------- /dockerfile/docker-entrypoint.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | USER=${LOCAL_USER:-"root"} 4 | 5 | if [[ "${USER}" != "root" ]]; then 6 | USER_ID=${LOCAL_USER_ID:-9001} 7 | echo ${USER} 8 | echo ${USER_ID} 9 | 10 | chown ${USER_ID} /home/${USER} 11 | useradd --shell /bin/bash -u ${USER_ID} -o -c "" -m ${USER} 12 | usermod -a -G root ${USER} 13 | adduser ${USER} sudo 14 | 15 | # user:password 16 | echo "${USER}:123" | chpasswd 17 | 18 | export HOME=/home/${USER} 19 | export PATH=/home/${USER}/.local/bin/:$PATH 20 | else 21 | export PATH=/root/.local/bin/:$PATH 22 | fi 23 | 24 | cd $HOME 25 | exec gosu ${USER} "$@" -------------------------------------------------------------------------------- /docs/lmm-r1-logo-panda.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/lmm-r1-logo-panda.png -------------------------------------------------------------------------------- /docs/lmm-r1-logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/lmm-r1-logo.png -------------------------------------------------------------------------------- /docs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/logo.png -------------------------------------------------------------------------------- /docs/model.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/model.jpg -------------------------------------------------------------------------------- /docs/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/motivation.png -------------------------------------------------------------------------------- /docs/ppo_examples.md: -------------------------------------------------------------------------------- 1 | ### Training settings 2 | 3 | * Base model: LLaMA2 7B 4 | * SFT dataset: Ocra 500k 5 | * RM datasets: Anthropic HH, LMSys, OASST, WebGPT ~ 240K+ 6 | * PPO prompts: alpaca, Anthropic HH, OASST ~ 100K 7 | 8 | ### Question: Please introduce the game GTA5 9 | 10 | **SFT** 11 | 12 | Human: Please introduce the game GTA5 13 | 14 | Assistant: GTA5, or Grand Theft Auto V, is an action-adventure video game developed by Rockstar North and published by Rockstar Games. It was released in 2013 for the PlayStation 3 and Xbox 360, and later in 2014 for the PlayStation 4 and Xbox One. The game is set in the fictional city of Los Santos, based in Los Angeles, and follows the story of three criminals as they plan and execute heists. 15 | 16 | GTA5 is an open-world game, meaning that players can explore the game's vast and detailed world at their leisure, completing missions, engaging in side activities, and interacting with other characters. The game features an extensive array of vehicles, from cars to boats to planes, and players can use them to travel around the city and engage in various forms of criminal activity. 17 | 18 | GTA5 is known for its immersive gameplay, stunning graphics, and rich storytelling. It has received critical acclaim and has become one of the best-selling video games of all time, with over 150 million copies sold worldwide. 19 | 20 | **PPO** 21 | 22 | Human: Please introduce the game GTA5 23 | 24 | Assistant: 1. **Game Name:** Grand Theft Auto V (GTA V) 25 | 26 | 2. **Platforms:** GTA V is available on PlayStation 3 and Xbox 360, and later released on PlayStation 4, Xbox One, and Microsoft Windows. 27 | 28 | 3. **Genre:** Action-adventure game 29 | 30 | 4. **Gameplay:** GTA V is an open-world game where players can explore the fictional city of Los Santos and its surrounding areas, completing missions, engaging in side activities, and interacting with non-player characters. The game features a third-person perspective and allows players to freely roam the game world, engage in combat, and complete objectives. 31 | 32 | 5. **Storyline:** The game follows the story of three protagonists, Michael De Santa, Franklin Clinton, and Trevor Philips, as they become embroiled in a series of heists and other criminal activities. The game's storyline is considered one of the strongest aspects of GTA 33 | 34 | 6. **Graphics and Sound:** GTA V features stunning graphics and realistic character animations, making the game world feel alive and immersive. The soundtrack and voice acting are also highly praised, with a diverse and engaging musical score that complements the game's action-packed gameplay. 35 | 36 | 7. **Reception:** GTA V received critical acclaim upon its release, with many praising its storyline, gameplay, graphics, and sound. It is considered one of the best video games of all time and has sold over 150 million copies worldwide, making it one of the best-selling video games of all time. 37 | 38 | In summary, GTA V is an action-adventure game set in a massive open world, with a complex and engaging storyline, stunning graphics, and highly praised soundtrack and voice acting. It is considered one of the greatest video games of all time and continues to be a popular choice among gamers worldwide. 39 | 40 | 41 | ### Question: Which GPU is better, AMD or NVIDIA? 42 | 43 | **SFT** 44 | 45 | Human: Which GPU is better, AMD or NVIDIA? 46 | 47 | Assistant: 48 | 49 | The performance and suitability of a GPU (Graphics Processing Unit) depend on various factors, including the specific task or application being used, the model and generation of the GPU, and the system it is being used in. Both AMD and NVIDIA are well-known and reputable manufacturers of GPUs, and they both have a wide range of models and generations. 50 | 51 | In general, the performance of a GPU can vary depending on the specific task or application being used. For example, AMD's GPUs are known for their performance in gaming and video editing, while NVIDIA's GPUs are known for their performance in machine learning and deep learning. 52 | 53 | It is important to note that the performance of a GPU can also vary depending on the generation of the GPU. Newer generations of GPUs tend to have better performance and features than older generations. 54 | 55 | Ultimately, the best GPU for you will depend on your specific needs and the tasks you plan to use it for. It is recommended to research and compare the specific models and generations of AMD and NVIDIA GPUs to determine which one is the best fit for your needs. 56 | 57 | **PPO** 58 | 59 | Human: Which GPU is better, AMD or NVIDIA? 60 | 61 | Assistant: 1. **Performance:** In terms of raw performance, NVIDIA typically outperforms AMD in most cases, especially in gaming and video editing. 62 | 63 | 2. **Price:** NVIDIA GPUs tend to be more expensive than AMD GPUs, although this can vary depending on the specific model and its features. 64 | 65 | 3. **Architecture:** NVIDIA's GPU architecture is based on the CUDA (Compute Unified Device Architecture) platform, which allows for more efficient processing of parallel computing tasks. AMD's GPU architecture, on the other hand, is based on the Stream Compute architecture, which is designed for general-purpose computing. 66 | 67 | 4. **Application Support:** NVIDIA's GPUs have better support for certain applications, particularly in the fields of deep learning and artificial intelligence. AMD's GPUs, however, have better support for certain games and video editing software. 68 | 69 | In summary, NVIDIA's GPUs tend to be more powerful and efficient for certain tasks, while AMD's GPUs are better suited for other tasks and have better support for certain applications. Ultimately, the choice between the two depends on the specific needs and preferences of the user. 70 | -------------------------------------------------------------------------------- /docs/ray_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/ray_architecture.png -------------------------------------------------------------------------------- /docs/sokoban_demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/sokoban_demo.gif -------------------------------------------------------------------------------- /docs/time_compare.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/time_compare.jpg -------------------------------------------------------------------------------- /docs/wandb_log_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TideDra/lmm-r1/f917f186dc327fb2dadfe54f56c0bf409f9414e2/docs/wandb_log_1.png -------------------------------------------------------------------------------- /examples/data/gen_sokoban_tasks.py: -------------------------------------------------------------------------------- 1 | import gym 2 | import gym_sokoban 3 | import numpy as np 4 | import os 5 | from PIL import Image 6 | from tqdm import tqdm 7 | 8 | 9 | def save_level_and_initial_frame(env, level_index, env_name): 10 | # Save the generated level (map) 11 | level_data = env.unwrapped.room_state 12 | np.save(f'sokoban_levels/{env_name}-level_{level_index}.npy', level_data) 13 | 14 | # Save the initial frame as an image 15 | initial_frame = env.render(mode='rgb_array') 16 | img = Image.fromarray(initial_frame) 17 | w,h = img.size 18 | new_h = 640 19 | new_w = int(new_h/h*w) 20 | img = img.resize((new_w, new_h)) 21 | img.save(f'sokoban_initial_frames/{env_name}-frame_{level_index}.png') 22 | 23 | def main(): 24 | env_name = 'Sokoban-small-v0' 25 | num_tasks = 5000 26 | os.makedirs('sokoban_levels', exist_ok=True) 27 | os.makedirs('sokoban_initial_frames', exist_ok=True) 28 | 29 | for i in tqdm(range(num_tasks)): 30 | env = gym.make(env_name) 31 | env.reset() 32 | save_level_and_initial_frame(env, i, env_name) 33 | env.close() 34 | 35 | env_name = 'Sokoban-small-v1' 36 | num_tasks = 5000 37 | os.makedirs('sokoban_levels', exist_ok=True) 38 | os.makedirs('sokoban_initial_frames', exist_ok=True) 39 | 40 | for i in tqdm(range(num_tasks)): 41 | env = gym.make(env_name) 42 | env.reset() 43 | save_level_and_initial_frame(env, i, env_name) 44 | env.close() 45 | 46 | if __name__ == "__main__": 47 | main() -------------------------------------------------------------------------------- /examples/data/test_message.jsonl: -------------------------------------------------------------------------------- 1 | {"question":"Which number should be written in place of the question mark?\n","answer":"$\\boxed{60}$","message":"[{\"role\": \"system\", \"content\": \"Solve the question. The user asks a question, and you solves it. You first thinks about the reasoning process in the mind and then provides the user with the answer. The answer is in latex format and wrapped in $...$. The final answer must be wrapped using the \\\\boxed{} command. The reasoning process and answer are enclosed within <\/think> and <\/answer> tags, respectively, i.e., Since $1+1=2$, so the answer is $2$. <\/think> The answer is $\\\\boxed{2}$ <\/answer>, which means assistant's output should start with and end with <\/answer>.\\n\"}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"Which number should be written in place of the question mark?\\n\"}, {\"type\": \"image\", \"image\": \"\/root\/projects\/OpenRLHF\/data\/test\/\/images\/1.jpg\"}]}]"} 2 | {"question":"Which bike is most expensive?\n\nA. A\nB. B\nC. C\nD. D\nE. E","answer":"$\\boxed{A}$","message":"[{\"role\": \"system\", \"content\": \"Solve the question. The user asks a question, and you solves it. You first thinks about the reasoning process in the mind and then provides the user with the answer. The answer is in latex format and wrapped in $...$. The final answer must be wrapped using the \\\\boxed{} command. The reasoning process and answer are enclosed within <\/think> and <\/answer> tags, respectively, i.e., Since $1+1=2$, so the answer is $2$. <\/think> The answer is $\\\\boxed{2}$ <\/answer>, which means assistant's output should start with and end with <\/answer>.\\n\"}, {\"role\": \"user\", \"content\": [{\"type\": \"text\", \"text\": \"Which bike is most expensive?\\n\\nA. A\\nB. B\\nC. C\\nD. D\\nE. E\"}, {\"type\": \"image\", \"image\": \"\/root\/projects\/OpenRLHF\/data\/test\/\/images\/2.jpg\"}]}]"} -------------------------------------------------------------------------------- /examples/scripts/ckpt_ds_zero_to_universal.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Ensure at least one argument is provided. 4 | if [ "$#" -lt 1 ]; then 5 | echo "This script converts the latest DeepSpeed ZeRO checkpoint to a universal checkpoint." 6 | echo "Usage: $0 [additional arguments for deepspeed.checkpoint.ds_to_universal]" 7 | exit 1 8 | fi 9 | 10 | # Set CKPT_PATH to the first argument and shift it out so that "$@" contains the extra arguments. 11 | CKPT_PATH="$1" 12 | shift 13 | EXTRA_ARGS="$@" 14 | 15 | # Function to process a given directory. 16 | process_dir() { 17 | local path="$1" 18 | echo "Processing checkpoint: $path" 19 | 20 | # Check if the latest tag exists. 21 | if [ ! -f "$path/latest" ]; then 22 | echo "latest tag file not found in $path, ensure the directory contains a valid DeepSpeed ZeRO checkpoint." 23 | return 1 24 | fi 25 | 26 | # Read the latest tag. 27 | LATEST_TAG=$(cat "$path/latest") 28 | LATEST_UNI_TAG="${LATEST_TAG}_uni" 29 | 30 | # Write the universal tag. 31 | echo "$LATEST_UNI_TAG" > "$path/latest_universal" 32 | 33 | # Run the python command with any additional arguments. 34 | python -m deepspeed.checkpoint.ds_to_universal --inject_missing_state \ 35 | --input_folder "$path/$LATEST_TAG" \ 36 | --output_folder "$path/$LATEST_UNI_TAG" \ 37 | $EXTRA_ARGS 38 | } 39 | 40 | # Flag to check if at least one of the specific subdirectories exists. 41 | found_subdir=0 42 | 43 | ## For PPO, checkpoints for each model are stored under "_actor" and "_critic" separately. 44 | # Check for the subdirectory named exactly "_actor". 45 | if [ -d "$CKPT_PATH/_actor" ]; then 46 | process_dir "$CKPT_PATH/_actor" 47 | found_subdir=1 48 | fi 49 | 50 | # Check for the subdirectory named exactly "_critic". 51 | if [ -d "$CKPT_PATH/_critic" ]; then 52 | process_dir "$CKPT_PATH/_critic" 53 | found_subdir=1 54 | fi 55 | 56 | # If neither subdirectory exists, process the main CKPT_PATH. 57 | if [ "$found_subdir" -eq 0 ]; then 58 | process_dir "$CKPT_PATH" 59 | fi 60 | -------------------------------------------------------------------------------- /examples/scripts/docker_run.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | PROJECT_PATH=$(cd $(dirname $0)/../../; pwd) 4 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 5 | 6 | docker run --runtime=nvidia -it --rm --shm-size="10g" --cap-add=SYS_ADMIN \ 7 | -v $PROJECT_PATH:/openrlhf -v $HOME/.cache:/root/.cache -v $HOME/.bash_history2:/root/.bash_history \ 8 | $IMAGE_NAME bash -------------------------------------------------------------------------------- /examples/scripts/experience_filter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def experience_filter(experience_maker,experiences): 4 | return experiences[: len(experiences) // 2] -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_direct_rl_geo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # =================== User Configuration =================== 3 | # Please modify these variables according to your environment 4 | # ========================================================= 5 | 6 | # Base paths - MODIFY THESE 7 | export WORKSPACE_DIR="$(pwd)" # Path to project root directory 8 | export DATASET_PATH="./data/MathV360k/geo_split.jsonl" # Path to your dataset 9 | export PRETRAIN_MODEL_PATH="./models/Qwen2.5-VL-3B-Instruct" # Path to pretrained model 10 | export SAVE_PATH="/checkpoints" # Absolute path to save checkpoints 11 | 12 | # Model configuration 13 | export MODEL_NAME="lmm-r1-direct-rl-geo" # Name for this training run 14 | 15 | # Wandb configuration (optional) 16 | export WANDB_DIR="${WORKSPACE_DIR}" # Directory for wandb files 17 | export WANDB_API_KEY="YOUR_WANDB_API_KEY" # Your wandb API key (if online) 18 | 19 | # =================== Script Execution =================== 20 | # You shouldn't need to modify anything below this line 21 | # ====================================================== 22 | 23 | # Get script PID and setup directories 24 | SCRIPT_PID=$$ 25 | export TIMESTAMP=$(date +%Y%m%d_%H%M%S) 26 | export LOG_DIR="${SAVE_PATH}/${MODEL_NAME}/logs" 27 | export CUR_LOG_DIR="${LOG_DIR}/${TIMESTAMP}" 28 | 29 | # Stop any existing ray processes 30 | ray stop 31 | 32 | # Create necessary directories 33 | mkdir -p "${SAVE_PATH}/${MODEL_NAME}" 34 | mkdir -p "${LOG_DIR}" 35 | mkdir -p "${CUR_LOG_DIR}" 36 | 37 | # Print help information 38 | echo "================================================================" 39 | echo "LMM-R1 Direct RL Geometry Training" 40 | echo "================================================================" 41 | echo "Model name: ${MODEL_NAME}" 42 | echo "Dataset: ${DATASET_PATH}" 43 | echo "Pretrained model: ${PRETRAIN_MODEL_PATH}" 44 | echo "Logs will be saved to: ${CUR_LOG_DIR}" 45 | echo 46 | echo "To monitor logs:" 47 | echo " tail -f ${CUR_LOG_DIR}/train.log" 48 | echo 49 | echo "================================================================" 50 | 51 | # Start ray 52 | echo "Starting ray..." 53 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray 54 | 55 | # Start remote reward model server 56 | echo "Starting remote reward model server..." 57 | python -m openrlhf.models.remote_rm.math_verifier \ 58 | --dataset "${DATASET_PATH}" \ 59 | --input_key message \ 60 | --prompt-template chatml 2>&1 | tee -a "${CUR_LOG_DIR}/remote_rm.log" & 61 | REMOTE_RM_PID=$! 62 | 63 | # Start training 64 | echo "Starting training..." 65 | ray job submit --address="http://127.0.0.1:8265" \ 66 | --runtime-env-json="{\"working_dir\": \"${WORKSPACE_DIR}\",\"env_vars\":{\"VLLM_USE_V1\":\"1\",\"VLLM_ENABLE_V1_MULTIPROCESSING\":\"0\"}}" \ 67 | -- python -m openrlhf.cli.train_ppo_ray \ 68 | --ref_num_nodes 1 \ 69 | --ref_num_gpus_per_node 8 \ 70 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 71 | --actor_num_nodes 1 \ 72 | --actor_num_gpus_per_node 8 \ 73 | --critic_num_nodes 1 \ 74 | --critic_num_gpus_per_node 8 \ 75 | --vllm_num_engines 8 \ 76 | --vllm_tensor_parallel_size 1 \ 77 | --colocate_all_models \ 78 | --vllm_enable_sleep \ 79 | --vllm_gpu_memory_utilization 0.5 \ 80 | --vllm_sync_backend gloo \ 81 | --enable_prefix_caching \ 82 | --pretrain ${PRETRAIN_MODEL_PATH} \ 83 | --save_path ${SAVE_PATH}/${MODEL_NAME} \ 84 | --micro_train_batch_size 2 \ 85 | --train_batch_size 256 \ 86 | --micro_rollout_batch_size 2 \ 87 | --rollout_batch_size 256 \ 88 | --temperature 1.0 \ 89 | --n_samples_per_prompt 16 \ 90 | --max_epochs 1 \ 91 | --num_episodes 2 \ 92 | --prompt_max_len 4096 \ 93 | --max_samples 100000 \ 94 | --generate_max_len 4096 \ 95 | --advantage_estimator gae \ 96 | --zero_stage 3 \ 97 | --bf16 \ 98 | --actor_learning_rate 1e-6 \ 99 | --init_kl_coef 0.001 \ 100 | --prompt_data ${DATASET_PATH} \ 101 | --input_key message \ 102 | --label_key "answer" \ 103 | --normalize_reward \ 104 | --flash_attn \ 105 | --lambd 1 \ 106 | --gamma 1 \ 107 | --gradient_checkpointing \ 108 | --save_steps 20 \ 109 | --ckpt_path ${SAVE_PATH}/${MODEL_NAME}/ckpt \ 110 | --save_hf_ckpt \ 111 | --load_checkpoint \ 112 | --use_wandb ${WANDB_API_KEY} \ 113 | --wandb_run_name ${MODEL_NAME} \ 114 | --wandb_group "lmm-r1-training" \ 115 | --use_tensorboard ${LOG_DIR} > >(tee -a "${CUR_LOG_DIR}/train.log") 2>&1 & 116 | 117 | TRAIN_PID=$! 118 | 119 | # Record process IDs 120 | echo "Remote RM PID: $REMOTE_RM_PID" > "${CUR_LOG_DIR}/process_pids.txt" 121 | echo "Train PID: $TRAIN_PID" >> "${CUR_LOG_DIR}/process_pids.txt" 122 | 123 | # Wait for training to complete 124 | echo "Training is running in the background. Check logs at ${CUR_LOG_DIR}/train.log" 125 | echo "To attach to the training process: wait $TRAIN_PID" 126 | 127 | # Uncomment to wait for training to complete before exiting 128 | # wait $TRAIN_PID 129 | 130 | # Cleanup instructions 131 | echo "When finished, clean up with:" 132 | echo "pkill -f openrlhf" 133 | echo "ray stop" 134 | echo "All logs are available in ${CUR_LOG_DIR}" -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_fre_multi.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # =================== User Configuration =================== 3 | # Please modify these variables according to your environment 4 | # ========================================================= 5 | 6 | # Base paths - MODIFY THESE 7 | export WORKSPACE_DIR="$(pwd)" # Path to project root directory 8 | export DATASET_PATH="./data/VerMulti/mathv60k_message.jsonl" # Path to your dataset 9 | export PRETRAIN_MODEL_PATH="./models/Qwen2.5-VL-3B-Instruct" # Path to pretrained model 10 | export SAVE_PATH="/checkpoints" # Absolute path to save checkpoints 11 | 12 | # Model configuration 13 | export MODEL_NAME="lmm-r1-fre-multi" # Name for this training run 14 | 15 | # Wandb configuration (optional) 16 | export WANDB_DIR="${WORKSPACE_DIR}" # Directory for wandb files 17 | export WANDB_API_KEY="YOUR_WANDB_API_KEY" # Your wandb API key (if online) 18 | 19 | # =================== Script Execution =================== 20 | # You shouldn't need to modify anything below this line 21 | # ====================================================== 22 | 23 | # Get script PID and setup directories 24 | SCRIPT_PID=$$ 25 | export TIMESTAMP=$(date +%Y%m%d_%H%M%S) 26 | export LOG_DIR="${SAVE_PATH}/${MODEL_NAME}/logs" 27 | export CUR_LOG_DIR="${LOG_DIR}/${TIMESTAMP}" 28 | 29 | # Stop any existing ray processes 30 | ray stop 31 | 32 | # Create necessary directories 33 | mkdir -p "${SAVE_PATH}/${MODEL_NAME}" 34 | mkdir -p "${LOG_DIR}" 35 | mkdir -p "${CUR_LOG_DIR}" 36 | 37 | # Print help information 38 | echo "================================================================" 39 | echo "LMM-R1 FRE-Multi Training" 40 | echo "================================================================" 41 | echo "Model name: ${MODEL_NAME}" 42 | echo "Dataset: ${DATASET_PATH}" 43 | echo "Pretrained model: ${PRETRAIN_MODEL_PATH}" 44 | echo "Logs will be saved to: ${CUR_LOG_DIR}" 45 | echo 46 | echo "To monitor logs:" 47 | echo " tail -f ${CUR_LOG_DIR}/train.log" 48 | echo 49 | echo "================================================================" 50 | 51 | # Start ray 52 | echo "Starting ray..." 53 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray 54 | 55 | # Start remote reward model server 56 | echo "Starting remote reward model server..." 57 | python -m openrlhf.models.remote_rm.math_verifier \ 58 | --dataset "${DATASET_PATH}" \ 59 | --input_key message \ 60 | --prompt-template chatml 2>&1 | tee -a "${CUR_LOG_DIR}/remote_rm.log" & 61 | REMOTE_RM_PID=$! 62 | 63 | # Start training 64 | echo "Starting training..." 65 | ray job submit --address="http://127.0.0.1:8265" \ 66 | --runtime-env-json="{\"working_dir\": \"${WORKSPACE_DIR}\",\"env_vars\":{\"VLLM_USE_V1\":\"1\",\"VLLM_ENABLE_V1_MULTIPROCESSING\":\"0\"}}" \ 67 | -- python -m openrlhf.cli.train_ppo_ray \ 68 | --ref_num_nodes 1 \ 69 | --ref_num_gpus_per_node 8 \ 70 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 71 | --actor_num_nodes 1 \ 72 | --actor_num_gpus_per_node 8 \ 73 | --critic_num_nodes 1 \ 74 | --critic_num_gpus_per_node 8 \ 75 | --vllm_num_engines 8 \ 76 | --vllm_tensor_parallel_size 1 \ 77 | --colocate_all_models \ 78 | --vllm_enable_sleep \ 79 | --vllm_gpu_memory_utilization 0.5 \ 80 | --vllm_sync_backend gloo \ 81 | --enable_prefix_caching \ 82 | --pretrain ${PRETRAIN_MODEL_PATH} \ 83 | --save_path ${SAVE_PATH}/${MODEL_NAME} \ 84 | --micro_train_batch_size 2 \ 85 | --train_batch_size 256 \ 86 | --micro_rollout_batch_size 2 \ 87 | --rollout_batch_size 256 \ 88 | --temperature 1.0 \ 89 | --n_samples_per_prompt 16 \ 90 | --max_epochs 1 \ 91 | --num_episodes 2 \ 92 | --prompt_max_len 4096 \ 93 | --max_samples 100000 \ 94 | --generate_max_len 4096 \ 95 | --advantage_estimator reinforce_baseline \ 96 | --zero_stage 3 \ 97 | --bf16 \ 98 | --actor_learning_rate 1e-6 \ 99 | --init_kl_coef 0.001 \ 100 | --prompt_data ${DATASET_PATH} \ 101 | --input_key message \ 102 | --label_key "answer" \ 103 | --normalize_reward \ 104 | --flash_attn \ 105 | --lambd 1 \ 106 | --gamma 1 \ 107 | --gradient_checkpointing \ 108 | --save_steps 20 \ 109 | --ckpt_path ${SAVE_PATH}/${MODEL_NAME}/ckpt \ 110 | --save_hf_ckpt \ 111 | --load_checkpoint \ 112 | --use_wandb ${WANDB_API_KEY} \ 113 | --wandb_run_name ${MODEL_NAME} \ 114 | --wandb_group "lmm-r1-training" \ 115 | --use_tensorboard ${LOG_DIR} > >(tee -a "${CUR_LOG_DIR}/train.log") 2>&1 & 116 | 117 | TRAIN_PID=$! 118 | 119 | # Record process IDs 120 | echo "Remote RM PID: $REMOTE_RM_PID" > "${CUR_LOG_DIR}/process_pids.txt" 121 | echo "Train PID: $TRAIN_PID" >> "${CUR_LOG_DIR}/process_pids.txt" 122 | 123 | # Wait for training to complete 124 | echo "Training is running in the background. Check logs at ${CUR_LOG_DIR}/train.log" 125 | echo "To attach to the training process: wait $TRAIN_PID" 126 | 127 | # Uncomment to wait for training to complete before exiting 128 | # wait $TRAIN_PID 129 | 130 | # Cleanup instructions 131 | echo "When finished, clean up with:" 132 | echo "pkill -f openrlhf" 133 | echo "ray stop" 134 | echo "All logs are available in ${CUR_LOG_DIR}" -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_fre_text.sh: -------------------------------------------------------------------------------- 1 | # Download the datasets 2 | 3 | #!/bin/bash 4 | # =================== User Configuration =================== 5 | # Please modify these variables according to your environment 6 | # ========================================================= 7 | 8 | # Base paths - MODIFY THESE 9 | export WORKSPACE_DIR="$(pwd)" # Path to project root directory 10 | export DATASET_PATH="./data/deepscaler/deepscaler_message.jsonl" # Path to your dataset 11 | export PRETRAIN_MODEL_PATH="./models/Qwen2.5-VL-3B-Instruct" # Path to pretrained model 12 | export SAVE_PATH="/checkpoints" # Absolute path to save checkpoints 13 | 14 | # Model configuration 15 | export MODEL_NAME="lmm-r1-fre-text" # Name for this training run 16 | 17 | # Wandb configuration (optional) 18 | export WANDB_DIR="${WORKSPACE_DIR}" # Directory for wandb files 19 | export WANDB_API_KEY="YOUR_WANDB_API_KEY" # Your wandb API key (if online) 20 | 21 | # =================== Script Execution =================== 22 | # You shouldn't need to modify anything below this line 23 | # ====================================================== 24 | 25 | # Get script PID and setup directories 26 | SCRIPT_PID=$$ 27 | export TIMESTAMP=$(date +%Y%m%d_%H%M%S) 28 | export LOG_DIR="${SAVE_PATH}/${MODEL_NAME}/logs" 29 | export CUR_LOG_DIR="${LOG_DIR}/${TIMESTAMP}" 30 | 31 | # Stop any existing ray processes 32 | ray stop 33 | 34 | # Create necessary directories 35 | mkdir -p "${SAVE_PATH}/${MODEL_NAME}" 36 | mkdir -p "${LOG_DIR}" 37 | mkdir -p "${CUR_LOG_DIR}" 38 | 39 | # Print help information 40 | echo "================================================================" 41 | echo "LMM-R1 FRE-Text Training" 42 | echo "================================================================" 43 | echo "Model name: ${MODEL_NAME}" 44 | echo "Dataset: ${DATASET_PATH}" 45 | echo "Pretrained model: ${PRETRAIN_MODEL_PATH}" 46 | echo "Logs will be saved to: ${CUR_LOG_DIR}" 47 | echo 48 | echo "To monitor logs:" 49 | echo " tail -f ${CUR_LOG_DIR}/train.log" 50 | echo 51 | echo "================================================================" 52 | 53 | # Start ray 54 | echo "Starting ray..." 55 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray 56 | 57 | # Start remote reward model server 58 | echo "Starting remote reward model server..." 59 | python -m openrlhf.models.remote_rm.math_verifier \ 60 | --dataset "${DATASET_PATH}" \ 61 | --input_key message \ 62 | --prompt-template chatml 2>&1 | tee -a "${CUR_LOG_DIR}/remote_rm.log" & 63 | REMOTE_RM_PID=$! 64 | 65 | # Start training 66 | echo "Starting training..." 67 | ray job submit --address="http://127.0.0.1:8265" \ 68 | --runtime-env-json="{\"working_dir\": \"${WORKSPACE_DIR}\",\"env_vars\":{\"VLLM_USE_V1\":\"1\",\"VLLM_ENABLE_V1_MULTIPROCESSING\":\"0\"}}" \ 69 | -- python -m openrlhf.cli.train_ppo_ray \ 70 | --ref_num_nodes 1 \ 71 | --ref_num_gpus_per_node 8 \ 72 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 73 | --actor_num_nodes 1 \ 74 | --actor_num_gpus_per_node 8 \ 75 | --critic_num_nodes 1 \ 76 | --critic_num_gpus_per_node 8 \ 77 | --vllm_num_engines 8 \ 78 | --vllm_tensor_parallel_size 1 \ 79 | --colocate_all_models \ 80 | --vllm_enable_sleep \ 81 | --vllm_gpu_memory_utilization 0.5 \ 82 | --vllm_sync_backend gloo \ 83 | --enable_prefix_caching \ 84 | --pretrain ${PRETRAIN_MODEL_PATH} \ 85 | --save_path ${SAVE_PATH}/${MODEL_NAME} \ 86 | --micro_train_batch_size 2 \ 87 | --train_batch_size 256 \ 88 | --micro_rollout_batch_size 2 \ 89 | --rollout_batch_size 256 \ 90 | --temperature 1.0 \ 91 | --n_samples_per_prompt 16 \ 92 | --max_epochs 1 \ 93 | --num_episodes 2 \ 94 | --prompt_max_len 4096 \ 95 | --max_samples 100000 \ 96 | --generate_max_len 4096 \ 97 | --advantage_estimator reinforce_baseline \ 98 | --zero_stage 3 \ 99 | --bf16 \ 100 | --actor_learning_rate 4e-7 \ 101 | --init_kl_coef 0.001 \ 102 | --prompt_data ${DATASET_PATH} \ 103 | --input_key message \ 104 | --label_key "answer" \ 105 | --normalize_reward \ 106 | --flash_attn \ 107 | --lambd 1 \ 108 | --gamma 1 \ 109 | --gradient_checkpointing \ 110 | --save_steps 20 \ 111 | --ckpt_path ${SAVE_PATH}/${MODEL_NAME}/ckpt \ 112 | --save_hf_ckpt \ 113 | --load_checkpoint \ 114 | --use_wandb ${WANDB_API_KEY} \ 115 | --wandb_run_name ${MODEL_NAME} \ 116 | --wandb_group "lmm-r1-training" \ 117 | --use_tensorboard ${LOG_DIR} > >(tee -a "${CUR_LOG_DIR}/train.log") 2>&1 & 118 | 119 | TRAIN_PID=$! 120 | 121 | # Record process IDs 122 | echo "Remote RM PID: $REMOTE_RM_PID" > "${CUR_LOG_DIR}/process_pids.txt" 123 | echo "Train PID: $TRAIN_PID" >> "${CUR_LOG_DIR}/process_pids.txt" 124 | 125 | # Wait for training to complete 126 | echo "Training is running in the background. Check logs at ${CUR_LOG_DIR}/train.log" 127 | echo "To attach to the training process: wait $TRAIN_PID" 128 | 129 | # Uncomment to wait for training to complete before exiting 130 | # wait $TRAIN_PID 131 | 132 | # Cleanup instructions 133 | echo "When finished, clean up with:" 134 | echo "pkill -f openrlhf" 135 | echo "ray stop" 136 | echo "All logs are available in ${CUR_LOG_DIR}" -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_mgt_geo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # =================== User Configuration =================== 3 | # Please modify these variables according to your environment 4 | # ========================================================= 5 | 6 | # Base paths - MODIFY THESE 7 | export WORKSPACE_DIR="$(pwd)" # Path to project root directory 8 | export DATASET_PATH="./data/VerMulti/mathv_geo_message.jsonl" # Path to your dataset 9 | export PRETRAIN_MODEL_PATH="./checkpoints/lmm-r1-fre-text/" # Path to pretrained model 10 | export SAVE_PATH="/checkpoints" # Absolute path to save checkpoints 11 | 12 | # Model configuration 13 | export MODEL_NAME="lmm-r1-mgt-geo" # Name for this training run 14 | 15 | # Wandb configuration (optional) 16 | export WANDB_DIR="${WORKSPACE_DIR}" # Directory for wandb files 17 | export WANDB_API_KEY="YOUR_WANDB_API_KEY" # Your wandb API key (if online) 18 | 19 | # =================== Script Execution =================== 20 | # You shouldn't need to modify anything below this line 21 | # ====================================================== 22 | 23 | # Get script PID and setup directories 24 | SCRIPT_PID=$$ 25 | export TIMESTAMP=$(date +%Y%m%d_%H%M%S) 26 | export LOG_DIR="${SAVE_PATH}/${MODEL_NAME}/logs" 27 | export CUR_LOG_DIR="${LOG_DIR}/${TIMESTAMP}" 28 | 29 | # Stop any existing ray processes 30 | ray stop 31 | 32 | # Create necessary directories 33 | mkdir -p "${SAVE_PATH}/${MODEL_NAME}" 34 | mkdir -p "${LOG_DIR}" 35 | mkdir -p "${CUR_LOG_DIR}" 36 | 37 | # Print help information 38 | echo "================================================================" 39 | echo "LMM-R1 MGT Geometry Training" 40 | echo "================================================================" 41 | echo "Model name: ${MODEL_NAME}" 42 | echo "Dataset: ${DATASET_PATH}" 43 | echo "Pretrained model: ${PRETRAIN_MODEL_PATH}" 44 | echo "Logs will be saved to: ${CUR_LOG_DIR}" 45 | echo 46 | echo "To monitor logs:" 47 | echo " tail -f ${CUR_LOG_DIR}/train.log" 48 | echo 49 | echo "================================================================" 50 | 51 | # Start ray 52 | echo "Starting ray..." 53 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray 54 | 55 | # Start remote reward model server 56 | echo "Starting remote reward model server..." 57 | python -m openrlhf.models.remote_rm.math_verifier \ 58 | --dataset "${DATASET_PATH}" \ 59 | --input_key message \ 60 | --prompt-template chatml 2>&1 | tee -a "${CUR_LOG_DIR}/remote_rm.log" & 61 | REMOTE_RM_PID=$! 62 | 63 | # Start training 64 | echo "Starting training..." 65 | ray job submit --address="http://127.0.0.1:8265" \ 66 | --runtime-env-json="{\"working_dir\": \"${WORKSPACE_DIR}\",\"env_vars\":{\"VLLM_USE_V1\":\"1\",\"VLLM_ENABLE_V1_MULTIPROCESSING\":\"0\"}}" \ 67 | -- python -m openrlhf.cli.train_ppo_ray \ 68 | --ref_num_nodes 1 \ 69 | --ref_num_gpus_per_node 8 \ 70 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 71 | --actor_num_nodes 1 \ 72 | --actor_num_gpus_per_node 8 \ 73 | --critic_num_nodes 1 \ 74 | --critic_num_gpus_per_node 8 \ 75 | --vllm_num_engines 8 \ 76 | --vllm_tensor_parallel_size 1 \ 77 | --colocate_all_models \ 78 | --vllm_enable_sleep \ 79 | --vllm_gpu_memory_utilization 0.5 \ 80 | --vllm_sync_backend gloo \ 81 | --enable_prefix_caching \ 82 | --pretrain ${PRETRAIN_MODEL_PATH} \ 83 | --save_path ${SAVE_PATH}/${MODEL_NAME} \ 84 | --micro_train_batch_size 2 \ 85 | --train_batch_size 256 \ 86 | --micro_rollout_batch_size 2 \ 87 | --rollout_batch_size 256 \ 88 | --temperature 1.0 \ 89 | --n_samples_per_prompt 16 \ 90 | --max_epochs 1 \ 91 | --num_episodes 2 \ 92 | --prompt_max_len 4096 \ 93 | --max_samples 100000 \ 94 | --generate_max_len 4096 \ 95 | --advantage_estimator gae \ 96 | --zero_stage 3 \ 97 | --bf16 \ 98 | --actor_learning_rate 1e-6 \ 99 | --init_kl_coef 0.001 \ 100 | --prompt_data ${DATASET_PATH} \ 101 | --input_key message \ 102 | --label_key "answer" \ 103 | --normalize_reward \ 104 | --flash_attn \ 105 | --lambd 1 \ 106 | --gamma 1 \ 107 | --gradient_checkpointing \ 108 | --save_steps 20 \ 109 | --ckpt_path ${SAVE_PATH}/${MODEL_NAME}/ckpt \ 110 | --save_hf_ckpt \ 111 | --load_checkpoint \ 112 | --use_wandb ${WANDB_API_KEY} \ 113 | --wandb_run_name ${MODEL_NAME} \ 114 | --wandb_group "lmm-r1-training" \ 115 | --use_tensorboard ${LOG_DIR} > >(tee -a "${CUR_LOG_DIR}/train.log") 2>&1 & 116 | 117 | TRAIN_PID=$! 118 | 119 | # Record process IDs 120 | echo "Remote RM PID: $REMOTE_RM_PID" > "${CUR_LOG_DIR}/process_pids.txt" 121 | echo "Train PID: $TRAIN_PID" >> "${CUR_LOG_DIR}/process_pids.txt" 122 | 123 | # Wait for training to complete 124 | echo "Training is running in the background. Check logs at ${CUR_LOG_DIR}/train.log" 125 | echo "To attach to the training process: wait $TRAIN_PID" 126 | 127 | # Uncomment to wait for training to complete before exiting 128 | # wait $TRAIN_PID 129 | 130 | # Cleanup instructions 131 | echo "When finished, clean up with:" 132 | echo "pkill -f openrlhf" 133 | echo "ray stop" 134 | echo "All logs are available in ${CUR_LOG_DIR}" -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_mgt_percereas.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # =================== User Configuration =================== 3 | # Please modify these variables according to your environment 4 | # ========================================================= 5 | 6 | # Base paths - MODIFY THESE 7 | export WORKSPACE_DIR="$(pwd)" # Path to project root directory 8 | export DATASET_PATH="./data/VerMulti/mathv60k_message.jsonl" # Path to your dataset 9 | export PRETRAIN_MODEL_PATH="./checkpoints/lmm-r1-fre-text/" # Path to pretrained model 10 | export SAVE_PATH="/checkpoints" # Absolute path to save checkpoints 11 | 12 | # Model configuration 13 | export MODEL_NAME="lmm-r1-mgt-percereason" # Name for this training run 14 | 15 | # Wandb configuration (optional) 16 | export WANDB_DIR="${WORKSPACE_DIR}" # Directory for wandb files 17 | export WANDB_API_KEY="YOUR_WANDB_API_KEY" # Your wandb API key (if online) 18 | 19 | # =================== Script Execution =================== 20 | # You shouldn't need to modify anything below this line 21 | # ====================================================== 22 | 23 | # Get script PID and setup directories 24 | SCRIPT_PID=$$ 25 | export TIMESTAMP=$(date +%Y%m%d_%H%M%S) 26 | export LOG_DIR="${SAVE_PATH}/${MODEL_NAME}/logs" 27 | export CUR_LOG_DIR="${LOG_DIR}/${TIMESTAMP}" 28 | 29 | # Stop any existing ray processes 30 | ray stop 31 | 32 | # Create necessary directories 33 | mkdir -p "${SAVE_PATH}/${MODEL_NAME}" 34 | mkdir -p "${LOG_DIR}" 35 | mkdir -p "${CUR_LOG_DIR}" 36 | 37 | # Print help information 38 | echo "================================================================" 39 | echo "LMM-R1 MGT-PerceReason Training" 40 | echo "================================================================" 41 | echo "Model name: ${MODEL_NAME}" 42 | echo "Dataset: ${DATASET_PATH}" 43 | echo "Pretrained model: ${PRETRAIN_MODEL_PATH}" 44 | echo "Logs will be saved to: ${CUR_LOG_DIR}" 45 | echo 46 | echo "To monitor logs:" 47 | echo " tail -f ${CUR_LOG_DIR}/train.log" 48 | echo 49 | echo "To stop training:" 50 | echo " kill ${SCRIPT_PID}" 51 | echo " or: pkill -f openrlhf" 52 | echo " followed by: ray stop" 53 | echo "================================================================" 54 | 55 | # Start ray 56 | echo "Starting ray..." 57 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir ~/.cache/ray 58 | 59 | # Start remote reward model server 60 | echo "Starting remote reward model server..." 61 | python -m openrlhf.models.remote_rm.math_verifier \ 62 | --dataset "${DATASET_PATH}" \ 63 | --input_key message \ 64 | --prompt-template chatml 2>&1 | tee -a "${CUR_LOG_DIR}/remote_rm.log" & 65 | REMOTE_RM_PID=$! 66 | 67 | # Start training 68 | echo "Starting training..." 69 | ray job submit --address="http://127.0.0.1:8265" \ 70 | --runtime-env-json="{\"working_dir\": \"${WORKSPACE_DIR}\",\"env_vars\":{\"VLLM_USE_V1\":\"1\",\"VLLM_ENABLE_V1_MULTIPROCESSING\":\"0\"}}" \ 71 | -- python -m openrlhf.cli.train_ppo_ray \ 72 | --ref_num_nodes 1 \ 73 | --ref_num_gpus_per_node 8 \ 74 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 75 | --actor_num_nodes 1 \ 76 | --actor_num_gpus_per_node 8 \ 77 | --critic_num_nodes 1 \ 78 | --critic_num_gpus_per_node 8 \ 79 | --vllm_num_engines 8 \ 80 | --vllm_tensor_parallel_size 1 \ 81 | --colocate_all_models \ 82 | --vllm_enable_sleep \ 83 | --vllm_gpu_memory_utilization 0.5 \ 84 | --vllm_sync_backend gloo \ 85 | --enable_prefix_caching \ 86 | --pretrain ${PRETRAIN_MODEL_PATH} \ 87 | --save_path ${SAVE_PATH}/${MODEL_NAME} \ 88 | --micro_train_batch_size 1 \ 89 | --train_batch_size 256 \ 90 | --micro_rollout_batch_size 1 \ 91 | --rollout_batch_size 256 \ 92 | --temperature 1.0 \ 93 | --n_samples_per_prompt 16 \ 94 | --max_epochs 1 \ 95 | --num_episodes 2 \ 96 | --prompt_max_len 4096 \ 97 | --max_samples 100000 \ 98 | --generate_max_len 8192 \ 99 | --advantage_estimator reinforce_baseline \ 100 | --zero_stage 3 \ 101 | --bf16 \ 102 | --actor_learning_rate 4e-7 \ 103 | --init_kl_coef 0.001 \ 104 | --prompt_data ${DATASET_PATH} \ 105 | --input_key message \ 106 | --label_key "answer" \ 107 | --normalize_reward \ 108 | --flash_attn \ 109 | --lambd 1 \ 110 | --gamma 1 \ 111 | --gradient_checkpointing \ 112 | --save_steps 20 \ 113 | --ckpt_path ${SAVE_PATH}/${MODEL_NAME}/ckpt \ 114 | --save_hf_ckpt \ 115 | --load_checkpoint \ 116 | --prompt_data_probs ${PROMPT_DATA_PROBS} \ 117 | --use_wandb ${WANDB_API_KEY} \ 118 | --wandb_run_name ${MODEL_NAME} \ 119 | --wandb_group "lmm-r1-training" \ 120 | --use_tensorboard ${LOG_DIR} > >(tee -a "${CUR_LOG_DIR}/train.log") 2>&1 & 121 | 122 | TRAIN_PID=$! 123 | 124 | # Record process IDs 125 | echo "Remote RM PID: $REMOTE_RM_PID" > "${CUR_LOG_DIR}/process_pids.txt" 126 | echo "Train PID: $TRAIN_PID" >> "${CUR_LOG_DIR}/process_pids.txt" 127 | 128 | # Wait for training to complete 129 | echo "Training is running in the background. Check logs at ${CUR_LOG_DIR}/train.log" 130 | echo "To attach to the training process: wait $TRAIN_PID" 131 | 132 | # Uncomment to wait for training to complete before exiting 133 | # wait $TRAIN_PID 134 | 135 | # Cleanup instructions 136 | echo "When finished, clean up with:" 137 | echo "pkill -f openrlhf" 138 | echo "ray stop" 139 | echo "All logs are available in ${CUR_LOG_DIR}" 140 | -------------------------------------------------------------------------------- /examples/scripts/lmm_r1/train_sokoban.sh: -------------------------------------------------------------------------------- 1 | 2 | # To run Sokoban, you will need 3 | # 1. dataset.jsonl; 2. a folder containing images; 3. a folder containing game config 4 | 5 | # Example for dataset.jsonl 6 | # {"message": "[{\"role\": \"system\", \"content\": \"You're going to play a game of Sokoban, where the goal is to manipulate the green character to push the yellow box into the target area (an area with a red dot in the center). \\nGenerate all actions from the initial frame to the end at once.\"}, {\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": \"file://path/to/Sokoban-v0-frame_65.png\"}, {\"type\": \"text\", \"text\": \"You should first thinks about the reasoning process in the mind and then provides the user with the answer, the answer is a long sequence of Left, Right, Up, Down, separated by ','. The reasoning process and answer are enclosed within and tags, respectively, e.g., To ... Left, Right, ... , which means your output should start with and end with .\"}]}]", "question": "Sokoban-v0-level_65", "answer": "", "env_path": "path/to/Sokoban-v0-level_65.npy"} 7 | # {"message": "[{\"role\": \"system\", \"content\": \"You're going to play a game of Sokoban, where the goal is to manipulate the green character to push the yellow box into the target area (an area with a red dot in the center). \\nGenerate all actions from the initial frame to the end at once.\"}, {\"role\": \"user\", \"content\": [{\"type\": \"image\", \"image\": \"file://path/to/Sokoban-small-v1-frame_91.png\"}, {\"type\": \"text\", \"text\": \"You should first thinks about the reasoning process in the mind and then provides the user with the answer, the answer is a long sequence of Left, Right, Up, Down, separated by ','. The reasoning process and answer are enclosed within and tags, respectively, e.g., To ... Left, Right, ... , which means your output should start with and end with .\"}]}]", "question": "Sokoban-small-v1-level_91", "answer": "", "env_path": "path/to/Sokoban-small-v1-level_91.npy"} 8 | 9 | 10 | # use examples/data/gen_sokoban_tasks.py to generate images and configs (.npy files) 11 | 12 | 13 | export ROOT_PATH=pwd 14 | export DATASET="/path/to/sokoban_dataset.jsonl" 15 | wandb login "wandb key" 16 | MODEL_CPK_NAME="ckpt name" 17 | PRETRAIN_MODEL="Qwen/Qwen2.5-VL-3B-Instruct" 18 | SAVE_PATH="/path/to/ckpts" 19 | mkdir -p "${SAVE_PATH}/${MODEL_CPK_NAME}" 20 | 21 | python -m openrlhf.models.remote_rm.sokoban_verifier --dataset $DATASET --prompt-template chatml --input_key message > "${SAVE_PATH}/${MODEL_CPK_NAME}/remote_rm.log" 2>&1 & 22 | childpid=$! 23 | 24 | ray start --head --node-ip-address 0.0.0.0 --num-gpus 8 --temp-dir /tmp/ray --include-dashboard=false 25 | 26 | 27 | python3 -m openrlhf.cli.train_ppo_ray \ 28 | --ref_num_nodes 1 \ 29 | --ref_num_gpus_per_node 8 \ 30 | --remote_rm_url http://127.0.0.1:5000/get_reward \ 31 | --actor_num_nodes 1 \ 32 | --actor_num_gpus_per_node 8 \ 33 | --vllm_num_engines 8 \ 34 | --vllm_tensor_parallel_size 1 \ 35 | --colocate_all_models \ 36 | --vllm_enable_sleep \ 37 | --vllm_gpu_memory_utilization 0.5 \ 38 | --vllm_sync_backend gloo \ 39 | --enable_prefix_caching \ 40 | --pretrain $PRETRAIN_MODEL \ 41 | --save_path $SAVE_PATH/$MODEL_CPK_NAME \ 42 | --micro_train_batch_size 1 \ 43 | --train_batch_size 128 \ 44 | --micro_rollout_batch_size 1 \ 45 | --rollout_batch_size 128 \ 46 | --temperature 1 \ 47 | --n_samples_per_prompt 16 \ 48 | --max_epochs 1 \ 49 | --num_episodes 4 \ 50 | --prompt_max_len 4096 \ 51 | --max_samples 100000 \ 52 | --generate_max_len 8196 \ 53 | --advantage_estimator gae \ 54 | --zero_stage 3 \ 55 | --bf16 \ 56 | --actor_learning_rate 1e-6 \ 57 | --critic_learning_rate 5e-6 \ 58 | --init_kl_coef 0 \ 59 | --lambd 1 \ 60 | --gamma 1 \ 61 | --prompt_data $DATASET \ 62 | --input_key message \ 63 | --normalize_reward \ 64 | --flash_attn \ 65 | --gradient_checkpointing \ 66 | --save_steps 5 \ 67 | --ckpt_path $SAVE_PATH/$MODEL_CPK_NAME/ckpt \ 68 | --save_hf_ckpt \ 69 | --wandb_run_name $MODEL_CPK_NAME \ 70 | --wandb_group hyper_para_search \ 71 | --freeze_prefix visual 72 | 73 | kill $childpid 74 | ray stop 75 | -------------------------------------------------------------------------------- /examples/scripts/nvidia_docker_install.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | # remove old docker 4 | sudo apt-get autoremove docker docker-ce docker-engine docker.io containerd runc 5 | dpkg -l |grep ^rc|awk '{print $2}' |sudo xargs dpkg -P 6 | sudo apt-get autoremove docker-ce-* 7 | sudo rm -rf /etc/systemd/system/docker.service.d 8 | sudo rm -rf /var/lib/docker 9 | 10 | # install docker 11 | curl https://get.docker.com | sh \ 12 | && sudo systemctl --now enable docker 13 | 14 | # install nvidia-docker 15 | distribution=$(. /etc/os-release;echo $ID$VERSION_ID) \ 16 | && curl -fsSL https://nvidia.github.io/libnvidia-container/gpgkey | sudo gpg --dearmor -o /usr/share/keyrings/nvidia-container-toolkit-keyring.gpg \ 17 | && curl -s -L https://nvidia.github.io/libnvidia-container/$distribution/libnvidia-container.list | \ 18 | sed 's#deb https://#deb [signed-by=/usr/share/keyrings/nvidia-container-toolkit-keyring.gpg] https://#g' | \ 19 | sudo tee /etc/apt/sources.list.d/nvidia-container-toolkit.list 20 | 21 | sudo apt-get update 22 | sudo apt-get install -y nvidia-container-toolkit 23 | sudo nvidia-ctk runtime configure --runtime=docker 24 | 25 | sudo groupadd docker 26 | sudo usermod -aG docker $USER 27 | newgrp docker 28 | docker ps -------------------------------------------------------------------------------- /examples/scripts/reward_func.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def reward_func(queries, prompts, labels): 5 | # queries is prompts + responses 6 | # labels is answers 7 | print(queries) 8 | return torch.randn(len(queries)) 9 | -------------------------------------------------------------------------------- /examples/scripts/serve_remote_rm.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | python -m openrlhf.cli.serve_rm \ 4 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 5 | --port 5000 \ 6 | --bf16 \ 7 | --flash_attn \ 8 | --normalize_reward \ 9 | --max_len 8192 \ 10 | --batch_size 16 -------------------------------------------------------------------------------- /examples/scripts/train_conditional_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | checkSuccess() { 4 | if [[ $? != 0 ]]; then 5 | echo "FAILED $1" 6 | exit 1 7 | fi 8 | } 9 | 10 | mkdir -p ./checkpoint/llama-2-8b-csft 11 | RM_OUTPUT=./checkpoint/llama-2-8b-csft/rm.jsonl 12 | 13 | read -r -d '' get_rewards_commands < 0)); then 32 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 33 | fi 34 | 35 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /examples/scripts/train_knowledge_distillation.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands <> ${JOBLOG} 26 | 27 | # load training commands 28 | source ./${training_script} slurm 29 | echo training_commands &>> ${JOBLOG} 30 | echo $training_commands &>> ${JOBLOG} 31 | 32 | # master addr and port 33 | export MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1) 34 | export MASTER_PORT=9901 35 | 36 | srun --container-image="$IMAGE_NAME" \ 37 | --container-mounts="$PROJECT_PATH:/openrlhf,$HOME/.cache:/root/.cache" \ 38 | bash -c " cd /openrlhf; pip install . ; torchrun \ 39 | --nproc_per_node $GPUS_PER_NODE --nnodes $SLURM_NNODES --node_rank $SLURM_PROCID \ 40 | --master_addr $MASTER_ADDR --master_port $MASTER_PORT -m ${training_commands}" &>> ${JOBLOG} 41 | 42 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | -- python3 -m openrlhf.cli.train_ppo_ray \ 6 | --ref_num_nodes 1 \ 7 | --ref_num_gpus_per_node 2 \ 8 | --reward_num_nodes 1 \ 9 | --reward_num_gpus_per_node 2 \ 10 | --critic_num_nodes 1 \ 11 | --critic_num_gpus_per_node 2 \ 12 | --actor_num_nodes 1 \ 13 | --actor_num_gpus_per_node 2 \ 14 | --vllm_num_engines 2 \ 15 | --vllm_tensor_parallel_size 2 \ 16 | --colocate_critic_reward \ 17 | --colocate_actor_ref \ 18 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 19 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ 20 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 21 | --micro_train_batch_size 16 \ 22 | --train_batch_size 128 \ 23 | --micro_rollout_batch_size 32 \ 24 | --rollout_batch_size 1024 \ 25 | --max_samples 100000 \ 26 | --max_epochs 1 \ 27 | --prompt_max_len 1024 \ 28 | --generate_max_len 1024 \ 29 | --zero_stage 3 \ 30 | --bf16 \ 31 | --actor_learning_rate 5e-7 \ 32 | --critic_learning_rate 9e-6 \ 33 | --init_kl_coef 0.01 \ 34 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 35 | --input_key context_messages \ 36 | --apply_chat_template \ 37 | --normalize_reward \ 38 | --packing_samples \ 39 | --adam_offload \ 40 | --flash_attn \ 41 | --gradient_checkpointing \ 42 | --load_checkpoint \ 43 | --use_wandb {wandb_token} 44 | 45 | # --runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}' [Install deps] 46 | # --ref_reward_offload [Offload to CPU] 47 | # --remote_rm_url http://localhost:5000/get_reward 48 | 49 | -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray_70b.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | --no-wait \ 6 | -- python3 -m openrlhf.cli.train_ppo_ray \ 7 | --ref_num_nodes 1 \ 8 | --ref_num_gpus_per_node 4 \ 9 | --reward_num_nodes 1 \ 10 | --reward_num_gpus_per_node 4 \ 11 | --critic_num_nodes 1 \ 12 | --critic_num_gpus_per_node 8 \ 13 | --actor_num_nodes 1 \ 14 | --actor_num_gpus_per_node 8 \ 15 | --vllm_num_engines 4 \ 16 | --vllm_tensor_parallel_size 2 \ 17 | --pretrain meta-llama/Meta-Llama-3-70B-Instruct \ 18 | --reward_pretrain meta-llama/Meta-Llama-3-70B-Instruct \ 19 | --save_path /openrlhf/examples/checkpoint/llama-3-70b-rlhf \ 20 | --micro_train_batch_size 1 \ 21 | --train_batch_size 128 \ 22 | --micro_rollout_batch_size 2 \ 23 | --rollout_batch_size 1024 \ 24 | --max_epochs 1 \ 25 | --prompt_max_len 1024 \ 26 | --generate_max_len 1024 \ 27 | --zero_stage 3 \ 28 | --bf16 \ 29 | --actor_learning_rate 5e-7 \ 30 | --critic_learning_rate 9e-6 \ 31 | --init_kl_coef 0.01 \ 32 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 33 | --input_key context_messages \ 34 | --apply_chat_template \ 35 | --max_samples 100000 \ 36 | --packing_samples \ 37 | --normalize_reward \ 38 | --adam_offload \ 39 | --flash_attn \ 40 | --vllm_sync_backend nccl \ 41 | --gradient_checkpointing -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray_hybrid_engine.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | -- python3 -m openrlhf.cli.train_ppo_ray \ 6 | --ref_num_nodes 1 \ 7 | --ref_num_gpus_per_node 8 \ 8 | --reward_num_nodes 1 \ 9 | --reward_num_gpus_per_node 8 \ 10 | --critic_num_nodes 1 \ 11 | --critic_num_gpus_per_node 8 \ 12 | --actor_num_nodes 1 \ 13 | --actor_num_gpus_per_node 8 \ 14 | --vllm_num_engines 4 \ 15 | --vllm_tensor_parallel_size 2 \ 16 | --colocate_all_models \ 17 | --vllm_gpu_memory_utilization 0.5 \ 18 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 19 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-700k \ 20 | --save_path /openrlhf/examples/test_scripts/final/llama3-8b-rlhf \ 21 | --ckpt_path /openrlhf/examples/test_scripts/ckpt/llama3-8b-rlhf \ 22 | --save_hf_ckpt \ 23 | --micro_train_batch_size 4 \ 24 | --train_batch_size 128 \ 25 | --micro_rollout_batch_size 8 \ 26 | --rollout_batch_size 1024 \ 27 | --n_samples_per_prompt 1 \ 28 | --max_epochs 1 \ 29 | --prompt_max_len 1024 \ 30 | --max_samples 100000 \ 31 | --generate_max_len 1024 \ 32 | --zero_stage 3 \ 33 | --bf16 \ 34 | --actor_learning_rate 5e-7 \ 35 | --critic_learning_rate 9e-6 \ 36 | --init_kl_coef 0.01 \ 37 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 38 | --input_key context_messages \ 39 | --apply_chat_template \ 40 | --normalize_reward \ 41 | --gradient_checkpointing \ 42 | --packing_samples \ 43 | --vllm_sync_backend nccl \ 44 | --enforce_eager \ 45 | --vllm_enable_sleep \ 46 | --deepspeed_enable_sleep 47 | 48 | -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray_ring.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | ray job submit --address="http://127.0.0.1:8265" \ 4 | --runtime-env-json='{"working_dir": "/openrlhf"}' \ 5 | -- python3 -m openrlhf.cli.train_ppo_ray \ 6 | --ref_num_nodes 1 \ 7 | --ref_num_gpus_per_node 2 \ 8 | --reward_num_nodes 1 \ 9 | --reward_num_gpus_per_node 2 \ 10 | --critic_num_nodes 1 \ 11 | --critic_num_gpus_per_node 2 \ 12 | --actor_num_nodes 1 \ 13 | --actor_num_gpus_per_node 2 \ 14 | --vllm_num_engines 2 \ 15 | --vllm_tensor_parallel_size 2 \ 16 | --colocate_critic_reward \ 17 | --colocate_actor_ref \ 18 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 19 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ 20 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 21 | --micro_train_batch_size 16 \ 22 | --train_batch_size 128 \ 23 | --micro_rollout_batch_size 32 \ 24 | --rollout_batch_size 1024 \ 25 | --max_samples 100000 \ 26 | --max_epochs 1 \ 27 | --prompt_max_len 1024 \ 28 | --generate_max_len 1024 \ 29 | --zero_stage 3 \ 30 | --bf16 \ 31 | --actor_learning_rate 5e-7 \ 32 | --critic_learning_rate 9e-6 \ 33 | --init_kl_coef 0.01 \ 34 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 35 | --input_key context_messages \ 36 | --apply_chat_template \ 37 | --normalize_reward \ 38 | --packing_samples \ 39 | --adam_offload \ 40 | --flash_attn \ 41 | --gradient_checkpointing \ 42 | --load_checkpoint \ 43 | --ring_attn_size 2 \ 44 | --ring_head_stride 2 \ 45 | --use_wandb {wandb_token} 46 | 47 | # --runtime-env-json='{"setup_commands": ["pip install openrlhf[vllm]"]}' [Install deps] 48 | # --ref_reward_offload [Offload to CPU] 49 | # --remote_rm_url http://localhost:5000/get_reward 50 | 51 | -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_ray_slurm.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | #SBATCH -p { partition } 4 | #SBATCH -A { account } 5 | #SBATCH -J { jobname } 6 | #SBATCH -N 2 # 64x8x4 7 | #SBATCH -t {LIMIT_TIME} # wall time 8 | #SBATCH --ntasks-per-node=1 # tasks per node 9 | #SBATCH --exclusive # exclusive node access 10 | #SBATCH --mem=0 # all mem avail 11 | #SBATCH --mail-type=FAIL # only send email on failure 12 | #SBATCH --overcommit # needed for pytorch 13 | 14 | # project settings 15 | OPENRLHF_PATH= 16 | MOUNT="$OPENRLHF_PATH:/openrlhf,$HOME/.cache:/root/.cache" 17 | IMAGE_NAME="nvcr.io/nvidia/pytorch:24.07-py3" 18 | RAY_VERSION=2.12.0 19 | 20 | JOBLOG="$(realpath .)/train_ppo_llama_ray-$SLURM_JOB_ID.log" 21 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} started ..." &>> ${JOBLOG} 22 | 23 | # launch ray daemon 24 | nodes=$(scontrol show hostnames "$SLURM_JOB_NODELIST") # Getting the node names 25 | nodes_array=( $nodes ) 26 | node_1=${nodes_array[0]} 27 | ip=$node_1 28 | 29 | port=6379 30 | ip_head=$ip:$port 31 | export ip_head 32 | echo "IP Head: $ip_head" &>> ${JOBLOG} 33 | 34 | echo "STARTING HEAD at $node_1" &>> ${JOBLOG} 35 | srun --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 36 | "pip install ray[default]==$RAY_VERSION \ 37 | && /root/.local/bin/ray start --head --node-ip-address=$ip --port=$port --block" &>> ${JOBLOG} & 38 | sleep 10s 39 | 40 | worker_num=$((SLURM_JOB_NUM_NODES)) #number of nodes other than the head node 41 | for ((i = 1; i < worker_num; i++)); do 42 | node_i=${nodes_array[$i]} 43 | echo "STARTING WORKER $i at $node_i" &>> ${JOBLOG} 44 | srun --nodes=1 --ntasks=1 -w "$node_i" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 45 | "pip install ray[default]==$RAY_VERSION \ 46 | && /root/.local/bin/ray start --address $ip_head --block" &>> ${JOBLOG} & 47 | sleep 1s; 48 | done 49 | 50 | sleep 30s 51 | 52 | # ===== submit ray job ===== 53 | # Job start 54 | srun --overlap --nodes=1 --ntasks=1 -w "$node_1" --container-image="$IMAGE_NAME" --container-mounts="$MOUNT" bash -c \ 55 | "pip install ray[default]==$RAY_VERSION \ 56 | && /root/.local/bin/ray job submit --address=http://localhost:8265 \ 57 | --runtime-env-json='{\"working_dir\": \"/openrlhf\", \"pip\": \"/openrlhf/requirements.txt\"}' \ 58 | -- python3 -m openrlhf.cli.train_ppo_ray \ 59 | --ref_num_nodes 1 \ 60 | --ref_num_gpus_per_node 4 \ 61 | --reward_num_nodes 1 \ 62 | --reward_num_gpus_per_node 4 \ 63 | --critic_num_nodes 1 \ 64 | --critic_num_gpus_per_node 4 \ 65 | --actor_num_nodes 1 \ 66 | --actor_num_gpus_per_node 4 \ 67 | --vllm_num_engines 4 \ 68 | --vllm_tensor_parallel_size 2 \ 69 | --colocate_critic_reward \ 70 | --colocate_actor_ref \ 71 | --pretrain OpenRLHF/Llama-3-8b-sft-mixture \ 72 | --reward_pretrain OpenRLHF/Llama-3-8b-rm-mixture \ 73 | --save_path /openrlhf/examples/checkpoint/llama3-8b-rlhf \ 74 | --micro_train_batch_size 8 \ 75 | --train_batch_size 128 \ 76 | --micro_rollout_batch_size 16 \ 77 | --rollout_batch_size 1024 \ 78 | --max_samples 100000 \ 79 | --max_epochs 1 \ 80 | --prompt_max_len 1024 \ 81 | --generate_max_len 1024 \ 82 | --zero_stage 3 \ 83 | --bf16 \ 84 | --actor_learning_rate 5e-7 \ 85 | --critic_learning_rate 9e-6 \ 86 | --init_kl_coef 0.01 \ 87 | --prompt_data OpenRLHF/prompt-collection-v0.1 \ 88 | --input_key context_messages \ 89 | --apply_chat_template \ 90 | --normalize_reward \ 91 | --adam_offload \ 92 | --flash_attn \ 93 | --packing_samples \ 94 | --vllm_sync_backend nccl \ 95 | --gradient_checkpointing \ 96 | --use_wandb {wandb_token}" &>> ${JOBLOG} 97 | 98 | echo "$(date '+%Y-%m-%d %H:%M:%S') Job ${SLURM_JOB_ID} stopped ..." &>> ${JOBLOG} -------------------------------------------------------------------------------- /examples/scripts/train_ppo_llama_with_dynamic_sampling.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands < 0)); then 31 | POLICY_MODEL_PATH=$MODEL_OUTPUT_PATH 32 | fi 33 | 34 | read -r -d '' generate_commands <$ITER_LOG_PATH 98 | fi 99 | done -------------------------------------------------------------------------------- /examples/scripts/train_rm_llama.sh: -------------------------------------------------------------------------------- 1 | set -x 2 | 3 | read -r -d '' training_commands < None: 28 | super().__init__() 29 | self.tokenizer = tokenizer 30 | self.strategy = strategy 31 | self.max_length = max_length 32 | self.multiple_of = multiple_of 33 | 34 | # chat_template 35 | self.input_key = getattr(self.strategy.args, "input_key", None) 36 | self.label_key = getattr(self.strategy.args, "label_key", None) 37 | self.placeholder_token = getattr(self.strategy.args, "placeholder_token", None) 38 | self.reward_tokens = getattr(self.strategy.args, "reward_tokens", None) 39 | 40 | self.placeholder_token_id = convert_token_to_id(self.placeholder_token, self.tokenizer) 41 | 42 | # Store the processed data in class attributes 43 | self.inputs = dataset[self.input_key] 44 | self.labels = dataset[self.label_key] 45 | 46 | def __len__(self): 47 | length = len(self.inputs) 48 | return length 49 | 50 | def __getitem__(self, idx): 51 | input_token = self.tokenizer( 52 | self.inputs[idx], 53 | max_length=self.max_length, 54 | padding=False, 55 | truncation=True, 56 | return_tensors="pt", 57 | add_special_tokens=False, 58 | ) 59 | 60 | input_ids = input_token["input_ids"] 61 | label_values = self.labels[idx] 62 | assert isinstance(label_values, list), "labels should be a list of strings or numbers" 63 | if isinstance(label_values[0], str): 64 | label_tokens = [] 65 | for label in label_values: 66 | assert ( 67 | self.reward_tokens is None or label in self.reward_tokens 68 | ), f"label should be in reward tokens {self.reward_tokens}, got {label}" 69 | label_tokens.append(convert_token_to_id(label, self.tokenizer)) 70 | 71 | # label_tokens is list of token id (for '+', '-', etc) 72 | label_tensor = torch.tensor(label_tokens, dtype=input_ids.dtype) 73 | else: 74 | # label_values is list of float numbers (for reward values) 75 | label_tensor = torch.tensor(label_values, dtype=torch.float) 76 | # Motivation: inputs_ids maybe truncated to self.max_length, where placeholder_tokens at the end may be removed. 77 | # We should also truncate the labels to match the length of input_ids 78 | # Step 1: Create a mask for placeholder token positions 79 | mask = input_ids == self.placeholder_token_id 80 | # Step 2: Ensure that label_tensor is truncated along the last dimension 81 | # Find the length of the last dimension of the mask 82 | num_placeholders = mask.sum(dim=-1) 83 | # Truncate label_tensor along the last dimension to match num_placeholders 84 | truncated_labels = label_tensor[..., : num_placeholders.max()] 85 | # Step 3: Update labels at placeholder token positions 86 | labels = torch.full_like(input_ids, -100) 87 | labels[mask] = truncated_labels 88 | 89 | return ( 90 | input_ids, 91 | input_token["attention_mask"], 92 | labels, 93 | ) 94 | 95 | def collate_fn(self, item_list): 96 | input_ids = [] 97 | input_masks = [] 98 | label_ids = [] 99 | for input_id, input_mask, label_id in item_list: 100 | input_ids.append(input_id) 101 | input_masks.append(input_mask) 102 | label_ids.append(label_id) 103 | 104 | padding_side = "right" 105 | input_ids = zero_pad_sequences(input_ids, side=padding_side, value=self.tokenizer.pad_token_id) 106 | input_masks = zero_pad_sequences(input_masks, side=padding_side) 107 | label_ids = zero_pad_sequences(label_ids, side=padding_side, value=self.tokenizer.pad_token_id) 108 | return input_ids, input_masks, label_ids 109 | -------------------------------------------------------------------------------- /openrlhf/datasets/prompts_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | from tqdm import tqdm 3 | 4 | 5 | def preprocess_data(data, input_template=None, input_key="input", label_key=None, apply_chat_template=None) -> str: 6 | if apply_chat_template: 7 | chat = data[input_key] 8 | if isinstance(chat, str): 9 | chat = [{"role": "user", "content": chat}] 10 | prompt = apply_chat_template(chat, tokenize=False, add_generation_prompt=True) 11 | else: 12 | prompt = data[input_key] 13 | if input_template: 14 | prompt = input_template.format(prompt) 15 | 16 | # for Reinforced Fine-tuning 17 | label = "" if label_key is None else data[label_key] 18 | return prompt, label 19 | 20 | 21 | class PromptDataset(Dataset): 22 | """ 23 | Dataset for PPO model 24 | 25 | Args: 26 | dataset: dataset for PPO model 27 | tokenizer: tokenizer for PPO model 28 | max_length: max length of input 29 | """ 30 | 31 | def __init__( 32 | self, 33 | dataset, 34 | tokenizer, 35 | strategy, 36 | input_template=None, 37 | ) -> None: 38 | super().__init__() 39 | self.strategy = strategy 40 | self.tokenizer = tokenizer 41 | 42 | # chat_template 43 | self.input_template = input_template 44 | input_key = getattr(self.strategy.args, "input_key", None) 45 | label_key = getattr(self.strategy.args, "label_key", None) 46 | apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 47 | 48 | if apply_chat_template: 49 | apply_chat_template = self.tokenizer.apply_chat_template 50 | 51 | self.prompts = [] 52 | self.labels = [] 53 | self.datasources = [] 54 | for data in tqdm(dataset, desc="Preprocessing data", disable=not self.strategy.is_rank_0()): 55 | prompt, label = preprocess_data(data, input_template, input_key, label_key, apply_chat_template) 56 | self.prompts.append(prompt) 57 | self.labels.append(label) 58 | self.datasources.append(data.get("datasource", "default")) 59 | 60 | def __len__(self): 61 | length = len(self.prompts) 62 | return length 63 | 64 | def __getitem__(self, idx): 65 | return self.datasources[idx], self.prompts[idx], self.labels[idx] 66 | -------------------------------------------------------------------------------- /openrlhf/datasets/unpaired_preference_dataset.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from torch.utils.data import Dataset 5 | 6 | from .utils import zero_pad_sequences 7 | 8 | 9 | def preprocess_data( 10 | data, input_template=None, input_key=None, output_key=None, label_key=None, apply_chat_template=None 11 | ): 12 | """ 13 | Preprocess data from raw dataset to prompt, response, label 14 | 15 | Args: 16 | data: raw data from dataset 17 | """ 18 | label = data[label_key] 19 | 20 | if apply_chat_template: 21 | if output_key: 22 | prompt = apply_chat_template(data[input_key], tokenize=False, add_generation_prompt=True) 23 | response = apply_chat_template(data[input_key] + data[output_key], tokenize=False)[len(prompt) :] 24 | else: 25 | prompt = apply_chat_template(data[input_key][:-1], tokenize=False, add_generation_prompt=True) 26 | response = apply_chat_template(data[input_key], tokenize=False)[len(prompt) :] 27 | else: 28 | prompt = data[input_key] 29 | response = data[output_key] 30 | if input_template: 31 | prompt = input_template.format(prompt) 32 | return prompt, response, label 33 | 34 | 35 | class UnpairedPreferenceDataset(Dataset): 36 | """ 37 | Unpaired preference dataset for algorithm, like KTO 38 | 39 | Args: 40 | dataset: raw dataset 41 | self.tokenizer: self.tokenizer for model 42 | self.max_length: max length of input 43 | """ 44 | 45 | def __init__( 46 | self, dataset, tokenizer: Callable, max_length: int, strategy, input_template=None, num_processors=8 47 | ) -> None: 48 | super().__init__() 49 | self.tokenizer = tokenizer 50 | self.strategy = strategy 51 | self.max_length = max_length 52 | 53 | # chat_template 54 | self.input_template = input_template 55 | self.input_key = getattr(self.strategy.args, "input_key", None) 56 | self.output_key = getattr(self.strategy.args, "output_key", None) 57 | self.label_key = getattr(self.strategy.args, "label_key", None) 58 | self.apply_chat_template = getattr(self.strategy.args, "apply_chat_template", False) 59 | 60 | if self.apply_chat_template: 61 | self.apply_chat_template = self.tokenizer.apply_chat_template 62 | tokenizer_chat_template = getattr(self.strategy.args, "tokenizer_chat_template", None) 63 | if tokenizer_chat_template: 64 | self.tokenizer.chat_template = tokenizer_chat_template 65 | 66 | # Parallel loading datasets 67 | processed_dataset = dataset.map( 68 | self.process_data, remove_columns=dataset.column_names, num_proc=num_processors 69 | ) 70 | 71 | # Filter out None values if necessary 72 | processed_dataset = processed_dataset.filter(lambda x: x["prompt"] is not None) 73 | 74 | # Store the processed data in class attributes 75 | self.prompts = processed_dataset["prompt"] 76 | self.responses = processed_dataset["response"] 77 | self.labels = processed_dataset["label"] 78 | self.prompt_ids_lens = processed_dataset["prompt_ids_len"] 79 | 80 | def process_data(self, data): 81 | prompt, response, label = preprocess_data( 82 | data, self.input_template, self.input_key, self.output_key, self.label_key, self.apply_chat_template 83 | ) 84 | prompt_token = self.tokenizer( 85 | prompt, 86 | max_length=self.max_length, 87 | padding=False, 88 | truncation=True, 89 | return_tensors="pt", 90 | add_special_tokens=False, 91 | ) 92 | prompt_ids_len = prompt_token["attention_mask"].int().sum().item() 93 | 94 | # filter the sample whose length is greater than max_length (2 for answer length) 95 | if prompt_ids_len >= self.max_length - 2: 96 | prompt = None 97 | 98 | return {"prompt": prompt, "response": response, "label": label, "prompt_ids_len": prompt_ids_len} 99 | 100 | def __len__(self): 101 | return len(self.prompts) 102 | 103 | def __getitem__(self, index): 104 | return self.prompts[index], self.responses[index], self.labels[index], self.prompt_ids_lens[index] 105 | 106 | def collate_fn(self, item_list): 107 | def tokenizer(prompt, response): 108 | text = (prompt + response).rstrip("\n") 109 | if not text.endswith(self.tokenizer.eos_token): 110 | text += " " + self.tokenizer.eos_token 111 | inputs = self.tokenizer( 112 | text, 113 | max_length=self.max_length, 114 | padding=False, 115 | truncation=True, 116 | return_tensors="pt", 117 | add_special_tokens=False, 118 | ) 119 | 120 | inputs["input_ids"][0][-1] = self.tokenizer.eos_token_id 121 | inputs["attention_mask"][0][-1] = True 122 | return inputs["input_ids"], inputs["attention_mask"] 123 | 124 | tot_ids, tot_masks, tot_labels, prompt_ids_lens = [], [], [], [] 125 | for prompt, response, label, prompt_ids_len in item_list: 126 | input_ids, attention_mask = tokenizer(prompt, response) 127 | tot_ids.append(input_ids) 128 | tot_masks.append(attention_mask) 129 | tot_labels.append(label) 130 | prompt_ids_lens.append(prompt_ids_len) 131 | 132 | # add unmatched y'| x (used to estimate the KL divergence between policy and reference) 133 | for idx in range(len(item_list)): 134 | next_idx = (idx + 1) % len(item_list) 135 | input_ids, attention_mask = tokenizer(item_list[idx][0], item_list[next_idx][1]) 136 | tot_ids.append(input_ids) 137 | tot_masks.append(attention_mask) 138 | tot_labels.append(-1) 139 | prompt_ids_lens.append(item_list[idx][3]) 140 | 141 | input_ids = zero_pad_sequences(tot_ids, side="right", value=self.tokenizer.pad_token_id) 142 | attention_mask = zero_pad_sequences(tot_masks, side="right") 143 | return input_ids, attention_mask, torch.LongTensor(tot_labels), prompt_ids_lens 144 | -------------------------------------------------------------------------------- /openrlhf/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def zero_pad_sequences(sequences, side: str = "left", value=0): 6 | assert side in ("left", "right") 7 | max_len = max(seq.size(-1) for seq in sequences) 8 | padded_sequences = [] 9 | for seq in sequences: 10 | pad_len = max_len - seq.size(-1) 11 | padding = (pad_len, 0) if side == "left" else (0, pad_len) 12 | padded_sequences.append(F.pad(seq, padding, value=value)) 13 | return torch.stack(padded_sequences, dim=0) 14 | 15 | 16 | def exist_and_not_none(d, key): 17 | return key in d and not d[key] is None 18 | -------------------------------------------------------------------------------- /openrlhf/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .actor import Actor 2 | from .loss import ( 3 | DPOLoss, 4 | GPTLMLoss, 5 | KDLoss, 6 | KTOLoss, 7 | LogExpLoss, 8 | PairWiseLoss, 9 | PolicyLoss, 10 | PRMLoss, 11 | SFTLoss, 12 | ValueLoss, 13 | VanillaKTOLoss, 14 | ) 15 | from .model import get_llm_for_sequence_regression 16 | 17 | __all__ = [ 18 | "Actor", 19 | "SFTLoss", 20 | "DPOLoss", 21 | "GPTLMLoss", 22 | "KDLoss", 23 | "KTOLoss", 24 | "LogExpLoss", 25 | "PairWiseLoss", 26 | "PolicyLoss", 27 | "PRMLoss", 28 | "ValueLoss", 29 | "VanillaKTOLoss", 30 | "get_llm_for_sequence_regression", 31 | ] 32 | -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/base/data_processor.py: -------------------------------------------------------------------------------- 1 | import json 2 | from copy import deepcopy 3 | from abc import ABC, abstractmethod 4 | from dataclasses import dataclass 5 | from typing import List, Optional, Union, Dict 6 | import torch 7 | import transformers 8 | from transformers.processing_utils import ProcessorMixin 9 | from qwen_vl_utils import process_vision_info 10 | from PIL.Image import Image 11 | 12 | class BatchFeature(transformers.feature_extraction_utils.BatchFeature): 13 | def pin_memory(self): 14 | new_data = {} 15 | for k, v in self.items(): 16 | if isinstance(v, torch.Tensor): 17 | new_data[k] = v.pin_memory() 18 | else: 19 | new_data[k] = v 20 | self.data = new_data 21 | return self 22 | 23 | @dataclass 24 | class MMInputs: 25 | emb_inputs: Optional[BatchFeature | dict] = None # used for getting the multimodal input_embeds 26 | forward_inputs: Optional[BatchFeature | dict] = None # some models need extra inputs for forward, even if given input_embeds 27 | extra_info: Optional[BatchFeature | dict] = None # Reserved item for other usages. Now used for batching and splitting. 28 | 29 | def __post_init__(self): 30 | if isinstance(self.emb_inputs,(dict,type(None))): 31 | self.emb_inputs = BatchFeature(self.emb_inputs) 32 | if isinstance(self.forward_inputs,(dict,type(None))): 33 | self.forward_inputs = BatchFeature(self.forward_inputs) 34 | if isinstance(self.extra_info,(dict,type(None))): 35 | self.extra_info = BatchFeature(self.extra_info) 36 | 37 | def to(self, *args, **kwargs): 38 | self.emb_inputs = self.emb_inputs.to(*args, **kwargs) 39 | self.forward_inputs = self.forward_inputs.to(*args, **kwargs) 40 | self.extra_info = self.extra_info.to(*args, **kwargs) 41 | return self 42 | 43 | def pin_memory(self): 44 | self.emb_inputs = self.emb_inputs.pin_memory() 45 | self.forward_inputs = self.forward_inputs.pin_memory() 46 | self.extra_info = self.extra_info.pin_memory() 47 | return self 48 | 49 | def _merge_to_dict(self): 50 | result = {**self.emb_inputs.data,**self.forward_inputs.data,**self.extra_info.data} 51 | # if two items have the same key, the values should be the same. 52 | for key in result.keys(): 53 | if key in self.emb_inputs.keys(): 54 | # same value should be the same object 55 | assert result[key] is self.emb_inputs[key] 56 | if key in self.forward_inputs.keys(): 57 | assert result[key] is self.forward_inputs[key] 58 | if key in self.extra_info.keys(): 59 | assert result[key] is self.extra_info[key] 60 | return result 61 | 62 | def keys(self): 63 | return self._merge_to_dict().keys() 64 | 65 | def items(self): 66 | return self._merge_to_dict().items() 67 | 68 | def __contains__(self, key): 69 | return key in self._merge_to_dict() 70 | 71 | def __getitem__(self, key): 72 | return self._merge_to_dict()[key] 73 | 74 | 75 | class BaseDataProcessor(ABC): 76 | def __init__(self, processor: ProcessorMixin,processor_kwargs:Dict): 77 | super().__init__() 78 | self.processor = processor 79 | self.processor_kwargs = processor_kwargs 80 | # We use process_vision_info of qwen_vl_utils to get the image inputs for all model, 81 | # To be compatible with Qwen2VLImageProcessor, we always set the min_pixels and max_pixels for the processor 82 | self.min_pixels = processor_kwargs["min_pixels"] 83 | self.max_pixels = processor_kwargs["max_pixels"] 84 | @abstractmethod 85 | def __call__( 86 | self, 87 | messages: Union[Dict, List[str], str], 88 | max_length: int, 89 | padding: bool = True, 90 | device: Optional[Union[str, torch.device]] = None, 91 | return_tensors: Optional[str] = "pt", 92 | add_special_tokens: Optional[bool] = False, 93 | truncation: Optional[bool] = True, 94 | ) -> MMInputs: 95 | """ 96 | We mainly use this function to get the visual inputs for the model. 97 | """ 98 | raise NotImplementedError 99 | 100 | def _add_pixel_bounds(self,messages:List[List[Dict]]) -> List[List[Dict]]: 101 | DEFAULT_MIN_PIXELS = self.min_pixels 102 | DEFAULT_MAX_PIXELS = self.max_pixels 103 | 104 | def process_content(content): 105 | if isinstance(content, list): 106 | for item in content: 107 | if isinstance(item, dict) and item.get("type") == "image": 108 | if "min_pixels" not in item: 109 | item["min_pixels"] = DEFAULT_MIN_PIXELS 110 | if "max_pixels" not in item: 111 | item["max_pixels"] = DEFAULT_MAX_PIXELS 112 | return content 113 | 114 | for message in messages: 115 | for msg in message: 116 | msg["content"] = process_content(msg["content"]) 117 | return messages 118 | 119 | @abstractmethod 120 | def make_input_batch(self, inputs: List[MMInputs]) -> MMInputs: 121 | raise NotImplementedError 122 | 123 | @abstractmethod 124 | def split_input_batch(self, batch: MMInputs) -> List[MMInputs]: 125 | raise NotImplementedError 126 | 127 | def _format_messages(self, messages: Union[Dict, List[str], str]) -> List[List[Dict]]: 128 | messages = deepcopy(messages) 129 | if isinstance(messages, list) and isinstance(messages[0], str): 130 | formated_messages = [json.loads(m) for m in messages] 131 | elif isinstance(messages, str): 132 | formated_messages = [json.loads(messages)] 133 | elif isinstance(messages, dict): 134 | formated_messages = [[messages]] 135 | else: 136 | raise ValueError("Invalid messages format, must be a list of strings or a string or a dict") 137 | return self._add_pixel_bounds(formated_messages) 138 | 139 | def apply_chat_template( 140 | self, 141 | messages: Union[Dict, List[str], str], 142 | tokenize: bool = False, 143 | add_generation_prompt: bool = True, 144 | ) -> List[str]: 145 | messages = self._format_messages(messages) 146 | 147 | return self.processor.apply_chat_template( 148 | messages, tokenize=tokenize, add_generation_prompt=add_generation_prompt 149 | ) 150 | 151 | def get_images_from_messages( 152 | self, messages: Union[Dict, List[str], str] 153 | ) -> List[Image]: 154 | messages = self._format_messages(messages) 155 | image_inputs, _ = process_vision_info(messages) 156 | return image_inputs 157 | 158 | 159 | @property 160 | def pad_token_id(self) -> int: 161 | return self.processor.tokenizer.pad_token_id 162 | 163 | @property 164 | def eos_token_id(self) -> int: 165 | return self.processor.tokenizer.eos_token_id 166 | 167 | @property 168 | def tokenizer(self): 169 | return self.processor.tokenizer -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/base/patch.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | class BasePatch(ABC): 4 | def __init__(self): 5 | self.loaded = False 6 | @abstractmethod 7 | def _add_get_inputs_embeds(): 8 | ''' 9 | Add a `get_inputs_embeds(*args,**kwargs)` method to the model class, 10 | which embeds image embeddings into the text embeddings and return the results. 11 | ''' 12 | return NotImplementedError 13 | 14 | @abstractmethod 15 | def _add_get_position_ids(): 16 | ''' 17 | Add a `get_posiiton_ids(*args,**kwargs)` method to the model class, 18 | which return the position_ids of the given inputs. 19 | ''' 20 | return NotImplementedError 21 | 22 | @abstractmethod 23 | def _add_offset_split_position_ids(): 24 | ''' 25 | Add a `offset_split_position_ids(*args,**kwargs)` method to the model class, 26 | which offset the split position_ids to true position_ids. 27 | ''' 28 | return NotImplementedError 29 | 30 | def _register_to_autoclass(): 31 | ''' 32 | Register the model to the corresponding AutoModel class and AutoConfig class. Used for non-hf customized model. 33 | ''' 34 | return NotImplementedError 35 | 36 | def apply_liger_kernel(): 37 | ''' 38 | Apply liger kernel to the model. 39 | ''' 40 | return NotImplementedError 41 | 42 | @classmethod 43 | @abstractmethod 44 | def _load_all_patches(cls): 45 | ''' 46 | Load all patches. 47 | ''' 48 | return NotImplementedError 49 | 50 | def load_all_patches(self,use_liger_kernel=False): 51 | if not self.loaded: 52 | self._load_all_patches() 53 | self.loaded = True 54 | if use_liger_kernel: 55 | self.apply_liger_kernel() 56 | -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/gemma3/data_processor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict, Union 2 | import json 3 | 4 | import torch 5 | 6 | from ..base.data_processor import BaseDataProcessor, MMInputs 7 | 8 | 9 | class Gemma3_VLDataProcessor(BaseDataProcessor): 10 | def __call__( 11 | self, 12 | messages, 13 | max_length, 14 | padding=True, 15 | device=None, 16 | return_tensors="pt", 17 | add_special_tokens=False, 18 | truncation=True, 19 | ) -> MMInputs: 20 | messages = self._format_messages(messages) 21 | processor = self.processor 22 | batch = processor.apply_chat_template( 23 | messages, 24 | add_generation_prompt=True, 25 | tokenize=True, 26 | padding=padding, 27 | max_length=max_length, 28 | add_special_tokens=add_special_tokens, 29 | truncation=truncation, 30 | return_tensors=return_tensors, 31 | return_dict=True, 32 | ) 33 | emb_inputs, extra_info, forward_inputs = self._split_input_dict(batch) 34 | return MMInputs(emb_inputs=emb_inputs, extra_info=extra_info, forward_inputs=forward_inputs).to(device) 35 | 36 | def _split_input_dict(self, input_dict: Dict) -> tuple[Dict, Dict]: 37 | extra_info = {} 38 | forward_inputs = {} 39 | if "input_ids" in input_dict: 40 | extra_info["input_ids"] = input_dict.pop("input_ids") 41 | if "attention_mask" in input_dict: 42 | extra_info["attention_mask"] = input_dict.pop("attention_mask") 43 | if "token_type_ids" in input_dict: 44 | forward_inputs["token_type_ids"] = input_dict.pop("token_type_ids") 45 | return input_dict, extra_info, forward_inputs 46 | 47 | def make_input_batch(self, inputs: List[MMInputs]) -> MMInputs: 48 | # each element has no batch dimension 49 | batch = {} 50 | # collect all keys 51 | for inp in inputs: 52 | batch.update({k: None for k, v in inp.items() if v is not None}) 53 | for k in batch.keys(): 54 | if k in ["input_ids", "attention_mask", "token_type_ids"]: 55 | batch[k] = torch.stack([inp[k] for inp in inputs if k in inp], dim=0) 56 | elif k in ["pixel_values"]: 57 | # concat all patches of all images in a batch in the first dimension 58 | batch[k] = torch.cat([inp[k] for inp in inputs if k in inp], dim=0) 59 | else: 60 | raise ValueError(f"Unknown key {k} for Gemma3_VLDataProcessor") 61 | emb_inputs, extra_info, forward_inputs = self._split_input_dict(batch) 62 | return MMInputs(emb_inputs=emb_inputs, extra_info=extra_info, forward_inputs=forward_inputs) 63 | 64 | def split_input_batch(self, batch: MMInputs) -> List[MMInputs]: 65 | batch_size = len(batch["input_ids"]) 66 | batch_kwargs = [{} for _ in range(batch_size)] 67 | # first process None values 68 | keys = [] 69 | for k, v in batch.items(): 70 | if v is not None: 71 | keys.append(k) 72 | else: 73 | for i in range(batch_size): 74 | batch_kwargs[i][k] = None 75 | 76 | if "pixel_values" in keys and ("input_ids" not in keys): 77 | raise ValueError("Cannot split batch with pixel_values without input_ids") 78 | 79 | for k in ["input_ids", "attention_mask", "token_type_ids"]: 80 | if k in keys: 81 | vals = batch[k] 82 | if isinstance(vals, torch.Tensor): 83 | vals = torch.unbind(vals) 84 | assert batch_size == len(vals) 85 | for i, v in enumerate(vals): 86 | batch_kwargs[i][k] = v 87 | if "pixel_values" in keys: 88 | pixel_values = batch["pixel_values"] 89 | for i in range(batch_size): 90 | token_type_ids_i = batch_kwargs[i]["token_type_ids"] 91 | assert (token_type_ids_i == 1).sum() % 256 == 0 92 | img_num = (token_type_ids_i == 1).sum().item() // 256 93 | if img_num == 0: 94 | batch_kwargs[i]["pixel_values"] = None 95 | continue 96 | 97 | pixel_values_i = pixel_values[:img_num] 98 | assert len(pixel_values_i) == img_num 99 | pixel_values = pixel_values[img_num:] 100 | batch_kwargs[i]["pixel_values"] = pixel_values_i 101 | assert len(pixel_values) == 0, f"{pixel_values.shape}" 102 | mm_inputs_list = [] 103 | for b in batch_kwargs: 104 | emb_inputs, extra_info, forward_inputs = self._split_input_dict(b) 105 | mm_inputs_list.append( 106 | MMInputs(emb_inputs=emb_inputs, extra_info=extra_info, forward_inputs=forward_inputs) 107 | ) 108 | return mm_inputs_list 109 | 110 | def warp_str_content_to_dict(self, messages_list: List[List[Dict]]): 111 | """ 112 | Gemma Processor needs the content key to be a list of dict. 113 | """ 114 | for messages in messages_list: 115 | for message in messages: 116 | if isinstance(message["content"], str): 117 | message["content"] = [{"type": "text", "text": message["content"]}] 118 | else: 119 | assert isinstance(message["content"], list) 120 | return messages_list 121 | 122 | def _format_messages(self, messages: Union[Dict, List[str], str]) -> List[List[Dict]]: 123 | formated_messages = super()._format_messages(messages) 124 | formated_messages = self.warp_str_content_to_dict(formated_messages) 125 | 126 | return formated_messages 127 | 128 | 129 | DataProcessor = Gemma3_VLDataProcessor 130 | 131 | __all__ = ["DataProcessor"] 132 | -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/gemma3/patch.py: -------------------------------------------------------------------------------- 1 | from ..base.patch import BasePatch 2 | import torch 3 | 4 | 5 | class Gemma3_Patch(BasePatch): 6 | def _add_get_inputs_embeds(): 7 | from transformers import Gemma3ForConditionalGeneration 8 | from transformers.utils import is_torchdynamo_compiling 9 | 10 | def get_inputs_embeds(self, input_ids, pixel_values=None, **kwargs): 11 | inputs_embeds = self.get_input_embeddings()(input_ids) 12 | # Merge text and images 13 | if pixel_values is not None: 14 | image_features = self.get_image_features(pixel_values) 15 | 16 | if input_ids is None: 17 | special_image_mask = inputs_embeds == self.get_input_embeddings()( 18 | torch.tensor(self.config.image_token_index, dtype=torch.long, device=inputs_embeds.device) 19 | ) 20 | else: 21 | special_image_mask = (input_ids == self.config.image_token_index).unsqueeze(-1) 22 | special_image_mask = special_image_mask.expand_as(inputs_embeds).to(inputs_embeds.device) 23 | 24 | if ( 25 | not is_torchdynamo_compiling() 26 | and inputs_embeds[special_image_mask].numel() != image_features.numel() 27 | ): 28 | image_tokens_in_text = (special_image_mask).sum(dim=1).sum(dim=0)[0] 29 | raise ValueError( 30 | f"Number of images does not match number of special image tokens in the input text. " 31 | f"Got {image_tokens_in_text} image tokens in the text but {image_features.shape[0] * image_features.shape[1]} " 32 | "tokens from image embeddings." 33 | ) 34 | image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype) 35 | inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features) 36 | else: 37 | fake_pixel_values = torch.zeros( 38 | (1, 3, 224, 224), device=inputs_embeds.device, dtype=inputs_embeds.dtype 39 | ) 40 | image_features = self.get_image_features(fake_pixel_values) 41 | inputs_embeds = inputs_embeds + 0 * image_features.mean() 42 | return inputs_embeds 43 | 44 | Gemma3ForConditionalGeneration.get_inputs_embeds = get_inputs_embeds 45 | 46 | def _add_get_position_ids(): 47 | from transformers import Gemma3ForConditionalGeneration 48 | 49 | def get_position_ids(self, input_ids, attention_mask=None, **kwargs): 50 | if attention_mask is None: 51 | attention_mask = torch.ones_like(input_ids) 52 | position_ids = attention_mask.long().cumsum(-1) - 1 53 | position_ids.masked_fill_(attention_mask == 0, 1) 54 | return position_ids 55 | 56 | Gemma3ForConditionalGeneration.get_position_ids = get_position_ids 57 | 58 | def _add_offset_split_position_ids(): 59 | from transformers import Gemma3ForConditionalGeneration 60 | 61 | def offset_split_position_ids(self, split_position_ids, hacked_position_ids): 62 | # For common position_ids, hacked_position_ids is what we want 63 | return hacked_position_ids 64 | 65 | Gemma3ForConditionalGeneration.offset_split_position_ids = offset_split_position_ids 66 | 67 | def apply_liger_kernel(): 68 | from liger_kernel.transformers import apply_liger_kernel_to_gemma3 69 | 70 | apply_liger_kernel_to_gemma3() 71 | 72 | @classmethod 73 | def _load_all_patches(cls): 74 | cls._add_get_inputs_embeds() 75 | cls._add_get_position_ids() 76 | cls._add_offset_split_position_ids() 77 | 78 | 79 | Patch = Gemma3_Patch() 80 | -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/llm/data_processor.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from ..base.data_processor import BaseDataProcessor, MMInputs 3 | from qwen_vl_utils import process_vision_info 4 | import torch 5 | from loguru import logger 6 | 7 | class LLMProcessor: 8 | def __init__(self,tokenizer): 9 | self.tokenizer = tokenizer 10 | 11 | def __call__(self,*args,**kwargs): 12 | return self.tokenizer(*args,**kwargs) 13 | 14 | def apply_chat_template(self,*args,**kwargs): 15 | return self.tokenizer.apply_chat_template(*args,**kwargs) 16 | 17 | def save_pretrained(self,*args,**kwargs): 18 | self.tokenizer.save_pretrained(*args,**kwargs) 19 | 20 | 21 | class LLMDataProcessor(BaseDataProcessor): 22 | def __call__( 23 | self, 24 | messages, 25 | max_length, 26 | padding=True, 27 | device=None, 28 | return_tensors="pt", 29 | add_special_tokens=False, 30 | truncation=True, 31 | ) -> MMInputs: 32 | messages = self._format_messages(messages) 33 | texts = self.processor.apply_chat_template( 34 | messages, tokenize=False, add_generation_prompt=True 35 | ) 36 | image_inputs, video_inputs = process_vision_info(messages) 37 | if image_inputs or video_inputs: 38 | logger.warning("Vision inputs are not supported for LLMs") 39 | batch = self.processor( 40 | texts, 41 | padding=padding, 42 | max_length=max_length, 43 | truncation=truncation, 44 | return_tensors=return_tensors, 45 | add_special_tokens=add_special_tokens, 46 | ) 47 | return MMInputs(extra_info=batch).to(device) 48 | 49 | def make_input_batch(self, inputs: List[MMInputs]) -> MMInputs: 50 | input_ids = torch.stack([inp["input_ids"] for inp in inputs], dim=0) 51 | attention_mask = torch.stack([inp["attention_mask"] for inp in inputs], dim=0) 52 | return MMInputs(extra_info={"input_ids": input_ids, "attention_mask": attention_mask}) 53 | 54 | def split_input_batch(self, batch: MMInputs) -> List[MMInputs]: 55 | input_ids_batch = batch["input_ids"].unbind(dim=0) 56 | attention_mask_batch = batch["attention_mask"].unbind(dim=0) 57 | return [ 58 | MMInputs(extra_info={"input_ids": input_id, "attention_mask": attention_mask}) 59 | for input_id, attention_mask in zip(input_ids_batch, attention_mask_batch) 60 | ] 61 | 62 | DataProcessor = LLMDataProcessor 63 | 64 | __all__ = ["LLMProcessor", "DataProcessor"] -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/llm/patch.py: -------------------------------------------------------------------------------- 1 | from ..base.patch import BasePatch 2 | import torch 3 | 4 | class LLMPatch(BasePatch): 5 | ''' 6 | This patch is used to hack the common LLM models for compatibility with LMMs. 7 | ''' 8 | def _add_get_inputs_embeds(): 9 | from transformers.modeling_utils import PreTrainedModel 10 | def get_inputs_embeds(self, input_ids, **kwargs): 11 | if input_ids is None: 12 | return None 13 | return self.get_input_embeddings()(input_ids) 14 | PreTrainedModel.get_inputs_embeds = get_inputs_embeds 15 | 16 | def _add_get_position_ids(): 17 | from transformers.modeling_utils import PreTrainedModel 18 | def get_position_ids(self, input_ids, attention_mask=None, **kwargs): 19 | if attention_mask is None: 20 | attention_mask = torch.ones_like(input_ids) 21 | position_ids = attention_mask.long().cumsum(-1) - 1 22 | position_ids.masked_fill_(attention_mask == 0, 1) 23 | return position_ids 24 | PreTrainedModel.get_position_ids = get_position_ids 25 | 26 | def _add_offset_split_position_ids(): 27 | from transformers.modeling_utils import PreTrainedModel 28 | def offset_split_position_ids(self, split_position_ids, hacked_position_ids): 29 | # For common position_ids, hacked_position_ids is what we want 30 | return hacked_position_ids 31 | PreTrainedModel.offset_split_position_ids = offset_split_position_ids 32 | 33 | def apply_liger_kernel(): 34 | # For LLM, we directly apply liger_kernel in get_generation_cls 35 | pass 36 | 37 | @classmethod 38 | def _load_all_patches(cls): 39 | cls._add_get_inputs_embeds() 40 | cls._add_get_position_ids() 41 | cls._add_offset_split_position_ids() 42 | 43 | Patch = LLMPatch() -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/phi3_v/patch.py: -------------------------------------------------------------------------------- 1 | from ..base.patch import BasePatch 2 | import torch 3 | 4 | class Phi3_VPatch(BasePatch): 5 | def _register_to_autoclass(): 6 | from transformers import AutoModelForImageTextToText, AutoConfig, AutoProcessor 7 | from .src.configuration_phi3_v import Phi3VConfig 8 | from .src.modeling_phi3_v import Phi3VForCausalLM 9 | from .src.processing_phi3_v import Phi3VProcessor 10 | AutoConfig.register("phi3_v", Phi3VConfig) 11 | AutoModelForImageTextToText.register(Phi3VConfig, Phi3VForCausalLM) 12 | AutoProcessor.register(Phi3VConfig, Phi3VProcessor) 13 | 14 | def _add_get_inputs_embeds(): 15 | from .src.modeling_phi3_v import Phi3VForCausalLM 16 | def get_inputs_embeds(self, input_ids, pixel_values=None, image_sizes=None, **kwargs): 17 | return self.model.vision_embed_tokens(input_ids, pixel_values=pixel_values, image_sizes=image_sizes) 18 | Phi3VForCausalLM.get_inputs_embeds = get_inputs_embeds 19 | 20 | def _add_get_position_ids(): 21 | from .src.modeling_phi3_v import Phi3VForCausalLM 22 | def get_position_ids(self, input_ids, attention_mask=None, **kwargs): 23 | if attention_mask is None: 24 | attention_mask = torch.ones_like(input_ids) 25 | position_ids = attention_mask.long().cumsum(-1) - 1 26 | position_ids.masked_fill_(attention_mask == 0, 1) 27 | return position_ids 28 | Phi3VForCausalLM.get_position_ids = get_position_ids 29 | 30 | def _add_offset_split_position_ids(): 31 | from .src.modeling_phi3_v import Phi3VForCausalLM 32 | def offset_split_position_ids(self, split_position_ids, hacked_position_ids): 33 | # For common position_ids, hacked_position_ids is what we want 34 | return hacked_position_ids 35 | Phi3VForCausalLM.offset_split_position_ids = offset_split_position_ids 36 | 37 | def apply_liger_kernel(): 38 | from liger_kernel.transformers import LigerPhi3SwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb 39 | from .src import modeling_phi3_v 40 | modeling_phi3_v.Phi3MLP = LigerPhi3SwiGLUMLP 41 | modeling_phi3_v.Phi3RMSNorm = LigerRMSNorm 42 | modeling_phi3_v.apply_rotary_pos_emb = liger_rotary_pos_emb 43 | 44 | @classmethod 45 | def _load_all_patches(cls): 46 | cls._add_get_inputs_embeds() 47 | cls._add_get_position_ids() 48 | cls._add_offset_split_position_ids() 49 | cls._register_to_autoclass() 50 | 51 | Patch = Phi3_VPatch() -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/phi4mm/patch.py: -------------------------------------------------------------------------------- 1 | from ..base.patch import BasePatch 2 | import torch 3 | 4 | class Phi4MMPatch(BasePatch): 5 | def _register_to_autoclass(): 6 | from transformers import AutoModelForImageTextToText, AutoConfig, AutoProcessor 7 | from .src.configuration_phi4mm import Phi4MMConfig 8 | from .src.modeling_phi4mm import Phi4MMForCausalLM 9 | from .src.processing_phi4mm import Phi4MMProcessor 10 | AutoConfig.register("phi4mm", Phi4MMConfig) 11 | AutoModelForImageTextToText.register(Phi4MMConfig, Phi4MMForCausalLM) 12 | AutoProcessor.register(Phi4MMConfig, Phi4MMProcessor) 13 | 14 | def _add_get_inputs_embeds(): 15 | from .src.modeling_phi4mm import Phi4MMForCausalLM 16 | def get_inputs_embeds(self, input_ids, num_img_tokens=None,input_mode=None,**kwargs): 17 | return self.model.embed_tokens_extend(input_ids,wte=self.model.embed_tokens,**kwargs) 18 | Phi4MMForCausalLM.get_inputs_embeds = get_inputs_embeds 19 | 20 | def _add_get_position_ids(): 21 | from .src.modeling_phi4mm import Phi4MMForCausalLM 22 | def get_position_ids(self, input_ids, attention_mask=None, **kwargs): 23 | if attention_mask is None: 24 | attention_mask = torch.ones_like(input_ids) 25 | position_ids = attention_mask.long().cumsum(-1) - 1 26 | position_ids.masked_fill_(attention_mask == 0, 1) 27 | return position_ids 28 | Phi4MMForCausalLM.get_position_ids = get_position_ids 29 | 30 | def _add_offset_split_position_ids(): 31 | from .src.modeling_phi4mm import Phi4MMForCausalLM 32 | def offset_split_position_ids(self, split_position_ids, hacked_position_ids): 33 | # For common position_ids, hacked_position_ids is what we want 34 | return hacked_position_ids 35 | Phi4MMForCausalLM.offset_split_position_ids = offset_split_position_ids 36 | 37 | def _hack_multihead_attention(): 38 | import torch 39 | import torch.nn as nn 40 | raw_MultiheadAttention = nn.MultiheadAttention 41 | """ 42 | MultiheadAttention is not compatible with zero3, it accesses out_proj.weight and out_proj.bias directly, which is partitioned by zero3. 43 | We re-assign out_proj.weight and out_proj.bias to a new parameter, and the new parameter is not partitioned by zero3. 44 | """ 45 | class HackedMultiheadAttention(raw_MultiheadAttention): 46 | def __init__(self, *args, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | output_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.embed_dim))) 49 | output_proj_bias = nn.Parameter(torch.zeros(self.embed_dim)) 50 | self.out_proj.weight = output_proj_weight 51 | self.out_proj.bias = output_proj_bias 52 | 53 | torch.nn.MultiheadAttention = HackedMultiheadAttention 54 | 55 | 56 | def apply_liger_kernel(): 57 | from liger_kernel.transformers import LigerPhi3SwiGLUMLP, LigerRMSNorm, liger_rotary_pos_emb 58 | from .src import modeling_phi4mm 59 | modeling_phi4mm.Phi4MMMLP = LigerPhi3SwiGLUMLP 60 | modeling_phi4mm.Phi4MMRMSNorm = LigerRMSNorm 61 | modeling_phi4mm.apply_rotary_pos_emb = liger_rotary_pos_emb 62 | 63 | @classmethod 64 | def _load_all_patches(cls): 65 | cls._add_get_inputs_embeds() 66 | cls._add_get_position_ids() 67 | cls._add_offset_split_position_ids() 68 | cls._hack_multihead_attention() 69 | cls._register_to_autoclass() 70 | 71 | Patch = Phi4MMPatch() -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/qwen2_5_vl/data_processor.py: -------------------------------------------------------------------------------- 1 | from typing import List, Dict 2 | 3 | import torch 4 | from qwen_vl_utils import process_vision_info 5 | from ..base.data_processor import BaseDataProcessor, MMInputs 6 | 7 | 8 | class Qwen2_5_VLDataProcessor(BaseDataProcessor): 9 | def __call__( 10 | self, 11 | messages, 12 | max_length, 13 | padding=True, 14 | device=None, 15 | return_tensors="pt", 16 | add_special_tokens=False, 17 | truncation=True, 18 | ) -> MMInputs: 19 | messages = self._format_messages(messages) 20 | processor = self.processor 21 | texts = processor.apply_chat_template( 22 | messages, tokenize=False, add_generation_prompt=True 23 | ) 24 | image_inputs, video_inputs = process_vision_info(messages) 25 | 26 | batch = processor( 27 | text=texts, 28 | images=image_inputs, 29 | videos=video_inputs, 30 | padding=padding, 31 | max_length=max_length, 32 | add_special_tokens=add_special_tokens, 33 | truncation=truncation, 34 | return_tensors=return_tensors, 35 | ) 36 | emb_inputs, extra_info = self._split_input_dict(batch) 37 | return MMInputs(emb_inputs=emb_inputs,extra_info=extra_info).to(device) 38 | 39 | def _split_input_dict(self, input_dict: Dict) -> tuple[Dict, Dict]: 40 | extra_info = {} 41 | if "input_ids" in input_dict: 42 | extra_info["input_ids"] = input_dict.pop("input_ids") 43 | if "attention_mask" in input_dict: 44 | extra_info["attention_mask"] = input_dict.pop("attention_mask") 45 | return input_dict, extra_info 46 | 47 | def make_input_batch(self, inputs: List[MMInputs]) -> MMInputs: 48 | # each element has no batch dimension 49 | batch = {} 50 | # collect all keys 51 | for inp in inputs: 52 | batch.update({k:None for k,v in inp.items() if v is not None}) 53 | for k in batch.keys(): 54 | if k in ["input_ids", "attention_mask"]: 55 | batch[k] = torch.stack([inp[k] for inp in inputs if k in inp], dim=0) 56 | elif k in ["pixel_values", "image_grid_thw"]: 57 | # qwen2vl concat all patches of all images in a batch in the first dimension 58 | batch[k] = torch.cat([inp[k] for inp in inputs if k in inp], dim=0) 59 | else: 60 | raise ValueError(f"Unknown key {k} for Qwen2VLDataProcessor") 61 | 62 | emb_inputs, extra_info = self._split_input_dict(batch) 63 | return MMInputs(emb_inputs=emb_inputs,extra_info=extra_info) 64 | 65 | def split_input_batch(self, batch: MMInputs) -> List[MMInputs]: 66 | batch_size = len(batch["input_ids"]) 67 | batch_kwargs = [{} for _ in range(batch_size)] 68 | # first process None values 69 | keys = [] 70 | for k, v in batch.items(): 71 | if v is not None: 72 | keys.append(k) 73 | else: 74 | for i in range(batch_size): 75 | batch_kwargs[i][k] = None 76 | 77 | if "pixel_values" in keys and ( 78 | "input_ids" not in keys or "image_grid_thw" not in keys 79 | ): 80 | raise ValueError( 81 | "Cannot split batch with pixel_values without input_ids and image_grid_thw" 82 | ) 83 | if "image_grid_thw" in keys and ("input_ids" not in keys): 84 | raise ValueError("Cannot split batch with image_grid_thw without input_ids") 85 | for k in ["input_ids", "attention_mask"]: 86 | if k in keys: 87 | vals = batch[k] 88 | if isinstance(vals, torch.Tensor): 89 | vals = torch.unbind(vals) 90 | assert batch_size == len(vals) 91 | for i, v in enumerate(vals): 92 | batch_kwargs[i][k] = v 93 | if "pixel_values" in keys: 94 | thws = batch["image_grid_thw"] # (total_img_num, (t,h,w)) 95 | pixel_values = batch["pixel_values"] 96 | vision_start_id = self.processor.tokenizer("<|vision_start|>")["input_ids"][0] 97 | vision_end_id = self.processor.tokenizer("<|vision_end|>")["input_ids"][0] 98 | for i in range(batch_size): 99 | input_ids_i = batch_kwargs[i]["input_ids"] 100 | if not isinstance(input_ids_i, torch.Tensor): 101 | input_ids_i = torch.tensor(input_ids_i) 102 | vision_start_num = (input_ids_i == vision_start_id).sum().item() 103 | vision_end_num = (input_ids_i == vision_end_id).sum().item() 104 | assert vision_start_num == vision_end_num 105 | img_num = vision_start_num 106 | if img_num == 0: 107 | batch_kwargs[i]["pixel_values"] = None 108 | batch_kwargs[i]["image_grid_thw"] = None 109 | continue 110 | thws_i = thws[:img_num] 111 | assert len(thws_i) == img_num 112 | thws = thws[img_num:] 113 | if not isinstance(thws_i, torch.Tensor): 114 | thws_i = torch.stack(thws_i) 115 | batch_kwargs[i]["image_grid_thw"] = thws_i 116 | patchs_num = thws_i.prod(dim=1).sum().item() 117 | pixel_values_i = pixel_values[:patchs_num] 118 | assert len(pixel_values_i) == patchs_num 119 | pixel_values = pixel_values[patchs_num:] 120 | batch_kwargs[i]["pixel_values"] = pixel_values_i 121 | assert len(thws) == 0 122 | assert len(pixel_values) == 0 123 | mm_inputs_list = [] 124 | for b in batch_kwargs: 125 | emb_inputs, extra_info = self._split_input_dict(b) 126 | mm_inputs_list.append(MMInputs(emb_inputs=emb_inputs,extra_info=extra_info)) 127 | return mm_inputs_list 128 | 129 | DataProcessor = Qwen2_5_VLDataProcessor 130 | 131 | __all__ = ["DataProcessor"] -------------------------------------------------------------------------------- /openrlhf/models/lmm_kits/qwen2_5_vl/patch.py: -------------------------------------------------------------------------------- 1 | from ..base.patch import BasePatch 2 | import torch 3 | 4 | class Qwen2_5_VLPatch(BasePatch): 5 | def _add_get_inputs_embeds(): 6 | from transformers import Qwen2_5_VLForConditionalGeneration 7 | def get_inputs_embeds(self, input_ids, image_grid_thw=None, video_grid_thw=None, pixel_values=None, pixel_values_videos=None, **kwargs): 8 | inputs_embeds = self.model.embed_tokens(input_ids) 9 | if pixel_values is not None: 10 | pixel_values = pixel_values.type(self.visual.dtype) 11 | image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) 12 | n_image_tokens = (input_ids == self.config.image_token_id).sum().item() 13 | n_image_features = image_embeds.shape[0] 14 | if n_image_tokens != n_image_features: 15 | raise ValueError( 16 | f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}" 17 | ) 18 | 19 | mask = input_ids == self.config.image_token_id 20 | mask_unsqueezed = mask.unsqueeze(-1) 21 | mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) 22 | image_mask = mask_expanded.to(inputs_embeds.device) 23 | 24 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 25 | inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds) 26 | else: 27 | fake_pixel_values = torch.zeros((16, 1176),device=inputs_embeds.device,dtype=self.visual.dtype) 28 | fake_image_grid_thw = torch.tensor([[1, 4, 4]],device=inputs_embeds.device,dtype=torch.int32) 29 | image_embeds = self.visual(fake_pixel_values, grid_thw=fake_image_grid_thw) 30 | image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 31 | inputs_embeds = inputs_embeds + 0*image_embeds.mean() 32 | 33 | if pixel_values_videos is not None: 34 | pixel_values_videos = pixel_values_videos.type(self.visual.dtype) 35 | video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw) 36 | n_video_tokens = (input_ids == self.config.video_token_id).sum().item() 37 | n_video_features = video_embeds.shape[0] 38 | if n_video_tokens != n_video_features: 39 | raise ValueError( 40 | f"Video features and video tokens do not match: tokens: {n_video_tokens}, features {n_video_features}" 41 | ) 42 | 43 | mask = input_ids == self.config.video_token_id 44 | mask_unsqueezed = mask.unsqueeze(-1) 45 | mask_expanded = mask_unsqueezed.expand_as(inputs_embeds) 46 | video_mask = mask_expanded.to(inputs_embeds.device) 47 | 48 | video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype) 49 | inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds) 50 | return inputs_embeds 51 | 52 | Qwen2_5_VLForConditionalGeneration.get_inputs_embeds = get_inputs_embeds 53 | 54 | def _add_get_position_ids(): 55 | from transformers import Qwen2_5_VLForConditionalGeneration 56 | def get_position_ids(self, input_ids, image_grid_thw=None, video_grid_thw=None, attention_mask=None, packing=False, **kwargs): 57 | position_ids,mrope_position_deltas = self.get_rope_index(input_ids=input_ids, image_grid_thw=image_grid_thw, video_grid_thw=video_grid_thw, attention_mask=attention_mask) 58 | if packing: 59 | # For packing, the position_ids will be unpaded and sliced later, which needs the shape: [bs,seq_len,...] 60 | # However, the position_ids of Qwen2.5VL is [3,bs,seq_len], so we need to permute it. 61 | position_ids = position_ids.permute(1,2,0) # [3,bs,seq_len] -> [bs,seq_len,3] 62 | return position_ids 63 | Qwen2_5_VLForConditionalGeneration.get_position_ids = get_position_ids 64 | 65 | def _add_offset_split_position_ids(): 66 | from transformers import Qwen2_5_VLForConditionalGeneration 67 | def offset_split_position_ids(self,position_ids,hacked_position_ids): 68 | # This function is only called when using packing, in which case, the position_ids is permuted as [bs,seq_len,3] 69 | position_ids = position_ids.permute(2,0,1) # [bs,seq_len,3] -> [3,bs,seq_len] 70 | new_position_ids = position_ids.clone() 71 | for i in range(hacked_position_ids.size(0)): 72 | seq_idxes = torch.nonzero(hacked_position_ids[i]==0)[:,0] 73 | seq_idxes = torch.cat([seq_idxes, torch.tensor([hacked_position_ids.size(1)],device=seq_idxes.device)], dim=0) 74 | st = 0 75 | for seq_idx in seq_idxes: 76 | if st == 0 and seq_idx == 0: 77 | continue 78 | #shape: [3,bs,seq_len] 79 | raw_seq_position_ids = position_ids[:,i,st:seq_idx] 80 | new_position_ids[:,i,st:seq_idx] = raw_seq_position_ids - raw_seq_position_ids[:,:1] + hacked_position_ids[i,st] 81 | st = seq_idx 82 | return new_position_ids 83 | Qwen2_5_VLForConditionalGeneration.offset_split_position_ids = offset_split_position_ids 84 | 85 | def apply_liger_kernel(): 86 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl 87 | apply_liger_kernel_to_qwen2_5_vl() 88 | 89 | @classmethod 90 | def _load_all_patches(cls): 91 | cls._add_get_inputs_embeds() 92 | cls._add_get_position_ids() 93 | cls._add_offset_split_position_ids() 94 | 95 | Patch = Qwen2_5_VLPatch() -------------------------------------------------------------------------------- /openrlhf/models/remote_rm/math_verifier.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | import re 5 | from argparse import ArgumentParser 6 | 7 | from flask import Flask, jsonify, request 8 | from latex2sympy2_extended import NormalizationConfig 9 | from math_verify import LatexExtractionConfig, parse, verify 10 | 11 | from loguru import logger 12 | from concurrent import futures 13 | 14 | app = Flask(__name__) 15 | 16 | problem_to_answer = {} 17 | enable_format_reward = True 18 | 19 | 20 | def get_response_from_query(q: str): 21 | ends_of_sentence = ["<|im_end|>", "<|end▁of▁sentence|>", "<|endoftext|>", ""] 22 | pos = re.search(response_prefix, q) 23 | if pos is None: 24 | return None 25 | response = q[pos.end() :] 26 | for e in ends_of_sentence: 27 | response = response.replace(e, "") 28 | return response.strip() 29 | 30 | 31 | def verify_format(content): 32 | """ 33 | Verify if the string meets the format requirements: 34 | - Must start with and end with 35 | - Must contain exactly one pair of ... and ... tags 36 | - No extra characters allowed between and tags 37 | """ 38 | think_count = content.count("") 39 | answer_count = content.count("") 40 | return bool(re.match(format_pattern, content, re.DOTALL)) and think_count == 1 and answer_count == 1 41 | 42 | 43 | def verify_math(content, sol): 44 | gold_parsed = parse( 45 | sol, 46 | extraction_mode="first_match", 47 | extraction_config=[LatexExtractionConfig()], 48 | ) 49 | if len(gold_parsed) != 0: 50 | # We require the answer to be provided in correct latex (no malformed operators) 51 | answer_parsed = parse( 52 | content, 53 | extraction_config=[ 54 | LatexExtractionConfig( 55 | normalization_config=NormalizationConfig( 56 | nits=False, 57 | malformed_operators=False, 58 | basic_latex=True, 59 | boxed="all", 60 | units=True, 61 | ), 62 | # Ensures that boxed is tried first 63 | boxed_match_priority=0, 64 | try_extract_without_anchor=False, 65 | ) 66 | ], 67 | extraction_mode="first_match", 68 | ) 69 | # Reward 1 if the content is the same as the ground truth, 0 otherwise 70 | try: 71 | reward = float(verify(answer_parsed, gold_parsed)) 72 | except Exception as e: 73 | reward = 0.0 74 | print("Failed to verify: ", e) 75 | else: 76 | # If the gold solution is not parseable, we reward 1 to skip this example 77 | reward = 1.0 78 | print("Failed to parse gold solution: ", sol) 79 | return reward 80 | 81 | 82 | @app.route("/get_reward", methods=["POST"]) 83 | def get_reward(): 84 | # 获取请求中的 JSON 数据 85 | data = request.get_json() 86 | if "query" not in data or "prompts" not in data or "labels" not in data: 87 | return jsonify({"error": "query, prompts, and labels fields are required"}), 400 88 | rewards = [] 89 | format_rewards = [] 90 | acc_rewards_futures = [] 91 | for q, problem, answer in zip(data["query"], data["prompts"], data["labels"]): 92 | if problem is None: 93 | return jsonify({"error": f"problem not found from {q}"}), 400 94 | if not answer.startswith("$"): 95 | answer = "$" + answer + "$" 96 | 97 | response = get_response_from_query(q) or q 98 | if response is None: 99 | return jsonify({"error": f"response not found from {q}"}), 400 100 | # Apply format reward only if enabled 101 | format_reward = 0.0 102 | if enable_format_reward: 103 | format_reward = float(verify_format(response)) * 0.5 104 | acc_reward_future = math_verify_executor.submit(verify_math, response, answer) 105 | 106 | do_print = random.randint(1, 20) == 1 107 | if do_print: 108 | info = f"Query: {q}\n\nProblem: {problem}\n\n Answer: {answer}\n\n Response: {response}\n\n Format Reward: {format_reward}\n\n Acc Reward: {acc_reward_future.result()}\n\n" 109 | info = re.sub(r"<\|.*?\|>|", "", info) 110 | logger.info(info) 111 | 112 | format_rewards.append(format_reward) 113 | acc_rewards_futures.append(acc_reward_future) 114 | acc_rewards = [f.result() for f in acc_rewards_futures] 115 | rewards = [f + a for f, a in zip(format_rewards, acc_rewards)] 116 | # 返回包含 rewards 的响应 117 | return jsonify({"rewards": rewards, "format_rewards": format_rewards, "acc_rewards": acc_rewards}) 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = ArgumentParser() 122 | parser.add_argument("--prompt-template", type=str, default=None, help="Prompt template", required=True) 123 | parser.add_argument("--input_key", type=str, default="prompt", help="The key name of prompt.") 124 | parser.add_argument("--log_file", type=str, default="remote_rm.log", help="Log file path") 125 | parser.add_argument( 126 | "--disable-format-reward", 127 | action="store_true", 128 | help="Disable format reward calculation. When enabled (default), responses get +0.5 reward for correct format.", 129 | ) 130 | args = parser.parse_args() 131 | if os.path.exists(args.log_file): 132 | os.remove(args.log_file) 133 | logger.remove() 134 | logger.add(args.log_file) 135 | 136 | # Set format reward flag based on command line argument 137 | enable_format_reward = not args.disable_format_reward 138 | print(f"Format reward is {'disabled' if args.disable_format_reward else 'enabled'}") 139 | logger.info(f"Format reward is {'disabled' if args.disable_format_reward else 'enabled'}") 140 | 141 | format_pattern = r"^(?:(?!).)*(?:(?!).)*\Z" 142 | 143 | if args.prompt_template == "chatml": 144 | response_prefix = r"<\|im_start\|>assistant\n" 145 | elif args.prompt_template == "qwen1": 146 | response_prefix = r"<|Assistant|>" 147 | elif args.prompt_template == "base": 148 | response_prefix = r"Assistant: " 149 | elif args.prompt_template == "phi3": 150 | response_prefix = r"<|assistant|>\n" 151 | elif args.prompt_template == "phi4": 152 | response_prefix = r"<|assistant|>\n" 153 | elif args.prompt_template == "gemma3": 154 | response_prefix = r"model\n" 155 | else: 156 | raise ValueError(f"Unknown chat format: {args.prompt_template}") 157 | 158 | # math_verify can only run in main thread 159 | math_verify_executor = futures.ProcessPoolExecutor(max_workers=16) 160 | 161 | app.run(host="0.0.0.0", port=5000, debug=False, use_reloader=False) 162 | math_verify_executor.shutdown() 163 | -------------------------------------------------------------------------------- /openrlhf/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from .dpo_trainer import DPOTrainer 2 | from .kd_trainer import KDTrainer 3 | from .kto_trainer import KTOTrainer 4 | from .ppo_trainer import PPOTrainer 5 | from .prm_trainer import ProcessRewardModelTrainer 6 | from .rm_trainer import RewardModelTrainer 7 | from .sft_trainer import SFTTrainer 8 | 9 | __all__ = [ 10 | "DPOTrainer", 11 | "KDTrainer", 12 | "KTOTrainer", 13 | "PPOTrainer", 14 | "ProcessRewardModelTrainer", 15 | "RewardModelTrainer", 16 | "SFTTrainer", 17 | ] 18 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .kl_controller import AdaptiveKLController, FixedKLController 2 | from .replay_buffer import NaiveReplayBuffer 3 | 4 | __all__ = [ 5 | "AdaptiveKLController", 6 | "FixedKLController", 7 | "NaiveReplayBuffer", 8 | ] 9 | -------------------------------------------------------------------------------- /openrlhf/trainer/ppo_utils/kl_controller.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class AdaptiveKLController: 5 | """ 6 | Adaptive KL controller described in the paper: 7 | https://arxiv.org/pdf/1909.08593.pdf 8 | """ 9 | 10 | def __init__(self, init_kl_coef, target, horizon): 11 | self.value = init_kl_coef 12 | self.target = target 13 | self.horizon = horizon 14 | 15 | def update(self, current, n_steps): 16 | target = self.target 17 | proportional_error = np.clip(current / target - 1, -0.2, 0.2) 18 | mult = 1 + proportional_error * n_steps / self.horizon 19 | self.value *= mult 20 | 21 | 22 | class FixedKLController: 23 | """Fixed KL controller.""" 24 | 25 | def __init__(self, kl_coef): 26 | self.value = kl_coef 27 | 28 | def update(self, current, n_steps): 29 | pass 30 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/__init__.py: -------------------------------------------------------------------------------- 1 | from .launcher import DistributedTorchRayActor, PPORayActorGroup, ReferenceModelRayActor, RewardModelRayActor 2 | from .vllm_engine import batch_vllm_engine_call, create_vllm_engines 3 | 4 | __all__ = [ 5 | "DistributedTorchRayActor", 6 | "PPORayActorGroup", 7 | "ReferenceModelRayActor", 8 | "RewardModelRayActor", 9 | "create_vllm_engines", 10 | "batch_vllm_engine_call", 11 | ] 12 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | # Address https://github.com/ray-project/ray/issues/51117 5 | # This function is used to get the bundle indices of a placement group 6 | # and ensure that the bundles placed on the same node are grouped together. 7 | def get_bundle_indices(placement_group, index, length): 8 | import ray 9 | 10 | pg_infos = ray.util.placement_group_table(placement_group) 11 | 12 | node_id_to_bundles = {} 13 | for bundle, node_id in pg_infos["bundles_to_node_id"].items(): 14 | node_id_to_bundles.setdefault(node_id, []).append(bundle) 15 | 16 | sorted_bundle_indices = sum(node_id_to_bundles.values(), []) 17 | return sorted_bundle_indices[index * length : (index + 1) * length] 18 | 19 | 20 | def ray_noset_visible_devices(env_vars=os.environ): 21 | # Refer to 22 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/nvidia_gpu.py#L95-L96 23 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/amd_gpu.py#L102-L103 24 | # https://github.com/ray-project/ray/blob/3b9e729f6a669ffd85190f901f5e262af79771b0/python/ray/_private/accelerators/amd_gpu.py#L114-L115 25 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/npu.py#L94-L95 26 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/hpu.py#L116-L117 27 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/neuron.py#L108-L109 28 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/tpu.py#L171-L172 29 | # https://github.com/ray-project/ray/blob/161849364a784442cc659fb9780f1a6adee85fce/python/ray/_private/accelerators/intel_gpu.py#L97-L98 30 | NOSET_VISIBLE_DEVICES_ENV_VARS_LIST = [ 31 | "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES", 32 | "RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES", 33 | "RAY_EXPERIMENTAL_NOSET_HIP_VISIBLE_DEVICES", 34 | "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES", 35 | "RAY_EXPERIMENTAL_NOSET_HABANA_VISIBLE_MODULES", 36 | "RAY_EXPERIMENTAL_NOSET_NEURON_RT_VISIBLE_CORES", 37 | "RAY_EXPERIMENTAL_NOSET_TPU_VISIBLE_CHIPS", 38 | "RAY_EXPERIMENTAL_NOSET_ONEAPI_DEVICE_SELECTOR", 39 | ] 40 | return any(env_vars.get(env_var) for env_var in NOSET_VISIBLE_DEVICES_ENV_VARS_LIST) 41 | 42 | 43 | def get_physical_gpu_id(): 44 | import torch 45 | 46 | device = torch.cuda.current_device() 47 | props = torch.cuda.get_device_properties(device) 48 | return str(props.uuid) 49 | -------------------------------------------------------------------------------- /openrlhf/trainer/ray/vllm_worker_wrap.py: -------------------------------------------------------------------------------- 1 | class WorkerWrap: 2 | def init_process_group( 3 | self, master_address, master_port, rank_offset, world_size, group_name, backend="nccl", use_ray=False 4 | ): 5 | """Init torch process group for model weights update""" 6 | import torch 7 | from openrlhf.utils.distributed_util import init_process_group 8 | 9 | assert torch.distributed.is_initialized(), f"default torch process group must be initialized" 10 | assert group_name != "", f"group name must not be empty" 11 | 12 | rank = torch.distributed.get_rank() + rank_offset 13 | if use_ray: 14 | import ray.util.collective as collective 15 | 16 | collective.init_collective_group(world_size=world_size, rank=rank, backend=backend, group_name=group_name) 17 | self._model_update_group = group_name 18 | else: 19 | self._model_update_group = init_process_group( 20 | backend=backend, 21 | init_method=f"tcp://{master_address}:{master_port}", 22 | world_size=world_size, 23 | rank=rank, 24 | group_name=group_name, 25 | ) 26 | self._model_update_with_ray = use_ray 27 | print( 28 | f"init_process_group: master_address={master_address}, master_port={master_port}, ", 29 | f"rank={rank}, world_size={world_size}, group_name={group_name}", 30 | ) 31 | 32 | def update_weight(self, name, dtype, shape, empty_cache=False): 33 | import torch 34 | 35 | """Broadcast weight to all vllm workers from source rank 0 (actor model)""" 36 | if torch.distributed.get_rank() == 0: 37 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 38 | 39 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 40 | weight = torch.empty(shape, dtype=dtype, device="cuda") 41 | if self._model_update_with_ray: 42 | import ray.util.collective as collective 43 | 44 | collective.broadcast(weight, 0, group_name=self._model_update_group) 45 | else: 46 | torch.distributed.broadcast(weight, 0, group=self._model_update_group) 47 | 48 | self.model_runner.model.load_weights(weights=[(name, weight)]) 49 | 50 | del weight 51 | # TODO: should we empty cache if all weights have updated? 52 | # if empty_cache: 53 | # torch.cuda.empty_cache() 54 | 55 | def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None, empty_cache=False): 56 | import torch 57 | from openrlhf.trainer.ray.utils import get_physical_gpu_id 58 | 59 | if torch.distributed.get_rank() == 0: 60 | print(f"update weight: {name}, dtype: {dtype}, shape: {shape}") 61 | 62 | assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" 63 | 64 | handle = ipc_handles[get_physical_gpu_id()] 65 | device_id = self.device.index 66 | func, args = handle 67 | list_args = list(args) 68 | # the key is to change device id to the current device id 69 | # in case two processes have different CUDA_VISIBLE_DEVICES 70 | list_args[6] = device_id 71 | weight = func(*list_args) 72 | self.model_runner.model.load_weights(weights=[(name, weight)]) 73 | torch.cuda.synchronize() 74 | -------------------------------------------------------------------------------- /openrlhf/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .processor import get_processor, reward_normalization 2 | from .utils import blending_datasets, get_strategy, get_tokenizer 3 | 4 | __all__ = [ 5 | "get_processor", 6 | "reward_normalization", 7 | "blending_datasets", 8 | "get_strategy", 9 | "get_tokenizer" 10 | ] 11 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/__init__.py: -------------------------------------------------------------------------------- 1 | from .deepspeed import DeepspeedStrategy 2 | 3 | __all__ = [ 4 | "DeepspeedStrategy", 5 | ] 6 | -------------------------------------------------------------------------------- /openrlhf/utils/deepspeed/deepspeed_utils.py: -------------------------------------------------------------------------------- 1 | from deepspeed.runtime.zero.partition_parameters import ZeroParamStatus 2 | 3 | 4 | def get_train_ds_config( 5 | offload, 6 | adam_offload=True, 7 | stage=2, 8 | bf16=True, 9 | max_norm=1.0, 10 | zpg=8, 11 | grad_accum_dtype=None, 12 | overlap_comm=False, 13 | use_ds_universal_ckpt=False, 14 | deepcompile=False, 15 | ): 16 | device = "cpu" if offload else "none" 17 | zero_opt_dict = { 18 | "stage": stage, 19 | "offload_param": {"device": device}, 20 | "offload_optimizer": { 21 | "device": "cpu" if adam_offload else "none", 22 | "pin_memory": True, 23 | }, 24 | "sub_group_size": "auto", 25 | "stage3_max_live_parameters": "auto", 26 | "stage3_max_reuse_distance": "auto", 27 | "stage3_param_persistence_threshold": "auto", 28 | "stage3_prefetch_bucket_size": "auto", 29 | "reduce_bucket_size": "auto", 30 | # ZeRO++ 31 | "zero_hpz_partition_size": zpg, 32 | "zero_quantized_weights": False, 33 | "zero_quantized_gradients": False, 34 | } 35 | if overlap_comm: 36 | zero_opt_dict["overlap_comm"] = True 37 | zero_opt_dict["contiguous_gradients"] = True 38 | if stage == 3: 39 | zero_opt_dict["reduce_scatter"] = True 40 | 41 | return { 42 | "steps_per_print": 100, 43 | "zero_optimization": zero_opt_dict, 44 | "bf16": { 45 | "enabled": bf16, 46 | }, 47 | "gradient_clipping": max_norm, 48 | "prescale_gradients": False, 49 | "wall_clock_breakdown": False, 50 | "data_types": {"grad_accum_dtype": grad_accum_dtype}, 51 | "checkpoint": { 52 | "load_universal": use_ds_universal_ckpt, 53 | }, 54 | "compile": { 55 | "deepcompile": deepcompile, 56 | }, 57 | } 58 | 59 | 60 | def get_eval_ds_config( 61 | offload, 62 | stage=0, 63 | bf16=True, 64 | deepcompile=False, 65 | ): 66 | # At least for 0.16.6, DeepCompile hasn't support pure inference mode 67 | # https://github.com/deepspeedai/DeepSpeed/pull/7225 68 | deepcompile = False 69 | 70 | zero_opt_dict = { 71 | "stage": stage, 72 | "stage3_max_live_parameters": "auto", 73 | "stage3_max_reuse_distance": "auto", 74 | "stage3_param_persistence_threshold": "auto", 75 | "stage3_prefetch_bucket_size": "auto", 76 | "offload_param": { 77 | "device": "cpu" if offload else "none", 78 | "pin_memory": True, 79 | }, 80 | } 81 | return { 82 | "steps_per_print": 100, 83 | "zero_optimization": zero_opt_dict, 84 | "bf16": { 85 | "enabled": bf16, 86 | }, 87 | "gradient_clipping": 1.0, 88 | "prescale_gradients": False, 89 | "wall_clock_breakdown": False, 90 | "compile": { 91 | "deepcompile": deepcompile, 92 | }, 93 | } 94 | 95 | 96 | def get_optimizer_grouped_parameters( 97 | model, 98 | weight_decay, 99 | no_decay_name_list=["bias", "layer_norm.weight", "layernorm.weight", "norm.weight", "ln_f.weight"], 100 | ): 101 | optimizer_grouped_parameters = [ 102 | { 103 | "params": [ 104 | p 105 | for n, p in model.named_parameters() 106 | if (not any(nd in n for nd in no_decay_name_list) and p.requires_grad) 107 | ], 108 | "weight_decay": weight_decay, 109 | }, 110 | { 111 | "params": [ 112 | p 113 | for n, p in model.named_parameters() 114 | if (any(nd in n for nd in no_decay_name_list) and p.requires_grad) 115 | ], 116 | "weight_decay": 0.0, 117 | }, 118 | ] 119 | return optimizer_grouped_parameters 120 | 121 | 122 | def _z3_params_to_fetch(param_list): 123 | return [p for p in param_list if hasattr(p, "ds_id") and p.ds_status == ZeroParamStatus.NOT_AVAILABLE] 124 | 125 | 126 | def offload_deepspeed_states(model, pin_memory=True, non_blocking=True): 127 | zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] 128 | adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" 129 | 130 | # state offloading not required when using Adam optimizer offloading 131 | if adam_offload: 132 | return 133 | 134 | if zero_stage != 3: 135 | raise NotImplementedError("Only Zero stage 3 is currently supported") 136 | 137 | # if zero_stage == 3 and not adam_offload: 138 | import deepspeed 139 | import torch 140 | from deepspeed.runtime.zero.offload_config import OffloadDeviceEnum, OffloadStateTypeEnum 141 | 142 | offload_state_types = [ 143 | OffloadStateTypeEnum.optim_states, 144 | OffloadStateTypeEnum.contiguous_grad_buffer, 145 | OffloadStateTypeEnum.hp_params, 146 | ] 147 | 148 | if deepspeed.__version__ >= "0.16.5": 149 | # These offload types are fixed in https://github.com/deepspeedai/DeepSpeed/pull/7050 150 | offload_state_types += [ 151 | OffloadStateTypeEnum.lp_grads, 152 | # OffloadStateTypeEnum.lp_params, 153 | ] 154 | 155 | model.optimizer.offload_states( 156 | include=offload_state_types, 157 | device=OffloadDeviceEnum.cpu, 158 | pin_memory=pin_memory, 159 | non_blocking=non_blocking, 160 | ) 161 | model.empty_partition_cache() 162 | torch.cuda.empty_cache() 163 | torch.distributed.barrier() 164 | torch.cuda.synchronize() 165 | 166 | 167 | def reload_deepspeed_states(model, non_blocking=True): 168 | zero_stage = model.zero_optimization_stage() # config['zero_optimization']['stage'] 169 | adam_offload = model.config["zero_optimization"]["offload_optimizer"]["device"] == "cpu" 170 | 171 | # state offloading not required when using Adam optimizer offloading 172 | if adam_offload: 173 | return 174 | 175 | if zero_stage != 3: 176 | raise NotImplementedError("Only Zero stage 3 is currently supported") 177 | 178 | # if zero_stage == 3 and not adam_offload: 179 | import torch 180 | 181 | model.reload_states(non_blocking=non_blocking) 182 | torch.cuda.empty_cache() 183 | torch.distributed.barrier() 184 | torch.cuda.synchronize() 185 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Iterator, Optional, TypeVar 3 | 4 | import torch 5 | import torch.distributed as dist 6 | from torch.utils.data.dataset import Dataset 7 | from torch.utils.data.sampler import Sampler 8 | 9 | 10 | __all__ = ["DistributedSampler"] 11 | 12 | 13 | _T_co = TypeVar("_T_co", covariant=True) 14 | 15 | 16 | # Adapted from https://github.com/pytorch/pytorch/blob/5298acb5c76855bc5a99ae10016efc86b27949bd/torch/utils/data/distributed.py 17 | class DistributedSampler(Sampler[_T_co]): 18 | r"""Sampler that restricts data loading to a subset of the dataset. 19 | 20 | It is especially useful in conjunction with 21 | :class:`torch.nn.parallel.DistributedDataParallel`. In such a case, each 22 | process can pass a :class:`~torch.utils.data.DistributedSampler` instance as a 23 | :class:`~torch.utils.data.DataLoader` sampler, and load a subset of the 24 | original dataset that is exclusive to it. 25 | 26 | .. note:: 27 | Dataset is assumed to be of constant size and that any instance of it always 28 | returns the same elements in the same order. 29 | 30 | Args: 31 | dataset: Dataset used for sampling. 32 | num_replicas (int, optional): Number of processes participating in 33 | distributed training. By default, :attr:`world_size` is retrieved from the 34 | current distributed group. 35 | rank (int, optional): Rank of the current process within :attr:`num_replicas`. 36 | By default, :attr:`rank` is retrieved from the current distributed 37 | group. 38 | shuffle (bool, optional): If ``True`` (default), sampler will shuffle the 39 | indices. 40 | seed (int, optional): random seed used to shuffle the sampler if 41 | :attr:`shuffle=True`. This number should be identical across all 42 | processes in the distributed group. Default: ``0``. 43 | drop_last (bool, optional): if ``True``, then the sampler will drop the 44 | tail of the data to make it evenly divisible across the number of 45 | replicas. If ``False``, the sampler will add extra indices to make 46 | the data evenly divisible across the replicas. Default: ``False``. 47 | 48 | .. warning:: 49 | In distributed mode, calling the :meth:`set_epoch` method at 50 | the beginning of each epoch **before** creating the :class:`DataLoader` iterator 51 | is necessary to make shuffling work properly across multiple epochs. Otherwise, 52 | the same ordering will be always used. 53 | 54 | Example:: 55 | 56 | >>> # xdoctest: +SKIP 57 | >>> sampler = DistributedSampler(dataset) if is_distributed else None 58 | >>> loader = DataLoader(dataset, shuffle=(sampler is None), 59 | ... sampler=sampler) 60 | >>> for epoch in range(start_epoch, n_epochs): 61 | ... if is_distributed: 62 | ... sampler.set_epoch(epoch) 63 | ... train(loader) 64 | """ 65 | 66 | def __init__( 67 | self, 68 | dataset: Dataset, 69 | num_replicas: Optional[int] = None, 70 | rank: Optional[int] = None, 71 | shuffle: bool = True, 72 | seed: int = 0, 73 | drop_last: bool = False, 74 | consumed_samples=0, 75 | ) -> None: 76 | if num_replicas is None: 77 | if not dist.is_available(): 78 | raise RuntimeError("Requires distributed package to be available") 79 | num_replicas = dist.get_world_size() 80 | if rank is None: 81 | if not dist.is_available(): 82 | raise RuntimeError("Requires distributed package to be available") 83 | rank = dist.get_rank() 84 | if rank >= num_replicas or rank < 0: 85 | raise ValueError(f"Invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]") 86 | self.dataset = dataset 87 | self.num_replicas = num_replicas 88 | self.rank = rank 89 | self.epoch = 0 90 | self.drop_last = drop_last 91 | # If the dataset length is evenly divisible by # of replicas, then there 92 | # is no need to drop any data, since the dataset will be split equally. 93 | if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type] 94 | # Split to nearest available length that is evenly divisible. 95 | # This is to ensure each rank receives the same amount of data when 96 | # using this Sampler. 97 | self.num_samples = math.ceil( 98 | (len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type] 99 | ) 100 | else: 101 | self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] 102 | self.total_size = self.num_samples * self.num_replicas 103 | self.shuffle = shuffle 104 | self.seed = seed 105 | self.consumed_indicies = consumed_samples // self.num_replicas 106 | 107 | def __iter__(self) -> Iterator[_T_co]: 108 | if self.shuffle: 109 | # deterministically shuffle based on epoch and seed 110 | g = torch.Generator() 111 | g.manual_seed(self.seed + self.epoch) 112 | indices = torch.randperm(len(self.dataset), generator=g).tolist() # type: ignore[arg-type] 113 | else: 114 | indices = list(range(len(self.dataset))) # type: ignore[arg-type] 115 | 116 | if not self.drop_last: 117 | # add extra samples to make it evenly divisible 118 | padding_size = self.total_size - len(indices) 119 | if padding_size <= len(indices): 120 | indices += indices[:padding_size] 121 | else: 122 | indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] 123 | else: 124 | # remove tail of data to make it evenly divisible. 125 | indices = indices[: self.total_size] 126 | assert len(indices) == self.total_size 127 | 128 | # subsample 129 | indices = indices[self.rank : self.total_size : self.num_replicas] 130 | # skip consumed_samples 131 | indices = indices[self.consumed_indicies :] 132 | assert len(indices) == self.num_samples - self.consumed_indicies 133 | 134 | return iter(indices) 135 | 136 | def __len__(self) -> int: 137 | return self.num_samples - self.consumed_indicies 138 | 139 | def set_epoch(self, epoch: int, consumed_samples=0) -> None: 140 | r""" 141 | Set the epoch for this sampler. 142 | 143 | When :attr:`shuffle=True`, this ensures all replicas 144 | use a different random ordering for each epoch. Otherwise, the next iteration of this 145 | sampler will yield the same ordering. 146 | 147 | Args: 148 | epoch (int): Epoch number. 149 | """ 150 | self.epoch = epoch 151 | self.consumed_indicies = consumed_samples // self.num_replicas 152 | -------------------------------------------------------------------------------- /openrlhf/utils/distributed_util.py: -------------------------------------------------------------------------------- 1 | from datetime import timedelta 2 | from typing import Any, Optional, Union 3 | 4 | import torch 5 | import torch.distributed 6 | from torch.distributed.distributed_c10d import ( 7 | Backend, 8 | PrefixStore, 9 | Store, 10 | _new_process_group_helper, 11 | _world, 12 | default_pg_timeout, 13 | rendezvous, 14 | ) 15 | 16 | 17 | def torch_dist_barrier_and_cuda_sync(): 18 | """Synchronize distributed training and CUDA operations. 19 | This function ensures that: 20 | 1. All distributed processes reach this point (barrier) 21 | 2. All CUDA operations are completed (synchronize) 22 | """ 23 | torch.distributed.barrier() 24 | torch.cuda.synchronize() 25 | 26 | 27 | # Copy from pytorch to allow creating multiple main groups. 28 | # https://github.com/pytorch/pytorch/blob/main/torch/distributed/distributed_c10d.py 29 | def init_process_group( 30 | backend: Union[str, Backend] = None, 31 | init_method: Optional[str] = None, 32 | timeout: Optional[timedelta] = None, 33 | world_size: int = -1, 34 | rank: int = -1, 35 | store: Optional[Store] = None, 36 | group_name: str = None, 37 | pg_options: Optional[Any] = None, 38 | ): 39 | assert (store is None) or (init_method is None), "Cannot specify both init_method and store." 40 | 41 | if store is not None: 42 | assert world_size > 0, "world_size must be positive if using store" 43 | assert rank >= 0, "rank must be non-negative if using store" 44 | elif init_method is None: 45 | init_method = "env://" 46 | 47 | if backend: 48 | backend = Backend(backend) 49 | else: 50 | backend = Backend("undefined") 51 | 52 | if timeout is None: 53 | timeout = default_pg_timeout 54 | 55 | # backward compatible API 56 | if store is None: 57 | rendezvous_iterator = rendezvous(init_method, rank, world_size, timeout=timeout) 58 | store, rank, world_size = next(rendezvous_iterator) 59 | store.set_timeout(timeout) 60 | 61 | # Use a PrefixStore to avoid accidental overrides of keys used by 62 | # different systems (e.g. RPC) in case the store is multi-tenant. 63 | store = PrefixStore(group_name, store) 64 | 65 | # NOTE: The pg_options parameter was renamed into backend_options in PyTorch 2.6.0 66 | # https://github.com/pytorch/pytorch/commit/a0c7029a75628cd5fa8df83c0de0ea98ee7fd844 67 | # We need to determine the appropriate parameter name based on PyTorch version 68 | pg_options_param_name = "backend_options" if str(torch.__version__) >= "2.6" else "pg_options" 69 | pg, _ = _new_process_group_helper( 70 | world_size, 71 | rank, 72 | [], 73 | backend, 74 | store, 75 | group_name=group_name, 76 | **{pg_options_param_name: pg_options}, 77 | timeout=timeout, 78 | ) 79 | 80 | _world.pg_group_ranks[pg] = {i: i for i in range(world_size)} 81 | 82 | return pg 83 | -------------------------------------------------------------------------------- /openrlhf/utils/logging_utils.py: -------------------------------------------------------------------------------- 1 | # Adapted from 2 | # https://github.com/skypilot-org/skypilot/blob/86dc0f6283a335e4aa37b3c10716f90999f48ab6/sky/sky_logging.py 3 | """Logging configuration for vLLM.""" 4 | import logging 5 | import sys 6 | 7 | _FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" 8 | _DATE_FORMAT = "%m-%d %H:%M:%S" 9 | 10 | 11 | class NewLineFormatter(logging.Formatter): 12 | """Adds logging prefix to newlines to align multi-line messages.""" 13 | 14 | def __init__(self, fmt, datefmt=None): 15 | logging.Formatter.__init__(self, fmt, datefmt) 16 | 17 | def format(self, record): 18 | msg = logging.Formatter.format(self, record) 19 | if record.message != "": 20 | parts = msg.split(record.message) 21 | msg = msg.replace("\n", "\r\n" + parts[0]) 22 | return msg 23 | 24 | 25 | _root_logger = logging.getLogger("openrlhf") 26 | _default_handler = None 27 | 28 | 29 | def _setup_logger(): 30 | _root_logger.setLevel(logging.DEBUG) 31 | global _default_handler 32 | if _default_handler is None: 33 | _default_handler = logging.StreamHandler(sys.stdout) 34 | _default_handler.flush = sys.stdout.flush # type: ignore 35 | _default_handler.setLevel(logging.INFO) 36 | _root_logger.addHandler(_default_handler) 37 | fmt = NewLineFormatter(_FORMAT, datefmt=_DATE_FORMAT) 38 | _default_handler.setFormatter(fmt) 39 | # Setting this will avoid the message 40 | # being propagated to the parent logger. 41 | _root_logger.propagate = False 42 | 43 | 44 | # The logger is initialized when the module is imported. 45 | # This is thread-safe as the module is only imported once, 46 | # guaranteed by the Python GIL. 47 | _setup_logger() 48 | 49 | 50 | def init_logger(name: str): 51 | # Use the same settings as above for root logger 52 | logger = logging.getLogger(name) 53 | logger.setLevel(logging.DEBUG) 54 | logger.addHandler(_default_handler) 55 | logger.propagate = False 56 | return logger 57 | -------------------------------------------------------------------------------- /openrlhf/utils/processor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from tqdm import tqdm 3 | 4 | 5 | def reward_normalization(objs): 6 | rewards = [float(obj["reward"]) for obj in objs] 7 | rewards = torch.tensor(rewards, dtype=torch.float64) 8 | rewards = (rewards - rewards.mean()) / rewards.std() 9 | for i, obj in enumerate(objs): 10 | obj["reward"] = rewards[i].item() 11 | 12 | 13 | # Conditional SFT 14 | # See https://arxiv.org/abs/2308.12050 15 | DEFAULT_REWARD_PROMPT = "{input} : {reward} " 16 | 17 | 18 | def conditional_sft_processor(args, objs): 19 | if "reward_template" not in args or args.reward_template is None: 20 | reward_template = DEFAULT_REWARD_PROMPT 21 | else: 22 | reward_template = args.reward_template 23 | assert "{input}" in reward_template 24 | assert "{reward}" in reward_template 25 | 26 | if args.normalize_reward: 27 | reward_normalization(objs) 28 | 29 | for obj in tqdm(objs, desc="Conditional SFT process..."): 30 | input = obj["input"] 31 | reward = "{:.2f}".format(float(obj["reward"])) 32 | input = reward_template.replace("{reward}", reward).replace("{input}", input) 33 | obj["input"] = input 34 | 35 | return objs 36 | 37 | 38 | # Rejection Sampling 39 | # See https://arxiv.org/abs/2307.09288 40 | def rejection_sampling_processor(args, objs): 41 | out = {} 42 | for obj in tqdm(objs, desc="Rejection Sampling process...."): 43 | input = obj["input"] 44 | output = obj["output"] 45 | reward = float(obj["reward"]) 46 | 47 | if input not in out: 48 | out[input] = {"output": output, "reward": reward} 49 | elif reward > out[input]["reward"]: 50 | out[input]["reward"] = reward 51 | out[input]["output"] = output 52 | 53 | return [{"input": k, "output": v["output"], "reward": v["reward"]} for k, v in out.items()] 54 | 55 | 56 | # Iterative DPO 57 | # See https://github.com/RLHFlow/Online-RLHF/blob/main/run_loop.sh 58 | def iterative_dpo_processor(args, objs): 59 | out = {} 60 | for obj in tqdm(objs, desc="Iterative DPO process...."): 61 | input = obj["input"] 62 | output = obj["output"] 63 | reward = float(obj["reward"]) 64 | 65 | if input not in out: 66 | out[input] = { 67 | "output": output, 68 | "chosen": output, 69 | "chosen_reward": reward, 70 | "rejected": output, 71 | "rejected_reward": reward, 72 | } 73 | elif reward > out[input]["chosen_reward"]: 74 | out[input]["chosen_reward"] = reward 75 | out[input]["chosen"] = output 76 | elif reward < out[input]["rejected_reward"]: 77 | out[input]["rejected_reward"] = reward 78 | out[input]["rejected"] = output 79 | 80 | return [ 81 | { 82 | "prompt": k, 83 | "chosen": v["chosen"], 84 | "chosen_reward": v["chosen_reward"], 85 | "rejected": v["rejected"], 86 | "rejected_reward": v["rejected_reward"], 87 | } 88 | for k, v in out.items() 89 | ] 90 | 91 | 92 | PROCESSORS = { 93 | "rs": rejection_sampling_processor, 94 | "csft": conditional_sft_processor, 95 | "iter_dpo": iterative_dpo_processor, 96 | } 97 | 98 | 99 | def get_processor(name): 100 | if name in PROCESSORS: 101 | return PROCESSORS[name] 102 | else: 103 | raise ValueError(f"Processor {name} does not exist.") 104 | -------------------------------------------------------------------------------- /openrlhf/utils/remote_rm_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import ray 3 | import requests 4 | import torch 5 | 6 | from openrlhf.utils.logging_utils import init_logger 7 | 8 | logger = init_logger(__name__) 9 | 10 | 11 | def request_api_wrapper(url, data, score_key="rewards", try_max_times=5): 12 | """Synchronous request API wrapper""" 13 | headers = { 14 | "Content-Type": "application/json", 15 | } 16 | for _ in range(try_max_times): 17 | try: 18 | response = requests.post(url=url, json=data, headers=headers, timeout=180) 19 | response.raise_for_status() # Raise an HTTPError for bad responses 20 | response = response.json() 21 | assert score_key in response, f"{score_key} not in {response}" 22 | return response.get(score_key) 23 | except requests.RequestException as e: 24 | logger.info(f"Request error, please check: {e}") 25 | except Exception as e: 26 | logger.info(f"Unexpected error, please check: {e}") 27 | time.sleep(1) 28 | 29 | raise Exception(f"Request error for {try_max_times} times, returning None. Please check the API server.") 30 | 31 | 32 | def remote_rm_fn(api_url, queries, prompts, labels, score_key="rewards"): 33 | """remote reward model API 34 | api_url: RM API, We assume that the API supports two modes: merging query + response and not merging 35 | queries: query+response with the template 36 | design is made optional. 37 | score_key: RM score key 38 | """ 39 | scores = request_api_wrapper(api_url, {"query": queries, "prompts": prompts, "labels": labels}, score_key) 40 | return torch.tensor(scores) 41 | 42 | 43 | @ray.remote 44 | def remote_rm_fn_ray(api_url, queries, prompts, labels, score_key="rewards"): 45 | return remote_rm_fn(api_url, queries, prompts, labels, score_key) 46 | 47 | 48 | if __name__ == "__main__": 49 | # test utils 50 | url = "http:xxx/get_rm_score" 51 | score = remote_rm_fn(url, ["example query"], ["example response"]) 52 | print(score) -------------------------------------------------------------------------------- /openrlhf/utils/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from datasets import interleave_datasets, load_dataset, load_from_disk 4 | from transformers import AutoTokenizer 5 | 6 | def get_tokenizer(pretrain, model, padding_side="left", strategy=None, use_fast=True): 7 | tokenizer = AutoTokenizer.from_pretrained(pretrain, trust_remote_code=True, use_fast=use_fast) 8 | tokenizer.padding_side = padding_side 9 | # NOTE: When enable vLLM, do not resize_token_embeddings, or the vocab size will mismatch with vLLM. 10 | # https://github.com/facebookresearch/llama-recipes/pull/196 11 | if tokenizer.pad_token is None: 12 | tokenizer.pad_token = tokenizer.eos_token 13 | tokenizer.pad_token_id = tokenizer.eos_token_id 14 | if model is not None: 15 | model.config.pad_token_id = tokenizer.pad_token_id 16 | 17 | return tokenizer 18 | 19 | 20 | def get_strategy(args): 21 | from openrlhf.utils.deepspeed import DeepspeedStrategy 22 | 23 | strategy = DeepspeedStrategy( 24 | seed=getattr(args, "seed", 42), 25 | full_determinism=getattr(args, "full_determinism", False), 26 | max_norm=getattr(args, "max_norm", 1.0), 27 | micro_train_batch_size=getattr(args, "micro_train_batch_size", 1), 28 | train_batch_size=getattr(args, "train_batch_size", 128), 29 | zero_stage=args.zero_stage, 30 | bf16=getattr(args, "bf16", True), 31 | args=args, 32 | ) 33 | return strategy 34 | 35 | 36 | def blending_datasets( 37 | datasets, 38 | probabilities=None, 39 | strategy=None, 40 | seed=42, 41 | max_count=1e8, 42 | stopping_strategy="all_exhausted", 43 | dataset_split="train", 44 | ): 45 | """Blend multiple datasets with optional probability sampling. 46 | 47 | Args: 48 | datasets (str): Comma-separated list of dataset paths 49 | probabilities (str, optional): Comma-separated list of probabilities for sampling. 50 | If None, datasets will be concatenated without probability sampling. 51 | strategy: Training strategy object 52 | seed (int): Random seed 53 | max_count (int): Maximum number of samples per dataset 54 | """ 55 | datasets = datasets.split(",") 56 | if probabilities is not None: 57 | probabilities = list(map(float, probabilities.split(","))) 58 | assert len(probabilities) == len(datasets) 59 | 60 | data_list = [] 61 | for i, dataset in enumerate(datasets): 62 | dataset = dataset.strip() 63 | strategy.print(f"dataset: {dataset}") 64 | 65 | data_dir = dataset.split("@")[1].strip() if "@" in dataset else None 66 | dataset = dataset.split("@")[0].strip() 67 | dataset_basename = os.path.basename(dataset) 68 | 69 | ext = os.path.splitext(dataset)[-1] 70 | # local python script 71 | if ext == ".py" or ( 72 | os.path.isdir(dataset) and os.path.exists(os.path.join(dataset, f"{dataset_basename}.py")) 73 | ): 74 | data = load_dataset(dataset, trust_remote_code=True) 75 | strategy.print(f"loaded {dataset} with python script") 76 | # local text file 77 | elif ext in [".json", ".jsonl", ".csv", ".parquet", ".arrow"]: 78 | ext = ext.lower().strip(".") 79 | if ext == "jsonl": 80 | ext = "json" 81 | data = load_dataset(ext, data_files=dataset) 82 | strategy.print(f"loaded {dataset} with data_files={dataset}") 83 | # local dataset saved with `datasets.Dataset.save_to_disk` 84 | elif os.path.isdir(dataset): 85 | try: 86 | data = load_from_disk(dataset) 87 | strategy.print(f"loaded {dataset} from disk") 88 | except Exception as e: 89 | strategy.print(f"failed to load {dataset} from disk: {e}") 90 | data = load_dataset(dataset, data_dir=data_dir) 91 | strategy.print(f"loaded {dataset} from files") 92 | # remote/local folder or common file 93 | elif strategy.args.use_ms: 94 | from modelscope.msdatasets import MsDataset 95 | 96 | namespace, dataset = dataset.split("/") 97 | data = MsDataset.load(dataset, namespace=namespace) 98 | else: 99 | data = load_dataset(dataset, data_dir=data_dir) 100 | strategy.print(f"loaded {dataset} from files") 101 | 102 | # Select dataset 103 | if dataset_split and dataset_split in data: 104 | data = data[dataset_split] 105 | data = data.select(range(min(max_count, len(data)))) 106 | data_list.append(data) 107 | 108 | # merge datasets 109 | if strategy.is_rank_0(): 110 | print(data_list) 111 | 112 | # If probabilities is None, concatenate datasets directly 113 | if probabilities is None: 114 | from datasets import concatenate_datasets 115 | 116 | dataset = concatenate_datasets(data_list) 117 | else: 118 | dataset = interleave_datasets( 119 | data_list, 120 | probabilities=probabilities, 121 | seed=seed, 122 | stopping_strategy=stopping_strategy, 123 | ) 124 | 125 | return dataset 126 | 127 | 128 | def convert_token_to_id(token, tokenizer): 129 | if isinstance(token, str): 130 | token = tokenizer.encode(token, add_special_tokens=False) 131 | assert len(token) == 1 132 | return token[0] 133 | else: 134 | raise ValueError("token should be int or str") -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = [ 3 | "packaging", 4 | "setuptools >= 49.4.0", 5 | "wheel", 6 | ] 7 | build-backend = "setuptools.build_meta" 8 | 9 | [tool.isort] 10 | profile = "black" # black-compatible 11 | line_length = 119 # should match black parameters 12 | ignore_whitespace = true # ignore whitespace for compatibility with the initial style 13 | py_version = 310 # python 3.10 as a target version 14 | sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] 15 | default_section = "THIRDPARTY" 16 | extend_skip = ["setup.py", "docs/source/conf.py"] 17 | 18 | 19 | [tool.black] 20 | line_length = 119 21 | 22 | [tool.ruff] 23 | line-length = 119 24 | 25 | [tool.pytest.ini_options] 26 | # durations=0 will display all tests execution time, sorted in ascending order starting from from the slowest one. 27 | # -vv will also display tests with durration = 0.00s 28 | addopts = "--verbose --pyargs --durations=0 --strict-markers" # always add these arguments to pytest 29 | testpaths = ["./tests"] # must be an explicit path to avoid importing another "tests" module 30 | # directories to ignore when discovering tests 31 | norecursedirs = [ 32 | "external", 33 | "examples", 34 | "docs", 35 | "scripts", 36 | "tools", 37 | "tutorials", 38 | "*.egg", 39 | ".*", 40 | "_darcs", 41 | "build", 42 | "CVS", 43 | "dist", 44 | "venv", 45 | "{arch}", 46 | ] 47 | # markers to select tests, use `pytest --markers` to see all available markers, `pytest -m ""` to select tests 48 | markers = [ 49 | "unit: marks unit test, i.e. testing a single, well isolated functionality (deselect with '-m \"not unit\"')", 50 | "integration: marks test checking the elements when integrated into subsystems (deselect with '-m \"not integration\"')", 51 | "system: marks test working at the highest integration level (deselect with '-m \"not system\"')", 52 | "acceptance: marks test checking whether the developed product/model passes the user defined acceptance criteria (deselect with '-m \"not acceptance\"')", 53 | "docs: mark tests related to documentation (deselect with '-m \"not docs\"')", 54 | "skipduringci: marks tests that are skipped ci as they are addressed by Jenkins jobs but should be run to test user setups", 55 | "pleasefixme: marks tests that are broken and need fixing", 56 | ] 57 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | bitsandbytes 3 | datasets 4 | deepspeed==0.16.7 5 | einops 6 | flask 7 | flash-attn==2.7.4.post1 8 | isort 9 | jsonlines 10 | loralib 11 | math-verify==0.5.2 12 | levenshtein 13 | optimum 14 | optree>=0.13.0 15 | packaging 16 | peft 17 | pynvml>=12.0.0 18 | qwen_vl_utils 19 | ray[default]==2.43.0 20 | tensorboard 21 | torch 22 | torchvision 23 | torchmetrics 24 | tqdm 25 | transformers==4.51.3 26 | transformers_stream_generator 27 | wandb 28 | wheel 29 | loguru -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import platform 4 | 5 | from datetime import datetime 6 | from setuptools import find_packages, setup 7 | from wheel.bdist_wheel import bdist_wheel as _bdist_wheel 8 | 9 | _build_mode = os.getenv("OPENRLHF_BUILD_MODE", "") 10 | 11 | 12 | def _is_nightly(): 13 | return _build_mode.lower() == "nightly" 14 | 15 | 16 | def _fetch_requirements(path): 17 | with open(path, "r") as fd: 18 | return [r.strip() for r in fd.readlines()] 19 | 20 | 21 | def _fetch_readme(): 22 | with open("README.md", encoding="utf-8") as f: 23 | return f.read() 24 | 25 | 26 | def _fetch_version(): 27 | with open("version.txt", "r") as f: 28 | version = f.read().strip() 29 | 30 | if _is_nightly(): 31 | now = datetime.now() 32 | date_str = now.strftime("%Y%m%d") 33 | version += f".dev{date_str}" 34 | 35 | return version 36 | 37 | 38 | def _fetch_package_name(): 39 | return "openrlhf-nightly" if _is_nightly() else "openrlhf" 40 | 41 | 42 | # Custom wheel class to modify the wheel name 43 | class bdist_wheel(_bdist_wheel): 44 | def finalize_options(self): 45 | _bdist_wheel.finalize_options(self) 46 | self.root_is_pure = False 47 | 48 | def get_tag(self): 49 | python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" 50 | abi_tag = f"{python_version}" 51 | 52 | if platform.system() == "Linux": 53 | platform_tag = "manylinux1_x86_64" 54 | else: 55 | platform_tag = platform.system().lower() 56 | 57 | return python_version, abi_tag, platform_tag 58 | 59 | 60 | # Setup configuration 61 | setup( 62 | author="OpenRLHF Team", 63 | name=_fetch_package_name(), 64 | version=_fetch_version(), 65 | packages=find_packages( 66 | exclude=( 67 | "data", 68 | "docs", 69 | "examples", 70 | ) 71 | ), 72 | description="A Ray-based High-performance RLHF framework.", 73 | long_description=_fetch_readme(), 74 | long_description_content_type="text/markdown", 75 | install_requires=_fetch_requirements("requirements.txt"), 76 | extras_require={ 77 | "vllm": ["vllm==0.8.3"], 78 | "vllm_latest": ["vllm>0.8.3"], 79 | "ring": ["ring_flash_attn"], 80 | "liger": ["liger_kernel"], 81 | }, 82 | python_requires=">=3.10", 83 | classifiers=[ 84 | "Programming Language :: Python :: 3.10", 85 | "Programming Language :: Python :: 3.11", 86 | "Environment :: GPU :: NVIDIA CUDA", 87 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 88 | "Topic :: System :: Distributed Computing", 89 | ], 90 | cmdclass={"bdist_wheel": bdist_wheel}, 91 | ) 92 | -------------------------------------------------------------------------------- /version.txt: -------------------------------------------------------------------------------- 1 | 0.7.3a 2 | --------------------------------------------------------------------------------