├── tests ├── __init__.py ├── test_configs.py ├── test_trainers.py └── test_utils.py ├── examples ├── __init__.py ├── .DS_Store ├── reward_model │ ├── ds_config_gpt_j.json │ ├── reward_model.py │ ├── gptj_reward_test.py │ └── train_reward_model_gptj.py ├── triton_config.pbtxt ├── sft_hh.py ├── to_triton.py ├── ilql_hh.py ├── apa_off_hh.py ├── ppo_tldr.py ├── apa_hh.py ├── ppo_hh.py └── apa_tldr.py ├── trlx ├── models │ ├── __init__.py │ └── .DS_Store ├── __init__.py ├── .DS_Store ├── trainer │ ├── .DS_Store │ ├── accelerate_sft_trainer.py │ └── __init__.py ├── pipeline │ ├── .DS_Store │ ├── __init__.py │ ├── ppo_pipeline.py │ ├── offline_pipeline.py │ └── sql_on_pipeline.py ├── data │ ├── __init__.py │ ├── method_configs.py │ ├── accelerate_base_datatypes.py │ ├── ppo_types.py │ ├── ilql_types.py │ ├── default_configs.py │ └── configs.py ├── ray_tune │ ├── train_funcs.py │ ├── __init__.py │ └── wandb.py ├── utils │ ├── loading.py │ └── __init__.py ├── sweep.py └── trlx.py ├── setup.py ├── docs ├── .DS_Store ├── requirements.txt ├── source │ ├── trainer.rst │ ├── pipeline.rst │ ├── index.rst │ ├── configs.rst │ ├── examples.rst │ ├── data.rst │ └── conf.py ├── Makefile └── make.bat ├── configs ├── .DS_Store ├── accelerate │ ├── ddp.yaml │ ├── zero2-fp16.yaml │ ├── zero2-bf16.yaml │ └── zero3.yaml ├── sweeps │ ├── ppo_sweep.yml │ └── ilql_sweep.yml ├── test_config.yml └── nemo_configs │ ├── megatron_20b.yaml │ └── megatron_65b.yaml ├── pyproject.toml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── SUPPORT.md ├── setup.cfg ├── requirements.txt ├── SECURITY.md └── README.md /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /examples/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /trlx/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup() 4 | -------------------------------------------------------------------------------- /trlx/__init__.py: -------------------------------------------------------------------------------- 1 | from .trlx import train 2 | from .utils import logging 3 | -------------------------------------------------------------------------------- /docs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/docs/.DS_Store -------------------------------------------------------------------------------- /trlx/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/trlx/.DS_Store -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /examples/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/examples/.DS_Store -------------------------------------------------------------------------------- /trlx/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/trlx/models/.DS_Store -------------------------------------------------------------------------------- /trlx/trainer/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/trlx/trainer/.DS_Store -------------------------------------------------------------------------------- /trlx/pipeline/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/microsoft/RLHF-APA/HEAD/trlx/pipeline/.DS_Store -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.isort] 6 | multi_line_output = 3 7 | profile = "black" 8 | 9 | [tool.black] 10 | line-length = 120 11 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.12.0 2 | datasets==2.4.0 3 | deepspeed==0.7.3 4 | einops==0.4.1 5 | numpy==1.23.2 6 | sphinx==4.0.0 7 | sphinx_rtd_theme 8 | torchtyping 9 | tqdm==4.64.0 10 | transformers==4.21.2 11 | wandb==0.13.2 12 | -------------------------------------------------------------------------------- /configs/accelerate/ddp.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | downcast_bf16: no 5 | dynamo_backend: 'NO' 6 | fsdp_config: {} 7 | gpu_ids: all 8 | machine_rank: 0 9 | main_training_function: main 10 | megatron_lm_config: {} 11 | mixed_precision: bf16 12 | num_machines: 1 13 | num_processes: 8 14 | rdzv_backend: static 15 | same_network: true 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /configs/sweeps/ppo_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "reward/mean" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 32 7 | 8 | # https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 9 | lr: 10 | strategy: "loguniform" 11 | values: [0.00001, 0.01] 12 | init_kl_coef: 13 | strategy: "uniform" 14 | values: [0, 0.2] 15 | vf_coef: 16 | strategy: "uniform" 17 | values: [0.5, 2] 18 | -------------------------------------------------------------------------------- /configs/sweeps/ilql_sweep.yml: -------------------------------------------------------------------------------- 1 | tune_config: 2 | mode: "max" 3 | metric: "metrics/sentiments" 4 | search_alg: "random" 5 | scheduler: "fifo" 6 | num_samples: 32 7 | 8 | lr: 9 | strategy: "loguniform" 10 | values: [0.00001, 0.01] 11 | tau: 12 | strategy: "uniform" 13 | values: [0.6, 0.9] 14 | steps_for_target_q_sync: 15 | strategy: "choice" 16 | values: [1, 5, 10] 17 | alpha: 18 | strategy: "loguniform" 19 | values: [0.001, 1.0] 20 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Microsoft Open Source Code of Conduct 2 | 3 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 4 | 5 | Resources: 6 | 7 | - [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/) 8 | - [Microsoft Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) 9 | - Contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with questions or concerns 10 | -------------------------------------------------------------------------------- /configs/accelerate/zero2-fp16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: no 12 | dynamo_backend: 'NO' 13 | fsdp_config: {} 14 | machine_rank: 0 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | mixed_precision: fp16 18 | num_machines: 1 19 | num_processes: 8 20 | rdzv_backend: static 21 | same_network: true 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/accelerate/zero2-bf16.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: "cpu" 7 | offload_param_dvice: "cpu" 8 | zero3_init_flag: false 9 | zero_stage: 2 10 | distributed_type: DEEPSPEED 11 | downcast_bf16: no 12 | dynamo_backend: 'NO' 13 | fsdp_config: {} 14 | machine_rank: 0 15 | main_training_function: main 16 | megatron_lm_config: {} 17 | mixed_precision: bf16 18 | num_machines: 1 19 | num_processes: 1 20 | rdzv_backend: static 21 | same_network: true 22 | use_cpu: false 23 | -------------------------------------------------------------------------------- /configs/accelerate/zero3.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: 3 | deepspeed_multinode_launcher: standard 4 | gradient_accumulation_steps: 1 5 | gradient_clipping: 1.0 6 | offload_optimizer_device: none 7 | offload_param_device: none 8 | zero3_init_flag: true 9 | zero3_save_16bit_model: true 10 | zero_stage: 3 11 | distributed_type: DEEPSPEED 12 | downcast_bf16: no 13 | dynamo_backend: 'NO' 14 | fsdp_config: {} 15 | machine_rank: 0 16 | main_training_function: main 17 | megatron_lm_config: {} 18 | mixed_precision: bf16 19 | num_machines: 1 20 | num_processes: 8 21 | rdzv_backend: static 22 | same_network: true 23 | use_cpu: false 24 | -------------------------------------------------------------------------------- /docs/source/trainer.rst: -------------------------------------------------------------------------------- 1 | .. _trainers: 2 | 3 | RL Trainers 4 | ******************* 5 | 6 | RL Trainers are what you're training with trlX. Currently, we support PPO and ILQL. 7 | Note that new trainers must be registered with ``trlx.trainer.register_trainer``. 8 | 9 | **General** 10 | 11 | .. autoclass:: trlx.trainer.BaseRLTrainer 12 | :members: 13 | 14 | .. autoclass:: trlx.trainer.accelerate_base_trainer.AccelerateRLTrainer 15 | :members: 16 | 17 | **PPO** 18 | 19 | .. autoclass:: trlx.trainer.accelerate_ppo_trainer.AcceleratePPOTrainer 20 | :members: 21 | 22 | **ILQL** 23 | 24 | .. autoclass:: trlx.trainer.accelerate_ilql_trainer.AccelerateILQLTrainer 25 | :members: 26 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line, and also 5 | # from the environment for the first two. 6 | SPHINXOPTS ?= 7 | SPHINXBUILD ?= sphinx-build 8 | SOURCEDIR = source 9 | BUILDDIR = build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /docs/source/pipeline.rst: -------------------------------------------------------------------------------- 1 | .. _pipeline: 2 | 3 | Pipelines 4 | ************************ 5 | 6 | Pipelines are how you read from a dataset with trlX. Rollout stores are how models store experiences created 7 | for them. It is these experiences in their rollout store that they are trained on. 8 | 9 | **General** 10 | 11 | .. autoclass:: trlx.pipeline.BasePipeline 12 | :members: 13 | 14 | .. autoclass:: trlx.pipeline.BaseRolloutStore 15 | :members: 16 | 17 | **PPO** 18 | 19 | .. autoclass:: trlx.pipeline.ppo_pipeline.PPORolloutStorage 20 | :members: 21 | 22 | **ILQL** 23 | 24 | .. autoclass:: trlx.pipeline.offline_pipeline.PromptPipeline 25 | :members: 26 | 27 | .. autoclass:: trlx.pipeline.offline_pipeline.ILQLRolloutStorage 28 | :members: 29 | -------------------------------------------------------------------------------- /docs/source/index.rst: -------------------------------------------------------------------------------- 1 | .. trlX documentation master file, created by 2 | sphinx-quickstart on Mon Oct 3 21:21:33 2022. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Welcome to trlX's documentation! 7 | ================================ 8 | trlX is a library made for training large language models using reinforcement learning. It 9 | currently supports training using PPO or ILQL for models up to 20B using Accelerate. 10 | 11 | .. toctree:: 12 | :maxdepth: 2 13 | :caption: Contents: 14 | 15 | data 16 | models 17 | configs 18 | pipeline 19 | examples 20 | 21 | Indices and tables 22 | ================== 23 | 24 | * :ref:`genindex` 25 | * :ref:`modindex` 26 | * :ref:`search` 27 | -------------------------------------------------------------------------------- /trlx/data/__init__.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | @dataclass 8 | class GeneralElement: 9 | """ 10 | General element outputted by a data pipeline 11 | """ 12 | 13 | pass 14 | 15 | 16 | @dataclass 17 | class RLElement: 18 | """ 19 | Batch element for RL model 20 | """ 21 | 22 | state: Iterable[str] = None # Context/prompts 23 | action: TensorType["N"] = None # Tokens generated by model given prompts 24 | reward: float = None # Reward obtained for that generation 25 | 26 | 27 | @dataclass 28 | class BatchElement: 29 | """ 30 | General batch element for any transformer to use in its forward pass 31 | """ 32 | 33 | tokens: TensorType["BATCH", "SEQ_LEN"] 34 | masks: TensorType["BATCH", "SEQ_LEN"] 35 | -------------------------------------------------------------------------------- /examples/reward_model/ds_config_gpt_j.json: -------------------------------------------------------------------------------- 1 | { 2 | "train_batch_size": "auto", 3 | "fp16": { 4 | "enabled": true, 5 | "min_loss_scale": 1, 6 | "opt_level": "O2" 7 | }, 8 | "zero_optimization": { 9 | "stage": 2, 10 | "offload_param": { 11 | "device": "cpu" 12 | }, 13 | "offload_optimizer": { 14 | "device": "cpu" 15 | }, 16 | "allgather_partitions": true, 17 | "allgather_bucket_size": 5e8, 18 | "contiguous_gradients": true 19 | }, 20 | "optimizer": { 21 | "type": "AdamW", 22 | "params": { 23 | "lr": 1e-5, 24 | "betas": [ 25 | 0.9, 26 | 0.999 27 | ], 28 | "eps": 1e-08 29 | } 30 | }, 31 | "scheduler": { 32 | "type": "WarmupLR", 33 | "params": { 34 | "warmup_min_lr": 0, 35 | "warmup_max_lr": "auto", 36 | "warmup_num_steps": 100 37 | } 38 | } 39 | } 40 | -------------------------------------------------------------------------------- /examples/triton_config.pbtxt: -------------------------------------------------------------------------------- 1 | name: "${model_name}" 2 | backend: "pytorch" 3 | default_model_filename: "traced-model.pt" 4 | max_batch_size: 25 5 | 6 | parameters { 7 | key: "model_name" 8 | value: { 9 | string_value: "${model_name}" 10 | } 11 | } 12 | 13 | instance_group [ 14 | { 15 | count: 1 16 | kind: KIND_GPU 17 | gpus: [0] 18 | } 19 | ] 20 | 21 | input [ 22 | { 23 | name: "input_ids" 24 | data_type: TYPE_INT32 25 | dims: [-1] 26 | } 27 | ] 28 | 29 | output [ 30 | { 31 | name: "rewards" 32 | data_type: TYPE_FP16 33 | dims: [-1] 34 | } 35 | ] 36 | 37 | parameters { 38 | key: "data_type" 39 | value: { 40 | string_value: "fp16" 41 | } 42 | } 43 | 44 | parameters: { 45 | key: "INFERENCE_MODE" 46 | value: { 47 | string_value: "true" 48 | } 49 | } 50 | 51 | version_policy: {specific: {versions: [1]}} 52 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=source 11 | set BUILDDIR=build 12 | 13 | if "%1" == "" goto help 14 | 15 | %SPHINXBUILD% >NUL 2>NUL 16 | if errorlevel 9009 ( 17 | echo. 18 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 19 | echo.installed, then set the SPHINXBUILD environment variable to point 20 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 21 | echo.may add the Sphinx directory to PATH. 22 | echo. 23 | echo.If you don't have Sphinx installed, grab it from 24 | echo.https://www.sphinx-doc.org/ 25 | exit /b 1 26 | ) 27 | 28 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 29 | goto end 30 | 31 | :help 32 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% 33 | 34 | :end 35 | popd 36 | -------------------------------------------------------------------------------- /docs/source/configs.rst: -------------------------------------------------------------------------------- 1 | .. _configs: 2 | 3 | Configs 4 | ************************ 5 | 6 | Training a model in TRL will require you to set several configs: 7 | ModelConfig, which contains general info on the model being trained. TrainConfig, which contains things like 8 | training hyperparameters. And finally, MethodConfig, which contains hyperparameters or settings for 9 | the specific method being used (i.e. ILQL or PPO) 10 | 11 | 12 | **General** 13 | 14 | .. autoclass:: trlx.data.configs.TRLConfig 15 | :members: 16 | 17 | .. autoclass:: trlx.data.configs.ModelConfig 18 | :members: 19 | 20 | .. autoclass:: trlx.data.configs.TrainConfig 21 | :members: 22 | 23 | .. autoclass:: trlx.data.method_configs.MethodConfig 24 | :members: 25 | 26 | **PPO** 27 | 28 | .. autoclass:: trlx.data.method_configs.PPOConfig 29 | :members: 30 | 31 | **ILQL** 32 | 33 | .. autoclass:: trlx.data.method_configs.ILQLConfig 34 | :members: 35 | -------------------------------------------------------------------------------- /docs/source/examples.rst: -------------------------------------------------------------------------------- 1 | .. _examples: 2 | 3 | Examples 4 | ************************ 5 | 6 | In the ``examples`` folder you can find several example training tasks. Check 7 | the configs folder for the associated configs files. ``examples.randomwalks`` 8 | does offline reinforcement on a set of graph random walks to stitch shortest 9 | paths to some destination. ``examples.simulacra`` optimizes prompts by using 10 | prompts-ratings dataset (https://github.com/JD-P/simulacra-aesthetic-captions). 11 | ``examples.architext`` tries to optimize designs represented textually by 12 | minimazing number of rooms (pretrained model is under a license on hf). 13 | ``examples.ilql_sentiments`` and ``examples.ppo_sentiments`` train to generate 14 | movie reviews with a positive sentiment, in offline setting – by fitting to IMDB 15 | dataset sentiment scores, and in online setting – by sampling finetuned on IMDB 16 | model and rating samples with learned sentiment reward model, You can tweak 17 | these scripts to your liking and tune hyperparameters to your problem if you 18 | wish to use trlx for some custom task. 19 | -------------------------------------------------------------------------------- /trlx/ray_tune/train_funcs.py: -------------------------------------------------------------------------------- 1 | # Find the optimal hyperparameters to generates positive movie 2 | # reviews by tuning a pretrained on IMDB model with a sentiment reward function. 3 | 4 | from datasets import load_dataset 5 | 6 | import trlx 7 | from trlx.data.configs import TRLConfig 8 | 9 | 10 | def ppo_sentiments_train(config: dict): 11 | from transformers import pipeline 12 | 13 | config = TRLConfig.from_dict(config) 14 | 15 | sentiment_fn = pipeline("sentiment-analysis", "lvwerra/distilbert-imdb", device=-1) 16 | 17 | def reward_fn(samples, **kwargs): 18 | outputs = sentiment_fn(samples, return_all_scores=True) 19 | sentiments = [output[1]["score"] for output in outputs] 20 | return sentiments 21 | 22 | # Take few words off of movies reviews as prompts 23 | imdb = load_dataset("imdb", split="train+test") 24 | prompts = [" ".join(review.split()[:4]) for review in imdb["text"]] 25 | 26 | trlx.train( 27 | "lvwerra/gpt2-imdb", 28 | reward_fn=reward_fn, 29 | prompts=prompts, 30 | eval_prompts=["I don't know much about Hungarian underground"] * 64, 31 | config=config, 32 | ) 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Microsoft Corporation. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE 22 | -------------------------------------------------------------------------------- /SUPPORT.md: -------------------------------------------------------------------------------- 1 | # TODO: The maintainer of this repo has not yet edited this file 2 | 3 | **REPO OWNER**: Do you want Customer Service & Support (CSS) support for this product/project? 4 | 5 | - **No CSS support:** Fill out this template with information about how to file issues and get help. 6 | - **Yes CSS support:** Fill out an intake form at [aka.ms/onboardsupport](https://aka.ms/onboardsupport). CSS will work with/help you to determine next steps. 7 | - **Not sure?** Fill out an intake as though the answer were "Yes". CSS will help you decide. 8 | 9 | *Then remove this first heading from this SUPPORT.MD file before publishing your repo.* 10 | 11 | # Support 12 | 13 | ## How to file issues and get help 14 | 15 | This project uses GitHub Issues to track bugs and feature requests. Please search the existing 16 | issues before filing new issues to avoid duplicates. For new issues, file your bug or 17 | feature request as a new Issue. 18 | 19 | For help and questions about using this project, please **REPO MAINTAINER: INSERT INSTRUCTIONS HERE 20 | FOR HOW TO ENGAGE REPO OWNERS OR COMMUNITY FOR HELP. COULD BE A STACK OVERFLOW TAG OR OTHER 21 | CHANNEL. WHERE WILL YOU HELP PEOPLE?**. 22 | 23 | ## Microsoft Support Policy 24 | 25 | Support for this **PROJECT or PRODUCT** is limited to the resources listed above. 26 | -------------------------------------------------------------------------------- /docs/source/data.rst: -------------------------------------------------------------------------------- 1 | .. _data: 2 | 3 | Data Elements 4 | ************************ 5 | 6 | All of the major Carper projects: trlX, CHEESE, and magiCARP use 7 | dataclasses corresponding to batches of data to communicate data between models and different 8 | components. trlX is no different, though it has many different dataclasses for 9 | different components like training or inference. Currently, we support PPO and ILQL, which 10 | each demand different kinds of data during training. 11 | 12 | 13 | **Basic Data Elements for Accelerate** 14 | 15 | .. autoclass:: trlx.data.accelerate_base_datatypes.PromptElement 16 | :members: 17 | 18 | .. autoclass:: trlx.data.accelerate_base_datatypes.PromptBatch 19 | :members: 20 | 21 | .. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLElement 22 | :members: 23 | 24 | .. autoclass:: trlx.data.accelerate_base_datatypes.AccelerateRLBatchElement 25 | :members: 26 | 27 | **Data Elements for PPO** 28 | 29 | .. autoclass:: trlx.data.ppo_types.PPORLElement 30 | :members: 31 | 32 | .. autoclass:: trlx.data.ppo_types.PPORLBatch 33 | :members: 34 | 35 | **Data Elements for ILQL** 36 | 37 | .. autoclass:: trlx.data.ilql_types.ILQLElement 38 | :members: 39 | 40 | .. autoclass:: trlx.data.ilql_types.ILQLBatch 41 | :members: 42 | -------------------------------------------------------------------------------- /trlx/data/method_configs.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from dataclasses import dataclass 3 | from typing import Any, Dict 4 | 5 | # specifies a dictionary of method configs 6 | _METHODS: Dict[str, Any] = {} # registry 7 | 8 | 9 | def register_method(name): 10 | """Decorator used register a method config 11 | Args: 12 | name: Name of the method 13 | """ 14 | 15 | def register_class(cls, name): 16 | _METHODS[name] = cls 17 | setattr(sys.modules[__name__], name, cls) 18 | return cls 19 | 20 | if isinstance(name, str): 21 | name = name.lower() 22 | return lambda c: register_class(c, name) 23 | 24 | cls = name 25 | name = cls.__name__ 26 | register_class(cls, name.lower()) 27 | 28 | return cls 29 | 30 | 31 | @dataclass 32 | @register_method 33 | class MethodConfig: 34 | """ 35 | Config for a certain RL method. 36 | 37 | :param name: Name of the method 38 | :type name: str 39 | """ 40 | 41 | name: str 42 | 43 | @classmethod 44 | def from_dict(cls, config: Dict[str, Any]): 45 | return cls(**config) 46 | 47 | 48 | def get_method(name: str) -> MethodConfig: 49 | """ 50 | Return constructor for specified method config 51 | """ 52 | name = name.lower() 53 | if name in _METHODS: 54 | return _METHODS[name] 55 | else: 56 | raise Exception("Error: Trying to access a method that has not been registered") 57 | -------------------------------------------------------------------------------- /tests/test_configs.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | from trlx.data.configs import TRLConfig 5 | 6 | 7 | def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]: 8 | """Returns all sub-directories of `dir` named `configs`.""" 9 | config_dirs = [] 10 | for root, dirs, _ in os.walk(dir): 11 | for d in dirs: 12 | if d == config_dir_name: 13 | config_dirs.append(os.path.join(root, d)) 14 | return config_dirs 15 | 16 | 17 | def _get_yaml_filepaths(dir: str) -> List[str]: 18 | """Returns a list of `yml` filepaths in `dir`.""" 19 | filepaths = [] 20 | for file in os.listdir(dir): 21 | if file.endswith(".yml"): 22 | filepaths.append(os.path.join(dir, file)) 23 | return filepaths 24 | 25 | 26 | def test_repo_trl_configs(): 27 | """Tests to ensure all default configs in the repository are valid.""" 28 | config_dirs = ["configs", *_get_config_dirs("examples")] 29 | config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior 30 | for file in config_files: 31 | assert os.path.isfile(file), f"Config file {file} does not exist." 32 | assert file.endswith(".yml"), f"Config file {file} is not a yaml file." 33 | try: 34 | config = TRLConfig.load_yaml(file) 35 | assert ( 36 | config.train.entity_name is None 37 | ), f"Unexpected entity name in config file `{file}`. Remove before pushing to repo." 38 | except Exception as e: 39 | assert False, f"Failed to load config file `{file}` with error `{e}`" 40 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = trlx 3 | author = Alex Havrilla 4 | version = 0.5.0 5 | url = https://github.com/CarperAI/trlx 6 | description = A repo for distributed training of language models with Reinforcement Learning via Human Feedback (RLHF) 7 | long_description = file: README.md 8 | long_description_content_type = text/markdown 9 | license = MIT 10 | 11 | [options] 12 | packages = find: 13 | install_requires = 14 | accelerate>=0.12.0 15 | attrs>=22.1.0 16 | cattrs>=22.2.0 17 | datasets 18 | deepspeed>=0.7.3 19 | einops>=0.4.1 20 | numpy>=1.21.6 21 | torchtyping 22 | transformers>=4.21.2 23 | tqdm 24 | rich 25 | wandb>=0.13.5 26 | ray>=2.0.1 27 | tabulate>=0.9.0 28 | networkx 29 | tritonclient 30 | 31 | [options.extras_require] 32 | bnb = bitsandbytes 33 | dev = 34 | black 35 | isort 36 | flake8 37 | pre-commit 38 | pytest 39 | pytest-cov 40 | 41 | [options.packages.find] 42 | exclude = 43 | docs* 44 | tests* 45 | 46 | [flake8] 47 | max-complexity = 10 48 | max-line-length = 127 49 | # flake8 error codes: https://flake8.pycqa.org/en/latest/user/error-codes.html 50 | # pycodestyle codes: https://pycodestyle.pycqa.org/en/latest/intro.html#error-codes 51 | # E203 # whitespace before ‘,’, ‘;’, or ‘:’ 52 | # E741 # do not use variables named ‘l’, ‘O’, or ‘I’ 53 | # F401 # module imported but unused 54 | # F821 # undefined name name 55 | # W503 # line break before binary operator 56 | # W605 # invalid escape sequence ‘x’ 57 | ignore = 58 | E203 59 | E741 60 | F821 61 | W503 62 | W605 63 | per-file-ignores = __init__.py:F401,loading.py:F401 64 | exclude = 65 | .git 66 | __pycache__ 67 | docs/source/conf.py 68 | build 69 | dist 70 | -------------------------------------------------------------------------------- /trlx/utils/loading.py: -------------------------------------------------------------------------------- 1 | from typing import Callable 2 | 3 | # Register load pipelines via module import 4 | from trlx.pipeline import _DATAPIPELINE 5 | from trlx.pipeline.offline_pipeline import PromptPipeline 6 | 7 | # Register load trainers via module import 8 | from trlx.trainer import _TRAINERS, register_trainer 9 | from trlx.trainer.accelerate_ilql_trainer import AccelerateILQLTrainer 10 | from trlx.trainer.accelerate_ppo_trainer import AcceleratePPOTrainer 11 | from trlx.trainer.accelerate_sppo_trainer import AccelerateSPPOTrainer 12 | from trlx.trainer.accelerate_sft_trainer import AccelerateSFTTrainer 13 | from trlx.trainer.accelerate_sqloff_trainer import AccelerateSQLOffTrainer 14 | # try: 15 | # from trlx.trainer.nemo_ilql_trainer import NeMoILQLTrainer 16 | # except ImportError: 17 | # # NeMo is not installed 18 | # def _trainer_unavailble(name): 19 | # def log_error(*args, **kwargs): 20 | # raise ImportError(f"Unable to import NeMo so {name} is unavailable") 21 | 22 | # return register_trainer(name)(log_error) 23 | 24 | # _trainer_unavailble("NeMoILQLTrainer") 25 | 26 | 27 | def get_trainer(name: str) -> Callable: 28 | """ 29 | Return constructor for specified RL model trainer 30 | """ 31 | name = name.lower() 32 | if name in _TRAINERS: 33 | return _TRAINERS[name] 34 | else: 35 | raise Exception("Error: Trying to access a trainer that has not been registered") 36 | 37 | 38 | def get_pipeline(name: str) -> Callable: 39 | """ 40 | Return constructor for specified pipeline 41 | """ 42 | name = name.lower() 43 | if name in _DATAPIPELINE: 44 | return _DATAPIPELINE[name] 45 | else: 46 | raise Exception("Error: Trying to access a pipeline that has not been registered") 47 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | accelerate==0.18.0 2 | aiohttp==3.8.4 3 | aiosignal==1.3.1 4 | appdirs==1.4.4 5 | async-timeout==4.0.2 6 | attrs==22.2.0 7 | cattrs==22.2.0 8 | certifi==2022.12.7 9 | charset-normalizer==3.1.0 10 | click==8.1.3 11 | cmake==3.26.0 12 | datasets==2.10.1 13 | deepspeed==0.8.2 14 | dill==0.3.6 15 | distlib==0.3.6 16 | docker-pycreds==0.4.0 17 | einops==0.6.0 18 | exceptiongroup==1.1.1 19 | filelock==3.10.0 20 | frozenlist==1.3.3 21 | fsspec==2023.3.0 22 | gitdb==4.0.10 23 | GitPython==3.1.31 24 | grpcio==1.51.3 25 | hjson==3.1.0 26 | huggingface-hub==0.13.2 27 | idna==3.4 28 | importlib-metadata==6.1.0 29 | Jinja2==3.1.2 30 | jsonschema==4.17.3 31 | lit==15.0.7 32 | markdown-it-py==2.2.0 33 | MarkupSafe==2.1.2 34 | mdurl==0.1.2 35 | mpmath==1.3.0 36 | msgpack==1.0.5 37 | multidict==6.0.4 38 | multiprocess==0.70.14 39 | networkx==3.0 40 | ninja==1.11.1 41 | numpy==1.24.2 42 | packaging==23.0 43 | pandas==1.5.3 44 | pathtools==0.1.2 45 | platformdirs==3.1.1 46 | protobuf==3.20.1 47 | psutil==5.9.4 48 | py-cpuinfo==9.0.0 49 | pyarrow==11.0.0 50 | pydantic==1.10.6 51 | Pygments==2.14.0 52 | pyrsistent==0.19.3 53 | python-dateutil==2.8.2 54 | python-rapidjson==1.10 55 | pytz==2022.7.1 56 | PyYAML==6.0 57 | ray==2.3.0 58 | regex==2022.10.31 59 | requests==2.28.2 60 | responses==0.18.0 61 | rich==13.3.2 62 | sentry-sdk==1.17.0 63 | setproctitle==1.3.2 64 | six==1.16.0 65 | smmap==5.0.0 66 | sympy==1.11.1 67 | tabulate==0.9.0 68 | tokenizers==0.13.2 69 | torch==2.0.0 --extra-index-url https://download.pytorch.org/whl/cu116 70 | torchtyping==0.1.4 71 | tqdm==4.65.0 72 | transformers==4.27.1 73 | triton==2.0.0 74 | tritonclient==2.31.0 75 | typeguard==3.0.1 76 | typing_extensions==4.5.0 77 | urllib3==1.26.15 78 | virtualenv==20.21.0 79 | wandb==0.14.0 80 | xxhash==3.2.0 81 | yarl==1.8.2 82 | zipp==3.15.0 83 | evaluate>=0.4.0 84 | nltk>=3.8.1 85 | rouge-score>=0.1.2 86 | pytorch-lightning==0.7.3 -------------------------------------------------------------------------------- /trlx/data/accelerate_base_datatypes.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Iterable 3 | 4 | from torchtyping import TensorType 5 | 6 | 7 | @dataclass 8 | class PromptElement: 9 | """ 10 | Dataclass for a single prompt, containing its string and tokenized form. 11 | 12 | :param text: The prompt text. 13 | :type text: str 14 | 15 | :param tokens: The prompt tokens. Should be a long tensor 16 | :type tokens: torch.Tensor 17 | """ 18 | 19 | text: str 20 | tokens: TensorType["num_tokens"] 21 | 22 | 23 | @dataclass 24 | class PromptBatch: 25 | """ 26 | Batched PromptElement 27 | 28 | :param text: An iterable of prompt texts. 29 | :type text: Iterable[str] 30 | 31 | :param tokens: A long tensor batch of prompt tokens. 32 | :type tokens: torch.Tensor 33 | """ 34 | 35 | text: Iterable[str] 36 | tokens: TensorType["batch_size", "num_tokens"] 37 | 38 | 39 | @dataclass 40 | class AccelerateRLElement: 41 | """ 42 | Dataclass for RL elements, containing output tokens and rewards for each token. 43 | 44 | :param tokens: The output tokens. Should be a long tensor 45 | :type tokens: torch.Tensor 46 | 47 | :param rewards: The rewards for each token. Should be a float tensor of same size as tokens. 48 | :type rewards: torch.Tensor 49 | """ 50 | 51 | output_tokens: TensorType["output_size"] 52 | rewards: TensorType["output_size"] 53 | 54 | 55 | @dataclass 56 | class AccelerateRLBatchElement: 57 | """ 58 | Batched accelerate RL element 59 | 60 | :param tokens: Batches of long tensors of output tokens. 61 | :type tokens: torch.Tensor 62 | 63 | :param rewards: Batches of float tensors of rewards for each output token. 64 | :type rewards: torch.Tensor 65 | """ 66 | 67 | output_tokens: TensorType["batch_size", "output_size"] 68 | rewards: TensorType["batch_size", "output_size"] 69 | -------------------------------------------------------------------------------- /trlx/trainer/accelerate_sft_trainer.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from transformers import AutoModelForCausalLM 4 | 5 | from trlx.data.configs import TRLConfig 6 | from trlx.data.method_configs import MethodConfig, register_method 7 | from trlx.trainer import register_trainer 8 | from trlx.trainer.accelerate_base_trainer import AccelerateRLTrainer 9 | 10 | 11 | @dataclass 12 | @register_method 13 | class SFTConfig(MethodConfig): 14 | """ 15 | Config for SFT training 16 | 17 | :param gen_kwargs: kwargs for generation 18 | :type gen_kwargs: Dict[str, Any] 19 | """ 20 | 21 | gen_kwargs: dict 22 | 23 | 24 | @register_trainer 25 | class AccelerateSFTTrainer(AccelerateRLTrainer): 26 | def __init__(self, config: TRLConfig, **kwargs): 27 | super().__init__(config, **kwargs) 28 | 29 | self.generate_kwargs = dict( 30 | config.method.gen_kwargs, 31 | eos_token_id=self.tokenizer.eos_token_id, 32 | pad_token_id=self.tokenizer.pad_token_id, 33 | ) 34 | 35 | def get_arch(self, config): 36 | return AutoModelForCausalLM.from_pretrained(config.model.model_path) 37 | 38 | def loss(self, batch): 39 | loss = self.model(input_ids=batch.input_ids, attention_mask=batch.attention_mask, labels=batch.input_ids).loss 40 | stats = {"loss": loss} 41 | 42 | return loss, stats 43 | 44 | def prepare_learning(self): 45 | train_dataloader = self.store.create_loader(self.config.train.batch_size) 46 | eval_dataloader = self.eval_pipeline.create_loader(self.config.train.batch_size) 47 | 48 | ( 49 | self.model, 50 | self.opt, 51 | self.train_dataloader, 52 | self.eval_dataloader, 53 | ) = self.accelerator.prepare(self.model, self.opt, train_dataloader, eval_dataloader) 54 | 55 | self.n_updates_per_batch = 1 56 | self.total_steps = self.config.train.epochs * len(train_dataloader) 57 | self.total_steps = min(self.total_steps, self.config.train.total_steps) 58 | -------------------------------------------------------------------------------- /examples/sft_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import sys 3 | 4 | from datasets import load_dataset 5 | from ppo_hh import create_reward_fn 6 | 7 | import trlx 8 | from trlx.data.default_configs import ( 9 | ModelConfig, 10 | OptimizerConfig, 11 | SchedulerConfig, 12 | SFTConfig, 13 | TokenizerConfig, 14 | TrainConfig, 15 | TRLConfig, 16 | ) 17 | 18 | default_config = TRLConfig( 19 | train=TrainConfig( 20 | seq_length=1024, 21 | epochs=100, 22 | total_steps=10000, 23 | batch_size=4, 24 | checkpoint_interval=10000, 25 | eval_interval=1000, 26 | pipeline="PromptPipeline", 27 | trainer="AccelerateSFTTrainer", 28 | checkpoint_dir="checkpoints/sft_hh", 29 | ), 30 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=-1), 31 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 32 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 33 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=100000000, eta_min=1e-6)), 34 | method=SFTConfig( 35 | name="sftconfig", 36 | gen_kwargs=dict(max_new_tokens=128, top_k=20, top_p=1.0, do_sample=True), 37 | ), 38 | ) 39 | 40 | 41 | def preprocess(sample): 42 | sample["chosen_sample"] = sample["prompt"] + sample["chosen"] 43 | return sample 44 | 45 | 46 | def main(hparams={}): 47 | config = TRLConfig.update(default_config, hparams) 48 | 49 | dataset = load_dataset("Dahoas/full-hh-rlhf").map(preprocess) 50 | reward_fn = create_reward_fn() 51 | 52 | trlx.train( 53 | config=config, 54 | samples=dataset["train"]["chosen_sample"], 55 | eval_prompts=dataset["test"]["prompt"][:280], 56 | metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, 57 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 58 | ) 59 | 60 | 61 | if __name__ == "__main__": 62 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 63 | main(hparams) 64 | -------------------------------------------------------------------------------- /configs/test_config.yml: -------------------------------------------------------------------------------- 1 | train: 2 | seq_length: 64 # Size of LM context 3 | epochs: 100 # Train for max(epochs, total_steps) 4 | total_steps: 1000 # Train for max(epochs, total_steps) 5 | batch_size: 16 # batch size 6 | 7 | checkpoint_interval: 10000 # checkpoint interval 8 | eval_interval: 128 # eval interval 9 | 10 | pipeline: "PromptPipeline" # prompt pipeline to load 11 | trainer: "AcceleratePPOTrainer" # Name of model trainer to load 12 | 13 | model: 14 | model_path: "lvwerra/gpt2-imdb" # Name of hf model to load 15 | num_layers_unfrozen: 2 # Number of bottom layers to freeze during training 16 | 17 | tokenizer: 18 | tokenizer_path: "gpt2" # Name of hf tokenizer to load 19 | truncation_side: "right" # Trim this side of samples if they are longer than LM context 20 | 21 | optimizer: 22 | name: "adamw" # Name of optimizer to load 23 | kwargs: 24 | lr: 1.412e-4 # Learning rate 25 | betas: [0.9, 0.95] # Adam betas 26 | eps: 1.0e-8 # Adam eps 27 | weight_decay: 1.0e-6 # Weight decay param 28 | 29 | scheduler: 30 | name: "cosine_annealing" # Name of learning rate scheduler 31 | kwargs: 32 | T_max: 10000 # Maximum number of steps 33 | eta_min: 1.412e-4 # Minimum learning rate 34 | 35 | method: 36 | name: "ppoconfig" # Name of RL method config 37 | num_rollouts: 128 # Number of rollouts to collect per epoch 38 | chunk_size: 128 # Number of rollouts to collect in one loop 39 | ppo_epochs: 4 # Number of ppo epochs 40 | init_kl_coef: 0.2 # init kl coefficient 41 | target: 6 # target kl coefficient, set None for fixed kl coef 42 | horizon: 10000 # PPO horizon 43 | gamma: 0.99 # PPO discount 44 | lam: 0.95 # PPO lambda 45 | cliprange: 0.2 # clip range 46 | cliprange_value: 0.2 # clip range 47 | vf_coef: 1.0 # value term weight 48 | scale_reward: "running" # False|"ref"|"running" estimate against which to scale rewards 49 | cliprange_reward: 10 50 | ref_mean: null 51 | ref_std: null 52 | gen_kwargs: 53 | max_length: 48 # LM max sample gen length 54 | min_length: 48 # LM min sample gen length 55 | top_k: 0.0 # top k 56 | top_p: 1.0 # top p 57 | do_sample: True # sample 58 | -------------------------------------------------------------------------------- /docs/source/conf.py: -------------------------------------------------------------------------------- 1 | # Configuration file for the Sphinx documentation builder. 2 | # 3 | # This file only contains a selection of the most common options. For a full 4 | # list see the documentation: 5 | # https://www.sphinx-doc.org/en/master/usage/configuration.html 6 | 7 | # -- Path setup -------------------------------------------------------------- 8 | 9 | # If extensions (or modules to document with autodoc) are in another directory, 10 | # add these directories to sys.path here. If the directory is relative to the 11 | # documentation root, use os.path.abspath to make it absolute, like shown here. 12 | # 13 | import os 14 | import sys 15 | 16 | import sphinx_rtd_theme 17 | 18 | sys.path.insert(0, os.path.abspath('../..')) 19 | 20 | 21 | # -- Project information ----------------------------------------------------- 22 | 23 | project = 'trlX' 24 | copyright = '2022, CarperAI' 25 | author = 'CarperAI' 26 | 27 | # -- General configuration --------------------------------------------------- 28 | 29 | # Add any Sphinx extension module names here, as strings. They can be 30 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 31 | # ones. 32 | 33 | extensions = ['sphinx_rtd_theme', 'sphinx.ext.todo', 'sphinx.ext.viewcode', 'sphinx.ext.autodoc', 'sphinx.ext.autosummary', 'sphinx.ext.autosectionlabel'] 34 | 35 | # Add any paths that contain templates here, relative to this directory. 36 | templates_path = ['_templates'] 37 | 38 | # List of patterns, relative to source directory, that match files and 39 | # directories to ignore when looking for source files. 40 | # This pattern also affects html_static_path and html_extra_path. 41 | exclude_patterns = [] 42 | 43 | 44 | # -- Options for HTML output ------------------------------------------------- 45 | 46 | # The theme to use for HTML and HTML Help pages. See the documentation for 47 | # a list of builtin themes. 48 | # 49 | html_theme = 'sphinx_rtd_theme' 50 | 51 | # Add any paths that contain custom static files (such as style sheets) here, 52 | # relative to this directory. They are copied after the builtin static files, 53 | # so a file named "default.css" will overwrite the builtin "default.css". 54 | html_static_path = ['_static'] 55 | -------------------------------------------------------------------------------- /trlx/data/ppo_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | from torchtyping import TensorType 4 | 5 | 6 | @dataclass 7 | class PPORLElement: 8 | """ 9 | :param query_tensor: The query tensor i.e. the prompt tokens. 10 | Should be a long tensor. 11 | :type query_tensor: torch.Tensor 12 | 13 | :param response_tensor: The response tensor i.e. the output tokens. 14 | Should be a long tensor. 15 | :type response_tensor: torch.Tensor 16 | 17 | :param logprobs: The log probabilities over the response tokens generated 18 | by the policy network (i.e. the autoregressive model). 19 | Should be a float tensor of same size as tokens. 20 | :type logprobs: torch.Tensor 21 | 22 | :param values: The values for each token generated from the value network or value head. 23 | Should be a float tensor of same size as tokens. 24 | :type values: torch.Tensor 25 | 26 | :param rewards: The rewards for each token outputted in response. 27 | Should be a float tensor of same size as tokens. 28 | :type rewards: torch.Tensor 29 | """ 30 | 31 | query_tensor: TensorType["query_size"] 32 | response_tensor: TensorType["response_size"] 33 | logprobs: TensorType["response_size"] 34 | values: TensorType["response_size"] 35 | rewards: TensorType["response_size"] 36 | 37 | 38 | @dataclass 39 | class PPORLBatch: 40 | """ 41 | A batched version of the PPORLElement. See PPORLElement for more details on individual fields. 42 | 43 | :param query_tensors: A batch of query tensors. Should be a long tensor. 44 | :type query_tensors: torch.Tensor 45 | 46 | :param response_tensors: A batch of response tensors. Should be a long tensor. 47 | :type response_tensors: torch.Tensor 48 | 49 | :param logprobs: A batch of log probabilities from policy 50 | :type logprobs: torch.Tensor 51 | 52 | :param values: A batch of values from value network 53 | :type values: torch.Tensor 54 | 55 | :param rewards: A batch of rewards 56 | :type rewards: torch.Tensor 57 | """ 58 | 59 | query_tensors: TensorType["batch_size", "query_size"] 60 | response_tensors: TensorType["batch_size", "response_size"] 61 | logprobs: TensorType["batch_size", "response_size"] 62 | values: TensorType["batch_size", "response_size"] 63 | rewards: TensorType["batch_size", "response_size"] 64 | -------------------------------------------------------------------------------- /trlx/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | import random 2 | import sys 3 | from abc import abstractmethod, abstractstaticmethod 4 | from typing import Any, Callable, Dict, Iterable 5 | 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from trlx.data import GeneralElement, RLElement 9 | 10 | # specifies a dictionary of architectures 11 | _DATAPIPELINE: Dict[str, any] = {} # registry 12 | 13 | 14 | def register_datapipeline(name): 15 | """Decorator used register a CARP architecture 16 | Args: 17 | name: Name of the architecture 18 | """ 19 | 20 | def register_class(cls, name): 21 | _DATAPIPELINE[name] = cls 22 | setattr(sys.modules[__name__], name, cls) 23 | return cls 24 | 25 | if isinstance(name, str): 26 | name = name.lower() 27 | return lambda c: register_class(c, name) 28 | 29 | cls = name 30 | name = cls.__name__ 31 | register_class(cls, name.lower()) 32 | 33 | return cls 34 | 35 | 36 | @register_datapipeline 37 | class BasePipeline(Dataset): 38 | def __init__(self, path: str = "dataset"): 39 | super().__init__() 40 | 41 | @abstractmethod 42 | def __getitem__(self, index: int) -> GeneralElement: 43 | pass 44 | 45 | @abstractmethod 46 | def __len__(self) -> int: 47 | pass 48 | 49 | @abstractmethod 50 | def create_loader( 51 | self, 52 | batch_size: int, 53 | shuffle: bool, 54 | prep_fn: Callable = None, 55 | num_workers: int = 0, 56 | ) -> DataLoader: 57 | """ 58 | Create a dataloader for the pipeline 59 | 60 | :param prep_fn: Typically a tokenizer. Applied to GeneralElement after collation. 61 | """ 62 | pass 63 | 64 | 65 | class BaseRolloutStore(Dataset): 66 | def __init__(self, capacity=-1): 67 | self.history: Iterable[Any] = None 68 | self.capacity = capacity 69 | 70 | @abstractmethod 71 | def push(self, exps: Iterable[Any]): 72 | """ 73 | Push experiences to rollout storage 74 | """ 75 | pass 76 | 77 | def __getitem__(self, index: int) -> RLElement: 78 | return self.history[index] 79 | 80 | def __len__(self) -> int: 81 | return len(self.history) 82 | 83 | @abstractmethod 84 | def create_loader( 85 | self, 86 | batch_size: int, 87 | shuffle: bool, 88 | prep_fn: Callable = None, 89 | num_workers: int = 0, 90 | ) -> DataLoader: 91 | """ 92 | Create a dataloader for the rollout store 93 | 94 | :param prep_fn: Applied to RLElement after collation (typically tokenizer) 95 | :type prep_fn: Callable 96 | """ 97 | pass 98 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | ## Security 4 | 5 | Microsoft takes the security of our software products and services seriously, which includes all source code repositories managed through our GitHub organizations, which include [Microsoft](https://github.com/microsoft), [Azure](https://github.com/Azure), [DotNet](https://github.com/dotnet), [AspNet](https://github.com/aspnet), [Xamarin](https://github.com/xamarin), and [our GitHub organizations](https://opensource.microsoft.com/). 6 | 7 | If you believe you have found a security vulnerability in any Microsoft-owned repository that meets [Microsoft's definition of a security vulnerability](https://aka.ms/opensource/security/definition), please report it to us as described below. 8 | 9 | ## Reporting Security Issues 10 | 11 | **Please do not report security vulnerabilities through public GitHub issues.** 12 | 13 | Instead, please report them to the Microsoft Security Response Center (MSRC) at [https://msrc.microsoft.com/create-report](https://aka.ms/opensource/security/create-report). 14 | 15 | If you prefer to submit without logging in, send email to [secure@microsoft.com](mailto:secure@microsoft.com). If possible, encrypt your message with our PGP key; please download it from the [Microsoft Security Response Center PGP Key page](https://aka.ms/opensource/security/pgpkey). 16 | 17 | You should receive a response within 24 hours. If for some reason you do not, please follow up via email to ensure we received your original message. Additional information can be found at [microsoft.com/msrc](https://aka.ms/opensource/security/msrc). 18 | 19 | Please include the requested information listed below (as much as you can provide) to help us better understand the nature and scope of the possible issue: 20 | 21 | * Type of issue (e.g. buffer overflow, SQL injection, cross-site scripting, etc.) 22 | * Full paths of source file(s) related to the manifestation of the issue 23 | * The location of the affected source code (tag/branch/commit or direct URL) 24 | * Any special configuration required to reproduce the issue 25 | * Step-by-step instructions to reproduce the issue 26 | * Proof-of-concept or exploit code (if possible) 27 | * Impact of the issue, including how an attacker might exploit the issue 28 | 29 | This information will help us triage your report more quickly. 30 | 31 | If you are reporting for a bug bounty, more complete reports can contribute to a higher bounty award. Please visit our [Microsoft Bug Bounty Program](https://aka.ms/opensource/security/bounty) page for more details about our active programs. 32 | 33 | ## Preferred Languages 34 | 35 | We prefer all communications to be in English. 36 | 37 | ## Policy 38 | 39 | Microsoft follows the principle of [Coordinated Vulnerability Disclosure](https://aka.ms/opensource/security/cvd). 40 | 41 | 42 | -------------------------------------------------------------------------------- /examples/to_triton.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from string import Template 4 | 5 | import torch 6 | from huggingface_hub import snapshot_download 7 | from torch import nn 8 | from transformers import AutoModelForCausalLM, AutoTokenizer 9 | 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument("--base_model", type=str, required=True, help="Path to HF checkpoint with the base model") 13 | 14 | parser.add_argument( 15 | "--checkpoint", 16 | type=str, 17 | required=True, 18 | help="Path to either a local directory or a HF checkpoint with reward model's weights", 19 | ) 20 | 21 | parser.add_argument("--revision", type=str, required=False, help="Optional branch/commit of the HF checkpoint") 22 | 23 | parser.add_argument("--device", type=int, default=0) 24 | args = parser.parse_args() 25 | 26 | model_name = args.checkpoint.split("/")[-1] 27 | device = torch.device(args.device) 28 | 29 | 30 | class RewardModel(nn.Module): 31 | def __init__(self, checkpoint_path, eos_token_id): 32 | super().__init__() 33 | model = AutoModelForCausalLM.from_pretrained(checkpoint_path) 34 | self.transformer = model.transformer 35 | self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) 36 | self.eos_token_id = eos_token_id 37 | 38 | def forward(self, input_ids): 39 | states = self.transformer(input_ids)[0] 40 | rewards = self.v_head(states).squeeze(-1) 41 | ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) 42 | returns = torch.gather(rewards, 1, ends).squeeze(-1) 43 | return returns 44 | 45 | 46 | if os.path.isdir(args.checkpoint): 47 | directory = args.checkpoint 48 | else: 49 | directory = snapshot_download(args.checkpoint, revision=args.revision) 50 | 51 | print(f"searching through {os.listdir(directory)} in {directory}") 52 | 53 | for fpath in os.listdir(directory): 54 | if fpath.endswith(".pt") or fpath.endswith(".bin"): 55 | checkpoint = os.path.join(directory, fpath) 56 | break 57 | 58 | tokenizer = AutoTokenizer.from_pretrained(args.base_model) 59 | model = RewardModel(args.base_model, tokenizer.eos_token_id) 60 | model.load_state_dict(torch.load(checkpoint)) 61 | model.eval() 62 | model.requires_grad_(False) 63 | model = model.half().to(device) 64 | 65 | input = tokenizer("reward model's hash", return_tensors="pt").to(device) 66 | print(f"{model(input.input_ids)=}") 67 | 68 | traced_script_module = torch.jit.trace(model, input.input_ids) 69 | 70 | os.makedirs(f"model_store/{model_name}/1", exist_ok=True) 71 | traced_script_module.save(f"model_store/{model_name}/1/traced-model.pt") 72 | 73 | config_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "triton_config.pbtxt") 74 | with open(config_path) as f: 75 | template = Template(f.read()) 76 | config = template.substitute({"model_name": model_name}) 77 | with open(f"model_store/{model_name}/config.pbtxt", "w") as f: 78 | f.write(config) 79 | -------------------------------------------------------------------------------- /trlx/pipeline/ppo_pipeline.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import time 4 | from typing import Iterable 5 | 6 | from torch.nn.utils.rnn import pad_sequence 7 | from torch.utils.data import DataLoader 8 | 9 | from trlx.data.ppo_types import PPORLBatch, PPORLElement 10 | from trlx.pipeline import BaseRolloutStore 11 | 12 | 13 | class PPORolloutStorage(BaseRolloutStore): 14 | """ 15 | Rollout storage for training PPO 16 | """ 17 | 18 | def __init__(self, pad_token_id): 19 | super().__init__() 20 | 21 | self.pad_token_id = pad_token_id 22 | self.history: Iterable[PPORLElement] = [None] 23 | 24 | def push(self, exps: Iterable[PPORLElement]): 25 | self.history += exps 26 | 27 | def clear_history(self): 28 | self.history = [] 29 | 30 | def export_history(self, location: str): 31 | assert os.path.exists(location) 32 | 33 | fpath = os.path.join(location, f"epoch-{str(time.time())}.json") 34 | 35 | def exp_to_dict(exp): 36 | {k: v.cpu().tolist() for k, v in exp.__dict__.items()} 37 | 38 | data = [exp_to_dict(exp) for exp in self.history] 39 | with open(fpath, "w") as f: 40 | f.write(json.dumps(data, indent=2)) 41 | 42 | def __getitem__(self, index: int) -> PPORLElement: 43 | return self.history[index] 44 | 45 | def __len__(self) -> int: 46 | return len(self.history) 47 | 48 | def create_loader( 49 | self, 50 | batch_size: int, 51 | shuffle: bool, 52 | ) -> DataLoader: 53 | def collate_fn(elems: Iterable[PPORLElement]): 54 | # print("elem shape:", elems[0].query_tensor.flip(0).shape) 55 | # print("elem 1 shape:", elems[1].query_tensor.flip(0).shape) 56 | return PPORLBatch( 57 | # Left padding of already left-padded queries 58 | pad_sequence( 59 | [elem.query_tensor.flip(0) for elem in elems], 60 | padding_value=self.pad_token_id, 61 | batch_first=True, 62 | ).flip(1), 63 | # Right pad the rest, to have a single horizontal query/response split 64 | pad_sequence( 65 | [elem.response_tensor for elem in elems], 66 | padding_value=self.pad_token_id, 67 | batch_first=True, 68 | ), 69 | pad_sequence( 70 | [elem.logprobs for elem in elems], 71 | padding_value=0.0, 72 | batch_first=True, 73 | ), 74 | pad_sequence([elem.values for elem in elems], padding_value=0.0, batch_first=True), 75 | pad_sequence( 76 | [elem.rewards for elem in elems], 77 | padding_value=0.0, 78 | batch_first=True, 79 | ), 80 | ) 81 | 82 | return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=collate_fn) 83 | -------------------------------------------------------------------------------- /trlx/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from abc import abstractmethod 3 | from typing import Any, Callable, Dict, Iterable, Optional 4 | 5 | from trlx.data.configs import TRLConfig 6 | from trlx.pipeline import BaseRolloutStore 7 | 8 | # specifies a dictionary of architectures 9 | _TRAINERS: Dict[str, Any] = {} # registry 10 | 11 | 12 | def register_trainer(name): 13 | """Decorator used to register a trainer 14 | Args: 15 | name: Name of the trainer type to register 16 | """ 17 | 18 | def register_class(cls, name): 19 | _TRAINERS[name] = cls 20 | setattr(sys.modules[__name__], name, cls) 21 | return cls 22 | 23 | if isinstance(name, str): 24 | name = name.lower() 25 | return lambda c: register_class(c, name) 26 | 27 | cls = name 28 | name = cls.__name__ 29 | register_class(cls, name.lower()) 30 | 31 | return cls 32 | 33 | 34 | @register_trainer 35 | class BaseRLTrainer: 36 | def __init__( 37 | self, 38 | config: TRLConfig, 39 | reward_fn=None, 40 | metric_fn=None, 41 | logit_mask=None, 42 | stop_sequences=None, 43 | train_mode=False, 44 | ): 45 | self.store: BaseRolloutStore = None 46 | self.config = config 47 | self.reward_fn = reward_fn 48 | self.metric_fn = metric_fn 49 | self.train_mode = train_mode 50 | self.logit_mask = logit_mask 51 | self.stop_sequences = stop_sequences 52 | 53 | def push_to_store(self, data): 54 | self.store.push(data) 55 | 56 | def add_eval_pipeline(self, eval_pipeline): 57 | """Adds pipeline for validation prompts""" 58 | self.eval_pipeline = eval_pipeline 59 | 60 | @abstractmethod 61 | def sample(self, prompts: Iterable[str], length: int, n_samples: int) -> Iterable[str]: 62 | """ 63 | Sample from the language. Takes prompts and maximum length to generate. 64 | 65 | :param prompts: List of prompts to tokenize and use as context 66 | 67 | :param length: How many new tokens to genrate for each prompt 68 | :type length: int 69 | 70 | :param n_samples: Default behavior is to take number of prompts as this 71 | """ 72 | pass 73 | 74 | @abstractmethod 75 | def learn( 76 | self, 77 | log_fn: Callable = None, 78 | save_fn: Callable = None, 79 | eval_fn: Callable = None, 80 | ): 81 | """ 82 | Use experiences in RolloutStore to learn 83 | 84 | :param log_fn: Optional function that is called when logging and passed a dict of logging relevant values 85 | :type log_fn: Callable[Dict[str, any]] 86 | 87 | :param save_fn: Optional function to call after saving. Is passed the components. 88 | :type save_fn: Callable[Dict[str, any]] 89 | 90 | :param eval_fn: Optional function to call during evaluation. Eval doesn't do anything without this. 91 | :type eval_fn: Callable[BaseRLTrainer] 92 | """ 93 | pass 94 | 95 | @abstractmethod 96 | def save(self, directory: Optional[str] = None): 97 | """Creates a checkpoint of training states""" 98 | pass 99 | 100 | @abstractmethod 101 | def load(self, directory=None): 102 | """Loads a checkpoint created from `save`""" 103 | pass 104 | -------------------------------------------------------------------------------- /trlx/sweep.py: -------------------------------------------------------------------------------- 1 | # python -m trlx.sweep --config configs/sweeps/ppo_sweep.yml examples/ppo_sentiments.py 2 | import argparse 3 | import importlib 4 | from pathlib import Path 5 | 6 | import ray 7 | import yaml 8 | from ray import tune 9 | from ray.tune.logger import CSVLoggerCallback 10 | 11 | from trlx.ray_tune import get_param_space, get_tune_config 12 | from trlx.ray_tune.wandb import create_report, log_trials 13 | 14 | 15 | def tune_function(train_function, param_space: dict, tune_config: dict, resources: dict): 16 | tuner = tune.Tuner( 17 | tune.with_resources(train_function, resources=resources), 18 | param_space=param_space, 19 | tune_config=tune.TuneConfig(**tune_config), 20 | run_config=ray.air.RunConfig(local_dir="ray_results", callbacks=[CSVLoggerCallback()]), 21 | ) 22 | 23 | results = tuner.fit() 24 | project_name = tune_config.get("project_name", "sweep") 25 | 26 | log_trials( 27 | tuner._local_tuner.get_experiment_checkpoint_dir(), 28 | project_name, 29 | ) 30 | 31 | create_report( 32 | project_name, 33 | param_space, 34 | tune_config, 35 | Path(tuner._local_tuner.get_experiment_checkpoint_dir()).stem, 36 | results.get_best_result().config, 37 | ) 38 | 39 | print("Best hyperparameters found were: ", results.get_best_result().config) 40 | 41 | 42 | if __name__ == "__main__": 43 | parser = argparse.ArgumentParser() 44 | parser.add_argument("script", type=str, help="Path to the script") 45 | parser.add_argument( 46 | "--config", 47 | type=str, 48 | required=True, 49 | help="The config file defining the param_space.", 50 | ) 51 | parser.add_argument("--num-cpus", type=int, default=4, help="Number of CPUs to use per exp.") 52 | parser.add_argument("--num-gpus", type=int, default=1, help="Number of GPUs to use per exp.") 53 | parser.add_argument("-y", "--assume-yes", action="store_true", help="Don't ask for confirmation") 54 | parser.add_argument( 55 | "--server-address", 56 | type=str, 57 | default=None, 58 | required=False, 59 | help="The address of server to connect to if using Ray Client.", 60 | ) 61 | 62 | args, _ = parser.parse_known_args() 63 | 64 | # Read config and parse it 65 | with open(args.config) as f: 66 | config = yaml.safe_load(f) 67 | tune_config = get_tune_config(config.pop("tune_config")) 68 | param_space = get_param_space(config) 69 | 70 | # Initialize Ray. 71 | if args.server_address: 72 | ray.init(address=f"ray://{args.server_address}") 73 | else: 74 | ray.init() 75 | 76 | resources = { 77 | "cpu": args.num_cpus, 78 | "gpu": args.num_gpus, 79 | } 80 | 81 | print(f'WARNING: Importing main from "{args.script}" and everything along with it') 82 | 83 | if not args.assume_yes: 84 | print("Please confirm y/n: ", end="") 85 | if input() != "y": 86 | print("Exiting") 87 | exit(1) 88 | 89 | # convert a nested path to a module path 90 | script_path = args.script.replace(".py", "").replace("/", ".") 91 | script = importlib.import_module(script_path) 92 | # Register the training function that will be used for training the model. 93 | tune.register_trainable("train_function", script.main) 94 | tune_function(script.main, param_space, tune_config, resources) 95 | 96 | # Shut down Ray. 97 | ray.shutdown() 98 | -------------------------------------------------------------------------------- /trlx/data/ilql_types.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, fields 2 | 3 | from torchtyping import TensorType # type: ignore 4 | 5 | 6 | def flatten_dataclass(cls: type): 7 | """Return a function that flattens a dataclass into a list""" 8 | cls_fields = [f.name for f in fields(cls)] 9 | return lambda x: [getattr(x, f) for f in cls_fields] 10 | 11 | 12 | def unflatten_dataclass(cls: type): 13 | """Return a function that unflattens a list into a dataclass""" 14 | cls_fields = [f.name for f in fields(cls)] 15 | return lambda x: cls(**dict(zip(cls_fields, x))) 16 | 17 | 18 | @dataclass 19 | class ILQLElement: 20 | """ 21 | Data element for ILQL 22 | 23 | :param input_ids: Input tokens. Should be a long tensor. 24 | :type input_ids: torch.Tensor 25 | 26 | :param attention_mask: Attention mask. Should be a long tensor. 27 | :type attention_mask: torch.Tensor 28 | 29 | :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. 30 | :type rewards: torch.Tensor 31 | """ 32 | 33 | input_ids: TensorType["query_size"] 34 | attention_mask: TensorType["query_size"] 35 | rewards: TensorType["reward_size"] 36 | states_ixs: TensorType["states_size"] 37 | actions_ixs: TensorType["reward_size"] 38 | dones: TensorType["states_size"] 39 | 40 | 41 | @dataclass 42 | class ILQLSeq2SeqElement: 43 | """ 44 | Data element for ILQL 45 | 46 | :param input_ids: Input tokens. Should be a long tensor. 47 | :type input_ids: torch.Tensor 48 | 49 | :param attention_mask: Attention mask. Should be a long tensor. 50 | :type attention_mask: torch.Tensor 51 | 52 | :param rewards: Rewards for each token. Should be a float tensor of same size as tokens. 53 | :type rewards: torch.Tensor 54 | """ 55 | 56 | input_ids: TensorType["query_size"] 57 | attention_mask: TensorType["query_size"] 58 | decoder_input_ids: TensorType["reward_size"] 59 | rewards: TensorType["reward_size"] 60 | states_ixs: TensorType["states_size"] 61 | actions_ixs: TensorType["reward_size"] 62 | dones: TensorType["states_size"] 63 | 64 | 65 | @dataclass 66 | class ILQLBatch: 67 | """ 68 | Batched ILQL data elements 69 | 70 | :param input_ids: Batch of input tokens. 71 | :type input_ids: torch.Tensor 72 | 73 | :param attention_mask: Batch of attention masks. 74 | :type attention_mask: torch.Tensor 75 | 76 | :param rewards: Batch of rewards for each token in each token batch. 77 | :type rewards: torch.Tensor 78 | """ 79 | 80 | input_ids: TensorType["batch_size", "query_size"] 81 | attention_mask: TensorType["batch_size", "query_size"] 82 | rewards: TensorType["batch_size", "reward_size"] 83 | states_ixs: TensorType["batch_size", "states_size"] 84 | actions_ixs: TensorType["batch_size", "reward_size"] 85 | dones: TensorType["batch_size", "states_size"] 86 | 87 | 88 | @dataclass 89 | class ILQLSeq2SeqBatch: 90 | """ 91 | Batched ILQL data elements 92 | 93 | :param input_ids: Batch of input tokens. 94 | :type input_ids: torch.Tensor 95 | 96 | :param attention_mask: Batch of attention masks. 97 | :type attention_mask: torch.Tensor 98 | 99 | :param rewards: Batch of rewards for each token in each token batch. 100 | :type rewards: torch.Tensor 101 | """ 102 | 103 | input_ids: TensorType["batch_size", "query_size"] 104 | attention_mask: TensorType["batch_size", "query_size"] 105 | decoder_input_ids: TensorType["batch_size", "reward_size"] 106 | rewards: TensorType["batch_size", "reward_size"] 107 | states_ixs: TensorType["batch_size", "states_size"] 108 | actions_ixs: TensorType["batch_size", "reward_size"] 109 | dones: TensorType["batch_size", "states_size"] 110 | -------------------------------------------------------------------------------- /examples/reward_model/reward_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from transformers import AutoModelForCausalLM, AutoTokenizer 4 | 5 | 6 | class GPTRewardModel(nn.Module): 7 | def __init__(self, model_path): 8 | super().__init__() 9 | model = AutoModelForCausalLM.from_pretrained(model_path) 10 | self.config = model.config 11 | # `gpt-neo(x)` models use `hidden_size` attribute names instead of `n_embd`` 12 | self.config.n_embd = self.config.hidden_size if hasattr(self.config, "hidden_size") else self.config.n_embd 13 | self.transformer = model.transformer 14 | self.v_head = nn.Linear(self.config.n_embd, 1, bias=False) 15 | self.tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 16 | self.tokenizer.pad_token = self.tokenizer.eos_token 17 | self.PAD_ID = self.tokenizer(self.tokenizer.pad_token)["input_ids"][0] 18 | 19 | def forward( 20 | self, 21 | input_ids=None, 22 | past_key_values=None, 23 | attention_mask=None, 24 | token_type_ids=None, 25 | position_ids=None, 26 | head_mask=None, 27 | inputs_embeds=None, 28 | mc_token_ids=None, 29 | labels=None, 30 | return_dict=False, 31 | output_attentions=False, 32 | output_hidden_states=False, 33 | ): 34 | loss = None 35 | transformer_outputs = self.transformer( 36 | input_ids, 37 | past_key_values=past_key_values, 38 | attention_mask=attention_mask, 39 | token_type_ids=token_type_ids, 40 | position_ids=position_ids, 41 | head_mask=head_mask, 42 | inputs_embeds=inputs_embeds, 43 | ) 44 | 45 | hidden_states = transformer_outputs[0] 46 | 47 | rewards = self.v_head(hidden_states).squeeze(-1) 48 | chosen_end_scores = [] 49 | rejected_end_scores = [] 50 | 51 | # Split the inputs and rewards into two parts, chosen and rejected 52 | assert len(input_ids.shape) == 2 53 | bs = input_ids.shape[0] // 2 54 | chosen = input_ids[:bs] 55 | rejected = input_ids[bs:] 56 | chosen_rewards = rewards[:bs] 57 | rejected_rewards = rewards[bs:] 58 | 59 | loss = 0 60 | inference = False 61 | for i in range(bs): 62 | if torch.all(torch.eq(chosen[i], rejected[i])).item(): 63 | c_inds = (chosen[i] == self.PAD_ID).nonzero() 64 | c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] 65 | chosen_end_scores.append(chosen_rewards[i, c_ind - 1]) 66 | inference = True 67 | continue 68 | 69 | # Check if there is any padding otherwise take length of sequence 70 | c_inds = (chosen[i] == self.PAD_ID).nonzero() 71 | c_ind = c_inds[0].item() if len(c_inds) > 0 else chosen.shape[1] 72 | r_inds = (rejected[i] == self.PAD_ID).nonzero() 73 | r_ind = r_inds[0].item() if len(r_inds) > 0 else rejected.shape[1] 74 | end_ind = max(c_ind, r_ind) 75 | 76 | # Retrieve first index where trajectories diverge 77 | divergence_ind = (chosen[i] != rejected[i]).nonzero()[0] 78 | assert divergence_ind > 0 79 | 80 | # Index into the correct rewards 81 | c_truncated_reward = chosen_rewards[i][divergence_ind:end_ind] 82 | r_truncated_reward = rejected_rewards[i][divergence_ind:end_ind] 83 | 84 | # Append the last rewards to the list of end scores 85 | chosen_end_scores.append(c_truncated_reward[-1]) 86 | rejected_end_scores.append(r_truncated_reward[-1]) 87 | 88 | # Compute loss based on truncated rewards (ignore padding) 89 | loss += -torch.log(torch.sigmoid(c_truncated_reward - r_truncated_reward)).mean() 90 | loss = loss / bs 91 | 92 | if not inference: 93 | chosen_end_scores = torch.stack(chosen_end_scores) 94 | rejected_end_scores = torch.stack(rejected_end_scores) 95 | 96 | if inference: 97 | chosen_end_scores = torch.stack(chosen_end_scores) 98 | return {"chosen_end_scores": chosen_end_scores} 99 | 100 | return { 101 | "loss": loss, 102 | "chosen_end_scores": chosen_end_scores, 103 | "rejected_end_scores": rejected_end_scores, 104 | } 105 | -------------------------------------------------------------------------------- /examples/ilql_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import sys 4 | 5 | from datasets import load_dataset 6 | from ppo_hh import create_reward_fn 7 | from datasets import load_from_disk 8 | import random 9 | import numpy as np 10 | import torch 11 | import trlx 12 | from trlx.data.default_configs import ( 13 | ILQLConfig, 14 | ModelConfig, 15 | OptimizerConfig, 16 | SchedulerConfig, 17 | TokenizerConfig, 18 | TrainConfig, 19 | TRLConfig, 20 | ) 21 | RANDOM_SEED = 0 22 | MODEL_SIZE = "125M" 23 | OUTPUT_DIR = "./output" 24 | 25 | 26 | random.seed(RANDOM_SEED) 27 | np.random.seed(RANDOM_SEED) 28 | torch.manual_seed(RANDOM_SEED) 29 | torch.cuda.manual_seed(RANDOM_SEED) 30 | 31 | default_config = TRLConfig( 32 | train=TrainConfig( 33 | seq_length=1000, 34 | batch_size=1, 35 | epochs=100, 36 | total_steps=20000, 37 | checkpoint_interval=1000, 38 | eval_interval=1000, 39 | pipeline="PromptPipeline", 40 | trainer="AccelerateILQLTrainer", 41 | checkpoint_dir="checkpoints/ilql_hh", 42 | seed=RANDOM_SEED, 43 | ), 44 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2), 45 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 46 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 47 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=1000000000, eta_min=1e-6)), 48 | method=ILQLConfig( 49 | name="ilqlconfig", 50 | tau=0.6, 51 | gamma=0.99, 52 | cql_scale=0.1, 53 | awac_scale=1, 54 | alpha=0.0001, 55 | beta=0, 56 | steps_for_target_q_sync=1, 57 | two_qs=True, 58 | gen_kwargs=dict(max_new_tokens=128, top_k=20, beta=[1, 4], temperature=1.0), 59 | ), 60 | ) 61 | 62 | 63 | 64 | 65 | 66 | config_name = MODEL_SIZE 67 | if config_name == "125M": 68 | default_config.train.batch_size = 4 69 | default_config.model.model_path = "Dahoas/pythia-125M-static-sft" 70 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 71 | elif config_name == "1B": 72 | default_config.train.batch_size = 1 73 | default_config.model.model_path = "Dahoas/pythia-1B-static-sft" 74 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 75 | elif config_name == "6B": 76 | default_config.train.batch_size = 1 77 | default_config.model.model_path = "Dahoas/pythia-6B-static-sft" 78 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 79 | elif config_name == "20B": 80 | default_config.train.batch_size = 1 81 | default_config.train.total_steps = 3000 82 | default_config.train.checkpoint_dir = "checkpoints/ilql_hh_20B" 83 | default_config.model.model_path = "EleutherAI/gpt-neox-20b" 84 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 85 | 86 | reward_fn = create_reward_fn() 87 | def preprocess(sample): 88 | sample["prompt_output"] = [ 89 | [sample["prompt"], sample["chosen"]], 90 | [sample["prompt"], sample["rejected"]], 91 | ] 92 | prompt_res = [ 93 | sample["prompt"] + sample["chosen"], 94 | sample["prompt"] + sample["rejected"], 95 | ] 96 | reward = reward_fn(prompt_res, sample["prompt"], sample["rejected"]) 97 | 98 | sample["reward"] = reward 99 | return sample 100 | 101 | 102 | def main(hparams={}): 103 | output_dir = OUTPUT_DIR 104 | config = TRLConfig.update(default_config, hparams) 105 | config.train.rollout_logging_dir = output_dir 106 | config.train.checkpoint_dir = output_dir 107 | config.train.logging_dir = output_dir 108 | config.train.tracker = "tensorboard" 109 | 110 | dataset = load_dataset("Dahoas/rm-static").map(preprocess) 111 | prompts_outputs = sum(dataset["train"]["prompt_output"], []) 112 | 113 | rewards = sum(dataset["train"]["reward"], []) 114 | eval_prompts = [prompt_output[0][0] for prompt_output in dataset["test"]["prompt_output"]][:280] 115 | 116 | trlx.train( 117 | samples=prompts_outputs, 118 | rewards=rewards, 119 | config=config, 120 | eval_prompts=eval_prompts, 121 | metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, 122 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 123 | ) 124 | 125 | 126 | if __name__ == "__main__": 127 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 128 | main(hparams) 129 | -------------------------------------------------------------------------------- /examples/reward_model/gptj_reward_test.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | from datasets import load_dataset 6 | from reward_model import GPTRewardModel 7 | from torch.utils.data import Dataset 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer 10 | 11 | 12 | def set_seed(seed_val=42): 13 | random.seed(seed_val) 14 | np.random.seed(seed_val) 15 | torch.manual_seed(seed_val) 16 | torch.cuda.manual_seed_all(seed_val) 17 | 18 | 19 | def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): 20 | dataset = load_dataset(path, split=split) 21 | if split == "test": 22 | dataset = dataset.select(range(5000)) 23 | 24 | pairs = [] 25 | for sample in tqdm(dataset): 26 | pair = {} 27 | prompt = sample["prompt"] 28 | chosen_summary = sample["chosen"] 29 | rejected_summary = sample["rejected"] 30 | if chosen_summary == rejected_summary: 31 | continue 32 | if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: 33 | continue 34 | pair["chosen"] = prompt + "\n" + chosen_summary 35 | pair["rejected"] = prompt + "\n" + rejected_summary 36 | pairs.append(pair) 37 | return pairs 38 | 39 | 40 | class PairwiseDataset(Dataset): 41 | def __init__(self, pairs, tokenizer, max_length): 42 | self.chosen_input_ids = [] 43 | self.chosen_attn_masks = [] 44 | self.rejected_input_ids = [] 45 | self.rejected_attn_masks = [] 46 | for pair in pairs: 47 | chosen, rejected = pair["chosen"], pair["rejected"] 48 | chosen_encodings_dict = tokenizer( 49 | "<|startoftext|>" + chosen + "<|endoftext|>", 50 | truncation=True, 51 | max_length=max_length, 52 | padding="max_length", 53 | return_tensors="pt", 54 | ) 55 | rejected_encodings_dict = tokenizer( 56 | "<|startoftext|>" + rejected + "<|endoftext|>", 57 | truncation=True, 58 | max_length=max_length, 59 | padding="max_length", 60 | return_tensors="pt", 61 | ) 62 | self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) 63 | self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) 64 | self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) 65 | self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) 66 | 67 | def __len__(self): 68 | return len(self.chosen_input_ids) 69 | 70 | def __getitem__(self, idx): 71 | return ( 72 | self.chosen_input_ids[idx], 73 | self.chosen_attn_masks[idx], 74 | self.rejected_input_ids[idx], 75 | self.rejected_attn_masks[idx], 76 | ) 77 | 78 | 79 | class DataCollatorReward: 80 | def __call__(self, data): 81 | batch = {} 82 | batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) 83 | batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) 84 | batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) 85 | return batch 86 | 87 | 88 | if __name__ == "__main__": 89 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 90 | tokenizer.pad_token = tokenizer.eos_token 91 | PAD_ID = tokenizer(tokenizer.pad_token)["input_ids"][0] 92 | 93 | model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") 94 | model.load_state_dict(torch.load("rm_checkpoint/pytorch_model.bin")) 95 | max_length = 550 96 | val_pairs = create_comparison_dataset("CarperAI/openai_summarize_comparisons", "test") 97 | dev_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) 98 | 99 | from torch.utils.data import DataLoader 100 | 101 | dev_dataloader = DataLoader(dev_dataset, shuffle=False, batch_size=6, collate_fn=DataCollatorReward()) 102 | model.cuda() 103 | model.eval() 104 | model.half() 105 | correct = 0 106 | chosen_list = [] 107 | reject_list = [] 108 | with torch.no_grad(): 109 | for step, batch in tqdm(enumerate(dev_dataloader), total=len(dev_dataloader)): 110 | for x in batch: 111 | batch[x] = batch[x].cuda() 112 | outputs = model(**batch) 113 | correct += sum(outputs["chosen_end_scores"] > outputs["rejected_end_scores"]) 114 | chosen_list.append(outputs["chosen_end_scores"].cpu()) 115 | reject_list.append(outputs["rejected_end_scores"].cpu()) 116 | print("Total accuracy: ", correct / len(dev_dataset)) 117 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fine-tuning language models with Advantage-Induced Policy Alignment 2 | This repo contains the official implementation of paper "Fine-tuning language models with Advantage-Induced Policy Alignment", by Banghua Zhu, Hiteshi Sharma, Felipe Vieira Frujeri, Shi Dong, Chenguang Zhu, Michael I. Jordan, Jiantao Jiao. 3 | 4 | 5 | ### Abstract 6 | Reinforcement learning from human feedback (RLHF) has emerged as a reliable approach to aligning large language models (LLMs) to human preferences. Among the plethora of RLHF techniques, proximal policy optimization (PPO) is of the most widely used methods. Despite its popularity, however, PPO may suffer from mode collapse, instability, and poor sample efficiency. We show that these issues can be alleviated by a novel algorithm that we refer to as Advantage-Induced Policy Alignment (APA), which leverages a squared error loss function based on the estimated advantages. We demonstrate empirically that APA consistently outperforms PPO in language tasks by a large margin, when a separate reward model is employed as the evaluator. 7 | In addition, compared with PPO, APA offers a more stable form of control over the deviation from the model's initial policy, ensuring that the model improves its performance without collapsing to deterministic output. 8 | In addition to empirical results, we also provide a theoretical justification supporting the design of our loss function. 9 | 10 | 11 | ### Getting Started 12 | Python 3 is required for the current codebase. It's recommended to use Python 3.9 for installing the dependencies. Due to the current [Ray support issue](https://github.com/ray-project/ray/issues/33232), Python 3.11 may give error during executation. 13 | 14 | Install the dependencies as follows. 15 | 16 | ```shell 17 | pip install -r requirements.txt 18 | pip install -e . 19 | ``` 20 | 21 | To reproduce the experiments in the paper, execute the following set of code: 22 | ```shell 23 | ## For running APA or AWR on HH dataset 24 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/apa_hh.py 25 | 26 | ## For running PPO on HH dataset 27 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_hh.py 28 | 29 | ## For running APA or AWR on TLDR dataset 30 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/apa_tldr.py 31 | 32 | ## For running PPO on TLDR dataset 33 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ppo_tldr.py 34 | 35 | ## For running offline ILQL on HH dataset 36 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/ilql_hh.py 37 | 38 | ## For running offline APA or AWR on HH dataset 39 | accelerate launch --config_file configs/accelerate/zero2-bf16.yaml examples/hh/apa_off_hh.py 40 | ``` 41 | 42 | Inside each of the code file, one may adjust the random seed, model size and algorithm. Note that this code is not optimized with memory usage, only for a preliminary illustration of the differences between the existing policy iteration algorithms for RLHF. The code is tested on 4 V100 and 8 V100 for 125M and 1B models, and 4 MI200 for 6B models. We put reference model, reward model and value model in three difference GPUs. For smaller number of GPUs, you may need to change the device number in accelerate_sppo_trainer.py (and other corresponding accelerator files). 43 | 44 | 45 | ### Acknowledgement 46 | Our codebase is built based on a stable version of [CarperAI/trlX](https://github.com/CarperAI/trlx). We thank the authors for the nicely organized code! 47 | 48 | 49 | ### Contributing 50 | 51 | This project welcomes contributions and suggestions. Most contributions require you to agree to a 52 | Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us 53 | the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com. 54 | 55 | When you submit a pull request, a CLA bot will automatically determine whether you need to provide 56 | a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions 57 | provided by the bot. You will only need to do this once across all repos using our CLA. 58 | 59 | This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/). 60 | For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or 61 | contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments. 62 | 63 | ### Trademarks 64 | 65 | This project may contain trademarks or logos for projects, products, or services. Authorized use of Microsoft 66 | trademarks or logos is subject to and must follow 67 | [Microsoft's Trademark & Brand Guidelines](https://www.microsoft.com/en-us/legal/intellectualproperty/trademarks/usage/general). 68 | Use of Microsoft trademarks or logos in modified versions of this project must not cause confusion or imply Microsoft sponsorship. 69 | Any use of third-party trademarks or logos are subject to those third-party's policies. 70 | -------------------------------------------------------------------------------- /tests/test_trainers.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | import unittest 4 | from typing import List, Mapping 5 | 6 | import trlx.utils.logging as logging 7 | from trlx.data.configs import ( 8 | ModelConfig, 9 | OptimizerConfig, 10 | SchedulerConfig, 11 | TokenizerConfig, 12 | TrainConfig, 13 | TRLConfig, 14 | ) 15 | from trlx.models.modeling_ppo import PPOConfig 16 | from trlx.utils.loading import get_pipeline, get_trainer 17 | 18 | logging.disable_progress_bar() 19 | logging.set_verbosity(logging.ERROR) 20 | 21 | 22 | def get_default_train_and_eval_prompts() -> Mapping[str, List[str]]: 23 | return dict( 24 | train=[ 25 | "The quick brown fox jumps over the lazy", 26 | "The cat sat on the mat next to the", 27 | "What sort of food does a", 28 | "The nextdoor neighbor's fence couldn't keep the", 29 | "When Tom got home from work he had to walk his", 30 | ], 31 | eval=[ 32 | "I purchased a collar for my new", 33 | "I couldn't help but laugh when the mailman was chased by the", 34 | ], 35 | ) 36 | 37 | 38 | def get_default_reward_fn(): 39 | def reward_fn(samples: List[str], **kwargs): 40 | return [sample.count("dog") for sample in samples] 41 | 42 | return reward_fn 43 | 44 | 45 | class TestAccelerateBaseTrainer(unittest.TestCase): 46 | def setUp(self) -> None: 47 | super().setUp() 48 | self.prompt_dataset = get_default_train_and_eval_prompts() 49 | 50 | @classmethod 51 | def get_default_config(cls): 52 | return TRLConfig( 53 | train=TrainConfig( 54 | seq_length=16, 55 | epochs=1, 56 | total_steps=8, 57 | batch_size=2, 58 | checkpoint_interval=4, 59 | checkpoint_dir="checkpoints", 60 | eval_interval=8, 61 | pipeline="PromptPipeline", 62 | trainer="AcceleratePPOTrainer", 63 | tracker=None, 64 | ), 65 | model=ModelConfig(model_path="gpt2", num_layers_unfrozen=2), 66 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 67 | optimizer=OptimizerConfig( 68 | name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 69 | ), 70 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), 71 | method=PPOConfig( 72 | name="PPOConfig", 73 | num_rollouts=128, 74 | chunk_size=128, 75 | ppo_epochs=4, 76 | init_kl_coef=0.05, 77 | target=6, 78 | horizon=10000, 79 | gamma=1, 80 | lam=0.95, 81 | cliprange=0.2, 82 | cliprange_value=0.2, 83 | vf_coef=1, 84 | scale_reward="ignored", 85 | ref_mean=None, 86 | ref_std=None, 87 | cliprange_reward=10, 88 | gen_kwargs=dict( 89 | max_new_tokens=6, 90 | top_k=0, 91 | top_p=1.0, 92 | do_sample=True, 93 | ), 94 | ), 95 | ) 96 | 97 | def get_trainer(self, config: TRLConfig): 98 | trainer = get_trainer(config.train.trainer)( 99 | config=config, 100 | reward_fn=get_default_reward_fn(), 101 | metric_fn=None, 102 | stop_sequences=None, 103 | **config.train.trainer_kwargs, 104 | ) 105 | 106 | max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 107 | train_pipeline = get_pipeline(config.train.pipeline)( 108 | self.prompt_dataset["train"], max_prompt_length, trainer.tokenizer 109 | ) 110 | trainer.add_prompt_pipeline(train_pipeline) 111 | trainer.make_experience(config.method.num_rollouts) 112 | 113 | eval_pipeline = get_pipeline(config.train.pipeline)( 114 | self.prompt_dataset["eval"], max_prompt_length, trainer.tokenizer 115 | ) 116 | trainer.add_eval_pipeline(eval_pipeline) 117 | return trainer 118 | 119 | def test_save_checkpoint(self): 120 | with tempfile.TemporaryDirectory() as tmpdir: 121 | config = self.get_default_config() 122 | config.train.checkpoint_dir = tmpdir 123 | 124 | trainer = self.get_trainer(config) 125 | trainer.learn() 126 | 127 | total_steps = config.train.total_steps 128 | interval = config.train.checkpoint_interval 129 | for i in range(interval, total_steps + 1, interval): 130 | checkpoint_dir = os.path.join(tmpdir, f"checkpoint_{i}") 131 | self.assertTrue(os.path.isdir(checkpoint_dir)) 132 | if total_steps % interval != 0: 133 | self.assertTrue(os.path.isdir(os.path.join(tmpdir, f"checkpoint_{total_steps}"))) 134 | self.assertTrue(os.path.isdir(os.path.join(tmpdir, "best_checkpoint"))) 135 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import accelerate 4 | import pytest 5 | import torch 6 | import transformers 7 | 8 | import trlx.utils as utils 9 | import trlx.utils.modeling as modeling_utils 10 | 11 | try: 12 | import bitsandbytes 13 | 14 | HAS_BNB = True 15 | except ImportError: 16 | HAS_BNB = False 17 | 18 | 19 | # Test general utils 20 | 21 | 22 | @pytest.mark.parametrize( 23 | "optimizer_name", 24 | [o.value for o in utils.OptimizerName], 25 | ) 26 | def test_optimizer_class_getters(optimizer_name: str): 27 | try: 28 | _class = utils.get_optimizer_class(optimizer_name) 29 | except Exception as e: 30 | assert False, "Failed to get optimizer class with error: " + str(e) 31 | 32 | # Hard-check for one of the optimizers 33 | _class = utils.get_optimizer_class("adamw") 34 | assert _class == torch.optim.AdamW 35 | if HAS_BNB: 36 | _bnb_class = utils.get_optimizer_class("adamw_8bit_bnb") 37 | assert _bnb_class == bitsandbytes.optim.AdamW8bit 38 | 39 | 40 | @pytest.mark.parametrize( 41 | "scheduler_name", 42 | [o.value for o in utils.SchedulerName], 43 | ) 44 | def test_scheduler_class_getters(scheduler_name: str): 45 | try: 46 | _class = utils.get_scheduler_class(scheduler_name) 47 | except Exception as e: 48 | assert False, "Failed to get scheduler class with error: " + str(e) 49 | 50 | # Hard-check for one of the schedulers 51 | _class = utils.get_scheduler_class("cosine_annealing") 52 | assert _class == torch.optim.lr_scheduler.CosineAnnealingLR 53 | 54 | 55 | # Test modeling utils 56 | 57 | 58 | @pytest.mark.parametrize( 59 | "model_name", 60 | [ 61 | "EleutherAI/gpt-j-6B", 62 | "EleutherAI/gpt-neox-20b", 63 | "gpt2", 64 | "facebook/opt-1.3b", 65 | ], 66 | ) 67 | def test_hf_attr_getters(model_name: str): 68 | with accelerate.init_empty_weights(): 69 | config = transformers.AutoConfig.from_pretrained(model_name) 70 | arch = transformers.AutoModelForCausalLM.from_config(config) 71 | 72 | arch_getters = [ 73 | modeling_utils.hf_get_decoder, 74 | modeling_utils.hf_get_decoder_final_norm, 75 | modeling_utils.hf_get_decoder_blocks, 76 | modeling_utils.hf_get_lm_head, 77 | ] 78 | for get in arch_getters: 79 | try: 80 | get(arch) 81 | except Exception as e: 82 | assert False, "Failed to get model attribute with error: " + str(e) 83 | 84 | config_getters = [ 85 | modeling_utils.hf_get_hidden_size, 86 | modeling_utils.hf_get_num_hidden_layers, 87 | ] 88 | for get in config_getters: 89 | try: 90 | get(config) 91 | except Exception as e: 92 | assert False, "Failed to get config attribute with error: " + str(e) 93 | 94 | 95 | @pytest.mark.parametrize( 96 | "model_name", 97 | [ 98 | "EleutherAI/gpt-j-6B", 99 | "EleutherAI/gpt-neox-20b", 100 | "facebook/opt-1.3b", 101 | "bigscience/bloom-560m", 102 | "google/flan-t5-large", 103 | ], 104 | ) 105 | def test_parse_delta_kwargs(model_name): 106 | config = transformers.AutoConfig.from_pretrained(model_name) 107 | 108 | modified_modules_dict = modeling_utils.MODIFIED_MODULES_DICT[config.model_type] 109 | for default_modifier, default_modified_modules in modified_modules_dict.items(): 110 | delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs( 111 | delta_kwargs={"delta_type": "lora", "modified_modules": default_modifier}, 112 | config=config, 113 | num_layers_unfrozen=4, 114 | ) 115 | # Ensure the parsed module regex patterns capture the default module names 116 | for kwarg_mod, default_mod in zip(delta_kwargs["modified_modules"], default_modified_modules): 117 | assert kwarg_mod.endswith( 118 | default_mod 119 | ), f"Parsed modified module `{kwarg_mod}` should contain the trlx default `{default_mod}`" 120 | assert delta_type == "lora", "Delta type should be lora" 121 | 122 | # Ensure the defaults don't get used if the user specifies a list of `modified_modules` 123 | delta_type, delta_kwargs = modeling_utils.parse_delta_kwargs( 124 | delta_kwargs={"delta_type": "lora", "modified_modules": ["a", "b"]}, 125 | config=config, 126 | num_layers_unfrozen=2, 127 | ) 128 | for kwarg_mod in delta_kwargs["modified_modules"]: 129 | assert kwarg_mod.endswith("a") or kwarg_mod.endswith("b"), "Parsed modified module should contain ['a', 'b']" 130 | 131 | 132 | class TestStatistics(unittest.TestCase): 133 | @classmethod 134 | def setUpClass(cls): 135 | cls.m = modeling_utils.RunningMoments() 136 | cls.a1 = torch.arange(100, dtype=float) 137 | cls.a2 = torch.ones(100, dtype=float) 138 | cls.a3 = torch.exp(torch.arange(10, dtype=float)) 139 | cls.a4 = torch.tensor([-10, -1, 0, 1, 10], dtype=float) 140 | 141 | def test_running_moments(self): 142 | assert torch.isclose(self.m.update(self.a1)[1], self.a1.std(unbiased=True), atol=1e-6) 143 | assert torch.isclose(self.m.update(self.a2)[1], self.a2.std(unbiased=True), atol=1e-6) 144 | assert torch.isclose(self.m.update(self.a3)[1], self.a3.std(unbiased=True), atol=1e-6) 145 | assert torch.isclose(self.m.update(self.a4)[1], self.a4.std(unbiased=True), atol=1e-6) 146 | 147 | a = torch.hstack((self.a1, self.a2, self.a3, self.a4)) 148 | assert torch.isclose(self.m.mean, a.mean(), atol=1e-6) 149 | assert torch.isclose(self.m.std, a.std(unbiased=True), atol=1e-6) 150 | -------------------------------------------------------------------------------- /trlx/trlx.py: -------------------------------------------------------------------------------- 1 | import os 2 | import warnings 3 | from typing import Callable, Dict, Iterable, List, Optional, Tuple 4 | 5 | from trlx.data.configs import TRLConfig 6 | from trlx.data.default_configs import ( 7 | default_ilql_config, 8 | default_ppo_config, 9 | default_sft_config, 10 | ) 11 | from trlx.utils import set_seed 12 | from trlx.utils.loading import get_pipeline, get_trainer 13 | 14 | 15 | def train( # noqa: C901 16 | model_path: Optional[str] = None, 17 | reward_fn: Optional[Callable[[List[str], List[str], List[str]], List[float]]] = None, 18 | dataset: Optional[Iterable[Tuple[str, float]]] = None, 19 | samples: Optional[List[str]] = None, 20 | rewards: Optional[List[float]] = None, 21 | prompts: Optional[List[str]] = None, 22 | eval_prompts: Optional[List[str]] = None, 23 | metric_fn: Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]] = None, 24 | config: Optional[TRLConfig] = None, 25 | stop_sequences: Optional[List[str]] = [], 26 | ): 27 | """ 28 | Dispatches online, offline reinforcement training or supervised finetuning 29 | depending on whether a reward function or a list of samples & rewards, or only list of samples is given 30 | 31 | Args: 32 | model_path (Optional[str]): Path to either huggingface checkpoint or a local directory 33 | config (Optional[TRLConfig]): TRLX configuration object 34 | reward_fn (Optional[Callable[[List[str], List[str], List[str]], List[float]]]): 35 | Function to rate batches of generated samples. Its arguments are 36 | (`samples`, `prompts`, `outputs`) and the return is a list of `rewards` 37 | dataset (List[Union[str, List[str]]], List[float]): 38 | Lists of samples and rewards for offline training. (Use `samples` and `rewards` instead) 39 | samples (List[Union[str, List[str]]]): 40 | List of strings or a list of prompts (questions or environment states) and outputs which are 41 | meant to be optimized. In the latter case the following form is expected: 42 | (prompt_0: str, output_0: str, prompt_1: str, output_1: str ...). 43 | Giving a single string `s` for the sample is a shorthand for (`tokenizer.bos_token`, `s`) 44 | rewards (List[float]): 45 | List of real numbers measuring the goodness of each sample 46 | prompts (List[str]): Prompts to use for generations during online training 47 | eval_prompts (List[str]): Prompts to use for periodical validation of training 48 | metric_fn (Optional[Callable[[List[str], List[str], List[str]], Dict[str, List[float]]]]): 49 | Function to compute statistics on batches of generated samples. Its arguments are the same 50 | as in `reward_fn` (`samples`, `prompts`, `outputs`) but the return is dictionary with keys 51 | as metric's name and values and lists of numeric values per each sample in batch 52 | stop_sequences (Optional[List[str]]): 53 | String sequences to trim generations (both for generating of experience and evaluation) up to its 54 | encounter in them. Generations will not contain them and also will also be right-stripped 55 | """ 56 | if config is None: 57 | warnings.warn( 58 | "Passing the `config` argument implicitly is depreciated, use or" 59 | "adapt some from `trlx/data/default_configs.py` instead" 60 | ) 61 | if reward_fn: 62 | config = default_ppo_config() 63 | elif rewards: 64 | config = default_ilql_config() 65 | else: 66 | config = default_sft_config() 67 | 68 | set_seed(config.train.seed) 69 | 70 | if dataset: 71 | warnings.warn("the `dataset` argument is being depreciated, split it into `samples` and `rewards` instead") 72 | samples, rewards = dataset 73 | 74 | if model_path: 75 | config.model.model_path = model_path 76 | 77 | trainer = get_trainer(config.train.trainer)( 78 | config=config, 79 | reward_fn=reward_fn, 80 | metric_fn=metric_fn, 81 | stop_sequences=stop_sequences, 82 | **config.train.trainer_kwargs, 83 | ) 84 | 85 | batch_size = config.train.batch_size * int(os.environ.get("WORLD_SIZE", 1)) 86 | max_prompt_length = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 87 | 88 | # Online training against a reward function (e.g. PPO) 89 | if reward_fn: 90 | prompts = prompts or [trainer.tokenizer.bos_token] * batch_size 91 | 92 | if eval_prompts is None: 93 | eval_prompts = prompts[:batch_size] 94 | 95 | pipeline = get_pipeline(config.train.pipeline)(prompts, max_prompt_length, trainer.tokenizer) 96 | trainer.add_prompt_pipeline(pipeline) 97 | 98 | if eval_prompts is None: 99 | eval_prompts = prompts[:batch_size] 100 | 101 | trainer.make_experience(config.method.num_rollouts) 102 | 103 | # Offline training from the collected samples (e.g. SFT, ILQL) 104 | elif samples: 105 | if rewards: 106 | if len(samples) != len(rewards): 107 | raise ValueError(f"Number of samples {len(samples)} should match the number of rewards {len(rewards)}") 108 | 109 | if eval_prompts is None: 110 | eval_prompts = [trainer.tokenizer.bos_token] * batch_size 111 | 112 | if rewards: 113 | trainer.make_experience(samples, rewards, config.train.seq_length) 114 | else: 115 | trainer.store = get_pipeline(config.train.pipeline)(samples, max_prompt_length, trainer.tokenizer) 116 | 117 | else: 118 | raise ValueError("Either `samples` or `reward_fn` should be given for training") 119 | 120 | eval_pipeline = get_pipeline(config.train.pipeline)(eval_prompts, max_prompt_length, trainer.tokenizer) 121 | trainer.add_eval_pipeline(eval_pipeline) 122 | 123 | trainer.learn() 124 | return trainer 125 | -------------------------------------------------------------------------------- /trlx/data/default_configs.py: -------------------------------------------------------------------------------- 1 | from trlx.models.modeling_ilql import ILQLConfig 2 | from trlx.models.modeling_ppo import PPOConfig 3 | from trlx.models.modeling_sppo import SPPOConfig 4 | from trlx.trainer.accelerate_sft_trainer import SFTConfig 5 | 6 | from .configs import ( 7 | ModelConfig, 8 | OptimizerConfig, 9 | SchedulerConfig, 10 | TokenizerConfig, 11 | TrainConfig, 12 | TRLConfig, 13 | ) 14 | 15 | 16 | def default_sppo_config(): 17 | return TRLConfig( 18 | train=TrainConfig( 19 | seq_length=1024, 20 | epochs=100, 21 | total_steps=10000, 22 | batch_size=32, 23 | checkpoint_interval=10000, 24 | eval_interval=100, 25 | pipeline="PromptPipeline", 26 | trainer="AccelerateSPPOTrainer", 27 | ), 28 | model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), 29 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 30 | optimizer=OptimizerConfig( 31 | name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 32 | ), 33 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), 34 | method=SPPOConfig( 35 | name="PPOConfig", 36 | num_rollouts=128, 37 | chunk_size=128, 38 | ppo_epochs=4, 39 | init_kl_coef=0.05, 40 | target=6, 41 | horizon=10000, 42 | gamma=1, 43 | lam=0.95, 44 | cliprange=0.2, 45 | cliprange_value=0.2, 46 | vf_coef=1, 47 | scale_reward="ignored", 48 | ref_mean=None, 49 | ref_std=None, 50 | cliprange_reward=10, 51 | gen_kwargs=dict( 52 | max_new_tokens=40, 53 | top_k=0, 54 | top_p=1.0, 55 | do_sample=True, 56 | ), 57 | ), 58 | ) 59 | 60 | 61 | def default_ppo_config(): 62 | return TRLConfig( 63 | train=TrainConfig( 64 | seq_length=1024, 65 | epochs=100, 66 | total_steps=10000, 67 | batch_size=32, 68 | checkpoint_interval=10000, 69 | eval_interval=100, 70 | pipeline="PromptPipeline", 71 | trainer="AcceleratePPOTrainer", 72 | ), 73 | model=ModelConfig(model_path="lvwerra/gpt2-imdb", num_layers_unfrozen=2), 74 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 75 | optimizer=OptimizerConfig( 76 | name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 77 | ), 78 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4)), 79 | method=PPOConfig( 80 | name="PPOConfig", 81 | num_rollouts=128, 82 | chunk_size=128, 83 | ppo_epochs=4, 84 | init_kl_coef=0.05, 85 | target=6, 86 | horizon=10000, 87 | gamma=1, 88 | lam=0.95, 89 | cliprange=0.2, 90 | cliprange_value=0.2, 91 | vf_coef=1, 92 | scale_reward="ignored", 93 | ref_mean=None, 94 | ref_std=None, 95 | cliprange_reward=10, 96 | gen_kwargs=dict( 97 | max_new_tokens=40, 98 | top_k=0, 99 | top_p=1.0, 100 | do_sample=True, 101 | ), 102 | ), 103 | ) 104 | 105 | 106 | 107 | def default_ilql_config(): 108 | return TRLConfig( 109 | train=TrainConfig( 110 | seq_length=64, 111 | batch_size=32, 112 | epochs=100, 113 | total_steps=1000, 114 | checkpoint_interval=1000, 115 | eval_interval=100, 116 | pipeline="PromptPipeline", 117 | trainer="AccelerateILQLTrainer", 118 | ), 119 | model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), 120 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 121 | optimizer=OptimizerConfig( 122 | name="adamw", kwargs=dict(lr=5.0e-5, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 123 | ), 124 | scheduler=SchedulerConfig( 125 | name="cosine_annealing", kwargs=dict(T_max=1000, eta_min=5.0e-5) # train.total_steps 126 | ), 127 | method=ILQLConfig( 128 | name="ilqlconfig", 129 | tau=0.7, 130 | gamma=0.99, 131 | cql_scale=0.1, 132 | awac_scale=1, 133 | alpha=0.001, 134 | beta=0, 135 | steps_for_target_q_sync=5, 136 | two_qs=True, 137 | gen_kwargs=dict(max_new_tokens=56, top_k=20, beta=4, temperature=1.0), 138 | ), 139 | ) 140 | 141 | 142 | def default_sft_config(): 143 | return TRLConfig( 144 | train=TrainConfig( 145 | seq_length=1024, 146 | epochs=100, 147 | total_steps=1000, 148 | batch_size=8, 149 | checkpoint_interval=10000, 150 | eval_interval=100, 151 | pipeline="PromptPipeline", 152 | trainer="AccelerateSFTTrainer", 153 | ), 154 | model=ModelConfig(model_path="gpt2", num_layers_unfrozen=-1), 155 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 156 | optimizer=OptimizerConfig( 157 | name="adamw", kwargs=dict(lr=1.0e-4, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6) 158 | ), 159 | scheduler=SchedulerConfig( 160 | name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1.0e-4) # train.total_steps 161 | ), 162 | method=SFTConfig( 163 | name="sftconfig", 164 | gen_kwargs=dict(max_new_tokens=40, top_k=0, top_p=1.0, do_sample=True), 165 | ), 166 | ) 167 | -------------------------------------------------------------------------------- /examples/reward_model/train_reward_model_gptj.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from datasets import load_dataset 5 | from reward_model import GPTRewardModel 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | from transformers import AutoTokenizer, Trainer, TrainingArguments 9 | import wandb 10 | 11 | def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): 12 | dataset = load_dataset(path, split=split) 13 | pairs = [] 14 | for sample in tqdm(dataset): 15 | pair = {} 16 | prompt = sample["prompt"] 17 | chosen_summary = sample["chosen"] 18 | rejected_summary = sample["rejected"] 19 | if chosen_summary == rejected_summary: 20 | continue 21 | if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: 22 | continue 23 | pair["chosen"] = prompt + "\n" + chosen_summary 24 | pair["rejected"] = prompt + "\n" + rejected_summary 25 | pairs.append(pair) 26 | return pairs 27 | 28 | 29 | class PairwiseDataset(Dataset): 30 | def __init__(self, pairs, tokenizer, max_length): 31 | self.chosen_input_ids = [] 32 | self.chosen_attn_masks = [] 33 | self.rejected_input_ids = [] 34 | self.rejected_attn_masks = [] 35 | for pair in tqdm(pairs): 36 | chosen, rejected = pair["chosen"], pair["rejected"] 37 | if chosen != rejected: 38 | chosen_encodings_dict = tokenizer( 39 | "<|startoftext|>" + chosen + "<|endoftext|>", 40 | truncation=True, 41 | max_length=max_length, 42 | padding="max_length", 43 | return_tensors="pt", 44 | ) 45 | rejected_encodings_dict = tokenizer( 46 | "<|startoftext|>" + rejected + "<|endoftext|>", 47 | truncation=True, 48 | max_length=max_length, 49 | padding="max_length", 50 | return_tensors="pt", 51 | ) 52 | self.chosen_input_ids.append(chosen_encodings_dict["input_ids"]) 53 | self.chosen_attn_masks.append(chosen_encodings_dict["attention_mask"]) 54 | self.rejected_input_ids.append(rejected_encodings_dict["input_ids"]) 55 | self.rejected_attn_masks.append(rejected_encodings_dict["attention_mask"]) 56 | 57 | def __len__(self): 58 | return len(self.chosen_input_ids) 59 | 60 | def __getitem__(self, idx): 61 | return ( 62 | self.chosen_input_ids[idx], 63 | self.chosen_attn_masks[idx], 64 | self.rejected_input_ids[idx], 65 | self.rejected_attn_masks[idx], 66 | ) 67 | 68 | 69 | class DataCollatorReward: 70 | def __call__(self, data): 71 | batch = {} 72 | batch["input_ids"] = torch.cat([f[0] for f in data] + [f[2] for f in data]) 73 | batch["attention_mask"] = torch.cat([f[1] for f in data] + [f[3] for f in data]) 74 | batch["labels"] = torch.tensor([0] * len(data) + [1] * len(data)) 75 | return batch 76 | 77 | 78 | def compute_metrics(eval_preds): 79 | chosen_end_scores = eval_preds.predictions[0] # chosen scores 80 | rejected_end_scores = eval_preds.predictions[1] # rejected scores 81 | 82 | result = {} 83 | acc = sum(chosen_end_scores > rejected_end_scores) / len(rejected_end_scores) 84 | result["accuracy"] = acc 85 | 86 | return result 87 | 88 | 89 | if __name__ == "__main__": 90 | tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 91 | tokenizer.pad_token = tokenizer.eos_token 92 | output_dir = os.path.join(os.getenv('AMLT_LOGS_DIR', os.getenv('AZUREML_CR_HT_CAP_logs_PATH'))) 93 | wandb.init(name='RLHF rl gptj summarization', project='clausa-trlx', tags=['rlhf'], mode='offline', dir=output_dir) 94 | 95 | if not os.path.exists("/mnt/data/users/v-banghua/rm_checkpoint"): 96 | os.mkdir("/mnt/data/users/v-banghua/rm_checkpoint") 97 | 98 | training_args = TrainingArguments( 99 | output_dir="/mnt/data/users/v-banghua/rm_checkpoint/", 100 | num_train_epochs=5, 101 | logging_steps=10, 102 | gradient_accumulation_steps=4, 103 | save_strategy="steps", 104 | evaluation_strategy="steps", 105 | per_device_train_batch_size=1, 106 | per_device_eval_batch_size=1, 107 | eval_accumulation_steps=1, 108 | eval_steps=500, 109 | save_steps=500, 110 | warmup_steps=100, 111 | logging_dir="./logs", 112 | fp16=True, 113 | bf16=False, 114 | learning_rate=1e-5, 115 | deepspeed="examples/summarize_rlhf/reward_model/ds_config_gpt_j.json", 116 | save_total_limit=1, 117 | ) 118 | 119 | # Initialize the reward model from the (supervised) fine-tuned GPT-J 120 | model = GPTRewardModel("CarperAI/openai_summarize_tldr_sft") 121 | 122 | # Freeze the first 70% of the hidden layers of the reward model backbone 123 | layers = model.transformer.h 124 | num_layers = len(layers) 125 | num_unfrozen = int(0.3 * num_layers) 126 | for layer in layers[:-num_unfrozen]: 127 | layer.requires_grad_(False) 128 | 129 | # Create the comparisons datasets 130 | data_path = "Dahoas/full-hh-rlhf" # "CarperAI/openai_summarize_comparisons" 131 | train_pairs = create_comparison_dataset(data_path, "train") 132 | val_pairs = create_comparison_dataset(data_path, "test") 133 | 134 | # Make pairwise datasets for training 135 | max_length = 550 136 | train_dataset = PairwiseDataset(train_pairs, tokenizer, max_length=max_length) 137 | val_dataset = PairwiseDataset(val_pairs, tokenizer, max_length=max_length) 138 | 139 | # Create the collator to gather batches of pairwise comparisons 140 | data_collator = DataCollatorReward() 141 | 142 | Trainer( 143 | model=model, 144 | args=training_args, 145 | train_dataset=train_dataset, 146 | compute_metrics=compute_metrics, 147 | eval_dataset=val_dataset, 148 | data_collator=data_collator, 149 | ).train() 150 | -------------------------------------------------------------------------------- /examples/apa_off_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import tritonclient.grpc as client_util 9 | from datasets import load_dataset 10 | from huggingface_hub import snapshot_download 11 | from ppo_hh import create_reward_fn 12 | from torch import nn 13 | from transformers import AutoModelForCausalLM, AutoTokenizer 14 | from tritonclient.utils import np_to_triton_dtype 15 | from datasets import load_from_disk 16 | 17 | import trlx 18 | from trlx.data.default_configs import ( 19 | ModelConfig, 20 | OptimizerConfig, 21 | SPPOConfig, 22 | SchedulerConfig, 23 | TokenizerConfig, 24 | TrainConfig, 25 | TRLConfig, 26 | ) 27 | import random 28 | 29 | RANDOM_SEED = 1000 30 | MODEL_SIZE = "1B" 31 | LOSS = "log" # "square" or "log", square for APA, and log for AWR 32 | ADV_COEFF_SQ = 1 33 | ADV_COEFF_LOG = 1 34 | OUTPUT_DIR = "output" 35 | 36 | random.seed(RANDOM_SEED) 37 | np.random.seed(RANDOM_SEED) 38 | torch.manual_seed(RANDOM_SEED) 39 | torch.cuda.manual_seed(RANDOM_SEED) 40 | default_config = TRLConfig( 41 | train=TrainConfig( 42 | seq_length=1000, 43 | epochs=10000, 44 | total_steps=20000, 45 | batch_size=4, 46 | checkpoint_interval=10000, 47 | eval_interval=500, 48 | pipeline="PromptPipeline", 49 | trainer="AccelerateSQLOffTrainer", 50 | checkpoint_dir="checkpoints/ppo_hh", 51 | seed = RANDOM_SEED, 52 | ), 53 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2), 54 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 55 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 56 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1e-6)), 57 | method=SPPOConfig( 58 | name="SPPOConfig", 59 | num_rollouts=64, 60 | chunk_size=16, 61 | ppo_epochs=2, 62 | init_kl_coef=0.05, 63 | target=6, 64 | horizon=10000, 65 | gamma=1, 66 | lam=0.95, 67 | cliprange=100, 68 | cliprange_value=100, 69 | vf_coef=1, 70 | scale_reward="running", 71 | loss_str=LOSS, 72 | adv_coeff_sq=ADV_COEFF_SQ, 73 | adv_coeff_log=ADV_COEFF_LOG, 74 | ref_mean=None, 75 | ref_std=None, 76 | cliprange_reward=100, 77 | gen_kwargs=dict( 78 | max_new_tokens=128, 79 | do_sample=True, 80 | ), 81 | ), 82 | ) 83 | 84 | 85 | config_name = MODEL_SIZE 86 | if config_name == "125M": 87 | default_config.train.batch_size = 4 88 | default_config.method.chunk_size = 16 89 | default_config.train.total_steps = 20000 90 | default_config.model.model_path = "Dahoas/pythia-125M-static-sft" 91 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 92 | default_config.method.num_rollouts = 128 93 | elif config_name == "1B": 94 | default_config.train.batch_size = 2 95 | default_config.train.total_steps = 20000 96 | default_config.optimizer.kwargs["lr"] = 1e-6 97 | default_config.scheduler.kwargs["eta_min"] = 1e-6 98 | default_config.model.model_path = "Dahoas/pythia-1B-static-sft" 99 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 100 | default_config.method.chunk_size = 4 101 | elif config_name == "6B": 102 | default_config.train.batch_size = 1 103 | default_config.train.total_steps = 20000 104 | default_config.model.model_path = "Dahoas/pythia-6B-static-sft" 105 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 106 | default_config.method.chunk_size = 1 107 | default_config.optimizer.kwargs["lr"] = 1e-6 108 | default_config.scheduler.kwargs["eta_min"] = 1e-6 109 | elif config_name == "20B": 110 | default_config.train.seq_length = 512 111 | default_config.train.batch_size = 1 112 | default_config.train.total_steps = 8000 113 | default_config.optimizer.kwargs["lr"] = 1e-6 114 | default_config.scheduler.kwargs["eta_min"] = 1e-6 115 | default_config.train.checkpoint_dir = "checkpoints/ppo_hh_20B" 116 | default_config.model.model_path = "EleutherAI/gpt-neox-20b" 117 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 118 | default_config.method.num_rollouts = 16 119 | default_config.method.chunk_size = 4 120 | default_config.method.ppo_epochs = 2 121 | 122 | 123 | def prepare_tensor(name: str, input): 124 | t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) 125 | t.set_data_from_numpy(input) 126 | return t 127 | 128 | 129 | reward_fn = create_reward_fn() 130 | def preprocess(sample): 131 | sample["prompt_output"] = [ 132 | [sample["prompt"], sample["chosen"]], 133 | [sample["prompt"], sample["rejected"]], 134 | ] 135 | prompt_res = [ 136 | sample["prompt"] + sample["chosen"], 137 | sample["prompt"] + sample["rejected"], 138 | ] 139 | reward = reward_fn(prompt_res, sample["prompt"], sample["rejected"]) 140 | 141 | sample["reward"] = reward 142 | return sample 143 | 144 | def main(hparams={}): 145 | output_dir = OUTPUT_DIR 146 | config = TRLConfig.update(default_config, hparams) 147 | config.train.checkpoint_dir = output_dir 148 | config.train.logging_dir = output_dir 149 | config.train.tracker = "tensorboard" 150 | 151 | 152 | 153 | dataset = load_dataset("Dahoas/rm-static").map(preprocess) 154 | prompts_outputs = sum(dataset["train"]["prompt_output"], []) 155 | 156 | rewards = sum(dataset["train"]["reward"], []) 157 | 158 | eval_prompts = [prompt_output[0][0] for prompt_output in dataset["test"]["prompt_output"]][:280] #280 159 | 160 | 161 | trlx.train( 162 | samples=prompts_outputs, 163 | rewards=rewards, 164 | config=config, 165 | eval_prompts=eval_prompts, 166 | metric_fn=lambda **kwargs: {"reward": reward_fn(**kwargs)}, 167 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 168 | ) 169 | 170 | 171 | if __name__ == "__main__": 172 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 173 | main(hparams) 174 | -------------------------------------------------------------------------------- /configs/nemo_configs/megatron_20b.yaml: -------------------------------------------------------------------------------- 1 | name: megatron_gpt 2 | restore_from_path: null # used when starting from a .nemo file 3 | 4 | trainer: 5 | devices: 8 6 | num_nodes: 4 7 | accelerator: gpu 8 | precision: 16 9 | logger: False # logger provided by exp_manager 10 | enable_checkpointing: False 11 | replace_sampler_ddp: False 12 | max_epochs: -1 # PTL default. In practice, max_steps will be reached first. 13 | max_steps: 200 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches 14 | log_every_n_steps: 1 15 | val_check_interval: 20 16 | # check_val_every_n_epoch: null 17 | limit_val_batches: 2 18 | limit_test_batches: 0 19 | accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models 20 | gradient_clip_val: 1.0 21 | benchmark: False 22 | 23 | exp_manager: 24 | # set this to save checkpoints 25 | explicit_log_dir: ilql_sentiments_logs 26 | exp_dir: null 27 | name: megatron_gpt_20b_ilql_sentiments 28 | create_tensorboard_logger: False 29 | create_wandb_logger: True 30 | wandb_logger_kwargs: 31 | project: trlxnemo 32 | name: megatron_gpt_20b_ilql_sentiments 33 | resume_if_exists: False 34 | resume_ignore_no_checkpoint: True 35 | # set this to save checkpoints 36 | create_checkpoint_callback: True 37 | checkpoint_callback_params: 38 | monitor: reduced_train_loss 39 | save_top_k: 1 40 | mode: min 41 | always_save_nemo: False # saves nemo file during validation, not implemented for model parallel 42 | save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits 43 | filename: 'megatron_gpt-{reduced_train_loss:.2f}-{step}-{consumed_samples}' 44 | model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} 45 | log_step_timing: True 46 | step_timing_kwargs: 47 | sync_cuda: True 48 | buffer_size: 5 49 | 50 | model: 51 | micro_batch_size: 4 52 | global_batch_size: 512 53 | tensor_model_parallel_size: 4 54 | pipeline_model_parallel_size: 1 55 | resume_from_checkpoint: null # manually set the checkpoint file to load from 56 | # model architecture 57 | encoder_seq_length: 1024 58 | max_position_embeddings: 2048 59 | num_layers: 44 60 | hidden_size: 6144 61 | ffn_hidden_size: ${multiply:4, ${.hidden_size}} # Transformer FFN hidden size. 4 * hidden_size. 62 | num_attention_heads: 48 63 | init_method_std: 0.007 # Standard deviation of the zero mean normal distribution used for weight initialization.') 64 | hidden_dropout: 0.1 # Dropout probability for hidden state transformer. 65 | kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null 66 | apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. 67 | layernorm_epsilon: 1e-5 68 | make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. 69 | pre_process: True # add embedding 70 | post_process: True # add pooler 71 | persist_layer_norm: True # Use of persistent fused layer norm kernel. 72 | grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce 73 | gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs 74 | 75 | 76 | ## Activation Checkpointing 77 | activations_checkpoint_granularity: 'selective' #'selective' # 'selective' or 'full' 78 | activations_checkpoint_method: 'uniform' # 'uniform', 'block', not used with 'selective' 79 | activations_checkpoint_num_layers: null # not used with 'selective' 80 | 81 | ## Sequence Parallelism 82 | sequence_parallel: True 83 | 84 | tokenizer: 85 | library: 'megatron' 86 | type: 'GPT2BPETokenizer' 87 | model: null 88 | vocab_file: null 89 | merge_file: null 90 | delimiter: null # only used for tabular tokenizer 91 | sentencepiece_legacy: false # Legacy=True allows you to add special tokens to sentencepiece tokenizers. 92 | 93 | # precision 94 | native_amp_init_scale: 4294967296 # 2 ** 32 95 | native_amp_growth_interval: 1000 96 | hysteresis: 2 # Gradient scale hysteresis 97 | fp32_residual_connection: False # Move residual connections to fp32 98 | fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 99 | 100 | # Megatron O2-style half-precision 101 | # TODO: this causes hangs for some reason 102 | megatron_amp_O2: True # Enable O2-level automatic mixed precision using main parameters 103 | grad_allreduce_chunk_size_mb: 125 104 | sync_batch_comm: False 105 | # miscellaneous 106 | seed: 1234 107 | use_cpu_initialization: False # Init weights on the CPU (slow for large models) 108 | onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. 109 | apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this 110 | gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) 111 | 112 | data: 113 | data_prefix: 114 | - dataset: hh 115 | index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix 116 | data_impl: mmap 117 | splits_string: 900,50,50 118 | seq_length: ${model.encoder_seq_length} 119 | skip_warmup: True 120 | num_workers: 2 121 | dataloader_type: cyclic 122 | reset_position_ids: False # Reset position ids after end-of-document token 123 | reset_attention_mask: False # Reset attention mask after end-of-document token 124 | eod_mask_loss: False # Mask loss for the end of document tokens 125 | 126 | # Nsys profiling options 127 | nsys_profile: 128 | enabled: False 129 | start_step: 10 # Global batch to start profiling 130 | end_step: 10 # Global batch to end profiling 131 | ranks: [0, 4, 8, 12] # Global rank IDs to profile 132 | gen_shape: False # Generate model and kernel details including input shapes 133 | 134 | optim: 135 | name: distributed_fused_adam 136 | lr: 5.0e-5 137 | weight_decay: 1.0e-6 138 | betas: 139 | - 0.9 140 | - 0.95 141 | sched: 142 | name: CosineAnnealing 143 | max_steps: 200 144 | min_lr: 5.0e-5 145 | -------------------------------------------------------------------------------- /configs/nemo_configs/megatron_65b.yaml: -------------------------------------------------------------------------------- 1 | name: megatron_gpt 2 | restore_from_path: null # used when starting from a .nemo file 3 | 4 | trainer: 5 | devices: 8 6 | num_nodes: 16 7 | accelerator: gpu 8 | precision: bf16 9 | logger: False # logger provided by exp_manager 10 | enable_checkpointing: False 11 | replace_sampler_ddp: False 12 | max_epochs: -1 # PTL default. In practice, max_steps will be reached first. 13 | max_steps: 1000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches 14 | log_every_n_steps: 1 15 | val_check_interval: 10 16 | limit_val_batches: 0.0 17 | limit_test_batches: 500 18 | accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models 19 | gradient_clip_val: 1.0 20 | benchmark: False 21 | 22 | exp_manager: 23 | explicit_log_dir: null 24 | exp_dir: null 25 | name: megatron_gpt_70b 26 | create_wandb_logger: True 27 | wandb_logger_kwargs: 28 | project: trlx 29 | name: ilql_sentiments_70b 30 | resume_if_exists: True 31 | resume_ignore_no_checkpoint: True 32 | create_checkpoint_callback: False 33 | checkpoint_callback_params: 34 | monitor: val_loss 35 | save_top_k: 1 36 | mode: min 37 | always_save_nemo: False # saves nemo file during validation, not implemented for model parallel 38 | save_nemo_on_train_end: False # not recommended when training large models on clusters with short time limits 39 | filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}' 40 | model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} 41 | log_step_timing: True 42 | step_timing_kwargs: 43 | sync_cuda: True 44 | buffer_size: 5 45 | 46 | model: 47 | micro_batch_size: 8 48 | global_batch_size: 128 #2048 49 | tensor_model_parallel_size: 8 50 | pipeline_model_parallel_size: 4 #2 51 | resume_from_checkpoint: null # manually set the checkpoint file to load from 52 | 53 | # model architecture 54 | encoder_seq_length: 2048 55 | max_position_embeddings: 2048 56 | num_layers: 80 57 | hidden_size: 8192 58 | ffn_hidden_size: ${multiply:4, ${.hidden_size}} # Transformer FFN hidden size. 4 * hidden_size. 59 | num_attention_heads: 128 60 | init_method_std: 0.007 # Standard deviation of the zero mean normal distribution used for weight initialization.') 61 | hidden_dropout: 0.1 # Dropout probability for hidden state transformer. 62 | kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null 63 | apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. 64 | layernorm_epsilon: 1e-5 65 | make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. 66 | pre_process: True # add embedding 67 | post_process: True # add pooler 68 | persist_layer_norm: True # Use of persistent fused layer norm kernel. 69 | grad_div_ar_fusion: True # Fuse grad division into torch.distributed.all_reduce 70 | gradient_accumulation_fusion: True # Fuse weight gradient accumulation to GEMMs 71 | 72 | sync_batch_comm: True 73 | 74 | ## Activation Checkpointing 75 | activations_checkpoint_granularity: 'selective' # 'selective' or 'full' 76 | activations_checkpoint_method: 'uniform' # 'block' # 'uniform', 'block', not used with 'selective' 77 | activations_checkpoint_num_layers: 1 # 2 # not used with 'selective' 78 | 79 | ## Sequence Parallelism 80 | sequence_parallel: True 81 | 82 | tokenizer: 83 | library: 'megatron' 84 | type: 'GPT2BPETokenizer' 85 | model: null 86 | vocab_file: null 87 | merge_file: null 88 | delimiter: null # only used for tabular tokenizer 89 | sentencepiece_legacy: false # Legacy=True allows you to add special tokens to sentencepiece tokenizers. 90 | 91 | # precision 92 | native_amp_init_scale: 4294967296 # 2 ** 32 93 | native_amp_growth_interval: 1000 94 | hysteresis: 2 # Gradient scale hysteresis 95 | fp32_residual_connection: False # Move residual connections to fp32 96 | fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 97 | 98 | # Megatron O2-style half-precision 99 | megatron_amp_O2: False # Enable O2-level automatic mixed precision using main parameters 100 | grad_allreduce_chunk_size_mb: 125 101 | 102 | # miscellaneous 103 | seed: 1234 104 | use_cpu_initialization: False # Init weights on the CPU (slow for large models) 105 | onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. 106 | apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this 107 | gradient_as_bucket_view: True # PyTorch DDP argument. Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) 108 | 109 | data: 110 | # Path to data must be specified by the user. 111 | # can override from the CLI: "model.data.data_prefix=[.5,/raid/data/pile/my-gpt3_00_text_document,.5,/raid/data/pile/my-gpt3_01_text_document]", 112 | # Or see example below: 113 | # data_prefix: 114 | # - .5 115 | # - /raid/data/pile/my-gpt3_00_text_document 116 | # - .5 117 | # - /raid/data/pile/my-gpt3_01_text_document 118 | data_prefix: 119 | ignored: ignored 120 | index_mapping_dir: null # path to save index mapping .npy files, by default will save in the same location as data_prefix 121 | data_impl: mmap 122 | splits_string: 900,50,50 123 | seq_length: ${model.encoder_seq_length} 124 | skip_warmup: True 125 | num_workers: 2 126 | dataloader_type: single # cyclic 127 | reset_position_ids: False # Reset position ids after end-of-document token 128 | reset_attention_mask: False # Reset attention mask after end-of-document token 129 | eod_mask_loss: False # Mask loss for the end of document tokens 130 | 131 | # Nsys profiling options 132 | nsys_profile: 133 | enabled: False 134 | start_step: 10 # Global batch to start profiling 135 | end_step: 10 # Global batch to end profiling 136 | ranks: [0, 4, 8, 12] # Global rank IDs to profile 137 | gen_shape: False # Generate model and kernel details including input shapes 138 | 139 | optim: 140 | name: distributed_fused_adam 141 | lr: 1.1e-4 142 | weight_decay: 0.1 143 | betas: 144 | - 0.9 145 | - 0.95 146 | sched: 147 | name: CosineAnnealing 148 | warmup_steps: 115 149 | constant_steps: 12500 150 | min_lr: 1.1e-5 151 | -------------------------------------------------------------------------------- /trlx/ray_tune/__init__.py: -------------------------------------------------------------------------------- 1 | from ray import tune 2 | 3 | 4 | def get_param_space(config: dict): # noqa: C901 5 | """Get the param space from the config file.""" 6 | 7 | def get_strategy(value): 8 | """Get search space strategy from config. 9 | A search space defines valid values for your hyperparameters and 10 | can specify how these values are sampled. 11 | 12 | Refer to the documentation for more info: 13 | https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs 14 | 15 | The user will have to define the search space in the config file by providing 16 | the name of the `strategy` and the `values` to sample from. 17 | 18 | The valid strategies are: 19 | - `uniform` (List) - Samples uniformly between the given bounds. 20 | - `quniform` (List) - Samples uniformly between the given bounds, quantized. 21 | - `loguniform` (List) - Samples uniformly between the given bounds on a log scale. 22 | - `qloguniform` (List) - Samples uniformly between the given bounds on a log scale, quantized. 23 | - `randn` (List) - Samples from a normal distribution. 24 | - `qrandn` (List) - Samples from a normal distribution, quantized. 25 | - `randint` (List) - Samples uniformly between the given bounds, quantized to integers. 26 | - `qrandint` (List) - Samples uniformly between the given bounds, quantized to integers. 27 | - `lograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. 28 | - `qlograndint` (List) - Samples uniformly between the given bounds on a log scale, quantized to integers. 29 | - `choice` (List) - Samples from a discrete set of values. 30 | - `qrandn` (List) - Samples from a normal distribution, quantized. 31 | - `grid_search` (List) - Samples from the given list of values. 32 | 33 | """ 34 | 35 | strategy = value["strategy"] 36 | if strategy == "uniform": 37 | assert isinstance(value["values"], list) 38 | assert len(value["values"]) == 2 39 | return tune.uniform(*value["values"]) 40 | elif strategy == "quniform": 41 | assert isinstance(value["values"], list) 42 | assert len(value["values"]) == 3 43 | return tune.quniform(*value["values"]) 44 | elif strategy == "loguniform": 45 | assert isinstance(value["values"], list) 46 | assert 2 <= len(value["values"]) <= 3 47 | return tune.loguniform(*value["values"]) 48 | elif strategy == "qloguniform": 49 | assert isinstance(value["values"], list) 50 | assert len(value["values"]) == 4 51 | return tune.qloguniform(*value["values"]) 52 | elif strategy == "randn": 53 | assert isinstance(value["values"], list) 54 | assert len(value["values"]) == 2 55 | return tune.randn(*value["values"]) 56 | elif strategy == "qrandn": 57 | assert isinstance(value["values"], list) 58 | assert len(value["values"]) == 3 59 | return tune.qrandn(*value["values"]) 60 | elif strategy == "randint": 61 | assert isinstance(value["values"], list) 62 | assert len(value["values"]) == 2 63 | return tune.randint(*value["values"]) 64 | elif strategy == "qrandint": 65 | assert isinstance(value["values"], list) 66 | assert len(value["values"]) == 3 67 | return tune.qrandint(*value["values"]) 68 | elif strategy == "lograndint": 69 | assert isinstance(value["values"], list) 70 | assert len(value["values"]) == 3 71 | return tune.lograndint(*value["values"]) 72 | elif strategy == "qlograndint": 73 | assert isinstance(value["values"], list) 74 | assert len(value["values"]) == 4 75 | return tune.qlograndint(*value["values"]) 76 | elif strategy == "choice": 77 | assert isinstance(value["values"], list) 78 | return tune.choice(value["values"]) 79 | elif strategy == "grid": 80 | assert isinstance(value["values"], list) 81 | return tune.grid_search(value["values"]) 82 | 83 | for k, v in config.items(): 84 | if k != "tune_config": 85 | config[k] = get_strategy(v) 86 | 87 | return config 88 | 89 | 90 | def get_search_alg(tune_config: dict): 91 | """Initialize the search algorithm and return it. 92 | 93 | Bayesian Optimization is currently supported. 94 | """ 95 | search_alg = tune_config["search_alg"] 96 | 97 | if search_alg == "bayesopt": 98 | try: 99 | from ray.tune.search.bayesopt import BayesOptSearch 100 | except ImportError: 101 | raise ImportError("Please pip install bayesian-optimization to use BayesOptSearch.") 102 | 103 | assert "metric" in tune_config.keys() and "mode" in tune_config.keys() 104 | "Please specify metric and mode for BayesOptSearch." 105 | 106 | return BayesOptSearch(metric=tune_config["metric"], mode=tune_config["mode"]) 107 | elif search_alg == "bohb": 108 | try: 109 | from ray.tune.search.bohb import TuneBOHB 110 | except ImportError: 111 | raise ImportError("Please pip install hpbandster and ConfigSpace to use TuneBOHB.") 112 | 113 | assert "metric" in tune_config.keys() and "mode" in tune_config.keys() 114 | "Please specify metric and mode for TuneBOHB." 115 | 116 | return TuneBOHB() 117 | elif search_alg == "random": 118 | return None 119 | else: 120 | NotImplementedError("Search algorithm not supported.") 121 | 122 | 123 | def get_scheduler(tune_config: dict): 124 | """Initialize the scheduler and return it. 125 | 126 | The schedulers can early terminate bad trials, pause trials, 127 | clone trials, and alter hyperparameters of a running trial. 128 | 129 | Refer to the documentation for more info: 130 | https://docs.ray.io/en/latest/tune/api_docs/schedulers.html#tune-schedulers 131 | 132 | Currently available schedulers are: 133 | - `hyperband` - Implements the HyperBand early stopping algorithm. 134 | 135 | """ 136 | scheduler = tune_config["scheduler"] 137 | 138 | if scheduler == "hyperband": 139 | return tune.schedulers.HyperBandScheduler() 140 | elif scheduler == "hyperbandforbohb": 141 | return tune.schedulers.HyperBandForBOHB() 142 | elif scheduler == "fifo": 143 | return None 144 | else: 145 | NotImplementedError("Scheduler not supported.") 146 | 147 | 148 | def get_tune_config(tune_config: dict): 149 | """Get the tune config to initialized `tune.TuneConfig` 150 | to be passed `tune.Tuner`. 151 | """ 152 | if "search_alg" in tune_config.keys() and tune_config["search_alg"] is not None: 153 | tune_config["search_alg"] = get_search_alg(tune_config) 154 | 155 | if "scheduler" in tune_config.keys() and tune_config["scheduler"] is not None: 156 | tune_config["scheduler"] = get_scheduler(tune_config) 157 | 158 | # Remove config keys with None values. 159 | tune_config = {k: v for k, v in tune_config.items() if v is not None} 160 | 161 | return tune_config 162 | -------------------------------------------------------------------------------- /trlx/ray_tune/wandb.py: -------------------------------------------------------------------------------- 1 | """Utility function to log the results of a Ray Tune experiment to W&B.""" 2 | 3 | import json 4 | import math 5 | import os 6 | from pathlib import Path 7 | 8 | import wandb 9 | 10 | from trlx.utils import significant 11 | 12 | import wandb.apis.reports as wb # isort: skip 13 | 14 | 15 | ray_info = [ 16 | "done", 17 | "time_this_iter_s", 18 | "timesteps_total", 19 | "episodes_total", 20 | "iterations_since_restore", 21 | "timesteps_since_restore", 22 | "time_since_restore", 23 | "warmup_time", 24 | "should_checkpoint", 25 | "training_iteration", 26 | "timestamp", 27 | "pid", 28 | ] 29 | 30 | 31 | def parse_result(result): 32 | out = {} 33 | for k, v in result.items(): 34 | if isinstance(v, (int, float)) and not k.startswith("config.") and k not in ray_info: 35 | out[k] = v 36 | 37 | return out 38 | 39 | 40 | def log_trials(trial_path: str, project_name: str): 41 | trial_path = Path(trial_path) 42 | files = os.listdir(trial_path) 43 | 44 | trial_paths = [] 45 | for filename in files: 46 | tmp_path = os.path.join(trial_path, filename) 47 | if os.path.isdir(tmp_path): 48 | trial_paths.append(tmp_path) 49 | 50 | for trial in trial_paths: 51 | files = os.listdir(trial) 52 | 53 | # Open params.json and load the configs for that trial. 54 | with open(os.path.join(trial, "params.json"), "r") as f: 55 | params = json.load(f) 56 | 57 | name = ",".join(f"{k}={significant(v)}" for k, v in params.items()) 58 | # Initialize wandb 59 | run = wandb.init( 60 | name=name, 61 | project=project_name, 62 | config=params, 63 | group=trial_path.stem, 64 | job_type="hyperopt", 65 | ) 66 | 67 | # Open result.json and log the metrics to W&B. 68 | with open(os.path.join(trial, "result.json"), "r") as f: 69 | for line in f: 70 | result = json.loads(line) 71 | result.pop("config", None) 72 | wandb.log(parse_result(result)) 73 | 74 | # Close the W&B run. 75 | run.finish() 76 | 77 | 78 | def create_report(project_name, param_space, tune_config, trial_path, best_config=None): 79 | def get_parallel_coordinate(param_space, metric): 80 | column_names = list(param_space.keys()) 81 | columns = [wb.PCColumn(column) for column in column_names] 82 | 83 | return wb.ParallelCoordinatesPlot( 84 | columns=columns + [wb.PCColumn(metric)], 85 | layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2}, 86 | ) 87 | 88 | def get_param_importance(metric): 89 | return wb.ParameterImportancePlot( 90 | # Get it from the metric name. 91 | with_respect_to=metric, 92 | layout={"x": 0, "y": 5, "w": 6 * 2, "h": 4 * 2}, 93 | ) 94 | 95 | def get_scatter_plot(metric): 96 | return wb.ScatterPlot( 97 | # Get it from the metric name. 98 | title=f"{metric} v. Index", 99 | x="Index", 100 | y=metric, 101 | running_ymin=True, 102 | font_size="small", 103 | layout={"x": 6, "y": 5, "w": 6 * 2, "h": 4 * 2}, 104 | ) 105 | 106 | def get_metrics_with_history(project_name, group_name, entity=None): 107 | entity_project = f"{entity}/{project_name}" if entity else project_name 108 | api = wandb.Api() 109 | runs = api.runs(entity_project) 110 | 111 | runs = sorted( 112 | runs, 113 | key=lambda run: run.summary.get(tune_config["metric"], -math.inf), 114 | reverse=True, 115 | ) 116 | 117 | for run in runs: 118 | if run.group == str(group_name): 119 | history = run.history() 120 | metrics = history.columns 121 | break 122 | 123 | metrics = [metric for metric in metrics if not metric.startswith("_")] 124 | return metrics 125 | 126 | report = wb.Report( 127 | project=project_name, 128 | title=f"Hyperparameter Optimization Report: {trial_path}", 129 | description="This is a report that shows the results of a hyperparameter optimization experiment.", 130 | ) 131 | 132 | report.blocks = [ 133 | wb.P( 134 | "The following plots show the results of the hyperparameter optimization experiment. " 135 | "Use this as a starting point for your analysis. Go in the edit mode to customize the report. " 136 | "Share it with your team to collaborate on the analysis." 137 | ), 138 | wb.H1(text="Analysis"), 139 | wb.P( 140 | "Parallel coordinates chart (top) summarize the relationship between large numbers of hyperparameters " 141 | "and model metrics at a glance. \nThe scatter plot (right) compares the different trials and gives you a " 142 | "insight on how the trials progressed. \nThe parameter importance plot(left) lists the hyperparameters " 143 | "that were the best predictors of, and highly correlated to desirable values of your metrics." 144 | ), 145 | wb.PanelGrid( 146 | panels=[ 147 | get_parallel_coordinate(param_space, tune_config["metric"]), 148 | get_param_importance(tune_config["metric"]), 149 | get_scatter_plot(tune_config["metric"]), 150 | ], 151 | runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{trial_path}"')], 152 | ), 153 | ] 154 | 155 | metrics = get_metrics_with_history( 156 | project_name, 157 | trial_path, 158 | ) 159 | 160 | line_plot_panels = [] 161 | for metric in metrics: 162 | line_plot_panels.append( 163 | wb.LinePlot( 164 | title=f"{metric}", 165 | x="Step", 166 | y=[f"{metric}"], 167 | title_x="Step", 168 | smoothing_show_original=True, 169 | max_runs_to_show=10, 170 | plot_type="line", 171 | font_size="auto", 172 | legend_position="north", 173 | ) 174 | ) 175 | 176 | report.blocks = report.blocks + [ 177 | wb.H1(text="Metrics"), 178 | wb.P( 179 | "The following line plots show the metrics for each trial. Use this to investigate the " 180 | "performance of the model for each trial at the metrics level." 181 | ), 182 | wb.PanelGrid( 183 | panels=line_plot_panels, 184 | runsets=[wb.Runset(project=project_name).set_filters_with_python_expr(f'group == "{trial_path}"')], 185 | ), 186 | ] 187 | 188 | if best_config: 189 | report.blocks = report.blocks + [ 190 | wb.H1(text="Best Config"), 191 | wb.P( 192 | "The code block shown below is the best config found by the hyperparameter " 193 | "optimization experiment according to Ray Tune." 194 | ), 195 | wb.CodeBlock(code=[json.dumps(best_config, indent=4)], language="json"), 196 | ] 197 | 198 | report.save() 199 | print(report.url) 200 | -------------------------------------------------------------------------------- /trlx/pipeline/offline_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Union 2 | 3 | import torch 4 | from torch.nn.utils.rnn import pad_sequence 5 | from torch.utils.data import DataLoader 6 | from transformers import DataCollatorWithPadding, PreTrainedTokenizer 7 | 8 | from trlx.data.ilql_types import ( 9 | ILQLBatch, 10 | ILQLElement, 11 | ILQLSeq2SeqBatch, 12 | ILQLSeq2SeqElement, 13 | ) 14 | from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline 15 | 16 | 17 | def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=2048) -> List[int]: # noqa: C901 18 | """ 19 | Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) 20 | """ 21 | if isinstance(dialogue, str): 22 | dialogue = [tokenizer.bos_token, dialogue] 23 | elif isinstance(dialogue, tuple): 24 | dialogue = list(dialogue) 25 | 26 | out = [] 27 | ctx_length = max_length - 1 28 | if tokenizer.truncation_side == "left": 29 | for phrase in reversed(dialogue): 30 | # Manually added BOS and EOS above so we don't want to add special tokens here 31 | tokens = tokenizer(phrase, add_special_tokens=False).input_ids[-ctx_length:] 32 | ctx_length -= len(tokens) 33 | out.insert(0, tokens) 34 | if ctx_length == 0: 35 | break 36 | 37 | # in case of odd number of phrases (possibly due to truncation) 38 | # since the first phrase always has to be a prompt, force it to be 39 | if len(out) % 2 == 1: 40 | if sum(map(len, out)) == max_length: 41 | out[0].pop(0) 42 | out.insert(0, [tokenizer.bos_token_id]) 43 | 44 | elif tokenizer.truncation_side == "right": 45 | for phrase in dialogue: 46 | # Manually added BOS and EOS above so we don't want to add special tokens here 47 | tokens = tokenizer(phrase, add_special_tokens=False).input_ids[:ctx_length] 48 | ctx_length -= len(tokens) 49 | out.append(tokens) 50 | if ctx_length == 0: 51 | break 52 | 53 | out[-1].append(tokenizer.eos_token_id) 54 | 55 | return out 56 | 57 | 58 | @register_datapipeline 59 | class PromptPipeline(BasePipeline): 60 | """ 61 | Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right 62 | """ 63 | 64 | def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer): 65 | super().__init__() 66 | 67 | model_inputs = tokenizer( 68 | prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False 69 | ) 70 | 71 | prompts_tokens = model_inputs["input_ids"] 72 | attention_mask = model_inputs["attention_mask"] 73 | 74 | self.tokenizer = tokenizer 75 | self.prompts = [ 76 | {"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask) 77 | ] 78 | 79 | def __getitem__(self, ix: int): 80 | return self.prompts[ix] 81 | 82 | def __len__(self) -> int: 83 | return len(self.prompts) 84 | 85 | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: 86 | collate_fn = DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack 87 | return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) 88 | 89 | 90 | def ilql_collate_fn(elems: Iterable[ILQLElement]): 91 | return ILQLBatch( 92 | pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), 93 | pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), 94 | pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), 95 | pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), 96 | pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), 97 | pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), 98 | ) 99 | 100 | 101 | class ILQLRolloutStorage(BaseRolloutStore): 102 | """ 103 | Rollout storage for training ILQL 104 | """ 105 | 106 | def __init__(self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones): 107 | super().__init__() 108 | 109 | self.input_ids = input_ids 110 | self.attention_mask = attention_mask 111 | self.rewards = rewards 112 | self.states_ixs = states_ixs 113 | self.actions_ixs = actions_ixs 114 | self.dones = dones 115 | 116 | def __getitem__(self, ix: int) -> ILQLElement: 117 | return ILQLElement( 118 | self.input_ids[ix], 119 | self.attention_mask[ix], 120 | self.rewards[ix], 121 | self.states_ixs[ix], 122 | self.actions_ixs[ix], 123 | self.dones[ix], 124 | ) 125 | 126 | def __len__(self) -> int: 127 | return len(self.input_ids) 128 | 129 | def create_loader(self, batch_size: int, drop_last=True): 130 | return DataLoader( 131 | self, 132 | batch_size=batch_size, 133 | shuffle=True, 134 | collate_fn=ilql_collate_fn, 135 | drop_last=drop_last, 136 | ) 137 | 138 | 139 | def ilql_seq2seq_collate_fn(elems: Iterable[ILQLElement]): 140 | return ILQLSeq2SeqBatch( 141 | pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), 142 | pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), 143 | pad_sequence([x.decoder_input_ids for x in elems], batch_first=True, padding_value=0), 144 | pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), 145 | pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), 146 | pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), 147 | pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), 148 | ) 149 | 150 | 151 | class ILQLSeq2SeqRolloutStorage(BaseRolloutStore): 152 | """ 153 | Rollout storage for training ILQL 154 | """ 155 | 156 | def __init__(self, input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones): 157 | super().__init__() 158 | 159 | self.input_ids = input_ids 160 | self.attention_mask = attention_mask 161 | self.decoder_input_ids = decoder_input_ids 162 | self.rewards = rewards 163 | self.states_ixs = states_ixs 164 | self.actions_ixs = actions_ixs 165 | self.dones = dones 166 | 167 | def __getitem__(self, ix: int) -> ILQLElement: 168 | return ILQLSeq2SeqElement( 169 | self.input_ids[ix], 170 | self.attention_mask[ix], 171 | self.decoder_input_ids[ix], 172 | self.rewards[ix], 173 | self.states_ixs[ix], 174 | self.actions_ixs[ix], 175 | self.dones[ix], 176 | ) 177 | 178 | def __len__(self) -> int: 179 | return len(self.input_ids) 180 | 181 | def create_loader(self, batch_size: int, drop_last=True): 182 | return DataLoader( 183 | self, 184 | batch_size=batch_size, 185 | shuffle=True, 186 | collate_fn=ilql_seq2seq_collate_fn, 187 | drop_last=drop_last, 188 | ) 189 | -------------------------------------------------------------------------------- /examples/ppo_tldr.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import List 3 | 4 | import torch 5 | import accelerate 6 | from datasets import load_dataset 7 | from reward_model.reward_model import GPTRewardModel 8 | from tqdm import tqdm 9 | from transformers import AutoTokenizer 10 | 11 | import trlx 12 | from trlx.data.configs import ( 13 | ModelConfig, 14 | OptimizerConfig, 15 | SchedulerConfig, 16 | TokenizerConfig, 17 | TrainConfig, 18 | TRLConfig, 19 | ) 20 | from trlx.models.modeling_ppo import PPOConfig 21 | import wandb 22 | import os 23 | 24 | 25 | SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" 26 | OUTPUT_DIR = "output" 27 | RANDOM_SEED = 1000 28 | 29 | 30 | os.environ["TOKENIZERS_PARALLELISM"] = "false" 31 | REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" 32 | if not os.path.exists(REWARD_CHECKPOINT_PATH): 33 | os.makedirs("reward_model/rm_checkpoint", exist_ok=True) 34 | os.system( 35 | f"wget -O {REWARD_CHECKPOINT_PATH} \ 36 | https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin" 37 | ) 38 | 39 | config = TRLConfig( 40 | train=TrainConfig( 41 | seq_length=550, 42 | epochs=1000, 43 | total_steps=10000, 44 | batch_size=2, 45 | checkpoint_interval=1000, 46 | eval_interval=1000, 47 | pipeline="PromptPipeline", 48 | trainer="AcceleratePPOTrainer", 49 | seed=RANDOM_SEED, 50 | ), 51 | model=ModelConfig( 52 | model_path= "CarperAI/openai_summarize_tldr_sft", 53 | num_layers_unfrozen=8, 54 | ), 55 | tokenizer=TokenizerConfig( 56 | tokenizer_path="gpt2", 57 | truncation_side="right", 58 | ), 59 | optimizer = OptimizerConfig( 60 | name="adamw", 61 | kwargs={ 62 | "lr": 1.0e-6, 63 | "betas": [0.9, 0.999], 64 | "eps": 1.0e-8, 65 | "weight_decay": 0.01, 66 | }, 67 | ), 68 | scheduler=SchedulerConfig( 69 | name="cosine_annealing", 70 | kwargs={ 71 | "T_max": 100000, 72 | "eta_min": 1.0e-6, 73 | }, 74 | ), 75 | method=PPOConfig( 76 | name="PPOConfig", 77 | num_rollouts=128, 78 | chunk_size=1, 79 | ppo_epochs=2, 80 | init_kl_coef=0.1, 81 | target=6, 82 | horizon=10000, 83 | gamma=1, 84 | lam=0.95, 85 | cliprange=0.2, 86 | cliprange_value=0.2, 87 | vf_coef=0.2, 88 | scale_reward=None, 89 | ref_mean=None, 90 | ref_std=None, 91 | cliprange_reward=10, 92 | gen_kwargs=dict( 93 | max_new_tokens=50, 94 | do_sample=True, 95 | ), 96 | ), 97 | ) 98 | 99 | 100 | if __name__ == "__main__": 101 | output_dir = OUTPUT_DIR 102 | wandb.init(name='RLHF rl gptj summarization', project='clausa-trlx', tags=['rlhf'], mode='offline', dir=output_dir) 103 | config.train.rollout_logging_dir = output_dir 104 | config.train.checkpoint_dir = output_dir 105 | config.train.logging_dir = output_dir 106 | config.train.tracker = "tensorboard" 107 | # Load the pre-trained reward model 108 | rw_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B") 109 | rw_tokenizer.pad_token = rw_tokenizer.eos_token 110 | rw_model = GPTRewardModel(SFT_MODEL_PATH) 111 | rw_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) 112 | rw_model.half() 113 | rw_model.eval() 114 | rw_device = torch.cuda.device_count() - 1 # set reward model device 115 | rw_model.to(rw_device) 116 | 117 | def get_scores(samples: List[str]): 118 | scores_list = [] 119 | batch_size = 1 120 | for i in range(0, len(samples), batch_size): 121 | sub_samples = samples[i : i + batch_size] 122 | sub_samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples] 123 | encodings_dict = rw_tokenizer( 124 | sub_samples, 125 | truncation=True, 126 | max_length=config.train.seq_length, 127 | padding="max_length", 128 | return_tensors="pt", 129 | ) 130 | input_ids = encodings_dict["input_ids"].to(rw_device) 131 | attn_masks = encodings_dict["attention_mask"].to(rw_device) 132 | input_ids = input_ids.repeat(2, 1) 133 | attn_masks = attn_masks.repeat(2, 1) 134 | with torch.no_grad(): 135 | sub_scores = rw_model(input_ids=input_ids, attention_mask=attn_masks) 136 | scores_list.append(sub_scores["chosen_end_scores"]) 137 | scores = torch.cat(scores_list, dim=0) 138 | return scores 139 | 140 | def get_prompt_dataset(prompts, max_length): 141 | """ 142 | Get the prompt after T5 decoding to make sure dictionary 143 | of prompts and summaries is consistent decode prompt from trlX pipeline 144 | """ 145 | formatted_prompts = [] 146 | for i in tqdm(range(len(prompts))): 147 | tmp = tokenizer.decode( 148 | tokenizer( 149 | prompts[i].split("TL;DR:")[0], 150 | truncation=True, 151 | max_length=max_length - 5, # to make sure "TL;DR" dont get truncated 152 | add_special_tokens=False, 153 | )["input_ids"], 154 | skip_special_tokens=True, 155 | ).strip() 156 | tmp = tmp + "\nTL;DR:" 157 | tmp = tokenizer.decode( 158 | tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], 159 | skip_special_tokens=True, 160 | ).strip() 161 | formatted_prompts.append(tmp) 162 | return formatted_prompts 163 | 164 | def reward_fn(samples: List[str], **kwargs): 165 | # original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples] 166 | # original_samples = [text + post_summary_dict[text.strip()] for text in original_samples] 167 | # original_scores = get_scores(original_samples) 168 | scores = get_scores(samples) 169 | norms_scores = scores # - original_scores 170 | return norms_scores 171 | 172 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) 173 | tokenizer.pad_token = tokenizer.eos_token 174 | tokenizer.padding_side = "left" 175 | max_length_input = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 176 | 177 | dataset = load_dataset("CarperAI/openai_summarize_tldr") 178 | 179 | # Store data into prompt and label pairs 180 | train_set = [(sample["prompt"], sample["label"]) for sample in dataset["train"]] 181 | val_set = [(sample["prompt"], sample["label"]) for sample in dataset["valid"]] 182 | 183 | # Split contents into summaries and labels 184 | train_posts, train_summaries = zip(*train_set) 185 | val_posts, val_summaries = zip(*val_set) 186 | 187 | # Get the OpenAI summaries 188 | post_summary_dict = {} 189 | train_prompts = get_prompt_dataset(train_posts, max_length_input) 190 | for i in range(len(train_prompts)): 191 | post_summary_dict[train_prompts[i]] = train_summaries[i] 192 | val_prompts = get_prompt_dataset(val_posts, max_length_input) 193 | for i in range(len(val_prompts)): 194 | post_summary_dict[val_prompts[i]] = val_summaries[i] 195 | 196 | trainer = trlx.train( 197 | reward_fn=reward_fn, 198 | prompts=train_prompts, 199 | eval_prompts=val_prompts[0:1000], # sampling 1000 validation prompts for evaluation speed in training 200 | config=config, 201 | ) 202 | -------------------------------------------------------------------------------- /trlx/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import random 4 | import subprocess 5 | import time 6 | from dataclasses import is_dataclass 7 | from enum import Enum 8 | from itertools import repeat 9 | from numbers import Number 10 | from typing import Any, Dict, Iterable, Tuple 11 | 12 | import numpy as np 13 | import torch 14 | from accelerate import Accelerator 15 | from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR 16 | 17 | 18 | def print_rank_0(*message): 19 | """ 20 | Print only once from the main rank 21 | """ 22 | if os.environ.get("RANK", "0") == "0": 23 | print(*message) 24 | 25 | 26 | def significant(x: Number, ndigits=2) -> Number: 27 | """ 28 | Cut the number up to its `ndigits` after the most significant 29 | """ 30 | if isinstance(x, torch.Tensor): 31 | x = x.item() 32 | 33 | if not isinstance(x, Number) or math.isnan(x) or x == 0: 34 | return x 35 | 36 | return round(x, ndigits - int(math.floor(math.log10(abs(x))))) 37 | 38 | 39 | def set_seed(seed: int): 40 | """ 41 | Sets seeds across package dependencies for reproducibility. 42 | """ 43 | seed += int(os.environ.get("RANK", 0)) 44 | random.seed(seed) 45 | np.random.seed(seed) 46 | torch.manual_seed(seed) 47 | torch.cuda.manual_seed(seed) 48 | 49 | 50 | # Training utils 51 | 52 | 53 | def get_distributed_config(accelerator: Accelerator): 54 | """ 55 | Return accelerator distributed config 56 | """ 57 | 58 | dist_config = { 59 | "mixed_precision": accelerator.mixed_precision, 60 | "num_gpus": accelerator.num_processes, 61 | } 62 | 63 | if accelerator.state.deepspeed_plugin is not None: 64 | ds_plugin = accelerator.state.deepspeed_plugin 65 | dist_config.update( 66 | { 67 | "gradient_accumulation_steps": ds_plugin.gradient_accumulation_steps, 68 | "gradient_clipping": ds_plugin.gradient_clipping, 69 | "zero_stage": ds_plugin.zero_stage, 70 | "offload_optimizer_device": ds_plugin.offload_optimizer_device, 71 | "offload_param_device": ds_plugin.offload_param_device, 72 | } 73 | ) 74 | 75 | return dist_config 76 | 77 | 78 | class OptimizerName(str, Enum): 79 | """Supported optimizer names""" 80 | 81 | ADAM: str = "adam" 82 | ADAMW: str = "adamw" 83 | ADAM_8BIT_BNB: str = "adam_8bit_bnb" 84 | ADAMW_8BIT_BNB: str = "adamw_8bit_bnb" 85 | SGD: str = "sgd" 86 | 87 | 88 | def get_optimizer_class(name: OptimizerName): 89 | """ 90 | Returns the optimizer class with the given name 91 | 92 | Args: 93 | name (str): Name of the optimizer as found in `OptimizerNames` 94 | """ 95 | if name == OptimizerName.ADAM: 96 | return torch.optim.Adam 97 | if name == OptimizerName.ADAMW: 98 | return torch.optim.AdamW 99 | if name == OptimizerName.ADAM_8BIT_BNB.value: 100 | try: 101 | from bitsandbytes.optim import Adam8bit 102 | 103 | return Adam8bit 104 | except ImportError: 105 | raise ImportError( 106 | "You must install the `bitsandbytes` package to use the 8-bit Adam. " 107 | "Install with: `pip install bitsandbytes`" 108 | ) 109 | if name == OptimizerName.ADAMW_8BIT_BNB.value: 110 | try: 111 | from bitsandbytes.optim import AdamW8bit 112 | 113 | return AdamW8bit 114 | except ImportError: 115 | raise ImportError( 116 | "You must install the `bitsandbytes` package to use 8-bit AdamW. " 117 | "Install with: `pip install bitsandbytes`" 118 | ) 119 | if name == OptimizerName.SGD.value: 120 | return torch.optim.SGD 121 | supported_optimizers = [o.value for o in OptimizerName] 122 | raise ValueError(f"`{name}` is not a supported optimizer. " f"Supported optimizers are: {supported_optimizers}") 123 | 124 | 125 | class SchedulerName(str, Enum): 126 | """Supported scheduler names""" 127 | 128 | COSINE_ANNEALING = "cosine_annealing" 129 | LINEAR = "linear" 130 | 131 | 132 | def get_scheduler_class(name: SchedulerName): 133 | """ 134 | Returns the scheduler class with the given name 135 | """ 136 | if name == SchedulerName.COSINE_ANNEALING: 137 | return CosineAnnealingLR 138 | if name == SchedulerName.LINEAR: 139 | return LinearLR 140 | supported_schedulers = [s.value for s in SchedulerName] 141 | raise ValueError(f"`{name}` is not a supported scheduler. " f"Supported schedulers are: {supported_schedulers}") 142 | 143 | 144 | class Clock: 145 | """ 146 | Helper object for keeping track of time for computations. 147 | """ 148 | 149 | def __init__(self): 150 | self.start = time.time() 151 | self.total_time = 0 152 | self.total_samples = 0 153 | 154 | def tick(self, samples: int = 0) -> float: 155 | """ 156 | Returns time (s) since last call to tick(). Also records samples processed since last call. 157 | 158 | :param samples: number of samples that have been processed since last call 159 | """ 160 | end = time.time() 161 | delta = end - self.start 162 | self.start = end 163 | 164 | if samples != 0: 165 | self.total_time += delta 166 | self.total_samples += samples 167 | 168 | return delta 169 | 170 | def get_stat(self, n_samp: int = 1000, reset: bool = False): 171 | """ 172 | Returns average time (s) per n_samp samples processed 173 | 174 | :param reset: Reset counts? 175 | """ 176 | sec_per_samp = self.total_time / self.total_samples 177 | 178 | if reset: 179 | self.total_samples = 0 180 | self.total_time = 0 181 | 182 | return sec_per_samp * n_samp 183 | 184 | 185 | def tree_map(f, tree: Any) -> Any: 186 | """ 187 | Apply function f to all leaves in tree 188 | """ 189 | if is_dataclass(tree): 190 | return tree.__class__(**{k: tree_map(f, v) for k, v in tree.__dict__.items()}) 191 | elif isinstance(tree, dict): 192 | return {k: tree_map(f, v) for k, v in tree.items()} 193 | elif isinstance(tree, (list, tuple)): 194 | return tree.__class__(tree_map(f, v) for v in tree) 195 | else: 196 | return f(tree) 197 | 198 | 199 | def to_device(tree, device, non_blocking=False): 200 | """ 201 | Move all tensors in tree to device 202 | """ 203 | return tree_map(lambda x: x.to(device, non_blocking=non_blocking), tree) 204 | 205 | 206 | def filter_non_scalars(xs: Dict) -> Dict: 207 | """ 208 | Trims everything that can't be casted to float 209 | """ 210 | ys = {} 211 | for k, v in xs.items(): 212 | try: 213 | ys[k] = float(v) 214 | except TypeError: 215 | continue 216 | 217 | return ys 218 | 219 | 220 | def get_git_tag() -> Tuple[str, str]: 221 | """ 222 | Returns commit's short hash and date 223 | """ 224 | try: 225 | output = subprocess.check_output("git log --format='%h/%as' -n1".split()) 226 | branch = subprocess.check_output("git rev-parse --abbrev-ref HEAD".split()) 227 | return branch.decode()[:-1], output.decode()[1:-2] 228 | except subprocess.CalledProcessError: 229 | return "unknown", "unknown" 230 | 231 | 232 | # Iter utils 233 | 234 | 235 | def infinite_dataloader(dataloader: Iterable) -> Iterable: 236 | """ 237 | Returns a cyclic infinite dataloader from a finite dataloader 238 | """ 239 | for _ in repeat(dataloader): 240 | yield from dataloader 241 | -------------------------------------------------------------------------------- /examples/apa_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import tritonclient.grpc as client_util 9 | from datasets import load_dataset 10 | from huggingface_hub import snapshot_download 11 | from torch import nn 12 | from transformers import AutoModelForCausalLM, AutoTokenizer 13 | from tritonclient.utils import np_to_triton_dtype 14 | 15 | import trlx 16 | from trlx.data.default_configs import ( 17 | ModelConfig, 18 | OptimizerConfig, 19 | SPPOConfig, 20 | SchedulerConfig, 21 | TokenizerConfig, 22 | TrainConfig, 23 | TRLConfig, 24 | ) 25 | import random 26 | 27 | RANDOM_SEED = 0 28 | MODEL_SIZE = "125M" 29 | LOSS = "square" # "square" or "log", square for APA and log for AWR 30 | ADV_COEFF_SQ = 10 31 | ADV_COEFF_LOG = 0.5 32 | OUTPUT_DIR = "output" 33 | 34 | random.seed(RANDOM_SEED) 35 | np.random.seed(RANDOM_SEED) 36 | torch.manual_seed(RANDOM_SEED) 37 | torch.cuda.manual_seed(RANDOM_SEED) 38 | default_config = TRLConfig( 39 | train=TrainConfig( 40 | seq_length=1000, 41 | epochs=10000, 42 | total_steps=20000, 43 | batch_size=4, 44 | checkpoint_interval=1000, 45 | eval_interval=1000, 46 | pipeline="PromptPipeline", 47 | trainer="AccelerateSPPOTrainer", 48 | checkpoint_dir="checkpoints/ppo_hh", 49 | seed=RANDOM_SEED, 50 | ), 51 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2), 52 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 53 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 54 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1e-6)), 55 | method=SPPOConfig( 56 | name="SPPOConfig", 57 | num_rollouts=64, 58 | chunk_size=16, 59 | ppo_epochs=2, 60 | init_kl_coef=0.05, 61 | target=6, 62 | horizon=10000, 63 | gamma=1, 64 | lam=0.95, 65 | cliprange=100, 66 | cliprange_value=100, 67 | vf_coef=1, 68 | scale_reward="running", 69 | ref_mean=None, 70 | ref_std=None, 71 | loss_str=LOSS, 72 | adv_coeff_sq=ADV_COEFF_SQ, 73 | adv_coeff_log=ADV_COEFF_LOG, 74 | cliprange_reward=100, 75 | gen_kwargs=dict( 76 | max_new_tokens=128, 77 | do_sample=True, 78 | ), 79 | ), 80 | ) 81 | 82 | 83 | config_name = MODEL_SIZE 84 | if config_name == "125M": 85 | default_config.train.batch_size = 8 86 | default_config.method.chunk_size = 16 87 | default_config.train.total_steps = 20000 88 | default_config.model.model_path = "Dahoas/pythia-125M-static-sft" 89 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 90 | default_config.method.num_rollouts = 128 91 | elif config_name == "1B": 92 | default_config.train.batch_size = 2 93 | default_config.train.total_steps = 20000 94 | default_config.model.model_path = "Dahoas/pythia-1B-static-sft" 95 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 96 | default_config.method.chunk_size = 4 97 | elif config_name == "6B": 98 | default_config.train.batch_size = 1 99 | default_config.train.total_steps = 20000 100 | default_config.model.model_path = "Dahoas/pythia-6B-static-sft" # "databricks/dolly-v2-7b" # 101 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" # "databricks/dolly-v2-7b" # 102 | default_config.method.chunk_size = 1 103 | elif config_name == "20B": 104 | default_config.train.seq_length = 512 105 | default_config.train.batch_size = 1 106 | default_config.train.total_steps = 8000 107 | default_config.model.model_path = "EleutherAI/gpt-neox-20b" 108 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 109 | default_config.method.num_rollouts = 16 110 | default_config.method.chunk_size = 4 111 | default_config.method.ppo_epochs = 2 112 | 113 | 114 | def prepare_tensor(name: str, input): 115 | t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) 116 | t.set_data_from_numpy(input) 117 | return t 118 | 119 | 120 | def create_reward_fn(): # noqa: C901 121 | reward_tokenizer = AutoTokenizer.from_pretrained("gpt2") 122 | reward_tokenizer.pad_token = reward_tokenizer.eos_token 123 | reward_tokenizer.truncation_side = "left" 124 | triton_host = os.environ.get("TRITON_HOST") 125 | 126 | if triton_host: 127 | triton_url, triton_model = triton_host.split("/") 128 | client = client_util.InferenceServerClient(url=triton_url, verbose=False) 129 | 130 | def reward_fn(samples, prompts, outputs): 131 | samples = [s + reward_tokenizer.eos_token for s in samples] 132 | input = reward_tokenizer(samples, padding=True, max_length=1024) 133 | 134 | mbs = 24 135 | out = [] 136 | for i in range(math.ceil(len(samples) / mbs)): 137 | batch_ixs = slice(i * mbs, (i + 1) * mbs) 138 | input_ids = np.array(input.input_ids[batch_ixs], dtype=np.int32) 139 | 140 | result = client.infer(triton_model, [prepare_tensor("input_ids", input_ids)]) 141 | rewards = result.as_numpy("rewards") 142 | out.extend(rewards) 143 | 144 | return out 145 | 146 | elif os.environ.get("RANK", "0") == "0": 147 | 148 | class RewardModel(nn.Module): 149 | def __init__(self, checkpoint_path, eos_token_id): 150 | super().__init__() 151 | model = AutoModelForCausalLM.from_pretrained(checkpoint_path) 152 | self.transformer = model.transformer 153 | self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) 154 | self.eos_token_id = eos_token_id 155 | 156 | def forward(self, input_ids): 157 | states = self.transformer(input_ids)[0] 158 | rewards = self.v_head(states).squeeze(-1) 159 | ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) 160 | returns = torch.gather(rewards, 1, ends).squeeze(-1) 161 | return returns 162 | 163 | reward_model = RewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id) 164 | directory = snapshot_download("Dahoas/gptj-rm-static", revision="676bfd4d") 165 | for fpath in os.listdir(directory): 166 | if fpath.endswith(".pt") or fpath.endswith(".bin"): 167 | checkpoint = os.path.join(directory, fpath) 168 | break 169 | 170 | reward_model.load_state_dict(torch.load(checkpoint)) 171 | reward_model.eval() 172 | reward_model.requires_grad_(False) 173 | device = torch.cuda.device_count() - 1 174 | reward_model = reward_model.half().to(device) 175 | 176 | def reward_fn(samples, prompts, outputs): 177 | samples = [s + reward_tokenizer.eos_token for s in samples] 178 | input = reward_tokenizer(samples, padding=True, truncation=True, max_length=1024, return_tensors="pt").to( 179 | device 180 | ) 181 | 182 | mbs = 12 183 | out = [] 184 | for i in range(math.ceil(len(samples) / mbs)): 185 | batch_ixs = slice(i * mbs, (i + 1) * mbs) 186 | input_ids = input.input_ids[batch_ixs] 187 | rewards = reward_model(input_ids) 188 | out.extend(rewards) 189 | 190 | return out 191 | 192 | else: 193 | reward_fn = True 194 | 195 | return reward_fn 196 | 197 | 198 | def main(hparams={}): 199 | output_dir = OUTPUT_DIR 200 | config = TRLConfig.update(default_config, hparams) 201 | config.train.checkpoint_dir = output_dir 202 | config.train.logging_dir = output_dir 203 | config.train.tracker = "tensorboard" 204 | 205 | dataset = load_dataset("Dahoas/rm-static") 206 | prompts = dataset["train"]["prompt"] 207 | eval_prompts = dataset["test"]["prompt"][:280] 208 | reward_fn = create_reward_fn() 209 | 210 | trlx.train( 211 | prompts=prompts, 212 | eval_prompts=eval_prompts, 213 | reward_fn=reward_fn, 214 | config=config, 215 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 216 | ) 217 | 218 | 219 | if __name__ == "__main__": 220 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 221 | main(hparams) 222 | -------------------------------------------------------------------------------- /examples/ppo_hh.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import tritonclient.grpc as client_util 9 | from datasets import load_dataset 10 | from huggingface_hub import snapshot_download 11 | from torch import nn 12 | from transformers import AutoModelForCausalLM, AutoTokenizer 13 | from tritonclient.utils import np_to_triton_dtype 14 | 15 | import trlx 16 | from trlx.data.default_configs import ( 17 | ModelConfig, 18 | OptimizerConfig, 19 | PPOConfig, 20 | SchedulerConfig, 21 | TokenizerConfig, 22 | TrainConfig, 23 | TRLConfig, 24 | ) 25 | import random 26 | 27 | RANDOM_SEED = 100 28 | MODEL_SIZE = "125M" 29 | OUTPUT_DIR = "./output" 30 | 31 | random.seed(RANDOM_SEED) 32 | np.random.seed(RANDOM_SEED) 33 | torch.manual_seed(RANDOM_SEED) 34 | torch.cuda.manual_seed(RANDOM_SEED) 35 | default_config = TRLConfig( 36 | train=TrainConfig( 37 | seq_length=1024, 38 | epochs=10000, 39 | total_steps=20000, 40 | batch_size=4, 41 | checkpoint_interval=1000, 42 | eval_interval=1000, 43 | pipeline="PromptPipeline", 44 | trainer="AcceleratePPOTrainer", 45 | checkpoint_dir="checkpoints/ppo_hh", 46 | seed=RANDOM_SEED, 47 | ), 48 | model=ModelConfig(model_path="EleutherAI/gpt-j-6B", num_layers_unfrozen=2), 49 | tokenizer=TokenizerConfig(tokenizer_path="EleutherAI/gpt-j-6B", truncation_side="left"), 50 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=8e-6, betas=(0.9, 0.95), eps=1.0e-8, weight_decay=1.0e-6)), 51 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=8e-6)), 52 | method=PPOConfig( 53 | name="PPOConfig", 54 | num_rollouts=64, 55 | chunk_size=16, 56 | ppo_epochs=2, 57 | init_kl_coef=0.05, 58 | target=6, 59 | horizon=10000, 60 | gamma=1, 61 | lam=0.95, 62 | cliprange=0.2, 63 | cliprange_value=0.2, 64 | vf_coef=1, 65 | scale_reward="running", 66 | ref_mean=None, 67 | ref_std=None, 68 | cliprange_reward=10, 69 | gen_kwargs=dict( 70 | max_new_tokens=128, 71 | top_k=0, 72 | top_p=1.0, 73 | do_sample=True, 74 | ), 75 | ), 76 | ) 77 | 78 | 79 | config_name = MODEL_SIZE 80 | if config_name == "125M": 81 | default_config.train.batch_size = 8 82 | default_config.method.chunk_size = 16 83 | default_config.train.total_steps = 20000 84 | default_config.model.model_path = "Dahoas/pythia-125M-static-sft" 85 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 86 | default_config.method.num_rollouts = 128 87 | elif config_name == "1B": 88 | default_config.train.batch_size = 2 89 | default_config.train.total_steps = 20000 90 | default_config.optimizer.kwargs["lr"] = 1e-6 91 | default_config.scheduler.kwargs["eta_min"] = 1e-6 92 | default_config.model.model_path = "Dahoas/pythia-1B-static-sft" 93 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 94 | default_config.method.chunk_size = 4 95 | elif config_name == "6B": 96 | default_config.train.batch_size = 2 97 | default_config.train.total_steps = 20000 98 | default_config.model.model_path = "Dahoas/pythia-6B-static-sft" # "databricks/dolly-v2-7b" # 99 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" # "databricks/dolly-v2-7b" # 100 | default_config.method.chunk_size = 1 101 | elif config_name == "20B": 102 | default_config.train.seq_length = 512 103 | default_config.train.batch_size = 1 104 | default_config.train.total_steps = 8000 105 | default_config.optimizer.kwargs["lr"] = 1e-6 106 | default_config.scheduler.kwargs["eta_min"] = 1e-6 107 | default_config.train.checkpoint_dir = "checkpoints/ppo_hh_20B" 108 | default_config.model.model_path = "EleutherAI/gpt-neox-20b" 109 | default_config.tokenizer.tokenizer_path = "EleutherAI/gpt-neox-20b" 110 | default_config.method.num_rollouts = 16 111 | default_config.method.chunk_size = 4 112 | default_config.method.ppo_epochs = 2 113 | 114 | 115 | def prepare_tensor(name: str, input): 116 | t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) 117 | t.set_data_from_numpy(input) 118 | return t 119 | 120 | 121 | def create_reward_fn(): # noqa: C901 122 | reward_tokenizer = AutoTokenizer.from_pretrained("gpt2") 123 | reward_tokenizer.pad_token = reward_tokenizer.eos_token 124 | reward_tokenizer.truncation_side = "left" 125 | triton_host = os.environ.get("TRITON_HOST") 126 | 127 | if triton_host: 128 | triton_url, triton_model = triton_host.split("/") 129 | client = client_util.InferenceServerClient(url=triton_url, verbose=False) 130 | 131 | def reward_fn(samples, prompts, outputs): 132 | samples = [s + reward_tokenizer.eos_token for s in samples] 133 | input = reward_tokenizer(samples, padding=True, max_length=1024) 134 | 135 | mbs = 24 136 | out = [] 137 | for i in range(math.ceil(len(samples) / mbs)): 138 | batch_ixs = slice(i * mbs, (i + 1) * mbs) 139 | input_ids = np.array(input.input_ids[batch_ixs], dtype=np.int32) 140 | 141 | result = client.infer(triton_model, [prepare_tensor("input_ids", input_ids)]) 142 | rewards = result.as_numpy("rewards") 143 | out.extend(rewards) 144 | 145 | return out 146 | 147 | elif os.environ.get("RANK", "0") == "0": 148 | 149 | class RewardModel(nn.Module): 150 | def __init__(self, checkpoint_path, eos_token_id): 151 | super().__init__() 152 | model = AutoModelForCausalLM.from_pretrained(checkpoint_path) 153 | self.transformer = model.transformer 154 | self.v_head = nn.Linear(model.config.n_embd, 1, bias=False) 155 | self.eos_token_id = eos_token_id 156 | 157 | def forward(self, input_ids): 158 | states = self.transformer(input_ids)[0] 159 | rewards = self.v_head(states).squeeze(-1) 160 | ends = torch.argmax((input_ids == self.eos_token_id).float(), dim=1).view(-1, 1) 161 | returns = torch.gather(rewards, 1, ends).squeeze(-1) 162 | return returns 163 | 164 | reward_model = RewardModel("EleutherAI/gpt-j-6B", reward_tokenizer.eos_token_id) 165 | directory = snapshot_download("Dahoas/gptj-rm-static", revision="676bfd4d") 166 | for fpath in os.listdir(directory): 167 | if fpath.endswith(".pt") or fpath.endswith(".bin"): 168 | checkpoint = os.path.join(directory, fpath) 169 | break 170 | 171 | reward_model.load_state_dict(torch.load(checkpoint)) 172 | reward_model.eval() 173 | reward_model.requires_grad_(False) 174 | device = torch.cuda.device_count() - 1 175 | reward_model = reward_model.half().to(device) 176 | 177 | def reward_fn(samples, prompts, outputs): 178 | samples = [s + reward_tokenizer.eos_token for s in samples] 179 | input = reward_tokenizer(samples, padding=True, truncation=True, max_length=1024, return_tensors="pt").to( 180 | device 181 | ) 182 | 183 | mbs = 6 184 | out = [] 185 | for i in range(math.ceil(len(samples) / mbs)): 186 | batch_ixs = slice(i * mbs, (i + 1) * mbs) 187 | input_ids = input.input_ids[batch_ixs] 188 | rewards = reward_model(input_ids) 189 | out.extend(rewards) 190 | 191 | return out 192 | 193 | else: 194 | reward_fn = True 195 | 196 | return reward_fn 197 | 198 | 199 | def main(hparams={}): 200 | output_dir = OUTPUT_DIR 201 | config = TRLConfig.update(default_config, hparams) 202 | config.train.checkpoint_dir = output_dir 203 | config.train.logging_dir = output_dir 204 | config.train.tracker = "tensorboard" 205 | 206 | dataset = load_dataset("Dahoas/rm-static") 207 | prompts = dataset["train"]["prompt"] 208 | eval_prompts = dataset["test"]["prompt"][:280] 209 | reward_fn = create_reward_fn() 210 | 211 | trlx.train( 212 | prompts=prompts, 213 | eval_prompts=eval_prompts, 214 | reward_fn=reward_fn, 215 | config=config, 216 | stop_sequences=["Human:", "human:", "Assistant:", "assistant:"], 217 | ) 218 | 219 | 220 | if __name__ == "__main__": 221 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 222 | main(hparams) 223 | -------------------------------------------------------------------------------- /trlx/pipeline/sql_on_pipeline.py: -------------------------------------------------------------------------------- 1 | from typing import Iterable, List, Union 2 | import os 3 | import torch 4 | import time 5 | import json 6 | 7 | from torch.nn.utils.rnn import pad_sequence 8 | from torch.utils.data import DataLoader 9 | from transformers import DataCollatorWithPadding, PreTrainedTokenizer 10 | 11 | from trlx.data.ilql_types import ( 12 | ILQLBatch, 13 | ILQLElement, 14 | ILQLSeq2SeqBatch, 15 | ILQLSeq2SeqElement, 16 | ) 17 | from trlx.pipeline import BasePipeline, BaseRolloutStore, register_datapipeline 18 | 19 | 20 | def tokenize_dialogue(dialogue: Union[str, List[str]], tokenizer, max_length=2048) -> List[int]: # noqa: C901 21 | """ 22 | Tokenize sample with the interleaved form of (prompt_1, output_1, prompt_2, output_2...) 23 | """ 24 | if isinstance(dialogue, str): 25 | dialogue = [tokenizer.bos_token, dialogue] 26 | elif isinstance(dialogue, tuple): 27 | dialogue = list(dialogue) 28 | 29 | out = [] 30 | ctx_length = max_length - 1 31 | if tokenizer.truncation_side == "left": 32 | for phrase in reversed(dialogue): 33 | # Manually added BOS and EOS above so we don't want to add special tokens here 34 | tokens = tokenizer(phrase, add_special_tokens=False).input_ids[-ctx_length:] 35 | ctx_length -= len(tokens) 36 | out.insert(0, tokens) 37 | if ctx_length == 0: 38 | break 39 | 40 | # in case of odd number of phrases (possibly due to truncation) 41 | # since the first phrase always has to be a prompt, force it to be 42 | if len(out) % 2 == 1: 43 | if sum(map(len, out)) == max_length: 44 | out[0].pop(0) 45 | out.insert(0, [tokenizer.bos_token_id]) 46 | 47 | elif tokenizer.truncation_side == "right": 48 | for phrase in dialogue: 49 | # Manually added BOS and EOS above so we don't want to add special tokens here 50 | tokens = tokenizer(phrase, add_special_tokens=False).input_ids[:ctx_length] 51 | ctx_length -= len(tokens) 52 | out.append(tokens) 53 | if ctx_length == 0: 54 | break 55 | 56 | out[-1].append(tokenizer.eos_token_id) 57 | 58 | return out 59 | 60 | 61 | @register_datapipeline 62 | class PromptPipeline(BasePipeline): 63 | """ 64 | Tokenizes prompts, unless they are already tokenized, and truncates them to `max_prompt_length` from the right 65 | """ 66 | 67 | def __init__(self, prompts: List[str], max_prompt_length: int, tokenizer: PreTrainedTokenizer): 68 | super().__init__() 69 | 70 | model_inputs = tokenizer( 71 | prompts, truncation=True, padding=False, max_length=max_prompt_length, add_special_tokens=False 72 | ) 73 | 74 | prompts_tokens = model_inputs["input_ids"] 75 | attention_mask = model_inputs["attention_mask"] 76 | 77 | self.tokenizer = tokenizer 78 | self.prompts = [ 79 | {"input_ids": tokens, "attention_mask": mask} for tokens, mask in zip(prompts_tokens, attention_mask) 80 | ] 81 | 82 | def __getitem__(self, ix: int): 83 | return self.prompts[ix] 84 | 85 | def __len__(self) -> int: 86 | return len(self.prompts) 87 | 88 | def create_loader(self, batch_size: int, shuffle=False) -> DataLoader: 89 | collate_fn = DataCollatorWithPadding(self.tokenizer) if self.tokenizer else torch.vstack 90 | return DataLoader(self, batch_size=batch_size, collate_fn=collate_fn, shuffle=shuffle) 91 | 92 | 93 | 94 | 95 | class ILQLRolloutStorage(BaseRolloutStore): 96 | """ 97 | Rollout storage for training ILQL 98 | """ 99 | 100 | def __init__(self, input_ids, attention_mask, rewards, states_ixs, actions_ixs, dones): 101 | super().__init__() 102 | 103 | self.input_ids = input_ids 104 | self.attention_mask = attention_mask 105 | self.rewards = rewards 106 | self.states_ixs = states_ixs 107 | self.actions_ixs = actions_ixs 108 | self.dones = dones 109 | 110 | def __getitem__(self, ix: int) -> ILQLElement: 111 | return ILQLElement( 112 | self.input_ids[ix], 113 | self.attention_mask[ix], 114 | self.rewards[ix], 115 | self.states_ixs[ix], 116 | self.actions_ixs[ix], 117 | self.dones[ix], 118 | ) 119 | 120 | def __len__(self) -> int: 121 | return len(self.input_ids) 122 | 123 | def create_loader(self, batch_size: int, drop_last=True): 124 | return DataLoader( 125 | self, 126 | batch_size=batch_size, 127 | shuffle=True, 128 | collate_fn=ilql_collate_fn, 129 | drop_last=drop_last, 130 | ) 131 | 132 | 133 | def ilql_seq2seq_collate_fn(elems: Iterable[ILQLElement]): 134 | return ILQLSeq2SeqBatch( 135 | pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), 136 | pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), 137 | pad_sequence([x.decoder_input_ids for x in elems], batch_first=True, padding_value=0), 138 | pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), 139 | pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), 140 | pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), 141 | pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), 142 | ) 143 | 144 | 145 | class ILQLSeq2SeqRolloutStorage(BaseRolloutStore): 146 | """ 147 | Rollout storage for training ILQL 148 | """ 149 | 150 | def __init__(self, input_ids, attention_mask, decoder_input_ids, rewards, states_ixs, actions_ixs, dones): 151 | super().__init__() 152 | 153 | self.input_ids = input_ids 154 | self.attention_mask = attention_mask 155 | self.decoder_input_ids = decoder_input_ids 156 | self.rewards = rewards 157 | self.states_ixs = states_ixs 158 | self.actions_ixs = actions_ixs 159 | self.dones = dones 160 | 161 | def __getitem__(self, ix: int) -> ILQLElement: 162 | return ILQLSeq2SeqElement( 163 | self.input_ids[ix], 164 | self.attention_mask[ix], 165 | self.decoder_input_ids[ix], 166 | self.rewards[ix], 167 | self.states_ixs[ix], 168 | self.actions_ixs[ix], 169 | self.dones[ix], 170 | ) 171 | 172 | def __len__(self) -> int: 173 | return len(self.input_ids) 174 | 175 | def create_loader(self, batch_size: int, drop_last=True): 176 | return DataLoader( 177 | self, 178 | batch_size=batch_size, 179 | shuffle=True, 180 | collate_fn=ilql_seq2seq_collate_fn, 181 | drop_last=drop_last, 182 | ) 183 | 184 | 185 | class SQLRolloutStorage(BaseRolloutStore): 186 | """ 187 | Rollout storage for training PPO 188 | """ 189 | 190 | def __init__(self, pad_token_id): 191 | super().__init__() 192 | 193 | self.pad_token_id = pad_token_id 194 | self.history: Iterable[ILQLElement] = [None] 195 | 196 | def push(self, exps: Iterable[ILQLElement]): 197 | self.history += exps 198 | 199 | def clear_history(self): 200 | self.history = [] 201 | 202 | def export_history(self, location: str): 203 | assert os.path.exists(location) 204 | 205 | fpath = os.path.join(location, f"epoch-{str(time.time())}.json") 206 | 207 | def exp_to_dict(exp): 208 | {k: v.cpu().tolist() for k, v in exp.__dict__.items()} 209 | 210 | data = [exp_to_dict(exp) for exp in self.history] 211 | with open(fpath, "w") as f: 212 | f.write(json.dumps(data, indent=2)) 213 | 214 | def __getitem__(self, index: int) -> ILQLElement: 215 | return self.history[index] 216 | 217 | def __len__(self) -> int: 218 | return len(self.history) 219 | 220 | def create_loader( 221 | self, 222 | batch_size: int, 223 | shuffle: bool, 224 | ) -> DataLoader: 225 | def ilql_collate_fn(elems: Iterable[ILQLElement]): 226 | return ILQLBatch( 227 | pad_sequence([x.input_ids for x in elems], batch_first=True, padding_value=0), 228 | pad_sequence([x.attention_mask for x in elems], batch_first=True, padding_value=0), 229 | pad_sequence([x.rewards for x in elems], batch_first=True, padding_value=0.0), 230 | pad_sequence([x.states_ixs for x in elems], batch_first=True, padding_value=0), 231 | pad_sequence([x.actions_ixs for x in elems], batch_first=True, padding_value=0), 232 | pad_sequence([x.dones for x in elems], batch_first=True, padding_value=0), 233 | ) 234 | 235 | 236 | return DataLoader(self, batch_size, shuffle=shuffle, collate_fn=ilql_collate_fn) 237 | 238 | 239 | -------------------------------------------------------------------------------- /examples/apa_tldr.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | import tritonclient.grpc as client_util 9 | from datasets import load_dataset 10 | from huggingface_hub import snapshot_download 11 | from torch import nn 12 | 13 | from reward_model.reward_model import GPTRewardModel 14 | 15 | from transformers import AutoModelForCausalLM, AutoTokenizer 16 | from tritonclient.utils import np_to_triton_dtype 17 | 18 | import trlx 19 | from trlx.data.default_configs import ( 20 | ModelConfig, 21 | OptimizerConfig, 22 | SPPOConfig, 23 | SchedulerConfig, 24 | TokenizerConfig, 25 | TrainConfig, 26 | TRLConfig, 27 | ) 28 | import random 29 | 30 | OUTPUT_DIR = "output" 31 | 32 | RANDOM_SEED = 1000 33 | MODEL_SIZE = "6B" 34 | LOSS = "square" # "square" or "log", square for APA, and log for AWR 35 | ADV_COEFF_SQ = 10 36 | ADV_COEFF_LOG = 1 37 | 38 | REWARD_CHECKPOINT_PATH = "reward_model/rm_checkpoint/pytorch_model.bin" 39 | if not os.path.exists(REWARD_CHECKPOINT_PATH): 40 | os.makedirs("reward_model/rm_checkpoint", exist_ok=True) 41 | os.system( 42 | f"wget -O {REWARD_CHECKPOINT_PATH} \ 43 | https://huggingface.co/CarperAI/openai_summarize_tldr_rm_checkpoint/resolve/main/pytorch_model.bin" 44 | ) 45 | SFT_MODEL_PATH = "CarperAI/openai_summarize_tldr_sft" 46 | 47 | random.seed(RANDOM_SEED) 48 | np.random.seed(RANDOM_SEED) 49 | torch.manual_seed(RANDOM_SEED) 50 | torch.cuda.manual_seed(RANDOM_SEED) 51 | default_config = TRLConfig( 52 | train=TrainConfig( 53 | seq_length=550, 54 | epochs=10000, 55 | total_steps=5000, 56 | batch_size=4, 57 | checkpoint_interval=1000, 58 | eval_interval=1000, 59 | pipeline="PromptPipeline", 60 | trainer="AccelerateSPPOTrainer", 61 | checkpoint_dir="checkpoints/ppo_tldr", 62 | seed=RANDOM_SEED, 63 | ), 64 | model=ModelConfig(model_path="CarperAI/openai_summarize_tldr_sft", num_layers_unfrozen=8), 65 | tokenizer=TokenizerConfig(tokenizer_path="gpt2", truncation_side="right"), 66 | optimizer=OptimizerConfig(name="adamw", kwargs=dict(lr=1e-6, betas=(0.9, 0.999), eps=1.0e-8, weight_decay=1.0e-6)), 67 | scheduler=SchedulerConfig(name="cosine_annealing", kwargs=dict(T_max=10000, eta_min=1e-6)), 68 | method=SPPOConfig( 69 | name="SPPOConfig", 70 | num_rollouts=128, 71 | chunk_size=16, 72 | ppo_epochs=2, 73 | init_kl_coef=0.05, 74 | target=6, 75 | horizon=10000, 76 | gamma=1, 77 | lam=0.95, 78 | cliprange=100, 79 | cliprange_value=100, 80 | vf_coef=1, 81 | scale_reward="running", 82 | ref_mean=None, 83 | ref_std=None, 84 | loss_str=LOSS, 85 | adv_coeff_sq=ADV_COEFF_SQ, 86 | adv_coeff_log=ADV_COEFF_LOG, 87 | cliprange_reward=100, 88 | gen_kwargs=dict( 89 | max_new_tokens=50, 90 | ), 91 | ), 92 | ) 93 | 94 | 95 | 96 | def prepare_tensor(name: str, input): 97 | t = client_util.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) 98 | t.set_data_from_numpy(input) 99 | return t 100 | 101 | 102 | def create_reward_fn(post_summary_dict, config): # noqa: C901 103 | reward_tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B", cache_dir="/mnt/data/clausa-rl/tokenizers/gpt-j-6B") 104 | reward_tokenizer.pad_token = reward_tokenizer.eos_token 105 | triton_host = os.environ.get("TRITON_HOST") 106 | 107 | if triton_host: 108 | triton_url, triton_model = triton_host.split("/") 109 | client = client_util.InferenceServerClient(url=triton_url, verbose=False) 110 | 111 | def reward_fn(samples, prompts, outputs): 112 | samples = [s + reward_tokenizer.eos_token for s in samples] 113 | input = reward_tokenizer(samples, padding=True, max_length=1024) 114 | 115 | mbs = 24 116 | out = [] 117 | for i in range(math.ceil(len(samples) / mbs)): 118 | batch_ixs = slice(i * mbs, (i + 1) * mbs) 119 | input_ids = np.array(input.input_ids[batch_ixs], dtype=np.int32) 120 | 121 | result = client.infer(triton_model, [prepare_tensor("input_ids", input_ids)]) 122 | rewards = result.as_numpy("rewards") 123 | out.extend(rewards) 124 | 125 | return out 126 | 127 | elif os.environ.get("RANK", "0") == "0": 128 | reward_model = GPTRewardModel(SFT_MODEL_PATH, cache_dir="/mnt/data/clausa-rl/models/sft_summarization_gptj_tldr") 129 | reward_model.load_state_dict(torch.load(REWARD_CHECKPOINT_PATH)) 130 | reward_model.eval() 131 | reward_model.requires_grad_(False) 132 | device = torch.cuda.device_count() - 1 133 | reward_model = reward_model.half().to(device) 134 | 135 | def get_scores(samples): 136 | scores_list = [] 137 | batch_size = 1 138 | for i in range(0, len(samples), batch_size): 139 | sub_samples = samples[i : i + batch_size] 140 | sub_samples = ["<|startoftext|>" + chosen + "<|endoftext|>" for chosen in sub_samples] 141 | encodings_dict = reward_tokenizer( 142 | sub_samples, 143 | truncation=True, 144 | max_length=config.train.seq_length, 145 | padding="max_length", 146 | return_tensors="pt", 147 | ) 148 | input_ids = encodings_dict["input_ids"].to(device) 149 | attn_masks = encodings_dict["attention_mask"].to(device) 150 | input_ids = input_ids.repeat(2, 1) 151 | attn_masks = attn_masks.repeat(2, 1) 152 | with torch.no_grad(): 153 | sub_scores = reward_model(input_ids=input_ids, attention_mask=attn_masks) 154 | scores_list.append(sub_scores["chosen_end_scores"]) 155 | scores = torch.cat(scores_list, dim=0) 156 | return scores 157 | 158 | def reward_fn(samples, **kwargs): 159 | # original_samples = [text.split("TL;DR:")[0] + "TL;DR: " for text in samples] 160 | # for text in original_samples: 161 | # try: 162 | # original_samples = [text + post_summary_dict[text.strip()]] 163 | # except: 164 | # print("ERROR: \n {0}".format(text.strip())) 165 | 166 | # original_scores = get_scores(original_samples) 167 | scores = get_scores(samples) 168 | norms_scores = scores # - original_scores 169 | return norms_scores 170 | 171 | else: 172 | reward_fn = True 173 | 174 | return reward_fn 175 | 176 | 177 | def main(hparams={}): 178 | 179 | default_config.train.batch_size = 2 180 | default_config.train.total_steps = 5000 181 | default_config.model.model_path = SFT_MODEL_PATH 182 | default_config.method.chunk_size = 1 183 | 184 | output_dir = OUTPUT_DIR 185 | config = TRLConfig.update(default_config, hparams) 186 | # config.train.rollout_logging_dir = output_dir 187 | config.train.checkpoint_dir = output_dir 188 | config.train.logging_dir = output_dir 189 | config.train.tracker = "tensorboard" 190 | 191 | tokenizer = AutoTokenizer.from_pretrained(config.tokenizer.tokenizer_path) 192 | tokenizer.pad_token = tokenizer.eos_token 193 | tokenizer.padding_side = "left" 194 | max_length_input = config.train.seq_length - config.method.gen_kwargs["max_new_tokens"] 195 | 196 | def get_prompt_dataset(tokenizer, prompts, max_length): 197 | formatted_prompts = [] 198 | for i in range(len(prompts)): 199 | tmp = tokenizer.decode( 200 | tokenizer( 201 | prompts[i].split("TL;DR:")[0], 202 | truncation=True, 203 | max_length=max_length - 5, # to make sure "TL;DR" dont get truncated 204 | add_special_tokens=False, 205 | )["input_ids"], 206 | skip_special_tokens=True, 207 | ).strip() 208 | tmp = tmp + "\nTL;DR:" 209 | tmp = tokenizer.decode( 210 | tokenizer(tmp, truncation=True, max_length=max_length, add_special_tokens=False)["input_ids"], 211 | skip_special_tokens=True, 212 | ).strip() 213 | formatted_prompts.append(tmp) 214 | return formatted_prompts 215 | 216 | dataset = load_dataset("CarperAI/openai_summarize_tldr") 217 | 218 | train_set = [(sample["prompt"], sample["label"]) for sample in dataset["train"]] 219 | val_set = [(sample["prompt"], sample["label"]) for sample in dataset["valid"]] 220 | 221 | train_posts, train_summaries = zip(*train_set) 222 | val_posts, val_summaries = zip(*val_set) 223 | 224 | post_summary_dict = {} 225 | train_prompts = get_prompt_dataset(tokenizer, train_posts, max_length_input) 226 | 227 | for i in range(len(train_prompts)): 228 | post_summary_dict[train_prompts[i]] = train_summaries[i] 229 | val_prompts = get_prompt_dataset(tokenizer, val_posts, max_length_input) 230 | for i in range(len(val_prompts)): 231 | post_summary_dict[val_prompts[i]] = val_summaries[i] 232 | 233 | reward_fn = create_reward_fn(post_summary_dict, config) 234 | 235 | trlx.train( 236 | prompts=train_prompts, 237 | eval_prompts=val_prompts[0:500], 238 | reward_fn=reward_fn, 239 | config=config, 240 | ) 241 | 242 | 243 | if __name__ == "__main__": 244 | hparams = {} if len(sys.argv) == 1 else json.loads(sys.argv[1]) 245 | main(hparams) 246 | -------------------------------------------------------------------------------- /trlx/data/configs.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from dataclasses import dataclass, field 3 | from typing import Any, Dict, Optional, Set 4 | 5 | import yaml 6 | 7 | from trlx.data.method_configs import MethodConfig, get_method 8 | 9 | 10 | def merge(base: Dict, update: Dict, updated: Set) -> Dict: 11 | "Recursively updates a nested dictionary with new values" 12 | for k, v in base.items(): 13 | if k in update and isinstance(v, dict): 14 | base[k] = merge(v, update[k], updated) 15 | updated.add(k) 16 | elif k in update: 17 | base[k] = update[k] 18 | updated.add(k) 19 | 20 | return base 21 | 22 | 23 | def _merge_dicts(base: Dict, update: Dict) -> Dict: 24 | "Merge two dictionaries recursively, returning a new dictionary." 25 | 26 | base = deepcopy(base) 27 | 28 | for k, v in update.items(): 29 | if isinstance(v, dict): 30 | base[k] = _merge_dicts(base.get(k, {}), v) 31 | else: 32 | base[k] = v 33 | 34 | return base 35 | 36 | 37 | @dataclass 38 | class ModelConfig: 39 | """ 40 | Config for a model. 41 | 42 | :param model_path: Path or name of the model (local or on huggingface hub) 43 | :type model_path: str 44 | 45 | :param model_arch_type: Type of model architecture. Either "causal" or "seq2seq" 46 | :type model_arch_type: str 47 | 48 | :param num_layers_unfrozen: Number of layers to unfreeze for fine-tuning. 49 | -1 means all layers are unfrozen. 50 | :type num_layers_unfrozen: int 51 | 52 | :param delta_kwargs: Keyword arguments for instantiating OpenDelta models for delta-tuning. 53 | Follow the `OpenDelta.AutoDeltaConfig` specification, e.g. for LoRA style tuning, set 54 | the `delta_type` to `lora` and include the model specific hyper-parameters (e.g. `lora_r`) 55 | {"delta_type": "lora", "modified_modules": "all", "lora_r": 8, "lora_alpha": 16, "lora_dropout": 0.0} 56 | or in YAML format: 57 | delta_kwargs: 58 | delta_type: lora 59 | modified_modules: "all" 60 | lora_r: 8 61 | lora_alpha: 16 62 | lora_dropout: 0.0 63 | See: https://opendelta.readthedocs.io/en/latest/modules/auto_delta.html#opendelta.auto_delta.AutoDeltaConfig 64 | :type delta_kwargs: Optional[Dict[str, Any]] 65 | """ 66 | 67 | model_path: str 68 | model_arch_type: str = "causal" 69 | num_layers_unfrozen: int = -1 70 | delta_kwargs: Optional[Dict[str, Any]] = None 71 | 72 | @classmethod 73 | def from_dict(cls, config: Dict[str, Any]): 74 | return cls(**config) 75 | 76 | 77 | @dataclass 78 | class TokenizerConfig: 79 | """ 80 | Config for a model. 81 | 82 | :param tokenizer_path: Path or name of the tokenizer (local or on huggingface hub) 83 | :type tokenizer_path: str 84 | 85 | :param padding_side: Padding side 86 | :type padding_path: str 87 | 88 | :param truncation_side: Truncation side 89 | :type truncation_side: str 90 | """ 91 | 92 | tokenizer_path: str 93 | padding_side: str = "left" 94 | truncation_side: str = "right" 95 | 96 | @classmethod 97 | def from_dict(cls, config: Dict[str, Any]): 98 | return cls(**config) 99 | 100 | 101 | @dataclass 102 | class OptimizerConfig: 103 | """ 104 | Config for an optimizer. 105 | 106 | :param name: Name of the optimizer 107 | :type name: str 108 | 109 | :param kwargs: Keyword arguments for the optimizer (e.g. lr, betas, eps, weight_decay) 110 | :type kwargs: Dict[str, Any] 111 | """ 112 | 113 | name: str 114 | kwargs: Dict[str, Any] = field(default_factory=dict) 115 | 116 | @classmethod 117 | def from_dict(cls, config: Dict[str, Any]): 118 | return cls(**config) 119 | 120 | 121 | @dataclass 122 | class SchedulerConfig: 123 | """ 124 | Config for a learning rate scheduler. 125 | 126 | :param name: Name of the scheduler 127 | :type name: str 128 | 129 | :param kwargs: Keyword arguments for the scheduler instance (e.g. warmup_steps, T_max) 130 | :type kwargs: Dict[str, Any] 131 | """ 132 | 133 | name: str 134 | kwargs: Dict[str, Any] = field(default_factory=dict) 135 | 136 | @classmethod 137 | def from_dict(cls, config: Dict[str, Any]): 138 | return cls(**config) 139 | 140 | 141 | @dataclass 142 | class TrainConfig: 143 | """ 144 | Config for train job on model. 145 | 146 | :param total_steps: Total number of training steps 147 | :type total_steps: int 148 | 149 | :param seq_length: Number of tokens to use as context (max length for tokenizer) 150 | :type seq_length: int 151 | 152 | :param epochs: Total number of passes through data 153 | :type epochs: int 154 | 155 | :param batch_size: Batch size for training 156 | :type batch_size: int 157 | 158 | :param tracker: Tracker to use for logging. Default: "wandb" 159 | :type tracker: str 160 | 161 | :param checkpoint_interval: Save model every checkpoint_interval steps. 162 | Each checkpoint is stored in a sub-directory of the `TrainConfig.checkpoint_dir` 163 | directory in the format `checkpoint_dir/checkpoint_{step}`. 164 | :type checkpoint_interval: int 165 | 166 | :param eval_interval: Evaluate model every eval_interval steps 167 | :type eval_interval: int 168 | 169 | :param pipeline: Pipeline to use for training. One of the registered pipelines present in trlx.pipeline 170 | :type pipeline: str 171 | 172 | :param trainer: Trainer to use for training. One of the registered trainers present in trlx.trainer 173 | :type trainer: str 174 | 175 | :param trainer_kwargs: Extra keyword arguments for the trainer 176 | :type trainer: Dict[str, Any] 177 | 178 | :param project_name: Project name for wandb 179 | :type project_name: str 180 | 181 | :param entity_name: Entity name for wandb 182 | :type entity_name: str 183 | 184 | :param group_name: Group name for wandb (used for grouping runs) 185 | :type group_name: str 186 | 187 | :param checkpoint_dir: Directory to save checkpoints 188 | :type checkpoint_dir: str 189 | 190 | :param rollout_logging_dir: Directory to store generated rollouts for use in Algorithm Distillation. 191 | Only used by AcceleratePPOTrainer. 192 | :type rollout_logging_dir: Optional[str] 193 | 194 | :param save_best: Save best model based on mean reward 195 | :type save_best: bool 196 | 197 | :param seed: Random seed 198 | :type seed: int 199 | """ 200 | 201 | total_steps: int 202 | seq_length: int 203 | epochs: int 204 | batch_size: int 205 | 206 | checkpoint_interval: int 207 | eval_interval: int 208 | 209 | pipeline: str # One of the pipelines in framework.pipeline 210 | trainer: str # One of the trainers 211 | trainer_kwargs: Dict[str, Any] = field(default_factory=dict) # Extra keyword arguments for the trainer 212 | 213 | project_name: str = "trlx" 214 | entity_name: Optional[str] = None 215 | group_name: Optional[str] = None 216 | 217 | checkpoint_dir: str = "ckpts" 218 | rollout_logging_dir: Optional[str] = None 219 | save_best: bool = True 220 | 221 | tracker: Optional[str] = "wandb" 222 | logging_dir: Optional[str] = None 223 | 224 | seed: int = 1000 225 | 226 | @classmethod 227 | def from_dict(cls, config: Dict[str, Any]): 228 | return cls(**config) 229 | 230 | 231 | @dataclass 232 | class TRLConfig: 233 | """ 234 | Top level config for trlX. Loads configs and can be converted to dictionary. 235 | """ 236 | 237 | method: MethodConfig 238 | model: ModelConfig 239 | optimizer: OptimizerConfig 240 | scheduler: SchedulerConfig 241 | tokenizer: TokenizerConfig 242 | train: TrainConfig 243 | 244 | @classmethod 245 | def load_yaml(cls, yml_fp: str): 246 | """ 247 | Load yaml file as TRLConfig. 248 | 249 | :param yml_fp: Path to yaml file 250 | :type yml_fp: str 251 | """ 252 | with open(yml_fp, mode="r") as file: 253 | config = yaml.safe_load(file) 254 | return cls.from_dict(config) 255 | 256 | def to_dict(self): 257 | """ 258 | Convert TRLConfig to dictionary. 259 | """ 260 | data = { 261 | "method": self.method.__dict__, 262 | "model": self.model.__dict__, 263 | "optimizer": self.optimizer.__dict__, 264 | "scheduler": self.scheduler.__dict__, 265 | "tokenizer": self.tokenizer.__dict__, 266 | "train": self.train.__dict__, 267 | } 268 | 269 | return data 270 | 271 | def evolve(self, **kwargs) -> "TRLConfig": 272 | """ 273 | Evolve TRLConfig with new parameters. Can update nested parameters. 274 | >>> config = trlx.data.default_configs.default_ilql_config() 275 | >>> config = config.evolve(method=dict(gamma=0.99, gen_kwargs=dict(max_new_tokens=100)) 276 | >>> config.method.gamma 277 | 0.99 278 | """ 279 | return TRLConfig.from_dict(_merge_dicts(self.to_dict(), kwargs)) 280 | 281 | @classmethod 282 | def from_dict(cls, config: Dict): 283 | """ 284 | Convert dictionary to TRLConfig. 285 | """ 286 | return cls( 287 | method=get_method(config["method"]["name"]).from_dict(config["method"]), 288 | model=ModelConfig.from_dict(config["model"]), 289 | tokenizer=TokenizerConfig.from_dict(config["tokenizer"]), 290 | optimizer=OptimizerConfig.from_dict(config["optimizer"]), 291 | scheduler=SchedulerConfig.from_dict(config["scheduler"]), 292 | train=TrainConfig.from_dict(config["train"]), 293 | ) 294 | 295 | @classmethod 296 | def update(cls, baseconfig: Dict, config: Dict): 297 | if not isinstance(baseconfig, Dict): 298 | baseconfig = baseconfig.to_dict() 299 | 300 | updates = set() 301 | merged = merge(baseconfig, config, updates) 302 | 303 | for param in config: 304 | if param not in updates: 305 | raise ValueError(f"parameter {param} is not present in the config (typo or a wrong config)") 306 | 307 | return cls.from_dict(merged) 308 | 309 | def __str__(self): 310 | """Returns a human-readable string representation of the config.""" 311 | import json 312 | 313 | return json.dumps(self.to_dict(), indent=4) 314 | --------------------------------------------------------------------------------