├── .devcontainer ├── Dockerfile └── devcontainer.json ├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── documentation.yml │ └── feature_request.yml └── workflows │ ├── build.yml │ └── code_quality.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── accelerate │ ├── ddp.yaml │ ├── zero2-bf16.yaml │ ├── zero2-fp16.yaml │ └── zero3.yaml ├── nemo_configs │ ├── megatron_1.3b.yaml │ ├── megatron_20b.yaml │ ├── megatron_2b.yaml │ ├── megatron_65b.yaml │ └── sft_megatron_20b.yaml ├── sweeps │ ├── ilql_sweep.yml │ └── ppo_sweep.yml └── test_config.yml ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── api.rst │ ├── conf.py │ ├── configs.rst │ ├── data.rst │ ├── examples.rst │ ├── index.rst │ ├── installation.rst │ ├── pipelines.rst │ └── trainers.rst ├── examples ├── __init__.py ├── alpaca │ ├── README.md │ └── sft_alpaca.py ├── architext.py ├── experiments │ └── grounded_program_synthesis │ │ ├── README.md │ │ ├── __init__.py │ │ ├── configs │ │ └── trlx_ppo_config.yml │ │ ├── lang.py │ │ └── train_trlx.py ├── hh │ ├── README.md │ ├── ilql_hh.py │ ├── ppo_hh.py │ ├── sft_hh.py │ ├── to_triton.py │ └── triton_config.pbtxt ├── ilql_sentiments.py ├── ilql_sentiments_t5.py ├── llama_nemo │ ├── README.md │ ├── convert_llama_to_nemo.py │ ├── dist_train.sh │ ├── megatron_llama_cfg.yaml │ └── nemo_llama2_ppo_sentiments.py ├── nemo_ilql_inference.py ├── nemo_ilql_sentiments.py ├── nemo_ppo_inference.py ├── nemo_ppo_sentiments.py ├── nemo_sft_sentiments.py ├── nemo_vs_ds_chat.py ├── notebooks │ ├── trlx_sentiments.ipynb │ └── trlx_simulacra.ipynb ├── ppo_dense_sentiments.py ├── ppo_sentiments.py ├── ppo_sentiments_llama.py ├── ppo_sentiments_peft.py ├── ppo_sentiments_t5.py ├── ppo_translation_t5.py ├── randomwalks │ ├── README.md │ ├── __init__.py │ ├── graph-example.png │ ├── ilql_randomwalks.py │ ├── ppo_randomwalks.py │ ├── randomwalks.py │ └── rft_randomwalks.py ├── rft_sentiments.py ├── sft_sentiments.py ├── simulacra.py ├── summarize_daily_cnn │ ├── __init__.py │ └── t5_summarize_daily_cnn.py └── summarize_rlhf │ ├── README.md │ ├── configs │ ├── default_accelerate_config.yaml │ └── ds_config_trlx_gptj_summarize.json │ ├── ilql_summarize_t5.py │ ├── requirements.txt │ ├── reward_model │ ├── ds_config_gpt_j.json │ ├── gptj_reward_test.py │ ├── reward_model.py │ └── train_reward_model_gptj.py │ ├── sft │ ├── ds_config_gptj.json │ ├── summarize_dataset.py │ └── train_gptj_summarize.py │ ├── trlx_gptj_text_summarization.py │ └── trlx_inference_gptj.py ├── pyproject.toml ├── requirements.txt ├── scripts ├── accelerate_train_example.sh ├── benchmark.sh ├── slurm_train.sh └── sweep-cw.sh ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_configs.py ├── test_minibatch.py ├── test_models.py ├── test_peft.py ├── test_pipelines.py ├── test_trainers.py └── test_utils.py └── trlx ├── __init__.py ├── data ├── __init__.py ├── accelerate_base_datatypes.py ├── configs.py ├── default_configs.py ├── ilql_types.py ├── method_configs.py └── ppo_types.py ├── models ├── README.md ├── __init__.py ├── modeling_base.py ├── modeling_ilql.py ├── modeling_nemo_ilql.py ├── modeling_nemo_ppo.py ├── modeling_nemo_sft.py └── modeling_ppo.py ├── pipeline ├── __init__.py ├── offline_pipeline.py └── ppo_pipeline.py ├── reference.py ├── sweep.py ├── trainer ├── __init__.py ├── accelerate_base_trainer.py ├── accelerate_ilql_trainer.py ├── accelerate_ppo_trainer.py ├── accelerate_rft_trainer.py ├── accelerate_sft_trainer.py ├── nemo_ilql_trainer.py ├── nemo_ppo_trainer.py └── nemo_sft_trainer.py ├── trlx.py └── utils ├── __init__.py ├── loading.py ├── logging.py └── modeling.py /.devcontainer/Dockerfile: -------------------------------------------------------------------------------- 1 | # Use Nvidia Ubuntu 20 base (includes CUDA if a supported GPU is present) 2 | # https://hub.docker.com/r/nvidia/cuda 3 | FROM nvidia/cuda:11.6.2-cudnn8-devel-ubuntu20.04@sha256:55211df43bf393d3393559d5ab53283d4ebc3943d802b04546a24f3345825bd9 4 | 5 | ARG USERNAME 6 | ARG USER_UID=1000 7 | ARG USER_GID=$USER_UID 8 | 9 | # Create the user 10 | # https://code.visualstudio.com/remote/advancedcontainers/add-nonroot-user 11 | RUN groupadd --gid $USER_GID $USERNAME \ 12 | && useradd --uid $USER_UID --gid $USER_GID -m $USERNAME \ 13 | && usermod -a -G video user \ 14 | && apt-get update \ 15 | && apt-get install -y sudo \ 16 | && echo $USERNAME ALL=\(root\) NOPASSWD:ALL > /etc/sudoers.d/$USERNAME \ 17 | && chmod 0440 /etc/sudoers.d/$USERNAME 18 | 19 | # Install dependencies 20 | RUN sudo apt-get update && \ 21 | DEBIAN_FRONTEND=noninteractive apt-get -qq -y install \ 22 | build-essential \ 23 | python3.9 \ 24 | python3.9-dev \ 25 | python3.9-distutils \ 26 | python3.9-venv \ 27 | curl \ 28 | git 29 | 30 | # Install pip (we need the latest version not the standard Ubuntu version, to 31 | # support modern wheels) 32 | RUN sudo curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py && python3.9 get-pip.py 33 | 34 | # Set python aliases 35 | RUN sudo update-alternatives --install /usr/bin/python python /usr/bin/python3.9 1 36 | RUN sudo update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.9 1 37 | 38 | # User the new user 39 | USER $USERNAME 40 | 41 | # Install python dev dependencies 42 | RUN pip install \ 43 | autopep8 \ 44 | jedi \ 45 | mypy \ 46 | pytest \ 47 | toml \ 48 | yapf 49 | -------------------------------------------------------------------------------- /.devcontainer/devcontainer.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "Python 3", 3 | "build": { 4 | "dockerfile": "Dockerfile", 5 | "args": { 6 | "USERNAME": "user" 7 | } 8 | }, 9 | "customizations": { 10 | "vscode": { 11 | "settings": { 12 | "python.formatting.autopep8Path": "autopep8", 13 | "python.linting.mypyPath": "mypy" 14 | }, 15 | "extensions": [ 16 | "davidanson.vscode-markdownlint", 17 | "donjayamanne.githistory", 18 | "donjayamanne.python-extension-pack", 19 | "github.vscode-pull-request-github", 20 | "ms-python.python", 21 | "ms-toolsai.jupyter", 22 | "ms-vsliveshare.vsliveshare-pack", 23 | "njpwerner.autodocstring", 24 | "stkb.rewrap", 25 | "streetsidesoftware.code-spell-checker", 26 | "tushortz.python-extended-snippets", 27 | "yzhang.markdown-all-in-one", 28 | "elagil.pre-commit-helper" 29 | ] 30 | } 31 | }, 32 | "containerUser": "user", 33 | "postCreateCommand": "pip install -e .[dev]" 34 | } 35 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | description: Report a bug or unexpected behavior to help us improve trlX 4 | labels: 5 | - bug 6 | 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: > 11 | #### Before submitting your bug report, please check to see that the 12 | issue hasn't already been reported and/or fixed in a latest version. 13 | [Search Issues][Issue Search]. 14 | 15 | If you're asking a question or seeking support, please consider creating a 16 | new [GitHub discussion][Discussions] or heading over to CarperAI's 17 | [Discord server][CarperAI Discord]. 18 | 19 | 20 | [Issue Search]: https://github.com/CarperAI/trlx/search?q=is%3Aissue&type=issues 21 | 22 | [Discussions]: https://github.com/CarperAI/trlx/discussions 23 | 24 | [CarperAI Discord]: https://discord.gg/X2gHZMRP6m 25 | 26 | - type: textarea 27 | attributes: 28 | label: 🐛 Describe the bug 29 | description: >- 30 | Please provide a clear and concise description of what the problem is, 31 | preferably with self-contained code to reproduce the issue. You may want 32 | to follow the suggestions outlined in [this guide][Guide]. If you observe 33 | an error, please paste the error message including the full traceback. 34 | 35 | 36 | [Guide]: https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports 37 | 38 | placeholder: | 39 | A description of what the bug is. 40 | 41 | ```python 42 | # Sample code to reproduce the bug, if applicable. 43 | ``` 44 | 45 | ``` 46 | The error message, with the full traceback. 47 | ``` 48 | 49 | validations: 50 | required: true 51 | 52 | - type: input 53 | attributes: 54 | label: Which trlX version are you using? 55 | placeholder: For example, `trlx==1.0.0` 56 | 57 | - type: input 58 | attributes: 59 | label: Additional system and package information 60 | placeholder: Python version, `transformers` version, OS (Linux/Mac/Windows/WSL), etc. 61 | 62 | - type: markdown 63 | attributes: 64 | value: > 65 | Thanks for contributing 🐠! 66 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 📚 Documentation 3 | description: Report an issue related to https://trlx.readthedocs.io/en/latest/index.html 4 | labels: 5 | - documentation 6 | 7 | body: 8 | - type: textarea 9 | attributes: 10 | label: 📚 The doc issue 11 | description: > 12 | Please provide a clear and concise description of what content in https://trlx.readthedocs.io/en/latest/index.html is an issue. 13 | validations: 14 | required: true 15 | 16 | - type: textarea 17 | attributes: 18 | label: Suggest a potential alternative/fix 19 | description: > 20 | Tell us how we could improve the documentation in this regard. 21 | 22 | - type: markdown 23 | attributes: 24 | value: > 25 | Thanks for contributing 🐠! 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature Request 3 | description: Submit a proposal/request for a new trlX feature 4 | labels: 5 | - feature request 6 | 7 | body: 8 | - type: textarea 9 | attributes: 10 | label: 🚀 The feature, motivation, and pitch 11 | description: > 12 | Please provide a clear and concise description of the feature proposal. 13 | Outline the motivation for the proposal; is your feature request related to a 14 | specific problem? E.g., *"I'm working on X and would like Y to be 15 | possible"*. If this is related to another GitHub issue, please link here 16 | too. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | attributes: 22 | label: Alternatives 23 | description: > 24 | A description of any alternative solutions or features you've considered, 25 | if any. 26 | 27 | - type: textarea 28 | attributes: 29 | label: Additional context 30 | description: > 31 | Add any other context or screenshots about the feature request. 32 | 33 | - type: markdown 34 | attributes: 35 | value: > 36 | Thanks for contributing 🐠! 37 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 3.9 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.9 21 | 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install -r requirements.txt 26 | # Install extras 27 | # [bnb] (TODO: Remove `scipy` once `bnb` adds it as hard dep) 28 | pip install bitsandbytes scipy 29 | # [dev] 30 | pip install black hypothesis isort flake8 pre-commit pytest pytest-cov 31 | 32 | - name: Lint with flake8 33 | run: | 34 | # Stop the build if there are Python syntax errors or undefined names 35 | flake8 . --count --select=E9,F63,F7 --show-source --statistics 36 | # Exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 37 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 38 | 39 | - name: Run tests 40 | run: | 41 | pytest -vv --cov=trlx/ tests/ 42 | 43 | - name: Upload coverage to Codecov 44 | run: | 45 | bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN 46 | -------------------------------------------------------------------------------- /.github/workflows/code_quality.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | code-quality: 7 | runs-on: ubuntu-20.04 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: actions/setup-python@v2 11 | with: 12 | python-version: 3.8 13 | - uses: pre-commit/action@v2.0.3 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | 143 | nbs/wandb/ 144 | 145 | wandb/ 146 | 147 | OUT/ 148 | 149 | 150 | examples/experiments/grounded_program_synthesis/dataset 151 | ckpts/ 152 | 153 | ray_results/ 154 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v4.4.0 6 | hooks: 7 | - id: check-case-conflict 8 | - id: check-json 9 | - id: check-symlinks 10 | - id: check-yaml 11 | - id: destroyed-symlinks 12 | - id: end-of-file-fixer 13 | exclude: docs/CNAME 14 | - id: fix-byte-order-marker 15 | - id: fix-encoding-pragma 16 | args: [--remove] 17 | - id: mixed-line-ending 18 | args: [--fix=lf] 19 | - id: requirements-txt-fixer 20 | - id: trailing-whitespace 21 | - repo: https://github.com/psf/black 22 | rev: 23.1.0 23 | hooks: 24 | - id: black 25 | files: ^(trlx|examples|tests|setup.py)/ 26 | - repo: https://github.com/pycqa/isort 27 | rev: 5.12.0 28 | hooks: 29 | - id: isort 30 | name: isort (python) 31 | - repo: https://github.com/pycqa/flake8 32 | rev: 6.0.0 33 | hooks: 34 | - id: flake8 35 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.8" 10 | nodejs: "18" 11 | rust: "1.64" 12 | golang: "1.19" 13 | 14 | python: 15 | install: 16 | - requirements: docs/requirements.txt 17 | - method: pip 18 | path: . 19 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | caperai@stability.ai. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `trlX` 2 | 3 | Looking to improve `trlX`? Thanks for considering! 4 | 5 | There are many ways to contribute, from writing tutorials in [Colab notebooks](https://colab.research.google.com) to improving the project's [documentation](https://trlx.readthedocs.io), submitting bug reports and feature requests, or even implementing new features themselves. See the outstanding [issues](https://github.com/CarperAI/trlx/issues) for ideas on where to begin. 6 | 7 | Here are some guidelines to help you get started 🚀. 8 | 9 | ## Submitting a bug report or a feature request¶ 10 | 11 | To submit a bug report or a feature request, please open an [issue](https://github.com/CarperAI/trlx/issues) by clicking on the `New Issue` button and selecting the respective issue template. Make sure to fill out all the required information and provide as much detail as possible. For bug reports, this means including a minimal code example that reproduces the bug, and for feature requests, it means providing a clear and detailed description of the feature you would like to see implemented. 12 | 13 | ## Submitting code 14 | 15 | > **Note**: Make sure to first search through the [issue tracker](https://github.com/CarperAI/trlx/issues) and [PR list](https://github.com/CarperAI/trlx/pulls) to avoid duplicating work. If you want to work on a non-trivial feature, we highly recommended that you first open an issue in the [issue tracker](https://github.com/CarperAI/trlx/issues) to get feedback from core developers. 16 | 17 | Follow these steps to start contributing code: 18 | 19 | 1. Create your own [fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo#forking-a-repository) of the repository and clone it to your local machine. 20 | ```bash 21 | git clone https://github.com//trlx.git 22 | cd trlx 23 | git remote add upstream https://github.com/CarperAI/trlx.git 24 | ``` 25 | 2. Create a new branch for your changes and give it a concise name that reflects your contribution. 26 | ```bash 27 | git checkout -b 28 | ``` 29 | 2. Install the development dependencies in a Python environment. 30 | ```bash 31 | pip install -e ".[dev]" 32 | pre-commit install 33 | ``` 34 | 4. Implement your changes. Make small, independent, and well documented commits along the way (check out [these](https://cbea.ms/git-commit/) tips). 35 | 5. Add unit tests whenever appropriate and ensure that the tests pass. To run the entire test suite, use the following command from within the project root directory. 36 | ```bash 37 | pytest 38 | ``` 39 | For changes with minimal project scope (e.g. a simple bug fix), you might want to run the unit tests for just a specific test file instead: 40 | ```bash 41 | pytest -vv -k "" 42 | ``` 43 | 5. Commit your final changes. Our `pre-commit` hooks will automatically run before each commit and will prevent you from committing code that does not pass our style and linter checks. They'll also automatically format your code! To run these manually, use the following command: 44 | ```bash 45 | pre-commit run --all-files 46 | ``` 47 | 48 | 6. Push the changes to your fork. 49 | 50 | Finally ... 🥁 ... Create a [pull request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request) to the `trlX` repository! Make sure to include a description of your changes and link to any relevant issues. 51 | 52 | > __Tip__: If you're looking to introduce an experimental feature, we suggest testing the behavior of your proposed feature on some of the existing [examples](https://github.com/CarperAI/trlx/tree/master/examples), such as [random walks](https://github.com/CarperAI/trlx/blob/master/examples/randomwalks). This will help you get a better sense of how the feature would work in practice and will also help you identify any potential flaws in the implementation. 53 | 54 | ## Asking questions 55 | 56 | Have a question? Rather than opening an issue, you can readily chat with the core team on our [Discord server](https://discord.gg/canadagoose). 57 | 58 | ## Code of conduct 59 | 60 | This project adheres to the [Contributor Covenant Code of Conduct](https://github.com/CarperAI/trlx/blob/master/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. 61 | 62 | ## License 63 | 64 | By contributing, you agree that your contributions will be licensed under its MIT License. 65 | 66 | # Thank you for your contribution 🐠! 67 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 CarperAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /configs/accelerate/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: no 5 | dynamo_config: {} 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: bf16 12 | num_machines: 1 13 | num_processes: 8 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /configs/accelerate/zero2-bf16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: no 12 | dynamo_config: {} 13 | fsdp_config: {} 14 | machine_rank: 0 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/accelerate/zero2-fp16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: no 12 | dynamo_config: {} 13 | fsdp_config: {} 14 | machine_rank: 0 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | mixed_precision: fp16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/accelerate/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: no 13 | dynamo_config: {} 14 | fsdp_config: {} 15 | machine_rank: 0 16 | main_training_function: main 17 | megatron_lm_config: {} 18 | mixed_precision: bf16 19 | num_machines: 1 20 | num_processes: 8 21 | rdzv_backend: static 22 | same_network: true 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /configs/sweeps/ilql_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "metrics/sentiments" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 64 7 | 8 | # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 9 | optimizer.kwargs.lr: 10 | strategy: "loguniform" 11 | values: [0.000001, 0.001] 12 | method.tau: 13 | strategy: "uniform" 14 | values: [0.6, 0.9] 15 | method.steps_for_target_q_sync: 16 | strategy: "choice" 17 | values: [1, 5, 10] 18 | method.alpha: 19 | strategy: "loguniform" 20 | values: [0.001, 1.0] 21 | 22 | # disable checkpointing for storage sake 23 | train.checkpoint_interval: 24 | strategy: "choice" 25 | values: [10000000] 26 | train.save_best: 27 | strategy: "choice" 28 | values: [false] 29 | -------------------------------------------------------------------------------- /configs/sweeps/ppo_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "reward/mean" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 32 7 | 8 | # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 9 | optimizer.kwargs.lr: 10 | strategy: "loguniform" 11 | values: [0.000001, 0.001] 12 | method.init_kl_coef: 13 | strategy: "loguniform" 14 | values: [0.0001, 0.2] 15 | model.num_layers_unfrozen: 16 | strategy: "choice" 17 | values: [-1, 2, 6] 18 | method.num_rollouts: 19 | strategy: "choice" 20 | values: [32, 128, 512] 21 | method.target: 22 | strategy: "choice" 23 | values: [null, 1] 24 | 25 | # disable checkpointing for storage sake 26 | train.checkpoint_interval: 27 | strategy: "choice" 28 | values: [10000000] 29 | train.save_best: 30 | strategy: "choice" 31 | values: [false] 32 | -------------------------------------------------------------------------------- /configs/test_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 64 # Size of LM context 3 | epochs: 100 # Train for max(epochs, total_steps) 4 | total_steps: 1000 # Train for max(epochs, total_steps) 5 | batch_size: 16 # batch size 6 | 7 | checkpoint_interval: 10000 # checkpoint interval 8 | eval_interval: 128 # eval interval 9 | 10 | pipeline: "PromptPipeline" # prompt pipeline to load 11 | trainer: "AcceleratePPOTrainer" # Name of model trainer to load 12 | 13 | model: 14 | model_path: "lvwerra/gpt2-imdb" # Name of hf model to load 15 | num_layers_unfrozen: 2 # Number of bottom layers to freeze during training 16 | 17 | tokenizer: 18 | tokenizer_path: "gpt2" # Name of hf tokenizer to load 19 | truncation_side: "right" # Trim this side of samples if they are longer than LM context 20 | 21 | optimizer: 22 | name: "adamw" # Name of optimizer to load 23 | kwargs: 24 | lr: 1.412e-4 # Learning rate 25 | betas: [0.9, 0.95] # Adam betas 26 | eps: 1.0e-8 # Adam eps 27 | weight_decay: 1.0e-6 # Weight decay param 28 | 29 | scheduler: 30 | name: "cosine_annealing" # Name of learning rate scheduler 31 | kwargs: 32 | T_max: 10000 # Maximum number of steps 33 | eta_min: 1.412e-4 # Minimum learning rate 34 | 35 | method: 36 | name: "ppoconfig" # Name of RL method config 37 | num_rollouts: 128 # Number of rollouts to collect per epoch 38 | chunk_size: 128 # Number of rollouts to collect in one loop 39 | ppo_epochs: 4 # Number of ppo epochs 40 | init_kl_coef: 0.2 # init kl coefficient 41 | target: 6 # target kl coefficient, set None for fixed kl coef 42 | horizon: 10000 # PPO horizon 43 | gamma: 0.99 # PPO discount 44 | lam: 0.95 # PPO lambda 45 | cliprange: 0.2 # clip range 46 | cliprange_value: 0.2 # clip range 47 | vf_coef: 1.0 # value term weight 48 | scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards 49 | cliprange_reward: 10 50 | ref_mean: null 51 | ref_std: null 52 | gen_kwargs: 53 | max_length: 48 # LM max sample gen length 54 | min_length: 48 # LM min sample gen length 55 | top_k: 0.0 # top k 56 | top_p: 1.0 # top p 57 | do_sample: True # sample 58 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.0.0 2 | sphinx_rtd_theme 3 | torch 4 | torchtyping 5 | -------------------------------------------------------------------------------- /docs/source/api.rst: -------------------------------------------------------------------------------- 1 | .. _api: 2 | 3 | API 4 | === 5 | 6 | trlX uses a single entrypoint for training, which will execute training conditioned on the passed config and the necessary arguments for a specific training routine. For the online training `prompts` (a list of strings to prompt the training model) and `reward_fn` (a function which gives reward for model outputs sampled from `prompts`) are necessary, while for offline training `samples` (a list of environment/model interactions) and `rewards` (precomputed scores for each interaction) are required. 7 | 8 | Training 9 | -------- 10 | 11 | .. autofunction:: trlx.train 12 | 13 | Distributed 14 | ----------- 15 | 16 | Accelerate 17 | ^^^^^^^^^^ 18 | 19 | To launch distributed training with Accelerate, first you have to specify the training configuration. You only have to execute this command once per each training node. 20 | 21 | .. code-block:: console 22 | 23 | $ accelerate config 24 | $ accelerate launch examples/ppo_sentiments.py 25 | 26 | You can also use configs provided in `trlX repository `_): 27 | 28 | .. code-block:: console 29 | 30 | $ accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/ppo_sentiments.py 31 | 32 | 33 | NVIDIA NeMo 34 | ^^^^^^^^^^^ 35 | 36 | For training with NeMo you have to use a model stored in the NeMo format. You can convert an existing llama model with the following script: 37 | 38 | .. code-block:: console 39 | 40 | $ python examples/llama_nemo/convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b 41 | 42 | To start training you have to execute python script per each GPU, or launch the following sbatch script which has `-ntasks-per-node=8` 43 | 44 | .. code-block:: console 45 | 46 | $ sbatch examples/llama_nemo/dist_train.sh 47 | 48 | Run example: `wandb `_ 49 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | import sphinx_rtd_theme 17 | 18 | sys.path.insert(0, os.path.abspath('../..')) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'trlX' 24 | copyright = '2022, CarperAI' 25 | author = 'CarperAI' 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | 33 | extensions = ['sphinx_rtd_theme', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel'] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = 'sphinx_rtd_theme' 50 | 51 | # Add any paths that contain custom static files (such as style sheets) here, 52 | # relative to this directory. They are copied after the builtin static files, 53 | # so a file named "default.css" will overwrite the builtin "default.css". 54 | html_static_path = ['_static'] 55 | -------------------------------------------------------------------------------- /docs/source/configs.rst: -------------------------------------------------------------------------------- 1 | .. _configs: 2 | 3 | Configs 4 | ************************ 5 | 6 | Training requires configuration to be passed through a set of configs: `TrainConfig` with training configuration, `ModelConfig`, `TokenizerConfig`, `OptimizerConfig`, `SchedulerConfig` and a `MethodConfig` for a specific configuration of a particular algorithm (PPO, ILQL or SFT) 7 | 8 | **General** 9 | 10 | .. autoclass:: trlx.data.configs.TRLConfig 11 | :members: 12 | 13 | .. autoclass:: trlx.data.configs.TrainConfig 14 | :members: 15 | 16 | .. autoclass:: trlx.data.configs.ModelConfig 17 | :members: 18 | 19 | .. autoclass:: trlx.data.configs.TokenizerConfig 20 | :members: 21 | 22 | .. autoclass:: trlx.data.configs.OptimizerConfig 23 | :members: 24 | 25 | .. autoclass:: trlx.data.configs.SchedulerConfig 26 | :members: 27 | 28 | .. autoclass:: trlx.data.method_configs.MethodConfig 29 | :members: 30 | 31 | **PPO** 32 | 33 | .. autoclass:: trlx.models.modeling_ppo.PPOConfig 34 | :members: 35 | 36 | **ILQL** 37 | 38 | .. autoclass:: trlx.models.modeling_ilql.ILQLConfig 39 | :members: 40 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | Data Classes 4 | ============ 5 | 6 | Data Elements contain the necessary information for each individual training sample. 7 | 8 | PPO Data Classes 9 | ---------------- 10 | 11 | .. autoclass:: trlx.data.ppo_types.PPORLElement 12 | :members: 13 | 14 | .. autoclass:: trlx.data.ppo_types.PPORLBatch 15 | :members: 16 | 17 | ILQL Data Classes 18 | ----------------- 19 | 20 | .. autoclass:: trlx.data.ilql_types.ILQLElement 21 | :members: 22 | 23 | .. autoclass:: trlx.models.modeling_ilql.CausalILQLOutput 24 | :members: 25 | 26 | .. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqElement 27 | :members: 28 | 29 | .. autoclass:: trlx.models.modeling_ilql.Seq2SeqILQLOutput 30 | :members: 31 | 32 | .. autoclass:: trlx.data.ilql_types.ILQLBatch 33 | :members: 34 | 35 | .. autoclass:: trlx.data.ilql_types.ILQLSeq2SeqBatch 36 | :members: 37 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | ======== 5 | 6 | Random Walks 7 | ------------ 8 | 9 | This is a simple toy example described in `Decision Transformer 10 | (Lili Chen et al. 2021) `_. It's simple enough that it can be used for testing with a 1M sized LLM, training of which can complete entirely on CPU. 11 | 12 | Description 13 | ^^^^^^^^^^^ 14 | 15 | The task is to find the shortest path on a directed graph. The reward is based 16 | on how optimal the path is compared to the shortest possible. Paths are 17 | represented as strings of letters, where each letter corresponds to a node in 18 | the graph. 19 | 20 | Training 21 | ^^^^^^^^ 22 | 23 | For `PPO Training 24 | `_, 25 | a language model continually samples paths in a graph and directly optimizes for 26 | their shortness using surrogate reward function. For `ILQL Training 27 | `_ 28 | a language model learns directly from a set of 1000 pre-sampled randomwalks in a 29 | graph paired with their relative lengths' shortness. 30 | 31 | W&B runs: 32 | 33 | - PPO https://wandb.ai/sorry/trlx-references/runs/sf8ept0l 34 | - ILQL https://wandb.ai/sorry/trlx-references/runs/g44npaoq 35 | 36 | Positive Sentiment 37 | ------------------ 38 | 39 | Description 40 | ^^^^^^^^^^^ 41 | The task is to optimize a language model to generate positive sentiment responses for a given prompt. 42 | 43 | Training 44 | ^^^^^^^^ 45 | 46 | The training is done by using `PPO trainer 47 | `_ to 48 | maximize a score from pre-trained sentiment classifier trained on IMDB review 49 | sentiments `dataset `_ . For `ILQL Training 50 | `_ the 51 | model is trained directly on the dataset and its labels: `0` for a negative 52 | review and `1` for a positive one. For `SFT Training 53 | `_ the 54 | model is trained only on the positive reviews. 55 | 56 | W&B runs: 57 | 58 | - PPO: https://wandb.ai/sorry/trlx-references/runs/9ohlfd3s 59 | - ILQL: https://wandb.ai/sorry/trlx-references/runs/tplhaji6 60 | - SFT: https://wandb.ai/sorry/trlx-references/runs/vfxfv081 61 | 62 | Helpful & Harmless 63 | ------------------- 64 | 65 | Description 66 | ^^^^^^^^^^^ 67 | 68 | The task is to improve both helpfulness and harmlessness of the 69 | model's outputs following Anthropic's paper `Training a Helpful and Harmless 70 | Assistant with Reinforcement Learning from Human Feedback 71 | `_ 72 | 73 | Training 74 | ^^^^^^^^ 75 | 76 | The training is done by either utilizing a reward model trained on the 77 | Anthropic's Helpful & Harmless `dataset 78 | `_ using `PPO trainer 79 | `_, or by 80 | using the dataset directly by reward labeling each selected and rejected with 81 | `+1` and `-1` respectively using `ILQL trainer 82 | `_, or using 83 | `SFT trainer 84 | `_ and 85 | finetuning only over selected responses. 86 | 87 | The setup used for this example assumes a single machine with 8xA100 80GB, the 88 | last of which will be dedicated to hosting a reward model. Optionally you can 89 | use `Triton Inference Server `_ to 90 | host it elsewhere, otherwise the training script will instantiate it (`a 91 | pretrained one `_) on its own. 92 | 93 | Launch training of `GPT-J `_ on 7 94 | GPUs with 8th GPU hosting a reward model: 95 | 96 | .. code-block:: console 97 | 98 | accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 99 | # or for training from other predefined checkpoint 100 | CONFIG_NAME=125M accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 101 | 102 | Optional steps to setup a reward model using Triton Server: 103 | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 104 | 105 | .. code-block:: console 106 | 107 | # convert the model and create a config and a folder `model_store` structured for Triton 108 | python to_triton.py --base_model EleutherAI/gpt-j-6B --checkpoint Dahoas/gptj-rm-static --revision 676bfd4d 109 | 110 | # convert the docker image (skip this if you use docker instead) 111 | singularity build --sandbox tritonserver-pyt.sif docker://nvcr.io/nvidia/tritonserver:22.08-pyt-python-py3 112 | 113 | # start Triton Server pointing to the `model_store` containing the reward model 114 | SINGULARITYENV_CUDA_VISIBLE_DEVICES=7 singularity run --nv --bind model_store:/model_store tritonserver-pyt.sif tritonserver --model-repository=/model_store & 115 | 116 | Launch training: 117 | 118 | .. code-block:: console 119 | 120 | # set model's url and replace the name after the slash if you use a different checkpoint 121 | export TRITON_HOST=localhost:8001/gptj-rm-static 122 | accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 123 | 124 | W&B runs: 125 | 126 | - PPO GPT-J: https://wandb.ai/sorry/trlx/runs/v0bir5s9 127 | - ILQL GPT-J: https://wandb.ai/sorry/trlx/runs/1qqxp72a 128 | - SFT GPT-J: https://wandb.ai/sorry/trlx/runs/a7ng078v 129 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to trlX's documentation! 2 | ================================ 3 | trlX is a library for training large language models with reinforcement learning. Training can be done with two RL algorithms: PPO (`Schulman et al. 2017 `_) for online training and ILQL (`Snell et al. 2022 `_) for offline training. For distributed training two backends are supported: `Huggingface 🤗 Accelerate `_ and `NVIDIA NeMo `_. 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | :caption: Contents: 8 | 9 | installation 10 | api 11 | examples 12 | configs 13 | trainers 14 | pipelines 15 | data 16 | -------------------------------------------------------------------------------- /docs/source/installation.rst: -------------------------------------------------------------------------------- 1 | .. _installation: 2 | 3 | Installation 4 | ============ 5 | 6 | trlX is a pure Python library that supports two optional distributed backends: `Huggingface 🤗 Accelerate `_ and `NVIDIA NeMo `_, the latter is optional and can be installed separately. 7 | 8 | Requirements 9 | ------------ 10 | 11 | * OS: Linux 12 | * Python: 3.9-3.11 13 | 14 | Install with pip 15 | ---------------- 16 | 17 | You can install trlX using pip: 18 | 19 | .. code-block:: console 20 | 21 | $ pip install -U git+https://github.com/CarperAI/trlx.git 22 | 23 | .. _build_from_source: 24 | 25 | Install from source 26 | ------------------- 27 | 28 | You can also install trlX from source: 29 | 30 | .. code-block:: console 31 | 32 | $ git clone https://github.com/CarperAI/trlx.git 33 | $ cd trlx 34 | $ pip install torch --extra-index-url https://download.pytorch.org/whl/cu118 35 | $ pip install -e . 36 | 37 | Install NeMo 38 | ____________ 39 | 40 | Install NeMo version v1.17.0: 41 | 42 | .. code-block:: console 43 | 44 | $ git clone https://github.com/NVIDIA/NeMo/ 45 | $ cd NeMo 46 | $ git checkout d3017e4 47 | $ pip install -e '.[all]' 48 | 49 | Install Apex: 50 | 51 | .. code-block:: console 52 | 53 | $ git clone https://github.com/NVIDIA/apex 54 | $ cd apex 55 | $ # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... 56 | $ pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 57 | -------------------------------------------------------------------------------- /docs/source/pipelines.rst: -------------------------------------------------------------------------------- 1 | .. _pipeline: 2 | 3 | Pipelines 4 | ========= 5 | 6 | Pipelines are used for accumulation and convertion of the training data to appropriate format. 7 | 8 | .. autoclass:: trlx.pipeline.BasePipeline 9 | :members: 10 | 11 | .. autoclass:: trlx.pipeline.BaseRolloutStore 12 | :members: 13 | 14 | .. autoclass:: trlx.pipeline.offline_pipeline.DialogMessage 15 | :members: 16 | 17 | .. autoclass:: trlx.pipeline.offline_pipeline.DialogStore 18 | :members: 19 | 20 | .. autofunction:: trlx.pipeline.offline_pipeline.tokenize_dialogue 21 | 22 | .. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage 23 | :members: 24 | 25 | .. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline 26 | :members: 27 | 28 | .. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage 29 | :members: 30 | 31 | .. autoclass:: trlx.pipeline.offline_pipeline.ILQLSeq2SeqRolloutStorage 32 | :members: 33 | -------------------------------------------------------------------------------- /docs/source/trainers.rst: -------------------------------------------------------------------------------- 1 | .. _trainers: 2 | 3 | Trainers 4 | ======== 5 | 6 | Abstract Trainers 7 | ----------------- 8 | 9 | .. autoclass:: trlx.trainer.BaseRLTrainer 10 | :members: 11 | 12 | .. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer 13 | :members: 14 | 15 | Accelerate Trainers 16 | ------------------- 17 | 18 | .. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer 19 | :members: 20 | 21 | .. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer 22 | :members: 23 | 24 | .. autoclass:: trlx.trainer.accelerate_sft_trainer.AccelerateSFTTrainer 25 | :members: 26 | 27 | NeMo Trainers 28 | ------------- 29 | 30 | .. autoclass:: trlx.trainer.nemo_ppo_trainer.NeMoPPOTrainer 31 | :members: 32 | 33 | .. autoclass:: trlx.trainer.nemo_ilql_trainer.NeMoILQLTrainer 34 | :members: 35 | 36 | .. autoclass:: trlx.trainer.nemo_sft_trainer.NeMoSFTTrainer 37 | :members: 38 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/examples/__init__.py -------------------------------------------------------------------------------- /examples/alpaca/README.md: -------------------------------------------------------------------------------- 1 | ## Alpaca 2 | 3 | Finetune a model on [Alpaca](https://huggingface.co/datasets/tatsu-lab/alpaca) 4 | ```bash 5 | python sft_alpaca.py --model_name EleutherAI/gpt-j-6B --dataset tatsu-lab/alpaca 6 | ``` 7 | 8 | Finetune a model on [Alpaca-Cleaned](https://huggingface.co/datasets/yahma/alpaca-cleaned) 9 | ```bash 10 | python sft_alpaca.py --model_name EleutherAI/gpt-j-6B --dataset yahma/alpaca-cleaned 11 | ``` 12 | -------------------------------------------------------------------------------- /examples/alpaca/sft_alpaca.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser 4 | from typing import Dict, List 5 | 6 | from datasets import load_dataset 7 | from transformers import pipeline 8 | 9 | import trlx 10 | from trlx.data.default_configs import TRLConfig, default_sft_config 11 | 12 | 13 | def get_positive_score(scores): 14 | "Extract value associated with a positive sentiment from pipeline's output" 15 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 16 | 17 | 18 | def preprocess(instruction: str, input: str, output: str): 19 | """Build Alpaca prompt and output from instruction and input/output examples""" 20 | if input: 21 | prefix = ( 22 | "Below is an instruction that describes a task, paired with an input that provides further context. " 23 | "Write a response that appropriately completes the request." 24 | ) 25 | prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" 26 | return [prompt, output] 27 | else: 28 | prefix = ( 29 | "Below is an instruction that describes a task. Write a response that appropriately completes the request." 30 | ) 31 | prompt = f"{prefix}\n\n### Instruction:\n{instruction}\n\n### Response:\n" 32 | return [prompt, output] 33 | 34 | 35 | def main(hparams={}, model_name="EleutherAI/gpt-j-6B", dataset="tatsu-lab/alpaca"): 36 | config = default_sft_config() 37 | config = config.evolve( 38 | train=dict( 39 | total_steps=2400, 40 | batch_size=4, 41 | seq_length=1024, 42 | ), 43 | model=dict( 44 | model_path=model_name, 45 | ), 46 | tokenizer=dict( 47 | tokenizer_path=model_name, 48 | ), 49 | optimizer=dict(kwargs=dict(lr=2e-5)), 50 | scheduler=dict(kwargs=dict(eta_min=2e-5)), 51 | method=dict( 52 | gen_kwargs=dict( 53 | max_new_tokens=256, 54 | ) 55 | ), 56 | ) 57 | 58 | # Merge sweep config with default config if given 59 | config = TRLConfig.update(config.to_dict(), hparams) 60 | 61 | # alpaca = load_dataset("tatsu-lab/alpaca", split="train") 62 | alpaca = load_dataset(dataset, split="train") 63 | alpaca = [preprocess(x["instruction"], x["input"], x["output"]) for x in alpaca] 64 | 65 | sentiment_fn = pipeline( 66 | "sentiment-analysis", 67 | "lvwerra/distilbert-imdb", 68 | top_k=2, 69 | truncation=True, 70 | batch_size=256, 71 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 72 | ) 73 | 74 | def metric_fn(samples: List[str], prompts: List[str], outputs: List[str]) -> Dict[str, List[float]]: 75 | sentiments = list(map(get_positive_score, sentiment_fn(outputs))) 76 | return {"sentiments": sentiments} 77 | 78 | imdb = load_dataset("imdb", split="test") 79 | bad_reviews = imdb.filter(lambda sample: sample["label"] == 0).select(range(256)) 80 | zs_rewrite = [preprocess("Rewrite the input into a positive review.", x["text"][:1024], "")[0] for x in bad_reviews] 81 | 82 | trainer = trlx.train( 83 | samples=alpaca, 84 | eval_prompts=zs_rewrite, 85 | metric_fn=metric_fn, 86 | config=config, 87 | ) 88 | 89 | slug = f"{model_name.split('/')[-1]}-{dataset.split('/')[-1]}" 90 | trainer.save_pretrained(f"{slug}-sft") 91 | 92 | 93 | if __name__ == "__main__": 94 | parser = ArgumentParser() 95 | parser.add_argument("override_hparams", type=str, default="{}", nargs="?") 96 | parser.add_argument("--model_name", type=str, default="EleutherAI/gpt-j-6B") 97 | parser.add_argument("--dataset", type=str, default="tatsu-lab/alpaca") 98 | 99 | args = parser.parse_args() 100 | hparams = json.loads(args.override_hparams) 101 | 102 | main(hparams, args.model_name, args.dataset) 103 | -------------------------------------------------------------------------------- /examples/architext.py: -------------------------------------------------------------------------------- 1 | # Toy example of optimizing textual interior designs to output the least number of rooms 2 | # Also see https://architext.design/ 3 | import trlx 4 | from trlx.data.default_configs import default_ppo_config 5 | 6 | 7 | def reward_fn(samples, **kwargs): 8 | "Gives a negative count of rooms for each sample" 9 | return [-sample.count(":") for sample in samples] 10 | 11 | 12 | prompts = [ 13 | "[prompt] the bedroom is adjacent to the living room [layout]", 14 | "[prompt] a bedroom is adjacent to the living room [layout]", 15 | "[prompt] the bedroom is adjacent to the kitchen [layout]", 16 | "[prompt] a bedroom is adjacent to the kitchen [layout]", 17 | "[prompt] the bedroom is adjacent to the kitchen [layout]", 18 | "[prompt] the kitchen is adjacent to the bathroom [layout]", 19 | "[prompt] a bathroom is adjacent to the living room [layout]", 20 | "[prompt] the bathroom is adjacent to the living room [layout]", 21 | "[prompt] the bedroom is not adjacent to the living room [layout]", 22 | "[prompt] a bedroom is not adjacent to the living room [layout]", 23 | "[prompt] the bedroom is not adjacent to the kitchen [layout]", 24 | "[prompt] a bedroom is not adjacent to the kitchen [layout]", 25 | "[prompt] the bedroom is not adjacent to the kitchen [layout]", 26 | "[prompt] the kitchen is not adjacent to the bathroom [layout]", 27 | ] 28 | 29 | 30 | def main(): 31 | config = default_ppo_config() 32 | 33 | trlx.train(model_path="architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config) 34 | 35 | 36 | if __name__ == "__main__": 37 | main() 38 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/README.md: -------------------------------------------------------------------------------- 1 | # Interpreter Grounded Program Synthesis 2 | *Program synthesis* is the task of automatically generating programs that solve a given task by satisfying an IO condition. In Neural Program Synthesis the synthesizer is a neural network which is a Language Model that takes in an input/output pair and tries to generate the program in the defined toy DSL's Grammar. 3 | 4 | ## Toy List Manipulation DSL Grammar 5 | The DSL has the following grammar: 6 | ``` 7 | list_expr := list[int] 8 | integer := -5 | -4 | -3 | -2 | -1 | 0 | 1 | 2 | 3 | 4 | 5 9 | statement := 10 | | take(list_expr,integer) 11 | | drop(list_expr,integer) 12 | | reverse(list_expr) 13 | | sort_asc(list_expr) 14 | | sort_des(list_expr) 15 | | add_n(list_expr,integer) 16 | | sub_n(list_expr,integer) 17 | | mul_n(list_expr,integer) 18 | | expand_copy(list_expr) 19 | 20 | 21 | ``` 22 | This particular program `add_n(reverse([-2, -5, -4]),1)` would reverse the list and add one to it, thereby giving `[-3,-4,-1]`. 23 | More examples are showcased below: 24 | ``` 25 | take([1,2,3],2) -> [1,2] 26 | drop([1,2,3],2) -> [1] 27 | reverse([1,2,3]) -> [3,2,1] 28 | sort_asc([10,5,6]) -> [5,6,10] 29 | sort_des([10,5,6]) -> [10,6,5] 30 | 31 | ``` 32 | To generate training/testing data run, `python3 -m lang`. The dataset would be saved in `./dataset/train.json` and `./dataset/test.json`. To use the processed dataset refer to this [google drive link](https://drive.google.com/drive/folders/1093FlJA0MF7gh25yi4-__yU6Fj-onK1v?usp=share_link). 33 | Each datapoint in the dataset would look like, 34 | ```json 35 | {"input": "Input: [4, -2, 0, 0, 5, 5] Output: [25, 25, 20, 0, 0, -10] Function:", 36 | "output": "sort_des(reverse(mul_n(sort_asc(sort_asc([4, -2, 0, 0, 5, 5])),5)))"} 37 | ``` 38 | ## Caveat on DSL design 39 | The DSL designed here is a very simple toy example with every function returning type `list`, ideally in a real world scenario even list manipulation DSLs would be more complex with different types like strings, etc. 40 | ## Training with TRLX 41 | Run `python3 -m train_trlx.py` to run the training with grounded interpreter. The `reward_fn`, would return `-1` if a sample generated is of invalid syntax. it would return `0.5` if the generated syntax is valid but doesn't satisfy IO condition. 42 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/examples/experiments/grounded_program_synthesis/__init__.py -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 256 3 | epochs: 10 4 | total_steps: 80000 5 | batch_size: 8 6 | 7 | checkpoint_interval: 1000000 8 | eval_interval: 16 9 | 10 | pipeline: "PromptPipeline" 11 | trainer: "AcceleratePPOTrainer" 12 | 13 | model: 14 | model_path: "reshinthadith/codegen_350M_list_manip_5_len" 15 | num_layers_unfrozen: 2 16 | 17 | tokenizer: 18 | tokenizer_path: "reshinthadith/codegen_350M_list_manip_5_len" 19 | 20 | optimizer: 21 | name: "adamw" 22 | kwargs: 23 | lr: 1.412e-4 24 | betas: [0.9, 0.95] 25 | eps: 1.0e-8 26 | weight_decay: 1.0e-6 27 | 28 | scheduler: 29 | name: "cosine_annealing" 30 | kwargs: 31 | T_max: 80000 # train.total_steps 32 | eta_min: 1.412e-4 33 | 34 | method: 35 | name: "ppoconfig" 36 | num_rollouts: 8 37 | chunk_size: 8 38 | ppo_epochs: 4 39 | init_kl_coef: 0.2 40 | target: 6 41 | horizon: 10000 42 | gamma: 1 43 | lam: 0.95 44 | cliprange: 0.2 45 | cliprange_value: 0.2 46 | vf_coef: 0.2 47 | scale_reward: False 48 | cliprange_reward: 10 49 | ref_mean: null 50 | ref_std: null 51 | gen_kwargs: 52 | max_new_tokens: 256 53 | top_k: 0 54 | top_p: 0.7 55 | do_sample: True 56 | temperature: 0.5 57 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/train_trlx.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import pathlib 4 | 5 | import yaml 6 | from lang import Interpreter 7 | 8 | import trlx 9 | from trlx.data.configs import TRLConfig 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | class DSLDataset: 15 | def __init__(self): 16 | with open("dataset/train.json", "r") as f: 17 | self.train_data = json.load(f) 18 | with open("dataset/test.json", "r") as f: 19 | self.test_data = json.load(f) 20 | logger.info("Sucessfully loaded the dataset") 21 | 22 | def load_datapoints(self, split="train"): 23 | if split == "train": 24 | for datapoint in self.train_data: 25 | if "ERROR" not in datapoint["input"]: 26 | yield datapoint["input"] 27 | elif split == "test": 28 | for datapoint in self.test_data: 29 | yield datapoint["input"] 30 | 31 | 32 | interpreter = Interpreter() 33 | 34 | 35 | def reward_fn(samples, **kwargs): 36 | reward_list = [] 37 | for sample in samples: 38 | code = sample.split("Function:")[1].strip() 39 | output = eval(sample.split("Output:")[1].strip().split("Function:")[0].strip()) 40 | interpreted_output = interpreter(code) 41 | if interpreted_output == "ERROR": 42 | # If the code is unparsable, we give it a negative reward. 43 | reward_list.append(-1) 44 | else: 45 | # if the code is parseable 46 | if output == interpreted_output: 47 | # if the output is correct, we give it a positive reward. 48 | reward_list.append(1) 49 | else: 50 | # if the output is incorrect, we give it a negative reward. 51 | reward_list.append(-0.5) 52 | 53 | return reward_list 54 | 55 | 56 | config_path = pathlib.Path(__file__).parent.joinpath("configs/trlx_ppo_config.yml") 57 | with config_path.open() as f: 58 | default_config = yaml.safe_load(f) 59 | 60 | 61 | def main(hparams={}): 62 | config = TRLConfig.update(default_config, hparams) 63 | 64 | # Dataset 65 | dataset = DSLDataset() 66 | train_prompts = list(dataset.load_datapoints(split="train"))[:1000] 67 | 68 | trainer = trlx.train( 69 | reward_fn=reward_fn, 70 | prompts=train_prompts, 71 | config=config, 72 | ) 73 | trainer.save_pretrained("dataset/trained_model") 74 | 75 | 76 | if __name__ == "__main__": 77 | # TEST REWARD FUNTION 78 | assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"])) == [1] 79 | assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"])) == [-1] 80 | assert (reward_fn(["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"])) == [-0.5] 81 | 82 | main() 83 | -------------------------------------------------------------------------------- /examples/hh/README.md: -------------------------------------------------------------------------------- 1 | ### Training on Anthropic's Helpful & Harmless [dataset](https://github.com/anthropics/hh-rlhf) 2 | 3 | As an example, the following setup assumes a single machine with 8xA100 80GB, the last of which will be dedicated to hosting a reward model. Optionally you can use [Triton Inference Server](https://github.com/triton-inference-server) to host it elsewhere, otherwise the training script will instantiate it ([a default one](https://huggingface.co/Dahoas/gptj-rm-static)) on its own. 4 | 5 | Launch training of [GPT-J](https://huggingface.co/EleutherAI/gpt-j-6B) on 7 GPUs with 8th GPU hosting a reward model: 6 | ```sh 7 | accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 8 | ``` 9 | Or if you want to train a smaller model or start from a supervised checkpoint, you can use one of the [configs](../../configs) 10 | ```sh 11 | CONFIG_NAME=125M accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 12 | ``` 13 | 14 | Already trained models are hosted on https://huggingface.co/reciprocate 15 | 16 | #### Optional steps to setup a reward model (trained with [Dahoas/reward-modeling](https://github.com/Dahoas/reward-modeling)) with Triton Server: 17 | 18 | ```sh 19 | # convert the model and create a config and a folder `model_store` structured for Triton 20 | python to_triton.py --base_model EleutherAI/gpt-j-6B --checkpoint Dahoas/gptj-rm-static --revision 676bfd4d 21 | 22 | # convert the docker image (skip this if you use docker instead) 23 | singularity build --sandbox tritonserver-pyt.sif docker://nvcr.io/nvidia/tritonserver:22.08-pyt-python-py3 24 | ``` 25 | 26 | ```sh 27 | # start Triton Server pointing to the `model_store` containing the reward model 28 | SINGULARITYENV_CUDA_VISIBLE_DEVICES=7 singularity run --nv --bind model_store:/model_store tritonserver-pyt.sif tritonserver --model-repository=/model_store & 29 | 30 | # set model's url and replace the name after the slash if you use a different checkpoint 31 | export TRITON_HOST=localhost:8001/gptj-rm-static 32 | 33 | # launch training 34 | accelerate launch --num_processes 7 --config_file ../../configs/accelerate/zero2-bf16.yaml ppo_hh.py 35 | ``` 36 | 37 | #### Sample W&B runs 38 | 39 | PPO GPT-J: https://wandb.ai/sorry/trlx/runs/v0bir5s9 40 | 41 | ILQL GPT-J: https://wandb.ai/sorry/trlx/runs/1qqxp72a 42 | 43 | SFT GPT-J: https://wandb.ai/sorry/trlx/runs/a7ng078v 44 | -------------------------------------------------------------------------------- /examples/hh/ilql_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from itertools import islice 5 | 6 | from datasets import load_dataset 7 | from ppo_hh import create_reward_fn 8 | 9 | import trlx 10 | from trlx.data.default_configs import ( 11 | ILQLConfig, 12 | ModelConfig, 13 | OptimizerConfig, 14 | SchedulerConfig, 15 | TokenizerConfig, 16 | TrainConfig, 17 | TRLConfig, 18 | ) 19 | 20 | default_config = TRLConfig( 21 | train=TrainConfig( 22 | seq_length=1024, 23 | batch_size=4, 24 | epochs=100, 25 | total_steps=20000, 26 | checkpoint_interval=10000, 27 | eval_interval=1000, 28 | pipeline="PromptPipeline", 29 | trainer="AccelerateILQLTrainer", 30 | checkpoint_dir="checkpoints/ilql_hh", 31 | ), 32 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=-1), 33 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 34 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 35 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1000000000, eta_min=1e-6)), 36 | method=ILQLConfig( 37 | name="ilqlconfig", 38 | tau=0.6, 39 | gamma=0.99, 40 | cql_scale=0.1, 41 | awac_scale=1, 42 | alpha=0.0001, 43 | beta=0, 44 | steps_for_target_q_sync=1, 45 | two_qs=True, 46 | gen_kwargs=dict(max_new_tokens=128, top_k=20, beta=[1, 4], temperature=1.0), 47 | ), 48 | ) 49 | 50 | config_name = os.environ.get("CONFIG_NAME") 51 | if config_name == "125M": 52 | default_config.train.batch_size = 16 53 | default_config.train.checkpoint_dir = "checkpoints/ilql_hh_125M" 54 | default_config.model.model_path = "EleutherAI/pythia-125m-deduped" 55 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 56 | elif config_name == "1B": 57 | default_config.train.batch_size = 8 58 | default_config.train.checkpoint_dir = "checkpoints/ilql_hh_1B" 59 | default_config.model.model_path = "EleutherAI/pythia-1.4b-deduped" 60 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 61 | elif config_name == "6B": 62 | default_config.train.batch_size = 4 63 | default_config.train.checkpoint_dir = "checkpoints/ilql_hh_6B" 64 | default_config.model.model_path = "EleutherAI/pythia-6.9b-deduped" 65 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 66 | elif config_name == "20B": 67 | default_config.train.batch_size = 1 68 | default_config.train.total_steps = 3000 69 | default_config.train.checkpoint_dir = "checkpoints/ilql_hh_20B" 70 | default_config.model.model_path = "EleutherAI/gpt-neox-20b" 71 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 72 | 73 | 74 | def preprocess(sample): 75 | sample["prompt_output"] = [ 76 | [sample["prompt"], sample["chosen"]], 77 | [sample["prompt"], sample["rejected"]], 78 | ] 79 | sample["reward"] = [1, -1] 80 | return sample 81 | 82 | 83 | def main(hparams={}): 84 | config = TRLConfig.update(default_config, hparams) 85 | 86 | dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) 87 | prompts_outputs = sum(dataset["train"]["prompt_output"], []) 88 | 89 | rewards = sum(dataset["train"]["reward"], []) 90 | eval_prompts = [{"prompt": x["prompt"], "original_output": x["chosen"]} for x in islice(dataset["test"], 280)] 91 | reward_fn = create_reward_fn() 92 | 93 | trlx.train( 94 | samples=prompts_outputs, 95 | rewards=rewards, 96 | config=config, 97 | eval_prompts=eval_prompts, 98 | metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, 99 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 100 | ) 101 | 102 | 103 | if __name__ == "__main__": 104 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 105 | main(hparams) 106 | -------------------------------------------------------------------------------- /examples/hh/sft_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from datasets import load_dataset 5 | from ppo_hh import create_reward_fn 6 | 7 | import trlx 8 | from trlx.data.default_configs import ( 9 | ModelConfig, 10 | OptimizerConfig, 11 | SchedulerConfig, 12 | SFTConfig, 13 | TokenizerConfig, 14 | TrainConfig, 15 | TRLConfig, 16 | ) 17 | 18 | default_config = TRLConfig( 19 | train=TrainConfig( 20 | seq_length=1024, 21 | epochs=100, 22 | total_steps=10000, 23 | batch_size=4, 24 | checkpoint_interval=10000, 25 | eval_interval=1000, 26 | pipeline="PromptPipeline", 27 | trainer="AccelerateSFTTrainer", 28 | checkpoint_dir="checkpoints/sft_hh", 29 | ), 30 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=-1), 31 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 32 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 33 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=100000000, eta_min=1e-6)), 34 | method=SFTConfig( 35 | name="sftconfig", 36 | gen_kwargs=dict(max_new_tokens=128, top_k=20, top_p=1.0, do_sample=True), 37 | ), 38 | ) 39 | 40 | 41 | def preprocess(sample): 42 | sample["chosen_sample"] = sample["prompt"] + sample["chosen"] 43 | return sample 44 | 45 | 46 | def main(hparams={}): 47 | config = TRLConfig.update(default_config, hparams) 48 | 49 | dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) 50 | reward_fn = create_reward_fn() 51 | 52 | trlx.train( 53 | config=config, 54 | samples=dataset["train"]["chosen_sample"], 55 | eval_prompts=dataset["test"]["prompt"][:280], 56 | metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, 57 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 63 | main(hparams) 64 | -------------------------------------------------------------------------------- /examples/hh/to_triton.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from string import Template 4 | 5 | import torch 6 | from huggingface_hub import snapshot_download 7 | from torch import nn 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--base_model", type=str, required=True, help="Path to HF checkpoint with the base model") 13 | 14 | parser.add_argument( 15 | "--checkpoint", 16 | type=str, 17 | required=True, 18 | help="Path to either a local directory or a HF checkpoint with reward model's weights", 19 | ) 20 | 21 | parser.add_argument("--revision", type=str, required=False, help="Optional branch/commit of the HF checkpoint") 22 | 23 | parser.add_argument("--device", type=int, default=0) 24 | args = parser.parse_args() 25 | 26 | model_name = args.checkpoint.split("/")[-1] 27 | device = torch.device(args.device) 28 | 29 | 30 | class RewardModel(nn.Module): 31 | def __init__(self, checkpoint_path, eos_token_id): 32 | super().__init__() 33 | model = AutoModelForCausalLM.from_pretrained(checkpoint_path) 34 | self.transformer = model.transformer 35 | self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) 36 | self.eos_token_id = eos_token_id 37 | 38 | def forward(self, input_ids): 39 | states = self.transformer(input_ids)[0] 40 | rewards = self.v_head(states).squeeze(-1) 41 | ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) 42 | returns = torch.gather(rewards, 1, ends).squeeze(-1) 43 | return returns 44 | 45 | 46 | if os.path.isdir(args.checkpoint): 47 | directory = args.checkpoint 48 | else: 49 | directory = snapshot_download(args.checkpoint, revision=args.revision) 50 | 51 | print(f"searching through {os.listdir(directory)} in {directory}") 52 | 53 | for fpath in os.listdir(directory): 54 | if fpath.endswith(".pt") or fpath.endswith(".bin"): 55 | checkpoint = os.path.join(directory, fpath) 56 | break 57 | 58 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 59 | model = RewardModel(args.base_model, tokenizer.eos_token_id) 60 | model.load_state_dict(torch.load(checkpoint)) 61 | model.eval() 62 | model.requires_grad_(False) 63 | model = model.half().to(device) 64 | 65 | input = tokenizer("reward model's hash", return_tensors="pt").to(device) 66 | print(f"{model(input.input_ids)=}") 67 | 68 | traced_script_module = torch.jit.trace(model, input.input_ids) 69 | 70 | os.makedirs(f"model_store/{model_name}/1", exist_ok=True) 71 | traced_script_module.save(f"model_store/{model_name}/1/traced-model.pt") 72 | 73 | config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt") 74 | with open(config_path) as f: 75 | template = Template(f.read()) 76 | config = template.substitute({"model_name": model_name}) 77 | with open(f"model_store/{model_name}/config.pbtxt", "w") as f: 78 | f.write(config) 79 | -------------------------------------------------------------------------------- /examples/hh/triton_config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "${model_name}" 2 | backend: "pytorch" 3 | default_model_filename: "traced-model.pt" 4 | max_batch_size: 25 5 | 6 | parameters { 7 | key: "model_name" 8 | value: { 9 | string_value: "${model_name}" 10 | } 11 | } 12 | 13 | instance_group [ 14 | { 15 | count: 1 16 | kind: KIND_GPU 17 | gpus: [0] 18 | } 19 | ] 20 | 21 | input [ 22 | { 23 | name: "input_ids" 24 | data_type: TYPE_INT32 25 | dims: [-1] 26 | } 27 | ] 28 | 29 | output [ 30 | { 31 | name: "rewards" 32 | data_type: TYPE_FP16 33 | dims: [-1] 34 | } 35 | ] 36 | 37 | parameters { 38 | key: "data_type" 39 | value: { 40 | string_value: "fp16" 41 | } 42 | } 43 | 44 | parameters: { 45 | key: "INFERENCE_MODE" 46 | value: { 47 | string_value: "true" 48 | } 49 | } 50 | 51 | version_policy: {specific: {versions: [1]}} 52 | -------------------------------------------------------------------------------- /examples/ilql_sentiments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from typing import Dict, List 5 | 6 | from datasets import load_dataset 7 | from transformers import pipeline 8 | 9 | import trlx 10 | from trlx.data.default_configs import TRLConfig, default_ilql_config 11 | 12 | 13 | def get_positive_score(scores): 14 | "Extract value associated with a positive sentiment from pipeline's output" 15 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 16 | 17 | 18 | def main(hparams={}): 19 | # Merge sweep config with default config if given 20 | config = TRLConfig.update(default_ilql_config().to_dict(), hparams) 21 | 22 | sentiment_fn = pipeline( 23 | "sentiment-analysis", 24 | "lvwerra/distilbert-imdb", 25 | top_k=2, 26 | truncation=True, 27 | batch_size=256, 28 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 29 | ) 30 | 31 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 32 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 33 | return {"sentiments": sentiments} 34 | 35 | imdb = load_dataset("imdb", split="train+test") 36 | 37 | trlx.train( 38 | samples=imdb["text"], 39 | rewards=imdb["label"], 40 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 41 | metric_fn=metric_fn, 42 | config=config, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 48 | main(hparams) 49 | -------------------------------------------------------------------------------- /examples/ilql_sentiments_t5.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | from datasets import load_dataset 6 | from transformers import AutoTokenizer, pipeline 7 | 8 | import trlx 9 | from trlx.data.configs import ( 10 | ModelConfig, 11 | OptimizerConfig, 12 | SchedulerConfig, 13 | TokenizerConfig, 14 | TrainConfig, 15 | TRLConfig, 16 | ) 17 | from trlx.models.modeling_ilql import ILQLConfig 18 | 19 | 20 | def get_positive_score(scores): 21 | "Extract value associated with a positive sentiment from pipeline's output" 22 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 23 | 24 | 25 | default_config = TRLConfig( 26 | train=TrainConfig( 27 | seq_length=128, 28 | epochs=100, 29 | total_steps=1000, 30 | batch_size=32, 31 | checkpoint_interval=1000, 32 | eval_interval=100, 33 | pipeline="PromptPipeline", 34 | trainer="AccelerateILQLTrainer", 35 | save_best=False, 36 | ), 37 | model=ModelConfig( 38 | model_path="lvwerra/t5-imdb", 39 | num_layers_unfrozen=-1, 40 | model_arch_type="seq2seq", 41 | ), 42 | tokenizer=TokenizerConfig( 43 | tokenizer_path="lvwerra/t5-imdb", 44 | padding_side="right", 45 | truncation_side="right", 46 | ), 47 | optimizer=OptimizerConfig( 48 | name="adamw", 49 | kwargs={ 50 | "lr": 5.0e-5, 51 | "betas": [0.9, 0.999], 52 | "eps": 1.0e-8, 53 | "weight_decay": 1.0e-6, 54 | }, 55 | ), 56 | scheduler=SchedulerConfig( 57 | name="cosine_annealing", 58 | kwargs={ 59 | "T_max": 100000, 60 | "eta_min": 5.0e-5, 61 | }, 62 | ), 63 | method=ILQLConfig( 64 | name="ILQLConfig", 65 | tau=0.7, 66 | gamma=0.99, 67 | cql_scale=0.1, 68 | awac_scale=1, 69 | alpha=0.001, 70 | beta=0, 71 | steps_for_target_q_sync=5, 72 | two_qs=True, 73 | gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), 74 | ), 75 | ) 76 | 77 | 78 | class LengthSampler: 79 | """ 80 | Samples a length 81 | """ 82 | 83 | def __init__(self, min_value, max_value): 84 | self.values = list(range(min_value, max_value)) 85 | self.rng = np.random.default_rng(seed=2023) 86 | 87 | def __call__(self): 88 | return self.rng.choice(self.values) 89 | 90 | 91 | def main(hparams={}): 92 | config = TRLConfig.update(default_config, hparams) 93 | 94 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 95 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 96 | return dict(sentiments=sentiments) 97 | 98 | sentiment_fn = pipeline( 99 | "sentiment-analysis", 100 | "lvwerra/distilbert-imdb", 101 | top_k=2, 102 | truncation=True, 103 | batch_size=256, 104 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 105 | ) 106 | tokenizer = AutoTokenizer.from_pretrained("lvwerra/t5-imdb") 107 | 108 | def build_imdb_dataset_test(tokenizer, input_min_text_length=2, input_max_text_length=8): 109 | # load imdb with datasets 110 | ds = load_dataset("imdb", split="test") 111 | ds = ds.rename_columns({"text": "review"}) 112 | ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) 113 | 114 | input_size = LengthSampler(input_min_text_length, input_max_text_length) 115 | 116 | def tokenize(sample): 117 | sample["review"] = sample["review"].replace("/>br", "") 118 | input_ids = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id] 119 | sample["query"] = tokenizer.decode(input_ids) 120 | return sample 121 | 122 | ds = ds.map(tokenize, batched=False) 123 | return ds 124 | 125 | dataset = load_dataset("imdb", split="train") 126 | prompts = dataset["text"] 127 | rewards = dataset["label"] 128 | val_prompts = build_imdb_dataset_test(tokenizer)["query"][0:100] 129 | 130 | trlx.train( 131 | samples=prompts, 132 | rewards=rewards, 133 | eval_prompts=val_prompts, 134 | metric_fn=metric_fn, 135 | config=config, 136 | ) 137 | 138 | 139 | if __name__ == "__main__": 140 | main() 141 | -------------------------------------------------------------------------------- /examples/llama_nemo/README.md: -------------------------------------------------------------------------------- 1 | ### NeMo Megatron setup: 2 | 3 | - Install NeMo version: v1.17.0 4 | 5 | ```bash 6 | git clone https://github.com/NVIDIA/NeMo/ 7 | cd NeMo 8 | git checkout d3017e4 9 | pip install -e '.[all]' 10 | ``` 11 | 12 | - Install Apex: 13 | ```bash 14 | git clone https://github.com/NVIDIA/apex 15 | cd apex 16 | # if pip >= 23.1 (ref: https://pip.pypa.io/en/stable/news/#v23-1) which supports multiple `--config-settings` with the same key... 17 | pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./ 18 | ``` 19 | 20 | ### Convert LLaMa to NeMo: 21 | Example: 22 | 23 | ```bash 24 | python convert_llama_to_nemo.py --model_path NousResearch/Llama-2-7b-hf --output_folder nemo_llama2_7b --total_tp 4 --name 7b 25 | ``` 26 | 27 | ### Training: 28 | Example: [wandb](https://wandb.ai/carperai/trlxnemo/runs/v7592y73?workspace=user-pvduy) 29 | 30 | ```bash 31 | sbatch dist_train.sh 32 | ``` 33 | -------------------------------------------------------------------------------- /examples/llama_nemo/dist_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=llama 3 | #SBATCH --partition=g40 4 | #SBATCH --nodes=1 5 | #SBATCH --ntasks-per-node=8 6 | #SBATCH --mem=0 7 | #SBATCH --cpus-per-task=8 8 | #SBATCH --output=out.txt 9 | #SBATCH --error=error.txt 10 | #SBATCH --exclusive 11 | 12 | cd examples/llama_nemo 13 | srun --label python nemo_llama2_ppo_sentiments.py 14 | -------------------------------------------------------------------------------- /examples/llama_nemo/nemo_llama2_ppo_sentiments.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | from datasets import load_dataset 9 | from transformers import DistilBertForSequenceClassification, pipeline 10 | 11 | import trlx 12 | from trlx.data.default_configs import TRLConfig, default_ppo_config 13 | 14 | 15 | def get_positive_score(scores): 16 | "Extract value associated with a positive sentiment from pipeline's output" 17 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 18 | 19 | 20 | def load_nemo_config(): 21 | """Load nemo-megatron-1.3b model and trainer config""" 22 | # Import here to not require nemo as a dependency 23 | from omegaconf import OmegaConf 24 | 25 | return OmegaConf.load("nemo_llama2_7b/megatron_7b.yaml") 26 | 27 | 28 | def main(hparams={}): 29 | # Merge sweep config with default config if given 30 | default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams) 31 | nemo_config = load_nemo_config() 32 | print(nemo_config) 33 | cfg_name = "llama2-7b" 34 | config = default_config.evolve( 35 | train=dict( 36 | total_steps=1600, 37 | seq_length=256, 38 | batch_size=16, 39 | epochs=100, 40 | eval_interval=100, 41 | trainer="NeMoPPOTrainer", 42 | trainer_kwargs=dict( 43 | pretrained_model="nemo_llama2_7b/", 44 | megatron_cfg=nemo_config, 45 | ), 46 | checkpoint_interval=256, 47 | checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments", 48 | seed=2023, 49 | project_name="trlxnemo", 50 | tags=["nemo", "ppo", "sentiments", cfg_name], 51 | ), 52 | optimizer=dict( 53 | name="adamw", 54 | kwargs=dict( 55 | lr=1e-5, 56 | weight_decay=1e-06, 57 | eps=1.0e-8, 58 | betas=(0.9, 0.95), 59 | ), 60 | ), 61 | scheduler=dict( 62 | name="CosineAnnealing", 63 | ), 64 | model=dict(num_layers_unfrozen=24), 65 | method=dict( 66 | num_rollouts=128, 67 | init_kl_coef=0.05, 68 | vf_coef=1, 69 | scale_reward="ignored", 70 | gamma=1, 71 | lam=0.95, 72 | cliprange=0.2, 73 | cliprange_value=0.2, 74 | gen_kwargs=dict(temperature=1.0, max_new_tokens=64), 75 | chunk_size=64, 76 | ppo_epochs=4, 77 | ), 78 | ) 79 | config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=1e-6) 80 | 81 | rank = int(os.environ["SLURM_PROCID"]) 82 | local_rank = rank % 8 83 | 84 | reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb") 85 | reward_model.to("cpu") 86 | sentiment_fn = pipeline( 87 | "sentiment-analysis", 88 | model=reward_model, # "lvwerra/distilbert-imdb", 89 | tokenizer="lvwerra/distilbert-imdb", 90 | top_k=2, 91 | truncation=True, 92 | batch_size=256, 93 | device=local_rank, 94 | ) 95 | 96 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 97 | reward_model.to(local_rank) 98 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 99 | reward_model.to("cpu") 100 | return sentiments 101 | 102 | # Take few words off of movies reviews as prompts 103 | imdb = load_dataset("imdb", split="train+test") 104 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 105 | trlx.train( 106 | reward_fn=reward_fn, 107 | prompts=prompts, 108 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 109 | config=config, 110 | ) 111 | 112 | 113 | if __name__ == "__main__": 114 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 115 | main(hparams) 116 | -------------------------------------------------------------------------------- /examples/nemo_ilql_inference.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | from glob import glob 4 | 5 | from nemo.collections.nlp.modules.common.megatron.megatron_init import ( 6 | fake_initialize_model_parallel, 7 | ) 8 | from nemo.utils.app_state import AppState 9 | from nemo.utils.model_utils import inject_model_parallel_rank 10 | from omegaconf.omegaconf import OmegaConf 11 | 12 | from trlx.data.configs import TrainConfig 13 | from trlx.data.default_configs import default_ilql_config 14 | from trlx.trainer.nemo_ilql_trainer import ILQLGPT, megatron_trainer 15 | 16 | default_config = default_ilql_config() 17 | 18 | trl_config = default_config.evolve( 19 | train=TrainConfig( 20 | **dict( 21 | default_config.train.__dict__, 22 | trainer="NeMoILQLTrainer", 23 | trainer_kwargs=dict( 24 | pretrained_model=None, 25 | megatron_cfg="megatron_20b.yaml", 26 | ), 27 | ), 28 | ) 29 | ) 30 | 31 | 32 | def find_checkpoints(checkpoint_dir): 33 | checkpoints = glob(os.path.join(checkpoint_dir, "*", "*.ckpt")) 34 | names = [os.path.basename(c) for c in checkpoints] 35 | return set(names) 36 | 37 | 38 | def main(megatron_cfg_path, checkpoint_path): 39 | ilql_config = trl_config.method 40 | 41 | megatron_cfg = OmegaConf.load(megatron_cfg_path) 42 | megatron_cfg.trainer.num_nodes = 1 43 | megatron_cfg.trainer.devices = 4 44 | megatron_cfg.model.resume_from_checkpoint = checkpoint_path 45 | megatron_cfg.exp_manager.create_wandb_logger = False 46 | megatron_cfg.exp_manager.create_checkpoint_callback = False 47 | 48 | trainer = megatron_trainer(megatron_cfg) 49 | 50 | # Manually set up the TP and PP groups 51 | app_state = AppState() 52 | app_state.model_parallel_size = ( 53 | megatron_cfg.model.tensor_model_parallel_size * megatron_cfg.model.pipeline_model_parallel_size 54 | ) 55 | app_state.tensor_model_parallel_size = megatron_cfg.model.tensor_model_parallel_size 56 | app_state.pipeline_model_parallel_size = megatron_cfg.model.pipeline_model_parallel_size 57 | ( 58 | app_state.tensor_model_parallel_rank, 59 | app_state.pipeline_model_parallel_rank, 60 | app_state.model_parallel_size, 61 | app_state.data_parallel_size, 62 | app_state.pipeline_model_parallel_split_rank, 63 | app_state.virtual_pipeline_model_parallel_rank, 64 | ) = fake_initialize_model_parallel( 65 | world_size=app_state.model_parallel_size, 66 | rank=trainer.global_rank, 67 | tensor_model_parallel_size_=megatron_cfg.model.tensor_model_parallel_size, 68 | pipeline_model_parallel_size_=megatron_cfg.model.pipeline_model_parallel_size, 69 | pipeline_model_parallel_split_rank_=None, 70 | ) 71 | 72 | checkpoint_names = find_checkpoints(checkpoint_path) 73 | checkpoint_name = next(iter(checkpoint_names)) 74 | print(f"Loading checkpoint {checkpoint_name}, found {checkpoint_names} checkpoints") 75 | 76 | checkpoint_path = inject_model_parallel_rank(os.path.join(checkpoint_path, checkpoint_name)) 77 | 78 | model = ILQLGPT.load_from_checkpoint( 79 | checkpoint_path, 80 | cfg=megatron_cfg.model, 81 | trainer=trainer, 82 | ilql_config=ilql_config, 83 | ) 84 | 85 | model.sequence_parallel_(False) 86 | model.activation_checkpointing_(False) 87 | 88 | test = ["I don't know much about Hungarian underground"] 89 | test = [model.tokenizer.tokenizer.bos_token + t for t in test] 90 | 91 | print(model.generate(test, dict(max_length=40, min_length=0))["sentences"]) 92 | 93 | 94 | if __name__ == "__main__": 95 | main(sys.argv[1], sys.argv[2]) 96 | -------------------------------------------------------------------------------- /examples/nemo_ilql_sentiments.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | from datasets import load_dataset 4 | from transformers import pipeline 5 | 6 | import trlx 7 | from trlx.data.default_configs import default_ilql_config 8 | 9 | 10 | def get_positive_score(scores): 11 | "Extract value associated with a positive sentiment from pipeline's output" 12 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 13 | 14 | 15 | default_config = default_ilql_config() 16 | 17 | 18 | def main(hparams={}): 19 | # Merge sweep config with default config if given 20 | 21 | config = default_config.evolve( 22 | train=dict( 23 | seq_length=1024, 24 | batch_size=512, 25 | total_steps=200, 26 | trainer="NeMoILQLTrainer", 27 | trainer_kwargs=dict( 28 | pretrained_model=None, 29 | megatron_cfg="megatron_20b.yaml", 30 | ), 31 | ), 32 | method=dict( 33 | gen_kwargs=dict( 34 | beta=2.0, 35 | temperature=0.9, 36 | ) 37 | ), 38 | ) 39 | config = config.evolve(**hparams) 40 | 41 | sentiment_fn = pipeline( 42 | "sentiment-analysis", 43 | "lvwerra/distilbert-imdb", 44 | top_k=2, 45 | truncation=True, 46 | batch_size=256, 47 | device=-1, 48 | ) 49 | 50 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 51 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 52 | return {"sentiments": sentiments} 53 | 54 | imdb = load_dataset("imdb", split="train+test") 55 | 56 | trlx.train( 57 | samples=imdb["text"], 58 | rewards=imdb["label"], 59 | eval_prompts=["I don't know much about Hungarian underground"] * 128, 60 | metric_fn=metric_fn, 61 | config=config, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | main() 67 | -------------------------------------------------------------------------------- /examples/nemo_ppo_inference.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys 3 | from glob import glob 4 | 5 | from omegaconf.omegaconf import OmegaConf 6 | 7 | from trlx.data.default_configs import default_ppo_config 8 | from trlx.trainer.nemo_ppo_trainer import PPOGPT, megatron_trainer 9 | 10 | default_config = default_ppo_config() 11 | 12 | trl_config = default_config.evolve( 13 | train=dict( 14 | default_config.train.__dict__, 15 | trainer="NeMoPPOTrainer", 16 | trainer_kwargs=dict( 17 | pretrained_model=None, 18 | megatron_cfg="megatron_20b.yaml", 19 | ), 20 | ), 21 | ) 22 | 23 | 24 | def find_checkpoints(checkpoint_dir): 25 | checkpoints = glob(os.path.join(checkpoint_dir, "*", "*.ckpt")) 26 | names = [os.path.basename(c) for c in checkpoints] 27 | return set(names) 28 | 29 | 30 | def main(megatron_cfg_path, checkpoint_path): 31 | ppo_config = trl_config.method 32 | 33 | megatron_cfg = OmegaConf.load(megatron_cfg_path) 34 | megatron_cfg.trainer.num_nodes = 1 35 | megatron_cfg.trainer.devices = ( 36 | megatron_cfg.model.tensor_model_parallel_size * megatron_cfg.model.pipeline_model_parallel_size 37 | ) 38 | # Overriden in generate 39 | megatron_cfg.model.global_batch_size = megatron_cfg.model.micro_batch_size 40 | megatron_cfg.model.resume_from_checkpoint = checkpoint_path 41 | megatron_cfg.exp_manager.create_wandb_logger = False 42 | megatron_cfg.exp_manager.create_checkpoint_callback = False 43 | 44 | trainer = megatron_trainer(megatron_cfg) 45 | 46 | if trainer.world_size != megatron_cfg.trainer.devices: 47 | raise ValueError("Inference only supports data parallel world size of 1") 48 | 49 | # Initialize PyTorch Lightning DDP 50 | 51 | def dummy(): 52 | return 53 | 54 | if trainer.strategy.launcher is not None: 55 | trainer.strategy.launcher.launch(dummy, trainer=trainer) 56 | trainer.strategy.setup_environment() 57 | 58 | model = PPOGPT(ppo_config=ppo_config, cfg=megatron_cfg.model, trainer=trainer, build_reference_model=False) 59 | model.load_from_pretrained(checkpoint_path) 60 | 61 | test = ["I don't know much about Hungarian underground"] 62 | test = [model.tokenizer.tokenizer.bos_token + t for t in test] 63 | 64 | print(model.generate(test, dict(max_length=40, min_length=0))["sentences"]) 65 | 66 | 67 | if __name__ == "__main__": 68 | main(sys.argv[1], sys.argv[2]) 69 | -------------------------------------------------------------------------------- /examples/nemo_ppo_sentiments.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | from datasets import load_dataset 9 | from transformers import DistilBertForSequenceClassification, pipeline 10 | 11 | import trlx 12 | from trlx.data.default_configs import ( 13 | TRLConfig, 14 | default_nemo_1_3b_config, 15 | default_nemo_2b_config, 16 | default_nemo_20b_config, 17 | default_ppo_config, 18 | ) 19 | 20 | 21 | def get_positive_score(scores): 22 | "Extract value associated with a positive sentiment from pipeline's output" 23 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 24 | 25 | 26 | def main(hparams={}): 27 | # Merge sweep config with default config if given 28 | default_config = TRLConfig.update(default_ppo_config().to_dict(), hparams) 29 | 30 | cfg_name = os.environ.get("NEMO_CONFIG", "1.3B") 31 | if cfg_name == "1.3B": 32 | nemo_config = default_nemo_1_3b_config() 33 | elif cfg_name == "2B": 34 | nemo_config = default_nemo_2b_config() 35 | elif cfg_name == "20B": 36 | nemo_config = default_nemo_20b_config() 37 | else: 38 | raise ValueError(f"Unknown NEMO_CONFIG: {cfg_name}") 39 | 40 | config = default_config.evolve( 41 | train=dict( 42 | total_steps=512, 43 | seq_length=2048, 44 | batch_size=32, 45 | epochs=100, 46 | eval_interval=64, 47 | trainer="NeMoPPOTrainer", 48 | trainer_kwargs=dict( 49 | pretrained_model=f"/mnt/hdd/nemo-megatron-gpt-{cfg_name}/", 50 | megatron_cfg=nemo_config, 51 | ), 52 | checkpoint_interval=256, 53 | checkpoint_dir=f"nemo_{cfg_name}_ppo_sentiments", 54 | seed=2023, 55 | project_name="trlxnemo", 56 | tags=["nemo", "ppo", "sentiments", cfg_name], 57 | ), 58 | optimizer=dict( 59 | name="distributed_fused_adam", 60 | kwargs=dict( 61 | lr=6.001e-5, 62 | weight_decay=1e-06, 63 | eps=1.0e-8, 64 | betas=(0.9, 0.95), 65 | ), 66 | ), 67 | scheduler=dict( 68 | name="CosineAnnealing", 69 | ), 70 | model=dict(num_layers_unfrozen=2), 71 | method=dict( 72 | num_rollouts=128, 73 | init_kl_coef=0.05, 74 | scale_reward="ref", 75 | vf_coef=1, 76 | gen_kwargs=dict(temperature=1.0, max_new_tokens=40), 77 | chunk_size=128, 78 | ppo_epochs=4, 79 | ), 80 | ) 81 | config.scheduler.kwargs = dict(warmup_steps=0, constant_steps=1e12, min_lr=6.0e-5) 82 | 83 | rank = int(os.environ["SLURM_PROCID"]) 84 | local_rank = rank % 8 85 | 86 | reward_model = DistilBertForSequenceClassification.from_pretrained("lvwerra/distilbert-imdb") 87 | reward_model.to(local_rank) 88 | sentiment_fn = pipeline( 89 | "sentiment-analysis", 90 | model=reward_model, # "lvwerra/distilbert-imdb", 91 | tokenizer="lvwerra/distilbert-imdb", 92 | top_k=2, 93 | truncation=True, 94 | batch_size=256, 95 | device=local_rank, 96 | ) 97 | 98 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 99 | reward_model.to(local_rank) 100 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 101 | reward_model.to("cpu") 102 | return sentiments 103 | 104 | # Take few words off of movies reviews as prompts 105 | imdb = load_dataset("imdb", split="train+test") 106 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 107 | 108 | trlx.train( 109 | reward_fn=reward_fn, 110 | prompts=prompts, 111 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 112 | config=config, 113 | ) 114 | 115 | 116 | if __name__ == "__main__": 117 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 118 | main(hparams) 119 | -------------------------------------------------------------------------------- /examples/nemo_sft_sentiments.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, List 3 | 4 | from datasets import load_dataset 5 | from transformers import pipeline 6 | 7 | import trlx 8 | from trlx.data.default_configs import ( 9 | TRLConfig, 10 | default_nemo_1_3b_config, 11 | default_nemo_20b_config, 12 | default_sft_config, 13 | ) 14 | 15 | 16 | def get_positive_score(scores): 17 | "Extract value associated with a positive sentiment from pipeline's output" 18 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 19 | 20 | 21 | def main(hparams={}): 22 | # Merge sweep config with default config if given 23 | 24 | default_config = TRLConfig.update(default_sft_config(), hparams) 25 | 26 | cfg_name = os.environ.get("NEMO_CONFIG", "1.3B") 27 | if cfg_name == "1.3B": 28 | nemo_config = default_nemo_1_3b_config() 29 | elif cfg_name == "20B": 30 | nemo_config = default_nemo_20b_config() 31 | else: 32 | raise ValueError(f"Unknown NEMO_CONFIG: {cfg_name}") 33 | 34 | nemo_config.exp_manager.create_wandb_logger = True 35 | nemo_config.exp_manager.wandb_logger_kwargs.name = f"nemo-sft-sentiments-{cfg_name}" 36 | 37 | config = default_config.evolve( 38 | train=dict( 39 | trainer="NeMoSFTTrainer", 40 | trainer_kwargs=dict( 41 | pretrained_model=f"/mnt/hdd/nemo-megatron-gpt-{cfg_name}/", 42 | megatron_cfg=nemo_config, 43 | ), 44 | ), 45 | model=dict(num_layers_unfrozen=-1), 46 | tags=["nemo", "sft", "sentiments", cfg_name], 47 | ) 48 | 49 | imdb = load_dataset("imdb", split="train+test") 50 | # Finetune on only positive reviews 51 | imdb = imdb.filter(lambda sample: sample["label"] == 1) 52 | 53 | sentiment_fn = pipeline( 54 | "sentiment-analysis", 55 | "lvwerra/distilbert-imdb", 56 | top_k=2, 57 | truncation=True, 58 | batch_size=256, 59 | device=-1, 60 | ) 61 | 62 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 63 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 64 | return {"sentiments": sentiments} 65 | 66 | trlx.train( 67 | samples=imdb["text"], 68 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 69 | metric_fn=metric_fn, 70 | config=config, 71 | ) 72 | 73 | 74 | if __name__ == "__main__": 75 | main() 76 | -------------------------------------------------------------------------------- /examples/ppo_dense_sentiments.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from transformers import pipeline 11 | 12 | import trlx 13 | from trlx.data.default_configs import TRLConfig, default_ppo_config 14 | 15 | 16 | def get_positive_score(scores): 17 | "Extract value associated with a positive sentiment from pipeline's output" 18 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 19 | 20 | 21 | def get_negative_score(scores): 22 | return dict(map(lambda x: tuple(x.values()), scores))["NEGATIVE"] 23 | 24 | 25 | def main(hparams={}): 26 | # Merge sweep config with default config if given 27 | config = TRLConfig.update(default_ppo_config().to_dict(), hparams) 28 | 29 | if torch.cuda.is_available(): 30 | device = int(os.environ.get("LOCAL_RANK", 0)) 31 | else: 32 | device = -1 33 | 34 | sentiment_fn = pipeline( 35 | "sentiment-analysis", 36 | "lvwerra/distilbert-imdb", 37 | top_k=2, 38 | truncation=True, 39 | batch_size=256, 40 | device=device, 41 | ) 42 | 43 | def dense_reward_fn(samples: List[str], prompts: List[str], outputs: List[str], tokenizer, **kwargs) -> List[float]: 44 | # Reward positively for initially negative then positive review 45 | # Reward functions should never receive padded text except for a single EOS at the end 46 | # Reward function should return token rewards for just the response 47 | first_halves = [".".join(sample.split(".")[: len(sample.split(".")) // 2]) for sample in samples] 48 | negative_first_halves = list(map(get_negative_score, sentiment_fn(first_halves))) 49 | second_halves = [".".join(sample.split(".")[len(sample.split(".")) // 2 :]) for sample in samples] 50 | positive_second_halves = list(map(get_positive_score, sentiment_fn(second_halves))) 51 | text_scores = [[f, s] for f, s in zip(negative_first_halves, positive_second_halves)] 52 | tok_scores = [] 53 | for sample, prompt, response, text_score in zip(samples, prompts, outputs, text_scores): 54 | toks = tokenizer(response).input_ids 55 | tok_score = [0] * len(toks) 56 | tok_score[len(tok_score) // 2] = text_score[0] 57 | tok_score[-1] = text_score[1] 58 | tok_scores.append(tok_score) 59 | return tok_scores 60 | 61 | # Take few words off of movies reviews as prompts 62 | imdb = load_dataset("imdb", split="train+test") 63 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 64 | 65 | trlx.train( 66 | reward_fn=dense_reward_fn, 67 | prompts=prompts, 68 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 69 | config=config, 70 | ) 71 | 72 | 73 | if __name__ == "__main__": 74 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 75 | main(hparams) 76 | -------------------------------------------------------------------------------- /examples/ppo_sentiments.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from transformers import pipeline 11 | 12 | import trlx 13 | from trlx.data.default_configs import TRLConfig, default_ppo_config 14 | 15 | 16 | def get_positive_score(scores): 17 | "Extract value associated with a positive sentiment from pipeline's output" 18 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 19 | 20 | 21 | def main(hparams={}): 22 | # Merge sweep config with default config if given 23 | config = TRLConfig.update(default_ppo_config().to_dict(), hparams) 24 | 25 | if torch.cuda.is_available(): 26 | device = int(os.environ.get("LOCAL_RANK", 0)) 27 | else: 28 | device = -1 29 | 30 | sentiment_fn = pipeline( 31 | "sentiment-analysis", 32 | "lvwerra/distilbert-imdb", 33 | top_k=2, 34 | truncation=True, 35 | batch_size=256, 36 | device=device, 37 | ) 38 | 39 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 40 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 41 | return sentiments 42 | 43 | # Take few words off of movies reviews as prompts 44 | imdb = load_dataset("imdb", split="train+test") 45 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 46 | 47 | trlx.train( 48 | reward_fn=reward_fn, 49 | prompts=prompts, 50 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 51 | config=config, 52 | ) 53 | 54 | 55 | if __name__ == "__main__": 56 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 57 | main(hparams) 58 | -------------------------------------------------------------------------------- /examples/ppo_sentiments_llama.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from transformers import pipeline 11 | 12 | import trlx 13 | from trlx.data.default_configs import ( 14 | ModelConfig, 15 | OptimizerConfig, 16 | PPOConfig, 17 | SchedulerConfig, 18 | TokenizerConfig, 19 | TrainConfig, 20 | TRLConfig, 21 | ) 22 | 23 | 24 | def get_positive_score(scores): 25 | "Extract value associated with a positive sentiment from pipeline's output" 26 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 27 | 28 | 29 | def llama_config(): 30 | return TRLConfig( 31 | train=TrainConfig( 32 | seq_length=1024, 33 | epochs=100, 34 | total_steps=400, 35 | batch_size=32, 36 | checkpoint_interval=10000, 37 | eval_interval=100, 38 | pipeline="PromptPipeline", 39 | trainer="AcceleratePPOTrainer", 40 | save_best=False, 41 | ), 42 | model=ModelConfig(model_path="NousResearch/Llama-2-7b-hf", num_layers_unfrozen=2), 43 | tokenizer=TokenizerConfig(tokenizer_path="NousResearch/Llama-2-7b-hf", truncation_side="right"), 44 | optimizer=OptimizerConfig( 45 | name="adamw", kwargs=dict(lr=1e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 46 | ), 47 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-5)), 48 | method=PPOConfig( 49 | name="PPOConfig", 50 | num_rollouts=128, 51 | chunk_size=128, 52 | ppo_epochs=4, 53 | init_kl_coef=0.001, 54 | target=6, 55 | horizon=10000, 56 | gamma=1, 57 | lam=0.95, 58 | cliprange=0.2, 59 | cliprange_value=0.2, 60 | vf_coef=1, 61 | scale_reward="ignored", 62 | ref_mean=None, 63 | ref_std=None, 64 | cliprange_reward=10, 65 | gen_kwargs=dict( 66 | max_new_tokens=40, 67 | top_k=0, 68 | top_p=1.0, 69 | do_sample=True, 70 | ), 71 | ), 72 | ) 73 | 74 | 75 | def main(hparams={}): 76 | # Merge sweep config with default config if given 77 | config = TRLConfig.update(llama_config().to_dict(), hparams) 78 | 79 | if torch.cuda.is_available(): 80 | device = int(os.environ.get("LOCAL_RANK", 0)) 81 | else: 82 | device = -1 83 | 84 | sentiment_fn = pipeline( 85 | "sentiment-analysis", 86 | "lvwerra/distilbert-imdb", 87 | top_k=2, 88 | truncation=True, 89 | batch_size=256, 90 | device=device, 91 | ) 92 | 93 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 94 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 95 | return sentiments 96 | 97 | # Take few words off of movies reviews as prompts 98 | imdb = load_dataset("imdb", split="train+test") 99 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 100 | 101 | trlx.train( 102 | reward_fn=reward_fn, 103 | prompts=prompts, 104 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 105 | config=config, 106 | ) 107 | 108 | 109 | if __name__ == "__main__": 110 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 111 | main(hparams) 112 | -------------------------------------------------------------------------------- /examples/ppo_sentiments_peft.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from peft import LoraConfig 11 | from peft.utils.config import TaskType 12 | from transformers import pipeline 13 | 14 | import trlx 15 | from trlx.data.default_configs import TRLConfig, default_ppo_config 16 | 17 | 18 | def get_positive_score(scores): 19 | "Extract value associated with a positive sentiment from pipeline's output" 20 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 21 | 22 | 23 | def main(hparams={}): 24 | # Merge sweep config with default config if given 25 | config = TRLConfig.update(default_ppo_config().to_dict(), hparams) 26 | 27 | if torch.cuda.is_available(): 28 | device = int(os.environ.get("LOCAL_RANK", 0)) 29 | else: 30 | device = -1 31 | 32 | sentiment_fn = pipeline( 33 | "sentiment-analysis", 34 | "lvwerra/distilbert-imdb", 35 | top_k=2, 36 | truncation=True, 37 | batch_size=256, 38 | device=device, 39 | ) 40 | 41 | # Just insert your peft config here (the type must be an instance of peft.PeftConfig or a dict). 42 | config.model.peft_config = LoraConfig( 43 | r=8, 44 | task_type=TaskType.CAUSAL_LM, 45 | lora_alpha=32, 46 | lora_dropout=0.1, 47 | ) 48 | 49 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 50 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 51 | return sentiments 52 | 53 | # Take few words off of movies reviews as prompts 54 | imdb = load_dataset("imdb", split="train+test") 55 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 56 | 57 | trlx.train( 58 | reward_fn=reward_fn, 59 | prompts=prompts, 60 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 61 | config=config, 62 | ) 63 | 64 | 65 | if __name__ == "__main__": 66 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 67 | main(hparams) 68 | -------------------------------------------------------------------------------- /examples/ppo_sentiments_t5.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from typing import Dict, List 5 | 6 | import numpy as np 7 | from datasets import load_dataset 8 | from transformers import AutoTokenizer, pipeline 9 | 10 | import trlx 11 | from trlx.data.configs import ( 12 | ModelConfig, 13 | OptimizerConfig, 14 | SchedulerConfig, 15 | TokenizerConfig, 16 | TrainConfig, 17 | TRLConfig, 18 | ) 19 | from trlx.models.modeling_ppo import PPOConfig 20 | 21 | 22 | def get_positive_score(scores): 23 | "Extract value associated with a positive sentiment from pipeline's output" 24 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 25 | 26 | 27 | default_config = TRLConfig( 28 | train=TrainConfig( 29 | seq_length=128, 30 | epochs=100, 31 | total_steps=100000, 32 | batch_size=12, 33 | checkpoint_interval=10000, 34 | eval_interval=100, 35 | pipeline="PromptPipeline", 36 | trainer="AcceleratePPOTrainer", 37 | save_best=False, 38 | ), 39 | model=ModelConfig( 40 | model_path="lvwerra/t5-imdb", 41 | num_layers_unfrozen=-1, 42 | model_arch_type="seq2seq", 43 | ), 44 | tokenizer=TokenizerConfig( 45 | tokenizer_path="lvwerra/t5-imdb", 46 | padding_side="right", 47 | truncation_side="right", 48 | ), 49 | optimizer=OptimizerConfig( 50 | name="adamw", 51 | kwargs={ 52 | "lr": 5.0e-5, 53 | "betas": [0.9, 0.999], 54 | "eps": 1.0e-8, 55 | "weight_decay": 1.0e-6, 56 | }, 57 | ), 58 | scheduler=SchedulerConfig( 59 | name="cosine_annealing", 60 | kwargs={ 61 | "T_max": 100000, 62 | "eta_min": 5.0e-5, 63 | }, 64 | ), 65 | method=PPOConfig( 66 | name="PPOConfig", 67 | num_rollouts=128, 68 | chunk_size=12, 69 | ppo_epochs=4, 70 | init_kl_coef=0.05, 71 | target=6, 72 | horizon=10000, 73 | gamma=0.99, 74 | lam=0.95, 75 | cliprange=0.2, 76 | cliprange_value=0.2, 77 | vf_coef=1, 78 | scale_reward=None, 79 | ref_mean=None, 80 | ref_std=None, 81 | cliprange_reward=10, 82 | gen_kwargs={ 83 | "max_new_tokens": 50, 84 | "do_sample": True, 85 | "top_k": 0, 86 | "top_p": 1, 87 | "eos_token_id": -1, 88 | }, 89 | ), 90 | ) 91 | 92 | 93 | class LengthSampler: 94 | """ 95 | Samples a length 96 | """ 97 | 98 | def __init__(self, min_value, max_value): 99 | self.values = list(range(min_value, max_value)) 100 | self.rng = np.random.default_rng(seed=2023) 101 | 102 | def __call__(self): 103 | return self.rng.choice(self.values) 104 | 105 | 106 | def main(hparams={}): 107 | config = TRLConfig.update(default_config, hparams) 108 | 109 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 110 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 111 | return sentiments 112 | 113 | sentiment_fn = pipeline( 114 | "sentiment-analysis", 115 | "lvwerra/distilbert-imdb", 116 | top_k=2, 117 | truncation=True, 118 | batch_size=256, 119 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 120 | ) 121 | tokenizer = AutoTokenizer.from_pretrained("lvwerra/t5-imdb") 122 | 123 | def build_imdb_dataset(tokenizer, input_min_text_length=2, input_max_text_length=8): 124 | # load imdb with datasets 125 | ds = load_dataset("imdb", split="train") 126 | ds = ds.rename_columns({"text": "review"}) 127 | ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) 128 | 129 | input_size = LengthSampler(input_min_text_length, input_max_text_length) 130 | 131 | def tokenize(sample): 132 | sample["review"] = sample["review"].replace("/>br", "") 133 | sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id] 134 | sample["query"] = tokenizer.decode(sample["input_ids"]) 135 | return sample 136 | 137 | ds = ds.map(tokenize, batched=False) 138 | ds.set_format(type="torch") 139 | return ds 140 | 141 | def build_imdb_dataset_test(tokenizer, input_min_text_length=2, input_max_text_length=8): 142 | # load imdb with datasets 143 | ds = load_dataset("imdb", split="test") 144 | ds = ds.rename_columns({"text": "review"}) 145 | ds = ds.filter(lambda x: len(x["review"]) > 200, batched=False) 146 | 147 | input_size = LengthSampler(input_min_text_length, input_max_text_length) 148 | 149 | def tokenize(sample): 150 | sample["review"] = sample["review"].replace("/>br", "") 151 | sample["input_ids"] = tokenizer.encode(sample["review"])[: input_size()] + [tokenizer.eos_token_id] 152 | sample["query"] = tokenizer.decode(sample["input_ids"]) 153 | return sample 154 | 155 | ds = ds.map(tokenize, batched=False) 156 | ds.set_format(type="torch") 157 | return ds 158 | 159 | dataset = build_imdb_dataset(tokenizer) 160 | prompts = dataset["query"] 161 | val_prompts = build_imdb_dataset_test(tokenizer)["query"][0:100] 162 | 163 | trlx.train( 164 | prompts=prompts, 165 | eval_prompts=val_prompts, 166 | reward_fn=metric_fn, 167 | config=config, 168 | ) 169 | 170 | 171 | if __name__ == "__main__": 172 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 173 | main(hparams) 174 | -------------------------------------------------------------------------------- /examples/randomwalks/README.md: -------------------------------------------------------------------------------- 1 | # Random Walks: Decision Tree Example 2 | 3 | This example uses the Toy Problem described in [Decision Transformer (Lili Chen 4 | et al. 2021)](https://arxiv.org/abs/2106.01345). 5 | 6 | ## Game Description 7 | 8 | The task is to find the shortest path on a directed graph. The reward is based 9 | on how optimal the path is compared to the shortest possible (bounded in [0, 10 | 1]). 11 | 12 | Note this is different to the paper, which gave rewards of -1 for every 13 | turn not at the goal state, and 0 at the goal state. Here the model instead 14 | receives its reward at the end of the full trajectory, based on how optimal it 15 | is compared to the minimum number of steps to reach the goal state (bounded in 16 | [0, 1]). 17 | 18 | Paths are represented as strings of letters, with each letter corresponding to a 19 | node in the graph. 20 | 21 | ## Training 22 | 23 | ![Graph Example](graph-example.png) 24 | Source: Decision Transformer (Lili Chen et al. 2021) 25 | 26 | For PPO, a language model was fine-tuned to predict the next token in a sequence 27 | of returns-to-go (sum of future rewards), states and actions. It was trained 28 | only on random walk data. 29 | 30 | ILQL by contrast learns from the samples directly. 31 | -------------------------------------------------------------------------------- /examples/randomwalks/__init__.py: -------------------------------------------------------------------------------- 1 | from .randomwalks import generate_random_walks 2 | -------------------------------------------------------------------------------- /examples/randomwalks/graph-example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/examples/randomwalks/graph-example.png -------------------------------------------------------------------------------- /examples/randomwalks/ilql_randomwalks.py: -------------------------------------------------------------------------------- 1 | from transformers import GPT2Config 2 | 3 | import trlx 4 | from examples.randomwalks import generate_random_walks 5 | from trlx.data.default_configs import ( 6 | ILQLConfig, 7 | ModelConfig, 8 | OptimizerConfig, 9 | SchedulerConfig, 10 | TokenizerConfig, 11 | TrainConfig, 12 | TRLConfig, 13 | ) 14 | 15 | 16 | def main(hparams): 17 | config = TRLConfig.update(default_config, hparams) 18 | 19 | metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed) 20 | rewards = metric_fn(walks)["optimality"] 21 | # split each random walk into (starting state, rest of the walk) 22 | walks = [[walk[:1], walk[1:]] for walk in walks] 23 | 24 | trlx.train( 25 | model_path=GPT2Config(n_layer=6, n_embd=144, vocab_size=23), 26 | samples=walks, 27 | rewards=rewards, 28 | eval_prompts=eval_prompts, 29 | metric_fn=lambda samples, **kwargs: metric_fn(samples), 30 | config=config, 31 | stop_sequences=["|"], 32 | ) 33 | 34 | 35 | default_config = TRLConfig( 36 | train=TrainConfig( 37 | seq_length=11, 38 | batch_size=100, 39 | epochs=20, 40 | total_steps=1000, 41 | checkpoint_interval=1000, 42 | eval_interval=16, 43 | pipeline="PromptPipeline", 44 | trainer="AccelerateILQLTrainer", 45 | ), 46 | model=ModelConfig(model_path=GPT2Config(n_layer=6, n_embd=144, vocab_size=23), num_layers_unfrozen=-1), 47 | tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"), 48 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=2e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 49 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=2e-4)), 50 | method=ILQLConfig( 51 | name="ilqlconfig", 52 | tau=0.8, 53 | gamma=0.99, 54 | cql_scale=0.1, 55 | awac_scale=1, 56 | alpha=0.1, 57 | beta=0, 58 | steps_for_target_q_sync=5, 59 | two_qs=True, 60 | gen_kwargs=dict(max_new_tokens=9, top_k=10, beta=[0, 1, 100], temperature=1.0), 61 | ), 62 | ) 63 | 64 | if __name__ == "__main__": 65 | import json 66 | import sys 67 | 68 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 69 | main(hparams) 70 | -------------------------------------------------------------------------------- /examples/randomwalks/ppo_randomwalks.py: -------------------------------------------------------------------------------- 1 | import trlx 2 | from examples.randomwalks import generate_random_walks 3 | from trlx.data.default_configs import ( 4 | ModelConfig, 5 | OptimizerConfig, 6 | PPOConfig, 7 | SchedulerConfig, 8 | TokenizerConfig, 9 | TrainConfig, 10 | TRLConfig, 11 | ) 12 | 13 | default_config = TRLConfig( 14 | train=TrainConfig( 15 | seq_length=10, 16 | epochs=20, 17 | total_steps=10000, 18 | batch_size=100, 19 | checkpoint_interval=10000, 20 | eval_interval=20, 21 | pipeline="PromptPipeline", 22 | trainer="AcceleratePPOTrainer", 23 | ), 24 | model=ModelConfig(model_path="CarperAI/randomwalks", num_layers_unfrozen=-1), 25 | tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"), 26 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 27 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=3.0e-4)), 28 | method=PPOConfig( 29 | name="PPOConfig", 30 | num_rollouts=128, 31 | chunk_size=128, 32 | ppo_epochs=4, 33 | init_kl_coef=0, 34 | target=None, 35 | horizon=10000, 36 | gamma=1, 37 | lam=0.95, 38 | cliprange=0.2, 39 | cliprange_value=0.2, 40 | vf_coef=1.2, 41 | scale_reward="ignored", 42 | ref_mean=None, 43 | ref_std=None, 44 | cliprange_reward=1, 45 | gen_kwargs=dict( 46 | max_new_tokens=9, 47 | top_k=0, 48 | top_p=1.0, 49 | do_sample=True, 50 | ), 51 | ), 52 | ) 53 | 54 | 55 | def main(hparams={}): 56 | config = TRLConfig.update(default_config, hparams) 57 | metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) 58 | 59 | trlx.train( 60 | # An "optimality" reward function is used, with scores in [0,1] 61 | # depending on how close the path is to the shortest possible path. 62 | reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], 63 | # The prompts are simply the first nodes (represented as letters) to 64 | # start from. 65 | prompts=prompts, 66 | eval_prompts=prompts, 67 | metric_fn=lambda samples, **kwargs: metric_fn(samples), 68 | config=config, 69 | ) 70 | 71 | 72 | if __name__ == "__main__": 73 | import json 74 | import sys 75 | 76 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 77 | main(hparams) 78 | -------------------------------------------------------------------------------- /examples/randomwalks/rft_randomwalks.py: -------------------------------------------------------------------------------- 1 | import trlx 2 | from examples.randomwalks import generate_random_walks 3 | from trlx.data.default_configs import ( 4 | ModelConfig, 5 | OptimizerConfig, 6 | SchedulerConfig, 7 | TokenizerConfig, 8 | TrainConfig, 9 | TRLConfig, 10 | ) 11 | from trlx.trainer.accelerate_rft_trainer import RFTConfig 12 | 13 | default_config = TRLConfig( 14 | train=TrainConfig( 15 | seq_length=10, 16 | epochs=100, 17 | total_steps=1000, 18 | batch_size=100, 19 | checkpoint_interval=1000, 20 | eval_interval=100, 21 | pipeline="PromptPipeline", 22 | trainer="AccelerateRFTTrainer", 23 | checkpoint_dir="checkpoints/randomwalks", 24 | ), 25 | model=ModelConfig(model_path="CarperAI/randomwalks", num_layers_unfrozen=-1), 26 | tokenizer=TokenizerConfig(tokenizer_path="CarperAI/randomwalks", truncation_side="right"), 27 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3.0e-4, betas=(0.9, 0.99), eps=1.0e-8, weight_decay=0)), 28 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=3.0e-4)), 29 | method=RFTConfig( 30 | name="RFTConfig", 31 | n_generations_per_prompt=100, 32 | start_percentile=0.9, 33 | end_percentile=0.95, 34 | n_improve_steps=1, 35 | gen_kwargs=dict( 36 | max_new_tokens=9, 37 | top_k=0, 38 | top_p=1.0, 39 | temperature=1.0, 40 | do_sample=True, 41 | ), 42 | ), 43 | ) 44 | 45 | 46 | def main(hparams={}): 47 | config = TRLConfig.update(default_config, hparams) 48 | metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) 49 | 50 | trlx.train( 51 | reward_fn=lambda samples, **kwargs: metric_fn(samples)["optimality"], 52 | prompts=prompts, 53 | eval_prompts=prompts, 54 | metric_fn=lambda samples, **kwargs: metric_fn(samples), 55 | config=config, 56 | ) 57 | 58 | 59 | if __name__ == "__main__": 60 | import json 61 | import sys 62 | 63 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 64 | main(hparams) 65 | -------------------------------------------------------------------------------- /examples/rft_sentiments.py: -------------------------------------------------------------------------------- 1 | # This script trains a model to output positive reviews 2 | # using rejection finetuning with a sentiment classifier reward function. 3 | import json 4 | import os 5 | import sys 6 | from typing import List 7 | 8 | import torch 9 | from datasets import load_dataset 10 | from transformers import pipeline 11 | 12 | import trlx 13 | from trlx.data.default_configs import ( 14 | ModelConfig, 15 | OptimizerConfig, 16 | SchedulerConfig, 17 | TokenizerConfig, 18 | TrainConfig, 19 | TRLConfig, 20 | ) 21 | from trlx.trainer.accelerate_rft_trainer import RFTConfig 22 | 23 | 24 | def get_positive_score(scores): 25 | "Extract value associated with a positive sentiment from pipeline's output" 26 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 27 | 28 | 29 | default_config = TRLConfig( 30 | train=TrainConfig( 31 | seq_length=1024, 32 | epochs=100, 33 | total_steps=1000, 34 | batch_size=32, 35 | checkpoint_interval=10000, 36 | eval_interval=100, 37 | pipeline="PromptPipeline", 38 | trainer="AccelerateRFTTrainer", 39 | ), 40 | model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=-1), 41 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 42 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 43 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), 44 | method=RFTConfig( 45 | name="RFTConfig", 46 | n_generations_per_prompt=4, 47 | start_percentile=0.9, 48 | end_percentile=0.95, 49 | n_improve_steps=1, 50 | gen_kwargs=dict( 51 | max_new_tokens=40, 52 | top_k=0, 53 | top_p=1.0, 54 | temperature=1.0, 55 | do_sample=True, 56 | ), 57 | ), 58 | ) 59 | 60 | 61 | def main(hparams={}): 62 | config = TRLConfig.update(default_config, hparams) 63 | 64 | if torch.cuda.is_available(): 65 | device = int(os.environ.get("LOCAL_RANK", 0)) 66 | else: 67 | device = -1 68 | 69 | sentiment_fn = pipeline( 70 | "sentiment-analysis", 71 | "lvwerra/distilbert-imdb", 72 | top_k=2, 73 | truncation=True, 74 | batch_size=256, 75 | device=device, 76 | ) 77 | 78 | def reward_fn(samples: List[str], **kwargs) -> List[float]: 79 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 80 | return sentiments 81 | 82 | # Take few words off of movies reviews as prompts 83 | imdb = load_dataset("imdb", split="train[:512]") 84 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 85 | 86 | trlx.train( 87 | reward_fn=reward_fn, 88 | prompts=prompts, 89 | eval_prompts=["I don't know much about Hungarian underground"] * 256, 90 | config=config, 91 | ) 92 | 93 | 94 | if __name__ == "__main__": 95 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 96 | main(hparams) 97 | -------------------------------------------------------------------------------- /examples/sft_sentiments.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | from typing import Dict, List 5 | 6 | from datasets import load_dataset 7 | from transformers import pipeline 8 | 9 | import trlx 10 | from trlx.data.default_configs import TRLConfig, default_sft_config 11 | 12 | 13 | def get_positive_score(scores): 14 | "Extract value associated with a positive sentiment from pipeline's output" 15 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 16 | 17 | 18 | def main(hparams={}): 19 | # Merge sweep config with default config if given 20 | config = TRLConfig.update(default_sft_config().to_dict(), hparams) 21 | 22 | imdb = load_dataset("imdb", split="train+test") 23 | # Finetune on only positive reviews 24 | imdb = imdb.filter(lambda sample: sample["label"] == 1) 25 | 26 | sentiment_fn = pipeline( 27 | "sentiment-analysis", 28 | "lvwerra/distilbert-imdb", 29 | top_k=2, 30 | truncation=True, 31 | batch_size=256, 32 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 33 | ) 34 | 35 | def metric_fn(samples: List[str], **kwargs) -> Dict[str, List[float]]: 36 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 37 | return {"sentiments": sentiments} 38 | 39 | trainer = trlx.train( 40 | samples=imdb["text"], 41 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 42 | metric_fn=metric_fn, 43 | config=config, 44 | ) 45 | trainer.save_pretrained("reviews-sft") 46 | 47 | 48 | if __name__ == "__main__": 49 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 50 | main(hparams) 51 | -------------------------------------------------------------------------------- /examples/simulacra.py: -------------------------------------------------------------------------------- 1 | # Optimize prompts by training on prompts-ratings pairings dataset 2 | # taken from https://github.com/JD-P/simulacra-aesthetic-captions 3 | 4 | import os 5 | import sqlite3 6 | from urllib.request import urlretrieve 7 | 8 | from accelerate import Accelerator 9 | 10 | import trlx 11 | from trlx.data.default_configs import default_ilql_config 12 | 13 | url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" 14 | dbpath = "sac_public_2022_06_29.sqlite" 15 | 16 | if __name__ == "__main__": 17 | accelerator = Accelerator() 18 | if os.environ.get("LOCAL_RANK", "0") == "0" and not os.path.exists(dbpath): 19 | print(f"fetching {dbpath}") 20 | urlretrieve(url, dbpath) 21 | accelerator.wait_for_everyone() 22 | 23 | conn = sqlite3.connect(dbpath) 24 | c = conn.cursor() 25 | c.execute( 26 | "SELECT prompt, rating FROM ratings " 27 | "JOIN images ON images.id=ratings.iid " 28 | "JOIN generations ON images.gid=generations.id " 29 | "WHERE rating IS NOT NULL;" 30 | ) 31 | 32 | prompts, ratings = tuple(map(list, zip(*c.fetchall()))) 33 | trlx.train( 34 | config=default_ilql_config(), 35 | samples=prompts, 36 | rewards=ratings, 37 | eval_prompts=["An astronaut riding a horse"] * 64, 38 | ) 39 | -------------------------------------------------------------------------------- /examples/summarize_daily_cnn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/examples/summarize_daily_cnn/__init__.py -------------------------------------------------------------------------------- /examples/summarize_daily_cnn/t5_summarize_daily_cnn.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | from datasets import load_dataset 4 | 5 | import trlx 6 | from trlx.data.configs import ( 7 | ModelConfig, 8 | OptimizerConfig, 9 | SchedulerConfig, 10 | TokenizerConfig, 11 | TrainConfig, 12 | TRLConfig, 13 | ) 14 | from trlx.models.modeling_ppo import PPOConfig 15 | 16 | try: 17 | import evaluate 18 | except ImportError: 19 | raise ImportError( 20 | "To run this example, please install the `evaluate` and `nltk` packages" "by running `pip install evaluate`" 21 | ) 22 | 23 | config = TRLConfig( 24 | train=TrainConfig( 25 | seq_length=612, 26 | epochs=100, 27 | total_steps=100000, 28 | batch_size=12, 29 | checkpoint_interval=10000, 30 | eval_interval=500, 31 | pipeline="PromptPipeline", 32 | trainer="AcceleratePPOTrainer", 33 | ), 34 | model=ModelConfig( 35 | model_path="google/flan-t5-large", 36 | model_arch_type="seq2seq", 37 | num_layers_unfrozen=2, 38 | ), 39 | tokenizer=TokenizerConfig( 40 | tokenizer_path="google/flan-t5-large", 41 | truncation_side="right", 42 | ), 43 | optimizer=OptimizerConfig( 44 | name="adamw", 45 | kwargs={ 46 | "lr": 1.0e-5, 47 | "betas": [0.9, 0.999], 48 | "eps": 1.0e-8, 49 | "weight_decay": 1.0e-6, 50 | }, 51 | ), 52 | scheduler=SchedulerConfig( 53 | name="cosine_annealing", 54 | kwargs={ 55 | "T_max": 10000, 56 | "eta_min": 1.0e-6, 57 | }, 58 | ), 59 | method=PPOConfig( 60 | name="PPOConfig", 61 | num_rollouts=512, 62 | chunk_size=12, 63 | ppo_epochs=4, 64 | init_kl_coef=0.05, 65 | target=6, 66 | horizon=10000, 67 | gamma=0.99, 68 | lam=0.95, 69 | cliprange=0.2, 70 | cliprange_value=0.2, 71 | vf_coef=1.0, 72 | scale_reward=None, 73 | ref_mean=None, 74 | ref_std=None, 75 | cliprange_reward=10, 76 | gen_kwargs={ 77 | "max_new_tokens": 100, 78 | }, 79 | gen_experience_kwargs={ 80 | "max_new_tokens": 100, 81 | "do_sample": True, 82 | "temperature": 1.0, 83 | "top_k": 50, 84 | "top_p": 0.95, 85 | }, 86 | ), 87 | ) 88 | 89 | 90 | meteor = evaluate.load("meteor") # use meteor as the reward function 91 | 92 | if __name__ == "__main__": 93 | 94 | def reward_fn(samples: List[str], prompts: List[str], outputs: List[str], original_summaries: List[str], **kwargs): 95 | scores = [ 96 | meteor.compute(predictions=[output.strip()], references=[original_summary])["meteor"] 97 | for (original_summary, output) in zip(original_summaries, outputs) 98 | ] 99 | return scores 100 | 101 | dataset = load_dataset("cnn_dailymail", "3.0.0", cache_dir="data") 102 | 103 | # take 20,000 samples from the training set as prompts for training 104 | prompts = dataset["train"]["article"][0:20000] 105 | summaries = dataset["train"]["highlights"][0:20000] 106 | prompts = ["Summarize: " + prompt for prompt in prompts] 107 | 108 | # take 1,000 samples from the validation set as prompts for evaluation 109 | val_prompts = ["Summarize: " + prompt for prompt in dataset["validation"]["article"][0:1000]] 110 | val_summaries = dataset["validation"]["highlights"][0:1000] 111 | 112 | trlx.train( 113 | reward_fn=reward_fn, 114 | prompts=[{"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(prompts, summaries)], 115 | eval_prompts=[ 116 | {"prompt": prompt, "original_summaries": summary} for prompt, summary in zip(val_prompts, val_summaries) 117 | ], 118 | config=config, 119 | ) 120 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/README.md: -------------------------------------------------------------------------------- 1 | ## Learning to summarize from Human Feedback using `trlx` 2 | 3 | This example shows how to use `trlx` to train a summarization model using human feedback 4 | following the fine-tuning procedures described in Stiennon et al.'s, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325)". 5 | 6 | 7 | Before running everything, we need some extra packages not included in the `trlx` dependency list. Specifically, we need HuggingFace's [`evaluate`](https://huggingface.co/docs/evaluate/index) package and Google's re-implementation of ROUGE, [`rouge-score`](https://github.com/google-research/google-research/tree/master/rouge). To install them, run `requirements.txt` in this example's root directory: 8 | 9 | ```bash 10 | pip install -r requirements.txt 11 | ``` 12 | 13 | ### Training Process 14 | 15 | For an in-depth description of the example, please refer to our [blog post](http://wandb.me/summarize-rlhf-trlx). We leave the following for a quick overview of the fine-tuning process and what scripts to run. 16 | 17 | 18 | 1. Train SFT: 19 | ```bash 20 | cd sft/ && deepspeed train_gptj_summarize.py 21 | ``` 22 | Checkpoint: [SFT](https://huggingface.co/CarperAI/openai_summarize_tldr_sft) 23 | 24 | 2. Train Reward Model: 25 | ```bash 26 | cd reward_model/ && deepspeed train_reward_model_gptj.py 27 | ``` 28 | Download reward model checkpoint: 29 | ```bash 30 | mkdir reward_model/rm_checkpoint 31 | wget https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin -O reward_model/rm_checkpoint/pytorch_model.bin 32 | ``` 33 | 34 | 3. PPO training: 35 | ```bash 36 | accelerate launch --config_file configs/default_accelerate_config.yaml trlx_gptj_text_summarization.py 37 | ``` 38 | Checkpoint: [PPO](https://huggingface.co/CarperAI/openai_summarize_tldr_ppo) 39 | 40 | 🩹 Warning: This particular training configuration requires at least 55GB of VRAM and is setup to use two GPUs, decrease `batch_size` in case you're running out of memory. 41 | 42 | 43 | ### Results 44 | 45 | The following tables display ROUGE and reward scores on the test set of the TL;DR dataset between SFT and PPO models. 46 | 47 | 1. SFT vs PPO 48 | 49 | __ROUGE scores__ 50 | 51 | | Model | Rouge-1 | Rouge-2 | Rouge-L | Average | 52 | | --- | --- | --- | --- | --- | 53 | | SFT | 0.334 | 0.125 | 0.261 | 0.240 | 54 | | PPO | 0.323 | 0.109 | 0.238 | 0.223 | 55 | 56 | __Reward scores__ 57 | 58 | | Model | Average Reward | Reward $\Delta$ | 59 | | --- | --- | --- | 60 | | SFT | 2.729 | -0.181 | 61 | | PPO | 3.291 | +0.411 | 62 | 63 | 64 | 2. Examples of generated summaries can be found [here](https://wandb.ai/carperai/summarize_RLHF/runs/2uirt89a). 65 | 66 | 3. Check our blog post for metric logs and other results [here](http://wandb.me/summarize-rlhf-trlx). 67 | 68 | ## References 69 | 70 | 1. Nisan Stiennon, Long Ouyang, Jeff Wu, Daniel M. Ziegler, Ryan Lowe, Chelsea Voss, Alec Radford, Dario Amodei, Paul Christiano, "[Learning to Summarize from human feedback](https://arxiv.org/abs/2009.01325)", Neural Information Processing Systems, 2020. 71 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/configs/default_accelerate_config.yaml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: 5 | deepspeed_config_file: configs/ds_config_trlx_gptj_summarize.json 6 | zero3_init_flag: false 7 | distributed_type: DEEPSPEED 8 | downcast_bf16: 'no' 9 | dynamo_config: {} 10 | fsdp_config: {} 11 | gpu_ids: null 12 | machine_rank: 0 13 | main_process_ip: null 14 | main_process_port: null 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | num_machines: 1 18 | num_processes: 1 19 | rdzv_backend: static 20 | same_network: true 21 | tpu_name: null 22 | tpu_zone: null 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/configs/ds_config_trlx_gptj_summarize.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_micro_batch_size_per_gpu": 2, 3 | "gradient_accumulation_steps": 4, 4 | "fp16": { 5 | "enabled": true, 6 | "min_loss_scale": 0.5, 7 | "fp16_scale_tolerance": 0.25, 8 | "opt_level": "O2" 9 | }, 10 | "zero_optimization": { 11 | "stage": 2, 12 | "offload_param": { 13 | "device": "cpu" 14 | }, 15 | "offload_optimizer": { 16 | "device": "cpu" 17 | }, 18 | "allgather_partitions": true, 19 | "allgather_bucket_size": 5e8, 20 | "contiguous_gradients": true 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/ilql_summarize_t5.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from reward_model.reward_model import GPTRewardModel 6 | from transformers import AutoTokenizer 7 | 8 | import trlx 9 | from trlx.data.default_configs import ( 10 | ILQLConfig, 11 | ModelConfig, 12 | OptimizerConfig, 13 | SchedulerConfig, 14 | TokenizerConfig, 15 | TrainConfig, 16 | TRLConfig, 17 | ) 18 | 19 | default_config = TRLConfig( 20 | train=TrainConfig( 21 | seq_length=550, 22 | batch_size=8, 23 | epochs=100, 24 | total_steps=5000, 25 | checkpoint_interval=10000, 26 | eval_interval=1000, 27 | pipeline="PromptPipeline", 28 | trainer="AccelerateILQLTrainer", 29 | checkpoint_dir="ilql_summarize_t5", 30 | ), 31 | model=ModelConfig(model_path="pvduy/flant5-xl_openai_tldr_sft", num_layers_unfrozen=-1, model_arch_type="seq2seq"), 32 | tokenizer=TokenizerConfig(tokenizer_path="pvduy/flant5-xl_openai_tldr_sft", truncation_side="left"), 33 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 34 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=5000, eta_min=1e-6)), 35 | method=ILQLConfig( 36 | name="ilqlconfig", 37 | tau=0.6, 38 | gamma=0.99, 39 | cql_scale=0.1, 40 | awac_scale=1, 41 | alpha=0.0001, 42 | beta=0, 43 | steps_for_target_q_sync=1, 44 | two_qs=True, 45 | gen_kwargs=dict(max_new_tokens=50, top_k=50, beta=[1, 2, 3], temperature=1.0), 46 | ), 47 | ) 48 | 49 | REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" 50 | if not os.path.exists(REWARD_CHECKPOINT_PATH): 51 | os.makedirs("reward_model/rm_checkpoint", exist_ok=True) 52 | os.system( 53 | f"wget -O {REWARD_CHECKPOINT_PATH} \ 54 | https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin" 55 | ) 56 | SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" 57 | 58 | 59 | def main(hparams={}): 60 | config = TRLConfig.update(default_config, hparams) 61 | 62 | rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 63 | rw_tokenizer.pad_token = rw_tokenizer.eos_token 64 | rw_model = GPTRewardModel(SFT_MODEL_PATH) 65 | rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) 66 | rw_model.half() 67 | rw_model.eval() 68 | rw_device = torch.device("cuda:{}".format(1)) # set reward model device 69 | rw_model.to(rw_device) 70 | 71 | def reward_fn(samples): 72 | scores_list = [] 73 | batch_size = 2 74 | for i in range(0, len(samples), batch_size): 75 | sub_samples = samples[i : i + batch_size] 76 | sub_samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples] 77 | encodings_dict = rw_tokenizer( 78 | sub_samples, 79 | truncation=True, 80 | max_length=config.train.seq_length, 81 | padding="max_length", 82 | return_tensors="pt", 83 | ) 84 | input_ids = encodings_dict["input_ids"].to(rw_device) 85 | attn_masks = encodings_dict["attention_mask"].to(rw_device) 86 | input_ids = input_ids.repeat(2, 1) 87 | attn_masks = attn_masks.repeat(2, 1) 88 | with torch.no_grad(): 89 | sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) 90 | scores_list.append(sub_scores["chosen_end_scores"]) 91 | scores = torch.cat(scores_list, dim=0) 92 | return scores 93 | 94 | def preprocess(sample): 95 | sample["prompt_output"] = [ 96 | [sample["prompt"] + " TL;DR:", sample["chosen"][7:]], 97 | [sample["prompt"] + " TL;DR:", sample["rejected"][7:]], 98 | ] 99 | sample["reward"] = [1, -1] 100 | return sample 101 | 102 | dataset = load_dataset("CarperAI/openai_summarize_comparisons") 103 | dataset["train"] = dataset["train"] 104 | dataset = dataset.map(preprocess) 105 | 106 | prompts_outputs = sum(dataset["train"]["prompt_output"], []) 107 | rewards = sum(dataset["train"]["reward"], []) 108 | val_dataset = load_dataset("CarperAI/openai_summarize_tldr", split="valid") 109 | eval_prompts = list(val_dataset["prompt"])[:1000] 110 | 111 | trlx.train( 112 | dataset=(prompts_outputs, rewards), 113 | metric_fn=lambda samples, **kwargs: {"rewards": reward_fn(samples)}, 114 | eval_prompts=eval_prompts, 115 | config=config, 116 | ) 117 | 118 | 119 | if __name__ == "__main__": 120 | import json 121 | import sys 122 | 123 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 124 | main(hparams) 125 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/requirements.txt: -------------------------------------------------------------------------------- 1 | evaluate>=0.4.0 2 | nltk>=3.8.1 3 | rouge-score>=0.1.2 4 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/reward_model/ds_config_gpt_j.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 32, 3 | "fp16": { 4 | "enabled": true, 5 | "min_loss_scale": 1, 6 | "opt_level": "O2" 7 | }, 8 | "zero_optimization": { 9 | "stage": 2, 10 | "offload_param": { 11 | "device": "cpu" 12 | }, 13 | "offload_optimizer": { 14 | "device": "cpu" 15 | }, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 5e8, 18 | "contiguous_gradients": true 19 | }, 20 | "optimizer": { 21 | "type": "AdamW", 22 | "params": { 23 | "lr": 1e-5, 24 | "betas": [ 25 | 0.9, 26 | 0.999 27 | ], 28 | "eps": 1e-08 29 | } 30 | }, 31 | "scheduler": { 32 | "type": "WarmupLR", 33 | "params": { 34 | "warmup_min_lr": 0, 35 | "warmup_max_lr": "auto", 36 | "warmup_num_steps": 100 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/reward_model/gptj_reward_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset 6 | from reward_model import GPTRewardModel 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer 10 | 11 | 12 | def set_seed(seed_val=42): 13 | random.seed(seed_val) 14 | np.random.seed(seed_val) 15 | torch.manual_seed(seed_val) 16 | torch.cuda.manual_seed_all(seed_val) 17 | 18 | 19 | def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): 20 | dataset = load_dataset(path, split=split) 21 | if split == "test": 22 | dataset = dataset.select(range(5000)) 23 | 24 | pairs = [] 25 | for sample in tqdm(dataset): 26 | pair = {} 27 | prompt = sample["prompt"] 28 | chosen_summary = sample["chosen"] 29 | rejected_summary = sample["rejected"] 30 | if chosen_summary == rejected_summary: 31 | continue 32 | if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: 33 | continue 34 | pair["chosen"] = prompt + "\n" + chosen_summary 35 | pair["rejected"] = prompt + "\n" + rejected_summary 36 | pairs.append(pair) 37 | return pairs 38 | 39 | 40 | class PairwiseDataset(Dataset): 41 | def __init__(self, pairs, tokenizer, max_length): 42 | self.chosen_input_ids = [] 43 | self.chosen_attn_masks = [] 44 | self.rejected_input_ids = [] 45 | self.rejected_attn_masks = [] 46 | for pair in pairs: 47 | chosen, rejected = pair["chosen"], pair["rejected"] 48 | chosen_encodings_dict = tokenizer( 49 | "<|startoftext|>" + chosen + "<|endoftext|>", 50 | truncation=True, 51 | max_length=max_length, 52 | padding="max_length", 53 | return_tensors="pt", 54 | ) 55 | rejected_encodings_dict = tokenizer( 56 | "<|startoftext|>" + rejected + "<|endoftext|>", 57 | truncation=True, 58 | max_length=max_length, 59 | padding="max_length", 60 | return_tensors="pt", 61 | ) 62 | if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): 63 | self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) 64 | self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) 65 | self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) 66 | self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) 67 | 68 | def __len__(self): 69 | return len(self.chosen_input_ids) 70 | 71 | def __getitem__(self, idx): 72 | return ( 73 | self.chosen_input_ids[idx], 74 | self.chosen_attn_masks[idx], 75 | self.rejected_input_ids[idx], 76 | self.rejected_attn_masks[idx], 77 | ) 78 | 79 | 80 | class DataCollatorReward: 81 | def __call__(self, data): 82 | batch = {} 83 | batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) 84 | batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) 85 | batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) 86 | return batch 87 | 88 | 89 | if __name__ == "__main__": 90 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 91 | tokenizer.pad_token = tokenizer.eos_token 92 | PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] 93 | 94 | model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") 95 | model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) 96 | max_length = 550 97 | val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test") 98 | dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) 99 | 100 | from torch.utils.data import DataLoader 101 | 102 | dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()) 103 | model.cuda() 104 | model.eval() 105 | model.half() 106 | correct = 0 107 | chosen_list = [] 108 | reject_list = [] 109 | with torch.no_grad(): 110 | for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): 111 | for x in batch: 112 | batch[x] = batch[x].cuda() 113 | outputs = model(**batch) 114 | correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"]) 115 | chosen_list.append(outputs["chosen_end_scores"].cpu()) 116 | reject_list.append(outputs["rejected_end_scores"].cpu()) 117 | print("Total accuracy: ", correct / len(dev_dataset)) 118 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/reward_model/reward_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | 5 | 6 | class GPTRewardModel(nn.Module): 7 | def __init__(self, model_path): 8 | super().__init__() 9 | model = AutoModelForCausalLM.from_pretrained(model_path) 10 | self.config = model.config 11 | # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` 12 | self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd 13 | self.transformer = model.transformer 14 | self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) 15 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 16 | self.tokenizer.pad_token = self.tokenizer.eos_token 17 | self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] 18 | 19 | def forward( 20 | self, 21 | input_ids=None, 22 | past_key_values=None, 23 | attention_mask=None, 24 | token_type_ids=None, 25 | position_ids=None, 26 | head_mask=None, 27 | inputs_embeds=None, 28 | mc_token_ids=None, 29 | labels=None, 30 | return_dict=False, 31 | output_attentions=False, 32 | output_hidden_states=False, 33 | ): 34 | loss = None 35 | transformer_outputs = self.transformer( 36 | input_ids, 37 | past_key_values=past_key_values, 38 | attention_mask=attention_mask, 39 | token_type_ids=token_type_ids, 40 | position_ids=position_ids, 41 | head_mask=head_mask, 42 | inputs_embeds=inputs_embeds, 43 | ) 44 | 45 | hidden_states = transformer_outputs[0] 46 | 47 | rewards = self.v_head(hidden_states).squeeze(-1) 48 | chosen_end_scores = [] 49 | rejected_end_scores = [] 50 | 51 | # Split the inputs and rewards into two parts, chosen and rejected 52 | assert len(input_ids.shape) == 2 53 | bs = input_ids.shape[0] // 2 54 | chosen = input_ids[:bs] 55 | rejected = input_ids[bs:] 56 | chosen_rewards = rewards[:bs] 57 | rejected_rewards = rewards[bs:] 58 | 59 | loss = 0 60 | inference = False 61 | for i in range(bs): 62 | if torch.all(torch.eq(chosen[i], rejected[i])).item(): 63 | c_inds = (chosen[i] == self.PAD_ID).nonzero() 64 | c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] 65 | chosen_end_scores.append(chosen_rewards[i, c_ind - 1]) 66 | inference = True 67 | continue 68 | 69 | # Check if there is any padding otherwise take length of sequence 70 | c_inds = (chosen[i] == self.PAD_ID).nonzero() 71 | c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] 72 | r_inds = (rejected[i] == self.PAD_ID).nonzero() 73 | r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1] 74 | end_ind = max(c_ind, r_ind) 75 | 76 | # Retrieve first index where trajectories diverge 77 | divergence_ind = (chosen[i] != rejected[i]).nonzero()[0] 78 | assert divergence_ind > 0 79 | 80 | # Index into the correct rewards 81 | c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind] 82 | r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind] 83 | 84 | # Append the last rewards to the list of end scores 85 | chosen_end_scores.append(c_truncated_reward[-1]) 86 | rejected_end_scores.append(r_truncated_reward[-1]) 87 | 88 | # Compute loss based on truncated rewards (ignore padding) 89 | loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() 90 | loss = loss / bs 91 | 92 | if not inference: 93 | chosen_end_scores = torch.stack(chosen_end_scores) 94 | rejected_end_scores = torch.stack(rejected_end_scores) 95 | 96 | if inference: 97 | chosen_end_scores = torch.stack(chosen_end_scores) 98 | return {"chosen_end_scores": chosen_end_scores} 99 | 100 | return { 101 | "loss": loss, 102 | "chosen_end_scores": chosen_end_scores, 103 | "rejected_end_scores": rejected_end_scores, 104 | } 105 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/reward_model/train_reward_model_gptj.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from reward_model import GPTRewardModel 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer, Trainer, TrainingArguments 9 | 10 | 11 | def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): 12 | dataset = load_dataset(path, split=split) 13 | pairs = [] 14 | for sample in tqdm(dataset): 15 | pair = {} 16 | prompt = sample["prompt"] 17 | chosen_summary = sample["chosen"] 18 | rejected_summary = sample["rejected"] 19 | if chosen_summary == rejected_summary: 20 | continue 21 | if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: 22 | continue 23 | pair["chosen"] = prompt + "\n" + chosen_summary 24 | pair["rejected"] = prompt + "\n" + rejected_summary 25 | pairs.append(pair) 26 | return pairs 27 | 28 | 29 | class PairwiseDataset(Dataset): 30 | def __init__(self, pairs, tokenizer, max_length): 31 | self.chosen_input_ids = [] 32 | self.chosen_attn_masks = [] 33 | self.rejected_input_ids = [] 34 | self.rejected_attn_masks = [] 35 | for pair in tqdm(pairs): 36 | chosen, rejected = pair["chosen"], pair["rejected"] 37 | chosen_encodings_dict = tokenizer( 38 | "<|startoftext|>" + chosen + "<|endoftext|>", 39 | truncation=True, 40 | max_length=max_length, 41 | padding="max_length", 42 | return_tensors="pt", 43 | ) 44 | rejected_encodings_dict = tokenizer( 45 | "<|startoftext|>" + rejected + "<|endoftext|>", 46 | truncation=True, 47 | max_length=max_length, 48 | padding="max_length", 49 | return_tensors="pt", 50 | ) 51 | if not torch.all(torch.eq(chosen_encodings_dict["input_ids"], rejected_encodings_dict["input_ids"])).item(): 52 | self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) 53 | self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) 54 | self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) 55 | self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) 56 | 57 | def __len__(self): 58 | return len(self.chosen_input_ids) 59 | 60 | def __getitem__(self, idx): 61 | return ( 62 | self.chosen_input_ids[idx], 63 | self.chosen_attn_masks[idx], 64 | self.rejected_input_ids[idx], 65 | self.rejected_attn_masks[idx], 66 | ) 67 | 68 | 69 | class DataCollatorReward: 70 | def __call__(self, data): 71 | batch = {} 72 | batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) 73 | batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) 74 | batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) 75 | return batch 76 | 77 | 78 | def compute_metrics(eval_preds): 79 | chosen_end_scores = eval_preds.predictions[0] # chosen scores 80 | rejected_end_scores = eval_preds.predictions[1] # rejected scores 81 | 82 | result = {} 83 | acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) 84 | result["accuracy"] = acc 85 | 86 | return result 87 | 88 | 89 | if __name__ == "__main__": 90 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 91 | tokenizer.pad_token = tokenizer.eos_token 92 | 93 | if not os.path.exists("rm_checkpoint"): 94 | os.mkdir("rm_checkpoint") 95 | 96 | training_args = TrainingArguments( 97 | output_dir="rm_checkpoint/", 98 | num_train_epochs=5, 99 | logging_steps=10, 100 | gradient_accumulation_steps=4, 101 | save_strategy="steps", 102 | evaluation_strategy="steps", 103 | per_device_train_batch_size=1, 104 | per_device_eval_batch_size=1, 105 | eval_accumulation_steps=1, 106 | eval_steps=500, 107 | save_steps=500, 108 | warmup_steps=100, 109 | logging_dir="./logs", 110 | fp16=True, 111 | bf16=False, 112 | learning_rate=1e-5, 113 | deepspeed="ds_config_gpt_j.json", 114 | save_total_limit=1, 115 | ) 116 | 117 | # Initialize the reward model from the (supervised) fine-tuned GPT-J 118 | model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") 119 | 120 | # Freeze the first 70% of the hidden layers of the reward model backbone 121 | layers = model.transformer.h 122 | num_layers = len(layers) 123 | num_unfrozen = int(0.3 * num_layers) 124 | for layer in layers[:-num_unfrozen]: 125 | layer.requires_grad_(False) 126 | 127 | # Create the comparisons datasets 128 | data_path = "CarperAI/openai_summarize_comparisons" 129 | train_pairs = create_comparison_dataset(data_path, "train") 130 | val_pairs = create_comparison_dataset(data_path, "test") 131 | 132 | # Make pairwise datasets for training 133 | max_length = 550 134 | train_dataset = PairwiseDataset(train_pairs, tokenizer, max_length=max_length) 135 | val_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) 136 | 137 | # Create the collator to gather batches of pairwise comparisons 138 | data_collator = DataCollatorReward() 139 | 140 | Trainer( 141 | model=model, 142 | args=training_args, 143 | train_dataset=train_dataset, 144 | compute_metrics=compute_metrics, 145 | eval_dataset=val_dataset, 146 | data_collator=data_collator, 147 | ).train() 148 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/sft/ds_config_gptj.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": 128, 3 | "fp16": { 4 | "enabled": true, 5 | "min_loss_scale": 1, 6 | "opt_level": "O2" 7 | }, 8 | "zero_optimization": { 9 | "stage": 2, 10 | "offload_param": { 11 | "device": "cpu" 12 | }, 13 | "offload_optimizer": { 14 | "device": "cpu" 15 | }, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 5e8, 18 | "contiguous_gradients": true 19 | }, 20 | "optimizer": { 21 | "type": "AdamW", 22 | "params": { 23 | "lr": 1e-05, 24 | "betas": [ 25 | 0.9, 26 | 0.95 27 | ], 28 | "eps": 1e-08 29 | } 30 | }, 31 | "scheduler": { 32 | "type": "WarmupLR", 33 | "params": { 34 | "warmup_min_lr": 0, 35 | "warmup_max_lr": 1e-05, 36 | "warmup_num_steps": "auto" 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/sft/summarize_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | 3 | import pandas as pd 4 | import torch 5 | from datasets import load_dataset 6 | from torch.utils.data import Dataset 7 | 8 | 9 | def get_dataset_from_jsonl(jsonl_file, return_summary=True): 10 | # if return_summary is True, return a list of posts with summary concatenated 11 | # if return_summary is False, return a list of posts and a list of summaries 12 | with open(jsonl_file, "r") as f: 13 | dataset = [json.loads(line) for line in f] 14 | post_list = [] 15 | summary_list = [] 16 | for d in dataset: 17 | if return_summary: 18 | post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: {d['summary']}" 19 | else: 20 | post = f"SUBREDDIT: r/{d['subreddit']}\nTITLE: {d['title']}\nPOST: {d['post']}\nTL;DR: " 21 | summary_list.append(d["summary"]) 22 | post_list.append(post) 23 | if not return_summary: 24 | return post_list, summary_list 25 | return post_list 26 | 27 | 28 | class TLDRDataset(Dataset): 29 | def __init__(self, train_path, tokenizer, split, max_length=550): 30 | self.post_list = [] 31 | dataset = load_dataset(train_path, split=split) 32 | for sample in dataset: 33 | self.post_list.append(sample["prompt"] + sample["label"]) 34 | if "valid" in split: 35 | self.post_list = self.post_list[0:2000] 36 | self.tokenizer = tokenizer 37 | self.max_length = max_length 38 | self.input_ids = [] 39 | self.attn_masks = [] 40 | 41 | def __len__(self): 42 | return len(self.post_list) 43 | 44 | def __getitem__(self, idx): 45 | txt = self.post_list[idx] 46 | encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length") 47 | input_ids = torch.tensor(encodings_dict["input_ids"]) 48 | attn_masks = torch.tensor(encodings_dict["attention_mask"]) 49 | 50 | return { 51 | "input_ids": input_ids, 52 | "attention_mask": attn_masks, 53 | "labels": input_ids, 54 | } 55 | 56 | 57 | class ComparisonDataset(Dataset): 58 | def __init__(self, comparison_path, tokenizer, max_length=550): 59 | with open(comparison_path, "r") as f: 60 | dataset = [json.loads(line) for line in f] 61 | 62 | self.tokenizer = tokenizer 63 | self.post_list = [] 64 | self.summaries_0 = [] 65 | self.summaries_1 = [] 66 | self.labels = [] 67 | self.max_length = max_length 68 | 69 | def make_text(post, summarize): 70 | return f"SUBREDDIT: r/{post['subreddit']}\nTITLE: {post['title']}\nPOST: {post['post']}\nTL;DR: {summarize}" 71 | 72 | for sample in dataset: # chosen summary is always the first one 73 | self.post_list.append(sample["info"]["post"]) 74 | # NOTE: The chosen summary is always the first one, i.e. `sample["summaries"][0]` 75 | if sample["choice"] == 0: 76 | self.summaries_0.append(make_text(sample["info"], sample["summaries"][0]["text"])) 77 | self.summaries_1.append(make_text(sample["info"], sample["summaries"][1]["text"])) 78 | else: 79 | self.summaries_0.append(make_text(sample["info"], sample["summaries"][1]["text"])) 80 | self.summaries_1.append(make_text(sample["info"], sample["summaries"][0]["text"])) 81 | self.labels.append(0) 82 | 83 | def __len__(self): 84 | return len(self.post_list) 85 | 86 | def __getitem__(self, idx): 87 | summ0 = self.summaries_0[idx] 88 | summ1 = self.summaries_1[idx] 89 | encodings_dict = self.tokenizer( 90 | [summ0, summ1], 91 | truncation=True, 92 | max_length=self.max_length, 93 | padding="max_length", 94 | ) 95 | input_ids = torch.tensor(encodings_dict["input_ids"]) 96 | attention_mask = torch.tensor(encodings_dict["attention_mask"]) 97 | return {"input_ids": input_ids, "attention_mask": attention_mask} 98 | 99 | 100 | class AllSummDataset(Dataset): 101 | def __init__(self, train_path, tokenizer, split, max_length=1024): 102 | df = pd.read_parquet(train_path) 103 | if split == "valid": 104 | df = df.sample(n=5000) 105 | self.summarizes = [] 106 | for i, row in df.iterrows(): 107 | self.summarizes.append(f"Summarize: {row['text']}. TL;DR: {row['summary']}") 108 | self.tokenizer = tokenizer 109 | self.max_length = max_length 110 | self.input_ids = [] 111 | self.attn_masks = [] 112 | 113 | def __len__(self): 114 | return len(self.summarizes) 115 | 116 | def __getitem__(self, idx): 117 | txt = self.summarizes[idx] 118 | encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length") 119 | input_ids = torch.tensor(encodings_dict["input_ids"]) 120 | attn_masks = torch.tensor(encodings_dict["attention_mask"]) 121 | 122 | return { 123 | "input_ids": input_ids, 124 | "attention_mask": attn_masks, 125 | "labels": input_ids, 126 | } 127 | -------------------------------------------------------------------------------- /examples/summarize_rlhf/sft/train_gptj_summarize.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import evaluate 4 | import numpy as np 5 | import torch 6 | from summarize_dataset import TLDRDataset 7 | from transformers import ( 8 | AutoModelForCausalLM, 9 | AutoTokenizer, 10 | Trainer, 11 | TrainingArguments, 12 | default_data_collator, 13 | ) 14 | 15 | 16 | def set_seed(seed_val=42): 17 | random.seed(seed_val) 18 | np.random.seed(seed_val) 19 | torch.manual_seed(seed_val) 20 | torch.cuda.manual_seed_all(seed_val) 21 | 22 | 23 | if __name__ == "__main__": 24 | output_dir = "gptj-supervised-summarize-checkpoint" 25 | train_batch_size = 16 26 | gradient_accumulation_steps = 1 27 | learning_rate = 1e-5 28 | eval_batch_size = 1 29 | eval_steps = 500 30 | max_input_length = 550 31 | save_steps = 1000 32 | num_train_epochs = 5 33 | random.seed(42) 34 | 35 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 36 | model = AutoModelForCausalLM.from_pretrained("EleutherAI/gpt-j-6B", use_cache=False) 37 | tokenizer.pad_token = tokenizer.eos_token 38 | model.resize_token_embeddings(len(tokenizer)) 39 | tokenizer.pad_token_id = tokenizer.eos_token_id 40 | model.config.end_token_id = tokenizer.eos_token_id 41 | model.config.pad_token_id = model.config.eos_token_id 42 | 43 | # Set up the datasets 44 | data_path = "CarperAI/openai_summarize_tldr" 45 | train_dataset = TLDRDataset( 46 | data_path, 47 | tokenizer, 48 | "train", 49 | max_length=max_input_length, 50 | ) 51 | dev_dataset = TLDRDataset( 52 | data_path, 53 | tokenizer, 54 | "valid", 55 | max_length=max_input_length, 56 | ) 57 | 58 | # Set up the metric 59 | rouge = evaluate.load("rouge") 60 | 61 | def compute_metrics(eval_preds): 62 | labels_ids = eval_preds.label_ids 63 | pred_ids = eval_preds.predictions 64 | pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True) 65 | label_str = tokenizer.batch_decode(labels_ids, skip_special_tokens=True) 66 | result = rouge.compute(predictions=pred_str, references=label_str) 67 | return result 68 | 69 | # Create a preprocessing function to extract out the proper logits from the model output 70 | def preprocess_logits_for_metrics(logits, labels): 71 | if isinstance(logits, tuple): 72 | logits = logits[0] 73 | return logits.argmax(dim=-1) 74 | 75 | # Prepare the trainer and start training 76 | training_args = TrainingArguments( 77 | output_dir=output_dir, 78 | evaluation_strategy="steps", 79 | eval_accumulation_steps=1, 80 | learning_rate=learning_rate, 81 | per_device_train_batch_size=train_batch_size, 82 | per_device_eval_batch_size=eval_batch_size, 83 | gradient_checkpointing=True, 84 | half_precision_backend=True, 85 | fp16=True, 86 | adam_beta1=0.9, 87 | adam_beta2=0.95, 88 | gradient_accumulation_steps=gradient_accumulation_steps, 89 | num_train_epochs=num_train_epochs, 90 | warmup_steps=100, 91 | eval_steps=eval_steps, 92 | save_steps=save_steps, 93 | load_best_model_at_end=True, 94 | logging_steps=50, 95 | deepspeed="./ds_config_gptj.json", 96 | ) 97 | 98 | trainer = Trainer( 99 | model=model, 100 | args=training_args, 101 | train_dataset=train_dataset, 102 | eval_dataset=dev_dataset, 103 | compute_metrics=compute_metrics, 104 | data_collator=default_data_collator, 105 | preprocess_logits_for_metrics=preprocess_logits_for_metrics, 106 | ) 107 | trainer.train() 108 | trainer.save_model(output_dir) 109 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | profile = "black" 8 | 9 | [tool.black] 10 | line-length = 120 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu118 2 | accelerate==0.22.0 3 | aiohttp==3.8.5 4 | aiosignal==1.3.1 5 | appdirs==1.4.4 6 | async-timeout==4.0.3 7 | attrs==23.1.0 8 | cattrs==23.1.2 9 | certifi==2023.7.22 10 | charset-normalizer==3.2.0 11 | click==8.1.7 12 | cmake==3.25.0 13 | datasets==2.14.4 14 | deepspeed==0.10.1 15 | dill==0.3.7 16 | docker-pycreds==0.4.0 17 | einops==0.6.1 18 | exceptiongroup==1.1.3 19 | filelock==3.9.0 20 | frozenlist==1.4.0 21 | fsspec==2023.6.0 22 | gitdb==4.0.10 23 | GitPython==3.1.32 24 | grpcio==1.57.0 25 | hjson==3.1.0 26 | huggingface-hub==0.16.4 27 | idna==3.4 28 | Jinja2==3.1.2 29 | jsonschema==4.19.0 30 | jsonschema-specifications==2023.7.1 31 | lit==15.0.7 32 | markdown-it-py==3.0.0 33 | MarkupSafe==2.1.2 34 | mdurl==0.1.2 35 | mpmath==1.2.1 36 | msgpack==1.0.5 37 | multidict==6.0.4 38 | multiprocess==0.70.15 39 | networkx==3.0 40 | ninja==1.11.1 41 | numpy==1.25.2 42 | packaging==23.1 43 | pandas==2.0.3 44 | pathtools==0.1.2 45 | peft==0.5.0 46 | protobuf==4.24.2 47 | psutil==5.9.5 48 | py-cpuinfo==9.0.0 49 | pyarrow==13.0.0 50 | pydantic==1.10.12 51 | Pygments==2.16.1 52 | python-dateutil==2.8.2 53 | python-rapidjson==1.10 54 | pytz==2023.3 55 | PyYAML==6.0.1 56 | ray==2.6.3 57 | referencing==0.30.2 58 | regex==2023.8.8 59 | requests==2.31.0 60 | rich==13.5.2 61 | rpds-py==0.9.2 62 | safetensors==0.3.3 63 | sentry-sdk==1.29.2 64 | setproctitle==1.3.2 65 | six==1.16.0 66 | smmap==5.0.0 67 | sympy==1.11.1 68 | tabulate==0.9.0 69 | tokenizers==0.13.3 70 | torch==2.0.1+cu118 71 | torchtyping==0.1.4 72 | tqdm==4.66.1 73 | transformers==4.32.0 74 | triton==2.0.0 75 | tritonclient==2.36.0 76 | typeguard==4.1.3 77 | typing_extensions==4.7.1 78 | tzdata==2023.3 79 | urllib3==2.0.4 80 | wandb==0.15.8 81 | xxhash==3.3.0 82 | yarl==1.9.2 83 | -------------------------------------------------------------------------------- /scripts/accelerate_train_example.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -exuo pipefail 4 | 5 | # HOSTNAMES MASTER_ADDR MASTER_PORT COUNT_NODE are coming from the main script 6 | H=`hostname` 7 | RANK=`echo -e $HOSTNAMES | python3 -c "import sys;[sys.stdout.write(str(i)) for i,line in enumerate(next(sys.stdin).split(' ')) if line.strip() == '$H'.strip()]"` 8 | 9 | CONFIG_FILE=${1-configs/deepspeed/zero2-bf16.yaml} # relative to TRLX_DIR 10 | CONDA_DIR=${2:-/admin/home-amuzio/miniconda3} 11 | CONDA_ENV_NAME=${3:-trlx} 12 | 13 | # This script assumes the following: 14 | # (1) a conda environment named $CONDA_ENV_NAME 15 | # (2) It is being run from the $TRLX_DIR directory 16 | # If using venv, you can remove the conda stuff and just activate the venv directly 17 | set +x 18 | export PATH="$CONDA_DIR/condabin:$PATH" 19 | source $CONDA_DIR/etc/profile.d/conda.sh 20 | conda activate $CONDA_ENV_NAME 21 | set -x 22 | 23 | 24 | accelerate launch \ 25 | --num_processes $((8 * $COUNT_NODE)) \ 26 | --num_machines $COUNT_NODE \ 27 | --machine_rank $RANK \ 28 | --main_process_ip $MASTER_ADDR \ 29 | --config_file $CONFIG_FILE \ 30 | examples/ilql_sentiments.py 31 | -------------------------------------------------------------------------------- /scripts/benchmark.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | origin=CarperAI/trlx 5 | branch=main 6 | entity=null 7 | only_hash=false 8 | only_tiny=false 9 | 10 | while [[ "$#" -gt 0 ]]; do 11 | case $1 in 12 | --origin) origin="$2"; shift ;; 13 | --branch) branch="$2"; shift ;; 14 | --public) entity='"CarperAI"' ;; 15 | --only_hash) only_hash=true ;; 16 | --only_tiny) only_tiny=true ;; 17 | *) echo "Unknown parameter passed: $1"; exit 1 ;; 18 | esac 19 | shift 20 | done 21 | 22 | dir=`mktemp -d -p .` 23 | if [ ! -d "$dir" ]; then 24 | echo "Couldn't create a temporary directory, aborting" 25 | exit 1 26 | fi 27 | 28 | cd $dir 29 | trap "rm -rf ../$dir" EXIT 30 | 31 | git clone --depth 1 --single-branch -b $branch https://github.com/$origin . 32 | 33 | hash=`find . -not \( -path ./.git -prune \) -not -name "*.md" -type f -print0 | sort -z | xargs -0 sha1sum | sha1sum | cut -f1 -d" "` 34 | git_hash=`git log --format=%h/%s/%as -n1` 35 | 36 | if [ "$only_hash" = true ]; then 37 | echo "$hash" 38 | echo "$git_hash" 39 | exit 0 40 | fi 41 | 42 | python -m venv venv 43 | . venv/bin/activate 44 | python -m pip install pip --upgrade 45 | pip install -r requirements.txt 46 | pip install -e . 47 | 48 | args='{"train": {"project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' 49 | python examples/randomwalks/ilql_randomwalks.py "$args" 50 | python examples/randomwalks/ppo_randomwalks.py "$args" 51 | 52 | if [ "$only_tiny" = true ]; then 53 | exit 0 54 | fi 55 | 56 | rm -rf ../benchmark_logs && mkdir ../benchmark_logs 57 | 58 | CUDA_VISIBLE_DEVICES=0 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8880 examples/ppo_sentiments.py "$args" > ../benchmark_logs/ppo_sentiments.log 2>&1 & 59 | CUDA_VISIBLE_DEVICES=1 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8881 examples/sft_sentiments.py "$args" > ../benchmark_logs/sft_sentiments.log 2>&1 & 60 | CUDA_VISIBLE_DEVICES=2 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8882 examples/ilql_sentiments.py "$args" > ../benchmark_logs/ilql_sentiments.log 2>&1 & 61 | CUDA_VISIBLE_DEVICES=3 accelerate launch --num_processes 1 --config_file configs/accelerate/zero2-bf16.yaml --main_process_port 8883 examples/ppo_sentiments_t5.py "$args" > ../benchmark_logs/ppo_sentiments_t5.log 2>&1 & 62 | 63 | wait 64 | 65 | args='{"train": {"total_steps": 1500, "seq_length": 512, "project_name": "trlx-references", "entity_name": '$entity', "tags": ["'$hash'"]}}' 66 | CONFIG_NAME=6B accelerate launch --num_processes 7 --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py "$args" 67 | -------------------------------------------------------------------------------- /scripts/slurm_train.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trlx 3 | #SBATCH --nodes=1 4 | #SBATCH --ntasks-per-node=1 5 | #SBATCH --partition=g40 6 | #SBATCH --mem=0 7 | #SBATCH --output=logs/%x_%j.out 8 | #SBATCH --error=logs/%x_%j.err 9 | #SBATCH --comment=carperai 10 | #SBATCH --exclusive 11 | 12 | # Example usage: 13 | # sbatch slurm_train.sh TRLX_DIR 14 | 15 | set -exuo pipefail 16 | 17 | export LD_LIBRARY_PATH=/opt/aws-ofi-nccl/lib:/opt/amazon/efa/lib64:/usr/local/cuda-11.0/efa/lib:/usr/local/cuda-11.0/lib:/usr/local/cuda-11.0/lib64:/usr/local/cuda-11.0:/opt/nccl/build/lib:/opt/aws-ofi-nccl-install/lib:/opt/aws-ofi-nccl/lib:$LD_LIBRARY_PATH 18 | export PATH=/opt/amazon/efa/bin:/opt/amazon/openmpi/bin:$PATH 19 | 20 | export NCCL_DEBUG=WARN 21 | export NCCL_PROTO=simple 22 | export FI_EFA_FORK_SAFE=1 23 | export FI_LOG_LEVEL=1 24 | export FI_EFA_USE_DEVICE_RDMA=1 # use for p4dn 25 | export FI_EFA_ENABLE_SHM_TRANSFER=0 26 | export FI_PROVIDER=efa 27 | export FI_EFA_TX_MIN_CRE DITS=64 28 | # export CUDA_LAUNCH_BLOCKING=1 29 | 30 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 31 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 32 | export MASTER_PORT=1234 33 | export COUNT_NODE=`scontrol show hostnames "$SLURM_JOB_NODELIST" | wc -l` 34 | 35 | TRLX_DIR=${1:-/fsx/home-amuzio/trlx} 36 | TRAIN_SCRIPT=${2-scripts/accelerate_train_example.sh} # relative to TRLX_DIR 37 | CONFIG_FILE=${3-configs/accelerate/zero2-bf16.yaml} # relative to TRLX_DIR 38 | CONDA_DIR=${4:-/admin/home-amuzio/miniconda3} 39 | CONDA_ENV_NAME=${5:-trlx} 40 | 41 | pushd $TRLX_DIR 42 | srun --comment carperai $TRAIN_SCRIPT \ 43 | $CONFIG_FILE \ 44 | $CONDA_DIR \ 45 | $CONDA_ENV_NAME 46 | -------------------------------------------------------------------------------- /scripts/sweep-cw.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | #SBATCH --job-name=trlx-sweep 3 | #SBATCH --account=trlx 4 | #SBATCH --partition=a100-cu117 5 | #SBATCH --nodes=2 6 | #SBATCH --ntasks-per-node=1 7 | #SBATCH --mem=0 8 | #SBATCH --output=%j 9 | #SBATCH --exclusive 10 | 11 | export NCCL_DEBUG=WARN 12 | export NCCL_PROTO=simple 13 | export FI_EFA_FORK_SAFE=1 14 | export FI_LOG_LEVEL=1 15 | export FI_EFA_USE_DEVICE_RDMA=1 16 | export FI_EFA_ENABLE_SHM_TRANSFER=0 17 | export FI_PROVIDER=efa 18 | export FI_EFA_TX_MIN_CREDITS=64 19 | # export CUDA_LAUNCH_BLOCKING=1 20 | 21 | export HOSTNAMES=`scontrol show hostnames "$SLURM_JOB_NODELIST"` 22 | export MASTER_ADDR=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1) 23 | 24 | cd $TRLX 25 | source $TRLX/venv-with-pinned-ray/bin/activate 26 | 27 | ray start --head --port=6379 & 28 | 29 | export HOSTNAMES=($HOSTNAMES) 30 | for node in ${HOSTNAMES[@]:1}; do 31 | echo "Starting ray worker @ $node" 32 | srun --nodes=1 --ntasks=1 -w "$node" ray start --address $MASTER_ADDR:6379 --block & 33 | done 34 | 35 | sleep 10 36 | ray status 37 | 38 | NUM_GPUS=16 39 | python -m trlx.sweep -y --config configs/sweeps/ppo_sweep.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ppo_sentiments.py 40 | # python -m trlx.sweep -y --config configs/sweeps/ilql_sweep.yml --default_config configs/ilql_config.yml --accelerate_config configs/accelerate/zero2-bf16.yaml --num_gpus $NUM_GPUS examples/ilql_sentiments.py 41 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = trlx 3 | author = Alex Havrilla 4 | version = 0.7.0 5 | url = https://github.com/CarperAI/trlx 6 | description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | license = MIT 10 | 11 | [options] 12 | packages = find: 13 | install_requires = 14 | accelerate>=0.17.1 15 | attrs>=22.1.0 16 | cattrs>=22.2.0 17 | datasets 18 | deepspeed>=0.8.1 19 | einops>=0.4.1 20 | numpy>=1.23.2 21 | torchtyping 22 | transformers>=4.27.1 23 | tqdm 24 | rich 25 | wandb>=0.13.5 26 | ray>=2.4.0 27 | tabulate>=0.9.0 28 | networkx 29 | tritonclient 30 | 31 | [options.extras_require] 32 | bnb = 33 | bitsandbytes 34 | scipy # fix(bnb): Remove when `bitsandbytes` adds this dependency 35 | dev = 36 | black 37 | hypothesis 38 | isort 39 | flake8 40 | pre-commit 41 | pytest 42 | pytest-cov 43 | 44 | [options.packages.find] 45 | exclude = 46 | docs* 47 | tests* 48 | 49 | [flake8] 50 | max-complexity = 10 51 | max-line-length = 127 52 | # flake8 error codes: https://flake8.pycqa.org/en/latest/user/error-codes.html 53 | # pycodestyle codes: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes 54 | # E203 # whitespace before ‘,’, ‘;’, or ‘:’ 55 | # E741 # do not use variables named ‘l’, ‘O’, or ‘I’ 56 | # F401 # module imported but unused 57 | # F821 # undefined name name 58 | # W503 # line break before binary operator 59 | # W605 # invalid escape sequence ‘x’ 60 | ignore = 61 | E203 62 | E741 63 | F821 64 | W503 65 | W605 66 | per-file-ignores = __init__.py:F401,loading.py:F401 67 | exclude = 68 | .git 69 | __pycache__ 70 | docs/source/conf.py 71 | build 72 | dist 73 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from trlx.data.configs import TRLConfig 5 | 6 | 7 | def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]: 8 | """Returns all sub-directories of `dir` named `configs`.""" 9 | config_dirs = [] 10 | for root, dirs, _ in os.walk(dir): 11 | for d in dirs: 12 | if d == config_dir_name: 13 | config_dirs.append(os.path.join(root, d)) 14 | return config_dirs 15 | 16 | 17 | def _get_yaml_filepaths(dir: str) -> List[str]: 18 | """Returns a list of `yml` filepaths in `dir`.""" 19 | filepaths = [] 20 | for file in os.listdir(dir): 21 | if file.endswith(".yml"): 22 | filepaths.append(os.path.join(dir, file)) 23 | return filepaths 24 | 25 | 26 | def test_repo_trl_configs(): 27 | """Tests to ensure all default configs in the repository are valid.""" 28 | config_dirs = ["configs", *_get_config_dirs("examples")] 29 | config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior 30 | for file in config_files: 31 | assert os.path.isfile(file), f"Config file {file} does not exist." 32 | assert file.endswith(".yml"), f"Config file {file} is not a yaml file." 33 | try: 34 | config = TRLConfig.load_yaml(file) 35 | assert ( 36 | config.train.entity_name is None 37 | ), f"Unexpected entity name in config file `{file}`. Remove before pushing to repo." 38 | except Exception as e: 39 | assert False, f"Failed to load config file `{file}` with error `{e}`" 40 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import accelerate 4 | import pytest 5 | import torch 6 | import transformers 7 | 8 | import trlx.utils as utils 9 | import trlx.utils.modeling as modeling_utils 10 | 11 | try: 12 | import bitsandbytes 13 | 14 | HAS_BNB = True 15 | except ImportError: 16 | HAS_BNB = False 17 | 18 | 19 | # Test general utils 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "optimizer_name", 24 | [o.value for o in utils.OptimizerName], 25 | ) 26 | def test_optimizer_class_getters(optimizer_name: str): 27 | try: 28 | _class = utils.get_optimizer_class(optimizer_name) 29 | except Exception as e: 30 | assert False, "Failed to get optimizer class with error: " + str(e) 31 | 32 | # Hard-check for one of the optimizers 33 | _class = utils.get_optimizer_class("adamw") 34 | assert _class == torch.optim.AdamW 35 | if HAS_BNB: 36 | _bnb_class = utils.get_optimizer_class("adamw_8bit_bnb") 37 | assert _bnb_class == bitsandbytes.optim.AdamW8bit 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "scheduler_name", 42 | [o.value for o in utils.SchedulerName], 43 | ) 44 | def test_scheduler_class_getters(scheduler_name: str): 45 | try: 46 | _class = utils.get_scheduler_class(scheduler_name) 47 | except Exception as e: 48 | assert False, "Failed to get scheduler class with error: " + str(e) 49 | 50 | # Hard-check for one of the schedulers 51 | _class = utils.get_scheduler_class("cosine_annealing") 52 | assert _class == torch.optim.lr_scheduler.CosineAnnealingLR 53 | 54 | 55 | # Test modeling utils 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "model_name", 60 | [ 61 | "EleutherAI/gpt-j-6B", 62 | "EleutherAI/gpt-neox-20b", 63 | "gpt2", 64 | "facebook/opt-1.3b", 65 | ], 66 | ) 67 | def test_hf_attr_getters(model_name: str): 68 | with accelerate.init_empty_weights(): 69 | config = transformers.AutoConfig.from_pretrained(model_name) 70 | arch = transformers.AutoModelForCausalLM.from_config(config) 71 | 72 | arch_getters = [ 73 | modeling_utils.hf_get_decoder, 74 | modeling_utils.hf_get_decoder_final_norm, 75 | modeling_utils.hf_get_decoder_blocks, 76 | modeling_utils.hf_get_lm_head, 77 | ] 78 | for get in arch_getters: 79 | try: 80 | get(arch) 81 | except Exception as e: 82 | assert False, "Failed to get model attribute with error: " + str(e) 83 | 84 | config_getters = [ 85 | modeling_utils.hf_get_hidden_size, 86 | modeling_utils.hf_get_num_hidden_layers, 87 | ] 88 | for get in config_getters: 89 | try: 90 | get(config) 91 | except Exception as e: 92 | assert False, "Failed to get config attribute with error: " + str(e) 93 | 94 | 95 | class TestStatistics(unittest.TestCase): 96 | @classmethod 97 | def setUpClass(cls): 98 | cls.m = modeling_utils.RunningMoments() 99 | cls.a1 = torch.arange(100, dtype=float) 100 | cls.a2 = torch.ones(100, dtype=float) 101 | cls.a3 = torch.exp(torch.arange(10, dtype=float)) 102 | cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) 103 | 104 | def test_running_moments(self): 105 | assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) 106 | assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) 107 | assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) 108 | assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) 109 | 110 | a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) 111 | assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) 112 | assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) 113 | -------------------------------------------------------------------------------- /trlx/__init__.py: -------------------------------------------------------------------------------- 1 | from .trlx import train 2 | from .utils import logging 3 | -------------------------------------------------------------------------------- /trlx/data/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | @dataclass 8 | class GeneralElement: 9 | """ 10 | General element outputted by a data pipeline 11 | """ 12 | 13 | pass 14 | 15 | 16 | @dataclass 17 | class RLElement: 18 | """ 19 | Batch element for RL model 20 | """ 21 | 22 | state: Iterable[str] = None # Context/prompts 23 | action: TensorType["N"] = None # Tokens generated by model given prompts 24 | reward: float = None # Reward obtained for that generation 25 | 26 | 27 | @dataclass 28 | class BatchElement: 29 | """ 30 | General batch element for any transformer to use in its forward pass 31 | """ 32 | 33 | tokens: TensorType["BATCH", "SEQ_LEN"] 34 | masks: TensorType["BATCH", "SEQ_LEN"] 35 | -------------------------------------------------------------------------------- /trlx/data/accelerate_base_datatypes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | @dataclass 8 | class PromptElement: 9 | """ 10 | Dataclass for a single prompt, containing its string and tokenized form. 11 | 12 | :param text: The prompt text. 13 | :type text: str 14 | 15 | :param tokens: The prompt tokens. Should be a long tensor 16 | :type tokens: torch.Tensor 17 | """ 18 | 19 | text: str 20 | tokens: TensorType["num_tokens"] 21 | 22 | 23 | @dataclass 24 | class PromptBatch: 25 | """ 26 | Batched PromptElement 27 | 28 | :param text: An iterable of prompt texts. 29 | :type text: Iterable[str] 30 | 31 | :param tokens: A long tensor batch of prompt tokens. 32 | :type tokens: torch.Tensor 33 | """ 34 | 35 | text: Iterable[str] 36 | tokens: TensorType["batch_size", "num_tokens"] 37 | 38 | 39 | @dataclass 40 | class AccelerateRLElement: 41 | """ 42 | Dataclass for RL elements, containing output tokens and rewards for each token. 43 | 44 | :param tokens: The output tokens. Should be a long tensor 45 | :type tokens: torch.Tensor 46 | 47 | :param rewards: The rewards for each token. Should be a float tensor of same size as tokens. 48 | :type rewards: torch.Tensor 49 | """ 50 | 51 | output_tokens: TensorType["output_size"] 52 | rewards: TensorType["output_size"] 53 | 54 | 55 | @dataclass 56 | class AccelerateRLBatchElement: 57 | """ 58 | Batched accelerate RL element 59 | 60 | :param tokens: Batches of long tensors of output tokens. 61 | :type tokens: torch.Tensor 62 | 63 | :param rewards: Batches of float tensors of rewards for each output token. 64 | :type rewards: torch.Tensor 65 | """ 66 | 67 | output_tokens: TensorType["batch_size", "output_size"] 68 | rewards: TensorType["batch_size", "output_size"] 69 | -------------------------------------------------------------------------------- /trlx/data/default_configs.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from trlx.models.modeling_ilql import ILQLConfig 4 | from trlx.models.modeling_ppo import PPOConfig 5 | from trlx.trainer.accelerate_sft_trainer import SFTConfig 6 | 7 | from .configs import ( 8 | ModelConfig, 9 | OptimizerConfig, 10 | SchedulerConfig, 11 | TokenizerConfig, 12 | TrainConfig, 13 | TRLConfig, 14 | ) 15 | 16 | 17 | def default_ppo_config(): 18 | return TRLConfig( 19 | train=TrainConfig( 20 | seq_length=1024, 21 | epochs=100, 22 | total_steps=10000, 23 | batch_size=32, 24 | checkpoint_interval=10000, 25 | eval_interval=100, 26 | pipeline="PromptPipeline", 27 | trainer="AcceleratePPOTrainer", 28 | ), 29 | model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), 30 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 31 | optimizer=OptimizerConfig( 32 | name="adamw", kwargs=dict(lr=3e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 33 | ), 34 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=3e-5)), 35 | method=PPOConfig( 36 | name="PPOConfig", 37 | num_rollouts=128, 38 | chunk_size=128, 39 | ppo_epochs=4, 40 | init_kl_coef=0.001, 41 | target=None, 42 | horizon=10000, 43 | gamma=1, 44 | lam=0.95, 45 | cliprange=0.2, 46 | cliprange_value=0.2, 47 | vf_coef=1, 48 | scale_reward="ignored", 49 | ref_mean=None, 50 | ref_std=None, 51 | cliprange_reward=10, 52 | gen_kwargs=dict( 53 | max_new_tokens=40, 54 | top_k=0, 55 | top_p=1.0, 56 | do_sample=True, 57 | ), 58 | ), 59 | ) 60 | 61 | 62 | def default_ilql_config(): 63 | return TRLConfig( 64 | train=TrainConfig( 65 | seq_length=64, 66 | batch_size=128, 67 | epochs=100, 68 | total_steps=1000, 69 | checkpoint_interval=1000, 70 | eval_interval=100, 71 | pipeline="PromptPipeline", 72 | trainer="AccelerateILQLTrainer", 73 | ), 74 | model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), 75 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 76 | optimizer=OptimizerConfig( 77 | name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 78 | ), 79 | scheduler=SchedulerConfig( 80 | name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=5.0e-5) # train.total_steps 81 | ), 82 | method=ILQLConfig( 83 | name="ilqlconfig", 84 | tau=0.7, 85 | gamma=0.99, 86 | cql_scale=0.1, 87 | awac_scale=1, 88 | alpha=0.001, 89 | beta=0, 90 | steps_for_target_q_sync=5, 91 | two_qs=True, 92 | gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=1, temperature=1.0), 93 | ), 94 | ) 95 | 96 | 97 | def default_sft_config(): 98 | return TRLConfig( 99 | train=TrainConfig( 100 | seq_length=1024, 101 | epochs=100, 102 | total_steps=1000, 103 | batch_size=8, 104 | checkpoint_interval=10000, 105 | eval_interval=100, 106 | pipeline="PromptPipeline", 107 | trainer="AccelerateSFTTrainer", 108 | ), 109 | model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), 110 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 111 | optimizer=OptimizerConfig( 112 | name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 113 | ), 114 | scheduler=SchedulerConfig( 115 | name="cosine_annealing", kwargs=dict(T_max=1e12, eta_min=1.0e-4) # train.total_steps 116 | ), 117 | method=SFTConfig( 118 | name="sftconfig", 119 | gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), 120 | ), 121 | ) 122 | 123 | 124 | def default_nemo_20b_config(): 125 | """Load nemo-megatron-20b model and trainer config""" 126 | # Import here to not require nemo as a dependency 127 | from omegaconf import OmegaConf 128 | 129 | here = Path(__file__).parent 130 | return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_20b.yaml") 131 | 132 | 133 | def default_nemo_2b_config(): 134 | """Load nemo-megatron-1.3b model and trainer config""" 135 | # Import here to not require nemo as a dependency 136 | from omegaconf import OmegaConf 137 | 138 | here = Path(__file__).parent 139 | return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_2b.yaml") 140 | 141 | 142 | def default_nemo_1_3b_config(): 143 | """Load nemo-megatron-1.3b model and trainer config""" 144 | # Import here to not require nemo as a dependency 145 | from omegaconf import OmegaConf 146 | 147 | here = Path(__file__).parent 148 | return OmegaConf.load(here.parent.parent / "configs" / "nemo_configs" / "megatron_1.3b.yaml") 149 | -------------------------------------------------------------------------------- /trlx/data/ilql_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torchtyping import TensorType # type: ignore 4 | 5 | 6 | @dataclass 7 | class ILQLElement: 8 | """ 9 | A single data item for ILQL training 10 | 11 | :param input_ids: Long tensor of input tokens. 12 | :type input_ids: torch.Tensor 13 | 14 | :param attention_mask: Attention mask for input tokens. Should be a long tensor. 15 | :type attention_mask: torch.Tensor 16 | 17 | :param rewards: Rewards for each input token. 18 | :type rewards: torch.Tensor 19 | 20 | :param states_ixs: Indices of states (user input or environment input for example) in the `input_ids`. 21 | :type states_ixs: torch.Tensor 22 | 23 | :param actions_ixs: Indices of actions (model output) in the `input_ids` tensor. 24 | :type actions_ixs: torch.Tensor 25 | 26 | :param dones: Indicator of for the terminal state (end of episode) in the `input_ids` tensor. 27 | :type dones: torch.Tensor 28 | """ 29 | 30 | input_ids: TensorType["query_size"] 31 | attention_mask: TensorType["query_size"] 32 | rewards: TensorType["reward_size"] 33 | states_ixs: TensorType["states_size"] 34 | actions_ixs: TensorType["reward_size"] 35 | dones: TensorType["states_size"] 36 | 37 | 38 | @dataclass 39 | class ILQLSeq2SeqElement: 40 | """ 41 | A single data item for ILQL training 42 | 43 | :param input_ids: Long tensor of input tokens. 44 | :type input_ids: torch.Tensor 45 | 46 | :param attention_mask: Attention mask for input tokens. Should be a long tensor. 47 | :type attention_mask: torch.Tensor 48 | 49 | :param decoder_input_ids: Long tensor of target input tokens. 50 | :type decoder_input_ids: torch.Tensor 51 | 52 | :param rewards: Rewards for each input token. 53 | :type rewards: torch.Tensor 54 | 55 | :param states_ixs: Indices of states (user input or environment input for example) in the `input_ids`. 56 | :type states_ixs: torch.Tensor 57 | 58 | :param actions_ixs: Indices of actions (model output) in the `input_ids` tensor. 59 | :type actions_ixs: torch.Tensor 60 | 61 | :param dones: Indicator of for the terminal state (end of episode) in the `input_ids` tensor. 62 | :type dones: torch.Tensor 63 | """ 64 | 65 | input_ids: TensorType["query_size"] 66 | attention_mask: TensorType["query_size"] 67 | decoder_input_ids: TensorType["reward_size"] 68 | rewards: TensorType["reward_size"] 69 | states_ixs: TensorType["states_size"] 70 | actions_ixs: TensorType["reward_size"] 71 | dones: TensorType["states_size"] 72 | 73 | 74 | @dataclass 75 | class ILQLBatch: 76 | """ 77 | Batched ILQL data elements 78 | 79 | :param input_ids: Batch of input tokens. 80 | :type input_ids: torch.Tensor 81 | 82 | :param attention_mask: Batch of attention masks. 83 | :type attention_mask: torch.Tensor 84 | 85 | :param rewards: Batch of rewards for each token in each token batch. 86 | :type rewards: torch.Tensor 87 | 88 | :param states_ixs: Batch of indices of states (user input or environment input for example) in the `input_ids`. 89 | :type states_ixs: torch.Tensor 90 | 91 | :param actions_ixs: Batch of indices of actions (model output) in the `input_ids` tensor. 92 | :type actions_ixs: torch.Tensor 93 | 94 | :param dones: Batch of indicators of for the terminal state (end of episode) in the `input_ids` tensor. 95 | :type dones: torch.Tensor 96 | """ 97 | 98 | input_ids: TensorType["batch_size", "query_size"] 99 | attention_mask: TensorType["batch_size", "query_size"] 100 | rewards: TensorType["batch_size", "reward_size"] 101 | states_ixs: TensorType["batch_size", "states_size"] 102 | actions_ixs: TensorType["batch_size", "reward_size"] 103 | dones: TensorType["batch_size", "states_size"] 104 | 105 | 106 | @dataclass 107 | class ILQLSeq2SeqBatch: 108 | """ 109 | Batched ILQL data elements 110 | 111 | :param input_ids: Batch of input tokens. 112 | :type input_ids: torch.Tensor 113 | 114 | :param attention_mask: Batch of attention masks. 115 | :type attention_mask: torch.Tensor 116 | 117 | :param decoder_input_ids: Batch of target input tokens. 118 | :type decoder_input_ids: torch.Tensor 119 | 120 | :param rewards: Batch of rewards for each token in each token batch. 121 | :type rewards: torch.Tensor 122 | 123 | :param states_ixs: Batch of indices of states (user input or environment input for example) in the `input_ids`. 124 | :type states_ixs: torch.Tensor 125 | 126 | :param actions_ixs: Batch of indices of actions (model output) in the `input_ids` tensor. 127 | :type actions_ixs: torch.Tensor 128 | 129 | :param dones: Batch of indicators of for the terminal state (end of episode) in the `input_ids` tensor. 130 | :type dones: torch.Tensor 131 | """ 132 | 133 | input_ids: TensorType["batch_size", "query_size"] 134 | attention_mask: TensorType["batch_size", "query_size"] 135 | decoder_input_ids: TensorType["batch_size", "reward_size"] 136 | rewards: TensorType["batch_size", "reward_size"] 137 | states_ixs: TensorType["batch_size", "states_size"] 138 | actions_ixs: TensorType["batch_size", "reward_size"] 139 | dones: TensorType["batch_size", "states_size"] 140 | -------------------------------------------------------------------------------- /trlx/data/method_configs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Any, Dict 4 | 5 | # specifies a dictionary of method configs 6 | _METHODS: Dict[str, Any] = {} # registry 7 | 8 | 9 | def register_method(name): 10 | """Decorator used register a method config 11 | Args: 12 | name: Name of the method 13 | """ 14 | 15 | def register_class(cls, name): 16 | _METHODS[name] = cls 17 | setattr(sys.modules[__name__], name, cls) 18 | return cls 19 | 20 | if isinstance(name, str): 21 | name = name.lower() 22 | return lambda c: register_class(c, name) 23 | 24 | cls = name 25 | name = cls.__name__ 26 | register_class(cls, name.lower()) 27 | 28 | return cls 29 | 30 | 31 | @dataclass 32 | @register_method 33 | class MethodConfig: 34 | """ 35 | Config for a certain RL method. 36 | 37 | :param name: Name of the method 38 | :type name: str 39 | """ 40 | 41 | name: str 42 | 43 | @classmethod 44 | def from_dict(cls, config: Dict[str, Any]): 45 | return cls(**config) 46 | 47 | 48 | def get_method(name: str) -> MethodConfig: 49 | """ 50 | Return constructor for specified method config 51 | """ 52 | name = name.lower() 53 | if name in _METHODS: 54 | return _METHODS[name] 55 | else: 56 | raise Exception("Error: Trying to access a method that has not been registered") 57 | -------------------------------------------------------------------------------- /trlx/data/ppo_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torchtyping import TensorType 4 | 5 | 6 | @dataclass 7 | class PPORLElement: 8 | """ 9 | :param query_tensor: The query tensor i.e. the prompt tokens. 10 | Should be a long tensor. 11 | :type query_tensor: torch.Tensor 12 | 13 | :param response_tensor: The response tensor i.e. the output tokens. 14 | Should be a long tensor. 15 | :type response_tensor: torch.Tensor 16 | 17 | :param logprobs: The log probabilities over the response tokens generated 18 | by the policy network (i.e. the autoregressive model). 19 | Should be a float tensor of same size as tokens. 20 | :type logprobs: torch.Tensor 21 | 22 | :param values: The values for each token generated from the value network or value head. 23 | Should be a float tensor of same size as tokens. 24 | :type values: torch.Tensor 25 | 26 | :param rewards: The rewards for each token outputted in response. 27 | Should be a float tensor of same size as tokens. 28 | :type rewards: torch.Tensor 29 | """ 30 | 31 | query_tensor: TensorType["query_size"] 32 | response_tensor: TensorType["response_size"] 33 | logprobs: TensorType["response_size"] 34 | values: TensorType["response_size"] 35 | rewards: TensorType["response_size"] 36 | 37 | 38 | @dataclass 39 | class PPORLBatch: 40 | """ 41 | A batched version of the PPORLElement. See PPORLElement for more details on individual fields. 42 | 43 | :param query_tensors: A batch of query tensors. Should be a long tensor. 44 | :type query_tensors: torch.Tensor 45 | 46 | :param response_tensors: A batch of response tensors. Should be a long tensor. 47 | :type response_tensors: torch.Tensor 48 | 49 | :param logprobs: A batch of log probabilities from policy 50 | :type logprobs: torch.Tensor 51 | 52 | :param values: A batch of values from value network 53 | :type values: torch.Tensor 54 | 55 | :param rewards: A batch of rewards 56 | :type rewards: torch.Tensor 57 | """ 58 | 59 | query_tensors: TensorType["batch_size", "query_size"] 60 | response_tensors: TensorType["batch_size", "response_size"] 61 | logprobs: TensorType["batch_size", "response_size"] 62 | values: TensorType["batch_size", "response_size"] 63 | rewards: TensorType["batch_size", "response_size"] 64 | -------------------------------------------------------------------------------- /trlx/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CarperAI/trlx/3340c2f3a56d1d14fdd5f13ad575121fa26b6d92/trlx/models/__init__.py -------------------------------------------------------------------------------- /trlx/pipeline/ppo_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from functools import partial 5 | from typing import Iterable 6 | 7 | from torch.nn.utils.rnn import pad_sequence 8 | from torch.utils.data import DataLoader 9 | 10 | from trlx.data.ppo_types import PPORLBatch, PPORLElement 11 | from trlx.pipeline import BaseRolloutStore 12 | 13 | 14 | def ppo_collate_fn(padding_side: str, pad_token_id: int, elems: Iterable[PPORLElement]): 15 | if padding_side == "left": 16 | # Left padding of already left-padded queries 17 | query_tensors = pad_sequence( 18 | [elem.query_tensor.flip(0) for elem in elems], 19 | padding_value=pad_token_id, 20 | batch_first=True, 21 | ).flip(1) 22 | elif padding_side == "right": 23 | query_tensors = pad_sequence( 24 | [elem.query_tensor for elem in elems], 25 | padding_value=pad_token_id, 26 | batch_first=True, 27 | ) 28 | else: 29 | raise ValueError(f"Invalid padding side: {padding_side}") 30 | 31 | return PPORLBatch( 32 | query_tensors, 33 | # Right pad the rest, to have a single horizontal query/response split 34 | pad_sequence( 35 | [elem.response_tensor for elem in elems], 36 | padding_value=pad_token_id, 37 | batch_first=True, 38 | ), 39 | pad_sequence( 40 | [elem.logprobs for elem in elems], 41 | padding_value=0.0, 42 | batch_first=True, 43 | ), 44 | pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True), 45 | pad_sequence( 46 | [elem.rewards for elem in elems], 47 | padding_value=0.0, 48 | batch_first=True, 49 | ), 50 | ) 51 | 52 | 53 | class PPORolloutStorage(BaseRolloutStore): 54 | """ 55 | Rollout storage for training PPO 56 | """ 57 | 58 | def __init__(self, pad_token_id, padding_side): 59 | super().__init__() 60 | 61 | self.pad_token_id = pad_token_id 62 | self.padding_side = padding_side 63 | self.history: Iterable[PPORLElement] = [None] 64 | 65 | def push(self, exps: Iterable[PPORLElement]): 66 | self.history += exps 67 | 68 | def clear_history(self): 69 | self.history = [] 70 | 71 | def export_history(self, location: str, only_text=True): 72 | assert os.path.exists(location) 73 | 74 | fpath = os.path.join(location, f"epoch-{str(time.time())}.json") 75 | 76 | def exp_to_dict(exp): 77 | return {k: v.cpu().tolist() for k, v in exp.__dict__.items()} 78 | 79 | def filter_text(d, only_text): 80 | if only_text: 81 | keys = list(d.keys()) 82 | for key in keys: 83 | if key != "query_tensor" and key != "response_tensor": 84 | d.pop(key) 85 | return d 86 | 87 | data = [filter_text(exp_to_dict(exp), only_text) for exp in self.history] 88 | with open(fpath, "w") as f: 89 | f.write(json.dumps(data, indent=2)) 90 | 91 | def __getitem__(self, index: int) -> PPORLElement: 92 | return self.history[index] 93 | 94 | def __len__(self) -> int: 95 | return len(self.history) 96 | 97 | def create_loader( 98 | self, 99 | batch_size: int, 100 | shuffle: bool, 101 | ) -> DataLoader: 102 | return DataLoader( 103 | self, batch_size, shuffle=shuffle, collate_fn=partial(ppo_collate_fn, self.padding_side, self.pad_token_id) 104 | ) 105 | -------------------------------------------------------------------------------- /trlx/reference.py: -------------------------------------------------------------------------------- 1 | # python -m trlx.reference CarperAI/trlx:add-benchmark-tools --against CarperAI/trlx:main 2 | 3 | import argparse 4 | import os 5 | import subprocess 6 | 7 | import wandb 8 | import wandb.apis.reports as wb 9 | 10 | parser = argparse.ArgumentParser() 11 | parser.add_argument("branch", type=str, help="Git branch in the format `origin:branch`") 12 | parser.add_argument("--against", type=str, default="CarperAI/trlx:main", help="Reference git branch") 13 | parser.add_argument("--public", action="store_true", help="Use CarperAI entity to store/pull from w&b runs") 14 | args = parser.parse_args() 15 | 16 | pr_origin = ref_origin = "CarperAI/trlx" 17 | pr_branch = args.branch 18 | ref_branch = args.against 19 | if ":" in pr_branch: 20 | pr_origin, pr_branch = pr_branch.rsplit(":", 1) 21 | if ":" in ref_branch: 22 | ref_origin, ref_branch = ref_branch.rsplit(":", 1) 23 | 24 | out = os.popen(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} --only_hash") 25 | pr_hash, pr_git_hash = [x[:-1] for x in out.readlines()] 26 | 27 | out = os.popen(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} --only_hash") 28 | ref_hash, ref_git_hash = [x[:-1] for x in out.readlines()] 29 | 30 | print(f"{pr_origin}:{pr_branch=} {pr_hash=} {pr_git_hash=}") 31 | print(f"{ref_origin}:{ref_branch} {ref_hash=} {ref_git_hash=}") 32 | 33 | api = wandb.Api() 34 | project_name = "CarperAI/trlx-references" if args.public else "trlx-references" 35 | public = "--public" if args.public else "" 36 | 37 | runs = api.runs(project_name, filters={"tags": {"$in": [ref_hash]}}) 38 | if runs: 39 | print(f"On {ref_branch} @{ref_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") 40 | else: 41 | print(f"Making runs on {ref_branch} @{ref_git_hash}") 42 | subprocess.run(f"./scripts/benchmark.sh --origin {ref_origin} --branch {ref_branch} {public}".split()) 43 | 44 | runs = api.runs(project_name, filters={"tags": {"$in": [pr_hash]}}) 45 | if runs: 46 | print(f"On {pr_branch} @{pr_git_hash} these runs were already made: \n{chr(10).join(run.name for run in runs)}") 47 | else: 48 | print(f"Making runs on {pr_branch} @{pr_git_hash}") 49 | subprocess.run(f"./scripts/benchmark.sh --origin {pr_origin} --branch {pr_branch} {public}".split()) 50 | 51 | report = wb.Report( 52 | project=project_name.split("/")[1] if args.public else project_name, 53 | title=f"{pr_branch} v. {ref_branch}", 54 | description=f"{pr_branch}\n@{pr_git_hash}\n\n{ref_branch}\n@{ref_git_hash}", 55 | ) 56 | blocks = [] 57 | 58 | experiment_names = set(x.name.split(":")[0] for x in api.runs(project_name)) 59 | for name in experiment_names: 60 | filters = {"$and": [{"display_name": {"$regex": f"^{name}"}}, {"tags": {"$in": [pr_hash, ref_hash]}}]} 61 | 62 | runs = api.runs(project_name, filters=filters) 63 | metrics = set(sum([[metric for metric in run.history().columns if not metric.startswith("_")] for run in runs], [])) 64 | 65 | metrics_panels = [ 66 | wb.LinePlot( 67 | title=f"{metric}", 68 | x="Step", 69 | y=[metric], 70 | title_x="Step", 71 | smoothing_show_original=True, 72 | max_runs_to_show=2, 73 | plot_type="line", 74 | font_size="auto", 75 | legend_position="north", 76 | ) 77 | for metric in metrics 78 | ] 79 | 80 | # sort the most important metrics to be shown first 81 | major_metrics = set() 82 | for metric in metrics: 83 | if metric.startswith("reward") or metric.startswith("metric"): 84 | major_metrics.add(metric) 85 | metrics = metrics - major_metrics 86 | 87 | blocks.extend( 88 | [ 89 | wb.H1(text=name), 90 | wb.PanelGrid( 91 | panels=[panel for panel in metrics_panels if panel.title in major_metrics], 92 | runsets=[wb.Runset(project=project_name, filters=filters)], 93 | ), 94 | wb.PanelGrid( 95 | panels=[panel for panel in metrics_panels if panel.title in metrics], 96 | runsets=[wb.Runset(project=project_name, filters=filters)], 97 | ), 98 | ] 99 | ) 100 | 101 | report.blocks = blocks 102 | report.save() 103 | print(report.url) 104 | -------------------------------------------------------------------------------- /trlx/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from abc import abstractmethod 3 | from typing import Any, Callable, Dict, Iterable, Optional 4 | 5 | from trlx.data.configs import TRLConfig 6 | from trlx.pipeline import BaseRolloutStore 7 | 8 | # specifies a dictionary of architectures 9 | _TRAINERS: Dict[str, Any] = {} # registry 10 | 11 | 12 | def register_trainer(name): 13 | """Decorator used to register a trainer 14 | Args: 15 | name: Name of the trainer type to register 16 | """ 17 | 18 | def register_class(cls, name): 19 | _TRAINERS[name] = cls 20 | setattr(sys.modules[__name__], name, cls) 21 | return cls 22 | 23 | if isinstance(name, str): 24 | name = name.lower() 25 | return lambda c: register_class(c, name) 26 | 27 | cls = name 28 | name = cls.__name__ 29 | register_class(cls, name.lower()) 30 | 31 | return cls 32 | 33 | 34 | @register_trainer 35 | class BaseRLTrainer: 36 | def __init__( 37 | self, 38 | config: TRLConfig, 39 | reward_fn=None, 40 | metric_fn=None, 41 | logit_mask=None, 42 | stop_sequences=None, 43 | train_mode=False, 44 | ): 45 | self.store: BaseRolloutStore = None 46 | self.config = config 47 | self.reward_fn = reward_fn 48 | self.metric_fn = metric_fn 49 | self.logit_mask = logit_mask 50 | self.train_mode = train_mode 51 | self.stop_sequences = stop_sequences 52 | 53 | def push_to_store(self, data): 54 | """ 55 | Append new data to the rollout store 56 | """ 57 | self.store.push(data) 58 | 59 | @abstractmethod 60 | def learn(self): 61 | """ 62 | Use data in the the rollout store to update the model 63 | """ 64 | pass 65 | -------------------------------------------------------------------------------- /trlx/trainer/accelerate_sft_trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from transformers import AutoModelForCausalLM, PretrainedConfig 4 | 5 | from trlx.data.configs import TRLConfig 6 | from trlx.data.method_configs import MethodConfig, register_method 7 | from trlx.pipeline.offline_pipeline import ( 8 | DialogStore, 9 | PromptPipeline, 10 | tokenize_dialogue, 11 | ) 12 | from trlx.trainer import register_trainer 13 | from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer 14 | 15 | 16 | @dataclass 17 | @register_method 18 | class SFTConfig(MethodConfig): 19 | """ 20 | Config for SFT training 21 | 22 | :param gen_kwargs: kwargs for generation 23 | :type gen_kwargs: Dict[str, Any] 24 | """ 25 | 26 | gen_kwargs: dict 27 | 28 | 29 | @register_trainer 30 | class AccelerateSFTTrainer(AccelerateRLTrainer): 31 | def __init__(self, config: TRLConfig, **kwargs): 32 | super().__init__(config, **kwargs) 33 | 34 | self.generate_kwargs = dict( 35 | config.method.gen_kwargs, 36 | eos_token_id=self.tokenizer.eos_token_id, 37 | pad_token_id=self.tokenizer.pad_token_id, 38 | ) 39 | 40 | def get_arch(self, config): 41 | from_fn = AutoModelForCausalLM.from_pretrained 42 | if issubclass(type(config.model.model_path), PretrainedConfig): 43 | from_fn = AutoModelForCausalLM.from_config 44 | 45 | model = from_fn(config.model.model_path, **config.model.model_extra_configs) 46 | 47 | if config.model.peft_config is not None: 48 | # Initialize the peft adapter 49 | import peft 50 | 51 | peft_config = config.model.peft_config 52 | if not isinstance(peft_config, peft.PeftConfig): 53 | if isinstance(peft_config, dict): 54 | peft_config = peft.get_peft_config(peft_config) 55 | else: 56 | raise ValueError("`peft_config` should be an instance of `peft.PeftConfig` or a dict.") 57 | model = peft.get_peft_model(model, peft_config) 58 | if self.accelerator.is_main_process: 59 | model.print_trainable_parameters() 60 | 61 | return model 62 | 63 | def loss(self, batch): 64 | if "labels" in batch: 65 | labels = batch.labels.clone() 66 | else: 67 | labels = batch.input_ids.clone() 68 | labels[~batch.attention_mask.bool()] = -100 69 | 70 | loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=labels).loss 71 | stats = {"loss": loss.item()} 72 | 73 | return loss, stats 74 | 75 | def create_train_dataloader(self): 76 | return self.accelerator.prepare(self.store.create_loader(self.config.train.batch_size)) 77 | 78 | def prepare_learning(self): 79 | self.train_dataloader = self.create_train_dataloader() 80 | eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) 81 | 82 | ( 83 | self.model, 84 | self.opt, 85 | self.eval_dataloader, 86 | ) = self.accelerator.prepare(self.model, self.opt, eval_dataloader) 87 | 88 | self.n_inner_epochs = 1 89 | self.total_steps = self.config.train.epochs * len(self.train_dataloader) 90 | self.total_steps = min(self.total_steps, self.config.train.total_steps) 91 | 92 | def make_experience(self, samples, seq_length): 93 | if isinstance(samples[0], str): 94 | self.store = PromptPipeline(samples, seq_length, self.tokenizer) 95 | else: 96 | dialogs = [tokenize_dialogue(d, self.tokenizer, seq_length) for d in samples] 97 | self.store = DialogStore(dialogs, self.tokenizer) 98 | -------------------------------------------------------------------------------- /trlx/utils/loading.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, List 2 | 3 | # Register load pipelines via module import 4 | from trlx.pipeline import _DATAPIPELINE 5 | from trlx.pipeline.offline_pipeline import PromptPipeline 6 | 7 | # Register load trainers via module import 8 | from trlx.trainer import _TRAINERS, register_trainer 9 | from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer 10 | from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer 11 | from trlx.trainer.accelerate_rft_trainer import AccelerateRFTTrainer 12 | from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer 13 | 14 | try: 15 | from trlx.trainer.nemo_ilql_trainer import NeMoILQLTrainer 16 | from trlx.trainer.nemo_ppo_trainer import NeMoPPOTrainer 17 | from trlx.trainer.nemo_sft_trainer import NeMoSFTTrainer 18 | except ImportError: 19 | # NeMo is not installed 20 | def _trainers_unavailble(names: List[str]): 21 | def log_error(*args, **kwargs): 22 | raise ImportError("NeMo is not installed. Please install `nemo_toolkit` to use NeMo-based trainers.") 23 | 24 | # Register dummy trainers 25 | for name in names: 26 | register_trainer(name)(log_error) 27 | 28 | _trainers_unavailble(["NeMoILQLTrainer", "NeMoSFTTrainer", "NeMoPPOTrainer"]) 29 | 30 | 31 | def get_trainer(name: str) -> Callable: 32 | """ 33 | Return constructor for specified RL model trainer 34 | """ 35 | name = name.lower() 36 | if name in _TRAINERS: 37 | return _TRAINERS[name] 38 | else: 39 | raise Exception("Error: Trying to access a trainer that has not been registered") 40 | 41 | 42 | def get_pipeline(name: str) -> Callable: 43 | """ 44 | Return constructor for specified pipeline 45 | """ 46 | name = name.lower() 47 | if name in _DATAPIPELINE: 48 | return _DATAPIPELINE[name] 49 | else: 50 | raise Exception("Error: Trying to access a pipeline that has not been registered") 51 | --------------------------------------------------------------------------------