├── .github ├── ISSUE_TEMPLATE │ ├── bug_report.yml │ ├── documentation.yml │ └── feature_request.yml └── workflows │ ├── build.yml │ └── code_quality.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .readthedocs.yml ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── configs ├── deepspeed_configs │ ├── accelerate_t5_configs.yml │ └── default_configs.yml ├── ilql_config.yml ├── ppo_config.yml ├── ppo_config_t5.yml ├── ppo_config_t5_old.yml ├── ppo_gptj.yml ├── sweeps │ ├── ilql_sweep.yml │ └── ppo_sweep.yml └── test_config.yml ├── docs ├── Makefile ├── make.bat ├── requirements.txt └── source │ ├── conf.py │ ├── configs.rst │ ├── data.rst │ ├── examples.rst │ ├── index.rst │ ├── models.rst │ ├── orchestrator.rst │ └── pipeline.rst ├── examples ├── __init__.py ├── architext.py ├── experiments │ └── grounded_program_synthesis │ │ ├── README.md │ │ ├── __init__.py │ │ ├── configs │ │ └── trlx_ppo_config.yml │ │ ├── lang.py │ │ └── train_trlx.py ├── ilql_sentiments.py ├── ppo_reward_model.py ├── ppo_sentiments.py ├── prompts.json ├── randomwalks │ ├── README.md │ ├── __init__.py │ ├── configs │ │ ├── ilql_randomwalks.yml │ │ └── ppo_randomwalks.yml │ ├── ilql_randomwalks.py │ ├── ppo_randomwalks.py │ └── randomwalks.py └── simulacra.py ├── pyproject.toml ├── setup.cfg ├── setup.py ├── tests ├── __init__.py ├── test_configs.py ├── test_ppo.py └── test_utils.py └── trlx ├── __init__.py ├── data ├── __init__.py ├── accelerate_base_datatypes.py ├── configs.py ├── ilql_types.py ├── method_configs.py └── ppo_types.py ├── model ├── __init__.py ├── accelerate_base_model.py ├── accelerate_ilql_model.py ├── accelerate_ppo_model.py └── nn │ ├── __init__.py │ ├── ilql_models.py │ └── ppo_models.py ├── orchestrator ├── __init__.py ├── offline_orchestrator.py └── ppo_orchestrator.py ├── pipeline ├── __init__.py ├── offline_pipeline.py └── ppo_pipeline.py ├── ray_tune ├── __init__.py ├── train_funcs.py └── wandb.py ├── sweep.py ├── trlx.py └── utils ├── __init__.py ├── loading.py └── modeling.py /.github/ISSUE_TEMPLATE/bug_report.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🐛 Bug Report 3 | description: Report a bug or unexpected behavior to help us improve trlX 4 | labels: 5 | - bug 6 | 7 | body: 8 | - type: markdown 9 | attributes: 10 | value: > 11 | #### Before submitting your bug report, please check to see that the 12 | issue hasn't already been reported and/or fixed in a latest version. 13 | [Search Issues][Issue Search]. 14 | 15 | If you're asking a question or seeking support, please consider creating a 16 | new [GitHub discussion][Discussions] or heading over to CarperAI's 17 | [Discord server][CarperAI Discord]. 18 | 19 | 20 | [Issue Search]: https://github.com/CarperAI/trlx/search?q=is%3Aissue&type=issues 21 | 22 | [Discussions]: https://github.com/CarperAI/trlx/discussions 23 | 24 | [CarperAI Discord]: https://discord.gg/X2gHZMRP6m 25 | 26 | - type: textarea 27 | attributes: 28 | label: 🐛 Describe the bug 29 | description: >- 30 | Please provide a clear and concise description of what the problem is, 31 | preferably with self-contained code to reproduce the issue. You may want 32 | to follow the suggestions outlined in [this guide][Guide]. If you observe 33 | an error, please paste the error message including the full traceback. 34 | 35 | 36 | [Guide]: https://matthewrocklin.com/blog/work/2018/02/28/minimal-bug-reports 37 | 38 | placeholder: | 39 | A description of what the bug is. 40 | 41 | ```python 42 | # Sample code to reproduce the bug, if applicable. 43 | ``` 44 | 45 | ``` 46 | The error message, with the full traceback. 47 | ``` 48 | 49 | validations: 50 | required: true 51 | 52 | - type: input 53 | attributes: 54 | label: Which trlX version are you using? 55 | placeholder: For example, `trlx==1.0.0` 56 | 57 | - type: input 58 | attributes: 59 | label: Additional system and package information 60 | placeholder: Python version, `transformers` version, OS (Linux/Mac/Windows/WSL), etc. 61 | 62 | - type: markdown 63 | attributes: 64 | value: > 65 | Thanks for contributing 🐠! 66 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/documentation.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 📚 Documentation 3 | description: Report an issue related to https://trlx.readthedocs.io/en/latest/index.html 4 | labels: 5 | - documentation 6 | 7 | body: 8 | - type: textarea 9 | attributes: 10 | label: 📚 The doc issue 11 | description: > 12 | Please provide a clear and concise description of what content in https://trlx.readthedocs.io/en/latest/index.html is an issue. 13 | validations: 14 | required: true 15 | 16 | - type: textarea 17 | attributes: 18 | label: Suggest a potential alternative/fix 19 | description: > 20 | Tell us how we could improve the documentation in this regard. 21 | 22 | - type: markdown 23 | attributes: 24 | value: > 25 | Thanks for contributing 🐠! 26 | -------------------------------------------------------------------------------- /.github/ISSUE_TEMPLATE/feature_request.yml: -------------------------------------------------------------------------------- 1 | --- 2 | name: 🚀 Feature Request 3 | description: Submit a proposal/request for a new trlX feature 4 | labels: 5 | - feature request 6 | 7 | body: 8 | - type: textarea 9 | attributes: 10 | label: 🚀 The feature, motivation, and pitch 11 | description: > 12 | Please provide a clear and concise description of the feature proposal. 13 | Outline the motivation for the proposal; is your feature request related to a 14 | specific problem? E.g., *"I'm working on X and would like Y to be 15 | possible"*. If this is related to another GitHub issue, please link here 16 | too. 17 | validations: 18 | required: true 19 | 20 | - type: textarea 21 | attributes: 22 | label: Alternatives 23 | description: > 24 | A description of any alternative solutions or features you've considered, 25 | if any. 26 | 27 | - type: textarea 28 | attributes: 29 | label: Additional context 30 | description: > 31 | Add any other context or screenshots about the feature request. 32 | 33 | - type: markdown 34 | attributes: 35 | value: > 36 | Thanks for contributing 🐠! 37 | -------------------------------------------------------------------------------- /.github/workflows/build.yml: -------------------------------------------------------------------------------- 1 | name: Build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | steps: 15 | - uses: actions/checkout@v2 16 | 17 | - name: Set up Python 3.9 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: 3.9.13 21 | cache: 'pip' 22 | 23 | - name: Install dependencies 24 | run: | 25 | python -m pip install --upgrade pip 26 | pip install -e .[dev] 27 | 28 | - name: Lint with flake8 29 | run: | 30 | # Stop the build if there are Python syntax errors or undefined names 31 | flake8 . --count --select=E9,F63,F7 --show-source --statistics 32 | # Exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 33 | flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics 34 | 35 | - name: Run tests 36 | run: | 37 | pytest -vv --cov=trlx/ tests/ 38 | 39 | - name: Upload coverage to Codecov 40 | run: | 41 | bash <(curl -s https://codecov.io/bash) -t $CODECOV_TOKEN 42 | -------------------------------------------------------------------------------- /.github/workflows/code_quality.yml: -------------------------------------------------------------------------------- 1 | name: Code Quality 2 | 3 | on: [pull_request] 4 | 5 | jobs: 6 | code-quality: 7 | runs-on: ubuntu-20.04 8 | steps: 9 | - uses: actions/checkout@v2 10 | - uses: actions/setup-python@v2 11 | with: 12 | python-version: 3.9 13 | - uses: pre-commit/action@v2.0.3 14 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.bak 2 | .gitattributes 3 | .last_checked 4 | .gitconfig 5 | *.bak 6 | *.log 7 | *~ 8 | ~* 9 | _tmp* 10 | tmp* 11 | tags 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | env/ 24 | build/ 25 | develop-eggs/ 26 | dist/ 27 | downloads/ 28 | eggs/ 29 | .eggs/ 30 | lib/ 31 | lib64/ 32 | parts/ 33 | sdist/ 34 | var/ 35 | wheels/ 36 | *.egg-info/ 37 | .installed.cfg 38 | *.egg 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # dotenv 95 | .env 96 | 97 | # virtualenv 98 | .venv 99 | venv/ 100 | ENV/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | 115 | .vscode 116 | *.swp 117 | 118 | # osx generated files 119 | .DS_Store 120 | .DS_Store? 121 | .Trashes 122 | ehthumbs.db 123 | Thumbs.db 124 | .idea 125 | 126 | # pytest 127 | .pytest_cache 128 | 129 | # tools/trust-doc-nbs 130 | docs_src/.last_checked 131 | 132 | # symlinks to fastai 133 | docs_src/fastai 134 | tools/fastai 135 | 136 | # link checker 137 | checklink/cookies.txt 138 | 139 | # .gitconfig is now autogenerated 140 | .gitconfig 141 | 142 | 143 | nbs/wandb/ 144 | 145 | wandb/ 146 | 147 | OUT/ 148 | 149 | 150 | examples/experiments/grounded_program_synthesis/dataset 151 | ckpts/ 152 | 153 | ray_results/ 154 | checkpoints/ 155 | _checkpoints/ 156 | base_models/ 157 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | # This should be the _latest_ version of python supported by us 4 | default_language_version: 5 | python: python3.9 6 | repos: 7 | - repo: https://github.com/pre-commit/pre-commit-hooks 8 | rev: v3.2.0 9 | hooks: 10 | - id: trailing-whitespace 11 | - id: end-of-file-fixer 12 | - id: check-yaml 13 | - repo: https://github.com/psf/black 14 | rev: 22.10.0 15 | hooks: 16 | - id: black 17 | files: ^(trlx|examples|unittests|setup.py)/ 18 | -------------------------------------------------------------------------------- /.readthedocs.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | 3 | sphinx: 4 | configuration: docs/source/conf.py 5 | 6 | python: 7 | version: 3.9 8 | install: 9 | - requirements: docs/requirements.txt 10 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Contributor Covenant Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | We as members, contributors, and leaders pledge to make participation in our 6 | community a harassment-free experience for everyone, regardless of age, body 7 | size, visible or invisible disability, ethnicity, sex characteristics, gender 8 | identity and expression, level of experience, education, socio-economic status, 9 | nationality, personal appearance, race, religion, or sexual identity 10 | and orientation. 11 | 12 | We pledge to act and interact in ways that contribute to an open, welcoming, 13 | diverse, inclusive, and healthy community. 14 | 15 | ## Our Standards 16 | 17 | Examples of behavior that contributes to a positive environment for our 18 | community include: 19 | 20 | * Demonstrating empathy and kindness toward other people 21 | * Being respectful of differing opinions, viewpoints, and experiences 22 | * Giving and gracefully accepting constructive feedback 23 | * Accepting responsibility and apologizing to those affected by our mistakes, 24 | and learning from the experience 25 | * Focusing on what is best not just for us as individuals, but for the 26 | overall community 27 | 28 | Examples of unacceptable behavior include: 29 | 30 | * The use of sexualized language or imagery, and sexual attention or 31 | advances of any kind 32 | * Trolling, insulting or derogatory comments, and personal or political attacks 33 | * Public or private harassment 34 | * Publishing others' private information, such as a physical or email 35 | address, without their explicit permission 36 | * Other conduct which could reasonably be considered inappropriate in a 37 | professional setting 38 | 39 | ## Enforcement Responsibilities 40 | 41 | Community leaders are responsible for clarifying and enforcing our standards of 42 | acceptable behavior and will take appropriate and fair corrective action in 43 | response to any behavior that they deem inappropriate, threatening, offensive, 44 | or harmful. 45 | 46 | Community leaders have the right and responsibility to remove, edit, or reject 47 | comments, commits, code, wiki edits, issues, and other contributions that are 48 | not aligned to this Code of Conduct, and will communicate reasons for moderation 49 | decisions when appropriate. 50 | 51 | ## Scope 52 | 53 | This Code of Conduct applies within all community spaces, and also applies when 54 | an individual is officially representing the community in public spaces. 55 | Examples of representing our community include using an official e-mail address, 56 | posting via an official social media account, or acting as an appointed 57 | representative at an online or offline event. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported to the community leaders responsible for enforcement at 63 | caperai@stability.ai. 64 | All complaints will be reviewed and investigated promptly and fairly. 65 | 66 | All community leaders are obligated to respect the privacy and security of the 67 | reporter of any incident. 68 | 69 | ## Enforcement Guidelines 70 | 71 | Community leaders will follow these Community Impact Guidelines in determining 72 | the consequences for any action they deem in violation of this Code of Conduct: 73 | 74 | ### 1. Correction 75 | 76 | **Community Impact**: Use of inappropriate language or other behavior deemed 77 | unprofessional or unwelcome in the community. 78 | 79 | **Consequence**: A private, written warning from community leaders, providing 80 | clarity around the nature of the violation and an explanation of why the 81 | behavior was inappropriate. A public apology may be requested. 82 | 83 | ### 2. Warning 84 | 85 | **Community Impact**: A violation through a single incident or series 86 | of actions. 87 | 88 | **Consequence**: A warning with consequences for continued behavior. No 89 | interaction with the people involved, including unsolicited interaction with 90 | those enforcing the Code of Conduct, for a specified period of time. This 91 | includes avoiding interactions in community spaces as well as external channels 92 | like social media. Violating these terms may lead to a temporary or 93 | permanent ban. 94 | 95 | ### 3. Temporary Ban 96 | 97 | **Community Impact**: A serious violation of community standards, including 98 | sustained inappropriate behavior. 99 | 100 | **Consequence**: A temporary ban from any sort of interaction or public 101 | communication with the community for a specified period of time. No public or 102 | private interaction with the people involved, including unsolicited interaction 103 | with those enforcing the Code of Conduct, is allowed during this period. 104 | Violating these terms may lead to a permanent ban. 105 | 106 | ### 4. Permanent Ban 107 | 108 | **Community Impact**: Demonstrating a pattern of violation of community 109 | standards, including sustained inappropriate behavior, harassment of an 110 | individual, or aggression toward or disparagement of classes of individuals. 111 | 112 | **Consequence**: A permanent ban from any sort of public interaction within 113 | the community. 114 | 115 | ## Attribution 116 | 117 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], 118 | version 2.0, available at 119 | https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. 120 | 121 | Community Impact Guidelines were inspired by [Mozilla's code of conduct 122 | enforcement ladder](https://github.com/mozilla/diversity). 123 | 124 | [homepage]: https://www.contributor-covenant.org 125 | 126 | For answers to common questions about this code of conduct, see the FAQ at 127 | https://www.contributor-covenant.org/faq. Translations are available at 128 | https://www.contributor-covenant.org/translations. 129 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to `trlX` 2 | 3 | Looking to improve `trlX`? Thanks for considering! 4 | 5 | There are many ways to contribute, from writing tutorials in [Colab notebooks](https://colab.research.google.com) to improving the project's [documentation](https://trlx.readthedocs.io), submitting bug reports and feature requests, or even implementing new features themselves. See the outstanding [issues](https://github.com/CarperAI/trlx/issues) for ideas on where to begin. 6 | 7 | Here are some guidelines to help you get started 🚀. 8 | 9 | ## Submitting a bug report or a feature request¶ 10 | 11 | To submit a bug report or a feature request, please open an [issue](https://github.com/CarperAI/trlx/issues) by clicking on the `New Issue` button and selecting the respective issue template. Make sure to fill out all the required information and provide as much detail as possible. For bug reports, this means including a minimal code example that reproduces the bug, and for feature requests, it means providing a clear and detailed description of the feature you would like to see implemented. 12 | 13 | ## Submitting code 14 | 15 | > **Note**: Make sure to first search through the [issue tracker](https://github.com/CarperAI/trlx/issues) and [PR list](https://github.com/CarperAI/trlx/pulls) to avoid duplicating work. If you want to work on a non-trivial feature, we highly recommended that you first open an issue in the [issue tracker](https://github.com/CarperAI/trlx/issues) to get feedback from core developers. 16 | 17 | Follow these steps to start contributing code: 18 | 19 | 1. Create your own [fork](https://docs.github.com/en/get-started/quickstart/fork-a-repo#forking-a-repository) of the repository and clone it to your local machine. 20 | ```bash 21 | git clone https://github.com//trlx.git 22 | cd trlx 23 | git remote add upstream https://github.com/CarperAI/trlx.git 24 | ``` 25 | 2. Create a new branch for your changes and give it a concise name that reflects your contribution. 26 | ```bash 27 | git checkout -b 28 | ``` 29 | 2. Install the development dependencies in a Python environment. 30 | ```bash 31 | pip install -e ".[dev]" 32 | pre-commit install 33 | ``` 34 | 4. Implement your changes. Make small, independent, and well documented commits along the way (check out [these](https://cbea.ms/git-commit/) tips). 35 | 5. Add unit tests whenever appropriate and ensure that the tests pass. To run the entire test suite, use the following command from within the project root directory. 36 | ```bash 37 | pytest 38 | ``` 39 | For changes with minimal project scope (e.g. a simple bug fix), you might want to run the unit tests for just a specific test file instead: 40 | ```bash 41 | pytest -vv -k "" 42 | ``` 43 | 5. Commit your final changes. Our `pre-commit` hooks will automatically run before each commit and will prevent you from committing code that does not pass our style and linter checks. They'll also automatically format your code! To run these manually, use the following command: 44 | ```bash 45 | pre-commit run --all-files 46 | ``` 47 | 48 | 6. Push the changes to your fork. 49 | 50 | Finally ... 🥁 ... Create a [pull request](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests/creating-a-pull-request) to the `trlX` repository! Make sure to include a description of your changes and link to any relevant issues. 51 | 52 | > __Tip__: If you're looking to introduce an experimental feature, we suggest testing the behavior of your proposed feature on some of the existing [examples](https://github.com/CarperAI/trlx/tree/master/examples), such as [random walks](https://github.com/CarperAI/trlx/blob/master/examples/randomwalks). This will help you get a better sense of how the feature would work in practice and will also help you identify any potential flaws in the implementation. 53 | 54 | ## Asking questions 55 | 56 | Have a question? Rather than opening an issue, you can readily chat with the core team on our [Discord server](https://discord.gg/canadagoose). 57 | 58 | ## Code of conduct 59 | 60 | This project adheres to the [Contributor Covenant Code of Conduct](https://github.com/CarperAI/trlx/blob/master/CODE_OF_CONDUCT.md). By participating, you are expected to uphold this code. 61 | 62 | ## License 63 | 64 | By contributing, you agree that your contributions will be licensed under its MIT License. 65 | 66 | # Thank you for your contribution 🐠! 67 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 CarperAI 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [docs-image]: https://readthedocs.org/projects/trlX/badge/?version=latest 2 | [docs-url]: https://trlX.readthedocs.io/en/latest/?badge=latest 3 | 4 | # Transformer Reinforcement Learning X 5 | 6 | trlX allows you to fine-tune 🤗 Hugging Face supported language models (`gpt2`, `gpt-j`, `gpt-neo` and `gpt-neox` based) up to 20B parameters using reinforcement learning via either a provided reward function or reward-labeled dataset. Proximal Policy Optimization ([PPO](https://arxiv.org/pdf/1909.08593.pdf)) and Implicit Language Q-Learning ([ILQL](https://sea-snell.github.io/ILQL_site/)) are implemented. 7 | 8 | You can read more about trlX in our [documentation](https://trlX.readthedocs.io). 9 | 10 | ## Installation 11 | ```bash 12 | git clone https://github.com/CarperAI/trlx.git 13 | cd trlx 14 | pip install torch --extra-index-url https://download.pytorch.org/whl/cu116 # for cuda 15 | pip install -e . 16 | ``` 17 | 18 | ## How to Train 19 | You can train a model using a reward function or a reward-labeled dataset. 20 | 21 | #### Using a reward function 22 | ```python 23 | model = trlx.train('gpt2', reward_fn=lambda samples: [sample.count('cats') for sample in samples]) 24 | ``` 25 | #### Using a reward-labeled dataset 26 | ```python 27 | model = trlx.train('EleutherAI/gpt-j-6B', dataset=[('dolphins', 'geese'), (1.0, 100.0)]) 28 | ``` 29 | 30 | #### Trained model is a wrapper over a given autoregressive model 31 | ```python 32 | model.generate(**tokenizer('Q: Who rules the world? A:', return_tensors='pt'), do_sample=True) 33 | ``` 34 | 35 | #### Use 🤗 Accelerate to launch distributed training 36 | 37 | ```bash 38 | accelerate config # choose DeepSpeed option 39 | accelerate launch examples/simulacra.py 40 | ``` 41 | 42 | #### Use Ray Tune to launch hyperparameter sweep 43 | ```bash 44 | python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py 45 | ``` 46 | 47 | For more usage see [examples](./examples) 48 | 49 | ## Contributing 50 | 51 | For development check out these [guidelines](./CONTRIBUTING.md) 52 | and also read our [docs](https://trlX.readthedocs.io) 53 | 54 | ## Acknowledgements 55 | 56 | Many thanks to Leandro von Werra for contributing with [trl](https://github.com/lvwerra/trl/), a library that initially inspired this repo. 57 | -------------------------------------------------------------------------------- /configs/deepspeed_configs/accelerate_t5_configs.yml: -------------------------------------------------------------------------------- 1 | command_file: null 2 | commands: null 3 | compute_environment: LOCAL_MACHINE 4 | deepspeed_config: 5 | gradient_accumulation_steps: 1 6 | gradient_clipping: 1.0 7 | offload_optimizer_device: cpu 8 | offload_param_device: none 9 | zero3_init_flag: false 10 | zero_stage: 2 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: 'no' 13 | dynamo_backend: 'NO' 14 | fsdp_config: {} 15 | gpu_ids: null 16 | machine_rank: 0 17 | main_process_ip: null 18 | main_process_port: null 19 | main_training_function: main 20 | megatron_lm_config: {} 21 | mixed_precision: bf16 22 | num_machines: 1 23 | num_processes: 1 24 | rdzv_backend: static 25 | same_network: true 26 | tpu_name: null 27 | tpu_zone: null 28 | use_cpu: false -------------------------------------------------------------------------------- /configs/deepspeed_configs/default_configs.yml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | gradient_accumulation_steps: 1 4 | gradient_clipping: 1.0 5 | offload_optimizer_device: none 6 | offload_param_device: none 7 | zero3_init_flag: true 8 | zero_stage: 2 9 | distributed_type: DEEPSPEED 10 | downcast_bf16: 'no' 11 | fsdp_config: {} 12 | machine_rank: 0 13 | main_process_ip: null 14 | main_process_port: null 15 | main_training_function: main 16 | mixed_precision: 'no' 17 | num_machines: 1 18 | num_processes: 2 19 | use_cpu: false 20 | -------------------------------------------------------------------------------- /configs/ilql_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 64 3 | batch_size: 128 4 | epochs: 100 5 | total_steps: 1000 6 | 7 | checkpoint_interval: 1000 8 | eval_interval: 100 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "OfflineOrchestrator" 12 | 13 | seed: 1000 14 | 15 | model: 16 | model_type: "AccelerateILQLModel" 17 | model_path: "gpt2" 18 | tokenizer_path: "gpt2" 19 | num_layers_unfrozen: -1 20 | 21 | optimizer: 22 | name: "adamw" 23 | kwargs: 24 | lr: 5.0e-5 25 | betas: [0.9, 0.95] 26 | eps: 1.0e-8 27 | weight_decay: 1.0e-6 28 | 29 | scheduler: 30 | name: "cosine_annealing" 31 | kwargs: 32 | T_max: 1000 # train.total_steps 33 | eta_min: 5.0e-5 34 | 35 | method: 36 | name: "ilqlconfig" 37 | tau: 0.7 38 | gamma: 0.99 39 | cql_scale: 0.1 40 | awac_scale: 1 41 | alpha: 0.001 42 | steps_for_target_q_sync: 5 43 | two_qs: true 44 | gen_kwargs: 45 | max_new_tokens: 56 46 | top_k: 20 47 | beta: 4 48 | temperature: 1.0 49 | -------------------------------------------------------------------------------- /configs/ppo_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 1024 3 | epochs: 50 4 | total_steps: 10000 5 | batch_size: 128 6 | 7 | checkpoint_interval: 10000 8 | eval_interval: 100 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "PPOOrchestrator" 12 | 13 | model: 14 | model_type: "AcceleratePPOModel" 15 | model_path: "lvwerra/gpt2-imdb" 16 | tokenizer_path: "gpt2" 17 | num_layers_unfrozen: 2 18 | 19 | optimizer: 20 | name: "adamw" 21 | kwargs: 22 | lr: 1.0e-4 23 | betas: [0.9, 0.95] 24 | eps: 1.0e-8 25 | weight_decay: 1.0e-6 26 | 27 | scheduler: 28 | name: "cosine_annealing" 29 | kwargs: 30 | T_max: 10000 # train.total_steps 31 | eta_min: 1.0e-4 32 | 33 | method: 34 | name: "ppoconfig" 35 | num_rollouts: 128 36 | chunk_size: 128 37 | ppo_epochs: 4 38 | init_kl_coef: 0.05 39 | target: 6 40 | horizon: 10000 41 | gamma: 1 42 | lam: 0.95 43 | cliprange: 0.2 44 | cliprange_value: 0.2 45 | vf_coef: 1 46 | scale_reward: False 47 | ref_mean: null 48 | ref_std: null 49 | cliprange_reward: 10 50 | gen_kwargs: 51 | max_new_tokens: 40 52 | top_k: 0 53 | top_p: 1.0 54 | do_sample: True 55 | -------------------------------------------------------------------------------- /configs/ppo_config_t5.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 512 3 | epochs: 500000 # We'll stop training when we reach 10,000 steps 4 | total_steps: 6000 5 | batch_size: 1 6 | 7 | checkpoint_interval: 100000000000 # Don't save checkpoints 8 | checkpoint_dir: "/root/trlx-with-T5/checkpoints" 9 | eval_interval: 64 10 | 11 | pipeline: "PromptPipeline" 12 | orchestrator: "T5PPOOrchestrator" 13 | seed: 314159 14 | 15 | model: 16 | model_type: "T5AcceleratePPOModel" 17 | model_path: "/root/trlx-with-T5/base_models/FT-Flan-T5-XXL" 18 | tokenizer_path: "google/flan-t5-xxl" 19 | # model_path: "google/flan-t5-small" 20 | # tokenizer_path: "google/flan-t5-small" 21 | 22 | optimizer: 23 | name: "adamw" 24 | kwargs: 25 | lr: 5.0e-5 26 | betas: [0.9, 0.95] 27 | eps: 1.0e-8 28 | weight_decay: 1.0e-6 29 | 30 | scheduler: 31 | name: "cosine_annealing" 32 | kwargs: 33 | T_max: 6000 # train.total_steps 34 | eta_min: 5.0e-5 35 | 36 | method: 37 | name: "ppoconfig" 38 | num_rollouts: 8 39 | chunk_size: 8 40 | ppo_epochs: 4 41 | init_kl_coef: 0.5 42 | target: 6.0 43 | horizon: 6000 44 | gamma: 0.99 45 | lam: 0.95 46 | cliprange: 0.2 47 | cliprange_value: 0.2 48 | vf_coef: 1.0 49 | scale_reward: False 50 | ref_mean: null 51 | ref_std: null 52 | cliprange_reward: 12 53 | gen_kwargs: 54 | max_new_tokens: 50 55 | top_k: 0 56 | top_p: 1.0 57 | do_sample: True 58 | -------------------------------------------------------------------------------- /configs/ppo_config_t5_old.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 512 3 | epochs: 500000 # We'll stop training when we reach 10,000 steps 4 | total_steps: 6000 5 | batch_size: 1 6 | 7 | checkpoint_interval: 100000000000 # Don't save checkpoints 8 | checkpoint_dir: "/root/trlx-with-T5/checkpoints" 9 | eval_interval: 64 10 | 11 | pipeline: "PromptPipeline" 12 | orchestrator: "T5PPOOrchestrator" 13 | seed: 1000 14 | 15 | model: 16 | model_type: "T5AcceleratePPOModel" 17 | model_path: "/root/trlx-with-T5/base_models/FT-Flan-T5-XXL" 18 | tokenizer_path: "google/flan-t5-xxl" 19 | # model_path: "google/flan-t5-small" 20 | # tokenizer_path: "google/flan-t5-small" 21 | 22 | optimizer: 23 | name: "adamw" 24 | kwargs: 25 | lr: 5.0e-5 26 | betas: [0.9, 0.95] 27 | eps: 1.0e-8 28 | weight_decay: 1.0e-6 29 | 30 | scheduler: 31 | name: "cosine_annealing" 32 | kwargs: 33 | T_max: 6000 # train.total_steps 34 | eta_min: 5.0e-5 35 | 36 | method: 37 | name: "ppoconfig" 38 | num_rollouts: 8 39 | chunk_size: 8 40 | ppo_epochs: 4 41 | init_kl_coef: 0.5 42 | target: 6 43 | horizon: 6000 44 | gamma: 0.99 45 | lam: 0.95 46 | cliprange: 0.2 47 | cliprange_value: 0.2 48 | vf_coef: 1.0 49 | scale_reward: False 50 | ref_mean: null 51 | ref_std: null 52 | cliprange_reward: 12 53 | gen_kwargs: 54 | max_new_tokens: 50 55 | top_k: 0 56 | top_p: 1.0 57 | do_sample: True -------------------------------------------------------------------------------- /configs/ppo_gptj.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 48 3 | epochs: 10 4 | total_steps: 80000 5 | batch_size: 8 6 | 7 | checkpoint_interval: 1000000 8 | eval_interval: 16 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "PPOOrchestrator" 12 | 13 | model: 14 | model_type: "AcceleratePPOModel" 15 | model_path: "EleutherAI/gpt-j-6B" 16 | tokenizer_path: "gpt2" 17 | num_layers_unfrozen: 2 18 | 19 | optimizer: 20 | name: "adamw" 21 | kwargs: 22 | lr: 1.412e-4 23 | betas: [0.9, 0.95] 24 | eps: 1.0e-8 25 | weight_decay: 1.0e-6 26 | 27 | scheduler: 28 | name: "cosine_annealing" 29 | kwargs: 30 | T_max: 80000 # train.total_steps 31 | eta_min: 1.412e-4 32 | 33 | method: 34 | name: "ppoconfig" 35 | num_rollouts: 8 36 | chunk_size: 8 37 | ppo_epochs: 4 38 | init_kl_coef: 0.2 39 | target: 6 40 | horizon: 10000 41 | gamma: 1 42 | lam: 0.95 43 | cliprange: 0.2 44 | cliprange_value: 0.2 45 | vf_coef: 0.2 46 | scale_reward: False 47 | ref_mean: null 48 | ref_std: null 49 | cliprange_reward: 10 50 | gen_kwargs: 51 | max_new_tokens: 48 52 | top_k: 0.0 53 | top_p: 0.7 54 | do_sample: True 55 | temperature: 0.5 56 | -------------------------------------------------------------------------------- /configs/sweeps/ilql_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "metrics/sentiments" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 32 7 | 8 | lr_init: 9 | strategy: "loguniform" 10 | values: [0.00001, 0.01] 11 | tau: 12 | strategy: "uniform" 13 | values: [0.6, 0.9] 14 | steps_for_target_q_sync: 15 | strategy: "choice" 16 | values: [1, 5, 10] 17 | alpha: 18 | strategy: "loguniform" 19 | values: [0.001, 1.0] 20 | -------------------------------------------------------------------------------- /configs/sweeps/ppo_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "mean_reward" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 32 7 | 8 | # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 9 | lr_init: 10 | strategy: "loguniform" 11 | values: [0.00001, 0.01] 12 | init_kl_coef: 13 | strategy: "uniform" 14 | values: [0, 0.2] 15 | vf_coef: 16 | strategy: "uniform" 17 | values: [0.5, 2] 18 | -------------------------------------------------------------------------------- /configs/test_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 64 # Size of LM context 3 | epochs: 100 # Train for max(epochs, total_steps) 4 | total_steps: 1000 # Train for max(epochs, total_steps) 5 | batch_size: 16 # batch size 6 | 7 | checkpoint_interval: 10000 # checkpoint interval 8 | eval_interval: 128 # eval interval 9 | 10 | pipeline: "PromptPipeline" # prompt pipeline to load 11 | orchestrator: "PPOOrchestrator" # orchestrator to load 12 | 13 | model: 14 | model_type: "AcceleratePPOModel" # Name of accelerate model type to load 15 | model_path: "lvwerra/gpt2-imdb" # Name of hf model to load 16 | tokenizer_path: "gpt2" # Name of hf tokenizer to load 17 | num_layers_unfrozen: 2 # Number of bottom layers to freeze during training 18 | 19 | optimizer: 20 | name: "adamw" # Name of optimizer to load 21 | kwargs: 22 | lr: 1.412e-4 # Learning rate 23 | betas: [0.9, 0.95] # Adam betas 24 | eps: 1.0e-8 # Adam eps 25 | weight_decay: 1.0e-6 # Weight decay param 26 | 27 | scheduler: 28 | name: "cosine_annealing" # Name of learning rate scheduler 29 | kwargs: 30 | T_max: 10000 # Maximum number of steps 31 | eta_min: 1.412e-4 # Minimum learning rate 32 | 33 | method: 34 | name: "ppoconfig" # Name of RL method config 35 | num_rollouts: 128 # Number of rollouts to collect per epoch 36 | chunk_size: 128 # Number of rollouts to collect in one loop of orchestrator 37 | ppo_epochs: 4 # Number of ppo epochs 38 | init_kl_coef: 0.2 # init kl coefficient 39 | target: 6 # target kl coefficient, set None for fixed kl coef 40 | horizon: 10000 # PPO horizon 41 | gamma: 0.99 # PPO discount 42 | lam: 0.95 # PPO lambda 43 | cliprange: 0.2 # clip range 44 | cliprange_value: 0.2 # clip range 45 | vf_coef: 1.0 # value term weight 46 | scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards 47 | cliprange_reward: 10 48 | ref_mean: null 49 | ref_std: null 50 | gen_kwargs: 51 | max_length: 48 # LM max sample gen length 52 | min_length: 48 # LM min sample gen length 53 | top_k: 0.0 # top k 54 | top_p: 1.0 # top p 55 | do_sample: True # sample 56 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx==4.0.0 2 | sphinx_rtd_theme 3 | accelerate==0.12.0 4 | datasets==2.4.0 5 | deepspeed==0.7.3 6 | einops==0.4.1 7 | numpy==1.23.2 8 | tqdm==4.64.0 9 | transformers==4.21.2 10 | wandb==0.13.2 11 | torchtyping 12 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | sys.path.insert(0, os.path.abspath('../..')) 16 | 17 | 18 | # -- Project information ----------------------------------------------------- 19 | 20 | project = 'trlX' 21 | copyright = '2022, CarperAI' 22 | author = 'CarperAI' 23 | 24 | # -- General configuration --------------------------------------------------- 25 | 26 | # Add any Sphinx extension module names here, as strings. They can be 27 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 28 | # ones. 29 | 30 | import sphinx_rtd_theme 31 | 32 | extensions = [ 33 | 'sphinx_rtd_theme', 34 | 'sphinx.ext.todo', 35 | 'sphinx.ext.viewcode', 36 | 'sphinx.ext.autodoc', 37 | 'sphinx.ext.autosummary', 38 | 'sphinx.ext.autosectionlabel' 39 | ] 40 | 41 | # Add any paths that contain templates here, relative to this directory. 42 | templates_path = ['_templates'] 43 | 44 | # List of patterns, relative to source directory, that match files and 45 | # directories to ignore when looking for source files. 46 | # This pattern also affects html_static_path and html_extra_path. 47 | exclude_patterns = [] 48 | 49 | 50 | # -- Options for HTML output ------------------------------------------------- 51 | 52 | # The theme to use for HTML and HTML Help pages. See the documentation for 53 | # a list of builtin themes. 54 | # 55 | html_theme = 'sphinx_rtd_theme' 56 | 57 | # Add any paths that contain custom static files (such as style sheets) here, 58 | # relative to this directory. They are copied after the builtin static files, 59 | # so a file named "default.css" will overwrite the builtin "default.css". 60 | html_static_path = ['_static'] 61 | -------------------------------------------------------------------------------- /docs/source/configs.rst: -------------------------------------------------------------------------------- 1 | .. _configs: 2 | 3 | Configs 4 | ************************ 5 | 6 | Training a model in TRL will require you to set several configs: 7 | ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like 8 | training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for 9 | the specific method being used (i.e. ILQL or PPO) 10 | 11 | 12 | **General** 13 | 14 | .. autoclass:: trlx.data.configs.TRLConfig 15 | :members: 16 | 17 | .. autoclass:: trlx.data.configs.ModelConfig 18 | :members: 19 | 20 | .. autoclass:: trlx.data.configs.TrainConfig 21 | :members: 22 | 23 | .. autoclass:: trlx.data.method_configs.MethodConfig 24 | :members: 25 | 26 | **PPO** 27 | 28 | .. autoclass:: trlx.data.method_configs.PPOConfig 29 | :members: 30 | 31 | **ILQL** 32 | 33 | .. autoclass:: trlx.data.method_configs.ILQLConfig 34 | :members: 35 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | Data Elements 4 | ************************ 5 | 6 | All of the major Carper projects: trlX, CHEESE, and magiCARP use 7 | dataclasses corresponding to batches of data to communicate data between models and different 8 | components. trlX is no different, though it has many different dataclasses for 9 | different components like training or inference. Currently, we support PPO and ILQL, which 10 | each demand different kinds of data during training. 11 | 12 | 13 | **Basic Data Elements for Accelerate** 14 | 15 | .. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement 16 | :members: 17 | 18 | .. autoclass:: trlx.data.accelerate_base_datatypes.PromptBatch 19 | :members: 20 | 21 | .. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLElement 22 | :members: 23 | 24 | .. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement 25 | :members: 26 | 27 | **Data Elements for PPO** 28 | 29 | .. autoclass:: trlx.data.ppo_types.PPORLElement 30 | :members: 31 | 32 | .. autoclass:: trlx.data.ppo_types.PPORLBatch 33 | :members: 34 | 35 | **Data Elements for ILQL** 36 | 37 | .. autoclass:: trlx.data.ilql_types.ILQLElement 38 | :members: 39 | 40 | .. autoclass:: trlx.data.ilql_types.ILQLBatch 41 | :members: 42 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | ************************ 5 | 6 | In the ``examples`` folder you can find several example training tasks. Check 7 | the configs folder for the associated configs files. ``examples.randomwalks`` 8 | does offline reinforcement on a set of graph random walks to stitch shortest 9 | paths to some destination. ``examples.simulacra`` optimizes prompts by using 10 | prompts-ratings dataset (https://github.com/JD-P/simulacra-aesthetic-captions). 11 | ``examples.architext`` tries to optimize designs represented textually by 12 | minimazing number of rooms (pretrained model is under a license on hf). 13 | ``examples.ilql_sentiments`` and ``examples.ppo_sentiments`` train to generate 14 | movie reviews with a positive sentiment, in offline setting – by fitting to IMDB 15 | dataset sentiment scores, and in online setting – by sampling finetuned on IMDB 16 | model and rating samples with learned sentiment reward model, You can tweak 17 | these scripts to your liking and tune hyperparameters to your problem if you 18 | wish to use trlx for some custom task. 19 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. trlX documentation master file, created by 2 | sphinx-quickstart on Mon Oct 3 21:21:33 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to trlX's documentation! 7 | ================================ 8 | trlX is a library made for training large language models using reinforcement learning. It 9 | currently supports training using PPO or ILQL for models up to 20B using Accelerate. 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | data 16 | models 17 | orchestrator 18 | configs 19 | pipeline 20 | examples 21 | 22 | Indices and tables 23 | ================== 24 | 25 | * :ref:`genindex` 26 | * :ref:`modindex` 27 | * :ref:`search` 28 | -------------------------------------------------------------------------------- /docs/source/models.rst: -------------------------------------------------------------------------------- 1 | .. _models: 2 | 3 | RL Models 4 | ******************* 5 | 6 | RL Models are what you're training with trlX. Currently, we support PPO and ILQL. 7 | Note that new models must be registered with ``trlx.model.register_model``. 8 | 9 | **General** 10 | 11 | .. autoclass:: trlx.model.BaseRLModel 12 | :members: 13 | 14 | .. autoclass:: trlx.model.accelerate_base_model.AccelerateRLModel 15 | :members: 16 | 17 | **PPO** 18 | 19 | .. autoclass:: trlx.model.accelerate_ppo_model.AcceleratePPOModel 20 | :members: 21 | 22 | .. autoclass:: trlx.model.nn.ppo_models.CausalLMWithValueHead 23 | :members: 24 | 25 | .. autoclass:: trlx.model.nn.ppo_models.GPTModelBranch 26 | :members: 27 | 28 | .. autoclass:: trlx.model.nn.ppo_models.OPTModelBranch 29 | :members: 30 | 31 | .. autoclass:: trlx.model.nn.ppo_models.CausalLMHydraWithValueHead 32 | :members: 33 | 34 | **ILQL** 35 | 36 | .. autoclass:: trlx.model.accelerate_ilql_model.AccelerateILQLModel 37 | :members: 38 | 39 | .. autoclass:: trlx.model.nn.ilql_models.CausalLMWithValueHeads 40 | :members: 41 | -------------------------------------------------------------------------------- /docs/source/orchestrator.rst: -------------------------------------------------------------------------------- 1 | .. _orchestrator: 2 | 3 | Orchestrators 4 | ******************* 5 | 6 | Orchestrators manage reading data from a pipeline and creating RL data elements (i.e. ``trlx.data.RLElement``) 7 | to push to a models rollout storage. Use the ``trlx.orchestrator.register_orchestrator`` decorator when creating 8 | new orchestrators. 9 | 10 | **General** 11 | 12 | .. autoclass:: trlx.orchestrator.Orchestrator 13 | :members: 14 | 15 | **PPO** 16 | 17 | .. autoclass:: trlx.orchestrator.ppo_orchestrator.PPOOrchestrator 18 | :members: 19 | 20 | **ILQL** 21 | 22 | .. autoclass:: trlx.orchestrator.offline_orchestrator.OfflineOrchestrator 23 | :members: 24 | -------------------------------------------------------------------------------- /docs/source/pipeline.rst: -------------------------------------------------------------------------------- 1 | .. _pipeline: 2 | 3 | Pipelines 4 | ************************ 5 | 6 | Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created 7 | for them by the orchestrator. It is these experiences in their rollout store that they are trained on. 8 | 9 | **General** 10 | 11 | .. autoclass:: trlx.pipeline.BasePipeline 12 | :members: 13 | 14 | .. autoclass:: trlx.pipeline.BaseRolloutStore 15 | :members: 16 | 17 | **PPO** 18 | 19 | .. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage 20 | :members: 21 | 22 | **ILQL** 23 | 24 | .. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline 25 | :members: 26 | 27 | .. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage 28 | :members: 29 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CG80499/trlx-with-T5/f4eae6703eee125a7adf6a291031f5efe76e2ed7/examples/__init__.py -------------------------------------------------------------------------------- /examples/architext.py: -------------------------------------------------------------------------------- 1 | # Toy example of optimizing textual interior designs to output the least number of rooms 2 | # Also see https://architext.design/ 3 | 4 | import trlx 5 | 6 | 7 | def reward_fn(samples): 8 | "Gives a negative count of rooms for each sample" 9 | return [-sample.count(":") for sample in samples] 10 | 11 | 12 | prompts = [ 13 | "[prompt] the bedroom is adjacent to the living room [layout]", 14 | "[prompt] a bedroom is adjacent to the living room [layout]", 15 | "[prompt] the bedroom is adjacent to the kitchen [layout]", 16 | "[prompt] a bedroom is adjacent to the kitchen [layout]", 17 | "[prompt] the bedroom is adjacent to the kitchen [layout]", 18 | "[prompt] the kitchen is adjacent to the bathroom [layout]", 19 | "[prompt] a bathroom is adjacent to the living room [layout]", 20 | "[prompt] the bathroom is adjacent to the living room [layout]", 21 | "[prompt] the bedroom is not adjacent to the living room [layout]", 22 | "[prompt] a bedroom is not adjacent to the living room [layout]", 23 | "[prompt] the bedroom is not adjacent to the kitchen [layout]", 24 | "[prompt] a bedroom is not adjacent to the kitchen [layout]", 25 | "[prompt] the bedroom is not adjacent to the kitchen [layout]", 26 | "[prompt] the kitchen is not adjacent to the bathroom [layout]", 27 | ] 28 | 29 | default_config = yaml.safe_load(open("configs/ppo_config.yml")) 30 | 31 | 32 | def main(hparams={}): 33 | config = TRLConfig.update(default_config, hparams) 34 | 35 | model = trlx.train( 36 | "architext/gptj-162M", reward_fn=reward_fn, prompts=prompts, config=config 37 | ) 38 | 39 | 40 | if __name__ == "__main__": 41 | main() 42 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/README.md: -------------------------------------------------------------------------------- 1 | # Interpreter Grounded Program Synthesis 2 | *Program synthesis* is the task of automatically generating programs that solve a given task by satisfying an IO condition. In Neural Program Synthesis the synthesizer is a neural network which is a Language Model that takes in an input/output pair and tries to generate the program in the defined toy DSL's Grammar. 3 | 4 | ## Toy List Manipulation DSL Grammar 5 | The DSL has the following grammar: 6 | ``` 7 | list_expr := list[int] 8 | integer := -5 | -4 | -3 | -2 | -1 | 0 | 1 | 2 | 3 | 4 | 5 9 | statement := 10 | | take(list_expr,integer) 11 | | drop(list_expr,integer) 12 | | reverse(list_expr) 13 | | sort_asc(list_expr) 14 | | sort_des(list_expr) 15 | | add_n(list_expr,integer) 16 | | sub_n(list_expr,integer) 17 | | mul_n(list_expr,integer) 18 | | expand_copy(list_expr) 19 | 20 | 21 | ``` 22 | This particular program `add_n(reverse([-2, -5, -4]),1)` would reverse the list and add one to it, thereby giving `[-3,-4,-1]`. 23 | More examples are showcased below: 24 | ``` 25 | take([1,2,3],2) -> [1,2] 26 | drop([1,2,3],2) -> [1] 27 | reverse([1,2,3]) -> [3,2,1] 28 | sort_asc([10,5,6]) -> [5,6,10] 29 | sort_des([10,5,6]) -> [10,6,5] 30 | 31 | ``` 32 | To generate training/testing data run, `python3 -m lang`. The dataset would be saved in `./dataset/train.json` and `./dataset/test.json`. To use the processed dataset refer to this [google drive link](https://drive.google.com/drive/folders/1093FlJA0MF7gh25yi4-__yU6Fj-onK1v?usp=share_link). 33 | Each datapoint in the dataset would look like, 34 | ```json 35 | {"input": "Input: [4, -2, 0, 0, 5, 5] Output: [25, 25, 20, 0, 0, -10] Function:", 36 | "output": "sort_des(reverse(mul_n(sort_asc(sort_asc([4, -2, 0, 0, 5, 5])),5)))"} 37 | ``` 38 | ## Caveat on DSL design 39 | The DSL designed here is a very simple toy example with every function returning type `list`, ideally in a real world scenario even list manipulation DSLs would be more complex with different types like strings, etc. 40 | ## Training with TRLX 41 | Run `python3 -m train_trlx.py` to run the training with grounded interpreter. The `reward_fn`, would return `-1` if a sample generated is of invalid syntax. it would return `0.5` if the generated syntax is valid but doesn't satisfy IO condition. 42 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CG80499/trlx-with-T5/f4eae6703eee125a7adf6a291031f5efe76e2ed7/examples/experiments/grounded_program_synthesis/__init__.py -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/configs/trlx_ppo_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 256 3 | epochs: 10 4 | total_steps: 80000 5 | batch_size: 8 6 | 7 | checkpoint_interval: 1000000 8 | eval_interval: 16 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "PPOOrchestrator" 12 | 13 | model: 14 | model_type: "AcceleratePPOModel" 15 | model_path: "reshinthadith/codegen_350M_list_manip_5_len" 16 | tokenizer_path: "reshinthadith/codegen_350M_list_manip_5_len" 17 | num_layers_unfrozen: 2 18 | 19 | optimizer: 20 | name: "adamw" 21 | kwargs: 22 | lr: 1.412e-4 23 | betas: [0.9, 0.95] 24 | eps: 1.0e-8 25 | weight_decay: 1.0e-6 26 | 27 | scheduler: 28 | name: "cosine_annealing" 29 | kwargs: 30 | T_max: 80000 # train.total_steps 31 | eta_min: 1.412e-4 32 | 33 | method: 34 | name: "ppoconfig" 35 | num_rollouts: 8 36 | chunk_size: 8 37 | ppo_epochs: 4 38 | init_kl_coef: 0.2 39 | target: 6 40 | horizon: 10000 41 | gamma: 1 42 | lam: 0.95 43 | cliprange: 0.2 44 | cliprange_value: 0.2 45 | vf_coef: 0.2 46 | scale_reward: False 47 | cliprange_reward: 10 48 | ref_mean: null 49 | ref_std: null 50 | gen_kwargs: 51 | max_new_tokens: 256 52 | top_k: 0 53 | top_p: 0.7 54 | do_sample: True 55 | temperature: 0.5 56 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/lang.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import json 3 | import random 4 | from pathlib import Path 5 | from pprint import pprint 6 | 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer 9 | 10 | 11 | def init_random_input(len_range: int = 5, value_gen=5) -> list: 12 | len_gen = random.randint(2, len_range + 1) 13 | value_range = list(range(-value_gen, value_gen + 1)) 14 | output = [] 15 | for index in range(len_gen): 16 | value_gen = random.choice(value_range) 17 | output.append(value_gen) 18 | return output 19 | 20 | 21 | const_integer = [-5, -4, -3, -2, -1, 1, 2, 3, 4, 5] 22 | 23 | # Functions in the DSL 24 | # Each function defines a transformation in the given DSL Grammar. 25 | def take(input_list: list, n: int) -> list: 26 | return input_list[:n] 27 | 28 | 29 | def drop(input_list: list, n: int) -> list: 30 | return input_list[n:] 31 | 32 | 33 | def minimum(input_list: list) -> int: 34 | return min(input_list) 35 | 36 | 37 | def maximum(input_list: list) -> int: 38 | return max(input_list) 39 | 40 | 41 | def reverse(input_list: list) -> list: 42 | return input_list[::-1] 43 | 44 | 45 | def sort_asc(input_list: list) -> list: 46 | return sorted(input_list) 47 | 48 | 49 | def sort_des(input_list: list) -> list: 50 | return sorted(input_list, reverse=True) 51 | 52 | 53 | def add_n(input_list: list, n: int) -> list: 54 | return [x + n for x in input_list] 55 | 56 | 57 | def sub_n(input_list: list, n: int) -> list: 58 | return [x - n for x in input_list] 59 | 60 | 61 | def mul_n(input_list: list, n: int) -> list: 62 | return [x * n for x in input_list] 63 | 64 | 65 | def div_n(input_list: list, n: int) -> list: 66 | return [x / n for x in input_list] 67 | 68 | 69 | def expand_copy(input_list: list) -> list: 70 | return input_list + input_list 71 | 72 | 73 | # Main Production Rules for the Toy DSL. 74 | list_manip_dsl = { 75 | "take": take, 76 | "drop": drop, 77 | "reverse": reverse, 78 | "sort_asc": sort_asc, 79 | "sort_des": sort_des, 80 | "add_n": add_n, 81 | "sub_n": sub_n, 82 | "mul_n": mul_n, 83 | "expand_copy": expand_copy, 84 | } 85 | 86 | 87 | # Use this class to execute programs written in the DSL. 88 | class Interpreter: 89 | def __init__(self) -> None: 90 | self.parser = list_manip_dsl 91 | 92 | def __call__(self, statement_string: str): 93 | """ 94 | Evaluation Function for the interpreter. 95 | args: 96 | statement_string (str) : Statement String 97 | """ 98 | try: 99 | return eval(statement_string) # Adding an exception to unparsable strings 100 | except: 101 | return "ERROR" 102 | 103 | 104 | interpreter = Interpreter() 105 | 106 | # TEMPLATE 107 | # This is used to store the input, output and the function template. 108 | # Input : List given as an input to the function. 109 | # function_template : The atomic function in a given DSL Grammar 110 | # Output : Transformed outut by applying function on the input. 111 | generation_template = {"function_template": "NONE", "output": "NONE", "input": []} 112 | 113 | 114 | # Each of the generate function is used to generate a 115 | # template for a given function 116 | # if chosen while sampling the dataset. 117 | # each function takes in expressions based on the grammar and generates a template. 118 | # Example: gen_take() generates a template for the take function. 119 | # take function has two arguments, 120 | # list_expression and a bounded integer(Should not be more 121 | # than the length of the list).. 122 | 123 | 124 | def gen_take(expr1=None, expr2=None): 125 | if expr1 == None: 126 | expr1 = init_random_input() 127 | if expr2 == None: 128 | expr2 = random.choice(range(1, len(expr1) - 1)) 129 | 130 | formatted_fn = f"take({expr1},{expr2})" 131 | template = copy.copy(generation_template) 132 | template["function_template"] = formatted_fn 133 | template["output"] = interpreter(formatted_fn) 134 | template["input"] = [expr1, expr2] 135 | return template 136 | 137 | 138 | def gen_drop(expr1=None, expr2=None): 139 | if expr1 == None: 140 | expr1 = init_random_input() 141 | if expr2 == None: 142 | expr2 = random.choice(range(1, len(expr1) - 1)) 143 | 144 | formatted_fn = f"drop({expr1},{expr2})" 145 | template = copy.copy(generation_template) 146 | template["function_template"] = formatted_fn 147 | template["output"] = interpreter(formatted_fn) 148 | template["input"] = [expr1, expr2] 149 | return template 150 | 151 | 152 | def gen_minimum(expr1=None): 153 | if expr1 == None: 154 | expr1 = init_random_input() 155 | 156 | formatted_fn = f"minimum({expr1})" 157 | template = copy.copy(generation_template) 158 | template["function_template"] = formatted_fn 159 | template["output"] = interpreter(formatted_fn) 160 | template["input"] = [expr1] 161 | return template 162 | 163 | 164 | def gen_maximum(expr1=None): 165 | if expr1 == None: 166 | expr1 = init_random_input() 167 | 168 | formatted_fn = f"maximum({expr1})" 169 | template = copy.copy(generation_template) 170 | template["function_template"] = formatted_fn 171 | template["output"] = interpreter(formatted_fn) 172 | template["input"] = [expr1] 173 | return template 174 | 175 | 176 | def gen_reverse(expr1=None): 177 | if expr1 == None: 178 | expr1 = init_random_input() 179 | 180 | formatted_fn = f"reverse({expr1})" 181 | template = copy.copy(generation_template) 182 | template["function_template"] = formatted_fn 183 | template["output"] = interpreter(formatted_fn) 184 | template["input"] = [expr1] 185 | return template 186 | 187 | 188 | def gen_sort_asc(expr1=None): 189 | if expr1 == None: 190 | expr1 = init_random_input() 191 | 192 | formatted_fn = f"sort_asc({expr1})" 193 | template = copy.copy(generation_template) 194 | template["function_template"] = formatted_fn 195 | template["output"] = interpreter(formatted_fn) 196 | template["input"] = [expr1] 197 | return template 198 | 199 | 200 | def gen_sort_des(expr1=None): 201 | if expr1 == None: 202 | expr1 = init_random_input() 203 | 204 | formatted_fn = f"sort_des({expr1})" 205 | template = copy.copy(generation_template) 206 | template["function_template"] = formatted_fn 207 | template["output"] = interpreter(formatted_fn) 208 | template["input"] = [expr1] 209 | return template 210 | 211 | 212 | def gen_add_n(expr1=None, expr2=None): 213 | if expr1 == None: 214 | expr1 = init_random_input() 215 | if expr2 == None: 216 | expr2 = random.choice(const_integer) 217 | 218 | formatted_fn = f"add_n({expr1},{expr2})" 219 | template = copy.copy(generation_template) 220 | template["function_template"] = formatted_fn 221 | template["output"] = interpreter(formatted_fn) 222 | template["input"] = [expr1, expr2] 223 | return template 224 | 225 | 226 | def gen_sub_n(expr1=None, expr2=None): 227 | if expr1 == None: 228 | expr1 = init_random_input() 229 | if expr2 == None: 230 | expr2 = random.choice(const_integer) 231 | 232 | formatted_fn = f"sub_n({expr1},{expr2})" 233 | template = copy.copy(generation_template) 234 | template["function_template"] = formatted_fn 235 | template["output"] = interpreter(formatted_fn) 236 | template["input"] = [expr1, expr2] 237 | return template 238 | 239 | 240 | def gen_mul_n(expr1=None, expr2=None): 241 | if expr1 == None: 242 | expr1 = init_random_input() 243 | if expr2 == None: 244 | expr2 = random.choice(const_integer) 245 | 246 | formatted_fn = f"mul_n({expr1},{expr2})" 247 | template = copy.copy(generation_template) 248 | template["function_template"] = formatted_fn 249 | template["output"] = interpreter(formatted_fn) 250 | template["input"] = [expr1, expr2] 251 | return template 252 | 253 | 254 | def gen_div_n(expr1=None, expr2=None): 255 | if expr1 == None: 256 | expr1 = init_random_input() 257 | if expr2 == None: 258 | expr2 = random.choice(const_integer) 259 | 260 | formatted_fn = f"div_n({expr1},{expr2})" 261 | template = copy.copy(generation_template) 262 | template["function_template"] = formatted_fn 263 | template["output"] = interpreter(formatted_fn) 264 | template["input"] = [expr1, expr2] 265 | return template 266 | 267 | 268 | def gen_expand_copy(expr1=None, expr2=None): 269 | if expr1 == None: 270 | expr1 = init_random_input() 271 | if expr2 == None: 272 | expr2 = random.choice(range(1, 3)) 273 | 274 | formatted_fn = f"expand_copy({expr1},{expr2})" 275 | template = copy.copy(generation_template) 276 | template["function_template"] = formatted_fn 277 | template["output"] = interpreter(formatted_fn) 278 | template["input"] = [expr1, expr2] 279 | return template 280 | 281 | 282 | list_manip_dsl_gen = { 283 | "take": gen_take, 284 | "drop": gen_drop, 285 | "minimum": gen_minimum, 286 | "maximum": gen_maximum, 287 | "reverse": gen_reverse, 288 | "sort_asc": gen_sort_asc, 289 | "sort_des": gen_sort_des, 290 | "add_n": gen_add_n, 291 | "sub_n": gen_sub_n, 292 | "mul_n": gen_mul_n, 293 | "div_n": gen_div_n, 294 | "expand_copy": gen_expand_copy, 295 | } 296 | 297 | 298 | class Sampler: 299 | def __init__( 300 | self, 301 | max_sample_length: int = 5, 302 | code_sep: str = ";", 303 | interpreter_sep: str = "->", 304 | ): 305 | self.max_sample_length = max_sample_length 306 | self.parser = Interpreter() 307 | self.production_list = list_manip_dsl 308 | self.production_idt = [i for i in self.production_list.keys()] 309 | self.production_gen_list = list_manip_dsl_gen 310 | self.code_sep = code_sep 311 | self.interpreter_sep = interpreter_sep 312 | 313 | def sample_production(self, gen_length: int = 5): 314 | init_flag = True 315 | hash_functions = [] 316 | if gen_length == None: 317 | gen_length = self.max_sample_length 318 | 319 | for ind in range(gen_length): 320 | if init_flag: 321 | random_chosen_function = random.choice(self.production_idt) 322 | generated_function = self.production_gen_list[random_chosen_function]() 323 | hash_functions.append(generated_function) 324 | init_flag = False 325 | else: 326 | random_chosen_function = random.choice(self.production_idt) 327 | generated_function = self.production_gen_list[random_chosen_function]( 328 | hash_functions[-1]["function_template"] 329 | ) 330 | if generated_function["output"] == "ERROR": 331 | break 332 | hash_functions.append(generated_function) 333 | 334 | return hash_functions 335 | 336 | 337 | def create_synthetic_dataset(size: int, io_size=3) -> dict: 338 | output_list = [] 339 | sampler = Sampler() 340 | for i in tqdm(range(size)): 341 | try: 342 | sampled = sampler.sample_production() 343 | inp = sampled[0]["input"][0] 344 | out = sampled[-1]["output"] 345 | function = sampled[-1]["function_template"] 346 | prompt_inp = f"Input: {inp} Output: {out} Function:" 347 | prompt_out = function 348 | if out != [] and out != "ERROR": 349 | output_list.append( 350 | { 351 | "input": prompt_inp, 352 | "output": prompt_out, 353 | "io_inp": inp, 354 | "io_out": out, 355 | } 356 | ) 357 | except: 358 | pass 359 | 360 | return output_list 361 | 362 | 363 | def write_to_json(data: dict, file_name: str): 364 | with open(file_name, "w") as f: 365 | json.dump(data, f, indent=2) 366 | 367 | 368 | def basic_stats(dataset, tokenizer): 369 | """ 370 | Basic stats to calculate the token length of the dataset. 371 | """ 372 | length_list = [] 373 | for examples in tqdm(dataset): 374 | datapoint = tokenizer( 375 | examples["input"] + " " + examples["output"] + "<|endoftext|>" 376 | ) 377 | length_list.append(len(datapoint["input_ids"])) 378 | return { 379 | "max": max(length_list), 380 | "min": min(length_list), 381 | "mean": sum(length_list) / len(length_list), 382 | } 383 | 384 | 385 | if __name__ == "__main__": 386 | # sampler = Sampler() 387 | # pprint(sampler.sample_production()) 388 | # pprint(interpreter("div_n(reverse([-2, -5, -4]),1)")) 389 | train_data = create_synthetic_dataset(2000000) 390 | test_data = create_synthetic_dataset(2_000) 391 | print(f"Train data size: {len(train_data)}") 392 | print(f"Test data size: {len(test_data)}") 393 | Path("dataset").mkdir(parents=True, exist_ok=True) 394 | write_to_json(train_data, "dataset/train.json") 395 | write_to_json(test_data, "dataset/test.json") 396 | -------------------------------------------------------------------------------- /examples/experiments/grounded_program_synthesis/train_trlx.py: -------------------------------------------------------------------------------- 1 | import trlx 2 | from trlx.data.configs import TRLConfig 3 | from lang import Interpreter 4 | import json 5 | import logging 6 | import yaml 7 | 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class DSLDataset: 13 | def __init__(self): 14 | self.train_data = json.load(open("dataset/train.json", "r")) 15 | self.test_data = json.load(open("dataset/test.json", "r")) 16 | logger.info("Sucessfully loaded the dataset") 17 | 18 | def load_datapoints(self, split="train"): 19 | if split == "train": 20 | for datapoint in self.train_data: 21 | if "ERROR" not in datapoint["input"]: 22 | yield datapoint["input"] 23 | elif split == "test": 24 | for datapoint in self.test_data: 25 | yield datapoint["input"] 26 | 27 | 28 | interpreter = Interpreter() 29 | 30 | 31 | def reward_fn(samples): 32 | reward_list = [] 33 | for sample in samples: 34 | code = sample.split("Function:")[1].strip() 35 | output = eval(sample.split("Output:")[1].strip().split("Function:")[0].strip()) 36 | interpreted_output = interpreter(code) 37 | if interpreted_output == "ERROR": 38 | # If the code is unparsable, we give it a negative reward. 39 | reward_list.append(-1) 40 | else: 41 | # if the code is parseable 42 | if output == interpreted_output: 43 | # if the output is correct, we give it a positive reward. 44 | reward_list.append(1) 45 | else: 46 | # if the output is incorrect, we give it a negative reward. 47 | reward_list.append(-0.5) 48 | 49 | return reward_list 50 | 51 | 52 | default_config = yaml.safe_load(open("configs/trlx_ppo_config.yml")) 53 | 54 | 55 | def main(hparams={}): 56 | config = TRLConfig.update(default_config, hparams) 57 | 58 | # Dataset 59 | dataset = DSLDataset() 60 | train_prompts = list(dataset.load_datapoints(split="train"))[:1000] 61 | 62 | model = trlx.train( 63 | reward_fn=reward_fn, 64 | prompts=train_prompts, 65 | config=config, 66 | ) 67 | model.save_pretrained("dataset/trained_model") 68 | 69 | 70 | if __name__ == "__main__": 71 | # TEST REWARD FUNTION 72 | assert ( 73 | reward_fn( 74 | ["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -4]),1)"] 75 | ) 76 | ) == [1] 77 | assert ( 78 | reward_fn( 79 | ["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -a]),1)"] 80 | ) 81 | ) == [-1] 82 | assert ( 83 | reward_fn( 84 | ["Input: 1 Output: [-4,-5,-2] Function: div_n(reverse([-2, -5, -3]),1)"] 85 | ) 86 | ) == [-0.5] 87 | 88 | main() 89 | -------------------------------------------------------------------------------- /examples/ilql_sentiments.py: -------------------------------------------------------------------------------- 1 | from datasets import load_dataset 2 | from transformers import pipeline 3 | 4 | import trlx 5 | import yaml 6 | from typing import List, Dict 7 | import os 8 | from trlx.data.configs import TRLConfig 9 | 10 | 11 | def get_positive_score(scores): 12 | "Extract value associated with a positive sentiment from pipeline's output" 13 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 14 | 15 | 16 | default_config = yaml.safe_load(open("configs/ilql_config.yml")) 17 | 18 | 19 | def main(hparams={}): 20 | config = TRLConfig.update(default_config, hparams) 21 | 22 | sentiment_fn = pipeline( 23 | "sentiment-analysis", 24 | "lvwerra/distilbert-imdb", 25 | top_k=2, 26 | truncation=True, 27 | batch_size=256, 28 | device=0 if int(os.environ.get("LOCAL_RANK", 0)) == 0 else -1, 29 | ) 30 | 31 | def metric_fn(samples: List[str]) -> Dict[str, List[float]]: 32 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 33 | return {"sentiments": sentiments} 34 | 35 | imdb = load_dataset("imdb", split="train+test") 36 | 37 | trlx.train( 38 | "gpt2", 39 | dataset=(imdb["text"], imdb["label"]), 40 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 41 | metric_fn=metric_fn, 42 | config=config, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | main() 48 | -------------------------------------------------------------------------------- /examples/ppo_reward_model.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | 4 | from datasets import load_dataset 5 | from transformers import pipeline 6 | import os 7 | import json 8 | 9 | import trlx 10 | import torch 11 | from typing import List 12 | 13 | import requests 14 | 15 | from trlx.data.configs import TRLConfig 16 | 17 | url = 'http://65.108.33.75:5000/rewards' 18 | 19 | def reward_fn(samples: List[str]) -> List[float]: 20 | return requests.post(url, json = {"texts": samples}).json()["rewards"] 21 | 22 | PROMPT = """Question: {query} 23 | 24 | Relevant paper: 25 | Title: {title} 26 | Abstract: {abstract} 27 | 28 | Write a helpful 1-line summary of the paper based on the question. 29 | 30 | Helpful summary:""" 31 | 32 | with open("/root/fine-tuning-takeaway-models/human_ft_data.json", "r") as f: 33 | data = json.load(f) 34 | 35 | with open("/root/fine-tuning-takeaway-models/human_ft_data_test.json", "r") as f: 36 | data_test = json.load(f) 37 | 38 | prompts = [ 39 | PROMPT.format(query=d["query"], title=d["title"], abstract=d["abstract"][-2200:]) 40 | for d in data 41 | ] 42 | 43 | eval_prompts = [ 44 | PROMPT.format(query=d["query"], title=d["title"], abstract=d["abstract"][-2200:]) 45 | for d in data_test 46 | ] 47 | 48 | def main(): 49 | 50 | config = TRLConfig.load_yaml("configs/ppo_config_t5_old.yml") 51 | 52 | model = trlx.train( 53 | reward_fn=reward_fn, 54 | prompts=prompts, 55 | eval_prompts=eval_prompts, 56 | config=config, 57 | ) 58 | 59 | if __name__ == "__main__": 60 | main() -------------------------------------------------------------------------------- /examples/ppo_sentiments.py: -------------------------------------------------------------------------------- 1 | # Generates positive movie reviews by tuning a pretrained model on IMDB dataset 2 | # with a sentiment reward function 3 | 4 | from datasets import load_dataset 5 | from transformers import pipeline 6 | import os 7 | import yaml 8 | 9 | import trlx 10 | import torch 11 | from typing import List 12 | from trlx.data.configs import TRLConfig 13 | 14 | 15 | def get_positive_score(scores): 16 | "Extract value associated with a positive sentiment from pipeline's output" 17 | return dict(map(lambda x: tuple(x.values()), scores))["POSITIVE"] 18 | 19 | 20 | default_config = yaml.safe_load(open("configs/ppo_config.yml")) 21 | 22 | 23 | def main(hparams={}): 24 | config = TRLConfig.update(default_config, hparams) 25 | 26 | if torch.cuda.is_available(): 27 | device = int(os.environ.get("LOCAL_RANK", 0)) 28 | else: 29 | device = -1 30 | 31 | sentiment_fn = pipeline( 32 | "sentiment-analysis", 33 | "lvwerra/distilbert-imdb", 34 | top_k=2, 35 | truncation=True, 36 | batch_size=256, 37 | device=device, 38 | ) 39 | 40 | def reward_fn(samples: List[str]) -> List[float]: 41 | sentiments = list(map(get_positive_score, sentiment_fn(samples))) 42 | return sentiments 43 | 44 | # Take few words off of movies reviews as prompts 45 | imdb = load_dataset("imdb", split="train+test") 46 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 47 | 48 | model = trlx.train( 49 | reward_fn=reward_fn, 50 | prompts=prompts, 51 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 52 | config=config, 53 | ) 54 | 55 | 56 | if __name__ == "__main__": 57 | main() 58 | -------------------------------------------------------------------------------- /examples/randomwalks/README.md: -------------------------------------------------------------------------------- 1 | Toy problem similar to the one described in [Decision Transformer (Lili Chen et al. 2021)](https://arxiv.org/abs/2106.01345) [1]: 2 | finding graph's shortest paths by learning from a dataset of sampled random 3 | walks. 4 | 5 | In this implementation there are not environment dynamics – impossible and 6 | incorrect paths are penalized the same way by a single reward which is given at 7 | the end of the trajectory, measuring how optimal the path is compared to the 8 | shortest possible (bounded in [0, 1]). Paths are represented as strings of 9 | letters, with each letter corresponding to a node in a graph. PPO example uses a 10 | pretrained model for starting transition probabilities, ILQL learns them from 11 | the samples directly. 12 | 13 | [1] code for which is not present in the official repo, see issue 14 | https://github.com/kzl/decision-transformer/issues/48 15 | -------------------------------------------------------------------------------- /examples/randomwalks/__init__.py: -------------------------------------------------------------------------------- 1 | from .randomwalks import generate_random_walks 2 | -------------------------------------------------------------------------------- /examples/randomwalks/configs/ilql_randomwalks.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 10 3 | batch_size: 100 4 | epochs: 20 5 | total_steps: 1000 6 | 7 | checkpoint_interval: 100000 8 | eval_interval: 16 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "OfflineOrchestrator" 12 | 13 | seed: 1000 14 | 15 | model: 16 | model_type: "AccelerateILQLModel" 17 | model_path: "CarperAI/randomwalks" 18 | tokenizer_path: "CarperAI/randomwalks" 19 | num_layers_unfrozen: -1 20 | 21 | optimizer: 22 | name: "adamw" 23 | kwargs: 24 | lr: 2.0e-4 25 | betas: [0.9, 0.95] 26 | eps: 1.0e-8 27 | weight_decay: 1.0e-6 28 | 29 | scheduler: 30 | name: "cosine_annealing" 31 | kwargs: 32 | T_max: 1000 # train.total_steps 33 | eta_min: 2.0e-4 34 | 35 | method: 36 | name: "ilqlconfig" 37 | tau: 0.8 38 | gamma: 0.99 39 | cql_scale: 0.1 40 | awac_scale: 1 41 | alpha: 0.1 42 | steps_for_target_q_sync: 5 43 | two_qs: true 44 | gen_kwargs: 45 | max_new_tokens: 9 46 | top_k: 1 47 | beta: 100 48 | temperature: 1.0 49 | -------------------------------------------------------------------------------- /examples/randomwalks/configs/ppo_randomwalks.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 10 3 | batch_size: 100 4 | epochs: 20 5 | total_steps: 1000 6 | 7 | checkpoint_interval: 10000 8 | eval_interval: 20 9 | 10 | pipeline: "PromptPipeline" 11 | orchestrator: "PPOOrchestrator" 12 | 13 | model: 14 | model_type: "AcceleratePPOModel" 15 | model_path: "CarperAI/randomwalks" 16 | tokenizer_path: "CarperAI/randomwalks" 17 | num_layers_unfrozen: -1 18 | 19 | optimizer: 20 | name: "adamw" 21 | kwargs: 22 | lr: 3.0e-4 23 | betas: [0.9, 0.95] 24 | eps: 1.0e-8 25 | weight_decay: 1.0e-6 26 | 27 | scheduler: 28 | name: "cosine_annealing" 29 | kwargs: 30 | T_max: 1000 # train.total_steps 31 | eta_min: 3.0e-4 32 | 33 | method: 34 | name: "ppoconfig" 35 | num_rollouts: 128 36 | chunk_size: 128 37 | ppo_epochs: 4 38 | init_kl_coef: 0.05 39 | target: 6 40 | horizon: 10000 41 | gamma: 1 42 | lam: 0.95 43 | cliprange: 0.2 44 | cliprange_value: 0.2 45 | vf_coef: 1.2 46 | scale_reward: False 47 | ref_mean: null 48 | ref_std: null 49 | cliprange_reward: 1 50 | gen_kwargs: 51 | max_new_tokens: 9 52 | top_k: 0.0 53 | top_p: 1.0 54 | do_sample: True 55 | -------------------------------------------------------------------------------- /examples/randomwalks/ilql_randomwalks.py: -------------------------------------------------------------------------------- 1 | from examples.randomwalks import generate_random_walks 2 | 3 | import os 4 | import trlx 5 | from trlx.data.configs import TRLConfig 6 | import yaml 7 | from transformers import GPT2Config 8 | 9 | config_path = os.path.join(os.path.dirname(__file__), "configs/ilql_randomwalks.yml") 10 | default_config = yaml.safe_load(open(config_path)) 11 | 12 | 13 | def main(hparams={}): 14 | config = TRLConfig.update(default_config, hparams) 15 | 16 | metric_fn, eval_prompts, walks, _ = generate_random_walks(seed=config.train.seed) 17 | rewards = metric_fn(walks)["optimality"] 18 | 19 | trlx.train( 20 | GPT2Config(n_layer=6, n_embd=144, vocab_size=23), 21 | dataset=(walks, rewards), 22 | eval_prompts=eval_prompts, 23 | metric_fn=metric_fn, 24 | config=config, 25 | ) 26 | 27 | 28 | if __name__ == "__main__": 29 | main() 30 | -------------------------------------------------------------------------------- /examples/randomwalks/ppo_randomwalks.py: -------------------------------------------------------------------------------- 1 | from examples.randomwalks import generate_random_walks 2 | 3 | import yaml 4 | import trlx 5 | from trlx.data.configs import TRLConfig 6 | import os 7 | 8 | config_path = os.path.join(os.path.dirname(__file__), "configs/ppo_randomwalks.yml") 9 | default_config = yaml.safe_load(open(config_path)) 10 | 11 | 12 | def main(hparams={}): 13 | config = TRLConfig.update(default_config, hparams) 14 | 15 | metric_fn, prompts, *_ = generate_random_walks(seed=config.train.seed) 16 | 17 | trlx.train( 18 | "CarperAI/randomwalks", 19 | reward_fn=lambda walks: metric_fn(walks)["optimality"], 20 | prompts=prompts, 21 | eval_prompts=prompts, 22 | metric_fn=metric_fn, 23 | config=config, 24 | ) 25 | 26 | 27 | if __name__ == "__main__": 28 | main() 29 | -------------------------------------------------------------------------------- /examples/randomwalks/randomwalks.py: -------------------------------------------------------------------------------- 1 | import networkx as nx 2 | import numpy as np 3 | import torch 4 | 5 | 6 | def randexclude(rng: np.random.RandomState, n: int, exclude: int) -> int: 7 | while True: 8 | x = rng.randint(n) 9 | if x != exclude: 10 | return x 11 | 12 | 13 | def generate_random_walks( 14 | n_nodes=21, max_length=10, n_walks=1000, p_edge=0.1, seed=1002, gpt2_tokenizer=False 15 | ): 16 | rng = np.random.RandomState(seed) 17 | 18 | while True: 19 | adj = rng.rand(n_nodes, n_nodes) > (1 - p_edge) 20 | np.fill_diagonal(adj, 0) 21 | if np.all(adj.sum(1)): 22 | break 23 | 24 | # terminal state 25 | adj[0, :] = 0 26 | adj[0, 0] = 1 27 | 28 | char_to_node = {chr(ix + ord("a")): ix for ix in range(n_nodes)} 29 | node_to_char = {ix: chr(ix + ord("a")) for ix in range(n_nodes)} 30 | 31 | goal = 0 32 | sample_walks = [] 33 | for _ in range(n_walks): 34 | node = randexclude(rng, n_nodes, goal) 35 | walk = [node] 36 | 37 | for istep in range(max_length - 1): 38 | node = rng.choice(np.nonzero(adj[node])[0]) 39 | walk.append(node) 40 | if node == goal: 41 | break 42 | 43 | # code each node by a letter 44 | # for bpe tokenizer join them over | for a guaranteed split 45 | walk = [node_to_char[ix] for ix in walk] 46 | delimiter = "|" if gpt2_tokenizer else "" 47 | 48 | sample_walks.append(delimiter.join(walk)) 49 | 50 | # calculate the shortest paths for comparison 51 | shortest_lengths = [] 52 | g = nx.from_numpy_array(adj, create_using=nx.DiGraph) 53 | for start in set(range(n_nodes)) - {goal}: 54 | try: 55 | shortest_path = nx.shortest_path(g, start, goal)[:max_length] 56 | shortest_lengths.append(len(shortest_path)) 57 | except Exception: 58 | shortest_lengths.append(max_length) 59 | 60 | shortest_lengths = torch.tensor(shortest_lengths) 61 | 62 | def metric_fn(samples): 63 | # a measure for an invalid or a not found path 64 | infty = 100 65 | lengths = [] 66 | ref_lengths = [] 67 | 68 | for s in samples: 69 | if gpt2_tokenizer: 70 | s = s.replace("|", "") 71 | 72 | s = [char_to_node.get(c, 1000) for c in s] 73 | length = None 74 | for ix in range(len(s)): 75 | # a nonexisting path is taken 76 | if s[ix] >= n_nodes or ix > 0 and not adj[s[ix - 1], s[ix]]: 77 | length = infty 78 | break 79 | elif s[ix] == 0: 80 | length = ix + 1 81 | break 82 | 83 | if length is None: 84 | length = infty 85 | 86 | lengths.append(length) 87 | # allows for inorder checking of % optimality 88 | ref_lengths.append(shortest_lengths[s[0] - 1]) 89 | 90 | lengths = torch.tensor(lengths, dtype=torch.float) 91 | bound_lengths = torch.where(lengths.eq(infty), max_length, lengths).abs() 92 | ref_lengths = torch.as_tensor(ref_lengths) 93 | 94 | return { 95 | "lengths": lengths, 96 | # percentage-optimal \in (0, 1) when compared to the shortest path 97 | "optimality": (max_length - bound_lengths) / (max_length - ref_lengths), 98 | } 99 | 100 | logit_mask = torch.tensor(adj) 101 | 102 | eval_prompts = list(sorted(set(w[0] for w in sample_walks))) 103 | eval_prompts = [prompt + delimiter for prompt in eval_prompts] 104 | 105 | return metric_fn, eval_prompts, sample_walks, logit_mask 106 | -------------------------------------------------------------------------------- /examples/simulacra.py: -------------------------------------------------------------------------------- 1 | # Optimize prompts by training on prompts-ratings pairings dataset 2 | # taken from https://github.com/JD-P/simulacra-aesthetic-captions 3 | 4 | import sqlite3 5 | 6 | from urllib.request import urlretrieve 7 | import os 8 | 9 | import trlx 10 | 11 | url = "https://raw.githubusercontent.com/JD-P/simulacra-aesthetic-captions/main/sac_public_2022_06_29.sqlite" 12 | dbpath = "sac_public_2022_06_29.sqlite" 13 | 14 | if __name__ == "__main__": 15 | if not os.path.exists(dbpath): 16 | print(f"fetching {dbpath}") 17 | urlretrieve(url, dbpath) 18 | 19 | conn = sqlite3.connect(dbpath) 20 | c = conn.cursor() 21 | c.execute( 22 | "SELECT prompt, rating FROM ratings " 23 | "JOIN images ON images.id=ratings.iid " 24 | "JOIN generations ON images.gid=generations.id " 25 | "WHERE rating IS NOT NULL;" 26 | ) 27 | 28 | prompts, ratings = tuple(map(list, zip(*c.fetchall()))) 29 | model = trlx.train( 30 | "gpt2", 31 | dataset=(prompts, ratings), 32 | eval_prompts=["Hatsune Miku, Red Dress"] * 64, 33 | ) 34 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = trlx 3 | author = Alex Havrilla 4 | version = 0.3.0 5 | url = https://github.com/CarperAI/trlx 6 | description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | license = MIT 10 | 11 | [options] 12 | packages = find: 13 | install_requires = 14 | accelerate>=0.12.0 15 | datasets 16 | deepspeed>=0.7.3 17 | einops>=0.4.1 18 | numpy>=1.23.2 19 | torchtyping 20 | transformers>=4.21.2 21 | tqdm 22 | wandb 23 | ray>=2.0.1 24 | tabulate>=0.9.0 25 | networkx 26 | 27 | [options.extras_require] 28 | dev = 29 | black 30 | isort 31 | flake8 32 | pre-commit 33 | pytest 34 | pytest-cov 35 | 36 | [options.packages.find] 37 | exclude = 38 | docs* 39 | tests* 40 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CG80499/trlx-with-T5/f4eae6703eee125a7adf6a291031f5efe76e2ed7/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from trlx.data.configs import TRLConfig 4 | from typing import List 5 | 6 | 7 | def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]: 8 | """Returns all sub-directories of `dir` named `configs`.""" 9 | config_dirs = [] 10 | for root, dirs, _ in os.walk(dir): 11 | for d in dirs: 12 | if d == config_dir_name: 13 | config_dirs.append(os.path.join(root, d)) 14 | return config_dirs 15 | 16 | 17 | def _get_yaml_filepaths(dir: str) -> List[str]: 18 | """Returns a list of `yml` filepaths in `dir`.""" 19 | filepaths = [] 20 | for file in os.listdir(dir): 21 | if file.endswith(".yml"): 22 | filepaths.append(os.path.join(dir, file)) 23 | return filepaths 24 | 25 | 26 | def test_repo_trl_configs(): 27 | """Tests to ensure all default configs in the repository are valid.""" 28 | config_dirs = ["configs", *_get_config_dirs("examples")] 29 | config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior 30 | for file in config_files: 31 | assert os.path.isfile(file), f"Config file {file} does not exist." 32 | assert file.endswith(".yml"), f"Config file {file} is not a yaml file." 33 | try: 34 | TRLConfig.load_yaml(file) 35 | except Exception as e: 36 | assert False, f"Failed to load config file `{file}` with error `{e}`" 37 | -------------------------------------------------------------------------------- /tests/test_ppo.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | from trlx.data.configs import TRLConfig 3 | from trlx.model.nn.ppo_models import CausalLMHydraWithValueHead 4 | from trlx.utils.modeling import RunningMoments 5 | from transformers import AutoTokenizer 6 | import torch 7 | 8 | 9 | # Note tests must start with "test_" 10 | class TestHydraHead(unittest.TestCase): 11 | @classmethod 12 | def setUpClass(cls): 13 | print("Testing Hydra model...") 14 | config = TRLConfig.load_yaml("configs/test_config.yml") 15 | cls.hydra_model = CausalLMHydraWithValueHead( 16 | config.model.model_path, config.model.num_layers_unfrozen 17 | ) 18 | 19 | tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path) 20 | tokenizer.pad_token = tokenizer.eos_token 21 | tokenizer.padding_side = "left" 22 | 23 | cls.dummy_inputs = tokenizer("Once upon a time there was a happy goose named Louis. He liked to eat bananas.", truncation=True, padding="max_length", max_length=4, return_tensors="pt") 24 | 25 | def test_lm_heads(self): 26 | with torch.no_grad(): 27 | unfrozen_outputs = TestHydraHead.hydra_model(**TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True) 28 | unfrozen_logits = unfrozen_outputs.logits 29 | last_hidden_states = unfrozen_outputs.hidden_states[-1].to(torch.float32) 30 | frozen_logits = TestHydraHead.hydra_model.frozen_head.lm_head(last_hidden_states) 31 | diff = torch.sum(unfrozen_logits - frozen_logits).item() 32 | self.assertEqual(diff, 0) 33 | 34 | def test_forward(self): 35 | with torch.no_grad(): 36 | unfrozen_outputs = TestHydraHead.hydra_model(**TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True) 37 | unfrozen_last_hidden_states = unfrozen_outputs.hidden_states[-1] 38 | unfrozen_logits = unfrozen_outputs.logits 39 | 40 | frozen_outputs = TestHydraHead.hydra_model.forward_hydra(**TestHydraHead.dummy_inputs, return_dict=True, output_hidden_states=True) 41 | frozen_last_hidden_states = frozen_outputs.hidden_states[-1] 42 | frozen_logits = frozen_outputs.logits 43 | 44 | hs_diff = torch.sum(unfrozen_last_hidden_states - frozen_last_hidden_states).item() 45 | logits_diff = torch.sum(unfrozen_logits - frozen_logits).item() 46 | self.assertEqual(hs_diff, 0) 47 | self.assertEqual(logits_diff, 0) 48 | 49 | class TestStatistics(unittest.TestCase): 50 | @classmethod 51 | def setUpClass(cls): 52 | cls.m = RunningMoments() 53 | cls.a1 = torch.arange(100, dtype=float) 54 | cls.a2 = torch.ones(100, dtype=float) 55 | cls.a3 = torch.exp(torch.arange(10, dtype=float)) 56 | cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) 57 | 58 | def test_running_moments(self): 59 | assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) 60 | assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) 61 | assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) 62 | assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) 63 | 64 | a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) 65 | assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) 66 | assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) 67 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | import transformers 4 | 5 | import accelerate 6 | import trlx.utils as utils 7 | import trlx.utils.modeling as modeling_utils 8 | 9 | # Test general utils 10 | 11 | 12 | @pytest.mark.parametrize( 13 | "optimizer_name", 14 | [o.value for o in utils.OptimizerNames], 15 | ) 16 | def test_optimizer_class_getters(optimizer_name: str): 17 | try: 18 | _class = utils.get_optimizer_class(optimizer_name) 19 | except Exception as e: 20 | assert False, "Failed to get optimizer class with error: " + str(e) 21 | 22 | # Hard-check for one of the optimizers 23 | _class = utils.get_optimizer_class("adamw") 24 | assert _class == torch.optim.AdamW 25 | 26 | 27 | @pytest.mark.parametrize( 28 | "scheduler_name", 29 | [o.value for o in utils.SchedulerNames], 30 | ) 31 | def test_scheduler_class_getters(scheduler_name: str): 32 | try: 33 | _class = utils.get_scheduler_class(scheduler_name) 34 | except Exception as e: 35 | assert False, "Failed to get scheduler class with error: " + str(e) 36 | 37 | # Hard-check for one of the schedulers 38 | _class = utils.get_scheduler_class("cosine_annealing") 39 | assert _class == torch.optim.lr_scheduler.CosineAnnealingLR 40 | 41 | 42 | # Test modeling utils 43 | 44 | 45 | @pytest.mark.parametrize( 46 | "model_name", 47 | [ 48 | "EleutherAI/gpt-j-6B", 49 | "EleutherAI/gpt-neox-20b", 50 | "gpt2", 51 | "facebook/opt-1.3b", 52 | ], 53 | ) 54 | def test_hf_attr_getters(model_name: str): 55 | with accelerate.init_empty_weights(): 56 | config = transformers.AutoConfig.from_pretrained(model_name) 57 | arch = transformers.AutoModelForCausalLM.from_config(config) 58 | 59 | arch_getters = [ 60 | modeling_utils.hf_get_causal_base_model, 61 | modeling_utils.hf_get_causal_final_norm, 62 | modeling_utils.hf_get_causal_hidden_layers, 63 | modeling_utils.hf_get_lm_head, 64 | ] 65 | for get in arch_getters: 66 | try: 67 | get(arch) 68 | except Exception as e: 69 | assert False, "Failed to get model attribute with error: " + str(e) 70 | 71 | config_getters = [ 72 | modeling_utils.hf_get_hidden_size, 73 | modeling_utils.hf_get_num_hidden_layers, 74 | ] 75 | for get in config_getters: 76 | try: 77 | get(config) 78 | except Exception as e: 79 | assert False, "Failed to get config attribute with error: " + str(e) 80 | -------------------------------------------------------------------------------- /trlx/__init__.py: -------------------------------------------------------------------------------- 1 | from .trlx import train 2 | -------------------------------------------------------------------------------- /trlx/data/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Iterable, List 4 | 5 | from torchtyping import TensorType 6 | 7 | 8 | @dataclass 9 | class GeneralElement: 10 | """ 11 | General element outputted by data pipeline being read by orchestrator. 12 | """ 13 | 14 | pass 15 | 16 | 17 | @dataclass 18 | class SimElement: 19 | """ 20 | Batch element for Gyarados or Gyarados-like similarity scoring model 21 | """ 22 | 23 | content: Any = None 24 | preference: Any = None 25 | score: float = None 26 | 27 | 28 | @dataclass 29 | class RLElement: 30 | """ 31 | Batch element for RL model 32 | """ 33 | 34 | state: Iterable[str] = None # Context/prompts 35 | action: TensorType["N"] = None # Tokens generated by model given prompts 36 | reward: float = None # Reward obtained for that generation 37 | 38 | 39 | @dataclass 40 | class BatchElement: 41 | """ 42 | General batch element for any transformer to use in its forward pass 43 | """ 44 | 45 | tokens: TensorType["BATCH", "SEQ_LEN"] 46 | masks: TensorType["BATCH", "SEQ_LEN"] 47 | -------------------------------------------------------------------------------- /trlx/data/accelerate_base_datatypes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | @dataclass 8 | class PromptElement: 9 | """ 10 | Dataclass for a single prompt, containing its string and tokenized form. 11 | 12 | :param text: The prompt text. 13 | :type text: str 14 | 15 | :param tokens: The prompt tokens. Should be a long tensor 16 | :type tokens: torch.Tensor 17 | """ 18 | 19 | text: str 20 | tokens: TensorType["num_tokens"] 21 | 22 | 23 | @dataclass 24 | class PromptBatch: 25 | """ 26 | Batched PromptElement 27 | 28 | :param text: An iterable of prompt texts. 29 | :type text: Iterable[str] 30 | 31 | :param tokens: A long tensor batch of prompt tokens. 32 | :type tokens: torch.Tensor 33 | """ 34 | 35 | text: Iterable[str] 36 | tokens: TensorType["batch_size", "num_tokens"] 37 | 38 | 39 | @dataclass 40 | class AccelerateRLElement: 41 | """ 42 | Dataclass for RL elements, containing output tokens and rewards for each token. 43 | 44 | :param tokens: The output tokens. Should be a long tensor 45 | :type tokens: torch.Tensor 46 | 47 | :param rewards: The rewards for each token. Should be a float tensor of same size as tokens. 48 | :type rewards: torch.Tensor 49 | """ 50 | 51 | output_tokens: TensorType["output_size"] 52 | rewards: TensorType["output_size"] 53 | 54 | 55 | @dataclass 56 | class AccelerateRLBatchElement: 57 | """ 58 | Batched accelerate RL element 59 | 60 | :param tokens: Batches of long tensors of output tokens. 61 | :type tokens: torch.Tensor 62 | 63 | :param rewards: Batches of float tensors of rewards for each output token. 64 | :type rewards: torch.Tensor 65 | """ 66 | 67 | output_tokens: TensorType["batch_size", "output_size"] 68 | rewards: TensorType["batch_size", "output_size"] 69 | -------------------------------------------------------------------------------- /trlx/data/configs.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Any, Dict, Optional, Tuple, Set 3 | 4 | import yaml 5 | 6 | from trlx.data.method_configs import MethodConfig, get_method 7 | import os 8 | 9 | 10 | def merge(base: Dict, update: Dict, updated: Set) -> Dict: 11 | "Recursively updates a nested dictionary with new values" 12 | for k, v in base.items(): 13 | if isinstance(v, dict): 14 | base[k] = merge(v, update, updated) 15 | elif k in update: 16 | base[k] = update[k] 17 | updated.add(k) 18 | 19 | return base 20 | 21 | 22 | @dataclass 23 | class ModelConfig: 24 | """ 25 | Config for a model. 26 | 27 | :param model_type: One of the registered RL models present in trlx.model 28 | :type model_type: str 29 | 30 | :param model_path: Path or name of the model (local or on huggingface hub) 31 | :type model_path: str 32 | 33 | :param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub) 34 | :type tokenizer_path: str 35 | 36 | :param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning. 37 | -1 means all layers are unfrozen. 38 | :type num_layers_unfrozen: int 39 | """ 40 | 41 | model_type: str 42 | model_path: str 43 | tokenizer_path: str 44 | num_layers_unfrozen: int = -1 45 | 46 | @classmethod 47 | def from_dict(cls, config: Dict[str, Any]): 48 | return cls(**config) 49 | 50 | 51 | @dataclass 52 | class OptimizerConfig: 53 | """ 54 | Config for an optimizer. 55 | 56 | :param name: Name of the optimizer 57 | :type name: str 58 | 59 | :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay) 60 | :type kwargs: Dict[str, Any] 61 | """ 62 | 63 | name: str 64 | kwargs: Dict[str, Any] = None 65 | 66 | @classmethod 67 | def from_dict(cls, config: Dict[str, Any]): 68 | return cls(**config) 69 | 70 | 71 | @dataclass 72 | class SchedulerConfig: 73 | """ 74 | Config for a learning rate scheduler. 75 | 76 | :param name: Name of the scheduler 77 | :type name: str 78 | 79 | :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max) 80 | :type kwargs: Dict[str, Any] 81 | """ 82 | 83 | name: str 84 | kwargs: Dict[str, Any] = None 85 | 86 | @classmethod 87 | def from_dict(cls, config: Dict[str, Any]): 88 | return cls(**config) 89 | 90 | 91 | @dataclass 92 | class TrainConfig: 93 | """ 94 | Config for train job on model. 95 | 96 | :param total_steps: Total number of training steps 97 | :type total_steps: int 98 | 99 | :param seq_length: Number of tokens to use as context (max length for tokenizer) 100 | :type seq_length: int 101 | 102 | :param epochs: Total number of passes through data 103 | :type epochs: int 104 | 105 | :param batch_size: Batch size for training 106 | :type batch_size: int 107 | 108 | :param checkpoint_interval: Save model every checkpoint_interval steps 109 | :type checkpoint_interval: int 110 | 111 | :param eval_interval: Evaluate model every eval_interval steps 112 | :type eval_interval: int 113 | 114 | :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline 115 | :type pipeline: str 116 | 117 | :param orchestrator: Orchestrator to use for training. One of the registered orchestrators present in trlx.orchestrator 118 | :type orchestrator: str 119 | 120 | :param project_name: Project name for wandb 121 | :type project_name: str 122 | 123 | :param entity_name: Entity name for wandb 124 | :type entity_name: str 125 | 126 | :param checkpoint_dir: Directory to save checkpoints 127 | :type checkpoint_dir: str 128 | 129 | :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. Only used by AcceleratePPOModel. 130 | :type rollout_logging_dir: Optional[str] 131 | 132 | :param seed: Random seed 133 | :type seed: int 134 | """ 135 | 136 | total_steps: int 137 | seq_length: int 138 | epochs: int 139 | batch_size: int 140 | 141 | checkpoint_interval: int 142 | eval_interval: int 143 | 144 | pipeline: str # One of the pipelines in framework.pipeline 145 | orchestrator: str # One of the orchestrators 146 | 147 | project_name: str = "trlx" 148 | entity_name: Optional[str] = None 149 | 150 | checkpoint_dir: str = "ckpts" 151 | rollout_logging_dir: Optional[str] = None 152 | 153 | seed: int = 1000 154 | 155 | @classmethod 156 | def from_dict(cls, config: Dict[str, Any]): 157 | return cls(**config) 158 | 159 | 160 | @dataclass 161 | class TRLConfig: 162 | """ 163 | Top level config for trlX. Loads configs and can be converted to dictionary. 164 | """ 165 | 166 | method: MethodConfig 167 | model: ModelConfig 168 | optimizer: OptimizerConfig 169 | scheduler: SchedulerConfig 170 | train: TrainConfig 171 | 172 | @classmethod 173 | def load_yaml(cls, yml_fp: str): 174 | """ 175 | Load yaml file as TRLConfig. 176 | 177 | :param yml_fp: Path to yaml file 178 | :type yml_fp: str 179 | """ 180 | with open(yml_fp, mode="r") as file: 181 | config = yaml.safe_load(file) 182 | return cls.from_dict(config) 183 | 184 | def to_dict(self): 185 | """ 186 | Convert TRLConfig to dictionary. 187 | """ 188 | data = { 189 | "method": self.method.__dict__, 190 | "model": self.model.__dict__, 191 | "optimizer": self.optimizer.__dict__, 192 | "scheduler": self.scheduler.__dict__, 193 | "train": self.train.__dict__, 194 | } 195 | 196 | return data 197 | 198 | @classmethod 199 | def from_dict(cls, config: Dict): 200 | """ 201 | Convert dictionary to TRLConfig. 202 | """ 203 | return cls( 204 | method=get_method(config["method"]["name"]).from_dict(config["method"]), 205 | model=ModelConfig.from_dict(config["model"]), 206 | optimizer=OptimizerConfig.from_dict(config["optimizer"]), 207 | scheduler=SchedulerConfig.from_dict(config["scheduler"]), 208 | train=TrainConfig.from_dict(config["train"]), 209 | ) 210 | 211 | @classmethod 212 | def update(cls, baseconfig: Dict, config: Dict): 213 | updates = set() 214 | merged = merge(baseconfig, config, updates) 215 | 216 | for param in config: 217 | if param not in updates: 218 | raise ValueError( 219 | f"parameter {param} is not present in the config (typo or a wrong config)" 220 | ) 221 | 222 | return cls.from_dict(merged) 223 | 224 | def __str__(self): 225 | """Returns a human-readable string representation of the config.""" 226 | import json 227 | 228 | return json.dumps(self.to_dict(), indent=4) 229 | -------------------------------------------------------------------------------- /trlx/data/ilql_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torchtyping import TensorType # type: ignore 4 | 5 | 6 | @dataclass 7 | class ILQLElement: 8 | """ 9 | Data element for ILQL 10 | 11 | :param input_ids: Input tokens. Should be a long tensor. 12 | :type input_ids: torch.Tensor 13 | 14 | :param attention_mask: Attention mask. Should be a long tensor. 15 | :type attention_mask: torch.Tensor 16 | 17 | :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. 18 | :type rewards: torch.Tensor 19 | """ 20 | 21 | input_ids: TensorType["query_size"] 22 | attention_mask: TensorType["query_size"] 23 | rewards: TensorType["reward_size"] 24 | states_ixs: TensorType["states_size"] 25 | actions_ixs: TensorType["reward_size"] 26 | dones: TensorType["states_size"] 27 | 28 | 29 | @dataclass 30 | class ILQLBatch: 31 | """ 32 | Batched ILQL data elements 33 | 34 | :param input_ids: Batch of input tokens. 35 | :type input_ids: torch.Tensor 36 | 37 | :param attention_mask: Batch of attention masks. 38 | :type attention_mask: torch.Tensor 39 | 40 | :param rewards: Batch of rewards for each token in each token batch. 41 | :type rewards: torch.Tensor 42 | """ 43 | 44 | input_ids: TensorType["batch_size", "query_size"] 45 | attention_mask: TensorType["batch_size", "query_size"] 46 | rewards: TensorType["batch_size", "reward_size"] 47 | states_ixs: TensorType["batch_size", "states_size"] 48 | actions_ixs: TensorType["batch_size", "reward_size"] 49 | dones: TensorType["batch_size", "states_size"] 50 | -------------------------------------------------------------------------------- /trlx/data/method_configs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Any, Callable, Dict, List 4 | 5 | # specifies a dictionary of method configs 6 | _METHODS: Dict[str, Any] = {} # registry 7 | 8 | 9 | def register_method(name): 10 | """Decorator used register a method config 11 | Args: 12 | name: Name of the method 13 | """ 14 | 15 | def register_class(cls, name): 16 | _METHODS[name] = cls 17 | setattr(sys.modules[__name__], name, cls) 18 | return cls 19 | 20 | if isinstance(name, str): 21 | name = name.lower() 22 | return lambda c: register_class(c, name) 23 | 24 | cls = name 25 | name = cls.__name__ 26 | register_class(cls, name.lower()) 27 | 28 | return cls 29 | 30 | 31 | @dataclass 32 | @register_method 33 | class MethodConfig: 34 | """ 35 | Config for a certain RL method. 36 | 37 | :param name: Name of the method 38 | :type name: str 39 | """ 40 | 41 | name: str 42 | 43 | @classmethod 44 | def from_dict(cls, config: Dict[str, Any]): 45 | return cls(**config) 46 | 47 | 48 | def get_method(name: str) -> MethodConfig: 49 | """ 50 | Return constructor for specified method config 51 | """ 52 | name = name.lower() 53 | if name in _METHODS: 54 | return _METHODS[name] 55 | else: 56 | raise Exception("Error: Trying to access a method that has not been registered") 57 | -------------------------------------------------------------------------------- /trlx/data/ppo_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torchtyping import TensorType 4 | 5 | 6 | @dataclass 7 | class PPORLElement: 8 | """ 9 | :param query_tensor: The query tensor i.e. the prompt tokens. Should be a long tensor. 10 | :type query_tensor: torch.Tensor 11 | 12 | :param response_tensor: The response tensor i.e. the output tokens. Should be a long tensor. 13 | :type response_tensor: torch.Tensor 14 | 15 | :param logprobs: The log probabilities over all tokens in the vocabulary for each token generated from the policy network (i.e. the autoregressive model). Should be a float tensor of same size as tokens, with a dimension across the vocabulary. 16 | :type logprobs: torch.Tensor 17 | 18 | :param values: The values for each token generated from the value network or value head. Should be a float tensor of same size as tokens. 19 | :type values: torch.Tensor 20 | 21 | :param rewards: The rewards for each token outputted in response. Should be a float tensor of same size as tokens. 22 | :type rewards: torch.Tensor 23 | """ 24 | 25 | query_tensor: TensorType["query_size"] 26 | response_tensor: TensorType["response_size"] 27 | logprobs: TensorType["response_size", "vocab_size"] 28 | values: TensorType["response_size"] 29 | rewards: TensorType["response_size"] 30 | 31 | 32 | @dataclass 33 | class PPORLBatch: 34 | """ 35 | A batched version of the PPORLElement. See PPORLElement for more details on individual fields. 36 | 37 | :param query_tensors: A batch of query tensors. Should be a long tensor. 38 | :type query_tensors: torch.Tensor 39 | 40 | :param response_tensors: A batch of response tensors. Should be a long tensor. 41 | :type response_tensors: torch.Tensor 42 | 43 | :param logprobs: A batch of log probabilities from policy 44 | :type logprobs: torch.Tensor 45 | 46 | :param values: A batch of values from value network 47 | :type values: torch.Tensor 48 | 49 | :param rewards: A batch of rewards 50 | :type rewards: torch.Tensor 51 | """ 52 | 53 | query_tensors: TensorType["batch_size", "query_size"] 54 | response_tensors: TensorType["batch_size", "response_size"] 55 | logprobs: TensorType["batch_size", "response_size", "vocab_size"] 56 | values: TensorType["batch_size", "response_size"] 57 | rewards: TensorType["batch_size", "response_size"] 58 | -------------------------------------------------------------------------------- /trlx/model/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from abc import abstractmethod 4 | from typing import Any, Callable, Dict, Iterable 5 | 6 | import torch 7 | 8 | from trlx.data import RLElement 9 | from trlx.data.configs import TRLConfig 10 | from trlx.pipeline import BaseRolloutStore 11 | from trlx.utils import safe_mkdir 12 | 13 | # specifies a dictionary of architectures 14 | _MODELS: Dict[str, Any] = {} # registry 15 | 16 | 17 | def register_model(name): 18 | """Decorator used register an architecture 19 | Args: 20 | name: Name of the architecture 21 | """ 22 | 23 | def register_class(cls, name): 24 | _MODELS[name] = cls 25 | setattr(sys.modules[__name__], name, cls) 26 | return cls 27 | 28 | if isinstance(name, str): 29 | name = name.lower() 30 | return lambda c: register_class(c, name) 31 | 32 | cls = name 33 | name = cls.__name__ 34 | register_class(cls, name.lower()) 35 | 36 | return cls 37 | 38 | 39 | @register_model 40 | class BaseRLModel: 41 | def __init__(self, config: TRLConfig, train_mode=False): 42 | self.store: BaseRolloutStore = None 43 | self.config = config 44 | self.train_mode = train_mode 45 | 46 | def push_to_store(self, data): 47 | self.store.push(data) 48 | 49 | def add_eval_pipeline(self, eval_pipeline): 50 | """Adds pipeline from with validation prompts""" 51 | self.eval_pipeline = eval_pipeline 52 | 53 | @abstractmethod 54 | def act(self, data: RLElement) -> RLElement: 55 | """ 56 | Given RLElement with state, produce an action and add it to the RLElement. 57 | Orchestrator should call this, get reward and push subsequent RLElement to RolloutStore 58 | """ 59 | pass 60 | 61 | @abstractmethod 62 | def sample( 63 | self, prompts: Iterable[str], length: int, n_samples: int 64 | ) -> Iterable[str]: 65 | """ 66 | Sample from the language. Takes prompts and maximum length to generate. 67 | 68 | :param prompts: List of prompts to tokenize and use as context 69 | 70 | :param length: How many new tokens to genrate for each prompt 71 | :type length: int 72 | 73 | :param n_samples: Default behavior is to take number of prompts as this 74 | """ 75 | pass 76 | 77 | @abstractmethod 78 | def learn( 79 | self, 80 | log_fn: Callable = None, 81 | save_fn: Callable = None, 82 | eval_fn: Callable = None, 83 | ): 84 | """ 85 | Use experiences in RolloutStore to learn 86 | 87 | :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values 88 | :type log_fn: Callable[Dict[str, any]] 89 | 90 | :param save_fn: Optional function to call after saving. Is passed the components. 91 | :type save_fn: Callable[Dict[str, any]] 92 | 93 | :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this. 94 | :type eval_fn: Callable[BaseRLModel] 95 | """ 96 | pass 97 | 98 | @abstractmethod 99 | def save(self, directory=None): 100 | """Creates a checkpoint of training states""" 101 | pass 102 | 103 | @abstractmethod 104 | def load(self, directory=None): 105 | """Loads a checkpoint created from `save`""" 106 | pass 107 | 108 | def intervals(self, steps: int) -> Dict[str, bool]: 109 | """ 110 | Using config and current step number, returns a dict of whether certain things should be done 111 | """ 112 | 113 | return { 114 | "do_log": (steps + 1) % self.config.train.log_interval == 0, 115 | "do_eval": (steps + 1) % self.config.train.eval_interval == 0, 116 | "do_save": (steps + 1) % self.config.train.checkpoint_interval == 0, 117 | } 118 | -------------------------------------------------------------------------------- /trlx/model/accelerate_base_model.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import json 3 | import os 4 | import sys 5 | from abc import abstractmethod 6 | from time import time 7 | from typing import Any, Dict, Iterable, Sequence, Tuple, Union 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from transformers import AutoTokenizer 12 | 13 | import wandb 14 | from accelerate import Accelerator # type: ignore 15 | 16 | if importlib.util.find_spec("rich") is not None: 17 | from tqdm.rich import tqdm 18 | else: 19 | from tqdm import tqdm 20 | 21 | import ray 22 | from ray.air import session 23 | from ray.air.checkpoint import Checkpoint 24 | 25 | from trlx.data.configs import TRLConfig 26 | from trlx.model import BaseRLModel, register_model 27 | from trlx.utils import ( 28 | filter_non_scalars, 29 | get_optimizer_class, 30 | get_scheduler_class, 31 | get_distributed_config, 32 | get_git_tag, 33 | ) 34 | from trlx.utils.modeling import freeze_bottom_causal_layers 35 | 36 | 37 | @register_model 38 | class AccelerateRLModel(BaseRLModel): 39 | """ 40 | RL Model that uses accelerate for training 41 | """ 42 | 43 | def __init__(self, config, train_mode=True): 44 | super().__init__(config, train_mode) 45 | 46 | self.accelerator = Accelerator(log_with="wandb") 47 | 48 | if int(os.environ.get("WORLD_SIZE", 1)) > 1: 49 | torch.distributed.barrier(device_ids=[int(os.environ.get("LOCAL_RANK", 0))]) 50 | 51 | self.max_length = config.train.seq_length 52 | 53 | # Retrieves model equipped for ppo, ilql, etc 54 | self.model = self.get_arch(self.config) 55 | 56 | if not "t5" in config.model.model_path.lower(): 57 | freeze_bottom_causal_layers( 58 | self.model.base_model, self.config.model.num_layers_unfrozen 59 | ) 60 | 61 | if config.model.tokenizer_path: 62 | self.tokenizer = AutoTokenizer.from_pretrained(config.model.tokenizer_path) 63 | self.tokenizer.pad_token = self.tokenizer.eos_token 64 | self.tokenizer.padding_side = "left" 65 | else: 66 | self.tokenizer = None 67 | 68 | script_name = os.path.basename(sys.argv[0]).rsplit(".", 1)[0] 69 | if not isinstance(config.model.model_path, str): 70 | model_name = str(config.model.model_path).split()[0] 71 | else: 72 | model_name = config.model.model_path.split("/")[-1] 73 | run_name = f"{script_name}/{model_name}" 74 | 75 | if self.accelerator.is_main_process and not ray.is_initialized(): 76 | config_dict = self.config.to_dict() 77 | dist_config = get_distributed_config(self.accelerator) 78 | config_dict["distributed"] = dist_config 79 | self.accelerator.init_trackers( 80 | project_name=self.config.train.project_name, 81 | config=config_dict, 82 | init_kwargs={ 83 | "wandb": { 84 | "name": run_name, 85 | "entity": self.config.train.entity_name, 86 | "tags": [get_git_tag()], 87 | "mode": "disabled" 88 | if os.environ.get("debug", False) 89 | else "online", 90 | } 91 | }, 92 | ) 93 | 94 | self.opt = get_optimizer_class(config.optimizer.name)( 95 | self.model.parameters(), 96 | **config.optimizer.kwargs, 97 | ) 98 | 99 | self.scheduler = get_scheduler_class(config.scheduler.name)( 100 | self.opt, 101 | **config.scheduler.kwargs, 102 | ) 103 | self.best_mean_reward = -float("inf") 104 | 105 | def tokenize(self, text: Union[Sequence[str], Sequence[torch.LongTensor]]): 106 | """ 107 | Tokenize a batch of text after adding bos token to each of the samples 108 | """ 109 | if isinstance(text[0], torch.LongTensor): 110 | return text 111 | 112 | text = [self.tokenizer.bos_token + txt for txt in text] 113 | return self.tokenizer( 114 | text, 115 | truncation=True, 116 | max_length=self.config.seq_length, 117 | return_tensors="pt", 118 | # NOTE: We manually add special tokens (bos) above so we set this False 119 | # to avoid models that automatically add special tokens (e.g. OPT) 120 | # adding them twice more. 121 | add_special_tokens=False, 122 | ) 123 | 124 | def generate(self, input_ids, attention_mask=None, **kwargs): 125 | """Wraps hf's `generate` adding some specific method's defaults""" 126 | input_ids = input_ids.to(self.accelerator.device) 127 | if attention_mask is not None: 128 | attention_mask = attention_mask.to(self.accelerator.device) 129 | 130 | kwargs = dict(self.generate_kwargs, **kwargs) 131 | 132 | with torch.no_grad(): 133 | return self.accelerator.unwrap_model(self.model).generate( 134 | input_ids=input_ids, attention_mask=attention_mask, **kwargs 135 | ) 136 | 137 | def save(self, directory=None): 138 | """Creates checkpoint of optimizer, scheduler and a model""" 139 | self.accelerator.save_state(directory or self.config.train.checkpoint_dir) 140 | 141 | def load(self, directory=None): 142 | """Load checkpoint of optimizer, scheduler and a model""" 143 | self.accelerator.load_state(directory or self.config.train.checkpoint_dir) 144 | 145 | def add_eval_pipeline(self, eval_pipeline): 146 | """Adds pipeline from with validation prompts""" 147 | self.eval_pipeline = eval_pipeline 148 | 149 | def evaluate(self): 150 | """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" 151 | stats = {} 152 | all_samples = [] 153 | prompts_sizes = [] 154 | generate_time = time() 155 | for prompts in self.eval_dataloader: 156 | if isinstance(prompts, torch.Tensor): 157 | samples = self.generate(prompts) 158 | else: 159 | samples = self.generate(**prompts) 160 | 161 | if isinstance(samples, tuple): 162 | samples, *_ = samples 163 | 164 | pad_token = self.tokenizer.eos_token_id if self.tokenizer else 0 165 | all_samples.append( 166 | F.pad( 167 | samples, 168 | (0, self.max_length - samples.shape[1]), 169 | value=pad_token, 170 | ) 171 | ) 172 | sizes = torch.tensor(prompts.input_ids.shape[1]).repeat( 173 | len(prompts.input_ids) 174 | ) 175 | prompts_sizes.append(sizes.to(samples.device)) 176 | 177 | stats["time/generate"] = time() - generate_time 178 | 179 | samples = self.accelerator.gather(torch.vstack(all_samples)) 180 | prompts_sizes = self.accelerator.gather(torch.hstack(prompts_sizes)) 181 | 182 | if self.accelerator.is_main_process: 183 | if self.tokenizer: 184 | str_samples = self.tokenizer.batch_decode( 185 | samples, skip_special_tokens=True 186 | ) 187 | 188 | prompts, responses = [], [] 189 | for sample, prompt_size in zip(samples, prompts_sizes): 190 | prompts.append(sample[:prompt_size]) 191 | responses.append(sample[prompt_size:]) 192 | 193 | str_prompts = self.tokenizer.batch_decode( 194 | prompts, skip_special_tokens=True 195 | ) 196 | str_responses = self.tokenizer.batch_decode( 197 | responses, skip_special_tokens=True 198 | ) 199 | 200 | if isinstance(str_samples[0], str): 201 | columns_data = [str_prompts, str_responses] 202 | else: 203 | columns_data = [samples.tolist()] 204 | columns = ["prompt", "response"] 205 | 206 | # in online setting, compute the reward for validation 207 | if self.reward_fn: 208 | rewards = torch.tensor(self.reward_fn(str_samples), dtype=torch.float) 209 | mean_reward = rewards.mean() 210 | columns.append("reward") 211 | columns_data.append(rewards) 212 | stats["reward/mean"] = mean_reward 213 | print(f"Mean rewards: {mean_reward}") 214 | if mean_reward > self.best_mean_reward: 215 | self.best_mean_reward = mean_reward 216 | self.save() 217 | print("=== Saved ===") 218 | 219 | # additionally log any other metrics 220 | if self.metric_fn: 221 | metric_time = time() 222 | metrics = self.metric_fn(str_samples) 223 | stats["time/metric"] = time() - metric_time 224 | 225 | mean_metrics = { 226 | f"metrics/{k}": torch.as_tensor(xs).mean(-1) 227 | for k, xs in metrics.items() 228 | } 229 | 230 | stats.update(mean_metrics) 231 | 232 | for metric, values in metrics.items(): 233 | columns.append(metric) 234 | columns_data.append(values) 235 | 236 | rows = list(zip(*columns_data)) 237 | print(rows[0]) 238 | if not ray.is_initialized(): 239 | stats["samples"] = wandb.Table(columns=columns, rows=rows) 240 | 241 | return stats 242 | 243 | def learn(self): 244 | """ 245 | Samples batches from `self.store`, updates model and periodically evaluates it on `self.eval_dataloader` 246 | """ 247 | 248 | self.prepare_learning() 249 | self.iter_count = 0 250 | 251 | if ray.is_initialized(): 252 | checkpoint = session.get_checkpoint() 253 | if checkpoint: 254 | with checkpoint.as_directory() as dir: 255 | self.accelerator.load_state(dir) 256 | 257 | with open(os.path.join(dir, "state.json")) as f: 258 | state = json.load(f) 259 | self.iter_count = state["iter_count"] 260 | else: 261 | results = self.evaluate() 262 | self.accelerator.log(results, step=self.iter_count) 263 | 264 | tbar = tqdm( 265 | initial=self.iter_count, 266 | total=self.total_steps, 267 | disable=not self.accelerator.is_local_main_process, 268 | ) 269 | 270 | for _ in range(self.config.train.epochs): 271 | for batch in self.train_dataloader: 272 | for _ in range(self.n_updates_per_batch): 273 | forward_time = time() 274 | loss, stats = self.loss(batch) 275 | forward_time = time() - forward_time 276 | 277 | backward_time = time() 278 | self.accelerator.backward(loss) 279 | backward_time = time() - backward_time 280 | 281 | self.opt.step() 282 | self.opt.zero_grad() 283 | self.scheduler.step() 284 | self.iter_count += 1 285 | 286 | if self.iter_count % self.config.train.checkpoint_interval == 0: 287 | pass#self.save() 288 | 289 | stats["time/forward"] = forward_time 290 | stats["time/backward"] = backward_time 291 | 292 | if self.iter_count % self.config.train.eval_interval == 0: 293 | results = self.evaluate() 294 | stats.update(results) 295 | 296 | # Report the metrics to Ray Tune. 297 | if ray.is_initialized(): 298 | self.save("state") 299 | with open("state/state.json", "w") as f: 300 | json.dump(dict(iter_count=self.iter_count), f) 301 | checkpoint = Checkpoint.from_directory("state") 302 | session.report( 303 | filter_non_scalars(stats), checkpoint=checkpoint 304 | ) 305 | 306 | if not ray.is_initialized(): 307 | self.accelerator.log(stats, step=self.iter_count) 308 | 309 | desc = ", ".join( 310 | f"{k}: {v:.2f}" 311 | for k, v in stats.items() 312 | if k.startswith("loss") 313 | ) 314 | tbar.set_description(desc) 315 | tbar.update() 316 | 317 | if self.iter_count >= self.total_steps: 318 | #self.save() 319 | return self.evaluate() 320 | 321 | self.post_backward_callback() 322 | 323 | self.post_epoch_callback() 324 | 325 | @abstractmethod 326 | def get_arch(self, config: TRLConfig): 327 | """Returns a specific wrapper of the decoder architecture""" 328 | pass 329 | 330 | @abstractmethod 331 | def loss(self, batch) -> Tuple[float, Dict]: 332 | """Compute loss on a batch from `store` and return some statistics""" 333 | pass 334 | 335 | @abstractmethod 336 | def post_backward_callback(self): 337 | """Do something after model update""" 338 | pass 339 | 340 | @abstractmethod 341 | def post_epoch_callback(self): 342 | """Do something after exhausting/single pass over `self.store`""" 343 | pass -------------------------------------------------------------------------------- /trlx/model/accelerate_ilql_model.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, Sequence, Union, cast 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | from trlx.model import register_model 8 | from trlx.model.nn.ilql_models import ILQLConfig, CausalLMWithValueHeads 9 | from trlx.data.ilql_types import ILQLBatch 10 | from trlx.data.configs import TRLConfig 11 | from trlx.utils import to_device 12 | 13 | from .accelerate_base_model import AccelerateRLModel 14 | 15 | 16 | @register_model 17 | class AccelerateILQLModel(AccelerateRLModel): 18 | def __init__( 19 | self, 20 | config: TRLConfig, 21 | logit_mask=None, 22 | metric_fn=None, 23 | train_mode=True, 24 | ): 25 | super().__init__(config, train_mode) 26 | self.logit_mask = logit_mask 27 | self.metric_fn = metric_fn 28 | self.reward_fn = None 29 | 30 | if not isinstance(config.method, ILQLConfig): 31 | raise ValueError("config.method must be ILQLConfig") 32 | 33 | self.ilql: ILQLConfig = cast(ILQLConfig, config.method) 34 | 35 | self.generate_kwargs = dict( 36 | config.method.gen_kwargs, 37 | max_length=self.max_length, 38 | logit_mask=self.logit_mask, 39 | eos_token_id=self.tokenizer.eos_token_id if self.tokenizer else 0, 40 | pad_token_id=self.tokenizer.pad_token_id if self.tokenizer else 0, 41 | ) 42 | 43 | def get_arch(self, config): 44 | return CausalLMWithValueHeads( 45 | config.model.model_path, 46 | ilql_config=config.method, 47 | num_layers_unfrozen=config.model.num_layers_unfrozen, 48 | ) 49 | 50 | def tokenize(self, texts: Union[Sequence[str], Sequence[torch.LongTensor]]): 51 | if isinstance(texts[0], torch.LongTensor): 52 | return texts 53 | 54 | tokenized = self.tokenizer( 55 | [self.tokenizer.bos_token + x + self.tokenizer.eos_token for x in texts], 56 | max_length=self.max_length, 57 | truncation=True, 58 | # NOTE: We manually add special tokens (bos) above so we set this False 59 | # to avoid models that automatically add special tokens (e.g. OPT) 60 | # adding them twice more. 61 | add_special_tokens=False, 62 | ) 63 | input_ids = list(map(torch.as_tensor, tokenized.input_ids)) 64 | return input_ids 65 | 66 | def post_backward_callback(self): 67 | if self.iter_count % self.config.method.steps_for_target_q_sync == 0: 68 | self.accelerator.unwrap_model(self.model).sync_target_q_heads() 69 | 70 | def loss(self, batch: ILQLBatch): 71 | batch = to_device(batch, self.accelerator.device) 72 | 73 | logits, qs, target_qs, vs, _ = self.model( 74 | input_ids=batch.input_ids, 75 | attention_mask=batch.attention_mask, 76 | actions_ixs=batch.actions_ixs, 77 | states_ixs=batch.states_ixs, 78 | ) 79 | 80 | return self.ilql.loss((logits, (qs, target_qs, vs)), batch) 81 | 82 | def prepare_learning(self): 83 | train_dataloader = self.store.create_loader(self.config.train.batch_size) 84 | eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) 85 | 86 | ( 87 | self.model, 88 | self.opt, 89 | self.train_dataloader, 90 | self.eval_dataloader, 91 | ) = self.accelerator.prepare( 92 | self.model, self.opt, train_dataloader, eval_dataloader 93 | ) 94 | 95 | self.n_updates_per_batch = 1 96 | self.total_steps = self.config.train.epochs * len(train_dataloader) 97 | self.total_steps = min(self.total_steps, self.config.train.total_steps) 98 | -------------------------------------------------------------------------------- /trlx/model/accelerate_ppo_model.py: -------------------------------------------------------------------------------- 1 | from time import time 2 | from typing import Tuple 3 | import uuid, os, json 4 | 5 | import torch 6 | from torchtyping import TensorType 7 | 8 | from trlx.data.configs import TRLConfig 9 | from trlx.data.ppo_types import PPORLBatch 10 | from trlx.model import register_model 11 | from trlx.model.accelerate_base_model import AccelerateRLModel 12 | from trlx.model.nn.ppo_models import ( 13 | AdaptiveKLController, 14 | FixedKLController, 15 | CausalLMHydraWithValueHead, 16 | T5HydraWithValueHead, 17 | ) 18 | from trlx.pipeline.ppo_pipeline import PPORolloutStorage 19 | from trlx.utils.modeling import logprobs_from_logits 20 | import torch.nn.functional as F 21 | import wandb 22 | 23 | import ray 24 | 25 | from tqdm import tqdm 26 | 27 | 28 | @register_model 29 | class AcceleratePPOModel(AccelerateRLModel): 30 | def __init__(self, config): 31 | super().__init__(config) 32 | 33 | if config.train.rollout_logging_dir is not None: 34 | self.log_rollouts = True 35 | self.setup_rollout_logging(config) 36 | else: 37 | self.log_rollouts = False 38 | 39 | self.store = PPORolloutStorage(self.tokenizer.pad_token_id) 40 | 41 | rollout_loader = self.store.create_loader( 42 | self.config.train.batch_size, shuffle=True 43 | ) 44 | 45 | self.model, self.opt, self.scheduler, rollout_loader = self.accelerator.prepare( 46 | self.model, self.opt, self.scheduler, rollout_loader 47 | ) 48 | 49 | self.store.clear_history() 50 | if config.method.target is not None: 51 | self.kl_ctl = AdaptiveKLController( 52 | config.method.init_kl_coef, config.method.target, config.method.horizon 53 | ) 54 | else: 55 | self.kl_ctl = FixedKLController(config.method.init_kl_coef) 56 | 57 | self.generate_kwargs = dict( 58 | config.method.gen_kwargs, 59 | eos_token_id=self.tokenizer.eos_token_id, 60 | pad_token_id=self.tokenizer.eos_token_id, 61 | ) 62 | 63 | def get_arch(self, config: TRLConfig): 64 | return CausalLMHydraWithValueHead( 65 | config.model.model_path, config.model.num_layers_unfrozen 66 | ) 67 | 68 | def get_model_inputs( 69 | self, 70 | query_tensors: TensorType["batch_size", "query_size"], 71 | response_tensors: TensorType["batch_size", "response_size"], 72 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 73 | tokens = torch.cat((query_tensors, response_tensors), dim=1)[ 74 | :, -self.max_length : 75 | ] 76 | attention_mask = ( 77 | tokens.not_equal(self.tokenizer.pad_token_id).long().to(tokens.device) 78 | ) 79 | # For a proper positional encoding in case of left padding 80 | position_ids = attention_mask.cumsum(-1) - 1 81 | position_ids.masked_fill_(attention_mask.eq(0), 0) 82 | return tokens, attention_mask, position_ids 83 | 84 | def loss(self, batch: PPORLBatch): 85 | # Move `batch` data to `accelerator` device 86 | query_tensors = batch.query_tensors.to(self.accelerator.device) 87 | response_tensors = batch.response_tensors.to(self.accelerator.device) 88 | old_logprobs = batch.logprobs.to(self.accelerator.device) 89 | old_values = batch.values.to(self.accelerator.device) 90 | old_rewards = batch.rewards.to(self.accelerator.device) 91 | 92 | response_length = old_rewards.shape[1] 93 | 94 | advantages, returns = self.config.method.get_advantages_and_returns( 95 | old_values, old_rewards, response_length 96 | ) 97 | 98 | tokens, attention_mask, position_ids = self.get_model_inputs( 99 | query_tensors, response_tensors 100 | ) 101 | 102 | logits, *_, values_pred = self.model( 103 | tokens, attention_mask=attention_mask, position_ids=position_ids 104 | ) 105 | values_pred = values_pred[:, :-1] 106 | logprobs = logprobs_from_logits(logits[:, :-1, :], tokens[:, 1:]) 107 | attention_mask = attention_mask[:, :-1] 108 | 109 | # Only the response part of the values/logprobs is needed 110 | start = query_tensors.shape[1] - 1 111 | end = start + response_length 112 | logprobs, values_pred, mask = ( 113 | logprobs[:, start:end], 114 | values_pred[:, start:end], 115 | attention_mask[:, start:end], 116 | ) 117 | 118 | loss, stats = self.config.method.loss( 119 | logprobs=logprobs, 120 | values=values_pred, 121 | old_logprobs=old_logprobs, 122 | old_values=old_values, 123 | advantages=advantages, 124 | returns=returns, 125 | mask=mask, 126 | ) 127 | self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats 128 | return loss, stats 129 | 130 | def setup_rollout_logging(self, config): 131 | # Make rollout logging dir for this run and store config 132 | exists = os.path.exists(config.train.rollout_logging_dir) 133 | isdir = os.path.isdir(config.train.rollout_logging_dir) 134 | assert exists and isdir 135 | 136 | self.run_id = f"run-{uuid.uuid4()}" 137 | self.rollout_logging_dir = os.path.join( 138 | config.train.rollout_logging_dir, self.run_id 139 | ) 140 | os.mkdir(self.rollout_logging_dir) 141 | 142 | with open(os.path.join(self.rollout_logging_dir, "config.json"), "w") as f: 143 | f.write(json.dumps(config.to_dict(), indent=2)) 144 | 145 | def post_epoch_callback(self): 146 | if self.log_rollouts: 147 | self.store.export_history(location=self.rollout_logging_dir) 148 | self.store.clear_history() 149 | self.orch.make_experience( 150 | self.config.method.num_rollouts, self.iter_count 151 | ) # Collect more rollouts for training 152 | 153 | def post_backward_callback(self): 154 | self.kl_ctl.update(self.approx_kl, n_steps=self.config.train.batch_size) 155 | 156 | def prepare_learning(self): 157 | eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) 158 | 159 | train_dataloader = self.store.create_loader( 160 | self.config.train.batch_size, shuffle=True 161 | ) 162 | 163 | self.train_dataloader, self.eval_dataloader = self.accelerator.prepare( 164 | train_dataloader, eval_dataloader 165 | ) 166 | 167 | self.n_updates_per_batch = self.config.method.ppo_epochs 168 | self.total_steps = ( 169 | self.config.train.epochs 170 | * self.n_updates_per_batch 171 | * len(self.train_dataloader) 172 | ) 173 | self.total_steps = min(self.total_steps, self.config.train.total_steps) 174 | 175 | @register_model 176 | class T5AcceleratePPOModel(AcceleratePPOModel): 177 | 178 | def __init__(self, config: TRLConfig): 179 | super().__init__(config) 180 | 181 | self.tokenizer.padding_side = "right" # Left padding not supported 182 | 183 | def get_arch(self, config: TRLConfig): 184 | return T5HydraWithValueHead( 185 | config.model.model_path, 186 | ) 187 | 188 | def get_model_inputs( 189 | self, 190 | query_tensors: TensorType["batch_size", "query_size"], 191 | response_tensors: TensorType["batch_size", "response_size"], 192 | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: 193 | input_ids = query_tensors[:, :self.max_length] 194 | attention_mask = ( 195 | input_ids.not_equal(self.tokenizer.pad_token_id).long().to(input_ids.device) 196 | ) 197 | decoder_input_ids = response_tensors[:, : self.max_length] 198 | 199 | decoder_attention_mask = ( 200 | decoder_input_ids.not_equal(self.tokenizer.pad_token_id) 201 | .long() 202 | .to(decoder_input_ids.device) 203 | ) 204 | 205 | return input_ids, attention_mask, decoder_input_ids, decoder_attention_mask 206 | 207 | def loss(self, batch: PPORLBatch): 208 | # Move `batch` data to `accelerator` device 209 | query_tensors = batch.query_tensors.to(self.accelerator.device) 210 | response_tensors = batch.response_tensors.to(self.accelerator.device) 211 | old_logprobs = batch.logprobs.to(self.accelerator.device) 212 | old_values = batch.values.to(self.accelerator.device) 213 | old_rewards = batch.rewards.to(self.accelerator.device) 214 | 215 | response_length = old_rewards.shape[1] 216 | 217 | advantages, returns = self.config.method.get_advantages_and_returns( 218 | old_values, old_rewards, response_length 219 | ) 220 | 221 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask = self.get_model_inputs( 222 | query_tensors, response_tensors 223 | ) 224 | 225 | outputs = self.model( 226 | input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, decoder_attention_mask=decoder_attention_mask 227 | ) 228 | logits, values_pred = outputs.logits, outputs.value 229 | values_pred = values_pred[:, :-1] 230 | logprobs = logprobs_from_logits(logits[:, :-1, :], decoder_input_ids[:, 1:]) # decoder_input_ids doesn't include the start token so we don't need to shift it 231 | attention_mask = attention_mask[:, :-1] 232 | 233 | loss, stats = self.config.method.loss( 234 | logprobs=logprobs, 235 | values=values_pred, 236 | old_logprobs=old_logprobs, 237 | old_values=old_values, 238 | advantages=advantages, 239 | returns=returns, 240 | mask=decoder_attention_mask[:, 1:], 241 | ) 242 | self.approx_kl = stats["policy/approx_kl"] # Update kl controller stats 243 | return loss, stats 244 | 245 | def evaluate(self): 246 | """Samples model on `eval_prompts`, logs stats with `reward_fn` or `metric_fn` if provided""" 247 | stats = {} 248 | generate_time = time() 249 | prompts_list, responses = [], [] 250 | print("Starting evaluation...") 251 | for prompts in tqdm(self.eval_dataloader, desc="Evaluating"): 252 | if isinstance(prompts, torch.Tensor): 253 | attention_mask = ( 254 | prompts.not_equal(self.tokenizer.pad_token_id) 255 | .long() 256 | .to(prompts.device) 257 | ) 258 | samples = self.generate(prompts, attention_mask=attention_mask, use_cache=True, max_new_tokens=50, do_sample=False) 259 | prompts_list.extend( 260 | self.tokenizer.batch_decode( 261 | prompts, skip_special_tokens=True 262 | ) 263 | ) 264 | else: 265 | samples = self.generate(**prompts, use_cache=True, max_new_tokens=50, do_sample=False) 266 | prompts_list.extend( 267 | self.tokenizer.batch_decode( 268 | prompts["input_ids"], skip_special_tokens=True 269 | ) 270 | ) 271 | 272 | if isinstance(samples, tuple): 273 | samples, *_ = samples 274 | 275 | responses.extend( 276 | self.tokenizer.batch_decode( 277 | samples, skip_special_tokens=True 278 | ) 279 | ) 280 | 281 | 282 | stats["time/generate"] = time() - generate_time 283 | 284 | if self.accelerator.is_main_process: 285 | 286 | columns_data = [prompts_list, responses] 287 | 288 | columns = ["prompt", "response"] 289 | 290 | str_samples = [f"{prompt} {response}" for prompt, response in zip(prompts_list, responses)] 291 | 292 | # in online setting, compute the reward for validation 293 | if self.reward_fn: 294 | rewards = torch.tensor(self.reward_fn(str_samples), dtype=torch.float) 295 | mean_reward = rewards.mean() 296 | columns.append("reward") 297 | columns_data.append(rewards) 298 | stats["reward/mean"] = mean_reward 299 | print(f"Mean rewards: {mean_reward}") 300 | if mean_reward > self.best_mean_reward: 301 | self.best_mean_reward = mean_reward 302 | self.save() 303 | print("=== Saved ===") 304 | 305 | # additionally log any other metrics 306 | if self.metric_fn: 307 | metric_time = time() 308 | metrics = self.metric_fn(str_samples) 309 | stats["time/metric"] = time() - metric_time 310 | 311 | mean_metrics = { 312 | f"metrics/{k}": torch.as_tensor(xs).mean(-1) 313 | for k, xs in metrics.items() 314 | } 315 | 316 | stats.update(mean_metrics) 317 | 318 | for metric, values in metrics.items(): 319 | columns.append(metric) 320 | columns_data.append(values) 321 | 322 | rows = list(zip(*columns_data)) 323 | if not ray.is_initialized(): 324 | stats["samples"] = wandb.Table(columns=columns, rows=rows) 325 | 326 | return stats -------------------------------------------------------------------------------- /trlx/model/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/CG80499/trlx-with-T5/f4eae6703eee125a7adf6a291031f5efe76e2ed7/trlx/model/nn/__init__.py -------------------------------------------------------------------------------- /trlx/model/nn/ilql_models.py: -------------------------------------------------------------------------------- 1 | import inspect 2 | import os 3 | from collections import defaultdict 4 | from copy import deepcopy 5 | from dataclasses import dataclass 6 | from functools import reduce 7 | from itertools import chain 8 | from typing import Any, Dict, Union, Sequence 9 | 10 | from trlx.data.ilql_types import ILQLBatch 11 | from trlx.data.method_configs import register_method, MethodConfig 12 | from trlx.utils.modeling import ( 13 | freeze_bottom_causal_layers, 14 | hf_get_causal_base_model, 15 | hf_get_hidden_size, 16 | hf_get_lm_head, 17 | make_head, 18 | ) 19 | 20 | 21 | import deepspeed # type: ignore 22 | import numpy as np 23 | import torch 24 | import torch.nn.functional as F 25 | import transformers 26 | from torch import nn 27 | 28 | 29 | def topk_mask(xs: torch.FloatTensor, k: int): 30 | if k > xs.shape[-1]: 31 | return xs 32 | mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) 33 | return torch.where(xs < mintop, -np.inf * torch.ones_like(xs, dtype=xs.dtype), xs) 34 | 35 | 36 | @dataclass 37 | @register_method 38 | class ILQLConfig(MethodConfig): 39 | tau: float 40 | gamma: float 41 | cql_scale: float 42 | awac_scale: float 43 | alpha: float 44 | steps_for_target_q_sync: float 45 | two_qs: bool 46 | gen_kwargs: dict 47 | 48 | def heads(self, hidden_size: int, vocab_size: int): 49 | return ILQLHeads(self, hidden_size, vocab_size) 50 | 51 | def loss(self, outputs, labels: ILQLBatch): 52 | logits, (qs, target_qs, vs) = outputs 53 | actions = ( 54 | labels.input_ids[:, 1:] 55 | .gather(dim=1, index=labels.actions_ixs) 56 | .unsqueeze(-1) 57 | ) 58 | bsize, ntokens, dsize = logits.shape 59 | 60 | Q = [q.gather(-1, actions).squeeze(-1) for q in qs] 61 | targetQs = [q.gather(-1, actions).squeeze(-1).detach() for q in target_qs] 62 | targetQ = reduce(torch.minimum, targetQs) 63 | terminal_mask = labels.dones[:, :-1] 64 | n_nonterminal = max(1, terminal_mask.sum()) 65 | 66 | # values of current states 67 | V = vs[:, :-1].squeeze() 68 | # values of next states 69 | Vnext = vs[:, 1:].squeeze() * labels.dones[:, 1:] 70 | # target to fit Q 71 | Q_ = labels.rewards + self.gamma * Vnext.detach() 72 | 73 | loss_qs = [((Qi - Q_) * terminal_mask).pow(2).sum() / n_nonterminal for Qi in Q] 74 | loss_q = sum(loss_qs) 75 | 76 | targetQ = targetQ.detach() 77 | 78 | loss_v = ( 79 | ( 80 | (targetQ >= V).int() * self.tau * (targetQ - V).pow(2) 81 | + (targetQ < V).int() * (1 - self.tau) * (targetQ - V).pow(2) 82 | ) 83 | * terminal_mask 84 | ).sum() / n_nonterminal 85 | 86 | nactions = qs[0].shape[1] 87 | 88 | def cql_loss(q): 89 | loss = F.cross_entropy( 90 | q.reshape(-1, dsize), actions.reshape(-1), reduction="none" 91 | ) 92 | loss = loss.reshape(bsize, nactions) * terminal_mask 93 | loss = loss.sum() / n_nonterminal 94 | return loss 95 | 96 | loss_cql = sum(cql_loss(q) for q in qs) 97 | 98 | loss_awac = ( 99 | F.cross_entropy( 100 | logits[:, :-1, :].reshape(-1, dsize), 101 | labels.input_ids[:, 1:].reshape(-1), 102 | reduction="none", 103 | ).reshape(bsize, ntokens - 1) 104 | * labels.attention_mask[:, 1:] 105 | ).sum() / labels.attention_mask[:, 1:].sum() 106 | 107 | loss = loss_q + loss_v + self.cql_scale * loss_cql + self.awac_scale * loss_awac 108 | 109 | stats = { 110 | f"losses/{k}": v 111 | for k, v in locals().items() 112 | if k in ["loss", "loss_v", "loss_q", "loss_cql", "loss_awac"] 113 | } 114 | 115 | return loss, stats 116 | 117 | 118 | class ILQLHeads(nn.Module): 119 | def __init__(self, config: ILQLConfig, hidden_size: int, vocab_size: int): 120 | super().__init__() 121 | 122 | self.hidden_size = hidden_size 123 | self.vocab_size = vocab_size 124 | self.v_head = make_head(self.hidden_size, 1) 125 | self.config = config 126 | 127 | n_qs = 2 if self.config.two_qs else 1 128 | 129 | self.q_heads = nn.ModuleList( 130 | make_head(self.hidden_size, self.vocab_size) for _ in range(n_qs) 131 | ) 132 | self.target_q_heads = nn.ModuleList(deepcopy(q_head) for q_head in self.q_heads) 133 | 134 | for q_head in self.target_q_heads: 135 | q_head.requires_grad_(False) 136 | 137 | def forward( 138 | self, 139 | hs: torch.Tensor, 140 | states_ixs: torch.Tensor = None, 141 | actions_ixs: torch.Tensor = None, 142 | ): 143 | if states_ixs is not None: 144 | states_hs = hs.gather( 145 | dim=1, index=states_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1]) 146 | ) 147 | actions_hs = hs.gather( 148 | dim=1, index=actions_ixs.unsqueeze(-1).repeat(1, 1, hs.shape[-1]) 149 | ) 150 | else: 151 | states_hs = actions_hs = hs 152 | 153 | qs = tuple(q_head(actions_hs) for q_head in self.q_heads) 154 | target_qs = tuple(q_head(actions_hs) for q_head in self.target_q_heads) 155 | vs = self.v_head(states_hs) 156 | 157 | return qs, target_qs, vs 158 | 159 | def _sync_target_q_heads(self, alpha): 160 | for target_q_head, q_head in zip(self.target_q_heads, self.q_heads): 161 | for target_param, copy_param in zip( 162 | target_q_head.parameters(), q_head.parameters() 163 | ): 164 | target_param.data.copy_( 165 | (alpha * copy_param.data) + (1.0 - alpha) * target_param.data 166 | ) 167 | 168 | def sync_target_q_heads(self): 169 | if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": 170 | params = chain( 171 | chain(q_head.parameters() for q_head in self.q_heads), 172 | chain(q_head.parameters() for q_head in self.target_q_heads), 173 | ) 174 | 175 | with deepspeed.zero.GatheredParameters(list(params), modifier_rank=0): 176 | if deepspeed.comm.get_rank() == 0: 177 | self._sync_target_q_heads(self.config.alpha) 178 | else: 179 | self._sync_target_q_heads(self.config.alpha) 180 | 181 | 182 | class CausalLMWithValueHeads(nn.Module): 183 | """This is a wrapper around huggingface AutoModelForCausalLM with two additional scalar heads""" 184 | 185 | def __init__( 186 | self, 187 | config: Union[transformers.PretrainedConfig, str], 188 | ilql_config: ILQLConfig, 189 | num_layers_unfrozen=-1, 190 | ): 191 | super().__init__() 192 | 193 | # enable zero3 init within from_pretrained 194 | if os.environ.get("DEEPSPEED_ZERO_STAGE", "0") == "3": 195 | config_path = os.environ.get("DEEPSPEED_CONFIG_FILE", "") 196 | if config_path: 197 | _hfconfig = transformers.deepspeed.HfDeepSpeedConfig( # noqa: F841 198 | config_path 199 | ) 200 | if isinstance(config, str): 201 | self.config = transformers.AutoConfig.from_pretrained(config) 202 | else: 203 | self.config = config 204 | 205 | self.base_model = transformers.AutoModelForCausalLM.from_pretrained( 206 | self.config.name_or_path, 207 | ) 208 | self.base_model.transformer = hf_get_causal_base_model(self.base_model) 209 | self.base_model.lm_head = hf_get_lm_head(self.base_model) 210 | freeze_bottom_causal_layers(self.base_model, num_layers_unfrozen) 211 | 212 | # Cache `transformer.forward` args for general use (avoids incompatible args across architectures) 213 | self.base_model_transformer_args = inspect.getfullargspec( 214 | self.base_model.transformer.forward 215 | ).args 216 | 217 | self.hidden_size = hf_get_hidden_size(self.config) 218 | self.ilql_heads = ilql_config.heads(self.hidden_size, self.config.vocab_size) 219 | self.ilql_config = ilql_config 220 | 221 | def _get_compatible_forward_kwargs(self, **kwargs) -> Dict[str, Any]: 222 | """Filter out arguments not supported by the specific instance of `base_model.transformer.forward`""" 223 | return { 224 | k: v for k, v in kwargs.items() if k in self.base_model_transformer_args 225 | } 226 | 227 | def sync_target_q_heads(self): 228 | self.ilql_heads.sync_target_q_heads() 229 | 230 | def forward( 231 | self, 232 | input_ids, 233 | attention_mask=None, 234 | position_ids=None, 235 | past_key_values=None, 236 | actions_ixs=None, 237 | states_ixs=None, 238 | ): 239 | forward_kwargs = self._get_compatible_forward_kwargs( 240 | input_ids=input_ids, 241 | attention_mask=attention_mask, 242 | position_ids=position_ids, 243 | past_key_values=past_key_values, 244 | ) 245 | out = self.base_model.transformer(**forward_kwargs) 246 | hs = out.last_hidden_state 247 | 248 | logits = self.base_model.lm_head(hs) 249 | qs, target_qs, vs = self.ilql_heads( 250 | hs, states_ixs=states_ixs, actions_ixs=actions_ixs 251 | ) 252 | 253 | return logits, qs, target_qs, vs, out.past_key_values 254 | 255 | def generate( 256 | self, 257 | input_ids, 258 | attention_mask=None, 259 | position_ids=None, 260 | past_key_values=None, 261 | beta=1, 262 | max_new_tokens=32, 263 | max_length=1024, 264 | temperature=1, 265 | top_k=20, 266 | logit_mask=None, 267 | pad_token_id=None, 268 | eos_token_id=None, 269 | ): 270 | """ 271 | Generates samples akin to hf's `.generate` but with custom logp prepossessing: changing token probabilities as to how advantageous they would be according to value functions estimations. 272 | """ 273 | if attention_mask is None: 274 | attention_mask = input_ids.not_equal(pad_token_id) 275 | 276 | if position_ids is None: 277 | position_ids = attention_mask.cumsum(-1) - 1 278 | position_ids.masked_fill_(attention_mask.eq(0), 0) 279 | 280 | samples = input_ids.clone() 281 | max_new_tokens = min(max_new_tokens, max_length - input_ids.shape[1]) 282 | 283 | finished = torch.zeros( 284 | input_ids.shape[0], 1, dtype=torch.long, device=input_ids.device 285 | ) 286 | for _ in range(max_new_tokens): 287 | out = self.forward( 288 | input_ids=input_ids, 289 | attention_mask=attention_mask, 290 | position_ids=position_ids, 291 | past_key_values=past_key_values, 292 | ) 293 | 294 | logits, _, target_qs, vs, past_key_values = out 295 | if self.ilql_config.two_qs: 296 | qs = torch.minimum(target_qs[0][:, -1, :], target_qs[1][:, -1, :]) 297 | else: 298 | qs = target_qs[:, -1, :] 299 | 300 | logits = logits[:, -1, :] 301 | vs = vs[:, -1, :] 302 | 303 | if logit_mask is not None: 304 | mask = logit_mask[input_ids[:, -1].squeeze().to(logit_mask.device)] 305 | logits[torch.where(mask)] = -np.inf 306 | 307 | adv = qs - vs 308 | pi_beta = F.log_softmax(logits, -1) 309 | pi_top_k = topk_mask(pi_beta + beta * adv, top_k) 310 | pi = F.softmax(pi_top_k / temperature, -1) 311 | 312 | input_ids = torch.multinomial(pi, num_samples=1) 313 | input_ids = (1 - finished) * input_ids + finished * eos_token_id 314 | finished = (input_ids == eos_token_id).long() 315 | 316 | samples = torch.hstack((samples, input_ids)) 317 | attention_mask = torch.hstack( 318 | (attention_mask, (input_ids != eos_token_id).long()) 319 | ) 320 | position_ids = (position_ids[:, -1] + 1).view(-1, 1) 321 | 322 | if torch.all(finished): 323 | break 324 | 325 | return samples 326 | 327 | @property 328 | def dummy_inputs(self): 329 | return { 330 | "input_ids": torch.ones( 331 | 1, 1, device=self.base_model.device, dtype=torch.long 332 | ) 333 | } 334 | 335 | @property 336 | def device(self): 337 | return self.base_model.device 338 | -------------------------------------------------------------------------------- /trlx/orchestrator/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from abc import abstractmethod 3 | from typing import Dict 4 | 5 | from trlx.model import BaseRLModel 6 | from trlx.pipeline import BasePipeline 7 | 8 | # specifies a dictionary of architectures 9 | _ORCH: Dict[str, any] = {} # registry 10 | 11 | 12 | def register_orchestrator(name): 13 | """Decorator used register a CARP architecture 14 | Args: 15 | name: Name of the architecture 16 | """ 17 | 18 | def register_class(cls, name): 19 | _ORCH[name] = cls 20 | setattr(sys.modules[__name__], name, cls) 21 | return cls 22 | 23 | if isinstance(name, str): 24 | name = name.lower() 25 | return lambda c: register_class(c, name) 26 | 27 | cls = name 28 | name = cls.__name__ 29 | register_class(cls, name.lower()) 30 | 31 | return cls 32 | 33 | 34 | @register_orchestrator 35 | class Orchestrator: 36 | def __init__(self, pipeline: BasePipeline, rl_model: BaseRLModel): 37 | self.pipeline = pipeline 38 | self.rl_model = rl_model 39 | 40 | @abstractmethod 41 | def make_experience(self): 42 | """ 43 | Draw from pipeline, get action, generate reward 44 | Push to models RolloutStorage 45 | """ 46 | pass 47 | -------------------------------------------------------------------------------- /trlx/orchestrator/offline_orchestrator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from trlx.orchestrator import Orchestrator, register_orchestrator 4 | from trlx.pipeline.offline_pipeline import ILQLRolloutStorage 5 | 6 | 7 | @register_orchestrator 8 | class OfflineOrchestrator(Orchestrator): 9 | """ 10 | Orchestrator that creates a static dataset for offline training 11 | """ 12 | 13 | def __init__(self, model, split_token=None): 14 | self.model = model 15 | self.split_token = split_token 16 | 17 | def make_experience(self, samples, rewards): 18 | """ 19 | Tokenizes samples and shapes rewards into proper tensors and then inserts the resulting dataset into the model 20 | """ 21 | if self.model.tokenizer: 22 | input_ids = self.model.tokenize(samples) 23 | else: 24 | input_ids = samples 25 | 26 | input_ids = list(map(torch.as_tensor, input_ids)) 27 | 28 | states_ixs, actions_ixs = [], [] 29 | dones = [] 30 | for s, s_tok in zip(samples, input_ids): 31 | # split samples on (prompts, continuations) on a given substring `split_token` 32 | if self.split_token: 33 | prompt_str_len = s.index(self.split_token) + len(self.split_token) 34 | prompt_tok_len = len(self.model.tokenizer(s[:prompt_str_len]).input_ids) 35 | # else assume that the prompt is a bos token 36 | else: 37 | prompt_tok_len = 1 38 | 39 | # indices of continuations, to mask prompts in loss computation 40 | a_ixs = torch.arange(prompt_tok_len - 1, len(s_tok) - 1) 41 | # same continuations but for value computation, with the premise to eventually support interleaved dialog 42 | s_ixs = torch.arange(prompt_tok_len - 1, len(s_tok)) 43 | # mask continuation's ending 44 | terminals = torch.ones_like(s_ixs) 45 | terminals[-1] = 0 46 | 47 | actions_ixs.append(a_ixs) 48 | states_ixs.append(s_ixs) 49 | dones.append(terminals) 50 | 51 | if self.model.tokenizer: 52 | prompt = self.model.tokenizer.decode(input_ids[0][: states_ixs[0][1]]) 53 | response = self.model.tokenizer.decode(input_ids[0][states_ixs[0][1] :]) 54 | print("[Sample example]") 55 | print("Prompt: ", prompt) 56 | print("Response: ", response) 57 | 58 | print(f"[Mean reward] {torch.Tensor(rewards).mean():.2f}") 59 | print( 60 | f"[Mean sample length] {torch.mean(torch.Tensor(list(map(len, input_ids)))):.2f}" 61 | ) 62 | 63 | returns = torch.as_tensor(rewards, dtype=torch.float) 64 | returns = (returns - returns.mean()) / (returns.std() + 1e-30) 65 | 66 | rewards = [torch.zeros(x.shape[0]) for x in actions_ixs] 67 | for rs, G in zip(rewards, returns): 68 | rs[-1] = G 69 | 70 | attention_mask = [torch.ones(x.shape[0], dtype=int) for x in input_ids] 71 | 72 | self.model.store = ILQLRolloutStorage( 73 | input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones 74 | ) 75 | -------------------------------------------------------------------------------- /trlx/orchestrator/ppo_orchestrator.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | import torch 4 | from trlx.data.accelerate_base_datatypes import PromptBatch 5 | from trlx.data.ppo_types import PPORLElement 6 | from trlx.model import BaseRLModel 7 | from trlx.orchestrator import Orchestrator, register_orchestrator 8 | from trlx.pipeline import BasePipeline 9 | from trlx.utils import Clock 10 | from trlx.utils.modeling import logprobs_from_logits, RunningMoments 11 | from trlx.model.accelerate_ppo_model import T5AcceleratePPOModel 12 | 13 | from time import time 14 | import ray 15 | 16 | import transformers 17 | 18 | def _add_start_token_to_decoder_ids(decoder_input_ids, decoder_attention_mask): 19 | """Add padding to decoder_input_ids""" 20 | batch_size, seq_len = decoder_input_ids.shape 21 | padding = torch.zeros(batch_size, 1, dtype=decoder_input_ids.dtype).to( 22 | decoder_input_ids.device 23 | ) 24 | decoder_attention_mask = torch.cat( 25 | [1 - padding, decoder_attention_mask], dim=1 # Start token is not masked 26 | ) 27 | 28 | return torch.cat([padding, decoder_input_ids], dim=1), decoder_attention_mask 29 | 30 | GPU_REFERENCE_MODEL = 2 31 | 32 | @register_orchestrator 33 | class PPOOrchestrator(Orchestrator): 34 | """ 35 | Orchestrator that prepares data for PPO training: transforms samples from `pipeline` into `PPOBatch` and pushes them into model's `store` 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model: BaseRLModel, 41 | pipeline: BasePipeline, 42 | reward_fn: Callable, 43 | metric_fn: Callable = None, 44 | chunk_size: int = 512, 45 | ): 46 | self.pipeline = pipeline 47 | self.rl_model = model 48 | self.chunk_size = chunk_size 49 | 50 | self.pipeline_loader = self.pipeline.create_loader( 51 | self.chunk_size, shuffle=True 52 | ) 53 | self.pipeline_loader = self.rl_model.accelerator.prepare(self.pipeline_loader) 54 | self.pipeline_iterator = iter(self.pipeline_loader) 55 | 56 | if not hasattr(self.rl_model.model, "frozen_head"): 57 | self.ref_model = self.rl_model.get_arch(self.rl_model.config) 58 | 59 | self.rl_model.orch = self 60 | self.rl_model.reward_fn = reward_fn 61 | self.rl_model.metric_fn = metric_fn 62 | 63 | self.running = RunningMoments() 64 | self.ref_mean = self.rl_model.config.method.ref_mean 65 | self.ref_std = self.rl_model.config.method.ref_std 66 | 67 | def score(self, samples): 68 | """ 69 | Batched scoring function taking text and generating scalar 70 | """ 71 | return self.rl_model.reward_fn(samples) 72 | 73 | def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): 74 | """ 75 | Takes `num_rollouts` prompts from `pipeline`, samples model, computes KL againts a reference model appends PPOElements to model's `store` 76 | """ 77 | ppo_rl_elements = [] 78 | stats = {} 79 | clock = Clock() 80 | while len(ppo_rl_elements) < num_rollouts: 81 | # Get next batch in prompt dataset and refresh if exhausted 82 | try: 83 | batch: PromptBatch = next(self.pipeline_iterator) 84 | except StopIteration: 85 | self.pipeline_iterator = iter(self.pipeline_loader) 86 | batch = next(self.pipeline_iterator) 87 | 88 | exp_generate_time = time() 89 | samples = self.rl_model.generate(**batch) 90 | stats["time/exp_generate"] = time() - exp_generate_time 91 | 92 | query_tensors = batch.input_ids 93 | response_tensors = samples[:, query_tensors.shape[1] :] 94 | texts = self.rl_model.tokenizer.batch_decode( 95 | samples, skip_special_tokens=True 96 | ) 97 | exp_score_time = time() 98 | scores = torch.tensor( 99 | self.score(texts), device=samples.device, dtype=torch.float 100 | ) 101 | stats["time/exp_score"] = time() - exp_score_time 102 | 103 | # store statistics of the initial rollout as reference 104 | if self.ref_mean is None: 105 | self.ref_mean, self.ref_std = scores.mean(), scores.std() 106 | all_scores_mean, all_scores_std = self.running.update(scores) 107 | stats["exp_scores/mean"] = all_scores_mean 108 | stats["exp_scores/std"] = all_scores_std 109 | stats["exp_scores/running_mean"] = self.running.mean 110 | stats["exp_scores/running_std"] = self.running.std 111 | 112 | if self.rl_model.config.method.scale_reward == "running": 113 | scores /= self.running.std 114 | elif self.rl_model.config.method.scale_reward == "ref": 115 | scores /= self.ref_std 116 | 117 | clip_reward = self.rl_model.config.method.cliprange_reward 118 | if clip_reward: 119 | scores = torch.clip(scores, -clip_reward, clip_reward) 120 | 121 | # Precompute logprobs, values 122 | all_tokens, attention_mask, position_ids = self.rl_model.get_model_inputs( 123 | query_tensors.to(response_tensors.device), response_tensors 124 | ) 125 | with torch.no_grad(): 126 | logits, *_, values = self.rl_model.model( 127 | all_tokens, attention_mask=attention_mask, position_ids=position_ids 128 | ) 129 | # TODO(dahoas): When hydra model works need to also support generation on hydra head 130 | if hasattr(self.rl_model.model, "frozen_head"): 131 | ref_logits = self.rl_model.model.forward_hydra( 132 | all_tokens, 133 | attention_mask=attention_mask, 134 | position_ids=position_ids, 135 | return_dict=False, 136 | ) 137 | else: 138 | ref_logits, _, *_ = self.ref_model( 139 | all_tokens.to(self.reference_device), 140 | attention_mask=attention_mask.to(self.reference_device), 141 | position_ids=position_ids.to(self.reference_device), 142 | ) 143 | ref_logits = ref_logits.to(self.rl_model.accelerator.device) 144 | 145 | logprobs = logprobs_from_logits(logits[:, :-1, :], all_tokens[:, 1:]) 146 | ref_logprobs = logprobs_from_logits( 147 | ref_logits[:, :-1, :], all_tokens[:, 1:] 148 | ) 149 | 150 | n = samples.shape[0] 151 | values = values.cpu()[:, :-1] 152 | logprobs = logprobs.cpu() 153 | ref_logprobs = ref_logprobs.cpu() 154 | query_tensors = query_tensors.cpu() 155 | response_tensors = response_tensors.cpu() 156 | 157 | start = query_tensors.shape[1] - 1 158 | ends = start + attention_mask[:, start:].sum(1) 159 | all_values = [values[ix, start : ends[ix]] for ix in range(n)] 160 | all_logprobs = [logprobs[ix, start : ends[ix]] for ix in range(n)] 161 | 162 | # Compute rewards 163 | rewards = -self.rl_model.kl_ctl.value * (logprobs - ref_logprobs) 164 | all_rewards = [None] * n 165 | for ix in range(n): 166 | rs = rewards[ix][start : ends[ix]] 167 | rs[-1] = scores[ix] 168 | all_rewards[ix] = rs 169 | 170 | new_ppo_rl_elements = [ 171 | PPORLElement( 172 | query_tensor=query_tensors[i], 173 | response_tensor=response_tensors[i], 174 | logprobs=all_logprobs[i], 175 | values=all_values[i], 176 | rewards=all_rewards[i], 177 | ) 178 | for i in range(n) 179 | ] 180 | 181 | ppo_rl_elements += new_ppo_rl_elements 182 | exp_time = clock.tick() 183 | 184 | stats["kl_ctl_value"] = self.rl_model.kl_ctl.value 185 | stats["time/exp"] = exp_time 186 | 187 | if not ray.is_initialized(): 188 | self.rl_model.accelerator.log(stats, step=iter_count) 189 | 190 | # Push samples and rewards to model's rollout storage 191 | self.rl_model.push_to_store(ppo_rl_elements) 192 | 193 | @register_orchestrator 194 | class T5PPOOrchestrator(PPOOrchestrator): 195 | 196 | def __init__( 197 | self, 198 | model: T5AcceleratePPOModel, 199 | pipeline: BasePipeline, 200 | reward_fn: Callable, 201 | metric_fn: Callable = None, 202 | chunk_size: int = 512, 203 | ): 204 | super().__init__(model, pipeline, reward_fn, metric_fn, chunk_size) 205 | 206 | print(" ===== LOADING REFERENCE MODEL =====") 207 | self.ref_model = transformers.T5ForConditionalGeneration.from_pretrained( 208 | model.config.model.model_path 209 | ) 210 | 211 | self.reference_device = torch.device(f"cuda:{GPU_REFERENCE_MODEL}") 212 | 213 | #if GPU_REFERENCE_MODEL: 214 | print(" ===== MOVING REFERENCE MODEL TO GPU =====") 215 | self.ref_model = self.ref_model.to(torch.bfloat16).to(self.reference_device) 216 | 217 | 218 | def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): 219 | """ 220 | Takes `num_rollouts` prompts from `pipeline`, samples model, computes KL againts a reference model appends PPOElements to model's `store` 221 | """ 222 | ppo_rl_elements = [] 223 | stats = {} 224 | clock = Clock() 225 | while len(ppo_rl_elements) < num_rollouts: 226 | # Get next batch in prompt dataset and refresh if exhausted 227 | try: 228 | batch: PromptBatch = next(self.pipeline_iterator) 229 | except StopIteration: 230 | self.pipeline_iterator = iter(self.pipeline_loader) 231 | batch = next(self.pipeline_iterator) 232 | 233 | exp_generate_time = time() 234 | samples = self.rl_model.generate(**batch, use_cache=True, max_new_tokens=50, do_sample=True) 235 | stats["time/exp_generate"] = time() - exp_generate_time 236 | 237 | query_tensors = batch.input_ids 238 | response_tensors = samples 239 | response_texts = self.rl_model.tokenizer.batch_decode( 240 | samples, skip_special_tokens=True 241 | ) 242 | 243 | query_texts = self.rl_model.tokenizer.batch_decode( 244 | query_tensors, skip_special_tokens=True 245 | ) 246 | 247 | texts = [f"{q} {r}" for q, r in zip(query_texts, response_texts)] 248 | 249 | exp_score_time = time() 250 | scores = torch.tensor( 251 | self.score(texts), device=samples.device, dtype=torch.float 252 | ) 253 | stats["time/exp_score"] = time() - exp_score_time 254 | 255 | # store statistics of the initial rollout as reference 256 | if self.ref_mean is None: 257 | self.ref_mean, self.ref_std = scores.mean(), scores.std() 258 | all_scores_mean, all_scores_std = self.running.update(scores) 259 | stats["exp_scores/mean"] = all_scores_mean 260 | stats["exp_scores/std"] = all_scores_std 261 | stats["exp_scores/running_mean"] = self.running.mean 262 | stats["exp_scores/running_std"] = self.running.std 263 | 264 | if self.rl_model.config.method.scale_reward == "running": 265 | scores /= self.running.std 266 | elif self.rl_model.config.method.scale_reward == "ref": 267 | scores /= self.ref_std 268 | 269 | clip_reward = self.rl_model.config.method.cliprange_reward 270 | if clip_reward: 271 | scores = torch.clip(scores, -clip_reward, clip_reward) 272 | 273 | # Precompute logprobs, values 274 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask = self.rl_model.get_model_inputs( 275 | query_tensors.to(response_tensors.device), response_tensors 276 | ) 277 | 278 | with torch.no_grad(): 279 | outputs = self.rl_model.model( 280 | input_ids, attention_mask, decoder_input_ids, decoder_attention_mask 281 | ) 282 | logits, values = outputs.logits, outputs.value 283 | ref_logits = self.ref_model( 284 | input_ids=input_ids.to(self.reference_device), 285 | attention_mask=attention_mask.to(self.reference_device), 286 | decoder_input_ids=decoder_input_ids.to(self.reference_device), 287 | decoder_attention_mask=decoder_attention_mask.to(self.reference_device), 288 | ).logits 289 | 290 | ref_logits = ref_logits.to(self.rl_model.accelerator.device) 291 | 292 | logprobs = logprobs_from_logits(logits[:, :-1, :], decoder_input_ids[:, 1:]) 293 | ref_logprobs = logprobs_from_logits( 294 | ref_logits[:, :-1, :], decoder_input_ids[:, 1:] 295 | ) 296 | 297 | n = samples.shape[0] 298 | values = values.cpu()[:, :-1] 299 | logprobs = logprobs.cpu() 300 | ref_logprobs = ref_logprobs.cpu() 301 | query_tensors = query_tensors.cpu() 302 | response_tensors = response_tensors.cpu() 303 | 304 | all_values = [values[ix] for ix in range(n)] 305 | all_logprobs = [logprobs[ix] for ix in range(n)] 306 | 307 | # Compute rewards 308 | rewards = -self.rl_model.kl_ctl.value * (logprobs - ref_logprobs) 309 | all_rewards = [None] * n 310 | for ix in range(n): 311 | rs = rewards[ix] 312 | rs[-1] = scores[ix] 313 | all_rewards[ix] = rs 314 | 315 | new_ppo_rl_elements = [ 316 | PPORLElement( 317 | query_tensor=query_tensors[i], 318 | response_tensor=response_tensors[i], 319 | logprobs=all_logprobs[i], 320 | values=all_values[i], 321 | rewards=all_rewards[i], 322 | ) 323 | for i in range(n) 324 | ] 325 | 326 | ppo_rl_elements += new_ppo_rl_elements 327 | exp_time = clock.tick() 328 | 329 | stats["kl_ctl_value"] = self.rl_model.kl_ctl.value 330 | stats["time/exp"] = exp_time 331 | 332 | if not ray.is_initialized(): 333 | self.rl_model.accelerator.log(stats, step=iter_count) 334 | 335 | # Push samples and rewards to model's rollout storage 336 | self.rl_model.push_to_store(ppo_rl_elements) 337 | -------------------------------------------------------------------------------- /trlx/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from abc import abstractmethod, abstractstaticmethod 4 | from typing import Any, Callable, Dict, Iterable 5 | 6 | from datasets import load_from_disk 7 | from torch.utils.data import DataLoader, Dataset 8 | 9 | from trlx.data import GeneralElement, RLElement 10 | 11 | # specifies a dictionary of architectures 12 | _DATAPIPELINE: Dict[str, any] = {} # registry 13 | 14 | 15 | def register_datapipeline(name): 16 | """Decorator used register a CARP architecture 17 | Args: 18 | name: Name of the architecture 19 | """ 20 | 21 | def register_class(cls, name): 22 | _DATAPIPELINE[name] = cls 23 | setattr(sys.modules[__name__], name, cls) 24 | return cls 25 | 26 | if isinstance(name, str): 27 | name = name.lower() 28 | return lambda c: register_class(c, name) 29 | 30 | cls = name 31 | name = cls.__name__ 32 | register_class(cls, name.lower()) 33 | 34 | return cls 35 | 36 | 37 | @register_datapipeline 38 | class BasePipeline(Dataset): 39 | def __init__(self, path: str = "dataset"): 40 | super().__init__() 41 | 42 | @abstractmethod 43 | def __getitem__(self, index: int) -> GeneralElement: 44 | pass 45 | 46 | @abstractmethod 47 | def __len__(self) -> int: 48 | pass 49 | 50 | @abstractmethod 51 | def create_loader( 52 | self, 53 | batch_size: int, 54 | shuffle: bool, 55 | prep_fn: Callable = None, 56 | num_workers: int = 0, 57 | ) -> DataLoader: 58 | """ 59 | Create a dataloader for the pipeline 60 | 61 | :param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation. 62 | """ 63 | pass 64 | 65 | 66 | class BaseRolloutStore(Dataset): 67 | def __init__(self, capacity=-1): 68 | self.history: Iterable[Any] = None 69 | self.capacity = capacity 70 | 71 | @abstractmethod 72 | def push(self, exps: Iterable[Any]): 73 | """ 74 | Push experiences to rollout storage 75 | """ 76 | pass 77 | 78 | def __getitem__(self, index: int) -> RLElement: 79 | return self.history[index] 80 | 81 | def __len__(self) -> int: 82 | return len(self.history) 83 | 84 | @abstractmethod 85 | def create_loader( 86 | self, 87 | batch_size: int, 88 | shuffle: bool, 89 | prep_fn: Callable = None, 90 | num_workers: int = 0, 91 | ) -> DataLoader: 92 | """ 93 | Create a dataloader for the rollout store 94 | 95 | :param prep_fn: Applied to RLElement after collation (typically tokenizer) 96 | :type prep_fn: Callable 97 | """ 98 | pass 99 | -------------------------------------------------------------------------------- /trlx/pipeline/offline_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import DataLoader 6 | from transformers import DataCollatorWithPadding 7 | 8 | from trlx.data.ilql_types import ILQLBatch, ILQLElement 9 | from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline 10 | 11 | 12 | @register_datapipeline 13 | class PromptPipeline(BasePipeline): 14 | """ 15 | Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right 16 | """ 17 | 18 | def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer=None): 19 | super().__init__() 20 | 21 | if tokenizer: 22 | prompts = tokenizer(prompts).input_ids 23 | 24 | self.tokenizer = tokenizer 25 | self.prompts = [prompt[-max_prompt_length:] for prompt in prompts] 26 | self.prompts = [ 27 | {"input_ids": prompt, "attention_mask": [1] * len(prompt)} 28 | for prompt in self.prompts 29 | ] 30 | 31 | def __getitem__(self, ix: int): 32 | return self.prompts[ix] 33 | 34 | def __len__(self) -> int: 35 | return len(self.prompts) 36 | 37 | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: 38 | collate_fn = ( 39 | DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack 40 | ) 41 | return DataLoader( 42 | self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle 43 | ) 44 | 45 | 46 | class ILQLRolloutStorage(BaseRolloutStore): 47 | """ 48 | Rollout storage for training ILQL 49 | """ 50 | 51 | def __init__( 52 | self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones 53 | ): 54 | super().__init__() 55 | 56 | self.input_ids = input_ids 57 | self.attention_mask = attention_mask 58 | self.rewards = rewards 59 | self.states_ixs = states_ixs 60 | self.actions_ixs = actions_ixs 61 | self.dones = dones 62 | 63 | def __getitem__(self, ix: int) -> ILQLElement: 64 | return ILQLElement( 65 | self.input_ids[ix], 66 | self.attention_mask[ix], 67 | self.rewards[ix], 68 | self.states_ixs[ix], 69 | self.actions_ixs[ix], 70 | self.dones[ix], 71 | ) 72 | 73 | def __len__(self) -> int: 74 | return len(self.input_ids) 75 | 76 | def create_loader(self, batch_size: int): 77 | def collate_fn(elems: Iterable[ILQLElement]): 78 | return ILQLBatch( 79 | pad_sequence( 80 | [x.input_ids for x in elems], batch_first=True, padding_value=0 81 | ), 82 | pad_sequence( 83 | [x.attention_mask for x in elems], batch_first=True, padding_value=0 84 | ), 85 | pad_sequence( 86 | [x.rewards for x in elems], batch_first=True, padding_value=0.0 87 | ), 88 | pad_sequence( 89 | [x.states_ixs for x in elems], batch_first=True, padding_value=0 90 | ), 91 | pad_sequence( 92 | [x.actions_ixs for x in elems], batch_first=True, padding_value=0 93 | ), 94 | pad_sequence( 95 | [x.dones for x in elems], batch_first=True, padding_value=0 96 | ), 97 | ) 98 | 99 | return DataLoader( 100 | self, batch_size=batch_size, shuffle=True, collate_fn=collate_fn 101 | ) 102 | -------------------------------------------------------------------------------- /trlx/pipeline/ppo_pipeline.py: -------------------------------------------------------------------------------- 1 | import os, json, time 2 | 3 | from typing import Iterable, Optional 4 | 5 | from torch.nn.utils.rnn import pad_sequence 6 | from torch.utils.data import DataLoader 7 | from torchtyping import TensorType 8 | 9 | from trlx.data.ppo_types import PPORLBatch, PPORLElement 10 | from trlx.pipeline import BaseRolloutStore 11 | 12 | 13 | class PPORolloutStorage(BaseRolloutStore): 14 | """ 15 | Rollout storage for training PPO 16 | """ 17 | 18 | def __init__(self, pad_token_id): 19 | super().__init__() 20 | 21 | self.pad_token_id = pad_token_id 22 | self.history: Iterable[PPORLElement] = [None] 23 | 24 | def push(self, exps: Iterable[PPORLElement]): 25 | self.history += exps 26 | 27 | def clear_history(self): 28 | self.history = [] 29 | 30 | def export_history(self, location: str): 31 | assert os.path.exists(location) 32 | 33 | fpath = os.path.join(location, f"epoch-{str(time.time())}.json") 34 | exp_to_dict = lambda exp: {k: v.cpu().tolist() for k, v in exp.__dict__.items()} 35 | data = [exp_to_dict(exp) for exp in self.history] 36 | with open(fpath, "w") as f: 37 | f.write(json.dumps(data, indent=2)) 38 | 39 | def __getitem__(self, index: int) -> PPORLElement: 40 | return self.history[index] 41 | 42 | def __len__(self) -> int: 43 | return len(self.history) 44 | 45 | def create_loader( 46 | self, 47 | batch_size: int, 48 | shuffle: bool, 49 | ) -> DataLoader: 50 | def collate_fn(elems: Iterable[PPORLElement]): 51 | return PPORLBatch( 52 | # Left padding of already left-padded queries 53 | pad_sequence( 54 | [elem.query_tensor.flip(0) for elem in elems], 55 | padding_value=self.pad_token_id, 56 | batch_first=True, 57 | ).flip(1), 58 | # Right pad the rest, to have a single horizontal query/response split 59 | pad_sequence( 60 | [elem.response_tensor for elem in elems], 61 | padding_value=self.pad_token_id, 62 | batch_first=True, 63 | ), 64 | pad_sequence( 65 | [elem.logprobs for elem in elems], 66 | padding_value=0.0, 67 | batch_first=True, 68 | ), 69 | pad_sequence( 70 | [elem.values for elem in elems], padding_value=0.0, batch_first=True 71 | ), 72 | pad_sequence( 73 | [elem.rewards for elem in elems], 74 | padding_value=0.0, 75 | batch_first=True, 76 | ), 77 | ) 78 | 79 | return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) 80 | -------------------------------------------------------------------------------- /trlx/ray_tune/__init__.py: -------------------------------------------------------------------------------- 1 | from ray import tune 2 | 3 | 4 | def get_param_space(config: dict): 5 | """Get the param space from the config file.""" 6 | 7 | def get_strategy(value): 8 | """Get search space strategy from config. 9 | A search space defines valid values for your hyperparameters and 10 | can specify how these values are sampled. 11 | 12 | Refer to the documentation for more info: 13 | https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 14 | 15 | The user will have to define the search space in the config file by providing 16 | the name of the `strategy` and the `values` to sample from. 17 | 18 | The valid strategies are: 19 | - `uniform` (List) - Samples uniformly between the given bounds. 20 | - `quniform` (List) - Samples uniformly between the given bounds, quantized. 21 | - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. 22 | - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. 23 | - `randn` (List) - Samples from a normal distribution. 24 | - `qrandn` (List) - Samples from a normal distribution, quantized. 25 | - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. 26 | - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. 27 | - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. 28 | - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. 29 | - `choice` (List) - Samples from a discrete set of values. 30 | - `qrandn` (List) - Samples from a normal distribution, quantized. 31 | - `grid_search` (List) - Samples from the given list of values. 32 | 33 | """ 34 | 35 | strategy = value["strategy"] 36 | if strategy == "uniform": 37 | assert isinstance(value["values"], list) 38 | assert len(value["values"]) == 2 39 | return tune.uniform(*value["values"]) 40 | elif strategy == "quniform": 41 | assert isinstance(value["values"], list) 42 | assert len(value["values"]) == 3 43 | return tune.quniform(*value["values"]) 44 | elif strategy == "loguniform": 45 | assert isinstance(value["values"], list) 46 | assert 2 <= len(value["values"]) <= 3 47 | return tune.loguniform(*value["values"]) 48 | elif strategy == "qloguniform": 49 | assert isinstance(value["values"], list) 50 | assert len(value["values"]) == 4 51 | return tune.qloguniform(*value["values"]) 52 | elif strategy == "randn": 53 | assert isinstance(value["values"], list) 54 | assert len(value["values"]) == 2 55 | return tune.randn(*value["values"]) 56 | elif strategy == "qrandn": 57 | assert isinstance(value["values"], list) 58 | assert len(value["values"]) == 3 59 | return tune.qrandn(*value["values"]) 60 | elif strategy == "randint": 61 | assert isinstance(value["values"], list) 62 | assert len(value["values"]) == 2 63 | return tune.randint(*value["values"]) 64 | elif strategy == "qrandint": 65 | assert isinstance(value["values"], list) 66 | assert len(value["values"]) == 3 67 | return tune.qrandint(*value["values"]) 68 | elif strategy == "lograndint": 69 | assert isinstance(value["values"], list) 70 | assert len(value["values"]) == 3 71 | return tune.lograndint(*value["values"]) 72 | elif strategy == "qlograndint": 73 | assert isinstance(value["values"], list) 74 | assert len(value["values"]) == 4 75 | return tune.qlograndint(*value["values"]) 76 | elif strategy == "choice": 77 | assert isinstance(value["values"], list) 78 | return tune.choice(value["values"]) 79 | elif strategy == "grid": 80 | assert isinstance(value["values"], list) 81 | return tune.grid_search(value["values"]) 82 | 83 | for k, v in config.items(): 84 | if k != "tune_config": 85 | config[k] = get_strategy(v) 86 | 87 | return config 88 | 89 | 90 | def get_search_alg(tune_config: dict): 91 | """Initialize the search algorithm and return it. 92 | 93 | Bayesian Optimization is currently supported. 94 | """ 95 | search_alg = tune_config["search_alg"] 96 | 97 | if search_alg == "bayesopt": 98 | try: 99 | from ray.tune.search.bayesopt import BayesOptSearch 100 | except ImportError: 101 | raise ImportError( 102 | "Please pip install bayesian-optimization to use BayesOptSearch." 103 | ) 104 | 105 | assert "metric" in tune_config.keys() and "mode" in tune_config.keys() 106 | "Please specify metric and mode for BayesOptSearch." 107 | 108 | return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) 109 | elif search_alg == "bohb": 110 | try: 111 | from ray.tune.search.bohb import TuneBOHB 112 | except ImportError: 113 | raise ImportError( 114 | "Please pip install hpbandster and ConfigSpace to use TuneBOHB." 115 | ) 116 | 117 | assert "metric" in tune_config.keys() and "mode" in tune_config.keys() 118 | "Please specify metric and mode for TuneBOHB." 119 | 120 | return TuneBOHB() 121 | elif search_alg == "random": 122 | return None 123 | else: 124 | NotImplementedError("Search algorithm not supported.") 125 | 126 | 127 | def get_scheduler(tune_config: dict): 128 | """Initialize the scheduler and return it. 129 | 130 | The schedulers can early terminate bad trials, pause trials, 131 | clone trials, and alter hyperparameters of a running trial. 132 | 133 | Refer to the documentation for more info: 134 | https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers 135 | 136 | Currently available schedulers are: 137 | - `hyperband` - Implements the HyperBand early stopping algorithm. 138 | 139 | """ 140 | scheduler = tune_config["scheduler"] 141 | 142 | if scheduler == "hyperband": 143 | return tune.schedulers.HyperBandScheduler() 144 | elif scheduler == "hyperbandforbohb": 145 | return tune.schedulers.HyperBandForBOHB() 146 | elif scheduler == "fifo": 147 | return None 148 | else: 149 | NotImplementedError("Scheduler not supported.") 150 | 151 | 152 | def get_tune_config(tune_config: dict): 153 | """Get the tune config to initialized `tune.TuneConfig` 154 | to be passed `tune.Tuner`. 155 | """ 156 | if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: 157 | tune_config["search_alg"] = get_search_alg(tune_config) 158 | 159 | if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: 160 | tune_config["scheduler"] = get_scheduler(tune_config) 161 | 162 | # Remove config keys with None values. 163 | tune_config = {k: v for k, v in tune_config.items() if v is not None} 164 | 165 | return tune_config 166 | -------------------------------------------------------------------------------- /trlx/ray_tune/train_funcs.py: -------------------------------------------------------------------------------- 1 | # Find the optimal hyperparameters to generates positive movie 2 | # reviews by tuning a pretrained on IMDB model with a sentiment reward function. 3 | 4 | from datasets import load_dataset 5 | 6 | import trlx 7 | from trlx.data.configs import TRLConfig 8 | 9 | 10 | def ppo_sentiments_train(config: dict): 11 | from transformers import pipeline 12 | 13 | config = TRLConfig.from_dict(config) 14 | 15 | sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1) 16 | 17 | def reward_fn(samples): 18 | outputs = sentiment_fn(samples, return_all_scores=True) 19 | sentiments = [output[1]["score"] for output in outputs] 20 | return sentiments 21 | 22 | # Take few words off of movies reviews as prompts 23 | imdb = load_dataset("imdb", split="train+test") 24 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 25 | 26 | model = trlx.train( 27 | "lvwerra/gpt2-imdb", 28 | reward_fn=reward_fn, 29 | prompts=prompts, 30 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 31 | config=config, 32 | ) 33 | -------------------------------------------------------------------------------- /trlx/ray_tune/wandb.py: -------------------------------------------------------------------------------- 1 | """Utility function to log the results of a Ray Tune experiment to W&B.""" 2 | 3 | import os 4 | import json 5 | import pandas as pd 6 | from pathlib import Path 7 | import math 8 | 9 | import wandb 10 | 11 | wandb.require("report-editing") 12 | import wandb.apis.reports as wb 13 | 14 | ray_info = [ 15 | "done", 16 | "time_this_iter_s", 17 | "timesteps_total", 18 | "episodes_total", 19 | "iterations_since_restore", 20 | "timesteps_since_restore", 21 | "time_since_restore", 22 | "warmup_time", 23 | "should_checkpoint", 24 | "training_iteration", 25 | "timestamp", 26 | "pid", 27 | ] 28 | 29 | 30 | def parse_result(result): 31 | out = {} 32 | for k, v in result.items(): 33 | if ( 34 | isinstance(v, (int, float)) 35 | and not k.startswith("config.") 36 | and k not in ray_info 37 | ): 38 | out[k] = v 39 | 40 | return out 41 | 42 | 43 | def significant(x): 44 | return round(x, 1 - int(math.floor(math.log10(x)))) 45 | 46 | 47 | def log_trials(trial_path: str, project_name: str): 48 | trial_path = Path(trial_path) 49 | files = os.listdir(trial_path) 50 | 51 | trial_paths = [] 52 | for filename in files: 53 | tmp_path = os.path.join(trial_path, filename) 54 | if os.path.isdir(tmp_path): 55 | trial_paths.append(tmp_path) 56 | 57 | for trial in trial_paths: 58 | files = os.listdir(trial) 59 | 60 | # Open params.json and load the configs for that trial. 61 | with open(os.path.join(trial, "params.json"), "r") as f: 62 | params = json.load(f) 63 | 64 | name = ",".join(f"{k}={significant(v)}" for k, v in params.items()) 65 | # Initialize wandb 66 | run = wandb.init( 67 | name=name, 68 | project=project_name, 69 | config=params, 70 | group=trial_path.stem, 71 | job_type="hyperopt", 72 | ) 73 | 74 | # Open result.json and log the metrics to W&B. 75 | with open(os.path.join(trial, "result.json"), "r") as f: 76 | for line in f: 77 | result = json.loads(line) 78 | result.pop("config", None) 79 | wandb.log(parse_result(result)) 80 | 81 | # Close the W&B run. 82 | run.finish() 83 | 84 | 85 | def create_report(project_name, param_space, tune_config, trial_path, best_config=None): 86 | def get_parallel_coordinate(param_space, metric): 87 | column_names = list(param_space.keys()) 88 | columns = [wb.reports.PCColumn(column) for column in column_names] 89 | 90 | return wb.ParallelCoordinatesPlot( 91 | columns=columns + [wb.reports.PCColumn(metric)], 92 | layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, 93 | ) 94 | 95 | def get_param_importance(metric): 96 | return wb.ParameterImportancePlot( 97 | # Get it from the metric name. 98 | with_respect_to=metric, 99 | layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, 100 | ) 101 | 102 | def get_scatter_plot(metric): 103 | return wb.ScatterPlot( 104 | # Get it from the metric name. 105 | title=f"{metric} v. Index", 106 | x="Index", 107 | y=metric, 108 | running_ymin=True, 109 | font_size="small", 110 | layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, 111 | ) 112 | 113 | def get_metrics_with_history(project_name, group_name, entity=None): 114 | entity_project = f"{entity}/{project_name}" if entity else project_name 115 | api = wandb.Api() 116 | runs = api.runs(entity_project) 117 | 118 | runs = sorted( 119 | runs, 120 | key=lambda run: run.summary.get(tune_config["metric"], -math.inf), 121 | reverse=True, 122 | ) 123 | 124 | for run in runs: 125 | if run.group == str(group_name): 126 | history = run.history() 127 | metrics = history.columns 128 | break 129 | 130 | metrics = [metric for metric in metrics if not metric.startswith("_")] 131 | return metrics 132 | 133 | report = wb.Report( 134 | project=project_name, 135 | title=f"Hyperparameter Optimization Report: {trial_path}", 136 | description="This is a report that shows the results of a hyperparameter optimization experiment.", 137 | ) 138 | 139 | report.blocks = [ 140 | wb.P( 141 | "The following plots show the results of the hyperparameter optimization experiment. " 142 | "Use this as a starting point for your analysis. Go in the edit mode to customize the report. " 143 | "Share it with your team to collaborate on the analysis." 144 | ), 145 | wb.H1(text="Analysis"), 146 | wb.P( 147 | "Parallel coordinates chart (top) summarize the relationship between large numbers of hyperparameters " 148 | "and model metrics at a glance. \nThe scatter plot (right) compares the different trials and gives you a " 149 | "insight on how the trials progressed. \nThe parameter importance plot(left) lists the hyperparameters " 150 | "that were the best predictors of, and highly correlated to desirable values of your metrics." 151 | ), 152 | wb.PanelGrid( 153 | panels=[ 154 | get_parallel_coordinate(param_space, tune_config["metric"]), 155 | get_param_importance(tune_config["metric"]), 156 | get_scatter_plot(tune_config["metric"]), 157 | ], 158 | runsets=[ 159 | wb.RunSet(project=project_name).set_filters_with_python_expr( 160 | f'group == "{trial_path}"' 161 | ) 162 | ], 163 | ), 164 | ] 165 | 166 | metrics = get_metrics_with_history( 167 | project_name, 168 | trial_path, 169 | ) 170 | 171 | line_plot_panels = [] 172 | for metric in metrics: 173 | line_plot_panels.append( 174 | wb.LinePlot( 175 | title=f"{metric}", 176 | x="Step", 177 | y=[f"{metric}"], 178 | title_x="Step", 179 | smoothing_show_original=True, 180 | max_runs_to_show=10, 181 | plot_type="line", 182 | font_size="auto", 183 | legend_position="north", 184 | ) 185 | ) 186 | 187 | report.blocks = report.blocks + [ 188 | wb.H1(text="Metrics"), 189 | wb.P( 190 | "The following line plots show the metrics for each trial. Use this to investigate the " 191 | "performance of the model for each trial at the metrics level." 192 | ), 193 | wb.PanelGrid( 194 | panels=line_plot_panels, 195 | runsets=[ 196 | wb.RunSet(project=project_name).set_filters_with_python_expr( 197 | f'group == "{trial_path}"' 198 | ) 199 | ], 200 | ), 201 | ] 202 | 203 | if best_config: 204 | report.blocks = report.blocks + [ 205 | wb.H1(text="Best Config"), 206 | wb.P( 207 | "The code block shown below is the best config found by the hyperparameter " 208 | "optimization experiment according to Ray Tune." 209 | ), 210 | wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), 211 | ] 212 | 213 | report.save() 214 | print(report.url) 215 | -------------------------------------------------------------------------------- /trlx/sweep.py: -------------------------------------------------------------------------------- 1 | # python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py 2 | import wandb 3 | import argparse 4 | from pathlib import Path 5 | 6 | import ray 7 | from ray.air import session 8 | from ray import tune 9 | import importlib 10 | import yaml 11 | 12 | import trlx 13 | from trlx.ray_tune import get_param_space 14 | from trlx.ray_tune import get_tune_config 15 | from trlx.ray_tune.wandb import log_trials, create_report 16 | 17 | from ray.tune.logger import JsonLoggerCallback 18 | from ray.tune.logger import CSVLoggerCallback 19 | 20 | 21 | def tune_function( 22 | train_function, param_space: dict, tune_config: dict, resources: dict 23 | ): 24 | tuner = tune.Tuner( 25 | tune.with_resources(train_function, resources=resources), 26 | param_space=param_space, 27 | tune_config=tune.TuneConfig(**tune_config), 28 | run_config=ray.air.RunConfig( 29 | local_dir="ray_results", callbacks=[CSVLoggerCallback()] 30 | ), 31 | ) 32 | 33 | results = tuner.fit() 34 | project_name = tune_config.get("project_name", "sweep") 35 | 36 | log_trials( 37 | tuner._local_tuner.get_experiment_checkpoint_dir(), 38 | project_name, 39 | ) 40 | 41 | create_report( 42 | project_name, 43 | param_space, 44 | tune_config, 45 | Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, 46 | results.get_best_result().config, 47 | ) 48 | 49 | print("Best hyperparameters found were: ", results.get_best_result().config) 50 | 51 | 52 | if __name__ == "__main__": 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument("script", type=str, help="Path to the script") 55 | parser.add_argument( 56 | "--config", 57 | type=str, 58 | required=True, 59 | help="The config file defining the param_space.", 60 | ) 61 | parser.add_argument( 62 | "--num-cpus", type=int, default=4, help="Number of CPUs to use per exp." 63 | ) 64 | parser.add_argument( 65 | "--num-gpus", type=int, default=1, help="Number of GPUs to use per exp." 66 | ) 67 | parser.add_argument( 68 | "-y", "--assume-yes", action="store_true", help="Don't ask for confirmation" 69 | ) 70 | parser.add_argument( 71 | "--server-address", 72 | type=str, 73 | default=None, 74 | required=False, 75 | help="The address of server to connect to if using Ray Client.", 76 | ) 77 | 78 | args, _ = parser.parse_known_args() 79 | 80 | # Read config and parse it 81 | with open(args.config) as f: 82 | config = yaml.safe_load(f) 83 | tune_config = get_tune_config(config.pop("tune_config")) 84 | param_space = get_param_space(config) 85 | 86 | # Initialize Ray. 87 | if args.server_address: 88 | ray.init(address=f"ray://{args.server_address}") 89 | else: 90 | ray.init() 91 | 92 | resources = { 93 | "cpu": args.num_cpus, 94 | "gpu": args.num_gpus, 95 | } 96 | 97 | print(f'WARNING: Importing main from "{args.script}" and everything along with it') 98 | 99 | if not args.assume_yes: 100 | print("Please confirm y/n: ", end="") 101 | if input() != "y": 102 | print("Exiting") 103 | exit(1) 104 | 105 | # convert a nested path to a module path 106 | script_path = args.script.replace(".py", "").replace("/", ".") 107 | script = importlib.import_module(script_path) 108 | # Register the training function that will be used for training the model. 109 | tune.register_trainable("train_function", script.main) 110 | tune_function(script.main, param_space, tune_config, resources) 111 | 112 | # Shut down Ray. 113 | ray.shutdown() 114 | -------------------------------------------------------------------------------- /trlx/trlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Callable, Iterable, List, Optional, Tuple 3 | 4 | from trlx.data.configs import TRLConfig 5 | from trlx.utils import set_seed 6 | from trlx.utils.loading import get_model, get_orchestrator, get_pipeline 7 | 8 | def train( 9 | model_path: Optional[str] = None, 10 | reward_fn: Optional[Callable] = None, 11 | dataset: Optional[Iterable[Tuple[str, float]]] = None, 12 | prompts: Optional[List[str]] = None, 13 | eval_prompts: Optional[List[str]] = None, 14 | metric_fn: Optional[Callable] = None, 15 | config: Optional[TRLConfig] = None, 16 | split_token: Optional[str] = None, 17 | logit_mask: Optional[List[List[bool]]] = None, 18 | ): 19 | """ 20 | Dispatches online or offline reinforcement training depending on whether a reward function or a list of samples & rewards is given 21 | 22 | Args: 23 | model_path (Optional[str]): Path to either huggingface checkpoint or a local directory 24 | reward_fn (List[str] -> List[float]): Function to rate batches of generated samples 25 | dataset (List[str], List[float]): Lists of samples and rewards 26 | prompts (List[str]): Prompts to sample off from during online training 27 | eval_prompts (List[str]): Prompts to periodically validate training on 28 | metric_fn (Optional[Callable[List[str], List[float]]]): Function to compute statistics on validation samples 29 | config (Optional[TRLConfig]): TRL configuration object to override default settings 30 | split_token (Optional[str]): Split samples in the dataset on prompts and continuations 31 | logit_mask (Optional[List]): Bigram masking matrix 32 | """ 33 | if reward_fn is not None: 34 | if config is None: 35 | config = TRLConfig.load_yaml("configs/ppo_config.yml") 36 | set_seed(config.train.seed) 37 | 38 | if model_path: 39 | config.model.model_path = model_path 40 | 41 | model = get_model(config.model.model_type)(config) 42 | 43 | batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) 44 | prompts = prompts or [model.tokenizer.bos_token] * batch_size 45 | 46 | if eval_prompts is None: 47 | eval_prompts = prompts[:batch_size] 48 | 49 | max_prompt_length = ( 50 | config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 51 | ) 52 | pipeline = get_pipeline(config.train.pipeline)( 53 | prompts, max_prompt_length, model.tokenizer 54 | ) 55 | orch = get_orchestrator(config.train.orchestrator)( 56 | model, pipeline, reward_fn=reward_fn, chunk_size=config.method.chunk_size 57 | ) 58 | orch.make_experience(config.method.num_rollouts) 59 | eval_pipeline = get_pipeline(config.train.pipeline)( 60 | eval_prompts, max_prompt_length, model.tokenizer 61 | ) 62 | model.add_eval_pipeline(eval_pipeline) 63 | 64 | elif dataset is not None: 65 | samples, rewards = dataset 66 | 67 | if len(samples) != len(rewards): 68 | raise ValueError( 69 | f"Number of samples {len(samples)} should match the number of rewards {len(rewards)}" 70 | ) 71 | 72 | if config is None: 73 | config = TRLConfig.load_yaml("configs/ilql_config.yml") 74 | set_seed(config.train.seed) 75 | 76 | if model_path: 77 | config.model.model_path = model_path 78 | 79 | model = get_model(config.model.model_type)( 80 | config=config, 81 | logit_mask=logit_mask, 82 | metric_fn=metric_fn, 83 | ) 84 | 85 | batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) 86 | max_prompt_length = ( 87 | config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 88 | ) 89 | 90 | if eval_prompts is None: 91 | eval_prompts = [model.tokenizer.bos_token] * batch_size 92 | eval_pipeline = get_pipeline(config.train.pipeline)( 93 | eval_prompts, max_prompt_length, model.tokenizer 94 | ) 95 | 96 | orch = get_orchestrator(config.train.orchestrator)( 97 | model, split_token=split_token 98 | ) 99 | orch.make_experience(samples, rewards) 100 | model.add_eval_pipeline(eval_pipeline) 101 | 102 | else: 103 | raise ValueError(f"Either {dataset=} or {reward_fn=} should be given") 104 | 105 | model.learn() 106 | return model 107 | -------------------------------------------------------------------------------- /trlx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | from enum import Enum 5 | from functools import reduce 6 | from typing import Any, Iterable, List, Dict 7 | from dataclasses import is_dataclass 8 | import subprocess 9 | 10 | import numpy as np 11 | import torch 12 | from torch.optim.lr_scheduler import ChainedScheduler, LinearLR 13 | from torchtyping import TensorType 14 | 15 | import accelerate 16 | from accelerate import Accelerator 17 | 18 | 19 | def set_seed(seed: int): 20 | """ 21 | Sets seeds across package dependencies for reproducibility. 22 | """ 23 | seed += int(os.environ.get("RANK", 0)) 24 | random.seed(seed) 25 | np.random.seed(seed) 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | 29 | 30 | def flatten(L: Iterable[Iterable[Any]]) -> Iterable[Any]: 31 | """ 32 | Flatten a list of lists into a single list (i.e. [[1, 2], [3, 4]] -> [1,2,3,4]) 33 | """ 34 | return list(reduce(lambda acc, x: acc + x, L, [])) 35 | 36 | 37 | def chunk(L: Iterable[Any], chunk_size: int) -> List[Iterable[Any]]: 38 | """ 39 | Chunk iterable into list of iterables of given chunk size 40 | """ 41 | return [L[i : i + chunk_size] for i in range(0, len(L), chunk_size)] 42 | 43 | 44 | # Training utils 45 | 46 | 47 | def rampup_decay(ramp_steps, decay_steps, decay_target, opt): 48 | return ChainedScheduler( 49 | [ 50 | LinearLR(opt, decay_target, 1, total_iters=ramp_steps), 51 | LinearLR(opt, 1, decay_target, total_iters=decay_steps), 52 | ] 53 | ) 54 | 55 | 56 | def safe_mkdir(path: str): 57 | """ 58 | Make directory if it doesn't exist, otherwise do nothing 59 | """ 60 | if os.path.isdir(path): 61 | return 62 | os.mkdir(path) 63 | 64 | 65 | def get_distributed_config(accelerator: Accelerator): 66 | """ 67 | Return accelerator distributed config 68 | """ 69 | 70 | accelerate_config = accelerator.state 71 | dist_config = { 72 | "mixed_precision": accelerate_config.mixed_precision, 73 | "num_gpus": accelerate_config.num_processes, 74 | } 75 | 76 | if accelerator.state.deepspeed_plugin is not None: 77 | ds_plugin = accelerator.state.deepspeed_plugin 78 | dist_config.update( 79 | { 80 | "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, 81 | "gradient_clipping": ds_plugin.gradient_clipping, 82 | "zero_stage": ds_plugin.zero_stage, 83 | "offload_optimizer_device": ds_plugin.offload_optimizer_device, 84 | "offload_param_device": ds_plugin.offload_param_device, 85 | } 86 | ) 87 | 88 | return dist_config 89 | 90 | 91 | class OptimizerNames(Enum): 92 | """Supported optimizer names""" 93 | 94 | ADAM: str = "adam" 95 | ADAMW: str = "adamw" 96 | SGD: str = "sgd" 97 | 98 | 99 | def get_optimizer_class(name: str): 100 | """ 101 | Returns the optimizer class with the given name 102 | """ 103 | if name == OptimizerNames.ADAM.value: 104 | return torch.optim.Adam 105 | if name == OptimizerNames.ADAMW.value: 106 | return torch.optim.AdamW 107 | if name == OptimizerNames.SGD.value: 108 | return torch.optim.SGD 109 | supported_optimizers = [o.value for o in OptimizerNames] 110 | raise ValueError( 111 | f"`{name}` is not a supported optimizer. " 112 | f"Supported optimizers are: {supported_optimizers}" 113 | ) 114 | 115 | 116 | class SchedulerNames(Enum): 117 | """Supported scheduler names""" 118 | 119 | COSINE_ANNEALING: str = "cosine_annealing" 120 | 121 | 122 | def get_scheduler_class(name: str): 123 | """ 124 | Returns the scheduler class with the given name 125 | """ 126 | if name == SchedulerNames.COSINE_ANNEALING.value: 127 | return torch.optim.lr_scheduler.CosineAnnealingLR 128 | supported_schedulers = [s.value for s in SchedulerNames] 129 | raise ValueError( 130 | f"`{name}` is not a supported scheduler. " 131 | f"Supported schedulers are: {supported_schedulers}" 132 | ) 133 | 134 | 135 | # Stats 136 | 137 | 138 | class Clock: 139 | """ 140 | Helper object for keeping track of time for computations. 141 | """ 142 | 143 | def __init__(self): 144 | self.start = time.time() 145 | self.total_time = 0 146 | self.total_samples = 0 147 | 148 | def tick(self, samples: int = 0) -> float: 149 | """ 150 | Returns time (s) since last call to tick(). Also records samples processed since last call. 151 | 152 | :param samples: number of samples that have been processed since last call 153 | """ 154 | end = time.time() 155 | delta = end - self.start 156 | self.start = end 157 | 158 | if samples != 0: 159 | self.total_time += delta 160 | self.total_samples += samples 161 | 162 | return delta 163 | 164 | def get_stat(self, n_samp: int = 1000, reset: bool = False): 165 | """ 166 | Returns average time (s) per n_samp samples processed 167 | 168 | :param reset: Reset counts? 169 | """ 170 | sec_per_samp = self.total_time / self.total_samples 171 | 172 | if reset: 173 | self.total_samples = 0 174 | self.total_time = 0 175 | 176 | return sec_per_samp * n_samp 177 | 178 | 179 | # Sampling 180 | 181 | 182 | def topk_mask(xs: TensorType["Batch", "Vocab"], k: int): 183 | """ 184 | Takes batched distribution over tokens and masks out scores for tokens 185 | that are not in the top k for that distribution. 186 | """ 187 | 188 | # Get topk per distribution 189 | # For each dist, getting last value gives k-th largest 190 | mintop = torch.topk(xs, k)[0][:, -1].unsqueeze(-1) 191 | return torch.where(xs < mintop, -np.inf * torch.ones_like(xs), xs) 192 | 193 | 194 | # Sentiment/scores 195 | 196 | 197 | def sentiment_score(sentiments: Iterable[float]): 198 | """ 199 | Return tensor of scores in [-1, 1] from sentiment analysis pipeline output 200 | """ 201 | sentiments = torch.tensor( 202 | [-s["score"] if s["label"] == "NEGATIVE" else s["score"] for s in sentiments] 203 | ) 204 | return sentiments 205 | 206 | 207 | def tree_map(f, tree): 208 | """ 209 | Apply function f to all leaves in tree 210 | """ 211 | if is_dataclass(tree): 212 | return tree.__class__(**{k: tree_map(f, v) for k, v in tree.__dict__.items()}) 213 | elif isinstance(tree, dict): 214 | return {k: tree_map(f, v) for k, v in tree.items()} 215 | elif isinstance(tree, (list, tuple)): 216 | return tree.__class__(tree_map(f, v) for v in tree) 217 | else: 218 | return f(tree) 219 | 220 | 221 | def to_device(tree, device): 222 | """ 223 | Move all tensors in tree to device 224 | """ 225 | return tree_map(lambda x: x.to(device), tree) 226 | 227 | 228 | def filter_non_scalars(xs: Dict) -> Dict: 229 | """ 230 | Trims everything that can't be casted to float 231 | """ 232 | ys = {} 233 | for k, v in xs.items(): 234 | try: 235 | ys[k] = float(v) 236 | except TypeError: 237 | continue 238 | 239 | return ys 240 | 241 | 242 | def get_git_tag() -> str: 243 | """ 244 | Returns commit's short hash and date 245 | """ 246 | output = subprocess.check_output("git log --format='%h/%as' -n1".split()) 247 | branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD".split()) 248 | return f"{branch.decode()[:-1]}/{output.decode()[1:-2]}" 249 | -------------------------------------------------------------------------------- /trlx/utils/loading.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | # Register load models via module import 4 | from trlx.model import _MODELS 5 | from trlx.model.accelerate_ilql_model import AccelerateILQLModel 6 | from trlx.model.accelerate_ppo_model import AcceleratePPOModel 7 | 8 | # Register load orchestrators via module import 9 | from trlx.orchestrator import _ORCH 10 | from trlx.orchestrator.offline_orchestrator import OfflineOrchestrator 11 | from trlx.orchestrator.ppo_orchestrator import PPOOrchestrator 12 | 13 | # Register load pipelines via module import 14 | from trlx.pipeline import _DATAPIPELINE 15 | from trlx.pipeline.offline_pipeline import PromptPipeline 16 | 17 | 18 | def get_model(name: str) -> Callable: 19 | """ 20 | Return constructor for specified model 21 | """ 22 | name = name.lower() 23 | if name in _MODELS: 24 | return _MODELS[name] 25 | else: 26 | raise Exception("Error: Trying to access a model that has not been registered") 27 | 28 | 29 | def get_pipeline(name: str) -> Callable: 30 | """ 31 | Return constructor for specified pipeline 32 | """ 33 | name = name.lower() 34 | if name in _DATAPIPELINE: 35 | return _DATAPIPELINE[name] 36 | else: 37 | raise Exception( 38 | "Error: Trying to access a pipeline that has not been registered" 39 | ) 40 | 41 | 42 | def get_orchestrator(name: str) -> Callable: 43 | """ 44 | Return constructor for specified orchestrator 45 | """ 46 | name = name.lower() 47 | if name in _ORCH: 48 | return _ORCH[name] 49 | else: 50 | raise Exception( 51 | "Error: Trying to access an orchestrator that has not been registered" 52 | ) 53 | -------------------------------------------------------------------------------- /trlx/utils/modeling.py: -------------------------------------------------------------------------------- 1 | from typing import MutableMapping, Tuple, Union 2 | 3 | import functools 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.distributed as dist 8 | import transformers 9 | from typing import Tuple 10 | import numpy as np 11 | 12 | 13 | def make_head(n_embd: int, out: int) -> nn.Sequential: 14 | """Returns a generic sequential MLP head.""" 15 | return nn.Sequential( 16 | nn.Linear(n_embd, n_embd * 2), 17 | nn.ReLU(), 18 | nn.Linear(n_embd * 2, out), 19 | ) 20 | 21 | 22 | def freeze_bottom_causal_layers(model: nn.Module, num_layers_unfrozen: int = 0): 23 | """Freezes the bottom transformer block layers of the specified model.""" 24 | hidden_layers = hf_get_causal_hidden_layers(model) 25 | if num_layers_unfrozen == 0: 26 | hidden_layers_to_freeze = list(hidden_layers) 27 | elif num_layers_unfrozen > 0: 28 | hidden_layers_to_freeze = list(hidden_layers)[:-num_layers_unfrozen] 29 | else: 30 | hidden_layers_to_freeze = [] 31 | for layer in hidden_layers_to_freeze: 32 | layer.requires_grad_(False) 33 | 34 | 35 | # HuggingFace utilities 36 | 37 | 38 | def rhasattr(obj, attr): 39 | """A chain-able attribute version of hasattr. For example, to check if 40 | `obj` has the attribute `foo.bar.baz`, you can use: 41 | `rhasattr(obj, "foo.bar.baz")` 42 | Reference: https://stackoverflow.com/a/67303315 43 | """ 44 | _nested_attrs = attr.split(".") 45 | _curr_obj = obj 46 | for _a in _nested_attrs[:-1]: 47 | if hasattr(_curr_obj, _a): 48 | _curr_obj = getattr(_curr_obj, _a) 49 | else: 50 | return False 51 | return hasattr(_curr_obj, _nested_attrs[-1]) 52 | 53 | 54 | def rgetattr(obj, attr: str, *args) -> object: 55 | """A chain-able attribute version of getattr. For example, to get the 56 | attribute `foo.bar.baz` from `obj`, you can use: 57 | `rgetattr(obj, "foo.bar.baz")` 58 | Reference: https://stackoverflow.com/a/31174427 59 | """ 60 | 61 | def _getattr(obj, attr): 62 | return getattr(obj, attr, *args) 63 | 64 | return functools.reduce(_getattr, [obj] + attr.split(".")) 65 | 66 | 67 | def findattr(obj, attrs: Tuple[str]) -> Union[object, None]: 68 | for attr in attrs: 69 | if rhasattr(obj, attr): 70 | return rgetattr(obj, attr) 71 | raise ValueError(f"Could not find an attribute from `{attrs}` in `{obj}`") 72 | 73 | 74 | def hf_get_causal_base_model(model: transformers.AutoModelForCausalLM) -> nn.Module: 75 | """Returns the causal decoder backbone of the specified HuggingFace transformers 76 | model. 77 | NOTE: Different model configurations have different causal decoder attribute 78 | names. 79 | - transformer: (GPT2LMHeadModel, GPTJConfig) 80 | - model.decoder: (OPTConfig, BloomConfig) 81 | - gpt_neox: (GPTNeoXConfig) 82 | """ 83 | decoder_attrs = ("transformer", "model.decoder", "gpt_neox") 84 | return findattr(model, decoder_attrs) 85 | 86 | 87 | def hf_get_causal_final_norm(model: nn.Module) -> float: 88 | """Returns the final (layer) norm of the specified model. 89 | NOTE: Different model configurations have different final norm attribute names. 90 | - transformer.ln_f: (GPT2LMHeadModel, GPTJForCausalLM) 91 | - model.decoder.final_layer_norm: (OPTForCausalLM) 92 | - gpt_neox.layers.final_layer_norm: (GPTNeoXForCausalLM) 93 | """ 94 | norm_attrs = ( 95 | "transformer.ln_f", 96 | "model.decoder.final_layer_norm", 97 | "gpt_neox.final_layer_norm", 98 | ) 99 | return findattr(model, norm_attrs) 100 | 101 | 102 | def hf_get_causal_hidden_layers(model: nn.Module) -> Tuple[nn.Module]: 103 | """Returns the hidden layers of the specified model. 104 | NOTE: Different model configurations have different hidden layer attribute names. 105 | - transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM) 106 | - model.decoder.layers: (OPTForCausalLM) 107 | - gpt_neox.layers: (GPTNeoXForCausalLM) 108 | """ 109 | hidden_layers_attrs = ( 110 | "transformer.h", 111 | "model.decoder.layers", 112 | "gpt_neox.layers", 113 | ) 114 | return findattr(model, hidden_layers_attrs) 115 | 116 | 117 | def hf_get_lm_head(model: transformers.AutoModelForCausalLM) -> nn.Module: 118 | """Returns the language modeling (lm) head of the specified HuggingFace 119 | transformers model. 120 | NOTE: Different model configurations have different `lm_head` attribute names. 121 | - lm_head: (GPT2LMHeadModel, BloomForCausalLM) 122 | - embed_out: (GPTNeoXForCausalLM) 123 | """ 124 | return model.get_output_embeddings() 125 | 126 | 127 | def hf_get_hidden_size(config: transformers.PretrainedConfig) -> int: 128 | """Returns the hidden layer dimensionality of the model architecture specified 129 | by the HuggingFace transformers config. 130 | NOTE: Different model configurations have different hidden size attribute names. 131 | - hidden_size: (OPTConfig, BloomConfig) 132 | - n_embd: (GPT2Config, GPTJConfig) 133 | - d_model: (PegasusConfig, XLNetConfig) 134 | """ 135 | hidden_size_attrs = ("hidden_size", "n_embd", "d_model") 136 | return findattr(config, hidden_size_attrs) 137 | 138 | 139 | def hf_get_num_hidden_layers(config: transformers.PretrainedConfig) -> int: 140 | """Returns the number of hidden layers in the model architecture specified 141 | by the HuggingFace transformers config. 142 | NOTE: Different model configurations have different number-of-layers attribute 143 | names. 144 | - num_hidden_layers: (GPTNeoXConfig, OPTConfig) 145 | - n_layer: (GPT2Config, GPTJConfig, BloomConfig) 146 | """ 147 | num_hidden_layers_attrs = ("num_hidden_layers", "n_layer") 148 | return findattr(config, num_hidden_layers_attrs) 149 | 150 | 151 | def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]: 152 | """ 153 | Computes element-wise mean and variance of the tensor across processes 154 | """ 155 | sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) 156 | dist.all_reduce(sum_and_count, dist.ReduceOp.SUM) 157 | global_sum, count = sum_and_count 158 | global_mean = global_sum / count 159 | 160 | sum_var = torch.sum((xs - global_mean) ** 2) 161 | dist.all_reduce(sum_var, dist.ReduceOp.SUM) 162 | global_var = sum_var / count 163 | return global_mean, global_var, count 164 | 165 | 166 | def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: 167 | """Whitens values""" 168 | if distributed and dist.is_initialized(): 169 | mean, var, _ = get_global_statistics(xs) 170 | else: 171 | var, mean = torch.var_mean(xs) 172 | 173 | whitened = (xs - mean) * torch.rsqrt(var + 1e-8) 174 | if not shift_mean: 175 | whitened += mean 176 | return whitened 177 | 178 | 179 | def logprobs_from_logits(logits, labels): 180 | """Compute log softmax values from logits.""" 181 | logprobs = F.log_softmax(logits, dim=-1) 182 | logprobs_labels = torch.gather(logprobs, dim=-1, index=labels.unsqueeze(-1)) 183 | return logprobs_labels.squeeze(-1) 184 | 185 | 186 | def flatten_dict( 187 | d: Union[dict, MutableMapping], 188 | parent_key: str = "", 189 | sep: str = "/", 190 | ) -> dict: 191 | # From: https://stackoverflow.com/a/6027615 192 | items = [] 193 | for k, v in d.items(): 194 | new_key = parent_key + sep + k if parent_key else k 195 | if isinstance(v, MutableMapping): 196 | items.extend(flatten_dict(v, new_key, sep=sep).items()) 197 | else: 198 | items.append((new_key, v)) 199 | return dict(items) 200 | 201 | 202 | def get_tensor_stats(xs: torch.Tensor, mask: torch.Tensor, n: int): 203 | mean = (xs * mask).sum() / n 204 | return dict( 205 | mean=mean, 206 | min=torch.where(mask.bool(), xs, np.inf).min(), 207 | max=torch.where(mask.bool(), xs, -np.inf).max(), 208 | std=torch.sqrt(((xs - mean) * mask).pow(2).sum() / n), 209 | ) 210 | 211 | 212 | class RunningMoments: 213 | def __init__(self): 214 | """ 215 | Calculates the running mean and standard deviation of a data stream. Modified version of 216 | https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py 217 | """ 218 | self.mean = 0 219 | self.std = 1 220 | self.var = 1 221 | self.count = 1e-24 222 | 223 | def update(self, xs: torch.Tensor) -> Tuple[float, float]: 224 | """Updates running moments from batch's moments computed across ranks""" 225 | if dist.is_initialized(): 226 | xs_mean, xs_var, xs_count = get_global_statistics(xs) 227 | else: 228 | xs_count = xs.numel() 229 | xs_var, xs_mean = torch.var_mean(xs, unbiased=False) 230 | 231 | delta = xs_mean - self.mean 232 | tot_count = self.count + xs_count 233 | 234 | new_sum = xs_var * xs_count 235 | # correct old_sum deviation accounting for the new mean 236 | old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count 237 | tot_sum = old_sum + new_sum 238 | 239 | self.mean += delta * xs_count / tot_count 240 | self.var = tot_sum / tot_count 241 | self.std = (self.var * tot_count / (tot_count - 1)).sqrt() 242 | self.count = tot_count 243 | 244 | return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() 245 | --------------------------------------------------------------------------------