├── trl ├── py.typed ├── accelerate_configs │ ├── single_gpu.yaml │ ├── multi_gpu.yaml │ ├── zero1.yaml │ ├── zero2.yaml │ ├── zero3.yaml │ ├── fsdp2.yaml │ └── fsdp1.yaml ├── extras │ ├── __init__.py │ ├── dataset_formatting.py │ └── profiling.py ├── experimental │ ├── gspo_token │ │ └── __init__.py │ ├── bco │ │ └── __init__.py │ ├── gfpo │ │ ├── __init__.py │ │ └── gfpo_config.py │ ├── papo │ │ ├── __init__.py │ │ └── papo_config.py │ ├── bema_for_ref_model │ │ ├── __init__.py │ │ └── dpo_trainer.py │ ├── openenv │ │ └── __init__.py │ ├── cpo │ │ └── __init__.py │ ├── gkd │ │ └── __init__.py │ ├── kto │ │ └── __init__.py │ ├── prm │ │ └── __init__.py │ ├── xpo │ │ ├── __init__.py │ │ └── xpo_config.py │ ├── gold │ │ └── __init__.py │ ├── orpo │ │ └── __init__.py │ ├── minillm │ │ └── __init__.py │ ├── nash_md │ │ ├── __init__.py │ │ └── nash_md_config.py │ ├── online_dpo │ │ └── __init__.py │ ├── grpo_with_replay_buffer │ │ ├── __init__.py │ │ └── grpo_with_replay_buffer_config.py │ ├── ppo │ │ └── __init__.py │ ├── judges │ │ └── __init__.py │ └── __init__.py ├── scripts │ ├── __init__.py │ └── env.py ├── rewards │ ├── __init__.py │ ├── format_rewards.py │ └── other_rewards.py ├── trainer │ ├── prm_config.py │ ├── xpo_config.py │ ├── prm_trainer.py │ ├── xpo_trainer.py │ ├── bco_config.py │ ├── cpo_config.py │ ├── gkd_config.py │ ├── ppo_config.py │ ├── orpo_config.py │ ├── bco_trainer.py │ ├── cpo_trainer.py │ ├── gkd_trainer.py │ ├── nash_md_config.py │ ├── ppo_trainer.py │ ├── orpo_trainer.py │ ├── nash_md_trainer.py │ ├── online_dpo_config.py │ ├── kto_config.py │ ├── online_dpo_trainer.py │ ├── kto_trainer.py │ └── base_trainer.py ├── models │ ├── __init__.py │ ├── modeling_base.py │ └── modeling_value_head.py ├── templates │ ├── rm_model_card.md │ └── lm_model_card.md └── mergekit_utils.py ├── VERSION ├── requirements.txt ├── assets ├── logo-dark.png └── logo-light.png ├── docs └── source │ ├── winrate_callback.md │ ├── merge_model_callback.md │ ├── others.md │ ├── model_utils.md │ ├── callbacks.md │ ├── script_utils.md │ ├── rewards.md │ ├── chat_template_utils.md │ ├── installation.md │ ├── gspo_token.md │ ├── data_utils.md │ ├── bema_for_reference_model.md │ ├── deepspeed_integration.md │ ├── experimental_overview.md │ ├── gfpo.md │ ├── grpo_with_replay_buffer.md │ ├── trackio_integration.md │ ├── papo_trainer.md │ ├── liger_kernel_integration.md │ ├── use_model.md │ ├── judges.md │ └── _toctree.yml ├── examples ├── README.md ├── accelerate_configs │ ├── single_gpu.yaml │ ├── multi_gpu.yaml │ ├── deepspeed_zero1.yaml │ ├── deepspeed_zero2.yaml │ ├── deepspeed_zero3.yaml │ ├── fsdp2.yaml │ ├── fsdp1.yaml │ ├── context_parallel_2gpu.yaml │ └── alst_ulysses_4gpu.yaml ├── cli_configs │ └── example_config.yaml ├── scripts │ ├── dpo.py │ ├── sft.py │ ├── sft_gemma3.py │ ├── sft_gpt_oss.py │ └── rloo.py ├── notebooks │ └── README.md └── datasets │ └── deepmath_103k.py ├── MANIFEST.in ├── docker ├── trl │ └── Dockerfile └── trl-dev │ └── Dockerfile ├── .github ├── workflows │ ├── issue_auto_labeller.yml │ ├── upload_pr_documentation.yml │ ├── build_documentation.yml │ ├── trufflehog.yml │ ├── build_pr_documentation.yml │ ├── codeQL.yml │ ├── clear_cache.yml │ ├── publish.yml │ ├── tests-experimental.yml │ ├── tests_latest.yml │ ├── docker-build.yml │ └── slow-tests.yml ├── codeql │ └── custom-queries.qls ├── ISSUE_TEMPLATE │ ├── feature-request.yml │ ├── new-trainer-addition.yml │ └── bug-report.yml └── PULL_REQUEST_TEMPLATE.md ├── .pre-commit-config.yaml ├── tests ├── __init__.py ├── experimental │ ├── __init__.py │ ├── test_minillm_trainer.py │ ├── test_gspo_token_trainer.py │ ├── test_merge_model_callback.py │ └── test_judges.py ├── testing_constants.py ├── conftest.py ├── test_model_utils.py ├── test_rich_progress_callback.py └── test_collators.py ├── Makefile ├── CITATION.cff ├── .gitignore └── scripts └── add_copyrights.py /trl/py.typed: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 0.27.0.dev0 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate>=1.4.0 2 | datasets>=3.0.0 3 | transformers>=4.56.1 -------------------------------------------------------------------------------- /assets/logo-dark.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/trl/HEAD/assets/logo-dark.png -------------------------------------------------------------------------------- /assets/logo-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huggingface/trl/HEAD/assets/logo-light.png -------------------------------------------------------------------------------- /docs/source/winrate_callback.md: -------------------------------------------------------------------------------- 1 | # WinRateCallback 2 | 3 | [[autodoc]] experimental.winrate_callback.WinRateCallback 4 | -------------------------------------------------------------------------------- /docs/source/merge_model_callback.md: -------------------------------------------------------------------------------- 1 | # MergeModelCallback 2 | 3 | [[autodoc]] experimental.merge_model_callback.MergeModelCallback 4 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples. 4 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include CONTRIBUTING.md 3 | include README.md 4 | include trl/accelerate_configs/*.yaml 5 | include trl/templates/*.md 6 | recursive-exclude * __pycache__ 7 | prune tests 8 | -------------------------------------------------------------------------------- /docs/source/others.md: -------------------------------------------------------------------------------- 1 | # Other 2 | 3 | ## profiling_decorator 4 | 5 | [[autodoc]] extras.profiling.profiling_decorator 6 | 7 | ## profiling_context 8 | 9 | [[autodoc]] extras.profiling.profiling_context 10 | -------------------------------------------------------------------------------- /docker/trl/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 2 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 3 | RUN pip install --upgrade pip uv 4 | RUN uv pip install --system trl[liger,peft,vlm] kernels trackio -------------------------------------------------------------------------------- /docs/source/model_utils.md: -------------------------------------------------------------------------------- 1 | # Model Utilities 2 | 3 | ## get_act_offloading_ctx_manager 4 | 5 | [[autodoc]] models.get_act_offloading_ctx_manager 6 | 7 | ## disable_gradient_checkpointing 8 | 9 | [[autodoc]] models.utils.disable_gradient_checkpointing 10 | 11 | ## create_reference_model 12 | 13 | [[autodoc]] create_reference_model 14 | -------------------------------------------------------------------------------- /docker/trl-dev/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 2 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* 3 | RUN pip install --upgrade pip uv 4 | RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]" 5 | RUN uv pip install --system kernels liger_kernel peft trackio -------------------------------------------------------------------------------- /docs/source/callbacks.md: -------------------------------------------------------------------------------- 1 | # Callbacks 2 | 3 | ## SyncRefModelCallback 4 | 5 | [[autodoc]] SyncRefModelCallback 6 | 7 | ## RichProgressCallback 8 | 9 | [[autodoc]] RichProgressCallback 10 | 11 | ## LogCompletionsCallback 12 | 13 | [[autodoc]] LogCompletionsCallback 14 | 15 | ## BEMACallback 16 | 17 | [[autodoc]] BEMACallback 18 | 19 | ## WeaveCallback 20 | 21 | [[autodoc]] WeaveCallback 22 | -------------------------------------------------------------------------------- /.github/workflows/issue_auto_labeller.yml: -------------------------------------------------------------------------------- 1 | name: "Hugging Face Issue Labeler" 2 | on: 3 | issues: 4 | types: opened 5 | 6 | jobs: 7 | triage: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: August-murr/auto-labeler@main 14 | with: 15 | hf-api-key: ${{ secrets.CI_HF_API_TOKEN }} 16 | -------------------------------------------------------------------------------- /trl/accelerate_configs/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: "NO" 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /examples/accelerate_configs/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: "NO" 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /trl/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /examples/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: trl 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /docs/source/script_utils.md: -------------------------------------------------------------------------------- 1 | # Scripts Utilities 2 | 3 | ## ScriptArguments 4 | 5 | [[autodoc]] ScriptArguments 6 | 7 | ## TrlParser 8 | 9 | [[autodoc]] TrlParser 10 | - parse_args_and_config 11 | - parse_args_into_dataclasses 12 | - set_defaults_with_config 13 | 14 | ## get_dataset 15 | 16 | [[autodoc]] get_dataset 17 | 18 | ## DatasetConfig 19 | 20 | [[autodoc]] scripts.utils.DatasetConfig 21 | 22 | ## DatasetMixtureConfig 23 | 24 | [[autodoc]] DatasetMixtureConfig 25 | -------------------------------------------------------------------------------- /examples/cli_configs/example_config.yaml: -------------------------------------------------------------------------------- 1 | # This is an example configuration file of TRL CLI, you can use it for 2 | # SFT like that: `trl sft --config config.yaml --output_dir test-sft` 3 | # The YAML file supports environment variables by adding an `env` field 4 | # as below 5 | 6 | # env: 7 | # CUDA_VISIBLE_DEVICES: 0 8 | 9 | model_name_or_path: 10 | Qwen/Qwen2.5-0.5B 11 | dataset_name: 12 | stanfordnlp/imdb 13 | report_to: 14 | none 15 | learning_rate: 16 | 0.0001 17 | lr_scheduler_type: 18 | cosine 19 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | env: 11 | TRL_EXPERIMENTAL_SILENCE: 1 12 | 13 | jobs: 14 | build: 15 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 16 | with: 17 | commit_sha: ${{ github.sha }} 18 | package: trl 19 | version_tag_suffix: "" 20 | secrets: 21 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 22 | -------------------------------------------------------------------------------- /docs/source/rewards.md: -------------------------------------------------------------------------------- 1 | # Reward Functions 2 | 3 | This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`]. 4 | 5 | ## accuracy_reward 6 | 7 | [[autodoc]] rewards.accuracy_reward 8 | 9 | ## reasoning_accuracy_reward 10 | 11 | [[autodoc]] rewards.reasoning_accuracy_reward 12 | 13 | ## think_format_reward 14 | 15 | [[autodoc]] rewards.think_format_reward 16 | 17 | ## get_soft_overlong_punishment 18 | 19 | [[autodoc]] rewards.get_soft_overlong_punishment 20 | -------------------------------------------------------------------------------- /docs/source/chat_template_utils.md: -------------------------------------------------------------------------------- 1 | # Chat template utilities 2 | 3 | ## clone_chat_template 4 | 5 | [[autodoc]] clone_chat_template 6 | 7 | ## add_response_schema 8 | 9 | [[autodoc]] chat_template_utils.add_response_schema 10 | 11 | ## is_chat_template_prefix_preserving 12 | 13 | [[autodoc]] chat_template_utils.is_chat_template_prefix_preserving 14 | 15 | ## get_training_chat_template 16 | 17 | [[autodoc]] chat_template_utils.get_training_chat_template 18 | 19 | ## parse_response 20 | 21 | [[autodoc]] chat_template_utils.parse_response 22 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.13.3 4 | hooks: 5 | - id: ruff-check 6 | types_or: [ python, pyi ] 7 | args: [ --fix ] 8 | - id: ruff-format 9 | types_or: [ python, pyi ] 10 | 11 | # - repo: https://github.com/codespell-project/codespell 12 | # rev: v2.1.0 13 | # hooks: 14 | # - id: codespell 15 | # args: 16 | # - --ignore-words-list=nd,reacher,thist,ths,magent,ba 17 | # - --skip=docs/css/termynal.css,docs/js/termynal.js 18 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d 16 | with: 17 | # exclude buggy postgres detector that is causing false positives and not relevant to our codebase 18 | extra_args: --results=verified,unknown --exclude-detectors=postgres 19 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | env: 7 | TRL_EXPERIMENTAL_SILENCE: 1 8 | 9 | concurrency: 10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 11 | cancel-in-progress: true 12 | 13 | jobs: 14 | build: 15 | if: github.event.pull_request.draft == false 16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 17 | with: 18 | commit_sha: ${{ github.event.pull_request.head.sha }} 19 | pr_number: ${{ github.event.number }} 20 | package: trl 21 | version_tag_suffix: "" 22 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test precommit common_tests slow_tests tests_gpu test_experimental 2 | 3 | check_dirs := examples tests trl 4 | 5 | ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs 6 | 7 | test: 8 | pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/ 9 | 10 | precommit: 11 | python scripts/add_copyrights.py 12 | pre-commit run --all-files 13 | doc-builder style trl tests docs/source --max_len 119 14 | 15 | slow_tests: 16 | pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",) 17 | 18 | test_experimental: 19 | pytest -k "experimental" -n auto -s -v -------------------------------------------------------------------------------- /.github/workflows/codeQL.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL Analysis - Workflows" 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | analyze: 8 | name: "Analyze GitHub Workflows" 9 | runs-on: ubuntu-latest 10 | permissions: 11 | security-events: write 12 | actions: read 13 | contents: read 14 | 15 | steps: 16 | - name: "Checkout repository" 17 | uses: actions/checkout@v4 18 | 19 | - name: "Initialize CodeQL" 20 | uses: github/codeql-action/init@v2 21 | with: 22 | languages: "yaml" 23 | queries: +security-and-quality, ./.github/codeql/custom-queries.qls 24 | 25 | - name: "Perform CodeQL Analysis" 26 | uses: github/codeql-action/analyze@v2 27 | -------------------------------------------------------------------------------- /trl/experimental/gspo_token/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .grpo_trainer import GRPOTrainer 16 | -------------------------------------------------------------------------------- /trl/accelerate_configs/fsdp2.yaml: -------------------------------------------------------------------------------- 1 | # Requires accelerate 1.7.0 or higher 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | enable_cpu_affinity: false 7 | fsdp_config: 8 | fsdp_activation_checkpointing: false 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_offload_params: false 12 | fsdp_reshard_after_forward: true 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_version: 2 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /examples/accelerate_configs/fsdp2.yaml: -------------------------------------------------------------------------------- 1 | # Requires accelerate 1.7.0 or higher 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | enable_cpu_affinity: false 7 | fsdp_config: 8 | fsdp_activation_checkpointing: false 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_offload_params: false 12 | fsdp_reshard_after_forward: true 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_version: 2 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /trl/experimental/bco/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .bco_config import BCOConfig 16 | from .bco_trainer import BCOTrainer 17 | -------------------------------------------------------------------------------- /trl/experimental/gfpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .gfpo_config import GFPOConfig 16 | from .gfpo_trainer import GFPOTrainer 17 | -------------------------------------------------------------------------------- /trl/experimental/papo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | from .papo_config import PAPOConfig 17 | from .papo_trainer import PAPOTrainer 18 | -------------------------------------------------------------------------------- /trl/experimental/bema_for_ref_model/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .callback import BEMACallback 16 | from .dpo_trainer import DPOTrainer 17 | -------------------------------------------------------------------------------- /trl/experimental/openenv/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .utils import generate_rollout_completions 16 | 17 | 18 | __all__ = ["generate_rollout_completions"] 19 | -------------------------------------------------------------------------------- /trl/experimental/cpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .cpo_config import CPOConfig 16 | from .cpo_trainer import CPOTrainer 17 | 18 | 19 | __all__ = ["CPOConfig", "CPOTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/gkd/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .gkd_config import GKDConfig 16 | from .gkd_trainer import GKDTrainer 17 | 18 | 19 | __all__ = ["GKDConfig", "GKDTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/kto/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .kto_config import KTOConfig 16 | from .kto_trainer import KTOTrainer 17 | 18 | 19 | __all__ = ["KTOConfig", "KTOTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/prm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .prm_config import PRMConfig 16 | from .prm_trainer import PRMTrainer 17 | 18 | 19 | __all__ = ["PRMConfig", "PRMTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/xpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .xpo_config import XPOConfig 16 | from .xpo_trainer import XPOTrainer 17 | 18 | 19 | __all__ = ["XPOConfig", "XPOTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/gold/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .gold_config import GOLDConfig 16 | from .gold_trainer import GOLDTrainer 17 | 18 | 19 | __all__ = ["GOLDConfig", "GOLDTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/orpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .orpo_config import ORPOConfig 16 | from .orpo_trainer import ORPOTrainer 17 | 18 | 19 | __all__ = ["ORPOConfig", "ORPOTrainer"] 20 | -------------------------------------------------------------------------------- /tests/testing_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" 16 | CI_HUB_USER_FULL_NAME = "Dummy User" 17 | 18 | CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" 19 | -------------------------------------------------------------------------------- /trl/experimental/minillm/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .minillm_config import MiniLLMConfig 16 | from .minillm_trainer import MiniLLMTrainer 17 | 18 | 19 | __all__ = ["MiniLLMConfig", "MiniLLMTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/nash_md/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .nash_md_config import NashMDConfig 16 | from .nash_md_trainer import NashMDTrainer 17 | 18 | 19 | __all__ = ["NashMDConfig", "NashMDTrainer"] 20 | -------------------------------------------------------------------------------- /trl/experimental/online_dpo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .online_dpo_config import OnlineDPOConfig 16 | from .online_dpo_trainer import OnlineDPOTrainer 17 | 18 | 19 | __all__ = ["OnlineDPOConfig", "OnlineDPOTrainer"] 20 | -------------------------------------------------------------------------------- /trl/accelerate_configs/fsdp1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: false 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_reshard_after_forward: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | fsdp_version: 1 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /trl/experimental/grpo_with_replay_buffer/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig 16 | from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer 17 | -------------------------------------------------------------------------------- /examples/accelerate_configs/fsdp1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: false 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_reshard_after_forward: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | fsdp_version: 1 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /examples/scripts/dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ############################################################################################### 16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # 17 | ############################################################################################### 18 | -------------------------------------------------------------------------------- /examples/scripts/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ############################################################################################### 16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # 17 | ############################################################################################### 18 | -------------------------------------------------------------------------------- /.github/workflows/clear_cache.yml: -------------------------------------------------------------------------------- 1 | name: "Cleanup Cache" 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 0 * * *" 7 | 8 | jobs: 9 | cleanup: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out code 13 | uses: actions/checkout@v4 14 | 15 | - name: Cleanup 16 | run: | 17 | gh extension install actions/gh-actions-cache 18 | 19 | REPO=${{ github.repository }} 20 | 21 | echo "Fetching list of cache key" 22 | cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) 23 | 24 | ## Setting this to not fail the workflow while deleting cache keys. 25 | set +e 26 | echo "Deleting caches..." 27 | for cacheKey in $cacheKeysForPR 28 | do 29 | gh actions-cache delete $cacheKey -R $REPO --confirm 30 | done 31 | echo "Done" 32 | env: 33 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | 3 | You can install TRL either from PyPI or from source: 4 | 5 | ## PyPI 6 | 7 | Install the library with pip or [uv](https://docs.astral.sh/uv/): 8 | 9 | 10 | 11 | 12 | uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions. 13 | 14 | ```bash 15 | uv pip install trl 16 | ``` 17 | 18 | 19 | 20 | 21 | ```bash 22 | pip install trl 23 | ``` 24 | 25 | 26 | 27 | 28 | ## Source 29 | 30 | You can also install the latest version from source. First clone the repo and then run the installation with `pip`: 31 | 32 | ```bash 33 | git clone https://github.com/huggingface/trl.git 34 | cd trl/ 35 | pip install -e . 36 | ``` 37 | 38 | If you want the development install you can replace the pip install with the following: 39 | 40 | ```bash 41 | pip install -e ".[dev]" 42 | ``` 43 | -------------------------------------------------------------------------------- /examples/accelerate_configs/context_parallel_2gpu.yaml: -------------------------------------------------------------------------------- 1 | # Context Parallelism with FSDP for 2 GPUs 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | enable_cpu_affinity: false 7 | fsdp_config: 8 | fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_offload_params: false 12 | fsdp_reshard_after_forward: true 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_version: 2 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 2 # Number of GPUs 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | parallelism_config: 27 | parallelism_config_dp_replicate_size: 1 28 | parallelism_config_dp_shard_size: 1 29 | parallelism_config_tp_size: 1 30 | parallelism_config_cp_size: 2 # Context parallel size 31 | -------------------------------------------------------------------------------- /docs/source/gspo_token.md: -------------------------------------------------------------------------------- 1 | # GSPO-token 2 | 3 | In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`. 4 | 5 | ## Usage 6 | 7 | ```python 8 | from trl.experimental.gspo_token import GRPOTrainer 9 | from trl import GRPOConfig 10 | 11 | training_args = GRPOConfig( 12 | importance_sampling_level="sequence_token", 13 | ... 14 | ) 15 | ``` 16 | 17 | > [!WARNING] 18 | > To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation. 19 | 20 | ## GRPOTrainer 21 | 22 | [[autodoc]] experimental.gspo_token.GRPOTrainer 23 | - train 24 | - save_model 25 | - push_to_hub 26 | -------------------------------------------------------------------------------- /docs/source/data_utils.md: -------------------------------------------------------------------------------- 1 | # Data Utilities 2 | 3 | ## prepare_multimodal_messages 4 | 5 | [[autodoc]] prepare_multimodal_messages 6 | 7 | ## prepare_multimodal_messages_vllm 8 | 9 | [[autodoc]] prepare_multimodal_messages_vllm 10 | 11 | ## is_conversational 12 | 13 | [[autodoc]] is_conversational 14 | 15 | ## is_conversational_from_value 16 | 17 | [[autodoc]] is_conversational_from_value 18 | 19 | ## apply_chat_template 20 | 21 | [[autodoc]] apply_chat_template 22 | 23 | ## maybe_apply_chat_template 24 | 25 | [[autodoc]] maybe_apply_chat_template 26 | 27 | ## maybe_convert_to_chatml 28 | 29 | [[autodoc]] maybe_convert_to_chatml 30 | 31 | ## extract_prompt 32 | 33 | [[autodoc]] extract_prompt 34 | 35 | ## maybe_extract_prompt 36 | 37 | [[autodoc]] maybe_extract_prompt 38 | 39 | ## unpair_preference_dataset 40 | 41 | [[autodoc]] unpair_preference_dataset 42 | 43 | ## maybe_unpair_preference_dataset 44 | 45 | [[autodoc]] maybe_unpair_preference_dataset 46 | 47 | ## pack_dataset 48 | 49 | [[autodoc]] pack_dataset 50 | 51 | ## truncate_dataset 52 | 53 | [[autodoc]] truncate_dataset 54 | -------------------------------------------------------------------------------- /.github/codeql/custom-queries.qls: -------------------------------------------------------------------------------- 1 | import codeql 2 | 3 | from WorkflowString interpolation, Workflow workflow 4 | where 5 | interpolation.getStringValue().matches("${{ github.event.issue.title }}") or 6 | interpolation.getStringValue().matches("${{ github.event.issue.body }}") or 7 | interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or 8 | interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or 9 | interpolation.getStringValue().matches("${{ github.event.review.body }}") or 10 | interpolation.getStringValue().matches("${{ github.event.comment.body }}") or 11 | interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or 12 | interpolation.getStringValue().matches("${{ github.event.head_commit.message }}") 13 | interpolation.getStringValue().matches("${{ github.event.* }}") and 14 | ( 15 | step.getKey() = "run" or // Injection in run 16 | step.getKey() = "env" or // Injection via env 17 | step.getKey() = "with" // Injection via with 18 | ) 19 | select workflow, "🚨 Do not use directly as input of action" 20 | -------------------------------------------------------------------------------- /trl/experimental/ppo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .modeling_value_head import ( 16 | AutoModelForCausalLMWithValueHead, 17 | AutoModelForSeq2SeqLMWithValueHead, 18 | PreTrainedModelWrapper, 19 | ) 20 | from .ppo_config import PPOConfig 21 | from .ppo_trainer import PPOTrainer 22 | 23 | 24 | __all__ = [ 25 | "AutoModelForCausalLMWithValueHead", 26 | "AutoModelForSeq2SeqLMWithValueHead", 27 | "PreTrainedModelWrapper", 28 | "PPOConfig", 29 | "PPOTrainer", 30 | ] 31 | -------------------------------------------------------------------------------- /trl/experimental/judges/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from .judges import ( 16 | AllTrueJudge, 17 | BaseBinaryJudge, 18 | BaseJudge, 19 | BasePairwiseJudge, 20 | BaseRankJudge, 21 | HfPairwiseJudge, 22 | OpenAIPairwiseJudge, 23 | PairRMJudge, 24 | ) 25 | 26 | 27 | __all__ = [ 28 | "AllTrueJudge", 29 | "BaseBinaryJudge", 30 | "BaseJudge", 31 | "BasePairwiseJudge", 32 | "BaseRankJudge", 33 | "HfPairwiseJudge", 34 | "OpenAIPairwiseJudge", 35 | "PairRMJudge", 36 | ] 37 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - v*-release 8 | paths: 9 | - "VERSION" 10 | 11 | jobs: 12 | publish: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v4 16 | 17 | - name: Read version 18 | id: get_version 19 | run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT 20 | 21 | - name: Debug - Show version.txt content 22 | run: echo "Version is ${{ steps.get_version.outputs.version }}" 23 | 24 | - name: Set up Python 25 | uses: actions/setup-python@v4 26 | with: 27 | python-version: "3.x" 28 | 29 | - name: Install dependencies 30 | run: | 31 | python -m pip install --upgrade pip 32 | pip install build twine 33 | 34 | - name: Build package 35 | run: python -m build 36 | 37 | - name: Publish to PyPI 38 | if: ${{ !contains(steps.get_version.outputs.version, 'dev') }} 39 | env: 40 | TWINE_USERNAME: __token__ 41 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }} 42 | run: | 43 | python -m twine upload dist/* 44 | -------------------------------------------------------------------------------- /trl/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new TRL feature 3 | labels: [ "Feature request" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: 'TRL: Transformer Reinforcement Learning' 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Leandro 9 | family-names: von Werra 10 | - given-names: Younes 11 | family-names: Belkada 12 | - given-names: Lewis 13 | family-names: Tunstall 14 | - given-names: Edward 15 | family-names: Beeching 16 | - given-names: Tristan 17 | family-names: Thrush 18 | - given-names: Nathan 19 | family-names: Lambert 20 | - given-names: Shengyi 21 | family-names: Huang 22 | - given-names: Kashif 23 | family-names: Rasul 24 | - given-names: Quentin 25 | family-names: Gallouédec 26 | repository-code: 'https://github.com/huggingface/trl' 27 | abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported." 28 | keywords: 29 | - rlhf 30 | - deep-learning 31 | - pytorch 32 | - transformers 33 | license: Apache-2.0 34 | version: "0.26" 35 | -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | 17 | import pytest 18 | import torch 19 | 20 | 21 | @pytest.fixture(autouse=True) 22 | def cleanup_gpu(): 23 | """ 24 | Automatically cleanup GPU memory after each test. 25 | 26 | This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring 27 | models and tensors are properly garbage collected and GPU memory caches are cleared between tests. 28 | """ 29 | yield 30 | # Cleanup after test 31 | gc.collect() 32 | if torch.cuda.is_available(): 33 | torch.cuda.empty_cache() 34 | torch.cuda.synchronize() 35 | -------------------------------------------------------------------------------- /trl/experimental/bema_for_ref_model/dpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from ...trainer.dpo_trainer import DPOTrainer as _DPOTrainer 16 | from .callback import CallbackHandlerWithRefModel 17 | 18 | 19 | class DPOTrainer(_DPOTrainer): 20 | def __init__(self, *args, **kwargs): 21 | super().__init__(*args, **kwargs) 22 | # Replace with a new one that calls the events with the reference model 23 | self.callback_handler = CallbackHandlerWithRefModel( 24 | self.callback_handler.callbacks, 25 | self.model, 26 | self.ref_model, 27 | self.processing_class, 28 | self.optimizer, 29 | self.lr_scheduler, 30 | ) 31 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new-trainer-addition.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F31F New trainer addition" 2 | description: Submit a proposal/request to implement a new trainer for a post-training method 3 | labels: [ "New trainer" ] 4 | 5 | body: 6 | - type: textarea 7 | id: description-request 8 | validations: 9 | required: true 10 | attributes: 11 | label: Method description 12 | description: | 13 | Put any and all important information relative to the method 14 | 15 | - type: checkboxes 16 | id: information-tasks 17 | attributes: 18 | label: Open source status 19 | description: | 20 | Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`. 21 | options: 22 | - label: "The method implementation is available" 23 | - label: "The model weights are available" 24 | - label: "The training datasets are available" 25 | 26 | - type: textarea 27 | id: additional-info 28 | attributes: 29 | label: Provide useful links for the implementation 30 | description: | 31 | Please provide information regarding the implementation, the weights, and the authors. 32 | Please mention the authors by @gh-username if you're aware of their usernames. 33 | -------------------------------------------------------------------------------- /docs/source/bema_for_reference_model.md: -------------------------------------------------------------------------------- 1 | # BEMA for Reference Model 2 | 3 | This feature implements the BEMA algorithm to update the reference model during DPO training. 4 | 5 | ## Usage 6 | 7 | ```python 8 | from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer 9 | from datasets import load_dataset 10 | from transformers import AutoModelForCausalLM, AutoTokenizer 11 | 12 | 13 | pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") 14 | ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 15 | 16 | bema_callback = BEMACallback(update_ref_model=True) 17 | 18 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 19 | tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 20 | tokenizer.pad_token = tokenizer.eos_token 21 | 22 | trainer = DPOTrainer( 23 | model=model, 24 | ref_model=ref_model, 25 | train_dataset=pref_dataset, 26 | processing_class=tokenizer, 27 | callbacks=[bema_callback], 28 | ) 29 | 30 | trainer.train() 31 | ``` 32 | 33 | ## DPOTrainer 34 | 35 | [[autodoc]] experimental.bema_for_ref_model.DPOTrainer 36 | - train 37 | - save_model 38 | - push_to_hub 39 | 40 | ## BEMACallback 41 | 42 | [[autodoc]] experimental.bema_for_ref_model.BEMACallback 43 | -------------------------------------------------------------------------------- /trl/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | from typing import TYPE_CHECKING 17 | 18 | from ..import_utils import _LazyModule 19 | 20 | 21 | _import_structure = { 22 | "accuracy_rewards": ["accuracy_reward", "reasoning_accuracy_reward"], 23 | "format_rewards": ["think_format_reward"], 24 | "other_rewards": ["get_soft_overlong_punishment"], 25 | } 26 | 27 | 28 | if TYPE_CHECKING: 29 | from .accuracy_rewards import accuracy_reward, reasoning_accuracy_reward 30 | from .format_rewards import think_format_reward 31 | from .other_rewards import get_soft_overlong_punishment 32 | 33 | 34 | else: 35 | sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) 36 | -------------------------------------------------------------------------------- /trl/extras/dataset_formatting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import datasets 17 | from datasets import Value 18 | from packaging import version 19 | 20 | 21 | if version.parse(datasets.__version__) >= version.parse("4.0.0"): 22 | from datasets import List 23 | 24 | FORMAT_MAPPING = { 25 | "chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}), 26 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 27 | } 28 | else: 29 | FORMAT_MAPPING = { 30 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}], 31 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)}, 32 | } 33 | -------------------------------------------------------------------------------- /trl/trainer/prm_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.prm import PRMConfig as _PRMConfig 22 | 23 | 24 | class PRMConfig(_PRMConfig): 25 | def __post_init__(self): 26 | warnings.warn( 27 | "The `PRMConfig` is now located in `trl.experimental`. Please update your imports to " 28 | "`from trl.experimental.xco import PRMConfig`. The current import path will be removed and no longer " 29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 30 | FutureWarning, 31 | stacklevel=2, 32 | ) 33 | super().__post_init__() 34 | -------------------------------------------------------------------------------- /trl/trainer/xpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.xpo import XPOConfig as _XPOConfig 22 | 23 | 24 | class XPOConfig(_XPOConfig): 25 | def __post_init__(self): 26 | warnings.warn( 27 | "The `XPOConfig` is now located in `trl.experimental`. Please update your imports to " 28 | "`from trl.experimental.xco import XPOConfig`. The current import path will be removed and no longer " 29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 30 | FutureWarning, 31 | stacklevel=2, 32 | ) 33 | super().__post_init__() 34 | -------------------------------------------------------------------------------- /trl/trainer/prm_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.prm import PRMTrainer as _PRMTrainer 22 | 23 | 24 | class PRMTrainer(_PRMTrainer): 25 | def __init__(self, *args, **kwargs): 26 | warnings.warn( 27 | "The `PRMTrainer` is now located in `trl.experimental`. Please update your imports to " 28 | "`from trl.experimental.prm import PRMTrainer`. The current import path will be removed and no longer " 29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 30 | FutureWarning, 31 | stacklevel=2, 32 | ) 33 | super().__init__(*args, **kwargs) 34 | -------------------------------------------------------------------------------- /trl/trainer/xpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.xpo import XPOTrainer as _XPOTrainer 22 | 23 | 24 | class XPOTrainer(_XPOTrainer): 25 | def __init__(self, *args, **kwargs): 26 | warnings.warn( 27 | "The `XPOTrainer` is now located in `trl.experimental`. Please update your imports to " 28 | "`from trl.experimental.xpo import XPOTrainer`. The current import path will be removed and no longer " 29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 30 | FutureWarning, 31 | stacklevel=2, 32 | ) 33 | super().__init__(*args, **kwargs) 34 | -------------------------------------------------------------------------------- /trl/experimental/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Experimental submodule for TRL. 17 | 18 | This submodule contains unstable or incubating features. Anything here may change (or be removed) in any release 19 | without deprecation. Use at your own risk. 20 | 21 | To silence this notice set environment variable TRL_EXPERIMENTAL_SILENCE=1. 22 | """ 23 | 24 | import os 25 | import warnings 26 | 27 | from ..import_utils import TRLExperimentalWarning 28 | 29 | 30 | if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"): 31 | warnings.warn( 32 | "You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without " 33 | "notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.", 34 | TRLExperimentalWarning, 35 | stacklevel=2, 36 | ) 37 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request), 21 | Pull Request section? 22 | - [ ] Was this discussed/approved via a GitHub issue? Please add a link 23 | to it if that's the case. 24 | - [ ] Did you make sure to update the documentation with your changes? 25 | - [ ] Did you write any new necessary tests? 26 | 27 | 28 | ## Who can review? 29 | 30 | Anyone in the community is free to review the PR once the tests have passed. Feel free to tag 31 | members/contributors who may be interested in your PR. -------------------------------------------------------------------------------- /trl/experimental/gfpo/gfpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from ...trainer.grpo_config import GRPOConfig as _GRPOConfig 18 | 19 | 20 | @dataclass 21 | class GFPOConfig(_GRPOConfig): 22 | num_remains_in_group: int | None = field( 23 | default=None, 24 | metadata={ 25 | "help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given." 26 | }, 27 | ) 28 | 29 | def __post_init__(self): 30 | super().__post_init__() 31 | 32 | if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations: 33 | raise ValueError( 34 | f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}." 35 | ) 36 | -------------------------------------------------------------------------------- /trl/trainer/bco_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.bco import BCOConfig as _BCOConfig 23 | 24 | 25 | @dataclass 26 | class BCOConfig(_BCOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `BCOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.bco import BCOConfig`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/cpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.cpo import CPOConfig as _CPOConfig 23 | 24 | 25 | @dataclass 26 | class CPOConfig(_CPOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `CPOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.cpo import CPOConfig`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/gkd_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.gkd import GKDConfig as _GKDConfig 23 | 24 | 25 | @dataclass 26 | class GKDConfig(_GKDConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `GKDConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.gkd import GKDConfig`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/ppo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.ppo import PPOConfig as _PPOConfig 23 | 24 | 25 | @dataclass 26 | class PPOConfig(_PPOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `PPOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.ppo import PPOConfig`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/orpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.orpo import ORPOConfig as _ORPOConfig 23 | 24 | 25 | @dataclass 26 | class ORPOConfig(_ORPOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `ORPOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.orpo import ORPOConfig`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/bco_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.bco import BCOTrainer as _BCOTrainer 23 | 24 | 25 | @dataclass 26 | class BCOTrainer(_BCOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `BCOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.bco import BCOTrainer`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/cpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.cpo import CPOTrainer as _CPOTrainer 23 | 24 | 25 | @dataclass 26 | class CPOTrainer(_CPOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `CPOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.cpo import CPOTrainer`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/gkd_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.gkd import GKDTrainer as _GKDTrainer 23 | 24 | 25 | @dataclass 26 | class GKDTrainer(_GKDTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `GKDTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.gkd import GKDTrainer`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/nash_md_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.nash_md import NashMDConfig as _NashMDConfig 23 | 24 | 25 | @dataclass 26 | class NashMDConfig(_NashMDConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `NashMDConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.nash_md import NashMDConfig`. The current import path will be removed and no " 31 | "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__post_init__() 36 | -------------------------------------------------------------------------------- /trl/trainer/ppo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.ppo import PPOTrainer as _PPOTrainer 23 | 24 | 25 | @dataclass 26 | class PPOTrainer(_PPOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `PPOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.ppo import PPOTrainer`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/orpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.orpo import ORPOTrainer as _ORPOTrainer 23 | 24 | 25 | @dataclass 26 | class ORPOTrainer(_ORPOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `ORPOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.orpo import ORPOTrainer`. The current import path will be removed and no longer " 31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/nash_md_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.nash_md import NashMDTrainer as _NashMDTrainer 23 | 24 | 25 | @dataclass 26 | class NashMDTrainer(_NashMDTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `NashMDTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.nash_md import NashMDTrainer`. The current import path will be removed and no " 31 | "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | -------------------------------------------------------------------------------- /trl/trainer/online_dpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.online_dpo import OnlineDPOConfig as _OnlineDPOConfig 23 | 24 | 25 | @dataclass 26 | class OnlineDPOConfig(_OnlineDPOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `OnlineDPOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.online_dpo import OnlineDPOConfig`. The current import path will be removed and " 31 | "no longer supported in TRL 0.29. For more information, see " 32 | "https://github.com/huggingface/trl/issues/4223.", 33 | FutureWarning, 34 | stacklevel=2, 35 | ) 36 | super().__post_init__() 37 | -------------------------------------------------------------------------------- /trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from ...trainer.grpo_config import GRPOConfig 18 | 19 | 20 | @dataclass 21 | class GRPOWithReplayBufferConfig(GRPOConfig): 22 | """ 23 | New Parameters: 24 | replay_buffer_size (`int`, *optional*, defaults to `0`): 25 | A cache that stores the rollouts with the highest advantage scores and variance per group. If a new 26 | group has 0 variance, it is replaced with a group sampled from the replay buffer. 27 | """ 28 | 29 | replay_buffer_size: int = field( 30 | default=64, 31 | metadata={ 32 | "help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer." 33 | }, 34 | ) 35 | -------------------------------------------------------------------------------- /trl/trainer/kto_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.kto import KTOConfig as _KTOConfig 23 | 24 | 25 | @dataclass 26 | class KTOConfig(_KTOConfig): 27 | def __post_init__(self): 28 | warnings.warn( 29 | "The `KTOConfig` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.kto import KTOConfig`. For more information, see " 31 | "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. " 32 | "Until then, this current path (`from trl import KTOConfig`) will remain, but API changes may occur.", 33 | FutureWarning, 34 | stacklevel=2, 35 | ) 36 | super().__post_init__() 37 | -------------------------------------------------------------------------------- /trl/trainer/online_dpo_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.online_dpo import OnlineDPOTrainer as _OnlineDPOTrainer 23 | 24 | 25 | @dataclass 26 | class OnlineDPOTrainer(_OnlineDPOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `OnlineDPOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.online_dpo import OnlineDPOTrainer`. The current import path will be removed and " 31 | "no longer supported in TRL 0.29. For more information, see " 32 | "https://github.com/huggingface/trl/issues/4223.", 33 | FutureWarning, 34 | stacklevel=2, 35 | ) 36 | super().__init__(*args, **kwargs) 37 | -------------------------------------------------------------------------------- /trl/trainer/kto_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | from dataclasses import dataclass 17 | 18 | from ..import_utils import suppress_experimental_warning 19 | 20 | 21 | with suppress_experimental_warning(): 22 | from ..experimental.kto import KTOTrainer as _KTOTrainer 23 | 24 | 25 | @dataclass 26 | class KTOTrainer(_KTOTrainer): 27 | def __init__(self, *args, **kwargs): 28 | warnings.warn( 29 | "The `KTOTrainer` is now located in `trl.experimental`. Please update your imports to " 30 | "`from trl.experimental.kto import KTOTrainer`. For more information, see " 31 | "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. " 32 | "Until then, this current path (`from trl import KTOTrainer`) will remain, but API changes may occur.", 33 | FutureWarning, 34 | stacklevel=2, 35 | ) 36 | super().__init__(*args, **kwargs) 37 | -------------------------------------------------------------------------------- /tests/test_model_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from transformers import AutoModelForCausalLM 16 | 17 | from trl.models.utils import disable_gradient_checkpointing 18 | 19 | 20 | class TestDisableGradientCheckpointing: 21 | def test_when_disabled(self): 22 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 23 | assert model.is_gradient_checkpointing is False 24 | with disable_gradient_checkpointing(model): 25 | assert model.is_gradient_checkpointing is False 26 | assert model.is_gradient_checkpointing is False 27 | 28 | def test_when_enabled(self): 29 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 30 | model.gradient_checkpointing_enable() 31 | assert model.is_gradient_checkpointing is True 32 | with disable_gradient_checkpointing(model): 33 | assert model.is_gradient_checkpointing is False 34 | assert model.is_gradient_checkpointing is True 35 | -------------------------------------------------------------------------------- /docs/source/deepspeed_integration.md: -------------------------------------------------------------------------------- 1 | # DeepSpeed Integration 2 | 3 | > [!WARNING] 4 | > Section under construction. Feel free to contribute! 5 | 6 | TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more. 7 | 8 | DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency. 9 | 10 | ![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png) 11 | 12 | ## Installation 13 | 14 | To use DeepSpeed with TRL, install it using the following command: 15 | 16 | ```bash 17 | pip install deepspeed 18 | ``` 19 | 20 | ## Running Training Scripts with DeepSpeed 21 | 22 | No modifications to your training script are required. Simply run it with the DeepSpeed configuration file: 23 | 24 | ```bash 25 | accelerate launch --config_file train.py 26 | ``` 27 | 28 | We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command: 29 | 30 | ```bash 31 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py 32 | ``` 33 | 34 | ## Additional Resources 35 | 36 | Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. 37 | -------------------------------------------------------------------------------- /trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "activation_offloading": ["get_act_offloading_ctx_manager"], 22 | "modeling_base": ["PreTrainedModelWrapper"], 23 | "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], 24 | "utils": ["create_reference_model", "prepare_deepspeed", "prepare_fsdp", "unwrap_model_for_generation"], 25 | } 26 | 27 | 28 | if TYPE_CHECKING: 29 | from .activation_offloading import get_act_offloading_ctx_manager 30 | from .modeling_base import PreTrainedModelWrapper 31 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 32 | from .utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation 33 | else: 34 | import sys 35 | 36 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 37 | -------------------------------------------------------------------------------- /examples/accelerate_configs/alst_ulysses_4gpu.yaml: -------------------------------------------------------------------------------- 1 | # ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + SP) for 4 GPUs 2 | # 3 | # This configuration enables 2D parallelism: 4 | # - Sequence Parallelism (sp_size=2): Sequences split across 2 GPUs using ALST/Ulysses 5 | # - Data Parallelism (dp_shard_size=2): Model/optimizer sharded across 2 GPUs 6 | # - Total: 4 GPUs (2 × 2) 7 | # 8 | # Set parallelism_config in your training script: 9 | # parallelism_config = ParallelismConfig( 10 | # sp_backend="deepspeed", 11 | # sp_size=2, 12 | # dp_shard_size=2, # Calculated as: num_gpus // sp_size 13 | # sp_handler=DeepSpeedSequenceParallelConfig(...) 14 | # ) 15 | 16 | compute_environment: LOCAL_MACHINE 17 | debug: false 18 | deepspeed_config: 19 | zero_stage: 3 20 | seq_parallel_communication_data_type: bf16 21 | offload_optimizer_device: none 22 | offload_param_device: none 23 | zero3_init_flag: true 24 | zero3_save_16bit_model: true 25 | distributed_type: DEEPSPEED 26 | downcast_bf16: 'no' 27 | machine_rank: 0 28 | main_training_function: main 29 | mixed_precision: bf16 30 | num_machines: 1 31 | num_processes: 4 # Total number of GPUs 32 | rdzv_backend: static 33 | same_network: true 34 | tpu_env: [] 35 | tpu_use_cluster: false 36 | tpu_use_sudo: false 37 | use_cpu: false 38 | parallelism_config: 39 | parallelism_config_dp_replicate_size: 1 40 | parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP 41 | parallelism_config_tp_size: 1 42 | parallelism_config_sp_size: 2 # Sequence parallel size 43 | parallelism_config_sp_backend: deepspeed 44 | parallelism_config_sp_seq_length_is_variable: true 45 | parallelism_config_sp_attn_implementation: flash_attention_2 46 | -------------------------------------------------------------------------------- /docs/source/experimental_overview.md: -------------------------------------------------------------------------------- 1 | # Experimental 2 | 3 | This directory contains a minimal, clearly separated space for fast iteration on new ideas. 4 | 5 | > [!WARNING] 6 | > **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads. 7 | 8 | ## Promotion Path (Simple) 9 | 10 | 1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly. 11 | 2. **Experimental inclusion:** Once it’s ready for early users, move the idea into `trl.experimental.`. 12 | 3. **Improve:** Add tests, a short doc/example, and demonstrate the usage. 13 | 4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.` (stable module). 14 | 15 | ## FAQ 16 | 17 | **Why not just use branches?** 18 | Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback. 19 | 20 | **Can these APIs change or vanish without warning?** 21 | Yes. Anything inside `trl.experimental` can change or disappear in *any* release. 22 | 23 | **Should I use this in production?** 24 | Only if you are fine with updating your code quickly when things change. 25 | 26 | **Will maintainers promptly fix issues in `trl.experimental`?** 27 | Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API. 28 | 29 | **How to silence the runtime notice?** 30 | 31 | Use: `export TRL_EXPERIMENTAL_SILENCE=1`. 32 | -------------------------------------------------------------------------------- /trl/models/modeling_base.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.ppo.modeling_value_head import PreTrainedModelWrapper as _PreTrainedModelWrapper 22 | 23 | 24 | LAYER_PATTERNS = [ 25 | "transformer.h.{layer}", 26 | "model.decoder.layers.{layer}", 27 | "gpt_neox.layers.{layer}", 28 | "model.layers.{layer}", 29 | ] 30 | 31 | 32 | class PreTrainedModelWrapper(_PreTrainedModelWrapper): 33 | def __init__(self, *args, **kwargs): 34 | warnings.warn( 35 | "The `PreTrainedModelWrapper` is now located in `trl.experimental`. Please update your imports to " 36 | "`from trl.experimental.bco import PreTrainedModelWrapper`. The current import path will be removed and " 37 | "no longer supported in TRL 0.29. For more information, see " 38 | "https://github.com/huggingface/trl/issues/4223.", 39 | FutureWarning, 40 | stacklevel=2, 41 | ) 42 | super().__init__(*args, **kwargs) 43 | -------------------------------------------------------------------------------- /docs/source/gfpo.md: -------------------------------------------------------------------------------- 1 | # GFPO 2 | 3 | This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726). 4 | 5 | ## Usage 6 | 7 | To activate GFPO in [`GFPOTrainer`]: 8 | 9 | - set `num_remains_in_group` in [`GFPOConfig`] 10 | - define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group. 11 | 12 | ```python 13 | # train_gfpo.py 14 | from trl.experimental.gfpo import GFPOConfig, GFPOTrainer 15 | 16 | # dummy group filter to scores the completions based on its indice in group 17 | class GroupFilter: 18 | def __call__(self, group_completions, group_rewards, **kwargs): 19 | group_scores = [] 20 | for completions, rewards in zip(group_completions, group_rewards): 21 | scores = [float(i) for i in range(len(completions))] 22 | group_scores.append(scores) 23 | return group_scores 24 | 25 | training_args = GFPOConfig( 26 | output_dir="Qwen3-0.6B-GFPO", 27 | per_device_train_batch_size=4, 28 | num_remains_in_group=2, 29 | bf16=True, 30 | ) 31 | trainer = GFPOTrainer( 32 | model="Qwen/Qwen3-0.6B", 33 | reward_funcs=..., 34 | train_dataset=..., 35 | args=training_args, 36 | group_filter_func=GroupFilter(), 37 | ) 38 | trainer.train() 39 | ``` 40 | 41 | ## GFPOTrainer 42 | 43 | [[autodoc]] experimental.gfpo.GFPOTrainer 44 | - train 45 | - save_model 46 | - push_to_hub 47 | 48 | ## GFPOConfig 49 | 50 | [[autodoc]] experimental.gfpo.GFPOConfig 51 | -------------------------------------------------------------------------------- /trl/experimental/xpo/xpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from ..online_dpo import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class XPOConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`experimental.xpo.XPOTrainer`]. 24 | 25 | Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): 29 | Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch 30 | and the last alpha is used for the rest of the epochs. 31 | """ 32 | 33 | alpha: list[float] = field( 34 | default_factory=lambda: [1e-5], 35 | metadata={ 36 | "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each " 37 | "new epoch and the last alpha is used for the rest of the epochs." 38 | }, 39 | ) 40 | 41 | def __post_init__(self): 42 | super().__post_init__() 43 | if hasattr(self.alpha, "__len__") and len(self.alpha) == 1: 44 | self.alpha = self.alpha[0] 45 | -------------------------------------------------------------------------------- /docs/source/grpo_with_replay_buffer.md: -------------------------------------------------------------------------------- 1 | # GRPO With Replay Buffer 2 | 3 | This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches. 4 | 5 | ## Usage 6 | 7 | ```python 8 | import torch 9 | from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer 10 | from datasets import load_dataset 11 | 12 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") 13 | 14 | # Guarantee that some rewards have 0 std 15 | def custom_reward_func(completions, **kwargs): 16 | if torch.rand(1).item() < 0.25: 17 | return [0] * len(completions) # simulate some None rewards 18 | else: 19 | return torch.rand(len(completions)).tolist() 20 | 21 | training_args = GRPOWithReplayBufferConfig( 22 | output_dir="./tmp", 23 | learning_rate=1e-4, 24 | per_device_train_batch_size=4, 25 | num_generations=4, 26 | max_completion_length=8, 27 | replay_buffer_size=8, 28 | report_to="none", 29 | ) 30 | 31 | trainer = GRPOWithReplayBufferTrainer( 32 | model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", 33 | reward_funcs=[custom_reward_func], 34 | args=training_args, 35 | train_dataset=dataset, 36 | ) 37 | 38 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} 39 | 40 | trainer.train() 41 | ``` 42 | 43 | ## GRPOWithReplayBufferTrainer 44 | 45 | [[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer 46 | - train 47 | - save_model 48 | - push_to_hub 49 | 50 | ## GRPOWithReplayBufferConfig 51 | 52 | [[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig 53 | 54 | ## ReplayBuffer 55 | 56 | [[autodoc]] experimental.grpo_with_replay_buffer.ReplayBuffer 57 | -------------------------------------------------------------------------------- /.github/workflows/tests-experimental.yml: -------------------------------------------------------------------------------- 1 | name: Tests (experimental) 2 | 3 | on: 4 | pull_request: 5 | paths: 6 | # Run only when relevant files are modified 7 | - "trl/experimental/**" 8 | - "tests/experimental/**" 9 | 10 | env: 11 | TQDM_DISABLE: 1 12 | PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True" 13 | TRL_EXPERIMENTAL_SILENCE: 1 14 | 15 | jobs: 16 | check_code_quality: 17 | name: Check code quality 18 | runs-on: ubuntu-latest 19 | if: github.event.pull_request.draft == false 20 | steps: 21 | - uses: actions/checkout@v4 22 | - name: Set up Python 3.13 23 | uses: actions/setup-python@v5 24 | with: 25 | python-version: 3.13 26 | - uses: pre-commit/action@v3.0.1 27 | with: 28 | extra_args: --all-files 29 | 30 | tests: 31 | name: Tests (experimental) 32 | runs-on: 33 | group: aws-g4dn-2xlarge 34 | container: 35 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 36 | options: --gpus all 37 | defaults: 38 | run: 39 | shell: bash 40 | steps: 41 | - name: Git checkout 42 | uses: actions/checkout@v4 43 | 44 | - name: Set up Python 3.13 45 | uses: actions/setup-python@v5 46 | with: 47 | python-version: 3.13 48 | 49 | - name: Install Make and Git 50 | run: | 51 | apt-get update && apt-get install -y make git curl 52 | 53 | - name: Install uv 54 | run: | 55 | curl -LsSf https://astral.sh/uv/install.sh | sh 56 | 57 | - name: Create Python virtual environment 58 | run: | 59 | uv venv 60 | uv pip install --upgrade setuptools wheel 61 | 62 | - name: Install dependencies 63 | run: | 64 | source .venv/bin/activate 65 | uv pip install ".[dev]" 66 | 67 | - name: Test with pytest 68 | run: | 69 | source .venv/bin/activate 70 | make test_experimental 71 | -------------------------------------------------------------------------------- /trl/experimental/nash_md/nash_md_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from ..online_dpo import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class NashMDConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`experimental.nash_md.NashMDTrainer`]. 24 | 25 | Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): 29 | Logit mixture coefficient for the model and reference model. If a list of floats is provided then the 30 | mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the 31 | epochs. 32 | """ 33 | 34 | mixture_coef: list[float] = field( 35 | default_factory=lambda: [0.5], 36 | metadata={ 37 | "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " 38 | "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " 39 | "rest of the epochs." 40 | }, 41 | ) 42 | 43 | def __post_init__(self): 44 | super().__post_init__() 45 | if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: 46 | self.mixture_coef = self.mixture_coef[0] 47 | -------------------------------------------------------------------------------- /trl/templates/rm_model_card.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # Model Card for {{ model_name }} 6 | 7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. 8 | It has been trained using [TRL](https://github.com/huggingface/trl). 9 | 10 | ## Quick start 11 | 12 | ```python 13 | from transformers import pipeline 14 | 15 | text = "The capital of France is Paris." 16 | rewarder = pipeline(model="{{ hub_model_id }}", device="cuda") 17 | output = rewarder(text)[0] 18 | print(output["score"]) 19 | ``` 20 | 21 | ## Training procedure 22 | 23 | {% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} 24 | {% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} 25 | 26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. 27 | 28 | ### Framework versions 29 | 30 | - TRL: {{ trl_version }} 31 | - Transformers: {{ transformers_version }} 32 | - Pytorch: {{ pytorch_version }} 33 | - Datasets: {{ datasets_version }} 34 | - Tokenizers: {{ tokenizers_version }} 35 | 36 | ## Citations 37 | 38 | {% if trainer_citation %}Cite {{ trainer_name }} as: 39 | 40 | ```bibtex 41 | {{ trainer_citation }} 42 | ```{% endif %} 43 | 44 | Cite TRL as: 45 | 46 | ```bibtex 47 | {% raw %}@misc{vonwerra2022trl, 48 | title = {{TRL: Transformer Reinforcement Learning}}, 49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, 50 | year = 2020, 51 | journal = {GitHub repository}, 52 | publisher = {GitHub}, 53 | howpublished = {\url{https://github.com/huggingface/trl}} 54 | }{% endraw %} 55 | ``` 56 | -------------------------------------------------------------------------------- /.github/workflows/tests_latest.yml: -------------------------------------------------------------------------------- 1 | name: Tests latest TRL release with dev dependencies 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' # Runs daily at midnight UTC 6 | 7 | workflow_dispatch: 8 | 9 | env: 10 | TQDM_DISABLE: 1 11 | CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} 12 | TRL_EXPERIMENTAL_SILENCE: 1 13 | 14 | jobs: 15 | tests: 16 | name: Tests latest TRL release with dev dependencies 17 | runs-on: 18 | group: aws-g4dn-2xlarge 19 | container: 20 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 21 | options: --gpus all 22 | defaults: 23 | run: 24 | shell: bash 25 | steps: 26 | - name: Git checkout 27 | uses: actions/checkout@v4 28 | with: { ref: v0.26-release } 29 | 30 | - name: Set up Python 3.12 31 | uses: actions/setup-python@v5 32 | with: 33 | python-version: '3.12' 34 | 35 | - name: Install Make and Git 36 | run: | 37 | apt-get update && apt-get install -y make git curl 38 | 39 | - name: Install uv 40 | run: | 41 | curl -LsSf https://astral.sh/uv/install.sh | sh 42 | 43 | - name: Create Python virtual environment 44 | run: | 45 | uv venv 46 | uv pip install --upgrade setuptools wheel 47 | 48 | - name: Install dependencies 49 | run: | 50 | source .venv/bin/activate 51 | uv pip install ".[dev]" 52 | uv pip install -U git+https://github.com/huggingface/accelerate.git 53 | uv pip install -U git+https://github.com/huggingface/datasets.git 54 | uv pip install -U git+https://github.com/huggingface/transformers.git 55 | 56 | - name: Test with pytest 57 | run: | 58 | source .venv/bin/activate 59 | make test 60 | 61 | - name: Post to Slack 62 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 63 | with: 64 | slack_channel: ${{ env.CI_SLACK_CHANNEL }} 65 | title: Results of latest TRL with Python 3.12 and dev dependencies 66 | status: ${{ job.status }} 67 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 68 | -------------------------------------------------------------------------------- /tests/test_rich_progress_callback.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | import torch.nn as nn 17 | from datasets import Dataset 18 | from transformers import Trainer, TrainingArguments 19 | 20 | from trl.trainer.callbacks import RichProgressCallback 21 | 22 | from .testing_utils import TrlTestCase, require_rich 23 | 24 | 25 | class DummyModel(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | self.a = nn.Parameter(torch.tensor(1.0)) 29 | 30 | def forward(self, x): 31 | return self.a * x 32 | 33 | 34 | @require_rich 35 | class TestRichProgressCallback(TrlTestCase): 36 | def setup_method(self): 37 | self.dummy_model = DummyModel() 38 | self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5) 39 | self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101) 40 | 41 | def test_rich_progress_callback_logging(self): 42 | training_args = TrainingArguments( 43 | output_dir=self.tmp_dir, 44 | per_device_eval_batch_size=2, 45 | per_device_train_batch_size=2, 46 | num_train_epochs=4, 47 | eval_strategy="steps", 48 | eval_steps=1, 49 | logging_strategy="steps", 50 | logging_steps=1, 51 | save_strategy="no", 52 | report_to="none", 53 | disable_tqdm=True, 54 | ) 55 | callbacks = [RichProgressCallback()] 56 | trainer = Trainer( 57 | model=self.dummy_model, 58 | train_dataset=self.dummy_train_dataset, 59 | eval_dataset=self.dummy_val_dataset, 60 | args=training_args, 61 | callbacks=callbacks, 62 | ) 63 | 64 | trainer.train() 65 | -------------------------------------------------------------------------------- /trl/templates/lm_model_card.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # Model Card for {{ model_name }} 6 | 7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. 8 | It has been trained using [TRL](https://github.com/huggingface/trl). 9 | 10 | ## Quick start 11 | 12 | ```python 13 | from transformers import pipeline 14 | 15 | question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" 16 | generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") 17 | output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] 18 | print(output["generated_text"]) 19 | ``` 20 | 21 | ## Training procedure 22 | 23 | {% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} 24 | {% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} 25 | 26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. 27 | 28 | ### Framework versions 29 | 30 | - TRL: {{ trl_version }} 31 | - Transformers: {{ transformers_version }} 32 | - Pytorch: {{ pytorch_version }} 33 | - Datasets: {{ datasets_version }} 34 | - Tokenizers: {{ tokenizers_version }} 35 | 36 | ## Citations 37 | 38 | {% if trainer_citation %}Cite {{ trainer_name }} as: 39 | 40 | ```bibtex 41 | {{ trainer_citation }} 42 | ```{% endif %} 43 | 44 | Cite TRL as: 45 | 46 | ```bibtex 47 | {% raw %}@misc{vonwerra2022trl, 48 | title = {{TRL: Transformer Reinforcement Learning}}, 49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, 50 | year = 2020, 51 | journal = {GitHub repository}, 52 | publisher = {GitHub}, 53 | howpublished = {\url{https://github.com/huggingface/trl}} 54 | }{% endraw %} 55 | ``` 56 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | # wandb files 143 | nbs/wandb/ 144 | examples/notebooks/wandb/ 145 | wandb/ -------------------------------------------------------------------------------- /docs/source/trackio_integration.md: -------------------------------------------------------------------------------- 1 | # Trackio Integration 2 | 3 | [Trackio](https://huggingface.co/docs/trackio) is a lightweight, free experiment tracking library built on top of **🤗 Datasets** and **🤗 Spaces**. It is the **recommended tracking solution for TRL** and comes natively integrated with all trainers. 4 | 5 | To enable logging, simply set `report_to="trackio"` in your training config: 6 | 7 | ```python 8 | from trl import SFTConfig # works with any trainer config (e.g. DPOConfig, GRPOConfig, etc.) 9 | 10 | training_args = SFTConfig( 11 | ..., 12 | report_to="trackio", # enable Trackio logging 13 | ) 14 | ``` 15 | 16 | ## Organizing Your Experiments with Run Names and Projects 17 | 18 | By default, Trackio will generate a name to identify each run. However, we highly recommend setting a descriptive `run_name` to make it easier to organize experiments. For example: 19 | 20 | ```python 21 | from trl import SFTConfig 22 | 23 | training_args = SFTConfig( 24 | ..., 25 | report_to="trackio", 26 | run_name="sft_qwen3-4b_lr2e-5_bs128", # descriptive run name 27 | ) 28 | ``` 29 | 30 | You can also group related experiments by project by setting the following environment variable: 31 | 32 | ```bash 33 | export TRACKIO_PROJECT="my_project" 34 | ``` 35 | 36 | ## Hosting Your Logs on 🤗 Spaces 37 | 38 | Trackio has local-first design, meaning your logs stay on your machine. If you’d like to host them and deploy a dashboard on **🤗 Spaces**, set: 39 | 40 | ```bash 41 | export TRACKIO_SPACE_ID="username/space_id" 42 | ``` 43 | 44 | Running the following example: 45 | 46 | ```python 47 | import os 48 | from trl import SFTConfig, SFTTrainer 49 | from datasets import load_dataset 50 | 51 | os.environ["TRACKIO_SPACE_ID"] = "trl-lib/trackio" 52 | os.environ["TRACKIO_PROJECT"] = "trl-documentation" 53 | 54 | trainer = SFTTrainer( 55 | model="Qwen/Qwen3-0.6B", 56 | train_dataset=load_dataset("trl-lib/Capybara", split="train"), 57 | args=SFTConfig( 58 | report_to="trackio", 59 | run_name="sft_qwen3-0.6b_capybara", 60 | ), 61 | ) 62 | trainer.train() 63 | ``` 64 | 65 | will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio. 66 | 67 | 68 | -------------------------------------------------------------------------------- /trl/rewards/format_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: 19 | r""" 20 | Reward function that checks if the reasoning process is enclosed within `""` and `""` tags. The 21 | function returns a reward of 1.0 if the format is correct, otherwise 0.0. 22 | 23 | Args: 24 | completions (`list[list[dict[str, str]]]`): 25 | List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary 26 | containing the key `"content"` with the value being the text of the completion. 27 | **kwargs: 28 | Additional keyword arguments. This function does not use them, but they are required in the function 29 | signature to ensure compatibility with trainers like [`GRPOTrainer`]. 30 | 31 | Returns: 32 | `list[float]`: 33 | A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0. 34 | 35 | Example: 36 | ```python 37 | >>> from trl.rewards import think_format_reward 38 | 39 | >>> completions = [ 40 | ... [{"content": "\nThis is my reasoning.\n\nThis is my answer."}], 41 | ... [{"content": "\nThis is my reasoning.\nThis is my answer."}], 42 | ... ] 43 | >>> think_format_reward(completions) 44 | [1.0, 0.0] 45 | ``` 46 | """ 47 | pattern = r"^(?!.*)(.*?).*$" 48 | completion_contents = [completion[0]["content"] for completion in completions] 49 | matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] 50 | return [1.0 if match else 0.0 for match in matches] 51 | -------------------------------------------------------------------------------- /trl/models/modeling_value_head.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from ..import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from ..experimental.ppo import AutoModelForCausalLMWithValueHead as _AutoModelForCausalLMWithValueHead 22 | from ..experimental.ppo import AutoModelForSeq2SeqLMWithValueHead as _AutoModelForSeq2SeqLMWithValueHead 23 | 24 | 25 | class AutoModelForCausalLMWithValueHead(_AutoModelForCausalLMWithValueHead): 26 | def __init__(self, *args, **kwargs): 27 | warnings.warn( 28 | "The `AutoModelForCausalLMWithValueHead` is now located in `trl.experimental`. Please update your imports " 29 | "to `from trl.experimental.ppo import AutoModelForCausalLMWithValueHead`. The current import path will be " 30 | "removed and no longer supported in TRL 0.29. For more information, see " 31 | "https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | super().__init__(*args, **kwargs) 36 | 37 | 38 | class AutoModelForSeq2SeqLMWithValueHead(_AutoModelForSeq2SeqLMWithValueHead): 39 | def __init__(self, *args, **kwargs): 40 | warnings.warn( 41 | "The `AutoModelForSeq2SeqLMWithValueHead` is now located in `trl.experimental`. Please update your imports " 42 | "to `from trl.experimental.ppo import AutoModelForSeq2SeqLMWithValueHead`. The current import path will be " 43 | "removed and no longer supported in TRL 0.29. For more information, see " 44 | "https://github.com/huggingface/trl/issues/4223.", 45 | FutureWarning, 46 | stacklevel=2, 47 | ) 48 | super().__init__(*args, **kwargs) 49 | -------------------------------------------------------------------------------- /examples/scripts/sft_gemma3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # /// script 16 | # dependencies = [ 17 | # "trl", 18 | # "Pillow", 19 | # "trackio", 20 | # "kernels", 21 | # ] 22 | # /// 23 | 24 | """ 25 | Train Gemma-3 on the Codeforces COTS dataset. 26 | 27 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py 28 | """ 29 | 30 | import os 31 | 32 | from datasets import load_dataset 33 | from transformers import AutoModelForImageTextToText 34 | 35 | from trl import SFTConfig, SFTTrainer 36 | 37 | 38 | # Enable logging in a Hugging Face Space 39 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") 40 | 41 | 42 | def main(): 43 | # Load dataset 44 | train_dataset = load_dataset("open-r1/codeforces-cots", split="train") 45 | train_dataset = train_dataset.remove_columns("prompt") 46 | 47 | # Load model 48 | model_id = "google/gemma-3-12b-it" 49 | model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") 50 | 51 | # Train model 52 | training_args = SFTConfig( 53 | output_dir=f"{model_id}-codeforces-SFT", 54 | bf16=True, 55 | use_liger_kernel=True, 56 | gradient_checkpointing=True, 57 | gradient_checkpointing_kwargs={"use_reentrant": False}, 58 | max_length=8192, 59 | per_device_train_batch_size=1, 60 | gradient_accumulation_steps=8, 61 | dataset_num_proc=32, 62 | num_train_epochs=1, 63 | ) 64 | 65 | trainer = SFTTrainer( 66 | args=training_args, 67 | model=model, 68 | train_dataset=train_dataset, 69 | ) 70 | trainer.train() 71 | 72 | # Push to hub 73 | trainer.push_to_hub(dataset_name="open-r1/codeforces-cots") 74 | 75 | 76 | if __name__ == "__main__": 77 | main() 78 | -------------------------------------------------------------------------------- /docs/source/papo_trainer.md: -------------------------------------------------------------------------------- 1 | # PAPO Trainer 2 | 3 | [![model badge](https://img.shields.io/badge/All_models-PAPO-blue)](https://huggingface.co/models?other=papo,trl) 4 | 5 | TRL supports the Perception-Aware Policy Optimization (PAPO) as described in the paper [Perception-Aware Policy Optimization for Multimodal Reasoning](https://huggingface.co/papers/2507.06448) by [Zhenhailong Wang](https://huggingface.co/mikewang), Xuehang Guo, Sofia Stoica, [Haiyang Xu](https://huggingface.co/xhyandwyy), Hongru Wang, Hyeonjeong Ha, Xiusi Chen, Yangyi Chen, Ming Yan, Fei Huang, Heng Ji 6 | 7 | The abstract from the paper is the following: 8 | 9 | > Reinforcement Learning with Verifiable Rewards (RLVR) has proven to be a highly effective strategy for endowing Large Language Models (LLMs) with robust multi-step reasoning abilities. However, its design and optimizations remain tailored to purely textual domains, resulting in suboptimal performance when applied to multimodal reasoning tasks. In particular, we observe that a major source of error in current multimodal reasoning lies in the perception of visual inputs. To address this bottleneck, we propose Perception-Aware Policy Optimization (PAPO), a simple yet effective extension of GRPO that encourages the model to learn to perceive while learning to reason, entirely from internal supervision signals. Notably, PAPO does not rely on additional data curation, external reward models, or proprietary models. Specifically, we introduce the Implicit Perception Loss in the form of a KL divergence term to the GRPO objective, which, despite its simplicity, yields significant overall improvements (4.4%) on diverse multimodal benchmarks. The improvements are more pronounced, approaching 8.0%, on tasks with high vision dependency. We also observe a substantial reduction (30.5%) in perception errors, indicating improved perceptual capabilities with PAPO. We conduct comprehensive analysis of PAPO and identify a unique loss hacking issue, which we rigorously analyze and mitigate through a Double Entropy Loss. Overall, our work introduces a deeper integration of perception-aware supervision into RLVR learning objectives and lays the groundwork for a new RL framework that encourages visually grounded reasoning. Project page: https://mikewangwzhl.github.io/PAPO. 10 | 11 | ## PAPOTrainer 12 | 13 | [[autodoc]] experimental.papo.PAPOTrainer 14 | - train 15 | - save_model 16 | - push_to_hub 17 | 18 | ## PAPOConfig 19 | 20 | [[autodoc]] experimental.papo.PAPOConfig 21 | -------------------------------------------------------------------------------- /tests/experimental/test_minillm_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import pytest 16 | import torch 17 | from datasets import load_dataset 18 | 19 | from trl.experimental.minillm import MiniLLMConfig, MiniLLMTrainer 20 | 21 | from ..testing_utils import TrlTestCase 22 | 23 | 24 | @pytest.mark.low_priority 25 | class TestMiniLLMTrainer(TrlTestCase): 26 | def test_train(self): 27 | # Get the dataset 28 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") 29 | 30 | # Initialize the trainer 31 | training_args = MiniLLMConfig( 32 | output_dir=self.tmp_dir, 33 | per_device_train_batch_size=3, # reduce the batch size to reduce memory usage 34 | num_generations=3, # reduce the number of generations to reduce memory usage 35 | max_completion_length=32, # reduce the completion length to reduce memory usage 36 | report_to="none", 37 | ) 38 | trainer = MiniLLMTrainer( 39 | model="trl-internal-testing/small-Qwen3ForCausalLM", 40 | teacher_model="trl-internal-testing/tiny-Qwen3ForCausalLM", 41 | args=training_args, 42 | train_dataset=dataset, 43 | ) 44 | 45 | # Save the initial parameters to compare them later 46 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} 47 | 48 | # Train the model 49 | trainer.train() 50 | 51 | # Check that the training loss is not None 52 | assert trainer.state.log_history[-1]["train_loss"] is not None 53 | 54 | # Check the params have changed 55 | for n, param in previous_trainable_params.items(): 56 | new_param = trainer.model.get_parameter(n) 57 | assert not torch.allclose(param, new_param), f"Parameter {n} has not changed" 58 | -------------------------------------------------------------------------------- /tests/experimental/test_gspo_token_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import torch 17 | from datasets import load_dataset 18 | from transformers.utils import is_peft_available 19 | 20 | from trl import GRPOConfig 21 | from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer 22 | 23 | from ..testing_utils import TrlTestCase 24 | 25 | 26 | if is_peft_available(): 27 | pass 28 | 29 | 30 | class TestGSPOTokenTrainer(TrlTestCase): 31 | def test_training(self): 32 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train") 33 | 34 | training_args = GRPOConfig( 35 | output_dir=self.tmp_dir, 36 | learning_rate=0.1, # increase the learning rate to speed up the test 37 | per_device_train_batch_size=3, # reduce the batch size to reduce memory usage 38 | num_generations=3, # reduce the number of generations to reduce memory usage 39 | max_completion_length=8, # reduce the completion length to reduce memory usage 40 | num_iterations=2, # the importance sampling weights won't be 0 in this case 41 | importance_sampling_level="sequence_token", 42 | report_to="none", 43 | ) 44 | trainer = GSPOTokenTrainer( 45 | model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5", 46 | reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5", 47 | args=training_args, 48 | train_dataset=dataset, 49 | ) 50 | 51 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()} 52 | 53 | trainer.train() 54 | 55 | assert trainer.state.log_history[-1]["train_loss"] is not None 56 | 57 | # Check that the params have changed 58 | for n, param in previous_trainable_params.items(): 59 | new_param = trainer.model.get_parameter(n) 60 | assert not torch.equal(param, new_param), f"Parameter {n} has not changed." 61 | -------------------------------------------------------------------------------- /trl/mergekit_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import warnings 16 | 17 | from .import_utils import suppress_experimental_warning 18 | 19 | 20 | with suppress_experimental_warning(): 21 | from .experimental.merge_model_callback import MergeConfig as _MergeConfig 22 | from .experimental.merge_model_callback import merge_models as _merge_models 23 | from .experimental.merge_model_callback import upload_model_to_hf as _upload_model_to_hf 24 | 25 | 26 | def upload_model_to_hf(*args, **kwargs): 27 | warnings.warn( 28 | "`upload_model_to_hf` is now located in `trl.experimental`. Please update your imports to " 29 | "`from trl.experimental.merge_model_callback import upload_model_to_hf`. The current import path will be " 30 | "removed and no longer supported in TRL 0.29. For more information, see " 31 | "https://github.com/huggingface/trl/issues/4223.", 32 | FutureWarning, 33 | stacklevel=2, 34 | ) 35 | return _upload_model_to_hf(*args, **kwargs) 36 | 37 | 38 | class MergeConfig(_MergeConfig): 39 | def __init__(self, *args, **kwargs): 40 | warnings.warn( 41 | "`MergeConfig` is now located in `trl.experimental`. Please update your imports to " 42 | "`from trl.experimental.merge_model_callback import MergeConfig`. The current import path will be " 43 | "removed and no longer supported in TRL 0.29. For more information, see " 44 | "https://github.com/huggingface/trl/issues/4223.", 45 | FutureWarning, 46 | stacklevel=2, 47 | ) 48 | 49 | 50 | def merge_models(*args, **kwargs): 51 | warnings.warn( 52 | "`merge_models` is now located in `trl.experimental`. Please update your imports to " 53 | "`from trl.experimental.merge_model_callback import merge_models`. The current import path will be " 54 | "removed and no longer supported in TRL 0.29. For more information, see " 55 | "https://github.com/huggingface/trl/issues/4223.", 56 | FutureWarning, 57 | stacklevel=2, 58 | ) 59 | return _merge_models(*args, **kwargs) 60 | -------------------------------------------------------------------------------- /docs/source/liger_kernel_integration.md: -------------------------------------------------------------------------------- 1 | # Liger Kernel Integration 2 | 3 | [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). 4 | 5 | With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance. 6 | 7 | | Speed Up | Memory Reduction | 8 | | --- | --- | 9 | | ![Speed up](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-tps.png) | ![Memory](https://raw.githubusercontent.com/linkedin/Liger-Kernel/main/docs/images/e2e-memory.png) | 10 | 11 | ## Supported Trainers 12 | 13 | Liger Kernel is supported in the following TRL trainers: 14 | - **SFT** (Supervised Fine-Tuning) 15 | - **DPO** (Direct Preference Optimization) 16 | - **GRPO** (Group Relative Policy Optimization) 17 | - **KTO** (Kahneman-Tversky Optimization) 18 | - **GKD** (Generalized Knowledge Distillation) 19 | 20 | ## Usage 21 | 22 | 1. First, install Liger Kernel: 23 | 24 | ```bash 25 | pip install liger-kernel 26 | ``` 27 | 28 | 2. Once installed, set `use_liger_kernel=True` in your trainer config. No other changes are needed! 29 | 30 | 31 | 32 | 33 | ```python 34 | from trl import SFTConfig 35 | 36 | training_args = SFTConfig(..., use_liger_kernel=True) 37 | ``` 38 | 39 | 40 | 41 | 42 | ```python 43 | from trl import DPOConfig 44 | 45 | training_args = DPOConfig(..., use_liger_kernel=True) 46 | ``` 47 | 48 | 49 | 50 | 51 | ```python 52 | from trl import GRPOConfig 53 | 54 | training_args = GRPOConfig(..., use_liger_kernel=True) 55 | ``` 56 | 57 | 58 | 59 | 60 | ```python 61 | from trl import KTOConfig 62 | 63 | training_args = KTOConfig(..., use_liger_kernel=True) 64 | ``` 65 | 66 | 67 | 68 | 69 | ```python 70 | from trl.experimental.gkd import GKDConfig 71 | 72 | training_args = GKDConfig(..., use_liger_kernel=True) 73 | ``` 74 | 75 | 76 | 77 | 78 | To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/). 79 | -------------------------------------------------------------------------------- /.github/workflows/docker-build.yml: -------------------------------------------------------------------------------- 1 | name: Build TRL Docker image 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | workflow_dispatch: 8 | 9 | concurrency: 10 | group: docker-image-builds 11 | cancel-in-progress: false 12 | 13 | jobs: 14 | trl: 15 | name: "Build and push TRL Docker image" 16 | runs-on: 17 | group: aws-general-8-plus 18 | steps: 19 | - name: Checkout code 20 | uses: actions/checkout@v4 21 | 22 | - name: Get TRL version from PyPI 23 | run: | 24 | VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version) 25 | echo "VERSION=$VERSION" >> $GITHUB_ENV 26 | 27 | - name: Set up Docker Buildx 28 | uses: docker/setup-buildx-action@v3 29 | 30 | - name: Login to DockerHub 31 | uses: docker/login-action@v3 32 | with: 33 | username: ${{ secrets.DOCKERHUB_USERNAME }} 34 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 35 | 36 | - name: Build and Push 37 | uses: docker/build-push-action@v4 38 | with: 39 | context: docker/trl 40 | push: true 41 | tags: | 42 | huggingface/trl:${{ env.VERSION }} 43 | huggingface/trl 44 | 45 | - name: Post to Slack 46 | if: always() 47 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 48 | with: 49 | slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }} 50 | title: 🤗 Results of the TRL Dev Docker Image build 51 | status: ${{ job.status }} 52 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 53 | 54 | trl-dev: 55 | name: "Build and push TRL Dev Docker image" 56 | runs-on: 57 | group: aws-general-8-plus 58 | steps: 59 | - name: Checkout code 60 | uses: actions/checkout@v4 61 | 62 | - name: Set up Docker Buildx 63 | uses: docker/setup-buildx-action@v3 64 | 65 | - name: Login to DockerHub 66 | uses: docker/login-action@v3 67 | with: 68 | username: ${{ secrets.DOCKERHUB_USERNAME }} 69 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 70 | 71 | - name: Build and Push 72 | uses: docker/build-push-action@v4 73 | with: 74 | context: docker/trl-dev 75 | push: true 76 | tags: | 77 | huggingface/trl:dev 78 | 79 | - name: Post to Slack 80 | if: always() 81 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 82 | with: 83 | slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }} 84 | title: 🤗 Results of the TRL Dev Docker Image build 85 | status: ${{ job.status }} 86 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 87 | -------------------------------------------------------------------------------- /docs/source/use_model.md: -------------------------------------------------------------------------------- 1 | # Use model after training 2 | 3 | Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). 4 | 5 | ## Load and Generate 6 | 7 | If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: 8 | 9 | ```python 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | 12 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub 13 | device = "cpu" # or "cuda" if you have a GPU 14 | 15 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) 16 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 17 | 18 | inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) 19 | outputs = model.generate(inputs) 20 | print(tokenizer.decode(outputs[0])) 21 | ``` 22 | 23 | Alternatively you can also use the pipeline: 24 | 25 | ```python 26 | from transformers import pipeline 27 | 28 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub 29 | pipe = pipeline("text-generation", model=model_name_or_path) 30 | print(pipe("This movie was really")[0]["generated_text"]) 31 | ``` 32 | 33 | ## Use Adapters PEFT 34 | 35 | ```python 36 | from peft import PeftConfig, PeftModel 37 | from transformers import AutoModelForCausalLM, AutoTokenizer 38 | 39 | base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub 40 | adapter_model_name = "path/to/my/adapter" 41 | 42 | model = AutoModelForCausalLM.from_pretrained(base_model_name) 43 | model = PeftModel.from_pretrained(model, adapter_model_name) 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 46 | ``` 47 | 48 | You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: 49 | 50 | ```python 51 | model = AutoModelForCausalLM.from_pretrained(base_model_name) 52 | model = PeftModel.from_pretrained(model, adapter_model_name) 53 | 54 | model = model.merge_and_unload() 55 | model.save_pretrained("merged_adapters") 56 | ``` 57 | 58 | Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. 59 | -------------------------------------------------------------------------------- /trl/rewards/other_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from collections.abc import Callable 16 | 17 | 18 | def get_soft_overlong_punishment(max_completion_len: int, soft_punish_cache: int) -> Callable: 19 | # docstyle-ignore 20 | r""" 21 | Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward 22 | shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476) 23 | 24 | $$ 25 | R_{\text{length}}(y) = \begin{cases} 26 | 0, & |y| \le L_{\max} - L_{\text{cache}} \\ 27 | \dfrac{(L_{\max} - L_{\text{cache}}) - |y|}{L_{\text{cache}}}, & L_{\max} - L_{\text{cache}} < |y| \le L_{\max} \\ 28 | -1, & L_{\max} < |y| 29 | \end{cases} 30 | $$ 31 | 32 | Args: 33 | max_completion_len (`int`): 34 | Maximum length of the completion, \( L_{\max} \). 35 | soft_punish_cache (`int`): 36 | Minimum length of the completion, \( L_{\text{cache}} \). If set to `0`, no minimum length is applied. 37 | 38 | Example: 39 | ```python 40 | from trl.rewards import get_soft_overlong_punishment 41 | 42 | soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20) 43 | completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100. 44 | rewards = soft_overlong_punishment(completion_ids) 45 | print(rewards) # [-0.5] 46 | ``` 47 | """ 48 | 49 | def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]: 50 | """Reward function that penalizes overlong completions.""" 51 | rewards = [] 52 | for ids in completion_ids: 53 | completion_length = len(ids) 54 | if completion_length <= max_completion_len - soft_punish_cache: 55 | rewards.append(0.0) 56 | elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len: 57 | rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache) 58 | else: 59 | rewards.append(-1.0) 60 | return rewards 61 | 62 | return soft_overlong_punishment_reward 63 | -------------------------------------------------------------------------------- /examples/notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications. 4 | 5 | | Notebook | Description | Open in Colab | 6 | | --- | --- | --- | 7 | | [`grpo_agent.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_agent.ipynb) | GRPO for agent training | Not available due to OOM with Colab GPUs | 8 | | [`grpo_rnj_1_instruct.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | GRPO rnj-1-instruct with QLoRA using TRL on Colab to add reasoning capabilities | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | 9 | | [`sft_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_ministral3_vl.ipynb) | Supervised Fine-Tuning (SFT) Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_ministral3_vl.ipynb) | 10 | | [`grpo_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_ministral3_vl.ipynb) | GRPO Ministral 3 with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_ministral3_vl.ipynb) | 11 | | [`openenv_wordle_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_wordle_grpo.ipynb) | GRPO to play Worldle on an OpenEnv environment | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb) | 12 | | [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | 13 | | [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) | 14 | | [`grpo_qwen3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_qwen3_vl.ipynb) | GRPO Qwen3-VL with QLoRA using TRL on free Colab | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb) | 15 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve TRL 3 | labels: [ "bug" ] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this bug report! 🤗 9 | 10 | 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug) 11 | 12 | - type: textarea 13 | id: reproduction 14 | validations: 15 | required: true 16 | attributes: 17 | label: Reproduction 18 | description: | 19 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 20 | If you have code snippets, error messages, stack traces please provide them here as well. 21 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 22 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 23 | 24 | value: | 25 | ```python 26 | from trl import ... 27 | 28 | ``` 29 | 30 | outputs: 31 | 32 | ``` 33 | Traceback (most recent call last): 34 | File "example.py", line 42, in 35 | ... 36 | ``` 37 | 38 | - type: textarea 39 | id: system-info 40 | attributes: 41 | label: System Info 42 | description: | 43 | Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ... 44 | You can get this information by running `trl env` in your terminal. 45 | 46 | placeholder: Copy-paste the output of `trl env` 47 | validations: 48 | required: true 49 | 50 | - type: checkboxes 51 | id: terms 52 | attributes: 53 | label: Checklist 54 | description: | 55 | Before submitting, please confirm that you've completed each of the following. 56 | If an item doesn't apply to your issue, check it anyway to show you've reviewed it. 57 | options: 58 | - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))" 59 | required: true 60 | - label: "I have included my system information" 61 | required: true 62 | - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" 63 | required: true 64 | - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" 65 | required: true 66 | - label: "Any traceback provided is complete" 67 | required: true 68 | -------------------------------------------------------------------------------- /docs/source/judges.md: -------------------------------------------------------------------------------- 1 | # Judges 2 | 3 | > [!WARNING] 4 | > TRL Judges is an experimental API which is subject to change at any time. As of TRL v1.0, judges have been moved to the `trl.experimental.judges` module. 5 | 6 | TRL provides judges to easily compare two completions. 7 | 8 | Make sure to have installed the required dependencies by running: 9 | 10 | ```bash 11 | pip install trl[judges] 12 | ``` 13 | 14 | ## Using the provided judges 15 | 16 | TRL provides several judges out of the box. For example, you can use the [`experimental.judges.HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub: 17 | 18 | ```python 19 | from trl.experimental.judges import HfPairwiseJudge 20 | 21 | judge = HfPairwiseJudge() 22 | judge.judge( 23 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], 24 | completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]], 25 | ) # Outputs: [0, 1] 26 | ``` 27 | 28 | ## Define your own judge 29 | 30 | To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`experimental.judges.BaseRankJudge`] and implement the [`experimental.judges.BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`experimental.judges.BasePairJudge`] and implement the [`experimental.judges.BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`experimental.judges.BaseJudge`] and implement the [`experimental.judges.BaseJudge.judge`] method. 31 | 32 | As an example, let's define a pairwise judge that prefers shorter completions: 33 | 34 | ```python 35 | from trl.experimental.judges import BasePairwiseJudge 36 | 37 | class PrefersShorterJudge(BasePairwiseJudge): 38 | def judge(self, prompts, completions, shuffle_order=False): 39 | return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions] 40 | ``` 41 | 42 | You can then use this judge as follows: 43 | 44 | ```python 45 | judge = PrefersShorterJudge() 46 | judge.judge( 47 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], 48 | completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]], 49 | ) # Outputs: [0, 1] 50 | ``` 51 | 52 | ## Provided judges 53 | 54 | ### PairRMJudge 55 | 56 | [[autodoc]] trl.experimental.judges.PairRMJudge 57 | 58 | ### HfPairwiseJudge 59 | 60 | [[autodoc]] trl.experimental.judges.HfPairwiseJudge 61 | 62 | ### OpenAIPairwiseJudge 63 | 64 | [[autodoc]] trl.experimental.judges.OpenAIPairwiseJudge 65 | 66 | ### AllTrueJudge 67 | 68 | [[autodoc]] trl.experimental.judges.AllTrueJudge 69 | 70 | ## Base classes 71 | 72 | ### BaseJudge 73 | 74 | [[autodoc]] trl.experimental.judges.BaseJudge 75 | 76 | ### BaseBinaryJudge 77 | 78 | [[autodoc]] trl.experimental.judges.BaseBinaryJudge 79 | 80 | ### BaseRankJudge 81 | 82 | [[autodoc]] trl.experimental.judges.BaseRankJudge 83 | 84 | ### BasePairwiseJudge 85 | 86 | [[autodoc]] trl.experimental.judges.BasePairwiseJudge 87 | -------------------------------------------------------------------------------- /trl/trainer/base_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | from transformers import Trainer, is_wandb_available 18 | 19 | from .utils import generate_model_card, get_comet_experiment_url, get_config_model_id 20 | 21 | 22 | if is_wandb_available(): 23 | import wandb 24 | 25 | 26 | class BaseTrainer(Trainer): 27 | _tag_names = [] 28 | _name = "Base" 29 | _paper = {} 30 | _template_file = None 31 | 32 | def create_model_card( 33 | self, 34 | model_name: str | None = None, 35 | dataset_name: str | None = None, 36 | tags: str | list[str] | None = None, 37 | ): 38 | """ 39 | Creates a draft of a model card using the information available to the `Trainer`. 40 | 41 | Args: 42 | model_name (`str`, *optional*): 43 | Name of the model. 44 | dataset_name (`str`, *optional*): 45 | Name of the dataset used for training. 46 | tags (`str`, `list[str]`, *optional*): 47 | Tags to be associated with the model card. 48 | """ 49 | if not self.is_world_process_zero(): 50 | return 51 | 52 | model_name_or_path = get_config_model_id(self.model.config) 53 | if model_name_or_path and not os.path.isdir(model_name_or_path): 54 | base_model = model_name_or_path 55 | else: 56 | base_model = None 57 | 58 | # Normalize tags 59 | if tags is None: 60 | tags = set() 61 | elif isinstance(tags, str): 62 | tags = {tags} 63 | else: 64 | tags = set(tags) 65 | if hasattr(self.model.config, "unsloth_version"): 66 | tags.add("unsloth") 67 | if "JOB_ID" in os.environ: 68 | tags.add("hf_jobs") 69 | tags.update(self._tag_names) 70 | tags = list(tags) 71 | 72 | model_card = generate_model_card( 73 | base_model=base_model, 74 | model_name=model_name, 75 | hub_model_id=self.hub_model_id, 76 | dataset_name=dataset_name, 77 | tags=tags, 78 | wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None, 79 | comet_url=get_comet_experiment_url(), 80 | trainer_name=self._name, 81 | trainer_citation=self._paper.get("citation"), 82 | template_file=self._template_file, 83 | paper_title=self._paper.get("title"), 84 | paper_id=self._paper.get("id"), 85 | ) 86 | model_card.save(os.path.join(self.args.output_dir, "README.md")) 87 | -------------------------------------------------------------------------------- /.github/workflows/slow-tests.yml: -------------------------------------------------------------------------------- 1 | name: Slow tests (on push) 2 | 3 | on: 4 | push: 5 | branches: [main] 6 | paths: 7 | # Run only when python files are modified 8 | - "trl/**.py" 9 | - "examples/**.py" 10 | env: 11 | RUN_SLOW: "yes" 12 | IS_GITHUB_CI: "1" 13 | SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 14 | TRL_EXPERIMENTAL_SILENCE: 1 15 | 16 | jobs: 17 | run_all_tests_single_gpu: 18 | runs-on: 19 | group: aws-g4dn-2xlarge 20 | env: 21 | CUDA_VISIBLE_DEVICES: "0" 22 | TEST_TYPE: "single_gpu" 23 | container: 24 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 25 | options: --gpus all --shm-size "16gb" 26 | defaults: 27 | run: 28 | shell: bash 29 | steps: 30 | - name: Git checkout 31 | uses: actions/checkout@v4 32 | 33 | - name: Install system dependencies 34 | run: | 35 | apt-get update && apt-get install -y make git curl 36 | 37 | - name: Install uv 38 | run: | 39 | curl -LsSf https://astral.sh/uv/install.sh | sh 40 | 41 | - name: Create Python virtual environment 42 | run: | 43 | uv venv 44 | uv pip install --upgrade setuptools wheel 45 | 46 | - name: Install dependencies 47 | run: | 48 | source .venv/bin/activate 49 | uv pip install ".[dev]" 50 | uv pip install pytest-reportlog 51 | 52 | - name: Run slow SFT tests on single GPU 53 | if: always() 54 | run: | 55 | source .venv/bin/activate 56 | make slow_tests 57 | 58 | - name: Generate Report 59 | if: always() 60 | run: | 61 | source .venv/bin/activate 62 | uv pip install slack_sdk tabulate 63 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY 64 | 65 | run_all_tests_multi_gpu: 66 | runs-on: 67 | group: aws-g4dn-2xlarge 68 | env: 69 | CUDA_VISIBLE_DEVICES: "0,1" 70 | TEST_TYPE: "multi_gpu" 71 | container: 72 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel 73 | options: --gpus all --shm-size "16gb" 74 | defaults: 75 | run: 76 | shell: bash 77 | steps: 78 | - name: Git checkout 79 | uses: actions/checkout@v4 80 | 81 | - name: Install system dependencies 82 | run: | 83 | apt-get update && apt-get install -y make git curl 84 | 85 | - name: Install uv 86 | run: | 87 | curl -LsSf https://astral.sh/uv/install.sh | sh 88 | 89 | - name: Create Python virtual environment 90 | run: | 91 | uv venv 92 | uv pip install --upgrade setuptools wheel 93 | 94 | - name: Install dependencies 95 | run: | 96 | source .venv/bin/activate 97 | uv pip install ".[dev]" 98 | uv pip install pytest-reportlog 99 | 100 | - name: Run slow SFT tests on Multi GPU 101 | if: always() 102 | run: | 103 | source .venv/bin/activate 104 | make slow_tests 105 | 106 | - name: Generate Reports 107 | if: always() 108 | run: | 109 | source .venv/bin/activate 110 | uv pip install slack_sdk tabulate 111 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY 112 | rm *.txt -------------------------------------------------------------------------------- /trl/experimental/papo/papo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass 16 | from typing import Literal 17 | 18 | from ...trainer.grpo_config import GRPOConfig 19 | 20 | 21 | @dataclass 22 | class PAPOConfig(GRPOConfig): 23 | """ 24 | Configuration class for PAPOTrainer. 25 | 26 | PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit 27 | perception loss and double entropy regularization. 28 | 29 | Args: 30 | perception_loss_weight (`float`, *optional*, defaults to `0.1`): 31 | gamma Weight coefficient for the perception loss term. This encourages the model to be sensitive to visual 32 | changes. 33 | 34 | mask_ratio (`float`, *optional*, defaults to `0.3`): 35 | Ratio of the image to mask when computing perception loss. 36 | 37 | mask_type (`Literal["random", "patch", "grid"]`, *optional*, defaults to `"random"`): 38 | Type of masking strategy to use. 39 | 40 | der_loss_weight1 (`float`, *optional*, defaults to `0.03`): 41 | eta1 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident 42 | predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). 43 | 44 | der_loss_weight2 (`float`, *optional*, defaults to `0.03`): 45 | eta2 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident 46 | predictions with original images (low entropy) and uncertain predictions with masked images (high entropy). 47 | 48 | loss_type (`Literal["grpo", "dapo"]`, inherited from GRPOConfig): 49 | Base loss type to use. Set to "grpo" for PAPO-G or "dapo" for PAPO-D. 50 | """ 51 | 52 | perception_loss_weight: float = 0.1 53 | mask_ratio: float = 0.3 54 | mask_type: Literal["random", "patch", "grid"] = "random" 55 | 56 | # Added for Double Entropy Regularization 57 | der_loss_weight1: float = 0.03 58 | der_loss_weight2: float = 0.03 59 | 60 | def __post_init__(self): 61 | super().__post_init__() 62 | 63 | # Validation 64 | if not 0.0 <= self.mask_ratio <= 1.0: 65 | raise ValueError(f"mask_ratio must be between 0 and 1, got {self.mask_ratio}") 66 | 67 | if self.der_loss_weight1 < 0 or self.der_loss_weight2 < 0: 68 | raise ValueError( 69 | f"der_loss_weight1 and der_loss_weight2 must be non-negative, got {self.der_loss_weight1} and {self.der_loss_weight2}" 70 | ) 71 | 72 | if self.mask_type not in ["random", "patch", "grid"]: 73 | raise ValueError(f"mask_type must be one of ['random', 'patch', 'grid'], got {self.mask_type}") 74 | -------------------------------------------------------------------------------- /tests/experimental/test_merge_model_callback.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import os 17 | 18 | from datasets import load_dataset 19 | from transformers import AutoModelForCausalLM, AutoTokenizer 20 | from transformers.trainer_utils import get_last_checkpoint 21 | 22 | from trl import DPOConfig, DPOTrainer 23 | from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback 24 | 25 | from ..testing_utils import TrlTestCase, require_mergekit 26 | 27 | 28 | @require_mergekit 29 | class TestMergeModelCallback(TrlTestCase): 30 | def setup_method(self): 31 | self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 32 | self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") 33 | self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train") 34 | 35 | def test_callback(self): 36 | training_args = DPOConfig( 37 | output_dir=self.tmp_dir, 38 | num_train_epochs=1, 39 | report_to="none", 40 | save_strategy="steps", 41 | save_steps=1, 42 | ) 43 | config = MergeConfig() 44 | merge_callback = MergeModelCallback(config) 45 | trainer = DPOTrainer( 46 | model=self.model, 47 | args=training_args, 48 | train_dataset=self.dataset, 49 | processing_class=self.tokenizer, 50 | callbacks=[merge_callback], 51 | ) 52 | trainer.train() 53 | last_checkpoint = get_last_checkpoint(self.tmp_dir) 54 | merged_path = os.path.join(last_checkpoint, "merged") 55 | assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint." 56 | 57 | def test_every_checkpoint(self): 58 | training_args = DPOConfig( 59 | output_dir=self.tmp_dir, 60 | num_train_epochs=1, 61 | report_to="none", 62 | save_strategy="steps", 63 | save_steps=1, 64 | ) 65 | config = MergeConfig() 66 | merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True) 67 | trainer = DPOTrainer( 68 | model=self.model, 69 | args=training_args, 70 | train_dataset=self.dataset, 71 | processing_class=self.tokenizer, 72 | callbacks=[merge_callback], 73 | ) 74 | trainer.train() 75 | 76 | checkpoints = sorted( 77 | [os.path.join(self.tmp_dir, cp) for cp in os.listdir(self.tmp_dir) if cp.startswith("checkpoint-")] 78 | ) 79 | 80 | for checkpoint in checkpoints: 81 | merged_path = os.path.join(checkpoint, "merged") 82 | assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}." 83 | -------------------------------------------------------------------------------- /tests/test_collators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import torch 16 | 17 | from trl.trainer.dpo_trainer import DataCollatorForPreference 18 | 19 | from .testing_utils import TrlTestCase 20 | 21 | 22 | class TestDataCollatorForPreference(TrlTestCase): 23 | def setup_method(self): 24 | self.collator = DataCollatorForPreference(pad_token_id=0) 25 | 26 | def assertTensorEqual(self, tensor1, tensor2): 27 | assert torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}" 28 | 29 | def test_padding_behavior(self): 30 | examples = [ 31 | {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, 32 | {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, 33 | ] 34 | output = self.collator.torch_call(examples) 35 | 36 | expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]]) 37 | expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) 38 | expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]]) 39 | expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]]) 40 | expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]]) 41 | expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]]) 42 | 43 | self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids) 44 | self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask) 45 | self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids) 46 | self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask) 47 | self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids) 48 | self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask) 49 | 50 | def test_optional_fields(self): 51 | examples = [ 52 | { 53 | "prompt_input_ids": [1], 54 | "chosen_input_ids": [2], 55 | "rejected_input_ids": [3], 56 | "pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2) 57 | }, 58 | { 59 | "prompt_input_ids": [4], 60 | "chosen_input_ids": [5], 61 | "rejected_input_ids": [6], 62 | "pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2) 63 | }, 64 | ] 65 | output = self.collator.torch_call(examples) 66 | 67 | expected_pixel_values = torch.tensor( 68 | [ 69 | [[[0.1, 0.2], [0.3, 0.4]]], 70 | [[[0.5, 0.6], [0.7, 0.8]]], 71 | ] 72 | ) # Shape: (2, 1, 2, 2) 73 | 74 | self.assertTensorEqual(output["pixel_values"], expected_pixel_values) 75 | -------------------------------------------------------------------------------- /examples/datasets/deepmath_103k.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from datasets import load_dataset 18 | from huggingface_hub import ModelCard 19 | from transformers import HfArgumentParser 20 | 21 | 22 | @dataclass 23 | class ScriptArguments: 24 | r""" 25 | Arguments for the script. 26 | 27 | Args: 28 | push_to_hub (`bool`, *optional*, defaults to `False`): 29 | Whether to push the dataset to the Hugging Face Hub. 30 | repo_id (`str`, *optional*, defaults to `"trl-lib/DeepMath-103K"`): 31 | Hugging Face repository ID to push the dataset to. 32 | dataset_num_proc (`int`, *optional*): 33 | Number of workers to use for dataset processing. 34 | """ 35 | 36 | push_to_hub: bool = field( 37 | default=False, 38 | metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, 39 | ) 40 | repo_id: str = field( 41 | default="trl-lib/DeepMath-103K", 42 | metadata={"help": "Hugging Face repository ID to push the dataset to."}, 43 | ) 44 | dataset_num_proc: int | None = field( 45 | default=None, 46 | metadata={"help": "Number of workers to use for dataset processing."}, 47 | ) 48 | 49 | 50 | def process_example(example): 51 | solution = example["final_answer"] 52 | if solution not in ["True", "False", "Yes", "No"]: 53 | solution = f"${solution}$" 54 | prompt = [{"role": "user", "content": example["question"]}] 55 | return {"prompt": prompt, "solution": solution} 56 | 57 | 58 | model_card = ModelCard(""" 59 | --- 60 | tags: [trl] 61 | --- 62 | 63 | # DeepMath-103K Dataset 64 | 65 | ## Summary 66 | 67 | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) is meticulously curated to push the boundaries of mathematical reasoning in language models. 68 | 69 | ## Data Structure 70 | 71 | - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) 72 | - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only) 73 | 74 | Column: 75 | - `"prompt"`: The input question. 76 | - `"solution"`: The solution to the math problem. 77 | 78 | ## Generation script 79 | 80 | The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/deepmath_103k.py). 81 | """) 82 | 83 | if __name__ == "__main__": 84 | parser = HfArgumentParser(ScriptArguments) 85 | script_args = parser.parse_args_into_dataclasses()[0] 86 | 87 | dataset = load_dataset("zwhe99/DeepMath-103K", split="train") 88 | 89 | dataset = dataset.map( 90 | process_example, 91 | remove_columns=dataset.column_names, 92 | num_proc=script_args.dataset_num_proc, 93 | ) 94 | dataset = dataset.train_test_split(test_size=0.05, seed=42) 95 | 96 | if script_args.push_to_hub: 97 | dataset.push_to_hub(script_args.repo_id) 98 | model_card.push_to_hub(script_args.repo_id, repo_type="dataset") 99 | -------------------------------------------------------------------------------- /scripts/add_copyrights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import sys 18 | from datetime import datetime 19 | 20 | 21 | COPYRIGHT_HEADER = f"""# Copyright 2020-{datetime.now().year} The HuggingFace Team. All rights reserved. 22 | # 23 | # Licensed under the Apache License, Version 2.0 (the "License"); 24 | # you may not use this file except in compliance with the License. 25 | # You may obtain a copy of the License at 26 | # 27 | # http://www.apache.org/licenses/LICENSE-2.0 28 | # 29 | # Unless required by applicable law or agreed to in writing, software 30 | # distributed under the License is distributed on an "AS IS" BASIS, 31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # See the License for the specific language governing permissions and 33 | # limitations under the License. 34 | """ 35 | 36 | 37 | def get_tracked_python_files(): 38 | """Get a list of all tracked Python files using git.""" 39 | try: 40 | # Get the list of all tracked files from Git 41 | result = subprocess.run(["git", "ls-files"], stdout=subprocess.PIPE, text=True, check=True) 42 | # Split the result by lines to get individual file paths 43 | files = result.stdout.splitlines() 44 | # Filter only Python files 45 | py_files = [f for f in files if f.endswith(".py")] 46 | return py_files 47 | except subprocess.CalledProcessError as e: 48 | print(f"Error fetching tracked files: {e}") 49 | return [] 50 | 51 | 52 | def check_and_add_copyright(file_path): 53 | """Check if the file contains a copyright notice, and add it if missing.""" 54 | if not os.path.isfile(file_path): 55 | print(f"[SKIP] {file_path} does not exist.") 56 | return 57 | 58 | with open(file_path, encoding="utf-8") as f: 59 | content = f.readlines() 60 | 61 | # Check if the exact copyright header exists 62 | if "".join(content).startswith(COPYRIGHT_HEADER): 63 | return True 64 | 65 | # If no copyright notice was found, prepend the header 66 | print(f"[MODIFY] Adding copyright to {file_path}.") 67 | with open(file_path, "w", encoding="utf-8") as f: 68 | # Write the copyright header followed by the original content 69 | f.write(COPYRIGHT_HEADER + "\n" + "".join(content)) 70 | return False 71 | 72 | 73 | def main(): 74 | """Main function to check and add copyright for all tracked Python files.""" 75 | py_files = get_tracked_python_files() 76 | if not py_files: 77 | print("No Python files are tracked in the repository.") 78 | return 79 | 80 | print(f"Checking {len(py_files)} Python files for copyright notice...") 81 | 82 | have_copyright = [check_and_add_copyright(file_path) for file_path in py_files] 83 | if not all(have_copyright): 84 | print("❌ Some files were missing the required copyright and have been updated.") 85 | sys.exit(1) 86 | else: 87 | print("✅ All files have the required copyright.") 88 | sys.exit(0) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /examples/scripts/sft_gpt_oss.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # /// script 16 | # dependencies = [ 17 | # "trl", 18 | # "kernels", 19 | # "trackio", 20 | # "kernels", 21 | # ] 22 | # /// 23 | 24 | """ 25 | pip install –-upgrade kernels 26 | 27 | Example: 28 | 29 | accelerate launch \ 30 | --config_file examples/accelerate_configs/deepspeed_zero3.yaml \ 31 | examples/scripts/sft_gpt_oss.py \ 32 | --dtype bfloat16 \ 33 | --model_name_or_path openai/gpt-oss-20b \ 34 | --packing \ 35 | --run_name 20b-full-eager \ 36 | --attn_implementation kernels-community/vllm-flash-attn3 \ 37 | --dataset_num_proc 12 \ 38 | --dataset_name HuggingFaceH4/Multilingual-Thinking \ 39 | --gradient_checkpointing \ 40 | --max_length 4096 \ 41 | --per_device_train_batch_size 2 \ 42 | --num_train_epochs 1 \ 43 | --logging_steps 1 \ 44 | --warmup_ratio 0.03 \ 45 | --lr_scheduler_type cosine_with_min_lr \ 46 | --lr_scheduler_kwargs '{"min_lr_rate": 0.1}' \ 47 | --output_dir gpt-oss-20b-multilingual-reasoner \ 48 | --report_to trackio \ 49 | --seed 42 50 | """ 51 | 52 | import os 53 | 54 | from datasets import load_dataset 55 | from transformers import AutoModelForCausalLM, Mxfp4Config 56 | 57 | from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_peft_config 58 | 59 | 60 | # Enable logging in a Hugging Face Space 61 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") 62 | 63 | 64 | def main(script_args, training_args, model_args): 65 | # Load model 66 | quantization_config = Mxfp4Config(dequantize=True) 67 | model_kwargs = dict( 68 | revision=model_args.model_revision, 69 | trust_remote_code=model_args.trust_remote_code, 70 | attn_implementation=model_args.attn_implementation, 71 | dtype=model_args.dtype, 72 | use_cache=False if training_args.gradient_checkpointing else True, 73 | quantization_config=quantization_config, 74 | ) 75 | 76 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) 77 | 78 | # Load dataset 79 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 80 | 81 | # Train model 82 | trainer = SFTTrainer( 83 | model=model, 84 | args=training_args, 85 | train_dataset=dataset[script_args.dataset_train_split], 86 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 87 | peft_config=get_peft_config(model_args), 88 | ) 89 | 90 | trainer.train() 91 | trainer.save_model(training_args.output_dir) 92 | if training_args.push_to_hub: 93 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 94 | 95 | 96 | if __name__ == "__main__": 97 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig)) 98 | script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True) 99 | main(script_args, training_args, model_args) 100 | -------------------------------------------------------------------------------- /trl/extras/profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import functools 17 | import time 18 | from collections.abc import Callable, Generator 19 | 20 | from transformers import Trainer 21 | from transformers.integrations import is_mlflow_available, is_wandb_available 22 | 23 | 24 | if is_wandb_available(): 25 | import wandb 26 | 27 | if is_mlflow_available(): 28 | import mlflow 29 | 30 | 31 | @contextlib.contextmanager 32 | def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: 33 | """ 34 | A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow 35 | depending on the trainer's configuration. 36 | 37 | Args: 38 | trainer (`~transformers.Trainer`): 39 | Trainer object. 40 | name (`str`): 41 | Name of the block to be profiled. Used as a key in the logged dictionary. 42 | 43 | Example: 44 | ```python 45 | from transformers import Trainer 46 | from trl.extras.profiling import profiling_context 47 | 48 | 49 | class MyTrainer(Trainer): 50 | def some_method(self): 51 | A = np.random.rand(1000, 1000) 52 | B = np.random.rand(1000, 1000) 53 | with profiling_context(self, "matrix_multiplication"): 54 | # Code to profile: simulate a computationally expensive operation 55 | result = A @ B # Matrix multiplication 56 | ``` 57 | """ 58 | start_time = time.perf_counter() 59 | yield 60 | end_time = time.perf_counter() 61 | duration = end_time - start_time 62 | 63 | profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration} 64 | if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: 65 | wandb.log(profiling_metrics) 66 | 67 | if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process: 68 | mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) 69 | 70 | 71 | def profiling_decorator(func: Callable) -> Callable: 72 | """ 73 | Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. 74 | 75 | Args: 76 | func (`Callable`): 77 | Function to be profiled. 78 | 79 | Example: 80 | ```python 81 | from transformers import Trainer 82 | from trl.extras.profiling import profiling_decorator 83 | 84 | 85 | class MyTrainer(Trainer): 86 | @profiling_decorator 87 | def some_method(self): 88 | A = np.random.rand(1000, 1000) 89 | B = np.random.rand(1000, 1000) 90 | # Code to profile: simulate a computationally expensive operation 91 | result = A @ B 92 | ``` 93 | """ 94 | 95 | @functools.wraps(func) 96 | def wrapper(self, *args, **kwargs): 97 | with profiling_context(self, func.__name__): 98 | return func(self, *args, **kwargs) 99 | 100 | return wrapper 101 | -------------------------------------------------------------------------------- /examples/scripts/rloo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # /// script 16 | # dependencies = [ 17 | # "trl[vllm]", 18 | # "peft", 19 | # "math-verify", 20 | # "latex2sympy2_extended", 21 | # "trackio", 22 | # "kernels", 23 | # ] 24 | # /// 25 | 26 | """ 27 | pip install math_verify num2words==0.5.14 peft trackio vllm 28 | export TRACKIO_PROJECT="RLOO-NuminaMath-TIR" 29 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/rloo.py 30 | """ 31 | 32 | import os 33 | 34 | import torch 35 | from datasets import load_dataset 36 | from peft import LoraConfig 37 | 38 | from trl import RLOOConfig, RLOOTrainer 39 | from trl.rewards import accuracy_reward, think_format_reward 40 | 41 | 42 | # Enable logging in a Hugging Face Space 43 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio") 44 | 45 | 46 | def main(): 47 | # Dataset 48 | train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"]) 49 | 50 | SYSTEM_PROMPT = ( 51 | "A conversation between user and assistant. The user asks a question, and the assistant solves it. The " 52 | "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. " 53 | "The reasoning process and answer are enclosed within tags, i.e., \nThis is my " 54 | "reasoning.\n\nThis is my answer." 55 | ) 56 | 57 | def make_conversation(example): 58 | return { 59 | "prompt": [ 60 | {"role": "system", "content": SYSTEM_PROMPT}, 61 | {"role": "user", "content": example["problem"]}, 62 | ], 63 | } 64 | 65 | train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"]) 66 | eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"]) 67 | 68 | # Training 69 | training_args = RLOOConfig( 70 | output_dir="Qwen3-0.6B-RLOO", 71 | model_init_kwargs={"dtype": torch.bfloat16}, 72 | learning_rate=1e-5, 73 | gradient_checkpointing_kwargs=dict(use_reentrant=False), 74 | log_completions=True, 75 | num_completions_to_print=2, 76 | max_prompt_length=2048, 77 | max_completion_length=1024, 78 | gradient_accumulation_steps=2, 79 | steps_per_generation=8, 80 | use_vllm=True, 81 | vllm_mode="colocate", 82 | vllm_gpu_memory_utilization=0.5, 83 | run_name="Qwen3-0.6B-RLOO-NuminaMath-TIR", 84 | ) 85 | 86 | trainer = RLOOTrainer( 87 | model="Qwen/Qwen3-0.6B", 88 | args=training_args, 89 | reward_funcs=[think_format_reward, accuracy_reward], 90 | train_dataset=train_dataset, 91 | eval_dataset=eval_dataset, 92 | peft_config=LoraConfig(), 93 | ) 94 | 95 | trainer.train() 96 | 97 | # Save and push to hub 98 | trainer.save_model(training_args.output_dir) 99 | trainer.push_to_hub(dataset_name="AI-MO/NuminaMath-TIR") 100 | 101 | 102 | if __name__ == "__main__": 103 | main() 104 | -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: TRL 4 | - local: installation 5 | title: Installation 6 | - local: quickstart 7 | title: Quickstart 8 | title: Getting started 9 | - sections: 10 | - local: dataset_formats 11 | title: Dataset Formats 12 | - local: paper_index 13 | title: Paper Index 14 | title: Conceptual Guides 15 | - sections: # Sorted alphabetically 16 | - local: dpo_trainer 17 | title: DPO 18 | - local: grpo_trainer 19 | title: GRPO 20 | - local: reward_trainer 21 | title: Reward 22 | - local: rloo_trainer 23 | title: RLOO 24 | - local: sft_trainer 25 | title: SFT 26 | title: Trainers 27 | - sections: 28 | - local: clis 29 | title: Command Line Interface (CLI) 30 | - local: jobs_training 31 | title: Training using Jobs 32 | - local: customization 33 | title: Customizing the Training 34 | - local: reducing_memory_usage 35 | title: Reducing Memory Usage 36 | - local: speeding_up_training 37 | title: Speeding Up Training 38 | - local: distributing_training 39 | title: Distributing Training 40 | - local: use_model 41 | title: Using Trained Models 42 | title: How-to guides 43 | - sections: 44 | - local: deepspeed_integration 45 | title: DeepSpeed 46 | - local: kernels_hub 47 | title: Kernels Hub 48 | - local: liger_kernel_integration 49 | title: Liger Kernel 50 | - local: peft_integration 51 | title: PEFT 52 | - local: rapidfire_integration 53 | title: RapidFire AI 54 | - local: trackio_integration 55 | title: Trackio 56 | - local: unsloth_integration 57 | title: Unsloth 58 | - local: vllm_integration 59 | title: vLLM 60 | title: Integrations 61 | - sections: 62 | - local: example_overview 63 | title: Example Overview 64 | - local: community_tutorials 65 | title: Community Tutorials 66 | - local: lora_without_regret 67 | title: LoRA Without Regret 68 | title: Examples 69 | - sections: 70 | - sections: 71 | - local: chat_template_utils 72 | title: Chat Template Utilities 73 | - local: data_utils 74 | title: Data Utilities 75 | - local: model_utils 76 | title: Model Utilities 77 | - local: script_utils 78 | title: Script Utilities 79 | title: Utilities 80 | - local: callbacks 81 | title: Callbacks 82 | - local: rewards 83 | title: Reward Functions 84 | - local: others 85 | title: Others 86 | title: API 87 | - sections: 88 | - local: experimental_overview 89 | title: Experimental Overview 90 | - local: openenv 91 | title: OpenEnv Integration 92 | - local: bema_for_reference_model # Sorted alphabetically 93 | title: BEMA for Reference Model 94 | - local: bco_trainer 95 | title: BCO 96 | - local: cpo_trainer 97 | title: CPO 98 | - local: gfpo 99 | title: GFPO 100 | - local: gkd_trainer 101 | title: GKD 102 | - local: gold_trainer 103 | title: GOLD 104 | - local: grpo_with_replay_buffer 105 | title: GRPO With Replay Buffer 106 | - local: gspo_token 107 | title: GSPO-token 108 | - local: judges 109 | title: Judges 110 | - local: kto_trainer 111 | title: KTO 112 | - local: merge_model_callback 113 | title: MergeModelCallback 114 | - local: minillm_trainer 115 | title: MiniLLM 116 | - local: nash_md_trainer 117 | title: Nash-MD 118 | - local: online_dpo_trainer 119 | title: Online DPO 120 | - local: orpo_trainer 121 | title: ORPO 122 | - local: papo_trainer 123 | title: PAPO 124 | - local: ppo_trainer 125 | title: PPO 126 | - local: prm_trainer 127 | title: PRM 128 | - local: winrate_callback 129 | title: WinRateCallback 130 | - local: xpo_trainer 131 | title: XPO 132 | title: Experimental -------------------------------------------------------------------------------- /trl/scripts/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # /// script 16 | # dependencies = [ 17 | # "trl", 18 | # ] 19 | # /// 20 | 21 | import os 22 | import platform 23 | from importlib.metadata import version 24 | 25 | import torch 26 | from accelerate.commands.config import default_config_file, load_config_from_file 27 | from transformers import is_bitsandbytes_available 28 | from transformers.utils import is_openai_available, is_peft_available 29 | 30 | from trl import __version__ 31 | from trl.import_utils import ( 32 | is_deepspeed_available, 33 | is_liger_kernel_available, 34 | is_llm_blender_available, 35 | is_vllm_available, 36 | ) 37 | from trl.scripts.utils import get_git_commit_hash 38 | 39 | 40 | def print_env(): 41 | devices = None 42 | if torch.cuda.is_available(): 43 | devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] 44 | elif torch.backends.mps.is_available(): 45 | devices = ["MPS"] 46 | elif torch.xpu.is_available(): 47 | devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())] 48 | 49 | accelerate_config = accelerate_config_str = "not found" 50 | 51 | # Get the default from the config file. 52 | if os.path.isfile(default_config_file): 53 | accelerate_config = load_config_from_file(default_config_file).to_dict() 54 | 55 | accelerate_config_str = ( 56 | "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) 57 | if isinstance(accelerate_config, dict) 58 | else accelerate_config 59 | ) 60 | 61 | commit_hash = get_git_commit_hash("trl") 62 | 63 | info = { 64 | "Platform": platform.platform(), 65 | "Python version": platform.python_version(), 66 | "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, 67 | "PyTorch version": version("torch"), 68 | "accelerator(s)": ", ".join(devices) if devices is not None else "cpu", 69 | "Transformers version": version("transformers"), 70 | "Accelerate version": version("accelerate"), 71 | "Accelerate config": accelerate_config_str, 72 | "Datasets version": version("datasets"), 73 | "HF Hub version": version("huggingface_hub"), 74 | "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", 75 | "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", 76 | "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", 77 | "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", 78 | "OpenAI version": version("openai") if is_openai_available() else "not installed", 79 | "PEFT version": version("peft") if is_peft_available() else "not installed", 80 | "vLLM version": version("vllm") if is_vllm_available() else "not installed", 81 | } 82 | 83 | info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) 84 | print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa 85 | 86 | 87 | if __name__ == "__main__": 88 | print_env() 89 | -------------------------------------------------------------------------------- /tests/experimental/test_judges.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import sys 16 | import time 17 | 18 | import pytest 19 | 20 | from trl.experimental.judges import AllTrueJudge, HfPairwiseJudge, PairRMJudge 21 | 22 | from ..testing_utils import RandomBinaryJudge, TrlTestCase, require_llm_blender 23 | 24 | 25 | class TestJudges(TrlTestCase): 26 | def _get_prompts_and_pairwise_completions(self): 27 | prompts = ["The capital of France is", "The biggest planet in the solar system is"] 28 | completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] 29 | return prompts, completions 30 | 31 | def _get_prompts_and_single_completions(self): 32 | prompts = ["What's the capital of France?", "What's the color of the sky?"] 33 | completions = ["Marseille", "blue"] 34 | return prompts, completions 35 | 36 | def test_all_true_judge(self): 37 | judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) 38 | prompts, completions = self._get_prompts_and_single_completions() 39 | judgements = judge.judge(prompts=prompts, completions=completions) 40 | assert len(judgements) == 2 41 | assert all(judgement in {0, 1, -1} for judgement in judgements) 42 | 43 | @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.") 44 | def test_hugging_face_judge(self): 45 | judge = HfPairwiseJudge() 46 | prompts, completions = self._get_prompts_and_pairwise_completions() 47 | ranks = judge.judge(prompts=prompts, completions=completions) 48 | assert len(ranks) == 2 49 | assert all(isinstance(rank, int) for rank in ranks) 50 | assert ranks == [0, 1] 51 | 52 | def load_pair_rm_judge(self): 53 | # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. 54 | # This is a workaround to retry loading the model a few times. 55 | for _ in range(5): 56 | try: 57 | return PairRMJudge() 58 | except ValueError: 59 | time.sleep(5) 60 | raise ValueError("Failed to load PairRMJudge") 61 | 62 | @require_llm_blender 63 | @pytest.mark.skipif( 64 | sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)" 65 | ) 66 | def test_pair_rm_judge(self): 67 | judge = self.load_pair_rm_judge() 68 | prompts, completions = self._get_prompts_and_pairwise_completions() 69 | ranks = judge.judge(prompts=prompts, completions=completions) 70 | assert len(ranks) == 2 71 | assert all(isinstance(rank, int) for rank in ranks) 72 | assert ranks == [0, 1] 73 | 74 | @require_llm_blender 75 | @pytest.mark.skipif( 76 | sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)" 77 | ) 78 | def test_pair_rm_judge_return_scores(self): 79 | judge = self.load_pair_rm_judge() 80 | prompts, completions = self._get_prompts_and_pairwise_completions() 81 | probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) 82 | assert len(probs) == 2 83 | assert all(isinstance(prob, float) for prob in probs) 84 | assert all(0 <= prob <= 1 for prob in probs) 85 | --------------------------------------------------------------------------------