├── .github
└── workflows
│ └── python-publish.yml
├── .gitignore
├── LICENSE
├── README.md
├── chatgpt.png
├── data
├── README.md
└── enwik8.gz
├── examples.py
├── palm_rlhf_pytorch
├── __init__.py
├── attention.py
├── grpo.py
├── implicit_process_reward.py
├── lora.py
├── palm.py
├── ppo.py
├── reward.py
└── utils.py
├── setup.py
└── train.py
/.github/workflows/python-publish.yml:
--------------------------------------------------------------------------------
1 |
2 |
3 | # This workflow will upload a Python Package using Twine when a release is created
4 | # For more information see: https://help.github.com/en/actions/language-and-framework-guides/using-python-with-github-actions#publishing-to-package-registries
5 |
6 | # This workflow uses actions that are not certified by GitHub.
7 | # They are provided by a third-party and are governed by
8 | # separate terms of service, privacy policy, and support
9 | # documentation.
10 |
11 | name: Upload Python Package
12 |
13 | on:
14 | release:
15 | types: [published]
16 |
17 | jobs:
18 | deploy:
19 |
20 | runs-on: ubuntu-latest
21 |
22 | steps:
23 | - uses: actions/checkout@v2
24 | - name: Set up Python
25 | uses: actions/setup-python@v2
26 | with:
27 | python-version: '3.x'
28 | - name: Install dependencies
29 | run: |
30 | python -m pip install --upgrade pip
31 | pip install build
32 | - name: Build package
33 | run: python -m build
34 | - name: Publish package
35 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
36 | with:
37 | user: __token__
38 | password: ${{ secrets.PYPI_API_TOKEN }}
39 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | *.py,cover
51 | .hypothesis/
52 | .pytest_cache/
53 |
54 | # Translations
55 | *.mo
56 | *.pot
57 |
58 | # Django stuff:
59 | *.log
60 | local_settings.py
61 | db.sqlite3
62 | db.sqlite3-journal
63 |
64 | # Flask stuff:
65 | instance/
66 | .webassets-cache
67 |
68 | # Scrapy stuff:
69 | .scrapy
70 |
71 | # Sphinx documentation
72 | docs/_build/
73 |
74 | # PyBuilder
75 | target/
76 |
77 | # Jupyter Notebook
78 | .ipynb_checkpoints
79 |
80 | # IPython
81 | profile_default/
82 | ipython_config.py
83 |
84 | # pyenv
85 | .python-version
86 |
87 | # pipenv
88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
91 | # install all needed dependencies.
92 | #Pipfile.lock
93 |
94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95 | __pypackages__/
96 |
97 | # Celery stuff
98 | celerybeat-schedule
99 | celerybeat.pid
100 |
101 | # SageMath parsed files
102 | *.sage.py
103 |
104 | # Environments
105 | .env
106 | .venv
107 | env/
108 | venv/
109 | ENV/
110 | env.bak/
111 | venv.bak/
112 |
113 | # Spyder project settings
114 | .spyderproject
115 | .spyproject
116 |
117 | # Rope project settings
118 | .ropeproject
119 |
120 | # mkdocs documentation
121 | /site
122 |
123 | # mypy
124 | .mypy_cache/
125 | .dmypy.json
126 | dmypy.json
127 |
128 | # Pyre type checker
129 | .pyre/
130 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Phil Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | *official chatgpt blogpost*
4 |
5 | ## PaLM + RLHF - Pytorch (wip)
6 |
7 | Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Maybe I'll add retrieval functionality too, à la RETRO
8 |
9 | If you are interested in replicating something like ChatGPT out in the open, please consider joining Laion
10 |
11 | Potential successor: Direct Preference Optimization - all the code in this repo becomes ~ binary cross entropy loss, < 5 loc. So much for Reward models and PPO
12 |
13 | ## FAQ
14 |
15 | - Does this contain a model for inference?
16 |
17 | There is no trained model. This is just the ship and overall map. We still need millions of dollars of compute + data to sail to the correct point in high dimensional parameter space. Even then, you need professional sailors (like Robin Rombach of Stable Diffusion fame) to actually guide the ship through turbulent times to that point.
18 |
19 | ## Community
20 |
21 | CarperAI had been working on an RLHF framework for large language models for many months prior to the release of ChatGPT.
22 |
23 | Yannic Kilcher is also working on an open sourced implementation
24 |
25 | AI Coffeebreak w/ Letitia | Code Emporium | Code Emporium Part 2
26 |
27 | ## Appreciation
28 |
29 | - Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research
30 |
31 | - 🤗 Hugging Face and CarperAI for penning the blog post Illustrating Reinforcement Learning from Human Feedback (RLHF), and the former also for their accelerate library
32 |
33 | - @kisseternity and @taynoel84 for the code review and finding bugs
34 |
35 | - Enrico for integrating Flash Attention from Pytorch 2.0
36 |
37 | ## Install
38 |
39 | ```bash
40 | $ pip install palm-rlhf-pytorch
41 | ```
42 |
43 | ## Usage
44 |
45 | First train `PaLM`, like any other autoregressive transformer
46 |
47 | ```python
48 | import torch
49 | from palm_rlhf_pytorch import PaLM
50 |
51 | palm = PaLM(
52 | num_tokens = 20000,
53 | dim = 512,
54 | depth = 12,
55 | flash_attn = True # https://arxiv.org/abs/2205.14135
56 | ).cuda()
57 |
58 | seq = torch.randint(0, 20000, (1, 2048)).cuda()
59 |
60 | loss = palm(seq, return_loss = True)
61 | loss.backward()
62 |
63 | # after much training, you can now generate sequences
64 |
65 | generated = palm.generate(2048) # (1, 2048)
66 | ```
67 |
68 | Then train your reward model, with the curated human feedback. In the original paper, they could not get reward model to be finetuned from a pretrained transformer without overfitting, but I gave the option to finetune with `LoRA` anyways, since it is still open research.
69 |
70 | ```python
71 | import torch
72 | from palm_rlhf_pytorch import PaLM, RewardModel
73 |
74 | palm = PaLM(
75 | num_tokens = 20000,
76 | dim = 512,
77 | depth = 12,
78 | causal = False
79 | )
80 |
81 | reward_model = RewardModel(
82 | palm,
83 | num_binned_output = 5 # say rating from 1 to 5
84 | ).cuda()
85 |
86 | # mock data
87 |
88 | seq = torch.randint(0, 20000, (1, 1024)).cuda()
89 | prompt_mask = torch.zeros(1, 1024).bool().cuda() # which part of the sequence is prompt, which part is response
90 | labels = torch.randint(0, 5, (1,)).cuda()
91 |
92 | # train
93 |
94 | loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
95 | loss.backward()
96 |
97 | # after much training
98 |
99 | reward = reward_model(seq, prompt_mask = prompt_mask)
100 | ```
101 |
102 | Then you will pass your transformer and the rewards model to the `RLHFTrainer`
103 |
104 | ```python
105 | import torch
106 | from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer
107 |
108 | # load your pretrained palm
109 |
110 | palm = PaLM(
111 | num_tokens = 20000,
112 | dim = 512,
113 | depth = 12
114 | ).cuda()
115 |
116 | palm.load('./path/to/pretrained/palm.pt')
117 |
118 | # load your pretrained reward model
119 |
120 | reward_model = RewardModel(
121 | palm,
122 | num_binned_output = 5
123 | ).cuda()
124 |
125 | reward_model.load('./path/to/pretrained/reward_model.pt')
126 |
127 | # ready your list of prompts for reinforcement learning
128 |
129 | prompts = torch.randint(0, 256, (50000, 512)).cuda() # 50k prompts
130 |
131 | # pass it all to the trainer and train
132 |
133 | trainer = RLHFTrainer(
134 | palm = palm,
135 | reward_model = reward_model,
136 | prompt_token_ids = prompts
137 | )
138 |
139 | trainer.train(num_episodes = 50000)
140 |
141 | # then, if it succeeded...
142 | # generate say 10 samples and use the reward model to return the best one
143 |
144 | answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
145 | ```
146 |
147 | ## Todo
148 |
149 | - [x] clone base transformer with separate lora for critic
150 | - [x] also allow for non-LoRA based finetuning
151 | - [x] redo normalize to be able to have a masked version, not sure if anyone will ever use per token rewards / values, but good practice to implement
152 | - [x] equip with the best attention
153 |
154 | - [ ] add Hugging Face accelerate and test out wandb instrumentation
155 | - [ ] search literature to figure out what is the latest SOTA for PPO, assuming RL field is still making progress.
156 | - [ ] test the system using a pretrained sentiment network as reward model
157 | - [ ] write the memory in PPO to memmapped numpy file
158 | - [ ] get sampling with variable lengthed prompts working, even if it is not needed given bottleneck is human feedback
159 | - [ ] allow for finetuning penultimate N layers only in either actor or critic, assuming if pretrained
160 | - [ ] incorporate some learning points from Sparrow, given Letitia's video
161 | - [ ] simple web interface with django + htmx for collecting human feedback
162 | - [ ] consider RLAIF
163 |
164 | ## Citations
165 |
166 | ```bibtex
167 | @article{Stiennon2020LearningTS,
168 | title = {Learning to summarize from human feedback},
169 | author = {Nisan Stiennon and Long Ouyang and Jeff Wu and Daniel M. Ziegler and Ryan J. Lowe and Chelsea Voss and Alec Radford and Dario Amodei and Paul Christiano},
170 | journal = {ArXiv},
171 | year = {2020},
172 | volume = {abs/2009.01325}
173 | }
174 | ```
175 |
176 | ```bibtex
177 | @inproceedings{Chowdhery2022PaLMSL,
178 | title = {PaLM: Scaling Language Modeling with Pathways},
179 | author = {Aakanksha Chowdhery and Sharan Narang and Jacob Devlin and Maarten Bosma and Gaurav Mishra and Adam Roberts and Paul Barham and Hyung Won Chung and Charles Sutton and Sebastian Gehrmann and Parker Schuh and Kensen Shi and Sasha Tsvyashchenko and Joshua Maynez and Abhishek Rao and Parker Barnes and Yi Tay and Noam M. Shazeer and Vinodkumar Prabhakaran and Emily Reif and Nan Du and Benton C. Hutchinson and Reiner Pope and James Bradbury and Jacob Austin and Michael Isard and Guy Gur-Ari and Pengcheng Yin and Toju Duke and Anselm Levskaya and Sanjay Ghemawat and Sunipa Dev and Henryk Michalewski and Xavier Garc{\'i}a and Vedant Misra and Kevin Robinson and Liam Fedus and Denny Zhou and Daphne Ippolito and David Luan and Hyeontaek Lim and Barret Zoph and Alexander Spiridonov and Ryan Sepassi and David Dohan and Shivani Agrawal and Mark Omernick and Andrew M. Dai and Thanumalayan Sankaranarayana Pillai and Marie Pellat and Aitor Lewkowycz and Erica Oliveira Moreira and Rewon Child and Oleksandr Polozov and Katherine Lee and Zongwei Zhou and Xuezhi Wang and Brennan Saeta and Mark Diaz and Orhan Firat and Michele Catasta and Jason Wei and Kathleen S. Meier-Hellstern and Douglas Eck and Jeff Dean and Slav Petrov and Noah Fiedel},
180 | year = {2022}
181 | }
182 | ```
183 |
184 | ```bibtex
185 | @article{Hu2021LoRALA,
186 | title = {LoRA: Low-Rank Adaptation of Large Language Models},
187 | author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
188 | journal = {ArXiv},
189 | year = {2021},
190 | volume = {abs/2106.09685}
191 | }
192 | ```
193 |
194 | ```bibtex
195 | @inproceedings{Sun2022ALT,
196 | title = {A Length-Extrapolatable Transformer},
197 | author = {Yutao Sun and Li Dong and Barun Patra and Shuming Ma and Shaohan Huang and Alon Benhaim and Vishrav Chaudhary and Xia Song and Furu Wei},
198 | year = {2022}
199 | }
200 | ```
201 |
202 | ```bibtex
203 | @misc{gilmer2023intriguing
204 | title = {Intriguing Properties of Transformer Training Instabilities},
205 | author = {Justin Gilmer, Andrea Schioppa, and Jeremy Cohen},
206 | year = {2023},
207 | status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
208 | }
209 | ```
210 |
211 | ```bibtex
212 | @inproceedings{dao2022flashattention,
213 | title = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
214 | author = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
215 | booktitle = {Advances in Neural Information Processing Systems},
216 | year = {2022}
217 | }
218 | ```
219 |
220 | ```bibtex
221 | @misc{Rubin2024,
222 | author = {Ohad Rubin},
223 | url = {https://medium.com/@ohadrubin/exploring-weight-decay-in-layer-normalization-challenges-and-a-reparameterization-solution-ad4d12c24950}
224 | }
225 | ```
226 |
227 | ```bibtex
228 | @inproceedings{Yuan2024FreePR,
229 | title = {Free Process Rewards without Process Labels},
230 | author = {Lifan Yuan and Wendi Li and Huayu Chen and Ganqu Cui and Ning Ding and Kaiyan Zhang and Bowen Zhou and Zhiyuan Liu and Hao Peng},
231 | year = {2024},
232 | url = {https://api.semanticscholar.org/CorpusID:274445748}
233 | }
234 | ```
235 |
236 | ```bibtex
237 | @article{Shao2024DeepSeekMathPT,
238 | title = {DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models},
239 | author = {Zhihong Shao and Peiyi Wang and Qihao Zhu and Runxin Xu and Jun-Mei Song and Mingchuan Zhang and Y. K. Li and Yu Wu and Daya Guo},
240 | journal = {ArXiv},
241 | year = {2024},
242 | volume = {abs/2402.03300},
243 | url = {https://api.semanticscholar.org/CorpusID:267412607}
244 | }
245 | ```
246 |
247 | ```bibtex
248 | @article{Farebrother2024StopRT,
249 | title = {Stop Regressing: Training Value Functions via Classification for Scalable Deep RL},
250 | author = {Jesse Farebrother and Jordi Orbay and Quan Ho Vuong and Adrien Ali Taiga and Yevgen Chebotar and Ted Xiao and Alex Irpan and Sergey Levine and Pablo Samuel Castro and Aleksandra Faust and Aviral Kumar and Rishabh Agarwal},
251 | journal = {ArXiv},
252 | year = {2024},
253 | volume = {abs/2403.03950},
254 | url = {https://api.semanticscholar.org/CorpusID:268253088}
255 | }
256 | ```
257 |
258 | ```bibtex
259 | @misc{Liu2025,
260 | title = {Understanding R1-Zero-Like Training: A Critical Perspective},
261 | author = {Zichen Liu, Changyu Chen, Wenjun Li, Penghui Qi, Tianyu Pang, Chao Du, Wee Sun Lee, Min Lin},
262 | url = {https://github.com/sail-sg/understand-r1-zero/blob/main/understand-r1-zero.pdf}
263 | }
264 | ```
265 |
266 | ```bibtex
267 | @inproceedings{Yue2025DoesRL,
268 | title = {Does Reinforcement Learning Really Incentivize Reasoning Capacity in LLMs Beyond the Base Model?},
269 | author = {Yang Yue and Zhiqi Chen and Rui Lu and Andrew Zhao and Zhaokai Wang and Shiji Song and Gao Huang},
270 | year = {2025},
271 | url = {https://api.semanticscholar.org/CorpusID:277940134}
272 | }
273 | ```
274 |
--------------------------------------------------------------------------------
/chatgpt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/PaLM-rlhf-pytorch/114b4db005809c8c055db35257af3ce49395b75a/chatgpt.png
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data source
2 |
3 | The enwik8 data was downloaded from the Hutter prize page: http://prize.hutter1.net/
--------------------------------------------------------------------------------
/data/enwik8.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lucidrains/PaLM-rlhf-pytorch/114b4db005809c8c055db35257af3ce49395b75a/data/enwik8.gz
--------------------------------------------------------------------------------
/examples.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from palm_rlhf_pytorch import PaLM, RewardModel, RLHFTrainer
3 | from accelerate import Accelerator
4 |
5 | accelerator = Accelerator()
6 | device = accelerator.device
7 |
8 | # load your pretrained palm
9 |
10 | palm = PaLM(
11 | num_tokens = 20000,
12 | dim = 512,
13 | depth = 12
14 | ).to(device)
15 |
16 |
17 | # load your pretrained reward model
18 |
19 | reward_model = RewardModel(
20 | palm,
21 | num_binned_output = 5
22 | ).to(device)
23 |
24 | # Train you reward model on mock data :
25 | # mock data
26 |
27 | seq = torch.randint(0, 20000, (1, 1024)).to(device)
28 | prompt_mask = torch.zeros(1, 1024).bool().to(device) # which part of the sequence is prompt, which part is response
29 | labels = torch.randint(0, 5, (1,)).to(device)
30 |
31 | # train
32 | loss = reward_model(seq, prompt_mask = prompt_mask, labels = labels)
33 | accelerator.backward(loss)
34 |
35 | # after much training
36 | reward = reward_model(seq, prompt_mask = prompt_mask)
37 |
38 |
39 | # ready your list of prompts for reinforcement learning
40 |
41 | prompts = torch.randint(0, 256, (1, 512)).to(device) # 1 prompt
42 |
43 | # pass it all to the trainer and train
44 |
45 | trainer = RLHFTrainer(
46 | palm = palm,
47 | reward_model = reward_model,
48 | prompt_token_ids = prompts
49 | )
50 |
51 | accelerator.print("Training")
52 | trainer.train(
53 | num_episodes = 1,
54 | max_timesteps = 1,
55 | update_timesteps = 1,
56 | max_batch_size = 256,
57 | max_seq_len = 2048,
58 | eos_token = None,
59 | temperature = 1.
60 | )
61 |
62 | # then, if it succeeded...
63 | # generate say 10 samples and use the reward model to return the best one
64 | accelerator.print("Generating answer")
65 | answer = trainer.generate(2048, prompt = prompts[0], num_samples = 10) # (<= 2048,)
66 | accelerator.print(f"answer: {answer}")
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/__init__.py:
--------------------------------------------------------------------------------
1 | from palm_rlhf_pytorch.palm import PaLM
2 | from palm_rlhf_pytorch.ppo import RLHFTrainer, ActorCritic
3 |
4 | from palm_rlhf_pytorch.reward import RewardModel
5 | from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM
6 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/attention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn, einsum
3 | from torch.nn import Module
4 | import torch.nn.functional as F
5 |
6 | from collections import namedtuple
7 | from functools import wraps
8 | from packaging import version
9 |
10 | from einops import rearrange
11 |
12 | # constants
13 |
14 | Config = namedtuple('EfficientAttentionConfig', ['enable_flash', 'enable_math', 'enable_mem_efficient'])
15 |
16 | # helpers
17 |
18 | def exists(val):
19 | return val is not None
20 |
21 | def once(fn):
22 | called = False
23 | @wraps(fn)
24 | def inner(x):
25 | nonlocal called
26 | if called:
27 | return
28 | called = True
29 | return fn(x)
30 | return inner
31 |
32 | print_once = once(print)
33 |
34 | # main class
35 |
36 | class Attention(Module):
37 | def __init__(
38 | self,
39 | dropout = 0.,
40 | causal = False,
41 | use_flash_attn = False
42 | ):
43 | super().__init__()
44 | self.dropout = dropout
45 | self.attn_dropout = nn.Dropout(dropout)
46 |
47 | self.causal = causal
48 | self.register_buffer("mask", None, persistent=False)
49 |
50 | self.use_flash_attn = use_flash_attn
51 | assert not (use_flash_attn and version.parse(torch.__version__) < version.parse('2.0.0')), 'in order to use flash attention, you must be using pytorch 2.0 or above'
52 |
53 | # determine efficient attention configs for cuda and cpu
54 |
55 | self.cpu_config = Config(True, True, True)
56 | self.cuda_config = None
57 |
58 | if not torch.cuda.is_available() or not use_flash_attn:
59 | return
60 |
61 | device_properties = torch.cuda.get_device_properties(torch.device('cuda'))
62 |
63 | if device_properties.major == 8 and device_properties.minor == 0:
64 | print_once('A100 GPU detected, using flash attention if input tensor is on cuda')
65 | self.cuda_config = Config(True, False, False)
66 | else:
67 | print_once('Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda')
68 | self.cuda_config = Config(False, True, True)
69 |
70 | def get_mask(self, n, device):
71 | if exists(self.mask) and self.mask.shape[-1] >= n:
72 | return self.mask[:n, :n]
73 |
74 | mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1)
75 | self.register_buffer("mask", mask, persistent=False)
76 | return mask
77 |
78 | def flash_attn(self, q, k, v, mask = None):
79 | _, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda
80 |
81 | # Recommended for multi-query single-key-value attention by Tri Dao
82 | # kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64])
83 |
84 | k = rearrange(k, 'b ... -> b 1 ...').expand_as(q)
85 | v = rearrange(v, 'b ... -> b 1 ...').expand_as(q)
86 |
87 | # Check if mask exists and expand to compatible shape
88 | # The mask is B L, so it would have to be expanded to B H N L
89 |
90 | if exists(mask):
91 | mask = rearrange(mask, 'b j -> b 1 1 j')
92 | mask = mask.expand(-1, heads, q_len, -1)
93 |
94 | # Check if there is a compatible device for flash attention
95 |
96 | config = self.cuda_config if is_cuda else self.cpu_config
97 |
98 | # pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale
99 |
100 | with torch.backends.cuda.sdp_kernel(**config._asdict()):
101 | out = F.scaled_dot_product_attention(
102 | q, k, v,
103 | attn_mask = mask,
104 | dropout_p = self.dropout if self.training else 0.,
105 | is_causal = self.causal
106 | )
107 |
108 | return out
109 |
110 | def forward(self, q, k, v, mask = None):
111 | """
112 | einstein notation
113 | b - batch
114 | h - heads
115 | n, i, j - sequence length (base sequence length, source, target)
116 | d - feature dimension
117 | """
118 |
119 | n, device = q.shape[-2], q.device
120 |
121 | scale = q.shape[-1] ** -0.5
122 |
123 | if self.use_flash_attn:
124 | return self.flash_attn(q, k, v, mask = mask)
125 |
126 | # similarity
127 |
128 | sim = einsum("b h i d, b j d -> b h i j", q, k) * scale
129 |
130 | # key padding mask
131 |
132 | if exists(mask):
133 | mask = rearrange(mask, 'b j -> b 1 1 j')
134 | sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max)
135 |
136 | # causal mask
137 |
138 | if self.causal:
139 | causal_mask = self.get_mask(n, device)
140 | sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
141 |
142 | # attention
143 |
144 | attn = sim.softmax(dim=-1)
145 | attn = self.attn_dropout(attn)
146 |
147 | # aggregate values
148 |
149 | out = einsum("b h i j, b j d -> b h i d", attn, v)
150 |
151 | return out
152 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/grpo.py:
--------------------------------------------------------------------------------
1 | """
2 | GRPO based training logic - https://arxiv.org/abs/2402.03300
3 | """
4 |
5 | from __future__ import annotations
6 | from typing import Callable, Deque
7 |
8 | import math
9 | import copy
10 | from pathlib import Path
11 | from functools import partial
12 | from collections import deque, namedtuple
13 | from random import randrange
14 |
15 | import torch
16 | from torch import nn, Tensor
17 | from torch.nn import Module
18 | import torch.nn.functional as F
19 |
20 | from torch.utils.data import Dataset, DataLoader
21 | from torch.nn.utils.rnn import pad_sequence
22 |
23 | from adam_atan2_pytorch import AdoptAtan2
24 |
25 | from palm_rlhf_pytorch.palm import PaLM
26 | from palm_rlhf_pytorch.reward import RewardModel
27 | from palm_rlhf_pytorch.utils import masked_mean, eval_decorator
28 |
29 | from accelerate import Accelerator
30 | from accelerate.utils.tqdm import tqdm
31 |
32 | from einx import get_at
33 | from einops import rearrange, repeat, reduce, pack, unpack
34 | from einops.layers.torch import Rearrange
35 |
36 | # einstein notation
37 |
38 | # b - batch
39 | # n - sequence
40 | # d - feature dimension
41 | # l - logits
42 |
43 | # grpo based training
44 |
45 | # critic completely replaced with monte carlo sampling from actor + reward model
46 | # https://www.youtube.com/watch?v=bAWV_yrqx4w
47 |
48 | GRPOActionReturn = namedtuple('GRPOActionReturn', [
49 | 'actions',
50 | 'sequence',
51 | 'mask',
52 | 'prompt_mask',
53 | 'action_logits',
54 | ])
55 |
56 | class Actor(Module):
57 | def __init__(
58 | self,
59 | palm: PaLM,
60 | actor_lora = True,
61 | actor_lora_r = 8,
62 | actor_lora_scope = 'actor',
63 | actor_dropout = 0.,
64 | ):
65 | super().__init__()
66 | self.actor_palm = palm
67 |
68 | self.actor_palm.set_dropout(actor_dropout)
69 |
70 | self.actor_lora = actor_lora
71 |
72 | self.actor_lora_scope = actor_lora_scope if actor_lora else None
73 |
74 | if self.actor_lora:
75 | self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
76 |
77 | def parameters(self):
78 | if not self.actor_lora:
79 | return self.actor_palm.parameters()
80 |
81 | return [
82 | *self.actor_palm.finetune_parameters(self.actor_lora_scope)
83 | ]
84 |
85 | @torch.no_grad()
86 | @eval_decorator
87 | def generate(
88 | self,
89 | state,
90 | max_seq_len,
91 | eos_token = None,
92 | **kwargs
93 | ):
94 | actions = self.actor_palm.generate(
95 | max_seq_len,
96 | prompt = state,
97 | eos_token = eos_token,
98 | finetune_scope = self.actor_lora_scope,
99 | use_tqdm = True,
100 | **kwargs
101 | )
102 |
103 | sequence = torch.cat((state, actions), dim = -1)
104 | action_len = actions.shape[-1]
105 | state_len = state.shape[-1]
106 |
107 | prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
108 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
109 |
110 | action_mask = ~prompt_mask
111 |
112 | mask = None
113 | if exists(eos_token):
114 | mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
115 | mask = F.pad(mask, (1, -1), value = True) # include eos token
116 | action_mask &= mask
117 |
118 | action_logits = self.forward(
119 | sequence,
120 | mask = action_mask,
121 | )
122 |
123 | return GRPOActionReturn(
124 | actions,
125 | sequence,
126 | mask,
127 | prompt_mask,
128 | action_logits
129 | )
130 |
131 | def forward(
132 | self,
133 | x,
134 | mask = None,
135 | ):
136 | return self.actor_palm(x, finetune_scope = self.actor_lora_scope)
137 |
138 | # data
139 |
140 | Memory = namedtuple('Memory', [
141 | 'sequence',
142 | 'prompt_mask',
143 | 'mask',
144 | 'action_prob',
145 | 'action_log_prob',
146 | 'group_relative_normalized_reward',
147 | ])
148 |
149 | class ExperienceDataset(Dataset):
150 | def __init__(
151 | self,
152 | data,
153 | device = None
154 | ):
155 | super().__init__()
156 | self.data = data
157 | self.device = device
158 |
159 | def __len__(self):
160 | return self.data[0].shape[0]
161 |
162 | def __getitem__(self, ind):
163 | return tuple(map(lambda t: t[ind].to(self.device), self.data))
164 |
165 | def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
166 | ds = ExperienceDataset(data, device = device)
167 | return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
168 |
169 | # helper functions
170 |
171 | def exists(val):
172 | return val is not None
173 |
174 | def default(val, d):
175 | if exists(val):
176 | return val
177 | return d() if callable(d) else d
178 |
179 | def first(x):
180 | return x[0]
181 |
182 | def divisible_by(num, den):
183 | return (num % den) == 0
184 |
185 | def pad_sequence_fixed(sequences, *args, **kwargs):
186 | first_el = sequences[0]
187 | has_no_dimension = first_el.ndim == 0
188 |
189 | # if no dimensions, add a single dimension
190 | if has_no_dimension:
191 | sequences = tuple(map(lambda t: t[None], sequences))
192 |
193 | out = pad_sequence(sequences, *args, **kwargs)
194 |
195 | if not has_no_dimension:
196 | return out
197 |
198 | return rearrange(out, '... 1 -> ...')
199 |
200 | def log(t, eps = 1e-20):
201 | return torch.log(t.clamp(min = eps))
202 |
203 | def shift(t, value = 0, shift = 1, dim = -1):
204 | zeros = (0, 0) * (-dim - 1)
205 | return F.pad(t, (*zeros, shift, -shift), value = value)
206 |
207 | def masked_entropy(prob, dim = -1, mask = None):
208 | entropies = (prob * log(prob)).sum(dim = -1)
209 | return masked_mean(entropies, mask = mask).mean()
210 |
211 | def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False):
212 | """
213 | need to account for variable sequence lengths, therefore not using the built-in functional version
214 | """
215 | kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)
216 | loss = masked_mean(kl_divs, mask)
217 |
218 | if not reduce_batch:
219 | return loss
220 |
221 | return loss.mean()
222 |
223 | # rlhf trainer
224 |
225 | class RLHFTrainer(Module):
226 |
227 | def __init__(
228 | self,
229 | *,
230 | prompts: list[str] | None = None,
231 | prompts_path: str | None = None,
232 | prompt_token_ids: Tensor | None = None,
233 | tokenizer: Callable | None = None,
234 | palm: PaLM,
235 | reward_model: RewardModel,
236 | grpo_num_times_sample_rewards = 10,
237 | actor_lr = 1e-4,
238 | actor_wd = 0.,
239 | actor_lora = True,
240 | actor_lora_r = 8,
241 | actor_dropout = 0.,
242 | betas = (0.9, 0.999),
243 | max_norm = None,
244 | eps_clip = 0.2,
245 | beta_s = .01,
246 | pad_value = 0.,
247 | minibatch_size = 16,
248 | epochs = 1,
249 | kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is
250 | accelerate_kwargs: dict = dict(),
251 | ):
252 | super().__init__()
253 |
254 | self.accelerate = Accelerator(**accelerate_kwargs)
255 |
256 | # take care of prompts -> token ids
257 |
258 | assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1
259 |
260 | if exists(prompts_path):
261 | path = Path(prompts_path)
262 | prompts = path.read_text().split('\n')
263 |
264 | if exists(prompts):
265 | assert len(prompts) > 0, 'no prompts'
266 | assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
267 | prompt_token_ids = tokenizer(prompts)
268 |
269 | self.pad_value = pad_value # token pad value
270 | self.num_prompts = prompt_token_ids.shape[0]
271 | self.register_buffer('prompt_token_ids', prompt_token_ids)
272 |
273 | # models
274 |
275 | self.palm = palm
276 |
277 | actor = Actor(
278 | palm = palm,
279 | actor_lora = actor_lora,
280 | actor_lora_r = actor_lora_r,
281 | actor_dropout = actor_dropout,
282 | )
283 |
284 | self.actor = actor
285 |
286 | self.actor_generate = self.actor.generate
287 |
288 | self.reward_model = reward_model.eval()
289 |
290 | # train hyperparameters
291 |
292 | self.epochs = epochs
293 | self.minibatch_size = minibatch_size
294 | self.max_norm = max_norm
295 |
296 | self.kl_div_loss_weight = kl_div_loss_weight
297 |
298 | # optimizers
299 |
300 | self.actor_optim = AdoptAtan2(actor.parameters(), lr = actor_lr, weight_decay = actor_wd, betas = betas)
301 |
302 | # grpo hyperparams
303 |
304 | self.eps_clip = eps_clip
305 | self.beta_s = beta_s
306 |
307 | # grpo - the number of times to sample rewards for a given state (prompt) for normalization
308 |
309 | self.grpo_num_times_sample_rewards = grpo_num_times_sample_rewards
310 |
311 | # prepare with accelerator
312 |
313 | (
314 | self.actor,
315 | self.reward_model,
316 | self.actor_optim,
317 | ) = self.accelerate.prepare(
318 | self.actor,
319 | self.reward_model,
320 | self.actor_optim,
321 | )
322 |
323 |
324 | def print(self, msg):
325 | return self.accelerate.print(msg)
326 |
327 | def save(self, filepath = './checkpoint.pt'):
328 | torch.save(self.actor.state_dict(), filepath)
329 |
330 | def load(self, filepath = './checkpoint.pt'):
331 | state_dict = torch.load(filepath)
332 | self.actor.load_state_dict(state_dict)
333 |
334 | @property
335 | def device(self):
336 | return self.accelerate.device
337 |
338 | @torch.no_grad()
339 | def generate(
340 | self,
341 | max_seq_len,
342 | *args,
343 | prompt,
344 | num_samples = 4, # sample 4 per prompt and select the one with highest reward
345 | **kwargs
346 | ):
347 | assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
348 | prompt = repeat(prompt, 'n -> b n', b = num_samples)
349 |
350 | actor = self.accelerate.unwrap_model(self.actor)
351 | reward_model = self.accelerate.unwrap_model(self.reward_model)
352 |
353 | actor.eval()
354 |
355 | (
356 | actions,
357 | sequences,
358 | mask,
359 | prompt_mask,
360 | action_logits,
361 | _
362 | ) = actor.generate(
363 | prompt,
364 | *args,
365 | max_seq_len = max_seq_len,
366 | **kwargs
367 | )
368 |
369 | rewards = reward_model(
370 | sequences,
371 | prompt_mask = prompt_mask,
372 | mask = mask
373 | )
374 |
375 | best_sequence_index = rewards.topk(1, dim = -1).indices
376 |
377 | best_sequence = sequences[best_sequence_index]
378 | best_sequence = rearrange(best_sequence, '1 ... -> ...')
379 |
380 | return best_sequence
381 |
382 | def learn(
383 | self,
384 | memories: Deque[Memory]
385 | ):
386 | # stack all data stored in the memories
387 |
388 | all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))
389 |
390 | # prepare dataloader for policy phase training
391 |
392 | dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
393 |
394 | self.actor.train()
395 |
396 | # GRPO training
397 |
398 | for _ in range(self.epochs):
399 | for (
400 | sequences,
401 | prompt_masks,
402 | masks,
403 | old_action_probs,
404 | old_log_probs,
405 | rewards,
406 | ) in dl:
407 | action_masks = ~prompt_masks & masks
408 |
409 | action_logits = self.actor(
410 | sequences,
411 | mask = action_masks
412 | )
413 |
414 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
415 | action_len = old_log_probs.shape[-1]
416 |
417 | action_probs = action_logits.softmax(dim = -1)
418 | action_log_probs = get_at('b n [l], b n -> b n', action_probs, sequences)
419 | action_log_probs = action_log_probs[:, -action_len:]
420 |
421 | # calculate entropies, taking into account which part of the sequence is actually an action
422 |
423 | entropies = masked_entropy(action_probs, mask = action_masks)
424 |
425 | # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
426 |
427 | kl_penalty = 0.
428 |
429 | if self.kl_div_loss_weight > 0:
430 | kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight
431 |
432 | # subtract the kl penalty from the rewards
433 |
434 | rewards = rewards - kl_penalty
435 |
436 | # calculate clipped surrogate objective, classic PPO loss
437 |
438 | ratios = (action_log_probs - old_log_probs).exp()
439 |
440 | surr1 = ratios * rewards
441 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * rewards
442 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies
443 |
444 | # combine losses
445 |
446 | loss = policy_loss.mean()
447 |
448 | # update actor
449 |
450 | self.accelerate.backward(loss)
451 |
452 | self.print(f'policy_loss: {loss.item():.3f}')
453 |
454 | if exists(self.max_norm):
455 | self.accelerator.clip_grad_norm_(self.actor.actor_parameters(), self.max_norm)
456 |
457 | self.actor_optim.step()
458 | self.actor_optim.zero_grad()
459 |
460 | def train(
461 | self,
462 | num_episodes = 50000,
463 | max_timesteps = 500,
464 | update_timesteps = 5000,
465 | max_batch_size = 16,
466 | max_seq_len = 2048,
467 | eos_token = None,
468 | temperature = 1.
469 | ):
470 | action_sample_times = self.grpo_num_times_sample_rewards
471 |
472 | device = self.device
473 |
474 | time = 0
475 | memories = deque([])
476 |
477 | for eps in tqdm(range(num_episodes), desc = 'episodes'):
478 | for timestep in range(max_timesteps):
479 | time += 1
480 |
481 | # select a bunch of random states (prompts)
482 | # and get the action (sampled sequence from palm as well as the action probs)
483 | # also calculate the reward using reward model and store
484 |
485 | rand_prompt_index = randrange(0, self.num_prompts)
486 |
487 | state = self.prompt_token_ids[rand_prompt_index]
488 |
489 | # remove padding from state
490 |
491 | state_mask = state != self.pad_value
492 | state = state[state_mask]
493 |
494 | # will sample each state more than once to get an estimate of the value, for removing the critic altogether from GRPO paper, Shao et al.
495 |
496 | states = repeat(state, 'n -> b n', b = action_sample_times + 1)
497 |
498 | # get predicted sequence
499 |
500 | (
501 | actions,
502 | sequence,
503 | mask,
504 | prompt_mask,
505 | action_logits,
506 | ) = self.actor_generate(
507 | states,
508 | max_seq_len = max_seq_len,
509 | eos_token = eos_token,
510 | temperature = temperature,
511 | )
512 |
513 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
514 |
515 | action_prob = action_logits.softmax(dim = -1)
516 |
517 | action_len = actions.shape[-1]
518 | action_log_prob = get_at('b n [l], b n -> b n', action_prob, sequence)
519 | action_log_prob = action_log_prob[:, -action_len:]
520 |
521 | # get reward as given by supervised trained reward model
522 |
523 | sequence = torch.cat((states, actions), dim = 1)
524 |
525 | prompt_length = states.shape[1]
526 | prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
527 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = action_sample_times + 1)
528 |
529 | mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device))
530 |
531 | rewards = self.reward_model(
532 | sequence,
533 | prompt_mask = prompt_mask,
534 | mask = mask,
535 | sample = True
536 | )
537 |
538 | rewards = rewards.float()
539 |
540 | # rewards are normalized for use as advantages
541 | # following Dr. GRPO paper from Sea AI labs, remove the standard deviation
542 |
543 | normalized_rewards = (rewards - rewards.mean()) / (action_sample_times + 1)
544 |
545 | # store memory for learning
546 |
547 | detach_to_cpu_ = lambda t: t.detach().cpu()
548 |
549 | memories.extend([Memory(*memories) for memories in zip(*map(detach_to_cpu_, (
550 | sequence,
551 | prompt_mask,
552 | mask,
553 | action_prob,
554 | action_log_prob,
555 | normalized_rewards,
556 | )))])
557 |
558 | # learn from the stored memories
559 |
560 | if divisible_by(time, update_timesteps):
561 | self.learn(memories)
562 | memories.clear()
563 |
564 | print('dr grpo rlhf training complete')
565 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/implicit_process_reward.py:
--------------------------------------------------------------------------------
1 | # Free Process Rewards without Process Labels
2 | # Yuan et al. https://arxiv.org/abs/2412.01981 - paper that led to Prime
3 |
4 | from __future__ import annotations
5 | from copy import deepcopy
6 |
7 | import torch
8 | from torch.nn import Module
9 | from torch.nn.functional import logsigmoid
10 |
11 | from einops import rearrange
12 |
13 | # helpers
14 |
15 | def exists(v):
16 | return v is not None
17 |
18 | def get_logprob_at(logits, seq):
19 | log_probs = logits.log_softmax(dim = -1)
20 | seq = rearrange(seq, '... -> ... 1')
21 | log_prob = log_probs.gather(-1, seq)
22 | return rearrange(log_prob, '... 1 -> ...')
23 |
24 | class ImplicitPRM(Module):
25 | """ PRM stands for process reward model, an openai paper that shows that rewarding the steps a model takes to its outcome is better than only rewarding based on final answer or outcome. basically same as when a teacher gives you some credit for showing your steps on an exam """
26 |
27 | def __init__(
28 | self,
29 | model: Module,
30 | ref_model: Module | None = None,
31 | beta = 0.1
32 | ):
33 | super().__init__()
34 | self.model = model
35 |
36 | # only drawback to this technique is needing a reference model
37 |
38 | if not exists(ref_model):
39 | ref_model = deepcopy(model)
40 |
41 | self.ref_model = ref_model
42 | ref_model.requires_grad_(False) # insurance
43 |
44 | self.beta = beta
45 |
46 | def parameters(self):
47 | return self.model.parameters() # only main model is trained
48 |
49 | def forward(
50 | self,
51 | seq,
52 | labels = None
53 | ):
54 | source_seq, target_seq = seq[:, :-1], seq[:, 1:]
55 |
56 | mask = target_seq >= 0 # assume any token ids < 0 to be padding
57 |
58 | model_logits = self.model(source_seq)
59 | ref_model_logits = self.ref_model(source_seq)
60 |
61 | log_prob = get_logprob_at(model_logits, target_seq)
62 | ref_log_prob = get_logprob_at(ref_model_logits, target_seq)
63 |
64 | # main formula is DPO-like, and has some connection with Q-learning https://arxiv.org/abs/2404.12358 . it is all connected
65 |
66 | implicit_rewards = self.beta * (log_prob - ref_log_prob)
67 |
68 | # zero out rewards in padding
69 |
70 | implicit_rewards = implicit_rewards.masked_fill(~mask, 0.)
71 |
72 | # early return if not training, as in Prime with alternating model and prm training
73 |
74 | if not exists(labels):
75 | return implicit_rewards
76 |
77 | labels = rearrange(labels, 'b -> b 1')
78 |
79 | # otherwise use the cross entropy formulation from their paper (eq 5)
80 |
81 | loss = (
82 | labels * logsigmoid(implicit_rewards) +
83 | (1. - labels) * logsigmoid(-implicit_rewards) # (1. - sigmoid(x)) == sigmoid(-x)
84 | )
85 |
86 | return loss[mask].mean()
87 |
88 | # make it easy for others to copy paste into another project
89 |
90 | if __name__ == '__main__':
91 | from palm_rlhf_pytorch import PaLM
92 |
93 | palm = PaLM(
94 | num_tokens = 256,
95 | dim = 64,
96 | depth = 2
97 | )
98 |
99 | ref_palm = PaLM(
100 | num_tokens = 256,
101 | dim = 64,
102 | depth = 2
103 | )
104 |
105 | implicit_prm = ImplicitPRM(
106 | palm,
107 | ref_model = ref_palm
108 | )
109 |
110 | # mock data
111 |
112 | seq = torch.randint(0, 256, (2, 1024))
113 | labels = torch.randint(0, 2, (2,))
114 |
115 | loss = implicit_prm(seq, labels)
116 | loss.backward()
117 |
118 | # after much training
119 |
120 | implicit_rewards = implicit_prm(seq) # Float[2, 1024]
121 |
122 | # there you go, free process reward model
123 | # now you can use this dense reward for rlhf, beam search whatever
124 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/lora.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import Module
4 |
5 | # helper functions
6 |
7 | def exists(val):
8 | return val is not None
9 |
10 | def default(val, d):
11 | return val if exists(val) else d
12 |
13 | # LoRA - https://arxiv.org/abs/2106.09685
14 |
15 | class LoRA(Module):
16 | def __init__(
17 | self,
18 | dim,
19 | dim_out,
20 | r = 8,
21 | alpha = None
22 | ):
23 | super().__init__()
24 | alpha = default(alpha, r)
25 | self.scale = alpha / r
26 |
27 | self.A = nn.Parameter(torch.randn(dim, r))
28 | self.B = nn.Parameter(torch.zeros(r, dim_out))
29 |
30 | @property
31 | def weight(self):
32 | return (self.A @ self.B) * self.scale
33 |
34 | def forward(self, x):
35 | return x @ self.weight
36 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/palm.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | import copy
5 | from pathlib import Path
6 | from collections import namedtuple
7 | from functools import wraps
8 | from itertools import zip_longest
9 |
10 | from tqdm import tqdm
11 | from beartype import beartype
12 |
13 | import torch
14 | from torch import einsum, nn
15 | import torch.nn.functional as F
16 | from torch.nn import Module, ModuleList, ModuleDict
17 |
18 | from einops import rearrange, repeat, reduce, pack, unpack
19 | from einops.layers.torch import Rearrange, Reduce
20 |
21 | from palm_rlhf_pytorch.attention import Attention
22 | from palm_rlhf_pytorch.utils import top_p, top_k, masked_mean, gumbel_sample, eval_decorator
23 | from palm_rlhf_pytorch.lora import LoRA
24 |
25 | # functions and decorators
26 |
27 | def exists(val):
28 | return val is not None
29 |
30 | def default(val, d):
31 | return val if exists(val) else d
32 |
33 | def identity(t, *args, **kwargs):
34 | return t
35 |
36 | def l2norm(t):
37 | return F.normalize(t, dim = -1)
38 |
39 | # normalization
40 | # they use layernorm without bias, something that pytorch does not offer
41 |
42 | class LayerNorm(Module):
43 | def __init__(self, dim):
44 | super().__init__()
45 | self.gamma = nn.Parameter(torch.zeros(dim))
46 | self.register_buffer("beta", torch.zeros(dim))
47 |
48 | def forward(self, x):
49 | return F.layer_norm(x, x.shape[-1:], (self.gamma + 1), self.beta)
50 |
51 | # residual
52 |
53 |
54 | class Residual(Module):
55 | def __init__(self, fn):
56 | super().__init__()
57 | self.fn = fn
58 |
59 | def forward(self, x, **kwargs):
60 | y = self.fn(x, **kwargs)
61 |
62 | if not any([t.requires_grad for t in (x, y)]):
63 | return x.add_(y)
64 |
65 | return y + x
66 |
67 | # rotary positional embedding w/ xpos
68 | # https://arxiv.org/abs/2104.09864
69 | # https://arxiv.org/abs/2212.10554v1
70 |
71 | class RotaryEmbedding(Module):
72 | def __init__(self, dim, scale_base = 512, use_xpos = True):
73 | super().__init__()
74 | inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
75 | self.register_buffer("inv_freq", inv_freq)
76 |
77 | self.use_xpos = use_xpos
78 | self.scale_base = scale_base
79 | scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
80 | self.register_buffer('scale', scale)
81 |
82 | def forward(self, seq_len, device):
83 | t = torch.arange(seq_len, device = device).type_as(self.inv_freq)
84 | freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
85 | freqs = torch.cat((freqs, freqs), dim = -1)
86 |
87 | if not self.use_xpos:
88 | return freqs, torch.ones(1, device = device)
89 |
90 | power = (t - (seq_len // 2)) / self.scale_base
91 | scale = self.scale ** rearrange(power, 'n -> n 1')
92 | scale = torch.cat((scale, scale), dim = -1)
93 |
94 | return freqs, scale
95 |
96 | def rotate_half(x):
97 | x1, x2 = x.chunk(2, dim=-1)
98 | return torch.cat((-x2, x1), dim=-1)
99 |
100 |
101 | def apply_rotary_pos_emb(pos, t, scale = 1.):
102 | return (t * pos.cos() * scale) + (rotate_half(t) * pos.sin() * scale)
103 |
104 |
105 | # classic Noam Shazeer paper, except here they use SwiGLU instead of the more popular GEGLU for gating the feedforward
106 | # https://arxiv.org/abs/2002.05202
107 |
108 |
109 | class SwiGLU(Module):
110 | def forward(self, x):
111 | x, gate = x.chunk(2, dim=-1)
112 | return F.silu(gate) * x
113 |
114 |
115 | # parallel attention and feedforward with residual
116 | # discovered by Wang et al + EleutherAI from GPT-J fame
117 |
118 |
119 | class ParallelTransformerBlock(Module):
120 | def __init__(
121 | self,
122 | dim,
123 | dim_head = 64,
124 | causal = True,
125 | heads = 8,
126 | qk_rmsnorm = False,
127 | qk_scale = 8,
128 | ff_mult = 4,
129 | attn_dropout = 0.,
130 | ff_dropout = 0.,
131 | use_xpos = True,
132 | xpos_scale_base = 512,
133 | flash_attn = False,
134 | ):
135 | super().__init__()
136 | self.norm = LayerNorm(dim)
137 |
138 | attn_inner_dim = dim_head * heads
139 | ff_inner_dim = dim * ff_mult
140 | self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
141 |
142 | self.qk_rmsnorm = qk_rmsnorm
143 |
144 | if qk_rmsnorm:
145 | self.q_scale = nn.Parameter(torch.ones(dim_head))
146 | self.k_scale = nn.Parameter(torch.ones(dim_head))
147 |
148 | self.attend = Attention(
149 | causal = causal,
150 | dropout = attn_dropout,
151 | use_flash_attn = flash_attn
152 | )
153 |
154 | self.heads = heads
155 | self.scale = (dim_head ** -0.5) if not qk_rmsnorm else qk_scale
156 | self.causal = causal
157 |
158 | self.rotary_emb = RotaryEmbedding(dim_head, scale_base = xpos_scale_base, use_xpos = use_xpos and causal)
159 |
160 | self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
161 |
162 | self.flash_attn = flash_attn
163 | self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
164 | self.attn_dropout = nn.Dropout(attn_dropout)
165 | self.flash_attn_dropout = attn_dropout
166 |
167 | # parallel feedforward tail
168 |
169 | self.ff_out = nn.Sequential(
170 | SwiGLU(),
171 | nn.Dropout(ff_dropout),
172 | nn.Linear(ff_inner_dim, dim, bias=False)
173 | )
174 |
175 | # for caching causal mask and rotary embeddings
176 |
177 | self.register_buffer("pos_emb", None, persistent=False)
178 | self.register_buffer("pos_emb_scale", None, persistent=False)
179 |
180 | def get_rotary_embedding(self, n, device):
181 | if exists(self.pos_emb) and self.pos_emb.shape[-2] >= n:
182 | return self.pos_emb[:n], self.pos_emb_scale[:n]
183 |
184 | pos_emb, scale = self.rotary_emb(n, device=device)
185 | self.register_buffer("pos_emb", pos_emb, persistent=False)
186 | self.register_buffer("pos_emb_scale", scale, persistent=False)
187 | return pos_emb, scale
188 |
189 | def forward(
190 | self,
191 | x,
192 | mask = None,
193 | finetune_modules = None
194 | ):
195 | """
196 | einstein notation
197 | b - batch
198 | h - heads
199 | n, i, j - sequence length (base sequence length, source, target)
200 | d - feature dimension
201 | """
202 |
203 | n, device, h = x.shape[1], x.device, self.heads
204 |
205 | # pre layernorm
206 |
207 | x = self.norm(x)
208 |
209 | # attention queries, keys, values, and feedforward inner
210 |
211 | q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
212 |
213 | # finetune loras
214 |
215 | lora_q = lora_k = lora_v = lora_o = None
216 |
217 | if exists(finetune_modules):
218 | lora_q, lora_k, lora_v, lora_o = finetune_modules
219 | q = q + lora_q(x)
220 | k = k + lora_k(x)
221 | v = v + lora_v(x)
222 |
223 | # split heads
224 | # they use multi-query single-key-value attention, yet another Noam Shazeer paper
225 | # they found no performance loss past a certain scale, and more efficient decoding obviously
226 | # https://arxiv.org/abs/1911.02150
227 |
228 | q = rearrange(q, "b n (h d) -> b h n d", h=h)
229 |
230 | # qk rmsnorm
231 |
232 | if self.qk_rmsnorm:
233 | q, k = map(l2norm, (q, k))
234 | q = q * self.q_scale
235 | k = k * self.k_scale
236 |
237 | # rotary embeddings with xpos decay for better length extrapolation
238 |
239 | positions, scale = self.get_rotary_embedding(n, device)
240 |
241 | q = apply_rotary_pos_emb(positions, q, scale)
242 | k = apply_rotary_pos_emb(positions, k, scale ** -1)
243 |
244 | # attention function, either regular or flash
245 |
246 | out = self.attend(q, k, v, mask = mask)
247 |
248 | # merge heads
249 |
250 | out = rearrange(out, "b h n d -> b n (h d)")
251 |
252 | attn_out = self.attn_out(out)
253 |
254 | ff_out = self.ff_out(ff)
255 |
256 | if exists(lora_o):
257 | attn_out = attn_out + lora_o(out)
258 |
259 | return attn_out + ff_out
260 |
261 | # transformer
262 |
263 | class PaLM(Module):
264 | @beartype
265 | def __init__(
266 | self,
267 | *,
268 | dim,
269 | num_tokens,
270 | depth,
271 | causal = True,
272 | dim_head = 64,
273 | heads = 8,
274 | ff_mult = 4,
275 | attn_dropout = 0.,
276 | ff_dropout = 0.,
277 | qk_rmsnorm = False,
278 | lora_r = 8,
279 | rotary_xpos_scale_base = 512,
280 | flash_attn = False,
281 | finetune_scopes = tuple(),
282 | cross_entropy_ignore_index = 0
283 | ):
284 | super().__init__()
285 | self.dim = dim
286 | self.dim_head = dim_head
287 | self.heads = heads
288 | self.causal = causal
289 | self.num_tokens = num_tokens
290 |
291 | self.token_emb = nn.Embedding(num_tokens, dim)
292 | self.layers = ModuleList([])
293 |
294 | for _ in range(depth):
295 | block = Residual(ParallelTransformerBlock(
296 | dim = dim,
297 | causal = causal,
298 | dim_head = dim_head,
299 | heads = heads,
300 | qk_rmsnorm = qk_rmsnorm,
301 | ff_mult = ff_mult,
302 | attn_dropout = attn_dropout,
303 | ff_dropout = ff_dropout,
304 | xpos_scale_base = rotary_xpos_scale_base,
305 | flash_attn = flash_attn
306 | ))
307 |
308 | self.layers.append(block)
309 |
310 | self.norm = LayerNorm(dim)
311 | self.to_logits = nn.Linear(dim, num_tokens, bias=False)
312 |
313 | self.to_logits.weight = self.token_emb.weight
314 |
315 | nn.init.normal_(self.token_emb.weight, std=0.02)
316 |
317 | # fine tuning related
318 |
319 | self.lora_r = lora_r
320 | self.finetune_modules = ModuleDict({})
321 |
322 | for scope in finetune_scopes:
323 | self.add_finetune_params(scope)
324 |
325 | # loss related
326 |
327 | self.cross_entropy_ignore_index = cross_entropy_ignore_index
328 |
329 | @property
330 | def device(self):
331 | return next(self.parameters()).device
332 |
333 | def load(self, path):
334 | path = Path(path)
335 | assert path.exists()
336 | self.load_state_dict(torch.load(str(path)))
337 |
338 | def set_dropout(self, dropout):
339 | for module in self.layers.modules():
340 | if isinstance(module, nn.Dropout):
341 | module.p = dropout
342 | return self
343 |
344 | def add_finetune_params(self, scope, lora_r = None):
345 | assert scope not in self.finetune_modules, f'finetune scope {scope} already found'
346 | dim, dim_head, heads, r, device = self.dim, self.dim_head, self.heads, default(lora_r, self.lora_r), self.device
347 |
348 | q_inner_dim = heads * dim_head
349 | kv_inner_dim = dim_head
350 |
351 | lora_modules = ModuleList([])
352 |
353 | for _ in range(len(self.layers)):
354 | lora_modules.append(ModuleList([
355 | LoRA(dim, q_inner_dim, r = r), # queries
356 | LoRA(dim, kv_inner_dim, r = r), # keys
357 | LoRA(dim, kv_inner_dim, r = r), # values
358 | LoRA(q_inner_dim, dim, r = r) # wo
359 | ]))
360 |
361 | self.finetune_modules[scope] = lora_modules.to(device)
362 |
363 | def remove_finetune_params(self, scope):
364 | assert scope in self.finetune_modules, f'finetune scope {scope} not found'
365 | return self.finetune_modules.pop(scope)
366 |
367 | @torch.no_grad()
368 | def merge_finetune_params(self, scope):
369 | """ in the case one wants to merge the fine-tuned actor LORA parameters and do multiple rounds of fine tuning off different reward models """
370 |
371 | assert scope in self.finetune_modules, f'finetune scope {scope} not found'
372 |
373 | lora_modules = self.finetune_modules.pop(scope)
374 |
375 | for layer, (lora_q, lora_k, lora_v, lora_o) in zip(self.layers, lora_modules):
376 | block = layer.fn
377 |
378 | fused_attn_ff_weight = block.fused_attn_ff_proj.weight
379 | attn_out_weight = block.attn_out.weight
380 |
381 | fused_proj_out_dim = fused_attn_ff_weight.shape[0]
382 |
383 | lora_qkv_weight, _ = pack([lora_q.weight, lora_k.weight, lora_v.weight], 'i *')
384 | lora_qkv_weight = F.pad(lora_qkv_weight, (0, fused_proj_out_dim - lora_qkv_weight.shape[1]))
385 |
386 | lora_qkv_weight = rearrange(lora_qkv_weight, 'i o -> o i')
387 | lora_o_weight = rearrange(lora_o.weight, 'i o -> o i')
388 |
389 | fused_attn_ff_weight.add_(lora_qkv_weight)
390 | attn_out_weight.add_(lora_o_weight)
391 |
392 | # researcher train palm parameters first
393 | # before finetuning
394 |
395 | def palm_parameters(self):
396 | return set(self.parameters()) - set(self.finetune_modules.parameters())
397 |
398 | def finetune_parameters(self, scope = 'default'):
399 | assert scope in self.finetune_modules, f'finetune parameters of scope {scope} not found'
400 | return self.finetune_modules[scope].parameters()
401 |
402 | # generate function
403 |
404 | @torch.no_grad()
405 | @eval_decorator
406 | def generate(
407 | self,
408 | seq_len,
409 | prompt = None,
410 | temperature = 1.,
411 | filter_logits_fn = top_k,
412 | filter_thres = 0.9,
413 | pad_value = 0.,
414 | eos_token = None,
415 | return_seq_without_prompt = True,
416 | use_tqdm = False,
417 | **kwargs
418 | ):
419 | if not exists(prompt):
420 | prompt = torch.randint(0, self.num_tokens, (1, 1))
421 | prompt = prompt.to(self.device)
422 | return_seq_without_prompt = False
423 |
424 | prompt, leading_dims = pack([prompt], '* n')
425 |
426 | n, out = prompt.shape[-1], prompt.clone()
427 |
428 | wrapper_fn = identity if not use_tqdm else tqdm
429 | sample_num_times = max(1, seq_len - prompt.shape[-1])
430 |
431 | for _ in wrapper_fn(range(sample_num_times)):
432 | logits, embeds = self.forward(out, return_logits_with_embedding = True, **kwargs)
433 | logits, embeds = logits[:, -1], embeds[:, -1]
434 |
435 | if exists(filter_logits_fn):
436 | logits = filter_logits_fn(logits, thres = filter_thres)
437 |
438 | sample = gumbel_sample(logits, temperature = temperature, dim = -1)
439 | out, _ = pack([out, sample], 'b *')
440 |
441 | if exists(eos_token):
442 | is_eos_tokens = (out == eos_token)
443 |
444 | if is_eos_tokens.any(dim = -1).all():
445 | # mask out everything after the eos tokens
446 | shifted_is_eos_tokens = F.pad(is_eos_tokens, (1, -1))
447 | mask = shifted_is_eos_tokens.float().cumsum(dim = -1) >= 1
448 | out = out.masked_fill(mask, pad_value)
449 | break
450 |
451 | out, = unpack(out, leading_dims, '* n')
452 |
453 | if not return_seq_without_prompt:
454 | return out
455 |
456 | return out[..., n:]
457 |
458 | def forward(
459 | self,
460 | x,
461 | return_loss = False,
462 | disable_lora = False,
463 | finetune_scope = None,
464 | extra_embed = None,
465 | return_only_embedding = False,
466 | return_logits_with_embedding = False
467 | ):
468 | if return_loss:
469 | x, labels = x[:, :-1], x[:, 1:]
470 |
471 | # mask if encoder
472 | # treat any token ids that are negative as tokens to mask out - only needed if not autoregressive
473 |
474 | if not self.causal:
475 | mask = x >= 0
476 | x = x.masked_fill(~mask, 0)
477 | else:
478 | mask = None
479 |
480 | # get token embedding
481 |
482 | x = self.token_emb(x)
483 |
484 | if exists(extra_embed):
485 | x = x + extra_embed
486 |
487 | # finetune modules
488 |
489 | finetune_modules = tuple()
490 | if exists(finetune_scope) and not disable_lora:
491 | assert finetune_scope in self.finetune_modules
492 | finetune_modules = self.finetune_modules[finetune_scope]
493 |
494 | # parallel attention / ff blocks, passing in finetuning loras
495 |
496 | for layer, finetune_modules in zip_longest(self.layers, finetune_modules):
497 | x = layer(x, mask = mask, finetune_modules = finetune_modules)
498 |
499 | # final norm
500 |
501 | embeds = self.norm(x)
502 |
503 | if return_only_embedding:
504 | return embeds
505 |
506 | # to logits
507 |
508 | logits = self.to_logits(embeds)
509 |
510 | ret = (logits, embeds) if return_logits_with_embedding else logits
511 |
512 | if not return_loss:
513 | return ret
514 |
515 | logits = rearrange(logits, 'b n c -> b c n')
516 | return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/ppo.py:
--------------------------------------------------------------------------------
1 | from __future__ import annotations
2 |
3 | import math
4 | from pathlib import Path
5 | import copy
6 | from accelerate.utils.tqdm import tqdm
7 | from functools import partial
8 | from collections import deque, namedtuple
9 | from random import randrange
10 |
11 | from beartype import beartype
12 | from beartype.typing import Callable, Deque
13 |
14 | import torch
15 | from torch import nn, Tensor
16 | from torch.nn import Module
17 | import torch.nn.functional as F
18 |
19 | from torch.utils.data import Dataset, DataLoader
20 | from torch.nn.utils.rnn import pad_sequence
21 |
22 | from einops import rearrange, repeat, reduce
23 | from einops.layers.torch import Rearrange
24 |
25 | from adam_atan2_pytorch import AdoptAtan2
26 |
27 | from hl_gauss_pytorch import HLGaussLoss
28 |
29 | from palm_rlhf_pytorch.palm import PaLM
30 | from palm_rlhf_pytorch.reward import RewardModel
31 | from palm_rlhf_pytorch.implicit_process_reward import ImplicitPRM
32 | from palm_rlhf_pytorch.utils import masked_mean, eval_decorator
33 | from accelerate import Accelerator
34 |
35 | # actor critic - PaLM with lora
36 |
37 | PPOActionCriticReturn = namedtuple('PPOActionCriticReturn', [
38 | 'actions',
39 | 'sequence',
40 | 'mask',
41 | 'prompt_mask',
42 | 'action_logits',
43 | 'values'
44 | ])
45 |
46 | class ActorCritic(Module):
47 | @beartype
48 | def __init__(
49 | self,
50 | palm: PaLM,
51 | critic: PaLM | ImplicitPRM | None = None,
52 | pooled_values = False,
53 | actor_lora = True,
54 | critic_lora = True,
55 | actor_lora_r = 8,
56 | critic_lora_r = 8,
57 | actor_lora_scope = 'actor',
58 | critic_lora_scope = 'critic',
59 | actor_dropout = 0.,
60 | critic_dropout = 0.,
61 | critic_dim_out = 6 # rewards from 0 to 5
62 | ):
63 | super().__init__()
64 | self.actor_palm = palm
65 |
66 | # detect implicit prm and auto-set some hyperparameters
67 |
68 | critic_is_prm = isinstance(critic, ImplicitPRM)
69 |
70 | critic_lora &= not critic_is_prm
71 | pooled_values |= critic_is_prm
72 |
73 | self.critic_is_prm = critic_is_prm
74 |
75 | # critic
76 |
77 | self.critic = critic
78 |
79 | if not exists(self.critic):
80 | self.critic = copy.deepcopy(palm)
81 |
82 | self.actor_palm.set_dropout(actor_dropout)
83 |
84 | if not critic_is_prm:
85 | self.critic.set_dropout(critic_dropout)
86 |
87 | self.actor_lora = actor_lora
88 | self.critic_lora = critic_lora
89 |
90 | self.actor_lora_scope = actor_lora_scope if actor_lora else None
91 | self.critic_lora_scope = critic_lora_scope if critic_lora else None
92 |
93 | if self.actor_lora:
94 | self.actor_palm.add_finetune_params(actor_lora_scope, lora_r = actor_lora_r)
95 |
96 | if self.critic_lora:
97 | self.critic.add_finetune_params(critic_lora_scope, lora_r = critic_lora_r)
98 |
99 | self.pooled_values = pooled_values
100 | self.value_head = nn.Identity()
101 |
102 | if not critic_is_prm:
103 | assert critic_dim_out > 1
104 | self.value_head = nn.Linear(palm.dim, critic_dim_out)
105 |
106 | nn.init.zeros_(self.value_head.bias)
107 | nn.init.orthogonal_(self.value_head.weight, gain = math.sqrt(2))
108 |
109 | def actor_parameters(self):
110 | if not self.actor_lora:
111 | return self.actor_palm.parameters()
112 |
113 | return [
114 | *self.actor_palm.finetune_parameters(self.actor_lora_scope)
115 | ]
116 |
117 | def critic_parameters(self):
118 | if self.critic_is_prm:
119 | return self.critic.parameters()
120 |
121 | if not self.actor_lora:
122 | return [*self.critic.parameters(), *self.value_head.parameters()]
123 |
124 | return [
125 | *self.critic.finetune_parameters(self.critic_lora_scope),
126 | *self.value_head.parameters()
127 | ]
128 |
129 | @torch.no_grad()
130 | @eval_decorator
131 | def generate(
132 | self,
133 | state,
134 | max_seq_len,
135 | eos_token = None,
136 | return_values = False,
137 | **kwargs
138 | ):
139 | actions = self.actor_palm.generate(
140 | max_seq_len,
141 | prompt = state,
142 | eos_token = eos_token,
143 | finetune_scope = self.actor_lora_scope,
144 | use_tqdm = True,
145 | **kwargs
146 | )
147 |
148 | sequence = torch.cat((state, actions), dim = -1)
149 | action_len = actions.shape[-1]
150 | state_len = state.shape[-1]
151 |
152 | prompt_mask = torch.arange(sequence.shape[-1], device = state.device) < state_len
153 | prompt_mask = repeat(prompt_mask, 'n -> b n', b = sequence.shape[0])
154 |
155 | action_mask = ~prompt_mask
156 |
157 | mask = None
158 | if exists(eos_token):
159 | mask = ((sequence == eos_token).cumsum(dim = -1) == 0)
160 | mask = F.pad(mask, (1, -1), value = True) # include eos token
161 | action_mask &= mask
162 |
163 | action_logits, value = self.forward(
164 | sequence,
165 | mask = action_mask,
166 | return_values = return_values
167 | )
168 |
169 | return PPOActionCriticReturn(
170 | actions,
171 | sequence,
172 | mask,
173 | prompt_mask,
174 | action_logits,
175 | value
176 | )
177 |
178 | def forward(
179 | self,
180 | x,
181 | mask = None,
182 | return_values = True
183 | ):
184 | action_logits = self.actor_palm(
185 | x,
186 | finetune_scope = self.actor_lora_scope
187 | )
188 |
189 | if not return_values:
190 | return action_logits, None
191 |
192 | if self.critic_is_prm:
193 | values = self.critic(x)
194 | return action_logits, values
195 |
196 | critic_embeds = self.critic(
197 | x,
198 | return_only_embedding = True,
199 | finetune_scope = self.critic_lora_scope
200 | )
201 |
202 | if self.pooled_values:
203 | critic_embeds = shift(critic_embeds, shift = 1, dim = -2)
204 | critic_embeds = masked_mean(critic_embeds, mask, dim = 1)
205 |
206 | values = self.value_head(critic_embeds)
207 |
208 | return action_logits, values
209 |
210 | # data
211 |
212 | Memory = namedtuple('Memory', [
213 | 'sequence',
214 | 'prompt_mask',
215 | 'mask',
216 | 'action_prob',
217 | 'action_log_prob',
218 | 'reward',
219 | 'value'
220 | ])
221 |
222 | class ExperienceDataset(Dataset):
223 | @beartype
224 | def __init__(
225 | self,
226 | data,
227 | device = None
228 | ):
229 | super().__init__()
230 | self.data = data
231 | self.device = device
232 |
233 | def __len__(self):
234 | return self.data[0].shape[0]
235 |
236 | def __getitem__(self, ind):
237 | return tuple(map(lambda t: t[ind].to(self.device), self.data))
238 |
239 | def create_dataloader(data, batch_size, shuffle = True, device = None, **kwargs):
240 | ds = ExperienceDataset(data, device = device)
241 | return DataLoader(ds, batch_size = batch_size, shuffle = shuffle, **kwargs)
242 |
243 | # helper functions
244 |
245 | def exists(val):
246 | return val is not None
247 |
248 | def default(val, d):
249 | if exists(val):
250 | return val
251 | return d() if callable(d) else d
252 |
253 | def masked_normalize(t, eps = 1e-5, mask = None, dim = None):
254 | dim = default(dim, tuple(range(t.ndim)))
255 | kwargs = dict(dim = dim, keepdim = True)
256 |
257 | mean = masked_mean(t, mask = mask, **kwargs)
258 | mean_centered = t - mean
259 | var = masked_mean(mean_centered ** 2, mask = mask, **kwargs)
260 |
261 | return mean_centered * var.clamp(min = eps).rsqrt()
262 |
263 | def pad_sequence_fixed(sequences, *args, **kwargs):
264 | first_el = sequences[0]
265 | has_no_dimension = first_el.ndim == 0
266 |
267 | # if no dimensions, add a single dimension
268 | if has_no_dimension:
269 | sequences = tuple(map(lambda t: t[None], sequences))
270 |
271 | out = pad_sequence(sequences, *args, **kwargs)
272 |
273 | if has_no_dimension:
274 | out = rearrange(out, '... 1 -> ...')
275 |
276 | return out
277 |
278 | def log(t, eps = 1e-20):
279 | return torch.log(t.clamp(min = eps))
280 |
281 | def log_prob(prob, indices):
282 | assert prob.shape[:2] == indices.shape, f'preceding shapes of prob {prob.shape[:2]} and indices {indices.shape} must match'
283 | return log(prob.gather(-1, indices[..., None])).squeeze(-1)
284 |
285 | def shift(t, value = 0, shift = 1, dim = -1):
286 | zeros = (0, 0) * (-dim - 1)
287 | return F.pad(t, (*zeros, shift, -shift), value = value)
288 |
289 | def masked_entropy(prob, dim = -1, mask = None):
290 | entropies = (prob * log(prob)).sum(dim = -1)
291 | return masked_mean(entropies, mask = mask).mean()
292 |
293 | def masked_kl_div(prob1, prob2, mask = None, reduce_batch = False):
294 | """
295 | need to account for variable sequence lengths, therefore not using the built-in functional version
296 | """
297 | kl_divs = (prob1 * (log(prob1) - log(prob2))).sum(dim = -1)
298 | loss = masked_mean(kl_divs, mask)
299 |
300 | if reduce_batch:
301 | return loss.mean()
302 |
303 | return loss
304 |
305 | # rlhf trainer
306 |
307 | class RLHFTrainer(Module):
308 | @beartype
309 | def __init__(
310 | self,
311 | *,
312 | prompts: list[str] | None = None,
313 | prompts_path: str | None = None,
314 | prompt_token_ids: Tensor | None = None,
315 | tokenizer: Callable | None = None,
316 | palm: PaLM,
317 | reward_model: RewardModel,
318 | critic: PaLM | ImplicitPRM | None = None,
319 | actor_critic: ActorCritic | None = None,
320 | actor_lr = 1e-4,
321 | critic_lr = 1e-4,
322 | actor_wd = 0.,
323 | critic_wd = 0.,
324 | actor_lora = True,
325 | critic_lora = True,
326 | actor_lora_r = 8,
327 | critic_lora_r = 8,
328 | critic_pooled_values = True,
329 | actor_dropout = 0.,
330 | critic_dropout = 0.,
331 | betas = (0.9, 0.999),
332 | max_norm = None,
333 | eps_clip = 0.2,
334 | value_clip = 0.4,
335 | beta_s = .01,
336 | pad_value = 0.,
337 | minibatch_size = 16,
338 | epochs = 1,
339 | kl_div_loss_weight = 0.1, # between old action probs and new action probs - not sure what the right value is
340 | accelerate_kwargs: dict = dict(),
341 | critic_num_pred_bins = 6,
342 | hl_gauss_loss_kwargs: dict = dict(
343 | min_value = 0.,
344 | max_value = 5.,
345 | min_max_value_on_bin_center = True,
346 | clamp_to_range = True,
347 | sigma_to_bin_ratio = 1.
348 | )
349 | ):
350 | super().__init__()
351 |
352 | self.accelerate = Accelerator(**accelerate_kwargs)
353 |
354 | # take care of prompts -> token ids
355 |
356 | assert (exists(prompts) + exists(prompts_path) + exists(prompt_token_ids)) == 1
357 |
358 | if exists(prompts_path):
359 | path = Path(prompts_path)
360 | prompts = path.read_text().split('\n')
361 |
362 | if exists(prompts):
363 | assert len(prompts) > 0, 'no prompts'
364 | assert exists(tokenizer), 'tokenizer must be passed in if raw text prompts are given'
365 | prompt_token_ids = tokenizer(prompts)
366 |
367 | self.pad_value = pad_value # token pad value
368 | self.num_prompts = prompt_token_ids.shape[0]
369 | self.register_buffer('prompt_token_ids', prompt_token_ids)
370 |
371 | # models
372 |
373 | self.palm = palm
374 |
375 | if not exists(actor_critic):
376 | actor_critic = ActorCritic(
377 | palm = palm,
378 | critic = critic,
379 | actor_lora = actor_lora,
380 | critic_lora = critic_lora,
381 | actor_lora_r = actor_lora_r,
382 | critic_lora_r = critic_lora_r,
383 | pooled_values = critic_pooled_values,
384 | actor_dropout = actor_dropout,
385 | critic_dropout = critic_dropout,
386 | critic_dim_out = critic_num_pred_bins
387 | ).to(palm.device)
388 |
389 | self.actor_critic = actor_critic
390 |
391 | self.actor_critic_generate = actor_critic.generate
392 |
393 | self.reward_model = reward_model.eval()
394 |
395 | # critic outputs reward bin prediction
396 | # for classification loss, buying into "Stop Regressing" paper from Farebrother et al. https://arxiv.org/abs/2403.03950
397 |
398 | self.critic_hl_gauss_loss = HLGaussLoss(num_bins = critic_num_pred_bins, **hl_gauss_loss_kwargs)
399 |
400 | # train hyperparameters
401 |
402 | self.epochs = epochs
403 | self.minibatch_size = minibatch_size
404 | self.max_norm = max_norm
405 |
406 | self.kl_div_loss_weight = kl_div_loss_weight
407 |
408 | # optimizers
409 |
410 | self.actor_optim = AdoptAtan2(actor_critic.actor_parameters(), lr = actor_lr, weight_decay = actor_wd, betas = betas)
411 | self.critic_optim = AdoptAtan2(actor_critic.critic_parameters(), lr = critic_lr, weight_decay = critic_wd, betas = betas)
412 |
413 | # ppo hyperparams
414 |
415 | self.eps_clip = eps_clip
416 | self.value_clip = value_clip
417 | self.beta_s = beta_s
418 |
419 | # prepare with accelerator
420 |
421 | (
422 | self.actor_critic,
423 | self.reward_model,
424 | self.actor_optim,
425 | self.critic_optim
426 | ) = self.accelerate.prepare(
427 | self.actor_critic,
428 | self.reward_model,
429 | self.actor_optim,
430 | self.critic_optim
431 | )
432 |
433 |
434 | def print(self, msg):
435 | return self.accelerate.print(msg)
436 |
437 | def save(self, filepath = './checkpoint.pt'):
438 | torch.save(self.actor_critic.state_dict(), filepath)
439 |
440 | def load(self, filepath = './checkpoint.pt'):
441 | state_dict = torch.load(filepath)
442 | self.actor_critic.load_state_dict(state_dict)
443 |
444 | @property
445 | def device(self):
446 | return self.accelerate.device
447 |
448 | @torch.no_grad()
449 | def generate(
450 | self,
451 | max_seq_len,
452 | *args,
453 | prompt,
454 | num_samples = 4, # sample 4 per prompt and select the one with highest reward
455 | **kwargs
456 | ):
457 | assert prompt.ndim == 1, 'only one prompt allowed at a time for now'
458 | prompt = repeat(prompt, 'n -> b n', b = num_samples)
459 |
460 | actor_critic = self.accelerate.unwrap_model(self.actor_critic)
461 | reward_model = self.accelerate.unwrap_model(self.reward_model)
462 |
463 | actor_critic.eval()
464 |
465 | (
466 | actions,
467 | sequences,
468 | mask,
469 | prompt_mask,
470 | action_logits,
471 | _
472 | ) = actor_critic.generate(
473 | prompt,
474 | *args,
475 | max_seq_len = max_seq_len,
476 | return_values = False,
477 | **kwargs
478 | )
479 |
480 | rewards = reward_model(
481 | sequences,
482 | prompt_mask = prompt_mask,
483 | mask = mask
484 | )
485 |
486 | best_sequence_index = rewards.topk(1, dim = -1).indices
487 |
488 | best_sequence = sequences[best_sequence_index]
489 | best_sequence = rearrange(best_sequence, '1 ... -> ...')
490 |
491 | return best_sequence
492 |
493 | def learn(
494 | self,
495 | memories: Deque[Memory]
496 | ):
497 | # stack all data stored in the memories
498 |
499 | all_memories_stacked_and_padded = list(map(partial(pad_sequence_fixed, batch_first = True), zip(*memories)))
500 |
501 | # prepare dataloader for policy phase training
502 |
503 | dl = create_dataloader(all_memories_stacked_and_padded, self.minibatch_size, device = self.device)
504 |
505 | self.actor_critic.train()
506 |
507 | # PPO training
508 |
509 | for _ in range(self.epochs):
510 |
511 | for (
512 | sequences,
513 | prompt_masks,
514 | masks,
515 | old_action_probs,
516 | old_log_probs,
517 | rewards,
518 | old_values_bins
519 | ) in dl:
520 |
521 | action_masks = ~prompt_masks & masks
522 |
523 | action_logits, values_bins = self.actor_critic(
524 | sequences,
525 | mask = action_masks
526 | )
527 |
528 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
529 | action_len = old_log_probs.shape[-1]
530 |
531 | action_probs = action_logits.softmax(dim = -1)
532 | action_log_probs = log_prob(action_probs, sequences)
533 | action_log_probs = action_log_probs[:, -action_len:]
534 |
535 | # calculate entropies, taking into account which part of the sequence is actually an action
536 |
537 | entropies = masked_entropy(action_probs, mask = action_masks)
538 |
539 | # calculate kl div between old action probs and new ones, taking into account which part of the sequence is action or not
540 |
541 | kl_penalty = 0.
542 |
543 | if self.kl_div_loss_weight > 0:
544 | kl_penalty = masked_kl_div(old_action_probs, action_probs, mask = action_masks) * self.kl_div_loss_weight
545 |
546 | # subtract the kl penalty from the rewards
547 |
548 | rewards = rewards - kl_penalty
549 |
550 | # convert binned value predictions to scalar value
551 |
552 | to_pred_value = self.critic_hl_gauss_loss.transform_from_logits
553 |
554 | old_values, values = map(to_pred_value, (old_values_bins, values_bins))
555 |
556 | # handle non-pooled values
557 |
558 | normalize_kwargs = dict()
559 |
560 | if old_values.ndim == 2:
561 | old_values, values = map(lambda t: shift(t, shift = 1, dim = -2), (old_values, values))
562 |
563 | old_values = old_values[:, -action_len:]
564 | values = values[:, -action_len:]
565 | rewards = rearrange(rewards, 'b -> b 1')
566 | normalize_kwargs = dict(dim = -1, mask = action_masks[:, -action_len:])
567 |
568 | if values.ndim < rewards.ndim:
569 | values = rearrange(values, '... -> ... 1')
570 |
571 | # calculate clipped surrogate objective, classic PPO loss
572 |
573 | ratios = (action_log_probs - old_log_probs).exp()
574 | advantages = masked_normalize(rewards - old_values, **normalize_kwargs)
575 |
576 | if advantages.ndim == 1:
577 | advantages = rearrange(advantages, 'b -> b 1')
578 |
579 | surr1 = ratios * advantages
580 | surr2 = ratios.clamp(1 - self.eps_clip, 1 + self.eps_clip) * advantages
581 | policy_loss = - torch.min(surr1, surr2) - self.beta_s * entropies
582 |
583 | # combine losses
584 |
585 | loss = policy_loss.mean()
586 |
587 | # update actor
588 |
589 | self.accelerate.backward(loss)
590 |
591 | self.print(f'policy_loss: {loss.item():.3f}')
592 |
593 | if exists(self.max_norm):
594 | self.accelerator.clip_grad_norm_(self.actor_critic.actor_parameters(), self.max_norm)
595 |
596 | self.actor_optim.step()
597 | self.actor_optim.zero_grad()
598 |
599 | # calculate clipped value loss and update value network separate from policy network
600 |
601 | value_clipped = old_values + (values - old_values).clamp(-self.value_clip, self.value_clip)
602 |
603 | rewards.detach_()
604 |
605 | value_loss_1 = self.critic_hl_gauss_loss(value_clipped, rewards, reduction = 'none')
606 | value_loss_2 = self.critic_hl_gauss_loss(values_bins, rewards, reduction = 'none')
607 |
608 | value_loss = torch.maximum(value_loss_1, value_loss_2).mean()
609 |
610 | self.print(f'critic_loss: {value_loss.item():.3f}')
611 |
612 | self.accelerate.backward(value_loss)
613 |
614 | if exists(self.max_norm):
615 | self.accelerator.clip_grad_norm_(self.actor_critic.critic_parameters(), self.max_norm)
616 |
617 | self.critic_optim.step()
618 | self.critic_optim.zero_grad()
619 |
620 | def train(
621 | self,
622 | num_episodes = 50000,
623 | max_timesteps = 500,
624 | update_timesteps = 5000,
625 | max_batch_size = 16,
626 | max_seq_len = 2048,
627 | eos_token = None,
628 | temperature = 1.
629 | ):
630 | device = self.device
631 |
632 | time = 0
633 | memories = deque([])
634 |
635 | for eps in tqdm(range(num_episodes), desc = 'episodes'):
636 | for timestep in range(max_timesteps):
637 | time += 1
638 |
639 | # select a bunch of random states (prompts)
640 | # and get the action (sampled sequence from palm as well as the action probs)
641 | # also calculate the reward using reward model and store
642 |
643 | rand_prompt_index = randrange(0, self.num_prompts)
644 |
645 | state = self.prompt_token_ids[rand_prompt_index]
646 |
647 | # remove padding from state
648 |
649 | state_mask = state != self.pad_value
650 | state = state[state_mask]
651 |
652 | # get predicted sequence
653 |
654 | (
655 | actions,
656 | sequence,
657 | mask,
658 | prompt_mask,
659 | action_logits,
660 | values_bins
661 | ) = self.actor_critic_generate(
662 | rearrange(state, 'n ... -> 1 n ...'),
663 | max_seq_len = max_seq_len,
664 | eos_token = eos_token,
665 | temperature = temperature,
666 | return_values = True
667 | )
668 |
669 | action_logits = shift(action_logits, shift = 1, dim = -2) # need to shift along sequence dimension by 1, since actions start from the last prompt (state) token
670 |
671 | action_prob = action_logits.softmax(dim = -1)
672 |
673 | action_len = actions.shape[-1]
674 | action_log_prob = log_prob(action_prob, sequence)
675 | action_log_prob = action_log_prob[:, -action_len:]
676 |
677 | actions = rearrange(actions, '1 ... -> ...')
678 |
679 | # get reward as given by supervised trained reward model
680 |
681 | sequence = torch.cat((state, actions), dim = 0)
682 |
683 | prompt_length = len(state)
684 | prompt_mask = torch.arange(sequence.shape[-1], device = device) < prompt_length
685 |
686 | sequence = rearrange(sequence, 'n -> 1 n')
687 | prompt_mask = rearrange(prompt_mask, 'n -> 1 n')
688 | mask = default(mask, lambda: torch.ones(sequence.shape, dtype = torch.bool, device = device))
689 |
690 | reward = self.reward_model(
691 | sequence,
692 | prompt_mask = prompt_mask,
693 | mask = mask
694 | )
695 |
696 | detach_to_cpu_ = lambda t: rearrange(t.detach().cpu(), '1 ... -> ...')
697 |
698 | # store memory for learning
699 |
700 | memories.append(Memory(*map(detach_to_cpu_, (
701 | sequence,
702 | prompt_mask,
703 | mask,
704 | action_prob,
705 | action_log_prob,
706 | reward,
707 | values_bins
708 | ))))
709 |
710 | # learn from the stored memories
711 |
712 | if time % update_timesteps == 0:
713 | self.learn(memories)
714 | memories.clear()
715 |
716 | print('rlhf training complete')
717 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/reward.py:
--------------------------------------------------------------------------------
1 | import copy
2 | from pathlib import Path
3 |
4 | from tqdm import tqdm
5 | from beartype import beartype
6 |
7 | import torch
8 | from torch import nn
9 | from torch.nn import Module
10 | import torch.nn.functional as F
11 |
12 | from einops import rearrange, repeat, reduce, pack, unpack
13 | from einops.layers.torch import Rearrange, Reduce
14 |
15 | from palm_rlhf_pytorch.utils import masked_mean, gumbel_sample
16 | from palm_rlhf_pytorch.palm import PaLM
17 |
18 | # helper functions
19 |
20 | def exists(val):
21 | return val is not None
22 |
23 | def default(val, default_val):
24 | return val if exists(val) else default_val
25 |
26 | # Reward Model - PaLM with a scalar head
27 |
28 | class RewardModel(Module):
29 | @beartype
30 | def __init__(
31 | self,
32 | palm: PaLM,
33 | dropout = 0.1,
34 | num_binned_output = 0.,
35 | use_lora = True,
36 | lora_r = 8,
37 | reward_lora_scope = 'reward',
38 | sample_from_bins = None,
39 | sample_temperature = 1.
40 | ):
41 | super().__init__()
42 |
43 | self.palm = copy.deepcopy(palm)
44 | self.palm.set_dropout(dropout)
45 |
46 | self.reward_lora_scope = reward_lora_scope if use_lora else None
47 |
48 | if exists(self.reward_lora_scope):
49 | self.palm.add_finetune_params(reward_lora_scope, lora_r = lora_r)
50 |
51 | dim = palm.dim
52 |
53 | self.binned_output = num_binned_output > 1
54 |
55 | self.prompt_embed = nn.Parameter(torch.zeros(1, 1, dim))
56 | self.response_embed = nn.Parameter(torch.zeros(1, 1, dim))
57 |
58 | if self.binned_output:
59 | self.to_pred = nn.Linear(dim, num_binned_output)
60 | else:
61 | self.to_pred = nn.Sequential(
62 | nn.Linear(dim, 1, bias = False),
63 | Rearrange('... 1 -> ...')
64 | )
65 |
66 | self.sample_from_bins = default(sample_from_bins, self.binned_output)
67 | self.sample_temperature = sample_temperature
68 |
69 | def load(self, path):
70 | path = Path(path)
71 | assert path.exists()
72 | self.load_state_dict(torch.load(str(path)))
73 |
74 | def finetune_parameters(self):
75 | return [
76 | *self.to_pred.parameters(),
77 | *(self.palm.finetune_parameters(self.reward_lora_scope) if exists(self.reward_lora_scope) else self.palm.parameters())
78 | ]
79 |
80 | def forward(
81 | self,
82 | x,
83 | mask = None,
84 | prompt_mask = None,
85 | prompt_lengths = None,
86 | labels = None,
87 | disable_lora = False
88 | ):
89 |
90 | assert not (exists(prompt_mask) and exists(prompt_lengths))
91 |
92 | # derive prompt mask from prompt lengths
93 |
94 | if exists(prompt_lengths):
95 | batch, seq_len = x.shape
96 | arange = torch.arange(seq_len, device = x.device)
97 | prompt_mask = repeat(arange, 'n -> b n', b = batch) < rearrange(prompt_lengths, 'b -> b 1')
98 |
99 | # reward model should have an understanding of which section is prompt, and which section is response
100 |
101 | extra_embed = None
102 |
103 | if exists(prompt_mask):
104 | extra_embed = torch.where(
105 | rearrange(prompt_mask, 'b n -> b n 1'),
106 | self.prompt_embed,
107 | self.response_embed
108 | )
109 |
110 | # get embeddings from palm
111 |
112 | embeds = self.palm(
113 | x,
114 | extra_embed = extra_embed,
115 | return_only_embedding = True,
116 | disable_lora = disable_lora,
117 | finetune_scope = self.reward_lora_scope
118 | )
119 |
120 | pooled = masked_mean(embeds, mask, dim = 1)
121 | pred = self.to_pred(pooled)
122 |
123 | if self.sample_from_bins and self.binned_output:
124 | assert not exists(labels)
125 | pred = gumbel_sample(pred, temperature = self.sample_temperature, dim = -1)
126 |
127 | if not exists(labels):
128 | return pred
129 |
130 | if not self.binned_output:
131 | return F.mse_loss(pred, labels)
132 |
133 | return F.cross_entropy(pred, labels)
134 |
--------------------------------------------------------------------------------
/palm_rlhf_pytorch/utils.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import einsum, nn
4 | import torch.nn.functional as F
5 |
6 | from einops import rearrange
7 |
8 | def exists(val):
9 | return val is not None
10 |
11 | # decorators
12 |
13 | def eval_decorator(fn):
14 | def inner(self, *args, **kwargs):
15 | was_training = self.training
16 | self.eval()
17 | out = fn(self, *args, **kwargs)
18 | self.train(was_training)
19 | return out
20 | return inner
21 |
22 | # tensor helpers
23 |
24 | def log(t, eps = 1e-20):
25 | return torch.log(t.clamp(min = eps))
26 |
27 | def masked_mean(seq, mask = None, dim = 1, keepdim = False):
28 | if not exists(mask):
29 | return seq.mean(dim = dim)
30 |
31 | if seq.ndim == 3:
32 | mask = rearrange(mask, 'b n -> b n 1')
33 |
34 | masked_seq = seq.masked_fill(~mask, 0.)
35 | numer = masked_seq.sum(dim = dim, keepdim = keepdim)
36 | denom = mask.sum(dim = dim, keepdim = keepdim)
37 |
38 | masked_mean = numer / denom.clamp(min = 1e-3)
39 | masked_mean = masked_mean.masked_fill(denom == 0, 0.)
40 | return masked_mean
41 |
42 | # sampling helpers
43 |
44 | def gumbel_noise(t):
45 | noise = torch.zeros_like(t).uniform_(0, 1)
46 | return -log(-log(noise))
47 |
48 | def gumbel_sample(t, temperature = 1., dim = -1):
49 | return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
50 |
51 | def top_p(logits, thres = 0.9):
52 | sorted_logits, sorted_indices = torch.sort(logits, descending=True)
53 | cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
54 |
55 | sorted_indices_to_remove = cum_probs > (1 - thres)
56 | sorted_indices_to_remove[:, 1:] = sorted_indices_to_remove[:, :-1].clone()
57 | sorted_indices_to_remove[:, 0] = 0
58 |
59 | sorted_logits[sorted_indices_to_remove] = float('-inf')
60 | return sorted_logits.scatter(1, sorted_indices, sorted_logits)
61 |
62 | def top_k(logits, thres = 0.9):
63 | k = math.ceil((1 - thres) * logits.shape[-1])
64 | val, ind = torch.topk(logits, k)
65 | probs = torch.full_like(logits, float('-inf'))
66 | probs.scatter_(1, ind, val)
67 | return probs
68 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 | setup(
4 | name = 'PaLM-rlhf-pytorch',
5 | packages = find_packages(exclude=[]),
6 | version = '0.5.3',
7 | license='MIT',
8 | description = 'PaLM + Reinforcement Learning with Human Feedback - Pytorch',
9 | author = 'Phil Wang',
10 | author_email = 'lucidrains@gmail.com',
11 | long_description_content_type = 'text/markdown',
12 | url = 'https://github.com/lucidrains/PaLM-rlhf-pytorch',
13 | keywords = [
14 | 'artificial intelligence',
15 | 'deep learning',
16 | 'transformers',
17 | 'attention mechanism',
18 | 'reinforcement learning',
19 | 'human feedback'
20 | ],
21 | install_requires=[
22 | 'accelerate',
23 | 'adam-atan2-pytorch',
24 | 'beartype',
25 | 'einx>=0.3.0',
26 | 'einops>=0.8.0',
27 | 'hl-gauss-pytorch>=0.1.19',
28 | 'torch>=2.2',
29 | 'tqdm'
30 | ],
31 | classifiers=[
32 | 'Development Status :: 4 - Beta',
33 | 'Intended Audience :: Developers',
34 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
35 | 'License :: OSI Approved :: MIT License',
36 | 'Programming Language :: Python :: 3.6',
37 | ],
38 | )
39 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import gzip
2 | import random
3 | from accelerate.utils.tqdm import tqdm
4 | import numpy as np
5 |
6 | import torch
7 | from lion_pytorch import Lion
8 | from torch.nn import functional as F
9 | from torch.utils.data import DataLoader, Dataset
10 |
11 | from palm_rlhf_pytorch import PaLM
12 | from accelerate import Accelerator
13 |
14 | # constants
15 |
16 | NUM_BATCHES = int(1e5)
17 | BATCH_SIZE = 4
18 | GRADIENT_ACCUMULATE_EVERY = 4
19 | LEARNING_RATE = 1e-4
20 | VALIDATE_EVERY = 100
21 | PRIME_LENGTH = 128
22 | GENERATE_EVERY = 500
23 | GENERATE_LENGTH = 512
24 | SEQ_LEN = 1024
25 |
26 | # helpers
27 |
28 | def cycle(loader):
29 | while True:
30 | for data in loader:
31 | yield data
32 |
33 | def decode_token(token):
34 | return str(chr(max(32, token)))
35 |
36 | def decode_tokens(tokens):
37 | return "".join(list(map(decode_token, tokens)))
38 |
39 |
40 | # accelerator
41 |
42 | accelerator = Accelerator(gradient_accumulation_steps=GRADIENT_ACCUMULATE_EVERY)
43 | device = accelerator.device
44 |
45 | # instantiate palm
46 |
47 | model = PaLM(
48 | num_tokens=256,
49 | dim=512,
50 | depth=8,
51 | flash_attn=True
52 | ).to(device)
53 |
54 | # prepare enwik8 data
55 |
56 | with gzip.open("./data/enwik8.gz") as file:
57 | data = np.frombuffer(file.read(int(95e6)), dtype=np.uint8).copy()
58 | np_train, np_valid = np.split(data, [int(90e6)])
59 | data_train, data_val = torch.from_numpy(np_train), torch.from_numpy(np_valid)
60 |
61 | class TextSamplerDataset(Dataset):
62 | def __init__(self, data, seq_len):
63 | super().__init__()
64 | self.data = data
65 | self.seq_len = seq_len
66 |
67 | def __getitem__(self, index):
68 | rand_start = torch.randint(0, self.data.size(0) - self.seq_len, (1,))
69 | full_seq = self.data[rand_start : rand_start + self.seq_len + 1].long()
70 | return full_seq.to(device)
71 |
72 | def __len__(self):
73 | return self.data.size(0) // self.seq_len
74 |
75 | train_dataset = TextSamplerDataset(data_train, SEQ_LEN)
76 | val_dataset = TextSamplerDataset(data_val, SEQ_LEN)
77 | train_loader = cycle(DataLoader(train_dataset, batch_size=BATCH_SIZE))
78 | val_loader = cycle(DataLoader(val_dataset, batch_size=BATCH_SIZE))
79 |
80 | # optimizer
81 |
82 | optim = Lion(model.palm_parameters(), lr = LEARNING_RATE)
83 |
84 | model, optim, train_loader, val_loader = accelerator.prepare(
85 | model, optim, train_loader, val_loader
86 | )
87 |
88 | # training
89 |
90 | for i in tqdm(range(NUM_BATCHES), mininterval=10.0, desc="training"):
91 | model.train()
92 |
93 | with accelerator.accumulate(model):
94 | loss = model(next(train_loader), return_loss = True)
95 | accelerator.backward(loss / GRADIENT_ACCUMULATE_EVERY)
96 |
97 | accelerator.print(f"training loss: {loss.item()}")
98 | accelerator.clip_grad_norm_(model.parameters(), 0.5)
99 |
100 | optim.step()
101 | optim.zero_grad()
102 |
103 | if i % VALIDATE_EVERY == 0:
104 | model.eval()
105 | with torch.no_grad():
106 | loss = model(next(val_loader), return_loss = True)
107 | accelerator.print(f"validation loss: {loss.item()}")
108 |
109 | if i % GENERATE_EVERY == 0:
110 | model.eval()
111 | inp = random.choice(val_dataset)[:PRIME_LENGTH]
112 | prime = decode_tokens(inp)
113 | accelerator.print(f"%s \n\n %s", (prime, "*" * 100))
114 |
115 | # Check if model is wrapped
116 | if hasattr(model, "module"):
117 | sample = model.module.generate(GENERATE_LENGTH, inp[None, ...])
118 | else:
119 | sample = model.generate(GENERATE_LENGTH, inp[None, ...])
120 |
121 | output_str = decode_tokens(sample[0])
122 | accelerator.print(output_str, "\n")
123 |
--------------------------------------------------------------------------------