├── .github ├── ISSUE_TEMPLATE │ ├── bug-report.yml │ ├── feature-request.yml │ └── new-trainer-addition.yml ├── PULL_REQUEST_TEMPLATE.md ├── codeql │ └── custom-queries.qls └── workflows │ ├── build_documentation.yml │ ├── build_pr_documentation.yml │ ├── clear_cache.yml │ ├── codeQL.yml │ ├── docker-build.yml │ ├── issue_auto_labeller.yml │ ├── pr_style_bot.yml │ ├── slow-tests.yml │ ├── tests.yml │ ├── tests_latest.yml │ ├── trufflehog.yml │ └── upload_pr_documentation.yml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── Makefile ├── README.md ├── commands ├── run_dpo.sh └── run_sft.sh ├── docker ├── trl-latest-gpu │ └── Dockerfile └── trl-source-gpu │ └── Dockerfile ├── docs └── source │ ├── _toctree.yml │ ├── alignprop_trainer.md │ ├── bco_trainer.md │ ├── best_of_n.md │ ├── callbacks.md │ ├── clis.md │ ├── community_tutorials.md │ ├── cpo_trainer.md │ ├── customization.md │ ├── data_utils.md │ ├── dataset_formats.md │ ├── ddpo_trainer.md │ ├── deepspeed_integration.md │ ├── detoxifying_a_lm.md │ ├── distributing_training.md │ ├── dpo_trainer.md │ ├── example_overview.md │ ├── gkd_trainer.md │ ├── grpo_trainer.md │ ├── how_to_train.md │ ├── index.md │ ├── installation.md │ ├── iterative_sft_trainer.md │ ├── judges.md │ ├── kto_trainer.md │ ├── liger_kernel_integration.md │ ├── logging.md │ ├── model_utils.md │ ├── models.md │ ├── multi_adapter_rl.md │ ├── nash_md_trainer.md │ ├── online_dpo_trainer.md │ ├── orpo_trainer.md │ ├── others.md │ ├── peft_integration.md │ ├── ppo_trainer.md │ ├── prm_trainer.md │ ├── quickstart.md │ ├── reducing_memory_usage.md │ ├── reward_trainer.md │ ├── rewards.md │ ├── rloo_trainer.md │ ├── script_utils.md │ ├── sentiment_tuning.md │ ├── sft_trainer.md │ ├── speeding_up_training.md │ ├── training_vlm_sft.md │ ├── unsloth_integration.md │ ├── use_model.md │ ├── using_llama_models.md │ ├── vllm_integration.md │ └── xpo_trainer.md ├── examples ├── README.md ├── accelerate_configs │ ├── deepspeed_zero1.yaml │ ├── deepspeed_zero2.yaml │ ├── deepspeed_zero3.yaml │ ├── fsdp1.yaml │ ├── fsdp2.yaml │ ├── multi_gpu.yaml │ └── single_gpu.yaml ├── cli_configs │ └── example_config.yaml ├── datasets │ ├── hh-rlhf-helpful-base.py │ ├── lm-human-preferences-descriptiveness.py │ ├── lm-human-preferences-sentiment.py │ ├── math_shepherd.py │ ├── prm800k.py │ ├── rlaif-v.py │ ├── tldr.py │ ├── tldr_preference.py │ ├── ultrafeedback-prompt.py │ └── ultrafeedback.py ├── notebooks │ ├── README.md │ ├── best_of_n.ipynb │ ├── gpt2-sentiment-control.ipynb │ └── gpt2-sentiment.ipynb ├── research_projects │ ├── README.md │ ├── layer_skip │ │ ├── README.md │ │ └── scripts │ │ │ ├── benchmark_layer_skip.py │ │ │ ├── config.py │ │ │ ├── custom_trainer.py │ │ │ └── layer_skip_sft.py │ ├── stack_llama │ │ └── scripts │ │ │ ├── README.md │ │ │ ├── merge_peft_adapter.py │ │ │ ├── reward_modeling.py │ │ │ ├── rl_training.py │ │ │ └── supervised_finetuning.py │ ├── stack_llama_2 │ │ └── scripts │ │ │ ├── README.md │ │ │ ├── dpo_llama2.py │ │ │ ├── requirements.txt │ │ │ └── sft_llama2.py │ └── toxicity │ │ ├── README.md │ │ └── scripts │ │ ├── evaluate-toxicity.py │ │ └── gpt-j-6b-toxicity.py └── scripts │ ├── alignprop.py │ ├── bco.py │ ├── cpo.py │ ├── ddpo.py │ ├── dpo.py │ ├── dpo_online.py │ ├── dpo_vlm.py │ ├── evals │ └── judge_tldr.py │ ├── gkd.py │ ├── kto.py │ ├── nash_md.py │ ├── orpo.py │ ├── ppo │ ├── ppo.py │ └── ppo_tldr.py │ ├── prm.py │ ├── reward_modeling.py │ ├── rloo │ ├── rloo.py │ └── rloo_tldr.py │ ├── sft.py │ ├── sft_gemma3.py │ ├── sft_video_llm.py │ ├── sft_vlm.py │ ├── sft_vlm_gemma3.py │ ├── sft_vlm_smol_vlm.py │ └── xpo.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── add_copyrights.py ├── generate_tiny_models.py ├── generate_zen_dataset.py ├── log_example_reports.py └── log_reports.py ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── slow │ ├── __init__.py │ ├── test_dpo_slow.py │ ├── test_grpo_slow.py │ ├── test_sft_slow.py │ └── testing_constants.py ├── test_activation_offloading.py ├── test_alignprop_trainer.py ├── test_bco_trainer.py ├── test_best_of_n_sampler.py ├── test_callbacks.py ├── test_cli.py ├── test_cli_utils.py ├── test_collators.py ├── test_core.py ├── test_cpo_trainer.py ├── test_data_collator_completion_only.py ├── test_data_utils.py ├── test_dataset_formatting.py ├── test_ddpo_trainer.py ├── test_dpo_trainer.py ├── test_environments.py ├── test_gkd_trainer.py ├── test_grpo_trainer.py ├── test_iterative_sft_trainer.py ├── test_judges.py ├── test_kto_trainer.py ├── test_modeling_geometric_mixture_wrapper.py ├── test_modeling_value_head.py ├── test_nash_md_trainer.py ├── test_online_dpo_trainer.py ├── test_orpo_trainer.py ├── test_peft_models.py ├── test_ppo_trainer.py ├── test_prm_trainer.py ├── test_reward_trainer.py ├── test_rewards.py ├── test_rich_progress_callback.py ├── test_rloo_trainer.py ├── test_sft_trainer.py ├── test_trainers_args.py ├── test_utils.py ├── test_vllm_client_server.py ├── test_xpo_trainer.py ├── testing_constants.py └── testing_utils.py └── trl ├── __init__.py ├── accelerate_configs ├── fsdp1.yaml ├── fsdp2.yaml ├── multi_gpu.yaml ├── single_gpu.yaml ├── zero1.yaml ├── zero2.yaml └── zero3.yaml ├── cli.py ├── core.py ├── data_utils.py ├── environment ├── __init__.py └── base_environment.py ├── extras ├── __init__.py ├── best_of_n_sampler.py ├── dataset_formatting.py ├── profiling.py └── vllm_client.py ├── import_utils.py ├── mergekit_utils.py ├── models ├── __init__.py ├── activation_offloading.py ├── auxiliary_modules.py ├── modeling_base.py ├── modeling_sd_base.py ├── modeling_value_head.py ├── sd_utils.py └── utils.py ├── rewards ├── __init__.py └── format_rewards.py ├── scripts ├── __init__.py ├── dpo.py ├── env.py ├── grpo.py ├── kto.py ├── sft.py ├── utils.py └── vllm_serve.py ├── templates └── lm_model_card.md └── trainer ├── __init__.py ├── alignprop_config.py ├── alignprop_trainer.py ├── bco_config.py ├── bco_trainer.py ├── callbacks.py ├── cpo_config.py ├── cpo_trainer.py ├── ddpo_config.py ├── ddpo_trainer.py ├── dpo_config.py ├── dpo_trainer.py ├── gkd_config.py ├── gkd_trainer.py ├── grpo_config.py ├── grpo_trainer.py ├── iterative_sft_config.py ├── iterative_sft_trainer.py ├── judges.py ├── kto_config.py ├── kto_trainer.py ├── model_config.py ├── nash_md_config.py ├── nash_md_trainer.py ├── online_dpo_config.py ├── online_dpo_trainer.py ├── orpo_config.py ├── orpo_trainer.py ├── ppo_config.py ├── ppo_trainer.py ├── prm_config.py ├── prm_trainer.py ├── reward_config.py ├── reward_trainer.py ├── rloo_config.py ├── rloo_trainer.py ├── sft_config.py ├── sft_trainer.py ├── utils.py ├── xpo_config.py └── xpo_trainer.py /.github/ISSUE_TEMPLATE/bug-report.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F41B Bug Report" 2 | description: Submit a bug report to help us improve TRL 3 | labels: [ "bug" ] 4 | body: 5 | - type: markdown 6 | attributes: 7 | value: | 8 | Thanks for taking the time to fill out this bug report! 🤗 9 | 10 | 🚩 If it is your first time submitting, be sure to check our [bug report guidelines](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#did-you-find-a-bug) 11 | 12 | - type: textarea 13 | id: reproduction 14 | validations: 15 | required: true 16 | attributes: 17 | label: Reproduction 18 | description: | 19 | Please provide a code sample that reproduces the problem you ran into. It can be a Colab link or just a code snippet. 20 | If you have code snippets, error messages, stack traces please provide them here as well. 21 | Important! Use code tags to correctly format your code. See https://help.github.com/en/github/writing-on-github/creating-and-highlighting-code-blocks#syntax-highlighting 22 | Do not use screenshots, as they are hard to read and (more importantly) don't allow others to copy-and-paste your code. 23 | 24 | value: | 25 | ```python 26 | from trl import ... 27 | 28 | ``` 29 | 30 | outputs: 31 | 32 | ``` 33 | Traceback (most recent call last): 34 | File "example.py", line 42, in 35 | ... 36 | ``` 37 | 38 | - type: textarea 39 | id: system-info 40 | attributes: 41 | label: System Info 42 | description: | 43 | Please provide information about your system: platform, Python version, PyTorch version, Transformers version, devices, TRL version, ... 44 | You can get this information by running `trl env` in your terminal. 45 | 46 | placeholder: Copy-paste the output of `trl env` 47 | validations: 48 | required: true 49 | 50 | - type: checkboxes 51 | id: terms 52 | attributes: 53 | label: Checklist 54 | description: | 55 | Before submitting, please confirm that you've completed each of the following. 56 | If an item doesn't apply to your issue, check it anyway to show you've reviewed it. 57 | options: 58 | - label: "I have checked that my issue isn't already filed (see [open issues](https://github.com/huggingface/trl/issues?q=is%3Aissue))" 59 | required: true 60 | - label: "I have included my system information" 61 | required: true 62 | - label: "Any code provided is minimal, complete, and reproducible ([more on MREs](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" 63 | required: true 64 | - label: "Any code provided is properly formatted in code blocks, (no screenshot, [more on code blocks](https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/creating-and-highlighting-code-blocks))" 65 | required: true 66 | - label: "Any traceback provided is complete" 67 | required: true 68 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature-request.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F680 Feature request" 2 | description: Submit a proposal/request for a new TRL feature 3 | labels: [ "Feature request" ] 4 | body: 5 | - type: textarea 6 | id: feature-request 7 | validations: 8 | required: true 9 | attributes: 10 | label: Feature request 11 | description: | 12 | A clear and concise description of the feature proposal. Please provide a link to the paper and code in case they exist. 13 | 14 | - type: textarea 15 | id: motivation 16 | validations: 17 | required: true 18 | attributes: 19 | label: Motivation 20 | description: | 21 | Please outline the motivation for the proposal. Is your feature request related to a problem? e.g., I'm always frustrated when [...]. If this is related to another GitHub issue, please link here too. 22 | 23 | 24 | - type: textarea 25 | id: contribution 26 | validations: 27 | required: true 28 | attributes: 29 | label: Your contribution 30 | description: | 31 | Is there any way that you could help, e.g. by submitting a PR? Make sure to read the CONTRIBUTING.MD [readme](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md) 32 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/new-trainer-addition.yml: -------------------------------------------------------------------------------- 1 | name: "\U0001F31F New trainer addition" 2 | description: Submit a proposal/request to implement a new trainer for a post-training method 3 | labels: [ "New trainer" ] 4 | 5 | body: 6 | - type: textarea 7 | id: description-request 8 | validations: 9 | required: true 10 | attributes: 11 | label: Method description 12 | description: | 13 | Put any and all important information relative to the method 14 | 15 | - type: checkboxes 16 | id: information-tasks 17 | attributes: 18 | label: Open source status 19 | description: | 20 | Please note that if the method implementation isn't available or model weights with training datasets aren't available, we are less likely to implement it in `trl`. 21 | options: 22 | - label: "The method implementation is available" 23 | - label: "The model weights are available" 24 | - label: "The training datasets are available" 25 | 26 | - type: textarea 27 | id: additional-info 28 | attributes: 29 | label: Provide useful links for the implementation 30 | description: | 31 | Please provide information regarding the implementation, the weights, and the authors. 32 | Please mention the authors by @gh-username if you're aware of their usernames. 33 | -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | # What does this PR do? 2 | 3 | 12 | 13 | 14 | 15 | Fixes # (issue) 16 | 17 | 18 | ## Before submitting 19 | - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). 20 | - [ ] Did you read the [contributor guideline](https://github.com/huggingface/trl/blob/main/CONTRIBUTING.md#create-a-pull-request), 21 | Pull Request section? 22 | - [ ] Was this discussed/approved via a GitHub issue? Please add a link 23 | to it if that's the case. 24 | - [ ] Did you make sure to update the documentation with your changes? 25 | - [ ] Did you write any new necessary tests? 26 | 27 | 28 | ## Who can review? 29 | 30 | Anyone in the community is free to review the PR once the tests have passed. Feel free to tag 31 | members/contributors who may be interested in your PR. -------------------------------------------------------------------------------- /.github/codeql/custom-queries.qls: -------------------------------------------------------------------------------- 1 | import codeql 2 | 3 | from WorkflowString interpolation, Workflow workflow 4 | where 5 | interpolation.getStringValue().matches("${{ github.event.issue.title }}") or 6 | interpolation.getStringValue().matches("${{ github.event.issue.body }}") or 7 | interpolation.getStringValue().matches("${{ github.event.pull_request.title }}") or 8 | interpolation.getStringValue().matches("${{ github.event.pull_request.body }}") or 9 | interpolation.getStringValue().matches("${{ github.event.review.body }}") or 10 | interpolation.getStringValue().matches("${{ github.event.comment.body }}") or 11 | interpolation.getStringValue().matches("${{ github.event.inputs.* }}") or 12 | interpolation.getStringValue().matches("${{ github.event.head_commit.message }}") 13 | interpolation.getStringValue().matches("${{ github.event.* }}") and 14 | ( 15 | step.getKey() = "run" or // Injection in run 16 | step.getKey() = "env" or // Injection via env 17 | step.getKey() = "with" // Injection via with 18 | ) 19 | select workflow, "🚨 Do not use directly as input of action" 20 | -------------------------------------------------------------------------------- /.github/workflows/build_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build documentation 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | - doc-builder* 8 | - v*-release 9 | 10 | jobs: 11 | build: 12 | uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main 13 | with: 14 | commit_sha: ${{ github.sha }} 15 | package: trl 16 | version_tag_suffix: "" 17 | custom_container: huggingface/transformers-doc-builder 18 | secrets: 19 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 20 | -------------------------------------------------------------------------------- /.github/workflows/build_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Build PR Documentation 2 | 3 | on: 4 | pull_request: 5 | 6 | concurrency: 7 | group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} 8 | cancel-in-progress: true 9 | 10 | jobs: 11 | build: 12 | if: github.event.pull_request.draft == false 13 | uses: huggingface/doc-builder/.github/workflows/build_pr_documentation.yml@main 14 | with: 15 | commit_sha: ${{ github.event.pull_request.head.sha }} 16 | pr_number: ${{ github.event.number }} 17 | package: trl 18 | version_tag_suffix: "" 19 | custom_container: huggingface/transformers-doc-builder 20 | -------------------------------------------------------------------------------- /.github/workflows/clear_cache.yml: -------------------------------------------------------------------------------- 1 | name: "Cleanup Cache" 2 | 3 | on: 4 | workflow_dispatch: 5 | schedule: 6 | - cron: "0 0 * * *" 7 | 8 | jobs: 9 | cleanup: 10 | runs-on: ubuntu-latest 11 | steps: 12 | - name: Check out code 13 | uses: actions/checkout@v4 14 | 15 | - name: Cleanup 16 | run: | 17 | gh extension install actions/gh-actions-cache 18 | 19 | REPO=${{ github.repository }} 20 | 21 | echo "Fetching list of cache key" 22 | cacheKeysForPR=$(gh actions-cache list -R $REPO | cut -f 1 ) 23 | 24 | ## Setting this to not fail the workflow while deleting cache keys. 25 | set +e 26 | echo "Deleting caches..." 27 | for cacheKey in $cacheKeysForPR 28 | do 29 | gh actions-cache delete $cacheKey -R $REPO --confirm 30 | done 31 | echo "Done" 32 | env: 33 | GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} 34 | -------------------------------------------------------------------------------- /.github/workflows/codeQL.yml: -------------------------------------------------------------------------------- 1 | name: "CodeQL Analysis - Workflows" 2 | 3 | on: 4 | workflow_dispatch: 5 | 6 | jobs: 7 | analyze: 8 | name: "Analyze GitHub Workflows" 9 | runs-on: ubuntu-latest 10 | permissions: 11 | security-events: write 12 | actions: read 13 | contents: read 14 | 15 | steps: 16 | - name: "Checkout repository" 17 | uses: actions/checkout@v4 18 | 19 | - name: "Initialize CodeQL" 20 | uses: github/codeql-action/init@v2 21 | with: 22 | languages: "yaml" 23 | queries: +security-and-quality, ./.github/codeql/custom-queries.qls 24 | 25 | - name: "Perform CodeQL Analysis" 26 | uses: github/codeql-action/analyze@v2 27 | -------------------------------------------------------------------------------- /.github/workflows/docker-build.yml: -------------------------------------------------------------------------------- 1 | name: Build Docker images (scheduled) 2 | 3 | on: 4 | workflow_dispatch: 5 | workflow_call: 6 | schedule: 7 | - cron: "0 1 * * *" 8 | 9 | concurrency: 10 | group: docker-image-builds 11 | cancel-in-progress: false 12 | 13 | env: 14 | CI_SLACK_CHANNEL: ${{ secrets.CI_DOCKER_CHANNEL }} 15 | 16 | jobs: 17 | trl-latest: 18 | name: "Latest TRL GPU" 19 | runs-on: ubuntu-latest 20 | steps: 21 | - name: Cleanup disk 22 | run: | 23 | sudo ls -l /usr/local/lib/ 24 | sudo ls -l /usr/share/ 25 | sudo du -sh /usr/local/lib/ 26 | sudo du -sh /usr/share/ 27 | sudo rm -rf /usr/local/lib/android 28 | sudo rm -rf /usr/share/dotnet 29 | sudo du -sh /usr/local/lib/ 30 | sudo du -sh /usr/share/ 31 | - name: Set up Docker Buildx 32 | uses: docker/setup-buildx-action@v1 33 | - name: Check out code 34 | uses: actions/checkout@v4 35 | - name: Login to DockerHub 36 | uses: docker/login-action@v1 37 | with: 38 | username: ${{ secrets.DOCKERHUB_USERNAME }} 39 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 40 | 41 | - name: Build and Push GPU 42 | uses: docker/build-push-action@v4 43 | with: 44 | context: ./docker/trl-latest-gpu 45 | push: true 46 | tags: huggingface/trl-latest-gpu 47 | 48 | - name: Post to Slack 49 | if: always() 50 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 51 | with: 52 | slack_channel: ${{ env.CI_SLACK_CHANNEL }} 53 | title: 🤗 Results of the trl-latest-gpu Docker Image build 54 | status: ${{ job.status }} 55 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 56 | 57 | trl-source: 58 | name: "Latest TRL + HF ecosystem from source" 59 | runs-on: ubuntu-latest 60 | steps: 61 | - name: Cleanup disk 62 | run: | 63 | sudo ls -l /usr/local/lib/ 64 | sudo ls -l /usr/share/ 65 | sudo du -sh /usr/local/lib/ 66 | sudo du -sh /usr/share/ 67 | sudo rm -rf /usr/local/lib/android 68 | sudo rm -rf /usr/share/dotnet 69 | sudo du -sh /usr/local/lib/ 70 | sudo du -sh /usr/share/ 71 | - name: Set up Docker Buildx 72 | uses: docker/setup-buildx-action@v1 73 | - name: Check out code 74 | uses: actions/checkout@v4 75 | - name: Login to DockerHub 76 | uses: docker/login-action@v1 77 | with: 78 | username: ${{ secrets.DOCKERHUB_USERNAME }} 79 | password: ${{ secrets.DOCKERHUB_PASSWORD }} 80 | 81 | - name: Build and Push GPU 82 | uses: docker/build-push-action@v4 83 | with: 84 | context: ./docker/trl-source-gpu 85 | push: true 86 | tags: huggingface/trl-source-gpu 87 | 88 | - name: Post to Slack 89 | if: always() 90 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 91 | with: 92 | slack_channel: ${{ env.CI_SLACK_CHANNEL }} 93 | title: 🤗 Results of the trl-source-gpu Docker Image build 94 | status: ${{ job.status }} 95 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 96 | -------------------------------------------------------------------------------- /.github/workflows/issue_auto_labeller.yml: -------------------------------------------------------------------------------- 1 | name: "Hugging Face Issue Labeler" 2 | on: 3 | issues: 4 | types: opened 5 | 6 | jobs: 7 | triage: 8 | runs-on: ubuntu-latest 9 | permissions: 10 | issues: write 11 | steps: 12 | - uses: actions/checkout@v3 13 | - uses: August-murr/auto-labeler@main 14 | with: 15 | hf-api-key: ${{ secrets.CI_HF_API_TOKEN }} 16 | -------------------------------------------------------------------------------- /.github/workflows/slow-tests.yml: -------------------------------------------------------------------------------- 1 | name: Slow tests (on push) 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | paths: 7 | # Run only when python files are modified 8 | - "trl/**.py" 9 | - "examples/**.py" 10 | env: 11 | RUN_SLOW: "yes" 12 | IS_GITHUB_CI: "1" 13 | SLACK_API_TOKEN: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 14 | 15 | 16 | jobs: 17 | run_all_tests_single_gpu: 18 | strategy: 19 | fail-fast: false 20 | matrix: 21 | docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"] 22 | runs-on: 23 | group: aws-g4dn-2xlarge 24 | env: 25 | CUDA_VISIBLE_DEVICES: "0" 26 | TEST_TYPE: "single_gpu_${{ matrix.docker-image-name }}" 27 | container: 28 | image: ${{ matrix.docker-image-name }} 29 | options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true 30 | defaults: 31 | run: 32 | shell: bash 33 | steps: 34 | - uses: actions/checkout@v4 35 | - name: Pip install 36 | run: | 37 | source activate trl 38 | pip install -e ".[test]" --no-deps 39 | pip install pytest-reportlog parameterized 40 | 41 | - name: Run slow SFT tests on single GPU 42 | if: always() 43 | run: | 44 | source activate trl 45 | make slow_tests 46 | 47 | - name: Generate Report 48 | if: always() 49 | run: | 50 | pip install slack_sdk tabulate 51 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY 52 | 53 | 54 | run_all_tests_multi_gpu: 55 | strategy: 56 | fail-fast: false 57 | matrix: 58 | docker-image-name: ["huggingface/trl-latest-gpu:latest", "huggingface/trl-source-gpu:latest"] 59 | runs-on: 60 | group: aws-g4dn-2xlarge 61 | env: 62 | CUDA_VISIBLE_DEVICES: "0,1" 63 | TEST_TYPE: "multi_gpu_${{ matrix.docker-image-name }}" 64 | container: 65 | image: ${{ matrix.docker-image-name }} 66 | options: --gpus all --shm-size "16gb" -e NVIDIA_DISABLE_REQUIRE=true 67 | defaults: 68 | run: 69 | shell: bash 70 | steps: 71 | - uses: actions/checkout@v4 72 | - name: Pip install 73 | run: | 74 | source activate trl 75 | pip install -e ".[test]" --no-deps 76 | pip install pytest-reportlog parameterized 77 | 78 | - name: Run slow SFT tests on Multi GPU 79 | if: always() 80 | run: | 81 | source activate trl 82 | make slow_tests 83 | 84 | - name: Run end-to-end examples tests on multi GPU 85 | if: always() 86 | run: | 87 | source activate trl 88 | pip install deepspeed 89 | make test_examples 90 | 91 | - name: Generate Reports 92 | if: always() 93 | run: | 94 | pip install slack_sdk tabulate 95 | python scripts/log_reports.py >> $GITHUB_STEP_SUMMARY 96 | python scripts/log_example_reports.py --text_file_name temp_results_sft_tests.txt >> $GITHUB_STEP_SUMMARY 97 | python scripts/log_example_reports.py --text_file_name temp_results_dpo_tests.txt >> $GITHUB_STEP_SUMMARY 98 | rm *.txt 99 | -------------------------------------------------------------------------------- /.github/workflows/tests_latest.yml: -------------------------------------------------------------------------------- 1 | name: Tests latest TRL release with dev dependencies 2 | 3 | on: 4 | schedule: 5 | - cron: '0 0 * * *' # Runs daily at midnight UTC 6 | 7 | workflow_dispatch: 8 | 9 | env: 10 | TQDM_DISABLE: 1 11 | CI_SLACK_CHANNEL: ${{ secrets.CI_PUSH_MAIN_CHANNEL }} 12 | 13 | jobs: 14 | tests: 15 | name: Tests latest TRL release with dev dependencies 16 | runs-on: 17 | group: aws-g4dn-2xlarge 18 | container: 19 | image: pytorch/pytorch:2.6.0-cuda12.6-cudnn9-devel 20 | options: --gpus all 21 | defaults: 22 | run: 23 | shell: bash 24 | steps: 25 | - name: Git checkout 26 | uses: actions/checkout@v4 27 | with: { ref: v0.18-release } 28 | 29 | - name: Set up Python 3.12 30 | uses: actions/setup-python@v5 31 | with: 32 | python-version: '3.12' 33 | 34 | - name: Install Make and Git 35 | run: | 36 | apt-get update && apt-get install -y make git curl 37 | 38 | - name: Install uv 39 | run: | 40 | curl -LsSf https://astral.sh/uv/install.sh | sh 41 | 42 | - name: Create Python virtual environment 43 | run: | 44 | uv venv 45 | uv pip install --upgrade setuptools wheel 46 | 47 | - name: Install dependencies 48 | run: | 49 | source .venv/bin/activate 50 | uv pip install -U git+https://github.com/huggingface/accelerate.git 51 | uv pip install -U git+https://github.com/huggingface/datasets.git 52 | uv pip install -U git+https://github.com/huggingface/transformers.git 53 | uv pip install ".[dev]" 54 | 55 | - name: Test with pytest 56 | run: | 57 | source .venv/bin/activate 58 | make test 59 | 60 | - name: Post to Slack 61 | uses: huggingface/hf-workflows/.github/actions/post-slack@main 62 | with: 63 | slack_channel: ${{ env.CI_SLACK_CHANNEL }} 64 | title: Results of latest TRL with Python 3.12 and dev dependencies 65 | status: ${{ job.status }} 66 | slack_token: ${{ secrets.SLACK_CIFEEDBACK_BOT_TOKEN }} 67 | -------------------------------------------------------------------------------- /.github/workflows/trufflehog.yml: -------------------------------------------------------------------------------- 1 | on: 2 | push: 3 | 4 | name: Secret Leaks 5 | 6 | jobs: 7 | trufflehog: 8 | runs-on: ubuntu-latest 9 | steps: 10 | - name: Checkout code 11 | uses: actions/checkout@v4 12 | with: 13 | fetch-depth: 0 14 | - name: Secret Scanning 15 | uses: trufflesecurity/trufflehog@853e1e8d249fd1e29d0fcc7280d29b03df3d643d 16 | with: 17 | # exclude buggy postgres detector that is causing false positives and not relevant to our codebase 18 | extra_args: --results=verified,unknown --exclude-detectors=postgres 19 | -------------------------------------------------------------------------------- /.github/workflows/upload_pr_documentation.yml: -------------------------------------------------------------------------------- 1 | name: Upload PR Documentation 2 | 3 | on: 4 | workflow_run: 5 | workflows: ["Build PR Documentation"] 6 | types: 7 | - completed 8 | 9 | jobs: 10 | build: 11 | uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main 12 | with: 13 | package_name: trl 14 | secrets: 15 | hf_token: ${{ secrets.HF_DOC_BUILD_PUSH }} 16 | comment_bot_token: ${{ secrets.COMMENT_BOT_TOKEN }} -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | # wandb files 143 | nbs/wandb/ 144 | examples/notebooks/wandb/ 145 | wandb/ -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/astral-sh/ruff-pre-commit 3 | rev: v0.11.10 4 | hooks: 5 | - id: ruff-check 6 | types_or: [ python, pyi ] 7 | args: [ --fix ] 8 | - id: ruff-format 9 | types_or: [ python, pyi ] 10 | 11 | # - repo: https://github.com/codespell-project/codespell 12 | # rev: v2.1.0 13 | # hooks: 14 | # - id: codespell 15 | # args: 16 | # - --ignore-words-list=nd,reacher,thist,ths,magent,ba 17 | # - --skip=docs/css/termynal.css,docs/js/termynal.js 18 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | title: 'TRL: Transformer Reinforcement Learning' 3 | message: >- 4 | If you use this software, please cite it using the 5 | metadata from this file. 6 | type: software 7 | authors: 8 | - given-names: Leandro 9 | family-names: von Werra 10 | - given-names: Younes 11 | family-names: Belkada 12 | - given-names: Lewis 13 | family-names: Tunstall 14 | - given-names: Edward 15 | family-names: Beeching 16 | - given-names: Tristan 17 | family-names: Thrush 18 | - given-names: Nathan 19 | family-names: Lambert 20 | - given-names: Shengyi 21 | family-names: Huang 22 | - given-names: Kashif 23 | family-names: Rasul 24 | - given-names: Quentin 25 | family-names: Gallouédec 26 | repository-code: 'https://github.com/huggingface/trl' 27 | abstract: "With trl you can train transformer language models with Proximal Policy Optimization (PPO). The library is built on top of the transformers library by \U0001F917 Hugging Face. Therefore, pre-trained language models can be directly loaded via transformers. At this point, most decoder and encoder-decoder architectures are supported." 28 | keywords: 29 | - rlhf 30 | - deep-learning 31 | - pytorch 32 | - transformers 33 | license: Apache-2.0 34 | version: 0.18 35 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include LICENSE 2 | include CONTRIBUTING.md 3 | include README.md 4 | recursive-exclude * __pycache__ 5 | include trl/templates/*.md 6 | include trl/accelerate_configs/*.yaml -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: test precommit common_tests slow_tests test_examples tests_gpu 2 | 3 | check_dirs := examples tests trl 4 | 5 | ACCELERATE_CONFIG_PATH = `pwd`/examples/accelerate_configs 6 | COMMAND_FILES_PATH = `pwd`/commands 7 | 8 | test: 9 | pytest -n auto -m "not slow and not low-priority" -s -v --reruns 5 --reruns-delay 1 --only-rerun '(OSError|Timeout|HTTPError.*502|HTTPError.*504||not less than or equal to 0.01)' tests/ 10 | 11 | precommit: 12 | python scripts/add_copyrights.py 13 | pre-commit run --all-files 14 | 15 | slow_tests: 16 | pytest -m "slow" tests/ $(if $(IS_GITHUB_CI),--report-log "slow_tests.log",) 17 | 18 | test_examples: 19 | touch temp_results_sft_tests.txt 20 | for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \ 21 | TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_sft.sh; \ 22 | echo $$?','$${file} >> temp_results_sft_tests.txt; \ 23 | done 24 | 25 | touch temp_results_dpo_tests.txt 26 | for file in $(ACCELERATE_CONFIG_PATH)/*.yaml; do \ 27 | TRL_ACCELERATE_CONFIG=$${file} bash $(COMMAND_FILES_PATH)/run_dpo.sh; \ 28 | echo $$?','$${file} >> temp_results_dpo_tests.txt; \ 29 | done 30 | -------------------------------------------------------------------------------- /commands/run_dpo.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script runs an SFT example end-to-end on a tiny model using different possible configurations 3 | # but defaults to QLoRA + PEFT 4 | OUTPUT_DIR="test_dpo/" 5 | MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" 6 | DATASET_NAME="trl-internal-testing/hh-rlhf-helpful-base-trl-style" 7 | MAX_STEPS=5 8 | BATCH_SIZE=2 9 | SEQ_LEN=128 10 | 11 | # Handle extra arguments in case one passes accelerate configs. 12 | EXTRA_ACCELERATE_ARGS="" 13 | EXTRA_TRAINING_ARGS="""--use_peft \ 14 | --load_in_4bit 15 | """ 16 | 17 | # This is a hack to get the number of available GPUs 18 | NUM_GPUS=2 19 | 20 | if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then 21 | EXTRA_ACCELERATE_ARGS="" 22 | else 23 | EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG" 24 | # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed 25 | # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training. 26 | if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then 27 | EXTRA_TRAINING_ARGS="--fp16" 28 | else 29 | echo "Keeping QLoRA + PEFT" 30 | fi 31 | fi 32 | 33 | 34 | CMD=""" 35 | accelerate launch $EXTRA_ACCELERATE_ARGS \ 36 | --num_processes $NUM_GPUS \ 37 | --mixed_precision 'fp16' \ 38 | `pwd`/trl/scripts/dpo.py \ 39 | --model_name_or_path $MODEL_NAME \ 40 | --dataset_name $DATASET_NAME \ 41 | --output_dir $OUTPUT_DIR \ 42 | --max_steps $MAX_STEPS \ 43 | --per_device_train_batch_size $BATCH_SIZE \ 44 | --max_length $SEQ_LEN \ 45 | $EXTRA_TRAINING_ARGS 46 | """ 47 | 48 | echo "Starting program..." 49 | 50 | { # try 51 | echo $CMD 52 | eval "$CMD" 53 | } || { # catch 54 | # save log for exception 55 | echo "Operation Failed!" 56 | exit 1 57 | } 58 | exit 0 59 | -------------------------------------------------------------------------------- /commands/run_sft.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # This script runs an SFT example end-to-end on a tiny model using different possible configurations 3 | # but defaults to QLoRA + PEFT 4 | OUTPUT_DIR="test_sft/" 5 | MODEL_NAME="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" 6 | DATASET_NAME="stanfordnlp/imdb" 7 | MAX_STEPS=5 8 | BATCH_SIZE=2 9 | SEQ_LEN=128 10 | 11 | 12 | # Handle extra arguments in case one passes accelerate configs. 13 | EXTRA_ACCELERATE_ARGS="" 14 | EXTRA_TRAINING_ARGS="""--use_peft \ 15 | --load_in_4bit 16 | """ 17 | 18 | # Set your number of GPUs here 19 | NUM_GPUS=2 20 | 21 | if [[ "${TRL_ACCELERATE_CONFIG}" == "" ]]; then 22 | EXTRA_ACCELERATE_ARGS="" 23 | else 24 | EXTRA_ACCELERATE_ARGS="--config_file $TRL_ACCELERATE_CONFIG" 25 | # For DeepSpeed configs we need to set the `--fp16` flag to comply with our configs exposed 26 | # on `examples/accelerate_configs` and our runners do not support bf16 mixed precision training. 27 | if [[ $TRL_ACCELERATE_CONFIG == *"deepspeed"* ]]; then 28 | EXTRA_TRAINING_ARGS="--fp16" 29 | else 30 | echo "Keeping QLoRA + PEFT" 31 | fi 32 | fi 33 | 34 | 35 | CMD=""" 36 | accelerate launch $EXTRA_ACCELERATE_ARGS \ 37 | --num_processes $NUM_GPUS \ 38 | --mixed_precision 'fp16' \ 39 | `pwd`/trl/scripts/sft.py \ 40 | --model_name $MODEL_NAME \ 41 | --dataset_name $DATASET_NAME \ 42 | --output_dir $OUTPUT_DIR \ 43 | --max_steps $MAX_STEPS \ 44 | --per_device_train_batch_size $BATCH_SIZE \ 45 | --max_length $SEQ_LEN \ 46 | $EXTRA_TRAINING_ARGS 47 | """ 48 | 49 | echo "Starting program..." 50 | 51 | { # try 52 | echo $CMD 53 | eval "$CMD" 54 | } || { # catch 55 | # save log for exception 56 | echo "Operation Failed!" 57 | exit 1 58 | } 59 | exit 0 60 | -------------------------------------------------------------------------------- /docker/trl-latest-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.10 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/trl/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | 33 | # Stage 2 34 | FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image 35 | COPY --from=compile-image /opt/conda /opt/conda 36 | ENV PATH /opt/conda/bin:$PATH 37 | 38 | RUN chsh -s /bin/bash 39 | SHELL ["/bin/bash", "-c"] 40 | RUN source activate trl && \ 41 | python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq 42 | 43 | # Install apt libs 44 | RUN apt-get update && \ 45 | apt-get install -y curl git wget && \ 46 | apt-get clean && \ 47 | rm -rf /var/lib/apt/lists* 48 | 49 | # Activate the conda env and install transformers + accelerate from source 50 | RUN source activate trl && \ 51 | python3 -m pip install -U --no-cache-dir \ 52 | librosa \ 53 | "soundfile>=0.12.1" \ 54 | scipy \ 55 | transformers \ 56 | accelerate \ 57 | peft \ 58 | trl[test]@git+https://github.com/huggingface/trl 59 | 60 | RUN source activate trl && \ 61 | pip freeze | grep trl 62 | 63 | RUN echo "source activate trl" >> ~/.profile 64 | 65 | # Activate the virtualenv 66 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /docker/trl-source-gpu/Dockerfile: -------------------------------------------------------------------------------- 1 | # Builds GPU docker image of PyTorch 2 | # Uses multi-staged approach to reduce size 3 | # Stage 1 4 | # Use base conda image to reduce time 5 | FROM continuumio/miniconda3:latest AS compile-image 6 | # Specify py version 7 | ENV PYTHON_VERSION=3.10 8 | # Install apt libs - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 9 | RUN apt-get update && \ 10 | apt-get install -y curl git wget software-properties-common git-lfs && \ 11 | apt-get clean && \ 12 | rm -rf /var/lib/apt/lists* 13 | 14 | # Install audio-related libraries 15 | RUN apt-get update && \ 16 | apt install -y ffmpeg 17 | 18 | RUN apt install -y libsndfile1-dev 19 | RUN git lfs install 20 | 21 | # Create our conda env - copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 22 | RUN conda create --name trl python=${PYTHON_VERSION} ipython jupyter pip 23 | RUN python3 -m pip install --no-cache-dir --upgrade pip 24 | 25 | # Below is copied from https://github.com/huggingface/accelerate/blob/main/docker/accelerate-gpu/Dockerfile 26 | # We don't install pytorch here yet since CUDA isn't available 27 | # instead we use the direct torch wheel 28 | ENV PATH /opt/conda/envs/trl/bin:$PATH 29 | # Activate our bash shell 30 | RUN chsh -s /bin/bash 31 | SHELL ["/bin/bash", "-c"] 32 | 33 | # Stage 2 34 | FROM nvidia/cuda:12.2.2-devel-ubuntu22.04 AS build-image 35 | COPY --from=compile-image /opt/conda /opt/conda 36 | ENV PATH /opt/conda/bin:$PATH 37 | 38 | RUN chsh -s /bin/bash 39 | SHELL ["/bin/bash", "-c"] 40 | RUN source activate trl && \ 41 | python3 -m pip install --no-cache-dir bitsandbytes optimum auto-gptq 42 | 43 | # Install apt libs 44 | RUN apt-get update && \ 45 | apt-get install -y curl git wget && \ 46 | apt-get clean && \ 47 | rm -rf /var/lib/apt/lists* 48 | 49 | # Activate the conda env and install transformers + accelerate from source 50 | RUN source activate trl && \ 51 | python3 -m pip install -U --no-cache-dir \ 52 | librosa \ 53 | "soundfile>=0.12.1" \ 54 | scipy \ 55 | git+https://github.com/huggingface/transformers \ 56 | git+https://github.com/huggingface/accelerate \ 57 | git+https://github.com/huggingface/peft \ 58 | trl[test]@git+https://github.com/huggingface/trl 59 | 60 | RUN source activate trl && \ 61 | pip freeze | grep transformers 62 | 63 | RUN echo "source activate trl" >> ~/.profile 64 | 65 | # Activate the virtualenv 66 | CMD ["/bin/bash"] -------------------------------------------------------------------------------- /docs/source/_toctree.yml: -------------------------------------------------------------------------------- 1 | - sections: 2 | - local: index 3 | title: TRL 4 | - local: installation 5 | title: Installation 6 | - local: quickstart 7 | title: Quickstart 8 | title: Getting started 9 | - sections: 10 | - local: dataset_formats 11 | title: Dataset Formats 12 | - local: how_to_train 13 | title: Training FAQ 14 | - local: logging 15 | title: Understanding Logs 16 | title: Conceptual Guides 17 | - sections: 18 | - local: clis 19 | title: Command Line Interface (CLI) 20 | - local: customization 21 | title: Customizing the Training 22 | - local: reducing_memory_usage 23 | title: Reducing Memory Usage 24 | - local: speeding_up_training 25 | title: Speeding Up Training 26 | - local: distributing_training 27 | title: Distributing Training 28 | - local: use_model 29 | title: Using Trained Models 30 | title: How-to guides 31 | - sections: 32 | - local: deepspeed_integration 33 | title: DeepSpeed 34 | - local: liger_kernel_integration 35 | title: Liger Kernel 36 | - local: peft_integration 37 | title: PEFT 38 | - local: unsloth_integration 39 | title: Unsloth 40 | - local: vllm_integration 41 | title: vLLM 42 | title: Integrations 43 | - sections: 44 | - local: example_overview 45 | title: Example Overview 46 | - local: community_tutorials 47 | title: Community Tutorials 48 | - local: sentiment_tuning 49 | title: Sentiment Tuning 50 | - local: using_llama_models 51 | title: Training StackLlama 52 | - local: detoxifying_a_lm 53 | title: Detoxifying a Language Model 54 | - local: multi_adapter_rl 55 | title: Multi Adapter RLHF 56 | - local: training_vlm_sft 57 | title: Fine-tuning a Multimodal Model Using SFT (Single or Multi-Image Dataset) 58 | title: Examples 59 | - sections: 60 | - sections: # Sorted alphabetically 61 | - local: alignprop_trainer 62 | title: AlignProp 63 | - local: bco_trainer 64 | title: BCO 65 | - local: cpo_trainer 66 | title: CPO 67 | - local: ddpo_trainer 68 | title: DDPO 69 | - local: dpo_trainer 70 | title: DPO 71 | - local: online_dpo_trainer 72 | title: Online DPO 73 | - local: gkd_trainer 74 | title: GKD 75 | - local: grpo_trainer 76 | title: GRPO 77 | - local: kto_trainer 78 | title: KTO 79 | - local: nash_md_trainer 80 | title: Nash-MD 81 | - local: orpo_trainer 82 | title: ORPO 83 | - local: ppo_trainer 84 | title: PPO 85 | - local: prm_trainer 86 | title: PRM 87 | - local: reward_trainer 88 | title: Reward 89 | - local: rloo_trainer 90 | title: RLOO 91 | - local: sft_trainer 92 | title: SFT 93 | - local: iterative_sft_trainer 94 | title: Iterative SFT 95 | - local: xpo_trainer 96 | title: XPO 97 | title: Trainers 98 | - local: models 99 | title: Model Classes 100 | - local: model_utils 101 | title: Model Utilities 102 | - local: best_of_n 103 | title: Best of N Sampling 104 | - local: judges 105 | title: Judges 106 | - local: callbacks 107 | title: Callbacks 108 | - local: data_utils 109 | title: Data Utilities 110 | - local: rewards 111 | title: Reward Functions 112 | - local: script_utils 113 | title: Script Utilities 114 | - local: others 115 | title: Others 116 | title: API 117 | -------------------------------------------------------------------------------- /docs/source/bco_trainer.md: -------------------------------------------------------------------------------- 1 | # BCO Trainer 2 | 3 | [![](https://img.shields.io/badge/All_models-BCO-blue)](https://huggingface.co/models?other=bco,trl) 4 | 5 | TRL supports the Binary Classifier Optimization (BCO). 6 | The [BCO](https://huggingface.co/papers/2404.04656) authors train a binary classifier whose logit serves as a reward so that the classifier maps {prompt, chosen completion} pairs to 1 and {prompt, rejected completion} pairs to 0. 7 | For a full example have a look at [`examples/scripts/bco.py`]. 8 | 9 | ## Expected dataset type 10 | 11 | The [`BCOTrainer`] requires an [unpaired preference dataset](dataset_formats#unpaired-preference). 12 | The [`BCOTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. 13 | 14 | ## Expected model format 15 | The BCO trainer expects a model of `AutoModelForCausalLM`, compared to PPO that expects `AutoModelForCausalLMWithValueHead` for the value function. 16 | 17 | ## Using the `BCOTrainer` 18 | 19 | For a detailed example have a look at the `examples/scripts/bco.py` script. At a high level we need to initialize the `BCOTrainer` with a `model` we wish to train and a reference `ref_model` which we will use to calculate the implicit rewards of the preferred and rejected response. 20 | 21 | The `beta` refers to the hyperparameter of the implicit reward, and the dataset contains the 3 entries listed above. Note that the `model` and `ref_model` need to have the same architecture (ie decoder only or encoder-decoder). 22 | 23 | 24 | 25 | ```py 26 | training_args = BCOConfig( 27 | beta=0.1, 28 | ) 29 | 30 | bco_trainer = BCOTrainer( 31 | model, 32 | model_ref, 33 | args=training_args, 34 | train_dataset=train_dataset, 35 | processing_class=tokenizer, 36 | ) 37 | ``` 38 | After this one can then call: 39 | 40 | ```py 41 | bco_trainer.train() 42 | ``` 43 | 44 | ## Underlying Distribution matching (UDM) 45 | 46 | In practical scenarios, the thumbs-up and thumbs-down datasets are likely to have divergent underlying distributions of prompts. 47 | Consider an LLM deployed for user feedback: if the model excels in writing tasks but underperforms in coding, the thumbs-up dataset will be dominated by writing-related prompts, while the thumbs-down dataset will contain mostly coding-related prompts. 48 | If the prompts in your desired and undesired datasets differ a lot, it is useful to enable UDM. 49 | 50 | Choose an embedding model and tokenizer: 51 | 52 | ```py 53 | embedding_model = AutoModel.from_pretrained(your_model_id) 54 | embedding_tokenizer = AutoTokenizer.from_pretrained(your_model_id) 55 | 56 | # customize this function depending on your embedding model 57 | def embed_prompt(input_ids, attention_mask, model): 58 | outputs = model(input_ids=input_ids, attention_mask=attention_mask) 59 | return outputs.last_hidden_state.mean(dim=1) 60 | 61 | embedding_model = Accelerator().prepare_model(self.embedding_model) 62 | embedding_func = partial(embed_prompt, model=embedding_model) 63 | ``` 64 | 65 | Set `prompt_sample_size` to define how many prompts are selected to train the UDM classifier and start the training with the provided embedding function: 66 | 67 | ```py 68 | training_args = BCOConfig( 69 | beta=0.1, 70 | prompt_sample_size=512, 71 | ) 72 | 73 | bco_trainer = BCOTrainer( 74 | model, 75 | model_ref, 76 | args=training_args, 77 | train_dataset=train_dataset, 78 | processing_class=tokenizer, 79 | embedding_func=embedding_func, 80 | embedding_tokenizer=self.embedding_tokenizer, 81 | ) 82 | 83 | bco_trainer.train() 84 | ``` 85 | 86 | ### For Mixture of Experts Models: Enabling the auxiliary loss 87 | 88 | MOEs are the most efficient if the load is about equally distributed between experts. 89 | To ensure that we train MOEs similarly during preference-tuning, it is beneficial to add the auxiliary loss from the load balancer to the final loss. 90 | 91 | This option is enabled by setting `output_router_logits=True` in the model config (e.g. MixtralConfig). 92 | To scale how much the auxiliary loss contributes to the total loss, use the hyperparameter `router_aux_loss_coef=...` (default: 0.001). 93 | 94 | ## BCOTrainer 95 | 96 | [[autodoc]] BCOTrainer 97 | 98 | ## BCOConfig 99 | 100 | [[autodoc]] BCOConfig 101 | -------------------------------------------------------------------------------- /docs/source/best_of_n.md: -------------------------------------------------------------------------------- 1 | # Best of N sampling: Alternative ways to get better model output without RL based fine-tuning 2 | 3 | Within the extras module is the `best-of-n` sampler class that serves as an alternative method of generating better model output. 4 | As to how it fares against the RL based fine-tuning, please look in the `examples` directory for a comparison example 5 | 6 | ## Usage 7 | 8 | To get started quickly, instantiate an instance of the class with a model, a length sampler, a tokenizer and a callable that serves as a proxy reward pipeline that outputs reward scores for input queries 9 | 10 | ```python 11 | 12 | from transformers import pipeline, AutoTokenizer 13 | from trl import AutoModelForCausalLMWithValueHead 14 | from trl.core import LengthSampler 15 | from trl.extras import BestOfNSampler 16 | 17 | ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) 18 | reward_pipe = pipeline("sentiment-analysis", model=reward_model, device=device) 19 | tokenizer = AutoTokenizer.from_pretrained(ref_model_name) 20 | tokenizer.pad_token = tokenizer.eos_token 21 | 22 | 23 | # callable that takes a list of raw text and returns a list of corresponding reward scores 24 | def queries_to_scores(list_of_strings): 25 | return [output["score"] for output in reward_pipe(list_of_strings)] 26 | 27 | best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler) 28 | 29 | 30 | ``` 31 | 32 | And assuming you have a list/tensor of tokenized queries, you can generate better output by calling the `generate` method 33 | 34 | ```python 35 | 36 | best_of_n.generate(query_tensors, device=device, **gen_kwargs) 37 | 38 | ``` 39 | The default sample size is 4, but you can change it at the time of instance initialization like so 40 | 41 | ```python 42 | 43 | best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, sample_size=8) 44 | 45 | ``` 46 | 47 | The default output is the result of taking the top scored output for each query, but you can change it to top 2 and so on by passing the `n_candidates` argument at the time of instance initialization 48 | 49 | ```python 50 | 51 | best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, n_candidates=2) 52 | 53 | ``` 54 | 55 | There is the option of setting the generation settings (like `temperature`, `pad_token_id`) at the time of instance creation as opposed to when calling the `generate` method. 56 | This is done by passing a `GenerationConfig` from the `transformers` library at the time of initialization 57 | 58 | ```python 59 | 60 | from transformers import GenerationConfig 61 | 62 | generation_config = GenerationConfig(min_length= -1, top_k=0.0, top_p= 1.0, do_sample= True, pad_token_id=tokenizer.eos_token_id) 63 | 64 | best_of_n = BestOfNSampler(model, tokenizer, queries_to_scores, length_sampler=output_length_sampler, generation_config=generation_config) 65 | 66 | best_of_n.generate(query_tensors, device=device) 67 | 68 | ``` 69 | 70 | Furthermore, at the time of initialization you can set the seed to control the repeatability of the generation process and the number of samples to generate for each query 71 | 72 | 73 | -------------------------------------------------------------------------------- /docs/source/callbacks.md: -------------------------------------------------------------------------------- 1 | # Callbacks 2 | 3 | ## SyncRefModelCallback 4 | 5 | [[autodoc]] SyncRefModelCallback 6 | 7 | ## RichProgressCallback 8 | 9 | [[autodoc]] RichProgressCallback 10 | 11 | ## WinRateCallback 12 | 13 | [[autodoc]] WinRateCallback 14 | 15 | ## LogCompletionsCallback 16 | 17 | [[autodoc]] LogCompletionsCallback 18 | 19 | ## MergeModelCallback 20 | 21 | [[autodoc]] MergeModelCallback -------------------------------------------------------------------------------- /docs/source/data_utils.md: -------------------------------------------------------------------------------- 1 | # Data Utilities 2 | 3 | ## is_conversational 4 | 5 | [[autodoc]] is_conversational 6 | 7 | ## apply_chat_template 8 | 9 | [[autodoc]] apply_chat_template 10 | 11 | ## maybe_apply_chat_template 12 | 13 | [[autodoc]] maybe_apply_chat_template 14 | 15 | ## maybe_convert_to_chatml 16 | 17 | [[autodoc]] maybe_convert_to_chatml 18 | 19 | ## extract_prompt 20 | 21 | [[autodoc]] extract_prompt 22 | 23 | ## maybe_extract_prompt 24 | 25 | [[autodoc]] maybe_extract_prompt 26 | 27 | ## unpair_preference_dataset 28 | 29 | [[autodoc]] unpair_preference_dataset 30 | 31 | ## maybe_unpair_preference_dataset 32 | 33 | [[autodoc]] maybe_unpair_preference_dataset 34 | 35 | ## pack_dataset 36 | 37 | [[autodoc]] pack_dataset 38 | 39 | ## truncate_dataset 40 | 41 | [[autodoc]] truncate_dataset 42 | -------------------------------------------------------------------------------- /docs/source/deepspeed_integration.md: -------------------------------------------------------------------------------- 1 | # DeepSpeed Integration 2 | 3 | 4 | 5 | Section under construction. Feel free to contribute! 6 | 7 | 8 | 9 | TRL supports training with DeepSpeed, a library that implements advanced training optimization techniques. These include optimizer state partitioning, offloading, gradient partitioning, and more. 10 | 11 | DeepSpeed integrates the [Zero Redundancy Optimizer (ZeRO)](https://huggingface.co/papers/1910.02054), which allows to scale the model size proportional to the number of devices with sustained high efficiency. 12 | 13 | ![ZeRO Stages](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/zero_stages.png) 14 | 15 | ## Installation 16 | 17 | To use DeepSpeed with TRL, install it using the following command: 18 | 19 | ```bash 20 | pip install deepspeed 21 | ``` 22 | 23 | ## Running Training Scripts with DeepSpeed 24 | 25 | No modifications to your training script are required. Simply run it with the DeepSpeed configuration file: 26 | 27 | ```bash 28 | accelerate launch --config_file train.py 29 | ``` 30 | 31 | We provide ready-to-use DeepSpeed configuration files in the [`examples/accelerate_configs`](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) directory. For example, to run training with ZeRO Stage 2, use the following command: 32 | 33 | ```bash 34 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero2.yaml train.py 35 | ``` 36 | 37 | ## Additional Resources 38 | 39 | Consult the 🤗 Accelerate [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more information about the DeepSpeed plugin. 40 | -------------------------------------------------------------------------------- /docs/source/distributing_training.md: -------------------------------------------------------------------------------- 1 | # Distributing Training 2 | 3 | 4 | Section under construction. Feel free to contribute! 5 | 6 | 7 | ## Multi-GPU Training with TRL 8 | 9 | The trainers in TRL use [🤗 Accelerate](https://github.com/huggingface/accelerate) to enable distributed training across multiple GPUs or nodes. To do so, first create an [🤗 Accelerate](https://github.com/huggingface/accelerate) config file by running 10 | 11 | ```bash 12 | accelerate config 13 | ``` 14 | 15 | and answering the questions according to your multi-GPU / multi-node setup. You can then launch distributed training by running: 16 | 17 | ```bash 18 | accelerate launch train.py 19 | ``` 20 | 21 | We also provide config files in the [examples folder](https://github.com/huggingface/trl/tree/main/examples/accelerate_configs) that can be used as templates. To use these templates, simply pass the path to the config file when launching a job, e.g.: 22 | 23 | ```shell 24 | accelerate launch --config_file examples/accelerate_configs/multi_gpu.yaml train.py 25 | ``` 26 | 27 | This automatically distributes the workload across all available GPUs. 28 | 29 | Under the hood, [🤗 Accelerate](https://github.com/huggingface/accelerate) creates one model per GPU. Each process: 30 | - Processes its own batch of data 31 | - Computes the loss and gradients for that batch 32 | - Shares gradient updates across all GPUs 33 | 34 | ![](https://huggingface.co/datasets/trl-lib/documentation-images/resolve/main/multi_gpu.png) 35 | 36 | The effective batch size is calculated as: 37 | 38 | $$ 39 | \text{Batch Size} = \text{per\_device\_train\_batch\_size} \times \text{num\_devices} \times \text{gradient\_accumulation\_steps} 40 | $$ 41 | 42 | To maintain a consistent batch size when scaling to multiple GPUs, make sure to update `per_device_train_batch_size` and `gradient_accumulation_steps` accordingly. 43 | 44 | Example, these configurations are equivalent, and should yield the same results: 45 | 46 | | Number of GPUs | Per device batch size | Gradient accumulation steps | Comments | 47 | | --- | --- | --- | --- | 48 | | 1 | 32 | 1 | Possibly high memory usage, but faster training | 49 | | 1 | 4 | 8 | Lower memory usage, slower training | 50 | | 8 | 4 | 1 | Multi-GPU to get the best of both worlds | 51 | 52 | 53 | 54 | Having one model per GPU can lead to high memory usage, which may not be feasible for large models or low-memory GPUs. In such cases, you can leverage [DeepSpeed](https://github.com/deepspeedai/DeepSpeed), which provides optimizations like model sharding, Zero Redundancy Optimizer, mixed precision training, and offloading to CPU or NVMe. Check out our [DeepSpeed Integration](deepspeed_integration.md) guide for more details. 55 | 56 | 57 | 58 | ## Multi-Nodes Training 59 | 60 | We're working on a guide for multi-node training. Stay tuned! 🚀 -------------------------------------------------------------------------------- /docs/source/installation.md: -------------------------------------------------------------------------------- 1 | # Installation 2 | You can install TRL either from PyPI or from source: 3 | 4 | ## PyPI 5 | Install the library with pip or [uv](https://docs.astral.sh/uv/): 6 | 7 | 8 | 9 | 10 | uv is a fast Rust-based Python package and project manager. Refer to [Installation](https://docs.astral.sh/uv/getting-started/installation/) for installation instructions), . 11 | 12 | ```bash 13 | uv pip install trl 14 | ``` 15 | 16 | 17 | 18 | 19 | ```bash 20 | pip install trl 21 | ``` 22 | 23 | 24 | 25 | 26 | ## Source 27 | You can also install the latest version from source. First clone the repo and then run the installation with `pip`: 28 | 29 | ```bash 30 | git clone https://github.com/huggingface/trl.git 31 | cd trl/ 32 | pip install -e . 33 | ``` 34 | 35 | If you want the development install you can replace the pip install with the following: 36 | 37 | ```bash 38 | pip install -e ".[dev]" 39 | ``` 40 | -------------------------------------------------------------------------------- /docs/source/iterative_sft_trainer.md: -------------------------------------------------------------------------------- 1 | # Iterative Trainer 2 | 3 | [![](https://img.shields.io/badge/All_models-Iterative_SFT-blue)](https://huggingface.co/models?other=iterative-sft,trl) 4 | 5 | Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code. 6 | 7 | ## Quickstart 8 | 9 | To get started quickly, you can either pass a model identifier or a pre-instantiated model to the trainer: 10 | 11 | ```python 12 | from trl import IterativeSFTConfig, IterativeSFTTrainer 13 | 14 | # Using a model identifier 15 | trainer = IterativeSFTTrainer( 16 | "facebook/opt-350m", 17 | args=IterativeSFTConfig( 18 | max_length=512, 19 | output_dir="./output", 20 | ), 21 | ) 22 | 23 | # Or using a pre-instantiated model 24 | from transformers import AutoModelForCausalLM, AutoTokenizer 25 | 26 | model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m") 27 | tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m") 28 | 29 | trainer = IterativeSFTTrainer( 30 | model, 31 | args=IterativeSFTConfig( 32 | max_length=512, 33 | output_dir="./output", 34 | ), 35 | processing_class=tokenizer, 36 | ) 37 | ``` 38 | 39 | ## Usage 40 | 41 | The [`IterativeSFTTrainer`] supports two ways of providing input data to the `step` function: 42 | 43 | ### Using a list of tensors as input: 44 | 45 | ```python 46 | inputs = { 47 | "input_ids": input_ids, 48 | "attention_mask": attention_mask, 49 | } 50 | 51 | trainer.step(**inputs) 52 | ``` 53 | 54 | ### Using a list of strings as input: 55 | 56 | ```python 57 | inputs = { 58 | "texts": texts, 59 | "texts_labels": texts_labels, # Optional, defaults to texts 60 | } 61 | 62 | trainer.step(**inputs) 63 | ``` 64 | 65 | For causal language models, labels will automatically be created from `input_ids` or from `texts`. When using sequence to sequence models you will have to provide your own labels or `text_labels`. 66 | 67 | ## Configuration 68 | 69 | The [`IterativeSFTConfig`] class provides several parameters to customize the training: 70 | 71 | ```python 72 | from trl import IterativeSFTConfig 73 | 74 | config = IterativeSFTConfig( 75 | # Model initialization parameters 76 | model_init_kwargs={"torch_dtype": "bfloat16"}, 77 | 78 | # Data preprocessing parameters 79 | max_length=512, 80 | truncation_mode="keep_end", 81 | 82 | # Training parameters 83 | output_dir="./output", 84 | learning_rate=2e-5, 85 | per_device_train_batch_size=4, 86 | gradient_accumulation_steps=4, 87 | max_steps=1000, 88 | logging_steps=10, 89 | save_steps=100, 90 | optim="adamw_torch", 91 | report_to="wandb", 92 | ) 93 | ``` 94 | 95 | ### Model Initialization 96 | 97 | You can control how the model is initialized by passing keyword arguments to `model_init_kwargs`: 98 | 99 | ```python 100 | config = IterativeSFTConfig( 101 | model_init_kwargs={ 102 | "torch_dtype": "bfloat16", 103 | "device_map": "auto", 104 | "trust_remote_code": True, 105 | } 106 | ) 107 | ``` 108 | 109 | ### Data Preprocessing 110 | 111 | The trainer supports two truncation modes: 112 | 113 | - `keep_end`: Truncates from the start of the sequence 114 | - `keep_start`: Truncates from the end of the sequence 115 | 116 | ```python 117 | config = IterativeSFTConfig( 118 | max_length=512, 119 | truncation_mode="keep_end", # or "keep_start" 120 | ) 121 | ``` 122 | 123 | ### Training Optimization 124 | 125 | You can optimize CUDA cache usage for more memory-efficient training: 126 | 127 | ```python 128 | config = IterativeSFTConfig( 129 | optimize_device_cache=True, 130 | ) 131 | ``` 132 | 133 | ## IterativeSFTTrainer 134 | 135 | [[autodoc]] IterativeSFTTrainer 136 | 137 | ## IterativeSFTConfig 138 | 139 | [[autodoc]] IterativeSFTConfig 140 | -------------------------------------------------------------------------------- /docs/source/judges.md: -------------------------------------------------------------------------------- 1 | # Judges 2 | 3 | 4 | 5 | TRL Judges is an experimental API which is subject to change at any time. 6 | 7 | 8 | 9 | TRL provides judges to easily compare two completions. 10 | 11 | Make sure to have installed the required dependencies by running: 12 | 13 | ```bash 14 | pip install trl[judges] 15 | ``` 16 | 17 | ## Using the provided judges 18 | 19 | TRL provides several judges out of the box. For example, you can use the `HfPairwiseJudge` to compare two completions using a pre-trained model from the Hugging Face model hub: 20 | 21 | ```python 22 | from trl import HfPairwiseJudge 23 | 24 | judge = HfPairwiseJudge() 25 | judge.judge( 26 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], 27 | completions=[["Paris", "Lyon"], ["Saturn", "Jupiter"]], 28 | ) # Outputs: [0, 1] 29 | ``` 30 | 31 | ## Define your own judge 32 | 33 | To define your own judge, we provide several base classes that you can subclass. For rank-based judges, you need to subclass [`BaseRankJudge`] and implement the [`BaseRankJudge.judge`] method. For pairwise judges, you need to subclass [`BasePairJudge`] and implement the [`BasePairJudge.judge`] method. If you want to define a judge that doesn't fit into these categories, you need to subclass [`BaseJudge`] and implement the [`BaseJudge.judge`] method. 34 | 35 | As an example, let's define a pairwise judge that prefers shorter completions: 36 | 37 | ```python 38 | from trl import BasePairwiseJudge 39 | 40 | class PrefersShorterJudge(BasePairwiseJudge): 41 | def judge(self, prompts, completions, shuffle_order=False): 42 | return [0 if len(completion[0]) > len(completion[1]) else 1 for completion in completions] 43 | ``` 44 | 45 | You can then use this judge as follows: 46 | 47 | ```python 48 | judge = PrefersShorterJudge() 49 | judge.judge( 50 | prompts=["What is the capital of France?", "What is the biggest planet in the solar system?"], 51 | completions=[["Paris", "The capital of France is Paris."], ["Jupiter is the biggest planet in the solar system.", "Jupiter"]], 52 | ) # Outputs: [0, 1] 53 | ``` 54 | 55 | ## Provided judges 56 | 57 | ### PairRMJudge 58 | 59 | [[autodoc]] PairRMJudge 60 | 61 | ### HfPairwiseJudge 62 | 63 | [[autodoc]] HfPairwiseJudge 64 | 65 | ### OpenAIPairwiseJudge 66 | 67 | [[autodoc]] OpenAIPairwiseJudge 68 | 69 | ### AllTrueJudge 70 | 71 | [[autodoc]] AllTrueJudge 72 | 73 | ## Base classes 74 | 75 | ### BaseJudge 76 | 77 | [[autodoc]] BaseJudge 78 | 79 | ### BaseBinaryJudge 80 | 81 | [[autodoc]] BaseBinaryJudge 82 | 83 | ### BaseRankJudge 84 | 85 | [[autodoc]] BaseRankJudge 86 | 87 | ### BasePairwiseJudge 88 | 89 | [[autodoc]] BasePairwiseJudge 90 | -------------------------------------------------------------------------------- /docs/source/liger_kernel_integration.md: -------------------------------------------------------------------------------- 1 | # Liger Kernel Integration 2 | 3 | 4 | 5 | Section under construction. Feel free to contribute! 6 | 7 | -------------------------------------------------------------------------------- /docs/source/model_utils.md: -------------------------------------------------------------------------------- 1 | # Model Utilities 2 | 3 | ## get_act_offloading_ctx_manager 4 | 5 | [[autodoc]] models.get_act_offloading_ctx_manager 6 | -------------------------------------------------------------------------------- /docs/source/models.md: -------------------------------------------------------------------------------- 1 | # Models 2 | 3 | With the `AutoModelForCausalLMWithValueHead` class TRL supports all decoder model architectures in transformers such as GPT-2, OPT, and GPT-Neo. In addition, with `AutoModelForSeq2SeqLMWithValueHead` you can use encoder-decoder architectures such as T5. TRL also requires reference models which are frozen copies of the model that is trained. With `create_reference_model` you can easily create a frozen copy and also share layers between the two models to save memory. 4 | 5 | ## PreTrainedModelWrapper 6 | 7 | [[autodoc]] PreTrainedModelWrapper 8 | 9 | ## AutoModelForCausalLMWithValueHead 10 | 11 | 12 | [[autodoc]] AutoModelForCausalLMWithValueHead 13 | - __init__ 14 | - forward 15 | - generate 16 | - _init_weights 17 | 18 | ## AutoModelForSeq2SeqLMWithValueHead 19 | 20 | [[autodoc]] AutoModelForSeq2SeqLMWithValueHead 21 | - __init__ 22 | - forward 23 | - generate 24 | - _init_weights 25 | 26 | ## create_reference_model 27 | 28 | [[autodoc]] create_reference_model -------------------------------------------------------------------------------- /docs/source/multi_adapter_rl.md: -------------------------------------------------------------------------------- 1 | # Multi Adapter RL (MARL) - a single base model for everything 2 | 3 | Here we present an approach that uses a single base model for the entire PPO algorithm - which includes retrieving the reference logits, computing the active logits and the rewards. This feature is experimental as we did not test the convergence of the approach. We encourage the community to let us know if they potentially face issues. 4 | 5 | ## Requirements 6 | 7 | You just need to install `peft` and optionally install `bitsandbytes` as well if you want to go for 8bit base models, for more memory efficient finetuning. 8 | 9 | ## Summary 10 | 11 | You need to address this approach in three stages that we summarize as follows: 12 | 13 | 1- Train a base model on the target domain (e.g. [IMDB dataset](https://huggingface.co/datasets/stanfordnlp/imdb)) - this is the Supervised Fine Tuning stage - it can leverage the `SFTTrainer` from TRL. 14 | 2- Train a reward model using `peft`. This is required in order to re-use the adapter during the RL optimisation process (step 3 below). We show an example of leveraging the `RewardTrainer` from TRL in [this example](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py) 15 | 3- Fine tune new adapters on the base model using PPO and the reward adapter. ("0 abstraction RL") 16 | 17 | Make sure to use the same model (i.e. same architecture and same weights) for the stages 2 & 3. 18 | 19 | ## Quickstart 20 | 21 | Let us assume you have trained your reward adapter on `llama-7b` model using `RewardTrainer` and pushed the weights on the hub under `trl-lib/llama-7b-hh-rm-adapter`. 22 | When doing PPO, before passing the model to `PPOTrainer` create your model as follows: 23 | 24 | ```python 25 | model_name = "huggyllama/llama-7b" 26 | rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" 27 | 28 | # PPO adapter 29 | lora_config = LoraConfig( 30 | r=16, 31 | lora_alpha=32, 32 | lora_dropout=0.05, 33 | bias="none", 34 | task_type="CAUSAL_LM", 35 | ) 36 | 37 | model = AutoModelForCausalLMWithValueHead.from_pretrained( 38 | model_name, 39 | peft_config=lora_config, 40 | reward_adapter=rm_adapter_id, 41 | ) 42 | 43 | ... 44 | trainer = PPOTrainer( 45 | model=model, 46 | ... 47 | ) 48 | 49 | ... 50 | ``` 51 | Then inside your PPO training loop, call the `compute_reward_score` method by accessing the `model` attribute from `PPOTrainer`. 52 | 53 | ```python 54 | rewards = trainer.model.compute_reward_score(**inputs) 55 | ``` 56 | 57 | ## Advanced usage 58 | 59 | ### Control on the adapter name 60 | 61 | If you are familiar with the `peft` library, you know that you can use multiple adapters inside the same model. What you can do is train multiple adapters on the same base model to fine-tune on different policies. 62 | In this case, you want to be able to control the adapter name you want to activate back, after retrieving the reward. For that, simply pass the appropriate `adapter_name` to `ppo_adapter_name` argument when calling `compute_reward_score`. 63 | 64 | ```python 65 | adapter_name_policy_1 = "policy_1" 66 | rewards = trainer.model.compute_reward_score(**inputs, ppo_adapter_name=adapter_name_policy_1) 67 | ... 68 | ``` 69 | 70 | ### Using 4-bit and 8-bit base models 71 | 72 | For more memory efficient fine-tuning, you can load your base model in 8-bit or 4-bit while keeping the adapters in the default precision (float32). 73 | Just pass the appropriate arguments (i.e. `load_in_8bit=True` or `load_in_4bit=True`) to `AutoModelForCausalLMWithValueHead.from_pretrained` as follows (assuming you have installed `bitsandbytes`): 74 | ```python 75 | model_name = "llama-7b" 76 | rm_adapter_id = "trl-lib/llama-7b-hh-rm-adapter" 77 | 78 | # PPO adapter 79 | lora_config = LoraConfig( 80 | r=16, 81 | lora_alpha=32, 82 | lora_dropout=0.05, 83 | bias="none", 84 | task_type="CAUSAL_LM", 85 | ) 86 | 87 | model = AutoModelForCausalLMWithValueHead.from_pretrained( 88 | model_name, 89 | peft_config=lora_config, 90 | reward_adapter=rm_adapter_id, 91 | load_in_8bit=True, 92 | ) 93 | 94 | ... 95 | trainer = PPOTrainer( 96 | model=model, 97 | ... 98 | ) 99 | ... 100 | ``` 101 | -------------------------------------------------------------------------------- /docs/source/others.md: -------------------------------------------------------------------------------- 1 | # Other 2 | 3 | ## profiling_decorator 4 | 5 | [[autodoc]] extras.profiling.profiling_decorator 6 | 7 | ## profiling_context 8 | 9 | [[autodoc]] extras.profiling.profiling_context 10 | -------------------------------------------------------------------------------- /docs/source/quickstart.md: -------------------------------------------------------------------------------- 1 | # Quickstart 2 | 3 | ## How does it work? 4 | 5 | Fine-tuning a language model via PPO consists of roughly three steps: 6 | 7 | 1. **Rollout**: The language model generates a response or continuation based on a query which could be the start of a sentence. 8 | 2. **Evaluation**: The query and response are evaluated with a function, model, human feedback, or some combination of them. The important thing is that this process should yield a scalar value for each query/response pair. The optimization will aim at maximizing this value. 9 | 3. **Optimization**: This is the most complex part. In the optimisation step the query/response pairs are used to calculate the log-probabilities of the tokens in the sequences. This is done with the model that is trained and a reference model, which is usually the pre-trained model before fine-tuning. The KL-divergence between the two outputs is used as an additional reward signal to make sure the generated responses don't deviate too far from the reference language model. The active language model is then trained with PPO. 10 | 11 | The full process is illustrated in the following figure: 12 | 13 | 14 | ## Minimal example 15 | 16 | The following code illustrates the steps above. 17 | 18 | ```python 19 | # 0. imports 20 | import torch 21 | from transformers import GPT2Tokenizer 22 | 23 | from trl import AutoModelForCausalLMWithValueHead, PPOConfig, PPOTrainer 24 | 25 | 26 | # 1. load a pretrained model 27 | model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") 28 | ref_model = AutoModelForCausalLMWithValueHead.from_pretrained("gpt2") 29 | tokenizer = GPT2Tokenizer.from_pretrained("gpt2") 30 | tokenizer.pad_token = tokenizer.eos_token 31 | 32 | # 2. initialize trainer 33 | ppo_config = {"mini_batch_size": 1, "batch_size": 1} 34 | config = PPOConfig(**ppo_config) 35 | ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer) 36 | 37 | # 3. encode a query 38 | query_txt = "This morning I went to the " 39 | query_tensor = tokenizer.encode(query_txt, return_tensors="pt").to(model.pretrained_model.device) 40 | 41 | # 4. generate model response 42 | generation_kwargs = { 43 | "min_length": -1, 44 | "top_k": 0.0, 45 | "top_p": 1.0, 46 | "do_sample": True, 47 | "pad_token_id": tokenizer.eos_token_id, 48 | "max_new_tokens": 20, 49 | } 50 | response_tensor = ppo_trainer.generate([item for item in query_tensor], return_prompt=False, **generation_kwargs) 51 | response_txt = tokenizer.decode(response_tensor[0]) 52 | 53 | # 5. define a reward for response 54 | # (this could be any reward such as human feedback or output from another model) 55 | reward = [torch.tensor(1.0, device=model.pretrained_model.device)] 56 | 57 | # 6. train model with ppo 58 | train_stats = ppo_trainer.step([query_tensor[0]], [response_tensor[0]], reward) 59 | ``` 60 | 61 | In general, you would run steps 3-6 in a for-loop and run it on many diverse queries. You can find more realistic examples in the examples section. 62 | 63 | ## How to use a trained model 64 | 65 | After training a `AutoModelForCausalLMWithValueHead`, you can directly use the model in `transformers`. 66 | ```python 67 | 68 | # .. Let's assume we have a trained model using `PPOTrainer` and `AutoModelForCausalLMWithValueHead` 69 | 70 | # push the model on the Hub 71 | model.push_to_hub("my-fine-tuned-model-ppo") 72 | 73 | # or save it locally 74 | model.save_pretrained("my-fine-tuned-model-ppo") 75 | 76 | # load the model from the Hub 77 | from transformers import AutoModelForCausalLM 78 | 79 | model = AutoModelForCausalLM.from_pretrained("my-fine-tuned-model-ppo") 80 | ``` 81 | 82 | You can also load your model with `AutoModelForCausalLMWithValueHead` if you want to use the value head, for example to continue training. 83 | 84 | ```python 85 | from trl.model import AutoModelForCausalLMWithValueHead 86 | 87 | model = AutoModelForCausalLMWithValueHead.from_pretrained("my-fine-tuned-model-ppo") 88 | ``` 89 | -------------------------------------------------------------------------------- /docs/source/reward_trainer.md: -------------------------------------------------------------------------------- 1 | # Reward Modeling 2 | 3 | [![](https://img.shields.io/badge/All_models-Reward_Trainer-blue)](https://huggingface.co/models?other=reward-trainer,trl) 4 | 5 | TRL supports custom reward modeling for anyone to perform reward modeling on their dataset and model. 6 | 7 | Check out a complete flexible example at [`examples/scripts/reward_modeling.py`](https://github.com/huggingface/trl/tree/main/examples/scripts/reward_modeling.py). 8 | 9 | ## Expected dataset type 10 | 11 | The [`RewardTrainer`] requires a [*implicit prompt* preference dataset](dataset_formats#preference). It means that the dataset should only contain the columns `"chosen"` and `"rejected"` (and not `"prompt"`). 12 | The [`RewardTrainer`] supports both [conversational](dataset_formats#conversational) and [standard](dataset_formats#standard) dataset format. When provided with a conversational dataset, the trainer will automatically apply the chat template to the dataset. 13 | 14 | You can also use a pretokenized dataset, in which case the dataset should contain the following columns: `input_ids_chosen`, `attention_mask_chosen`, `input_ids_rejected` and `attention_mask_rejected`. 15 | 16 | ## Using the `RewardTrainer` 17 | 18 | After preparing your dataset, you can use the [`RewardTrainer`] in the same way as the `Trainer` class from 🤗 Transformers. 19 | You should pass an `AutoModelForSequenceClassification` model to the [`RewardTrainer`], along with a [`RewardConfig`] which configures the hyperparameters of the training. 20 | 21 | ### Leveraging 🤗 PEFT to train a reward model 22 | 23 | Just pass a `peft_config` in the keyword arguments of [`RewardTrainer`], and the trainer should automatically take care of converting the model into a PEFT model! 24 | 25 | ```python 26 | from peft import LoraConfig, TaskType 27 | from transformers import AutoModelForSequenceClassification, AutoTokenizer 28 | from trl import RewardTrainer, RewardConfig 29 | 30 | model = AutoModelForSequenceClassification.from_pretrained("gpt2") 31 | peft_config = LoraConfig( 32 | task_type=TaskType.SEQ_CLS, 33 | inference_mode=False, 34 | r=8, 35 | lora_alpha=32, 36 | lora_dropout=0.1, 37 | ) 38 | 39 | ... 40 | 41 | trainer = RewardTrainer( 42 | model=model, 43 | args=training_args, 44 | processing_class=tokenizer, 45 | train_dataset=dataset, 46 | peft_config=peft_config, 47 | ) 48 | 49 | trainer.train() 50 | 51 | ``` 52 | 53 | ### Adding a margin to the loss 54 | 55 | As in the [Llama 2 paper](https://huggingface.co/papers/2307.09288), you can add a margin to the loss by adding a `margin` column to the dataset. The reward collator will automatically pass it through and the loss will be computed accordingly. 56 | 57 | ```python 58 | def add_margin(row): 59 | # Assume you have a score_chosen and score_rejected columns that you want to use to compute the margin 60 | return {'margin': row['score_chosen'] - row['score_rejected']} 61 | 62 | dataset = dataset.map(add_margin) 63 | ``` 64 | 65 | ### Centering rewards 66 | 67 | In many scenarios, it's preferable to ensure that a reward model's output is mean zero. This is often done by first calculating the model's average score and then subtracting it. 68 | 69 | [[Eisenstein et al., 2023]](https://huggingface.co/papers/2312.09244) proposed an auxiliary loss function designed to directly learn a centered reward model. This auxiliary loss minimizes the squared sum of the rewards, encouraging the model to naturally produce mean-zero outputs: 70 | 71 | $$\Big( R(p, r_1) + R(p, r_2) \Big)^2 $$ 72 | 73 | This auxiliary loss is combined with the main loss function, weighted by the parameter `center_rewards_coefficient` in the `[RewardConfig]`. By default, this feature is deactivated (`center_rewards_coefficient = None`). 74 | 75 | ```python 76 | training_args = RewardConfig( 77 | center_rewards_coefficient=0.01, 78 | ... 79 | ) 80 | ``` 81 | 82 | For reference results, please refer PR [#1932](https://github.com/huggingface/trl/pull/1932). 83 | 84 | ## RewardTrainer 85 | 86 | [[autodoc]] RewardTrainer 87 | 88 | ## RewardConfig 89 | 90 | [[autodoc]] RewardConfig 91 | -------------------------------------------------------------------------------- /docs/source/rewards.md: -------------------------------------------------------------------------------- 1 | # Reward Functions 2 | 3 | This module contains some useful reward functions, primarily intended for use with the [`GRPOTrainer`]. 4 | 5 | ## Format rewards 6 | 7 | ### think_format_reward 8 | 9 | [[autodoc]] rewards.think_format_reward 10 | -------------------------------------------------------------------------------- /docs/source/script_utils.md: -------------------------------------------------------------------------------- 1 | # Scripts Utilities 2 | 3 | ## ScriptArguments 4 | 5 | [[autodoc]] ScriptArguments 6 | 7 | ## TrlParser 8 | 9 | [[autodoc]] TrlParser 10 | - parse_args_and_config 11 | - parse_args_into_dataclasses 12 | - set_defaults_with_config 13 | -------------------------------------------------------------------------------- /docs/source/sentiment_tuning.md: -------------------------------------------------------------------------------- 1 | # Sentiment Tuning Examples 2 | 3 | The notebooks and scripts in this examples show how to fine-tune a model with a sentiment classifier (such as `lvwerra/distilbert-imdb`). 4 | 5 | Here's an overview of the notebooks and scripts in the [trl repository](https://github.com/huggingface/trl/tree/main/examples): 6 | 7 | 8 | 9 | | File | Description | 10 | |------------------------------------------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------| 11 | | [`examples/scripts/ppo.py`](https://github.com/huggingface/trl/blob/main/examples/scripts/ppo.py) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment.ipynb) | This script shows how to use the `PPOTrainer` to fine-tune a sentiment analysis model using IMDB dataset | 12 | | [`examples/notebooks/gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb) | This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. | 13 | | [`examples/notebooks/gpt2-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-control.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/trl/blob/main/examples/sentiment/notebooks/gpt2-sentiment-control.ipynb) | This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. 14 | 15 | 16 | 17 | ## Usage 18 | 19 | ```bash 20 | # 1. run directly 21 | python examples/scripts/ppo.py 22 | # 2. run via `accelerate` (recommended), enabling more features (e.g., multiple GPUs, deepspeed) 23 | accelerate config # will prompt you to define the training configuration 24 | accelerate launch examples/scripts/ppo.py # launches training 25 | # 3. get help text and documentation 26 | python examples/scripts/ppo.py --help 27 | # 4. configure logging with wandb and, say, mini_batch_size=1 and gradient_accumulation_steps=16 28 | python examples/scripts/ppo.py --log_with wandb --mini_batch_size 1 --gradient_accumulation_steps 16 29 | ``` 30 | 31 | Note: if you don't want to log with `wandb` remove `log_with="wandb"` in the scripts/notebooks. You can also replace it with your favourite experiment tracker that's [supported by `accelerate`](https://huggingface.co/docs/accelerate/usage_guides/tracking). 32 | 33 | 34 | ## Few notes on multi-GPU 35 | 36 | To run in multi-GPU setup with DDP (distributed Data Parallel) change the `device_map` value to `device_map={"": Accelerator().process_index}` and make sure to run your script with `accelerate launch yourscript.py`. If you want to apply naive pipeline parallelism you can use `device_map="auto"`. -------------------------------------------------------------------------------- /docs/source/speeding_up_training.md: -------------------------------------------------------------------------------- 1 | # Speeding Up Training 2 | 3 | 4 | 5 | Section under construction. Feel free to contribute! 6 | 7 | 8 | 9 | ## vLLM for fast generation in online methods 10 | 11 | Online methods such as GRPO or Online DPO require the model to generate completions, which is often a slow process and can significantly impact training time. 12 | To speed up generation, you can use [vLLM](https://github.com/vllm-project/vllm), a library that enables fast generation through, among other things, PagedAttention. TRL's online trainers support vLLM, greatly improving training speed. 13 | 14 | To use [vLLM](https://github.com/vllm-project/vllm), first install it using: 15 | 16 | ```bash 17 | pip install vllm 18 | ``` 19 | 20 | or 21 | 22 | ```bash 23 | pip install "trl[vllm]" 24 | ``` 25 | 26 | 27 | 28 | 29 | Then, enable it by passing `use_vllm=True` in the training arguments. 30 | 31 | ```python 32 | from trl import OnlineDPOConfig 33 | 34 | training_args = OnlineDPOConfig(..., use_vllm=True) 35 | ``` 36 | 37 | 38 | 39 | 40 | First, start a vLLM server by running: 41 | 42 | ```bash 43 | trl vllm-serve --model 44 | ``` 45 | 46 | Then, run the training script and pass `use_vllm=True` in the training arguments. 47 | 48 | ```python 49 | from trl import GRPOConfig 50 | 51 | training_args = GRPOConfig(..., use_vllm=True) 52 | ``` 53 | 54 | You can customize the server configuration by passing additional arguments. For more information, see [vLLM integration](vllm_integration). 55 | 56 | 57 | 58 | When using vLLM, ensure that the GPUs assigned for training and generation are separate to avoid resource conflicts. For instance, if you plan to use 4 GPUs for training and another 4 for vLLM generation, you can specify GPU allocation using `CUDA_VISIBLE_DEVICES`. 59 | 60 | Set GPUs **0-3** for vLLM generation: 61 | ```sh 62 | CUDA_VISIBLE_DEVICES=0,1,2,3 trl vllm-serve --model 63 | ``` 64 | 65 | And GPUs **4-7** for training: 66 | ```sh 67 | CUDA_VISIBLE_DEVICES=4,5,6,7 accelerate launch train.py 68 | ``` 69 | 70 | 71 | 72 | 73 | 74 | -------------------------------------------------------------------------------- /docs/source/unsloth_integration.md: -------------------------------------------------------------------------------- 1 | # Unsloth Integration 2 | 3 | 4 | 5 | Section under construction. Feel free to contribute! 6 | 7 | -------------------------------------------------------------------------------- /docs/source/use_model.md: -------------------------------------------------------------------------------- 1 | # Use model after training 2 | 3 | Once you have trained a model using either the SFTTrainer, PPOTrainer, or DPOTrainer, you will have a fine-tuned model that can be used for text generation. In this section, we'll walk through the process of loading the fine-tuned model and generating text. If you need to run an inference server with the trained model, you can explore libraries such as [`text-generation-inference`](https://github.com/huggingface/text-generation-inference). 4 | 5 | ## Load and Generate 6 | 7 | If you have fine-tuned a model fully, meaning without the use of PEFT you can simply load it like any other language model in transformers. E.g. the value head that was trained during the PPO training is no longer needed and if you load the model with the original transformer class it will be ignored: 8 | 9 | ```python 10 | from transformers import AutoTokenizer, AutoModelForCausalLM 11 | 12 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub 13 | device = "cpu" # or "cuda" if you have a GPU 14 | 15 | model = AutoModelForCausalLM.from_pretrained(model_name_or_path).to(device) 16 | tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) 17 | 18 | inputs = tokenizer.encode("This movie was really", return_tensors="pt").to(device) 19 | outputs = model.generate(inputs) 20 | print(tokenizer.decode(outputs[0])) 21 | ``` 22 | 23 | Alternatively you can also use the pipeline: 24 | 25 | ```python 26 | from transformers import pipeline 27 | 28 | model_name_or_path = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub 29 | pipe = pipeline("text-generation", model=model_name_or_path) 30 | print(pipe("This movie was really")[0]["generated_text"]) 31 | ``` 32 | 33 | ## Use Adapters PEFT 34 | 35 | ```python 36 | from peft import PeftConfig, PeftModel 37 | from transformers import AutoModelForCausalLM, AutoTokenizer 38 | 39 | base_model_name = "kashif/stack-llama-2" #path/to/your/model/or/name/on/hub" 40 | adapter_model_name = "path/to/my/adapter" 41 | 42 | model = AutoModelForCausalLM.from_pretrained(base_model_name) 43 | model = PeftModel.from_pretrained(model, adapter_model_name) 44 | 45 | tokenizer = AutoTokenizer.from_pretrained(base_model_name) 46 | ``` 47 | 48 | You can also merge the adapters into the base model so you can use the model like a normal transformers model, however the checkpoint will be significantly bigger: 49 | 50 | ```python 51 | model = AutoModelForCausalLM.from_pretrained(base_model_name) 52 | model = PeftModel.from_pretrained(model, adapter_model_name) 53 | 54 | model = model.merge_and_unload() 55 | model.save_pretrained("merged_adapters") 56 | ``` 57 | 58 | Once you have the model loaded and either merged the adapters or keep them separately on top you can run generation as with a normal model outlined above. 59 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # Examples 2 | 3 | Please check out https://huggingface.co/docs/trl/example_overview for documentation on our examples. -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /examples/accelerate_configs/deepspeed_zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /examples/accelerate_configs/fsdp1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: false 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_reshard_after_forward: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | fsdp_version: 1 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /examples/accelerate_configs/fsdp2.yaml: -------------------------------------------------------------------------------- 1 | # Requires accelerate 1.7.0 or higher 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | enable_cpu_affinity: false 7 | fsdp_config: 8 | fsdp_activation_checkpointing: false 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_offload_params: false 12 | fsdp_reshard_after_forward: true 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_version: 2 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /examples/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /examples/accelerate_configs/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: "NO" 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /examples/cli_configs/example_config.yaml: -------------------------------------------------------------------------------- 1 | # This is an example configuration file of TRL CLI, you can use it for 2 | # SFT like that: `trl sft --config config.yaml --output_dir test-sft` 3 | # The YAML file supports environment variables by adding an `env` field 4 | # as below 5 | 6 | # env: 7 | # CUDA_VISIBLE_DEVICES: 0 8 | 9 | model_name_or_path: 10 | Qwen/Qwen2.5-0.5B 11 | dataset_name: 12 | stanfordnlp/imdb 13 | report_to: 14 | none 15 | learning_rate: 16 | 0.0001 17 | lr_scheduler_type: 18 | cosine 19 | -------------------------------------------------------------------------------- /examples/datasets/tldr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from datasets import load_dataset 19 | from huggingface_hub import ModelCard 20 | from transformers import HfArgumentParser 21 | 22 | 23 | @dataclass 24 | class ScriptArguments: 25 | r""" 26 | Arguments for the script. 27 | 28 | Args: 29 | push_to_hub (`bool`, *optional*, defaults to `False`): 30 | Whether to push the dataset to the Hugging Face Hub. 31 | repo_id (`str`, *optional*, defaults to `"trl-lib/tldr"`): 32 | Hugging Face repository ID to push the dataset to. 33 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 34 | Number of workers to use for dataset processing. 35 | """ 36 | 37 | push_to_hub: bool = field( 38 | default=False, 39 | metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, 40 | ) 41 | repo_id: str = field( 42 | default="trl-lib/tldr", 43 | metadata={"help": "Hugging Face repository ID to push the dataset to."}, 44 | ) 45 | dataset_num_proc: Optional[int] = field( 46 | default=None, 47 | metadata={"help": "Number of workers to use for dataset processing."}, 48 | ) 49 | 50 | 51 | def to_prompt_completion(example): 52 | tldr_format_str = "SUBREDDIT: r/{subreddit}\n\nTITLE: {title}\n\nPOST: {post}\n\nTL;DR:" 53 | prompt = tldr_format_str.format(subreddit=example["subreddit"], title=example["title"], post=example["post"]) 54 | completion = " " + example["summary"] # Add a space to separate the prompt from the completion 55 | return {"prompt": prompt, "completion": completion} 56 | 57 | 58 | model_card = ModelCard(""" 59 | --- 60 | tags: [trl] 61 | --- 62 | 63 | # TL;DR Dataset 64 | 65 | ## Summary 66 | 67 | The TL;DR dataset is a processed version of Reddit posts, specifically curated to train models using the [TRL library](https://github.com/huggingface/trl) for summarization tasks. It leverages the common practice on Reddit where users append "TL;DR" (Too Long; Didn't Read) summaries to lengthy posts, providing a rich source of paired text data for training summarization models. 68 | 69 | ## Data Structure 70 | 71 | - **Format**: [Standard](https://huggingface.co/docs/trl/main/dataset_formats#standard) 72 | - **Type**: [Prompt-completion](https://huggingface.co/docs/trl/main/dataset_formats#prompt-completion) 73 | 74 | Columns: 75 | - `"prompt"`: The unabridged Reddit post. 76 | - `"completion"`: The concise "TL;DR" summary appended by the author. 77 | 78 | This structure enables models to learn the relationship between detailed content and its abbreviated form, enhancing their summarization capabilities. 79 | 80 | ## Generation script 81 | 82 | The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/tldr.py). 83 | """) 84 | 85 | if __name__ == "__main__": 86 | parser = HfArgumentParser(ScriptArguments) 87 | script_args = parser.parse_args_into_dataclasses()[0] 88 | 89 | # Filtered reddit TL;DR dataset from https://github.com/openai/summarize-from-feedback?tab=readme-ov-file#reddit-tldr-dataset 90 | data_files = { 91 | "train": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/train.jsonl", 92 | "validation": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/valid.jsonl", 93 | "test": "https://openaipublic.blob.core.windows.net/summarize-from-feedback/datasets/tldr_3_filtered/test.jsonl", 94 | } 95 | dataset = load_dataset("json", data_files=data_files) 96 | 97 | dataset = dataset.map( 98 | to_prompt_completion, 99 | num_proc=script_args.dataset_num_proc, 100 | remove_columns=["id", "subreddit", "title", "post", "summary"], 101 | ) 102 | 103 | if script_args.push_to_hub: 104 | dataset.push_to_hub(script_args.repo_id) 105 | model_card.push_to_hub(script_args.repo_id, repo_type="dataset") 106 | -------------------------------------------------------------------------------- /examples/datasets/ultrafeedback-prompt.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from datasets import load_dataset 19 | from huggingface_hub import ModelCard 20 | from transformers import HfArgumentParser 21 | 22 | 23 | @dataclass 24 | class ScriptArguments: 25 | r""" 26 | Arguments for the script. 27 | 28 | Args: 29 | push_to_hub (`bool`, *optional*, defaults to `False`): 30 | Whether to push the dataset to the Hugging Face Hub. 31 | repo_id (`str`, *optional*, defaults to `"trl-lib/ultrafeedback-prompt"`): 32 | Hugging Face repository ID to push the dataset to. 33 | dataset_num_proc (`int` or `None`, *optional*, defaults to `None`): 34 | Number of workers to use for dataset processing. 35 | """ 36 | 37 | push_to_hub: bool = field( 38 | default=False, 39 | metadata={"help": "Whether to push the dataset to the Hugging Face Hub."}, 40 | ) 41 | repo_id: str = field( 42 | default="trl-lib/ultrafeedback-prompt", 43 | metadata={"help": "Hugging Face repository ID to push the dataset to."}, 44 | ) 45 | dataset_num_proc: Optional[int] = field( 46 | default=None, 47 | metadata={"help": "Number of workers to use for dataset processing."}, 48 | ) 49 | 50 | 51 | def to_unpaired_preference(example): 52 | prompt = [{"role": "user", "content": example["instruction"]}] 53 | return {"prompt": prompt} 54 | 55 | 56 | def drop_long_prompt(example): 57 | if len(example["prompt"][0]["content"]) > 512: 58 | return False 59 | else: 60 | return True 61 | 62 | 63 | model_card = ModelCard(""" 64 | --- 65 | tags: [trl] 66 | --- 67 | 68 | # UltraFeedback - Prompts Dataset 69 | 70 | ## Summary 71 | 72 | The UltraFeedback - Prompts dataset is a processed version of the [UltraFeedback](https://huggingface.co/datasets/openbmb/UltraFeedback) dataset for model evaluation on specific aspects like helpfulness, honesty, and instruction-following. 73 | 74 | ## Data Structure 75 | 76 | - **Format**: [Conversational](https://huggingface.co/docs/trl/main/dataset_formats#conversational) 77 | - **Type**: [Prompt-only](https://huggingface.co/docs/trl/main/dataset_formats#prompt-only) 78 | 79 | Column: 80 | - `"prompt"`: The input question or instruction provided to the model. 81 | 82 | ## Generation script 83 | 84 | The script used to generate this dataset can be found [here](https://github.com/huggingface/trl/blob/main/examples/datasets/ultrafeedback-prompt.py). 85 | """) 86 | 87 | if __name__ == "__main__": 88 | parser = HfArgumentParser(ScriptArguments) 89 | script_args = parser.parse_args_into_dataclasses()[0] 90 | 91 | dataset = load_dataset("openbmb/UltraFeedback", split="train") 92 | 93 | dataset = dataset.map( 94 | to_unpaired_preference, 95 | remove_columns=["source", "instruction", "models", "completions", "correct_answers", "incorrect_answers"], 96 | num_proc=script_args.dataset_num_proc, 97 | ) 98 | dataset = dataset.filter(drop_long_prompt) 99 | dataset = dataset.train_test_split(test_size=0.05, seed=42) 100 | 101 | if script_args.push_to_hub: 102 | dataset.push_to_hub(script_args.repo_id) 103 | model_card.push_to_hub(script_args.repo_id, repo_type="dataset") 104 | -------------------------------------------------------------------------------- /examples/notebooks/README.md: -------------------------------------------------------------------------------- 1 | # Notebooks 2 | 3 | This directory contains a collection of Jupyter notebooks that demonstrate how to use the TRL library in different applications. 4 | 5 | - [`best_of_n.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/best_of_n.ipynb): This notebook demonstrates how to use the "Best of N" sampling strategy using TRL when fine-tuning your model with PPO. 6 | - [`gpt2-sentiment.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment.ipynb): This notebook demonstrates how to reproduce the GPT2 imdb sentiment tuning example on a jupyter notebook. 7 | - [`gpt2-sentiment-control.ipynb`](https://github.com/huggingface/trl/tree/main/examples/notebooks/gpt2-sentiment-control.ipynb): This notebook demonstrates how to reproduce the GPT2 sentiment control example on a jupyter notebook. 8 | -------------------------------------------------------------------------------- /examples/research_projects/README.md: -------------------------------------------------------------------------------- 1 | # Research projects that use TRL 2 | 3 | Welcome to the research projects folder! Here you can find the scripts used for some research projects that used TRL and maintained by the developers and the community (LM de-toxification, Stack-Llama, etc.). Check out the READMEs in the subfolders for more information! 4 | 5 | - [De-detoxifying language models](https://github.com/huggingface/trl/tree/main/examples/research_projects/toxicity) 6 | - [Stack-Llama](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama) 7 | - [Stack-Llama-2](https://github.com/huggingface/trl/tree/main/examples/research_projects/stack_llama_2) -------------------------------------------------------------------------------- /examples/research_projects/layer_skip/README.md: -------------------------------------------------------------------------------- 1 | # LayerSkip Training Recipe 2 | 3 | Implements the training recipe as described in the [LayerSkip paper](https://huggingface.co/papers/2404.16710). 4 | 5 | ## Run training 6 | ``` 7 | cd scripts 8 | python layer_skip_sft.py 9 | ``` 10 | 11 | ## Run benchmark 12 | ``` 13 | cd scripts 14 | python benchmark_layer_skip.py 15 | ``` -------------------------------------------------------------------------------- /examples/research_projects/layer_skip/scripts/benchmark_layer_skip.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import config 16 | import torch 17 | from torch.utils import benchmark 18 | from transformers import AutoModelForCausalLM, AutoTokenizer 19 | 20 | 21 | def generate_tokens(model, inputs): 22 | outputs = model.generate( 23 | **inputs, 24 | do_sample=False, 25 | max_new_tokens=64, 26 | ) 27 | return outputs 28 | 29 | 30 | def generate_tokens_with_assistance(model, inputs, assistant_early_exit): 31 | outputs = model.generate( 32 | **inputs, 33 | assistant_early_exit=assistant_early_exit, 34 | do_sample=False, 35 | max_new_tokens=64, 36 | ) 37 | return outputs 38 | 39 | 40 | if __name__ == "__main__": 41 | ckpt = config.hub_model_id 42 | 43 | model = AutoModelForCausalLM.from_pretrained(ckpt, device_map="auto", torch_dtype=torch.bfloat16) 44 | tokenizer = AutoTokenizer.from_pretrained(ckpt) 45 | 46 | prompt = "### Instruction: What are my alarms for the rest of the day?\n ### Response: " 47 | 48 | results = [] 49 | label = "Generation Times" 50 | inputs = tokenizer(prompt, return_tensors="pt").to(model.device) 51 | 52 | results.append( 53 | benchmark.Timer( 54 | stmt="generate_tokens(model, inputs)", 55 | setup="from __main__ import generate_tokens", 56 | globals={"model": model, "inputs": inputs}, 57 | num_threads=torch.get_num_threads(), 58 | label=label, 59 | sub_label="no layer skip", 60 | description="generation", 61 | ).blocked_autorange() 62 | ) 63 | 64 | for i in range(1, model.config.num_hidden_layers): 65 | results.append( 66 | benchmark.Timer( 67 | stmt="generate_tokens_with_assistance(model, inputs, assistant_early_exit)", 68 | setup="from __main__ import generate_assistant_tokens", 69 | globals={"model": model, "assistant_early_exit": i, "inputs": inputs}, 70 | num_threads=torch.get_num_threads(), 71 | label=label, 72 | sub_label=f"layer skip {i}", 73 | description="generation", 74 | ).blocked_autorange() 75 | ) 76 | 77 | benchmark.Compare(results).print() 78 | -------------------------------------------------------------------------------- /examples/research_projects/layer_skip/scripts/config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from huggingface_hub import whoami 16 | 17 | 18 | model_name = "unsloth/Llama-3.2-3B" 19 | tokenizer_name = "unsloth/Llama-3.2-3B" 20 | dataset_name = "WillHeld/top_v2" 21 | 22 | output_root_dir = "./checkpoints/" 23 | hub_model_id = f"{whoami()['name']}/layerskip-{model_name.split('/')[1]}-{dataset_name.split('/')[1]}" 24 | output_dir = f"{output_root_dir}/{hub_model_id}" 25 | 26 | per_device_train_batch_size = 8 27 | gradient_accumulation_steps = 1 28 | learning_rate = 2e-5 29 | -------------------------------------------------------------------------------- /examples/research_projects/layer_skip/scripts/custom_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from trl import SFTTrainer 16 | 17 | 18 | class LayerSkipSFTTrainer(SFTTrainer): 19 | def __init__(self, *args, **kwargs): 20 | super().__init__(*args, **kwargs) 21 | self.early_exit_layer = 0 # initialize with 0 22 | self.always_last_layer = True 23 | self.early_exit_loss_scale = 1.0 24 | 25 | def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): 26 | self.early_exit_layer = ( 27 | self.early_exit_layer % (model.config.num_hidden_layers - 1) 28 | ) + 1 # rotates between [1, num_hidden_layers-1] 29 | bs, seqlen = inputs.input_ids.shape 30 | 31 | labels = inputs.pop("labels") 32 | outputs = model(**inputs, output_hidden_states=True) 33 | 34 | hidden_state = outputs["hidden_states"][self.early_exit_layer].to(model.dtype) 35 | if self.early_exit_layer != model.config.num_hidden_layers: 36 | hidden_state = model.model.norm(hidden_state) 37 | logits = model.lm_head(hidden_state) 38 | loss_early = model.loss_function(logits=logits, labels=labels, vocab_size=model.vocab_size) 39 | 40 | if self.always_last_layer: 41 | loss_last = model.loss_function(logits=outputs["logits"], labels=labels, vocab_size=model.vocab_size) 42 | loss = self.early_exit_loss_scale * loss_early.to(loss_last.device) + 1.0 * loss_last 43 | # normalize loss scales 44 | loss = loss / (1.0 + self.early_exit_loss_scale) 45 | else: 46 | loss = loss_early 47 | 48 | return loss 49 | -------------------------------------------------------------------------------- /examples/research_projects/layer_skip/scripts/layer_skip_sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import config 16 | import torch 17 | from custom_trainer import LayerSkipSFTTrainer 18 | from datasets import load_dataset 19 | from transformers import AutoModelForCausalLM, AutoTokenizer 20 | 21 | from trl import DataCollatorForCompletionOnlyLM, SFTConfig 22 | 23 | 24 | def formatting_prompts_func(example): 25 | text = f"### Instruction: {example['utterance']}\n ### Response: {example['semantic_parse']}" 26 | 27 | # Inject eos_token as a string before tokenization, because they are not always added 28 | # See: https://github.com/huggingface/transformers/issues/22794 and 29 | # https://github.com/huggingface/trl/issues/1623 30 | if tokenizer.eos_token: # usually something like "" for GPT2 or "<|endoftext|>" 31 | text += f"{tokenizer.eos_token}" 32 | 33 | return text 34 | 35 | 36 | if __name__ == "__main__": 37 | # load the dataset 38 | print("[INFO] loading the dataset...") 39 | train_dataset = load_dataset(config.dataset_name, split="train") 40 | 41 | print(f"output_root_dir: {config.output_root_dir}") 42 | print(f"hub_model_id: {config.hub_model_id}") 43 | 44 | # load the model and tokenizer 45 | print("[INFO] loading the model and tokenizer...") 46 | model = AutoModelForCausalLM.from_pretrained(config.model_name, device_map="auto", torch_dtype=torch.bfloat16) 47 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, add_eos_token=True) 48 | 49 | # adding pad and eos tokens if not provided in the tokenizer 50 | if tokenizer.pad_token is None: 51 | # Add '[PAD]' token if it doesn't exist 52 | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) 53 | model.resize_token_embeddings(len(tokenizer)) 54 | model.config.pad_token_id = tokenizer.pad_token_id 55 | 56 | if tokenizer.eos_token is None or tokenizer.eos_token == tokenizer.bos_token: 57 | # Add '[EOS]' token if it doesn't exist 58 | tokenizer.add_special_tokens({"eos_token": "[EOS]"}) 59 | model.resize_token_embeddings(len(tokenizer)) 60 | model.config.eos_token_id = tokenizer.eos_token_id 61 | 62 | response_template = " ### Response:" 63 | collator = DataCollatorForCompletionOnlyLM(response_template, tokenizer=tokenizer) 64 | 65 | args = SFTConfig( 66 | do_train=True, 67 | bf16=True, 68 | max_seq_length=None, 69 | per_device_train_batch_size=config.per_device_train_batch_size, 70 | gradient_accumulation_steps=config.gradient_accumulation_steps, 71 | learning_rate=config.learning_rate, 72 | packing=False, 73 | num_train_epochs=1.0, 74 | report_to="none", 75 | push_to_hub=True, 76 | hub_model_id=config.hub_model_id, 77 | output_dir=config.output_dir, 78 | logging_steps=500, 79 | save_steps=1000, 80 | save_total_limit=2, 81 | ) 82 | 83 | trainer = LayerSkipSFTTrainer( 84 | model, 85 | train_dataset=train_dataset, 86 | args=args, 87 | formatting_func=formatting_prompts_func, 88 | data_collator=collator, 89 | ) 90 | 91 | trainer.train() 92 | -------------------------------------------------------------------------------- /examples/research_projects/stack_llama/scripts/README.md: -------------------------------------------------------------------------------- 1 | # RLHF pipeline for the creation of StackLLaMa: a Stack exchange llama-7b model. 2 | There were three main steps to the training process: 3 | 1. Supervised fine-tuning of the base llama-7b model to create llama-7b-se: 4 | - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/supervised_finetuning.py --model_path= --streaming --learning_rate 1e-5 --max_steps 5000 --output_dir ./llama-se` 5 | 2. Reward modeling using dialog pairs from the SE dataset using the llama-7b-se to create llama-7b-se-rm: 6 | - `torchrun --nnodes 1 --nproc_per_node 8 examples/research_projects/stack_llama/scripts/reward_modeling.py --model_name=` 7 | 3. RL fine-tuning of llama-7b-se with the llama-7b-se-rm reward model: 8 | - `accelerate launch --multi_gpu --num_machines 1 --num_processes 8 examples/research_projects/stack_llama/scripts/rl_training.py --log_with=wandb --model_name= --reward_model_name= --adafactor=False --tokenizer_name= --save_freq=100 --output_max_length=128 --batch_size=8 --gradient_accumulation_steps=8 --batched_gen=True --ppo_epochs=4 --seed=0 --learning_rate=1.4e-5 --early_stopping=True --output_dir=llama-se-rl-finetune-128-8-8-1.4e-5_adam` 9 | 10 | 11 | LoRA layers were using at all stages to reduce memory requirements. 12 | At each stage the peft adapter layers were merged with the base model, using: 13 | ```shell 14 | python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --adapter_model_name=XXX --base_model_name=YYY --output_name=ZZZ 15 | ``` 16 | Note that this script requires `peft>=0.3.0`. 17 | 18 | For access to the base llama-7b model, please see Meta's [release](https://ai.facebook.com/blog/large-language-model-llama-meta-ai/) and [request form](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform). 19 | -------------------------------------------------------------------------------- /examples/research_projects/stack_llama/scripts/merge_peft_adapter.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | import torch 19 | from peft import PeftConfig, PeftModel 20 | from transformers import AutoModelForCausalLM, AutoModelForSequenceClassification, AutoTokenizer, HfArgumentParser 21 | 22 | 23 | @dataclass 24 | class ScriptArguments: 25 | """ 26 | The input names representing the Adapter and Base model fine-tuned with PEFT, and the output name representing the 27 | merged model. 28 | """ 29 | 30 | adapter_model_name: Optional[str] = field(default=None, metadata={"help": "the adapter name"}) 31 | base_model_name: Optional[str] = field(default=None, metadata={"help": "the base model name"}) 32 | output_name: Optional[str] = field(default=None, metadata={"help": "the merged model name"}) 33 | 34 | 35 | parser = HfArgumentParser(ScriptArguments) 36 | script_args = parser.parse_args_into_dataclasses()[0] 37 | assert script_args.adapter_model_name is not None, "please provide the name of the Adapter you would like to merge" 38 | assert script_args.base_model_name is not None, "please provide the name of the Base model" 39 | assert script_args.output_name is not None, "please provide the output name of the merged model" 40 | 41 | peft_config = PeftConfig.from_pretrained(script_args.adapter_model_name) 42 | if peft_config.task_type == "SEQ_CLS": 43 | # The sequence classification task is used for the reward model in PPO 44 | model = AutoModelForSequenceClassification.from_pretrained( 45 | script_args.base_model_name, num_labels=1, torch_dtype=torch.bfloat16 46 | ) 47 | else: 48 | model = AutoModelForCausalLM.from_pretrained( 49 | script_args.base_model_name, return_dict=True, torch_dtype=torch.bfloat16 50 | ) 51 | 52 | tokenizer = AutoTokenizer.from_pretrained(script_args.base_model_name) 53 | 54 | # Load the PEFT model 55 | model = PeftModel.from_pretrained(model, script_args.adapter_model_name) 56 | model.eval() 57 | 58 | model = model.merge_and_unload() 59 | 60 | model.save_pretrained(f"{script_args.output_name}") 61 | tokenizer.save_pretrained(f"{script_args.output_name}") 62 | model.push_to_hub(f"{script_args.output_name}", use_temp_dir=False) 63 | -------------------------------------------------------------------------------- /examples/research_projects/stack_llama_2/scripts/README.md: -------------------------------------------------------------------------------- 1 | # DPO pipeline for the creation of StackLlaMa 2: a Stack exchange llama-v2-7b model 2 | 3 | ## Prerequisites 4 | 5 | Install all the dependencies in the `requirements.txt`: 6 | 7 | ``` 8 | $ pip install -U -r requirements.txt 9 | ``` 10 | 11 | Since we will use `accelerate` for training, make sure to run: 12 | ``` 13 | $ accelerate config 14 | ``` 15 | 16 | ## Training 17 | 18 | There were two main steps to the DPO training process: 19 | 1. Supervised fine-tuning of the base llama-v2-7b model to create llama-v2-7b-se: 20 | 21 | ``` 22 | accelerate launch examples/research_projects/stack_llama_2/scripts/sft_llama2.py \ 23 | --output_dir="./sft" \ 24 | --max_steps=500 \ 25 | --logging_steps=10 \ 26 | --save_steps=10 \ 27 | --per_device_train_batch_size=4 \ 28 | --per_device_eval_batch_size=1 \ 29 | --gradient_accumulation_steps=2 \ 30 | --gradient_checkpointing=False \ 31 | --group_by_length=False \ 32 | --learning_rate=1e-4 \ 33 | --lr_scheduler_type="cosine" \ 34 | --warmup_steps=100 \ 35 | --weight_decay=0.05 \ 36 | --optim="paged_adamw_32bit" \ 37 | --bf16=True \ 38 | --remove_unused_columns=False \ 39 | --run_name="sft_llama2" \ 40 | --report_to="wandb" 41 | ``` 42 | 1. Run the DPO trainer using the model saved by the previous step: 43 | ``` 44 | accelerate launch examples/research_projects/stack_llama_2/scripts/dpo_llama2.py \ 45 | --model_name_or_path="sft/final_checkpoint" \ 46 | --output_dir="dpo" 47 | ``` 48 | 49 | 50 | ## Merging the adaptors 51 | 52 | To merge the adaptors into the base model we can use the `merge_peft_adapter.py` helper script that comes with TRL: 53 | 54 | ``` 55 | python examples/research_projects/stack_llama/scripts/merge_peft_adapter.py --base_model_name="meta-llama/Llama-2-7b-hf" --adapter_model_name="dpo/final_checkpoint/" --output_name="stack-llama-2" 56 | ``` 57 | 58 | which will also push the model to your HuggingFace hub account. 59 | 60 | ## Running the model 61 | 62 | We can load the DPO-trained LoRA adaptors which were saved by the DPO training step and load them via: 63 | 64 | ```py 65 | from peft import AutoPeftModelForCausalLM 66 | 67 | 68 | model = AutoPeftModelForCausalLM.from_pretrained( 69 | "dpo/final_checkpoint", 70 | low_cpu_mem_usage=True, 71 | torch_dtype=torch.float16, 72 | load_in_4bit=True, 73 | ) 74 | 75 | model.generate(...) 76 | ``` 77 | -------------------------------------------------------------------------------- /examples/research_projects/stack_llama_2/scripts/requirements.txt: -------------------------------------------------------------------------------- 1 | transformers 2 | trl 3 | peft 4 | accelerate 5 | datasets 6 | bitsandbytes 7 | wandb 8 | -------------------------------------------------------------------------------- /examples/research_projects/toxicity/README.md: -------------------------------------------------------------------------------- 1 | # De-detoxifying language models 2 | 3 | To run this code, do the following: 4 | 5 | ```shell 6 | ACCELERATE_LOG_LEVEL=info accelerate launch --config_file {CONFIG} examples/research_projects/toxicity/scripts/gpt-j-6b-toxicity.py --log_with wandb 7 | ``` 8 | -------------------------------------------------------------------------------- /examples/scripts/cpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the CPO training script with the following command with some example arguments. 17 | In general, the optimal configuration for CPO will be similar to that of DPO: 18 | 19 | # regular: 20 | python examples/scripts/cpo.py \ 21 | --dataset_name trl-lib/ultrafeedback_binarized \ 22 | --model_name_or_path=gpt2 \ 23 | --per_device_train_batch_size 4 \ 24 | --max_steps 1000 \ 25 | --learning_rate 8e-6 \ 26 | --gradient_accumulation_steps 1 \ 27 | --logging_steps 10 \ 28 | --eval_steps 500 \ 29 | --output_dir="gpt2-aligned-cpo" \ 30 | --warmup_steps 150 \ 31 | --report_to wandb \ 32 | --bf16 \ 33 | --logging_first_step \ 34 | --no_remove_unused_columns 35 | 36 | # peft: 37 | python examples/scripts/cpo.py \ 38 | --dataset_name trl-lib/ultrafeedback_binarized \ 39 | --model_name_or_path=gpt2 \ 40 | --per_device_train_batch_size 4 \ 41 | --max_steps 1000 \ 42 | --learning_rate 8e-5 \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 10 \ 45 | --eval_steps 500 \ 46 | --output_dir="gpt2-lora-aligned-cpo" \ 47 | --optim rmsprop \ 48 | --warmup_steps 150 \ 49 | --report_to wandb \ 50 | --bf16 \ 51 | --logging_first_step \ 52 | --no_remove_unused_columns \ 53 | --use_peft \ 54 | --lora_r=16 \ 55 | --lora_alpha=16 56 | """ 57 | 58 | from datasets import load_dataset 59 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 60 | 61 | from trl import CPOConfig, CPOTrainer, ModelConfig, ScriptArguments, get_peft_config 62 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = HfArgumentParser((ScriptArguments, CPOConfig, ModelConfig)) 67 | script_args, training_args, model_args = parser.parse_args_into_dataclasses() 68 | 69 | ################ 70 | # Model & Tokenizer 71 | ################ 72 | model = AutoModelForCausalLM.from_pretrained( 73 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 74 | ) 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 77 | ) 78 | if tokenizer.pad_token is None: 79 | tokenizer.pad_token = tokenizer.eos_token 80 | 81 | ################ 82 | # Dataset 83 | ################ 84 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 85 | if tokenizer.chat_template is None: 86 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE 87 | 88 | ################ 89 | # Training 90 | ################ 91 | trainer = CPOTrainer( 92 | model, 93 | args=training_args, 94 | train_dataset=dataset[script_args.dataset_train_split], 95 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 96 | processing_class=tokenizer, 97 | peft_config=get_peft_config(model_args), 98 | ) 99 | 100 | # train and save the model 101 | trainer.train() 102 | 103 | # Save and push to hub 104 | trainer.save_model(training_args.output_dir) 105 | if training_args.push_to_hub: 106 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 107 | -------------------------------------------------------------------------------- /examples/scripts/dpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ############################################################################################### 16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/dpo.py # 17 | ############################################################################################### 18 | -------------------------------------------------------------------------------- /examples/scripts/evals/judge_tldr.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from datasets import load_dataset 19 | from transformers import HfArgumentParser 20 | from vllm import LLM, SamplingParams 21 | 22 | from trl import HfPairwiseJudge, OpenAIPairwiseJudge 23 | 24 | 25 | """ 26 | Examples: 27 | 28 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --num_examples 1000 29 | Model win rate: 31.40% 30 | 31 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 32 | Model win rate: 51.60% 33 | 34 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/rloo_tldr --judge_model gpt-4o-mini --num_examples 1000 35 | Model win rate: 51.20% 36 | 37 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --num_examples 1000 38 | Model win rate: 46.30% 39 | 40 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-3.5-turbo-0125 --num_examples 1000 41 | Model win rate: 52.50% 42 | 43 | python examples/scripts/evals/judge_tldr.py --model_name_or_path vwxyzjn/ppo_tldr --judge_model gpt-4o-mini --num_examples 1000 44 | Model win rate: 63.00% 45 | """ 46 | 47 | 48 | @dataclass 49 | class ScriptArguments: 50 | r""" 51 | Arguments for the script. 52 | 53 | Args: 54 | model_name_or_path (`str`): 55 | Model name or path to the model to evaluate. 56 | judge_model (`str`, *optional*, defaults to `"meta-llama/Meta-Llama-3-70B-Instruct"`): 57 | Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or 58 | 'meta-llama/Meta-Llama-3-70B-Instruct'. 59 | num_examples (`int` or `None`, *optional*, defaults to `None`): 60 | Number of examples to evaluate. 61 | """ 62 | 63 | model_name_or_path: str = field(metadata={"help": "Model name or path to the model to evaluate."}) 64 | judge_model: str = field( 65 | default="meta-llama/Meta-Llama-3-70B-Instruct", 66 | metadata={ 67 | "help": "Model name or path to the model to use as a judge. E.g., 'gpt-3.5-turbo-0125' or " 68 | "'meta-llama/Meta-Llama-3-70B-Instruct'." 69 | }, 70 | ) 71 | num_examples: Optional[int] = field(default=None, metadata={"help": "Number of examples to evaluate."}) 72 | 73 | 74 | # Parse the arguments 75 | parser = HfArgumentParser(ScriptArguments) 76 | script_args = parser.parse_args_into_dataclasses()[0] 77 | 78 | # Load the dataset 79 | dataset = load_dataset("trl-lib/tldr", split="validation") 80 | if script_args.num_examples is not None: 81 | dataset = dataset.select(range(script_args.num_examples)) 82 | 83 | # Extract the prompts and reference completions 84 | prompts = dataset["prompt"] 85 | reference_completions = dataset["completion"] 86 | 87 | # Generate the model completions 88 | sampling_params = SamplingParams(temperature=0.0, top_p=0.95, max_tokens=200) # very generous max token length 89 | llm = LLM(model=script_args.model_name_or_path, tensor_parallel_size=1) 90 | outputs = llm.generate(prompts, sampling_params) 91 | model_completions = [output.outputs[0].text.strip() for output in outputs] 92 | 93 | # Judge the outputs 94 | if "gpt" in script_args.judge_model: 95 | judge = OpenAIPairwiseJudge(script_args.judge_model) 96 | else: 97 | judge = HfPairwiseJudge(script_args.judge_model) 98 | 99 | completions = [[c0, c1] for c0, c1 in zip(reference_completions, model_completions)] 100 | best_idxs = judge.judge(prompts, completions) 101 | model_win_rate = best_idxs.count(1) / len(best_idxs) 102 | print(f"Model win rate: {model_win_rate * 100:.2f}%") 103 | -------------------------------------------------------------------------------- /examples/scripts/kto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. 17 | 18 | # Full training: 19 | python trl/scripts/kto.py \ 20 | --dataset_name trl-lib/kto-mix-14k \ 21 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 22 | --per_device_train_batch_size 16 \ 23 | --num_train_epochs 1 \ 24 | --learning_rate 5e-7 \ 25 | --lr_scheduler_type=cosine \ 26 | --gradient_accumulation_steps 1 \ 27 | --logging_steps 10 \ 28 | --eval_steps 500 \ 29 | --output_dir=kto-aligned-model \ 30 | --warmup_ratio 0.1 \ 31 | --report_to wandb \ 32 | --bf16 \ 33 | --logging_first_step 34 | 35 | # QLoRA: 36 | python trl/scripts/kto.py \ 37 | --dataset_name trl-lib/kto-mix-14k \ 38 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 39 | --per_device_train_batch_size 8 \ 40 | --num_train_epochs 1 \ 41 | --learning_rate 5e-7 \ 42 | --lr_scheduler_type=cosine \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 10 \ 45 | --eval_steps 500 \ 46 | --output_dir=kto-aligned-model-lora \ 47 | --warmup_ratio 0.1 \ 48 | --report_to wandb \ 49 | --bf16 \ 50 | --logging_first_step \ 51 | --use_peft \ 52 | --load_in_4bit \ 53 | --lora_target_modules=all-linear \ 54 | --lora_r=16 \ 55 | --lora_alpha=16 56 | """ 57 | 58 | from datasets import load_dataset 59 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 60 | 61 | from trl import ( 62 | KTOConfig, 63 | KTOTrainer, 64 | ModelConfig, 65 | ScriptArguments, 66 | get_peft_config, 67 | setup_chat_format, 68 | ) 69 | 70 | 71 | if __name__ == "__main__": 72 | parser = HfArgumentParser((ScriptArguments, KTOConfig, ModelConfig)) 73 | script_args, training_args, model_args = parser.parse_args_into_dataclasses() 74 | 75 | # Load a pretrained model 76 | model = AutoModelForCausalLM.from_pretrained( 77 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 78 | ) 79 | ref_model = AutoModelForCausalLM.from_pretrained( 80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 81 | ) 82 | 83 | tokenizer = AutoTokenizer.from_pretrained( 84 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 85 | ) 86 | if tokenizer.pad_token is None: 87 | tokenizer.pad_token = tokenizer.eos_token 88 | 89 | # If we are aligning a base model, we use ChatML as the default template 90 | if tokenizer.chat_template is None: 91 | model, tokenizer = setup_chat_format(model, tokenizer) 92 | 93 | # Load the dataset 94 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 95 | 96 | # Initialize the KTO trainer 97 | trainer = KTOTrainer( 98 | model, 99 | ref_model, 100 | args=training_args, 101 | train_dataset=dataset[script_args.dataset_train_split], 102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 103 | processing_class=tokenizer, 104 | peft_config=get_peft_config(model_args), 105 | ) 106 | 107 | # Train and push the model to the Hub 108 | trainer.train() 109 | 110 | # Save and push to hub 111 | trainer.save_model(training_args.output_dir) 112 | if training_args.push_to_hub: 113 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 114 | -------------------------------------------------------------------------------- /examples/scripts/orpo.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the ORPO training script with the following command with some example arguments. 17 | In general, the optimal configuration for ORPO will be similar to that of DPO without the need for a reference model: 18 | 19 | # regular: 20 | python examples/scripts/orpo.py \ 21 | --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ 22 | --model_name_or_path=gpt2 \ 23 | --per_device_train_batch_size 4 \ 24 | --max_steps 1000 \ 25 | --learning_rate 8e-6 \ 26 | --gradient_accumulation_steps 1 \ 27 | --logging_steps 10 \ 28 | --eval_steps 500 \ 29 | --output_dir="gpt2-aligned-orpo" \ 30 | --warmup_steps 150 \ 31 | --report_to wandb \ 32 | --bf16 \ 33 | --logging_first_step \ 34 | --no_remove_unused_columns 35 | 36 | # peft: 37 | python examples/scripts/orpo.py \ 38 | --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style \ 39 | --model_name_or_path=gpt2 \ 40 | --per_device_train_batch_size 4 \ 41 | --max_steps 1000 \ 42 | --learning_rate 8e-5 \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 10 \ 45 | --eval_steps 500 \ 46 | --output_dir="gpt2-lora-aligned-orpo" \ 47 | --optim rmsprop \ 48 | --warmup_steps 150 \ 49 | --report_to wandb \ 50 | --bf16 \ 51 | --logging_first_step \ 52 | --no_remove_unused_columns \ 53 | --use_peft \ 54 | --lora_r=16 \ 55 | --lora_alpha=16 56 | """ 57 | 58 | from datasets import load_dataset 59 | from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser 60 | 61 | from trl import ModelConfig, ORPOConfig, ORPOTrainer, ScriptArguments, get_peft_config 62 | from trl.trainer.utils import SIMPLE_CHAT_TEMPLATE 63 | 64 | 65 | if __name__ == "__main__": 66 | parser = HfArgumentParser((ScriptArguments, ORPOConfig, ModelConfig)) 67 | script_args, training_args, model_args = parser.parse_args_into_dataclasses() 68 | 69 | ################ 70 | # Model & Tokenizer 71 | ################ 72 | model = AutoModelForCausalLM.from_pretrained( 73 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 74 | ) 75 | tokenizer = AutoTokenizer.from_pretrained( 76 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 77 | ) 78 | if tokenizer.pad_token is None: 79 | tokenizer.pad_token = tokenizer.eos_token 80 | 81 | ################ 82 | # Dataset 83 | ################ 84 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 85 | if tokenizer.chat_template is None: 86 | tokenizer.chat_template = SIMPLE_CHAT_TEMPLATE 87 | 88 | ################ 89 | # Training 90 | ################ 91 | trainer = ORPOTrainer( 92 | model, 93 | args=training_args, 94 | train_dataset=dataset[script_args.dataset_train_split], 95 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 96 | processing_class=tokenizer, 97 | peft_config=get_peft_config(model_args), 98 | ) 99 | 100 | # train and save the model 101 | trainer.train() 102 | 103 | # Save and push to hub 104 | trainer.save_model(training_args.output_dir) 105 | if training_args.push_to_hub: 106 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 107 | -------------------------------------------------------------------------------- /examples/scripts/sft.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | ############################################################################################### 16 | # This file has been moved to https://github.com/huggingface/trl/blob/main/trl/scripts/sft.py # 17 | ############################################################################################### 18 | -------------------------------------------------------------------------------- /examples/scripts/sft_gemma3.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Train Gemma-3 on the Codeforces COTS dataset. 17 | 18 | accelerate launch --config_file examples/accelerate_configs/deepspeed_zero3.yaml examples/scripts/sft_gemma3.py 19 | """ 20 | 21 | from datasets import load_dataset 22 | from transformers import AutoModelForImageTextToText 23 | 24 | from trl import SFTConfig, SFTTrainer 25 | 26 | 27 | def main(): 28 | # Load dataset 29 | train_dataset = load_dataset("open-r1/codeforces-cots", split="train") 30 | train_dataset = train_dataset.remove_columns("prompt") 31 | 32 | # Load model 33 | model_id = "google/gemma-3-12b-it" 34 | model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager") 35 | 36 | # Train model 37 | training_args = SFTConfig( 38 | output_dir=f"{model_id}-codeforces-SFT", 39 | logging_steps=10, 40 | bf16=True, 41 | use_liger_kernel=True, 42 | gradient_checkpointing=True, 43 | gradient_checkpointing_kwargs={"use_reentrant": False}, 44 | max_length=8192, 45 | per_device_train_batch_size=1, 46 | gradient_accumulation_steps=8, 47 | dataset_num_proc=32, 48 | num_train_epochs=1, 49 | ) 50 | trainer = SFTTrainer( 51 | args=training_args, 52 | model=model, 53 | train_dataset=train_dataset, 54 | ) 55 | trainer.train() 56 | 57 | # Push to hub 58 | trainer.push_to_hub(dataset_name="open-r1/codeforces-cots") 59 | 60 | 61 | if __name__ == "__main__": 62 | main() 63 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.ruff] 2 | target-version = "py39" 3 | line-length = 119 4 | 5 | [tool.ruff.lint] 6 | ignore = [ 7 | "B028", # warning without explicit stacklevel 8 | "C408", # dict() calls (stylistic) 9 | "C901", # function complexity 10 | "E501", 11 | ] 12 | extend-select = ["E", "F", "I", "W", "UP", "B", "T", "C"] 13 | 14 | [tool.ruff.lint.per-file-ignores] 15 | # Allow prints in auxiliary scripts 16 | "examples/**.py" = ["T201"] 17 | "scripts/**.py" = ["T201"] 18 | # Ignore import violations in all `__init__.py` files. 19 | "__init__.py" = ["F401"] 20 | 21 | [tool.ruff.lint.isort] 22 | lines-after-imports = 2 23 | known-first-party = ["trl"] 24 | 25 | [tool.pytest.ini_options] 26 | markers = [ 27 | "slow: marks tests as slow (deselect with '-m \"not slow\"')", 28 | "low-priority: marks tests as low priority (deselect with '-m \"not low-priority\"')" 29 | ] 30 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate 2 | datasets 3 | rich 4 | transformers>=4.46.0 -------------------------------------------------------------------------------- /scripts/add_copyrights.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import subprocess 17 | import sys 18 | from datetime import datetime 19 | 20 | 21 | COPYRIGHT_HEADER = f"""# Copyright 2020-{datetime.now().year} The HuggingFace Team. All rights reserved. 22 | # 23 | # Licensed under the Apache License, Version 2.0 (the "License"); 24 | # you may not use this file except in compliance with the License. 25 | # You may obtain a copy of the License at 26 | # 27 | # http://www.apache.org/licenses/LICENSE-2.0 28 | # 29 | # Unless required by applicable law or agreed to in writing, software 30 | # distributed under the License is distributed on an "AS IS" BASIS, 31 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 32 | # See the License for the specific language governing permissions and 33 | # limitations under the License. 34 | """ 35 | 36 | 37 | def get_tracked_python_files(): 38 | """Get a list of all tracked Python files using git.""" 39 | try: 40 | # Get the list of all tracked files from Git 41 | result = subprocess.run(["git", "ls-files"], stdout=subprocess.PIPE, text=True, check=True) 42 | # Split the result by lines to get individual file paths 43 | files = result.stdout.splitlines() 44 | # Filter only Python files 45 | py_files = [f for f in files if f.endswith(".py")] 46 | return py_files 47 | except subprocess.CalledProcessError as e: 48 | print(f"Error fetching tracked files: {e}") 49 | return [] 50 | 51 | 52 | def check_and_add_copyright(file_path): 53 | """Check if the file contains a copyright notice, and add it if missing.""" 54 | if not os.path.isfile(file_path): 55 | print(f"[SKIP] {file_path} does not exist.") 56 | return 57 | 58 | with open(file_path, encoding="utf-8") as f: 59 | content = f.readlines() 60 | 61 | # Check if the exact copyright header exists 62 | if "".join(content).startswith(COPYRIGHT_HEADER): 63 | return True 64 | 65 | # If no copyright notice was found, prepend the header 66 | print(f"[MODIFY] Adding copyright to {file_path}.") 67 | with open(file_path, "w", encoding="utf-8") as f: 68 | # Write the copyright header followed by the original content 69 | f.write(COPYRIGHT_HEADER + "\n" + "".join(content)) 70 | return False 71 | 72 | 73 | def main(): 74 | """Main function to check and add copyright for all tracked Python files.""" 75 | py_files = get_tracked_python_files() 76 | if not py_files: 77 | print("No Python files are tracked in the repository.") 78 | return 79 | 80 | print(f"Checking {len(py_files)} Python files for copyright notice...") 81 | 82 | have_copyright = [check_and_add_copyright(file_path) for file_path in py_files] 83 | if not all(have_copyright): 84 | print("❌ Some files were missing the required copyright and have been updated.") 85 | sys.exit(1) 86 | else: 87 | print("✅ All files have the required copyright.") 88 | sys.exit(0) 89 | 90 | 91 | if __name__ == "__main__": 92 | main() 93 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = trl 3 | version = 0.19.0 4 | description = Train transformer language models with reinforcement learning. 5 | long_description = file: README.md 6 | long_description_content_type = text/markdown 7 | author = Leandro von Werra 8 | author_email = leandro.vonwerra@gmail.com 9 | url = https://github.com/huggingface/trl 10 | keywords = transformers, huggingface, language modeling, post-training, rlhf, sft, dpo, grpo 11 | license_file = LICENSE 12 | classifiers = 13 | Development Status :: 2 - Pre-Alpha 14 | Intended Audience :: Developers 15 | Intended Audience :: Science/Research 16 | Natural Language :: English 17 | Operating System :: OS Independent 18 | Programming Language :: Python :: 3 19 | Programming Language :: Python :: 3.9 20 | Programming Language :: Python :: 3.10 21 | Programming Language :: Python :: 3.11 22 | Programming Language :: Python :: 3.12 23 | Programming Language :: Python :: 3.13 24 | 25 | [options] 26 | packages = find: 27 | python_requires = >=3.9 28 | include_package_data = True 29 | install_requires = 30 | accelerate>=0.34.0 31 | datasets>=3.0.0 32 | transformers>=4.50.0 33 | 34 | [options.packages.find] 35 | exclude = 36 | tests* 37 | 38 | [options.package_data] 39 | trl = 40 | templates/*.md 41 | accelerate_configs/*.yaml 42 | 43 | [options.extras_require] 44 | bco = 45 | scikit-learn 46 | joblib 47 | deepspeed = 48 | deepspeed>=0.14.4 49 | diffusers = 50 | diffusers>=0.18.0 51 | judges = 52 | openai>=1.23.2 53 | llm-blender>=0.0.2 54 | liger = 55 | liger-kernel>=0.5.9 56 | mergekit = 57 | mergekit>=0.0.5.1 58 | peft = 59 | peft>=0.8.0 60 | quantization = 61 | bitsandbytes 62 | scikit = 63 | scikit-learn 64 | test = 65 | parameterized 66 | pytest-cov 67 | pytest-rerunfailures 68 | pytest-xdist 69 | pytest 70 | vllm = 71 | # vLLM package does not yet support Python 3.13. These constraints can be lifted once support is added: 72 | # see https://github.com/vllm-project/vllm/pull/13164 73 | vllm>=0.8.3; python_version < "3.13" 74 | fastapi; python_version < "3.13" 75 | pydantic; python_version < "3.13" 76 | requests; python_version < "3.13" 77 | uvicorn; python_version < "3.13" 78 | 79 | vlm = 80 | Pillow 81 | dev = 82 | %(bco)s 83 | %(deepspeed)s 84 | %(diffusers)s 85 | %(judges)s 86 | %(liger)s 87 | %(mergekit)s 88 | %(peft)s 89 | %(quantization)s 90 | %(scikit)s 91 | %(test)s 92 | %(vlm)s 93 | 94 | [options.entry_points] 95 | console_scripts = 96 | trl = trl.cli:main 97 | 98 | [coverage:run] 99 | branch = True 100 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from setuptools import setup 16 | 17 | 18 | setup() 19 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/slow/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /tests/slow/testing_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | MODELS_TO_TEST = [ 16 | "trl-internal-testing/tiny-LlamaForCausalLM-3.2", 17 | "trl-internal-testing/tiny-MistralForCausalLM-0.2", 18 | ] 19 | 20 | # We could have also not declared these variables but let's be verbose 21 | PACKING_OPTIONS = [True, False] 22 | GRADIENT_CHECKPOINTING_KWARGS = [None, {"use_reentrant": False}, {"use_reentrant": True}] 23 | DEVICE_MAP_OPTIONS = [{"": 0}, "auto"] 24 | 25 | DPO_LOSS_TYPES = ["sigmoid", "ipo"] 26 | DPO_PRECOMPUTE_LOGITS = [True, False] 27 | -------------------------------------------------------------------------------- /tests/test_alignprop_trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import gc 16 | import unittest 17 | 18 | import pytest 19 | import torch 20 | from parameterized import parameterized 21 | from transformers.utils import is_peft_available 22 | 23 | from trl.import_utils import is_diffusers_available 24 | 25 | from .testing_utils import require_diffusers 26 | 27 | 28 | if is_diffusers_available() and is_peft_available(): 29 | from trl import AlignPropConfig, AlignPropTrainer, DefaultDDPOStableDiffusionPipeline 30 | 31 | 32 | def scorer_function(images, prompts, metadata): 33 | return torch.randn(1) * 3.0, {} 34 | 35 | 36 | def prompt_function(): 37 | return ("cabbages", {}) 38 | 39 | 40 | @pytest.mark.low_priority 41 | @require_diffusers 42 | class AlignPropTrainerTester(unittest.TestCase): 43 | """ 44 | Test the AlignPropTrainer class. 45 | """ 46 | 47 | def setUp(self): 48 | training_args = AlignPropConfig( 49 | num_epochs=2, 50 | train_gradient_accumulation_steps=1, 51 | train_batch_size=2, 52 | truncated_backprop_rand=False, 53 | mixed_precision=None, 54 | save_freq=1000000, 55 | ) 56 | pretrained_model = "hf-internal-testing/tiny-stable-diffusion-torch" 57 | pretrained_revision = "main" 58 | pipeline_with_lora = DefaultDDPOStableDiffusionPipeline( 59 | pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=True 60 | ) 61 | pipeline_without_lora = DefaultDDPOStableDiffusionPipeline( 62 | pretrained_model, pretrained_model_revision=pretrained_revision, use_lora=False 63 | ) 64 | self.trainer_with_lora = AlignPropTrainer(training_args, scorer_function, prompt_function, pipeline_with_lora) 65 | self.trainer_without_lora = AlignPropTrainer( 66 | training_args, scorer_function, prompt_function, pipeline_without_lora 67 | ) 68 | 69 | def tearDown(self) -> None: 70 | gc.collect() 71 | 72 | @parameterized.expand([True, False]) 73 | def test_generate_samples(self, use_lora): 74 | trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora 75 | output_pairs = trainer._generate_samples(2, with_grad=True) 76 | self.assertEqual(len(output_pairs.keys()), 3) 77 | self.assertEqual(len(output_pairs["images"]), 2) 78 | 79 | @parameterized.expand([True, False]) 80 | def test_calculate_loss(self, use_lora): 81 | trainer = self.trainer_with_lora if use_lora else self.trainer_without_lora 82 | sample = trainer._generate_samples(2) 83 | 84 | images = sample["images"] 85 | prompts = sample["prompts"] 86 | 87 | self.assertTupleEqual(images.shape, (2, 3, 128, 128)) 88 | self.assertEqual(len(prompts), 2) 89 | 90 | rewards = trainer.compute_rewards(sample) 91 | loss = trainer.calculate_loss(rewards) 92 | 93 | self.assertTrue(torch.isfinite(loss.cpu())) 94 | -------------------------------------------------------------------------------- /tests/test_best_of_n_sampler.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | from transformers import AutoTokenizer, GenerationConfig 19 | 20 | from trl import AutoModelForCausalLMWithValueHead 21 | from trl.core import LengthSampler 22 | from trl.extras import BestOfNSampler 23 | 24 | 25 | def queries_to_scores(list_of_strings): 26 | return [torch.rand(1).item() for _ in list_of_strings] 27 | 28 | 29 | class BestOfNSamplerTester(unittest.TestCase): 30 | """ 31 | Tests the BestOfNSampler class 32 | """ 33 | 34 | ref_model_name = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" 35 | output_length_sampler = LengthSampler(2, 6) 36 | model = AutoModelForCausalLMWithValueHead.from_pretrained(ref_model_name) 37 | tokenizer = AutoTokenizer.from_pretrained(ref_model_name) 38 | tokenizer.pad_token = tokenizer.eos_token 39 | output_length_sampler = LengthSampler(2, 6) 40 | 41 | def test_different_input_types(self): 42 | r""" 43 | Tests if the different input types normalizer works 44 | """ 45 | 46 | generation_config = GenerationConfig( 47 | min_length=-1, 48 | top_k=0.0, 49 | top_p=1.0, 50 | do_sample=True, 51 | pad_token_id=self.tokenizer.eos_token_id, 52 | ) 53 | 54 | output_length_sampler = LengthSampler(2, 6) 55 | 56 | best_of_n = BestOfNSampler( 57 | self.model, 58 | self.tokenizer, 59 | queries_to_scores, 60 | length_sampler=output_length_sampler, 61 | generation_config=generation_config, 62 | ) 63 | 64 | queries = ["hello world", "goodbye world"] 65 | tokenized_queries = [self.tokenizer.encode(query) for query in queries] 66 | 67 | various_queries_formats = [ 68 | (tokenized_queries[0], 1), 69 | (tokenized_queries, 2), 70 | (torch.tensor(tokenized_queries[1]), 1), 71 | ([torch.tensor(query) for query in tokenized_queries], 2), 72 | ] 73 | 74 | for q, expected_length in various_queries_formats: 75 | results = best_of_n.generate(q) 76 | self.assertIsInstance(results, list) 77 | self.assertEqual(len(results), expected_length) 78 | 79 | def test_different_sample_sizes_and_n_candidates_values(self): 80 | r""" 81 | Tests different sample sizes and n_candidates values 82 | """ 83 | generation_config = GenerationConfig( 84 | min_length=-1, 85 | top_k=0.0, 86 | top_p=1.0, 87 | do_sample=True, 88 | pad_token_id=self.tokenizer.eos_token_id, 89 | ) 90 | 91 | output_length_sampler = LengthSampler(6, 10) 92 | 93 | for sample_value, n_candidates_values, expected in [ 94 | (4, 2, 2), 95 | (10, 3, 3), 96 | (6, 4, 4), 97 | ]: 98 | best_of_n = BestOfNSampler( 99 | self.model, 100 | self.tokenizer, 101 | queries_to_scores, 102 | length_sampler=output_length_sampler, 103 | generation_config=generation_config, 104 | sample_size=sample_value, 105 | n_candidates=n_candidates_values, 106 | ) 107 | 108 | queries = ["hello world", "troll the world"] 109 | tokenized_queries = [self.tokenizer.encode(query) for query in queries] 110 | results = best_of_n.generate(tokenized_queries) 111 | for result in results: 112 | self.assertEqual(len(result), expected) 113 | -------------------------------------------------------------------------------- /tests/test_collators.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | 19 | from trl.trainer.dpo_trainer import DataCollatorForPreference 20 | 21 | 22 | class TestDataCollatorForPreference(unittest.TestCase): 23 | def setUp(self): 24 | self.collator = DataCollatorForPreference(pad_token_id=0) 25 | 26 | def assertTensorEqual(self, tensor1, tensor2): 27 | self.assertTrue(torch.equal(tensor1, tensor2), f"Tensors are not equal:\n{tensor1}\n{tensor2}") 28 | 29 | def test_padding_behavior(self): 30 | examples = [ 31 | {"prompt_input_ids": [1, 2, 3], "chosen_input_ids": [4, 5], "rejected_input_ids": [6]}, 32 | {"prompt_input_ids": [7, 8], "chosen_input_ids": [9, 10], "rejected_input_ids": [11, 12, 13]}, 33 | ] 34 | output = self.collator.torch_call(examples) 35 | 36 | expected_prompt_input_ids = torch.tensor([[1, 2, 3], [0, 7, 8]]) 37 | expected_prompt_attention_mask = torch.tensor([[1, 1, 1], [0, 1, 1]]) 38 | expected_chosen_input_ids = torch.tensor([[4, 5], [9, 10]]) 39 | expected_chosen_attention_mask = torch.tensor([[1, 1], [1, 1]]) 40 | expected_rejected_input_ids = torch.tensor([[6, 0, 0], [11, 12, 13]]) 41 | expected_rejected_attention_mask = torch.tensor([[1, 0, 0], [1, 1, 1]]) 42 | 43 | self.assertTensorEqual(output["prompt_input_ids"], expected_prompt_input_ids) 44 | self.assertTensorEqual(output["prompt_attention_mask"], expected_prompt_attention_mask) 45 | self.assertTensorEqual(output["chosen_input_ids"], expected_chosen_input_ids) 46 | self.assertTensorEqual(output["chosen_attention_mask"], expected_chosen_attention_mask) 47 | self.assertTensorEqual(output["rejected_input_ids"], expected_rejected_input_ids) 48 | self.assertTensorEqual(output["rejected_attention_mask"], expected_rejected_attention_mask) 49 | 50 | def test_optional_fields(self): 51 | examples = [ 52 | { 53 | "prompt_input_ids": [1], 54 | "chosen_input_ids": [2], 55 | "rejected_input_ids": [3], 56 | "pixel_values": [[[0.1, 0.2], [0.3, 0.4]]], # Example 3D tensor (1x2x2) 57 | }, 58 | { 59 | "prompt_input_ids": [4], 60 | "chosen_input_ids": [5], 61 | "rejected_input_ids": [6], 62 | "pixel_values": [[[0.5, 0.6], [0.7, 0.8]]], # Example 3D tensor (1x2x2) 63 | }, 64 | ] 65 | output = self.collator.torch_call(examples) 66 | 67 | expected_pixel_values = torch.tensor( 68 | [ 69 | [[[0.1, 0.2], [0.3, 0.4]]], 70 | [[[0.5, 0.6], [0.7, 0.8]]], 71 | ] 72 | ) # Shape: (2, 1, 2, 2) 73 | 74 | self.assertTensorEqual(output["pixel_values"], expected_pixel_values) 75 | -------------------------------------------------------------------------------- /tests/test_core.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | 19 | from trl.core import masked_mean, masked_var, masked_whiten 20 | 21 | 22 | class CoreTester(unittest.TestCase): 23 | """ 24 | A wrapper class for testing core utils functions 25 | """ 26 | 27 | def setUp(self): 28 | self.test_input = torch.Tensor([1, 2, 3, 4]) 29 | self.test_mask = torch.Tensor([0, 1, 1, 0]) 30 | self.test_input_unmasked = self.test_input[1:3] 31 | 32 | def test_masked_mean(self): 33 | self.assertEqual(torch.mean(self.test_input_unmasked), masked_mean(self.test_input, self.test_mask)) 34 | 35 | def test_masked_var(self): 36 | self.assertEqual(torch.var(self.test_input_unmasked), masked_var(self.test_input, self.test_mask)) 37 | 38 | def test_masked_whiten(self): 39 | def whiten(values: torch.Tensor) -> torch.Tensor: 40 | mean, var = torch.mean(values), torch.var(values) 41 | return (values - mean) * torch.rsqrt(var + 1e-8) 42 | 43 | whiten_unmasked = whiten(self.test_input_unmasked) 44 | whiten_masked = masked_whiten(self.test_input, self.test_mask)[1:3] 45 | diffs = (whiten_unmasked - whiten_masked).sum() 46 | self.assertLess(abs(diffs.item()), 0.00001) 47 | -------------------------------------------------------------------------------- /tests/test_judges.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import time 16 | import unittest 17 | 18 | from trl import AllTrueJudge, HfPairwiseJudge, PairRMJudge 19 | 20 | from .testing_utils import RandomBinaryJudge, require_llm_blender 21 | 22 | 23 | class TestJudges(unittest.TestCase): 24 | def _get_prompts_and_pairwise_completions(self): 25 | prompts = ["The capital of France is", "The biggest planet in the solar system is"] 26 | completions = [["Paris", "Marseille"], ["Saturn", "Jupiter"]] 27 | return prompts, completions 28 | 29 | def _get_prompts_and_single_completions(self): 30 | prompts = ["What's the capital of France?", "What's the color of the sky?"] 31 | completions = ["Marseille", "blue"] 32 | return prompts, completions 33 | 34 | def test_all_true_judge(self): 35 | judge = AllTrueJudge(judges=[RandomBinaryJudge(), RandomBinaryJudge()]) 36 | prompts, completions = self._get_prompts_and_single_completions() 37 | judgements = judge.judge(prompts=prompts, completions=completions) 38 | self.assertEqual(len(judgements), 2) 39 | self.assertTrue(all(judgement in {0, 1, -1} for judgement in judgements)) 40 | 41 | @unittest.skip("This test needs to be run manually since it requires a valid Hugging Face API key.") 42 | def test_hugging_face_judge(self): 43 | judge = HfPairwiseJudge() 44 | prompts, completions = self._get_prompts_and_pairwise_completions() 45 | ranks = judge.judge(prompts=prompts, completions=completions) 46 | self.assertEqual(len(ranks), 2) 47 | self.assertTrue(all(isinstance(rank, int) for rank in ranks)) 48 | self.assertEqual(ranks, [0, 1]) 49 | 50 | def load_pair_rm_judge(self): 51 | # When using concurrent tests, PairRM may fail to load the model while another job is still downloading. 52 | # This is a workaround to retry loading the model a few times. 53 | for _ in range(5): 54 | try: 55 | return PairRMJudge() 56 | except ValueError: 57 | time.sleep(5) 58 | raise ValueError("Failed to load PairRMJudge") 59 | 60 | @require_llm_blender 61 | def test_pair_rm_judge(self): 62 | judge = self.load_pair_rm_judge() 63 | prompts, completions = self._get_prompts_and_pairwise_completions() 64 | ranks = judge.judge(prompts=prompts, completions=completions) 65 | self.assertEqual(len(ranks), 2) 66 | self.assertTrue(all(isinstance(rank, int) for rank in ranks)) 67 | self.assertEqual(ranks, [0, 1]) 68 | 69 | @require_llm_blender 70 | def test_pair_rm_judge_return_scores(self): 71 | judge = self.load_pair_rm_judge() 72 | prompts, completions = self._get_prompts_and_pairwise_completions() 73 | probs = judge.judge(prompts=prompts, completions=completions, return_scores=True) 74 | self.assertEqual(len(probs), 2) 75 | self.assertTrue(all(isinstance(prob, float) for prob in probs)) 76 | self.assertTrue(all(0 <= prob <= 1 for prob in probs)) 77 | -------------------------------------------------------------------------------- /tests/test_modeling_geometric_mixture_wrapper.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | import torch 18 | from transformers import AutoModelForCausalLM, GenerationConfig 19 | 20 | from trl.models.modeling_base import GeometricMixtureWrapper, create_reference_model 21 | 22 | 23 | class TestGeometricMixtureWrapper(unittest.TestCase): 24 | def setUp(self): 25 | model_id = "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5" 26 | self.model = AutoModelForCausalLM.from_pretrained(model_id) 27 | self.ref_model = create_reference_model(self.model) 28 | self.generation_config = GenerationConfig.from_pretrained(model_id) 29 | self.mixture_coef = 0.5 30 | self.wrapper = GeometricMixtureWrapper( 31 | self.model, self.ref_model, self.generation_config, mixture_coef=self.mixture_coef 32 | ) 33 | 34 | def test_forward(self): 35 | input_ids = torch.tensor([[1, 2, 3, 4, 5]]) 36 | attention_mask = torch.ones_like(input_ids) 37 | 38 | output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) 39 | 40 | self.assertIsNotNone(output) 41 | self.assertTrue(hasattr(output, "logits")) 42 | self.assertEqual(output.logits.shape, (1, 5, self.model.config.vocab_size)) 43 | 44 | def test_mixture_coefficient(self): 45 | input_ids = torch.tensor([[1, 2, 3, 4, 5]]) 46 | attention_mask = torch.ones_like(input_ids) 47 | 48 | with torch.no_grad(): 49 | model_output = self.model(input_ids=input_ids, attention_mask=attention_mask) 50 | ref_model_output = self.ref_model(input_ids=input_ids, attention_mask=attention_mask) 51 | wrapper_output = self.wrapper(input_ids=input_ids, attention_mask=attention_mask) 52 | 53 | expected_logits = torch.nn.functional.log_softmax( 54 | self.mixture_coef * ref_model_output.logits + (1 - self.mixture_coef) * model_output.logits, dim=-1 55 | ) 56 | 57 | self.assertTrue(torch.allclose(wrapper_output.logits, expected_logits, atol=1e-5)) 58 | 59 | def test_prepare_inputs_for_generation(self): 60 | input_ids = torch.tensor([[1, 2, 3, 4, 5]]) 61 | attention_mask = torch.ones_like(input_ids) 62 | 63 | inputs = self.wrapper.prepare_inputs_for_generation(input_ids, attention_mask=attention_mask, use_cache=True) 64 | 65 | self.assertIn("input_ids", inputs) 66 | self.assertIn("attention_mask", inputs) 67 | self.assertFalse(inputs.get("use_cache", False)) 68 | -------------------------------------------------------------------------------- /tests/test_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import unittest 16 | 17 | from trl.rewards import think_format_reward 18 | 19 | 20 | class ThinkFormatRewardTester(unittest.TestCase): 21 | def test_valid_format(self): 22 | completions = [ 23 | "This is my reasoning.This is my answer.", # Simple, one-line reasoning 24 | "\nThis is my reasoning.\n\nThis is my answer.", # Multiline reasoning 25 | "\nThis is\nmy reasoning.\n\nThis is my answer.", # Multiline reasoning 26 | "\nThis is my reasoning.\nThis is my answer.", # Reasoning including other tags 27 | "\nThis is my answer.", # Empty reasoning 28 | ] 29 | completions = [[{"content": completion}] for completion in completions] 30 | expected_rewards = [1.0, 1.0, 1.0, 1.0, 1.0] # All should be valid 31 | rewards = think_format_reward(completions) 32 | self.assertEqual(rewards, expected_rewards) 33 | 34 | def test_invalid_format(self): 35 | completions = [ 36 | "\nThis is my reasoning.\nThis is my answer.", # No closing 37 | "This is my reasoning.\nThis is my answer.", # No closing 38 | "This is my reasoning. This is my answer.", # No tags 39 | "This is my reasoning.\nThis is my answer.", # No tags 40 | "This is my reasoning.\nThis is my answer.", # No opening 41 | "This is my reasoning.This is my answer.", # No opening 42 | "Thisis my reasoning.\nThis is my answer.", # tag in the middle 43 | "This ismy reasoning.This is my answer.", # Nested tags 44 | "This is\nmy\nreasoning.\nThis is my answer.", # Multiline 45 | ] 46 | completions = [[{"content": completion}] for completion in completions] 47 | expected_rewards = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] # All should be invalid 48 | rewards = think_format_reward(completions) 49 | self.assertEqual(rewards, expected_rewards) 50 | 51 | def test_mixed_format(self): 52 | completions = [ 53 | "This is my reasoning.This is my answer.", # Valid 54 | "\nThis is my reasoning.\n\nThis is my answer.", # Valid 55 | "This is my reasoning.\nThis is my answer.", # Invalid 56 | "This is my reasoning. This is my answer.", # Invalid 57 | ] 58 | completions = [[{"content": completion}] for completion in completions] 59 | expected_rewards = [1.0, 1.0, 0.0, 0.0] 60 | rewards = think_format_reward(completions) 61 | self.assertEqual(rewards, expected_rewards) 62 | 63 | 64 | if __name__ == "__main__": 65 | unittest.main() 66 | -------------------------------------------------------------------------------- /tests/test_rich_progress_callback.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import tempfile 16 | import unittest 17 | 18 | import torch 19 | import torch.nn as nn 20 | from datasets import Dataset 21 | from transformers import Trainer, TrainingArguments 22 | 23 | from trl.trainer.callbacks import RichProgressCallback 24 | 25 | from .testing_utils import require_rich 26 | 27 | 28 | class DummyModel(nn.Module): 29 | def __init__(self): 30 | super().__init__() 31 | self.a = nn.Parameter(torch.tensor(1.0)) 32 | 33 | def forward(self, x): 34 | return self.a * x 35 | 36 | 37 | @require_rich 38 | class TestRichProgressCallback(unittest.TestCase): 39 | def setUp(self): 40 | self.dummy_model = DummyModel() 41 | self.dummy_train_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 5) 42 | self.dummy_val_dataset = Dataset.from_list([{"x": 1.0, "y": 2.0}] * 101) 43 | 44 | def test_rich_progress_callback_logging(self): 45 | with tempfile.TemporaryDirectory() as tmp_dir: 46 | training_args = TrainingArguments( 47 | output_dir=tmp_dir, 48 | per_device_eval_batch_size=2, 49 | per_device_train_batch_size=2, 50 | num_train_epochs=4, 51 | eval_strategy="steps", 52 | eval_steps=1, 53 | logging_strategy="steps", 54 | logging_steps=1, 55 | save_strategy="no", 56 | report_to="none", 57 | disable_tqdm=True, 58 | ) 59 | callbacks = [RichProgressCallback()] 60 | trainer = Trainer( 61 | model=self.dummy_model, 62 | train_dataset=self.dummy_train_dataset, 63 | eval_dataset=self.dummy_val_dataset, 64 | args=training_args, 65 | callbacks=callbacks, 66 | ) 67 | 68 | trainer.train() 69 | trainer.train() 70 | -------------------------------------------------------------------------------- /tests/testing_constants.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | CI_HUB_USER = "__DUMMY_TRANSFORMERS_USER__" 16 | CI_HUB_USER_FULL_NAME = "Dummy User" 17 | 18 | CI_HUB_ENDPOINT = "https://hub-ci.huggingface.co" 19 | -------------------------------------------------------------------------------- /trl/accelerate_configs/fsdp1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: FSDP 4 | downcast_bf16: 'no' 5 | enable_cpu_affinity: false 6 | fsdp_config: 7 | fsdp_activation_checkpointing: false 8 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 9 | fsdp_backward_prefetch: BACKWARD_PRE 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_forward_prefetch: true 12 | fsdp_offload_params: false 13 | fsdp_reshard_after_forward: FULL_SHARD 14 | fsdp_state_dict_type: FULL_STATE_DICT 15 | fsdp_sync_module_states: true 16 | fsdp_use_orig_params: true 17 | fsdp_version: 1 18 | machine_rank: 0 19 | main_training_function: main 20 | mixed_precision: bf16 21 | num_machines: 1 22 | num_processes: 8 23 | rdzv_backend: static 24 | same_network: true 25 | tpu_env: [] 26 | tpu_use_cluster: false 27 | tpu_use_sudo: false 28 | use_cpu: false 29 | -------------------------------------------------------------------------------- /trl/accelerate_configs/fsdp2.yaml: -------------------------------------------------------------------------------- 1 | # Requires accelerate 1.7.0 or higher 2 | compute_environment: LOCAL_MACHINE 3 | debug: false 4 | distributed_type: FSDP 5 | downcast_bf16: 'no' 6 | enable_cpu_affinity: false 7 | fsdp_config: 8 | fsdp_activation_checkpointing: false 9 | fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP 10 | fsdp_cpu_ram_efficient_loading: true 11 | fsdp_offload_params: false 12 | fsdp_reshard_after_forward: true 13 | fsdp_state_dict_type: FULL_STATE_DICT 14 | fsdp_version: 2 15 | machine_rank: 0 16 | main_training_function: main 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | tpu_env: [] 23 | tpu_use_cluster: false 24 | tpu_use_sudo: false 25 | use_cpu: false 26 | -------------------------------------------------------------------------------- /trl/accelerate_configs/multi_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /trl/accelerate_configs/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: "NO" 4 | downcast_bf16: 'no' 5 | gpu_ids: all 6 | machine_rank: 0 7 | main_training_function: main 8 | mixed_precision: 'bf16' 9 | num_machines: 1 10 | num_processes: 8 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero1.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | gradient_accumulation_steps: 1 6 | zero3_init_flag: false 7 | zero_stage: 1 8 | distributed_type: DEEPSPEED 9 | downcast_bf16: 'no' 10 | machine_rank: 0 11 | main_training_function: main 12 | mixed_precision: 'bf16' 13 | num_machines: 1 14 | num_processes: 8 15 | rdzv_backend: static 16 | same_network: true 17 | tpu_env: [] 18 | tpu_use_cluster: false 19 | tpu_use_sudo: false 20 | use_cpu: false 21 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero2.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: false 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | machine_rank: 0 12 | main_training_function: main 13 | mixed_precision: 'bf16' 14 | num_machines: 1 15 | num_processes: 8 16 | rdzv_backend: static 17 | same_network: true 18 | tpu_env: [] 19 | tpu_use_cluster: false 20 | tpu_use_sudo: false 21 | use_cpu: false 22 | -------------------------------------------------------------------------------- /trl/accelerate_configs/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | deepspeed_config: 4 | deepspeed_multinode_launcher: standard 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero3_save_16bit_model: true 9 | zero_stage: 3 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: 'no' 12 | machine_rank: 0 13 | main_training_function: main 14 | mixed_precision: bf16 15 | num_machines: 1 16 | num_processes: 8 17 | rdzv_backend: static 18 | same_network: true 19 | tpu_env: [] 20 | tpu_use_cluster: false 21 | tpu_use_sudo: false 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /trl/environment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "base_environment": ["TextEnvironment", "TextHistory"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .base_environment import TextEnvironment, TextHistory 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /trl/extras/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "best_of_n_sampler": ["BestOfNSampler"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .best_of_n_sampler import BestOfNSampler 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /trl/extras/profiling.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import contextlib 16 | import functools 17 | import time 18 | from collections.abc import Generator 19 | 20 | from transformers import Trainer 21 | from transformers.integrations import is_mlflow_available, is_wandb_available 22 | 23 | 24 | if is_wandb_available(): 25 | import wandb 26 | 27 | if is_mlflow_available(): 28 | import mlflow 29 | 30 | 31 | @contextlib.contextmanager 32 | def profiling_context(trainer: Trainer, name: str) -> Generator[None, None, None]: 33 | """ 34 | A context manager function for profiling a block of code. Results are logged to Weights & Biases or MLflow 35 | depending on the trainer's configuration. 36 | 37 | Args: 38 | trainer (`~transformers.Trainer`): 39 | Trainer object. 40 | name (`str`): 41 | Name of the block to be profiled. Used as a key in the logged dictionary. 42 | 43 | Example: 44 | ```python 45 | from transformers import Trainer 46 | from trl.extras.profiling import profiling_context 47 | 48 | class MyTrainer(Trainer): 49 | def some_method(self): 50 | A = np.random.rand(1000, 1000) 51 | B = np.random.rand(1000, 1000) 52 | with profiling_context(self, "matrix_multiplication"): 53 | # Code to profile: simulate a computationally expensive operation 54 | result = A @ B # Matrix multiplication 55 | ``` 56 | """ 57 | start_time = time.perf_counter() 58 | yield 59 | end_time = time.perf_counter() 60 | duration = end_time - start_time 61 | 62 | profiling_metrics = {f"profiling/Time taken: {trainer.__class__.__name__}.{name}": duration} 63 | if "wandb" in trainer.args.report_to and wandb.run is not None and trainer.accelerator.is_main_process: 64 | wandb.log(profiling_metrics) 65 | 66 | if "mlflow" in trainer.args.report_to and mlflow.run is not None and trainer.accelerator.is_main_process: 67 | mlflow.log_metrics(profiling_metrics, step=trainer.state.global_step) 68 | 69 | 70 | def profiling_decorator(func: callable) -> callable: 71 | """ 72 | Decorator to profile a function and log execution time using [`extras.profiling.profiling_context`]. 73 | 74 | Args: 75 | func (`callable`): 76 | Function to be profiled. 77 | 78 | Example: 79 | ```python 80 | from transformers import Trainer 81 | from trl.extras.profiling import profiling_decorator 82 | 83 | class MyTrainer(Trainer): 84 | @profiling_decorator 85 | def some_method(self): 86 | A = np.random.rand(1000, 1000) 87 | B = np.random.rand(1000, 1000) 88 | # Code to profile: simulate a computationally expensive operation 89 | result = A @ B 90 | ``` 91 | """ 92 | 93 | @functools.wraps(func) 94 | def wrapper(self, *args, **kwargs): 95 | with profiling_context(self, func.__name__): 96 | return func(self, *args, **kwargs) 97 | 98 | return wrapper 99 | -------------------------------------------------------------------------------- /trl/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import OptionalDependencyNotAvailable, _LazyModule, is_diffusers_available 18 | 19 | 20 | _import_structure = { 21 | "activation_offloading": ["get_act_offloading_ctx_manager"], 22 | "modeling_base": ["GeometricMixtureWrapper", "PreTrainedModelWrapper", "create_reference_model"], 23 | "modeling_value_head": ["AutoModelForCausalLMWithValueHead", "AutoModelForSeq2SeqLMWithValueHead"], 24 | "utils": [ 25 | "SUPPORTED_ARCHITECTURES", 26 | "prepare_deepspeed", 27 | "prepare_fsdp", 28 | "setup_chat_format", 29 | "unwrap_model_for_generation", 30 | ], 31 | } 32 | 33 | try: 34 | if not is_diffusers_available(): 35 | raise OptionalDependencyNotAvailable() 36 | except OptionalDependencyNotAvailable: 37 | pass 38 | else: 39 | _import_structure["modeling_sd_base"] = [ 40 | "DDPOPipelineOutput", 41 | "DDPOSchedulerOutput", 42 | "DDPOStableDiffusionPipeline", 43 | "DefaultDDPOStableDiffusionPipeline", 44 | ] 45 | 46 | if TYPE_CHECKING: 47 | from .activation_offloading import get_act_offloading_ctx_manager 48 | from .modeling_base import GeometricMixtureWrapper, PreTrainedModelWrapper, create_reference_model 49 | from .modeling_value_head import AutoModelForCausalLMWithValueHead, AutoModelForSeq2SeqLMWithValueHead 50 | from .utils import ( 51 | SUPPORTED_ARCHITECTURES, 52 | prepare_deepspeed, 53 | prepare_fsdp, 54 | setup_chat_format, 55 | unwrap_model_for_generation, 56 | ) 57 | 58 | try: 59 | if not is_diffusers_available(): 60 | raise OptionalDependencyNotAvailable() 61 | except OptionalDependencyNotAvailable: 62 | pass 63 | else: 64 | from .modeling_sd_base import ( 65 | DDPOPipelineOutput, 66 | DDPOSchedulerOutput, 67 | DDPOStableDiffusionPipeline, 68 | DefaultDDPOStableDiffusionPipeline, 69 | ) 70 | else: 71 | import sys 72 | 73 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 74 | -------------------------------------------------------------------------------- /trl/models/auxiliary_modules.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torchvision 20 | from huggingface_hub import hf_hub_download 21 | from huggingface_hub.utils import EntryNotFoundError 22 | from transformers import CLIPModel, is_torch_npu_available, is_torch_xpu_available 23 | 24 | 25 | class MLP(nn.Module): 26 | def __init__(self): 27 | super().__init__() 28 | self.layers = nn.Sequential( 29 | nn.Linear(768, 1024), 30 | nn.Dropout(0.2), 31 | nn.Linear(1024, 128), 32 | nn.Dropout(0.2), 33 | nn.Linear(128, 64), 34 | nn.Dropout(0.1), 35 | nn.Linear(64, 16), 36 | nn.Linear(16, 1), 37 | ) 38 | 39 | def forward(self, embed): 40 | return self.layers(embed) 41 | 42 | 43 | class AestheticScorer(torch.nn.Module): 44 | """ 45 | This model attempts to predict the aesthetic score of an image. The aesthetic score 46 | is a numerical approximation of how much a specific image is liked by humans on average. 47 | This is from https://github.com/christophschuhmann/improved-aesthetic-predictor 48 | """ 49 | 50 | def __init__(self, *, dtype, model_id, model_filename): 51 | super().__init__() 52 | self.clip = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") 53 | self.normalize = torchvision.transforms.Normalize( 54 | mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711] 55 | ) 56 | self.target_size = 224 57 | self.mlp = MLP() 58 | try: 59 | cached_path = hf_hub_download(model_id, model_filename) 60 | except EntryNotFoundError: 61 | cached_path = os.path.join(model_id, model_filename) 62 | state_dict = torch.load(cached_path, map_location=torch.device("cpu"), weights_only=True) 63 | self.mlp.load_state_dict(state_dict) 64 | self.dtype = dtype 65 | self.eval() 66 | 67 | def __call__(self, images): 68 | device = next(self.parameters()).device 69 | images = torchvision.transforms.Resize(self.target_size)(images) 70 | images = self.normalize(images).to(self.dtype).to(device) 71 | embed = self.clip.get_image_features(pixel_values=images) 72 | # normalize embedding 73 | embed = embed / torch.linalg.vector_norm(embed, dim=-1, keepdim=True) 74 | reward = self.mlp(embed).squeeze(1) 75 | return reward 76 | 77 | 78 | def aesthetic_scorer(hub_model_id, model_filename): 79 | scorer = AestheticScorer( 80 | model_id=hub_model_id, 81 | model_filename=model_filename, 82 | dtype=torch.float32, 83 | ) 84 | if is_torch_npu_available(): 85 | scorer = scorer.npu() 86 | elif is_torch_xpu_available(): 87 | scorer = scorer.xpu() 88 | else: 89 | scorer = scorer.cuda() 90 | 91 | def _fn(images, prompts, metadata): 92 | images = (images).clamp(0, 1) 93 | scores = scorer(images) 94 | return scores, {} 95 | 96 | return _fn 97 | -------------------------------------------------------------------------------- /trl/rewards/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | import sys 17 | from typing import TYPE_CHECKING 18 | 19 | from ..import_utils import _LazyModule 20 | 21 | 22 | _import_structure = { 23 | "format_rewards": ["think_format_reward"], 24 | } 25 | 26 | 27 | if TYPE_CHECKING: 28 | from .format_rewards import think_format_reward 29 | 30 | 31 | else: 32 | sys.modules[__name__] = _LazyModule(__name__, __file__, _import_structure, module_spec=__spec__) 33 | -------------------------------------------------------------------------------- /trl/rewards/format_rewards.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import re 16 | 17 | 18 | def think_format_reward(completions: list[list[dict[str, str]]], **kwargs) -> list[float]: 19 | r""" 20 | Reward function that checks if the reasoning process is enclosed within `""` and `""` tags. The 21 | function returns a reward of 1.0 if the format is correct, otherwise 0.0. 22 | 23 | Args: 24 | completions (`list[list[dict[str, str]]]`): 25 | List of completions to be evaluated. Each completion must be a list of one message, i.e. a dictionary 26 | containing the key `"content"` with the value being the text of the completion. 27 | **kwargs: 28 | Additional keyword arguments. This function does not use them, but they are required in the function 29 | signature to ensure compatibility with trainers like [`GRPOTrainer`]. 30 | 31 | Returns: 32 | `list[float]`: 33 | A list of rewards, where each reward is 1.0 if the completion matches the expected format, otherwise 0.0. 34 | 35 | Example: 36 | ```python 37 | >>> from trl.rewards import think_format_reward 38 | >>> completions = [ 39 | ... [{"content": "\nThis is my reasoning.\n\nThis is my answer."}], 40 | ... [{"content": "\nThis is my reasoning.\nThis is my answer."}], 41 | ... ] 42 | >>> think_format_reward(completions) 43 | [1.0, 0.0] 44 | ``` 45 | """ 46 | pattern = r"^(?!.*)(.*?).*$" 47 | completion_contents = [completion[0]["content"] for completion in completions] 48 | matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents] 49 | return [1.0 if match else 0.0 for match in matches] 50 | -------------------------------------------------------------------------------- /trl/scripts/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from typing import TYPE_CHECKING 16 | 17 | from ..import_utils import _LazyModule 18 | 19 | 20 | _import_structure = { 21 | "utils": ["init_zero_verbose", "ScriptArguments", "TrlParser"], 22 | } 23 | 24 | if TYPE_CHECKING: 25 | from .utils import ScriptArguments, TrlParser, init_zero_verbose 26 | else: 27 | import sys 28 | 29 | sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__) 30 | -------------------------------------------------------------------------------- /trl/scripts/env.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import os 16 | import platform 17 | from importlib.metadata import version 18 | 19 | import torch 20 | from accelerate.commands.config import default_config_file, load_config_from_file 21 | from transformers import is_bitsandbytes_available 22 | from transformers.utils import is_openai_available, is_peft_available 23 | 24 | from .. import __version__ 25 | from ..import_utils import ( 26 | is_deepspeed_available, 27 | is_diffusers_available, 28 | is_liger_kernel_available, 29 | is_llm_blender_available, 30 | is_vllm_available, 31 | ) 32 | from .utils import get_git_commit_hash 33 | 34 | 35 | def print_env(): 36 | devices = None 37 | if torch.cuda.is_available(): 38 | devices = [torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())] 39 | elif torch.backends.mps.is_available(): 40 | devices = ["MPS"] 41 | elif torch.xpu.is_available(): 42 | devices = [torch.xpu.get_device_name(i) for i in range(torch.xpu.device_count())] 43 | 44 | accelerate_config = accelerate_config_str = "not found" 45 | 46 | # Get the default from the config file. 47 | if os.path.isfile(default_config_file): 48 | accelerate_config = load_config_from_file(default_config_file).to_dict() 49 | 50 | accelerate_config_str = ( 51 | "\n" + "\n".join([f" - {prop}: {val}" for prop, val in accelerate_config.items()]) 52 | if isinstance(accelerate_config, dict) 53 | else accelerate_config 54 | ) 55 | 56 | commit_hash = get_git_commit_hash("trl") 57 | 58 | info = { 59 | "Platform": platform.platform(), 60 | "Python version": platform.python_version(), 61 | "TRL version": f"{__version__}+{commit_hash[:7]}" if commit_hash else __version__, 62 | "PyTorch version": version("torch"), 63 | "accelerator(s)": ", ".join(devices) if devices is not None else "cpu", 64 | "Transformers version": version("transformers"), 65 | "Accelerate version": version("accelerate"), 66 | "Accelerate config": accelerate_config_str, 67 | "Datasets version": version("datasets"), 68 | "HF Hub version": version("huggingface_hub"), 69 | "bitsandbytes version": version("bitsandbytes") if is_bitsandbytes_available() else "not installed", 70 | "DeepSpeed version": version("deepspeed") if is_deepspeed_available() else "not installed", 71 | "Diffusers version": version("diffusers") if is_diffusers_available() else "not installed", 72 | "Liger-Kernel version": version("liger_kernel") if is_liger_kernel_available() else "not installed", 73 | "LLM-Blender version": version("llm_blender") if is_llm_blender_available() else "not installed", 74 | "OpenAI version": version("openai") if is_openai_available() else "not installed", 75 | "PEFT version": version("peft") if is_peft_available() else "not installed", 76 | "vLLM version": version("vllm") if is_vllm_available() else "not installed", 77 | } 78 | 79 | info_str = "\n".join([f"- {prop}: {val}" for prop, val in info.items()]) 80 | print(f"\nCopy-paste the following information when reporting an issue:\n\n{info_str}\n") # noqa 81 | 82 | 83 | if __name__ == "__main__": 84 | print_env() 85 | -------------------------------------------------------------------------------- /trl/scripts/kto.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ 16 | Run the KTO training script with the commands below. In general, the optimal configuration for KTO will be similar to that of DPO. 17 | 18 | # Full training: 19 | python trl/scripts/kto.py \ 20 | --dataset_name trl-lib/kto-mix-14k \ 21 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 22 | --per_device_train_batch_size 16 \ 23 | --num_train_epochs 1 \ 24 | --learning_rate 5e-7 \ 25 | --lr_scheduler_type=cosine \ 26 | --gradient_accumulation_steps 1 \ 27 | --logging_steps 10 \ 28 | --eval_steps 500 \ 29 | --output_dir=kto-aligned-model \ 30 | --warmup_ratio 0.1 \ 31 | --report_to wandb \ 32 | --bf16 \ 33 | --logging_first_step 34 | 35 | # QLoRA: 36 | python trl/scripts/kto.py \ 37 | --dataset_name trl-lib/kto-mix-14k \ 38 | --model_name_or_path=trl-lib/qwen1.5-1.8b-sft \ 39 | --per_device_train_batch_size 8 \ 40 | --num_train_epochs 1 \ 41 | --learning_rate 5e-7 \ 42 | --lr_scheduler_type=cosine \ 43 | --gradient_accumulation_steps 1 \ 44 | --logging_steps 10 \ 45 | --eval_steps 500 \ 46 | --output_dir=kto-aligned-model-lora \ 47 | --warmup_ratio 0.1 \ 48 | --report_to wandb \ 49 | --bf16 \ 50 | --logging_first_step \ 51 | --use_peft \ 52 | --load_in_4bit \ 53 | --lora_target_modules=all-linear \ 54 | --lora_r=16 \ 55 | --lora_alpha=16 56 | """ 57 | 58 | import argparse 59 | 60 | from datasets import load_dataset 61 | from transformers import AutoModelForCausalLM, AutoTokenizer 62 | 63 | from trl import ( 64 | KTOConfig, 65 | KTOTrainer, 66 | ModelConfig, 67 | ScriptArguments, 68 | TrlParser, 69 | get_peft_config, 70 | setup_chat_format, 71 | ) 72 | 73 | 74 | def main(script_args, training_args, model_args): 75 | # Load a pretrained model 76 | model = AutoModelForCausalLM.from_pretrained( 77 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 78 | ) 79 | ref_model = AutoModelForCausalLM.from_pretrained( 80 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 81 | ) 82 | 83 | tokenizer = AutoTokenizer.from_pretrained( 84 | model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code 85 | ) 86 | if tokenizer.pad_token is None: 87 | tokenizer.pad_token = tokenizer.eos_token 88 | 89 | # If we are aligning a base model, we use ChatML as the default template 90 | if tokenizer.chat_template is None: 91 | model, tokenizer = setup_chat_format(model, tokenizer) 92 | 93 | # Load the dataset 94 | dataset = load_dataset(script_args.dataset_name, name=script_args.dataset_config) 95 | 96 | # Initialize the KTO trainer 97 | trainer = KTOTrainer( 98 | model, 99 | ref_model, 100 | args=training_args, 101 | train_dataset=dataset[script_args.dataset_train_split], 102 | eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None, 103 | processing_class=tokenizer, 104 | peft_config=get_peft_config(model_args), 105 | ) 106 | 107 | # Train and push the model to the Hub 108 | trainer.train() 109 | 110 | # Save and push to hub 111 | trainer.save_model(training_args.output_dir) 112 | if training_args.push_to_hub: 113 | trainer.push_to_hub(dataset_name=script_args.dataset_name) 114 | 115 | 116 | def make_parser(subparsers: argparse._SubParsersAction = None): 117 | dataclass_types = (ScriptArguments, KTOConfig, ModelConfig) 118 | if subparsers is not None: 119 | parser = subparsers.add_parser("kto", help="Run the KTO training script", dataclass_types=dataclass_types) 120 | else: 121 | parser = TrlParser(dataclass_types) 122 | return parser 123 | 124 | 125 | if __name__ == "__main__": 126 | parser = make_parser() 127 | script_args, training_args, model_args = parser.parse_args_and_config() 128 | main(script_args, training_args, model_args) 129 | -------------------------------------------------------------------------------- /trl/templates/lm_model_card.md: -------------------------------------------------------------------------------- 1 | --- 2 | {{ card_data }} 3 | --- 4 | 5 | # Model Card for {{ model_name }} 6 | 7 | This model is a fine-tuned version of [{{ base_model }}](https://huggingface.co/{{ base_model }}){% if dataset_name %} on the [{{ dataset_name }}](https://huggingface.co/datasets/{{ dataset_name }}) dataset{% endif %}. 8 | It has been trained using [TRL](https://github.com/huggingface/trl). 9 | 10 | ## Quick start 11 | 12 | ```python 13 | from transformers import pipeline 14 | 15 | question = "If you had a time machine, but could only go to the past or the future once and never return, which would you choose and why?" 16 | generator = pipeline("text-generation", model="{{ hub_model_id }}", device="cuda") 17 | output = generator([{"role": "user", "content": question}], max_new_tokens=128, return_full_text=False)[0] 18 | print(output["generated_text"]) 19 | ``` 20 | 21 | ## Training procedure 22 | 23 | {% if wandb_url %}[Visualize in Weights & Biases]({{ wandb_url }}){% endif %} 24 | {% if comet_url %}[Visualize in Comet]({{ comet_url }}){% endif %} 25 | 26 | This model was trained with {{ trainer_name }}{% if paper_id %}, a method introduced in [{{ paper_title }}](https://huggingface.co/papers/{{ paper_id }}){% endif %}. 27 | 28 | ### Framework versions 29 | 30 | - TRL: {{ trl_version }} 31 | - Transformers: {{ transformers_version }} 32 | - Pytorch: {{ pytorch_version }} 33 | - Datasets: {{ datasets_version }} 34 | - Tokenizers: {{ tokenizers_version }} 35 | 36 | ## Citations 37 | 38 | {% if trainer_citation %}Cite {{ trainer_name }} as: 39 | 40 | ```bibtex 41 | {{ trainer_citation }} 42 | ```{% endif %} 43 | 44 | Cite TRL as: 45 | 46 | ```bibtex 47 | {% raw %}@misc{vonwerra2022trl, 48 | title = {{TRL: Transformer Reinforcement Learning}}, 49 | author = {Leandro von Werra and Younes Belkada and Lewis Tunstall and Edward Beeching and Tristan Thrush and Nathan Lambert and Shengyi Huang and Kashif Rasul and Quentin Gallou{\'e}dec}, 50 | year = 2020, 51 | journal = {GitHub repository}, 52 | publisher = {GitHub}, 53 | howpublished = {\url{https://github.com/huggingface/trl}} 54 | }{% endraw %} 55 | ``` 56 | -------------------------------------------------------------------------------- /trl/trainer/iterative_sft_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Any, Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class IterativeSFTConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`IterativeSFTTrainer`]. 25 | 26 | This class includes only the parameters that are specific to Iterative SFT training. For a full list of training 27 | arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this 28 | class may differ from those in [`~transformers.TrainingArguments`]. 29 | 30 | Using [`~transformers.HfArgumentParser`] we can turn this class into 31 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 32 | command line. 33 | 34 | Parameters: 35 | > Parameters that control the model 36 | 37 | model_init_kwargs (`dict[str, Any]` or `None`, *optional*, defaults to `None`): 38 | Keyword arguments for [`~transformers.AutoModelForCausalLM.from_pretrained`], used when the `model` 39 | argument of the [`IterativeSFTTrainer`] is provided as a string. 40 | 41 | > Parameters that control the data preprocessing 42 | 43 | max_length (`int` or `None`, *optional*, defaults to `None`): 44 | Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated. 45 | truncation_mode (`str`, *optional*, defaults to `"keep_end"`): 46 | The truncation mode to use, either `"keep_end"` or `"keep_start"`. 47 | optimize_device_cache (`bool`, *optional*, defaults to `False`): 48 | Whether to optimize CUDA cache for slightly more memory-efficient training. 49 | """ 50 | 51 | # Parameters whose default values are overridden from TrainingArguments 52 | logging_steps: float = field( 53 | default=10, 54 | metadata={ 55 | "help": ( 56 | "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " 57 | "If smaller than 1, will be interpreted as ratio of total training steps." 58 | ) 59 | }, 60 | ) 61 | 62 | # Parameters that control the model 63 | model_init_kwargs: Optional[dict[str, Any]] = field( 64 | default=None, 65 | metadata={ 66 | "help": "Keyword arguments for `AutoModelForCausalLM.from_pretrained`, used when the `model` argument of " 67 | "the `IterativeSFTTrainer` is provided as a string." 68 | }, 69 | ) 70 | 71 | # Parameters that control the data preprocessing 72 | max_length: Optional[int] = field( 73 | default=None, 74 | metadata={ 75 | "help": "Maximum length of the tokenized sequence. Sequences longer than `max_length` are truncated." 76 | }, 77 | ) 78 | truncation_mode: str = field( 79 | default="keep_end", 80 | metadata={"help": "The truncation mode to use, either 'keep_end' or 'keep_start'."}, 81 | ) 82 | optimize_device_cache: bool = field( 83 | default=False, 84 | metadata={"help": "Whether to optimize CUDA cache for slightly more memory-efficient training."}, 85 | ) 86 | 87 | def __post_init__(self): 88 | super().__post_init__() 89 | 90 | if self.truncation_mode not in ["keep_end", "keep_start"]: 91 | raise ValueError(f"truncation_mode must be either 'keep_end' or 'keep_start', got {self.truncation_mode}") 92 | -------------------------------------------------------------------------------- /trl/trainer/nash_md_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from trl.trainer.online_dpo_config import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class NashMDConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`NashMDTrainer`]. 24 | 25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | mixture_coef (`float` or `list[float]`, *optional*, defaults to `0.5`): 29 | Logit mixture coefficient for the model and reference model. If a list of floats is provided then the 30 | mixture coefficient is selected for each new epoch and the last coefficient is used for the rest of the 31 | epochs. 32 | """ 33 | 34 | mixture_coef: list[float] = field( 35 | default_factory=lambda: [0.5], 36 | metadata={ 37 | "help": "Logit mixture coefficient for the model and reference model. If a list of floats is provided " 38 | "then the mixture coefficient is selected for each new epoch and the last coefficient is used for the " 39 | "rest of the epochs." 40 | }, 41 | ) 42 | 43 | def __post_init__(self): 44 | super().__post_init__() 45 | if hasattr(self.mixture_coef, "__len__") and len(self.mixture_coef) == 1: 46 | self.mixture_coef = self.mixture_coef[0] 47 | -------------------------------------------------------------------------------- /trl/trainer/reward_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | from typing import Optional 17 | 18 | from transformers import TrainingArguments 19 | 20 | 21 | @dataclass 22 | class RewardConfig(TrainingArguments): 23 | r""" 24 | Configuration class for the [`RewardTrainer`]. 25 | 26 | This class includes only the parameters that are specific to Reward training. For a full list of training 27 | arguments, please refer to the [`~transformers.TrainingArguments`] documentation. Note that default values in this 28 | class may differ from those in [`~transformers.TrainingArguments`]. 29 | 30 | Using [`~transformers.HfArgumentParser`] we can turn this class into 31 | [argparse](https://docs.python.org/3/library/argparse#module-argparse) arguments that can be specified on the 32 | command line. 33 | 34 | Parameters: 35 | max_length (`int` or `None`, *optional*, defaults to `1024`): 36 | Maximum length of the sequences (prompt + completion) in the batch, filters out entries that exceed the 37 | limit. This argument is required if you want to use the default data collator. 38 | disable_dropout (`bool`, *optional*, defaults to `True`): 39 | Whether to disable dropout in the model. 40 | dataset_num_proc (`int`, *optional*, defaults to `None`): 41 | Number of processes to use for processing the dataset. 42 | center_rewards_coefficient (`float`, *optional*, defaults to `None`): 43 | Coefficient to incentivize the reward model to output mean-zero rewards (proposed by 44 | https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`. 45 | remove_unused_columns (`bool`, *optional*, defaults to `False`): 46 | Whether to remove the columns that are not used by the model's forward pass. Can be `True` only if 47 | the dataset is pretokenized. 48 | """ 49 | 50 | # Parameters whose default values are overridden from TrainingArguments 51 | logging_steps: float = field( 52 | default=10, 53 | metadata={ 54 | "help": ( 55 | "Log every X updates steps. Should be an integer or a float in range `[0,1)`. " 56 | "If smaller than 1, will be interpreted as ratio of total training steps." 57 | ) 58 | }, 59 | ) 60 | average_tokens_across_devices: bool = field( 61 | default=True, 62 | metadata={ 63 | "help": "Whether or not to average tokens across devices. If enabled, will use all_reduce to synchronize " 64 | "num_tokens_in_batch for precise loss calculation. Reference: https://github.com/huggingface/transformers/issues/34242 " 65 | }, 66 | ) 67 | 68 | max_length: Optional[int] = field( 69 | default=1024, 70 | metadata={ 71 | "help": "Maximum length of the sequences (prompt + completion) in the batch, filters out entries that " 72 | "exceed the limit. This argument is required if you want to use the default data collator." 73 | }, 74 | ) 75 | disable_dropout: bool = field( 76 | default=True, 77 | metadata={"help": "Whether to disable dropout in the model and reference model."}, 78 | ) 79 | dataset_num_proc: Optional[int] = field( 80 | default=None, 81 | metadata={"help": "Number of processes to use for processing the dataset."}, 82 | ) 83 | center_rewards_coefficient: Optional[float] = field( 84 | default=None, 85 | metadata={ 86 | "help": "Coefficient to incentivize the reward model to output mean-zero rewards (proposed by " 87 | "https://huggingface.co/papers/2312.09244, Eq. 2). Recommended value: `0.01`." 88 | }, 89 | ) 90 | remove_unused_columns: bool = field( 91 | default=False, 92 | metadata={ 93 | "help": "Whether to remove the columns that are not used by the model's forward pass. Can be `True` only " 94 | "if the dataset is pretokenized." 95 | }, 96 | ) 97 | -------------------------------------------------------------------------------- /trl/trainer/xpo_config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020-2025 The HuggingFace Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | from dataclasses import dataclass, field 16 | 17 | from trl.trainer.online_dpo_config import OnlineDPOConfig 18 | 19 | 20 | @dataclass 21 | class XPOConfig(OnlineDPOConfig): 22 | r""" 23 | Configuration class for the [`XPOTrainer`]. 24 | 25 | Subclass of [`OnlineDPOConfig`] we can use all its arguments and add the following: 26 | 27 | Parameters: 28 | alpha (`float` or `list[float]`, *optional*, defaults to `1e-5`): 29 | Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each new epoch 30 | and the last alpha is used for the rest of the epochs. 31 | """ 32 | 33 | alpha: list[float] = field( 34 | default_factory=lambda: [1e-5], 35 | metadata={ 36 | "help": "Weight of the XPO loss term. If a list of floats is provided then the alpha is selected for each " 37 | "new epoch and the last alpha is used for the rest of the epochs." 38 | }, 39 | ) 40 | 41 | def __post_init__(self): 42 | super().__post_init__() 43 | if hasattr(self.alpha, "__len__") and len(self.alpha) == 1: 44 | self.alpha = self.alpha[0] 45 | --------------------------------------------------------------------------------