├── src ├── readme.md ├── open_r1 │ ├── readme.md │ ├── utils │ │ ├── __pycache__ │ │ │ ├── hub.cpython-311.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── callbacks.cpython-311.pyc │ │ │ ├── evaluation.cpython-311.pyc │ │ │ ├── import_utils.cpython-311.pyc │ │ │ ├── model_utils.cpython-311.pyc │ │ │ └── wandb_logging.cpython-311.pyc │ │ ├── __init__.py │ │ ├── wandb_logging.py │ │ ├── import_utils.py │ │ ├── model_utils.py │ │ ├── callbacks.py │ │ ├── evaluation.py │ │ └── hub.py │ ├── trl │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── data_utils.cpython-311.pyc │ │ │ ├── import_utils.cpython-311.pyc │ │ │ └── mergekit_utils.cpython-311.pyc │ │ ├── models │ │ │ ├── __pycache__ │ │ │ │ ├── utils.cpython-311.pyc │ │ │ │ ├── __init__.cpython-311.pyc │ │ │ │ ├── modeling_base.cpython-311.pyc │ │ │ │ └── modeling_value_head.cpython-311.pyc │ │ │ ├── __init__.py │ │ │ ├── auxiliary_modules.py │ │ │ └── sd_utils.py │ │ ├── scripts │ │ │ ├── __pycache__ │ │ │ │ ├── utils.cpython-311.pyc │ │ │ │ └── __init__.cpython-311.pyc │ │ │ ├── __init__.py │ │ │ ├── env.py │ │ │ ├── grpo.py │ │ │ ├── sft.py │ │ │ ├── kto.py │ │ │ └── dpo.py │ │ ├── trainer │ │ │ ├── __pycache__ │ │ │ │ ├── judges.cpython-311.pyc │ │ │ │ ├── utils.cpython-311.pyc │ │ │ │ ├── __init__.cpython-311.pyc │ │ │ │ ├── callbacks.cpython-311.pyc │ │ │ │ ├── grpo_config.cpython-311.pyc │ │ │ │ ├── sft_config.cpython-311.pyc │ │ │ │ ├── grpo_trainer.cpython-311.pyc │ │ │ │ ├── model_config.cpython-311.pyc │ │ │ │ └── drgrpo_trainer.cpython-311.pyc │ │ │ ├── xpo_config.py │ │ │ ├── nash_md_config.py │ │ │ ├── reward_config.py │ │ │ ├── prm_config.py │ │ │ ├── rloo_config.py │ │ │ ├── gkd_config.py │ │ │ ├── ppo_config.py │ │ │ ├── __init__.py │ │ │ ├── orpo_config.py │ │ │ ├── sft_config.py │ │ │ ├── model_config.py │ │ │ ├── online_dpo_config.py │ │ │ ├── cpo_config.py │ │ │ └── bco_config.py │ │ ├── extras │ │ │ ├── __init__.py │ │ │ ├── dataset_formatting.py │ │ │ └── best_of_n_sampler.py │ │ ├── environment │ │ │ └── __init__.py │ │ ├── templates │ │ │ └── lm_model_card.md │ │ ├── cli.py │ │ ├── import_utils.py │ │ ├── core.py │ │ └── __init__.py │ ├── __init__.py │ ├── configs.py │ ├── generate.py │ ├── sft.py │ └── evaluate.py └── open_r1.egg-info │ ├── not-zip-safe │ ├── dependency_links.txt │ ├── top_level.txt │ ├── SOURCES.txt │ └── requires.txt ├── Figs └── readme.md ├── recipes ├── data_cleaner.yaml ├── accelerate_configs │ ├── ddp.yaml │ ├── zero2.yaml │ ├── zero3.yaml │ └── fsdp.yaml ├── dra_grpo.yaml └── dra_dr_grpo.yaml ├── LICENSE └── README.md /src/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /Figs/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/open_r1/readme.md: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/not-zip-safe: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | open_r1 2 | -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/hub.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/hub.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/__pycache__/data_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/data_utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/callbacks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/callbacks.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/__pycache__/import_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/import_utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/__pycache__/mergekit_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/mergekit_utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/models/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/scripts/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/judges.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/judges.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/evaluation.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/evaluation.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/import_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/import_utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/model_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/model_utils.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/scripts/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__pycache__/wandb_logging.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/wandb_logging.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/callbacks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/callbacks.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/grpo_config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/grpo_config.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/sft_config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/sft_config.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/models/__pycache__/modeling_base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/modeling_base.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/grpo_trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/grpo_trainer.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/model_config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/model_config.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .import_utils import is_e2b_available 2 | from .model_utils import get_tokenizer 3 | 4 | 5 | __all__ = ["get_tokenizer", "is_e2b_available"] 6 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__pycache__/drgrpo_trainer.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/drgrpo_trainer.cpython-311.pyc -------------------------------------------------------------------------------- /src/open_r1/trl/models/__pycache__/modeling_value_head.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/modeling_value_head.cpython-311.pyc -------------------------------------------------------------------------------- /recipes/data_cleaner.yaml: -------------------------------------------------------------------------------- 1 | model_kwargs: 2 | model: Qwen/Qwen2.5-Math-7B-Instruct 3 | trust_remote_code: true 4 | max_model_len: 4096 5 | gpu_memory_utilization: 0.9 6 | enforce_eager: true 7 | tensor_parallel_size: 4 8 | 9 | sampling_params: 10 | temperature: 0.7 11 | top_p: 0.9 12 | max_tokens: 4096 13 | -------------------------------------------------------------------------------- /recipes/accelerate_configs/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: bf16 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /src/open_r1/utils/wandb_logging.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | 4 | def init_wandb_training(training_args): 5 | """ 6 | Helper function for setting up Weights & Biases logging tools. 7 | """ 8 | if training_args.wandb_entity is not None: 9 | os.environ["WANDB_ENTITY"] = training_args.wandb_entity 10 | if training_args.wandb_project is not None: 11 | os.environ["WANDB_PROJECT"] = training_args.wandb_project 12 | -------------------------------------------------------------------------------- /recipes/accelerate_configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: bf16 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false -------------------------------------------------------------------------------- /recipes/accelerate_configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /src/open_r1/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | LICENSE 2 | README.md 3 | setup.cfg 4 | setup.py 5 | src/open_r1/__init__.py 6 | src/open_r1/configs.py 7 | src/open_r1/evaluate.py 8 | src/open_r1/generate.py 9 | src/open_r1/grpo.py 10 | src/open_r1/rewards.py 11 | src/open_r1/sft.py 12 | src/open_r1.egg-info/PKG-INFO 13 | src/open_r1.egg-info/SOURCES.txt 14 | src/open_r1.egg-info/dependency_links.txt 15 | src/open_r1.egg-info/not-zip-safe 16 | src/open_r1.egg-info/requires.txt 17 | src/open_r1.egg-info/top_level.txt 18 | src/open_r1/utils/__init__.py 19 | src/open_r1/utils/callbacks.py 20 | src/open_r1/utils/evaluation.py 21 | src/open_r1/utils/hub.py 22 | src/open_r1/utils/import_utils.py 23 | src/open_r1/utils/model_utils.py 24 | src/open_r1/utils/wandb_logging.py -------------------------------------------------------------------------------- /recipes/accelerate_configs/fsdp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: false # Need fix from: https://github.com/huggingface/transformers/pull/36610 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_sharding_strategy: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | machine_rank: 0 18 | main_training_function: main 19 | mixed_precision: bf16 20 | num_machines: 1 21 | num_processes: 8 22 | rdzv_backend: static 23 | same_network: true 24 | tpu_env: [] 25 | tpu_use_cluster: false 26 | tpu_use_sudo: false 27 | use_cpu: false -------------------------------------------------------------------------------- /src/open_r1/utils/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers.utils.import_utils import _is_package_available 16 | 17 | 18 | # Use same as transformers.utils.import_utils 19 | _e2b_available = _is_package_available("e2b") 20 | 21 | 22 | def is_e2b_available() -> bool: 23 | return _e2b_available 24 | -------------------------------------------------------------------------------- /src/open_r1/trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "best_of_n_sampler": ["BestOfNSampler"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .best_of_n_sampler import BestOfNSampler 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /src/open_r1.egg-info/requires.txt: -------------------------------------------------------------------------------- 1 | accelerate==1.4.0 2 | bitsandbytes>=0.43.0 3 | einops>=0.8.0 4 | datasets>=3.2.0 5 | deepspeed==0.15.4 6 | hf_transfer>=0.1.4 7 | huggingface-hub[cli]<1.0,>=0.19.2 8 | langdetect 9 | latex2sympy2_extended>=1.0.6 10 | math-verify==0.5.2 11 | liger_kernel==0.5.3 12 | packaging>=23.0 13 | safetensors>=0.3.3 14 | sentencepiece>=0.1.99 15 | transformers==4.49.0 16 | trl@ git+https://github.com/huggingface/trl.git@69ad852e5654a77f1695eb4c608906fe0c7e8624 17 | wandb>=0.19.1 18 | 19 | [code] 20 | e2b-code-interpreter>=1.0.5 21 | python-dotenv 22 | 23 | [dev] 24 | ruff>=0.9.0 25 | isort>=5.12.0 26 | flake8>=6.0.0 27 | pytest 28 | parameterized>=0.9.0 29 | math-verify==0.5.2 30 | lighteval@ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905 31 | 32 | [eval] 33 | lighteval@ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905 34 | math-verify==0.5.2 35 | 36 | [quality] 37 | ruff>=0.9.0 38 | isort>=5.12.0 39 | flake8>=6.0.0 40 | 41 | [tests] 42 | pytest 43 | parameterized>=0.9.0 44 | math-verify==0.5.2 45 | 46 | [torch] 47 | torch==2.5.1 48 | -------------------------------------------------------------------------------- /src/open_r1/trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "base_environment": ["TextEnvironment", "TextHistory"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .base_environment import TextEnvironment, TextHistory 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 xiwenc1 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .utils import ScriptArguments, TrlParser, init_zero_verbose 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /src/open_r1/utils/model_utils.py: -------------------------------------------------------------------------------- 1 | from transformers import AutoTokenizer, PreTrainedTokenizer 2 | 3 | from trl import ModelConfig 4 | 5 | from ..configs import GRPOConfig, SFTConfig 6 | 7 | 8 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}" 9 | 10 | 11 | def get_tokenizer( 12 | model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True 13 | ) -> PreTrainedTokenizer: 14 | """Get the tokenizer for the model.""" 15 | tokenizer = AutoTokenizer.from_pretrained( 16 | model_args.model_name_or_path, 17 | revision=model_args.model_revision, 18 | trust_remote_code=model_args.trust_remote_code, 19 | ) 20 | 21 | if training_args.chat_template is not None: 22 | tokenizer.chat_template = training_args.chat_template 23 | elif auto_set_chat_template and tokenizer.get_chat_template() is None: 24 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE 25 | 26 | return tokenizer 27 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/xpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from trl.trainer.online_dpo_config import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class XPOConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`XPOTrainer`]. 24 | 25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): 29 | Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch 30 | and the last alpha is used for the rest of the epochs. 31 | """ 32 | 33 | alpha: list[float] = field( 34 | default_factory=lambda: [1e-5], 35 | metadata={ 36 | "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each " 37 | "new epoch and the last alpha is used for the rest of the epochs." 38 | }, 39 | ) 40 | 41 | def __post_init__(self): 42 | super().__post_init__() 43 | if hasattr(self.alpha, "__len__") and len(self.alpha) == 1: 44 | self.alpha = self.alpha[0] 45 | -------------------------------------------------------------------------------- /recipes/dra_grpo.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | attn_implementation: flash_attention_2 6 | 7 | # Data training arguments 8 | dataset_name: knoveleng/open-rs 9 | system_prompt: "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} . The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Note that respond by English, NOT use other languages." 10 | 11 | # GRPO trainer config 12 | bf16: true 13 | use_vllm: true 14 | vllm_device: auto 15 | vllm_enforce_eager: true 16 | vllm_gpu_memory_utilization: 0.7 17 | vllm_max_model_len: 4608 18 | do_eval: false 19 | gradient_accumulation_steps: 4 20 | gradient_checkpointing: true 21 | gradient_checkpointing_kwargs: 22 | use_reentrant: false 23 | hub_model_id: DRA-GRPO 24 | hub_strategy: every_save 25 | learning_rate: 1.0e-06 26 | log_completions: true 27 | log_level: info 28 | logging_first_step: true 29 | logging_steps: 1 30 | logging_strategy: steps 31 | lr_scheduler_type: cosine_with_min_lr 32 | lr_scheduler_kwargs: 33 | min_lr_rate: 0.1 34 | max_prompt_length: 512 35 | max_completion_length: 3584 36 | max_steps: 500 37 | num_generations: 6 #6 38 | num_train_epochs: 1 39 | output_dir: data/DRA-GRPO 40 | overwrite_output_dir: true 41 | per_device_eval_batch_size: 6 #6 42 | per_device_train_batch_size: 4 #6 43 | push_to_hub: true 44 | report_to: 45 | - wandb 46 | reward_funcs: 47 | - format 48 | - cosine 49 | reward_weights: 50 | - 1.0 51 | - 2.0 52 | save_strategy: "steps" 53 | save_steps: 50 54 | seed: 2025 55 | temperature: 0.7 56 | warmup_ratio: 0.1 57 | SMI_reweighting: True 58 | extractor_name: jina -------------------------------------------------------------------------------- /recipes/dra_dr_grpo.yaml: -------------------------------------------------------------------------------- 1 | # Model arguments 2 | model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B 3 | model_revision: main 4 | torch_dtype: bfloat16 5 | attn_implementation: flash_attention_2 6 | 7 | # Data training arguments 8 | dataset_name: knoveleng/open-rs 9 | system_prompt: "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} . The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Note that respond by English, NOT use other languages." 10 | 11 | # GRPO trainer config 12 | bf16: true 13 | use_vllm: true 14 | vllm_device: auto 15 | vllm_enforce_eager: true 16 | vllm_gpu_memory_utilization: 0.7 17 | vllm_max_model_len: 4608 18 | do_eval: false 19 | gradient_accumulation_steps: 4 20 | gradient_checkpointing: true 21 | gradient_checkpointing_kwargs: 22 | use_reentrant: false 23 | hub_model_id: DRA-DR_GRPO 24 | hub_strategy: every_save 25 | learning_rate: 1.0e-06 26 | log_completions: true 27 | log_level: info 28 | logging_first_step: true 29 | logging_steps: 1 30 | logging_strategy: steps 31 | lr_scheduler_type: cosine_with_min_lr 32 | lr_scheduler_kwargs: 33 | min_lr_rate: 0.1 34 | max_prompt_length: 512 35 | max_completion_length: 3584 36 | max_steps: 500 37 | num_generations: 6 #6 38 | num_train_epochs: 1 39 | output_dir: data/DRA-DR_GRPO 40 | overwrite_output_dir: true 41 | per_device_eval_batch_size: 6 #6 42 | per_device_train_batch_size: 4 #6 43 | push_to_hub: true 44 | report_to: 45 | - wandb 46 | reward_funcs: 47 | - format 48 | - cosine 49 | reward_weights: 50 | - 1.0 51 | - 2.0 52 | save_strategy: "steps" 53 | save_steps: 25 54 | seed: 42 55 | temperature: 0.7 56 | warmup_ratio: 0.1 57 | SMI_reweighting: True 58 | extractor_name: jina -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/nash_md_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from trl.trainer.online_dpo_config import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class NashMDConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`NashMDTrainer`]. 24 | 25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): 29 | Logit mixture coefficient for the model and reference model. If a list of floats is provided then the 30 | mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the 31 | epochs. 32 | """ 33 | 34 | mixture_coef: list[float] = field( 35 | default_factory=lambda: [0.5], 36 | metadata={ 37 | "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " 38 | "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " 39 | "rest of the epochs." 40 | }, 41 | ) 42 | 43 | def __post_init__(self): 44 | super().__post_init__() 45 | if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: 46 | self.mixture_coef = self.mixture_coef[0] 47 | -------------------------------------------------------------------------------- /src/open_r1/trl/templates/lm_model_card.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # Model Card for {{ model_name }} 6 | 7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. 8 | It has been trained using [TRL](https://github.com/huggingface/trl). 9 | 10 | ## Quick start 11 | 12 | ```python 13 | from transformers import pipeline 14 | 15 | question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" 16 | generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") 17 | output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] 18 | print(output["generated_text"]) 19 | ``` 20 | 21 | ## Training procedure 22 | 23 | {% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} 24 | {% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} 25 | 26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. 27 | 28 | ### Framework versions 29 | 30 | - TRL: {{ trl_version }} 31 | - Transformers: {{ transformers_version }} 32 | - Pytorch: {{ pytorch_version }} 33 | - Datasets: {{ datasets_version }} 34 | - Tokenizers: {{ tokenizers_version }} 35 | 36 | ## Citations 37 | 38 | {% if trainer_citation %}Cite {{ trainer_name }} as: 39 | 40 | ```bibtex 41 | {{ trainer_citation }} 42 | ```{% endif %} 43 | 44 | Cite TRL as: 45 | 46 | ```bibtex 47 | {% raw %}@misc{vonwerra2022trl, 48 | title = {{TRL: Transformer Reinforcement Learning}}, 49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec}, 50 | year = 2020, 51 | journal = {GitHub repository}, 52 | publisher = {GitHub}, 53 | howpublished = {\url{https://github.com/huggingface/trl}} 54 | }{% endraw %} 55 | ``` 56 | -------------------------------------------------------------------------------- /src/open_r1/trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available 18 | 19 | 20 | _import_structure = { 21 | "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], 22 | "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], 23 | "utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"], 24 | } 25 | 26 | try: 27 | if not is_diffusers_available(): 28 | raise OptionalDependencyNotAvailable() 29 | except OptionalDependencyNotAvailable: 30 | pass 31 | else: 32 | _import_structure["modeling_sd_base"] = [ 33 | "DDPOPipelineOutput", 34 | "DDPOSchedulerOutput", 35 | "DDPOStableDiffusionPipeline", 36 | "DefaultDDPOStableDiffusionPipeline", 37 | ] 38 | 39 | if TYPE_CHECKING: 40 | from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model 41 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 42 | from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation 43 | 44 | try: 45 | if not is_diffusers_available(): 46 | raise OptionalDependencyNotAvailable() 47 | except OptionalDependencyNotAvailable: 48 | pass 49 | else: 50 | from .modeling_sd_base import ( 51 | DDPOPipelineOutput, 52 | DDPOSchedulerOutput, 53 | DDPOStableDiffusionPipeline, 54 | DefaultDDPOStableDiffusionPipeline, 55 | ) 56 | else: 57 | import sys 58 | 59 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 60 | -------------------------------------------------------------------------------- /src/open_r1/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import subprocess 18 | from typing import List 19 | 20 | from transformers import TrainerCallback 21 | from transformers.trainer_callback import TrainerControl, TrainerState 22 | from transformers.training_args import TrainingArguments 23 | 24 | from .evaluation import run_benchmark_jobs 25 | from .hub import push_to_hub_revision 26 | 27 | 28 | def is_slurm_available() -> bool: 29 | # returns true if a slurm queueing system is available 30 | try: 31 | subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) 32 | return True 33 | except FileNotFoundError: 34 | return False 35 | 36 | 37 | class DummyConfig: 38 | def __init__(self, **kwargs): 39 | for k, v in kwargs.items(): 40 | setattr(self, k, v) 41 | 42 | 43 | class PushToHubRevisionCallback(TrainerCallback): 44 | def __init__(self, model_config) -> None: 45 | self.model_config = model_config 46 | 47 | def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs): 48 | if state.is_world_process_zero: 49 | global_step = state.global_step 50 | 51 | # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround 52 | # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken 53 | dummy_config = DummyConfig( 54 | hub_model_id=args.hub_model_id, 55 | hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}", 56 | output_dir=f"{args.output_dir}/checkpoint-{global_step}", 57 | system_prompt=args.system_prompt, 58 | ) 59 | 60 | future = push_to_hub_revision( 61 | dummy_config, extra_ignore_patterns=["*.pt"] 62 | ) # don't push the optimizer states 63 | 64 | if is_slurm_available(): 65 | dummy_config.benchmarks = args.benchmarks 66 | 67 | def run_benchmark_callback(_): 68 | print(f"Checkpoint {global_step} pushed to hub.") 69 | run_benchmark_jobs(dummy_config, self.model_config) 70 | 71 | future.add_done_callback(run_benchmark_callback) 72 | 73 | 74 | CALLBACKS = { 75 | "push_to_hub_revision": PushToHubRevisionCallback, 76 | } 77 | 78 | 79 | def get_callbacks(train_config, model_config) -> List[TrainerCallback]: 80 | callbacks = [] 81 | for callback_name in train_config.callbacks: 82 | if callback_name not in CALLBACKS: 83 | raise ValueError(f"Callback {callback_name} not found in CALLBACKS.") 84 | callbacks.append(CALLBACKS[callback_name](model_config)) 85 | 86 | return callbacks 87 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import platform 17 | from importlib.metadata import version 18 | 19 | import torch 20 | from accelerate.commands.config import default_config_file, load_config_from_file 21 | from transformers import is_bitsandbytes_available 22 | from transformers.utils import is_liger_kernel_available, is_openai_available, is_peft_available 23 | 24 | from .. import __version__ 25 | from ..import_utils import is_deepspeed_available, is_diffusers_available, is_llm_blender_available 26 | from .utils import get_git_commit_hash 27 | 28 | 29 | def print_env(): 30 | if torch.cuda.is_available(): 31 | devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] 32 | 33 | accelerate_config = accelerate_config_str = "not found" 34 | 35 | # Get the default from the config file. 36 | if os.path.isfile(default_config_file): 37 | accelerate_config = load_config_from_file(default_config_file).to_dict() 38 | 39 | accelerate_config_str = ( 40 | "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) 41 | if isinstance(accelerate_config, dict) 42 | else accelerate_config 43 | ) 44 | 45 | commit_hash = get_git_commit_hash("trl") 46 | 47 | info = { 48 | "Platform": platform.platform(), 49 | "Python version": platform.python_version(), 50 | "PyTorch version": version("torch"), 51 | "CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available", 52 | "Transformers version": version("transformers"), 53 | "Accelerate version": version("accelerate"), 54 | "Accelerate config": accelerate_config_str, 55 | "Datasets version": version("datasets"), 56 | "HF Hub version": version("huggingface_hub"), 57 | "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, 58 | "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", 59 | "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", 60 | "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed", 61 | "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", 62 | "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", 63 | "OpenAI version": version("openai") if is_openai_available() else "not installed", 64 | "PEFT version": version("peft") if is_peft_available() else "not installed", 65 | } 66 | 67 | info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) 68 | print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa 69 | 70 | 71 | if __name__ == "__main__": 72 | print_env() 73 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`RewardTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`int` or `None`, *optional*, defaults to `1024`): 32 | Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the 33 | limit. This argument is required if you want to use the default data collator. 34 | disable_dropout (`bool`, *optional*, defaults to `True`): 35 | Whether to disable dropout in the model. 36 | dataset_num_proc (`int`, *optional*, defaults to `None`): 37 | Number of processes to use for processing the dataset. 38 | center_rewards_coefficient (`float`, *optional*, defaults to `None`): 39 | Coefficient to incentivize the reward model to output mean-zero rewards (proposed by 40 | https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. 41 | remove_unused_columns (`bool`, *optional*, defaults to `False`): 42 | Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if 43 | the dataset is pretokenized. 44 | """ 45 | 46 | max_length: Optional[int] = field( 47 | default=1024, 48 | metadata={ 49 | "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " 50 | "exceed the limit. This argument is required if you want to use the default data collator." 51 | }, 52 | ) 53 | disable_dropout: bool = field( 54 | default=True, 55 | metadata={"help": "Whether to disable dropout in the model and reference model."}, 56 | ) 57 | dataset_num_proc: Optional[int] = field( 58 | default=None, 59 | metadata={"help": "Number of processes to use for processing the dataset."}, 60 | ) 61 | center_rewards_coefficient: Optional[float] = field( 62 | default=None, 63 | metadata={ 64 | "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " 65 | "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." 66 | }, 67 | ) 68 | remove_unused_columns: bool = field( 69 | default=False, 70 | metadata={ 71 | "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " 72 | "if the dataset is pretokenized." 73 | }, 74 | ) 75 | -------------------------------------------------------------------------------- /src/open_r1/configs.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2025 The HuggingFace Team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | 19 | import trl 20 | 21 | 22 | # TODO: add the shared options with a mixin to reduce code duplication 23 | @dataclass 24 | class GRPOConfig(trl.GRPOConfig): 25 | """ 26 | args for callbacks, benchmarks etc 27 | """ 28 | 29 | benchmarks: list[str] = field( 30 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} 31 | ) 32 | callbacks: list[str] = field( 33 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} 34 | ) 35 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) 36 | system_prompt: Optional[str] = field( 37 | default=None, 38 | metadata={"help": "The optional system prompt to use."}, 39 | ) 40 | hub_model_revision: Optional[str] = field( 41 | default="main", metadata={"help": "The Hub model branch to push the model to."} 42 | ) 43 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) 44 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) 45 | wandb_entity: Optional[str] = field( 46 | default=None, 47 | metadata={"help": ("The entity to store runs under.")}, 48 | ) 49 | wandb_project: Optional[str] = field( 50 | default=None, 51 | metadata={"help": ("The project to store runs under.")}, 52 | ) 53 | 54 | 55 | @dataclass 56 | class SFTConfig(trl.SFTConfig): 57 | """ 58 | args for callbacks, benchmarks etc 59 | """ 60 | 61 | benchmarks: list[str] = field( 62 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."} 63 | ) 64 | callbacks: list[str] = field( 65 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."} 66 | ) 67 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."}) 68 | system_prompt: Optional[str] = field( 69 | default=None, 70 | metadata={"help": "The optional system prompt to use for benchmarking."}, 71 | ) 72 | hub_model_revision: Optional[str] = field( 73 | default="main", 74 | metadata={"help": "The Hub model branch to push the model to."}, 75 | ) 76 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."}) 77 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."}) 78 | wandb_entity: Optional[str] = field( 79 | default=None, 80 | metadata={"help": ("The entity to store runs under.")}, 81 | ) 82 | wandb_project: Optional[str] = field( 83 | default=None, 84 | metadata={"help": ("The project to store runs under.")}, 85 | ) 86 | -------------------------------------------------------------------------------- /src/open_r1/trl/models/auxiliary_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torchvision 20 | from huggingface_hub import hf_hub_download 21 | from huggingface_hub.utils import EntryNotFoundError 22 | from transformers import CLIPModel, is_torch_npu_available, is_torch_xpu_available 23 | 24 | 25 | class MLP(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | self.layers = nn.Sequential( 29 | nn.Linear(768, 1024), 30 | nn.Dropout(0.2), 31 | nn.Linear(1024, 128), 32 | nn.Dropout(0.2), 33 | nn.Linear(128, 64), 34 | nn.Dropout(0.1), 35 | nn.Linear(64, 16), 36 | nn.Linear(16, 1), 37 | ) 38 | 39 | def forward(self, embed): 40 | return self.layers(embed) 41 | 42 | 43 | class AestheticScorer(torch.nn.Module): 44 | """ 45 | This model attempts to predict the aesthetic score of an image. The aesthetic score 46 | is a numerical approximation of how much a specific image is liked by humans on average. 47 | This is from https://github.com/christophschuhmann/improved-aesthetic-predictor 48 | """ 49 | 50 | def __init__(self, *, dtype, model_id, model_filename): 51 | super().__init__() 52 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") 53 | self.normalize = torchvision.transforms.Normalize( 54 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] 55 | ) 56 | self.target_size = 224 57 | self.mlp = MLP() 58 | try: 59 | cached_path = hf_hub_download(model_id, model_filename) 60 | except EntryNotFoundError: 61 | cached_path = os.path.join(model_id, model_filename) 62 | state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True) 63 | self.mlp.load_state_dict(state_dict) 64 | self.dtype = dtype 65 | self.eval() 66 | 67 | def __call__(self, images): 68 | device = next(self.parameters()).device 69 | images = torchvision.transforms.Resize(self.target_size)(images) 70 | images = self.normalize(images).to(self.dtype).to(device) 71 | embed = self.clip.get_image_features(pixel_values=images) 72 | # normalize embedding 73 | embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) 74 | reward = self.mlp(embed).squeeze(1) 75 | return reward 76 | 77 | 78 | def aesthetic_scorer(hub_model_id, model_filename): 79 | scorer = AestheticScorer( 80 | model_id=hub_model_id, 81 | model_filename=model_filename, 82 | dtype=torch.float32, 83 | ) 84 | if is_torch_npu_available(): 85 | scorer = scorer.npu() 86 | elif is_torch_xpu_available(): 87 | scorer = scorer.xpu() 88 | else: 89 | scorer = scorer.cuda() 90 | 91 | def _fn(images, prompts, metadata): 92 | images = (images).clamp(0, 1) 93 | scores = scorer(images) 94 | return scores, {} 95 | 96 | return _fn 97 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DRA-GRPO 2 | Official code for the paper: DRA-GRPO: Exploring Diversity-Aware Reward Adjustment for R1-Zero-Like Training of Large Language Models [![paper](https://img.shields.io/badge/arXiv-Paper-brightgreen)](https://arxiv.org/abs/2505.09655) 3 | 4 | Paper link (preprint): https://arxiv.org/abs/2505.09655 5 | 6 | 7 | 8 | > **Abstract.** Recent advances in reinforcement learning for language model post-training, such as Group Relative Policy Optimization (GRPO), have shown promise in low-resource settings. However, GRPO typically relies on solution-level and scalar reward signals that fail to capture the semantic diversity among sampled completions. This leads to what we identify as a diversity-quality inconsistency, where distinct reasoning paths may receive indistinguishable rewards. To address this limitation, we propose $\textit{Diversity-aware Reward Adjustment} (DRA)$, a method that explicitly incorporates semantic diversity into the reward computation. DRA uses Submodular Mutual Information (SMI) to downweight redundant completions and amplify rewards for diverse ones. This encourages better exploration during learning, while maintaining stable exploitation of high-quality samples. Our method integrates seamlessly with both GRPO and its variant DR.~GRPO, resulting in $\textit{DRA-GRPO}$ and $\textit{DGA-DR.~GRPO}$. We evaluate our method on five mathematical reasoning benchmarks and find that it outperforms recent strong baselines. It achieves state-of-the-art performance with an average accuracy of 58.2\%, using only 7,000 fine-tuning samples and a total training cost of approximately $55. 9 | 10 | ## Installation 11 | 12 | Clone the code. We are using the following modules. 13 | 14 | ``` 15 | module load anaconda3/2023.09-0 16 | module load git-lfs/3.3.0 17 | module load cuda/11.8.0 18 | 19 | ``` 20 | 21 | Please follow the instructions of [Open-RS](https://github.com/knoveleng/open-rs) to install the environment. 22 | Log in to Hugging Face and Weights & Biases: 23 | ``` 24 | huggingface-cli login 25 | wandb login 26 | ``` 27 | 28 | ``` 29 | source activate openr3 30 | ``` 31 | 32 | **You can then remove ```trl``` package from the environment, because we customized it.** 33 | 34 | 35 | 36 | ## Training 37 | 38 | ### DRA-GRPO 39 | ``` 40 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 41 | --main_process_port 11188 \ 42 | --config_file recipes/accelerate_configs/zero2.yaml \ 43 | --num_processes=3 \ 44 | src/open_r1/grpo.py \ 45 | --config recipes/dra_grpo.yaml 46 | ``` 47 | 48 | 49 | ### DRA-DR. GRPO 50 | ``` 51 | ACCELERATE_LOG_LEVEL=info accelerate launch \ 52 | --main_process_port 18007 \ 53 | --config_file recipes/accelerate_configs/zero2.yaml \ 54 | --num_processes=3 \ 55 | src/open_r1/drgrpo.py \ 56 | --config recipes/dra_dr_grpo.yaml 57 | ``` 58 | 59 | All weights will update to Huggingface. 60 | 61 | ## Inference via lighteval (Test multiple steps) 62 | We have an evaluation template 63 | 64 | ``` 65 | base evaL_all.sh 66 | ``` 67 | 68 | ## Checkpoints (Updated in [Huggingface Repo](https://huggingface.co/SpiceRL) ) 69 | ``` 70 | MODEL= xxx 71 | MODEL_NAME=$(basename "$MODEL") 72 | 73 | TASKS="math_500 amc23 minerva olympiadbench aime24" 74 | 75 | OUTPUT_DIR=data-test/evals/${MODEL_NAME} 76 | MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}" 77 | for TASK in $TASKS; do 78 | lighteval vllm "$MODEL_ARGS" "custom|$TASK|0|0" \ 79 | --custom-tasks src/open_r1/evaluate.py \ 80 | --use-chat-template \ 81 | --output-dir "$OUTPUT_DIR" 82 | done 83 | 84 | done 85 | ``` 86 | 87 | Replace ```xxx``` with ```SpiceRL/DRA-GRPO``` or ```SpiceRL/DRA-DR.GRPO```. The evaluation only requires one GPU. 88 | 89 | ## 90 | 91 | Our code is built based on [Open-rs](https://github.com/knoveleng/open-rs). Thanks! 92 | 93 | 94 | 95 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/grpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import argparse 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | 19 | from datasets import load_dataset 20 | from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer 21 | 22 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config 23 | 24 | 25 | @dataclass 26 | class GRPOScriptArguments(ScriptArguments): 27 | """ 28 | Script arguments for the GRPO training script. 29 | 30 | Args: 31 | reward_model_name_or_path (`str` or `None`): 32 | Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a 33 | directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`]. 34 | """ 35 | 36 | reward_model_name_or_path: Optional[str] = field( 37 | default=None, 38 | metadata={ 39 | "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or " 40 | "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`." 41 | }, 42 | ) 43 | 44 | 45 | def main(script_args, training_args, model_args): 46 | # Load a pretrained model 47 | model = AutoModelForCausalLM.from_pretrained( 48 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 49 | ) 50 | tokenizer = AutoTokenizer.from_pretrained( 51 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 52 | ) 53 | reward_model = AutoModelForSequenceClassification.from_pretrained( 54 | script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1 55 | ) 56 | 57 | # Load the dataset 58 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 59 | 60 | # Initialize the GRPO trainer 61 | trainer = GRPOTrainer( 62 | model=model, 63 | reward_funcs=reward_model, 64 | args=training_args, 65 | train_dataset=dataset[script_args.dataset_train_split], 66 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 67 | processing_class=tokenizer, 68 | peft_config=get_peft_config(model_args), 69 | ) 70 | 71 | # Train and push the model to the Hub 72 | trainer.train() 73 | 74 | # Save and push to hub 75 | trainer.save_model(training_args.output_dir) 76 | if training_args.push_to_hub: 77 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 78 | 79 | 80 | def make_parser(subparsers: argparse._SubParsersAction = None): 81 | dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig) 82 | if subparsers is not None: 83 | parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types) 84 | else: 85 | parser = TrlParser(dataclass_types) 86 | return parser 87 | 88 | 89 | if __name__ == "__main__": 90 | parser = make_parser() 91 | script_args, training_args, model_args = parser.parse_args_and_config() 92 | main(script_args, training_args, model_args) 93 | -------------------------------------------------------------------------------- /src/open_r1/trl/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import sys 17 | 18 | from accelerate.commands.launch import launch_command, launch_command_parser 19 | 20 | from .scripts.chat import main as chat_main 21 | from .scripts.chat import make_parser as make_chat_parser 22 | from .scripts.dpo import make_parser as make_dpo_parser 23 | from .scripts.env import print_env 24 | from .scripts.grpo import make_parser as make_grpo_parser 25 | from .scripts.kto import make_parser as make_kto_parser 26 | from .scripts.sft import make_parser as make_sft_parser 27 | from .scripts.utils import TrlParser 28 | 29 | 30 | def main(): 31 | parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False) 32 | 33 | # Add the subparsers 34 | subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser) 35 | 36 | # Add the subparsers for every script 37 | make_chat_parser(subparsers) 38 | make_dpo_parser(subparsers) 39 | subparsers.add_parser("env", help="Print the environment information") 40 | make_grpo_parser(subparsers) 41 | make_kto_parser(subparsers) 42 | make_sft_parser(subparsers) 43 | 44 | # Parse the arguments 45 | args = parser.parse_args() 46 | 47 | if args.command == "chat": 48 | (chat_args,) = parser.parse_args_and_config() 49 | chat_main(chat_args) 50 | 51 | if args.command == "dpo": 52 | # Get the default args for the launch command 53 | dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py") 54 | args = launch_command_parser().parse_args([dpo_training_script]) 55 | 56 | # Feed the args to the launch command 57 | args.training_script_args = sys.argv[2:] # remove "trl" and "dpo" 58 | launch_command(args) # launch training 59 | 60 | elif args.command == "env": 61 | print_env() 62 | 63 | elif args.command == "grpo": 64 | # Get the default args for the launch command 65 | grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py") 66 | args = launch_command_parser().parse_args([grpo_training_script]) 67 | 68 | # Feed the args to the launch command 69 | args.training_script_args = sys.argv[2:] # remove "trl" and "grpo" 70 | launch_command(args) # launch training 71 | 72 | elif args.command == "kto": 73 | # Get the default args for the launch command 74 | kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py") 75 | args = launch_command_parser().parse_args([kto_training_script]) 76 | 77 | # Feed the args to the launch command 78 | args.training_script_args = sys.argv[2:] # remove "trl" and "kto" 79 | launch_command(args) # launch training 80 | 81 | elif args.command == "sft": 82 | # Get the default args for the launch command 83 | sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py") 84 | args = launch_command_parser().parse_args([sft_training_script]) 85 | 86 | # Feed the args to the launch command 87 | args.training_script_args = sys.argv[2:] # remove "trl" and "sft" 88 | launch_command(args) # launch training 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/prm_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class PRMConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`PRMTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | learning_rate (`float`, *optional*, defaults to `1e-5`): 32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 33 | [`~transformers.TrainingArguments`]. 34 | max_length (`int` or `None`, *optional*, defaults to `1024`): 35 | Maximum length of the sequences (prompt + completion) used for truncation. 36 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): 37 | Maximum length of the prompt used for truncation. 38 | max_completion_length (`int` or `None`, *optional*, defaults to `None`): 39 | Maximum length of the completion used for truncation. The completion is the concatenation of the steps. 40 | disable_dropout (`bool`, *optional*, defaults to `True`): 41 | Whether to disable dropout in the model. 42 | step_separator (`str`, *optional*, defaults to `"\n"`): 43 | Separator used to separate each step of the reasoning process. 44 | train_on_last_step_only (`bool`, *optional*, defaults to `False`): 45 | Whether to train only on the last step. 46 | dataset_num_proc (`int`, *optional*, defaults to `None`): 47 | Number of processes to use for processing the dataset. 48 | """ 49 | 50 | learning_rate: float = field( 51 | default=1e-5, 52 | metadata={ 53 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 54 | "`TrainingArguments`." 55 | }, 56 | ) 57 | max_length: Optional[int] = field( 58 | default=1024, 59 | metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."}, 60 | ) 61 | max_prompt_length: Optional[int] = field( 62 | default=512, 63 | metadata={"help": "Maximum length of the prompt used for truncation."}, 64 | ) 65 | max_completion_length: Optional[int] = field( 66 | default=None, 67 | metadata={ 68 | "help": "Maximum length of the completion used for truncation. The completion is the concatenation of the " 69 | "steps." 70 | }, 71 | ) 72 | disable_dropout: bool = field( 73 | default=True, 74 | metadata={"help": "Whether to disable dropout in the model and reference model."}, 75 | ) 76 | step_separator: str = field( 77 | default="\n", 78 | metadata={"help": "Separator used to separate each step of the reasoning process."}, 79 | ) 80 | train_on_last_step_only: bool = field( 81 | default=False, 82 | metadata={"help": "Whether to train only on the last step."}, 83 | ) 84 | dataset_num_proc: Optional[int] = field( 85 | default=None, 86 | metadata={"help": "Number of processes to use for processing the dataset."}, 87 | ) 88 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | # Full training 17 | python trl/scripts/sft.py \ 18 | --model_name_or_path Qwen/Qwen2-0.5B \ 19 | --dataset_name trl-lib/Capybara \ 20 | --learning_rate 2.0e-5 \ 21 | --num_train_epochs 1 \ 22 | --packing \ 23 | --per_device_train_batch_size 2 \ 24 | --gradient_accumulation_steps 8 \ 25 | --gradient_checkpointing \ 26 | --logging_steps 25 \ 27 | --eval_strategy steps \ 28 | --eval_steps 100 \ 29 | --output_dir Qwen2-0.5B-SFT \ 30 | --push_to_hub 31 | 32 | # LoRA 33 | python trl/scripts/sft.py \ 34 | --model_name_or_path Qwen/Qwen2-0.5B \ 35 | --dataset_name trl-lib/Capybara \ 36 | --learning_rate 2.0e-4 \ 37 | --num_train_epochs 1 \ 38 | --packing \ 39 | --per_device_train_batch_size 2 \ 40 | --gradient_accumulation_steps 8 \ 41 | --gradient_checkpointing \ 42 | --logging_steps 25 \ 43 | --eval_strategy steps \ 44 | --eval_steps 100 \ 45 | --use_peft \ 46 | --lora_r 32 \ 47 | --lora_alpha 16 \ 48 | --output_dir Qwen2-0.5B-SFT \ 49 | --push_to_hub 50 | """ 51 | 52 | import argparse 53 | 54 | from datasets import load_dataset 55 | from transformers import AutoTokenizer 56 | 57 | from trl import ( 58 | ModelConfig, 59 | ScriptArguments, 60 | SFTConfig, 61 | SFTTrainer, 62 | TrlParser, 63 | get_kbit_device_map, 64 | get_peft_config, 65 | get_quantization_config, 66 | ) 67 | 68 | 69 | def main(script_args, training_args, model_args): 70 | ################ 71 | # Model init kwargs & Tokenizer 72 | ################ 73 | quantization_config = get_quantization_config(model_args) 74 | model_kwargs = dict( 75 | revision=model_args.model_revision, 76 | trust_remote_code=model_args.trust_remote_code, 77 | attn_implementation=model_args.attn_implementation, 78 | torch_dtype=model_args.torch_dtype, 79 | use_cache=False if training_args.gradient_checkpointing else True, 80 | device_map=get_kbit_device_map() if quantization_config is not None else None, 81 | quantization_config=quantization_config, 82 | ) 83 | training_args.model_init_kwargs = model_kwargs 84 | tokenizer = AutoTokenizer.from_pretrained( 85 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True 86 | ) 87 | if tokenizer.pad_token is None: 88 | tokenizer.pad_token = tokenizer.eos_token 89 | 90 | ################ 91 | # Dataset 92 | ################ 93 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 94 | 95 | ################ 96 | # Training 97 | ################ 98 | trainer = SFTTrainer( 99 | model=model_args.model_name_or_path, 100 | args=training_args, 101 | train_dataset=dataset[script_args.dataset_train_split], 102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 103 | processing_class=tokenizer, 104 | peft_config=get_peft_config(model_args), 105 | ) 106 | 107 | trainer.train() 108 | 109 | # Save and push to hub 110 | trainer.save_model(training_args.output_dir) 111 | if training_args.push_to_hub: 112 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 113 | 114 | 115 | def make_parser(subparsers: argparse._SubParsersAction = None): 116 | dataclass_types = (ScriptArguments, SFTConfig, ModelConfig) 117 | if subparsers is not None: 118 | parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types) 119 | else: 120 | parser = TrlParser(dataclass_types) 121 | return parser 122 | 123 | 124 | if __name__ == "__main__": 125 | parser = make_parser() 126 | script_args, training_args, model_args = parser.parse_args_and_config() 127 | main(script_args, training_args, model_args) 128 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/kto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. 17 | 18 | # Full training: 19 | python trl/scripts/kto.py \ 20 | --dataset_name trl-lib/kto-mix-14k \ 21 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 22 | --per_device_train_batch_size 16 \ 23 | --num_train_epochs 1 \ 24 | --learning_rate 5e-7 \ 25 | --lr_scheduler_type=cosine \ 26 | --gradient_accumulation_steps 1 \ 27 | --logging_steps 10 \ 28 | --eval_steps 500 \ 29 | --output_dir=kto-aligned-model \ 30 | --warmup_ratio 0.1 \ 31 | --report_to wandb \ 32 | --bf16 \ 33 | --logging_first_step 34 | 35 | # QLoRA: 36 | python trl/scripts/kto.py \ 37 | --dataset_name trl-lib/kto-mix-14k \ 38 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 39 | --per_device_train_batch_size 8 \ 40 | --num_train_epochs 1 \ 41 | --learning_rate 5e-7 \ 42 | --lr_scheduler_type=cosine \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 10 \ 45 | --eval_steps 500 \ 46 | --output_dir=kto-aligned-model-lora \ 47 | --warmup_ratio 0.1 \ 48 | --report_to wandb \ 49 | --bf16 \ 50 | --logging_first_step \ 51 | --use_peft \ 52 | --load_in_4bit \ 53 | --lora_target_modules=all-linear \ 54 | --lora_r=16 \ 55 | --lora_alpha=16 56 | """ 57 | 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | from transformers import AutoModelForCausalLM, AutoTokenizer 62 | 63 | from trl import ( 64 | KTOConfig, 65 | KTOTrainer, 66 | ModelConfig, 67 | ScriptArguments, 68 | TrlParser, 69 | get_peft_config, 70 | setup_chat_format, 71 | ) 72 | 73 | 74 | def main(script_args, training_args, model_args): 75 | # Load a pretrained model 76 | model = AutoModelForCausalLM.from_pretrained( 77 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 78 | ) 79 | ref_model = AutoModelForCausalLM.from_pretrained( 80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 81 | ) 82 | 83 | tokenizer = AutoTokenizer.from_pretrained( 84 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 85 | ) 86 | if tokenizer.pad_token is None: 87 | tokenizer.pad_token = tokenizer.eos_token 88 | 89 | # If we are aligning a base model, we use ChatML as the default template 90 | if tokenizer.chat_template is None: 91 | model, tokenizer = setup_chat_format(model, tokenizer) 92 | 93 | # Load the dataset 94 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 95 | 96 | # Initialize the KTO trainer 97 | trainer = KTOTrainer( 98 | model, 99 | ref_model, 100 | args=training_args, 101 | train_dataset=dataset[script_args.dataset_train_split], 102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 103 | processing_class=tokenizer, 104 | peft_config=get_peft_config(model_args), 105 | ) 106 | 107 | # Train and push the model to the Hub 108 | trainer.train() 109 | 110 | # Save and push to hub 111 | trainer.save_model(training_args.output_dir) 112 | if training_args.push_to_hub: 113 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 114 | 115 | 116 | def make_parser(subparsers: argparse._SubParsersAction = None): 117 | dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) 118 | if subparsers is not None: 119 | parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) 120 | else: 121 | parser = TrlParser(dataclass_types) 122 | return parser 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = make_parser() 127 | script_args, training_args, model_args = parser.parse_args_and_config() 128 | main(script_args, training_args, model_args) 129 | -------------------------------------------------------------------------------- /src/open_r1/utils/evaluation.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | from typing import TYPE_CHECKING, Dict, Union 3 | 4 | from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id 5 | 6 | 7 | if TYPE_CHECKING: 8 | from trl import GRPOConfig, SFTConfig, ModelConfig 9 | 10 | import os 11 | 12 | 13 | # We need a special environment setup to launch vLLM from within Slurm training jobs. 14 | # - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105 15 | # - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269 16 | user_home_directory = os.path.expanduser("~") 17 | VLLM_SLURM_PREFIX = [ 18 | "env", 19 | "-i", 20 | "bash", 21 | "-c", 22 | f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ", 23 | ] 24 | 25 | 26 | def register_lighteval_task( 27 | configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0 28 | ): 29 | """Registers a LightEval task configuration. 30 | 31 | - Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl 32 | - Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks 33 | 34 | Args: 35 | configs (Dict[str, str]): The dictionary to store the task configuration. 36 | eval_suite (str, optional): The evaluation suite. 37 | task_name (str): The name of the task. 38 | task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0". 39 | num_fewshot (int, optional): The number of few-shot examples. Defaults to 0. 40 | is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False. 41 | """ 42 | # Format task list in lighteval format 43 | task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(",")) 44 | configs[task_name] = task_list 45 | 46 | 47 | LIGHTEVAL_TASKS = {} 48 | 49 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0) 50 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0) 51 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25", "aime25", 0) 52 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0) 53 | register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0) 54 | register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb_v4", "lcb:codegeneration_v4", 0) 55 | 56 | 57 | def get_lighteval_tasks(): 58 | return list(LIGHTEVAL_TASKS.keys()) 59 | 60 | 61 | SUPPORTED_BENCHMARKS = get_lighteval_tasks() 62 | 63 | 64 | def run_lighteval_job( 65 | benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig" 66 | ) -> None: 67 | task_list = LIGHTEVAL_TASKS[benchmark] 68 | model_name = training_args.hub_model_id 69 | model_revision = training_args.hub_model_revision 70 | # For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM 71 | num_gpus = get_gpu_count_for_vllm(model_name, model_revision) 72 | if get_param_count_from_repo_id(model_name) >= 30_000_000_000: 73 | tensor_parallel = True 74 | else: 75 | num_gpus = 8 76 | tensor_parallel = False 77 | 78 | cmd = VLLM_SLURM_PREFIX.copy() 79 | cmd_args = [ 80 | f"--gres=gpu:{num_gpus}", 81 | f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}", 82 | "slurm/evaluate.slurm", 83 | benchmark, 84 | f'"{task_list}"', 85 | model_name, 86 | model_revision, 87 | f"{tensor_parallel}", 88 | f"{model_args.trust_remote_code}", 89 | ] 90 | if training_args.system_prompt is not None: 91 | cmd_args.append(f"--system_prompt={training_args.system_prompt}") 92 | cmd[-1] += " " + " ".join(cmd_args) 93 | subprocess.run(cmd, check=True) 94 | 95 | 96 | def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None: 97 | benchmarks = training_args.benchmarks 98 | if len(benchmarks) == 1 and benchmarks[0] == "all": 99 | benchmarks = get_lighteval_tasks() 100 | # Evaluate on all supported benchmarks. Later we may want to include a `chat` option 101 | # that just evaluates on `ifeval` and `mt_bench` etc. 102 | 103 | for benchmark in benchmarks: 104 | print(f"Launching benchmark `{benchmark}`") 105 | if benchmark in get_lighteval_tasks(): 106 | run_lighteval_job(benchmark, training_args, model_args) 107 | else: 108 | raise ValueError(f"Unknown benchmark {benchmark}") 109 | -------------------------------------------------------------------------------- /src/open_r1/trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import logging 16 | from typing import Callable, Literal, Optional, Union 17 | 18 | from datasets import Dataset, Value 19 | from transformers import AutoTokenizer 20 | 21 | from ..trainer.utils import ConstantLengthDataset 22 | 23 | 24 | FORMAT_MAPPING = { 25 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 26 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 27 | } 28 | 29 | 30 | def conversations_formatting_function( 31 | tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None 32 | ): 33 | r""" 34 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer 35 | apply chat template to the dataset along with the schema of the list of functions in the tools list. 36 | """ 37 | 38 | def format_dataset(examples): 39 | if isinstance(examples[messages_field][0], list): 40 | output_texts = [] 41 | for i in range(len(examples[messages_field])): 42 | output_texts.append( 43 | tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools) 44 | ) 45 | return output_texts 46 | else: 47 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools) 48 | 49 | return format_dataset 50 | 51 | 52 | def instructions_formatting_function(tokenizer: AutoTokenizer): 53 | r""" 54 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer 55 | apply chat template to the dataset 56 | """ 57 | 58 | def format_dataset(examples): 59 | if isinstance(examples["prompt"], list): 60 | output_texts = [] 61 | for i in range(len(examples["prompt"])): 62 | converted_sample = [ 63 | {"role": "user", "content": examples["prompt"][i]}, 64 | {"role": "assistant", "content": examples["completion"][i]}, 65 | ] 66 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False)) 67 | return output_texts 68 | else: 69 | converted_sample = [ 70 | {"role": "user", "content": examples["prompt"]}, 71 | {"role": "assistant", "content": examples["completion"]}, 72 | ] 73 | return tokenizer.apply_chat_template(converted_sample, tokenize=False) 74 | 75 | return format_dataset 76 | 77 | 78 | def get_formatting_func_from_dataset( 79 | dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None 80 | ) -> Optional[Callable]: 81 | r""" 82 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are: 83 | - `ChatML` with [{"role": str, "content": str}] 84 | - `instruction` with [{"prompt": str, "completion": str}] 85 | 86 | Args: 87 | dataset (Dataset): User dataset 88 | tokenizer (AutoTokenizer): Tokenizer used for formatting 89 | 90 | Returns: 91 | Callable: Formatting function if the dataset format is supported else None 92 | """ 93 | if isinstance(dataset, Dataset): 94 | if "messages" in dataset.features: 95 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]: 96 | logging.info("Formatting dataset with chatml format") 97 | return conversations_formatting_function(tokenizer, "messages", tools) 98 | if "conversations" in dataset.features: 99 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]: 100 | logging.info("Formatting dataset with chatml format") 101 | return conversations_formatting_function(tokenizer, "conversations", tools) 102 | elif dataset.features == FORMAT_MAPPING["instruction"]: 103 | logging.info("Formatting dataset with instruction format") 104 | return instructions_formatting_function(tokenizer) 105 | 106 | return None 107 | -------------------------------------------------------------------------------- /src/open_r1/trl/import_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import importlib 16 | import os 17 | from itertools import chain 18 | from types import ModuleType 19 | from typing import Any 20 | 21 | from transformers.utils.import_utils import _is_package_available 22 | 23 | 24 | # Use same as transformers.utils.import_utils 25 | _deepspeed_available = _is_package_available("deepspeed") 26 | _diffusers_available = _is_package_available("diffusers") 27 | _llm_blender_available = _is_package_available("llm_blender") 28 | _mergekit_available = _is_package_available("mergekit") 29 | _rich_available = _is_package_available("rich") 30 | _unsloth_available = _is_package_available("unsloth") 31 | _vllm_available = _is_package_available("vllm") 32 | 33 | 34 | def is_deepspeed_available() -> bool: 35 | return _deepspeed_available 36 | 37 | 38 | def is_diffusers_available() -> bool: 39 | return _diffusers_available 40 | 41 | 42 | def is_llm_blender_available() -> bool: 43 | return _llm_blender_available 44 | 45 | 46 | def is_mergekit_available() -> bool: 47 | return _mergekit_available 48 | 49 | 50 | def is_rich_available() -> bool: 51 | return _rich_available 52 | 53 | 54 | def is_unsloth_available() -> bool: 55 | return _unsloth_available 56 | 57 | 58 | def is_vllm_available() -> bool: 59 | return _vllm_available 60 | 61 | 62 | class _LazyModule(ModuleType): 63 | """ 64 | Module class that surfaces all objects but only performs associated imports when the objects are requested. 65 | """ 66 | 67 | # Very heavily inspired by optuna.integration._IntegrationModule 68 | # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py 69 | def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None): 70 | super().__init__(name) 71 | self._modules = set(import_structure.keys()) 72 | self._class_to_module = {} 73 | for key, values in import_structure.items(): 74 | for value in values: 75 | self._class_to_module[value] = key 76 | # Needed for autocompletion in an IDE 77 | self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values())) 78 | self.__file__ = module_file 79 | self.__spec__ = module_spec 80 | self.__path__ = [os.path.dirname(module_file)] 81 | self._objects = {} if extra_objects is None else extra_objects 82 | self._name = name 83 | self._import_structure = import_structure 84 | 85 | # Needed for autocompletion in an IDE 86 | def __dir__(self): 87 | result = super().__dir__() 88 | # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether 89 | # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir. 90 | for attr in self.__all__: 91 | if attr not in result: 92 | result.append(attr) 93 | return result 94 | 95 | def __getattr__(self, name: str) -> Any: 96 | if name in self._objects: 97 | return self._objects[name] 98 | if name in self._modules: 99 | value = self._get_module(name) 100 | elif name in self._class_to_module.keys(): 101 | module = self._get_module(self._class_to_module[name]) 102 | value = getattr(module, name) 103 | else: 104 | raise AttributeError(f"module {self.__name__} has no attribute {name}") 105 | 106 | setattr(self, name, value) 107 | return value 108 | 109 | def _get_module(self, module_name: str): 110 | try: 111 | return importlib.import_module("." + module_name, self.__name__) 112 | except Exception as e: 113 | raise RuntimeError( 114 | f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its" 115 | f" traceback):\n{e}" 116 | ) from e 117 | 118 | def __reduce__(self): 119 | return (self.__class__, (self._name, self.__file__, self._import_structure)) 120 | 121 | 122 | class OptionalDependencyNotAvailable(BaseException): 123 | """Internally used error class for signalling an optional dependency was not found.""" 124 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/rloo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass, field 17 | 18 | from ..trainer.utils import OnPolicyConfig 19 | 20 | 21 | @dataclass 22 | class RLOOConfig(OnPolicyConfig): 23 | r""" 24 | Configuration class for the [`RLOOTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`): 32 | Name of this experiment. 33 | reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): 34 | Path to the reward model. 35 | num_ppo_epochs (`int`, *optional*, defaults to `4`): 36 | Number of epochs to train. 37 | whiten_rewards (`bool`, *optional*, defaults to `False`): 38 | Whether to whiten the rewards. 39 | kl_coef (`float`, *optional*, defaults to `0.05`): 40 | KL coefficient. 41 | cliprange (`float`, *optional*, defaults to `0.2`): 42 | Clip range. 43 | rloo_k (`int`, *optional*, defaults to `2`): 44 | REINFORCE Leave-One-Out (RLOO) number of online samples per prompt. 45 | normalize_reward (`bool`, *optional*, defaults to `False`): 46 | Whether to normalize rewards. 47 | reward_clip_range (`float`, *optional*, defaults to `10.0`): 48 | Clip range for rewards. 49 | normalize_advantage (`bool`, *optional*, defaults to `False`): 50 | Whether to normalize advantages. 51 | token_level_kl (`bool`, *optional*, defaults to `True`): 52 | Whether to use token-level KL penalty or sequence-level KL penalty. 53 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): 54 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, 55 | improving generation speed. However, disabling this option allows training models that exceed the VRAM 56 | capacity of a single GPU, albeit at the cost of slower generation. 57 | """ 58 | 59 | exp_name: str = field( 60 | default=os.path.basename(__file__)[:-3], 61 | metadata={"help": "Name of this experiment."}, 62 | ) 63 | reward_model_path: str = field( 64 | default="EleutherAI/pythia-160m", 65 | metadata={"help": "Path to the reward model."}, 66 | ) 67 | num_ppo_epochs: int = field( 68 | default=4, 69 | metadata={"help": "Number of epochs to train."}, 70 | ) 71 | whiten_rewards: bool = field( 72 | default=False, 73 | metadata={"help": "Whether to whiten the rewards."}, 74 | ) 75 | kl_coef: float = field( 76 | default=0.05, 77 | metadata={"help": "KL coefficient."}, 78 | ) 79 | cliprange: float = field( 80 | default=0.2, 81 | metadata={"help": "Clip range."}, 82 | ) 83 | rloo_k: int = field( 84 | default=2, 85 | metadata={"help": "REINFORCE Leave-One-Out (RLOO) number of online samples per prompt."}, 86 | ) 87 | normalize_reward: bool = field( 88 | default=False, 89 | metadata={"help": "Whether to normalize rewards"}, 90 | ) 91 | reward_clip_range: float = field( 92 | default=10.0, 93 | metadata={"help": "Clip range for rewards"}, 94 | ) 95 | normalize_advantage: bool = field( 96 | default=False, 97 | metadata={"help": "Whether to normalize advantages"}, 98 | ) 99 | token_level_kl: bool = field( 100 | default=False, 101 | metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"}, 102 | ) 103 | ds3_gather_for_generation: bool = field( 104 | default=True, 105 | metadata={ 106 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " 107 | "generation, improving generation speed. However, disabling this option allows training models that " 108 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." 109 | }, 110 | ) 111 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/gkd_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Any, Optional 17 | 18 | from .sft_config import SFTConfig 19 | 20 | 21 | @dataclass 22 | class GKDConfig(SFTConfig): 23 | """ 24 | Configuration class for [`GKDTrainer`]. 25 | 26 | Args: 27 | temperature (`float`, *optional*, defaults to `0.9`): 28 | Temperature for sampling. The higher the temperature, the more random the completions. 29 | lmbda (`float`, *optional*, defaults to `0.5`): 30 | Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy 31 | student-generated outputs). 32 | beta (`float`, *optional*, defaults to `0.5`): 33 | Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When 34 | beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence. 35 | max_new_tokens (`int`, *optional*, defaults to `128`): 36 | Maximum number of tokens to generate per completion. 37 | teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`): 38 | Model name or path of the teacher model. If `None`, the teacher model will be the same as the model 39 | being trained. 40 | teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`): 41 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model 42 | from a string. 43 | disable_dropout (`bool`, *optional*, defaults to `True`): 44 | Whether to disable dropout in the model. 45 | seq_kd (`bool`, *optional*, defaults to `False`): 46 | Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT 47 | on teacher-generated output). 48 | """ 49 | 50 | temperature: float = field( 51 | default=0.9, 52 | metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, 53 | ) 54 | lmbda: float = field( 55 | default=0.5, 56 | metadata={ 57 | "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy " 58 | "student-generated outputs)." 59 | }, 60 | ) 61 | beta: float = field( 62 | default=0.5, 63 | metadata={ 64 | "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence " 65 | "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL " 66 | "Divergence." 67 | }, 68 | ) 69 | max_new_tokens: int = field( 70 | default=128, 71 | metadata={"help": "Maximum number of tokens to generate per completion."}, 72 | ) 73 | teacher_model_name_or_path: Optional[str] = field( 74 | default=None, 75 | metadata={ 76 | "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the " 77 | "model being trained." 78 | }, 79 | ) 80 | teacher_model_init_kwargs: Optional[dict[str, Any]] = field( 81 | default=None, 82 | metadata={ 83 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " 84 | "teacher model from a string." 85 | }, 86 | ) 87 | disable_dropout: bool = field( 88 | default=True, 89 | metadata={"help": "Whether to disable dropouts in `model`."}, 90 | ) 91 | seq_kd: bool = field( 92 | default=False, 93 | metadata={ 94 | "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised " 95 | "FT on teacher-generated output)." 96 | }, 97 | ) 98 | 99 | 100 | SMI_reweighting: bool = field( 101 | default=False, 102 | metadata={ 103 | "help": "use SMI rewighting." 104 | }, 105 | ) 106 | 107 | 108 | def __post_init__(self): 109 | super().__post_init__() 110 | # check lmbda and beta are in the range [0, 1] 111 | if self.lmbda < 0.0 or self.lmbda > 1.0: 112 | raise ValueError("lmbda must be in the range [0.0, 1.0].") 113 | if self.beta < 0.0 or self.beta > 1.0: 114 | raise ValueError("beta must be in the range [0.0, 1.0].") 115 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/ppo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | from dataclasses import dataclass, field 17 | from typing import Optional 18 | 19 | from ..trainer.utils import OnPolicyConfig 20 | 21 | 22 | @dataclass 23 | class PPOConfig(OnPolicyConfig): 24 | r""" 25 | Configuration class for the [`PPOTrainer`]. 26 | 27 | Using [`~transformers.HfArgumentParser`] we can turn this class into 28 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 29 | command line. 30 | 31 | Parameters: 32 | exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`): 33 | Name of this experiment. 34 | reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`): 35 | Path to the reward model. 36 | model_adapter_name (`str` or `None`, *optional*, defaults to `None`): 37 | Name of the train target PEFT adapter, when using LoRA with multiple adapters. 38 | ref_adapter_name (`str` or `None`, *optional*, defaults to `None`): 39 | Name of the reference PEFT adapter, when using LoRA with multiple adapters. 40 | num_ppo_epochs (`int`, *optional*, defaults to `4`): 41 | Number of epochs to train. 42 | whiten_rewards (`bool`, *optional*, defaults to `False`): 43 | Whether to whiten the rewards. 44 | kl_coef (`float`, *optional*, defaults to `0.05`): 45 | KL coefficient. 46 | cliprange (`float`, *optional*, defaults to `0.2`): 47 | Clip range. 48 | vf_coef (`float`, *optional*, defaults to `0.1`): 49 | Value function coefficient. 50 | cliprange_value (`float`, *optional*, defaults to `0.2`): 51 | Clip range for the value function. 52 | gamma (`float`, *optional*, defaults to `1.0`): 53 | Discount factor. 54 | lam (`float`, *optional*, defaults to `0.95`): 55 | Lambda value for GAE. 56 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): 57 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, 58 | improving generation speed. However, disabling this option allows training models that exceed the VRAM 59 | capacity of a single GPU, albeit at the cost of slower generation. 60 | """ 61 | 62 | exp_name: str = field( 63 | default=os.path.basename(__file__)[:-3], 64 | metadata={"help": "Name of this experiment."}, 65 | ) 66 | reward_model_path: str = field( 67 | default="EleutherAI/pythia-160m", 68 | metadata={"help": "Path to the reward model."}, 69 | ) 70 | model_adapter_name: Optional[str] = field( 71 | default=None, 72 | metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."}, 73 | ) 74 | ref_adapter_name: Optional[str] = field( 75 | default=None, 76 | metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."}, 77 | ) 78 | num_ppo_epochs: int = field( 79 | default=4, 80 | metadata={"help": "Number of epochs to train."}, 81 | ) 82 | whiten_rewards: bool = field( 83 | default=False, 84 | metadata={"help": "Whether to whiten the rewards."}, 85 | ) 86 | kl_coef: float = field( 87 | default=0.05, 88 | metadata={"help": "KL coefficient."}, 89 | ) 90 | cliprange: float = field( 91 | default=0.2, 92 | metadata={"help": "Clip range."}, 93 | ) 94 | vf_coef: float = field( 95 | default=0.1, 96 | metadata={"help": "Value function coefficient."}, 97 | ) 98 | cliprange_value: float = field( 99 | default=0.2, 100 | metadata={"help": "Clip range for the value function."}, 101 | ) 102 | gamma: float = field( 103 | default=1.0, 104 | metadata={"help": "Discount factor."}, 105 | ) 106 | lam: float = field( 107 | default=0.95, 108 | metadata={"help": "Lambda value for GAE."}, 109 | ) 110 | ds3_gather_for_generation: bool = field( 111 | default=True, 112 | metadata={ 113 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " 114 | "generation, improving generation speed. However, disabling this option allows training models that " 115 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." 116 | }, 117 | ) 118 | -------------------------------------------------------------------------------- /src/open_r1/trl/scripts/dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | # Full training 17 | python trl/scripts/dpo.py \ 18 | --dataset_name trl-lib/ultrafeedback_binarized \ 19 | --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ 20 | --learning_rate 5.0e-7 \ 21 | --num_train_epochs 1 \ 22 | --per_device_train_batch_size 2 \ 23 | --gradient_accumulation_steps 8 \ 24 | --gradient_checkpointing \ 25 | --logging_steps 25 \ 26 | --eval_strategy steps \ 27 | --eval_steps 50 \ 28 | --output_dir Qwen2-0.5B-DPO \ 29 | --no_remove_unused_columns 30 | 31 | # LoRA: 32 | python trl/scripts/dpo.py \ 33 | --dataset_name trl-lib/ultrafeedback_binarized \ 34 | --model_name_or_path Qwen/Qwen2-0.5B-Instruct \ 35 | --learning_rate 5.0e-6 \ 36 | --num_train_epochs 1 \ 37 | --per_device_train_batch_size 2 \ 38 | --gradient_accumulation_steps 8 \ 39 | --gradient_checkpointing \ 40 | --logging_steps 25 \ 41 | --eval_strategy steps \ 42 | --eval_steps 50 \ 43 | --output_dir Qwen2-0.5B-DPO \ 44 | --no_remove_unused_columns \ 45 | --use_peft \ 46 | --lora_r 32 \ 47 | --lora_alpha 16 48 | """ 49 | 50 | import argparse 51 | 52 | import torch 53 | from datasets import load_dataset 54 | from transformers import AutoModelForCausalLM, AutoTokenizer 55 | 56 | from trl import ( 57 | DPOConfig, 58 | DPOTrainer, 59 | ModelConfig, 60 | ScriptArguments, 61 | TrlParser, 62 | get_kbit_device_map, 63 | get_peft_config, 64 | get_quantization_config, 65 | ) 66 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE 67 | 68 | 69 | def main(script_args, training_args, model_args): 70 | ################ 71 | # Model & Tokenizer 72 | ################### 73 | torch_dtype = ( 74 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 75 | ) 76 | quantization_config = get_quantization_config(model_args) 77 | model_kwargs = dict( 78 | revision=model_args.model_revision, 79 | attn_implementation=model_args.attn_implementation, 80 | torch_dtype=torch_dtype, 81 | use_cache=False if training_args.gradient_checkpointing else True, 82 | device_map=get_kbit_device_map() if quantization_config is not None else None, 83 | quantization_config=quantization_config, 84 | ) 85 | model = AutoModelForCausalLM.from_pretrained( 86 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs 87 | ) 88 | peft_config = get_peft_config(model_args) 89 | if peft_config is None: 90 | ref_model = AutoModelForCausalLM.from_pretrained( 91 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs 92 | ) 93 | else: 94 | ref_model = None 95 | tokenizer = AutoTokenizer.from_pretrained( 96 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 97 | ) 98 | if tokenizer.pad_token is None: 99 | tokenizer.pad_token = tokenizer.eos_token 100 | if tokenizer.chat_template is None: 101 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE 102 | if script_args.ignore_bias_buffers: 103 | # torch distributed hack 104 | model._ddp_params_and_buffers_to_ignore = [ 105 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool 106 | ] 107 | 108 | ################ 109 | # Dataset 110 | ################ 111 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 112 | 113 | ########## 114 | # Training 115 | ################ 116 | trainer = DPOTrainer( 117 | model, 118 | ref_model, 119 | args=training_args, 120 | train_dataset=dataset[script_args.dataset_train_split], 121 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 122 | processing_class=tokenizer, 123 | peft_config=peft_config, 124 | ) 125 | 126 | trainer.train() 127 | 128 | if training_args.eval_strategy != "no": 129 | metrics = trainer.evaluate() 130 | trainer.log_metrics("eval", metrics) 131 | trainer.save_metrics("eval", metrics) 132 | 133 | # Save and push to hub 134 | trainer.save_model(training_args.output_dir) 135 | if training_args.push_to_hub: 136 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 137 | 138 | 139 | def make_parser(subparsers: argparse._SubParsersAction = None): 140 | dataclass_types = (ScriptArguments, DPOConfig, ModelConfig) 141 | if subparsers is not None: 142 | parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types) 143 | else: 144 | parser = TrlParser(dataclass_types) 145 | return parser 146 | 147 | 148 | if __name__ == "__main__": 149 | parser = make_parser() 150 | script_args, training_args, model_args = parser.parse_args_and_config() 151 | main(script_args, training_args, model_args) 152 | -------------------------------------------------------------------------------- /src/open_r1/utils/hub.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import logging 18 | import re 19 | from concurrent.futures import Future 20 | 21 | from transformers import AutoConfig 22 | 23 | from huggingface_hub import ( 24 | create_branch, 25 | create_repo, 26 | get_safetensors_metadata, 27 | list_repo_commits, 28 | list_repo_files, 29 | list_repo_refs, 30 | repo_exists, 31 | upload_folder, 32 | ) 33 | from trl import GRPOConfig, SFTConfig 34 | 35 | 36 | logger = logging.getLogger(__name__) 37 | 38 | 39 | def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future: 40 | """Pushes the model to branch on a Hub repo.""" 41 | 42 | # Create a repo if it doesn't exist yet 43 | repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True) 44 | # Get initial commit to branch from 45 | initial_commit = list_repo_commits(training_args.hub_model_id)[-1] 46 | # Now create the branch we'll be pushing to 47 | create_branch( 48 | repo_id=training_args.hub_model_id, 49 | branch=training_args.hub_model_revision, 50 | revision=initial_commit.commit_id, 51 | exist_ok=True, 52 | ) 53 | logger.info(f"Created target repo at {repo_url}") 54 | logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...") 55 | ignore_patterns = ["checkpoint-*", "*.pth"] 56 | ignore_patterns.extend(extra_ignore_patterns) 57 | future = upload_folder( 58 | repo_id=training_args.hub_model_id, 59 | folder_path=training_args.output_dir, 60 | revision=training_args.hub_model_revision, 61 | commit_message=f"Add {training_args.hub_model_revision} checkpoint", 62 | ignore_patterns=ignore_patterns, 63 | run_as_future=True, 64 | ) 65 | logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!") 66 | 67 | return future 68 | 69 | 70 | def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig): 71 | """Checks if a given Hub revision exists.""" 72 | if repo_exists(training_args.hub_model_id): 73 | if training_args.push_to_hub_revision is True: 74 | # First check if the revision exists 75 | revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches] 76 | # If the revision exists, we next check it has a README file 77 | if training_args.hub_model_revision in revisions: 78 | repo_files = list_repo_files( 79 | repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision 80 | ) 81 | if "README.md" in repo_files and training_args.overwrite_hub_revision is False: 82 | raise ValueError( 83 | f"Revision {training_args.hub_model_revision} already exists. " 84 | "Use --overwrite_hub_revision to overwrite it." 85 | ) 86 | 87 | 88 | def get_param_count_from_repo_id(repo_id: str) -> int: 89 | """Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID.""" 90 | try: 91 | metadata = get_safetensors_metadata(repo_id) 92 | return list(metadata.parameter_count.values())[0] 93 | except Exception: 94 | # Pattern to match products (like 8x7b) and single values (like 42m) 95 | pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])" 96 | matches = re.findall(pattern, repo_id.lower()) 97 | 98 | param_counts = [] 99 | for full_match, number1, _, _, number2, _, unit in matches: 100 | if number2: # If there's a second number, it's a product 101 | number = float(number1) * float(number2) 102 | else: # Otherwise, it's a single value 103 | number = float(number1) 104 | 105 | if unit == "b": 106 | number *= 1_000_000_000 # Convert to billion 107 | elif unit == "m": 108 | number *= 1_000_000 # Convert to million 109 | 110 | param_counts.append(number) 111 | 112 | if len(param_counts) > 0: 113 | # Return the largest number 114 | return int(max(param_counts)) 115 | else: 116 | # Return -1 if no match found 117 | return -1 118 | 119 | 120 | def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int: 121 | """vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs. 122 | This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model. 123 | """ 124 | config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True) 125 | # Get number of attention heads 126 | num_heads = config.num_attention_heads 127 | # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus 128 | while num_heads % num_gpus != 0 or 64 % num_gpus != 0: 129 | logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus") 130 | num_gpus -= 1 131 | return num_gpus 132 | -------------------------------------------------------------------------------- /src/open_r1/trl/extras/best_of_n_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Any, Callable, Optional, Union 16 | 17 | import torch 18 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed 19 | 20 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper 21 | 22 | 23 | class BestOfNSampler: 24 | def __init__( 25 | self, 26 | model: PreTrainedModelWrapper, 27 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], 28 | queries_to_scores: Callable[[list[str]], list[float]], 29 | length_sampler: Any, 30 | sample_size: int = 4, 31 | seed: Optional[int] = None, 32 | n_candidates: int = 1, 33 | generation_config: Optional[GenerationConfig] = None, 34 | ) -> None: 35 | r""" 36 | Initialize the sampler for best-of-n generation 37 | 38 | Args: 39 | model (`PreTrainedModelWrapper`): 40 | The pretrained model to use for generation 41 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`): 42 | Tokenizer associated with the pretrained model 43 | queries_to_scores (`Callable[[list[str]], list[float]]`): 44 | Callable that takes a list of generated texts and returns the associated reward scores 45 | length_sampler (`Any`): 46 | Sampler used to sample the length of the generated text 47 | sample_size (`int`): 48 | Number of samples to generate for each query 49 | seed (`int`, *optional*): 50 | Random seed used to control generation 51 | n_candidates (`int`): 52 | Number of candidates to return for each query 53 | generation_config (`GenerationConfig`, *optional*): 54 | Generation config passed to the underlying model's `generate` method. 55 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details 56 | """ 57 | if seed is not None: 58 | set_seed(seed) 59 | 60 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)): 61 | raise ValueError( 62 | f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}" 63 | ) 64 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)): 65 | raise ValueError( 66 | f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}" 67 | ) 68 | 69 | self.model = model 70 | self.tokenizer = tokenizer 71 | 72 | self.queries_to_scores = queries_to_scores 73 | self.length_sampler = length_sampler 74 | self.gen_config = generation_config 75 | self.sample_size = sample_size 76 | self.n_candidates = n_candidates 77 | 78 | def generate( 79 | self, 80 | tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]], 81 | skip_special_tokens: bool = True, 82 | device: Optional[Union[str, torch.device]] = None, 83 | **generation_kwargs, 84 | ) -> list[list[str]]: 85 | r""" 86 | Generate the best of n samples for input queries 87 | 88 | Args: 89 | tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[int]`): 90 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers) 91 | skip_special_tokens (`bool`): 92 | Whether to remove the special tokens from the output 93 | device (`str` or `torch.device`, *optional*): 94 | The device on which the model will be loaded 95 | **generation_kwargs (`dict`, *optional*): 96 | Additional keyword arguments passed along to the underlying model's `generate` method. 97 | This is used to override generation config 98 | 99 | Returns: 100 | list[list[str]]: A list of lists of generated texts 101 | """ 102 | queries = None 103 | 104 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1: 105 | queries = tokenized_query.unsqueeze(0) 106 | elif isinstance(tokenized_query, list): 107 | element_type = type(tokenized_query[0]) 108 | if element_type is int: 109 | queries = torch.tensor(tokenized_query).unsqueeze(0) 110 | elif element_type is torch.Tensor: 111 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query] 112 | else: 113 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query] 114 | 115 | result = [] 116 | 117 | for query in queries: 118 | queries = query.repeat((self.sample_size, 1)) 119 | output = self.model.generate( 120 | queries.to(device), 121 | max_new_tokens=self.length_sampler(), 122 | generation_config=self.gen_config, 123 | **generation_kwargs, 124 | ).squeeze() 125 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens) 126 | scores = torch.tensor(self.queries_to_scores(output)) 127 | output = [output[i] for i in scores.topk(self.n_candidates).indices] 128 | result.append(output) 129 | 130 | return result 131 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available 18 | 19 | 20 | _import_structure = { 21 | "alignprop_config": ["AlignPropConfig"], 22 | "alignprop_trainer": ["AlignPropTrainer"], 23 | "bco_config": ["BCOConfig"], 24 | "bco_trainer": ["BCOTrainer"], 25 | "callbacks": [ 26 | "LogCompletionsCallback", 27 | "MergeModelCallback", 28 | "RichProgressCallback", 29 | "SyncRefModelCallback", 30 | "WinRateCallback", 31 | ], 32 | "cpo_config": ["CPOConfig"], 33 | "cpo_trainer": ["CPOTrainer"], 34 | "ddpo_config": ["DDPOConfig"], 35 | "dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"], 36 | "dpo_trainer": ["DPOTrainer"], 37 | "gkd_config": ["GKDConfig"], 38 | "gkd_trainer": ["GKDTrainer"], 39 | "grpo_config": ["GRPOConfig"], 40 | "grpo_trainer": ["GRPOTrainer"], 41 | "drgrpo_trainer": ["DRGRPOTrainer"], 42 | "iterative_sft_trainer": ["IterativeSFTTrainer"], 43 | "judges": [ 44 | "AllTrueJudge", 45 | "BaseBinaryJudge", 46 | "BaseJudge", 47 | "BasePairwiseJudge", 48 | "BaseRankJudge", 49 | "HfPairwiseJudge", 50 | "OpenAIPairwiseJudge", 51 | "PairRMJudge", 52 | ], 53 | "kto_config": ["KTOConfig"], 54 | "kto_trainer": ["KTOTrainer"], 55 | "model_config": ["ModelConfig"], 56 | "nash_md_config": ["NashMDConfig"], 57 | "nash_md_trainer": ["NashMDTrainer"], 58 | "online_dpo_config": ["OnlineDPOConfig"], 59 | "online_dpo_trainer": ["OnlineDPOTrainer"], 60 | "orpo_config": ["ORPOConfig"], 61 | "orpo_trainer": ["ORPOTrainer"], 62 | "ppo_config": ["PPOConfig"], 63 | "ppo_trainer": ["PPOTrainer"], 64 | "prm_config": ["PRMConfig"], 65 | "prm_trainer": ["PRMTrainer"], 66 | "reward_config": ["RewardConfig"], 67 | "reward_trainer": ["RewardTrainer"], 68 | "rloo_config": ["RLOOConfig"], 69 | "rloo_trainer": ["RLOOTrainer"], 70 | "sft_config": ["SFTConfig"], 71 | "sft_trainer": ["SFTTrainer"], 72 | "utils": [ 73 | "ConstantLengthDataset", 74 | "DataCollatorForCompletionOnlyLM", 75 | "RunningMoments", 76 | "compute_accuracy", 77 | "disable_dropout_in_model", 78 | "empty_cache", 79 | "peft_module_casting_to_bf16", 80 | ], 81 | "xpo_config": ["XPOConfig"], 82 | "xpo_trainer": ["XPOTrainer"], 83 | } 84 | try: 85 | if not is_diffusers_available(): 86 | raise OptionalDependencyNotAvailable() 87 | except OptionalDependencyNotAvailable: 88 | pass 89 | else: 90 | _import_structure["ddpo_trainer"] = ["DDPOTrainer"] 91 | 92 | if TYPE_CHECKING: 93 | from .alignprop_config import AlignPropConfig 94 | from .alignprop_trainer import AlignPropTrainer 95 | from .bco_config import BCOConfig 96 | from .bco_trainer import BCOTrainer 97 | from .callbacks import ( 98 | LogCompletionsCallback, 99 | MergeModelCallback, 100 | RichProgressCallback, 101 | SyncRefModelCallback, 102 | WinRateCallback, 103 | ) 104 | from .cpo_config import CPOConfig 105 | from .cpo_trainer import CPOTrainer 106 | from .ddpo_config import DDPOConfig 107 | from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType 108 | from .dpo_trainer import DPOTrainer 109 | from .gkd_config import GKDConfig 110 | from .gkd_trainer import GKDTrainer 111 | from .grpo_config import GRPOConfig 112 | from .grpo_trainer import GRPOTrainer 113 | from .drgrpo_trainer import DRGRPOTrainer 114 | from .iterative_sft_trainer import IterativeSFTTrainer 115 | from .judges import ( 116 | AllTrueJudge, 117 | BaseBinaryJudge, 118 | BaseJudge, 119 | BasePairwiseJudge, 120 | BaseRankJudge, 121 | HfPairwiseJudge, 122 | OpenAIPairwiseJudge, 123 | PairRMJudge, 124 | ) 125 | from .kto_config import KTOConfig 126 | from .kto_trainer import KTOTrainer 127 | from .model_config import ModelConfig 128 | from .nash_md_config import NashMDConfig 129 | from .nash_md_trainer import NashMDTrainer 130 | from .online_dpo_config import OnlineDPOConfig 131 | from .online_dpo_trainer import OnlineDPOTrainer 132 | from .orpo_config import ORPOConfig 133 | from .orpo_trainer import ORPOTrainer 134 | from .ppo_config import PPOConfig 135 | from .ppo_trainer import PPOTrainer 136 | from .prm_config import PRMConfig 137 | from .prm_trainer import PRMTrainer 138 | from .reward_config import RewardConfig 139 | from .reward_trainer import RewardTrainer 140 | from .rloo_config import RLOOConfig 141 | from .rloo_trainer import RLOOTrainer 142 | from .sft_config import SFTConfig 143 | from .sft_trainer import SFTTrainer 144 | from .utils import ( 145 | ConstantLengthDataset, 146 | DataCollatorForCompletionOnlyLM, 147 | RunningMoments, 148 | compute_accuracy, 149 | disable_dropout_in_model, 150 | empty_cache, 151 | peft_module_casting_to_bf16, 152 | ) 153 | from .xpo_config import XPOConfig 154 | from .xpo_trainer import XPOTrainer 155 | 156 | try: 157 | if not is_diffusers_available(): 158 | raise OptionalDependencyNotAvailable() 159 | except OptionalDependencyNotAvailable: 160 | pass 161 | else: 162 | from .ddpo_trainer import DDPOTrainer 163 | else: 164 | import sys 165 | 166 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 167 | -------------------------------------------------------------------------------- /src/open_r1/trl/models/sd_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | State dict utilities: utility methods for converting state dicts easily 17 | File copied from diffusers to avoid import issues and make TRL compatible 18 | with most of diffusers versions. 19 | """ 20 | 21 | import enum 22 | 23 | 24 | class StateDictType(enum.Enum): 25 | """ 26 | The mode to use when converting state dicts. 27 | """ 28 | 29 | DIFFUSERS_OLD = "diffusers_old" 30 | PEFT = "peft" 31 | 32 | 33 | PEFT_TO_DIFFUSERS = { 34 | ".q_proj.lora_B": ".q_proj.lora_linear_layer.up", 35 | ".q_proj.lora_A": ".q_proj.lora_linear_layer.down", 36 | ".k_proj.lora_B": ".k_proj.lora_linear_layer.up", 37 | ".k_proj.lora_A": ".k_proj.lora_linear_layer.down", 38 | ".v_proj.lora_B": ".v_proj.lora_linear_layer.up", 39 | ".v_proj.lora_A": ".v_proj.lora_linear_layer.down", 40 | ".out_proj.lora_B": ".out_proj.lora_linear_layer.up", 41 | ".out_proj.lora_A": ".out_proj.lora_linear_layer.down", 42 | "to_k.lora_A": "to_k.lora.down", 43 | "to_k.lora_B": "to_k.lora.up", 44 | "to_q.lora_A": "to_q.lora.down", 45 | "to_q.lora_B": "to_q.lora.up", 46 | "to_v.lora_A": "to_v.lora.down", 47 | "to_v.lora_B": "to_v.lora.up", 48 | "to_out.0.lora_A": "to_out.0.lora.down", 49 | "to_out.0.lora_B": "to_out.0.lora.up", 50 | } 51 | 52 | DIFFUSERS_OLD_TO_DIFFUSERS = { 53 | ".to_q_lora.up": ".q_proj.lora_linear_layer.up", 54 | ".to_q_lora.down": ".q_proj.lora_linear_layer.down", 55 | ".to_k_lora.up": ".k_proj.lora_linear_layer.up", 56 | ".to_k_lora.down": ".k_proj.lora_linear_layer.down", 57 | ".to_v_lora.up": ".v_proj.lora_linear_layer.up", 58 | ".to_v_lora.down": ".v_proj.lora_linear_layer.down", 59 | ".to_out_lora.up": ".out_proj.lora_linear_layer.up", 60 | ".to_out_lora.down": ".out_proj.lora_linear_layer.down", 61 | } 62 | 63 | DIFFUSERS_STATE_DICT_MAPPINGS = { 64 | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS, 65 | StateDictType.PEFT: PEFT_TO_DIFFUSERS, 66 | } 67 | 68 | KEYS_TO_ALWAYS_REPLACE = { 69 | ".processor.": ".", 70 | } 71 | 72 | 73 | def convert_state_dict(state_dict, mapping): 74 | r""" 75 | Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values. 76 | 77 | Args: 78 | state_dict (`dict[str, torch.Tensor]`): 79 | The state dict to convert. 80 | mapping (`dict[str, str]`): 81 | The mapping to use for conversion, the mapping should be a dictionary with the following structure: 82 | - key: the pattern to replace 83 | - value: the pattern to replace with 84 | 85 | Returns: 86 | converted_state_dict (`dict`) 87 | The converted state dict. 88 | """ 89 | converted_state_dict = {} 90 | for k, v in state_dict.items(): 91 | # First, filter out the keys that we always want to replace 92 | for pattern in KEYS_TO_ALWAYS_REPLACE.keys(): 93 | if pattern in k: 94 | new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern] 95 | k = k.replace(pattern, new_pattern) 96 | 97 | for pattern in mapping.keys(): 98 | if pattern in k: 99 | new_pattern = mapping[pattern] 100 | k = k.replace(pattern, new_pattern) 101 | break 102 | converted_state_dict[k] = v 103 | return converted_state_dict 104 | 105 | 106 | def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs): 107 | r""" 108 | Converts a state dict to new diffusers format. The state dict can be from previous diffusers format 109 | (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will 110 | return the state dict as is. 111 | 112 | The method only supports the conversion from diffusers old, PEFT to diffusers new for now. 113 | 114 | Args: 115 | state_dict (`dict[str, torch.Tensor]`): 116 | The state dict to convert. 117 | original_type (`StateDictType`, *optional*): 118 | The original type of the state dict, if not provided, the method will try to infer it automatically. 119 | kwargs (`dict`, *args*): 120 | Additional arguments to pass to the method. 121 | 122 | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended 123 | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in 124 | `get_peft_model_state_dict` method: 125 | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92 126 | but we add it here in case we don't want to rely on that method. 127 | """ 128 | peft_adapter_name = kwargs.pop("adapter_name", None) 129 | if peft_adapter_name is not None: 130 | peft_adapter_name = "." + peft_adapter_name 131 | else: 132 | peft_adapter_name = "" 133 | 134 | if original_type is None: 135 | # Old diffusers to PEFT 136 | if any("to_out_lora" in k for k in state_dict.keys()): 137 | original_type = StateDictType.DIFFUSERS_OLD 138 | elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()): 139 | original_type = StateDictType.PEFT 140 | elif any("lora_linear_layer" in k for k in state_dict.keys()): 141 | # nothing to do 142 | return state_dict 143 | else: 144 | raise ValueError("Could not automatically infer state dict type") 145 | 146 | if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys(): 147 | raise ValueError(f"Original type {original_type} is not supported") 148 | 149 | mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type] 150 | return convert_state_dict(state_dict, mapping) 151 | -------------------------------------------------------------------------------- /src/open_r1/trl/core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | import warnings 17 | from collections.abc import Mapping 18 | from contextlib import contextmanager 19 | from typing import Optional, Union 20 | 21 | import numpy as np 22 | import torch 23 | from transformers import is_torch_npu_available, is_torch_xpu_available 24 | 25 | 26 | def flatten_dict(nested: dict, sep: str = "/") -> dict: 27 | """Flatten dictionary and concatenate nested keys with separator.""" 28 | 29 | def recurse(nest: dict, prefix: str, into: dict) -> None: 30 | for k, v in nest.items(): 31 | if sep in k: 32 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'") 33 | if isinstance(v, Mapping): 34 | recurse(v, prefix + k + sep, into) 35 | else: 36 | into[prefix + k] = v 37 | 38 | flat = {} 39 | recurse(nested, "", flat) 40 | return flat 41 | 42 | 43 | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor: 44 | """Compute mean of tensor with a masked values.""" 45 | if axis is not None: 46 | return (values * mask).sum(axis=axis) / mask.sum(axis=axis) 47 | else: 48 | return (values * mask).sum() / mask.sum() 49 | 50 | 51 | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor: 52 | """Compute variance of tensor with masked values.""" 53 | mean = masked_mean(values, mask) 54 | centered_values = values - mean 55 | variance = masked_mean(centered_values**2, mask) 56 | if unbiased: 57 | mask_sum = mask.sum() 58 | if mask_sum == 0: 59 | raise ValueError( 60 | "The sum of the mask is zero, which can happen when `mini_batch_size=1`;" 61 | "try increase the `mini_batch_size` or `gradient_accumulation_steps`" 62 | ) 63 | # note that if mask_sum == 1, then there is a division by zero issue 64 | # to avoid it you just need to use a larger minibatch_size 65 | bessel_correction = mask_sum / (mask_sum - 1) 66 | variance = variance * bessel_correction 67 | return variance 68 | 69 | 70 | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor: 71 | """Whiten values with masked values.""" 72 | mean, var = masked_mean(values, mask), masked_var(values, mask) 73 | whitened = (values - mean) * torch.rsqrt(var + 1e-8) 74 | if not shift_mean: 75 | whitened += mean 76 | return whitened 77 | 78 | 79 | class LengthSampler: 80 | """ 81 | Samples a length 82 | """ 83 | 84 | def __init__(self, min_value: int, max_value: int): 85 | self.values = list(range(min_value, max_value)) 86 | 87 | def __call__(self) -> int: 88 | return np.random.choice(self.values) 89 | 90 | 91 | class PPODecorators: 92 | optimize_device_cache = False 93 | 94 | @classmethod 95 | @contextmanager 96 | def empty_device_cache(cls): 97 | yield 98 | if cls.optimize_device_cache: 99 | if is_torch_xpu_available(): 100 | gc.collect() 101 | torch.xpu.empty_cache() 102 | gc.collect() 103 | elif is_torch_npu_available(): 104 | gc.collect() 105 | torch.npu.empty_cache() 106 | gc.collect() 107 | elif torch.cuda.is_available(): 108 | gc.collect() 109 | torch.cuda.empty_cache() 110 | gc.collect() 111 | 112 | 113 | def randn_tensor( 114 | shape: Union[tuple, list], 115 | generator: Optional[Union[list[torch.Generator], torch.Generator]] = None, 116 | device: Optional[torch.device] = None, 117 | dtype: Optional[torch.dtype] = None, 118 | layout: Optional[torch.layout] = None, 119 | ) -> torch.Tensor: 120 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When 121 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 122 | is always created on the CPU. 123 | """ 124 | # device on which tensor is created defaults to device 125 | rand_device = device 126 | batch_size = shape[0] 127 | 128 | layout = layout or torch.strided 129 | device = device or torch.device("cpu") 130 | 131 | if generator is not None: 132 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 133 | if gen_device_type != device.type and gen_device_type == "cpu": 134 | rand_device = "cpu" 135 | if device != "mps": 136 | warnings.warn( 137 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 138 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 139 | f" slighly speed up this function by passing a generator that was created on the {device} device.", 140 | UserWarning, 141 | ) 142 | elif gen_device_type != device.type and gen_device_type == "cuda": 143 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 144 | 145 | # make sure generator list of length 1 is treated like a non-list 146 | if isinstance(generator, list) and len(generator) == 1: 147 | generator = generator[0] 148 | 149 | if isinstance(generator, list): 150 | shape = (1,) + shape[1:] 151 | latents = [ 152 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 153 | for i in range(batch_size) 154 | ] 155 | latents = torch.cat(latents, dim=0).to(device) 156 | else: 157 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 158 | 159 | return latents 160 | -------------------------------------------------------------------------------- /src/open_r1/generate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import Optional 16 | 17 | from distilabel.llms import OpenAILLM 18 | from distilabel.pipeline import Pipeline 19 | from distilabel.steps import StepResources 20 | from distilabel.steps.tasks import TextGeneration 21 | 22 | 23 | def build_distilabel_pipeline( 24 | model: str, 25 | base_url: str = "http://localhost:8000/v1", 26 | prompt_column: Optional[str] = None, 27 | prompt_template: str = "{{ instruction }}", 28 | temperature: Optional[float] = None, 29 | top_p: Optional[float] = None, 30 | max_new_tokens: int = 8192, 31 | num_generations: int = 1, 32 | input_batch_size: int = 64, 33 | client_replicas: int = 1, 34 | timeout: int = 900, 35 | retries: int = 0, 36 | ) -> Pipeline: 37 | generation_kwargs = {"max_new_tokens": max_new_tokens} 38 | 39 | if temperature is not None: 40 | generation_kwargs["temperature"] = temperature 41 | 42 | if top_p is not None: 43 | generation_kwargs["top_p"] = top_p 44 | 45 | with Pipeline().ray() as pipeline: 46 | TextGeneration( 47 | llm=OpenAILLM( 48 | base_url=base_url, 49 | api_key="something", 50 | model=model, 51 | timeout=timeout, 52 | max_retries=retries, 53 | generation_kwargs=generation_kwargs, 54 | ), 55 | template=prompt_template, 56 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {}, 57 | input_batch_size=input_batch_size, 58 | num_generations=num_generations, 59 | group_generations=True, 60 | resources=StepResources(replicas=client_replicas), 61 | ) 62 | 63 | return pipeline 64 | 65 | 66 | if __name__ == "__main__": 67 | import argparse 68 | 69 | from datasets import load_dataset 70 | 71 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1") 72 | parser.add_argument( 73 | "--hf-dataset", 74 | type=str, 75 | required=True, 76 | help="HuggingFace dataset to load", 77 | ) 78 | parser.add_argument( 79 | "--hf-dataset-config", 80 | type=str, 81 | required=False, 82 | help="Dataset config to use", 83 | ) 84 | parser.add_argument( 85 | "--hf-dataset-split", 86 | type=str, 87 | default="train", 88 | help="Dataset split to use", 89 | ) 90 | parser.add_argument( 91 | "--prompt-column", 92 | type=str, 93 | default="prompt", 94 | ) 95 | parser.add_argument( 96 | "--prompt-template", 97 | type=str, 98 | default="{{ instruction }}", 99 | help="Template string for formatting prompts.", 100 | ) 101 | parser.add_argument( 102 | "--model", 103 | type=str, 104 | required=True, 105 | help="Model name to use for generation", 106 | ) 107 | parser.add_argument( 108 | "--vllm-server-url", 109 | type=str, 110 | default="http://localhost:8000/v1", 111 | help="URL of the vLLM server", 112 | ) 113 | parser.add_argument( 114 | "--temperature", 115 | type=float, 116 | help="Temperature for generation", 117 | ) 118 | parser.add_argument( 119 | "--top-p", 120 | type=float, 121 | help="Top-p value for generation", 122 | ) 123 | parser.add_argument( 124 | "--max-new-tokens", 125 | type=int, 126 | default=8192, 127 | help="Maximum number of new tokens to generate", 128 | ) 129 | parser.add_argument( 130 | "--num-generations", 131 | type=int, 132 | default=1, 133 | help="Number of generations per problem", 134 | ) 135 | parser.add_argument( 136 | "--input-batch-size", 137 | type=int, 138 | default=64, 139 | help="Batch size for input processing", 140 | ) 141 | parser.add_argument( 142 | "--client-replicas", 143 | type=int, 144 | default=1, 145 | help="Number of client replicas for parallel processing", 146 | ) 147 | parser.add_argument( 148 | "--timeout", 149 | type=int, 150 | default=600, 151 | help="Request timeout in seconds (default: 600)", 152 | ) 153 | parser.add_argument( 154 | "--retries", 155 | type=int, 156 | default=0, 157 | help="Number of retries for failed requests (default: 0)", 158 | ) 159 | parser.add_argument( 160 | "--hf-output-dataset", 161 | type=str, 162 | required=False, 163 | help="HuggingFace repo to push results to", 164 | ) 165 | parser.add_argument( 166 | "--private", 167 | action="store_true", 168 | help="Whether to make the output dataset private when pushing to HF Hub", 169 | ) 170 | 171 | args = parser.parse_args() 172 | 173 | print("\nRunning with arguments:") 174 | for arg, value in vars(args).items(): 175 | print(f" {arg}: {value}") 176 | print() 177 | 178 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...") 179 | dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split) 180 | print("Dataset loaded!") 181 | 182 | pipeline = build_distilabel_pipeline( 183 | model=args.model, 184 | base_url=args.vllm_server_url, 185 | prompt_template=args.prompt_template, 186 | prompt_column=args.prompt_column, 187 | temperature=args.temperature, 188 | top_p=args.top_p, 189 | max_new_tokens=args.max_new_tokens, 190 | num_generations=args.num_generations, 191 | input_batch_size=args.input_batch_size, 192 | client_replicas=args.client_replicas, 193 | timeout=args.timeout, 194 | retries=args.retries, 195 | ) 196 | 197 | print("Running generation pipeline...") 198 | distiset = pipeline.run( 199 | dataset=dataset, 200 | dataset_batch_size=args.input_batch_size * 1000, 201 | use_cache=False, 202 | ) 203 | print("Generation pipeline finished!") 204 | 205 | if args.hf_output_dataset: 206 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...") 207 | distiset.push_to_hub(args.hf_output_dataset, private=args.private) 208 | print("Dataset pushed!") 209 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/orpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Any, Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class ORPOConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`ORPOTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | learning_rate (`float`, *optional*, defaults to `1e-6`): 32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 33 | [`~transformers.TrainingArguments`]. 34 | max_length (`int` or `None`, *optional*, defaults to `1024`): 35 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want 36 | to use the default data collator. 37 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): 38 | Maximum length of the prompt. This argument is required if you want to use the default data collator. 39 | max_completion_length (`int` or `None`, *optional*, defaults to `None`): 40 | Maximum length of the completion. This argument is required if you want to use the default data collator 41 | and your model is an encoder-decoder. 42 | beta (`float`, *optional*, defaults to `0.1`): 43 | Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691), 44 | it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`. 45 | disable_dropout (`bool`, *optional*, defaults to `True`): 46 | Whether to disable dropout in the model. 47 | label_pad_token_id (`int`, *optional*, defaults to `-100`): 48 | Label pad token id. This argument is required if you want to use the default data collator. 49 | padding_value (`int` or `None`, *optional*, defaults to `None`): 50 | Padding value to use. If `None`, the padding value of the tokenizer is used. 51 | truncation_mode (`str`, *optional*, defaults to `"keep_end"`): 52 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. 53 | This argument is required if you want to use the default data collator. 54 | generate_during_eval (`bool`, *optional*, defaults to `False`): 55 | If `True`, generates and logs completions from the model to W&B or Comet during evaluation. 56 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): 57 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, 58 | you need to specify if the model returned by the callable is an encoder-decoder model. 59 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 60 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a 61 | string. 62 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 63 | Number of processes to use for processing the dataset. 64 | """ 65 | 66 | learning_rate: float = field( 67 | default=1e-6, 68 | metadata={ 69 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 70 | "transformers.TrainingArguments." 71 | }, 72 | ) 73 | max_length: Optional[int] = field( 74 | default=1024, 75 | metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, 76 | ) 77 | max_prompt_length: Optional[int] = field( 78 | default=512, 79 | metadata={ 80 | "help": "Maximum length of the prompt. This argument is required if you want to use the default data " 81 | "collator and your model is an encoder-decoder." 82 | }, 83 | ) 84 | max_completion_length: Optional[int] = field( 85 | default=None, 86 | metadata={ 87 | "help": "Maximum length of the completion. This argument is required if you want to use the default data " 88 | "collator and your model is an encoder-decoder." 89 | }, 90 | ) 91 | beta: float = field( 92 | default=0.1, 93 | metadata={ 94 | "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is " 95 | "denoted by λ." 96 | }, 97 | ) 98 | disable_dropout: bool = field( 99 | default=True, 100 | metadata={"help": "Whether to disable dropout in the model."}, 101 | ) 102 | label_pad_token_id: int = field( 103 | default=-100, 104 | metadata={ 105 | "help": "Label pad token id. This argument is required if you want to use the default data collator." 106 | }, 107 | ) 108 | padding_value: Optional[int] = field( 109 | default=None, 110 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, 111 | ) 112 | truncation_mode: str = field( 113 | default="keep_end", 114 | metadata={ 115 | "help": "Truncation mode to use when the prompt is too long.", 116 | "choices": ["keep_end", "keep_start"], 117 | }, 118 | ) 119 | generate_during_eval: bool = field( 120 | default=False, 121 | metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, 122 | ) 123 | is_encoder_decoder: Optional[bool] = field( 124 | default=None, 125 | metadata={ 126 | "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` " 127 | "argument, you need to specify if the model returned by the callable is an encoder-decoder model." 128 | }, 129 | ) 130 | model_init_kwargs: Optional[dict[str, Any]] = field( 131 | default=None, 132 | metadata={ 133 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " 134 | "from a string." 135 | }, 136 | ) 137 | dataset_num_proc: Optional[int] = field( 138 | default=None, 139 | metadata={"help": "Number of processes to use for processing the dataset."}, 140 | ) 141 | -------------------------------------------------------------------------------- /src/open_r1/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Supervised fine-tuning script for decoder language models. 17 | 18 | Usage: 19 | 20 | # One 1 node of 8 x H100s 21 | accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \ 22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \ 23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \ 24 | --learning_rate 2.0e-5 \ 25 | --num_train_epochs 1 \ 26 | --packing \ 27 | --max_seq_length 4096 \ 28 | --per_device_train_batch_size 2 \ 29 | --gradient_accumulation_steps 8 \ 30 | --gradient_checkpointing \ 31 | --bf16 \ 32 | --logging_steps 5 \ 33 | --eval_strategy steps \ 34 | --eval_steps 100 \ 35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill 36 | """ 37 | 38 | import logging 39 | import os 40 | import sys 41 | 42 | import datasets 43 | import torch 44 | import transformers 45 | from datasets import load_dataset 46 | from transformers import set_seed 47 | from transformers.trainer_utils import get_last_checkpoint 48 | 49 | from open_r1.configs import SFTConfig 50 | from open_r1.utils import get_tokenizer 51 | from open_r1.utils.callbacks import get_callbacks 52 | from open_r1.utils.wandb_logging import init_wandb_training 53 | from trl import ( 54 | ModelConfig, 55 | ScriptArguments, 56 | SFTTrainer, 57 | TrlParser, 58 | get_kbit_device_map, 59 | get_peft_config, 60 | get_quantization_config, 61 | ) 62 | 63 | 64 | logger = logging.getLogger(__name__) 65 | 66 | 67 | def main(script_args, training_args, model_args): 68 | # Set seed for reproducibility 69 | set_seed(training_args.seed) 70 | 71 | ############### 72 | # Setup logging 73 | ############### 74 | logging.basicConfig( 75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 76 | datefmt="%Y-%m-%d %H:%M:%S", 77 | handlers=[logging.StreamHandler(sys.stdout)], 78 | ) 79 | log_level = training_args.get_process_log_level() 80 | logger.setLevel(log_level) 81 | datasets.utils.logging.set_verbosity(log_level) 82 | transformers.utils.logging.set_verbosity(log_level) 83 | transformers.utils.logging.enable_default_handler() 84 | transformers.utils.logging.enable_explicit_format() 85 | 86 | logger.info(f"Model parameters {model_args}") 87 | logger.info(f"Script parameters {script_args}") 88 | logger.info(f"Training parameters {training_args}") 89 | 90 | # Check for last checkpoint 91 | last_checkpoint = None 92 | if os.path.isdir(training_args.output_dir): 93 | last_checkpoint = get_last_checkpoint(training_args.output_dir) 94 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None: 95 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") 96 | 97 | if "wandb" in training_args.report_to: 98 | init_wandb_training(training_args) 99 | 100 | ################ 101 | # Load datasets 102 | ################ 103 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 104 | 105 | ################ 106 | # Load tokenizer 107 | ################ 108 | tokenizer = get_tokenizer(model_args, training_args) 109 | tokenizer.pad_token = tokenizer.eos_token 110 | 111 | ################### 112 | # Model init kwargs 113 | ################### 114 | logger.info("*** Initializing model kwargs ***") 115 | torch_dtype = ( 116 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) 117 | ) 118 | quantization_config = get_quantization_config(model_args) 119 | model_kwargs = dict( 120 | revision=model_args.model_revision, 121 | trust_remote_code=model_args.trust_remote_code, 122 | attn_implementation=model_args.attn_implementation, 123 | torch_dtype=torch_dtype, 124 | use_cache=False if training_args.gradient_checkpointing else True, 125 | device_map=get_kbit_device_map() if quantization_config is not None else None, 126 | quantization_config=quantization_config, 127 | ) 128 | training_args.model_init_kwargs = model_kwargs 129 | 130 | ############################ 131 | # Initialize the SFT Trainer 132 | ############################ 133 | trainer = SFTTrainer( 134 | model=model_args.model_name_or_path, 135 | args=training_args, 136 | train_dataset=dataset[script_args.dataset_train_split], 137 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 138 | processing_class=tokenizer, 139 | peft_config=get_peft_config(model_args), 140 | callbacks=get_callbacks(training_args, model_args), 141 | ) 142 | 143 | ############### 144 | # Training loop 145 | ############### 146 | logger.info("*** Train ***") 147 | checkpoint = None 148 | if training_args.resume_from_checkpoint is not None: 149 | checkpoint = training_args.resume_from_checkpoint 150 | elif last_checkpoint is not None: 151 | checkpoint = last_checkpoint 152 | train_result = trainer.train(resume_from_checkpoint=checkpoint) 153 | metrics = train_result.metrics 154 | metrics["train_samples"] = len(dataset[script_args.dataset_train_split]) 155 | trainer.log_metrics("train", metrics) 156 | trainer.save_metrics("train", metrics) 157 | trainer.save_state() 158 | 159 | ################################## 160 | # Save model and create model card 161 | ################################## 162 | logger.info("*** Save model ***") 163 | trainer.save_model(training_args.output_dir) 164 | logger.info(f"Model saved to {training_args.output_dir}") 165 | 166 | # Save everything else on main process 167 | kwargs = { 168 | "dataset_name": script_args.dataset_name, 169 | "tags": ["open-r1"], 170 | } 171 | if trainer.accelerator.is_main_process: 172 | trainer.create_model_card(**kwargs) 173 | # Restore k,v cache for fast inference 174 | trainer.model.config.use_cache = True 175 | trainer.model.config.save_pretrained(training_args.output_dir) 176 | 177 | ########## 178 | # Evaluate 179 | ########## 180 | if training_args.do_eval: 181 | logger.info("*** Evaluate ***") 182 | metrics = trainer.evaluate() 183 | metrics["eval_samples"] = len(dataset[script_args.dataset_test_split]) 184 | trainer.log_metrics("eval", metrics) 185 | trainer.save_metrics("eval", metrics) 186 | 187 | ############# 188 | # push to hub 189 | ############# 190 | if training_args.push_to_hub: 191 | logger.info("Pushing to hub...") 192 | trainer.push_to_hub(**kwargs) 193 | 194 | 195 | if __name__ == "__main__": 196 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) 197 | script_args, training_args, model_args = parser.parse_args_and_config() 198 | main(script_args, training_args, model_args) 199 | -------------------------------------------------------------------------------- /src/open_r1/trl/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | __version__ = "0.15.2" 16 | 17 | from typing import TYPE_CHECKING 18 | 19 | from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available 20 | 21 | 22 | _import_structure = { 23 | "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"], 24 | "data_utils": [ 25 | "apply_chat_template", 26 | "extract_prompt", 27 | "is_conversational", 28 | "maybe_apply_chat_template", 29 | "maybe_convert_to_chatml", 30 | "maybe_extract_prompt", 31 | "maybe_unpair_preference_dataset", 32 | "pack_examples", 33 | "unpair_preference_dataset", 34 | ], 35 | "environment": ["TextEnvironment", "TextHistory"], 36 | "extras": ["BestOfNSampler"], 37 | "import_utils": [ 38 | "is_deepspeed_available", 39 | "is_diffusers_available", 40 | "is_llm_blender_available", 41 | "is_mergekit_available", 42 | "is_rich_available", 43 | "is_unsloth_available", 44 | "is_vllm_available", 45 | ], 46 | "models": [ 47 | "SUPPORTED_ARCHITECTURES", 48 | "AutoModelForCausalLMWithValueHead", 49 | "AutoModelForSeq2SeqLMWithValueHead", 50 | "PreTrainedModelWrapper", 51 | "create_reference_model", 52 | "setup_chat_format", 53 | ], 54 | "trainer": [ 55 | "AlignPropConfig", 56 | "AlignPropTrainer", 57 | "AllTrueJudge", 58 | "BaseBinaryJudge", 59 | "BaseJudge", 60 | "BasePairwiseJudge", 61 | "BaseRankJudge", 62 | "BCOConfig", 63 | "BCOTrainer", 64 | "CPOConfig", 65 | "CPOTrainer", 66 | "DataCollatorForCompletionOnlyLM", 67 | "DPOConfig", 68 | "DPOTrainer", 69 | "FDivergenceConstants", 70 | "FDivergenceType", 71 | "GKDConfig", 72 | "GKDTrainer", 73 | "GRPOConfig", 74 | "GRPOTrainer", 75 | "DRGRPOTrainer", 76 | "HfPairwiseJudge", 77 | "IterativeSFTTrainer", 78 | "KTOConfig", 79 | "KTOTrainer", 80 | "LogCompletionsCallback", 81 | "MergeModelCallback", 82 | "ModelConfig", 83 | "NashMDConfig", 84 | "NashMDTrainer", 85 | "OnlineDPOConfig", 86 | "OnlineDPOTrainer", 87 | "OpenAIPairwiseJudge", 88 | "ORPOConfig", 89 | "ORPOTrainer", 90 | "PairRMJudge", 91 | "PPOConfig", 92 | "PPOTrainer", 93 | "PRMConfig", 94 | "PRMTrainer", 95 | "RewardConfig", 96 | "RewardTrainer", 97 | "RLOOConfig", 98 | "RLOOTrainer", 99 | "SFTConfig", 100 | "SFTTrainer", 101 | "WinRateCallback", 102 | "XPOConfig", 103 | "XPOTrainer", 104 | ], 105 | "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], 106 | "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], 107 | } 108 | 109 | try: 110 | if not is_diffusers_available(): 111 | raise OptionalDependencyNotAvailable() 112 | except OptionalDependencyNotAvailable: 113 | pass 114 | else: 115 | _import_structure["models"].extend( 116 | [ 117 | "DDPOPipelineOutput", 118 | "DDPOSchedulerOutput", 119 | "DDPOStableDiffusionPipeline", 120 | "DefaultDDPOStableDiffusionPipeline", 121 | ] 122 | ) 123 | _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"]) 124 | 125 | if TYPE_CHECKING: 126 | from .data_utils import ( 127 | apply_chat_template, 128 | extract_prompt, 129 | is_conversational, 130 | maybe_apply_chat_template, 131 | maybe_convert_to_chatml, 132 | maybe_extract_prompt, 133 | maybe_unpair_preference_dataset, 134 | pack_examples, 135 | unpair_preference_dataset, 136 | ) 137 | from .environment import TextEnvironment, TextHistory 138 | from .extras import BestOfNSampler 139 | from .import_utils import ( 140 | is_deepspeed_available, 141 | is_diffusers_available, 142 | is_llm_blender_available, 143 | is_mergekit_available, 144 | is_rich_available, 145 | is_unsloth_available, 146 | is_vllm_available, 147 | ) 148 | from .models import ( 149 | SUPPORTED_ARCHITECTURES, 150 | AutoModelForCausalLMWithValueHead, 151 | AutoModelForSeq2SeqLMWithValueHead, 152 | PreTrainedModelWrapper, 153 | create_reference_model, 154 | setup_chat_format, 155 | ) 156 | from .scripts import ScriptArguments, TrlParser, init_zero_verbose 157 | from .trainer import ( 158 | AlignPropConfig, 159 | AlignPropTrainer, 160 | AllTrueJudge, 161 | BaseBinaryJudge, 162 | BaseJudge, 163 | BasePairwiseJudge, 164 | BaseRankJudge, 165 | BCOConfig, 166 | BCOTrainer, 167 | CPOConfig, 168 | CPOTrainer, 169 | DataCollatorForCompletionOnlyLM, 170 | DPOConfig, 171 | DPOTrainer, 172 | FDivergenceConstants, 173 | FDivergenceType, 174 | GKDConfig, 175 | GKDTrainer, 176 | GRPOConfig, 177 | GRPOTrainer, 178 | DRGRPOTrainer, 179 | HfPairwiseJudge, 180 | IterativeSFTTrainer, 181 | KTOConfig, 182 | KTOTrainer, 183 | LogCompletionsCallback, 184 | MergeModelCallback, 185 | ModelConfig, 186 | NashMDConfig, 187 | NashMDTrainer, 188 | OnlineDPOConfig, 189 | OnlineDPOTrainer, 190 | OpenAIPairwiseJudge, 191 | ORPOConfig, 192 | ORPOTrainer, 193 | PairRMJudge, 194 | PPOConfig, 195 | PPOTrainer, 196 | PRMConfig, 197 | PRMTrainer, 198 | RewardConfig, 199 | RewardTrainer, 200 | RLOOConfig, 201 | RLOOTrainer, 202 | SFTConfig, 203 | SFTTrainer, 204 | WinRateCallback, 205 | XPOConfig, 206 | XPOTrainer, 207 | ) 208 | from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback 209 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config 210 | 211 | try: 212 | if not is_diffusers_available(): 213 | raise OptionalDependencyNotAvailable() 214 | except OptionalDependencyNotAvailable: 215 | pass 216 | else: 217 | from .models import ( 218 | DDPOPipelineOutput, 219 | DDPOSchedulerOutput, 220 | DDPOStableDiffusionPipeline, 221 | DefaultDDPOStableDiffusionPipeline, 222 | ) 223 | from .trainer import DDPOConfig, DDPOTrainer 224 | 225 | else: 226 | import sys 227 | 228 | sys.modules[__name__] = _LazyModule( 229 | __name__, 230 | globals()["__file__"], 231 | _import_structure, 232 | module_spec=__spec__, 233 | extra_objects={"__version__": __version__}, 234 | ) 235 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/sft_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass, field 17 | from typing import Any, Optional 18 | 19 | from transformers import TrainingArguments 20 | 21 | 22 | @dataclass 23 | class SFTConfig(TrainingArguments): 24 | r""" 25 | Configuration class for the [`SFTTrainer`]. 26 | 27 | Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the 28 | [`~transformers.TrainingArguments`] documentation. 29 | 30 | Using [`~transformers.HfArgumentParser`] we can turn this class into 31 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 32 | command line. 33 | 34 | Parameters: 35 | > Parameters that control the model 36 | 37 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 38 | Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` 39 | argument of the [`SFTTrainer`] is provided as a string. 40 | use_liger (`bool`, *optional*, defaults to `False`): 41 | Monkey patch the model with Liger kernels to increase throughput and reduce memory usage. 42 | 43 | > Parameters that control the data preprocessing 44 | 45 | dataset_text_field (`str`, *optional*, defaults to `"text"`): 46 | Name of the column that contains text data in the dataset. 47 | dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 48 | Dictionary of optional keyword arguments for the dataset preparation. The only supported key is 49 | `skip_prepare_dataset`. 50 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 51 | Number of processes to use for processing the dataset. 52 | max_seq_length (`int` or `None`, *optional*, defaults to `1024`): 53 | Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the 54 | right. 55 | If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length. 56 | packing (`bool`, *optional*, defaults to `False`): 57 | Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence 58 | length. 59 | eval_packing (`bool` or `None`, *optional*, defaults to `None`): 60 | Whether to pack the eval dataset. If `None`, uses the same value as `packing`. 61 | 62 | > Parameters that control the training 63 | 64 | learning_rate (`float`, *optional*, defaults to `2e-5`): 65 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 66 | [`~transformers.TrainingArguments`]. 67 | """ 68 | 69 | # Parameters that control the model 70 | model_init_kwargs: Optional[dict[str, Any]] = field( 71 | default=None, 72 | metadata={ 73 | "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " 74 | "the `SFTTrainer` is provided as a string." 75 | }, 76 | ) 77 | use_liger: bool = field( 78 | default=False, 79 | metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."}, 80 | ) 81 | 82 | # Parameters that control the data preprocessing 83 | dataset_text_field: str = field( 84 | default="text", 85 | metadata={"help": "Name of the column that contains text data in the dataset."}, 86 | ) 87 | dataset_kwargs: Optional[dict[str, Any]] = field( 88 | default=None, 89 | metadata={ 90 | "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is " 91 | "`skip_prepare_dataset`." 92 | }, 93 | ) 94 | dataset_num_proc: Optional[int] = field( 95 | default=None, 96 | metadata={"help": "Number of processes to use for processing the dataset."}, 97 | ) 98 | max_seq_length: Optional[int] = field( 99 | default=1024, 100 | metadata={ 101 | "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated " 102 | "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the " 103 | "sequence length." 104 | }, 105 | ) 106 | packing: bool = field( 107 | default=False, 108 | metadata={ 109 | "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to " 110 | "define sequence length." 111 | }, 112 | ) 113 | eval_packing: Optional[bool] = field( 114 | default=None, 115 | metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."}, 116 | ) 117 | 118 | # Parameters that control the training 119 | learning_rate: float = field( 120 | default=2.0e-5, 121 | metadata={ 122 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 123 | "`TrainingArguments`." 124 | }, 125 | ) 126 | 127 | # Deprecated parameters 128 | dataset_batch_size: int = field( 129 | default=None, 130 | metadata={"help": "Deprecated. You can safely remove this parameter from your configuration."}, 131 | ) 132 | num_of_sequences: int = field( 133 | default=None, 134 | metadata={ 135 | "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized " 136 | "sequence, unlike `num_of_sequences`, which referred to string sequences." 137 | }, 138 | ) 139 | chars_per_token: float = field( 140 | default=None, 141 | metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."}, 142 | ) 143 | 144 | def __post_init__(self): 145 | super().__post_init__() 146 | 147 | if self.dataset_batch_size is not None: 148 | warnings.warn( 149 | "`dataset_batch_size` is deprecated and will be remove in version 0.18.0. You can safely remove this " 150 | "parameter from your configuration.", 151 | DeprecationWarning, 152 | ) 153 | 154 | if self.num_of_sequences is not None: 155 | warnings.warn( 156 | "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, " 157 | "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r" 158 | "eferred to string sequences.", 159 | DeprecationWarning, 160 | ) 161 | 162 | if self.chars_per_token is not None: 163 | warnings.warn( 164 | "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the " 165 | "packing length, use `max_seq_length`.", 166 | DeprecationWarning, 167 | ) 168 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/model_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | 19 | @dataclass 20 | class ModelConfig: 21 | """ 22 | Configuration class for the models. 23 | 24 | Using [`~transformers.HfArgumentParser`] we can turn this class into 25 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 26 | command line. 27 | 28 | Parameters: 29 | model_name_or_path (`str` or `None`, *optional*, defaults to `None`): 30 | Model checkpoint for weights initialization. 31 | model_revision (`str`, *optional*, defaults to `"main"`): 32 | Specific model version to use. It can be a branch name, a tag name, or a commit id. 33 | torch_dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`): 34 | Override the default `torch.dtype` and load the model under this dtype. Possible values are 35 | 36 | - `"bfloat16"`: `torch.bfloat16` 37 | - `"float16"`: `torch.float16` 38 | - `"float32"`: `torch.float32` 39 | - `"auto"`: Automatically derive the dtype from the model's weights. 40 | 41 | trust_remote_code (`bool`, *optional*, defaults to `False`): 42 | Whether to allow for custom models defined on the Hub in their own modeling files. This option should only 43 | be set to `True` for repositories you trust and in which you have read the code, as it will execute code 44 | present on the Hub on your local machine. 45 | attn_implementation (`str` or `None`, *optional*, defaults to `None`): 46 | Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case 47 | you must install this manually by running `pip install flash-attn --no-build-isolation`. 48 | use_peft (`bool`, *optional*, defaults to `False`): 49 | Whether to use PEFT for training. 50 | lora_r (`int`, *optional*, defaults to `16`): 51 | LoRA R value. 52 | lora_alpha (`int`, *optional*, defaults to `32`): 53 | LoRA alpha. 54 | lora_dropout (`float`, *optional*, defaults to `0.05`): 55 | LoRA dropout. 56 | lora_target_modules (`Union[str, list[str]]` or `None`, *optional*, defaults to `None`): 57 | LoRA target modules. 58 | lora_modules_to_save (`list[str]` or `None`, *optional*, defaults to `None`): 59 | Model layers to unfreeze & train. 60 | lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`): 61 | Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling). 62 | use_rslora (`bool`, *optional*, defaults to `False`): 63 | Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of 64 | the original default value of `lora_alpha/r`. 65 | load_in_8bit (`bool`, *optional*, defaults to `False`): 66 | Whether to use 8 bit precision for the base model. Works only with LoRA. 67 | load_in_4bit (`bool`, *optional*, defaults to `False`): 68 | Whether to use 4 bit precision for the base model. Works only with LoRA. 69 | bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`): 70 | Quantization type (`"fp4"` or `"nf4"`). 71 | use_bnb_nested_quant (`bool`, *optional*, defaults to `False`): 72 | Whether to use nested quantization. 73 | """ 74 | 75 | model_name_or_path: Optional[str] = field( 76 | default=None, 77 | metadata={"help": "Model checkpoint for weights initialization."}, 78 | ) 79 | model_revision: str = field( 80 | default="main", 81 | metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."}, 82 | ) 83 | torch_dtype: Optional[str] = field( 84 | default=None, 85 | metadata={ 86 | "help": "Override the default `torch.dtype` and load the model under this dtype.", 87 | "choices": ["auto", "bfloat16", "float16", "float32"], 88 | }, 89 | ) 90 | trust_remote_code: bool = field( 91 | default=False, 92 | metadata={ 93 | "help": "Whether to allow for custom models defined on the Hub in their own modeling files. This option " 94 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will " 95 | "execute code present on the Hub on your local machine." 96 | }, 97 | ) 98 | attn_implementation: Optional[str] = field( 99 | default=None, 100 | metadata={ 101 | "help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in " 102 | "which case you must install this manually by running `pip install flash-attn --no-build-isolation`." 103 | }, 104 | ) 105 | use_peft: bool = field( 106 | default=False, 107 | metadata={"help": "Whether to use PEFT for training."}, 108 | ) 109 | lora_r: int = field( 110 | default=16, 111 | metadata={"help": "LoRA R value."}, 112 | ) 113 | lora_alpha: int = field( 114 | default=32, 115 | metadata={"help": "LoRA alpha."}, 116 | ) 117 | lora_dropout: float = field( 118 | default=0.05, 119 | metadata={"help": "LoRA dropout."}, 120 | ) 121 | lora_target_modules: Optional[list[str]] = field( 122 | default=None, 123 | metadata={"help": "LoRA target modules."}, 124 | ) 125 | lora_modules_to_save: Optional[list[str]] = field( 126 | default=None, 127 | metadata={"help": "Model layers to unfreeze & train."}, 128 | ) 129 | lora_task_type: str = field( 130 | default="CAUSAL_LM", 131 | metadata={"help": "Task type to pass for LoRA (use 'SEQ_CLS' for reward modeling)."}, 132 | ) 133 | use_rslora: bool = field( 134 | default=False, 135 | metadata={ 136 | "help": "Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, " 137 | "instead of the original default value of `lora_alpha/r`." 138 | }, 139 | ) 140 | load_in_8bit: bool = field( 141 | default=False, 142 | metadata={"help": "Whether to use 8 bit precision for the base model. Works only with LoRA."}, 143 | ) 144 | load_in_4bit: bool = field( 145 | default=False, 146 | metadata={"help": "Whether to use 4 bit precision for the base model. Works only with LoRA."}, 147 | ) 148 | bnb_4bit_quant_type: str = field( 149 | default="nf4", 150 | metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]}, 151 | ) 152 | use_bnb_nested_quant: bool = field( 153 | default=False, 154 | metadata={"help": "Whether to use nested quantization."}, 155 | ) 156 | 157 | def __post_init__(self): 158 | if self.load_in_8bit and self.load_in_4bit: 159 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time") 160 | 161 | if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1: 162 | self.lora_target_modules = self.lora_target_modules[0] 163 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/online_dpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class OnlineDPOConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`OnlineDPOTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | learning_rate (`float`, *optional*, defaults to `5e-7`): 32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 33 | [`~transformers.TrainingArguments`]. 34 | reward_model_path (`str` or `None`, *optional*, defaults to `None`): 35 | Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both. 36 | judge (`str` or `None`, *optional*, defaults to `None`): 37 | Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both. 38 | max_new_tokens (`int`, *optional*, defaults to `64`): 39 | Maximum number of tokens to generate per completion. 40 | max_length (`int`, *optional*, defaults to `256`): 41 | Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the 42 | sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as 43 | possible. 44 | temperature (`float`, *optional*, defaults to `0.9`): 45 | Temperature for sampling. The higher the temperature, the more random the completions. 46 | missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`): 47 | Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage 48 | to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive 49 | value. 50 | beta (`float` or `list[float]`, *optional*, defaults to `0.1`): 51 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the 52 | reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in 53 | the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is 54 | selected for each new epoch and the last β is used for the rest of the epochs. 55 | loss_type (`str`, *optional*, defaults to `"sigmoid"`): 56 | Type of loss to use. Possible values are: 57 | 58 | - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. 59 | - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. 60 | 61 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 62 | Number of processes to use for processing the dataset. 63 | disable_dropout (`bool`, *optional*, defaults to `True`): 64 | Whether to disable dropout in the model and reference model. 65 | use_vllm (`bool`, *optional*, defaults to `False`): 66 | Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`). 67 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`): 68 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation, 69 | improving generation speed. However, disabling this option allows training models that exceed the VRAM 70 | capacity of a single GPU, albeit at the cost of slower generation. 71 | """ 72 | 73 | learning_rate: float = field( 74 | default=5e-7, 75 | metadata={ 76 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 77 | "transformers.TrainingArguments." 78 | }, 79 | ) 80 | reward_model_path: Optional[str] = field( 81 | default=None, 82 | metadata={ 83 | "help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both." 84 | }, 85 | ) 86 | judge: Optional[str] = field( 87 | default=None, 88 | metadata={ 89 | "help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both." 90 | }, 91 | ) 92 | max_new_tokens: int = field( 93 | default=64, 94 | metadata={"help": "Maximum number of tokens to generate per completion."}, 95 | ) 96 | max_length: int = field( 97 | default=512, 98 | metadata={ 99 | "help": "Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If " 100 | "the sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the " 101 | "completion as possible." 102 | }, 103 | ) 104 | temperature: float = field( 105 | default=0.9, 106 | metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."}, 107 | ) 108 | missing_eos_penalty: Optional[float] = field( 109 | default=None, 110 | metadata={ 111 | "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to " 112 | "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be " 113 | "a positive value." 114 | }, 115 | ) 116 | beta: list[float] = field( 117 | default_factory=lambda: [0.1], 118 | metadata={ 119 | "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " 120 | "the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by " 121 | "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β " 122 | "is selected for each new epoch and the last β is used for the rest of the epochs." 123 | }, 124 | ) 125 | loss_type: str = field( 126 | default="sigmoid", 127 | metadata={ 128 | "help": "Type of loss to use.", 129 | "choices": ["sigmoid", "ipo"], 130 | }, 131 | ) 132 | dataset_num_proc: Optional[int] = field( 133 | default=None, 134 | metadata={"help": "Number of processes to use for processing the dataset."}, 135 | ) 136 | disable_dropout: bool = field( 137 | default=True, 138 | metadata={"help": "Whether to disable dropout in the model."}, 139 | ) 140 | use_vllm: bool = field( 141 | default=False, 142 | metadata={ 143 | "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed " 144 | "(`pip install vllm`)." 145 | }, 146 | ) 147 | ds3_gather_for_generation: bool = field( 148 | default=True, 149 | metadata={ 150 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for " 151 | "generation, improving generation speed. However, disabling this option allows training models that " 152 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation." 153 | }, 154 | ) 155 | 156 | def __post_init__(self): 157 | super().__post_init__() 158 | if hasattr(self.beta, "__len__") and len(self.beta) == 1: 159 | self.beta = self.beta[0] 160 | -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/cpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Any, Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class CPOConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`CPOTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | learning_rate (`float`, *optional*, defaults to `1e-6`): 32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of 33 | [`~transformers.TrainingArguments`]. 34 | max_length (`int` or `None`, *optional*, defaults to `1024`): 35 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want 36 | to use the default data collator. 37 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): 38 | Maximum length of the prompt. This argument is required if you want to use the default data collator. 39 | max_completion_length (`int` or `None`, *optional*, defaults to `None`): 40 | Maximum length of the completion. This argument is required if you want to use the default data collator 41 | and your model is an encoder-decoder. 42 | beta (`float`, *optional*, defaults to `0.1`): 43 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the 44 | reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in 45 | the [paper](https://huggingface.co/papers/2310.12036). 46 | label_smoothing (`float`, *optional*, defaults to `0.0`): 47 | Label smoothing factor. This argument is required if you want to use the default data collator. 48 | loss_type (`str`, *optional*, defaults to `"sigmoid"`): 49 | Type of loss to use. Possible values are: 50 | 51 | - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper. 52 | - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper. 53 | - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper. 54 | - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper. 55 | 56 | disable_dropout (`bool`, *optional*, defaults to `True`): 57 | Whether to disable dropout in the model. 58 | cpo_alpha (`float`, *optional*, defaults to `1.0`): 59 | Weight of the BC regularizer in CPO training. 60 | simpo_gamma (`float`, *optional*, defaults to `0.5`): 61 | Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`. 62 | label_pad_token_id (`int`, *optional*, defaults to `-100`): 63 | Label pad token id. This argument is required if you want to use the default data collator. 64 | padding_value (`int` or `None`, *optional*, defaults to `None`): 65 | Padding value to use. If `None`, the padding value of the tokenizer is used. 66 | truncation_mode (`str`,*optional*, defaults to `"keep_end"`): 67 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. 68 | This argument is required if you want to use the default data collator. 69 | generate_during_eval (`bool`, *optional*, defaults to `False`): 70 | If `True`, generates and logs completions from the model to W&B or Comet during evaluation. 71 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): 72 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, 73 | you need to specify if the model returned by the callable is an encoder-decoder model. 74 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 75 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a 76 | string. 77 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 78 | Number of processes to use for processing the dataset. 79 | """ 80 | 81 | learning_rate: float = field( 82 | default=1e-6, 83 | metadata={ 84 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of " 85 | "`transformers.TrainingArguments`." 86 | }, 87 | ) 88 | max_length: Optional[int] = field( 89 | default=1024, 90 | metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."}, 91 | ) 92 | max_prompt_length: Optional[int] = field( 93 | default=512, 94 | metadata={ 95 | "help": "Maximum length of the prompt. This argument is required if you want to use the default data " 96 | "collator and your model is an encoder-decoder." 97 | }, 98 | ) 99 | max_completion_length: Optional[int] = field( 100 | default=None, 101 | metadata={ 102 | "help": "Maximum length of the completion. This argument is required if you want to use the default data " 103 | "collator and your model is an encoder-decoder." 104 | }, 105 | ) 106 | beta: float = field( 107 | default=0.1, 108 | metadata={ 109 | "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from " 110 | "the reference model." 111 | }, 112 | ) 113 | label_smoothing: float = field( 114 | default=0.0, 115 | metadata={"help": "Label smoothing factor."}, 116 | ) 117 | loss_type: str = field( 118 | default="sigmoid", 119 | metadata={ 120 | "help": "Type of loss to use.", 121 | "choices": ["sigmoid", "hinge", "ipo", "simpo"], 122 | }, 123 | ) 124 | disable_dropout: bool = field( 125 | default=True, 126 | metadata={"help": "Whether to disable dropout in the model."}, 127 | ) 128 | cpo_alpha: float = field( 129 | default=1.0, 130 | metadata={"help": "Weight of the BC regularizer in CPO training."}, 131 | ) 132 | simpo_gamma: float = field( 133 | default=0.5, 134 | metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."}, 135 | ) 136 | label_pad_token_id: int = field( 137 | default=-100, 138 | metadata={"help": "Label pad token id."}, 139 | ) 140 | padding_value: Optional[int] = field( 141 | default=None, 142 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, 143 | ) 144 | truncation_mode: str = field( 145 | default="keep_end", 146 | metadata={ 147 | "help": "Truncation mode to use when the prompt is too long.", 148 | "choices": ["keep_end", "keep_start"], 149 | }, 150 | ) 151 | generate_during_eval: bool = field( 152 | default=False, 153 | metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."}, 154 | ) 155 | is_encoder_decoder: Optional[bool] = field( 156 | default=None, 157 | metadata={"help": "Whether the model is an encoder-decoder model."}, 158 | ) 159 | model_init_kwargs: Optional[dict[str, Any]] = field( 160 | default=None, 161 | metadata={ 162 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model " 163 | "from a string." 164 | }, 165 | ) 166 | dataset_num_proc: Optional[int] = field( 167 | default=None, 168 | metadata={"help": "Number of processes to use for processing the dataset."}, 169 | ) 170 | -------------------------------------------------------------------------------- /src/open_r1/evaluate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Custom evaluation tasks for LightEval.""" 16 | 17 | import random 18 | 19 | from lighteval.metrics.dynamic_metrics import ( 20 | ExprExtractionConfig, 21 | IndicesExtractionConfig, 22 | LatexExtractionConfig, 23 | multilingual_extractive_match_metric, 24 | ) 25 | from lighteval.tasks.lighteval_task import LightevalTaskConfig 26 | from lighteval.tasks.requests import Doc 27 | from lighteval.utils.language import Language 28 | 29 | 30 | # Prompt template adapted from 31 | # - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17 32 | # - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details 33 | # Note that it is important to have the final answer in a box for math-verify to work correctly 34 | MATH_QUERY_TEMPLATE = """ 35 | Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering. 36 | 37 | {Question} 38 | """.strip() 39 | 40 | # Prompt template from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14 41 | GPQA_QUERY_TEMPLATE = """ 42 | Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering. 43 | 44 | {Question} 45 | 46 | A) {A} 47 | B) {B} 48 | C) {C} 49 | D) {D} 50 | """.strip() 51 | 52 | latex_gold_metric = multilingual_extractive_match_metric( 53 | language=Language.ENGLISH, 54 | fallback_mode="first_match", 55 | precision=5, 56 | gold_extraction_target=(LatexExtractionConfig(),), 57 | # Match boxed first before trying other regexes 58 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), 59 | aggregation_function=max, 60 | ) 61 | 62 | expr_gold_metric = multilingual_extractive_match_metric( 63 | language=Language.ENGLISH, 64 | fallback_mode="first_match", 65 | precision=5, 66 | gold_extraction_target=(ExprExtractionConfig(),), 67 | # Match boxed first before trying other regexes 68 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)), 69 | aggregation_function=max, 70 | ) 71 | 72 | gpqa_metric = multilingual_extractive_match_metric( 73 | language=Language.ENGLISH, 74 | gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], 75 | pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")], 76 | precision=5, 77 | ) 78 | 79 | 80 | def math_prompt_fn(line, task_name: str = None): 81 | return Doc( 82 | task_name=task_name, 83 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), 84 | choices=[line["solution"]], 85 | gold_index=0, 86 | ) 87 | 88 | 89 | def aime_prompt_fn(line, task_name: str = None): 90 | return Doc( 91 | task_name=task_name, 92 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), 93 | choices=[line["answer"]], 94 | gold_index=0, 95 | ) 96 | 97 | 98 | def amc_prompt_fn(line, task_name: str = None): 99 | return Doc( 100 | task_name=task_name, 101 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), 102 | choices=[line["answer"]], 103 | gold_index=0, 104 | ) 105 | 106 | 107 | def minerva_prompt_fn(line, task_name: str = None): 108 | return Doc( 109 | task_name=task_name, 110 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]), 111 | choices=[line["solution"]], 112 | gold_index=0, 113 | ) 114 | 115 | 116 | def olympiadbench_prompt_fn(line, task_name: str = None): 117 | return Doc( 118 | task_name=task_name, 119 | query=MATH_QUERY_TEMPLATE.format(Question=line["question"]), 120 | choices=[line["answer"]], 121 | gold_index=0, 122 | ) 123 | 124 | 125 | def gpqa_prompt_fn(line, task_name: str = None): 126 | gold_index = random.randint(0, 3) 127 | choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]] 128 | choices.insert(gold_index, line["Correct Answer"]) 129 | query = GPQA_QUERY_TEMPLATE.format( 130 | A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"] 131 | ) 132 | return Doc( 133 | task_name=task_name, 134 | query=query, 135 | choices=["A", "B", "C", "D"], 136 | gold_index=gold_index, 137 | instruction=query, 138 | ) 139 | 140 | 141 | # Define tasks 142 | aime24 = LightevalTaskConfig( 143 | name="aime24", 144 | suite=["custom"], 145 | prompt_function=aime_prompt_fn, 146 | hf_repo="HuggingFaceH4/aime_2024", 147 | hf_subset="default", 148 | hf_avail_splits=["train"], 149 | evaluation_splits=["train"], 150 | few_shots_split=None, 151 | few_shots_select=None, 152 | generation_size=32768, 153 | metric=[expr_gold_metric], 154 | version=1, 155 | ) 156 | aime25 = LightevalTaskConfig( 157 | name="aime25", 158 | suite=["custom"], 159 | prompt_function=aime_prompt_fn, 160 | hf_repo="yentinglin/aime_2025", 161 | hf_subset="default", 162 | hf_avail_splits=["train"], 163 | evaluation_splits=["train"], 164 | few_shots_split=None, 165 | few_shots_select=None, 166 | generation_size=32768, 167 | metric=[expr_gold_metric], 168 | version=1, 169 | ) 170 | math_500 = LightevalTaskConfig( 171 | name="math_500", 172 | suite=["custom"], 173 | prompt_function=math_prompt_fn, 174 | hf_repo="HuggingFaceH4/MATH-500", 175 | hf_subset="default", 176 | hf_avail_splits=["test"], 177 | evaluation_splits=["test"], 178 | few_shots_split=None, 179 | few_shots_select=None, 180 | generation_size=32768, 181 | metric=[latex_gold_metric], 182 | version=1, 183 | ) 184 | gpqa_diamond = LightevalTaskConfig( 185 | name="gpqa:diamond", 186 | suite=["custom"], 187 | prompt_function=gpqa_prompt_fn, 188 | hf_repo="Idavidrein/gpqa", 189 | hf_subset="gpqa_diamond", 190 | hf_avail_splits=["train"], 191 | evaluation_splits=["train"], 192 | few_shots_split=None, 193 | few_shots_select=None, 194 | generation_size=32768, # needed for reasoning models like R1 195 | metric=[gpqa_metric], 196 | stop_sequence=[], # no stop sequence, will use eos token 197 | trust_dataset=True, 198 | version=1, 199 | ) 200 | minerva = LightevalTaskConfig( 201 | name="minerva", 202 | suite=["custom"], 203 | prompt_function=minerva_prompt_fn, 204 | hf_repo="knoveleng/Minerva-Math", 205 | hf_subset="default", 206 | hf_avail_splits=["train"], 207 | evaluation_splits=["train"], 208 | few_shots_split=None, 209 | few_shots_select=None, 210 | generation_size=32768, 211 | metric=[latex_gold_metric], 212 | version=1, 213 | ) 214 | amc23 = LightevalTaskConfig( 215 | name="amc23", 216 | suite=["custom"], 217 | prompt_function=amc_prompt_fn, 218 | hf_repo="knoveleng/AMC-23", 219 | hf_subset="default", 220 | hf_avail_splits=["train"], 221 | evaluation_splits=["train"], 222 | few_shots_split=None, 223 | few_shots_select=None, 224 | generation_size=32768, 225 | metric=[expr_gold_metric], 226 | version=1, 227 | ) 228 | olympiadbench = LightevalTaskConfig( 229 | name="olympiadbench", 230 | suite=["custom"], 231 | prompt_function=olympiadbench_prompt_fn, 232 | hf_repo="knoveleng/OlympiadBench", 233 | hf_subset="default", 234 | hf_avail_splits=["train"], 235 | evaluation_splits=["train"], 236 | few_shots_split=None, 237 | few_shots_select=None, 238 | generation_size=32768, 239 | metric=[latex_gold_metric], 240 | version=1, 241 | ) 242 | 243 | # Add tasks to the table 244 | TASKS_TABLE = [] 245 | TASKS_TABLE.append(aime24) 246 | TASKS_TABLE.append(aime25) 247 | TASKS_TABLE.append(math_500) 248 | TASKS_TABLE.append(gpqa_diamond) 249 | TASKS_TABLE.append(minerva) 250 | TASKS_TABLE.append(amc23) 251 | TASKS_TABLE.append(olympiadbench) 252 | 253 | # MODULE LOGIC 254 | if __name__ == "__main__": 255 | print([t["name"] for t in TASKS_TABLE]) 256 | print(len(TASKS_TABLE)) -------------------------------------------------------------------------------- /src/open_r1/trl/trainer/bco_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Any, Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class BCOConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`BCOTrainer`]. 25 | 26 | Using [`~transformers.HfArgumentParser`] we can turn this class into 27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 28 | command line. 29 | 30 | Parameters: 31 | max_length (`int` or `None`, *optional*, defaults to `1024`): 32 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want 33 | to use the default data collator. 34 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`): 35 | Maximum length of the prompt. This argument is required if you want to use the default data collator. 36 | max_completion_length (`int` or `None`, *optional*, defaults to `None`): 37 | Maximum length of the completion. This argument is required if you want to use the default data collator 38 | and your model is an encoder-decoder. 39 | beta (`float`, *optional*, defaults to `0.1`): 40 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the 41 | reference model. 42 | label_pad_token_id (`int`, *optional*, defaults to `-100`): 43 | Label pad token id. This argument is required if you want to use the default data collator. 44 | padding_value (`int` or `None`, *optional*, defaults to `None`): 45 | Padding value to use. If `None`, the padding value of the tokenizer is used. 46 | truncation_mode (`str`, *optional*, defaults to `"keep_end"`): 47 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`. 48 | This argument is required if you want to use the default data collator. 49 | disable_dropout (`bool`, *optional*, defaults to `True`): 50 | Whether to disable dropout in the model and reference model. 51 | generate_during_eval (`bool`, *optional*, defaults to `False`): 52 | If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during 53 | evaluation. 54 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`): 55 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument, 56 | you need to specify if the model returned by the callable is an encoder-decoder model. 57 | precompute_ref_log_probs (`bool`, *optional*, defaults to `False`): 58 | Whether to precompute reference model log probabilities for training and evaluation datasets. This is 59 | useful when training without the reference model to reduce the total GPU memory needed. 60 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 61 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a 62 | string. 63 | ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 64 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model 65 | from a string. 66 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 67 | Number of processes to use for processing the dataset. 68 | prompt_sample_size (`int`, *optional*, defaults to `1024`): 69 | Number of prompts that are fed to density ratio classifier. 70 | min_density_ratio (`float`, *optional*, defaults to `0.5`): 71 | Minimum value of the density ratio. The estimated density ratio is clamped to this value. 72 | max_density_ratio (`float`, *optional*, defaults to `10.0`): 73 | Maximum value of the density ratio. The estimated density ratio is clamped to this value. 74 | """ 75 | 76 | max_length: Optional[int] = field( 77 | default=1024, 78 | metadata={ 79 | "help": "Maximum length of the sequences (prompt + completion) in the batch. " 80 | "This argument is required if you want to use the default data collator." 81 | }, 82 | ) 83 | max_prompt_length: Optional[int] = field( 84 | default=512, 85 | metadata={ 86 | "help": "Maximum length of the prompt. " 87 | "This argument is required if you want to use the default data collator." 88 | }, 89 | ) 90 | max_completion_length: Optional[int] = field( 91 | default=None, 92 | metadata={ 93 | "help": "Maximum length of the completion. This argument is required if you want to use the " 94 | "default data collator and your model is an encoder-decoder." 95 | }, 96 | ) 97 | beta: float = field( 98 | default=0.1, 99 | metadata={ 100 | "help": "Parameter controlling the deviation from the reference model. " 101 | "Higher β means less deviation from the reference model." 102 | }, 103 | ) 104 | label_pad_token_id: int = field( 105 | default=-100, 106 | metadata={ 107 | "help": "Label pad token id. This argument is required if you want to use the default data collator." 108 | }, 109 | ) 110 | padding_value: Optional[int] = field( 111 | default=None, 112 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."}, 113 | ) 114 | truncation_mode: str = field( 115 | default="keep_end", 116 | metadata={ 117 | "help": "Truncation mode to use when the prompt is too long. Possible values are " 118 | "`keep_end` or `keep_start`. This argument is required if you want to use the " 119 | "default data collator." 120 | }, 121 | ) 122 | disable_dropout: bool = field( 123 | default=True, 124 | metadata={"help": "Whether to disable dropout in the model and reference model."}, 125 | ) 126 | generate_during_eval: bool = field( 127 | default=False, 128 | metadata={ 129 | "help": "If `True`, generates and logs completions from both the model and the reference model " 130 | "to W&B during evaluation." 131 | }, 132 | ) 133 | is_encoder_decoder: Optional[bool] = field( 134 | default=None, 135 | metadata={ 136 | "help": "When using the `model_init` argument (callable) to instantiate the model instead of the " 137 | "`model` argument, you need to specify if the model returned by the callable is an " 138 | "encoder-decoder model." 139 | }, 140 | ) 141 | precompute_ref_log_probs: bool = field( 142 | default=False, 143 | metadata={ 144 | "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. " 145 | "This is useful when training without the reference model to reduce the total GPU memory " 146 | "needed." 147 | }, 148 | ) 149 | model_init_kwargs: Optional[dict[str, Any]] = field( 150 | default=None, 151 | metadata={ 152 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " 153 | "model from a string." 154 | }, 155 | ) 156 | ref_model_init_kwargs: Optional[dict[str, Any]] = field( 157 | default=None, 158 | metadata={ 159 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the " 160 | "reference model from a string." 161 | }, 162 | ) 163 | dataset_num_proc: Optional[int] = field( 164 | default=None, 165 | metadata={"help": "Number of processes to use for processing the dataset."}, 166 | ) 167 | prompt_sample_size: int = field( 168 | default=1024, 169 | metadata={"help": "Number of prompts that are fed to density ratio classifier."}, 170 | ) 171 | min_density_ratio: float = field( 172 | default=0.5, 173 | metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."}, 174 | ) 175 | max_density_ratio: float = field( 176 | default=10.0, 177 | metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."}, 178 | ) 179 | --------------------------------------------------------------------------------