├── trl
├── py.typed
├── accelerate_configs
│ ├── single_gpu.yaml
│ ├── multi_gpu.yaml
│ ├── zero1.yaml
│ ├── zero2.yaml
│ ├── zero3.yaml
│ ├── fsdp2.yaml
│ └── fsdp1.yaml
├── extras
│ ├── __init__.py
│ ├── dataset_formatting.py
│ └── profiling.py
├── experimental
│ ├── gspo_token
│ │ └── __init__.py
│ ├── bco
│ │ └── __init__.py
│ ├── gfpo
│ │ ├── __init__.py
│ │ └── gfpo_config.py
│ ├── papo
│ │ ├── __init__.py
│ │ └── papo_config.py
│ ├── bema_for_ref_model
│ │ ├── __init__.py
│ │ └── dpo_trainer.py
│ ├── openenv
│ │ └── __init__.py
│ ├── cpo
│ │ └── __init__.py
│ ├── gkd
│ │ └── __init__.py
│ ├── kto
│ │ └── __init__.py
│ ├── prm
│ │ └── __init__.py
│ ├── xpo
│ │ ├── __init__.py
│ │ └── xpo_config.py
│ ├── gold
│ │ └── __init__.py
│ ├── orpo
│ │ └── __init__.py
│ ├── minillm
│ │ └── __init__.py
│ ├── nash_md
│ │ ├── __init__.py
│ │ └── nash_md_config.py
│ ├── online_dpo
│ │ └── __init__.py
│ ├── grpo_with_replay_buffer
│ │ ├── __init__.py
│ │ └── grpo_with_replay_buffer_config.py
│ ├── ppo
│ │ └── __init__.py
│ ├── judges
│ │ └── __init__.py
│ └── __init__.py
├── scripts
│ ├── __init__.py
│ └── env.py
├── rewards
│ ├── __init__.py
│ ├── format_rewards.py
│ └── other_rewards.py
├── trainer
│ ├── prm_config.py
│ ├── xpo_config.py
│ ├── prm_trainer.py
│ ├── xpo_trainer.py
│ ├── bco_config.py
│ ├── cpo_config.py
│ ├── gkd_config.py
│ ├── ppo_config.py
│ ├── orpo_config.py
│ ├── bco_trainer.py
│ ├── cpo_trainer.py
│ ├── gkd_trainer.py
│ ├── nash_md_config.py
│ ├── ppo_trainer.py
│ ├── orpo_trainer.py
│ ├── nash_md_trainer.py
│ ├── online_dpo_config.py
│ ├── kto_config.py
│ ├── online_dpo_trainer.py
│ ├── kto_trainer.py
│ └── base_trainer.py
├── models
│ ├── __init__.py
│ ├── modeling_base.py
│ └── modeling_value_head.py
├── templates
│ ├── rm_model_card.md
│ └── lm_model_card.md
└── mergekit_utils.py
├── VERSION
├── requirements.txt
├── assets
├── logo-dark.png
└── logo-light.png
├── docs
└── source
│ ├── winrate_callback.md
│ ├── merge_model_callback.md
│ ├── others.md
│ ├── model_utils.md
│ ├── callbacks.md
│ ├── script_utils.md
│ ├── rewards.md
│ ├── chat_template_utils.md
│ ├── installation.md
│ ├── gspo_token.md
│ ├── data_utils.md
│ ├── bema_for_reference_model.md
│ ├── deepspeed_integration.md
│ ├── experimental_overview.md
│ ├── gfpo.md
│ ├── grpo_with_replay_buffer.md
│ ├── trackio_integration.md
│ ├── papo_trainer.md
│ ├── liger_kernel_integration.md
│ ├── use_model.md
│ ├── judges.md
│ └── _toctree.yml
├── examples
├── README.md
├── accelerate_configs
│ ├── single_gpu.yaml
│ ├── multi_gpu.yaml
│ ├── deepspeed_zero1.yaml
│ ├── deepspeed_zero2.yaml
│ ├── deepspeed_zero3.yaml
│ ├── fsdp2.yaml
│ ├── fsdp1.yaml
│ ├── context_parallel_2gpu.yaml
│ └── alst_ulysses_4gpu.yaml
├── cli_configs
│ └── example_config.yaml
├── scripts
│ ├── dpo.py
│ ├── sft.py
│ ├── sft_gemma3.py
│ ├── sft_gpt_oss.py
│ └── rloo.py
├── notebooks
│ └── README.md
└── datasets
│ └── deepmath_103k.py
├── MANIFEST.in
├── docker
├── trl
│ └── Dockerfile
└── trl-dev
│ └── Dockerfile
├── .github
├── workflows
│ ├── issue_auto_labeller.yml
│ ├── upload_pr_documentation.yml
│ ├── build_documentation.yml
│ ├── trufflehog.yml
│ ├── build_pr_documentation.yml
│ ├── codeQL.yml
│ ├── clear_cache.yml
│ ├── publish.yml
│ ├── tests-experimental.yml
│ ├── tests_latest.yml
│ ├── docker-build.yml
│ └── slow-tests.yml
├── codeql
│ └── custom-queries.qls
├── ISSUE_TEMPLATE
│ ├── feature-request.yml
│ ├── new-trainer-addition.yml
│ └── bug-report.yml
└── PULL_REQUEST_TEMPLATE.md
├── .pre-commit-config.yaml
├── tests
├── __init__.py
├── experimental
│ ├── __init__.py
│ ├── test_minillm_trainer.py
│ ├── test_gspo_token_trainer.py
│ ├── test_merge_model_callback.py
│ └── test_judges.py
├── testing_constants.py
├── conftest.py
├── test_model_utils.py
├── test_rich_progress_callback.py
└── test_collators.py
├── Makefile
├── CITATION.cff
├── .gitignore
└── scripts
└── add_copyrights.py
/trl/py.typed:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 0.27.0.dev0
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | accelerate>=1.4.0
2 | datasets>=3.0.0
3 | transformers>=4.56.1
--------------------------------------------------------------------------------
/assets/logo-dark.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/trl/HEAD/assets/logo-dark.png
--------------------------------------------------------------------------------
/assets/logo-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/huggingface/trl/HEAD/assets/logo-light.png
--------------------------------------------------------------------------------
/docs/source/winrate_callback.md:
--------------------------------------------------------------------------------
1 | # WinRateCallback
2 |
3 | [[autodoc]] experimental.winrate_callback.WinRateCallback
4 |
--------------------------------------------------------------------------------
/docs/source/merge_model_callback.md:
--------------------------------------------------------------------------------
1 | # MergeModelCallback
2 |
3 | [[autodoc]] experimental.merge_model_callback.MergeModelCallback
4 |
--------------------------------------------------------------------------------
/examples/README.md:
--------------------------------------------------------------------------------
1 | # Examples
2 |
3 | Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples.
4 |
--------------------------------------------------------------------------------
/MANIFEST.in:
--------------------------------------------------------------------------------
1 | include LICENSE
2 | include CONTRIBUTING.md
3 | include README.md
4 | include trl/accelerate_configs/*.yaml
5 | include trl/templates/*.md
6 | recursive-exclude * __pycache__
7 | prune tests
8 |
--------------------------------------------------------------------------------
/docs/source/others.md:
--------------------------------------------------------------------------------
1 | # Other
2 |
3 | ## profiling_decorator
4 |
5 | [[autodoc]] extras.profiling.profiling_decorator
6 |
7 | ## profiling_context
8 |
9 | [[autodoc]] extras.profiling.profiling_context
10 |
--------------------------------------------------------------------------------
/docker/trl/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
2 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
3 | RUN pip install --upgrade pip uv
4 | RUN uv pip install --system trl[liger,peft,vlm] kernels trackio
--------------------------------------------------------------------------------
/docs/source/model_utils.md:
--------------------------------------------------------------------------------
1 | # Model Utilities
2 |
3 | ## get_act_offloading_ctx_manager
4 |
5 | [[autodoc]] models.get_act_offloading_ctx_manager
6 |
7 | ## disable_gradient_checkpointing
8 |
9 | [[autodoc]] models.utils.disable_gradient_checkpointing
10 |
11 | ## create_reference_model
12 |
13 | [[autodoc]] create_reference_model
14 |
--------------------------------------------------------------------------------
/docker/trl-dev/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
2 | RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/*
3 | RUN pip install --upgrade pip uv
4 | RUN uv pip install --system --no-cache "git+https://github.com/huggingface/trl.git#egg=trl[liger,peft,vlm]"
5 | RUN uv pip install --system kernels liger_kernel peft trackio
--------------------------------------------------------------------------------
/docs/source/callbacks.md:
--------------------------------------------------------------------------------
1 | # Callbacks
2 |
3 | ## SyncRefModelCallback
4 |
5 | [[autodoc]] SyncRefModelCallback
6 |
7 | ## RichProgressCallback
8 |
9 | [[autodoc]] RichProgressCallback
10 |
11 | ## LogCompletionsCallback
12 |
13 | [[autodoc]] LogCompletionsCallback
14 |
15 | ## BEMACallback
16 |
17 | [[autodoc]] BEMACallback
18 |
19 | ## WeaveCallback
20 |
21 | [[autodoc]] WeaveCallback
22 |
--------------------------------------------------------------------------------
/.github/workflows/issue_auto_labeller.yml:
--------------------------------------------------------------------------------
1 | name: "Hugging Face Issue Labeler"
2 | on:
3 | issues:
4 | types: opened
5 |
6 | jobs:
7 | triage:
8 | runs-on: ubuntu-latest
9 | permissions:
10 | issues: write
11 | steps:
12 | - uses: actions/checkout@v3
13 | - uses: August-murr/auto-labeler@main
14 | with:
15 | hf-api-key: ${{ secrets.CI_HF_API_TOKEN }}
16 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/single_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: "NO"
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: 'bf16'
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/single_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: "NO"
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: 'bf16'
9 | num_machines: 1
10 | num_processes: 1
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/multi_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: 'bf16'
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/multi_gpu.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: MULTI_GPU
4 | downcast_bf16: 'no'
5 | gpu_ids: all
6 | machine_rank: 0
7 | main_training_function: main
8 | mixed_precision: 'bf16'
9 | num_machines: 1
10 | num_processes: 8
11 | rdzv_backend: static
12 | same_network: true
13 | tpu_env: []
14 | tpu_use_cluster: false
15 | tpu_use_sudo: false
16 | use_cpu: false
17 |
--------------------------------------------------------------------------------
/.github/workflows/upload_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Upload PR Documentation
2 |
3 | on:
4 | workflow_run:
5 | workflows: ["Build PR Documentation"]
6 | types:
7 | - completed
8 |
9 | jobs:
10 | build:
11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
12 | with:
13 | package_name: trl
14 | secrets:
15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }}
--------------------------------------------------------------------------------
/docs/source/script_utils.md:
--------------------------------------------------------------------------------
1 | # Scripts Utilities
2 |
3 | ## ScriptArguments
4 |
5 | [[autodoc]] ScriptArguments
6 |
7 | ## TrlParser
8 |
9 | [[autodoc]] TrlParser
10 | - parse_args_and_config
11 | - parse_args_into_dataclasses
12 | - set_defaults_with_config
13 |
14 | ## get_dataset
15 |
16 | [[autodoc]] get_dataset
17 |
18 | ## DatasetConfig
19 |
20 | [[autodoc]] scripts.utils.DatasetConfig
21 |
22 | ## DatasetMixtureConfig
23 |
24 | [[autodoc]] DatasetMixtureConfig
25 |
--------------------------------------------------------------------------------
/examples/cli_configs/example_config.yaml:
--------------------------------------------------------------------------------
1 | # This is an example configuration file of TRL CLI, you can use it for
2 | # SFT like that: `trl sft --config config.yaml --output_dir test-sft`
3 | # The YAML file supports environment variables by adding an `env` field
4 | # as below
5 |
6 | # env:
7 | # CUDA_VISIBLE_DEVICES: 0
8 |
9 | model_name_or_path:
10 | Qwen/Qwen2.5-0.5B
11 | dataset_name:
12 | stanfordnlp/imdb
13 | report_to:
14 | none
15 | learning_rate:
16 | 0.0001
17 | lr_scheduler_type:
18 | cosine
19 |
--------------------------------------------------------------------------------
/.github/workflows/build_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build documentation
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - doc-builder*
8 | - v*-release
9 |
10 | env:
11 | TRL_EXPERIMENTAL_SILENCE: 1
12 |
13 | jobs:
14 | build:
15 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
16 | with:
17 | commit_sha: ${{ github.sha }}
18 | package: trl
19 | version_tag_suffix: ""
20 | secrets:
21 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }}
22 |
--------------------------------------------------------------------------------
/docs/source/rewards.md:
--------------------------------------------------------------------------------
1 | # Reward Functions
2 |
3 | This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`] and [`RLOOTrainer`].
4 |
5 | ## accuracy_reward
6 |
7 | [[autodoc]] rewards.accuracy_reward
8 |
9 | ## reasoning_accuracy_reward
10 |
11 | [[autodoc]] rewards.reasoning_accuracy_reward
12 |
13 | ## think_format_reward
14 |
15 | [[autodoc]] rewards.think_format_reward
16 |
17 | ## get_soft_overlong_punishment
18 |
19 | [[autodoc]] rewards.get_soft_overlong_punishment
20 |
--------------------------------------------------------------------------------
/docs/source/chat_template_utils.md:
--------------------------------------------------------------------------------
1 | # Chat template utilities
2 |
3 | ## clone_chat_template
4 |
5 | [[autodoc]] clone_chat_template
6 |
7 | ## add_response_schema
8 |
9 | [[autodoc]] chat_template_utils.add_response_schema
10 |
11 | ## is_chat_template_prefix_preserving
12 |
13 | [[autodoc]] chat_template_utils.is_chat_template_prefix_preserving
14 |
15 | ## get_training_chat_template
16 |
17 | [[autodoc]] chat_template_utils.get_training_chat_template
18 |
19 | ## parse_response
20 |
21 | [[autodoc]] chat_template_utils.parse_response
22 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/zero1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 1
6 | zero3_init_flag: false
7 | zero_stage: 1
8 | distributed_type: DEEPSPEED
9 | downcast_bf16: 'no'
10 | machine_rank: 0
11 | main_training_function: main
12 | mixed_precision: 'bf16'
13 | num_machines: 1
14 | num_processes: 8
15 | rdzv_backend: static
16 | same_network: true
17 | tpu_env: []
18 | tpu_use_cluster: false
19 | tpu_use_sudo: false
20 | use_cpu: false
21 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/deepspeed_zero1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | gradient_accumulation_steps: 1
6 | zero3_init_flag: false
7 | zero_stage: 1
8 | distributed_type: DEEPSPEED
9 | downcast_bf16: 'no'
10 | machine_rank: 0
11 | main_training_function: main
12 | mixed_precision: 'bf16'
13 | num_machines: 1
14 | num_processes: 8
15 | rdzv_backend: static
16 | same_network: true
17 | tpu_env: []
18 | tpu_use_cluster: false
19 | tpu_use_sudo: false
20 | use_cpu: false
21 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | repos:
2 | - repo: https://github.com/astral-sh/ruff-pre-commit
3 | rev: v0.13.3
4 | hooks:
5 | - id: ruff-check
6 | types_or: [ python, pyi ]
7 | args: [ --fix ]
8 | - id: ruff-format
9 | types_or: [ python, pyi ]
10 |
11 | # - repo: https://github.com/codespell-project/codespell
12 | # rev: v2.1.0
13 | # hooks:
14 | # - id: codespell
15 | # args:
16 | # - --ignore-words-list=nd,reacher,thist,ths,magent,ba
17 | # - --skip=docs/css/termynal.css,docs/js/termynal.js
18 |
--------------------------------------------------------------------------------
/.github/workflows/trufflehog.yml:
--------------------------------------------------------------------------------
1 | on:
2 | push:
3 |
4 | name: Secret Leaks
5 |
6 | jobs:
7 | trufflehog:
8 | runs-on: ubuntu-latest
9 | steps:
10 | - name: Checkout code
11 | uses: actions/checkout@v4
12 | with:
13 | fetch-depth: 0
14 | - name: Secret Scanning
15 | uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d
16 | with:
17 | # exclude buggy postgres detector that is causing false positives and not relevant to our codebase
18 | extra_args: --results=verified,unknown --exclude-detectors=postgres
19 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: 'bf16'
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/deepspeed_zero2.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: false
8 | zero_stage: 2
9 | distributed_type: DEEPSPEED
10 | downcast_bf16: 'no'
11 | machine_rank: 0
12 | main_training_function: main
13 | mixed_precision: 'bf16'
14 | num_machines: 1
15 | num_processes: 8
16 | rdzv_backend: static
17 | same_network: true
18 | tpu_env: []
19 | tpu_use_cluster: false
20 | tpu_use_sudo: false
21 | use_cpu: false
22 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/deepspeed_zero3.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | deepspeed_config:
4 | deepspeed_multinode_launcher: standard
5 | offload_optimizer_device: none
6 | offload_param_device: none
7 | zero3_init_flag: true
8 | zero3_save_16bit_model: true
9 | zero_stage: 3
10 | distributed_type: DEEPSPEED
11 | downcast_bf16: 'no'
12 | machine_rank: 0
13 | main_training_function: main
14 | mixed_precision: bf16
15 | num_machines: 1
16 | num_processes: 8
17 | rdzv_backend: static
18 | same_network: true
19 | tpu_env: []
20 | tpu_use_cluster: false
21 | tpu_use_sudo: false
22 | use_cpu: false
23 |
--------------------------------------------------------------------------------
/.github/workflows/build_pr_documentation.yml:
--------------------------------------------------------------------------------
1 | name: Build PR Documentation
2 |
3 | on:
4 | pull_request:
5 |
6 | env:
7 | TRL_EXPERIMENTAL_SILENCE: 1
8 |
9 | concurrency:
10 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
11 | cancel-in-progress: true
12 |
13 | jobs:
14 | build:
15 | if: github.event.pull_request.draft == false
16 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main
17 | with:
18 | commit_sha: ${{ github.event.pull_request.head.sha }}
19 | pr_number: ${{ github.event.number }}
20 | package: trl
21 | version_tag_suffix: ""
22 |
--------------------------------------------------------------------------------
/tests/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/trl/extras/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/tests/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
--------------------------------------------------------------------------------
/Makefile:
--------------------------------------------------------------------------------
1 | .PHONY: test precommit common_tests slow_tests tests_gpu test_experimental
2 |
3 | check_dirs := examples tests trl
4 |
5 | ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs
6 |
7 | test:
8 | pytest -n auto -m "not slow and not low_priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/
9 |
10 | precommit:
11 | python scripts/add_copyrights.py
12 | pre-commit run --all-files
13 | doc-builder style trl tests docs/source --max_len 119
14 |
15 | slow_tests:
16 | pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",)
17 |
18 | test_experimental:
19 | pytest -k "experimental" -n auto -s -v
--------------------------------------------------------------------------------
/.github/workflows/codeQL.yml:
--------------------------------------------------------------------------------
1 | name: "CodeQL Analysis - Workflows"
2 |
3 | on:
4 | workflow_dispatch:
5 |
6 | jobs:
7 | analyze:
8 | name: "Analyze GitHub Workflows"
9 | runs-on: ubuntu-latest
10 | permissions:
11 | security-events: write
12 | actions: read
13 | contents: read
14 |
15 | steps:
16 | - name: "Checkout repository"
17 | uses: actions/checkout@v4
18 |
19 | - name: "Initialize CodeQL"
20 | uses: github/codeql-action/init@v2
21 | with:
22 | languages: "yaml"
23 | queries: +security-and-quality, ./.github/codeql/custom-queries.qls
24 |
25 | - name: "Perform CodeQL Analysis"
26 | uses: github/codeql-action/analyze@v2
27 |
--------------------------------------------------------------------------------
/trl/experimental/gspo_token/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .grpo_trainer import GRPOTrainer
16 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/fsdp2.yaml:
--------------------------------------------------------------------------------
1 | # Requires accelerate 1.7.0 or higher
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | enable_cpu_affinity: false
7 | fsdp_config:
8 | fsdp_activation_checkpointing: false
9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_offload_params: false
12 | fsdp_reshard_after_forward: true
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_version: 2
15 | machine_rank: 0
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 1
19 | num_processes: 8
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/fsdp2.yaml:
--------------------------------------------------------------------------------
1 | # Requires accelerate 1.7.0 or higher
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | enable_cpu_affinity: false
7 | fsdp_config:
8 | fsdp_activation_checkpointing: false
9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_offload_params: false
12 | fsdp_reshard_after_forward: true
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_version: 2
15 | machine_rank: 0
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 1
19 | num_processes: 8
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 |
--------------------------------------------------------------------------------
/trl/experimental/bco/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .bco_config import BCOConfig
16 | from .bco_trainer import BCOTrainer
17 |
--------------------------------------------------------------------------------
/trl/experimental/gfpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .gfpo_config import GFPOConfig
16 | from .gfpo_trainer import GFPOTrainer
17 |
--------------------------------------------------------------------------------
/trl/experimental/papo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | from .papo_config import PAPOConfig
17 | from .papo_trainer import PAPOTrainer
18 |
--------------------------------------------------------------------------------
/trl/experimental/bema_for_ref_model/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .callback import BEMACallback
16 | from .dpo_trainer import DPOTrainer
17 |
--------------------------------------------------------------------------------
/trl/experimental/openenv/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .utils import generate_rollout_completions
16 |
17 |
18 | __all__ = ["generate_rollout_completions"]
19 |
--------------------------------------------------------------------------------
/trl/experimental/cpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .cpo_config import CPOConfig
16 | from .cpo_trainer import CPOTrainer
17 |
18 |
19 | __all__ = ["CPOConfig", "CPOTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/gkd/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .gkd_config import GKDConfig
16 | from .gkd_trainer import GKDTrainer
17 |
18 |
19 | __all__ = ["GKDConfig", "GKDTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/kto/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .kto_config import KTOConfig
16 | from .kto_trainer import KTOTrainer
17 |
18 |
19 | __all__ = ["KTOConfig", "KTOTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/prm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .prm_config import PRMConfig
16 | from .prm_trainer import PRMTrainer
17 |
18 |
19 | __all__ = ["PRMConfig", "PRMTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/xpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .xpo_config import XPOConfig
16 | from .xpo_trainer import XPOTrainer
17 |
18 |
19 | __all__ = ["XPOConfig", "XPOTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/gold/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .gold_config import GOLDConfig
16 | from .gold_trainer import GOLDTrainer
17 |
18 |
19 | __all__ = ["GOLDConfig", "GOLDTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/orpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .orpo_config import ORPOConfig
16 | from .orpo_trainer import ORPOTrainer
17 |
18 |
19 | __all__ = ["ORPOConfig", "ORPOTrainer"]
20 |
--------------------------------------------------------------------------------
/tests/testing_constants.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__"
16 | CI_HUB_USER_FULL_NAME = "Dummy User"
17 |
18 | CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co"
19 |
--------------------------------------------------------------------------------
/trl/experimental/minillm/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .minillm_config import MiniLLMConfig
16 | from .minillm_trainer import MiniLLMTrainer
17 |
18 |
19 | __all__ = ["MiniLLMConfig", "MiniLLMTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/nash_md/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .nash_md_config import NashMDConfig
16 | from .nash_md_trainer import NashMDTrainer
17 |
18 |
19 | __all__ = ["NashMDConfig", "NashMDTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/experimental/online_dpo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .online_dpo_config import OnlineDPOConfig
16 | from .online_dpo_trainer import OnlineDPOTrainer
17 |
18 |
19 | __all__ = ["OnlineDPOConfig", "OnlineDPOTrainer"]
20 |
--------------------------------------------------------------------------------
/trl/accelerate_configs/fsdp1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | enable_cpu_affinity: false
6 | fsdp_config:
7 | fsdp_activation_checkpointing: false
8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9 | fsdp_backward_prefetch: BACKWARD_PRE
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_forward_prefetch: true
12 | fsdp_offload_params: false
13 | fsdp_reshard_after_forward: FULL_SHARD
14 | fsdp_state_dict_type: FULL_STATE_DICT
15 | fsdp_sync_module_states: true
16 | fsdp_use_orig_params: true
17 | fsdp_version: 1
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 8
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
29 |
--------------------------------------------------------------------------------
/trl/experimental/grpo_with_replay_buffer/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .grpo_with_replay_buffer_config import GRPOWithReplayBufferConfig
16 | from .grpo_with_replay_buffer_trainer import GRPOWithReplayBufferTrainer, ReplayBuffer
17 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/fsdp1.yaml:
--------------------------------------------------------------------------------
1 | compute_environment: LOCAL_MACHINE
2 | debug: false
3 | distributed_type: FSDP
4 | downcast_bf16: 'no'
5 | enable_cpu_affinity: false
6 | fsdp_config:
7 | fsdp_activation_checkpointing: false
8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
9 | fsdp_backward_prefetch: BACKWARD_PRE
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_forward_prefetch: true
12 | fsdp_offload_params: false
13 | fsdp_reshard_after_forward: FULL_SHARD
14 | fsdp_state_dict_type: FULL_STATE_DICT
15 | fsdp_sync_module_states: true
16 | fsdp_use_orig_params: true
17 | fsdp_version: 1
18 | machine_rank: 0
19 | main_training_function: main
20 | mixed_precision: bf16
21 | num_machines: 1
22 | num_processes: 8
23 | rdzv_backend: static
24 | same_network: true
25 | tpu_env: []
26 | tpu_use_cluster: false
27 | tpu_use_sudo: false
28 | use_cpu: false
29 |
--------------------------------------------------------------------------------
/examples/scripts/dpo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ###############################################################################################
16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py #
17 | ###############################################################################################
18 |
--------------------------------------------------------------------------------
/examples/scripts/sft.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | ###############################################################################################
16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py #
17 | ###############################################################################################
18 |
--------------------------------------------------------------------------------
/.github/workflows/clear_cache.yml:
--------------------------------------------------------------------------------
1 | name: "Cleanup Cache"
2 |
3 | on:
4 | workflow_dispatch:
5 | schedule:
6 | - cron: "0 0 * * *"
7 |
8 | jobs:
9 | cleanup:
10 | runs-on: ubuntu-latest
11 | steps:
12 | - name: Check out code
13 | uses: actions/checkout@v4
14 |
15 | - name: Cleanup
16 | run: |
17 | gh extension install actions/gh-actions-cache
18 |
19 | REPO=${{ github.repository }}
20 |
21 | echo "Fetching list of cache key"
22 | cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 )
23 |
24 | ## Setting this to not fail the workflow while deleting cache keys.
25 | set +e
26 | echo "Deleting caches..."
27 | for cacheKey in $cacheKeysForPR
28 | do
29 | gh actions-cache delete $cacheKey -R $REPO --confirm
30 | done
31 | echo "Done"
32 | env:
33 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
34 |
--------------------------------------------------------------------------------
/docs/source/installation.md:
--------------------------------------------------------------------------------
1 | # Installation
2 |
3 | You can install TRL either from PyPI or from source:
4 |
5 | ## PyPI
6 |
7 | Install the library with pip or [uv](https://docs.astral.sh/uv/):
8 |
9 |
10 |
11 |
12 | uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions.
13 |
14 | ```bash
15 | uv pip install trl
16 | ```
17 |
18 |
19 |
20 |
21 | ```bash
22 | pip install trl
23 | ```
24 |
25 |
26 |
27 |
28 | ## Source
29 |
30 | You can also install the latest version from source. First clone the repo and then run the installation with `pip`:
31 |
32 | ```bash
33 | git clone https://github.com/huggingface/trl.git
34 | cd trl/
35 | pip install -e .
36 | ```
37 |
38 | If you want the development install you can replace the pip install with the following:
39 |
40 | ```bash
41 | pip install -e ".[dev]"
42 | ```
43 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/context_parallel_2gpu.yaml:
--------------------------------------------------------------------------------
1 | # Context Parallelism with FSDP for 2 GPUs
2 | compute_environment: LOCAL_MACHINE
3 | debug: false
4 | distributed_type: FSDP
5 | downcast_bf16: 'no'
6 | enable_cpu_affinity: false
7 | fsdp_config:
8 | fsdp_activation_checkpointing: true # Enable activation checkpointing for memory efficiency
9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
10 | fsdp_cpu_ram_efficient_loading: true
11 | fsdp_offload_params: false
12 | fsdp_reshard_after_forward: true
13 | fsdp_state_dict_type: FULL_STATE_DICT
14 | fsdp_version: 2
15 | machine_rank: 0
16 | main_training_function: main
17 | mixed_precision: bf16
18 | num_machines: 1
19 | num_processes: 2 # Number of GPUs
20 | rdzv_backend: static
21 | same_network: true
22 | tpu_env: []
23 | tpu_use_cluster: false
24 | tpu_use_sudo: false
25 | use_cpu: false
26 | parallelism_config:
27 | parallelism_config_dp_replicate_size: 1
28 | parallelism_config_dp_shard_size: 1
29 | parallelism_config_tp_size: 1
30 | parallelism_config_cp_size: 2 # Context parallel size
31 |
--------------------------------------------------------------------------------
/docs/source/gspo_token.md:
--------------------------------------------------------------------------------
1 | # GSPO-token
2 |
3 | In the paper [Group Sequence Policy Optimization](https://huggingface.co/papers/2507.18071), the authors propose a token-level objective variant to GSPO, called GSPO-token. To use GSPO-token, you can use the `GRPOTrainer` class in `trl.experimental.gspo_token`.
4 |
5 | ## Usage
6 |
7 | ```python
8 | from trl.experimental.gspo_token import GRPOTrainer
9 | from trl import GRPOConfig
10 |
11 | training_args = GRPOConfig(
12 | importance_sampling_level="sequence_token",
13 | ...
14 | )
15 | ```
16 |
17 | > [!WARNING]
18 | > To leverage GSPO-token, the user will need to provide the per-token advantage \\( \hat{A_{i,t}} \\) for each token \\( t \\) in the sequence \\( i \\) (i.e., make \\( \hat{A_{i,t}} \\) varies with \\( t \\)—which isn't the case here, \\( \hat{A_{i,t}}=\hat{A_{i}} \\)). Otherwise, GSPO-Token gradient is just equivalent to the original GSPO implementation.
19 |
20 | ## GRPOTrainer
21 |
22 | [[autodoc]] experimental.gspo_token.GRPOTrainer
23 | - train
24 | - save_model
25 | - push_to_hub
26 |
--------------------------------------------------------------------------------
/docs/source/data_utils.md:
--------------------------------------------------------------------------------
1 | # Data Utilities
2 |
3 | ## prepare_multimodal_messages
4 |
5 | [[autodoc]] prepare_multimodal_messages
6 |
7 | ## prepare_multimodal_messages_vllm
8 |
9 | [[autodoc]] prepare_multimodal_messages_vllm
10 |
11 | ## is_conversational
12 |
13 | [[autodoc]] is_conversational
14 |
15 | ## is_conversational_from_value
16 |
17 | [[autodoc]] is_conversational_from_value
18 |
19 | ## apply_chat_template
20 |
21 | [[autodoc]] apply_chat_template
22 |
23 | ## maybe_apply_chat_template
24 |
25 | [[autodoc]] maybe_apply_chat_template
26 |
27 | ## maybe_convert_to_chatml
28 |
29 | [[autodoc]] maybe_convert_to_chatml
30 |
31 | ## extract_prompt
32 |
33 | [[autodoc]] extract_prompt
34 |
35 | ## maybe_extract_prompt
36 |
37 | [[autodoc]] maybe_extract_prompt
38 |
39 | ## unpair_preference_dataset
40 |
41 | [[autodoc]] unpair_preference_dataset
42 |
43 | ## maybe_unpair_preference_dataset
44 |
45 | [[autodoc]] maybe_unpair_preference_dataset
46 |
47 | ## pack_dataset
48 |
49 | [[autodoc]] pack_dataset
50 |
51 | ## truncate_dataset
52 |
53 | [[autodoc]] truncate_dataset
54 |
--------------------------------------------------------------------------------
/.github/codeql/custom-queries.qls:
--------------------------------------------------------------------------------
1 | import codeql
2 |
3 | from WorkflowString interpolation, Workflow workflow
4 | where
5 | interpolation.getStringValue().matches("${{ github.event.issue.title }}") or
6 | interpolation.getStringValue().matches("${{ github.event.issue.body }}") or
7 | interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or
8 | interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or
9 | interpolation.getStringValue().matches("${{ github.event.review.body }}") or
10 | interpolation.getStringValue().matches("${{ github.event.comment.body }}") or
11 | interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or
12 | interpolation.getStringValue().matches("${{ github.event.head_commit.message }}")
13 | interpolation.getStringValue().matches("${{ github.event.* }}") and
14 | (
15 | step.getKey() = "run" or // Injection in run
16 | step.getKey() = "env" or // Injection via env
17 | step.getKey() = "with" // Injection via with
18 | )
19 | select workflow, "🚨 Do not use directly as input of action"
20 |
--------------------------------------------------------------------------------
/trl/experimental/ppo/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .modeling_value_head import (
16 | AutoModelForCausalLMWithValueHead,
17 | AutoModelForSeq2SeqLMWithValueHead,
18 | PreTrainedModelWrapper,
19 | )
20 | from .ppo_config import PPOConfig
21 | from .ppo_trainer import PPOTrainer
22 |
23 |
24 | __all__ = [
25 | "AutoModelForCausalLMWithValueHead",
26 | "AutoModelForSeq2SeqLMWithValueHead",
27 | "PreTrainedModelWrapper",
28 | "PPOConfig",
29 | "PPOTrainer",
30 | ]
31 |
--------------------------------------------------------------------------------
/trl/experimental/judges/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from .judges import (
16 | AllTrueJudge,
17 | BaseBinaryJudge,
18 | BaseJudge,
19 | BasePairwiseJudge,
20 | BaseRankJudge,
21 | HfPairwiseJudge,
22 | OpenAIPairwiseJudge,
23 | PairRMJudge,
24 | )
25 |
26 |
27 | __all__ = [
28 | "AllTrueJudge",
29 | "BaseBinaryJudge",
30 | "BaseJudge",
31 | "BasePairwiseJudge",
32 | "BaseRankJudge",
33 | "HfPairwiseJudge",
34 | "OpenAIPairwiseJudge",
35 | "PairRMJudge",
36 | ]
37 |
--------------------------------------------------------------------------------
/.github/workflows/publish.yml:
--------------------------------------------------------------------------------
1 | name: Publish to PyPI
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | - v*-release
8 | paths:
9 | - "VERSION"
10 |
11 | jobs:
12 | publish:
13 | runs-on: ubuntu-latest
14 | steps:
15 | - uses: actions/checkout@v4
16 |
17 | - name: Read version
18 | id: get_version
19 | run: echo "version=$(cat VERSION)" >> $GITHUB_OUTPUT
20 |
21 | - name: Debug - Show version.txt content
22 | run: echo "Version is ${{ steps.get_version.outputs.version }}"
23 |
24 | - name: Set up Python
25 | uses: actions/setup-python@v4
26 | with:
27 | python-version: "3.x"
28 |
29 | - name: Install dependencies
30 | run: |
31 | python -m pip install --upgrade pip
32 | pip install build twine
33 |
34 | - name: Build package
35 | run: python -m build
36 |
37 | - name: Publish to PyPI
38 | if: ${{ !contains(steps.get_version.outputs.version, 'dev') }}
39 | env:
40 | TWINE_USERNAME: __token__
41 | TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
42 | run: |
43 | python -m twine upload dist/*
44 |
--------------------------------------------------------------------------------
/trl/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import _LazyModule
18 |
19 |
20 | _import_structure = {
21 | "utils": ["DatasetMixtureConfig", "ScriptArguments", "TrlParser", "get_dataset", "init_zero_verbose"],
22 | }
23 |
24 | if TYPE_CHECKING:
25 | from .utils import DatasetMixtureConfig, ScriptArguments, TrlParser, get_dataset, init_zero_verbose
26 | else:
27 | import sys
28 |
29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
30 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/feature-request.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F680 Feature request"
2 | description: Submit a proposal/request for a new TRL feature
3 | labels: [ "Feature request" ]
4 | body:
5 | - type: textarea
6 | id: feature-request
7 | validations:
8 | required: true
9 | attributes:
10 | label: Feature request
11 | description: |
12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist.
13 |
14 | - type: textarea
15 | id: motivation
16 | validations:
17 | required: true
18 | attributes:
19 | label: Motivation
20 | description: |
21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too.
22 |
23 |
24 | - type: textarea
25 | id: contribution
26 | validations:
27 | required: true
28 | attributes:
29 | label: Your contribution
30 | description: |
31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md)
32 |
--------------------------------------------------------------------------------
/CITATION.cff:
--------------------------------------------------------------------------------
1 | cff-version: 1.2.0
2 | title: 'TRL: Transformer Reinforcement Learning'
3 | message: >-
4 | If you use this software, please cite it using the
5 | metadata from this file.
6 | type: software
7 | authors:
8 | - given-names: Leandro
9 | family-names: von Werra
10 | - given-names: Younes
11 | family-names: Belkada
12 | - given-names: Lewis
13 | family-names: Tunstall
14 | - given-names: Edward
15 | family-names: Beeching
16 | - given-names: Tristan
17 | family-names: Thrush
18 | - given-names: Nathan
19 | family-names: Lambert
20 | - given-names: Shengyi
21 | family-names: Huang
22 | - given-names: Kashif
23 | family-names: Rasul
24 | - given-names: Quentin
25 | family-names: Gallouédec
26 | repository-code: 'https://github.com/huggingface/trl'
27 | abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported."
28 | keywords:
29 | - rlhf
30 | - deep-learning
31 | - pytorch
32 | - transformers
33 | license: Apache-2.0
34 | version: "0.26"
35 |
--------------------------------------------------------------------------------
/tests/conftest.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import gc
16 |
17 | import pytest
18 | import torch
19 |
20 |
21 | @pytest.fixture(autouse=True)
22 | def cleanup_gpu():
23 | """
24 | Automatically cleanup GPU memory after each test.
25 |
26 | This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring
27 | models and tensors are properly garbage collected and GPU memory caches are cleared between tests.
28 | """
29 | yield
30 | # Cleanup after test
31 | gc.collect()
32 | if torch.cuda.is_available():
33 | torch.cuda.empty_cache()
34 | torch.cuda.synchronize()
35 |
--------------------------------------------------------------------------------
/trl/experimental/bema_for_ref_model/dpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from ...trainer.dpo_trainer import DPOTrainer as _DPOTrainer
16 | from .callback import CallbackHandlerWithRefModel
17 |
18 |
19 | class DPOTrainer(_DPOTrainer):
20 | def __init__(self, *args, **kwargs):
21 | super().__init__(*args, **kwargs)
22 | # Replace with a new one that calls the events with the reference model
23 | self.callback_handler = CallbackHandlerWithRefModel(
24 | self.callback_handler.callbacks,
25 | self.model,
26 | self.ref_model,
27 | self.processing_class,
28 | self.optimizer,
29 | self.lr_scheduler,
30 | )
31 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/new-trainer-addition.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F31F New trainer addition"
2 | description: Submit a proposal/request to implement a new trainer for a post-training method
3 | labels: [ "New trainer" ]
4 |
5 | body:
6 | - type: textarea
7 | id: description-request
8 | validations:
9 | required: true
10 | attributes:
11 | label: Method description
12 | description: |
13 | Put any and all important information relative to the method
14 |
15 | - type: checkboxes
16 | id: information-tasks
17 | attributes:
18 | label: Open source status
19 | description: |
20 | Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`.
21 | options:
22 | - label: "The method implementation is available"
23 | - label: "The model weights are available"
24 | - label: "The training datasets are available"
25 |
26 | - type: textarea
27 | id: additional-info
28 | attributes:
29 | label: Provide useful links for the implementation
30 | description: |
31 | Please provide information regarding the implementation, the weights, and the authors.
32 | Please mention the authors by @gh-username if you're aware of their usernames.
33 |
--------------------------------------------------------------------------------
/docs/source/bema_for_reference_model.md:
--------------------------------------------------------------------------------
1 | # BEMA for Reference Model
2 |
3 | This feature implements the BEMA algorithm to update the reference model during DPO training.
4 |
5 | ## Usage
6 |
7 | ```python
8 | from trl.experimental.bema_for_ref_model import BEMACallback, DPOTrainer
9 | from datasets import load_dataset
10 | from transformers import AutoModelForCausalLM, AutoTokenizer
11 |
12 |
13 | pref_dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
14 | ref_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
15 |
16 | bema_callback = BEMACallback(update_ref_model=True)
17 |
18 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
19 | tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
20 | tokenizer.pad_token = tokenizer.eos_token
21 |
22 | trainer = DPOTrainer(
23 | model=model,
24 | ref_model=ref_model,
25 | train_dataset=pref_dataset,
26 | processing_class=tokenizer,
27 | callbacks=[bema_callback],
28 | )
29 |
30 | trainer.train()
31 | ```
32 |
33 | ## DPOTrainer
34 |
35 | [[autodoc]] experimental.bema_for_ref_model.DPOTrainer
36 | - train
37 | - save_model
38 | - push_to_hub
39 |
40 | ## BEMACallback
41 |
42 | [[autodoc]] experimental.bema_for_ref_model.BEMACallback
43 |
--------------------------------------------------------------------------------
/trl/rewards/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | from typing import TYPE_CHECKING
17 |
18 | from ..import_utils import _LazyModule
19 |
20 |
21 | _import_structure = {
22 | "accuracy_rewards": ["accuracy_reward", "reasoning_accuracy_reward"],
23 | "format_rewards": ["think_format_reward"],
24 | "other_rewards": ["get_soft_overlong_punishment"],
25 | }
26 |
27 |
28 | if TYPE_CHECKING:
29 | from .accuracy_rewards import accuracy_reward, reasoning_accuracy_reward
30 | from .format_rewards import think_format_reward
31 | from .other_rewards import get_soft_overlong_punishment
32 |
33 |
34 | else:
35 | sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__)
36 |
--------------------------------------------------------------------------------
/trl/extras/dataset_formatting.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import datasets
17 | from datasets import Value
18 | from packaging import version
19 |
20 |
21 | if version.parse(datasets.__version__) >= version.parse("4.0.0"):
22 | from datasets import List
23 |
24 | FORMAT_MAPPING = {
25 | "chatml": List({"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}),
26 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
27 | }
28 | else:
29 | FORMAT_MAPPING = {
30 | "chatml": [{"content": Value(dtype="string", id=None), "role": Value(dtype="string", id=None)}],
31 | "instruction": {"completion": Value(dtype="string", id=None), "prompt": Value(dtype="string", id=None)},
32 | }
33 |
--------------------------------------------------------------------------------
/trl/trainer/prm_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.prm import PRMConfig as _PRMConfig
22 |
23 |
24 | class PRMConfig(_PRMConfig):
25 | def __post_init__(self):
26 | warnings.warn(
27 | "The `PRMConfig` is now located in `trl.experimental`. Please update your imports to "
28 | "`from trl.experimental.xco import PRMConfig`. The current import path will be removed and no longer "
29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
30 | FutureWarning,
31 | stacklevel=2,
32 | )
33 | super().__post_init__()
34 |
--------------------------------------------------------------------------------
/trl/trainer/xpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.xpo import XPOConfig as _XPOConfig
22 |
23 |
24 | class XPOConfig(_XPOConfig):
25 | def __post_init__(self):
26 | warnings.warn(
27 | "The `XPOConfig` is now located in `trl.experimental`. Please update your imports to "
28 | "`from trl.experimental.xco import XPOConfig`. The current import path will be removed and no longer "
29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
30 | FutureWarning,
31 | stacklevel=2,
32 | )
33 | super().__post_init__()
34 |
--------------------------------------------------------------------------------
/trl/trainer/prm_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.prm import PRMTrainer as _PRMTrainer
22 |
23 |
24 | class PRMTrainer(_PRMTrainer):
25 | def __init__(self, *args, **kwargs):
26 | warnings.warn(
27 | "The `PRMTrainer` is now located in `trl.experimental`. Please update your imports to "
28 | "`from trl.experimental.prm import PRMTrainer`. The current import path will be removed and no longer "
29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
30 | FutureWarning,
31 | stacklevel=2,
32 | )
33 | super().__init__(*args, **kwargs)
34 |
--------------------------------------------------------------------------------
/trl/trainer/xpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.xpo import XPOTrainer as _XPOTrainer
22 |
23 |
24 | class XPOTrainer(_XPOTrainer):
25 | def __init__(self, *args, **kwargs):
26 | warnings.warn(
27 | "The `XPOTrainer` is now located in `trl.experimental`. Please update your imports to "
28 | "`from trl.experimental.xpo import XPOTrainer`. The current import path will be removed and no longer "
29 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
30 | FutureWarning,
31 | stacklevel=2,
32 | )
33 | super().__init__(*args, **kwargs)
34 |
--------------------------------------------------------------------------------
/trl/experimental/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """
16 | Experimental submodule for TRL.
17 |
18 | This submodule contains unstable or incubating features. Anything here may change (or be removed) in any release
19 | without deprecation. Use at your own risk.
20 |
21 | To silence this notice set environment variable TRL_EXPERIMENTAL_SILENCE=1.
22 | """
23 |
24 | import os
25 | import warnings
26 |
27 | from ..import_utils import TRLExperimentalWarning
28 |
29 |
30 | if not os.environ.get("TRL_EXPERIMENTAL_SILENCE"):
31 | warnings.warn(
32 | "You are importing from 'trl.experimental'. APIs here are unstable and may change or be removed without "
33 | "notice. Silence this warning by setting environment variable TRL_EXPERIMENTAL_SILENCE=1.",
34 | TRLExperimentalWarning,
35 | stacklevel=2,
36 | )
37 |
--------------------------------------------------------------------------------
/.github/PULL_REQUEST_TEMPLATE.md:
--------------------------------------------------------------------------------
1 | # What does this PR do?
2 |
3 |
12 |
13 |
14 |
15 | Fixes # (issue)
16 |
17 |
18 | ## Before submitting
19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
20 | - [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request),
21 | Pull Request section?
22 | - [ ] Was this discussed/approved via a GitHub issue? Please add a link
23 | to it if that's the case.
24 | - [ ] Did you make sure to update the documentation with your changes?
25 | - [ ] Did you write any new necessary tests?
26 |
27 |
28 | ## Who can review?
29 |
30 | Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
31 | members/contributors who may be interested in your PR.
--------------------------------------------------------------------------------
/trl/experimental/gfpo/gfpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from ...trainer.grpo_config import GRPOConfig as _GRPOConfig
18 |
19 |
20 | @dataclass
21 | class GFPOConfig(_GRPOConfig):
22 | num_remains_in_group: int | None = field(
23 | default=None,
24 | metadata={
25 | "help": "number inputs remains after group filter function, `'num_remains_in_group'` must be >=2 if given."
26 | },
27 | )
28 |
29 | def __post_init__(self):
30 | super().__post_init__()
31 |
32 | if self.num_remains_in_group is not None and self.num_remains_in_group >= self.num_generations:
33 | raise ValueError(
34 | f"Number remains in Group {self.num_remains_in_group} must be less than num_generations : {self.num_generations}."
35 | )
36 |
--------------------------------------------------------------------------------
/trl/trainer/bco_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.bco import BCOConfig as _BCOConfig
23 |
24 |
25 | @dataclass
26 | class BCOConfig(_BCOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `BCOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.bco import BCOConfig`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/cpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.cpo import CPOConfig as _CPOConfig
23 |
24 |
25 | @dataclass
26 | class CPOConfig(_CPOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `CPOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.cpo import CPOConfig`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/gkd_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.gkd import GKDConfig as _GKDConfig
23 |
24 |
25 | @dataclass
26 | class GKDConfig(_GKDConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `GKDConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.gkd import GKDConfig`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/ppo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.ppo import PPOConfig as _PPOConfig
23 |
24 |
25 | @dataclass
26 | class PPOConfig(_PPOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `PPOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.ppo import PPOConfig`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/orpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.orpo import ORPOConfig as _ORPOConfig
23 |
24 |
25 | @dataclass
26 | class ORPOConfig(_ORPOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `ORPOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.orpo import ORPOConfig`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/bco_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.bco import BCOTrainer as _BCOTrainer
23 |
24 |
25 | @dataclass
26 | class BCOTrainer(_BCOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `BCOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.bco import BCOTrainer`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/cpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.cpo import CPOTrainer as _CPOTrainer
23 |
24 |
25 | @dataclass
26 | class CPOTrainer(_CPOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `CPOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.cpo import CPOTrainer`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/gkd_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.gkd import GKDTrainer as _GKDTrainer
23 |
24 |
25 | @dataclass
26 | class GKDTrainer(_GKDTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `GKDTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.gkd import GKDTrainer`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/nash_md_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.nash_md import NashMDConfig as _NashMDConfig
23 |
24 |
25 | @dataclass
26 | class NashMDConfig(_NashMDConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `NashMDConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.nash_md import NashMDConfig`. The current import path will be removed and no "
31 | "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__post_init__()
36 |
--------------------------------------------------------------------------------
/trl/trainer/ppo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.ppo import PPOTrainer as _PPOTrainer
23 |
24 |
25 | @dataclass
26 | class PPOTrainer(_PPOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `PPOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.ppo import PPOTrainer`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/orpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.orpo import ORPOTrainer as _ORPOTrainer
23 |
24 |
25 | @dataclass
26 | class ORPOTrainer(_ORPOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `ORPOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.orpo import ORPOTrainer`. The current import path will be removed and no longer "
31 | "supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/nash_md_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.nash_md import NashMDTrainer as _NashMDTrainer
23 |
24 |
25 | @dataclass
26 | class NashMDTrainer(_NashMDTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `NashMDTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.nash_md import NashMDTrainer`. The current import path will be removed and no "
31 | "longer supported in TRL 0.29. For more information, see https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
--------------------------------------------------------------------------------
/trl/trainer/online_dpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.online_dpo import OnlineDPOConfig as _OnlineDPOConfig
23 |
24 |
25 | @dataclass
26 | class OnlineDPOConfig(_OnlineDPOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `OnlineDPOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.online_dpo import OnlineDPOConfig`. The current import path will be removed and "
31 | "no longer supported in TRL 0.29. For more information, see "
32 | "https://github.com/huggingface/trl/issues/4223.",
33 | FutureWarning,
34 | stacklevel=2,
35 | )
36 | super().__post_init__()
37 |
--------------------------------------------------------------------------------
/trl/experimental/grpo_with_replay_buffer/grpo_with_replay_buffer_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from ...trainer.grpo_config import GRPOConfig
18 |
19 |
20 | @dataclass
21 | class GRPOWithReplayBufferConfig(GRPOConfig):
22 | """
23 | New Parameters:
24 | replay_buffer_size (`int`, *optional*, defaults to `0`):
25 | A cache that stores the rollouts with the highest advantage scores and variance per group. If a new
26 | group has 0 variance, it is replaced with a group sampled from the replay buffer.
27 | """
28 |
29 | replay_buffer_size: int = field(
30 | default=64,
31 | metadata={
32 | "help": "A cache that stores the rollouts with the highest advantage scores and variance per group. If a new group has 0 variance, it is replaced with a group sampled from the replay buffer."
33 | },
34 | )
35 |
--------------------------------------------------------------------------------
/trl/trainer/kto_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.kto import KTOConfig as _KTOConfig
23 |
24 |
25 | @dataclass
26 | class KTOConfig(_KTOConfig):
27 | def __post_init__(self):
28 | warnings.warn(
29 | "The `KTOConfig` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.kto import KTOConfig`. For more information, see "
31 | "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. "
32 | "Until then, this current path (`from trl import KTOConfig`) will remain, but API changes may occur.",
33 | FutureWarning,
34 | stacklevel=2,
35 | )
36 | super().__post_init__()
37 |
--------------------------------------------------------------------------------
/trl/trainer/online_dpo_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.online_dpo import OnlineDPOTrainer as _OnlineDPOTrainer
23 |
24 |
25 | @dataclass
26 | class OnlineDPOTrainer(_OnlineDPOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `OnlineDPOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.online_dpo import OnlineDPOTrainer`. The current import path will be removed and "
31 | "no longer supported in TRL 0.29. For more information, see "
32 | "https://github.com/huggingface/trl/issues/4223.",
33 | FutureWarning,
34 | stacklevel=2,
35 | )
36 | super().__init__(*args, **kwargs)
37 |
--------------------------------------------------------------------------------
/trl/trainer/kto_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 | from dataclasses import dataclass
17 |
18 | from ..import_utils import suppress_experimental_warning
19 |
20 |
21 | with suppress_experimental_warning():
22 | from ..experimental.kto import KTOTrainer as _KTOTrainer
23 |
24 |
25 | @dataclass
26 | class KTOTrainer(_KTOTrainer):
27 | def __init__(self, *args, **kwargs):
28 | warnings.warn(
29 | "The `KTOTrainer` is now located in `trl.experimental`. Please update your imports to "
30 | "`from trl.experimental.kto import KTOTrainer`. For more information, see "
31 | "https://github.com/huggingface/trl/issues/4223. Promoting KTO to the stable API is a high-priority task. "
32 | "Until then, this current path (`from trl import KTOTrainer`) will remain, but API changes may occur.",
33 | FutureWarning,
34 | stacklevel=2,
35 | )
36 | super().__init__(*args, **kwargs)
37 |
--------------------------------------------------------------------------------
/tests/test_model_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from transformers import AutoModelForCausalLM
16 |
17 | from trl.models.utils import disable_gradient_checkpointing
18 |
19 |
20 | class TestDisableGradientCheckpointing:
21 | def test_when_disabled(self):
22 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
23 | assert model.is_gradient_checkpointing is False
24 | with disable_gradient_checkpointing(model):
25 | assert model.is_gradient_checkpointing is False
26 | assert model.is_gradient_checkpointing is False
27 |
28 | def test_when_enabled(self):
29 | model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
30 | model.gradient_checkpointing_enable()
31 | assert model.is_gradient_checkpointing is True
32 | with disable_gradient_checkpointing(model):
33 | assert model.is_gradient_checkpointing is False
34 | assert model.is_gradient_checkpointing is True
35 |
--------------------------------------------------------------------------------
/docs/source/deepspeed_integration.md:
--------------------------------------------------------------------------------
1 | # DeepSpeed Integration
2 |
3 | > [!WARNING]
4 | > Section under construction. Feel free to contribute!
5 |
6 | TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more.
7 |
8 | DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency.
9 |
10 | 
11 |
12 | ## Installation
13 |
14 | To use DeepSpeed with TRL, install it using the following command:
15 |
16 | ```bash
17 | pip install deepspeed
18 | ```
19 |
20 | ## Running Training Scripts with DeepSpeed
21 |
22 | No modifications to your training script are required. Simply run it with the DeepSpeed configuration file:
23 |
24 | ```bash
25 | accelerate launch --config_file train.py
26 | ```
27 |
28 | We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command:
29 |
30 | ```bash
31 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py
32 | ```
33 |
34 | ## Additional Resources
35 |
36 | Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin.
37 |
--------------------------------------------------------------------------------
/trl/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from typing import TYPE_CHECKING
16 |
17 | from ..import_utils import _LazyModule
18 |
19 |
20 | _import_structure = {
21 | "activation_offloading": ["get_act_offloading_ctx_manager"],
22 | "modeling_base": ["PreTrainedModelWrapper"],
23 | "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"],
24 | "utils": ["create_reference_model", "prepare_deepspeed", "prepare_fsdp", "unwrap_model_for_generation"],
25 | }
26 |
27 |
28 | if TYPE_CHECKING:
29 | from .activation_offloading import get_act_offloading_ctx_manager
30 | from .modeling_base import PreTrainedModelWrapper
31 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead
32 | from .utils import create_reference_model, prepare_deepspeed, prepare_fsdp, unwrap_model_for_generation
33 | else:
34 | import sys
35 |
36 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)
37 |
--------------------------------------------------------------------------------
/examples/accelerate_configs/alst_ulysses_4gpu.yaml:
--------------------------------------------------------------------------------
1 | # ALST/Ulysses Sequence Parallelism with 2D Parallelism (DP + SP) for 4 GPUs
2 | #
3 | # This configuration enables 2D parallelism:
4 | # - Sequence Parallelism (sp_size=2): Sequences split across 2 GPUs using ALST/Ulysses
5 | # - Data Parallelism (dp_shard_size=2): Model/optimizer sharded across 2 GPUs
6 | # - Total: 4 GPUs (2 × 2)
7 | #
8 | # Set parallelism_config in your training script:
9 | # parallelism_config = ParallelismConfig(
10 | # sp_backend="deepspeed",
11 | # sp_size=2,
12 | # dp_shard_size=2, # Calculated as: num_gpus // sp_size
13 | # sp_handler=DeepSpeedSequenceParallelConfig(...)
14 | # )
15 |
16 | compute_environment: LOCAL_MACHINE
17 | debug: false
18 | deepspeed_config:
19 | zero_stage: 3
20 | seq_parallel_communication_data_type: bf16
21 | offload_optimizer_device: none
22 | offload_param_device: none
23 | zero3_init_flag: true
24 | zero3_save_16bit_model: true
25 | distributed_type: DEEPSPEED
26 | downcast_bf16: 'no'
27 | machine_rank: 0
28 | main_training_function: main
29 | mixed_precision: bf16
30 | num_machines: 1
31 | num_processes: 4 # Total number of GPUs
32 | rdzv_backend: static
33 | same_network: true
34 | tpu_env: []
35 | tpu_use_cluster: false
36 | tpu_use_sudo: false
37 | use_cpu: false
38 | parallelism_config:
39 | parallelism_config_dp_replicate_size: 1
40 | parallelism_config_dp_shard_size: 2 # Enables 2D parallelism with SP
41 | parallelism_config_tp_size: 1
42 | parallelism_config_sp_size: 2 # Sequence parallel size
43 | parallelism_config_sp_backend: deepspeed
44 | parallelism_config_sp_seq_length_is_variable: true
45 | parallelism_config_sp_attn_implementation: flash_attention_2
46 |
--------------------------------------------------------------------------------
/docs/source/experimental_overview.md:
--------------------------------------------------------------------------------
1 | # Experimental
2 |
3 | This directory contains a minimal, clearly separated space for fast iteration on new ideas.
4 |
5 | > [!WARNING]
6 | > **Stability contract:** Anything under `trl.experimental` may change or be removed in *any* release (including patch versions) without prior deprecation. Do not rely on these APIs for production workloads.
7 |
8 | ## Promotion Path (Simple)
9 |
10 | 1. **Prototype outside the main repo:** Start development in your own fork or a separate repository to iterate quickly.
11 | 2. **Experimental inclusion:** Once it’s ready for early users, move the idea into `trl.experimental.`.
12 | 3. **Improve:** Add tests, a short doc/example, and demonstrate the usage.
13 | 4. **Promote:** Once the API proves stable and there is clear interest or adoption from the community, move it into `trl.` (stable module).
14 |
15 | ## FAQ
16 |
17 | **Why not just use branches?**
18 | Because branches are not shipped to users; experimental code inside the package lets early adopters try things and give feedback.
19 |
20 | **Can these APIs change or vanish without warning?**
21 | Yes. Anything inside `trl.experimental` can change or disappear in *any* release.
22 |
23 | **Should I use this in production?**
24 | Only if you are fine with updating your code quickly when things change.
25 |
26 | **Will maintainers promptly fix issues in `trl.experimental`?**
27 | Not necessarily. The experimental module is a playground for new ideas, and maintainers may not prioritize bug fixes or feature requests there. Issues may remain unresolved until (or unless) the feature graduates to the stable API.
28 |
29 | **How to silence the runtime notice?**
30 |
31 | Use: `export TRL_EXPERIMENTAL_SILENCE=1`.
32 |
--------------------------------------------------------------------------------
/trl/models/modeling_base.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.ppo.modeling_value_head import PreTrainedModelWrapper as _PreTrainedModelWrapper
22 |
23 |
24 | LAYER_PATTERNS = [
25 | "transformer.h.{layer}",
26 | "model.decoder.layers.{layer}",
27 | "gpt_neox.layers.{layer}",
28 | "model.layers.{layer}",
29 | ]
30 |
31 |
32 | class PreTrainedModelWrapper(_PreTrainedModelWrapper):
33 | def __init__(self, *args, **kwargs):
34 | warnings.warn(
35 | "The `PreTrainedModelWrapper` is now located in `trl.experimental`. Please update your imports to "
36 | "`from trl.experimental.bco import PreTrainedModelWrapper`. The current import path will be removed and "
37 | "no longer supported in TRL 0.29. For more information, see "
38 | "https://github.com/huggingface/trl/issues/4223.",
39 | FutureWarning,
40 | stacklevel=2,
41 | )
42 | super().__init__(*args, **kwargs)
43 |
--------------------------------------------------------------------------------
/docs/source/gfpo.md:
--------------------------------------------------------------------------------
1 | # GFPO
2 |
3 | This feature implements the GFPO algorithm to enforce concise reasoning in the model's output generation, as proposed in the paper [Sample More to Think Less: Group Filtered Policy Optimization for Concise Reasoning](https://huggingface.co/papers/2508.09726).
4 |
5 | ## Usage
6 |
7 | To activate GFPO in [`GFPOTrainer`]:
8 |
9 | - set `num_remains_in_group` in [`GFPOConfig`]
10 | - define a group filter function and set it to `group_filter_func` in [`GFPOTrainer`]. `group_filter_func` will score the `num_generations` completions and The GFPOTrainer filters groups according to their scores to get top `num_remains_in_group` completions as a new group. Model will be trained on the filtered group.
11 |
12 | ```python
13 | # train_gfpo.py
14 | from trl.experimental.gfpo import GFPOConfig, GFPOTrainer
15 |
16 | # dummy group filter to scores the completions based on its indice in group
17 | class GroupFilter:
18 | def __call__(self, group_completions, group_rewards, **kwargs):
19 | group_scores = []
20 | for completions, rewards in zip(group_completions, group_rewards):
21 | scores = [float(i) for i in range(len(completions))]
22 | group_scores.append(scores)
23 | return group_scores
24 |
25 | training_args = GFPOConfig(
26 | output_dir="Qwen3-0.6B-GFPO",
27 | per_device_train_batch_size=4,
28 | num_remains_in_group=2,
29 | bf16=True,
30 | )
31 | trainer = GFPOTrainer(
32 | model="Qwen/Qwen3-0.6B",
33 | reward_funcs=...,
34 | train_dataset=...,
35 | args=training_args,
36 | group_filter_func=GroupFilter(),
37 | )
38 | trainer.train()
39 | ```
40 |
41 | ## GFPOTrainer
42 |
43 | [[autodoc]] experimental.gfpo.GFPOTrainer
44 | - train
45 | - save_model
46 | - push_to_hub
47 |
48 | ## GFPOConfig
49 |
50 | [[autodoc]] experimental.gfpo.GFPOConfig
51 |
--------------------------------------------------------------------------------
/trl/experimental/xpo/xpo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from ..online_dpo import OnlineDPOConfig
18 |
19 |
20 | @dataclass
21 | class XPOConfig(OnlineDPOConfig):
22 | r"""
23 | Configuration class for the [`experimental.xpo.XPOTrainer`].
24 |
25 | Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following:
26 |
27 | Parameters:
28 | alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`):
29 | Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch
30 | and the last alpha is used for the rest of the epochs.
31 | """
32 |
33 | alpha: list[float] = field(
34 | default_factory=lambda: [1e-5],
35 | metadata={
36 | "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each "
37 | "new epoch and the last alpha is used for the rest of the epochs."
38 | },
39 | )
40 |
41 | def __post_init__(self):
42 | super().__post_init__()
43 | if hasattr(self.alpha, "__len__") and len(self.alpha) == 1:
44 | self.alpha = self.alpha[0]
45 |
--------------------------------------------------------------------------------
/docs/source/grpo_with_replay_buffer.md:
--------------------------------------------------------------------------------
1 | # GRPO With Replay Buffer
2 |
3 | This experimental trainer, trains a model with GRPO but replaces groups (and corresponding completions) that have 0 standard deviation with groups with high rewards and standard deviation that've been used to train a model in prior batches.
4 |
5 | ## Usage
6 |
7 | ```python
8 | import torch
9 | from trl.experimental.grpo_with_replay_buffer import GRPOWithReplayBufferConfig, GRPOWithReplayBufferTrainer
10 | from datasets import load_dataset
11 |
12 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
13 |
14 | # Guarantee that some rewards have 0 std
15 | def custom_reward_func(completions, **kwargs):
16 | if torch.rand(1).item() < 0.25:
17 | return [0] * len(completions) # simulate some None rewards
18 | else:
19 | return torch.rand(len(completions)).tolist()
20 |
21 | training_args = GRPOWithReplayBufferConfig(
22 | output_dir="./tmp",
23 | learning_rate=1e-4,
24 | per_device_train_batch_size=4,
25 | num_generations=4,
26 | max_completion_length=8,
27 | replay_buffer_size=8,
28 | report_to="none",
29 | )
30 |
31 | trainer = GRPOWithReplayBufferTrainer(
32 | model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
33 | reward_funcs=[custom_reward_func],
34 | args=training_args,
35 | train_dataset=dataset,
36 | )
37 |
38 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
39 |
40 | trainer.train()
41 | ```
42 |
43 | ## GRPOWithReplayBufferTrainer
44 |
45 | [[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferTrainer
46 | - train
47 | - save_model
48 | - push_to_hub
49 |
50 | ## GRPOWithReplayBufferConfig
51 |
52 | [[autodoc]] experimental.grpo_with_replay_buffer.GRPOWithReplayBufferConfig
53 |
54 | ## ReplayBuffer
55 |
56 | [[autodoc]] experimental.grpo_with_replay_buffer.ReplayBuffer
57 |
--------------------------------------------------------------------------------
/.github/workflows/tests-experimental.yml:
--------------------------------------------------------------------------------
1 | name: Tests (experimental)
2 |
3 | on:
4 | pull_request:
5 | paths:
6 | # Run only when relevant files are modified
7 | - "trl/experimental/**"
8 | - "tests/experimental/**"
9 |
10 | env:
11 | TQDM_DISABLE: 1
12 | PYTORCH_CUDA_ALLOC_CONF: "expandable_segments:True"
13 | TRL_EXPERIMENTAL_SILENCE: 1
14 |
15 | jobs:
16 | check_code_quality:
17 | name: Check code quality
18 | runs-on: ubuntu-latest
19 | if: github.event.pull_request.draft == false
20 | steps:
21 | - uses: actions/checkout@v4
22 | - name: Set up Python 3.13
23 | uses: actions/setup-python@v5
24 | with:
25 | python-version: 3.13
26 | - uses: pre-commit/action@v3.0.1
27 | with:
28 | extra_args: --all-files
29 |
30 | tests:
31 | name: Tests (experimental)
32 | runs-on:
33 | group: aws-g4dn-2xlarge
34 | container:
35 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
36 | options: --gpus all
37 | defaults:
38 | run:
39 | shell: bash
40 | steps:
41 | - name: Git checkout
42 | uses: actions/checkout@v4
43 |
44 | - name: Set up Python 3.13
45 | uses: actions/setup-python@v5
46 | with:
47 | python-version: 3.13
48 |
49 | - name: Install Make and Git
50 | run: |
51 | apt-get update && apt-get install -y make git curl
52 |
53 | - name: Install uv
54 | run: |
55 | curl -LsSf https://astral.sh/uv/install.sh | sh
56 |
57 | - name: Create Python virtual environment
58 | run: |
59 | uv venv
60 | uv pip install --upgrade setuptools wheel
61 |
62 | - name: Install dependencies
63 | run: |
64 | source .venv/bin/activate
65 | uv pip install ".[dev]"
66 |
67 | - name: Test with pytest
68 | run: |
69 | source .venv/bin/activate
70 | make test_experimental
71 |
--------------------------------------------------------------------------------
/trl/experimental/nash_md/nash_md_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from ..online_dpo import OnlineDPOConfig
18 |
19 |
20 | @dataclass
21 | class NashMDConfig(OnlineDPOConfig):
22 | r"""
23 | Configuration class for the [`experimental.nash_md.NashMDTrainer`].
24 |
25 | Subclass of [`experimental.online_dpo.OnlineDPOConfig`] we can use all its arguments and add the following:
26 |
27 | Parameters:
28 | mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`):
29 | Logit mixture coefficient for the model and reference model. If a list of floats is provided then the
30 | mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the
31 | epochs.
32 | """
33 |
34 | mixture_coef: list[float] = field(
35 | default_factory=lambda: [0.5],
36 | metadata={
37 | "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided "
38 | "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the "
39 | "rest of the epochs."
40 | },
41 | )
42 |
43 | def __post_init__(self):
44 | super().__post_init__()
45 | if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1:
46 | self.mixture_coef = self.mixture_coef[0]
47 |
--------------------------------------------------------------------------------
/trl/templates/rm_model_card.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 | # Model Card for {{ model_name }}
6 |
7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
8 | It has been trained using [TRL](https://github.com/huggingface/trl).
9 |
10 | ## Quick start
11 |
12 | ```python
13 | from transformers import pipeline
14 |
15 | text = "The capital of France is Paris."
16 | rewarder = pipeline(model="{{ hub_model_id }}", device="cuda")
17 | output = rewarder(text)[0]
18 | print(output["score"])
19 | ```
20 |
21 | ## Training procedure
22 |
23 | {% if wandb_url %}[
]({{ wandb_url }}){% endif %}
24 | {% if comet_url %}[
]({{ comet_url }}){% endif %}
25 |
26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
27 |
28 | ### Framework versions
29 |
30 | - TRL: {{ trl_version }}
31 | - Transformers: {{ transformers_version }}
32 | - Pytorch: {{ pytorch_version }}
33 | - Datasets: {{ datasets_version }}
34 | - Tokenizers: {{ tokenizers_version }}
35 |
36 | ## Citations
37 |
38 | {% if trainer_citation %}Cite {{ trainer_name }} as:
39 |
40 | ```bibtex
41 | {{ trainer_citation }}
42 | ```{% endif %}
43 |
44 | Cite TRL as:
45 |
46 | ```bibtex
47 | {% raw %}@misc{vonwerra2022trl,
48 | title = {{TRL: Transformer Reinforcement Learning}},
49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
50 | year = 2020,
51 | journal = {GitHub repository},
52 | publisher = {GitHub},
53 | howpublished = {\url{https://github.com/huggingface/trl}}
54 | }{% endraw %}
55 | ```
56 |
--------------------------------------------------------------------------------
/.github/workflows/tests_latest.yml:
--------------------------------------------------------------------------------
1 | name: Tests latest TRL release with dev dependencies
2 |
3 | on:
4 | schedule:
5 | - cron: '0 0 * * *' # Runs daily at midnight UTC
6 |
7 | workflow_dispatch:
8 |
9 | env:
10 | TQDM_DISABLE: 1
11 | CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }}
12 | TRL_EXPERIMENTAL_SILENCE: 1
13 |
14 | jobs:
15 | tests:
16 | name: Tests latest TRL release with dev dependencies
17 | runs-on:
18 | group: aws-g4dn-2xlarge
19 | container:
20 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
21 | options: --gpus all
22 | defaults:
23 | run:
24 | shell: bash
25 | steps:
26 | - name: Git checkout
27 | uses: actions/checkout@v4
28 | with: { ref: v0.26-release }
29 |
30 | - name: Set up Python 3.12
31 | uses: actions/setup-python@v5
32 | with:
33 | python-version: '3.12'
34 |
35 | - name: Install Make and Git
36 | run: |
37 | apt-get update && apt-get install -y make git curl
38 |
39 | - name: Install uv
40 | run: |
41 | curl -LsSf https://astral.sh/uv/install.sh | sh
42 |
43 | - name: Create Python virtual environment
44 | run: |
45 | uv venv
46 | uv pip install --upgrade setuptools wheel
47 |
48 | - name: Install dependencies
49 | run: |
50 | source .venv/bin/activate
51 | uv pip install ".[dev]"
52 | uv pip install -U git+https://github.com/huggingface/accelerate.git
53 | uv pip install -U git+https://github.com/huggingface/datasets.git
54 | uv pip install -U git+https://github.com/huggingface/transformers.git
55 |
56 | - name: Test with pytest
57 | run: |
58 | source .venv/bin/activate
59 | make test
60 |
61 | - name: Post to Slack
62 | uses: huggingface/hf-workflows/.github/actions/post-slack@main
63 | with:
64 | slack_channel: ${{ env.CI_SLACK_CHANNEL }}
65 | title: Results of latest TRL with Python 3.12 and dev dependencies
66 | status: ${{ job.status }}
67 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
68 |
--------------------------------------------------------------------------------
/tests/test_rich_progress_callback.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 | import torch.nn as nn
17 | from datasets import Dataset
18 | from transformers import Trainer, TrainingArguments
19 |
20 | from trl.trainer.callbacks import RichProgressCallback
21 |
22 | from .testing_utils import TrlTestCase, require_rich
23 |
24 |
25 | class DummyModel(nn.Module):
26 | def __init__(self):
27 | super().__init__()
28 | self.a = nn.Parameter(torch.tensor(1.0))
29 |
30 | def forward(self, x):
31 | return self.a * x
32 |
33 |
34 | @require_rich
35 | class TestRichProgressCallback(TrlTestCase):
36 | def setup_method(self):
37 | self.dummy_model = DummyModel()
38 | self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5)
39 | self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101)
40 |
41 | def test_rich_progress_callback_logging(self):
42 | training_args = TrainingArguments(
43 | output_dir=self.tmp_dir,
44 | per_device_eval_batch_size=2,
45 | per_device_train_batch_size=2,
46 | num_train_epochs=4,
47 | eval_strategy="steps",
48 | eval_steps=1,
49 | logging_strategy="steps",
50 | logging_steps=1,
51 | save_strategy="no",
52 | report_to="none",
53 | disable_tqdm=True,
54 | )
55 | callbacks = [RichProgressCallback()]
56 | trainer = Trainer(
57 | model=self.dummy_model,
58 | train_dataset=self.dummy_train_dataset,
59 | eval_dataset=self.dummy_val_dataset,
60 | args=training_args,
61 | callbacks=callbacks,
62 | )
63 |
64 | trainer.train()
65 |
--------------------------------------------------------------------------------
/trl/templates/lm_model_card.md:
--------------------------------------------------------------------------------
1 | ---
2 | {{ card_data }}
3 | ---
4 |
5 | # Model Card for {{ model_name }}
6 |
7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}.
8 | It has been trained using [TRL](https://github.com/huggingface/trl).
9 |
10 | ## Quick start
11 |
12 | ```python
13 | from transformers import pipeline
14 |
15 | question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?"
16 | generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda")
17 | output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0]
18 | print(output["generated_text"])
19 | ```
20 |
21 | ## Training procedure
22 |
23 | {% if wandb_url %}[
]({{ wandb_url }}){% endif %}
24 | {% if comet_url %}[
]({{ comet_url }}){% endif %}
25 |
26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}.
27 |
28 | ### Framework versions
29 |
30 | - TRL: {{ trl_version }}
31 | - Transformers: {{ transformers_version }}
32 | - Pytorch: {{ pytorch_version }}
33 | - Datasets: {{ datasets_version }}
34 | - Tokenizers: {{ tokenizers_version }}
35 |
36 | ## Citations
37 |
38 | {% if trainer_citation %}Cite {{ trainer_name }} as:
39 |
40 | ```bibtex
41 | {{ trainer_citation }}
42 | ```{% endif %}
43 |
44 | Cite TRL as:
45 |
46 | ```bibtex
47 | {% raw %}@misc{vonwerra2022trl,
48 | title = {{TRL: Transformer Reinforcement Learning}},
49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec},
50 | year = 2020,
51 | journal = {GitHub repository},
52 | publisher = {GitHub},
53 | howpublished = {\url{https://github.com/huggingface/trl}}
54 | }{% endraw %}
55 | ```
56 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.bak
2 | .gitattributes
3 | .last_checked
4 | .gitconfig
5 | *.bak
6 | *.log
7 | *~
8 | ~*
9 | _tmp*
10 | tmp*
11 | tags
12 |
13 | # Byte-compiled / optimized / DLL files
14 | __pycache__/
15 | *.py[cod]
16 | *$py.class
17 |
18 | # C extensions
19 | *.so
20 |
21 | # Distribution / packaging
22 | .Python
23 | env/
24 | build/
25 | develop-eggs/
26 | dist/
27 | downloads/
28 | eggs/
29 | .eggs/
30 | lib/
31 | lib64/
32 | parts/
33 | sdist/
34 | var/
35 | wheels/
36 | *.egg-info/
37 | .installed.cfg
38 | *.egg
39 |
40 | # PyInstaller
41 | # Usually these files are written by a python script from a template
42 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
43 | *.manifest
44 | *.spec
45 |
46 | # Installer logs
47 | pip-log.txt
48 | pip-delete-this-directory.txt
49 |
50 | # Unit test / coverage reports
51 | htmlcov/
52 | .tox/
53 | .coverage
54 | .coverage.*
55 | .cache
56 | nosetests.xml
57 | coverage.xml
58 | *.cover
59 | .hypothesis/
60 |
61 | # Translations
62 | *.mo
63 | *.pot
64 |
65 | # Django stuff:
66 | *.log
67 | local_settings.py
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # dotenv
95 | .env
96 |
97 | # virtualenv
98 | .venv
99 | venv/
100 | ENV/
101 |
102 | # Spyder project settings
103 | .spyderproject
104 | .spyproject
105 |
106 | # Rope project settings
107 | .ropeproject
108 |
109 | # mkdocs documentation
110 | /site
111 |
112 | # mypy
113 | .mypy_cache/
114 |
115 | .vscode
116 | *.swp
117 |
118 | # osx generated files
119 | .DS_Store
120 | .DS_Store?
121 | .Trashes
122 | ehthumbs.db
123 | Thumbs.db
124 | .idea
125 |
126 | # pytest
127 | .pytest_cache
128 |
129 | # tools/trust-doc-nbs
130 | docs_src/.last_checked
131 |
132 | # symlinks to fastai
133 | docs_src/fastai
134 | tools/fastai
135 |
136 | # link checker
137 | checklink/cookies.txt
138 |
139 | # .gitconfig is now autogenerated
140 | .gitconfig
141 |
142 | # wandb files
143 | nbs/wandb/
144 | examples/notebooks/wandb/
145 | wandb/
--------------------------------------------------------------------------------
/docs/source/trackio_integration.md:
--------------------------------------------------------------------------------
1 | # Trackio Integration
2 |
3 | [Trackio](https://huggingface.co/docs/trackio) is a lightweight, free experiment tracking library built on top of **🤗 Datasets** and **🤗 Spaces**. It is the **recommended tracking solution for TRL** and comes natively integrated with all trainers.
4 |
5 | To enable logging, simply set `report_to="trackio"` in your training config:
6 |
7 | ```python
8 | from trl import SFTConfig # works with any trainer config (e.g. DPOConfig, GRPOConfig, etc.)
9 |
10 | training_args = SFTConfig(
11 | ...,
12 | report_to="trackio", # enable Trackio logging
13 | )
14 | ```
15 |
16 | ## Organizing Your Experiments with Run Names and Projects
17 |
18 | By default, Trackio will generate a name to identify each run. However, we highly recommend setting a descriptive `run_name` to make it easier to organize experiments. For example:
19 |
20 | ```python
21 | from trl import SFTConfig
22 |
23 | training_args = SFTConfig(
24 | ...,
25 | report_to="trackio",
26 | run_name="sft_qwen3-4b_lr2e-5_bs128", # descriptive run name
27 | )
28 | ```
29 |
30 | You can also group related experiments by project by setting the following environment variable:
31 |
32 | ```bash
33 | export TRACKIO_PROJECT="my_project"
34 | ```
35 |
36 | ## Hosting Your Logs on 🤗 Spaces
37 |
38 | Trackio has local-first design, meaning your logs stay on your machine. If you’d like to host them and deploy a dashboard on **🤗 Spaces**, set:
39 |
40 | ```bash
41 | export TRACKIO_SPACE_ID="username/space_id"
42 | ```
43 |
44 | Running the following example:
45 |
46 | ```python
47 | import os
48 | from trl import SFTConfig, SFTTrainer
49 | from datasets import load_dataset
50 |
51 | os.environ["TRACKIO_SPACE_ID"] = "trl-lib/trackio"
52 | os.environ["TRACKIO_PROJECT"] = "trl-documentation"
53 |
54 | trainer = SFTTrainer(
55 | model="Qwen/Qwen3-0.6B",
56 | train_dataset=load_dataset("trl-lib/Capybara", split="train"),
57 | args=SFTConfig(
58 | report_to="trackio",
59 | run_name="sft_qwen3-0.6b_capybara",
60 | ),
61 | )
62 | trainer.train()
63 | ```
64 |
65 | will give you a hosted dashboard at https://huggingface.co/spaces/trl-lib/trackio.
66 |
67 |
68 |
--------------------------------------------------------------------------------
/trl/rewards/format_rewards.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import re
16 |
17 |
18 | def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]:
19 | r"""
20 | Reward function that checks if the reasoning process is enclosed within `""` and `""` tags. The
21 | function returns a reward of 1.0 if the format is correct, otherwise 0.0.
22 |
23 | Args:
24 | completions (`list[list[dict[str, str]]]`):
25 | List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary
26 | containing the key `"content"` with the value being the text of the completion.
27 | **kwargs:
28 | Additional keyword arguments. This function does not use them, but they are required in the function
29 | signature to ensure compatibility with trainers like [`GRPOTrainer`].
30 |
31 | Returns:
32 | `list[float]`:
33 | A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0.
34 |
35 | Example:
36 | ```python
37 | >>> from trl.rewards import think_format_reward
38 |
39 | >>> completions = [
40 | ... [{"content": "\nThis is my reasoning.\n\nThis is my answer."}],
41 | ... [{"content": "\nThis is my reasoning.\nThis is my answer."}],
42 | ... ]
43 | >>> think_format_reward(completions)
44 | [1.0, 0.0]
45 | ```
46 | """
47 | pattern = r"^(?!.*)(.*?).*$"
48 | completion_contents = [completion[0]["content"] for completion in completions]
49 | matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
50 | return [1.0 if match else 0.0 for match in matches]
51 |
--------------------------------------------------------------------------------
/trl/models/modeling_value_head.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from ..import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from ..experimental.ppo import AutoModelForCausalLMWithValueHead as _AutoModelForCausalLMWithValueHead
22 | from ..experimental.ppo import AutoModelForSeq2SeqLMWithValueHead as _AutoModelForSeq2SeqLMWithValueHead
23 |
24 |
25 | class AutoModelForCausalLMWithValueHead(_AutoModelForCausalLMWithValueHead):
26 | def __init__(self, *args, **kwargs):
27 | warnings.warn(
28 | "The `AutoModelForCausalLMWithValueHead` is now located in `trl.experimental`. Please update your imports "
29 | "to `from trl.experimental.ppo import AutoModelForCausalLMWithValueHead`. The current import path will be "
30 | "removed and no longer supported in TRL 0.29. For more information, see "
31 | "https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | super().__init__(*args, **kwargs)
36 |
37 |
38 | class AutoModelForSeq2SeqLMWithValueHead(_AutoModelForSeq2SeqLMWithValueHead):
39 | def __init__(self, *args, **kwargs):
40 | warnings.warn(
41 | "The `AutoModelForSeq2SeqLMWithValueHead` is now located in `trl.experimental`. Please update your imports "
42 | "to `from trl.experimental.ppo import AutoModelForSeq2SeqLMWithValueHead`. The current import path will be "
43 | "removed and no longer supported in TRL 0.29. For more information, see "
44 | "https://github.com/huggingface/trl/issues/4223.",
45 | FutureWarning,
46 | stacklevel=2,
47 | )
48 | super().__init__(*args, **kwargs)
49 |
--------------------------------------------------------------------------------
/examples/scripts/sft_gemma3.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # /// script
16 | # dependencies = [
17 | # "trl",
18 | # "Pillow",
19 | # "trackio",
20 | # "kernels",
21 | # ]
22 | # ///
23 |
24 | """
25 | Train Gemma-3 on the Codeforces COTS dataset.
26 |
27 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py
28 | """
29 |
30 | import os
31 |
32 | from datasets import load_dataset
33 | from transformers import AutoModelForImageTextToText
34 |
35 | from trl import SFTConfig, SFTTrainer
36 |
37 |
38 | # Enable logging in a Hugging Face Space
39 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
40 |
41 |
42 | def main():
43 | # Load dataset
44 | train_dataset = load_dataset("open-r1/codeforces-cots", split="train")
45 | train_dataset = train_dataset.remove_columns("prompt")
46 |
47 | # Load model
48 | model_id = "google/gemma-3-12b-it"
49 | model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager")
50 |
51 | # Train model
52 | training_args = SFTConfig(
53 | output_dir=f"{model_id}-codeforces-SFT",
54 | bf16=True,
55 | use_liger_kernel=True,
56 | gradient_checkpointing=True,
57 | gradient_checkpointing_kwargs={"use_reentrant": False},
58 | max_length=8192,
59 | per_device_train_batch_size=1,
60 | gradient_accumulation_steps=8,
61 | dataset_num_proc=32,
62 | num_train_epochs=1,
63 | )
64 |
65 | trainer = SFTTrainer(
66 | args=training_args,
67 | model=model,
68 | train_dataset=train_dataset,
69 | )
70 | trainer.train()
71 |
72 | # Push to hub
73 | trainer.push_to_hub(dataset_name="open-r1/codeforces-cots")
74 |
75 |
76 | if __name__ == "__main__":
77 | main()
78 |
--------------------------------------------------------------------------------
/docs/source/papo_trainer.md:
--------------------------------------------------------------------------------
1 | # PAPO Trainer
2 |
3 | [](https://huggingface.co/models?other=papo,trl)
4 |
5 | TRL supports the Perception-Aware Policy Optimization (PAPO) as described in the paper [Perception-Aware Policy Optimization for Multimodal Reasoning](https://huggingface.co/papers/2507.06448) by [Zhenhailong Wang](https://huggingface.co/mikewang), Xuehang Guo, Sofia Stoica, [Haiyang Xu](https://huggingface.co/xhyandwyy), Hongru Wang, Hyeonjeong Ha, Xiusi Chen, Yangyi Chen, Ming Yan, Fei Huang, Heng Ji
6 |
7 | The abstract from the paper is the following:
8 |
9 | > Reinforcement Learning with Verifiable Rewards (RLVR) has proven to be a highly effective strategy for endowing Large Language Models (LLMs) with robust multi-step reasoning abilities. However, its design and optimizations remain tailored to purely textual domains, resulting in suboptimal performance when applied to multimodal reasoning tasks. In particular, we observe that a major source of error in current multimodal reasoning lies in the perception of visual inputs. To address this bottleneck, we propose Perception-Aware Policy Optimization (PAPO), a simple yet effective extension of GRPO that encourages the model to learn to perceive while learning to reason, entirely from internal supervision signals. Notably, PAPO does not rely on additional data curation, external reward models, or proprietary models. Specifically, we introduce the Implicit Perception Loss in the form of a KL divergence term to the GRPO objective, which, despite its simplicity, yields significant overall improvements (4.4%) on diverse multimodal benchmarks. The improvements are more pronounced, approaching 8.0%, on tasks with high vision dependency. We also observe a substantial reduction (30.5%) in perception errors, indicating improved perceptual capabilities with PAPO. We conduct comprehensive analysis of PAPO and identify a unique loss hacking issue, which we rigorously analyze and mitigate through a Double Entropy Loss. Overall, our work introduces a deeper integration of perception-aware supervision into RLVR learning objectives and lays the groundwork for a new RL framework that encourages visually grounded reasoning. Project page: https://mikewangwzhl.github.io/PAPO.
10 |
11 | ## PAPOTrainer
12 |
13 | [[autodoc]] experimental.papo.PAPOTrainer
14 | - train
15 | - save_model
16 | - push_to_hub
17 |
18 | ## PAPOConfig
19 |
20 | [[autodoc]] experimental.papo.PAPOConfig
21 |
--------------------------------------------------------------------------------
/tests/experimental/test_minillm_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import pytest
16 | import torch
17 | from datasets import load_dataset
18 |
19 | from trl.experimental.minillm import MiniLLMConfig, MiniLLMTrainer
20 |
21 | from ..testing_utils import TrlTestCase
22 |
23 |
24 | @pytest.mark.low_priority
25 | class TestMiniLLMTrainer(TrlTestCase):
26 | def test_train(self):
27 | # Get the dataset
28 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
29 |
30 | # Initialize the trainer
31 | training_args = MiniLLMConfig(
32 | output_dir=self.tmp_dir,
33 | per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
34 | num_generations=3, # reduce the number of generations to reduce memory usage
35 | max_completion_length=32, # reduce the completion length to reduce memory usage
36 | report_to="none",
37 | )
38 | trainer = MiniLLMTrainer(
39 | model="trl-internal-testing/small-Qwen3ForCausalLM",
40 | teacher_model="trl-internal-testing/tiny-Qwen3ForCausalLM",
41 | args=training_args,
42 | train_dataset=dataset,
43 | )
44 |
45 | # Save the initial parameters to compare them later
46 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
47 |
48 | # Train the model
49 | trainer.train()
50 |
51 | # Check that the training loss is not None
52 | assert trainer.state.log_history[-1]["train_loss"] is not None
53 |
54 | # Check the params have changed
55 | for n, param in previous_trainable_params.items():
56 | new_param = trainer.model.get_parameter(n)
57 | assert not torch.allclose(param, new_param), f"Parameter {n} has not changed"
58 |
--------------------------------------------------------------------------------
/tests/experimental/test_gspo_token_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import torch
17 | from datasets import load_dataset
18 | from transformers.utils import is_peft_available
19 |
20 | from trl import GRPOConfig
21 | from trl.experimental.gspo_token import GRPOTrainer as GSPOTokenTrainer
22 |
23 | from ..testing_utils import TrlTestCase
24 |
25 |
26 | if is_peft_available():
27 | pass
28 |
29 |
30 | class TestGSPOTokenTrainer(TrlTestCase):
31 | def test_training(self):
32 | dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")
33 |
34 | training_args = GRPOConfig(
35 | output_dir=self.tmp_dir,
36 | learning_rate=0.1, # increase the learning rate to speed up the test
37 | per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
38 | num_generations=3, # reduce the number of generations to reduce memory usage
39 | max_completion_length=8, # reduce the completion length to reduce memory usage
40 | num_iterations=2, # the importance sampling weights won't be 0 in this case
41 | importance_sampling_level="sequence_token",
42 | report_to="none",
43 | )
44 | trainer = GSPOTokenTrainer(
45 | model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
46 | reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
47 | args=training_args,
48 | train_dataset=dataset,
49 | )
50 |
51 | previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}
52 |
53 | trainer.train()
54 |
55 | assert trainer.state.log_history[-1]["train_loss"] is not None
56 |
57 | # Check that the params have changed
58 | for n, param in previous_trainable_params.items():
59 | new_param = trainer.model.get_parameter(n)
60 | assert not torch.equal(param, new_param), f"Parameter {n} has not changed."
61 |
--------------------------------------------------------------------------------
/trl/mergekit_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import warnings
16 |
17 | from .import_utils import suppress_experimental_warning
18 |
19 |
20 | with suppress_experimental_warning():
21 | from .experimental.merge_model_callback import MergeConfig as _MergeConfig
22 | from .experimental.merge_model_callback import merge_models as _merge_models
23 | from .experimental.merge_model_callback import upload_model_to_hf as _upload_model_to_hf
24 |
25 |
26 | def upload_model_to_hf(*args, **kwargs):
27 | warnings.warn(
28 | "`upload_model_to_hf` is now located in `trl.experimental`. Please update your imports to "
29 | "`from trl.experimental.merge_model_callback import upload_model_to_hf`. The current import path will be "
30 | "removed and no longer supported in TRL 0.29. For more information, see "
31 | "https://github.com/huggingface/trl/issues/4223.",
32 | FutureWarning,
33 | stacklevel=2,
34 | )
35 | return _upload_model_to_hf(*args, **kwargs)
36 |
37 |
38 | class MergeConfig(_MergeConfig):
39 | def __init__(self, *args, **kwargs):
40 | warnings.warn(
41 | "`MergeConfig` is now located in `trl.experimental`. Please update your imports to "
42 | "`from trl.experimental.merge_model_callback import MergeConfig`. The current import path will be "
43 | "removed and no longer supported in TRL 0.29. For more information, see "
44 | "https://github.com/huggingface/trl/issues/4223.",
45 | FutureWarning,
46 | stacklevel=2,
47 | )
48 |
49 |
50 | def merge_models(*args, **kwargs):
51 | warnings.warn(
52 | "`merge_models` is now located in `trl.experimental`. Please update your imports to "
53 | "`from trl.experimental.merge_model_callback import merge_models`. The current import path will be "
54 | "removed and no longer supported in TRL 0.29. For more information, see "
55 | "https://github.com/huggingface/trl/issues/4223.",
56 | FutureWarning,
57 | stacklevel=2,
58 | )
59 | return _merge_models(*args, **kwargs)
60 |
--------------------------------------------------------------------------------
/docs/source/liger_kernel_integration.md:
--------------------------------------------------------------------------------
1 | # Liger Kernel Integration
2 |
3 | [Liger Kernel](https://github.com/linkedin/Liger-Kernel) is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU training throughput by 20% and reduce memory usage by 60%. That way, we can **4x** our context length, as described in the benchmark below. They have implemented Hugging Face compatible `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, with more to come. The kernel works out of the box with [FlashAttention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed).
4 |
5 | With this memory reduction, you can potentially turn off `cpu_offloading` or gradient checkpointing to further boost the performance.
6 |
7 | | Speed Up | Memory Reduction |
8 | | --- | --- |
9 | |  |  |
10 |
11 | ## Supported Trainers
12 |
13 | Liger Kernel is supported in the following TRL trainers:
14 | - **SFT** (Supervised Fine-Tuning)
15 | - **DPO** (Direct Preference Optimization)
16 | - **GRPO** (Group Relative Policy Optimization)
17 | - **KTO** (Kahneman-Tversky Optimization)
18 | - **GKD** (Generalized Knowledge Distillation)
19 |
20 | ## Usage
21 |
22 | 1. First, install Liger Kernel:
23 |
24 | ```bash
25 | pip install liger-kernel
26 | ```
27 |
28 | 2. Once installed, set `use_liger_kernel=True` in your trainer config. No other changes are needed!
29 |
30 |
31 |
32 |
33 | ```python
34 | from trl import SFTConfig
35 |
36 | training_args = SFTConfig(..., use_liger_kernel=True)
37 | ```
38 |
39 |
40 |
41 |
42 | ```python
43 | from trl import DPOConfig
44 |
45 | training_args = DPOConfig(..., use_liger_kernel=True)
46 | ```
47 |
48 |
49 |
50 |
51 | ```python
52 | from trl import GRPOConfig
53 |
54 | training_args = GRPOConfig(..., use_liger_kernel=True)
55 | ```
56 |
57 |
58 |
59 |
60 | ```python
61 | from trl import KTOConfig
62 |
63 | training_args = KTOConfig(..., use_liger_kernel=True)
64 | ```
65 |
66 |
67 |
68 |
69 | ```python
70 | from trl.experimental.gkd import GKDConfig
71 |
72 | training_args = GKDConfig(..., use_liger_kernel=True)
73 | ```
74 |
75 |
76 |
77 |
78 | To learn more about Liger-Kernel, visit their [official repository](https://github.com/linkedin/Liger-Kernel/).
79 |
--------------------------------------------------------------------------------
/.github/workflows/docker-build.yml:
--------------------------------------------------------------------------------
1 | name: Build TRL Docker image
2 |
3 | on:
4 | push:
5 | branches:
6 | - main
7 | workflow_dispatch:
8 |
9 | concurrency:
10 | group: docker-image-builds
11 | cancel-in-progress: false
12 |
13 | jobs:
14 | trl:
15 | name: "Build and push TRL Docker image"
16 | runs-on:
17 | group: aws-general-8-plus
18 | steps:
19 | - name: Checkout code
20 | uses: actions/checkout@v4
21 |
22 | - name: Get TRL version from PyPI
23 | run: |
24 | VERSION=$(curl -s https://pypi.org/pypi/trl/json | jq -r .info.version)
25 | echo "VERSION=$VERSION" >> $GITHUB_ENV
26 |
27 | - name: Set up Docker Buildx
28 | uses: docker/setup-buildx-action@v3
29 |
30 | - name: Login to DockerHub
31 | uses: docker/login-action@v3
32 | with:
33 | username: ${{ secrets.DOCKERHUB_USERNAME }}
34 | password: ${{ secrets.DOCKERHUB_PASSWORD }}
35 |
36 | - name: Build and Push
37 | uses: docker/build-push-action@v4
38 | with:
39 | context: docker/trl
40 | push: true
41 | tags: |
42 | huggingface/trl:${{ env.VERSION }}
43 | huggingface/trl
44 |
45 | - name: Post to Slack
46 | if: always()
47 | uses: huggingface/hf-workflows/.github/actions/post-slack@main
48 | with:
49 | slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
50 | title: 🤗 Results of the TRL Dev Docker Image build
51 | status: ${{ job.status }}
52 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
53 |
54 | trl-dev:
55 | name: "Build and push TRL Dev Docker image"
56 | runs-on:
57 | group: aws-general-8-plus
58 | steps:
59 | - name: Checkout code
60 | uses: actions/checkout@v4
61 |
62 | - name: Set up Docker Buildx
63 | uses: docker/setup-buildx-action@v3
64 |
65 | - name: Login to DockerHub
66 | uses: docker/login-action@v3
67 | with:
68 | username: ${{ secrets.DOCKERHUB_USERNAME }}
69 | password: ${{ secrets.DOCKERHUB_PASSWORD }}
70 |
71 | - name: Build and Push
72 | uses: docker/build-push-action@v4
73 | with:
74 | context: docker/trl-dev
75 | push: true
76 | tags: |
77 | huggingface/trl:dev
78 |
79 | - name: Post to Slack
80 | if: always()
81 | uses: huggingface/hf-workflows/.github/actions/post-slack@main
82 | with:
83 | slack_channel: ${{ secrets.CI_DOCKER_CHANNEL }}
84 | title: 🤗 Results of the TRL Dev Docker Image build
85 | status: ${{ job.status }}
86 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
87 |
--------------------------------------------------------------------------------
/docs/source/use_model.md:
--------------------------------------------------------------------------------
1 | # Use model after training
2 |
3 | Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference).
4 |
5 | ## Load and Generate
6 |
7 | If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored:
8 |
9 | ```python
10 | from transformers import AutoTokenizer, AutoModelForCausalLM
11 |
12 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
13 | device = "cpu" # or "cuda" if you have a GPU
14 |
15 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device)
16 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
17 |
18 | inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device)
19 | outputs = model.generate(inputs)
20 | print(tokenizer.decode(outputs[0]))
21 | ```
22 |
23 | Alternatively you can also use the pipeline:
24 |
25 | ```python
26 | from transformers import pipeline
27 |
28 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
29 | pipe = pipeline("text-generation", model=model_name_or_path)
30 | print(pipe("This movie was really")[0]["generated_text"])
31 | ```
32 |
33 | ## Use Adapters PEFT
34 |
35 | ```python
36 | from peft import PeftConfig, PeftModel
37 | from transformers import AutoModelForCausalLM, AutoTokenizer
38 |
39 | base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub
40 | adapter_model_name = "path/to/my/adapter"
41 |
42 | model = AutoModelForCausalLM.from_pretrained(base_model_name)
43 | model = PeftModel.from_pretrained(model, adapter_model_name)
44 |
45 | tokenizer = AutoTokenizer.from_pretrained(base_model_name)
46 | ```
47 |
48 | You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger:
49 |
50 | ```python
51 | model = AutoModelForCausalLM.from_pretrained(base_model_name)
52 | model = PeftModel.from_pretrained(model, adapter_model_name)
53 |
54 | model = model.merge_and_unload()
55 | model.save_pretrained("merged_adapters")
56 | ```
57 |
58 | Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above.
59 |
--------------------------------------------------------------------------------
/trl/rewards/other_rewards.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from collections.abc import Callable
16 |
17 |
18 | def get_soft_overlong_punishment(max_completion_len: int, soft_punish_cache: int) -> Callable:
19 | # docstyle-ignore
20 | r"""
21 | Reward function that penalizes overlong completions. It is used to penalize overlong completions, but not to reward
22 | shorter completions. Reference: Eq. (13) from the DAPO paper (https://huggingface.co/papers/2503.14476)
23 |
24 | $$
25 | R_{\text{length}}(y) = \begin{cases}
26 | 0, & |y| \le L_{\max} - L_{\text{cache}} \\
27 | \dfrac{(L_{\max} - L_{\text{cache}}) - |y|}{L_{\text{cache}}}, & L_{\max} - L_{\text{cache}} < |y| \le L_{\max} \\
28 | -1, & L_{\max} < |y|
29 | \end{cases}
30 | $$
31 |
32 | Args:
33 | max_completion_len (`int`):
34 | Maximum length of the completion, \( L_{\max} \).
35 | soft_punish_cache (`int`):
36 | Minimum length of the completion, \( L_{\text{cache}} \). If set to `0`, no minimum length is applied.
37 |
38 | Example:
39 | ```python
40 | from trl.rewards import get_soft_overlong_punishment
41 |
42 | soft_overlong_punishment = get_soft_overlong_punishment(max_completion_len=100, soft_punish_cache=20)
43 | completion_ids = [[1] * 90] # simulating a completion with 90 tokens. 90 is between 80 and 100.
44 | rewards = soft_overlong_punishment(completion_ids)
45 | print(rewards) # [-0.5]
46 | ```
47 | """
48 |
49 | def soft_overlong_punishment_reward(completion_ids: list[list[int]], **kwargs) -> list[float]:
50 | """Reward function that penalizes overlong completions."""
51 | rewards = []
52 | for ids in completion_ids:
53 | completion_length = len(ids)
54 | if completion_length <= max_completion_len - soft_punish_cache:
55 | rewards.append(0.0)
56 | elif max_completion_len - soft_punish_cache < completion_length <= max_completion_len:
57 | rewards.append((max_completion_len - soft_punish_cache - completion_length) / soft_punish_cache)
58 | else:
59 | rewards.append(-1.0)
60 | return rewards
61 |
62 | return soft_overlong_punishment_reward
63 |
--------------------------------------------------------------------------------
/examples/notebooks/README.md:
--------------------------------------------------------------------------------
1 | # Notebooks
2 |
3 | This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications.
4 |
5 | | Notebook | Description | Open in Colab |
6 | | --- | --- | --- |
7 | | [`grpo_agent.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_agent.ipynb) | GRPO for agent training | Not available due to OOM with Colab GPUs |
8 | | [`grpo_rnj_1_instruct.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) | GRPO rnj-1-instruct with QLoRA using TRL on Colab to add reasoning capabilities | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_rnj_1_instruct.ipynb) |
9 | | [`sft_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_ministral3_vl.ipynb) | Supervised Fine-Tuning (SFT) Ministral 3 with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_ministral3_vl.ipynb) |
10 | | [`grpo_ministral3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_ministral3_vl.ipynb) | GRPO Ministral 3 with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_ministral3_vl.ipynb) |
11 | | [`openenv_wordle_grpo.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/openenv_wordle_grpo.ipynb) | GRPO to play Worldle on an OpenEnv environment | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/openenv_wordle_grpo.ipynb) |
12 | | [`sft_trl_lora_qlora.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_trl_lora_qlora.ipynb) | Supervised Fine-Tuning (SFT) using QLoRA on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_trl_lora_qlora.ipynb) |
13 | | [`sft_qwen_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/sft_qwen_vl.ipynb) | Supervised Fine-Tuning (SFT) Qwen3-VL with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/sft_qwen_vl.ipynb) |
14 | | [`grpo_qwen3_vl.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/grpo_qwen3_vl.ipynb) | GRPO Qwen3-VL with QLoRA using TRL on free Colab | [](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/notebooks/grpo_qwen3_vl.ipynb) |
15 |
--------------------------------------------------------------------------------
/.github/ISSUE_TEMPLATE/bug-report.yml:
--------------------------------------------------------------------------------
1 | name: "\U0001F41B Bug Report"
2 | description: Submit a bug report to help us improve TRL
3 | labels: [ "bug" ]
4 | body:
5 | - type: markdown
6 | attributes:
7 | value: |
8 | Thanks for taking the time to fill out this bug report! 🤗
9 |
10 | 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug)
11 |
12 | - type: textarea
13 | id: reproduction
14 | validations:
15 | required: true
16 | attributes:
17 | label: Reproduction
18 | description: |
19 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet.
20 | If you have code snippets, error messages, stack traces please provide them here as well.
21 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting
22 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code.
23 |
24 | value: |
25 | ```python
26 | from trl import ...
27 |
28 | ```
29 |
30 | outputs:
31 |
32 | ```
33 | Traceback (most recent call last):
34 | File "example.py", line 42, in
35 | ...
36 | ```
37 |
38 | - type: textarea
39 | id: system-info
40 | attributes:
41 | label: System Info
42 | description: |
43 | Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ...
44 | You can get this information by running `trl env` in your terminal.
45 |
46 | placeholder: Copy-paste the output of `trl env`
47 | validations:
48 | required: true
49 |
50 | - type: checkboxes
51 | id: terms
52 | attributes:
53 | label: Checklist
54 | description: |
55 | Before submitting, please confirm that you've completed each of the following.
56 | If an item doesn't apply to your issue, check it anyway to show you've reviewed it.
57 | options:
58 | - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))"
59 | required: true
60 | - label: "I have included my system information"
61 | required: true
62 | - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
63 | required: true
64 | - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))"
65 | required: true
66 | - label: "Any traceback provided is complete"
67 | required: true
68 |
--------------------------------------------------------------------------------
/docs/source/judges.md:
--------------------------------------------------------------------------------
1 | # Judges
2 |
3 | > [!WARNING]
4 | > TRL Judges is an experimental API which is subject to change at any time. As of TRL v1.0, judges have been moved to the `trl.experimental.judges` module.
5 |
6 | TRL provides judges to easily compare two completions.
7 |
8 | Make sure to have installed the required dependencies by running:
9 |
10 | ```bash
11 | pip install trl[judges]
12 | ```
13 |
14 | ## Using the provided judges
15 |
16 | TRL provides several judges out of the box. For example, you can use the [`experimental.judges.HfPairwiseJudge`] to compare two completions using a pre-trained model from the Hugging Face model hub:
17 |
18 | ```python
19 | from trl.experimental.judges import HfPairwiseJudge
20 |
21 | judge = HfPairwiseJudge()
22 | judge.judge(
23 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
24 | completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]],
25 | ) # Outputs: [0, 1]
26 | ```
27 |
28 | ## Define your own judge
29 |
30 | To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`experimental.judges.BaseRankJudge`] and implement the [`experimental.judges.BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`experimental.judges.BasePairJudge`] and implement the [`experimental.judges.BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`experimental.judges.BaseJudge`] and implement the [`experimental.judges.BaseJudge.judge`] method.
31 |
32 | As an example, let's define a pairwise judge that prefers shorter completions:
33 |
34 | ```python
35 | from trl.experimental.judges import BasePairwiseJudge
36 |
37 | class PrefersShorterJudge(BasePairwiseJudge):
38 | def judge(self, prompts, completions, shuffle_order=False):
39 | return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions]
40 | ```
41 |
42 | You can then use this judge as follows:
43 |
44 | ```python
45 | judge = PrefersShorterJudge()
46 | judge.judge(
47 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"],
48 | completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]],
49 | ) # Outputs: [0, 1]
50 | ```
51 |
52 | ## Provided judges
53 |
54 | ### PairRMJudge
55 |
56 | [[autodoc]] trl.experimental.judges.PairRMJudge
57 |
58 | ### HfPairwiseJudge
59 |
60 | [[autodoc]] trl.experimental.judges.HfPairwiseJudge
61 |
62 | ### OpenAIPairwiseJudge
63 |
64 | [[autodoc]] trl.experimental.judges.OpenAIPairwiseJudge
65 |
66 | ### AllTrueJudge
67 |
68 | [[autodoc]] trl.experimental.judges.AllTrueJudge
69 |
70 | ## Base classes
71 |
72 | ### BaseJudge
73 |
74 | [[autodoc]] trl.experimental.judges.BaseJudge
75 |
76 | ### BaseBinaryJudge
77 |
78 | [[autodoc]] trl.experimental.judges.BaseBinaryJudge
79 |
80 | ### BaseRankJudge
81 |
82 | [[autodoc]] trl.experimental.judges.BaseRankJudge
83 |
84 | ### BasePairwiseJudge
85 |
86 | [[autodoc]] trl.experimental.judges.BasePairwiseJudge
87 |
--------------------------------------------------------------------------------
/trl/trainer/base_trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 |
17 | from transformers import Trainer, is_wandb_available
18 |
19 | from .utils import generate_model_card, get_comet_experiment_url, get_config_model_id
20 |
21 |
22 | if is_wandb_available():
23 | import wandb
24 |
25 |
26 | class BaseTrainer(Trainer):
27 | _tag_names = []
28 | _name = "Base"
29 | _paper = {}
30 | _template_file = None
31 |
32 | def create_model_card(
33 | self,
34 | model_name: str | None = None,
35 | dataset_name: str | None = None,
36 | tags: str | list[str] | None = None,
37 | ):
38 | """
39 | Creates a draft of a model card using the information available to the `Trainer`.
40 |
41 | Args:
42 | model_name (`str`, *optional*):
43 | Name of the model.
44 | dataset_name (`str`, *optional*):
45 | Name of the dataset used for training.
46 | tags (`str`, `list[str]`, *optional*):
47 | Tags to be associated with the model card.
48 | """
49 | if not self.is_world_process_zero():
50 | return
51 |
52 | model_name_or_path = get_config_model_id(self.model.config)
53 | if model_name_or_path and not os.path.isdir(model_name_or_path):
54 | base_model = model_name_or_path
55 | else:
56 | base_model = None
57 |
58 | # Normalize tags
59 | if tags is None:
60 | tags = set()
61 | elif isinstance(tags, str):
62 | tags = {tags}
63 | else:
64 | tags = set(tags)
65 | if hasattr(self.model.config, "unsloth_version"):
66 | tags.add("unsloth")
67 | if "JOB_ID" in os.environ:
68 | tags.add("hf_jobs")
69 | tags.update(self._tag_names)
70 | tags = list(tags)
71 |
72 | model_card = generate_model_card(
73 | base_model=base_model,
74 | model_name=model_name,
75 | hub_model_id=self.hub_model_id,
76 | dataset_name=dataset_name,
77 | tags=tags,
78 | wandb_url=wandb.run.url if is_wandb_available() and wandb.run is not None else None,
79 | comet_url=get_comet_experiment_url(),
80 | trainer_name=self._name,
81 | trainer_citation=self._paper.get("citation"),
82 | template_file=self._template_file,
83 | paper_title=self._paper.get("title"),
84 | paper_id=self._paper.get("id"),
85 | )
86 | model_card.save(os.path.join(self.args.output_dir, "README.md"))
87 |
--------------------------------------------------------------------------------
/.github/workflows/slow-tests.yml:
--------------------------------------------------------------------------------
1 | name: Slow tests (on push)
2 |
3 | on:
4 | push:
5 | branches: [main]
6 | paths:
7 | # Run only when python files are modified
8 | - "trl/**.py"
9 | - "examples/**.py"
10 | env:
11 | RUN_SLOW: "yes"
12 | IS_GITHUB_CI: "1"
13 | SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }}
14 | TRL_EXPERIMENTAL_SILENCE: 1
15 |
16 | jobs:
17 | run_all_tests_single_gpu:
18 | runs-on:
19 | group: aws-g4dn-2xlarge
20 | env:
21 | CUDA_VISIBLE_DEVICES: "0"
22 | TEST_TYPE: "single_gpu"
23 | container:
24 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
25 | options: --gpus all --shm-size "16gb"
26 | defaults:
27 | run:
28 | shell: bash
29 | steps:
30 | - name: Git checkout
31 | uses: actions/checkout@v4
32 |
33 | - name: Install system dependencies
34 | run: |
35 | apt-get update && apt-get install -y make git curl
36 |
37 | - name: Install uv
38 | run: |
39 | curl -LsSf https://astral.sh/uv/install.sh | sh
40 |
41 | - name: Create Python virtual environment
42 | run: |
43 | uv venv
44 | uv pip install --upgrade setuptools wheel
45 |
46 | - name: Install dependencies
47 | run: |
48 | source .venv/bin/activate
49 | uv pip install ".[dev]"
50 | uv pip install pytest-reportlog
51 |
52 | - name: Run slow SFT tests on single GPU
53 | if: always()
54 | run: |
55 | source .venv/bin/activate
56 | make slow_tests
57 |
58 | - name: Generate Report
59 | if: always()
60 | run: |
61 | source .venv/bin/activate
62 | uv pip install slack_sdk tabulate
63 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
64 |
65 | run_all_tests_multi_gpu:
66 | runs-on:
67 | group: aws-g4dn-2xlarge
68 | env:
69 | CUDA_VISIBLE_DEVICES: "0,1"
70 | TEST_TYPE: "multi_gpu"
71 | container:
72 | image: pytorch/pytorch:2.8.0-cuda12.8-cudnn9-devel
73 | options: --gpus all --shm-size "16gb"
74 | defaults:
75 | run:
76 | shell: bash
77 | steps:
78 | - name: Git checkout
79 | uses: actions/checkout@v4
80 |
81 | - name: Install system dependencies
82 | run: |
83 | apt-get update && apt-get install -y make git curl
84 |
85 | - name: Install uv
86 | run: |
87 | curl -LsSf https://astral.sh/uv/install.sh | sh
88 |
89 | - name: Create Python virtual environment
90 | run: |
91 | uv venv
92 | uv pip install --upgrade setuptools wheel
93 |
94 | - name: Install dependencies
95 | run: |
96 | source .venv/bin/activate
97 | uv pip install ".[dev]"
98 | uv pip install pytest-reportlog
99 |
100 | - name: Run slow SFT tests on Multi GPU
101 | if: always()
102 | run: |
103 | source .venv/bin/activate
104 | make slow_tests
105 |
106 | - name: Generate Reports
107 | if: always()
108 | run: |
109 | source .venv/bin/activate
110 | uv pip install slack_sdk tabulate
111 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY
112 | rm *.txt
--------------------------------------------------------------------------------
/trl/experimental/papo/papo_config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass
16 | from typing import Literal
17 |
18 | from ...trainer.grpo_config import GRPOConfig
19 |
20 |
21 | @dataclass
22 | class PAPOConfig(GRPOConfig):
23 | """
24 | Configuration class for PAPOTrainer.
25 |
26 | PAPO (Perception-Aware Policy Optimization) extends GRPO/DAPO for multimodal reasoning by adding an implicit
27 | perception loss and double entropy regularization.
28 |
29 | Args:
30 | perception_loss_weight (`float`, *optional*, defaults to `0.1`):
31 | gamma Weight coefficient for the perception loss term. This encourages the model to be sensitive to visual
32 | changes.
33 |
34 | mask_ratio (`float`, *optional*, defaults to `0.3`):
35 | Ratio of the image to mask when computing perception loss.
36 |
37 | mask_type (`Literal["random", "patch", "grid"]`, *optional*, defaults to `"random"`):
38 | Type of masking strategy to use.
39 |
40 | der_loss_weight1 (`float`, *optional*, defaults to `0.03`):
41 | eta1 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident
42 | predictions with original images (low entropy) and uncertain predictions with masked images (high entropy).
43 |
44 | der_loss_weight2 (`float`, *optional*, defaults to `0.03`):
45 | eta2 Weight coefficient for the Double Entropy Regularization (DER) term. This term encourages confident
46 | predictions with original images (low entropy) and uncertain predictions with masked images (high entropy).
47 |
48 | loss_type (`Literal["grpo", "dapo"]`, inherited from GRPOConfig):
49 | Base loss type to use. Set to "grpo" for PAPO-G or "dapo" for PAPO-D.
50 | """
51 |
52 | perception_loss_weight: float = 0.1
53 | mask_ratio: float = 0.3
54 | mask_type: Literal["random", "patch", "grid"] = "random"
55 |
56 | # Added for Double Entropy Regularization
57 | der_loss_weight1: float = 0.03
58 | der_loss_weight2: float = 0.03
59 |
60 | def __post_init__(self):
61 | super().__post_init__()
62 |
63 | # Validation
64 | if not 0.0 <= self.mask_ratio <= 1.0:
65 | raise ValueError(f"mask_ratio must be between 0 and 1, got {self.mask_ratio}")
66 |
67 | if self.der_loss_weight1 < 0 or self.der_loss_weight2 < 0:
68 | raise ValueError(
69 | f"der_loss_weight1 and der_loss_weight2 must be non-negative, got {self.der_loss_weight1} and {self.der_loss_weight2}"
70 | )
71 |
72 | if self.mask_type not in ["random", "patch", "grid"]:
73 | raise ValueError(f"mask_type must be one of ['random', 'patch', 'grid'], got {self.mask_type}")
74 |
--------------------------------------------------------------------------------
/tests/experimental/test_merge_model_callback.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | import os
17 |
18 | from datasets import load_dataset
19 | from transformers import AutoModelForCausalLM, AutoTokenizer
20 | from transformers.trainer_utils import get_last_checkpoint
21 |
22 | from trl import DPOConfig, DPOTrainer
23 | from trl.experimental.merge_model_callback import MergeConfig, MergeModelCallback
24 |
25 | from ..testing_utils import TrlTestCase, require_mergekit
26 |
27 |
28 | @require_mergekit
29 | class TestMergeModelCallback(TrlTestCase):
30 | def setup_method(self):
31 | self.model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
32 | self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
33 | self.dataset = load_dataset("trl-internal-testing/zen", "standard_preference", split="train")
34 |
35 | def test_callback(self):
36 | training_args = DPOConfig(
37 | output_dir=self.tmp_dir,
38 | num_train_epochs=1,
39 | report_to="none",
40 | save_strategy="steps",
41 | save_steps=1,
42 | )
43 | config = MergeConfig()
44 | merge_callback = MergeModelCallback(config)
45 | trainer = DPOTrainer(
46 | model=self.model,
47 | args=training_args,
48 | train_dataset=self.dataset,
49 | processing_class=self.tokenizer,
50 | callbacks=[merge_callback],
51 | )
52 | trainer.train()
53 | last_checkpoint = get_last_checkpoint(self.tmp_dir)
54 | merged_path = os.path.join(last_checkpoint, "merged")
55 | assert os.path.isdir(merged_path), "Merged folder does not exist in the last checkpoint."
56 |
57 | def test_every_checkpoint(self):
58 | training_args = DPOConfig(
59 | output_dir=self.tmp_dir,
60 | num_train_epochs=1,
61 | report_to="none",
62 | save_strategy="steps",
63 | save_steps=1,
64 | )
65 | config = MergeConfig()
66 | merge_callback = MergeModelCallback(config, merge_at_every_checkpoint=True)
67 | trainer = DPOTrainer(
68 | model=self.model,
69 | args=training_args,
70 | train_dataset=self.dataset,
71 | processing_class=self.tokenizer,
72 | callbacks=[merge_callback],
73 | )
74 | trainer.train()
75 |
76 | checkpoints = sorted(
77 | [os.path.join(self.tmp_dir, cp) for cp in os.listdir(self.tmp_dir) if cp.startswith("checkpoint-")]
78 | )
79 |
80 | for checkpoint in checkpoints:
81 | merged_path = os.path.join(checkpoint, "merged")
82 | assert os.path.isdir(merged_path), f"Merged folder does not exist in checkpoint {checkpoint}."
83 |
--------------------------------------------------------------------------------
/tests/test_collators.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import torch
16 |
17 | from trl.trainer.dpo_trainer import DataCollatorForPreference
18 |
19 | from .testing_utils import TrlTestCase
20 |
21 |
22 | class TestDataCollatorForPreference(TrlTestCase):
23 | def setup_method(self):
24 | self.collator = DataCollatorForPreference(pad_token_id=0)
25 |
26 | def assertTensorEqual(self, tensor1, tensor2):
27 | assert torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}"
28 |
29 | def test_padding_behavior(self):
30 | examples = [
31 | {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]},
32 | {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]},
33 | ]
34 | output = self.collator.torch_call(examples)
35 |
36 | expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]])
37 | expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]])
38 | expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]])
39 | expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]])
40 | expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]])
41 | expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]])
42 |
43 | self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids)
44 | self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask)
45 | self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids)
46 | self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask)
47 | self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids)
48 | self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask)
49 |
50 | def test_optional_fields(self):
51 | examples = [
52 | {
53 | "prompt_input_ids": [1],
54 | "chosen_input_ids": [2],
55 | "rejected_input_ids": [3],
56 | "pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2)
57 | },
58 | {
59 | "prompt_input_ids": [4],
60 | "chosen_input_ids": [5],
61 | "rejected_input_ids": [6],
62 | "pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2)
63 | },
64 | ]
65 | output = self.collator.torch_call(examples)
66 |
67 | expected_pixel_values = torch.tensor(
68 | [
69 | [[[0.1, 0.2], [0.3, 0.4]]],
70 | [[[0.5, 0.6], [0.7, 0.8]]],
71 | ]
72 | ) # Shape: (2, 1, 2, 2)
73 |
74 | self.assertTensorEqual(output["pixel_values"], expected_pixel_values)
75 |
--------------------------------------------------------------------------------
/examples/datasets/deepmath_103k.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | from dataclasses import dataclass, field
16 |
17 | from datasets import load_dataset
18 | from huggingface_hub import ModelCard
19 | from transformers import HfArgumentParser
20 |
21 |
22 | @dataclass
23 | class ScriptArguments:
24 | r"""
25 | Arguments for the script.
26 |
27 | Args:
28 | push_to_hub (`bool`, *optional*, defaults to `False`):
29 | Whether to push the dataset to the Hugging Face Hub.
30 | repo_id (`str`, *optional*, defaults to `"trl-lib/DeepMath-103K"`):
31 | Hugging Face repository ID to push the dataset to.
32 | dataset_num_proc (`int`, *optional*):
33 | Number of workers to use for dataset processing.
34 | """
35 |
36 | push_to_hub: bool = field(
37 | default=False,
38 | metadata={"help": "Whether to push the dataset to the Hugging Face Hub."},
39 | )
40 | repo_id: str = field(
41 | default="trl-lib/DeepMath-103K",
42 | metadata={"help": "Hugging Face repository ID to push the dataset to."},
43 | )
44 | dataset_num_proc: int | None = field(
45 | default=None,
46 | metadata={"help": "Number of workers to use for dataset processing."},
47 | )
48 |
49 |
50 | def process_example(example):
51 | solution = example["final_answer"]
52 | if solution not in ["True", "False", "Yes", "No"]:
53 | solution = f"${solution}$"
54 | prompt = [{"role": "user", "content": example["question"]}]
55 | return {"prompt": prompt, "solution": solution}
56 |
57 |
58 | model_card = ModelCard("""
59 | ---
60 | tags: [trl]
61 | ---
62 |
63 | # DeepMath-103K Dataset
64 |
65 | ## Summary
66 |
67 | [DeepMath-103K](https://huggingface.co/datasets/zwhe99/DeepMath-103K) is meticulously curated to push the boundaries of mathematical reasoning in language models.
68 |
69 | ## Data Structure
70 |
71 | - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational)
72 | - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only)
73 |
74 | Column:
75 | - `"prompt"`: The input question.
76 | - `"solution"`: The solution to the math problem.
77 |
78 | ## Generation script
79 |
80 | The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/deepmath_103k.py).
81 | """)
82 |
83 | if __name__ == "__main__":
84 | parser = HfArgumentParser(ScriptArguments)
85 | script_args = parser.parse_args_into_dataclasses()[0]
86 |
87 | dataset = load_dataset("zwhe99/DeepMath-103K", split="train")
88 |
89 | dataset = dataset.map(
90 | process_example,
91 | remove_columns=dataset.column_names,
92 | num_proc=script_args.dataset_num_proc,
93 | )
94 | dataset = dataset.train_test_split(test_size=0.05, seed=42)
95 |
96 | if script_args.push_to_hub:
97 | dataset.push_to_hub(script_args.repo_id)
98 | model_card.push_to_hub(script_args.repo_id, repo_type="dataset")
99 |
--------------------------------------------------------------------------------
/scripts/add_copyrights.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import os
16 | import subprocess
17 | import sys
18 | from datetime import datetime
19 |
20 |
21 | COPYRIGHT_HEADER = f"""# Copyright 2020-{datetime.now().year} The HuggingFace Team. All rights reserved.
22 | #
23 | # Licensed under the Apache License, Version 2.0 (the "License");
24 | # you may not use this file except in compliance with the License.
25 | # You may obtain a copy of the License at
26 | #
27 | # http://www.apache.org/licenses/LICENSE-2.0
28 | #
29 | # Unless required by applicable law or agreed to in writing, software
30 | # distributed under the License is distributed on an "AS IS" BASIS,
31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
32 | # See the License for the specific language governing permissions and
33 | # limitations under the License.
34 | """
35 |
36 |
37 | def get_tracked_python_files():
38 | """Get a list of all tracked Python files using git."""
39 | try:
40 | # Get the list of all tracked files from Git
41 | result = subprocess.run(["git", "ls-files"], stdout=subprocess.PIPE, text=True, check=True)
42 | # Split the result by lines to get individual file paths
43 | files = result.stdout.splitlines()
44 | # Filter only Python files
45 | py_files = [f for f in files if f.endswith(".py")]
46 | return py_files
47 | except subprocess.CalledProcessError as e:
48 | print(f"Error fetching tracked files: {e}")
49 | return []
50 |
51 |
52 | def check_and_add_copyright(file_path):
53 | """Check if the file contains a copyright notice, and add it if missing."""
54 | if not os.path.isfile(file_path):
55 | print(f"[SKIP] {file_path} does not exist.")
56 | return
57 |
58 | with open(file_path, encoding="utf-8") as f:
59 | content = f.readlines()
60 |
61 | # Check if the exact copyright header exists
62 | if "".join(content).startswith(COPYRIGHT_HEADER):
63 | return True
64 |
65 | # If no copyright notice was found, prepend the header
66 | print(f"[MODIFY] Adding copyright to {file_path}.")
67 | with open(file_path, "w", encoding="utf-8") as f:
68 | # Write the copyright header followed by the original content
69 | f.write(COPYRIGHT_HEADER + "\n" + "".join(content))
70 | return False
71 |
72 |
73 | def main():
74 | """Main function to check and add copyright for all tracked Python files."""
75 | py_files = get_tracked_python_files()
76 | if not py_files:
77 | print("No Python files are tracked in the repository.")
78 | return
79 |
80 | print(f"Checking {len(py_files)} Python files for copyright notice...")
81 |
82 | have_copyright = [check_and_add_copyright(file_path) for file_path in py_files]
83 | if not all(have_copyright):
84 | print("❌ Some files were missing the required copyright and have been updated.")
85 | sys.exit(1)
86 | else:
87 | print("✅ All files have the required copyright.")
88 | sys.exit(0)
89 |
90 |
91 | if __name__ == "__main__":
92 | main()
93 |
--------------------------------------------------------------------------------
/examples/scripts/sft_gpt_oss.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # /// script
16 | # dependencies = [
17 | # "trl",
18 | # "kernels",
19 | # "trackio",
20 | # "kernels",
21 | # ]
22 | # ///
23 |
24 | """
25 | pip install –-upgrade kernels
26 |
27 | Example:
28 |
29 | accelerate launch \
30 | --config_file examples/accelerate_configs/deepspeed_zero3.yaml \
31 | examples/scripts/sft_gpt_oss.py \
32 | --dtype bfloat16 \
33 | --model_name_or_path openai/gpt-oss-20b \
34 | --packing \
35 | --run_name 20b-full-eager \
36 | --attn_implementation kernels-community/vllm-flash-attn3 \
37 | --dataset_num_proc 12 \
38 | --dataset_name HuggingFaceH4/Multilingual-Thinking \
39 | --gradient_checkpointing \
40 | --max_length 4096 \
41 | --per_device_train_batch_size 2 \
42 | --num_train_epochs 1 \
43 | --logging_steps 1 \
44 | --warmup_ratio 0.03 \
45 | --lr_scheduler_type cosine_with_min_lr \
46 | --lr_scheduler_kwargs '{"min_lr_rate": 0.1}' \
47 | --output_dir gpt-oss-20b-multilingual-reasoner \
48 | --report_to trackio \
49 | --seed 42
50 | """
51 |
52 | import os
53 |
54 | from datasets import load_dataset
55 | from transformers import AutoModelForCausalLM, Mxfp4Config
56 |
57 | from trl import ModelConfig, ScriptArguments, SFTConfig, SFTTrainer, TrlParser, get_peft_config
58 |
59 |
60 | # Enable logging in a Hugging Face Space
61 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
62 |
63 |
64 | def main(script_args, training_args, model_args):
65 | # Load model
66 | quantization_config = Mxfp4Config(dequantize=True)
67 | model_kwargs = dict(
68 | revision=model_args.model_revision,
69 | trust_remote_code=model_args.trust_remote_code,
70 | attn_implementation=model_args.attn_implementation,
71 | dtype=model_args.dtype,
72 | use_cache=False if training_args.gradient_checkpointing else True,
73 | quantization_config=quantization_config,
74 | )
75 |
76 | model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs)
77 |
78 | # Load dataset
79 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config)
80 |
81 | # Train model
82 | trainer = SFTTrainer(
83 | model=model,
84 | args=training_args,
85 | train_dataset=dataset[script_args.dataset_train_split],
86 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,
87 | peft_config=get_peft_config(model_args),
88 | )
89 |
90 | trainer.train()
91 | trainer.save_model(training_args.output_dir)
92 | if training_args.push_to_hub:
93 | trainer.push_to_hub(dataset_name=script_args.dataset_name)
94 |
95 |
96 | if __name__ == "__main__":
97 | parser = TrlParser((ScriptArguments, SFTConfig, ModelConfig))
98 | script_args, training_args, model_args, _ = parser.parse_args_and_config(return_remaining_strings=True)
99 | main(script_args, training_args, model_args)
100 |
--------------------------------------------------------------------------------
/trl/extras/profiling.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import contextlib
16 | import functools
17 | import time
18 | from collections.abc import Callable, Generator
19 |
20 | from transformers import Trainer
21 | from transformers.integrations import is_mlflow_available, is_wandb_available
22 |
23 |
24 | if is_wandb_available():
25 | import wandb
26 |
27 | if is_mlflow_available():
28 | import mlflow
29 |
30 |
31 | @contextlib.contextmanager
32 | def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]:
33 | """
34 | A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow
35 | depending on the trainer's configuration.
36 |
37 | Args:
38 | trainer (`~transformers.Trainer`):
39 | Trainer object.
40 | name (`str`):
41 | Name of the block to be profiled. Used as a key in the logged dictionary.
42 |
43 | Example:
44 | ```python
45 | from transformers import Trainer
46 | from trl.extras.profiling import profiling_context
47 |
48 |
49 | class MyTrainer(Trainer):
50 | def some_method(self):
51 | A = np.random.rand(1000, 1000)
52 | B = np.random.rand(1000, 1000)
53 | with profiling_context(self, "matrix_multiplication"):
54 | # Code to profile: simulate a computationally expensive operation
55 | result = A @ B # Matrix multiplication
56 | ```
57 | """
58 | start_time = time.perf_counter()
59 | yield
60 | end_time = time.perf_counter()
61 | duration = end_time - start_time
62 |
63 | profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration}
64 | if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process:
65 | wandb.log(profiling_metrics)
66 |
67 | if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process:
68 | mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step)
69 |
70 |
71 | def profiling_decorator(func: Callable) -> Callable:
72 | """
73 | Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`].
74 |
75 | Args:
76 | func (`Callable`):
77 | Function to be profiled.
78 |
79 | Example:
80 | ```python
81 | from transformers import Trainer
82 | from trl.extras.profiling import profiling_decorator
83 |
84 |
85 | class MyTrainer(Trainer):
86 | @profiling_decorator
87 | def some_method(self):
88 | A = np.random.rand(1000, 1000)
89 | B = np.random.rand(1000, 1000)
90 | # Code to profile: simulate a computationally expensive operation
91 | result = A @ B
92 | ```
93 | """
94 |
95 | @functools.wraps(func)
96 | def wrapper(self, *args, **kwargs):
97 | with profiling_context(self, func.__name__):
98 | return func(self, *args, **kwargs)
99 |
100 | return wrapper
101 |
--------------------------------------------------------------------------------
/examples/scripts/rloo.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # /// script
16 | # dependencies = [
17 | # "trl[vllm]",
18 | # "peft",
19 | # "math-verify",
20 | # "latex2sympy2_extended",
21 | # "trackio",
22 | # "kernels",
23 | # ]
24 | # ///
25 |
26 | """
27 | pip install math_verify num2words==0.5.14 peft trackio vllm
28 | export TRACKIO_PROJECT="RLOO-NuminaMath-TIR"
29 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/rloo.py
30 | """
31 |
32 | import os
33 |
34 | import torch
35 | from datasets import load_dataset
36 | from peft import LoraConfig
37 |
38 | from trl import RLOOConfig, RLOOTrainer
39 | from trl.rewards import accuracy_reward, think_format_reward
40 |
41 |
42 | # Enable logging in a Hugging Face Space
43 | os.environ.setdefault("TRACKIO_SPACE_ID", "trl-trackio")
44 |
45 |
46 | def main():
47 | # Dataset
48 | train_dataset, eval_dataset = load_dataset("AI-MO/NuminaMath-TIR", split=["train[:5%]", "test[:5%]"])
49 |
50 | SYSTEM_PROMPT = (
51 | "A conversation between user and assistant. The user asks a question, and the assistant solves it. The "
52 | "assistant first thinks about the reasoning process in the mind and then provides the user with the answer. "
53 | "The reasoning process and answer are enclosed within tags, i.e., \nThis is my "
54 | "reasoning.\n\nThis is my answer."
55 | )
56 |
57 | def make_conversation(example):
58 | return {
59 | "prompt": [
60 | {"role": "system", "content": SYSTEM_PROMPT},
61 | {"role": "user", "content": example["problem"]},
62 | ],
63 | }
64 |
65 | train_dataset = train_dataset.map(make_conversation, remove_columns=["messages", "problem"])
66 | eval_dataset = eval_dataset.map(make_conversation, remove_columns=["messages", "problem"])
67 |
68 | # Training
69 | training_args = RLOOConfig(
70 | output_dir="Qwen3-0.6B-RLOO",
71 | model_init_kwargs={"dtype": torch.bfloat16},
72 | learning_rate=1e-5,
73 | gradient_checkpointing_kwargs=dict(use_reentrant=False),
74 | log_completions=True,
75 | num_completions_to_print=2,
76 | max_prompt_length=2048,
77 | max_completion_length=1024,
78 | gradient_accumulation_steps=2,
79 | steps_per_generation=8,
80 | use_vllm=True,
81 | vllm_mode="colocate",
82 | vllm_gpu_memory_utilization=0.5,
83 | run_name="Qwen3-0.6B-RLOO-NuminaMath-TIR",
84 | )
85 |
86 | trainer = RLOOTrainer(
87 | model="Qwen/Qwen3-0.6B",
88 | args=training_args,
89 | reward_funcs=[think_format_reward, accuracy_reward],
90 | train_dataset=train_dataset,
91 | eval_dataset=eval_dataset,
92 | peft_config=LoraConfig(),
93 | )
94 |
95 | trainer.train()
96 |
97 | # Save and push to hub
98 | trainer.save_model(training_args.output_dir)
99 | trainer.push_to_hub(dataset_name="AI-MO/NuminaMath-TIR")
100 |
101 |
102 | if __name__ == "__main__":
103 | main()
104 |
--------------------------------------------------------------------------------
/docs/source/_toctree.yml:
--------------------------------------------------------------------------------
1 | - sections:
2 | - local: index
3 | title: TRL
4 | - local: installation
5 | title: Installation
6 | - local: quickstart
7 | title: Quickstart
8 | title: Getting started
9 | - sections:
10 | - local: dataset_formats
11 | title: Dataset Formats
12 | - local: paper_index
13 | title: Paper Index
14 | title: Conceptual Guides
15 | - sections: # Sorted alphabetically
16 | - local: dpo_trainer
17 | title: DPO
18 | - local: grpo_trainer
19 | title: GRPO
20 | - local: reward_trainer
21 | title: Reward
22 | - local: rloo_trainer
23 | title: RLOO
24 | - local: sft_trainer
25 | title: SFT
26 | title: Trainers
27 | - sections:
28 | - local: clis
29 | title: Command Line Interface (CLI)
30 | - local: jobs_training
31 | title: Training using Jobs
32 | - local: customization
33 | title: Customizing the Training
34 | - local: reducing_memory_usage
35 | title: Reducing Memory Usage
36 | - local: speeding_up_training
37 | title: Speeding Up Training
38 | - local: distributing_training
39 | title: Distributing Training
40 | - local: use_model
41 | title: Using Trained Models
42 | title: How-to guides
43 | - sections:
44 | - local: deepspeed_integration
45 | title: DeepSpeed
46 | - local: kernels_hub
47 | title: Kernels Hub
48 | - local: liger_kernel_integration
49 | title: Liger Kernel
50 | - local: peft_integration
51 | title: PEFT
52 | - local: rapidfire_integration
53 | title: RapidFire AI
54 | - local: trackio_integration
55 | title: Trackio
56 | - local: unsloth_integration
57 | title: Unsloth
58 | - local: vllm_integration
59 | title: vLLM
60 | title: Integrations
61 | - sections:
62 | - local: example_overview
63 | title: Example Overview
64 | - local: community_tutorials
65 | title: Community Tutorials
66 | - local: lora_without_regret
67 | title: LoRA Without Regret
68 | title: Examples
69 | - sections:
70 | - sections:
71 | - local: chat_template_utils
72 | title: Chat Template Utilities
73 | - local: data_utils
74 | title: Data Utilities
75 | - local: model_utils
76 | title: Model Utilities
77 | - local: script_utils
78 | title: Script Utilities
79 | title: Utilities
80 | - local: callbacks
81 | title: Callbacks
82 | - local: rewards
83 | title: Reward Functions
84 | - local: others
85 | title: Others
86 | title: API
87 | - sections:
88 | - local: experimental_overview
89 | title: Experimental Overview
90 | - local: openenv
91 | title: OpenEnv Integration
92 | - local: bema_for_reference_model # Sorted alphabetically
93 | title: BEMA for Reference Model
94 | - local: bco_trainer
95 | title: BCO
96 | - local: cpo_trainer
97 | title: CPO
98 | - local: gfpo
99 | title: GFPO
100 | - local: gkd_trainer
101 | title: GKD
102 | - local: gold_trainer
103 | title: GOLD
104 | - local: grpo_with_replay_buffer
105 | title: GRPO With Replay Buffer
106 | - local: gspo_token
107 | title: GSPO-token
108 | - local: judges
109 | title: Judges
110 | - local: kto_trainer
111 | title: KTO
112 | - local: merge_model_callback
113 | title: MergeModelCallback
114 | - local: minillm_trainer
115 | title: MiniLLM
116 | - local: nash_md_trainer
117 | title: Nash-MD
118 | - local: online_dpo_trainer
119 | title: Online DPO
120 | - local: orpo_trainer
121 | title: ORPO
122 | - local: papo_trainer
123 | title: PAPO
124 | - local: ppo_trainer
125 | title: PPO
126 | - local: prm_trainer
127 | title: PRM
128 | - local: winrate_callback
129 | title: WinRateCallback
130 | - local: xpo_trainer
131 | title: XPO
132 | title: Experimental
--------------------------------------------------------------------------------
/trl/scripts/env.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | # /// script
16 | # dependencies = [
17 | # "trl",
18 | # ]
19 | # ///
20 |
21 | import os
22 | import platform
23 | from importlib.metadata import version
24 |
25 | import torch
26 | from accelerate.commands.config import default_config_file, load_config_from_file
27 | from transformers import is_bitsandbytes_available
28 | from transformers.utils import is_openai_available, is_peft_available
29 |
30 | from trl import __version__
31 | from trl.import_utils import (
32 | is_deepspeed_available,
33 | is_liger_kernel_available,
34 | is_llm_blender_available,
35 | is_vllm_available,
36 | )
37 | from trl.scripts.utils import get_git_commit_hash
38 |
39 |
40 | def print_env():
41 | devices = None
42 | if torch.cuda.is_available():
43 | devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]
44 | elif torch.backends.mps.is_available():
45 | devices = ["MPS"]
46 | elif torch.xpu.is_available():
47 | devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())]
48 |
49 | accelerate_config = accelerate_config_str = "not found"
50 |
51 | # Get the default from the config file.
52 | if os.path.isfile(default_config_file):
53 | accelerate_config = load_config_from_file(default_config_file).to_dict()
54 |
55 | accelerate_config_str = (
56 | "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()])
57 | if isinstance(accelerate_config, dict)
58 | else accelerate_config
59 | )
60 |
61 | commit_hash = get_git_commit_hash("trl")
62 |
63 | info = {
64 | "Platform": platform.platform(),
65 | "Python version": platform.python_version(),
66 | "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__,
67 | "PyTorch version": version("torch"),
68 | "accelerator(s)": ", ".join(devices) if devices is not None else "cpu",
69 | "Transformers version": version("transformers"),
70 | "Accelerate version": version("accelerate"),
71 | "Accelerate config": accelerate_config_str,
72 | "Datasets version": version("datasets"),
73 | "HF Hub version": version("huggingface_hub"),
74 | "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed",
75 | "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed",
76 | "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed",
77 | "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed",
78 | "OpenAI version": version("openai") if is_openai_available() else "not installed",
79 | "PEFT version": version("peft") if is_peft_available() else "not installed",
80 | "vLLM version": version("vllm") if is_vllm_available() else "not installed",
81 | }
82 |
83 | info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()])
84 | print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa
85 |
86 |
87 | if __name__ == "__main__":
88 | print_env()
89 |
--------------------------------------------------------------------------------
/tests/experimental/test_judges.py:
--------------------------------------------------------------------------------
1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import sys
16 | import time
17 |
18 | import pytest
19 |
20 | from trl.experimental.judges import AllTrueJudge, HfPairwiseJudge, PairRMJudge
21 |
22 | from ..testing_utils import RandomBinaryJudge, TrlTestCase, require_llm_blender
23 |
24 |
25 | class TestJudges(TrlTestCase):
26 | def _get_prompts_and_pairwise_completions(self):
27 | prompts = ["The capital of France is", "The biggest planet in the solar system is"]
28 | completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]]
29 | return prompts, completions
30 |
31 | def _get_prompts_and_single_completions(self):
32 | prompts = ["What's the capital of France?", "What's the color of the sky?"]
33 | completions = ["Marseille", "blue"]
34 | return prompts, completions
35 |
36 | def test_all_true_judge(self):
37 | judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()])
38 | prompts, completions = self._get_prompts_and_single_completions()
39 | judgements = judge.judge(prompts=prompts, completions=completions)
40 | assert len(judgements) == 2
41 | assert all(judgement in {0, 1, -1} for judgement in judgements)
42 |
43 | @pytest.mark.skip(reason="This test needs to be run manually since it requires a valid Hugging Face API key.")
44 | def test_hugging_face_judge(self):
45 | judge = HfPairwiseJudge()
46 | prompts, completions = self._get_prompts_and_pairwise_completions()
47 | ranks = judge.judge(prompts=prompts, completions=completions)
48 | assert len(ranks) == 2
49 | assert all(isinstance(rank, int) for rank in ranks)
50 | assert ranks == [0, 1]
51 |
52 | def load_pair_rm_judge(self):
53 | # When using concurrent tests, PairRM may fail to load the model while another job is still downloading.
54 | # This is a workaround to retry loading the model a few times.
55 | for _ in range(5):
56 | try:
57 | return PairRMJudge()
58 | except ValueError:
59 | time.sleep(5)
60 | raise ValueError("Failed to load PairRMJudge")
61 |
62 | @require_llm_blender
63 | @pytest.mark.skipif(
64 | sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
65 | )
66 | def test_pair_rm_judge(self):
67 | judge = self.load_pair_rm_judge()
68 | prompts, completions = self._get_prompts_and_pairwise_completions()
69 | ranks = judge.judge(prompts=prompts, completions=completions)
70 | assert len(ranks) == 2
71 | assert all(isinstance(rank, int) for rank in ranks)
72 | assert ranks == [0, 1]
73 |
74 | @require_llm_blender
75 | @pytest.mark.skipif(
76 | sys.version_info[:3] == (3, 13, 8), reason="Python 3.13.8 has a bug in inspect.BlockFinder (cpython GH-139783)"
77 | )
78 | def test_pair_rm_judge_return_scores(self):
79 | judge = self.load_pair_rm_judge()
80 | prompts, completions = self._get_prompts_and_pairwise_completions()
81 | probs = judge.judge(prompts=prompts, completions=completions, return_scores=True)
82 | assert len(probs) == 2
83 | assert all(isinstance(prob, float) for prob in probs)
84 | assert all(0 <= prob <= 1 for prob in probs)
85 |
--------------------------------------------------------------------------------