├── src
├── readme.md
├── open_r1
│ ├── readme.md
│ ├── utils
│ │ ├── __pycache__
│ │ │ ├── hub.cpython-311.pyc
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── callbacks.cpython-311.pyc
│ │ │ ├── evaluation.cpython-311.pyc
│ │ │ ├── import_utils.cpython-311.pyc
│ │ │ ├── model_utils.cpython-311.pyc
│ │ │ └── wandb_logging.cpython-311.pyc
│ │ ├── __init__.py
│ │ ├── wandb_logging.py
│ │ ├── import_utils.py
│ │ ├── model_utils.py
│ │ ├── callbacks.py
│ │ ├── evaluation.py
│ │ └── hub.py
│ ├── trl
│ │ ├── __pycache__
│ │ │ ├── __init__.cpython-311.pyc
│ │ │ ├── data_utils.cpython-311.pyc
│ │ │ ├── import_utils.cpython-311.pyc
│ │ │ └── mergekit_utils.cpython-311.pyc
│ │ ├── models
│ │ │ ├── __pycache__
│ │ │ │ ├── utils.cpython-311.pyc
│ │ │ │ ├── __init__.cpython-311.pyc
│ │ │ │ ├── modeling_base.cpython-311.pyc
│ │ │ │ └── modeling_value_head.cpython-311.pyc
│ │ │ ├── __init__.py
│ │ │ ├── auxiliary_modules.py
│ │ │ └── sd_utils.py
│ │ ├── scripts
│ │ │ ├── __pycache__
│ │ │ │ ├── utils.cpython-311.pyc
│ │ │ │ └── __init__.cpython-311.pyc
│ │ │ ├── __init__.py
│ │ │ ├── env.py
│ │ │ ├── grpo.py
│ │ │ ├── sft.py
│ │ │ ├── kto.py
│ │ │ └── dpo.py
│ │ ├── trainer
│ │ │ ├── __pycache__
│ │ │ │ ├── judges.cpython-311.pyc
│ │ │ │ ├── utils.cpython-311.pyc
│ │ │ │ ├── __init__.cpython-311.pyc
│ │ │ │ ├── callbacks.cpython-311.pyc
│ │ │ │ ├── grpo_config.cpython-311.pyc
│ │ │ │ ├── sft_config.cpython-311.pyc
│ │ │ │ ├── grpo_trainer.cpython-311.pyc
│ │ │ │ ├── model_config.cpython-311.pyc
│ │ │ │ └── drgrpo_trainer.cpython-311.pyc
│ │ │ ├── xpo_config.py
│ │ │ ├── nash_md_config.py
│ │ │ ├── reward_config.py
│ │ │ ├── prm_config.py
│ │ │ ├── rloo_config.py
│ │ │ ├── gkd_config.py
│ │ │ ├── ppo_config.py
│ │ │ ├── __init__.py
│ │ │ ├── orpo_config.py
│ │ │ ├── sft_config.py
│ │ │ ├── model_config.py
│ │ │ ├── online_dpo_config.py
│ │ │ ├── cpo_config.py
│ │ │ └── bco_config.py
│ │ ├── extras
│ │ │ ├── __init__.py
│ │ │ ├── dataset_formatting.py
│ │ │ └── best_of_n_sampler.py
│ │ ├── environment
│ │ │ └── __init__.py
│ │ ├── templates
│ │ │ └── lm_model_card.md
│ │ ├── cli.py
│ │ ├── import_utils.py
│ │ ├── core.py
│ │ └── __init__.py
│ ├── __init__.py
│ ├── configs.py
│ ├── generate.py
│ ├── sft.py
│ └── evaluate.py
└── open_r1.egg-info
│ ├── not-zip-safe
│ ├── dependency_links.txt
│ ├── top_level.txt
│ ├── SOURCES.txt
│ └── requires.txt
├── Figs
└── readme.md
├── recipes
├── data_cleaner.yaml
├── accelerate_configs
│ ├── ddp.yaml
│ ├── zero2.yaml
│ ├── zero3.yaml
│ └── fsdp.yaml
├── dra_grpo.yaml
└── dra_dr_grpo.yaml
├── LICENSE
└── README.md
/src/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/Figs/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/open_r1/readme.md:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/not-zip-safe:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/dependency_links.txt:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/top_level.txt:
--------------------------------------------------------------------------------
1 | open_r1
2 |
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/hub.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/hub.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/__pycache__/data_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/data_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/callbacks.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/callbacks.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/__pycache__/import_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/import_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/__pycache__/mergekit_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/__pycache__/mergekit_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/models/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/scripts/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/judges.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/judges.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/evaluation.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/evaluation.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/import_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/import_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/model_utils.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/model_utils.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/models/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/scripts/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/__init__.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/__init__.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__pycache__/wandb_logging.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/utils/__pycache__/wandb_logging.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/callbacks.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/callbacks.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/grpo_config.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/grpo_config.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/sft_config.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/sft_config.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/models/__pycache__/modeling_base.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/modeling_base.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/grpo_trainer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/grpo_trainer.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/model_config.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/model_config.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .import_utils import is_e2b_available
2 | from .model_utils import get_tokenizer
3 |
4 |
5 | __all__ = ["get_tokenizer", "is_e2b_available"]
6 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__pycache__/drgrpo_trainer.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/trainer/__pycache__/drgrpo_trainer.cpython-311.pyc
--------------------------------------------------------------------------------
/src/open_r1/trl/models/__pycache__/modeling_value_head.cpython-311.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiwenc1/DRA-GRPO/HEAD/src/open_r1/trl/models/__pycache__/modeling_value_head.cpython-311.pyc
--------------------------------------------------------------------------------
/recipes/data_cleaner.yaml:
--------------------------------------------------------------------------------
1 | model_kwargs:
2 | model: Qwen/Qwen2.5-Math-7B-Instruct
3 | trust_remote_code: true
4 | max_model_len: 4096
5 | gpu_memory_utilization: 0.9
6 | enforce_eager: true
7 | tensor_parallel_size: 4
8 |
9 | sampling_params:
10 | temperature: 0.7
11 | top_p: 0.9
12 | max_tokens: 4096
13 |
--------------------------------------------------------------------------------
/recipes/accelerate_configs/ddp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: bf16
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/src/open_r1/utils/wandb_logging.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 |
4 | def init_wandb_training(training_args):
5 | """
6 | Helper function for setting up Weights & Biases logging tools.
7 | """
8 | if training_args.wandb_entity is not None:
9 | os.environ["WANDB_ENTITY"] = training_args.wandb_entity
10 | if training_args.wandb_project is not None:
11 | os.environ["WANDB_PROJECT"] = training_args.wandb_project
12 |
--------------------------------------------------------------------------------
/recipes/accelerate_configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: bf16
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
--------------------------------------------------------------------------------
/recipes/accelerate_configs/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/src/open_r1/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/SOURCES.txt:
--------------------------------------------------------------------------------
1 | LICENSE
2 | README.md
3 | setup.cfg
4 | setup.py
5 | src/open_r1/__init__.py
6 | src/open_r1/configs.py
7 | src/open_r1/evaluate.py
8 | src/open_r1/generate.py
9 | src/open_r1/grpo.py
10 | src/open_r1/rewards.py
11 | src/open_r1/sft.py
12 | src/open_r1.egg-info/PKG-INFO
13 | src/open_r1.egg-info/SOURCES.txt
14 | src/open_r1.egg-info/dependency_links.txt
15 | src/open_r1.egg-info/not-zip-safe
16 | src/open_r1.egg-info/requires.txt
17 | src/open_r1.egg-info/top_level.txt
18 | src/open_r1/utils/__init__.py
19 | src/open_r1/utils/callbacks.py
20 | src/open_r1/utils/evaluation.py
21 | src/open_r1/utils/hub.py
22 | src/open_r1/utils/import_utils.py
23 | src/open_r1/utils/model_utils.py
24 | src/open_r1/utils/wandb_logging.py
--------------------------------------------------------------------------------
/recipes/accelerate_configs/fsdp.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | enable_cpu_affinity: false
6 | fsdp_config:
7 | fsdp_activation_checkpointing: false # Need fix from: https://github.com/huggingface/transformers/pull/36610
8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9 | fsdp_backward_prefetch: BACKWARD_PRE
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_forward_prefetch: true
12 | fsdp_offload_params: false
13 | fsdp_sharding_strategy: FULL_SHARD
14 | fsdp_state_dict_type: FULL_STATE_DICT
15 | fsdp_sync_module_states: true
16 | fsdp_use_orig_params: true
17 | machine_rank: 0
18 | main_training_function: main
19 | mixed_precision: bf16
20 | num_machines: 1
21 | num_processes: 8
22 | rdzv_backend: static
23 | same_network: true
24 | tpu_env: []
25 | tpu_use_cluster: false
26 | tpu_use_sudo: false
27 | use_cpu: false
--------------------------------------------------------------------------------
/src/open_r1/utils/import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from transformers.utils.import_utils import _is_package_available
16 |
17 |
18 | # Use same as transformers.utils.import_utils
19 | _e2b_available = _is_package_available("e2b")
20 |
21 |
22 | def is_e2b_available() -> bool:
23 | return _e2b_available
24 |
--------------------------------------------------------------------------------
/src/open_r1/trl/extras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import _LazyModule
18 |
19 |
20 | _import_structure = {
21 | "best_of_n_sampler": ["BestOfNSampler"],
22 | }
23 |
24 | if TYPE_CHECKING:
25 | from .best_of_n_sampler import BestOfNSampler
26 | else:
27 | import sys
28 |
29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
30 |
--------------------------------------------------------------------------------
/src/open_r1.egg-info/requires.txt:
--------------------------------------------------------------------------------
1 | accelerate==1.4.0
2 | bitsandbytes>=0.43.0
3 | einops>=0.8.0
4 | datasets>=3.2.0
5 | deepspeed==0.15.4
6 | hf_transfer>=0.1.4
7 | huggingface-hub[cli]<1.0,>=0.19.2
8 | langdetect
9 | latex2sympy2_extended>=1.0.6
10 | math-verify==0.5.2
11 | liger_kernel==0.5.3
12 | packaging>=23.0
13 | safetensors>=0.3.3
14 | sentencepiece>=0.1.99
15 | transformers==4.49.0
16 | trl@ git+https://github.com/huggingface/trl.git@69ad852e5654a77f1695eb4c608906fe0c7e8624
17 | wandb>=0.19.1
18 |
19 | [code]
20 | e2b-code-interpreter>=1.0.5
21 | python-dotenv
22 |
23 | [dev]
24 | ruff>=0.9.0
25 | isort>=5.12.0
26 | flake8>=6.0.0
27 | pytest
28 | parameterized>=0.9.0
29 | math-verify==0.5.2
30 | lighteval@ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905
31 |
32 | [eval]
33 | lighteval@ git+https://github.com/huggingface/lighteval.git@ed084813e0bd12d82a06d9f913291fdbee774905
34 | math-verify==0.5.2
35 |
36 | [quality]
37 | ruff>=0.9.0
38 | isort>=5.12.0
39 | flake8>=6.0.0
40 |
41 | [tests]
42 | pytest
43 | parameterized>=0.9.0
44 | math-verify==0.5.2
45 |
46 | [torch]
47 | torch==2.5.1
48 |
--------------------------------------------------------------------------------
/src/open_r1/trl/environment/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import _LazyModule
18 |
19 |
20 | _import_structure = {
21 | "base_environment": ["TextEnvironment", "TextHistory"],
22 | }
23 |
24 | if TYPE_CHECKING:
25 | from .base_environment import TextEnvironment, TextHistory
26 | else:
27 | import sys
28 |
29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
30 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2025 xiwenc1
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import _LazyModule
18 |
19 |
20 | _import_structure = {
21 | "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
22 | }
23 |
24 | if TYPE_CHECKING:
25 | from .utils import ScriptArguments, TrlParser, init_zero_verbose
26 | else:
27 | import sys
28 |
29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
30 |
--------------------------------------------------------------------------------
/src/open_r1/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | from transformers import AutoTokenizer, PreTrainedTokenizer
2 |
3 | from trl import ModelConfig
4 |
5 | from ..configs import GRPOConfig, SFTConfig
6 |
7 |
8 | DEFAULT_CHAT_TEMPLATE = "{% for message in messages %}\n{% if message['role'] == 'user' %}\n{{ '<|user|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'system' %}\n{{ '<|system|>\n' + message['content'] + eos_token }}\n{% elif message['role'] == 'assistant' %}\n{{ '<|assistant|>\n' + message['content'] + eos_token }}\n{% endif %}\n{% if loop.last and add_generation_prompt %}\n{{ '<|assistant|>' }}\n{% endif %}\n{% endfor %}"
9 |
10 |
11 | def get_tokenizer(
12 | model_args: ModelConfig, training_args: SFTConfig | GRPOConfig, auto_set_chat_template: bool = True
13 | ) -> PreTrainedTokenizer:
14 | """Get the tokenizer for the model."""
15 | tokenizer = AutoTokenizer.from_pretrained(
16 | model_args.model_name_or_path,
17 | revision=model_args.model_revision,
18 | trust_remote_code=model_args.trust_remote_code,
19 | )
20 |
21 | if training_args.chat_template is not None:
22 | tokenizer.chat_template = training_args.chat_template
23 | elif auto_set_chat_template and tokenizer.get_chat_template() is None:
24 | tokenizer.chat_template = DEFAULT_CHAT_TEMPLATE
25 |
26 | return tokenizer
27 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/xpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from trl.trainer.online_dpo_config import OnlineDPOConfig
18 |
19 |
20 | @dataclass
21 | class XPOConfig(OnlineDPOConfig):
22 | r"""
23 | Configuration class for the [`XPOTrainer`].
24 |
25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
26 |
27 | Parameters:
28 | alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
29 | Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
30 | and the last alpha is used for the rest of the epochs.
31 | """
32 |
33 | alpha: list[float] = field(
34 | default_factory=lambda: [1e-5],
35 | metadata={
36 | "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each "
37 | "new epoch and the last alpha is used for the rest of the epochs."
38 | },
39 | )
40 |
41 | def __post_init__(self):
42 | super().__post_init__()
43 | if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
44 | self.alpha = self.alpha[0]
45 |
--------------------------------------------------------------------------------
/recipes/dra_grpo.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | attn_implementation: flash_attention_2
6 |
7 | # Data training arguments
8 | dataset_name: knoveleng/open-rs
9 | system_prompt: "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} . The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Note that respond by English, NOT use other languages."
10 |
11 | # GRPO trainer config
12 | bf16: true
13 | use_vllm: true
14 | vllm_device: auto
15 | vllm_enforce_eager: true
16 | vllm_gpu_memory_utilization: 0.7
17 | vllm_max_model_len: 4608
18 | do_eval: false
19 | gradient_accumulation_steps: 4
20 | gradient_checkpointing: true
21 | gradient_checkpointing_kwargs:
22 | use_reentrant: false
23 | hub_model_id: DRA-GRPO
24 | hub_strategy: every_save
25 | learning_rate: 1.0e-06
26 | log_completions: true
27 | log_level: info
28 | logging_first_step: true
29 | logging_steps: 1
30 | logging_strategy: steps
31 | lr_scheduler_type: cosine_with_min_lr
32 | lr_scheduler_kwargs:
33 | min_lr_rate: 0.1
34 | max_prompt_length: 512
35 | max_completion_length: 3584
36 | max_steps: 500
37 | num_generations: 6 #6
38 | num_train_epochs: 1
39 | output_dir: data/DRA-GRPO
40 | overwrite_output_dir: true
41 | per_device_eval_batch_size: 6 #6
42 | per_device_train_batch_size: 4 #6
43 | push_to_hub: true
44 | report_to:
45 | - wandb
46 | reward_funcs:
47 | - format
48 | - cosine
49 | reward_weights:
50 | - 1.0
51 | - 2.0
52 | save_strategy: "steps"
53 | save_steps: 50
54 | seed: 2025
55 | temperature: 0.7
56 | warmup_ratio: 0.1
57 | SMI_reweighting: True
58 | extractor_name: jina
--------------------------------------------------------------------------------
/recipes/dra_dr_grpo.yaml:
--------------------------------------------------------------------------------
1 | # Model arguments
2 | model_name_or_path: deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B
3 | model_revision: main
4 | torch_dtype: bfloat16
5 | attn_implementation: flash_attention_2
6 |
7 | # Data training arguments
8 | dataset_name: knoveleng/open-rs
9 | system_prompt: "A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer, and put your final answer within \\boxed{{}} . The reasoning process and answer are enclosed within and tags, respectively, i.e., reasoning process here answer here . Note that respond by English, NOT use other languages."
10 |
11 | # GRPO trainer config
12 | bf16: true
13 | use_vllm: true
14 | vllm_device: auto
15 | vllm_enforce_eager: true
16 | vllm_gpu_memory_utilization: 0.7
17 | vllm_max_model_len: 4608
18 | do_eval: false
19 | gradient_accumulation_steps: 4
20 | gradient_checkpointing: true
21 | gradient_checkpointing_kwargs:
22 | use_reentrant: false
23 | hub_model_id: DRA-DR_GRPO
24 | hub_strategy: every_save
25 | learning_rate: 1.0e-06
26 | log_completions: true
27 | log_level: info
28 | logging_first_step: true
29 | logging_steps: 1
30 | logging_strategy: steps
31 | lr_scheduler_type: cosine_with_min_lr
32 | lr_scheduler_kwargs:
33 | min_lr_rate: 0.1
34 | max_prompt_length: 512
35 | max_completion_length: 3584
36 | max_steps: 500
37 | num_generations: 6 #6
38 | num_train_epochs: 1
39 | output_dir: data/DRA-DR_GRPO
40 | overwrite_output_dir: true
41 | per_device_eval_batch_size: 6 #6
42 | per_device_train_batch_size: 4 #6
43 | push_to_hub: true
44 | report_to:
45 | - wandb
46 | reward_funcs:
47 | - format
48 | - cosine
49 | reward_weights:
50 | - 1.0
51 | - 2.0
52 | save_strategy: "steps"
53 | save_steps: 25
54 | seed: 42
55 | temperature: 0.7
56 | warmup_ratio: 0.1
57 | SMI_reweighting: True
58 | extractor_name: jina
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/nash_md_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from trl.trainer.online_dpo_config import OnlineDPOConfig
18 |
19 |
20 | @dataclass
21 | class NashMDConfig(OnlineDPOConfig):
22 | r"""
23 | Configuration class for the [`NashMDTrainer`].
24 |
25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following:
26 |
27 | Parameters:
28 | mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
29 | Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
30 | mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
31 | epochs.
32 | """
33 |
34 | mixture_coef: list[float] = field(
35 | default_factory=lambda: [0.5],
36 | metadata={
37 | "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided "
38 | "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the "
39 | "rest of the epochs."
40 | },
41 | )
42 |
43 | def __post_init__(self):
44 | super().__post_init__()
45 | if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
46 | self.mixture_coef = self.mixture_coef[0]
47 |
--------------------------------------------------------------------------------
/src/open_r1/trl/templates/lm_model_card.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 | # Model Card for {{ model_name }}
6 |
7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
8 | It has been trained using [TRL](https://github.com/huggingface/trl).
9 |
10 | ## Quick start
11 |
12 | ```python
13 | from transformers import pipeline
14 |
15 | question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
16 | generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda")
17 | output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
18 | print(output["generated_text"])
19 | ```
20 |
21 | ## Training procedure
22 |
23 | {% if wandb_url %}[
]({{ wandb_url }}){% endif %}
24 | {% if comet_url %}[
]({{ comet_url }}){% endif %}
25 |
26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
27 |
28 | ### Framework versions
29 |
30 | - TRL: {{ trl_version }}
31 | - Transformers: {{ transformers_version }}
32 | - Pytorch: {{ pytorch_version }}
33 | - Datasets: {{ datasets_version }}
34 | - Tokenizers: {{ tokenizers_version }}
35 |
36 | ## Citations
37 |
38 | {% if trainer_citation %}Cite {{ trainer_name }} as:
39 |
40 | ```bibtex
41 | {{ trainer_citation }}
42 | ```{% endif %}
43 |
44 | Cite TRL as:
45 |
46 | ```bibtex
47 | {% raw %}@misc{vonwerra2022trl,
48 | title = {{TRL: Transformer Reinforcement Learning}},
49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallouédec},
50 | year = 2020,
51 | journal = {GitHub repository},
52 | publisher = {GitHub},
53 | howpublished = {\url{https://github.com/huggingface/trl}}
54 | }{% endraw %}
55 | ```
56 |
--------------------------------------------------------------------------------
/src/open_r1/trl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
18 |
19 |
20 | _import_structure = {
21 | "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"],
22 | "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
23 | "utils": ["SUPPORTED_ARCHITECTURES", "prepare_deepspeed", "setup_chat_format", "unwrap_model_for_generation"],
24 | }
25 |
26 | try:
27 | if not is_diffusers_available():
28 | raise OptionalDependencyNotAvailable()
29 | except OptionalDependencyNotAvailable:
30 | pass
31 | else:
32 | _import_structure["modeling_sd_base"] = [
33 | "DDPOPipelineOutput",
34 | "DDPOSchedulerOutput",
35 | "DDPOStableDiffusionPipeline",
36 | "DefaultDDPOStableDiffusionPipeline",
37 | ]
38 |
39 | if TYPE_CHECKING:
40 | from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model
41 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
42 | from .utils import SUPPORTED_ARCHITECTURES, prepare_deepspeed, setup_chat_format, unwrap_model_for_generation
43 |
44 | try:
45 | if not is_diffusers_available():
46 | raise OptionalDependencyNotAvailable()
47 | except OptionalDependencyNotAvailable:
48 | pass
49 | else:
50 | from .modeling_sd_base import (
51 | DDPOPipelineOutput,
52 | DDPOSchedulerOutput,
53 | DDPOStableDiffusionPipeline,
54 | DefaultDDPOStableDiffusionPipeline,
55 | )
56 | else:
57 | import sys
58 |
59 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
60 |
--------------------------------------------------------------------------------
/src/open_r1/utils/callbacks.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import subprocess
18 | from typing import List
19 |
20 | from transformers import TrainerCallback
21 | from transformers.trainer_callback import TrainerControl, TrainerState
22 | from transformers.training_args import TrainingArguments
23 |
24 | from .evaluation import run_benchmark_jobs
25 | from .hub import push_to_hub_revision
26 |
27 |
28 | def is_slurm_available() -> bool:
29 | # returns true if a slurm queueing system is available
30 | try:
31 | subprocess.run(["sinfo"], check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
32 | return True
33 | except FileNotFoundError:
34 | return False
35 |
36 |
37 | class DummyConfig:
38 | def __init__(self, **kwargs):
39 | for k, v in kwargs.items():
40 | setattr(self, k, v)
41 |
42 |
43 | class PushToHubRevisionCallback(TrainerCallback):
44 | def __init__(self, model_config) -> None:
45 | self.model_config = model_config
46 |
47 | def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
48 | if state.is_world_process_zero:
49 | global_step = state.global_step
50 |
51 | # WARNING: if you use dataclasses.replace(args, ...) the accelerator dist state will be broken, so I do this workaround
52 | # Also if you instantiate a new SFTConfig, the accelerator dist state will be broken
53 | dummy_config = DummyConfig(
54 | hub_model_id=args.hub_model_id,
55 | hub_model_revision=f"{args.hub_model_revision}-step-{global_step:09d}",
56 | output_dir=f"{args.output_dir}/checkpoint-{global_step}",
57 | system_prompt=args.system_prompt,
58 | )
59 |
60 | future = push_to_hub_revision(
61 | dummy_config, extra_ignore_patterns=["*.pt"]
62 | ) # don't push the optimizer states
63 |
64 | if is_slurm_available():
65 | dummy_config.benchmarks = args.benchmarks
66 |
67 | def run_benchmark_callback(_):
68 | print(f"Checkpoint {global_step} pushed to hub.")
69 | run_benchmark_jobs(dummy_config, self.model_config)
70 |
71 | future.add_done_callback(run_benchmark_callback)
72 |
73 |
74 | CALLBACKS = {
75 | "push_to_hub_revision": PushToHubRevisionCallback,
76 | }
77 |
78 |
79 | def get_callbacks(train_config, model_config) -> List[TrainerCallback]:
80 | callbacks = []
81 | for callback_name in train_config.callbacks:
82 | if callback_name not in CALLBACKS:
83 | raise ValueError(f"Callback {callback_name} not found in CALLBACKS.")
84 | callbacks.append(CALLBACKS[callback_name](model_config))
85 |
86 | return callbacks
87 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import platform
17 | from importlib.metadata import version
18 |
19 | import torch
20 | from accelerate.commands.config import default_config_file, load_config_from_file
21 | from transformers import is_bitsandbytes_available
22 | from transformers.utils import is_liger_kernel_available, is_openai_available, is_peft_available
23 |
24 | from .. import __version__
25 | from ..import_utils import is_deepspeed_available, is_diffusers_available, is_llm_blender_available
26 | from .utils import get_git_commit_hash
27 |
28 |
29 | def print_env():
30 | if torch.cuda.is_available():
31 | devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
32 |
33 | accelerate_config = accelerate_config_str = "not found"
34 |
35 | # Get the default from the config file.
36 | if os.path.isfile(default_config_file):
37 | accelerate_config = load_config_from_file(default_config_file).to_dict()
38 |
39 | accelerate_config_str = (
40 | "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()])
41 | if isinstance(accelerate_config, dict)
42 | else accelerate_config
43 | )
44 |
45 | commit_hash = get_git_commit_hash("trl")
46 |
47 | info = {
48 | "Platform": platform.platform(),
49 | "Python version": platform.python_version(),
50 | "PyTorch version": version("torch"),
51 | "CUDA device(s)": ", ".join(devices) if torch.cuda.is_available() else "not available",
52 | "Transformers version": version("transformers"),
53 | "Accelerate version": version("accelerate"),
54 | "Accelerate config": accelerate_config_str,
55 | "Datasets version": version("datasets"),
56 | "HF Hub version": version("huggingface_hub"),
57 | "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__,
58 | "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed",
59 | "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed",
60 | "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed",
61 | "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed",
62 | "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed",
63 | "OpenAI version": version("openai") if is_openai_available() else "not installed",
64 | "PEFT version": version("peft") if is_peft_available() else "not installed",
65 | }
66 |
67 | info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()])
68 | print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa
69 |
70 |
71 | if __name__ == "__main__":
72 | print_env()
73 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/reward_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class RewardConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`RewardTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | max_length (`int` or `None`, *optional*, defaults to `1024`):
32 | Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the
33 | limit. This argument is required if you want to use the default data collator.
34 | disable_dropout (`bool`, *optional*, defaults to `True`):
35 | Whether to disable dropout in the model.
36 | dataset_num_proc (`int`, *optional*, defaults to `None`):
37 | Number of processes to use for processing the dataset.
38 | center_rewards_coefficient (`float`, *optional*, defaults to `None`):
39 | Coefficient to incentivize the reward model to output mean-zero rewards (proposed by
40 | https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`.
41 | remove_unused_columns (`bool`, *optional*, defaults to `False`):
42 | Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if
43 | the dataset is pretokenized.
44 | """
45 |
46 | max_length: Optional[int] = field(
47 | default=1024,
48 | metadata={
49 | "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that "
50 | "exceed the limit. This argument is required if you want to use the default data collator."
51 | },
52 | )
53 | disable_dropout: bool = field(
54 | default=True,
55 | metadata={"help": "Whether to disable dropout in the model and reference model."},
56 | )
57 | dataset_num_proc: Optional[int] = field(
58 | default=None,
59 | metadata={"help": "Number of processes to use for processing the dataset."},
60 | )
61 | center_rewards_coefficient: Optional[float] = field(
62 | default=None,
63 | metadata={
64 | "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by "
65 | "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`."
66 | },
67 | )
68 | remove_unused_columns: bool = field(
69 | default=False,
70 | metadata={
71 | "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only "
72 | "if the dataset is pretokenized."
73 | },
74 | )
75 |
--------------------------------------------------------------------------------
/src/open_r1/configs.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2025 The HuggingFace Team. All rights reserved.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 |
16 | from dataclasses import dataclass, field
17 | from typing import Optional
18 |
19 | import trl
20 |
21 |
22 | # TODO: add the shared options with a mixin to reduce code duplication
23 | @dataclass
24 | class GRPOConfig(trl.GRPOConfig):
25 | """
26 | args for callbacks, benchmarks etc
27 | """
28 |
29 | benchmarks: list[str] = field(
30 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
31 | )
32 | callbacks: list[str] = field(
33 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
34 | )
35 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
36 | system_prompt: Optional[str] = field(
37 | default=None,
38 | metadata={"help": "The optional system prompt to use."},
39 | )
40 | hub_model_revision: Optional[str] = field(
41 | default="main", metadata={"help": "The Hub model branch to push the model to."}
42 | )
43 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
44 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
45 | wandb_entity: Optional[str] = field(
46 | default=None,
47 | metadata={"help": ("The entity to store runs under.")},
48 | )
49 | wandb_project: Optional[str] = field(
50 | default=None,
51 | metadata={"help": ("The project to store runs under.")},
52 | )
53 |
54 |
55 | @dataclass
56 | class SFTConfig(trl.SFTConfig):
57 | """
58 | args for callbacks, benchmarks etc
59 | """
60 |
61 | benchmarks: list[str] = field(
62 | default_factory=lambda: [], metadata={"help": "The benchmarks to run after training."}
63 | )
64 | callbacks: list[str] = field(
65 | default_factory=lambda: [], metadata={"help": "The callbacks to run during training."}
66 | )
67 | chat_template: Optional[str] = field(default=None, metadata={"help": "The chat template to use."})
68 | system_prompt: Optional[str] = field(
69 | default=None,
70 | metadata={"help": "The optional system prompt to use for benchmarking."},
71 | )
72 | hub_model_revision: Optional[str] = field(
73 | default="main",
74 | metadata={"help": "The Hub model branch to push the model to."},
75 | )
76 | overwrite_hub_revision: bool = field(default=False, metadata={"help": "Whether to overwrite the Hub revision."})
77 | push_to_hub_revision: bool = field(default=False, metadata={"help": "Whether to push to a Hub revision/branch."})
78 | wandb_entity: Optional[str] = field(
79 | default=None,
80 | metadata={"help": ("The entity to store runs under.")},
81 | )
82 | wandb_project: Optional[str] = field(
83 | default=None,
84 | metadata={"help": ("The project to store runs under.")},
85 | )
86 |
--------------------------------------------------------------------------------
/src/open_r1/trl/models/auxiliary_modules.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torchvision
20 | from huggingface_hub import hf_hub_download
21 | from huggingface_hub.utils import EntryNotFoundError
22 | from transformers import CLIPModel, is_torch_npu_available, is_torch_xpu_available
23 |
24 |
25 | class MLP(nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 | self.layers = nn.Sequential(
29 | nn.Linear(768, 1024),
30 | nn.Dropout(0.2),
31 | nn.Linear(1024, 128),
32 | nn.Dropout(0.2),
33 | nn.Linear(128, 64),
34 | nn.Dropout(0.1),
35 | nn.Linear(64, 16),
36 | nn.Linear(16, 1),
37 | )
38 |
39 | def forward(self, embed):
40 | return self.layers(embed)
41 |
42 |
43 | class AestheticScorer(torch.nn.Module):
44 | """
45 | This model attempts to predict the aesthetic score of an image. The aesthetic score
46 | is a numerical approximation of how much a specific image is liked by humans on average.
47 | This is from https://github.com/christophschuhmann/improved-aesthetic-predictor
48 | """
49 |
50 | def __init__(self, *, dtype, model_id, model_filename):
51 | super().__init__()
52 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
53 | self.normalize = torchvision.transforms.Normalize(
54 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]
55 | )
56 | self.target_size = 224
57 | self.mlp = MLP()
58 | try:
59 | cached_path = hf_hub_download(model_id, model_filename)
60 | except EntryNotFoundError:
61 | cached_path = os.path.join(model_id, model_filename)
62 | state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True)
63 | self.mlp.load_state_dict(state_dict)
64 | self.dtype = dtype
65 | self.eval()
66 |
67 | def __call__(self, images):
68 | device = next(self.parameters()).device
69 | images = torchvision.transforms.Resize(self.target_size)(images)
70 | images = self.normalize(images).to(self.dtype).to(device)
71 | embed = self.clip.get_image_features(pixel_values=images)
72 | # normalize embedding
73 | embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True)
74 | reward = self.mlp(embed).squeeze(1)
75 | return reward
76 |
77 |
78 | def aesthetic_scorer(hub_model_id, model_filename):
79 | scorer = AestheticScorer(
80 | model_id=hub_model_id,
81 | model_filename=model_filename,
82 | dtype=torch.float32,
83 | )
84 | if is_torch_npu_available():
85 | scorer = scorer.npu()
86 | elif is_torch_xpu_available():
87 | scorer = scorer.xpu()
88 | else:
89 | scorer = scorer.cuda()
90 |
91 | def _fn(images, prompts, metadata):
92 | images = (images).clamp(0, 1)
93 | scores = scorer(images)
94 | return scores, {}
95 |
96 | return _fn
97 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DRA-GRPO
2 | Official code for the paper: DRA-GRPO: Exploring Diversity-Aware Reward Adjustment for R1-Zero-Like Training of Large Language Models [](https://arxiv.org/abs/2505.09655)
3 |
4 | Paper link (preprint): https://arxiv.org/abs/2505.09655
5 |
6 |
7 |
8 | > **Abstract.** Recent advances in reinforcement learning for language model post-training, such as Group Relative Policy Optimization (GRPO), have shown promise in low-resource settings. However, GRPO typically relies on solution-level and scalar reward signals that fail to capture the semantic diversity among sampled completions. This leads to what we identify as a diversity-quality inconsistency, where distinct reasoning paths may receive indistinguishable rewards. To address this limitation, we propose $\textit{Diversity-aware Reward Adjustment} (DRA)$, a method that explicitly incorporates semantic diversity into the reward computation. DRA uses Submodular Mutual Information (SMI) to downweight redundant completions and amplify rewards for diverse ones. This encourages better exploration during learning, while maintaining stable exploitation of high-quality samples. Our method integrates seamlessly with both GRPO and its variant DR.~GRPO, resulting in $\textit{DRA-GRPO}$ and $\textit{DGA-DR.~GRPO}$. We evaluate our method on five mathematical reasoning benchmarks and find that it outperforms recent strong baselines. It achieves state-of-the-art performance with an average accuracy of 58.2\%, using only 7,000 fine-tuning samples and a total training cost of approximately $55.
9 |
10 | ## Installation
11 |
12 | Clone the code. We are using the following modules.
13 |
14 | ```
15 | module load anaconda3/2023.09-0
16 | module load git-lfs/3.3.0
17 | module load cuda/11.8.0
18 |
19 | ```
20 |
21 | Please follow the instructions of [Open-RS](https://github.com/knoveleng/open-rs) to install the environment.
22 | Log in to Hugging Face and Weights & Biases:
23 | ```
24 | huggingface-cli login
25 | wandb login
26 | ```
27 |
28 | ```
29 | source activate openr3
30 | ```
31 |
32 | **You can then remove ```trl``` package from the environment, because we customized it.**
33 |
34 |
35 |
36 | ## Training
37 |
38 | ### DRA-GRPO
39 | ```
40 | ACCELERATE_LOG_LEVEL=info accelerate launch \
41 | --main_process_port 11188 \
42 | --config_file recipes/accelerate_configs/zero2.yaml \
43 | --num_processes=3 \
44 | src/open_r1/grpo.py \
45 | --config recipes/dra_grpo.yaml
46 | ```
47 |
48 |
49 | ### DRA-DR. GRPO
50 | ```
51 | ACCELERATE_LOG_LEVEL=info accelerate launch \
52 | --main_process_port 18007 \
53 | --config_file recipes/accelerate_configs/zero2.yaml \
54 | --num_processes=3 \
55 | src/open_r1/drgrpo.py \
56 | --config recipes/dra_dr_grpo.yaml
57 | ```
58 |
59 | All weights will update to Huggingface.
60 |
61 | ## Inference via lighteval (Test multiple steps)
62 | We have an evaluation template
63 |
64 | ```
65 | base evaL_all.sh
66 | ```
67 |
68 | ## Checkpoints (Updated in [Huggingface Repo](https://huggingface.co/SpiceRL) )
69 | ```
70 | MODEL= xxx
71 | MODEL_NAME=$(basename "$MODEL")
72 |
73 | TASKS="math_500 amc23 minerva olympiadbench aime24"
74 |
75 | OUTPUT_DIR=data-test/evals/${MODEL_NAME}
76 | MODEL_ARGS="pretrained=$MODEL,dtype=bfloat16,max_model_length=32768,gpu_memory_utilization=0.8,generation_parameters={max_new_tokens:32768,temperature:0.6,top_p:0.95}"
77 | for TASK in $TASKS; do
78 | lighteval vllm "$MODEL_ARGS" "custom|$TASK|0|0" \
79 | --custom-tasks src/open_r1/evaluate.py \
80 | --use-chat-template \
81 | --output-dir "$OUTPUT_DIR"
82 | done
83 |
84 | done
85 | ```
86 |
87 | Replace ```xxx``` with ```SpiceRL/DRA-GRPO``` or ```SpiceRL/DRA-DR.GRPO```. The evaluation only requires one GPU.
88 |
89 | ##
90 |
91 | Our code is built based on [Open-rs](https://github.com/knoveleng/open-rs). Thanks!
92 |
93 |
94 |
95 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/grpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import argparse
16 | from dataclasses import dataclass, field
17 | from typing import Optional
18 |
19 | from datasets import load_dataset
20 | from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer
21 |
22 | from trl import GRPOConfig, GRPOTrainer, ModelConfig, ScriptArguments, TrlParser, get_peft_config
23 |
24 |
25 | @dataclass
26 | class GRPOScriptArguments(ScriptArguments):
27 | """
28 | Script arguments for the GRPO training script.
29 |
30 | Args:
31 | reward_model_name_or_path (`str` or `None`):
32 | Reward model id of a pretrained model hosted inside a model repo on huggingface.co or local path to a
33 | directory containing model weights saved using [`~transformers.PreTrainedModel.save_pretrained`].
34 | """
35 |
36 | reward_model_name_or_path: Optional[str] = field(
37 | default=None,
38 | metadata={
39 | "help": "Reward model id of a pretrained model hosted inside a model repo on huggingface.co or "
40 | "local path to a directory containing model weights saved using `PreTrainedModel.save_pretrained`."
41 | },
42 | )
43 |
44 |
45 | def main(script_args, training_args, model_args):
46 | # Load a pretrained model
47 | model = AutoModelForCausalLM.from_pretrained(
48 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
49 | )
50 | tokenizer = AutoTokenizer.from_pretrained(
51 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
52 | )
53 | reward_model = AutoModelForSequenceClassification.from_pretrained(
54 | script_args.reward_model_name_or_path, trust_remote_code=model_args.trust_remote_code, num_labels=1
55 | )
56 |
57 | # Load the dataset
58 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
59 |
60 | # Initialize the GRPO trainer
61 | trainer = GRPOTrainer(
62 | model=model,
63 | reward_funcs=reward_model,
64 | args=training_args,
65 | train_dataset=dataset[script_args.dataset_train_split],
66 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
67 | processing_class=tokenizer,
68 | peft_config=get_peft_config(model_args),
69 | )
70 |
71 | # Train and push the model to the Hub
72 | trainer.train()
73 |
74 | # Save and push to hub
75 | trainer.save_model(training_args.output_dir)
76 | if training_args.push_to_hub:
77 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
78 |
79 |
80 | def make_parser(subparsers: argparse._SubParsersAction = None):
81 | dataclass_types = (GRPOScriptArguments, GRPOConfig, ModelConfig)
82 | if subparsers is not None:
83 | parser = subparsers.add_parser("grpo", help="Run the GRPO training script", dataclass_types=dataclass_types)
84 | else:
85 | parser = TrlParser(dataclass_types)
86 | return parser
87 |
88 |
89 | if __name__ == "__main__":
90 | parser = make_parser()
91 | script_args, training_args, model_args = parser.parse_args_and_config()
92 | main(script_args, training_args, model_args)
93 |
--------------------------------------------------------------------------------
/src/open_r1/trl/cli.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import sys
17 |
18 | from accelerate.commands.launch import launch_command, launch_command_parser
19 |
20 | from .scripts.chat import main as chat_main
21 | from .scripts.chat import make_parser as make_chat_parser
22 | from .scripts.dpo import make_parser as make_dpo_parser
23 | from .scripts.env import print_env
24 | from .scripts.grpo import make_parser as make_grpo_parser
25 | from .scripts.kto import make_parser as make_kto_parser
26 | from .scripts.sft import make_parser as make_sft_parser
27 | from .scripts.utils import TrlParser
28 |
29 |
30 | def main():
31 | parser = TrlParser(prog="TRL CLI", usage="trl", allow_abbrev=False)
32 |
33 | # Add the subparsers
34 | subparsers = parser.add_subparsers(help="available commands", dest="command", parser_class=TrlParser)
35 |
36 | # Add the subparsers for every script
37 | make_chat_parser(subparsers)
38 | make_dpo_parser(subparsers)
39 | subparsers.add_parser("env", help="Print the environment information")
40 | make_grpo_parser(subparsers)
41 | make_kto_parser(subparsers)
42 | make_sft_parser(subparsers)
43 |
44 | # Parse the arguments
45 | args = parser.parse_args()
46 |
47 | if args.command == "chat":
48 | (chat_args,) = parser.parse_args_and_config()
49 | chat_main(chat_args)
50 |
51 | if args.command == "dpo":
52 | # Get the default args for the launch command
53 | dpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "dpo.py")
54 | args = launch_command_parser().parse_args([dpo_training_script])
55 |
56 | # Feed the args to the launch command
57 | args.training_script_args = sys.argv[2:] # remove "trl" and "dpo"
58 | launch_command(args) # launch training
59 |
60 | elif args.command == "env":
61 | print_env()
62 |
63 | elif args.command == "grpo":
64 | # Get the default args for the launch command
65 | grpo_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "grpo.py")
66 | args = launch_command_parser().parse_args([grpo_training_script])
67 |
68 | # Feed the args to the launch command
69 | args.training_script_args = sys.argv[2:] # remove "trl" and "grpo"
70 | launch_command(args) # launch training
71 |
72 | elif args.command == "kto":
73 | # Get the default args for the launch command
74 | kto_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "kto.py")
75 | args = launch_command_parser().parse_args([kto_training_script])
76 |
77 | # Feed the args to the launch command
78 | args.training_script_args = sys.argv[2:] # remove "trl" and "kto"
79 | launch_command(args) # launch training
80 |
81 | elif args.command == "sft":
82 | # Get the default args for the launch command
83 | sft_training_script = os.path.join(os.path.dirname(os.path.abspath(__file__)), "scripts", "sft.py")
84 | args = launch_command_parser().parse_args([sft_training_script])
85 |
86 | # Feed the args to the launch command
87 | args.training_script_args = sys.argv[2:] # remove "trl" and "sft"
88 | launch_command(args) # launch training
89 |
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/prm_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class PRMConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`PRMTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | learning_rate (`float`, *optional*, defaults to `1e-5`):
32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
33 | [`~transformers.TrainingArguments`].
34 | max_length (`int` or `None`, *optional*, defaults to `1024`):
35 | Maximum length of the sequences (prompt + completion) used for truncation.
36 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
37 | Maximum length of the prompt used for truncation.
38 | max_completion_length (`int` or `None`, *optional*, defaults to `None`):
39 | Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
40 | disable_dropout (`bool`, *optional*, defaults to `True`):
41 | Whether to disable dropout in the model.
42 | step_separator (`str`, *optional*, defaults to `"\n"`):
43 | Separator used to separate each step of the reasoning process.
44 | train_on_last_step_only (`bool`, *optional*, defaults to `False`):
45 | Whether to train only on the last step.
46 | dataset_num_proc (`int`, *optional*, defaults to `None`):
47 | Number of processes to use for processing the dataset.
48 | """
49 |
50 | learning_rate: float = field(
51 | default=1e-5,
52 | metadata={
53 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
54 | "`TrainingArguments`."
55 | },
56 | )
57 | max_length: Optional[int] = field(
58 | default=1024,
59 | metadata={"help": "Maximum length of the sequences (prompt + completion) used for truncation."},
60 | )
61 | max_prompt_length: Optional[int] = field(
62 | default=512,
63 | metadata={"help": "Maximum length of the prompt used for truncation."},
64 | )
65 | max_completion_length: Optional[int] = field(
66 | default=None,
67 | metadata={
68 | "help": "Maximum length of the completion used for truncation. The completion is the concatenation of the "
69 | "steps."
70 | },
71 | )
72 | disable_dropout: bool = field(
73 | default=True,
74 | metadata={"help": "Whether to disable dropout in the model and reference model."},
75 | )
76 | step_separator: str = field(
77 | default="\n",
78 | metadata={"help": "Separator used to separate each step of the reasoning process."},
79 | )
80 | train_on_last_step_only: bool = field(
81 | default=False,
82 | metadata={"help": "Whether to train only on the last step."},
83 | )
84 | dataset_num_proc: Optional[int] = field(
85 | default=None,
86 | metadata={"help": "Number of processes to use for processing the dataset."},
87 | )
88 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | # Full training
17 | python trl/scripts/sft.py \
18 | --model_name_or_path Qwen/Qwen2-0.5B \
19 | --dataset_name trl-lib/Capybara \
20 | --learning_rate 2.0e-5 \
21 | --num_train_epochs 1 \
22 | --packing \
23 | --per_device_train_batch_size 2 \
24 | --gradient_accumulation_steps 8 \
25 | --gradient_checkpointing \
26 | --logging_steps 25 \
27 | --eval_strategy steps \
28 | --eval_steps 100 \
29 | --output_dir Qwen2-0.5B-SFT \
30 | --push_to_hub
31 |
32 | # LoRA
33 | python trl/scripts/sft.py \
34 | --model_name_or_path Qwen/Qwen2-0.5B \
35 | --dataset_name trl-lib/Capybara \
36 | --learning_rate 2.0e-4 \
37 | --num_train_epochs 1 \
38 | --packing \
39 | --per_device_train_batch_size 2 \
40 | --gradient_accumulation_steps 8 \
41 | --gradient_checkpointing \
42 | --logging_steps 25 \
43 | --eval_strategy steps \
44 | --eval_steps 100 \
45 | --use_peft \
46 | --lora_r 32 \
47 | --lora_alpha 16 \
48 | --output_dir Qwen2-0.5B-SFT \
49 | --push_to_hub
50 | """
51 |
52 | import argparse
53 |
54 | from datasets import load_dataset
55 | from transformers import AutoTokenizer
56 |
57 | from trl import (
58 | ModelConfig,
59 | ScriptArguments,
60 | SFTConfig,
61 | SFTTrainer,
62 | TrlParser,
63 | get_kbit_device_map,
64 | get_peft_config,
65 | get_quantization_config,
66 | )
67 |
68 |
69 | def main(script_args, training_args, model_args):
70 | ################
71 | # Model init kwargs & Tokenizer
72 | ################
73 | quantization_config = get_quantization_config(model_args)
74 | model_kwargs = dict(
75 | revision=model_args.model_revision,
76 | trust_remote_code=model_args.trust_remote_code,
77 | attn_implementation=model_args.attn_implementation,
78 | torch_dtype=model_args.torch_dtype,
79 | use_cache=False if training_args.gradient_checkpointing else True,
80 | device_map=get_kbit_device_map() if quantization_config is not None else None,
81 | quantization_config=quantization_config,
82 | )
83 | training_args.model_init_kwargs = model_kwargs
84 | tokenizer = AutoTokenizer.from_pretrained(
85 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, use_fast=True
86 | )
87 | if tokenizer.pad_token is None:
88 | tokenizer.pad_token = tokenizer.eos_token
89 |
90 | ################
91 | # Dataset
92 | ################
93 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
94 |
95 | ################
96 | # Training
97 | ################
98 | trainer = SFTTrainer(
99 | model=model_args.model_name_or_path,
100 | args=training_args,
101 | train_dataset=dataset[script_args.dataset_train_split],
102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
103 | processing_class=tokenizer,
104 | peft_config=get_peft_config(model_args),
105 | )
106 |
107 | trainer.train()
108 |
109 | # Save and push to hub
110 | trainer.save_model(training_args.output_dir)
111 | if training_args.push_to_hub:
112 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
113 |
114 |
115 | def make_parser(subparsers: argparse._SubParsersAction = None):
116 | dataclass_types = (ScriptArguments, SFTConfig, ModelConfig)
117 | if subparsers is not None:
118 | parser = subparsers.add_parser("sft", help="Run the SFT training script", dataclass_types=dataclass_types)
119 | else:
120 | parser = TrlParser(dataclass_types)
121 | return parser
122 |
123 |
124 | if __name__ == "__main__":
125 | parser = make_parser()
126 | script_args, training_args, model_args = parser.parse_args_and_config()
127 | main(script_args, training_args, model_args)
128 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/kto.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO.
17 |
18 | # Full training:
19 | python trl/scripts/kto.py \
20 | --dataset_name trl-lib/kto-mix-14k \
21 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
22 | --per_device_train_batch_size 16 \
23 | --num_train_epochs 1 \
24 | --learning_rate 5e-7 \
25 | --lr_scheduler_type=cosine \
26 | --gradient_accumulation_steps 1 \
27 | --logging_steps 10 \
28 | --eval_steps 500 \
29 | --output_dir=kto-aligned-model \
30 | --warmup_ratio 0.1 \
31 | --report_to wandb \
32 | --bf16 \
33 | --logging_first_step
34 |
35 | # QLoRA:
36 | python trl/scripts/kto.py \
37 | --dataset_name trl-lib/kto-mix-14k \
38 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \
39 | --per_device_train_batch_size 8 \
40 | --num_train_epochs 1 \
41 | --learning_rate 5e-7 \
42 | --lr_scheduler_type=cosine \
43 | --gradient_accumulation_steps 1 \
44 | --logging_steps 10 \
45 | --eval_steps 500 \
46 | --output_dir=kto-aligned-model-lora \
47 | --warmup_ratio 0.1 \
48 | --report_to wandb \
49 | --bf16 \
50 | --logging_first_step \
51 | --use_peft \
52 | --load_in_4bit \
53 | --lora_target_modules=all-linear \
54 | --lora_r=16 \
55 | --lora_alpha=16
56 | """
57 |
58 | import argparse
59 |
60 | from datasets import load_dataset
61 | from transformers import AutoModelForCausalLM, AutoTokenizer
62 |
63 | from trl import (
64 | KTOConfig,
65 | KTOTrainer,
66 | ModelConfig,
67 | ScriptArguments,
68 | TrlParser,
69 | get_peft_config,
70 | setup_chat_format,
71 | )
72 |
73 |
74 | def main(script_args, training_args, model_args):
75 | # Load a pretrained model
76 | model = AutoModelForCausalLM.from_pretrained(
77 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
78 | )
79 | ref_model = AutoModelForCausalLM.from_pretrained(
80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
81 | )
82 |
83 | tokenizer = AutoTokenizer.from_pretrained(
84 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
85 | )
86 | if tokenizer.pad_token is None:
87 | tokenizer.pad_token = tokenizer.eos_token
88 |
89 | # If we are aligning a base model, we use ChatML as the default template
90 | if tokenizer.chat_template is None:
91 | model, tokenizer = setup_chat_format(model, tokenizer)
92 |
93 | # Load the dataset
94 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
95 |
96 | # Initialize the KTO trainer
97 | trainer = KTOTrainer(
98 | model,
99 | ref_model,
100 | args=training_args,
101 | train_dataset=dataset[script_args.dataset_train_split],
102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
103 | processing_class=tokenizer,
104 | peft_config=get_peft_config(model_args),
105 | )
106 |
107 | # Train and push the model to the Hub
108 | trainer.train()
109 |
110 | # Save and push to hub
111 | trainer.save_model(training_args.output_dir)
112 | if training_args.push_to_hub:
113 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
114 |
115 |
116 | def make_parser(subparsers: argparse._SubParsersAction = None):
117 | dataclass_types = (ScriptArguments, KTOConfig, ModelConfig)
118 | if subparsers is not None:
119 | parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types)
120 | else:
121 | parser = TrlParser(dataclass_types)
122 | return parser
123 |
124 |
125 | if __name__ == "__main__":
126 | parser = make_parser()
127 | script_args, training_args, model_args = parser.parse_args_and_config()
128 | main(script_args, training_args, model_args)
129 |
--------------------------------------------------------------------------------
/src/open_r1/utils/evaluation.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | from typing import TYPE_CHECKING, Dict, Union
3 |
4 | from .hub import get_gpu_count_for_vllm, get_param_count_from_repo_id
5 |
6 |
7 | if TYPE_CHECKING:
8 | from trl import GRPOConfig, SFTConfig, ModelConfig
9 |
10 | import os
11 |
12 |
13 | # We need a special environment setup to launch vLLM from within Slurm training jobs.
14 | # - Reference code: https://github.com/huggingface/brrr/blob/c55ba3505686d690de24c7ace6487a5c1426c0fd/brrr/lighteval/one_job_runner.py#L105
15 | # - Slack thread: https://huggingface.slack.com/archives/C043JTYE1MJ/p1726566494958269
16 | user_home_directory = os.path.expanduser("~")
17 | VLLM_SLURM_PREFIX = [
18 | "env",
19 | "-i",
20 | "bash",
21 | "-c",
22 | f"for f in /etc/profile.d/*.sh; do source $f; done; export HOME={user_home_directory}; sbatch ",
23 | ]
24 |
25 |
26 | def register_lighteval_task(
27 | configs: Dict[str, str], eval_suite: str, task_name: str, task_list: str, num_fewshot: int = 0
28 | ):
29 | """Registers a LightEval task configuration.
30 |
31 | - Core tasks can be added from this table: https://github.com/huggingface/lighteval/blob/main/src/lighteval/tasks/tasks_table.jsonl
32 | - Custom tasks that require their own metrics / scripts, should be stored in scripts/evaluation/extended_lighteval_tasks
33 |
34 | Args:
35 | configs (Dict[str, str]): The dictionary to store the task configuration.
36 | eval_suite (str, optional): The evaluation suite.
37 | task_name (str): The name of the task.
38 | task_list (str): The comma-separated list of tasks in the format "extended|{task_name}|{num_fewshot}|0" or "lighteval|{task_name}|{num_fewshot}|0".
39 | num_fewshot (int, optional): The number of few-shot examples. Defaults to 0.
40 | is_custom_task (bool, optional): Whether the task is a custom task. Defaults to False.
41 | """
42 | # Format task list in lighteval format
43 | task_list = ",".join(f"{eval_suite}|{task}|{num_fewshot}|0" for task in task_list.split(","))
44 | configs[task_name] = task_list
45 |
46 |
47 | LIGHTEVAL_TASKS = {}
48 |
49 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "math_500", "math_500", 0)
50 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime24", "aime24", 0)
51 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "aime25", "aime25", 0)
52 | register_lighteval_task(LIGHTEVAL_TASKS, "custom", "gpqa", "gpqa:diamond", 0)
53 | register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb", "lcb:codegeneration", 0)
54 | register_lighteval_task(LIGHTEVAL_TASKS, "extended", "lcb_v4", "lcb:codegeneration_v4", 0)
55 |
56 |
57 | def get_lighteval_tasks():
58 | return list(LIGHTEVAL_TASKS.keys())
59 |
60 |
61 | SUPPORTED_BENCHMARKS = get_lighteval_tasks()
62 |
63 |
64 | def run_lighteval_job(
65 | benchmark: str, training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig"
66 | ) -> None:
67 | task_list = LIGHTEVAL_TASKS[benchmark]
68 | model_name = training_args.hub_model_id
69 | model_revision = training_args.hub_model_revision
70 | # For large models >= 30b params or those running the MATH benchmark, we need to shard them across the GPUs to avoid OOM
71 | num_gpus = get_gpu_count_for_vllm(model_name, model_revision)
72 | if get_param_count_from_repo_id(model_name) >= 30_000_000_000:
73 | tensor_parallel = True
74 | else:
75 | num_gpus = 8
76 | tensor_parallel = False
77 |
78 | cmd = VLLM_SLURM_PREFIX.copy()
79 | cmd_args = [
80 | f"--gres=gpu:{num_gpus}",
81 | f"--job-name=or1_{benchmark}_{model_name.split('/')[-1]}_{model_revision}",
82 | "slurm/evaluate.slurm",
83 | benchmark,
84 | f'"{task_list}"',
85 | model_name,
86 | model_revision,
87 | f"{tensor_parallel}",
88 | f"{model_args.trust_remote_code}",
89 | ]
90 | if training_args.system_prompt is not None:
91 | cmd_args.append(f"--system_prompt={training_args.system_prompt}")
92 | cmd[-1] += " " + " ".join(cmd_args)
93 | subprocess.run(cmd, check=True)
94 |
95 |
96 | def run_benchmark_jobs(training_args: Union["SFTConfig", "GRPOConfig"], model_args: "ModelConfig") -> None:
97 | benchmarks = training_args.benchmarks
98 | if len(benchmarks) == 1 and benchmarks[0] == "all":
99 | benchmarks = get_lighteval_tasks()
100 | # Evaluate on all supported benchmarks. Later we may want to include a `chat` option
101 | # that just evaluates on `ifeval` and `mt_bench` etc.
102 |
103 | for benchmark in benchmarks:
104 | print(f"Launching benchmark `{benchmark}`")
105 | if benchmark in get_lighteval_tasks():
106 | run_lighteval_job(benchmark, training_args, model_args)
107 | else:
108 | raise ValueError(f"Unknown benchmark {benchmark}")
109 |
--------------------------------------------------------------------------------
/src/open_r1/trl/extras/dataset_formatting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import logging
16 | from typing import Callable, Literal, Optional, Union
17 |
18 | from datasets import Dataset, Value
19 | from transformers import AutoTokenizer
20 |
21 | from ..trainer.utils import ConstantLengthDataset
22 |
23 |
24 | FORMAT_MAPPING = {
25 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
26 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
27 | }
28 |
29 |
30 | def conversations_formatting_function(
31 | tokenizer: AutoTokenizer, messages_field: Literal["messages", "conversations"], tools: Optional[list] = None
32 | ):
33 | r"""
34 | return a callable function that takes in a "messages" dataset and returns a formatted dataset, based on the tokenizer
35 | apply chat template to the dataset along with the schema of the list of functions in the tools list.
36 | """
37 |
38 | def format_dataset(examples):
39 | if isinstance(examples[messages_field][0], list):
40 | output_texts = []
41 | for i in range(len(examples[messages_field])):
42 | output_texts.append(
43 | tokenizer.apply_chat_template(examples[messages_field][i], tokenize=False, tools=tools)
44 | )
45 | return output_texts
46 | else:
47 | return tokenizer.apply_chat_template(examples[messages_field], tokenize=False, tools=tools)
48 |
49 | return format_dataset
50 |
51 |
52 | def instructions_formatting_function(tokenizer: AutoTokenizer):
53 | r"""
54 | return a callable function that takes in an "instructions" dataset and returns a formatted dataset, based on the tokenizer
55 | apply chat template to the dataset
56 | """
57 |
58 | def format_dataset(examples):
59 | if isinstance(examples["prompt"], list):
60 | output_texts = []
61 | for i in range(len(examples["prompt"])):
62 | converted_sample = [
63 | {"role": "user", "content": examples["prompt"][i]},
64 | {"role": "assistant", "content": examples["completion"][i]},
65 | ]
66 | output_texts.append(tokenizer.apply_chat_template(converted_sample, tokenize=False))
67 | return output_texts
68 | else:
69 | converted_sample = [
70 | {"role": "user", "content": examples["prompt"]},
71 | {"role": "assistant", "content": examples["completion"]},
72 | ]
73 | return tokenizer.apply_chat_template(converted_sample, tokenize=False)
74 |
75 | return format_dataset
76 |
77 |
78 | def get_formatting_func_from_dataset(
79 | dataset: Union[Dataset, ConstantLengthDataset], tokenizer: AutoTokenizer, tools: Optional[list] = None
80 | ) -> Optional[Callable]:
81 | r"""
82 | Finds the correct formatting function based on the dataset structure. Currently supported datasets are:
83 | - `ChatML` with [{"role": str, "content": str}]
84 | - `instruction` with [{"prompt": str, "completion": str}]
85 |
86 | Args:
87 | dataset (Dataset): User dataset
88 | tokenizer (AutoTokenizer): Tokenizer used for formatting
89 |
90 | Returns:
91 | Callable: Formatting function if the dataset format is supported else None
92 | """
93 | if isinstance(dataset, Dataset):
94 | if "messages" in dataset.features:
95 | if dataset.features["messages"] == FORMAT_MAPPING["chatml"]:
96 | logging.info("Formatting dataset with chatml format")
97 | return conversations_formatting_function(tokenizer, "messages", tools)
98 | if "conversations" in dataset.features:
99 | if dataset.features["conversations"] == FORMAT_MAPPING["chatml"]:
100 | logging.info("Formatting dataset with chatml format")
101 | return conversations_formatting_function(tokenizer, "conversations", tools)
102 | elif dataset.features == FORMAT_MAPPING["instruction"]:
103 | logging.info("Formatting dataset with instruction format")
104 | return instructions_formatting_function(tokenizer)
105 |
106 | return None
107 |
--------------------------------------------------------------------------------
/src/open_r1/trl/import_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import importlib
16 | import os
17 | from itertools import chain
18 | from types import ModuleType
19 | from typing import Any
20 |
21 | from transformers.utils.import_utils import _is_package_available
22 |
23 |
24 | # Use same as transformers.utils.import_utils
25 | _deepspeed_available = _is_package_available("deepspeed")
26 | _diffusers_available = _is_package_available("diffusers")
27 | _llm_blender_available = _is_package_available("llm_blender")
28 | _mergekit_available = _is_package_available("mergekit")
29 | _rich_available = _is_package_available("rich")
30 | _unsloth_available = _is_package_available("unsloth")
31 | _vllm_available = _is_package_available("vllm")
32 |
33 |
34 | def is_deepspeed_available() -> bool:
35 | return _deepspeed_available
36 |
37 |
38 | def is_diffusers_available() -> bool:
39 | return _diffusers_available
40 |
41 |
42 | def is_llm_blender_available() -> bool:
43 | return _llm_blender_available
44 |
45 |
46 | def is_mergekit_available() -> bool:
47 | return _mergekit_available
48 |
49 |
50 | def is_rich_available() -> bool:
51 | return _rich_available
52 |
53 |
54 | def is_unsloth_available() -> bool:
55 | return _unsloth_available
56 |
57 |
58 | def is_vllm_available() -> bool:
59 | return _vllm_available
60 |
61 |
62 | class _LazyModule(ModuleType):
63 | """
64 | Module class that surfaces all objects but only performs associated imports when the objects are requested.
65 | """
66 |
67 | # Very heavily inspired by optuna.integration._IntegrationModule
68 | # https://github.com/optuna/optuna/blob/master/optuna/integration/__init__.py
69 | def __init__(self, name, module_file, import_structure, module_spec=None, extra_objects=None):
70 | super().__init__(name)
71 | self._modules = set(import_structure.keys())
72 | self._class_to_module = {}
73 | for key, values in import_structure.items():
74 | for value in values:
75 | self._class_to_module[value] = key
76 | # Needed for autocompletion in an IDE
77 | self.__all__ = list(import_structure.keys()) + list(chain(*import_structure.values()))
78 | self.__file__ = module_file
79 | self.__spec__ = module_spec
80 | self.__path__ = [os.path.dirname(module_file)]
81 | self._objects = {} if extra_objects is None else extra_objects
82 | self._name = name
83 | self._import_structure = import_structure
84 |
85 | # Needed for autocompletion in an IDE
86 | def __dir__(self):
87 | result = super().__dir__()
88 | # The elements of self.__all__ that are submodules may or may not be in the dir already, depending on whether
89 | # they have been accessed or not. So we only add the elements of self.__all__ that are not already in the dir.
90 | for attr in self.__all__:
91 | if attr not in result:
92 | result.append(attr)
93 | return result
94 |
95 | def __getattr__(self, name: str) -> Any:
96 | if name in self._objects:
97 | return self._objects[name]
98 | if name in self._modules:
99 | value = self._get_module(name)
100 | elif name in self._class_to_module.keys():
101 | module = self._get_module(self._class_to_module[name])
102 | value = getattr(module, name)
103 | else:
104 | raise AttributeError(f"module {self.__name__} has no attribute {name}")
105 |
106 | setattr(self, name, value)
107 | return value
108 |
109 | def _get_module(self, module_name: str):
110 | try:
111 | return importlib.import_module("." + module_name, self.__name__)
112 | except Exception as e:
113 | raise RuntimeError(
114 | f"Failed to import {self.__name__}.{module_name} because of the following error (look up to see its"
115 | f" traceback):\n{e}"
116 | ) from e
117 |
118 | def __reduce__(self):
119 | return (self.__class__, (self._name, self.__file__, self._import_structure))
120 |
121 |
122 | class OptionalDependencyNotAvailable(BaseException):
123 | """Internally used error class for signalling an optional dependency was not found."""
124 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/rloo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from dataclasses import dataclass, field
17 |
18 | from ..trainer.utils import OnPolicyConfig
19 |
20 |
21 | @dataclass
22 | class RLOOConfig(OnPolicyConfig):
23 | r"""
24 | Configuration class for the [`RLOOTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[: -len(".py")]`):
32 | Name of this experiment.
33 | reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
34 | Path to the reward model.
35 | num_ppo_epochs (`int`, *optional*, defaults to `4`):
36 | Number of epochs to train.
37 | whiten_rewards (`bool`, *optional*, defaults to `False`):
38 | Whether to whiten the rewards.
39 | kl_coef (`float`, *optional*, defaults to `0.05`):
40 | KL coefficient.
41 | cliprange (`float`, *optional*, defaults to `0.2`):
42 | Clip range.
43 | rloo_k (`int`, *optional*, defaults to `2`):
44 | REINFORCE Leave-One-Out (RLOO) number of online samples per prompt.
45 | normalize_reward (`bool`, *optional*, defaults to `False`):
46 | Whether to normalize rewards.
47 | reward_clip_range (`float`, *optional*, defaults to `10.0`):
48 | Clip range for rewards.
49 | normalize_advantage (`bool`, *optional*, defaults to `False`):
50 | Whether to normalize advantages.
51 | token_level_kl (`bool`, *optional*, defaults to `True`):
52 | Whether to use token-level KL penalty or sequence-level KL penalty.
53 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
54 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
55 | improving generation speed. However, disabling this option allows training models that exceed the VRAM
56 | capacity of a single GPU, albeit at the cost of slower generation.
57 | """
58 |
59 | exp_name: str = field(
60 | default=os.path.basename(__file__)[:-3],
61 | metadata={"help": "Name of this experiment."},
62 | )
63 | reward_model_path: str = field(
64 | default="EleutherAI/pythia-160m",
65 | metadata={"help": "Path to the reward model."},
66 | )
67 | num_ppo_epochs: int = field(
68 | default=4,
69 | metadata={"help": "Number of epochs to train."},
70 | )
71 | whiten_rewards: bool = field(
72 | default=False,
73 | metadata={"help": "Whether to whiten the rewards."},
74 | )
75 | kl_coef: float = field(
76 | default=0.05,
77 | metadata={"help": "KL coefficient."},
78 | )
79 | cliprange: float = field(
80 | default=0.2,
81 | metadata={"help": "Clip range."},
82 | )
83 | rloo_k: int = field(
84 | default=2,
85 | metadata={"help": "REINFORCE Leave-One-Out (RLOO) number of online samples per prompt."},
86 | )
87 | normalize_reward: bool = field(
88 | default=False,
89 | metadata={"help": "Whether to normalize rewards"},
90 | )
91 | reward_clip_range: float = field(
92 | default=10.0,
93 | metadata={"help": "Clip range for rewards"},
94 | )
95 | normalize_advantage: bool = field(
96 | default=False,
97 | metadata={"help": "Whether to normalize advantages"},
98 | )
99 | token_level_kl: bool = field(
100 | default=False,
101 | metadata={"help": "Whether to use token-level KL penalty or sequence-level KL penalty"},
102 | )
103 | ds3_gather_for_generation: bool = field(
104 | default=True,
105 | metadata={
106 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
107 | "generation, improving generation speed. However, disabling this option allows training models that "
108 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
109 | },
110 | )
111 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/gkd_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Any, Optional
17 |
18 | from .sft_config import SFTConfig
19 |
20 |
21 | @dataclass
22 | class GKDConfig(SFTConfig):
23 | """
24 | Configuration class for [`GKDTrainer`].
25 |
26 | Args:
27 | temperature (`float`, *optional*, defaults to `0.9`):
28 | Temperature for sampling. The higher the temperature, the more random the completions.
29 | lmbda (`float`, *optional*, defaults to `0.5`):
30 | Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy
31 | student-generated outputs).
32 | beta (`float`, *optional*, defaults to `0.5`):
33 | Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence loss. When
34 | beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL Divergence.
35 | max_new_tokens (`int`, *optional*, defaults to `128`):
36 | Maximum number of tokens to generate per completion.
37 | teacher_model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
38 | Model name or path of the teacher model. If `None`, the teacher model will be the same as the model
39 | being trained.
40 | teacher_model_init_kwargs (`dict[str, Any]]` or `None`, *optional*, defaults to `None`):
41 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the teacher model
42 | from a string.
43 | disable_dropout (`bool`, *optional*, defaults to `True`):
44 | Whether to disable dropout in the model.
45 | seq_kd (`bool`, *optional*, defaults to `False`):
46 | Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised FT
47 | on teacher-generated output).
48 | """
49 |
50 | temperature: float = field(
51 | default=0.9,
52 | metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
53 | )
54 | lmbda: float = field(
55 | default=0.5,
56 | metadata={
57 | "help": "Lambda parameter that controls the student data fraction (i.e., the proportion of on-policy "
58 | "student-generated outputs)."
59 | },
60 | )
61 | beta: float = field(
62 | default=0.5,
63 | metadata={
64 | "help": "Interpolation coefficient between `0.0` and `1.0` of the Generalized Jensen-Shannon Divergence "
65 | "loss. When beta is `0.0`, the loss is the KL divergence. When beta is `1.0`, the loss is the Inverse KL "
66 | "Divergence."
67 | },
68 | )
69 | max_new_tokens: int = field(
70 | default=128,
71 | metadata={"help": "Maximum number of tokens to generate per completion."},
72 | )
73 | teacher_model_name_or_path: Optional[str] = field(
74 | default=None,
75 | metadata={
76 | "help": "Model name or path of the teacher model. If `None`, the teacher model will be the same as the "
77 | "model being trained."
78 | },
79 | )
80 | teacher_model_init_kwargs: Optional[dict[str, Any]] = field(
81 | default=None,
82 | metadata={
83 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
84 | "teacher model from a string."
85 | },
86 | )
87 | disable_dropout: bool = field(
88 | default=True,
89 | metadata={"help": "Whether to disable dropouts in `model`."},
90 | )
91 | seq_kd: bool = field(
92 | default=False,
93 | metadata={
94 | "help": "Seq_kd parameter that controls whether to perform Sequence-Level KD (can be viewed as supervised "
95 | "FT on teacher-generated output)."
96 | },
97 | )
98 |
99 |
100 | SMI_reweighting: bool = field(
101 | default=False,
102 | metadata={
103 | "help": "use SMI rewighting."
104 | },
105 | )
106 |
107 |
108 | def __post_init__(self):
109 | super().__post_init__()
110 | # check lmbda and beta are in the range [0, 1]
111 | if self.lmbda < 0.0 or self.lmbda > 1.0:
112 | raise ValueError("lmbda must be in the range [0.0, 1.0].")
113 | if self.beta < 0.0 or self.beta > 1.0:
114 | raise ValueError("beta must be in the range [0.0, 1.0].")
115 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/ppo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | from dataclasses import dataclass, field
17 | from typing import Optional
18 |
19 | from ..trainer.utils import OnPolicyConfig
20 |
21 |
22 | @dataclass
23 | class PPOConfig(OnPolicyConfig):
24 | r"""
25 | Configuration class for the [`PPOTrainer`].
26 |
27 | Using [`~transformers.HfArgumentParser`] we can turn this class into
28 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
29 | command line.
30 |
31 | Parameters:
32 | exp_name (`str`, *optional*, defaults to `os.path.basename(__file__)[:-3]`):
33 | Name of this experiment.
34 | reward_model_path (`str`, *optional*, defaults to `"EleutherAI/pythia-160m"`):
35 | Path to the reward model.
36 | model_adapter_name (`str` or `None`, *optional*, defaults to `None`):
37 | Name of the train target PEFT adapter, when using LoRA with multiple adapters.
38 | ref_adapter_name (`str` or `None`, *optional*, defaults to `None`):
39 | Name of the reference PEFT adapter, when using LoRA with multiple adapters.
40 | num_ppo_epochs (`int`, *optional*, defaults to `4`):
41 | Number of epochs to train.
42 | whiten_rewards (`bool`, *optional*, defaults to `False`):
43 | Whether to whiten the rewards.
44 | kl_coef (`float`, *optional*, defaults to `0.05`):
45 | KL coefficient.
46 | cliprange (`float`, *optional*, defaults to `0.2`):
47 | Clip range.
48 | vf_coef (`float`, *optional*, defaults to `0.1`):
49 | Value function coefficient.
50 | cliprange_value (`float`, *optional*, defaults to `0.2`):
51 | Clip range for the value function.
52 | gamma (`float`, *optional*, defaults to `1.0`):
53 | Discount factor.
54 | lam (`float`, *optional*, defaults to `0.95`):
55 | Lambda value for GAE.
56 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
57 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
58 | improving generation speed. However, disabling this option allows training models that exceed the VRAM
59 | capacity of a single GPU, albeit at the cost of slower generation.
60 | """
61 |
62 | exp_name: str = field(
63 | default=os.path.basename(__file__)[:-3],
64 | metadata={"help": "Name of this experiment."},
65 | )
66 | reward_model_path: str = field(
67 | default="EleutherAI/pythia-160m",
68 | metadata={"help": "Path to the reward model."},
69 | )
70 | model_adapter_name: Optional[str] = field(
71 | default=None,
72 | metadata={"help": "Name of the train target PEFT adapter, when using LoRA with multiple adapters."},
73 | )
74 | ref_adapter_name: Optional[str] = field(
75 | default=None,
76 | metadata={"help": "Name of the reference PEFT adapter, when using LoRA with multiple adapters."},
77 | )
78 | num_ppo_epochs: int = field(
79 | default=4,
80 | metadata={"help": "Number of epochs to train."},
81 | )
82 | whiten_rewards: bool = field(
83 | default=False,
84 | metadata={"help": "Whether to whiten the rewards."},
85 | )
86 | kl_coef: float = field(
87 | default=0.05,
88 | metadata={"help": "KL coefficient."},
89 | )
90 | cliprange: float = field(
91 | default=0.2,
92 | metadata={"help": "Clip range."},
93 | )
94 | vf_coef: float = field(
95 | default=0.1,
96 | metadata={"help": "Value function coefficient."},
97 | )
98 | cliprange_value: float = field(
99 | default=0.2,
100 | metadata={"help": "Clip range for the value function."},
101 | )
102 | gamma: float = field(
103 | default=1.0,
104 | metadata={"help": "Discount factor."},
105 | )
106 | lam: float = field(
107 | default=0.95,
108 | metadata={"help": "Lambda value for GAE."},
109 | )
110 | ds3_gather_for_generation: bool = field(
111 | default=True,
112 | metadata={
113 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
114 | "generation, improving generation speed. However, disabling this option allows training models that "
115 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
116 | },
117 | )
118 |
--------------------------------------------------------------------------------
/src/open_r1/trl/scripts/dpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | # Full training
17 | python trl/scripts/dpo.py \
18 | --dataset_name trl-lib/ultrafeedback_binarized \
19 | --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
20 | --learning_rate 5.0e-7 \
21 | --num_train_epochs 1 \
22 | --per_device_train_batch_size 2 \
23 | --gradient_accumulation_steps 8 \
24 | --gradient_checkpointing \
25 | --logging_steps 25 \
26 | --eval_strategy steps \
27 | --eval_steps 50 \
28 | --output_dir Qwen2-0.5B-DPO \
29 | --no_remove_unused_columns
30 |
31 | # LoRA:
32 | python trl/scripts/dpo.py \
33 | --dataset_name trl-lib/ultrafeedback_binarized \
34 | --model_name_or_path Qwen/Qwen2-0.5B-Instruct \
35 | --learning_rate 5.0e-6 \
36 | --num_train_epochs 1 \
37 | --per_device_train_batch_size 2 \
38 | --gradient_accumulation_steps 8 \
39 | --gradient_checkpointing \
40 | --logging_steps 25 \
41 | --eval_strategy steps \
42 | --eval_steps 50 \
43 | --output_dir Qwen2-0.5B-DPO \
44 | --no_remove_unused_columns \
45 | --use_peft \
46 | --lora_r 32 \
47 | --lora_alpha 16
48 | """
49 |
50 | import argparse
51 |
52 | import torch
53 | from datasets import load_dataset
54 | from transformers import AutoModelForCausalLM, AutoTokenizer
55 |
56 | from trl import (
57 | DPOConfig,
58 | DPOTrainer,
59 | ModelConfig,
60 | ScriptArguments,
61 | TrlParser,
62 | get_kbit_device_map,
63 | get_peft_config,
64 | get_quantization_config,
65 | )
66 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE
67 |
68 |
69 | def main(script_args, training_args, model_args):
70 | ################
71 | # Model & Tokenizer
72 | ###################
73 | torch_dtype = (
74 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
75 | )
76 | quantization_config = get_quantization_config(model_args)
77 | model_kwargs = dict(
78 | revision=model_args.model_revision,
79 | attn_implementation=model_args.attn_implementation,
80 | torch_dtype=torch_dtype,
81 | use_cache=False if training_args.gradient_checkpointing else True,
82 | device_map=get_kbit_device_map() if quantization_config is not None else None,
83 | quantization_config=quantization_config,
84 | )
85 | model = AutoModelForCausalLM.from_pretrained(
86 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
87 | )
88 | peft_config = get_peft_config(model_args)
89 | if peft_config is None:
90 | ref_model = AutoModelForCausalLM.from_pretrained(
91 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code, **model_kwargs
92 | )
93 | else:
94 | ref_model = None
95 | tokenizer = AutoTokenizer.from_pretrained(
96 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
97 | )
98 | if tokenizer.pad_token is None:
99 | tokenizer.pad_token = tokenizer.eos_token
100 | if tokenizer.chat_template is None:
101 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE
102 | if script_args.ignore_bias_buffers:
103 | # torch distributed hack
104 | model._ddp_params_and_buffers_to_ignore = [
105 | name for name, buffer in model.named_buffers() if buffer.dtype == torch.bool
106 | ]
107 |
108 | ################
109 | # Dataset
110 | ################
111 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
112 |
113 | ##########
114 | # Training
115 | ################
116 | trainer = DPOTrainer(
117 | model,
118 | ref_model,
119 | args=training_args,
120 | train_dataset=dataset[script_args.dataset_train_split],
121 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
122 | processing_class=tokenizer,
123 | peft_config=peft_config,
124 | )
125 |
126 | trainer.train()
127 |
128 | if training_args.eval_strategy != "no":
129 | metrics = trainer.evaluate()
130 | trainer.log_metrics("eval", metrics)
131 | trainer.save_metrics("eval", metrics)
132 |
133 | # Save and push to hub
134 | trainer.save_model(training_args.output_dir)
135 | if training_args.push_to_hub:
136 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
137 |
138 |
139 | def make_parser(subparsers: argparse._SubParsersAction = None):
140 | dataclass_types = (ScriptArguments, DPOConfig, ModelConfig)
141 | if subparsers is not None:
142 | parser = subparsers.add_parser("dpo", help="Run the DPO training script", dataclass_types=dataclass_types)
143 | else:
144 | parser = TrlParser(dataclass_types)
145 | return parser
146 |
147 |
148 | if __name__ == "__main__":
149 | parser = make_parser()
150 | script_args, training_args, model_args = parser.parse_args_and_config()
151 | main(script_args, training_args, model_args)
152 |
--------------------------------------------------------------------------------
/src/open_r1/utils/hub.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding=utf-8
3 | # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
4 | #
5 | # Licensed under the Apache License, Version 2.0 (the "License");
6 | # you may not use this file except in compliance with the License.
7 | # You may obtain a copy of the License at
8 | #
9 | # http://www.apache.org/licenses/LICENSE-2.0
10 | #
11 | # Unless required by applicable law or agreed to in writing, software
12 | # distributed under the License is distributed on an "AS IS" BASIS,
13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 | # See the License for the specific language governing permissions and
15 | # limitations under the License.
16 |
17 | import logging
18 | import re
19 | from concurrent.futures import Future
20 |
21 | from transformers import AutoConfig
22 |
23 | from huggingface_hub import (
24 | create_branch,
25 | create_repo,
26 | get_safetensors_metadata,
27 | list_repo_commits,
28 | list_repo_files,
29 | list_repo_refs,
30 | repo_exists,
31 | upload_folder,
32 | )
33 | from trl import GRPOConfig, SFTConfig
34 |
35 |
36 | logger = logging.getLogger(__name__)
37 |
38 |
39 | def push_to_hub_revision(training_args: SFTConfig | GRPOConfig, extra_ignore_patterns=[]) -> Future:
40 | """Pushes the model to branch on a Hub repo."""
41 |
42 | # Create a repo if it doesn't exist yet
43 | repo_url = create_repo(repo_id=training_args.hub_model_id, private=True, exist_ok=True)
44 | # Get initial commit to branch from
45 | initial_commit = list_repo_commits(training_args.hub_model_id)[-1]
46 | # Now create the branch we'll be pushing to
47 | create_branch(
48 | repo_id=training_args.hub_model_id,
49 | branch=training_args.hub_model_revision,
50 | revision=initial_commit.commit_id,
51 | exist_ok=True,
52 | )
53 | logger.info(f"Created target repo at {repo_url}")
54 | logger.info(f"Pushing to the Hub revision {training_args.hub_model_revision}...")
55 | ignore_patterns = ["checkpoint-*", "*.pth"]
56 | ignore_patterns.extend(extra_ignore_patterns)
57 | future = upload_folder(
58 | repo_id=training_args.hub_model_id,
59 | folder_path=training_args.output_dir,
60 | revision=training_args.hub_model_revision,
61 | commit_message=f"Add {training_args.hub_model_revision} checkpoint",
62 | ignore_patterns=ignore_patterns,
63 | run_as_future=True,
64 | )
65 | logger.info(f"Pushed to {repo_url} revision {training_args.hub_model_revision} successfully!")
66 |
67 | return future
68 |
69 |
70 | def check_hub_revision_exists(training_args: SFTConfig | GRPOConfig):
71 | """Checks if a given Hub revision exists."""
72 | if repo_exists(training_args.hub_model_id):
73 | if training_args.push_to_hub_revision is True:
74 | # First check if the revision exists
75 | revisions = [rev.name for rev in list_repo_refs(training_args.hub_model_id).branches]
76 | # If the revision exists, we next check it has a README file
77 | if training_args.hub_model_revision in revisions:
78 | repo_files = list_repo_files(
79 | repo_id=training_args.hub_model_id, revision=training_args.hub_model_revision
80 | )
81 | if "README.md" in repo_files and training_args.overwrite_hub_revision is False:
82 | raise ValueError(
83 | f"Revision {training_args.hub_model_revision} already exists. "
84 | "Use --overwrite_hub_revision to overwrite it."
85 | )
86 |
87 |
88 | def get_param_count_from_repo_id(repo_id: str) -> int:
89 | """Function to get model param counts from safetensors metadata or find patterns like 42m, 1.5b, 0.5m or products like 8x7b in a repo ID."""
90 | try:
91 | metadata = get_safetensors_metadata(repo_id)
92 | return list(metadata.parameter_count.values())[0]
93 | except Exception:
94 | # Pattern to match products (like 8x7b) and single values (like 42m)
95 | pattern = r"((\d+(\.\d+)?)(x(\d+(\.\d+)?))?)([bm])"
96 | matches = re.findall(pattern, repo_id.lower())
97 |
98 | param_counts = []
99 | for full_match, number1, _, _, number2, _, unit in matches:
100 | if number2: # If there's a second number, it's a product
101 | number = float(number1) * float(number2)
102 | else: # Otherwise, it's a single value
103 | number = float(number1)
104 |
105 | if unit == "b":
106 | number *= 1_000_000_000 # Convert to billion
107 | elif unit == "m":
108 | number *= 1_000_000 # Convert to million
109 |
110 | param_counts.append(number)
111 |
112 | if len(param_counts) > 0:
113 | # Return the largest number
114 | return int(max(param_counts))
115 | else:
116 | # Return -1 if no match found
117 | return -1
118 |
119 |
120 | def get_gpu_count_for_vllm(model_name: str, revision: str = "main", num_gpus: int = 8) -> int:
121 | """vLLM enforces a constraint that the number of attention heads must be divisible by the number of GPUs and 64 must be divisible by the number of GPUs.
122 | This function calculates the number of GPUs to use for decoding based on the number of attention heads in the model.
123 | """
124 | config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
125 | # Get number of attention heads
126 | num_heads = config.num_attention_heads
127 | # Reduce num_gpus so that num_heads is divisible by num_gpus and 64 is divisible by num_gpus
128 | while num_heads % num_gpus != 0 or 64 % num_gpus != 0:
129 | logger.info(f"Reducing num_gpus from {num_gpus} to {num_gpus - 1} to make num_heads divisible by num_gpus")
130 | num_gpus -= 1
131 | return num_gpus
132 |
--------------------------------------------------------------------------------
/src/open_r1/trl/extras/best_of_n_sampler.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Any, Callable, Optional, Union
16 |
17 | import torch
18 | from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast, set_seed
19 |
20 | from ..models import SUPPORTED_ARCHITECTURES, PreTrainedModelWrapper
21 |
22 |
23 | class BestOfNSampler:
24 | def __init__(
25 | self,
26 | model: PreTrainedModelWrapper,
27 | tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
28 | queries_to_scores: Callable[[list[str]], list[float]],
29 | length_sampler: Any,
30 | sample_size: int = 4,
31 | seed: Optional[int] = None,
32 | n_candidates: int = 1,
33 | generation_config: Optional[GenerationConfig] = None,
34 | ) -> None:
35 | r"""
36 | Initialize the sampler for best-of-n generation
37 |
38 | Args:
39 | model (`PreTrainedModelWrapper`):
40 | The pretrained model to use for generation
41 | tokenizer (`PreTrainedTokenizer` or `PreTrainedTokenizerFast`):
42 | Tokenizer associated with the pretrained model
43 | queries_to_scores (`Callable[[list[str]], list[float]]`):
44 | Callable that takes a list of generated texts and returns the associated reward scores
45 | length_sampler (`Any`):
46 | Sampler used to sample the length of the generated text
47 | sample_size (`int`):
48 | Number of samples to generate for each query
49 | seed (`int`, *optional*):
50 | Random seed used to control generation
51 | n_candidates (`int`):
52 | Number of candidates to return for each query
53 | generation_config (`GenerationConfig`, *optional*):
54 | Generation config passed to the underlying model's `generate` method.
55 | See `GenerationConfig` (https://huggingface.co/docs/transformers/v4.29.1/en/main_classes/text_generation#transformers.GenerationConfig) for more details
56 | """
57 | if seed is not None:
58 | set_seed(seed)
59 |
60 | if not isinstance(tokenizer, (PreTrainedTokenizer, PreTrainedTokenizerFast)):
61 | raise ValueError(
62 | f"tokenizer must be a PreTrainedTokenizer or PreTrainedTokenizerFast, got {type(tokenizer)}"
63 | )
64 | if not isinstance(model, (SUPPORTED_ARCHITECTURES)):
65 | raise ValueError(
66 | f"model must be a PreTrainedModelWrapper, got {type(model)} - supported architectures are: {SUPPORTED_ARCHITECTURES}"
67 | )
68 |
69 | self.model = model
70 | self.tokenizer = tokenizer
71 |
72 | self.queries_to_scores = queries_to_scores
73 | self.length_sampler = length_sampler
74 | self.gen_config = generation_config
75 | self.sample_size = sample_size
76 | self.n_candidates = n_candidates
77 |
78 | def generate(
79 | self,
80 | tokenized_query: Union[list[int], torch.Tensor, list[torch.Tensor], list[list[int]]],
81 | skip_special_tokens: bool = True,
82 | device: Optional[Union[str, torch.device]] = None,
83 | **generation_kwargs,
84 | ) -> list[list[str]]:
85 | r"""
86 | Generate the best of n samples for input queries
87 |
88 | Args:
89 | tokenized_query (`list[int]` or `torch.Tensor` or `list[torch.Tensor]` or `list[int]`):
90 | represents either a single tokenized query (a single tensor or a list of integers) or a batch of tokenized queries (a list of tensors or a list of lists of integers)
91 | skip_special_tokens (`bool`):
92 | Whether to remove the special tokens from the output
93 | device (`str` or `torch.device`, *optional*):
94 | The device on which the model will be loaded
95 | **generation_kwargs (`dict`, *optional*):
96 | Additional keyword arguments passed along to the underlying model's `generate` method.
97 | This is used to override generation config
98 |
99 | Returns:
100 | list[list[str]]: A list of lists of generated texts
101 | """
102 | queries = None
103 |
104 | if isinstance(tokenized_query, torch.Tensor) and tokenized_query.ndim == 1:
105 | queries = tokenized_query.unsqueeze(0)
106 | elif isinstance(tokenized_query, list):
107 | element_type = type(tokenized_query[0])
108 | if element_type is int:
109 | queries = torch.tensor(tokenized_query).unsqueeze(0)
110 | elif element_type is torch.Tensor:
111 | queries = [tensor.reshape((1, -1)) for tensor in tokenized_query]
112 | else:
113 | queries = [torch.tensor(query).reshape((1, -1)) for query in tokenized_query]
114 |
115 | result = []
116 |
117 | for query in queries:
118 | queries = query.repeat((self.sample_size, 1))
119 | output = self.model.generate(
120 | queries.to(device),
121 | max_new_tokens=self.length_sampler(),
122 | generation_config=self.gen_config,
123 | **generation_kwargs,
124 | ).squeeze()
125 | output = self.tokenizer.batch_decode(output, skip_special_tokens=skip_special_tokens)
126 | scores = torch.tensor(self.queries_to_scores(output))
127 | output = [output[i] for i in scores.topk(self.n_candidates).indices]
128 | result.append(output)
129 |
130 | return result
131 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
18 |
19 |
20 | _import_structure = {
21 | "alignprop_config": ["AlignPropConfig"],
22 | "alignprop_trainer": ["AlignPropTrainer"],
23 | "bco_config": ["BCOConfig"],
24 | "bco_trainer": ["BCOTrainer"],
25 | "callbacks": [
26 | "LogCompletionsCallback",
27 | "MergeModelCallback",
28 | "RichProgressCallback",
29 | "SyncRefModelCallback",
30 | "WinRateCallback",
31 | ],
32 | "cpo_config": ["CPOConfig"],
33 | "cpo_trainer": ["CPOTrainer"],
34 | "ddpo_config": ["DDPOConfig"],
35 | "dpo_config": ["DPOConfig", "FDivergenceConstants", "FDivergenceType"],
36 | "dpo_trainer": ["DPOTrainer"],
37 | "gkd_config": ["GKDConfig"],
38 | "gkd_trainer": ["GKDTrainer"],
39 | "grpo_config": ["GRPOConfig"],
40 | "grpo_trainer": ["GRPOTrainer"],
41 | "drgrpo_trainer": ["DRGRPOTrainer"],
42 | "iterative_sft_trainer": ["IterativeSFTTrainer"],
43 | "judges": [
44 | "AllTrueJudge",
45 | "BaseBinaryJudge",
46 | "BaseJudge",
47 | "BasePairwiseJudge",
48 | "BaseRankJudge",
49 | "HfPairwiseJudge",
50 | "OpenAIPairwiseJudge",
51 | "PairRMJudge",
52 | ],
53 | "kto_config": ["KTOConfig"],
54 | "kto_trainer": ["KTOTrainer"],
55 | "model_config": ["ModelConfig"],
56 | "nash_md_config": ["NashMDConfig"],
57 | "nash_md_trainer": ["NashMDTrainer"],
58 | "online_dpo_config": ["OnlineDPOConfig"],
59 | "online_dpo_trainer": ["OnlineDPOTrainer"],
60 | "orpo_config": ["ORPOConfig"],
61 | "orpo_trainer": ["ORPOTrainer"],
62 | "ppo_config": ["PPOConfig"],
63 | "ppo_trainer": ["PPOTrainer"],
64 | "prm_config": ["PRMConfig"],
65 | "prm_trainer": ["PRMTrainer"],
66 | "reward_config": ["RewardConfig"],
67 | "reward_trainer": ["RewardTrainer"],
68 | "rloo_config": ["RLOOConfig"],
69 | "rloo_trainer": ["RLOOTrainer"],
70 | "sft_config": ["SFTConfig"],
71 | "sft_trainer": ["SFTTrainer"],
72 | "utils": [
73 | "ConstantLengthDataset",
74 | "DataCollatorForCompletionOnlyLM",
75 | "RunningMoments",
76 | "compute_accuracy",
77 | "disable_dropout_in_model",
78 | "empty_cache",
79 | "peft_module_casting_to_bf16",
80 | ],
81 | "xpo_config": ["XPOConfig"],
82 | "xpo_trainer": ["XPOTrainer"],
83 | }
84 | try:
85 | if not is_diffusers_available():
86 | raise OptionalDependencyNotAvailable()
87 | except OptionalDependencyNotAvailable:
88 | pass
89 | else:
90 | _import_structure["ddpo_trainer"] = ["DDPOTrainer"]
91 |
92 | if TYPE_CHECKING:
93 | from .alignprop_config import AlignPropConfig
94 | from .alignprop_trainer import AlignPropTrainer
95 | from .bco_config import BCOConfig
96 | from .bco_trainer import BCOTrainer
97 | from .callbacks import (
98 | LogCompletionsCallback,
99 | MergeModelCallback,
100 | RichProgressCallback,
101 | SyncRefModelCallback,
102 | WinRateCallback,
103 | )
104 | from .cpo_config import CPOConfig
105 | from .cpo_trainer import CPOTrainer
106 | from .ddpo_config import DDPOConfig
107 | from .dpo_config import DPOConfig, FDivergenceConstants, FDivergenceType
108 | from .dpo_trainer import DPOTrainer
109 | from .gkd_config import GKDConfig
110 | from .gkd_trainer import GKDTrainer
111 | from .grpo_config import GRPOConfig
112 | from .grpo_trainer import GRPOTrainer
113 | from .drgrpo_trainer import DRGRPOTrainer
114 | from .iterative_sft_trainer import IterativeSFTTrainer
115 | from .judges import (
116 | AllTrueJudge,
117 | BaseBinaryJudge,
118 | BaseJudge,
119 | BasePairwiseJudge,
120 | BaseRankJudge,
121 | HfPairwiseJudge,
122 | OpenAIPairwiseJudge,
123 | PairRMJudge,
124 | )
125 | from .kto_config import KTOConfig
126 | from .kto_trainer import KTOTrainer
127 | from .model_config import ModelConfig
128 | from .nash_md_config import NashMDConfig
129 | from .nash_md_trainer import NashMDTrainer
130 | from .online_dpo_config import OnlineDPOConfig
131 | from .online_dpo_trainer import OnlineDPOTrainer
132 | from .orpo_config import ORPOConfig
133 | from .orpo_trainer import ORPOTrainer
134 | from .ppo_config import PPOConfig
135 | from .ppo_trainer import PPOTrainer
136 | from .prm_config import PRMConfig
137 | from .prm_trainer import PRMTrainer
138 | from .reward_config import RewardConfig
139 | from .reward_trainer import RewardTrainer
140 | from .rloo_config import RLOOConfig
141 | from .rloo_trainer import RLOOTrainer
142 | from .sft_config import SFTConfig
143 | from .sft_trainer import SFTTrainer
144 | from .utils import (
145 | ConstantLengthDataset,
146 | DataCollatorForCompletionOnlyLM,
147 | RunningMoments,
148 | compute_accuracy,
149 | disable_dropout_in_model,
150 | empty_cache,
151 | peft_module_casting_to_bf16,
152 | )
153 | from .xpo_config import XPOConfig
154 | from .xpo_trainer import XPOTrainer
155 |
156 | try:
157 | if not is_diffusers_available():
158 | raise OptionalDependencyNotAvailable()
159 | except OptionalDependencyNotAvailable:
160 | pass
161 | else:
162 | from .ddpo_trainer import DDPOTrainer
163 | else:
164 | import sys
165 |
166 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
167 |
--------------------------------------------------------------------------------
/src/open_r1/trl/models/sd_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | State dict utilities: utility methods for converting state dicts easily
17 | File copied from diffusers to avoid import issues and make TRL compatible
18 | with most of diffusers versions.
19 | """
20 |
21 | import enum
22 |
23 |
24 | class StateDictType(enum.Enum):
25 | """
26 | The mode to use when converting state dicts.
27 | """
28 |
29 | DIFFUSERS_OLD = "diffusers_old"
30 | PEFT = "peft"
31 |
32 |
33 | PEFT_TO_DIFFUSERS = {
34 | ".q_proj.lora_B": ".q_proj.lora_linear_layer.up",
35 | ".q_proj.lora_A": ".q_proj.lora_linear_layer.down",
36 | ".k_proj.lora_B": ".k_proj.lora_linear_layer.up",
37 | ".k_proj.lora_A": ".k_proj.lora_linear_layer.down",
38 | ".v_proj.lora_B": ".v_proj.lora_linear_layer.up",
39 | ".v_proj.lora_A": ".v_proj.lora_linear_layer.down",
40 | ".out_proj.lora_B": ".out_proj.lora_linear_layer.up",
41 | ".out_proj.lora_A": ".out_proj.lora_linear_layer.down",
42 | "to_k.lora_A": "to_k.lora.down",
43 | "to_k.lora_B": "to_k.lora.up",
44 | "to_q.lora_A": "to_q.lora.down",
45 | "to_q.lora_B": "to_q.lora.up",
46 | "to_v.lora_A": "to_v.lora.down",
47 | "to_v.lora_B": "to_v.lora.up",
48 | "to_out.0.lora_A": "to_out.0.lora.down",
49 | "to_out.0.lora_B": "to_out.0.lora.up",
50 | }
51 |
52 | DIFFUSERS_OLD_TO_DIFFUSERS = {
53 | ".to_q_lora.up": ".q_proj.lora_linear_layer.up",
54 | ".to_q_lora.down": ".q_proj.lora_linear_layer.down",
55 | ".to_k_lora.up": ".k_proj.lora_linear_layer.up",
56 | ".to_k_lora.down": ".k_proj.lora_linear_layer.down",
57 | ".to_v_lora.up": ".v_proj.lora_linear_layer.up",
58 | ".to_v_lora.down": ".v_proj.lora_linear_layer.down",
59 | ".to_out_lora.up": ".out_proj.lora_linear_layer.up",
60 | ".to_out_lora.down": ".out_proj.lora_linear_layer.down",
61 | }
62 |
63 | DIFFUSERS_STATE_DICT_MAPPINGS = {
64 | StateDictType.DIFFUSERS_OLD: DIFFUSERS_OLD_TO_DIFFUSERS,
65 | StateDictType.PEFT: PEFT_TO_DIFFUSERS,
66 | }
67 |
68 | KEYS_TO_ALWAYS_REPLACE = {
69 | ".processor.": ".",
70 | }
71 |
72 |
73 | def convert_state_dict(state_dict, mapping):
74 | r"""
75 | Simply iterates over the state dict and replaces the patterns in `mapping` with the corresponding values.
76 |
77 | Args:
78 | state_dict (`dict[str, torch.Tensor]`):
79 | The state dict to convert.
80 | mapping (`dict[str, str]`):
81 | The mapping to use for conversion, the mapping should be a dictionary with the following structure:
82 | - key: the pattern to replace
83 | - value: the pattern to replace with
84 |
85 | Returns:
86 | converted_state_dict (`dict`)
87 | The converted state dict.
88 | """
89 | converted_state_dict = {}
90 | for k, v in state_dict.items():
91 | # First, filter out the keys that we always want to replace
92 | for pattern in KEYS_TO_ALWAYS_REPLACE.keys():
93 | if pattern in k:
94 | new_pattern = KEYS_TO_ALWAYS_REPLACE[pattern]
95 | k = k.replace(pattern, new_pattern)
96 |
97 | for pattern in mapping.keys():
98 | if pattern in k:
99 | new_pattern = mapping[pattern]
100 | k = k.replace(pattern, new_pattern)
101 | break
102 | converted_state_dict[k] = v
103 | return converted_state_dict
104 |
105 |
106 | def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
107 | r"""
108 | Converts a state dict to new diffusers format. The state dict can be from previous diffusers format
109 | (`OLD_DIFFUSERS`), or PEFT format (`PEFT`) or new diffusers format (`DIFFUSERS`). In the last case the method will
110 | return the state dict as is.
111 |
112 | The method only supports the conversion from diffusers old, PEFT to diffusers new for now.
113 |
114 | Args:
115 | state_dict (`dict[str, torch.Tensor]`):
116 | The state dict to convert.
117 | original_type (`StateDictType`, *optional*):
118 | The original type of the state dict, if not provided, the method will try to infer it automatically.
119 | kwargs (`dict`, *args*):
120 | Additional arguments to pass to the method.
121 |
122 | - **adapter_name**: For example, in case of PEFT, some keys will be pre-pended
123 | with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
124 | `get_peft_model_state_dict` method:
125 | https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
126 | but we add it here in case we don't want to rely on that method.
127 | """
128 | peft_adapter_name = kwargs.pop("adapter_name", None)
129 | if peft_adapter_name is not None:
130 | peft_adapter_name = "." + peft_adapter_name
131 | else:
132 | peft_adapter_name = ""
133 |
134 | if original_type is None:
135 | # Old diffusers to PEFT
136 | if any("to_out_lora" in k for k in state_dict.keys()):
137 | original_type = StateDictType.DIFFUSERS_OLD
138 | elif any(f".lora_A{peft_adapter_name}.weight" in k for k in state_dict.keys()):
139 | original_type = StateDictType.PEFT
140 | elif any("lora_linear_layer" in k for k in state_dict.keys()):
141 | # nothing to do
142 | return state_dict
143 | else:
144 | raise ValueError("Could not automatically infer state dict type")
145 |
146 | if original_type not in DIFFUSERS_STATE_DICT_MAPPINGS.keys():
147 | raise ValueError(f"Original type {original_type} is not supported")
148 |
149 | mapping = DIFFUSERS_STATE_DICT_MAPPINGS[original_type]
150 | return convert_state_dict(state_dict, mapping)
151 |
--------------------------------------------------------------------------------
/src/open_r1/trl/core.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 | import warnings
17 | from collections.abc import Mapping
18 | from contextlib import contextmanager
19 | from typing import Optional, Union
20 |
21 | import numpy as np
22 | import torch
23 | from transformers import is_torch_npu_available, is_torch_xpu_available
24 |
25 |
26 | def flatten_dict(nested: dict, sep: str = "/") -> dict:
27 | """Flatten dictionary and concatenate nested keys with separator."""
28 |
29 | def recurse(nest: dict, prefix: str, into: dict) -> None:
30 | for k, v in nest.items():
31 | if sep in k:
32 | raise ValueError(f"separator '{sep}' not allowed to be in key '{k}'")
33 | if isinstance(v, Mapping):
34 | recurse(v, prefix + k + sep, into)
35 | else:
36 | into[prefix + k] = v
37 |
38 | flat = {}
39 | recurse(nested, "", flat)
40 | return flat
41 |
42 |
43 | def masked_mean(values: torch.Tensor, mask: torch.Tensor, axis: Optional[bool] = None) -> torch.Tensor:
44 | """Compute mean of tensor with a masked values."""
45 | if axis is not None:
46 | return (values * mask).sum(axis=axis) / mask.sum(axis=axis)
47 | else:
48 | return (values * mask).sum() / mask.sum()
49 |
50 |
51 | def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
52 | """Compute variance of tensor with masked values."""
53 | mean = masked_mean(values, mask)
54 | centered_values = values - mean
55 | variance = masked_mean(centered_values**2, mask)
56 | if unbiased:
57 | mask_sum = mask.sum()
58 | if mask_sum == 0:
59 | raise ValueError(
60 | "The sum of the mask is zero, which can happen when `mini_batch_size=1`;"
61 | "try increase the `mini_batch_size` or `gradient_accumulation_steps`"
62 | )
63 | # note that if mask_sum == 1, then there is a division by zero issue
64 | # to avoid it you just need to use a larger minibatch_size
65 | bessel_correction = mask_sum / (mask_sum - 1)
66 | variance = variance * bessel_correction
67 | return variance
68 |
69 |
70 | def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True) -> torch.Tensor:
71 | """Whiten values with masked values."""
72 | mean, var = masked_mean(values, mask), masked_var(values, mask)
73 | whitened = (values - mean) * torch.rsqrt(var + 1e-8)
74 | if not shift_mean:
75 | whitened += mean
76 | return whitened
77 |
78 |
79 | class LengthSampler:
80 | """
81 | Samples a length
82 | """
83 |
84 | def __init__(self, min_value: int, max_value: int):
85 | self.values = list(range(min_value, max_value))
86 |
87 | def __call__(self) -> int:
88 | return np.random.choice(self.values)
89 |
90 |
91 | class PPODecorators:
92 | optimize_device_cache = False
93 |
94 | @classmethod
95 | @contextmanager
96 | def empty_device_cache(cls):
97 | yield
98 | if cls.optimize_device_cache:
99 | if is_torch_xpu_available():
100 | gc.collect()
101 | torch.xpu.empty_cache()
102 | gc.collect()
103 | elif is_torch_npu_available():
104 | gc.collect()
105 | torch.npu.empty_cache()
106 | gc.collect()
107 | elif torch.cuda.is_available():
108 | gc.collect()
109 | torch.cuda.empty_cache()
110 | gc.collect()
111 |
112 |
113 | def randn_tensor(
114 | shape: Union[tuple, list],
115 | generator: Optional[Union[list[torch.Generator], torch.Generator]] = None,
116 | device: Optional[torch.device] = None,
117 | dtype: Optional[torch.dtype] = None,
118 | layout: Optional[torch.layout] = None,
119 | ) -> torch.Tensor:
120 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
121 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
122 | is always created on the CPU.
123 | """
124 | # device on which tensor is created defaults to device
125 | rand_device = device
126 | batch_size = shape[0]
127 |
128 | layout = layout or torch.strided
129 | device = device or torch.device("cpu")
130 |
131 | if generator is not None:
132 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
133 | if gen_device_type != device.type and gen_device_type == "cpu":
134 | rand_device = "cpu"
135 | if device != "mps":
136 | warnings.warn(
137 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected."
138 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably"
139 | f" slighly speed up this function by passing a generator that was created on the {device} device.",
140 | UserWarning,
141 | )
142 | elif gen_device_type != device.type and gen_device_type == "cuda":
143 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
144 |
145 | # make sure generator list of length 1 is treated like a non-list
146 | if isinstance(generator, list) and len(generator) == 1:
147 | generator = generator[0]
148 |
149 | if isinstance(generator, list):
150 | shape = (1,) + shape[1:]
151 | latents = [
152 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
153 | for i in range(batch_size)
154 | ]
155 | latents = torch.cat(latents, dim=0).to(device)
156 | else:
157 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
158 |
159 | return latents
160 |
--------------------------------------------------------------------------------
/src/open_r1/generate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import Optional
16 |
17 | from distilabel.llms import OpenAILLM
18 | from distilabel.pipeline import Pipeline
19 | from distilabel.steps import StepResources
20 | from distilabel.steps.tasks import TextGeneration
21 |
22 |
23 | def build_distilabel_pipeline(
24 | model: str,
25 | base_url: str = "http://localhost:8000/v1",
26 | prompt_column: Optional[str] = None,
27 | prompt_template: str = "{{ instruction }}",
28 | temperature: Optional[float] = None,
29 | top_p: Optional[float] = None,
30 | max_new_tokens: int = 8192,
31 | num_generations: int = 1,
32 | input_batch_size: int = 64,
33 | client_replicas: int = 1,
34 | timeout: int = 900,
35 | retries: int = 0,
36 | ) -> Pipeline:
37 | generation_kwargs = {"max_new_tokens": max_new_tokens}
38 |
39 | if temperature is not None:
40 | generation_kwargs["temperature"] = temperature
41 |
42 | if top_p is not None:
43 | generation_kwargs["top_p"] = top_p
44 |
45 | with Pipeline().ray() as pipeline:
46 | TextGeneration(
47 | llm=OpenAILLM(
48 | base_url=base_url,
49 | api_key="something",
50 | model=model,
51 | timeout=timeout,
52 | max_retries=retries,
53 | generation_kwargs=generation_kwargs,
54 | ),
55 | template=prompt_template,
56 | input_mappings={"instruction": prompt_column} if prompt_column is not None else {},
57 | input_batch_size=input_batch_size,
58 | num_generations=num_generations,
59 | group_generations=True,
60 | resources=StepResources(replicas=client_replicas),
61 | )
62 |
63 | return pipeline
64 |
65 |
66 | if __name__ == "__main__":
67 | import argparse
68 |
69 | from datasets import load_dataset
70 |
71 | parser = argparse.ArgumentParser(description="Run distilabel pipeline for generating responses with DeepSeek R1")
72 | parser.add_argument(
73 | "--hf-dataset",
74 | type=str,
75 | required=True,
76 | help="HuggingFace dataset to load",
77 | )
78 | parser.add_argument(
79 | "--hf-dataset-config",
80 | type=str,
81 | required=False,
82 | help="Dataset config to use",
83 | )
84 | parser.add_argument(
85 | "--hf-dataset-split",
86 | type=str,
87 | default="train",
88 | help="Dataset split to use",
89 | )
90 | parser.add_argument(
91 | "--prompt-column",
92 | type=str,
93 | default="prompt",
94 | )
95 | parser.add_argument(
96 | "--prompt-template",
97 | type=str,
98 | default="{{ instruction }}",
99 | help="Template string for formatting prompts.",
100 | )
101 | parser.add_argument(
102 | "--model",
103 | type=str,
104 | required=True,
105 | help="Model name to use for generation",
106 | )
107 | parser.add_argument(
108 | "--vllm-server-url",
109 | type=str,
110 | default="http://localhost:8000/v1",
111 | help="URL of the vLLM server",
112 | )
113 | parser.add_argument(
114 | "--temperature",
115 | type=float,
116 | help="Temperature for generation",
117 | )
118 | parser.add_argument(
119 | "--top-p",
120 | type=float,
121 | help="Top-p value for generation",
122 | )
123 | parser.add_argument(
124 | "--max-new-tokens",
125 | type=int,
126 | default=8192,
127 | help="Maximum number of new tokens to generate",
128 | )
129 | parser.add_argument(
130 | "--num-generations",
131 | type=int,
132 | default=1,
133 | help="Number of generations per problem",
134 | )
135 | parser.add_argument(
136 | "--input-batch-size",
137 | type=int,
138 | default=64,
139 | help="Batch size for input processing",
140 | )
141 | parser.add_argument(
142 | "--client-replicas",
143 | type=int,
144 | default=1,
145 | help="Number of client replicas for parallel processing",
146 | )
147 | parser.add_argument(
148 | "--timeout",
149 | type=int,
150 | default=600,
151 | help="Request timeout in seconds (default: 600)",
152 | )
153 | parser.add_argument(
154 | "--retries",
155 | type=int,
156 | default=0,
157 | help="Number of retries for failed requests (default: 0)",
158 | )
159 | parser.add_argument(
160 | "--hf-output-dataset",
161 | type=str,
162 | required=False,
163 | help="HuggingFace repo to push results to",
164 | )
165 | parser.add_argument(
166 | "--private",
167 | action="store_true",
168 | help="Whether to make the output dataset private when pushing to HF Hub",
169 | )
170 |
171 | args = parser.parse_args()
172 |
173 | print("\nRunning with arguments:")
174 | for arg, value in vars(args).items():
175 | print(f" {arg}: {value}")
176 | print()
177 |
178 | print(f"Loading '{args.hf_dataset}' (config: {args.hf_dataset_config}, split: {args.hf_dataset_split}) dataset...")
179 | dataset = load_dataset(args.hf_dataset, args.hf_dataset_config, split=args.hf_dataset_split)
180 | print("Dataset loaded!")
181 |
182 | pipeline = build_distilabel_pipeline(
183 | model=args.model,
184 | base_url=args.vllm_server_url,
185 | prompt_template=args.prompt_template,
186 | prompt_column=args.prompt_column,
187 | temperature=args.temperature,
188 | top_p=args.top_p,
189 | max_new_tokens=args.max_new_tokens,
190 | num_generations=args.num_generations,
191 | input_batch_size=args.input_batch_size,
192 | client_replicas=args.client_replicas,
193 | timeout=args.timeout,
194 | retries=args.retries,
195 | )
196 |
197 | print("Running generation pipeline...")
198 | distiset = pipeline.run(
199 | dataset=dataset,
200 | dataset_batch_size=args.input_batch_size * 1000,
201 | use_cache=False,
202 | )
203 | print("Generation pipeline finished!")
204 |
205 | if args.hf_output_dataset:
206 | print(f"Pushing resulting dataset to '{args.hf_output_dataset}'...")
207 | distiset.push_to_hub(args.hf_output_dataset, private=args.private)
208 | print("Dataset pushed!")
209 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/orpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Any, Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class ORPOConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`ORPOTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | learning_rate (`float`, *optional*, defaults to `1e-6`):
32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
33 | [`~transformers.TrainingArguments`].
34 | max_length (`int` or `None`, *optional*, defaults to `1024`):
35 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
36 | to use the default data collator.
37 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
38 | Maximum length of the prompt. This argument is required if you want to use the default data collator.
39 | max_completion_length (`int` or `None`, *optional*, defaults to `None`):
40 | Maximum length of the completion. This argument is required if you want to use the default data collator
41 | and your model is an encoder-decoder.
42 | beta (`float`, *optional*, defaults to `0.1`):
43 | Parameter controlling the relative ratio loss weight in the ORPO loss. In the [paper](https://huggingface.co/papers/2403.07691),
44 | it is denoted by λ. In the [code](https://github.com/xfactlab/orpo), it is denoted by `alpha`.
45 | disable_dropout (`bool`, *optional*, defaults to `True`):
46 | Whether to disable dropout in the model.
47 | label_pad_token_id (`int`, *optional*, defaults to `-100`):
48 | Label pad token id. This argument is required if you want to use the default data collator.
49 | padding_value (`int` or `None`, *optional*, defaults to `None`):
50 | Padding value to use. If `None`, the padding value of the tokenizer is used.
51 | truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
52 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
53 | This argument is required if you want to use the default data collator.
54 | generate_during_eval (`bool`, *optional*, defaults to `False`):
55 | If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
56 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
57 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
58 | you need to specify if the model returned by the callable is an encoder-decoder model.
59 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
60 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
61 | string.
62 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
63 | Number of processes to use for processing the dataset.
64 | """
65 |
66 | learning_rate: float = field(
67 | default=1e-6,
68 | metadata={
69 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
70 | "transformers.TrainingArguments."
71 | },
72 | )
73 | max_length: Optional[int] = field(
74 | default=1024,
75 | metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
76 | )
77 | max_prompt_length: Optional[int] = field(
78 | default=512,
79 | metadata={
80 | "help": "Maximum length of the prompt. This argument is required if you want to use the default data "
81 | "collator and your model is an encoder-decoder."
82 | },
83 | )
84 | max_completion_length: Optional[int] = field(
85 | default=None,
86 | metadata={
87 | "help": "Maximum length of the completion. This argument is required if you want to use the default data "
88 | "collator and your model is an encoder-decoder."
89 | },
90 | )
91 | beta: float = field(
92 | default=0.1,
93 | metadata={
94 | "help": "Parameter controlling the relative ratio loss weight in the ORPO loss. In the paper, it is "
95 | "denoted by λ."
96 | },
97 | )
98 | disable_dropout: bool = field(
99 | default=True,
100 | metadata={"help": "Whether to disable dropout in the model."},
101 | )
102 | label_pad_token_id: int = field(
103 | default=-100,
104 | metadata={
105 | "help": "Label pad token id. This argument is required if you want to use the default data collator."
106 | },
107 | )
108 | padding_value: Optional[int] = field(
109 | default=None,
110 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
111 | )
112 | truncation_mode: str = field(
113 | default="keep_end",
114 | metadata={
115 | "help": "Truncation mode to use when the prompt is too long.",
116 | "choices": ["keep_end", "keep_start"],
117 | },
118 | )
119 | generate_during_eval: bool = field(
120 | default=False,
121 | metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
122 | )
123 | is_encoder_decoder: Optional[bool] = field(
124 | default=None,
125 | metadata={
126 | "help": "When using the `model_init` argument (callable) to instantiate the model instead of the `model` "
127 | "argument, you need to specify if the model returned by the callable is an encoder-decoder model."
128 | },
129 | )
130 | model_init_kwargs: Optional[dict[str, Any]] = field(
131 | default=None,
132 | metadata={
133 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
134 | "from a string."
135 | },
136 | )
137 | dataset_num_proc: Optional[int] = field(
138 | default=None,
139 | metadata={"help": "Number of processes to use for processing the dataset."},
140 | )
141 |
--------------------------------------------------------------------------------
/src/open_r1/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Supervised fine-tuning script for decoder language models.
17 |
18 | Usage:
19 |
20 | # One 1 node of 8 x H100s
21 | accelerate launch --config_file=recipes/accelerate_configs/zero3.yaml src/open_r1/sft.py \
22 | --model_name_or_path Qwen/Qwen2.5-1.5B-Instruct \
23 | --dataset_name HuggingFaceH4/Bespoke-Stratos-17k \
24 | --learning_rate 2.0e-5 \
25 | --num_train_epochs 1 \
26 | --packing \
27 | --max_seq_length 4096 \
28 | --per_device_train_batch_size 2 \
29 | --gradient_accumulation_steps 8 \
30 | --gradient_checkpointing \
31 | --bf16 \
32 | --logging_steps 5 \
33 | --eval_strategy steps \
34 | --eval_steps 100 \
35 | --output_dir data/Qwen2.5-1.5B-Open-R1-Distill
36 | """
37 |
38 | import logging
39 | import os
40 | import sys
41 |
42 | import datasets
43 | import torch
44 | import transformers
45 | from datasets import load_dataset
46 | from transformers import set_seed
47 | from transformers.trainer_utils import get_last_checkpoint
48 |
49 | from open_r1.configs import SFTConfig
50 | from open_r1.utils import get_tokenizer
51 | from open_r1.utils.callbacks import get_callbacks
52 | from open_r1.utils.wandb_logging import init_wandb_training
53 | from trl import (
54 | ModelConfig,
55 | ScriptArguments,
56 | SFTTrainer,
57 | TrlParser,
58 | get_kbit_device_map,
59 | get_peft_config,
60 | get_quantization_config,
61 | )
62 |
63 |
64 | logger = logging.getLogger(__name__)
65 |
66 |
67 | def main(script_args, training_args, model_args):
68 | # Set seed for reproducibility
69 | set_seed(training_args.seed)
70 |
71 | ###############
72 | # Setup logging
73 | ###############
74 | logging.basicConfig(
75 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
76 | datefmt="%Y-%m-%d %H:%M:%S",
77 | handlers=[logging.StreamHandler(sys.stdout)],
78 | )
79 | log_level = training_args.get_process_log_level()
80 | logger.setLevel(log_level)
81 | datasets.utils.logging.set_verbosity(log_level)
82 | transformers.utils.logging.set_verbosity(log_level)
83 | transformers.utils.logging.enable_default_handler()
84 | transformers.utils.logging.enable_explicit_format()
85 |
86 | logger.info(f"Model parameters {model_args}")
87 | logger.info(f"Script parameters {script_args}")
88 | logger.info(f"Training parameters {training_args}")
89 |
90 | # Check for last checkpoint
91 | last_checkpoint = None
92 | if os.path.isdir(training_args.output_dir):
93 | last_checkpoint = get_last_checkpoint(training_args.output_dir)
94 | if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
95 | logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.")
96 |
97 | if "wandb" in training_args.report_to:
98 | init_wandb_training(training_args)
99 |
100 | ################
101 | # Load datasets
102 | ################
103 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
104 |
105 | ################
106 | # Load tokenizer
107 | ################
108 | tokenizer = get_tokenizer(model_args, training_args)
109 | tokenizer.pad_token = tokenizer.eos_token
110 |
111 | ###################
112 | # Model init kwargs
113 | ###################
114 | logger.info("*** Initializing model kwargs ***")
115 | torch_dtype = (
116 | model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
117 | )
118 | quantization_config = get_quantization_config(model_args)
119 | model_kwargs = dict(
120 | revision=model_args.model_revision,
121 | trust_remote_code=model_args.trust_remote_code,
122 | attn_implementation=model_args.attn_implementation,
123 | torch_dtype=torch_dtype,
124 | use_cache=False if training_args.gradient_checkpointing else True,
125 | device_map=get_kbit_device_map() if quantization_config is not None else None,
126 | quantization_config=quantization_config,
127 | )
128 | training_args.model_init_kwargs = model_kwargs
129 |
130 | ############################
131 | # Initialize the SFT Trainer
132 | ############################
133 | trainer = SFTTrainer(
134 | model=model_args.model_name_or_path,
135 | args=training_args,
136 | train_dataset=dataset[script_args.dataset_train_split],
137 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
138 | processing_class=tokenizer,
139 | peft_config=get_peft_config(model_args),
140 | callbacks=get_callbacks(training_args, model_args),
141 | )
142 |
143 | ###############
144 | # Training loop
145 | ###############
146 | logger.info("*** Train ***")
147 | checkpoint = None
148 | if training_args.resume_from_checkpoint is not None:
149 | checkpoint = training_args.resume_from_checkpoint
150 | elif last_checkpoint is not None:
151 | checkpoint = last_checkpoint
152 | train_result = trainer.train(resume_from_checkpoint=checkpoint)
153 | metrics = train_result.metrics
154 | metrics["train_samples"] = len(dataset[script_args.dataset_train_split])
155 | trainer.log_metrics("train", metrics)
156 | trainer.save_metrics("train", metrics)
157 | trainer.save_state()
158 |
159 | ##################################
160 | # Save model and create model card
161 | ##################################
162 | logger.info("*** Save model ***")
163 | trainer.save_model(training_args.output_dir)
164 | logger.info(f"Model saved to {training_args.output_dir}")
165 |
166 | # Save everything else on main process
167 | kwargs = {
168 | "dataset_name": script_args.dataset_name,
169 | "tags": ["open-r1"],
170 | }
171 | if trainer.accelerator.is_main_process:
172 | trainer.create_model_card(**kwargs)
173 | # Restore k,v cache for fast inference
174 | trainer.model.config.use_cache = True
175 | trainer.model.config.save_pretrained(training_args.output_dir)
176 |
177 | ##########
178 | # Evaluate
179 | ##########
180 | if training_args.do_eval:
181 | logger.info("*** Evaluate ***")
182 | metrics = trainer.evaluate()
183 | metrics["eval_samples"] = len(dataset[script_args.dataset_test_split])
184 | trainer.log_metrics("eval", metrics)
185 | trainer.save_metrics("eval", metrics)
186 |
187 | #############
188 | # push to hub
189 | #############
190 | if training_args.push_to_hub:
191 | logger.info("Pushing to hub...")
192 | trainer.push_to_hub(**kwargs)
193 |
194 |
195 | if __name__ == "__main__":
196 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
197 | script_args, training_args, model_args = parser.parse_args_and_config()
198 | main(script_args, training_args, model_args)
199 |
--------------------------------------------------------------------------------
/src/open_r1/trl/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | __version__ = "0.15.2"
16 |
17 | from typing import TYPE_CHECKING
18 |
19 | from .import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available
20 |
21 |
22 | _import_structure = {
23 | "scripts": ["init_zero_verbose", "ScriptArguments", "TrlParser"],
24 | "data_utils": [
25 | "apply_chat_template",
26 | "extract_prompt",
27 | "is_conversational",
28 | "maybe_apply_chat_template",
29 | "maybe_convert_to_chatml",
30 | "maybe_extract_prompt",
31 | "maybe_unpair_preference_dataset",
32 | "pack_examples",
33 | "unpair_preference_dataset",
34 | ],
35 | "environment": ["TextEnvironment", "TextHistory"],
36 | "extras": ["BestOfNSampler"],
37 | "import_utils": [
38 | "is_deepspeed_available",
39 | "is_diffusers_available",
40 | "is_llm_blender_available",
41 | "is_mergekit_available",
42 | "is_rich_available",
43 | "is_unsloth_available",
44 | "is_vllm_available",
45 | ],
46 | "models": [
47 | "SUPPORTED_ARCHITECTURES",
48 | "AutoModelForCausalLMWithValueHead",
49 | "AutoModelForSeq2SeqLMWithValueHead",
50 | "PreTrainedModelWrapper",
51 | "create_reference_model",
52 | "setup_chat_format",
53 | ],
54 | "trainer": [
55 | "AlignPropConfig",
56 | "AlignPropTrainer",
57 | "AllTrueJudge",
58 | "BaseBinaryJudge",
59 | "BaseJudge",
60 | "BasePairwiseJudge",
61 | "BaseRankJudge",
62 | "BCOConfig",
63 | "BCOTrainer",
64 | "CPOConfig",
65 | "CPOTrainer",
66 | "DataCollatorForCompletionOnlyLM",
67 | "DPOConfig",
68 | "DPOTrainer",
69 | "FDivergenceConstants",
70 | "FDivergenceType",
71 | "GKDConfig",
72 | "GKDTrainer",
73 | "GRPOConfig",
74 | "GRPOTrainer",
75 | "DRGRPOTrainer",
76 | "HfPairwiseJudge",
77 | "IterativeSFTTrainer",
78 | "KTOConfig",
79 | "KTOTrainer",
80 | "LogCompletionsCallback",
81 | "MergeModelCallback",
82 | "ModelConfig",
83 | "NashMDConfig",
84 | "NashMDTrainer",
85 | "OnlineDPOConfig",
86 | "OnlineDPOTrainer",
87 | "OpenAIPairwiseJudge",
88 | "ORPOConfig",
89 | "ORPOTrainer",
90 | "PairRMJudge",
91 | "PPOConfig",
92 | "PPOTrainer",
93 | "PRMConfig",
94 | "PRMTrainer",
95 | "RewardConfig",
96 | "RewardTrainer",
97 | "RLOOConfig",
98 | "RLOOTrainer",
99 | "SFTConfig",
100 | "SFTTrainer",
101 | "WinRateCallback",
102 | "XPOConfig",
103 | "XPOTrainer",
104 | ],
105 | "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
106 | "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
107 | }
108 |
109 | try:
110 | if not is_diffusers_available():
111 | raise OptionalDependencyNotAvailable()
112 | except OptionalDependencyNotAvailable:
113 | pass
114 | else:
115 | _import_structure["models"].extend(
116 | [
117 | "DDPOPipelineOutput",
118 | "DDPOSchedulerOutput",
119 | "DDPOStableDiffusionPipeline",
120 | "DefaultDDPOStableDiffusionPipeline",
121 | ]
122 | )
123 | _import_structure["trainer"].extend(["DDPOConfig", "DDPOTrainer"])
124 |
125 | if TYPE_CHECKING:
126 | from .data_utils import (
127 | apply_chat_template,
128 | extract_prompt,
129 | is_conversational,
130 | maybe_apply_chat_template,
131 | maybe_convert_to_chatml,
132 | maybe_extract_prompt,
133 | maybe_unpair_preference_dataset,
134 | pack_examples,
135 | unpair_preference_dataset,
136 | )
137 | from .environment import TextEnvironment, TextHistory
138 | from .extras import BestOfNSampler
139 | from .import_utils import (
140 | is_deepspeed_available,
141 | is_diffusers_available,
142 | is_llm_blender_available,
143 | is_mergekit_available,
144 | is_rich_available,
145 | is_unsloth_available,
146 | is_vllm_available,
147 | )
148 | from .models import (
149 | SUPPORTED_ARCHITECTURES,
150 | AutoModelForCausalLMWithValueHead,
151 | AutoModelForSeq2SeqLMWithValueHead,
152 | PreTrainedModelWrapper,
153 | create_reference_model,
154 | setup_chat_format,
155 | )
156 | from .scripts import ScriptArguments, TrlParser, init_zero_verbose
157 | from .trainer import (
158 | AlignPropConfig,
159 | AlignPropTrainer,
160 | AllTrueJudge,
161 | BaseBinaryJudge,
162 | BaseJudge,
163 | BasePairwiseJudge,
164 | BaseRankJudge,
165 | BCOConfig,
166 | BCOTrainer,
167 | CPOConfig,
168 | CPOTrainer,
169 | DataCollatorForCompletionOnlyLM,
170 | DPOConfig,
171 | DPOTrainer,
172 | FDivergenceConstants,
173 | FDivergenceType,
174 | GKDConfig,
175 | GKDTrainer,
176 | GRPOConfig,
177 | GRPOTrainer,
178 | DRGRPOTrainer,
179 | HfPairwiseJudge,
180 | IterativeSFTTrainer,
181 | KTOConfig,
182 | KTOTrainer,
183 | LogCompletionsCallback,
184 | MergeModelCallback,
185 | ModelConfig,
186 | NashMDConfig,
187 | NashMDTrainer,
188 | OnlineDPOConfig,
189 | OnlineDPOTrainer,
190 | OpenAIPairwiseJudge,
191 | ORPOConfig,
192 | ORPOTrainer,
193 | PairRMJudge,
194 | PPOConfig,
195 | PPOTrainer,
196 | PRMConfig,
197 | PRMTrainer,
198 | RewardConfig,
199 | RewardTrainer,
200 | RLOOConfig,
201 | RLOOTrainer,
202 | SFTConfig,
203 | SFTTrainer,
204 | WinRateCallback,
205 | XPOConfig,
206 | XPOTrainer,
207 | )
208 | from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
209 | from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config
210 |
211 | try:
212 | if not is_diffusers_available():
213 | raise OptionalDependencyNotAvailable()
214 | except OptionalDependencyNotAvailable:
215 | pass
216 | else:
217 | from .models import (
218 | DDPOPipelineOutput,
219 | DDPOSchedulerOutput,
220 | DDPOStableDiffusionPipeline,
221 | DefaultDDPOStableDiffusionPipeline,
222 | )
223 | from .trainer import DDPOConfig, DDPOTrainer
224 |
225 | else:
226 | import sys
227 |
228 | sys.modules[__name__] = _LazyModule(
229 | __name__,
230 | globals()["__file__"],
231 | _import_structure,
232 | module_spec=__spec__,
233 | extra_objects={"__version__": __version__},
234 | )
235 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/sft_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass, field
17 | from typing import Any, Optional
18 |
19 | from transformers import TrainingArguments
20 |
21 |
22 | @dataclass
23 | class SFTConfig(TrainingArguments):
24 | r"""
25 | Configuration class for the [`SFTTrainer`].
26 |
27 | Only the parameters specific to SFT training are listed here. For details on other parameters, refer to the
28 | [`~transformers.TrainingArguments`] documentation.
29 |
30 | Using [`~transformers.HfArgumentParser`] we can turn this class into
31 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
32 | command line.
33 |
34 | Parameters:
35 | > Parameters that control the model
36 |
37 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
38 | Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model`
39 | argument of the [`SFTTrainer`] is provided as a string.
40 | use_liger (`bool`, *optional*, defaults to `False`):
41 | Monkey patch the model with Liger kernels to increase throughput and reduce memory usage.
42 |
43 | > Parameters that control the data preprocessing
44 |
45 | dataset_text_field (`str`, *optional*, defaults to `"text"`):
46 | Name of the column that contains text data in the dataset.
47 | dataset_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
48 | Dictionary of optional keyword arguments for the dataset preparation. The only supported key is
49 | `skip_prepare_dataset`.
50 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
51 | Number of processes to use for processing the dataset.
52 | max_seq_length (`int` or `None`, *optional*, defaults to `1024`):
53 | Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated from the
54 | right.
55 | If `None`, no truncation is applied. When packing is enabled, this value sets the sequence length.
56 | packing (`bool`, *optional*, defaults to `False`):
57 | Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to define sequence
58 | length.
59 | eval_packing (`bool` or `None`, *optional*, defaults to `None`):
60 | Whether to pack the eval dataset. If `None`, uses the same value as `packing`.
61 |
62 | > Parameters that control the training
63 |
64 | learning_rate (`float`, *optional*, defaults to `2e-5`):
65 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
66 | [`~transformers.TrainingArguments`].
67 | """
68 |
69 | # Parameters that control the model
70 | model_init_kwargs: Optional[dict[str, Any]] = field(
71 | default=None,
72 | metadata={
73 | "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of "
74 | "the `SFTTrainer` is provided as a string."
75 | },
76 | )
77 | use_liger: bool = field(
78 | default=False,
79 | metadata={"help": "Monkey patch the model with Liger kernels to increase throughput and reduce memory usage."},
80 | )
81 |
82 | # Parameters that control the data preprocessing
83 | dataset_text_field: str = field(
84 | default="text",
85 | metadata={"help": "Name of the column that contains text data in the dataset."},
86 | )
87 | dataset_kwargs: Optional[dict[str, Any]] = field(
88 | default=None,
89 | metadata={
90 | "help": "Dictionary of optional keyword arguments for the dataset preparation. The only supported key is "
91 | "`skip_prepare_dataset`."
92 | },
93 | )
94 | dataset_num_proc: Optional[int] = field(
95 | default=None,
96 | metadata={"help": "Number of processes to use for processing the dataset."},
97 | )
98 | max_seq_length: Optional[int] = field(
99 | default=1024,
100 | metadata={
101 | "help": "Maximum length of the tokenized sequence. Sequences longer than `max_seq_length` are truncated "
102 | "from the right. If `None`, no truncation is applied. When packing is enabled, this value sets the "
103 | "sequence length."
104 | },
105 | )
106 | packing: bool = field(
107 | default=False,
108 | metadata={
109 | "help": "Whether to pack multiple sequences into a fixed-length format. Uses `max_seq_length` to "
110 | "define sequence length."
111 | },
112 | )
113 | eval_packing: Optional[bool] = field(
114 | default=None,
115 | metadata={"help": "Whether to pack the eval dataset. If `None`, uses the same value as `packing`."},
116 | )
117 |
118 | # Parameters that control the training
119 | learning_rate: float = field(
120 | default=2.0e-5,
121 | metadata={
122 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
123 | "`TrainingArguments`."
124 | },
125 | )
126 |
127 | # Deprecated parameters
128 | dataset_batch_size: int = field(
129 | default=None,
130 | metadata={"help": "Deprecated. You can safely remove this parameter from your configuration."},
131 | )
132 | num_of_sequences: int = field(
133 | default=None,
134 | metadata={
135 | "help": "Deprecated. Use `max_seq_length` instead, which specifies the maximum length of the tokenized "
136 | "sequence, unlike `num_of_sequences`, which referred to string sequences."
137 | },
138 | )
139 | chars_per_token: float = field(
140 | default=None,
141 | metadata={"help": "Deprecated. If you want to customize the packing length, use `max_seq_length`."},
142 | )
143 |
144 | def __post_init__(self):
145 | super().__post_init__()
146 |
147 | if self.dataset_batch_size is not None:
148 | warnings.warn(
149 | "`dataset_batch_size` is deprecated and will be remove in version 0.18.0. You can safely remove this "
150 | "parameter from your configuration.",
151 | DeprecationWarning,
152 | )
153 |
154 | if self.num_of_sequences is not None:
155 | warnings.warn(
156 | "`num_of_sequences` is deprecated and will be remove in version 0.18.0. Use `max_seq_length` instead, "
157 | "which specifies the maximum length of the tokenized sequence, unlike `num_of_sequences`, which r"
158 | "eferred to string sequences.",
159 | DeprecationWarning,
160 | )
161 |
162 | if self.chars_per_token is not None:
163 | warnings.warn(
164 | "`chars_per_token` is deprecated and will be remove in version 0.18.0. If you want to customize the "
165 | "packing length, use `max_seq_length`.",
166 | DeprecationWarning,
167 | )
168 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/model_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Optional
17 |
18 |
19 | @dataclass
20 | class ModelConfig:
21 | """
22 | Configuration class for the models.
23 |
24 | Using [`~transformers.HfArgumentParser`] we can turn this class into
25 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
26 | command line.
27 |
28 | Parameters:
29 | model_name_or_path (`str` or `None`, *optional*, defaults to `None`):
30 | Model checkpoint for weights initialization.
31 | model_revision (`str`, *optional*, defaults to `"main"`):
32 | Specific model version to use. It can be a branch name, a tag name, or a commit id.
33 | torch_dtype (`Literal["auto", "bfloat16", "float16", "float32"]` or `None`, *optional*, defaults to `None`):
34 | Override the default `torch.dtype` and load the model under this dtype. Possible values are
35 |
36 | - `"bfloat16"`: `torch.bfloat16`
37 | - `"float16"`: `torch.float16`
38 | - `"float32"`: `torch.float32`
39 | - `"auto"`: Automatically derive the dtype from the model's weights.
40 |
41 | trust_remote_code (`bool`, *optional*, defaults to `False`):
42 | Whether to allow for custom models defined on the Hub in their own modeling files. This option should only
43 | be set to `True` for repositories you trust and in which you have read the code, as it will execute code
44 | present on the Hub on your local machine.
45 | attn_implementation (`str` or `None`, *optional*, defaults to `None`):
46 | Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in which case
47 | you must install this manually by running `pip install flash-attn --no-build-isolation`.
48 | use_peft (`bool`, *optional*, defaults to `False`):
49 | Whether to use PEFT for training.
50 | lora_r (`int`, *optional*, defaults to `16`):
51 | LoRA R value.
52 | lora_alpha (`int`, *optional*, defaults to `32`):
53 | LoRA alpha.
54 | lora_dropout (`float`, *optional*, defaults to `0.05`):
55 | LoRA dropout.
56 | lora_target_modules (`Union[str, list[str]]` or `None`, *optional*, defaults to `None`):
57 | LoRA target modules.
58 | lora_modules_to_save (`list[str]` or `None`, *optional*, defaults to `None`):
59 | Model layers to unfreeze & train.
60 | lora_task_type (`str`, *optional*, defaults to `"CAUSAL_LM"`):
61 | Task type to pass for LoRA (use `"SEQ_CLS"` for reward modeling).
62 | use_rslora (`bool`, *optional*, defaults to `False`):
63 | Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, instead of
64 | the original default value of `lora_alpha/r`.
65 | load_in_8bit (`bool`, *optional*, defaults to `False`):
66 | Whether to use 8 bit precision for the base model. Works only with LoRA.
67 | load_in_4bit (`bool`, *optional*, defaults to `False`):
68 | Whether to use 4 bit precision for the base model. Works only with LoRA.
69 | bnb_4bit_quant_type (`str`, *optional*, defaults to `"nf4"`):
70 | Quantization type (`"fp4"` or `"nf4"`).
71 | use_bnb_nested_quant (`bool`, *optional*, defaults to `False`):
72 | Whether to use nested quantization.
73 | """
74 |
75 | model_name_or_path: Optional[str] = field(
76 | default=None,
77 | metadata={"help": "Model checkpoint for weights initialization."},
78 | )
79 | model_revision: str = field(
80 | default="main",
81 | metadata={"help": "Specific model version to use. It can be a branch name, a tag name, or a commit id."},
82 | )
83 | torch_dtype: Optional[str] = field(
84 | default=None,
85 | metadata={
86 | "help": "Override the default `torch.dtype` and load the model under this dtype.",
87 | "choices": ["auto", "bfloat16", "float16", "float32"],
88 | },
89 | )
90 | trust_remote_code: bool = field(
91 | default=False,
92 | metadata={
93 | "help": "Whether to allow for custom models defined on the Hub in their own modeling files. This option "
94 | "should only be set to `True` for repositories you trust and in which you have read the code, as it will "
95 | "execute code present on the Hub on your local machine."
96 | },
97 | )
98 | attn_implementation: Optional[str] = field(
99 | default=None,
100 | metadata={
101 | "help": "Which attention implementation to use. You can run `--attn_implementation=flash_attention_2`, in "
102 | "which case you must install this manually by running `pip install flash-attn --no-build-isolation`."
103 | },
104 | )
105 | use_peft: bool = field(
106 | default=False,
107 | metadata={"help": "Whether to use PEFT for training."},
108 | )
109 | lora_r: int = field(
110 | default=16,
111 | metadata={"help": "LoRA R value."},
112 | )
113 | lora_alpha: int = field(
114 | default=32,
115 | metadata={"help": "LoRA alpha."},
116 | )
117 | lora_dropout: float = field(
118 | default=0.05,
119 | metadata={"help": "LoRA dropout."},
120 | )
121 | lora_target_modules: Optional[list[str]] = field(
122 | default=None,
123 | metadata={"help": "LoRA target modules."},
124 | )
125 | lora_modules_to_save: Optional[list[str]] = field(
126 | default=None,
127 | metadata={"help": "Model layers to unfreeze & train."},
128 | )
129 | lora_task_type: str = field(
130 | default="CAUSAL_LM",
131 | metadata={"help": "Task type to pass for LoRA (use 'SEQ_CLS' for reward modeling)."},
132 | )
133 | use_rslora: bool = field(
134 | default=False,
135 | metadata={
136 | "help": "Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to `lora_alpha/√r`, "
137 | "instead of the original default value of `lora_alpha/r`."
138 | },
139 | )
140 | load_in_8bit: bool = field(
141 | default=False,
142 | metadata={"help": "Whether to use 8 bit precision for the base model. Works only with LoRA."},
143 | )
144 | load_in_4bit: bool = field(
145 | default=False,
146 | metadata={"help": "Whether to use 4 bit precision for the base model. Works only with LoRA."},
147 | )
148 | bnb_4bit_quant_type: str = field(
149 | default="nf4",
150 | metadata={"help": "Quantization type.", "choices": ["fp4", "nf4"]},
151 | )
152 | use_bnb_nested_quant: bool = field(
153 | default=False,
154 | metadata={"help": "Whether to use nested quantization."},
155 | )
156 |
157 | def __post_init__(self):
158 | if self.load_in_8bit and self.load_in_4bit:
159 | raise ValueError("You can't use 8 bit and 4 bit precision at the same time")
160 |
161 | if hasattr(self.lora_target_modules, "__len__") and len(self.lora_target_modules) == 1:
162 | self.lora_target_modules = self.lora_target_modules[0]
163 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/online_dpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class OnlineDPOConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`OnlineDPOTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | learning_rate (`float`, *optional*, defaults to `5e-7`):
32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
33 | [`~transformers.TrainingArguments`].
34 | reward_model_path (`str` or `None`, *optional*, defaults to `None`):
35 | Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both.
36 | judge (`str` or `None`, *optional*, defaults to `None`):
37 | Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both.
38 | max_new_tokens (`int`, *optional*, defaults to `64`):
39 | Maximum number of tokens to generate per completion.
40 | max_length (`int`, *optional*, defaults to `256`):
41 | Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If the
42 | sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the completion as
43 | possible.
44 | temperature (`float`, *optional*, defaults to `0.9`):
45 | Temperature for sampling. The higher the temperature, the more random the completions.
46 | missing_eos_penalty (`float` or `None`, *optional*, defaults to `None`):
47 | Penalty applied to the score when the model fails to generate an EOS token. This is useful to encourage
48 | to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be a positive
49 | value.
50 | beta (`float` or `list[float]`, *optional*, defaults to `0.1`):
51 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the
52 | reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
53 | the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β is
54 | selected for each new epoch and the last β is used for the rest of the epochs.
55 | loss_type (`str`, *optional*, defaults to `"sigmoid"`):
56 | Type of loss to use. Possible values are:
57 |
58 | - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
59 | - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
60 |
61 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
62 | Number of processes to use for processing the dataset.
63 | disable_dropout (`bool`, *optional*, defaults to `True`):
64 | Whether to disable dropout in the model and reference model.
65 | use_vllm (`bool`, *optional*, defaults to `False`):
66 | Whether to use vLLM for generating completions. Requires vLLM to be installed (`pip install vllm`).
67 | ds3_gather_for_generation (`bool`, *optional*, defaults to `True`):
68 | This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for generation,
69 | improving generation speed. However, disabling this option allows training models that exceed the VRAM
70 | capacity of a single GPU, albeit at the cost of slower generation.
71 | """
72 |
73 | learning_rate: float = field(
74 | default=5e-7,
75 | metadata={
76 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
77 | "transformers.TrainingArguments."
78 | },
79 | )
80 | reward_model_path: Optional[str] = field(
81 | default=None,
82 | metadata={
83 | "help": "Path to the reward model. Either `judge` or `reward_model_path` must be set, but not both."
84 | },
85 | )
86 | judge: Optional[str] = field(
87 | default=None,
88 | metadata={
89 | "help": "Name of the judge to use. Either `judge` or `reward_model_path` must be set, but not both."
90 | },
91 | )
92 | max_new_tokens: int = field(
93 | default=64,
94 | metadata={"help": "Maximum number of tokens to generate per completion."},
95 | )
96 | max_length: int = field(
97 | default=512,
98 | metadata={
99 | "help": "Maximum total length of the sequence (prompt + completion) used to compute log probabilities. If "
100 | "the sequence exceeds this limit, the leftmost tokens will be truncated to preserve as much of the "
101 | "completion as possible."
102 | },
103 | )
104 | temperature: float = field(
105 | default=0.9,
106 | metadata={"help": "Temperature for sampling. The higher the temperature, the more random the completions."},
107 | )
108 | missing_eos_penalty: Optional[float] = field(
109 | default=None,
110 | metadata={
111 | "help": "Penalty applied to the score when the model fails to generate an EOS token. This is useful to "
112 | "encourage to generate completions shorter than the maximum length (`max_new_tokens`). The penalty must be "
113 | "a positive value."
114 | },
115 | )
116 | beta: list[float] = field(
117 | default_factory=lambda: [0.1],
118 | metadata={
119 | "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
120 | "the reference model. For the IPO loss (`loss_type='ipo'`), β is the regularization parameter denoted by "
121 | "τ in the [paper](https://huggingface.co/papers/2310.12036). If a list of floats is provided then the β "
122 | "is selected for each new epoch and the last β is used for the rest of the epochs."
123 | },
124 | )
125 | loss_type: str = field(
126 | default="sigmoid",
127 | metadata={
128 | "help": "Type of loss to use.",
129 | "choices": ["sigmoid", "ipo"],
130 | },
131 | )
132 | dataset_num_proc: Optional[int] = field(
133 | default=None,
134 | metadata={"help": "Number of processes to use for processing the dataset."},
135 | )
136 | disable_dropout: bool = field(
137 | default=True,
138 | metadata={"help": "Whether to disable dropout in the model."},
139 | )
140 | use_vllm: bool = field(
141 | default=False,
142 | metadata={
143 | "help": "Whether to use vLLM for generating completions. Requires vLLM to be installed "
144 | "(`pip install vllm`)."
145 | },
146 | )
147 | ds3_gather_for_generation: bool = field(
148 | default=True,
149 | metadata={
150 | "help": "This setting applies to DeepSpeed ZeRO-3. If enabled, the policy model weights are gathered for "
151 | "generation, improving generation speed. However, disabling this option allows training models that "
152 | "exceed the VRAM capacity of a single GPU, albeit at the cost of slower generation."
153 | },
154 | )
155 |
156 | def __post_init__(self):
157 | super().__post_init__()
158 | if hasattr(self.beta, "__len__") and len(self.beta) == 1:
159 | self.beta = self.beta[0]
160 |
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/cpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Any, Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class CPOConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`CPOTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | learning_rate (`float`, *optional*, defaults to `1e-6`):
32 | Initial learning rate for [`AdamW`] optimizer. The default value replaces that of
33 | [`~transformers.TrainingArguments`].
34 | max_length (`int` or `None`, *optional*, defaults to `1024`):
35 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
36 | to use the default data collator.
37 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
38 | Maximum length of the prompt. This argument is required if you want to use the default data collator.
39 | max_completion_length (`int` or `None`, *optional*, defaults to `None`):
40 | Maximum length of the completion. This argument is required if you want to use the default data collator
41 | and your model is an encoder-decoder.
42 | beta (`float`, *optional*, defaults to `0.1`):
43 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the
44 | reference model. For the IPO loss (`loss_type="ipo"`), β is the regularization parameter denoted by τ in
45 | the [paper](https://huggingface.co/papers/2310.12036).
46 | label_smoothing (`float`, *optional*, defaults to `0.0`):
47 | Label smoothing factor. This argument is required if you want to use the default data collator.
48 | loss_type (`str`, *optional*, defaults to `"sigmoid"`):
49 | Type of loss to use. Possible values are:
50 |
51 | - `"sigmoid"`: sigmoid loss from the original [DPO](https://huggingface.co/papers/2305.18290) paper.
52 | - `"hinge"`: hinge loss on the normalized likelihood from the [SLiC](https://huggingface.co/papers/2305.10425) paper.
53 | - `"ipo"`: IPO loss from the [IPO](https://huggingface.co/papers/2310.12036) paper.
54 | - `"simpo"`: SimPO loss from the [SimPO](https://huggingface.co/papers/2405.14734) paper.
55 |
56 | disable_dropout (`bool`, *optional*, defaults to `True`):
57 | Whether to disable dropout in the model.
58 | cpo_alpha (`float`, *optional*, defaults to `1.0`):
59 | Weight of the BC regularizer in CPO training.
60 | simpo_gamma (`float`, *optional*, defaults to `0.5`):
61 | Target reward margin for the SimPO loss, used only when the `loss_type="simpo"`.
62 | label_pad_token_id (`int`, *optional*, defaults to `-100`):
63 | Label pad token id. This argument is required if you want to use the default data collator.
64 | padding_value (`int` or `None`, *optional*, defaults to `None`):
65 | Padding value to use. If `None`, the padding value of the tokenizer is used.
66 | truncation_mode (`str`,*optional*, defaults to `"keep_end"`):
67 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
68 | This argument is required if you want to use the default data collator.
69 | generate_during_eval (`bool`, *optional*, defaults to `False`):
70 | If `True`, generates and logs completions from the model to W&B or Comet during evaluation.
71 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
72 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
73 | you need to specify if the model returned by the callable is an encoder-decoder model.
74 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
75 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
76 | string.
77 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
78 | Number of processes to use for processing the dataset.
79 | """
80 |
81 | learning_rate: float = field(
82 | default=1e-6,
83 | metadata={
84 | "help": "Initial learning rate for `AdamW` optimizer. The default value replaces that of "
85 | "`transformers.TrainingArguments`."
86 | },
87 | )
88 | max_length: Optional[int] = field(
89 | default=1024,
90 | metadata={"help": "Maximum length of the sequences (prompt + completion) in the batch."},
91 | )
92 | max_prompt_length: Optional[int] = field(
93 | default=512,
94 | metadata={
95 | "help": "Maximum length of the prompt. This argument is required if you want to use the default data "
96 | "collator and your model is an encoder-decoder."
97 | },
98 | )
99 | max_completion_length: Optional[int] = field(
100 | default=None,
101 | metadata={
102 | "help": "Maximum length of the completion. This argument is required if you want to use the default data "
103 | "collator and your model is an encoder-decoder."
104 | },
105 | )
106 | beta: float = field(
107 | default=0.1,
108 | metadata={
109 | "help": "Parameter controlling the deviation from the reference model. Higher β means less deviation from "
110 | "the reference model."
111 | },
112 | )
113 | label_smoothing: float = field(
114 | default=0.0,
115 | metadata={"help": "Label smoothing factor."},
116 | )
117 | loss_type: str = field(
118 | default="sigmoid",
119 | metadata={
120 | "help": "Type of loss to use.",
121 | "choices": ["sigmoid", "hinge", "ipo", "simpo"],
122 | },
123 | )
124 | disable_dropout: bool = field(
125 | default=True,
126 | metadata={"help": "Whether to disable dropout in the model."},
127 | )
128 | cpo_alpha: float = field(
129 | default=1.0,
130 | metadata={"help": "Weight of the BC regularizer in CPO training."},
131 | )
132 | simpo_gamma: float = field(
133 | default=0.5,
134 | metadata={"help": "Target reward margin for the SimPO loss, used only when the `loss_type='simpo'`."},
135 | )
136 | label_pad_token_id: int = field(
137 | default=-100,
138 | metadata={"help": "Label pad token id."},
139 | )
140 | padding_value: Optional[int] = field(
141 | default=None,
142 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
143 | )
144 | truncation_mode: str = field(
145 | default="keep_end",
146 | metadata={
147 | "help": "Truncation mode to use when the prompt is too long.",
148 | "choices": ["keep_end", "keep_start"],
149 | },
150 | )
151 | generate_during_eval: bool = field(
152 | default=False,
153 | metadata={"help": "If `True`, generates and logs completions from the model to W&B during evaluation."},
154 | )
155 | is_encoder_decoder: Optional[bool] = field(
156 | default=None,
157 | metadata={"help": "Whether the model is an encoder-decoder model."},
158 | )
159 | model_init_kwargs: Optional[dict[str, Any]] = field(
160 | default=None,
161 | metadata={
162 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model "
163 | "from a string."
164 | },
165 | )
166 | dataset_num_proc: Optional[int] = field(
167 | default=None,
168 | metadata={"help": "Number of processes to use for processing the dataset."},
169 | )
170 |
--------------------------------------------------------------------------------
/src/open_r1/evaluate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Custom evaluation tasks for LightEval."""
16 |
17 | import random
18 |
19 | from lighteval.metrics.dynamic_metrics import (
20 | ExprExtractionConfig,
21 | IndicesExtractionConfig,
22 | LatexExtractionConfig,
23 | multilingual_extractive_match_metric,
24 | )
25 | from lighteval.tasks.lighteval_task import LightevalTaskConfig
26 | from lighteval.tasks.requests import Doc
27 | from lighteval.utils.language import Language
28 |
29 |
30 | # Prompt template adapted from
31 | # - simple-evals: https://github.com/openai/simple-evals/blob/6e84f4e2aed6b60f6a0c7b8f06bbbf4bfde72e58/math_eval.py#L17
32 | # - Llama 3: https://huggingface.co/datasets/meta-llama/Llama-3.2-1B-Instruct-evals/viewer/Llama-3.2-1B-Instruct-evals__math__details?views%5B%5D=llama_32_1b_instruct_evals__math__details
33 | # Note that it is important to have the final answer in a box for math-verify to work correctly
34 | MATH_QUERY_TEMPLATE = """
35 | Solve the following math problem efficiently and clearly. The last line of your response should be of the following format: 'Therefore, the final answer is: $\\boxed{{ANSWER}}$. I hope it is correct' (without quotes) where ANSWER is just the final number or expression that solves the problem. Think step by step before answering.
36 |
37 | {Question}
38 | """.strip()
39 |
40 | # Prompt template from simple-evals: https://github.com/openai/simple-evals/blob/83ed7640a7d9cd26849bcb3340125002ef14abbe/common.py#L14
41 | GPQA_QUERY_TEMPLATE = """
42 | Answer the following multiple choice question. The last line of your response should be of the following format: 'Answer: $LETTER' (without quotes) where LETTER is one of ABCD. Think step by step before answering.
43 |
44 | {Question}
45 |
46 | A) {A}
47 | B) {B}
48 | C) {C}
49 | D) {D}
50 | """.strip()
51 |
52 | latex_gold_metric = multilingual_extractive_match_metric(
53 | language=Language.ENGLISH,
54 | fallback_mode="first_match",
55 | precision=5,
56 | gold_extraction_target=(LatexExtractionConfig(),),
57 | # Match boxed first before trying other regexes
58 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
59 | aggregation_function=max,
60 | )
61 |
62 | expr_gold_metric = multilingual_extractive_match_metric(
63 | language=Language.ENGLISH,
64 | fallback_mode="first_match",
65 | precision=5,
66 | gold_extraction_target=(ExprExtractionConfig(),),
67 | # Match boxed first before trying other regexes
68 | pred_extraction_target=(ExprExtractionConfig(), LatexExtractionConfig(boxed_match_priority=0)),
69 | aggregation_function=max,
70 | )
71 |
72 | gpqa_metric = multilingual_extractive_match_metric(
73 | language=Language.ENGLISH,
74 | gold_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
75 | pred_extraction_target=[IndicesExtractionConfig(prefix_for_extraction="NativeLetters")],
76 | precision=5,
77 | )
78 |
79 |
80 | def math_prompt_fn(line, task_name: str = None):
81 | return Doc(
82 | task_name=task_name,
83 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
84 | choices=[line["solution"]],
85 | gold_index=0,
86 | )
87 |
88 |
89 | def aime_prompt_fn(line, task_name: str = None):
90 | return Doc(
91 | task_name=task_name,
92 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
93 | choices=[line["answer"]],
94 | gold_index=0,
95 | )
96 |
97 |
98 | def amc_prompt_fn(line, task_name: str = None):
99 | return Doc(
100 | task_name=task_name,
101 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
102 | choices=[line["answer"]],
103 | gold_index=0,
104 | )
105 |
106 |
107 | def minerva_prompt_fn(line, task_name: str = None):
108 | return Doc(
109 | task_name=task_name,
110 | query=MATH_QUERY_TEMPLATE.format(Question=line["problem"]),
111 | choices=[line["solution"]],
112 | gold_index=0,
113 | )
114 |
115 |
116 | def olympiadbench_prompt_fn(line, task_name: str = None):
117 | return Doc(
118 | task_name=task_name,
119 | query=MATH_QUERY_TEMPLATE.format(Question=line["question"]),
120 | choices=[line["answer"]],
121 | gold_index=0,
122 | )
123 |
124 |
125 | def gpqa_prompt_fn(line, task_name: str = None):
126 | gold_index = random.randint(0, 3)
127 | choices = [line["Incorrect Answer 1"], line["Incorrect Answer 2"], line["Incorrect Answer 3"]]
128 | choices.insert(gold_index, line["Correct Answer"])
129 | query = GPQA_QUERY_TEMPLATE.format(
130 | A=choices[0], B=choices[1], C=choices[2], D=choices[3], Question=line["Question"]
131 | )
132 | return Doc(
133 | task_name=task_name,
134 | query=query,
135 | choices=["A", "B", "C", "D"],
136 | gold_index=gold_index,
137 | instruction=query,
138 | )
139 |
140 |
141 | # Define tasks
142 | aime24 = LightevalTaskConfig(
143 | name="aime24",
144 | suite=["custom"],
145 | prompt_function=aime_prompt_fn,
146 | hf_repo="HuggingFaceH4/aime_2024",
147 | hf_subset="default",
148 | hf_avail_splits=["train"],
149 | evaluation_splits=["train"],
150 | few_shots_split=None,
151 | few_shots_select=None,
152 | generation_size=32768,
153 | metric=[expr_gold_metric],
154 | version=1,
155 | )
156 | aime25 = LightevalTaskConfig(
157 | name="aime25",
158 | suite=["custom"],
159 | prompt_function=aime_prompt_fn,
160 | hf_repo="yentinglin/aime_2025",
161 | hf_subset="default",
162 | hf_avail_splits=["train"],
163 | evaluation_splits=["train"],
164 | few_shots_split=None,
165 | few_shots_select=None,
166 | generation_size=32768,
167 | metric=[expr_gold_metric],
168 | version=1,
169 | )
170 | math_500 = LightevalTaskConfig(
171 | name="math_500",
172 | suite=["custom"],
173 | prompt_function=math_prompt_fn,
174 | hf_repo="HuggingFaceH4/MATH-500",
175 | hf_subset="default",
176 | hf_avail_splits=["test"],
177 | evaluation_splits=["test"],
178 | few_shots_split=None,
179 | few_shots_select=None,
180 | generation_size=32768,
181 | metric=[latex_gold_metric],
182 | version=1,
183 | )
184 | gpqa_diamond = LightevalTaskConfig(
185 | name="gpqa:diamond",
186 | suite=["custom"],
187 | prompt_function=gpqa_prompt_fn,
188 | hf_repo="Idavidrein/gpqa",
189 | hf_subset="gpqa_diamond",
190 | hf_avail_splits=["train"],
191 | evaluation_splits=["train"],
192 | few_shots_split=None,
193 | few_shots_select=None,
194 | generation_size=32768, # needed for reasoning models like R1
195 | metric=[gpqa_metric],
196 | stop_sequence=[], # no stop sequence, will use eos token
197 | trust_dataset=True,
198 | version=1,
199 | )
200 | minerva = LightevalTaskConfig(
201 | name="minerva",
202 | suite=["custom"],
203 | prompt_function=minerva_prompt_fn,
204 | hf_repo="knoveleng/Minerva-Math",
205 | hf_subset="default",
206 | hf_avail_splits=["train"],
207 | evaluation_splits=["train"],
208 | few_shots_split=None,
209 | few_shots_select=None,
210 | generation_size=32768,
211 | metric=[latex_gold_metric],
212 | version=1,
213 | )
214 | amc23 = LightevalTaskConfig(
215 | name="amc23",
216 | suite=["custom"],
217 | prompt_function=amc_prompt_fn,
218 | hf_repo="knoveleng/AMC-23",
219 | hf_subset="default",
220 | hf_avail_splits=["train"],
221 | evaluation_splits=["train"],
222 | few_shots_split=None,
223 | few_shots_select=None,
224 | generation_size=32768,
225 | metric=[expr_gold_metric],
226 | version=1,
227 | )
228 | olympiadbench = LightevalTaskConfig(
229 | name="olympiadbench",
230 | suite=["custom"],
231 | prompt_function=olympiadbench_prompt_fn,
232 | hf_repo="knoveleng/OlympiadBench",
233 | hf_subset="default",
234 | hf_avail_splits=["train"],
235 | evaluation_splits=["train"],
236 | few_shots_split=None,
237 | few_shots_select=None,
238 | generation_size=32768,
239 | metric=[latex_gold_metric],
240 | version=1,
241 | )
242 |
243 | # Add tasks to the table
244 | TASKS_TABLE = []
245 | TASKS_TABLE.append(aime24)
246 | TASKS_TABLE.append(aime25)
247 | TASKS_TABLE.append(math_500)
248 | TASKS_TABLE.append(gpqa_diamond)
249 | TASKS_TABLE.append(minerva)
250 | TASKS_TABLE.append(amc23)
251 | TASKS_TABLE.append(olympiadbench)
252 |
253 | # MODULE LOGIC
254 | if __name__ == "__main__":
255 | print([t["name"] for t in TASKS_TABLE])
256 | print(len(TASKS_TABLE))
--------------------------------------------------------------------------------
/src/open_r1/trl/trainer/bco_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 | from typing import Any, Optional
17 |
18 | from transformers import TrainingArguments
19 |
20 |
21 | @dataclass
22 | class BCOConfig(TrainingArguments):
23 | r"""
24 | Configuration class for the [`BCOTrainer`].
25 |
26 | Using [`~transformers.HfArgumentParser`] we can turn this class into
27 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the
28 | command line.
29 |
30 | Parameters:
31 | max_length (`int` or `None`, *optional*, defaults to `1024`):
32 | Maximum length of the sequences (prompt + completion) in the batch. This argument is required if you want
33 | to use the default data collator.
34 | max_prompt_length (`int` or `None`, *optional*, defaults to `512`):
35 | Maximum length of the prompt. This argument is required if you want to use the default data collator.
36 | max_completion_length (`int` or `None`, *optional*, defaults to `None`):
37 | Maximum length of the completion. This argument is required if you want to use the default data collator
38 | and your model is an encoder-decoder.
39 | beta (`float`, *optional*, defaults to `0.1`):
40 | Parameter controlling the deviation from the reference model. Higher β means less deviation from the
41 | reference model.
42 | label_pad_token_id (`int`, *optional*, defaults to `-100`):
43 | Label pad token id. This argument is required if you want to use the default data collator.
44 | padding_value (`int` or `None`, *optional*, defaults to `None`):
45 | Padding value to use. If `None`, the padding value of the tokenizer is used.
46 | truncation_mode (`str`, *optional*, defaults to `"keep_end"`):
47 | Truncation mode to use when the prompt is too long. Possible values are `"keep_end"` or `"keep_start"`.
48 | This argument is required if you want to use the default data collator.
49 | disable_dropout (`bool`, *optional*, defaults to `True`):
50 | Whether to disable dropout in the model and reference model.
51 | generate_during_eval (`bool`, *optional*, defaults to `False`):
52 | If `True`, generates and logs completions from both the model and the reference model to W&B or Comet during
53 | evaluation.
54 | is_encoder_decoder (`bool` or `None`, *optional*, defaults to `None`):
55 | When using the `model_init` argument (callable) to instantiate the model instead of the `model` argument,
56 | you need to specify if the model returned by the callable is an encoder-decoder model.
57 | precompute_ref_log_probs (`bool`, *optional*, defaults to `False`):
58 | Whether to precompute reference model log probabilities for training and evaluation datasets. This is
59 | useful when training without the reference model to reduce the total GPU memory needed.
60 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
61 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the model from a
62 | string.
63 | ref_model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`):
64 | Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the reference model
65 | from a string.
66 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`):
67 | Number of processes to use for processing the dataset.
68 | prompt_sample_size (`int`, *optional*, defaults to `1024`):
69 | Number of prompts that are fed to density ratio classifier.
70 | min_density_ratio (`float`, *optional*, defaults to `0.5`):
71 | Minimum value of the density ratio. The estimated density ratio is clamped to this value.
72 | max_density_ratio (`float`, *optional*, defaults to `10.0`):
73 | Maximum value of the density ratio. The estimated density ratio is clamped to this value.
74 | """
75 |
76 | max_length: Optional[int] = field(
77 | default=1024,
78 | metadata={
79 | "help": "Maximum length of the sequences (prompt + completion) in the batch. "
80 | "This argument is required if you want to use the default data collator."
81 | },
82 | )
83 | max_prompt_length: Optional[int] = field(
84 | default=512,
85 | metadata={
86 | "help": "Maximum length of the prompt. "
87 | "This argument is required if you want to use the default data collator."
88 | },
89 | )
90 | max_completion_length: Optional[int] = field(
91 | default=None,
92 | metadata={
93 | "help": "Maximum length of the completion. This argument is required if you want to use the "
94 | "default data collator and your model is an encoder-decoder."
95 | },
96 | )
97 | beta: float = field(
98 | default=0.1,
99 | metadata={
100 | "help": "Parameter controlling the deviation from the reference model. "
101 | "Higher β means less deviation from the reference model."
102 | },
103 | )
104 | label_pad_token_id: int = field(
105 | default=-100,
106 | metadata={
107 | "help": "Label pad token id. This argument is required if you want to use the default data collator."
108 | },
109 | )
110 | padding_value: Optional[int] = field(
111 | default=None,
112 | metadata={"help": "Padding value to use. If `None`, the padding value of the tokenizer is used."},
113 | )
114 | truncation_mode: str = field(
115 | default="keep_end",
116 | metadata={
117 | "help": "Truncation mode to use when the prompt is too long. Possible values are "
118 | "`keep_end` or `keep_start`. This argument is required if you want to use the "
119 | "default data collator."
120 | },
121 | )
122 | disable_dropout: bool = field(
123 | default=True,
124 | metadata={"help": "Whether to disable dropout in the model and reference model."},
125 | )
126 | generate_during_eval: bool = field(
127 | default=False,
128 | metadata={
129 | "help": "If `True`, generates and logs completions from both the model and the reference model "
130 | "to W&B during evaluation."
131 | },
132 | )
133 | is_encoder_decoder: Optional[bool] = field(
134 | default=None,
135 | metadata={
136 | "help": "When using the `model_init` argument (callable) to instantiate the model instead of the "
137 | "`model` argument, you need to specify if the model returned by the callable is an "
138 | "encoder-decoder model."
139 | },
140 | )
141 | precompute_ref_log_probs: bool = field(
142 | default=False,
143 | metadata={
144 | "help": "Whether to precompute reference model log probabilities for training and evaluation datasets. "
145 | "This is useful when training without the reference model to reduce the total GPU memory "
146 | "needed."
147 | },
148 | )
149 | model_init_kwargs: Optional[dict[str, Any]] = field(
150 | default=None,
151 | metadata={
152 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
153 | "model from a string."
154 | },
155 | )
156 | ref_model_init_kwargs: Optional[dict[str, Any]] = field(
157 | default=None,
158 | metadata={
159 | "help": "Keyword arguments to pass to `AutoModelForCausalLM.from_pretrained` when instantiating the "
160 | "reference model from a string."
161 | },
162 | )
163 | dataset_num_proc: Optional[int] = field(
164 | default=None,
165 | metadata={"help": "Number of processes to use for processing the dataset."},
166 | )
167 | prompt_sample_size: int = field(
168 | default=1024,
169 | metadata={"help": "Number of prompts that are fed to density ratio classifier."},
170 | )
171 | min_density_ratio: float = field(
172 | default=0.5,
173 | metadata={"help": "Minimum value of the density ratio. The estimated density ratio is clamped to this value."},
174 | )
175 | max_density_ratio: float = field(
176 | default=10.0,
177 | metadata={"help": "Maximum value of the density ratio. The estimated density ratio is clamped to this value."},
178 | )
179 |
--------------------------------------------------------------------------------