├── .github └── workflows │ └── python-publish.yml ├── .gitignore ├── LICENSE ├── README.md ├── chatgpt.png ├── data ├── README.md └── enwik8.gz ├── examples.py ├── palm_rlhf_pytorch ├── __init__.py ├── attention.py ├── grpo.py ├── implicit_process_reward.py ├── lora.py ├── palm.py ├── ppo.py ├── reward.py └── utils.py ├── setup.py └── train.py /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | 2 | 3 | # This workflow will upload a Python Package using Twine when a release is created 4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries 5 | 6 | # This workflow uses actions that are not certified by GitHub. 7 | # They are provided by a third-party and are governed by 8 | # separate terms of service, privacy policy, and support 9 | # documentation. 10 | 11 | name: Upload Python Package 12 | 13 | on: 14 | release: 15 | types: [published] 16 | 17 | jobs: 18 | deploy: 19 | 20 | runs-on: ubuntu-latest 21 | 22 | steps: 23 | - uses: actions/checkout@v2 24 | - name: Set up Python 25 | uses: actions/setup-python@v2 26 | with: 27 | python-version: '3.x' 28 | - name: Install dependencies 29 | run: | 30 | python -m pip install --upgrade pip 31 | pip install build 32 | - name: Build package 33 | run: python -m build 34 | - name: Publish package 35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 36 | with: 37 | user: __token__ 38 | password: ${{ secrets.PYPI_API_TOKEN }} 39 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | *official chatgpt blogpost* 4 | 5 | ## PaLM + RLHF - Pytorch (wip) 6 | 7 | Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Maybe I'll add retrieval functionality too, à la RETRO 8 | 9 | If you are interested in replicating something like ChatGPT out in the open, please consider joining Laion Join us on Discord 10 | 11 | Potential successor: Direct Preference Optimization - all the code in this repo becomes ~ binary cross entropy loss, < 5 loc. So much for Reward models and PPO 12 | 13 | ## FAQ 14 | 15 | - Does this contain a model for inference? 16 | 17 | There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point. 18 | 19 | ## Community 20 | 21 | CarperAI had been working on an RLHF framework for large language models for many months prior to the release of ChatGPT. 22 | 23 | Yannic Kilcher is also working on an open sourced implementation 24 | 25 | AI Coffeebreak w/ Letitia | Code Emporium | Code Emporium Part 2 26 | 27 | ## Appreciation 28 | 29 | - Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research 30 | 31 | - 🤗 Hugging Face and CarperAI for penning the blog post Illustrating Reinforcement Learning from Human Feedback (RLHF), and the former also for their accelerate library 32 | 33 | - @kisseternity and @taynoel84 for the code review and finding bugs 34 | 35 | - Enrico for integrating Flash Attention from Pytorch 2.0 36 | 37 | ## Install 38 | 39 | ```bash 40 | $ pip install palm-rlhf-pytorch 41 | ``` 42 | 43 | ## Usage 44 | 45 | First train `PaLM`, like any other autoregressive transformer 46 | 47 | ```python 48 | import torch 49 | from palm_rlhf_pytorch import PaLM 50 | 51 | palm = PaLM( 52 | num_tokens = 20000, 53 | dim = 512, 54 | depth = 12, 55 | flash_attn = True # https://arxiv.org/abs/2205.14135 56 | ).cuda() 57 | 58 | seq = torch.randint(0, 20000, (1, 2048)).cuda() 59 | 60 | loss = palm(seq, return_loss = True) 61 | loss.backward() 62 | 63 | # after much training, you can now generate sequences 64 | 65 | generated = palm.generate(2048) # (1, 2048) 66 | ``` 67 | 68 | Then train your reward model, with the curated human feedback. In the original paper, they could not get reward model to be finetuned from a pretrained transformer without overfitting, but I gave the option to finetune with `LoRA` anyways, since it is still open research. 69 | 70 | ```python 71 | import torch 72 | from palm_rlhf_pytorch import PaLM, RewardModel 73 | 74 | palm = PaLM( 75 | num_tokens = 20000, 76 | dim = 512, 77 | depth = 12, 78 | causal = False 79 | ) 80 | 81 | reward_model = RewardModel( 82 | palm, 83 | num_binned_output = 5 # say rating from 1 to 5 84 | ).cuda() 85 | 86 | # mock data 87 | 88 | seq = torch.randint(0, 20000, (1, 1024)).cuda() 89 | prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response 90 | labels = torch.randint(0, 5, (1,)).cuda() 91 | 92 | # train 93 | 94 | loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) 95 | loss.backward() 96 | 97 | # after much training 98 | 99 | reward = reward_model(seq, prompt_mask = prompt_mask) 100 | ``` 101 | 102 | Then you will pass your transformer and the rewards model to the `RLHFTrainer` 103 | 104 | ```python 105 | import torch 106 | from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer 107 | 108 | # load your pretrained palm 109 | 110 | palm = PaLM( 111 | num_tokens = 20000, 112 | dim = 512, 113 | depth = 12 114 | ).cuda() 115 | 116 | palm.load('./path/to/pretrained/palm.pt') 117 | 118 | # load your pretrained reward model 119 | 120 | reward_model = RewardModel( 121 | palm, 122 | num_binned_output = 5 123 | ).cuda() 124 | 125 | reward_model.load('./path/to/pretrained/reward_model.pt') 126 | 127 | # ready your list of prompts for reinforcement learning 128 | 129 | prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts 130 | 131 | # pass it all to the trainer and train 132 | 133 | trainer = RLHFTrainer( 134 | palm = palm, 135 | reward_model = reward_model, 136 | prompt_token_ids = prompts 137 | ) 138 | 139 | trainer.train(num_episodes = 50000) 140 | 141 | # then, if it succeeded... 142 | # generate say 10 samples and use the reward model to return the best one 143 | 144 | answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) 145 | ``` 146 | 147 | ## Todo 148 | 149 | - [x] clone base transformer with separate lora for critic 150 | - [x] also allow for non-LoRA based finetuning 151 | - [x] redo normalize to be able to have a masked version, not sure if anyone will ever use per token rewards / values, but good practice to implement 152 | - [x] equip with the best attention 153 | 154 | - [ ] add Hugging Face accelerate and test out wandb instrumentation 155 | - [ ] search literature to figure out what is the latest SOTA for PPO, assuming RL field is still making progress. 156 | - [ ] test the system using a pretrained sentiment network as reward model 157 | - [ ] write the memory in PPO to memmapped numpy file 158 | - [ ] get sampling with variable lengthed prompts working, even if it is not needed given bottleneck is human feedback 159 | - [ ] allow for finetuning penultimate N layers only in either actor or critic, assuming if pretrained 160 | - [ ] incorporate some learning points from Sparrow, given Letitia's video 161 | - [ ] simple web interface with django + htmx for collecting human feedback 162 | - [ ] consider RLAIF 163 | 164 | ## Citations 165 | 166 | ```bibtex 167 | @article{Stiennon2020LearningTS, 168 | title = {Learning to summarize from human feedback}, 169 | author = {Nisan Stiennon and Long Ouyang and Jeff Wu and Daniel M. Ziegler and Ryan J. Lowe and Chelsea Voss and Alec Radford and Dario Amodei and Paul Christiano}, 170 | journal = {ArXiv}, 171 | year = {2020}, 172 | volume = {abs/2009.01325} 173 | } 174 | ``` 175 | 176 | ```bibtex 177 | @inproceedings{Chowdhery2022PaLMSL, 178 | title = {PaLM: Scaling Language Modeling with Pathways}, 179 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel}, 180 | year = {2022} 181 | } 182 | ``` 183 | 184 | ```bibtex 185 | @article{Hu2021LoRALA, 186 | title = {LoRA: Low-Rank Adaptation of Large Language Models}, 187 | author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen}, 188 | journal = {ArXiv}, 189 | year = {2021}, 190 | volume = {abs/2106.09685} 191 | } 192 | ``` 193 | 194 | ```bibtex 195 | @inproceedings{Sun2022ALT, 196 | title = {A Length-Extrapolatable Transformer}, 197 | author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei}, 198 | year = {2022} 199 | } 200 | ``` 201 | 202 | ```bibtex 203 | @misc{gilmer2023intriguing 204 | title = {Intriguing Properties of Transformer Training Instabilities}, 205 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen}, 206 | year = {2023}, 207 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} 208 | } 209 | ``` 210 | 211 | ```bibtex 212 | @inproceedings{dao2022flashattention, 213 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, 214 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, 215 | booktitle = {Advances in Neural Information Processing Systems}, 216 | year = {2022} 217 | } 218 | ``` 219 | 220 | ```bibtex 221 | @misc{Rubin2024, 222 | author = {Ohad Rubin}, 223 | url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950} 224 | } 225 | ``` 226 | 227 | ```bibtex 228 | @inproceedings{Yuan2024FreePR, 229 | title = {Free Process Rewards without Process Labels}, 230 | author = {Lifan Yuan and Wendi Li and Huayu Chen and Ganqu Cui and Ning Ding and Kaiyan Zhang and Bowen Zhou and Zhiyuan Liu and Hao Peng}, 231 | year = {2024}, 232 | url = {https://api.semanticscholar.org/CorpusID:274445748} 233 | } 234 | ``` 235 | 236 | ```bibtex 237 | @article{Shao2024DeepSeekMathPT, 238 | title = {DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models}, 239 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Jun-Mei Song and Mingchuan Zhang and Y. K. Li and Yu Wu and Daya Guo}, 240 | journal = {ArXiv}, 241 | year = {2024}, 242 | volume = {abs/2402.03300}, 243 | url = {https://api.semanticscholar.org/CorpusID:267412607} 244 | } 245 | ``` 246 | 247 | ```bibtex 248 | @article{Farebrother2024StopRT, 249 | title = {Stop Regressing: Training Value Functions via Classification for Scalable Deep RL}, 250 | author = {Jesse Farebrother and Jordi Orbay and Quan Ho Vuong and Adrien Ali Taiga and Yevgen Chebotar and Ted Xiao and Alex Irpan and Sergey Levine and Pablo Samuel Castro and Aleksandra Faust and Aviral Kumar and Rishabh Agarwal}, 251 | journal = {ArXiv}, 252 | year = {2024}, 253 | volume = {abs/2403.03950}, 254 | url = {https://api.semanticscholar.org/CorpusID:268253088} 255 | } 256 | ``` 257 | 258 | ```bibtex 259 | @misc{Liu2025, 260 | title = {Understanding R1-Zero-Like Training: A Critical Perspective}, 261 | author = {Zichen Liu, Changyu Chen, Wenjun Li, Penghui Qi, Tianyu Pang, Chao Du, Wee Sun Lee, Min Lin}, 262 | url = {https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf} 263 | } 264 | ``` 265 | 266 | ```bibtex 267 | @inproceedings{Yue2025DoesRL, 268 | title = {Does Reinforcement Learning Really Incentivize Reasoning Capacity in LLMs Beyond the Base Model?}, 269 | author = {Yang Yue and Zhiqi Chen and Rui Lu and Andrew Zhao and Zhaokai Wang and Shiji Song and Gao Huang}, 270 | year = {2025}, 271 | url = {https://api.semanticscholar.org/CorpusID:277940134} 272 | } 273 | ``` 274 | -------------------------------------------------------------------------------- /chatgpt.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/PaLM-rlhf-pytorch/114b4db005809c8c055db35257af3ce49395b75a/chatgpt.png -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data source 2 | 3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/ -------------------------------------------------------------------------------- /data/enwik8.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lucidrains/PaLM-rlhf-pytorch/114b4db005809c8c055db35257af3ce49395b75a/data/enwik8.gz -------------------------------------------------------------------------------- /examples.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer 3 | from accelerate import Accelerator 4 | 5 | accelerator = Accelerator() 6 | device = accelerator.device 7 | 8 | # load your pretrained palm 9 | 10 | palm = PaLM( 11 | num_tokens = 20000, 12 | dim = 512, 13 | depth = 12 14 | ).to(device) 15 | 16 | 17 | # load your pretrained reward model 18 | 19 | reward_model = RewardModel( 20 | palm, 21 | num_binned_output = 5 22 | ).to(device) 23 | 24 | # Train you reward model on mock data : 25 | # mock data 26 | 27 | seq = torch.randint(0, 20000, (1, 1024)).to(device) 28 | prompt_mask = torch.zeros(1, 1024).bool().to(device) # which part of the sequence is prompt, which part is response 29 | labels = torch.randint(0, 5, (1,)).to(device) 30 | 31 | # train 32 | loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels) 33 | accelerator.backward(loss) 34 | 35 | # after much training 36 | reward = reward_model(seq, prompt_mask = prompt_mask) 37 | 38 | 39 | # ready your list of prompts for reinforcement learning 40 | 41 | prompts = torch.randint(0, 256, (1, 512)).to(device) # 1 prompt 42 | 43 | # pass it all to the trainer and train 44 | 45 | trainer = RLHFTrainer( 46 | palm = palm, 47 | reward_model = reward_model, 48 | prompt_token_ids = prompts 49 | ) 50 | 51 | accelerator.print("Training") 52 | trainer.train( 53 | num_episodes = 1, 54 | max_timesteps = 1, 55 | update_timesteps = 1, 56 | max_batch_size = 256, 57 | max_seq_len = 2048, 58 | eos_token = None, 59 | temperature = 1. 60 | ) 61 | 62 | # then, if it succeeded... 63 | # generate say 10 samples and use the reward model to return the best one 64 | accelerator.print("Generating answer") 65 | answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,) 66 | accelerator.print(f"answer: {answer}") -------------------------------------------------------------------------------- /palm_rlhf_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | from palm_rlhf_pytorch.palm import PaLM 2 | from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic 3 | 4 | from palm_rlhf_pytorch.reward import RewardModel 5 | from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM 6 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn, einsum 3 | from torch.nn import Module 4 | import torch.nn.functional as F 5 | 6 | from collections import namedtuple 7 | from functools import wraps 8 | from packaging import version 9 | 10 | from einops import rearrange 11 | 12 | # constants 13 | 14 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient']) 15 | 16 | # helpers 17 | 18 | def exists(val): 19 | return val is not None 20 | 21 | def once(fn): 22 | called = False 23 | @wraps(fn) 24 | def inner(x): 25 | nonlocal called 26 | if called: 27 | return 28 | called = True 29 | return fn(x) 30 | return inner 31 | 32 | print_once = once(print) 33 | 34 | # main class 35 | 36 | class Attention(Module): 37 | def __init__( 38 | self, 39 | dropout = 0., 40 | causal = False, 41 | use_flash_attn = False 42 | ): 43 | super().__init__() 44 | self.dropout = dropout 45 | self.attn_dropout = nn.Dropout(dropout) 46 | 47 | self.causal = causal 48 | self.register_buffer("mask", None, persistent=False) 49 | 50 | self.use_flash_attn = use_flash_attn 51 | assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above' 52 | 53 | # determine efficient attention configs for cuda and cpu 54 | 55 | self.cpu_config = Config(True, True, True) 56 | self.cuda_config = None 57 | 58 | if not torch.cuda.is_available() or not use_flash_attn: 59 | return 60 | 61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda')) 62 | 63 | if device_properties.major == 8 and device_properties.minor == 0: 64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda') 65 | self.cuda_config = Config(True, False, False) 66 | else: 67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda') 68 | self.cuda_config = Config(False, True, True) 69 | 70 | def get_mask(self, n, device): 71 | if exists(self.mask) and self.mask.shape[-1] >= n: 72 | return self.mask[:n, :n] 73 | 74 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) 75 | self.register_buffer("mask", mask, persistent=False) 76 | return mask 77 | 78 | def flash_attn(self, q, k, v, mask = None): 79 | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda 80 | 81 | # Recommended for multi-query single-key-value attention by Tri Dao 82 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) 83 | 84 | k = rearrange(k, 'b ... -> b 1 ...').expand_as(q) 85 | v = rearrange(v, 'b ... -> b 1 ...').expand_as(q) 86 | 87 | # Check if mask exists and expand to compatible shape 88 | # The mask is B L, so it would have to be expanded to B H N L 89 | 90 | if exists(mask): 91 | mask = rearrange(mask, 'b j -> b 1 1 j') 92 | mask = mask.expand(-1, heads, q_len, -1) 93 | 94 | # Check if there is a compatible device for flash attention 95 | 96 | config = self.cuda_config if is_cuda else self.cpu_config 97 | 98 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale 99 | 100 | with torch.backends.cuda.sdp_kernel(**config._asdict()): 101 | out = F.scaled_dot_product_attention( 102 | q, k, v, 103 | attn_mask = mask, 104 | dropout_p = self.dropout if self.training else 0., 105 | is_causal = self.causal 106 | ) 107 | 108 | return out 109 | 110 | def forward(self, q, k, v, mask = None): 111 | """ 112 | einstein notation 113 | b - batch 114 | h - heads 115 | n, i, j - sequence length (base sequence length, source, target) 116 | d - feature dimension 117 | """ 118 | 119 | n, device = q.shape[-2], q.device 120 | 121 | scale = q.shape[-1] ** -0.5 122 | 123 | if self.use_flash_attn: 124 | return self.flash_attn(q, k, v, mask = mask) 125 | 126 | # similarity 127 | 128 | sim = einsum("b h i d, b j d -> b h i j", q, k) * scale 129 | 130 | # key padding mask 131 | 132 | if exists(mask): 133 | mask = rearrange(mask, 'b j -> b 1 1 j') 134 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) 135 | 136 | # causal mask 137 | 138 | if self.causal: 139 | causal_mask = self.get_mask(n, device) 140 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) 141 | 142 | # attention 143 | 144 | attn = sim.softmax(dim=-1) 145 | attn = self.attn_dropout(attn) 146 | 147 | # aggregate values 148 | 149 | out = einsum("b h i j, b j d -> b h i d", attn, v) 150 | 151 | return out 152 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/grpo.py: -------------------------------------------------------------------------------- 1 | """ 2 | GRPO based training logic - https://arxiv.org/abs/2402.03300 3 | """ 4 | 5 | from __future__ import annotations 6 | from typing import Callable, Deque 7 | 8 | import math 9 | import copy 10 | from pathlib import Path 11 | from functools import partial 12 | from collections import deque, namedtuple 13 | from random import randrange 14 | 15 | import torch 16 | from torch import nn, Tensor 17 | from torch.nn import Module 18 | import torch.nn.functional as F 19 | 20 | from torch.utils.data import Dataset, DataLoader 21 | from torch.nn.utils.rnn import pad_sequence 22 | 23 | from adam_atan2_pytorch import AdoptAtan2 24 | 25 | from palm_rlhf_pytorch.palm import PaLM 26 | from palm_rlhf_pytorch.reward import RewardModel 27 | from palm_rlhf_pytorch.utils import masked_mean, eval_decorator 28 | 29 | from accelerate import Accelerator 30 | from accelerate.utils.tqdm import tqdm 31 | 32 | from einx import get_at 33 | from einops import rearrange, repeat, reduce, pack, unpack 34 | from einops.layers.torch import Rearrange 35 | 36 | # einstein notation 37 | 38 | # b - batch 39 | # n - sequence 40 | # d - feature dimension 41 | # l - logits 42 | 43 | # grpo based training 44 | 45 | # critic completely replaced with monte carlo sampling from actor + reward model 46 | # https://www.youtube.com/watch?v=bAWV_yrqx4w 47 | 48 | GRPOActionReturn = namedtuple('GRPOActionReturn', [ 49 | 'actions', 50 | 'sequence', 51 | 'mask', 52 | 'prompt_mask', 53 | 'action_logits', 54 | ]) 55 | 56 | class Actor(Module): 57 | def __init__( 58 | self, 59 | palm: PaLM, 60 | actor_lora = True, 61 | actor_lora_r = 8, 62 | actor_lora_scope = 'actor', 63 | actor_dropout = 0., 64 | ): 65 | super().__init__() 66 | self.actor_palm = palm 67 | 68 | self.actor_palm.set_dropout(actor_dropout) 69 | 70 | self.actor_lora = actor_lora 71 | 72 | self.actor_lora_scope = actor_lora_scope if actor_lora else None 73 | 74 | if self.actor_lora: 75 | self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) 76 | 77 | def parameters(self): 78 | if not self.actor_lora: 79 | return self.actor_palm.parameters() 80 | 81 | return [ 82 | *self.actor_palm.finetune_parameters(self.actor_lora_scope) 83 | ] 84 | 85 | @torch.no_grad() 86 | @eval_decorator 87 | def generate( 88 | self, 89 | state, 90 | max_seq_len, 91 | eos_token = None, 92 | **kwargs 93 | ): 94 | actions = self.actor_palm.generate( 95 | max_seq_len, 96 | prompt = state, 97 | eos_token = eos_token, 98 | finetune_scope = self.actor_lora_scope, 99 | use_tqdm = True, 100 | **kwargs 101 | ) 102 | 103 | sequence = torch.cat((state, actions), dim = -1) 104 | action_len = actions.shape[-1] 105 | state_len = state.shape[-1] 106 | 107 | prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len 108 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) 109 | 110 | action_mask = ~prompt_mask 111 | 112 | mask = None 113 | if exists(eos_token): 114 | mask = ((sequence == eos_token).cumsum(dim = -1) == 0) 115 | mask = F.pad(mask, (1, -1), value = True) # include eos token 116 | action_mask &= mask 117 | 118 | action_logits = self.forward( 119 | sequence, 120 | mask = action_mask, 121 | ) 122 | 123 | return GRPOActionReturn( 124 | actions, 125 | sequence, 126 | mask, 127 | prompt_mask, 128 | action_logits 129 | ) 130 | 131 | def forward( 132 | self, 133 | x, 134 | mask = None, 135 | ): 136 | return self.actor_palm(x, finetune_scope = self.actor_lora_scope) 137 | 138 | # data 139 | 140 | Memory = namedtuple('Memory', [ 141 | 'sequence', 142 | 'prompt_mask', 143 | 'mask', 144 | 'action_prob', 145 | 'action_log_prob', 146 | 'group_relative_normalized_reward', 147 | ]) 148 | 149 | class ExperienceDataset(Dataset): 150 | def __init__( 151 | self, 152 | data, 153 | device = None 154 | ): 155 | super().__init__() 156 | self.data = data 157 | self.device = device 158 | 159 | def __len__(self): 160 | return self.data[0].shape[0] 161 | 162 | def __getitem__(self, ind): 163 | return tuple(map(lambda t: t[ind].to(self.device), self.data)) 164 | 165 | def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): 166 | ds = ExperienceDataset(data, device = device) 167 | return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) 168 | 169 | # helper functions 170 | 171 | def exists(val): 172 | return val is not None 173 | 174 | def default(val, d): 175 | if exists(val): 176 | return val 177 | return d() if callable(d) else d 178 | 179 | def first(x): 180 | return x[0] 181 | 182 | def divisible_by(num, den): 183 | return (num % den) == 0 184 | 185 | def pad_sequence_fixed(sequences, *args, **kwargs): 186 | first_el = sequences[0] 187 | has_no_dimension = first_el.ndim == 0 188 | 189 | # if no dimensions, add a single dimension 190 | if has_no_dimension: 191 | sequences = tuple(map(lambda t: t[None], sequences)) 192 | 193 | out = pad_sequence(sequences, *args, **kwargs) 194 | 195 | if not has_no_dimension: 196 | return out 197 | 198 | return rearrange(out, '... 1 -> ...') 199 | 200 | def log(t, eps = 1e-20): 201 | return torch.log(t.clamp(min = eps)) 202 | 203 | def shift(t, value = 0, shift = 1, dim = -1): 204 | zeros = (0, 0) * (-dim - 1) 205 | return F.pad(t, (*zeros, shift, -shift), value = value) 206 | 207 | def masked_entropy(prob, dim = -1, mask = None): 208 | entropies = (prob * log(prob)).sum(dim = -1) 209 | return masked_mean(entropies, mask = mask).mean() 210 | 211 | def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False): 212 | """ 213 | need to account for variable sequence lengths, therefore not using the built-in functional version 214 | """ 215 | kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1) 216 | loss = masked_mean(kl_divs, mask) 217 | 218 | if not reduce_batch: 219 | return loss 220 | 221 | return loss.mean() 222 | 223 | # rlhf trainer 224 | 225 | class RLHFTrainer(Module): 226 | 227 | def __init__( 228 | self, 229 | *, 230 | prompts: list[str] | None = None, 231 | prompts_path: str | None = None, 232 | prompt_token_ids: Tensor | None = None, 233 | tokenizer: Callable | None = None, 234 | palm: PaLM, 235 | reward_model: RewardModel, 236 | grpo_num_times_sample_rewards = 10, 237 | actor_lr = 1e-4, 238 | actor_wd = 0., 239 | actor_lora = True, 240 | actor_lora_r = 8, 241 | actor_dropout = 0., 242 | betas = (0.9, 0.999), 243 | max_norm = None, 244 | eps_clip = 0.2, 245 | beta_s = .01, 246 | pad_value = 0., 247 | minibatch_size = 16, 248 | epochs = 1, 249 | kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is 250 | accelerate_kwargs: dict = dict(), 251 | ): 252 | super().__init__() 253 | 254 | self.accelerate = Accelerator(**accelerate_kwargs) 255 | 256 | # take care of prompts -> token ids 257 | 258 | assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1 259 | 260 | if exists(prompts_path): 261 | path = Path(prompts_path) 262 | prompts = path.read_text().split('\n') 263 | 264 | if exists(prompts): 265 | assert len(prompts) > 0, 'no prompts' 266 | assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given' 267 | prompt_token_ids = tokenizer(prompts) 268 | 269 | self.pad_value = pad_value # token pad value 270 | self.num_prompts = prompt_token_ids.shape[0] 271 | self.register_buffer('prompt_token_ids', prompt_token_ids) 272 | 273 | # models 274 | 275 | self.palm = palm 276 | 277 | actor = Actor( 278 | palm = palm, 279 | actor_lora = actor_lora, 280 | actor_lora_r = actor_lora_r, 281 | actor_dropout = actor_dropout, 282 | ) 283 | 284 | self.actor = actor 285 | 286 | self.actor_generate = self.actor.generate 287 | 288 | self.reward_model = reward_model.eval() 289 | 290 | # train hyperparameters 291 | 292 | self.epochs = epochs 293 | self.minibatch_size = minibatch_size 294 | self.max_norm = max_norm 295 | 296 | self.kl_div_loss_weight = kl_div_loss_weight 297 | 298 | # optimizers 299 | 300 | self.actor_optim = AdoptAtan2(actor.parameters(), lr = actor_lr, weight_decay = actor_wd, betas = betas) 301 | 302 | # grpo hyperparams 303 | 304 | self.eps_clip = eps_clip 305 | self.beta_s = beta_s 306 | 307 | # grpo - the number of times to sample rewards for a given state (prompt) for normalization 308 | 309 | self.grpo_num_times_sample_rewards = grpo_num_times_sample_rewards 310 | 311 | # prepare with accelerator 312 | 313 | ( 314 | self.actor, 315 | self.reward_model, 316 | self.actor_optim, 317 | ) = self.accelerate.prepare( 318 | self.actor, 319 | self.reward_model, 320 | self.actor_optim, 321 | ) 322 | 323 | 324 | def print(self, msg): 325 | return self.accelerate.print(msg) 326 | 327 | def save(self, filepath = './checkpoint.pt'): 328 | torch.save(self.actor.state_dict(), filepath) 329 | 330 | def load(self, filepath = './checkpoint.pt'): 331 | state_dict = torch.load(filepath) 332 | self.actor.load_state_dict(state_dict) 333 | 334 | @property 335 | def device(self): 336 | return self.accelerate.device 337 | 338 | @torch.no_grad() 339 | def generate( 340 | self, 341 | max_seq_len, 342 | *args, 343 | prompt, 344 | num_samples = 4, # sample 4 per prompt and select the one with highest reward 345 | **kwargs 346 | ): 347 | assert prompt.ndim == 1, 'only one prompt allowed at a time for now' 348 | prompt = repeat(prompt, 'n -> b n', b = num_samples) 349 | 350 | actor = self.accelerate.unwrap_model(self.actor) 351 | reward_model = self.accelerate.unwrap_model(self.reward_model) 352 | 353 | actor.eval() 354 | 355 | ( 356 | actions, 357 | sequences, 358 | mask, 359 | prompt_mask, 360 | action_logits, 361 | _ 362 | ) = actor.generate( 363 | prompt, 364 | *args, 365 | max_seq_len = max_seq_len, 366 | **kwargs 367 | ) 368 | 369 | rewards = reward_model( 370 | sequences, 371 | prompt_mask = prompt_mask, 372 | mask = mask 373 | ) 374 | 375 | best_sequence_index = rewards.topk(1, dim = -1).indices 376 | 377 | best_sequence = sequences[best_sequence_index] 378 | best_sequence = rearrange(best_sequence, '1 ... -> ...') 379 | 380 | return best_sequence 381 | 382 | def learn( 383 | self, 384 | memories: Deque[Memory] 385 | ): 386 | # stack all data stored in the memories 387 | 388 | all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) 389 | 390 | # prepare dataloader for policy phase training 391 | 392 | dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) 393 | 394 | self.actor.train() 395 | 396 | # GRPO training 397 | 398 | for _ in range(self.epochs): 399 | for ( 400 | sequences, 401 | prompt_masks, 402 | masks, 403 | old_action_probs, 404 | old_log_probs, 405 | rewards, 406 | ) in dl: 407 | action_masks = ~prompt_masks & masks 408 | 409 | action_logits = self.actor( 410 | sequences, 411 | mask = action_masks 412 | ) 413 | 414 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token 415 | action_len = old_log_probs.shape[-1] 416 | 417 | action_probs = action_logits.softmax(dim = -1) 418 | action_log_probs = get_at('b n [l], b n -> b n', action_probs, sequences) 419 | action_log_probs = action_log_probs[:, -action_len:] 420 | 421 | # calculate entropies, taking into account which part of the sequence is actually an action 422 | 423 | entropies = masked_entropy(action_probs, mask = action_masks) 424 | 425 | # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not 426 | 427 | kl_penalty = 0. 428 | 429 | if self.kl_div_loss_weight > 0: 430 | kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight 431 | 432 | # subtract the kl penalty from the rewards 433 | 434 | rewards = rewards - kl_penalty 435 | 436 | # calculate clipped surrogate objective, classic PPO loss 437 | 438 | ratios = (action_log_probs - old_log_probs).exp() 439 | 440 | surr1 = ratios * rewards 441 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * rewards 442 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies 443 | 444 | # combine losses 445 | 446 | loss = policy_loss.mean() 447 | 448 | # update actor 449 | 450 | self.accelerate.backward(loss) 451 | 452 | self.print(f'policy_loss: {loss.item():.3f}') 453 | 454 | if exists(self.max_norm): 455 | self.accelerator.clip_grad_norm_(self.actor.actor_parameters(), self.max_norm) 456 | 457 | self.actor_optim.step() 458 | self.actor_optim.zero_grad() 459 | 460 | def train( 461 | self, 462 | num_episodes = 50000, 463 | max_timesteps = 500, 464 | update_timesteps = 5000, 465 | max_batch_size = 16, 466 | max_seq_len = 2048, 467 | eos_token = None, 468 | temperature = 1. 469 | ): 470 | action_sample_times = self.grpo_num_times_sample_rewards 471 | 472 | device = self.device 473 | 474 | time = 0 475 | memories = deque([]) 476 | 477 | for eps in tqdm(range(num_episodes), desc = 'episodes'): 478 | for timestep in range(max_timesteps): 479 | time += 1 480 | 481 | # select a bunch of random states (prompts) 482 | # and get the action (sampled sequence from palm as well as the action probs) 483 | # also calculate the reward using reward model and store 484 | 485 | rand_prompt_index = randrange(0, self.num_prompts) 486 | 487 | state = self.prompt_token_ids[rand_prompt_index] 488 | 489 | # remove padding from state 490 | 491 | state_mask = state != self.pad_value 492 | state = state[state_mask] 493 | 494 | # will sample each state more than once to get an estimate of the value, for removing the critic altogether from GRPO paper, Shao et al. 495 | 496 | states = repeat(state, 'n -> b n', b = action_sample_times + 1) 497 | 498 | # get predicted sequence 499 | 500 | ( 501 | actions, 502 | sequence, 503 | mask, 504 | prompt_mask, 505 | action_logits, 506 | ) = self.actor_generate( 507 | states, 508 | max_seq_len = max_seq_len, 509 | eos_token = eos_token, 510 | temperature = temperature, 511 | ) 512 | 513 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token 514 | 515 | action_prob = action_logits.softmax(dim = -1) 516 | 517 | action_len = actions.shape[-1] 518 | action_log_prob = get_at('b n [l], b n -> b n', action_prob, sequence) 519 | action_log_prob = action_log_prob[:, -action_len:] 520 | 521 | # get reward as given by supervised trained reward model 522 | 523 | sequence = torch.cat((states, actions), dim = 1) 524 | 525 | prompt_length = states.shape[1] 526 | prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length 527 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = action_sample_times + 1) 528 | 529 | mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device)) 530 | 531 | rewards = self.reward_model( 532 | sequence, 533 | prompt_mask = prompt_mask, 534 | mask = mask, 535 | sample = True 536 | ) 537 | 538 | rewards = rewards.float() 539 | 540 | # rewards are normalized for use as advantages 541 | # following Dr. GRPO paper from Sea AI labs, remove the standard deviation 542 | 543 | normalized_rewards = (rewards - rewards.mean()) / (action_sample_times + 1) 544 | 545 | # store memory for learning 546 | 547 | detach_to_cpu_ = lambda t: t.detach().cpu() 548 | 549 | memories.extend([Memory(*memories) for memories in zip(*map(detach_to_cpu_, ( 550 | sequence, 551 | prompt_mask, 552 | mask, 553 | action_prob, 554 | action_log_prob, 555 | normalized_rewards, 556 | )))]) 557 | 558 | # learn from the stored memories 559 | 560 | if divisible_by(time, update_timesteps): 561 | self.learn(memories) 562 | memories.clear() 563 | 564 | print('dr grpo rlhf training complete') 565 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/implicit_process_reward.py: -------------------------------------------------------------------------------- 1 | # Free Process Rewards without Process Labels 2 | # Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime 3 | 4 | from __future__ import annotations 5 | from copy import deepcopy 6 | 7 | import torch 8 | from torch.nn import Module 9 | from torch.nn.functional import logsigmoid 10 | 11 | from einops import rearrange 12 | 13 | # helpers 14 | 15 | def exists(v): 16 | return v is not None 17 | 18 | def get_logprob_at(logits, seq): 19 | log_probs = logits.log_softmax(dim = -1) 20 | seq = rearrange(seq, '... -> ... 1') 21 | log_prob = log_probs.gather(-1, seq) 22 | return rearrange(log_prob, '... 1 -> ...') 23 | 24 | class ImplicitPRM(Module): 25 | """ PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """ 26 | 27 | def __init__( 28 | self, 29 | model: Module, 30 | ref_model: Module | None = None, 31 | beta = 0.1 32 | ): 33 | super().__init__() 34 | self.model = model 35 | 36 | # only drawback to this technique is needing a reference model 37 | 38 | if not exists(ref_model): 39 | ref_model = deepcopy(model) 40 | 41 | self.ref_model = ref_model 42 | ref_model.requires_grad_(False) # insurance 43 | 44 | self.beta = beta 45 | 46 | def parameters(self): 47 | return self.model.parameters() # only main model is trained 48 | 49 | def forward( 50 | self, 51 | seq, 52 | labels = None 53 | ): 54 | source_seq, target_seq = seq[:, :-1], seq[:, 1:] 55 | 56 | mask = target_seq >= 0 # assume any token ids < 0 to be padding 57 | 58 | model_logits = self.model(source_seq) 59 | ref_model_logits = self.ref_model(source_seq) 60 | 61 | log_prob = get_logprob_at(model_logits, target_seq) 62 | ref_log_prob = get_logprob_at(ref_model_logits, target_seq) 63 | 64 | # main formula is DPO-like, and has some connection with Q-learning https://arxiv.org/abs/2404.12358 . it is all connected 65 | 66 | implicit_rewards = self.beta * (log_prob - ref_log_prob) 67 | 68 | # zero out rewards in padding 69 | 70 | implicit_rewards = implicit_rewards.masked_fill(~mask, 0.) 71 | 72 | # early return if not training, as in Prime with alternating model and prm training 73 | 74 | if not exists(labels): 75 | return implicit_rewards 76 | 77 | labels = rearrange(labels, 'b -> b 1') 78 | 79 | # otherwise use the cross entropy formulation from their paper (eq 5) 80 | 81 | loss = ( 82 | labels * logsigmoid(implicit_rewards) + 83 | (1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x) 84 | ) 85 | 86 | return loss[mask].mean() 87 | 88 | # make it easy for others to copy paste into another project 89 | 90 | if __name__ == '__main__': 91 | from palm_rlhf_pytorch import PaLM 92 | 93 | palm = PaLM( 94 | num_tokens = 256, 95 | dim = 64, 96 | depth = 2 97 | ) 98 | 99 | ref_palm = PaLM( 100 | num_tokens = 256, 101 | dim = 64, 102 | depth = 2 103 | ) 104 | 105 | implicit_prm = ImplicitPRM( 106 | palm, 107 | ref_model = ref_palm 108 | ) 109 | 110 | # mock data 111 | 112 | seq = torch.randint(0, 256, (2, 1024)) 113 | labels = torch.randint(0, 2, (2,)) 114 | 115 | loss = implicit_prm(seq, labels) 116 | loss.backward() 117 | 118 | # after much training 119 | 120 | implicit_rewards = implicit_prm(seq) # Float[2, 1024] 121 | 122 | # there you go, free process reward model 123 | # now you can use this dense reward for rlhf, beam search whatever 124 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/lora.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import Module 4 | 5 | # helper functions 6 | 7 | def exists(val): 8 | return val is not None 9 | 10 | def default(val, d): 11 | return val if exists(val) else d 12 | 13 | # LoRA - https://arxiv.org/abs/2106.09685 14 | 15 | class LoRA(Module): 16 | def __init__( 17 | self, 18 | dim, 19 | dim_out, 20 | r = 8, 21 | alpha = None 22 | ): 23 | super().__init__() 24 | alpha = default(alpha, r) 25 | self.scale = alpha / r 26 | 27 | self.A = nn.Parameter(torch.randn(dim, r)) 28 | self.B = nn.Parameter(torch.zeros(r, dim_out)) 29 | 30 | @property 31 | def weight(self): 32 | return (self.A @ self.B) * self.scale 33 | 34 | def forward(self, x): 35 | return x @ self.weight 36 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/palm.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | import copy 5 | from pathlib import Path 6 | from collections import namedtuple 7 | from functools import wraps 8 | from itertools import zip_longest 9 | 10 | from tqdm import tqdm 11 | from beartype import beartype 12 | 13 | import torch 14 | from torch import einsum, nn 15 | import torch.nn.functional as F 16 | from torch.nn import Module, ModuleList, ModuleDict 17 | 18 | from einops import rearrange, repeat, reduce, pack, unpack 19 | from einops.layers.torch import Rearrange, Reduce 20 | 21 | from palm_rlhf_pytorch.attention import Attention 22 | from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator 23 | from palm_rlhf_pytorch.lora import LoRA 24 | 25 | # functions and decorators 26 | 27 | def exists(val): 28 | return val is not None 29 | 30 | def default(val, d): 31 | return val if exists(val) else d 32 | 33 | def identity(t, *args, **kwargs): 34 | return t 35 | 36 | def l2norm(t): 37 | return F.normalize(t, dim = -1) 38 | 39 | # normalization 40 | # they use layernorm without bias, something that pytorch does not offer 41 | 42 | class LayerNorm(Module): 43 | def __init__(self, dim): 44 | super().__init__() 45 | self.gamma = nn.Parameter(torch.zeros(dim)) 46 | self.register_buffer("beta", torch.zeros(dim)) 47 | 48 | def forward(self, x): 49 | return F.layer_norm(x, x.shape[-1:], (self.gamma + 1), self.beta) 50 | 51 | # residual 52 | 53 | 54 | class Residual(Module): 55 | def __init__(self, fn): 56 | super().__init__() 57 | self.fn = fn 58 | 59 | def forward(self, x, **kwargs): 60 | y = self.fn(x, **kwargs) 61 | 62 | if not any([t.requires_grad for t in (x, y)]): 63 | return x.add_(y) 64 | 65 | return y + x 66 | 67 | # rotary positional embedding w/ xpos 68 | # https://arxiv.org/abs/2104.09864 69 | # https://arxiv.org/abs/2212.10554v1 70 | 71 | class RotaryEmbedding(Module): 72 | def __init__(self, dim, scale_base = 512, use_xpos = True): 73 | super().__init__() 74 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) 75 | self.register_buffer("inv_freq", inv_freq) 76 | 77 | self.use_xpos = use_xpos 78 | self.scale_base = scale_base 79 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) 80 | self.register_buffer('scale', scale) 81 | 82 | def forward(self, seq_len, device): 83 | t = torch.arange(seq_len, device = device).type_as(self.inv_freq) 84 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq) 85 | freqs = torch.cat((freqs, freqs), dim = -1) 86 | 87 | if not self.use_xpos: 88 | return freqs, torch.ones(1, device = device) 89 | 90 | power = (t - (seq_len // 2)) / self.scale_base 91 | scale = self.scale ** rearrange(power, 'n -> n 1') 92 | scale = torch.cat((scale, scale), dim = -1) 93 | 94 | return freqs, scale 95 | 96 | def rotate_half(x): 97 | x1, x2 = x.chunk(2, dim=-1) 98 | return torch.cat((-x2, x1), dim=-1) 99 | 100 | 101 | def apply_rotary_pos_emb(pos, t, scale = 1.): 102 | return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale) 103 | 104 | 105 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward 106 | # https://arxiv.org/abs/2002.05202 107 | 108 | 109 | class SwiGLU(Module): 110 | def forward(self, x): 111 | x, gate = x.chunk(2, dim=-1) 112 | return F.silu(gate) * x 113 | 114 | 115 | # parallel attention and feedforward with residual 116 | # discovered by Wang et al + EleutherAI from GPT-J fame 117 | 118 | 119 | class ParallelTransformerBlock(Module): 120 | def __init__( 121 | self, 122 | dim, 123 | dim_head = 64, 124 | causal = True, 125 | heads = 8, 126 | qk_rmsnorm = False, 127 | qk_scale = 8, 128 | ff_mult = 4, 129 | attn_dropout = 0., 130 | ff_dropout = 0., 131 | use_xpos = True, 132 | xpos_scale_base = 512, 133 | flash_attn = False, 134 | ): 135 | super().__init__() 136 | self.norm = LayerNorm(dim) 137 | 138 | attn_inner_dim = dim_head * heads 139 | ff_inner_dim = dim * ff_mult 140 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) 141 | 142 | self.qk_rmsnorm = qk_rmsnorm 143 | 144 | if qk_rmsnorm: 145 | self.q_scale = nn.Parameter(torch.ones(dim_head)) 146 | self.k_scale = nn.Parameter(torch.ones(dim_head)) 147 | 148 | self.attend = Attention( 149 | causal = causal, 150 | dropout = attn_dropout, 151 | use_flash_attn = flash_attn 152 | ) 153 | 154 | self.heads = heads 155 | self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale 156 | self.causal = causal 157 | 158 | self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal) 159 | 160 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False) 161 | 162 | self.flash_attn = flash_attn 163 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False) 164 | self.attn_dropout = nn.Dropout(attn_dropout) 165 | self.flash_attn_dropout = attn_dropout 166 | 167 | # parallel feedforward tail 168 | 169 | self.ff_out = nn.Sequential( 170 | SwiGLU(), 171 | nn.Dropout(ff_dropout), 172 | nn.Linear(ff_inner_dim, dim, bias=False) 173 | ) 174 | 175 | # for caching causal mask and rotary embeddings 176 | 177 | self.register_buffer("pos_emb", None, persistent=False) 178 | self.register_buffer("pos_emb_scale", None, persistent=False) 179 | 180 | def get_rotary_embedding(self, n, device): 181 | if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n: 182 | return self.pos_emb[:n], self.pos_emb_scale[:n] 183 | 184 | pos_emb, scale = self.rotary_emb(n, device=device) 185 | self.register_buffer("pos_emb", pos_emb, persistent=False) 186 | self.register_buffer("pos_emb_scale", scale, persistent=False) 187 | return pos_emb, scale 188 | 189 | def forward( 190 | self, 191 | x, 192 | mask = None, 193 | finetune_modules = None 194 | ): 195 | """ 196 | einstein notation 197 | b - batch 198 | h - heads 199 | n, i, j - sequence length (base sequence length, source, target) 200 | d - feature dimension 201 | """ 202 | 203 | n, device, h = x.shape[1], x.device, self.heads 204 | 205 | # pre layernorm 206 | 207 | x = self.norm(x) 208 | 209 | # attention queries, keys, values, and feedforward inner 210 | 211 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1) 212 | 213 | # finetune loras 214 | 215 | lora_q = lora_k = lora_v = lora_o = None 216 | 217 | if exists(finetune_modules): 218 | lora_q, lora_k, lora_v, lora_o = finetune_modules 219 | q = q + lora_q(x) 220 | k = k + lora_k(x) 221 | v = v + lora_v(x) 222 | 223 | # split heads 224 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper 225 | # they found no performance loss past a certain scale, and more efficient decoding obviously 226 | # https://arxiv.org/abs/1911.02150 227 | 228 | q = rearrange(q, "b n (h d) -> b h n d", h=h) 229 | 230 | # qk rmsnorm 231 | 232 | if self.qk_rmsnorm: 233 | q, k = map(l2norm, (q, k)) 234 | q = q * self.q_scale 235 | k = k * self.k_scale 236 | 237 | # rotary embeddings with xpos decay for better length extrapolation 238 | 239 | positions, scale = self.get_rotary_embedding(n, device) 240 | 241 | q = apply_rotary_pos_emb(positions, q, scale) 242 | k = apply_rotary_pos_emb(positions, k, scale ** -1) 243 | 244 | # attention function, either regular or flash 245 | 246 | out = self.attend(q, k, v, mask = mask) 247 | 248 | # merge heads 249 | 250 | out = rearrange(out, "b h n d -> b n (h d)") 251 | 252 | attn_out = self.attn_out(out) 253 | 254 | ff_out = self.ff_out(ff) 255 | 256 | if exists(lora_o): 257 | attn_out = attn_out + lora_o(out) 258 | 259 | return attn_out + ff_out 260 | 261 | # transformer 262 | 263 | class PaLM(Module): 264 | @beartype 265 | def __init__( 266 | self, 267 | *, 268 | dim, 269 | num_tokens, 270 | depth, 271 | causal = True, 272 | dim_head = 64, 273 | heads = 8, 274 | ff_mult = 4, 275 | attn_dropout = 0., 276 | ff_dropout = 0., 277 | qk_rmsnorm = False, 278 | lora_r = 8, 279 | rotary_xpos_scale_base = 512, 280 | flash_attn = False, 281 | finetune_scopes = tuple(), 282 | cross_entropy_ignore_index = 0 283 | ): 284 | super().__init__() 285 | self.dim = dim 286 | self.dim_head = dim_head 287 | self.heads = heads 288 | self.causal = causal 289 | self.num_tokens = num_tokens 290 | 291 | self.token_emb = nn.Embedding(num_tokens, dim) 292 | self.layers = ModuleList([]) 293 | 294 | for _ in range(depth): 295 | block = Residual(ParallelTransformerBlock( 296 | dim = dim, 297 | causal = causal, 298 | dim_head = dim_head, 299 | heads = heads, 300 | qk_rmsnorm = qk_rmsnorm, 301 | ff_mult = ff_mult, 302 | attn_dropout = attn_dropout, 303 | ff_dropout = ff_dropout, 304 | xpos_scale_base = rotary_xpos_scale_base, 305 | flash_attn = flash_attn 306 | )) 307 | 308 | self.layers.append(block) 309 | 310 | self.norm = LayerNorm(dim) 311 | self.to_logits = nn.Linear(dim, num_tokens, bias=False) 312 | 313 | self.to_logits.weight = self.token_emb.weight 314 | 315 | nn.init.normal_(self.token_emb.weight, std=0.02) 316 | 317 | # fine tuning related 318 | 319 | self.lora_r = lora_r 320 | self.finetune_modules = ModuleDict({}) 321 | 322 | for scope in finetune_scopes: 323 | self.add_finetune_params(scope) 324 | 325 | # loss related 326 | 327 | self.cross_entropy_ignore_index = cross_entropy_ignore_index 328 | 329 | @property 330 | def device(self): 331 | return next(self.parameters()).device 332 | 333 | def load(self, path): 334 | path = Path(path) 335 | assert path.exists() 336 | self.load_state_dict(torch.load(str(path))) 337 | 338 | def set_dropout(self, dropout): 339 | for module in self.layers.modules(): 340 | if isinstance(module, nn.Dropout): 341 | module.p = dropout 342 | return self 343 | 344 | def add_finetune_params(self, scope, lora_r = None): 345 | assert scope not in self.finetune_modules, f'finetune scope {scope} already found' 346 | dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device 347 | 348 | q_inner_dim = heads * dim_head 349 | kv_inner_dim = dim_head 350 | 351 | lora_modules = ModuleList([]) 352 | 353 | for _ in range(len(self.layers)): 354 | lora_modules.append(ModuleList([ 355 | LoRA(dim, q_inner_dim, r = r), # queries 356 | LoRA(dim, kv_inner_dim, r = r), # keys 357 | LoRA(dim, kv_inner_dim, r = r), # values 358 | LoRA(q_inner_dim, dim, r = r) # wo 359 | ])) 360 | 361 | self.finetune_modules[scope] = lora_modules.to(device) 362 | 363 | def remove_finetune_params(self, scope): 364 | assert scope in self.finetune_modules, f'finetune scope {scope} not found' 365 | return self.finetune_modules.pop(scope) 366 | 367 | @torch.no_grad() 368 | def merge_finetune_params(self, scope): 369 | """ in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """ 370 | 371 | assert scope in self.finetune_modules, f'finetune scope {scope} not found' 372 | 373 | lora_modules = self.finetune_modules.pop(scope) 374 | 375 | for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules): 376 | block = layer.fn 377 | 378 | fused_attn_ff_weight = block.fused_attn_ff_proj.weight 379 | attn_out_weight = block.attn_out.weight 380 | 381 | fused_proj_out_dim = fused_attn_ff_weight.shape[0] 382 | 383 | lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *') 384 | lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1])) 385 | 386 | lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i') 387 | lora_o_weight = rearrange(lora_o.weight, 'i o -> o i') 388 | 389 | fused_attn_ff_weight.add_(lora_qkv_weight) 390 | attn_out_weight.add_(lora_o_weight) 391 | 392 | # researcher train palm parameters first 393 | # before finetuning 394 | 395 | def palm_parameters(self): 396 | return set(self.parameters()) - set(self.finetune_modules.parameters()) 397 | 398 | def finetune_parameters(self, scope = 'default'): 399 | assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found' 400 | return self.finetune_modules[scope].parameters() 401 | 402 | # generate function 403 | 404 | @torch.no_grad() 405 | @eval_decorator 406 | def generate( 407 | self, 408 | seq_len, 409 | prompt = None, 410 | temperature = 1., 411 | filter_logits_fn = top_k, 412 | filter_thres = 0.9, 413 | pad_value = 0., 414 | eos_token = None, 415 | return_seq_without_prompt = True, 416 | use_tqdm = False, 417 | **kwargs 418 | ): 419 | if not exists(prompt): 420 | prompt = torch.randint(0, self.num_tokens, (1, 1)) 421 | prompt = prompt.to(self.device) 422 | return_seq_without_prompt = False 423 | 424 | prompt, leading_dims = pack([prompt], '* n') 425 | 426 | n, out = prompt.shape[-1], prompt.clone() 427 | 428 | wrapper_fn = identity if not use_tqdm else tqdm 429 | sample_num_times = max(1, seq_len - prompt.shape[-1]) 430 | 431 | for _ in wrapper_fn(range(sample_num_times)): 432 | logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs) 433 | logits, embeds = logits[:, -1], embeds[:, -1] 434 | 435 | if exists(filter_logits_fn): 436 | logits = filter_logits_fn(logits, thres = filter_thres) 437 | 438 | sample = gumbel_sample(logits, temperature = temperature, dim = -1) 439 | out, _ = pack([out, sample], 'b *') 440 | 441 | if exists(eos_token): 442 | is_eos_tokens = (out == eos_token) 443 | 444 | if is_eos_tokens.any(dim = -1).all(): 445 | # mask out everything after the eos tokens 446 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1)) 447 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1 448 | out = out.masked_fill(mask, pad_value) 449 | break 450 | 451 | out, = unpack(out, leading_dims, '* n') 452 | 453 | if not return_seq_without_prompt: 454 | return out 455 | 456 | return out[..., n:] 457 | 458 | def forward( 459 | self, 460 | x, 461 | return_loss = False, 462 | disable_lora = False, 463 | finetune_scope = None, 464 | extra_embed = None, 465 | return_only_embedding = False, 466 | return_logits_with_embedding = False 467 | ): 468 | if return_loss: 469 | x, labels = x[:, :-1], x[:, 1:] 470 | 471 | # mask if encoder 472 | # treat any token ids that are negative as tokens to mask out - only needed if not autoregressive 473 | 474 | if not self.causal: 475 | mask = x >= 0 476 | x = x.masked_fill(~mask, 0) 477 | else: 478 | mask = None 479 | 480 | # get token embedding 481 | 482 | x = self.token_emb(x) 483 | 484 | if exists(extra_embed): 485 | x = x + extra_embed 486 | 487 | # finetune modules 488 | 489 | finetune_modules = tuple() 490 | if exists(finetune_scope) and not disable_lora: 491 | assert finetune_scope in self.finetune_modules 492 | finetune_modules = self.finetune_modules[finetune_scope] 493 | 494 | # parallel attention / ff blocks, passing in finetuning loras 495 | 496 | for layer, finetune_modules in zip_longest(self.layers, finetune_modules): 497 | x = layer(x, mask = mask, finetune_modules = finetune_modules) 498 | 499 | # final norm 500 | 501 | embeds = self.norm(x) 502 | 503 | if return_only_embedding: 504 | return embeds 505 | 506 | # to logits 507 | 508 | logits = self.to_logits(embeds) 509 | 510 | ret = (logits, embeds) if return_logits_with_embedding else logits 511 | 512 | if not return_loss: 513 | return ret 514 | 515 | logits = rearrange(logits, 'b n c -> b c n') 516 | return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index) -------------------------------------------------------------------------------- /palm_rlhf_pytorch/ppo.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | import math 4 | from pathlib import Path 5 | import copy 6 | from accelerate.utils.tqdm import tqdm 7 | from functools import partial 8 | from collections import deque, namedtuple 9 | from random import randrange 10 | 11 | from beartype import beartype 12 | from beartype.typing import Callable, Deque 13 | 14 | import torch 15 | from torch import nn, Tensor 16 | from torch.nn import Module 17 | import torch.nn.functional as F 18 | 19 | from torch.utils.data import Dataset, DataLoader 20 | from torch.nn.utils.rnn import pad_sequence 21 | 22 | from einops import rearrange, repeat, reduce 23 | from einops.layers.torch import Rearrange 24 | 25 | from adam_atan2_pytorch import AdoptAtan2 26 | 27 | from hl_gauss_pytorch import HLGaussLoss 28 | 29 | from palm_rlhf_pytorch.palm import PaLM 30 | from palm_rlhf_pytorch.reward import RewardModel 31 | from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM 32 | from palm_rlhf_pytorch.utils import masked_mean, eval_decorator 33 | from accelerate import Accelerator 34 | 35 | # actor critic - PaLM with lora 36 | 37 | PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [ 38 | 'actions', 39 | 'sequence', 40 | 'mask', 41 | 'prompt_mask', 42 | 'action_logits', 43 | 'values' 44 | ]) 45 | 46 | class ActorCritic(Module): 47 | @beartype 48 | def __init__( 49 | self, 50 | palm: PaLM, 51 | critic: PaLM | ImplicitPRM | None = None, 52 | pooled_values = False, 53 | actor_lora = True, 54 | critic_lora = True, 55 | actor_lora_r = 8, 56 | critic_lora_r = 8, 57 | actor_lora_scope = 'actor', 58 | critic_lora_scope = 'critic', 59 | actor_dropout = 0., 60 | critic_dropout = 0., 61 | critic_dim_out = 6 # rewards from 0 to 5 62 | ): 63 | super().__init__() 64 | self.actor_palm = palm 65 | 66 | # detect implicit prm and auto-set some hyperparameters 67 | 68 | critic_is_prm = isinstance(critic, ImplicitPRM) 69 | 70 | critic_lora &= not critic_is_prm 71 | pooled_values |= critic_is_prm 72 | 73 | self.critic_is_prm = critic_is_prm 74 | 75 | # critic 76 | 77 | self.critic = critic 78 | 79 | if not exists(self.critic): 80 | self.critic = copy.deepcopy(palm) 81 | 82 | self.actor_palm.set_dropout(actor_dropout) 83 | 84 | if not critic_is_prm: 85 | self.critic.set_dropout(critic_dropout) 86 | 87 | self.actor_lora = actor_lora 88 | self.critic_lora = critic_lora 89 | 90 | self.actor_lora_scope = actor_lora_scope if actor_lora else None 91 | self.critic_lora_scope = critic_lora_scope if critic_lora else None 92 | 93 | if self.actor_lora: 94 | self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r) 95 | 96 | if self.critic_lora: 97 | self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r) 98 | 99 | self.pooled_values = pooled_values 100 | self.value_head = nn.Identity() 101 | 102 | if not critic_is_prm: 103 | assert critic_dim_out > 1 104 | self.value_head = nn.Linear(palm.dim, critic_dim_out) 105 | 106 | nn.init.zeros_(self.value_head.bias) 107 | nn.init.orthogonal_(self.value_head.weight, gain = math.sqrt(2)) 108 | 109 | def actor_parameters(self): 110 | if not self.actor_lora: 111 | return self.actor_palm.parameters() 112 | 113 | return [ 114 | *self.actor_palm.finetune_parameters(self.actor_lora_scope) 115 | ] 116 | 117 | def critic_parameters(self): 118 | if self.critic_is_prm: 119 | return self.critic.parameters() 120 | 121 | if not self.actor_lora: 122 | return [*self.critic.parameters(), *self.value_head.parameters()] 123 | 124 | return [ 125 | *self.critic.finetune_parameters(self.critic_lora_scope), 126 | *self.value_head.parameters() 127 | ] 128 | 129 | @torch.no_grad() 130 | @eval_decorator 131 | def generate( 132 | self, 133 | state, 134 | max_seq_len, 135 | eos_token = None, 136 | return_values = False, 137 | **kwargs 138 | ): 139 | actions = self.actor_palm.generate( 140 | max_seq_len, 141 | prompt = state, 142 | eos_token = eos_token, 143 | finetune_scope = self.actor_lora_scope, 144 | use_tqdm = True, 145 | **kwargs 146 | ) 147 | 148 | sequence = torch.cat((state, actions), dim = -1) 149 | action_len = actions.shape[-1] 150 | state_len = state.shape[-1] 151 | 152 | prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len 153 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0]) 154 | 155 | action_mask = ~prompt_mask 156 | 157 | mask = None 158 | if exists(eos_token): 159 | mask = ((sequence == eos_token).cumsum(dim = -1) == 0) 160 | mask = F.pad(mask, (1, -1), value = True) # include eos token 161 | action_mask &= mask 162 | 163 | action_logits, value = self.forward( 164 | sequence, 165 | mask = action_mask, 166 | return_values = return_values 167 | ) 168 | 169 | return PPOActionCriticReturn( 170 | actions, 171 | sequence, 172 | mask, 173 | prompt_mask, 174 | action_logits, 175 | value 176 | ) 177 | 178 | def forward( 179 | self, 180 | x, 181 | mask = None, 182 | return_values = True 183 | ): 184 | action_logits = self.actor_palm( 185 | x, 186 | finetune_scope = self.actor_lora_scope 187 | ) 188 | 189 | if not return_values: 190 | return action_logits, None 191 | 192 | if self.critic_is_prm: 193 | values = self.critic(x) 194 | return action_logits, values 195 | 196 | critic_embeds = self.critic( 197 | x, 198 | return_only_embedding = True, 199 | finetune_scope = self.critic_lora_scope 200 | ) 201 | 202 | if self.pooled_values: 203 | critic_embeds = shift(critic_embeds, shift = 1, dim = -2) 204 | critic_embeds = masked_mean(critic_embeds, mask, dim = 1) 205 | 206 | values = self.value_head(critic_embeds) 207 | 208 | return action_logits, values 209 | 210 | # data 211 | 212 | Memory = namedtuple('Memory', [ 213 | 'sequence', 214 | 'prompt_mask', 215 | 'mask', 216 | 'action_prob', 217 | 'action_log_prob', 218 | 'reward', 219 | 'value' 220 | ]) 221 | 222 | class ExperienceDataset(Dataset): 223 | @beartype 224 | def __init__( 225 | self, 226 | data, 227 | device = None 228 | ): 229 | super().__init__() 230 | self.data = data 231 | self.device = device 232 | 233 | def __len__(self): 234 | return self.data[0].shape[0] 235 | 236 | def __getitem__(self, ind): 237 | return tuple(map(lambda t: t[ind].to(self.device), self.data)) 238 | 239 | def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs): 240 | ds = ExperienceDataset(data, device = device) 241 | return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs) 242 | 243 | # helper functions 244 | 245 | def exists(val): 246 | return val is not None 247 | 248 | def default(val, d): 249 | if exists(val): 250 | return val 251 | return d() if callable(d) else d 252 | 253 | def masked_normalize(t, eps = 1e-5, mask = None, dim = None): 254 | dim = default(dim, tuple(range(t.ndim))) 255 | kwargs = dict(dim = dim, keepdim = True) 256 | 257 | mean = masked_mean(t, mask = mask, **kwargs) 258 | mean_centered = t - mean 259 | var = masked_mean(mean_centered ** 2, mask = mask, **kwargs) 260 | 261 | return mean_centered * var.clamp(min = eps).rsqrt() 262 | 263 | def pad_sequence_fixed(sequences, *args, **kwargs): 264 | first_el = sequences[0] 265 | has_no_dimension = first_el.ndim == 0 266 | 267 | # if no dimensions, add a single dimension 268 | if has_no_dimension: 269 | sequences = tuple(map(lambda t: t[None], sequences)) 270 | 271 | out = pad_sequence(sequences, *args, **kwargs) 272 | 273 | if has_no_dimension: 274 | out = rearrange(out, '... 1 -> ...') 275 | 276 | return out 277 | 278 | def log(t, eps = 1e-20): 279 | return torch.log(t.clamp(min = eps)) 280 | 281 | def log_prob(prob, indices): 282 | assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match' 283 | return log(prob.gather(-1, indices[..., None])).squeeze(-1) 284 | 285 | def shift(t, value = 0, shift = 1, dim = -1): 286 | zeros = (0, 0) * (-dim - 1) 287 | return F.pad(t, (*zeros, shift, -shift), value = value) 288 | 289 | def masked_entropy(prob, dim = -1, mask = None): 290 | entropies = (prob * log(prob)).sum(dim = -1) 291 | return masked_mean(entropies, mask = mask).mean() 292 | 293 | def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False): 294 | """ 295 | need to account for variable sequence lengths, therefore not using the built-in functional version 296 | """ 297 | kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1) 298 | loss = masked_mean(kl_divs, mask) 299 | 300 | if reduce_batch: 301 | return loss.mean() 302 | 303 | return loss 304 | 305 | # rlhf trainer 306 | 307 | class RLHFTrainer(Module): 308 | @beartype 309 | def __init__( 310 | self, 311 | *, 312 | prompts: list[str] | None = None, 313 | prompts_path: str | None = None, 314 | prompt_token_ids: Tensor | None = None, 315 | tokenizer: Callable | None = None, 316 | palm: PaLM, 317 | reward_model: RewardModel, 318 | critic: PaLM | ImplicitPRM | None = None, 319 | actor_critic: ActorCritic | None = None, 320 | actor_lr = 1e-4, 321 | critic_lr = 1e-4, 322 | actor_wd = 0., 323 | critic_wd = 0., 324 | actor_lora = True, 325 | critic_lora = True, 326 | actor_lora_r = 8, 327 | critic_lora_r = 8, 328 | critic_pooled_values = True, 329 | actor_dropout = 0., 330 | critic_dropout = 0., 331 | betas = (0.9, 0.999), 332 | max_norm = None, 333 | eps_clip = 0.2, 334 | value_clip = 0.4, 335 | beta_s = .01, 336 | pad_value = 0., 337 | minibatch_size = 16, 338 | epochs = 1, 339 | kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is 340 | accelerate_kwargs: dict = dict(), 341 | critic_num_pred_bins = 6, 342 | hl_gauss_loss_kwargs: dict = dict( 343 | min_value = 0., 344 | max_value = 5., 345 | min_max_value_on_bin_center = True, 346 | clamp_to_range = True, 347 | sigma_to_bin_ratio = 1. 348 | ) 349 | ): 350 | super().__init__() 351 | 352 | self.accelerate = Accelerator(**accelerate_kwargs) 353 | 354 | # take care of prompts -> token ids 355 | 356 | assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1 357 | 358 | if exists(prompts_path): 359 | path = Path(prompts_path) 360 | prompts = path.read_text().split('\n') 361 | 362 | if exists(prompts): 363 | assert len(prompts) > 0, 'no prompts' 364 | assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given' 365 | prompt_token_ids = tokenizer(prompts) 366 | 367 | self.pad_value = pad_value # token pad value 368 | self.num_prompts = prompt_token_ids.shape[0] 369 | self.register_buffer('prompt_token_ids', prompt_token_ids) 370 | 371 | # models 372 | 373 | self.palm = palm 374 | 375 | if not exists(actor_critic): 376 | actor_critic = ActorCritic( 377 | palm = palm, 378 | critic = critic, 379 | actor_lora = actor_lora, 380 | critic_lora = critic_lora, 381 | actor_lora_r = actor_lora_r, 382 | critic_lora_r = critic_lora_r, 383 | pooled_values = critic_pooled_values, 384 | actor_dropout = actor_dropout, 385 | critic_dropout = critic_dropout, 386 | critic_dim_out = critic_num_pred_bins 387 | ).to(palm.device) 388 | 389 | self.actor_critic = actor_critic 390 | 391 | self.actor_critic_generate = actor_critic.generate 392 | 393 | self.reward_model = reward_model.eval() 394 | 395 | # critic outputs reward bin prediction 396 | # for classification loss, buying into "Stop Regressing" paper from Farebrother et al. https://arxiv.org/abs/2403.03950 397 | 398 | self.critic_hl_gauss_loss = HLGaussLoss(num_bins = critic_num_pred_bins, **hl_gauss_loss_kwargs) 399 | 400 | # train hyperparameters 401 | 402 | self.epochs = epochs 403 | self.minibatch_size = minibatch_size 404 | self.max_norm = max_norm 405 | 406 | self.kl_div_loss_weight = kl_div_loss_weight 407 | 408 | # optimizers 409 | 410 | self.actor_optim = AdoptAtan2(actor_critic.actor_parameters(), lr = actor_lr, weight_decay = actor_wd, betas = betas) 411 | self.critic_optim = AdoptAtan2(actor_critic.critic_parameters(), lr = critic_lr, weight_decay = critic_wd, betas = betas) 412 | 413 | # ppo hyperparams 414 | 415 | self.eps_clip = eps_clip 416 | self.value_clip = value_clip 417 | self.beta_s = beta_s 418 | 419 | # prepare with accelerator 420 | 421 | ( 422 | self.actor_critic, 423 | self.reward_model, 424 | self.actor_optim, 425 | self.critic_optim 426 | ) = self.accelerate.prepare( 427 | self.actor_critic, 428 | self.reward_model, 429 | self.actor_optim, 430 | self.critic_optim 431 | ) 432 | 433 | 434 | def print(self, msg): 435 | return self.accelerate.print(msg) 436 | 437 | def save(self, filepath = './checkpoint.pt'): 438 | torch.save(self.actor_critic.state_dict(), filepath) 439 | 440 | def load(self, filepath = './checkpoint.pt'): 441 | state_dict = torch.load(filepath) 442 | self.actor_critic.load_state_dict(state_dict) 443 | 444 | @property 445 | def device(self): 446 | return self.accelerate.device 447 | 448 | @torch.no_grad() 449 | def generate( 450 | self, 451 | max_seq_len, 452 | *args, 453 | prompt, 454 | num_samples = 4, # sample 4 per prompt and select the one with highest reward 455 | **kwargs 456 | ): 457 | assert prompt.ndim == 1, 'only one prompt allowed at a time for now' 458 | prompt = repeat(prompt, 'n -> b n', b = num_samples) 459 | 460 | actor_critic = self.accelerate.unwrap_model(self.actor_critic) 461 | reward_model = self.accelerate.unwrap_model(self.reward_model) 462 | 463 | actor_critic.eval() 464 | 465 | ( 466 | actions, 467 | sequences, 468 | mask, 469 | prompt_mask, 470 | action_logits, 471 | _ 472 | ) = actor_critic.generate( 473 | prompt, 474 | *args, 475 | max_seq_len = max_seq_len, 476 | return_values = False, 477 | **kwargs 478 | ) 479 | 480 | rewards = reward_model( 481 | sequences, 482 | prompt_mask = prompt_mask, 483 | mask = mask 484 | ) 485 | 486 | best_sequence_index = rewards.topk(1, dim = -1).indices 487 | 488 | best_sequence = sequences[best_sequence_index] 489 | best_sequence = rearrange(best_sequence, '1 ... -> ...') 490 | 491 | return best_sequence 492 | 493 | def learn( 494 | self, 495 | memories: Deque[Memory] 496 | ): 497 | # stack all data stored in the memories 498 | 499 | all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories))) 500 | 501 | # prepare dataloader for policy phase training 502 | 503 | dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device) 504 | 505 | self.actor_critic.train() 506 | 507 | # PPO training 508 | 509 | for _ in range(self.epochs): 510 | 511 | for ( 512 | sequences, 513 | prompt_masks, 514 | masks, 515 | old_action_probs, 516 | old_log_probs, 517 | rewards, 518 | old_values_bins 519 | ) in dl: 520 | 521 | action_masks = ~prompt_masks & masks 522 | 523 | action_logits, values_bins = self.actor_critic( 524 | sequences, 525 | mask = action_masks 526 | ) 527 | 528 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token 529 | action_len = old_log_probs.shape[-1] 530 | 531 | action_probs = action_logits.softmax(dim = -1) 532 | action_log_probs = log_prob(action_probs, sequences) 533 | action_log_probs = action_log_probs[:, -action_len:] 534 | 535 | # calculate entropies, taking into account which part of the sequence is actually an action 536 | 537 | entropies = masked_entropy(action_probs, mask = action_masks) 538 | 539 | # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not 540 | 541 | kl_penalty = 0. 542 | 543 | if self.kl_div_loss_weight > 0: 544 | kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight 545 | 546 | # subtract the kl penalty from the rewards 547 | 548 | rewards = rewards - kl_penalty 549 | 550 | # convert binned value predictions to scalar value 551 | 552 | to_pred_value = self.critic_hl_gauss_loss.transform_from_logits 553 | 554 | old_values, values = map(to_pred_value, (old_values_bins, values_bins)) 555 | 556 | # handle non-pooled values 557 | 558 | normalize_kwargs = dict() 559 | 560 | if old_values.ndim == 2: 561 | old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values)) 562 | 563 | old_values = old_values[:, -action_len:] 564 | values = values[:, -action_len:] 565 | rewards = rearrange(rewards, 'b -> b 1') 566 | normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:]) 567 | 568 | if values.ndim < rewards.ndim: 569 | values = rearrange(values, '... -> ... 1') 570 | 571 | # calculate clipped surrogate objective, classic PPO loss 572 | 573 | ratios = (action_log_probs - old_log_probs).exp() 574 | advantages = masked_normalize(rewards - old_values, **normalize_kwargs) 575 | 576 | if advantages.ndim == 1: 577 | advantages = rearrange(advantages, 'b -> b 1') 578 | 579 | surr1 = ratios * advantages 580 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages 581 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies 582 | 583 | # combine losses 584 | 585 | loss = policy_loss.mean() 586 | 587 | # update actor 588 | 589 | self.accelerate.backward(loss) 590 | 591 | self.print(f'policy_loss: {loss.item():.3f}') 592 | 593 | if exists(self.max_norm): 594 | self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm) 595 | 596 | self.actor_optim.step() 597 | self.actor_optim.zero_grad() 598 | 599 | # calculate clipped value loss and update value network separate from policy network 600 | 601 | value_clipped = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip) 602 | 603 | rewards.detach_() 604 | 605 | value_loss_1 = self.critic_hl_gauss_loss(value_clipped, rewards, reduction = 'none') 606 | value_loss_2 = self.critic_hl_gauss_loss(values_bins, rewards, reduction = 'none') 607 | 608 | value_loss = torch.maximum(value_loss_1, value_loss_2).mean() 609 | 610 | self.print(f'critic_loss: {value_loss.item():.3f}') 611 | 612 | self.accelerate.backward(value_loss) 613 | 614 | if exists(self.max_norm): 615 | self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm) 616 | 617 | self.critic_optim.step() 618 | self.critic_optim.zero_grad() 619 | 620 | def train( 621 | self, 622 | num_episodes = 50000, 623 | max_timesteps = 500, 624 | update_timesteps = 5000, 625 | max_batch_size = 16, 626 | max_seq_len = 2048, 627 | eos_token = None, 628 | temperature = 1. 629 | ): 630 | device = self.device 631 | 632 | time = 0 633 | memories = deque([]) 634 | 635 | for eps in tqdm(range(num_episodes), desc = 'episodes'): 636 | for timestep in range(max_timesteps): 637 | time += 1 638 | 639 | # select a bunch of random states (prompts) 640 | # and get the action (sampled sequence from palm as well as the action probs) 641 | # also calculate the reward using reward model and store 642 | 643 | rand_prompt_index = randrange(0, self.num_prompts) 644 | 645 | state = self.prompt_token_ids[rand_prompt_index] 646 | 647 | # remove padding from state 648 | 649 | state_mask = state != self.pad_value 650 | state = state[state_mask] 651 | 652 | # get predicted sequence 653 | 654 | ( 655 | actions, 656 | sequence, 657 | mask, 658 | prompt_mask, 659 | action_logits, 660 | values_bins 661 | ) = self.actor_critic_generate( 662 | rearrange(state, 'n ... -> 1 n ...'), 663 | max_seq_len = max_seq_len, 664 | eos_token = eos_token, 665 | temperature = temperature, 666 | return_values = True 667 | ) 668 | 669 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token 670 | 671 | action_prob = action_logits.softmax(dim = -1) 672 | 673 | action_len = actions.shape[-1] 674 | action_log_prob = log_prob(action_prob, sequence) 675 | action_log_prob = action_log_prob[:, -action_len:] 676 | 677 | actions = rearrange(actions, '1 ... -> ...') 678 | 679 | # get reward as given by supervised trained reward model 680 | 681 | sequence = torch.cat((state, actions), dim = 0) 682 | 683 | prompt_length = len(state) 684 | prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length 685 | 686 | sequence = rearrange(sequence, 'n -> 1 n') 687 | prompt_mask = rearrange(prompt_mask, 'n -> 1 n') 688 | mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device)) 689 | 690 | reward = self.reward_model( 691 | sequence, 692 | prompt_mask = prompt_mask, 693 | mask = mask 694 | ) 695 | 696 | detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...') 697 | 698 | # store memory for learning 699 | 700 | memories.append(Memory(*map(detach_to_cpu_, ( 701 | sequence, 702 | prompt_mask, 703 | mask, 704 | action_prob, 705 | action_log_prob, 706 | reward, 707 | values_bins 708 | )))) 709 | 710 | # learn from the stored memories 711 | 712 | if time % update_timesteps == 0: 713 | self.learn(memories) 714 | memories.clear() 715 | 716 | print('rlhf training complete') 717 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/reward.py: -------------------------------------------------------------------------------- 1 | import copy 2 | from pathlib import Path 3 | 4 | from tqdm import tqdm 5 | from beartype import beartype 6 | 7 | import torch 8 | from torch import nn 9 | from torch.nn import Module 10 | import torch.nn.functional as F 11 | 12 | from einops import rearrange, repeat, reduce, pack, unpack 13 | from einops.layers.torch import Rearrange, Reduce 14 | 15 | from palm_rlhf_pytorch.utils import masked_mean, gumbel_sample 16 | from palm_rlhf_pytorch.palm import PaLM 17 | 18 | # helper functions 19 | 20 | def exists(val): 21 | return val is not None 22 | 23 | def default(val, default_val): 24 | return val if exists(val) else default_val 25 | 26 | # Reward Model - PaLM with a scalar head 27 | 28 | class RewardModel(Module): 29 | @beartype 30 | def __init__( 31 | self, 32 | palm: PaLM, 33 | dropout = 0.1, 34 | num_binned_output = 0., 35 | use_lora = True, 36 | lora_r = 8, 37 | reward_lora_scope = 'reward', 38 | sample_from_bins = None, 39 | sample_temperature = 1. 40 | ): 41 | super().__init__() 42 | 43 | self.palm = copy.deepcopy(palm) 44 | self.palm.set_dropout(dropout) 45 | 46 | self.reward_lora_scope = reward_lora_scope if use_lora else None 47 | 48 | if exists(self.reward_lora_scope): 49 | self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r) 50 | 51 | dim = palm.dim 52 | 53 | self.binned_output = num_binned_output > 1 54 | 55 | self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim)) 56 | self.response_embed = nn.Parameter(torch.zeros(1, 1, dim)) 57 | 58 | if self.binned_output: 59 | self.to_pred = nn.Linear(dim, num_binned_output) 60 | else: 61 | self.to_pred = nn.Sequential( 62 | nn.Linear(dim, 1, bias = False), 63 | Rearrange('... 1 -> ...') 64 | ) 65 | 66 | self.sample_from_bins = default(sample_from_bins, self.binned_output) 67 | self.sample_temperature = sample_temperature 68 | 69 | def load(self, path): 70 | path = Path(path) 71 | assert path.exists() 72 | self.load_state_dict(torch.load(str(path))) 73 | 74 | def finetune_parameters(self): 75 | return [ 76 | *self.to_pred.parameters(), 77 | *(self.palm.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else self.palm.parameters()) 78 | ] 79 | 80 | def forward( 81 | self, 82 | x, 83 | mask = None, 84 | prompt_mask = None, 85 | prompt_lengths = None, 86 | labels = None, 87 | disable_lora = False 88 | ): 89 | 90 | assert not (exists(prompt_mask) and exists(prompt_lengths)) 91 | 92 | # derive prompt mask from prompt lengths 93 | 94 | if exists(prompt_lengths): 95 | batch, seq_len = x.shape 96 | arange = torch.arange(seq_len, device = x.device) 97 | prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1') 98 | 99 | # reward model should have an understanding of which section is prompt, and which section is response 100 | 101 | extra_embed = None 102 | 103 | if exists(prompt_mask): 104 | extra_embed = torch.where( 105 | rearrange(prompt_mask, 'b n -> b n 1'), 106 | self.prompt_embed, 107 | self.response_embed 108 | ) 109 | 110 | # get embeddings from palm 111 | 112 | embeds = self.palm( 113 | x, 114 | extra_embed = extra_embed, 115 | return_only_embedding = True, 116 | disable_lora = disable_lora, 117 | finetune_scope = self.reward_lora_scope 118 | ) 119 | 120 | pooled = masked_mean(embeds, mask, dim = 1) 121 | pred = self.to_pred(pooled) 122 | 123 | if self.sample_from_bins and self.binned_output: 124 | assert not exists(labels) 125 | pred = gumbel_sample(pred, temperature = self.sample_temperature, dim = -1) 126 | 127 | if not exists(labels): 128 | return pred 129 | 130 | if not self.binned_output: 131 | return F.mse_loss(pred, labels) 132 | 133 | return F.cross_entropy(pred, labels) 134 | -------------------------------------------------------------------------------- /palm_rlhf_pytorch/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import einsum, nn 4 | import torch.nn.functional as F 5 | 6 | from einops import rearrange 7 | 8 | def exists(val): 9 | return val is not None 10 | 11 | # decorators 12 | 13 | def eval_decorator(fn): 14 | def inner(self, *args, **kwargs): 15 | was_training = self.training 16 | self.eval() 17 | out = fn(self, *args, **kwargs) 18 | self.train(was_training) 19 | return out 20 | return inner 21 | 22 | # tensor helpers 23 | 24 | def log(t, eps = 1e-20): 25 | return torch.log(t.clamp(min = eps)) 26 | 27 | def masked_mean(seq, mask = None, dim = 1, keepdim = False): 28 | if not exists(mask): 29 | return seq.mean(dim = dim) 30 | 31 | if seq.ndim == 3: 32 | mask = rearrange(mask, 'b n -> b n 1') 33 | 34 | masked_seq = seq.masked_fill(~mask, 0.) 35 | numer = masked_seq.sum(dim = dim, keepdim = keepdim) 36 | denom = mask.sum(dim = dim, keepdim = keepdim) 37 | 38 | masked_mean = numer / denom.clamp(min = 1e-3) 39 | masked_mean = masked_mean.masked_fill(denom == 0, 0.) 40 | return masked_mean 41 | 42 | # sampling helpers 43 | 44 | def gumbel_noise(t): 45 | noise = torch.zeros_like(t).uniform_(0, 1) 46 | return -log(-log(noise)) 47 | 48 | def gumbel_sample(t, temperature = 1., dim = -1): 49 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim) 50 | 51 | def top_p(logits, thres = 0.9): 52 | sorted_logits, sorted_indices = torch.sort(logits, descending=True) 53 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) 54 | 55 | sorted_indices_to_remove = cum_probs > (1 - thres) 56 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone() 57 | sorted_indices_to_remove[:, 0] = 0 58 | 59 | sorted_logits[sorted_indices_to_remove] = float('-inf') 60 | return sorted_logits.scatter(1, sorted_indices, sorted_logits) 61 | 62 | def top_k(logits, thres = 0.9): 63 | k = math.ceil((1 - thres) * logits.shape[-1]) 64 | val, ind = torch.topk(logits, k) 65 | probs = torch.full_like(logits, float('-inf')) 66 | probs.scatter_(1, ind, val) 67 | return probs 68 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name = 'PaLM-rlhf-pytorch', 5 | packages = find_packages(exclude=[]), 6 | version = '0.5.3', 7 | license='MIT', 8 | description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch', 9 | author = 'Phil Wang', 10 | author_email = 'lucidrains@gmail.com', 11 | long_description_content_type = 'text/markdown', 12 | url = 'https://github.com/lucidrains/PaLM-rlhf-pytorch', 13 | keywords = [ 14 | 'artificial intelligence', 15 | 'deep learning', 16 | 'transformers', 17 | 'attention mechanism', 18 | 'reinforcement learning', 19 | 'human feedback' 20 | ], 21 | install_requires=[ 22 | 'accelerate', 23 | 'adam-atan2-pytorch', 24 | 'beartype', 25 | 'einx>=0.3.0', 26 | 'einops>=0.8.0', 27 | 'hl-gauss-pytorch>=0.1.19', 28 | 'torch>=2.2', 29 | 'tqdm' 30 | ], 31 | classifiers=[ 32 | 'Development Status :: 4 - Beta', 33 | 'Intended Audience :: Developers', 34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 35 | 'License :: OSI Approved :: MIT License', 36 | 'Programming Language :: Python :: 3.6', 37 | ], 38 | ) 39 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import random 3 | from accelerate.utils.tqdm import tqdm 4 | import numpy as np 5 | 6 | import torch 7 | from lion_pytorch import Lion 8 | from torch.nn import functional as F 9 | from torch.utils.data import DataLoader, Dataset 10 | 11 | from palm_rlhf_pytorch import PaLM 12 | from accelerate import Accelerator 13 | 14 | # constants 15 | 16 | NUM_BATCHES = int(1e5) 17 | BATCH_SIZE = 4 18 | GRADIENT_ACCUMULATE_EVERY = 4 19 | LEARNING_RATE = 1e-4 20 | VALIDATE_EVERY = 100 21 | PRIME_LENGTH = 128 22 | GENERATE_EVERY = 500 23 | GENERATE_LENGTH = 512 24 | SEQ_LEN = 1024 25 | 26 | # helpers 27 | 28 | def cycle(loader): 29 | while True: 30 | for data in loader: 31 | yield data 32 | 33 | def decode_token(token): 34 | return str(chr(max(32, token))) 35 | 36 | def decode_tokens(tokens): 37 | return "".join(list(map(decode_token, tokens))) 38 | 39 | 40 | # accelerator 41 | 42 | accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATE_EVERY) 43 | device = accelerator.device 44 | 45 | # instantiate palm 46 | 47 | model = PaLM( 48 | num_tokens=256, 49 | dim=512, 50 | depth=8, 51 | flash_attn=True 52 | ).to(device) 53 | 54 | # prepare enwik8 data 55 | 56 | with gzip.open("./data/enwik8.gz") as file: 57 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy() 58 | np_train, np_valid = np.split(data, [int(90e6)]) 59 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid) 60 | 61 | class TextSamplerDataset(Dataset): 62 | def __init__(self, data, seq_len): 63 | super().__init__() 64 | self.data = data 65 | self.seq_len = seq_len 66 | 67 | def __getitem__(self, index): 68 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,)) 69 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long() 70 | return full_seq.to(device) 71 | 72 | def __len__(self): 73 | return self.data.size(0) // self.seq_len 74 | 75 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN) 76 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN) 77 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE)) 78 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE)) 79 | 80 | # optimizer 81 | 82 | optim = Lion(model.palm_parameters(), lr = LEARNING_RATE) 83 | 84 | model, optim, train_loader, val_loader = accelerator.prepare( 85 | model, optim, train_loader, val_loader 86 | ) 87 | 88 | # training 89 | 90 | for i in tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"): 91 | model.train() 92 | 93 | with accelerator.accumulate(model): 94 | loss = model(next(train_loader), return_loss = True) 95 | accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY) 96 | 97 | accelerator.print(f"training loss: {loss.item()}") 98 | accelerator.clip_grad_norm_(model.parameters(), 0.5) 99 | 100 | optim.step() 101 | optim.zero_grad() 102 | 103 | if i % VALIDATE_EVERY == 0: 104 | model.eval() 105 | with torch.no_grad(): 106 | loss = model(next(val_loader), return_loss = True) 107 | accelerator.print(f"validation loss: {loss.item()}") 108 | 109 | if i % GENERATE_EVERY == 0: 110 | model.eval() 111 | inp = random.choice(val_dataset)[:PRIME_LENGTH] 112 | prime = decode_tokens(inp) 113 | accelerator.print(f"%s \n\n %s", (prime, "*" * 100)) 114 | 115 | # Check if model is wrapped 116 | if hasattr(model, "module"): 117 | sample = model.module.generate(GENERATE_LENGTH, inp[None, ...]) 118 | else: 119 | sample = model.generate(GENERATE_LENGTH, inp[None, ...]) 120 | 121 | output_str = decode_tokens(sample[0]) 122 | accelerator.print(output_str, "\n") 123 | --------------------------------------------------------------------------------