├── .gitattributes ├── LICENSE ├── README.md ├── assets ├── gpt2_124M_loss.png └── nanogpt.jpg ├── bench.py ├── chatgpt_dev_teaching.ipynb ├── config ├── config.yaml ├── config_reward.yaml ├── config_rl.yaml ├── eval_gpt2.py ├── eval_gpt2_large.py ├── eval_gpt2_medium.py ├── eval_gpt2_xl.py ├── finetune_shakespeare.py ├── train_gpt2.py └── train_shakespeare_char.py ├── configurator.py ├── data ├── openai_summarize_tldr │ └── prepare.py ├── openwebtext │ ├── prepare.py │ └── readme.md ├── shakespeare │ ├── prepare.py │ └── readme.md └── shakespeare_char │ ├── prepare.py │ └── readme.md ├── model.py ├── requirements.txt ├── sample.py ├── scaling_laws.ipynb ├── train.py ├── train_reward_model.py ├── train_reward_model_simple.py ├── train_rl.py ├── trainers ├── reward_trainer.py ├── rl_trainer.py └── trainer.py ├── transformer_sizing.ipynb └── utils.py /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # nanoChatGPT 2 | 3 | A crude RLHF (Reinforcement Learing from Human Feedback) layer on top of nanoGPT to test an idea I had that you can backpropagate through the reward function rather than use policy gradient. I have verified it works for a very basic example where you incentivise the network to produce words containing 'and'. The trick is to use the Straight-Through Gumbel-Softmax estimator. 4 | 5 | Also checkout chatgpt_dev_teaching.ipynb and the YouTube video explaining fine-tuning with RL: https://m.youtube.com/watch?v=soqTT0o1ZKo 6 | 7 | Prepare data: 8 | 9 | ``` 10 | $ python data/shakespeare/prepare.py 11 | ``` 12 | 13 | Once data is prepared start training. The configs assume cuda, if you don't have a gpu change to cpu in config. 14 | 15 | ``` 16 | $ python train.py # settings in config/config.yaml 17 | ``` 18 | 19 | Once a basic model is trained, can fine tune a reward model for an underlying reward rule. 20 | 21 | ``` 22 | $ python train_reward_model_simple.py # settings in config/config_reward.yaml 23 | ``` 24 | 25 | This creates a multihead model on top of the existing one. Once the reward model is trained sufficiently you can train the RL policy using: 26 | 27 | ``` 28 | $ python train_rl.py # settings in config/config_rl.yaml 29 | ``` 30 | 31 | The default config uses the Gumbel trick but it can be set to PG and it will do policy gradient instead (the latter still needs a critic implementation etc). I have validated that the Gumbel method works given that the preceding steps also worked. I am curious to see if this would scale to large models - let me know if you're able to test this. 32 | 33 | Model output after a short amount of training produces results like: 34 | 35 | ``` 36 | hand hand thousand the thousand the hand hand hand hand thousand 37 | ``` 38 | 39 | If you're feeling adventorous you can also try: 40 | 41 | ``` 42 | $ python train_reward_model.py 43 | ``` 44 | 45 | Which uses the reddit tldr dataset. The pipes should work but I have not actually finetuned this at all. 46 | 47 | References: 48 | 49 | Gumbel 50 | https://arxiv.org/pdf/1611.01144.pdf 51 | 52 | InstructGPT 53 | https://arxiv.org/abs/2203.02155 54 | 55 | Below is Andrej Karpathy's original README, make sure you have installed the revelevant packages 56 | ________ 57 | 58 | # nanoGPT 59 | 60 | ![nanoGPT](assets/nanogpt.jpg) 61 | 62 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it. 63 | 64 | ![repro124m](assets/gpt2_124M_loss.png) 65 | 66 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI). 67 | 68 | ## install 69 | 70 | Dependencies: 71 | 72 | - [pytorch](https://pytorch.org) <3 73 | - [numpy](https://numpy.org/install/) <3 74 | - `pip install transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 75 | - `pip install datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 76 | - `pip install tiktoken` for OpenAI's fast BPE code <3 77 | - `pip install wandb` for optional logging <3 78 | - `pip install tqdm` 79 | 80 | ## quick start 81 | 82 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers: 83 | 84 | ``` 85 | $ python data/shakespeare_char/prepare.py 86 | ``` 87 | 88 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system: 89 | 90 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file: 91 | 92 | ``` 93 | $ python train.py config/train_shakespeare_char.py 94 | ``` 95 | 96 | If you peak inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: 97 | 98 | ``` 99 | $ python sample.py --out_dir=out-shakespeare-char 100 | ``` 101 | 102 | This generates a few samples, for example: 103 | 104 | ``` 105 | ANGELO: 106 | And cowards it be strawn to my bed, 107 | And thrust the gates of my threats, 108 | Because he that ale away, and hang'd 109 | An one with him. 110 | 111 | DUKE VINCENTIO: 112 | I thank your eyes against it. 113 | 114 | DUKE VINCENTIO: 115 | Then will answer him to save the malm: 116 | And what have you tyrannous shall do this? 117 | 118 | DUKE VINCENTIO: 119 | If you have done evils of all disposition 120 | To end his power, the day of thrust for a common men 121 | That I leave, to fight with over-liking 122 | Hasting in a roseman. 123 | ``` 124 | 125 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later). 126 | 127 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows: 128 | 129 | ``` 130 | $ python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0 131 | ``` 132 | 133 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun: 134 | 135 | ``` 136 | GLEORKEN VINGHARD III: 137 | Whell's the couse, the came light gacks, 138 | And the for mought you in Aut fries the not high shee 139 | bot thou the sought bechive in that to doth groan you, 140 | No relving thee post mose the wear 141 | ``` 142 | 143 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. 144 | 145 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. 146 | 147 | ## reproducing GPT-2 148 | 149 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 150 | 151 | ``` 152 | $ python data/openwebtext/prepare.py 153 | ``` 154 | 155 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 156 | 157 | ``` 158 | $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 159 | ``` 160 | 161 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match. 162 | 163 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like: 164 | 165 | ``` 166 | Run on the first (master) node with example IP 123.456.123.456: 167 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 168 | Run on the worker node: 169 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 170 | ``` 171 | 172 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `$ python sample.py`. 173 | 174 | Finally, to train on a single GPU simply run the `$ python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs. 175 | 176 | ## baselines 177 | 178 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: 179 | 180 | ``` 181 | $ python train.py eval_gpt2 182 | $ python train.py eval_gpt2_medium 183 | $ python train.py eval_gpt2_large 184 | $ python train.py eval_gpt2_xl 185 | ``` 186 | 187 | and observe the following losses on train and val: 188 | 189 | | model | params | train loss | val loss | 190 | | ------| ------ | ---------- | -------- | 191 | | gpt2 | 124M | 3.11 | 3.12 | 192 | | gpt2-medium | 350M | 2.85 | 2.84 | 193 | | gpt2-large | 774M | 2.66 | 2.67 | 194 | | gpt2-xl | 1558M | 2.56 | 2.54 | 195 | 196 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction. 197 | 198 | ## finetuning 199 | 200 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: 201 | 202 | ``` 203 | $ python train.py config/finetune_shakespeare.py 204 | ``` 205 | 206 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`: 207 | 208 | ``` 209 | THEODORE: 210 | Thou shalt sell me to the highest bidder: if I die, 211 | I sell thee to the first; if I go mad, 212 | I sell thee to the second; if I 213 | lie, I sell thee to the third; if I slay, 214 | I sell thee to the fourth: so buy or sell, 215 | I tell thee again, thou shalt not sell my 216 | possession. 217 | 218 | JULIET: 219 | And if thou steal, thou shalt not sell thyself. 220 | 221 | THEODORE: 222 | I do not steal; I sell the stolen goods. 223 | 224 | THEODORE: 225 | Thou know'st not what thou sell'st; thou, a woman, 226 | Thou art ever a victim, a thing of no worth: 227 | Thou hast no right, no right, but to be sold. 228 | ``` 229 | 230 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try! 231 | 232 | ## sampling / inference 233 | 234 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model: 235 | 236 | ``` 237 | $ python sample.py \ 238 | --init_from=gpt2-xl \ 239 | --start="What is the answer to life, the universe, and everything?" \ 240 | --num_samples=5 --max_new_tokens=100 241 | ``` 242 | 243 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. `$ python sample.py --start=FILE:prompt.txt`. 244 | 245 | ## efficiency notes 246 | 247 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. 248 | 249 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! 250 | 251 | ## todos 252 | 253 | - Investigate and add FSDP instead of DDP 254 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.) 255 | - Finetune the finetuning script, I think the hyperparams are not great 256 | - Schedule for linear batch size increase during training 257 | - Incorporate other embeddings (rotary, alibi) 258 | - Separate out the optim buffers from model params in checkpoints I think 259 | - Additional logging around network health (e.g. gradient clip events, magnitudes) 260 | - Few more investigations around better init etc. 261 | 262 | ## troubleshooting 263 | 264 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run. 265 | 266 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context. 267 | 268 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord: 269 | 270 | [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp) 271 | 272 | ## acknowledgements 273 | 274 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT! 275 | -------------------------------------------------------------------------------- /assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanjeevanahilan/nanoChatGPT/a77c85c2cdfca470a22df9547ed0757a34b1583f/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /assets/nanogpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/sanjeevanahilan/nanoChatGPT/a77c85c2cdfca470a22df9547ed0757a34b1583f/assets/nanogpt.jpg -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | A much shorter version of train.py for benchmarking 3 | """ 4 | import os 5 | from contextlib import nullcontext 6 | import numpy as np 7 | import time 8 | import torch 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | batch_size = 12 13 | block_size = 1024 14 | bias = False 15 | real_data = True 16 | seed = 1337 17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 18 | dtype = 'bfloat16' # 'float32' or 'bfloat16' or 'float16' 19 | compile = True # use PyTorch 2.0 to compile the model to be faster 20 | profile = False # use pytorch profiler, or just simple benchmarking? 21 | exec(open('configurator.py').read()) # overrides from command line or config file 22 | # ----------------------------------------------------------------------------- 23 | 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 31 | 32 | # data loading init 33 | if real_data: 34 | dataset = 'openwebtext' 35 | data_dir = os.path.join('data', dataset) 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 37 | def get_batch(split): 38 | data = train_data # note ignore split in benchmarking script 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 43 | return x, y 44 | else: 45 | # alternatively, if fixed data is desired to not care about data loading 46 | x = torch.randint(50304, (batch_size, block_size), device=device) 47 | y = torch.randint(50304, (batch_size, block_size), device=device) 48 | get_batch = lambda split: (x, y) 49 | 50 | # model init 51 | gptconf = GPTConfig( 52 | block_size = block_size, # how far back does the model look? i.e. context size 53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model 54 | dropout = 0, # for determinism 55 | bias = bias, 56 | ) 57 | model = GPT(gptconf) 58 | model.to(device) 59 | 60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) 61 | 62 | if compile: 63 | print("Compiling model...") 64 | model = torch.compile(model) # pytorch 2.0 65 | 66 | if profile: 67 | # useful docs on pytorch profiler: 68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html 69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile 70 | wait, warmup, active = 5, 5, 5 71 | num_steps = wait + warmup + active 72 | with torch.profiler.profile( 73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), 75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), 76 | record_shapes=False, 77 | profile_memory=False, 78 | with_stack=False, # incurs an additional overhead, disable if not needed 79 | with_flops=True, 80 | with_modules=False, # only for torchscript models atm 81 | ) as prof: 82 | 83 | X, Y = get_batch('train') 84 | for k in range(num_steps): 85 | with ctx: 86 | logits, loss = model(X, Y) 87 | X, Y = get_batch('train') 88 | optimizer.zero_grad(set_to_none=True) 89 | loss.backward() 90 | optimizer.step() 91 | lossf = loss.item() 92 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 93 | 94 | prof.step() # notify the profiler at end of each step 95 | 96 | else: 97 | 98 | # simple benchmarking 99 | torch.cuda.synchronize() 100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark 101 | t0 = time.time() 102 | X, Y = get_batch('train') 103 | for k in range(num_steps): 104 | with ctx: 105 | logits, loss = model(X, Y) 106 | X, Y = get_batch('train') 107 | optimizer.zero_grad(set_to_none=True) 108 | loss.backward() 109 | optimizer.step() 110 | lossf = loss.item() 111 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 112 | torch.cuda.synchronize() 113 | t1 = time.time() 114 | dt = t1-t0 115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) 116 | if stage == 1: 117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") 118 | -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | IO: 2 | out_dir: out 3 | eval_interval: 2000 4 | log_interval: 1 5 | eval_iters: 200 6 | eval_only: False # if True, script exits right after the first eval 7 | always_save_checkpoint: True # if True, always save a checkpoint after each eval 8 | init_from: scratch # 'scratch' or 'resume' or 'gpt2*' 9 | wandb: 10 | wandb_log: False # disabled by default 11 | wandb_project: rlhf # 'gpt2' 12 | wandb_run_name: gpt2 # 'run' + str(time.time()) 13 | data: 14 | dataset: shakespeare # 'openwebtext', 'shakespeare', 'openai_summarize_tldr' 15 | gradient_accumulation_steps: 1 # used to simulate larger batch sizes 16 | batch_size: 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 17 | block_size: 32 18 | model: 19 | n_layer: 2 20 | n_head: 2 21 | n_embd: 32 22 | dropout: 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 23 | bias: False # do we use bias inside LayerNorm and Linear layers? 24 | optimizer: # adamw 25 | learning_rate: 6.0e-4 # max learning rate 26 | max_iters: 600000 # total number of training iterations 27 | weight_decay: 1.0e-2 28 | beta1: 0.9 29 | beta2: 0.95 30 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 31 | decay_lr: True # whether to decay the learning rate 32 | warmup_iters: 2000 # how many steps to warm up for 33 | lr_decay_iters: 600000 # should be ~= max_iters per Chinchilla 34 | min_lr: 6.0e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 35 | DDP: 36 | backend: nccl # 'nccl', 'gloo', etc. 37 | system: 38 | device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 39 | dtype: float16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 40 | compile: False # use PyTorch 2.0 to compile the model to be faster 41 | -------------------------------------------------------------------------------- /config/config_reward.yaml: -------------------------------------------------------------------------------- 1 | IO: 2 | out_dir: out 3 | eval_interval: 500 4 | log_interval: 1 5 | eval_iters: 100 6 | eval_only: False # if True, script exits right after the first eval 7 | always_save_checkpoint: True # if True, always save a checkpoint after each eval 8 | init_from: resume # 'scratch' or 'resume' or 'gpt2*' 9 | init_multihead_from: scratch 10 | out_dir_multihead: out_reward # used if restoring multihead 11 | wandb: 12 | wandb_log: True # disabled by default 13 | wandb_project: rlhf # 'gpt2' 14 | wandb_run_name: gpt2 # 'run' + str(time.time()) 15 | data: 16 | dataset: 'shakespeare' # 'openwebtext', 'shakespeare', 'openai_summarize_tldr' 17 | gradient_accumulation_steps: 1 # used to simulate larger batch sizes 18 | batch_size: 64 # if gradient_accumulation_steps > 1, this is the micro-batch size 19 | block_size: 32 20 | model: 21 | n_layer: 2 22 | n_head: 2 23 | n_embd: 32 24 | dropout: 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 25 | bias: False # do we use bias inside LayerNorm and Linear layers? 26 | optimizer: # adamw 27 | learning_rate: 6.0e-4 # max learning rate 28 | max_iters: 600000 # total number of training iterations 29 | weight_decay: 1.0e-2 30 | beta1: 0.9 31 | beta2: 0.95 32 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 33 | decay_lr: True # whether to decay the learning rate 34 | warmup_iters: 2000 # how many steps to warm up for 35 | lr_decay_iters: 600000 # should be ~= max_iters per Chinchilla 36 | min_lr: 6.0e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 37 | DDP: 38 | backend: nccl # 'nccl', 'gloo', etc. 39 | system: 40 | device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 41 | dtype: float16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 42 | compile: False # use PyTorch 2.0 to compile the model to be faster 43 | -------------------------------------------------------------------------------- /config/config_rl.yaml: -------------------------------------------------------------------------------- 1 | algorithm: 2 | method: gumbel # pg or gumbel 3 | hard_code_reward: False # use a learned reward model or hard code reward (latter does not work with Gumbel) 4 | separate_reward_model: True # when using a reward model, instantiate it separately rather than share params with LM 5 | discrete_reward: True # reward output is 0 or 1 sample if True, otherwise reward is continuous 6 | episode_length: 32 7 | IO: 8 | out_dir: out 9 | eval_interval: 100 10 | log_interval: 1 11 | eval_iters: 200 12 | eval_only: False # if True, script exits right after the first eval 13 | always_save_checkpoint: True # if True, always save a checkpoint after each eval 14 | init_from: scratch # 'scratch' or 'resume' or 'gpt2*' 15 | init_multihead_from: scratch 16 | out_dir_multihead: out_reward # used if restoring multihead 17 | wandb: 18 | wandb_log: False # disabled by default 19 | wandb_project: rlhf # 'gpt2' 20 | wandb_run_name: gpt2 # 'run' + str(time.time()) 21 | data: 22 | dataset: shakespeare # 'openwebtext', 'shakespeare', 'openai_summarize_tldr' 23 | gradient_accumulation_steps: 1 # used to simulate larger batch sizes 24 | batch_size: 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 25 | block_size: 32 26 | model: 27 | n_layer: 2 28 | n_head: 2 29 | n_embd: 32 30 | dropout: 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 31 | bias: False # do we use bias inside LayerNorm and Linear layers? 32 | optimizer: # adamw 33 | learning_rate: 6.0e-4 # max learning rate 34 | max_iters: 600000 # total number of training iterations 35 | weight_decay: 1.0e-2 36 | beta1: 0.9 37 | beta2: 0.95 38 | grad_clip: 1.0 # clip gradients at this value, or disable if == 0.0 39 | decay_lr: True # whether to decay the learning rate 40 | warmup_iters: 2000 # how many steps to warm up for 41 | lr_decay_iters: 600000 # should be ~= max_iters per Chinchilla 42 | min_lr: 6.0e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 43 | DDP: 44 | backend: nccl # 'nccl', 'gloo', etc. 45 | system: 46 | device: cuda # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 47 | dtype: float16 # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 48 | compile: False # use PyTorch 2.0 to compile the model to be faster 49 | -------------------------------------------------------------------------------- /config/eval_gpt2.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=12, n_head=12, n_embd=768 3 | # 124M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_large.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=36, n_head=20, n_embd=1280 3 | # 774M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-large' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=24, n_head=16, n_embd=1024 3 | # 350M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-medium' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=48, n_head=25, n_embd=1600 3 | # 1558M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-xl' 9 | -------------------------------------------------------------------------------- /config/finetune_shakespeare.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | out_dir = 'out-shakespeare' 4 | eval_interval = 5 5 | eval_iters = 40 6 | wandb_log = False # feel free to turn on 7 | wandb_project = 'shakespeare' 8 | wandb_run_name = 'ft-' + str(time.time()) 9 | 10 | dataset = 'shakespeare' 11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model 12 | 13 | # only save checkpoints if the validation loss improves 14 | always_save_checkpoint = False 15 | 16 | # the number of examples per iter: 17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter 18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters 19 | batch_size = 1 20 | gradient_accumulation_steps = 32 21 | max_iters = 20 22 | 23 | # finetune at constant LR 24 | learning_rate = 3e-5 25 | decay_lr = False 26 | -------------------------------------------------------------------------------- /config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | batch_size = 64 18 | block_size = 256 # context of up to 256 previous characters 19 | 20 | # baby GPT model :) 21 | n_layer = 6 22 | n_head = 6 23 | n_embd = 384 24 | dropout = 0.2 25 | 26 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 27 | max_iters = 5000 28 | lr_decay_iters = 5000 # make equal to max_iters usually 29 | min_lr = 1e-4 # learning_rate / 10 usually 30 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 31 | 32 | warmup_iters = 100 # not super necessary potentially 33 | 34 | # on macbook also add 35 | # device = 'cpu' # run on cpu only 36 | # compile = False # do not torch compile the model 37 | -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /data/openai_summarize_tldr/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 16 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("CarperAI/openai_summarize_tldr") 16 | 17 | 18 | 19 | # class TLDRDataset(Dataset): 20 | # def __init__(self, split): 21 | # self.text = [] 22 | # dataset = load_dataset(train_path, split=split) 23 | # for sample in dataset: 24 | # self.text.append(sample["prompt"] + sample["label"]) 25 | # # if "valid" in train_path: 26 | # # self.post_list = self.post_list[0:2000] 27 | # # self.tokenizer = tokenizer 28 | # # self.max_length = max_length 29 | # # self.input_ids = [] 30 | # # self.attn_masks = [] 31 | 32 | # def __len__(self): 33 | # return len(self.text) 34 | 35 | # def __getitem__(self, idx): 36 | # txt = self.text[idx] 37 | # # encodings_dict = self.tokenizer(txt, truncation=True, max_length=self.max_length, padding="max_length") 38 | # # input_ids = torch.tensor(encodings_dict["input_ids"]) 39 | # # attn_masks = torch.tensor(encodings_dict["attention_mask"]) 40 | 41 | # return { 42 | # # "input_ids": input_ids, 43 | # # "attention_mask": attn_masks, 44 | # # "labels": input_ids, 45 | # } 46 | 47 | # dataset = TLDRDataset(split="train") 48 | 49 | train_text_list = [] 50 | for sample in dataset['train']: 51 | train_text_list.append(sample['prompt'] + sample['label']) 52 | dataset['train'] = dataset['train'].add_column('text', train_text_list) # add the text column to the train dataset 53 | 54 | dataset['val'] = dataset.pop('valid') # rename the valid dataset to val 55 | 56 | val_text_list = [] 57 | for sample in dataset['val']: 58 | val_text_list.append(sample['prompt'] + sample['label']) 59 | dataset['val'] = dataset['val'].add_column('text', val_text_list) # add the text column to the val dataset 60 | 61 | dataset.pop('test') # remove the test dataset 62 | 63 | split_dataset = dataset 64 | 65 | # this results in: 66 | # >>> split_dataset 67 | # DatasetDict({ 68 | # train: Dataset({ 69 | # features: ['prompt', 'label', 'text'], 70 | # num_rows: 116722 71 | # }) 72 | # val: Dataset({ 73 | # features: ['prompt', 'label', 'text'], 74 | # num_rows: 6447 75 | # }) 76 | # }) 77 | 78 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 79 | enc = tiktoken.get_encoding("gpt2") 80 | def process(example): 81 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 82 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 83 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 84 | out = {'ids': ids, 'len': len(ids)} 85 | return out 86 | 87 | # tokenize the dataset 88 | tokenized = split_dataset.map( 89 | process, 90 | remove_columns=['text','prompt','label'], 91 | desc="tokenizing the splits", 92 | num_proc=num_proc, 93 | ) 94 | 95 | # concatenate all the ids in each dataset into one large file we can use for training 96 | for split, dset in tokenized.items(): 97 | arr_len = np.sum(dset['len']) 98 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 99 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 100 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 101 | 102 | print(f"writing {filename}...") 103 | idx = 0 104 | for example in tqdm(dset): 105 | arr[idx : idx + example['len']] = example['ids'] 106 | idx += example['len'] 107 | arr.flush() 108 | 109 | # train.bin is ~17GB, val.bin ~8.5MB 110 | # train has ~9B tokens (9,035,582,198) 111 | # val has ~4M tokens (4,434,897) 112 | 113 | # to read the bin files later, e.g. with numpy: 114 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 115 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 15 | dataset = load_dataset("openwebtext") 16 | 17 | # owt by default only contains the 'train' split, so create a test split 18 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 19 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 20 | 21 | # this results in: 22 | # >>> split_dataset 23 | # DatasetDict({ 24 | # train: Dataset({ 25 | # features: ['text'], 26 | # num_rows: 8009762 27 | # }) 28 | # val: Dataset({ 29 | # features: ['text'], 30 | # num_rows: 4007 31 | # }) 32 | # }) 33 | 34 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 35 | enc = tiktoken.get_encoding("gpt2") 36 | def process(example): 37 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 38 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 39 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 40 | out = {'ids': ids, 'len': len(ids)} 41 | return out 42 | 43 | # tokenize the dataset 44 | tokenized = split_dataset.map( 45 | process, 46 | remove_columns=['text'], 47 | desc="tokenizing the splits", 48 | num_proc=num_proc, 49 | ) 50 | 51 | # concatenate all the ids in each dataset into one large file we can use for training 52 | for split, dset in tokenized.items(): 53 | arr_len = np.sum(dset['len']) 54 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 55 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 56 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 57 | 58 | print(f"writing {filename}...") 59 | idx = 0 60 | for example in tqdm(dset): 61 | arr[idx : idx + example['len']] = example['ids'] 62 | idx += example['len'] 63 | arr.flush() 64 | 65 | # train.bin is ~17GB, val.bin ~8.5MB 66 | # train has ~9B tokens (9,035,582,198) 67 | # val has ~4M tokens (4,434,897) 68 | 69 | # to read the bin files later, e.g. with numpy: 70 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 71 | -------------------------------------------------------------------------------- /data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | # @torch.jit.script # good to enable when not using torch.compile, disable when using (our default) 19 | def new_gelu(x): 20 | """ 21 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 22 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 23 | """ 24 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 25 | 26 | class LayerNorm(nn.Module): 27 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 28 | 29 | def __init__(self, ndim, bias): 30 | super().__init__() 31 | self.weight = nn.Parameter(torch.ones(ndim)) 32 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 33 | 34 | def forward(self, input): 35 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 36 | 37 | class CausalSelfAttention(nn.Module): 38 | 39 | def __init__(self, config): 40 | super().__init__() 41 | assert config.n_embd % config.n_head == 0 42 | # key, query, value projections for all heads, but in a batch 43 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 44 | # output projection 45 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 46 | # regularization 47 | self.attn_dropout = nn.Dropout(config.dropout) 48 | self.resid_dropout = nn.Dropout(config.dropout) 49 | self.n_head = config.n_head 50 | self.n_embd = config.n_embd 51 | self.dropout = config.dropout 52 | # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary 53 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0 54 | if not self.flash: 55 | print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0") 56 | # causal mask to ensure that attention is only applied to the left in the input sequence 57 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 58 | .view(1, 1, config.block_size, config.block_size)) 59 | 60 | def forward(self, x): 61 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 62 | 63 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 64 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 65 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 66 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 67 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 68 | 69 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 70 | if self.flash: 71 | # efficient attention using Flash Attention CUDA kernels 72 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True) 73 | else: 74 | # manual implementation of attention 75 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 76 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 77 | att = F.softmax(att, dim=-1) 78 | att = self.attn_dropout(att) 79 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 80 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 81 | 82 | # output projection 83 | y = self.resid_dropout(self.c_proj(y)) 84 | return y 85 | 86 | class MLP(nn.Module): 87 | 88 | def __init__(self, config): 89 | super().__init__() 90 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 91 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 92 | self.dropout = nn.Dropout(config.dropout) 93 | 94 | def forward(self, x): 95 | x = self.c_fc(x) 96 | x = new_gelu(x) 97 | x = self.c_proj(x) 98 | x = self.dropout(x) 99 | return x 100 | 101 | class Block(nn.Module): 102 | 103 | def __init__(self, config): 104 | super().__init__() 105 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 106 | self.attn = CausalSelfAttention(config) 107 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 108 | self.mlp = MLP(config) 109 | 110 | def forward(self, x): 111 | x = x + self.attn(self.ln_1(x)) 112 | x = x + self.mlp(self.ln_2(x)) 113 | return x 114 | 115 | @dataclass 116 | class GPTConfig: 117 | block_size: int = 1024 118 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 119 | n_layer: int = 12 120 | n_head: int = 12 121 | n_embd: int = 768 122 | dropout: float = 0.0 123 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 124 | 125 | class GPT(nn.Module): 126 | 127 | def __init__(self, config): 128 | super().__init__() 129 | assert config.vocab_size is not None 130 | assert config.block_size is not None 131 | self.config = config 132 | 133 | self.transformer = nn.ModuleDict(dict( 134 | wte = nn.Embedding(config.vocab_size, config.n_embd), 135 | wpe = nn.Embedding(config.block_size, config.n_embd), 136 | drop = nn.Dropout(config.dropout), 137 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 138 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 139 | )) 140 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 141 | # with weight tying when using torch.compile() some warnings get generated: 142 | # "UserWarning: functional_call was passed multiple values for tied weights. 143 | # This behavior is deprecated and will be an error in future versions" 144 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 145 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 146 | 147 | # init all weights 148 | self.apply(self._init_weights) 149 | # apply special scaled init to the residual projections, per GPT-2 paper 150 | for pn, p in self.named_parameters(): 151 | if pn.endswith('c_proj.weight'): 152 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 153 | 154 | # report number of parameters 155 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 156 | 157 | def get_num_params(self, non_embedding=True): 158 | """ 159 | Return the number of parameters in the model. 160 | For non-embedding count (default), the position embeddings get subtracted. 161 | The token embeddings would too, except due to the parameter sharing these 162 | params are actually used as weights in the final layer, so we include them. 163 | """ 164 | n_params = sum(p.numel() for p in self.parameters()) 165 | if non_embedding: 166 | n_params -= self.transformer.wpe.weight.numel() 167 | return n_params 168 | 169 | def _init_weights(self, module): 170 | if isinstance(module, nn.Linear): 171 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 172 | if module.bias is not None: 173 | torch.nn.init.zeros_(module.bias) 174 | elif isinstance(module, nn.Embedding): 175 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 176 | 177 | def forward(self, idx, targets=None): 178 | device = idx.device 179 | b, t = idx.size() 180 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 181 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 182 | 183 | # forward the GPT model itself 184 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 185 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 186 | x = self.transformer.drop(tok_emb + pos_emb) 187 | for block in self.transformer.h: 188 | x = block(x) 189 | x = self.transformer.ln_f(x) 190 | 191 | if targets is not None: 192 | # if we are given some desired targets also calculate the loss 193 | logits = self.lm_head(x) 194 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 195 | else: 196 | # inference-time mini-optimization: only forward the lm_head on the very last position 197 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 198 | loss = None 199 | 200 | return logits, loss 201 | 202 | def crop_block_size(self, block_size): 203 | # model surgery to decrease the block size if necessary 204 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 205 | # but want to use a smaller block size for some smaller, simpler model 206 | assert block_size <= self.config.block_size 207 | self.config.block_size = block_size 208 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 209 | for block in self.transformer.h: 210 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 211 | 212 | @classmethod 213 | def from_pretrained(cls, model_type, override_args=None): 214 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 215 | override_args = override_args or {} # default to empty dict 216 | # only dropout can be overridden see more notes below 217 | assert all(k == 'dropout' for k in override_args) 218 | from transformers import GPT2LMHeadModel 219 | print("loading weights from pretrained gpt: %s" % model_type) 220 | 221 | # n_layer, n_head and n_embd are determined from model_type 222 | config_args = { 223 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 224 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 225 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 226 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 227 | }[model_type] 228 | print("forcing vocab_size=50257, block_size=1024, bias=True") 229 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 230 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 231 | config_args['bias'] = True # always True for GPT model checkpoints 232 | # we can override the dropout rate, if desired 233 | if 'dropout' in override_args: 234 | print(f"overriding dropout rate to {override_args['dropout']}") 235 | config_args['dropout'] = override_args['dropout'] 236 | # create a from-scratch initialized minGPT model 237 | config = GPTConfig(**config_args) 238 | model = GPT(config) 239 | sd = model.state_dict() 240 | sd_keys = sd.keys() 241 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 242 | 243 | # init a huggingface/transformers model 244 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 245 | sd_hf = model_hf.state_dict() 246 | 247 | # copy while ensuring all of the parameters are aligned and match in names and shapes 248 | sd_keys_hf = sd_hf.keys() 249 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 250 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 251 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 252 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 253 | # this means that we have to transpose these weights when we import them 254 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 255 | for k in sd_keys_hf: 256 | if any(k.endswith(w) for w in transposed): 257 | # special treatment for the Conv1D weights we need to transpose 258 | assert sd_hf[k].shape[::-1] == sd[k].shape 259 | with torch.no_grad(): 260 | sd[k].copy_(sd_hf[k].t()) 261 | else: 262 | # vanilla copy over the other parameters 263 | assert sd_hf[k].shape == sd[k].shape 264 | with torch.no_grad(): 265 | sd[k].copy_(sd_hf[k]) 266 | 267 | return model 268 | 269 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 270 | """ 271 | This long function is unfortunately doing something very simple and is being very defensive: 272 | We are separating out all parameters of the model into two buckets: those that will experience 273 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 274 | We are then returning the PyTorch optimizer object. 275 | """ 276 | 277 | # separate out all parameters to those that will and won't experience regularizing weight decay 278 | decay = set() 279 | no_decay = set() 280 | whitelist_weight_modules = (torch.nn.Linear, ) 281 | blacklist_weight_modules = (torch.nn.LayerNorm, LayerNorm, torch.nn.Embedding) 282 | for mn, m in self.named_modules(): 283 | for pn, p in m.named_parameters(): 284 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 285 | # random note: because named_modules and named_parameters are recursive 286 | # we will see the same tensors p many many times. but doing it this way 287 | # allows us to know which parent module any tensor p belongs to... 288 | if pn.endswith('bias'): 289 | # all biases will not be decayed 290 | no_decay.add(fpn) 291 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 292 | # weights of whitelist modules will be weight decayed 293 | decay.add(fpn) 294 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 295 | # weights of blacklist modules will NOT be weight decayed 296 | no_decay.add(fpn) 297 | 298 | # subtle: 'transformer.wte.weight' and 'lm_head.weight' are tied, so they 299 | # will appear in the no_decay and decay sets respectively after the above. 300 | # In addition, because named_parameters() doesn't return duplicates, it 301 | # will only return the first occurence, key'd by 'transformer.wte.weight', below. 302 | # so let's manually remove 'lm_head.weight' from decay set. This will include 303 | # this tensor into optimization via transformer.wte.weight only, and not decayed. 304 | decay.remove('lm_head.weight') 305 | 306 | # validate that we considered every parameter 307 | param_dict = {pn: p for pn, p in self.named_parameters()} 308 | inter_params = decay & no_decay 309 | union_params = decay | no_decay 310 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 311 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 312 | % (str(param_dict.keys() - union_params), ) 313 | 314 | # create the pytorch optimizer object 315 | optim_groups = [ 316 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": weight_decay}, 317 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 318 | ] 319 | # new PyTorch nightly has a new 'fused' option for AdamW that is much faster 320 | use_fused = (device_type == 'cuda') and ('fused' in inspect.signature(torch.optim.AdamW).parameters) 321 | print(f"using fused AdamW: {use_fused}") 322 | extra_args = dict(fused=True) if use_fused else dict() 323 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 324 | 325 | return optimizer 326 | 327 | def estimate_mfu(self, fwdbwd_per_iter, dt): 328 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 329 | # first estimate the number of flops we do per iteration. 330 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 331 | N = self.get_num_params() 332 | cfg = self.config 333 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 334 | flops_per_token = 6*N + 12*L*H*Q*T 335 | flops_per_fwdbwd = flops_per_token * T 336 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 337 | # express our flops throughput as ratio of A100 bfloat16 peak flops 338 | flops_achieved = flops_per_iter * (1.0/dt) # per second 339 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 340 | mfu = flops_achieved / flops_promised 341 | return mfu 342 | 343 | @torch.no_grad() 344 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 345 | """ 346 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 347 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 348 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 349 | """ 350 | for _ in range(max_new_tokens): 351 | # if the sequence context is growing too long we must crop it at block_size 352 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 353 | # forward the model to get the logits for the index in the sequence 354 | logits, _ = self(idx_cond) 355 | # pluck the logits at the final step and scale by desired temperature 356 | logits = logits[:, -1, :] / temperature 357 | # optionally crop the logits to only the top k options 358 | if top_k is not None: 359 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 360 | logits[logits < v[:, [-1]]] = -float('Inf') 361 | # apply softmax to convert logits to (normalized) probabilities 362 | probs = F.softmax(logits, dim=-1) 363 | # sample from the distribution 364 | idx_next = torch.multinomial(probs, num_samples=1) 365 | # append sampled index to the running sequence and continue 366 | idx = torch.cat((idx, idx_next), dim=1) 367 | 368 | return idx 369 | 370 | class RLHF(nn.Module): 371 | def __init__(self, model, mode, discrete_reward=False): 372 | super().__init__() 373 | self.model = model 374 | self.config = model.config 375 | 376 | # reward model 377 | self.n_embd = model.lm_head.in_features 378 | self.block_size = model.config.block_size 379 | model.policy_head = nn.Linear(model.lm_head.in_features, model.lm_head.out_features, bias=False) 380 | self.mode = mode 381 | self.discrete_reward = discrete_reward 382 | if discrete_reward: 383 | model.reward_head = nn.Linear(model.lm_head.in_features, 2, bias=False) 384 | else: 385 | model.reward_head = nn.Linear(self.n_embd*self.block_size, 1, bias=False) 386 | 387 | def forward_reward(self, idx, targets=None): 388 | device = idx.device 389 | b, t = idx.size() 390 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 391 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 392 | 393 | # forward the GPT model itself 394 | tok_emb = self.model.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 395 | pos_emb = self.model.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 396 | x = self.model.transformer.drop(tok_emb + pos_emb) 397 | for block in self.model.transformer.h: 398 | x = block(x) 399 | x = self.model.transformer.ln_f(x) 400 | 401 | rewards = self.model.reward_head(x[:, -1, :]) 402 | 403 | if self.discrete_reward: 404 | probs = torch.softmax(rewards,1) 405 | if targets is not None: 406 | # if we are given some desired targets also calculate the loss 407 | loss = F.cross_entropy(probs, targets, ignore_index=-1) 408 | else: 409 | loss = None 410 | return probs, loss 411 | else: 412 | return rewards 413 | 414 | def forward(self, idx, targets=None): 415 | if self.mode == 'reward': 416 | return self.forward_reward(idx, targets) 417 | else: 418 | return self.model(idx, targets) 419 | 420 | def generate(self, idx, max_new_tokens, device, block_size, use_reference=True, reward_model=None, hard_code_reward=True): 421 | # idx is (B, T) array of indices in the current context 422 | log_probs = torch.tensor([]).to(device) 423 | log_probs_ref = torch.tensor([]).to(device) 424 | values = torch.tensor([]).to(device) 425 | 426 | idx_cond_all = torch.zeros((idx.shape[0], block_size, max_new_tokens)).to(device) 427 | values_all = torch.zeros((idx.shape[0], max_new_tokens)).to(device) 428 | actions_all = torch.zeros((idx.shape[0], max_new_tokens)).to(device) 429 | rewards_all = torch.zeros((idx.shape[0],)).to(device) 430 | log_probs_all = torch.zeros((idx.shape[0], max_new_tokens)).to(device) 431 | advantages_all = torch.zeros((idx.shape[0], max_new_tokens)).to(device) 432 | returns_all = torch.zeros((idx.shape[0], max_new_tokens)).to(device) 433 | gamma = 1 434 | lam = 1 435 | 436 | # TODO: Critic, PPO 437 | for i in range(max_new_tokens): 438 | # crop idx to the last block_size tokens 439 | # block_size = 256 440 | idx_cond = idx[:, -block_size:] 441 | 442 | # get the predictions 443 | logits, loss = self(idx_cond) 444 | 445 | # focus only on the last time step 446 | logits = logits[:, -1, :] # becomes (B, C) 447 | # apply softmax to get probabilities 448 | 449 | probs_next = F.softmax(logits, dim=-1) # (B, C) 450 | # sample from the distribution 451 | idx_next = torch.multinomial(probs_next, num_samples=1) # (B, 1) 452 | 453 | probs_idx_next = torch.gather(probs_next, 1, idx_next) 454 | log_probs_idx_next = torch.log(probs_idx_next) 455 | log_probs = torch.cat((log_probs, log_probs_idx_next), dim=1) 456 | 457 | if use_reference: 458 | logits_ref, _ = self.model(idx_cond) 459 | logits_ref = logits_ref[:, -1, :] # becomes (B, C) 460 | probs_ref_next = F.softmax(logits_ref, dim=-1) # (B, C) 461 | probs_ref_idx_next = torch.gather(probs_ref_next, 1, idx_next) 462 | log_probs_ref_idx_next = torch.log(probs_ref_idx_next) 463 | log_probs_ref = torch.cat((log_probs_ref, log_probs_ref_idx_next), dim=1) 464 | 465 | # append sampled index to the running sequence 466 | idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) 467 | 468 | 469 | if i == max_new_tokens-1: 470 | states = idx[:,-max_new_tokens:] 471 | if hard_code_reward: 472 | # simple test where reward for outputting the letter 'z' (89) 473 | rewards = torch.zeros_like(states, dtype=torch.float16) 474 | rewards[states==89] = 1.0 475 | rewards = torch.sum(rewards, 1, keepdim=True) 476 | rewards[rewards > 1] = 1 477 | 478 | else: 479 | if self.discrete_reward: 480 | rewards = reward_model.forward_reward(torch.tensor(states))[0][:,1].unsqueeze(-1) 481 | else: 482 | rewards = reward_model.forward_reward(torch.tensor(states)) 483 | 484 | 485 | for t in reversed(range(max_new_tokens)): 486 | if t == max_new_tokens - 1: 487 | # value at last state is 0 488 | delta = rewards[:].squeeze() - values_all[:, t] 489 | advantages_all[:, t] = delta 490 | # returns_all[:, t] = rewards[:] 491 | else: 492 | # rewards can only be non-zero at the last state 493 | delta = gamma * values_all[:, t + 1] - values_all[:, t] 494 | advantages_all[:, t] = delta + gamma * lam * advantages_all[:, t + 1] 495 | # returns_all[:, t] += gamma * returns_all[:, t + 1] 496 | 497 | 498 | 499 | return idx, log_probs[:,-max_new_tokens:], log_probs_ref[:,-max_new_tokens:], rewards, advantages_all 500 | 501 | def generate_gumbel(self, idx, max_new_tokens, device, block_size, reward_model, use_reference=True): 502 | 503 | onehot = torch.tensor([]).to(device) 504 | for i in range(max_new_tokens): 505 | # crop idx to the last block_size tokens 506 | # block_size = 256 507 | idx_cond = idx[:, -block_size:] 508 | 509 | # get the predictions 510 | logits, loss = self(idx_cond) 511 | 512 | # focus only on the last time step 513 | logits = logits[:, -1, :] # becomes (B, C) 514 | 515 | 516 | #gumbel sample 517 | idx_next, onehot_next = self.gumbel_softmax(logits, tau=1, device=idx.device) 518 | 519 | # append sampled index to the running sequence 520 | idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) 521 | 522 | onehot = torch.cat((onehot, onehot_next), dim=2) # (B, T+1) 523 | 524 | if i == max_new_tokens-1: 525 | if self.discrete_reward: 526 | rewards = reward_model.forward_reward_gumbel(onehot)[0][:,1].unsqueeze(-1) 527 | else: 528 | rewards = reward_model.forward_reward_gumbel(onehot) 529 | 530 | return idx[:,-max_new_tokens:], rewards 531 | 532 | 533 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 534 | def sample_gumbel(self, shape, eps=1e-20): 535 | """Sample from Gumbel(0, 1)""" 536 | U = torch.distributions.Uniform(0,1).sample(shape) 537 | return -torch.log(-torch.log(U + eps) + eps) 538 | 539 | # modified for PyTorch from https://github.com/ericjang/gumbel-softmax/blob/master/Categorical%20VAE.ipynb 540 | def gumbel_softmax_sample(self, logits, tau, device, dim=1): 541 | """ Draw a sample from the Gumbel-Softmax distribution""" 542 | y = logits + self.sample_gumbel(logits.shape).to(device) 543 | return F.softmax(y / tau, dim=dim) 544 | 545 | def gumbel_softmax(self, logits, tau, device): 546 | gumbel_sample = self.gumbel_softmax_sample(logits, tau, device) 547 | 548 | # Alternatively could try 549 | # probs = F.softmax(gumbel_sample, dim=-1) 550 | # idx_next = torch.multinomial(probs, num_samples=1) 551 | 552 | idx_next = gumbel_sample.max(-1, keepdim=True)[1] 553 | onehot_idx_next = torch.nn.functional.one_hot(idx_next, num_classes=logits.shape[1]).squeeze() 554 | y = (onehot_idx_next-gumbel_sample).detach() + gumbel_sample 555 | idx_next_from_y = torch.argmax(y, dim=1).unsqueeze(-1) 556 | return idx_next_from_y, y.unsqueeze(-1) 557 | 558 | def forward_reward_gumbel(self, onehots, idx=None, targets=None): 559 | # (embd, vocab) @ (vocab, embd) = (embd,embd) 560 | 561 | device = onehots.device 562 | t = onehots.shape[2] 563 | b = onehots.shape[0] 564 | # b, t = idx.size() 565 | # assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 566 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 567 | 568 | # forward the GPT model itself 569 | # tok_emb = self.model.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 570 | tok_emb = (self.model.transformer.wte.weight.T @ onehots) 571 | tok_emb = torch.transpose(tok_emb, 1, 2) 572 | 573 | if idx is not None: 574 | tok_emb2 = self.model.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 575 | assert torch.all(tok_emb == tok_emb2) 576 | pos_emb = self.model.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 577 | x = self.model.transformer.drop(tok_emb + pos_emb) 578 | for block in self.model.transformer.h: 579 | x = block(x) 580 | x = self.model.transformer.ln_f(x) 581 | 582 | rewards = self.model.reward_head(x[:, -1, :]) 583 | 584 | if self.discrete_reward: 585 | probs = torch.softmax(rewards,1) 586 | if targets is not None: 587 | # if we are given some desired targets also calculate the loss 588 | loss = F.cross_entropy(rewards, targets, ignore_index=-1) 589 | else: 590 | loss = None 591 | return probs, loss 592 | else: 593 | return rewards -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tiktoken 2 | numpy 3 | contextlib 4 | torch 5 | wandb 6 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | import torch 8 | import tiktoken 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 13 | out_dir = 'out' # ignored if init_from is not 'resume' 14 | start = "TITLE" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 5 # number of samples to draw 16 | max_new_tokens = 500 # number of tokens generated in each sample 17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cpu' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | dtype = 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | compile = False # use PyTorch 2.0 to compile the model to be faster 23 | exec(open('configurator.py').read()) # overrides from command line or config file 24 | # ----------------------------------------------------------------------------- 25 | 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 33 | 34 | # model 35 | if init_from == 'resume': 36 | # init from a model saved in a specific directory 37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 38 | checkpoint = torch.load(ckpt_path, map_location=device) 39 | gptconf = GPTConfig(**checkpoint['model_args']) 40 | model = GPT(gptconf) 41 | state_dict = checkpoint['model'] 42 | unwanted_prefix = '_orig_mod.' 43 | for k,v in list(state_dict.items()): 44 | if k.startswith(unwanted_prefix): 45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 46 | model.load_state_dict(state_dict) 47 | elif init_from.startswith('gpt2'): 48 | # init from a given GPT-2 model 49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 50 | 51 | model.eval() 52 | model.to(device) 53 | if compile: 54 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 55 | 56 | # look for the meta pickle in case it is available in the dataset folder 57 | load_meta = False 58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... 59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') 60 | load_meta = os.path.exists(meta_path) 61 | if load_meta: 62 | print(f"Loading meta from {meta_path}...") 63 | with open(meta_path, 'rb') as f: 64 | meta = pickle.load(f) 65 | # TODO want to make this more general to arbitrary encoder/decoder schemes 66 | stoi, itos = meta['stoi'], meta['itos'] 67 | encode = lambda s: [stoi[c] for c in s] 68 | decode = lambda l: ''.join([itos[i] for i in l]) 69 | else: 70 | # ok let's assume gpt-2 encodings by default 71 | print("No meta.pkl found, assuming GPT-2 encodings...") 72 | enc = tiktoken.get_encoding("gpt2") 73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 74 | decode = lambda l: enc.decode(l) 75 | 76 | # encode the beginning of the prompt 77 | if start.startswith('FILE:'): 78 | with open(start[5:], 'r', encoding='utf-8') as f: 79 | start = f.read() 80 | start_ids = encode(start) 81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 82 | 83 | # run generation 84 | with torch.no_grad(): 85 | with ctx: 86 | for k in range(num_samples): 87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 88 | print(decode(y[0].tolist())) 89 | print('---------------') 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | 25 | import numpy as np 26 | import torch 27 | 28 | from model import GPTConfig, GPT 29 | import yaml 30 | from torch.nn.parallel import DistributedDataParallel as DDP 31 | from trainers.trainer import Trainer 32 | 33 | # load config.yaml from current directory 34 | with open('config/config.yaml') as f: 35 | conf = yaml.load(f, Loader=yaml.FullLoader) 36 | # nested dictionary structure 37 | config = {} 38 | for k, v in conf.items(): 39 | for k2, v2 in v.items(): 40 | config[k2] = v2 41 | # convert to dotdict 42 | print(config) 43 | trainer = Trainer(config) 44 | 45 | trainer.train() -------------------------------------------------------------------------------- /train_reward_model.py: -------------------------------------------------------------------------------- 1 | 2 | from trainers.reward_trainer import RewardModelTrainer 3 | import yaml 4 | 5 | from datasets import load_dataset 6 | from torch.utils.data import Dataset 7 | from tqdm import tqdm 8 | import tiktoken 9 | import torch 10 | 11 | 12 | # with inspiration from CarperAI's trlx library 13 | 14 | with open('config/config_reward.yaml') as f: 15 | conf = yaml.load(f, Loader=yaml.FullLoader) 16 | # nested dictionary structure 17 | config = {} 18 | for k, v in conf.items(): 19 | for k2, v2 in v.items(): 20 | config[k2] = v2 21 | print(config) 22 | 23 | def create_comparison_dataset(path="CarperAI/openai_summarize_comparisons", split="train"): 24 | dataset = load_dataset(path, split=split) 25 | pairs = [] 26 | for sample in tqdm(dataset): 27 | pair = {} 28 | prompt = sample["prompt"] 29 | chosen_summary = sample["chosen"] 30 | rejected_summary = sample["rejected"] 31 | if chosen_summary == rejected_summary: 32 | continue 33 | if len(chosen_summary.split()) < 5 or len(rejected_summary.split()) < 5: 34 | continue 35 | pair["chosen"] = prompt + "\n" + chosen_summary 36 | pair["rejected"] = prompt + "\n" + rejected_summary 37 | pairs.append(pair) 38 | return pairs 39 | 40 | 41 | class PairwiseDataset(Dataset): 42 | def __init__(self, pairs, max_length): 43 | # self.chosen_input_ids = [] 44 | # self.chosen_attn_masks = [] 45 | # self.rejected_input_ids = [] 46 | # self.rejected_attn_masks = [] 47 | self.chosens = [] 48 | self.rejecteds = [] 49 | self.enc = tiktoken.get_encoding("gpt2") 50 | for pair in tqdm(pairs): 51 | chosen_enc = self.enc.encode("<|startoftext|>" + pair['chosen'] + "<|endoftext|>", allowed_special="all")[-max_length:] 52 | rejected_enc = self.enc.encode("<|startoftext|>" + pair['rejected'] + "<|endoftext|>", allowed_special="all")[-max_length:] 53 | self.chosens.append(chosen_enc) 54 | self.rejecteds.append(rejected_enc) 55 | 56 | def __len__(self): 57 | return len(self.chosens) 58 | 59 | def __getitem__(self, idx): 60 | return ( 61 | self.chosens[idx], 62 | self.rejecteds[idx], 63 | ) 64 | 65 | class DataCollatorReward: 66 | def __call__(self, data): 67 | batch = {} 68 | batch["chosen_ids"] = torch.tensor([f[0] for f in data]) 69 | batch["rejected_ids"] = torch.tensor([f[1] for f in data]) 70 | return batch 71 | 72 | def collate_fn(data): 73 | batch = {} 74 | batch["chosen_ids"] = torch.tensor([f[0] for f in data]) 75 | batch["rejected_ids"] = torch.tensor([f[1] for f in data]) 76 | return batch 77 | 78 | data_path = "CarperAI/openai_summarize_comparisons" 79 | train_pairs = create_comparison_dataset(data_path, "train") 80 | val_pairs = create_comparison_dataset(data_path, "test") 81 | 82 | 83 | # Make pairwise datasets for training 84 | print("Creating pairwise datasets") 85 | train_dataset = PairwiseDataset(train_pairs, max_length=config['block_size']) 86 | val_dataset = PairwiseDataset(val_pairs, max_length=config['block_size']) 87 | 88 | trainer = RewardModelTrainer(config, train_dataset, val_dataset, collate_fn=collate_fn) 89 | 90 | trainer.train() 91 | 92 | -------------------------------------------------------------------------------- /train_reward_model_simple.py: -------------------------------------------------------------------------------- 1 | 2 | from trainers.reward_trainer import ProbRewardModelTrainer 3 | import yaml 4 | 5 | from tqdm import tqdm 6 | import tiktoken 7 | import torch 8 | 9 | with open('config/config_reward.yaml') as f: 10 | conf = yaml.load(f, Loader=yaml.FullLoader) 11 | # nested dictionary structure 12 | config = {} 13 | for k, v in conf.items(): 14 | for k2, v2 in v.items(): 15 | config[k2] = v2 16 | print(config) 17 | 18 | trainer = ProbRewardModelTrainer(config, discrete_reward=True) 19 | 20 | trainer.train() 21 | 22 | -------------------------------------------------------------------------------- /train_rl.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from torch.nn.parallel import DistributedDataParallel as DDP 3 | from trainers.rl_trainer import PolicyGradientTrainer, GumbelTrainer 4 | 5 | # load config.yaml from current directory 6 | with open('config/config_rl.yaml') as f: 7 | conf = yaml.load(f, Loader=yaml.FullLoader) 8 | # nested dictionary structure 9 | config = {} 10 | for k, v in conf.items(): 11 | for k2, v2 in v.items(): 12 | config[k2] = v2 13 | # convert to dotdict 14 | 15 | if config['method'] == 'gumbel': 16 | print('Using Gumbel method') 17 | assert config['hard_code_reward'] == False, 'hard_code_reward must be False for Gumbel method' 18 | trainer = GumbelTrainer(config) 19 | elif config['method'] == 'pg': 20 | print('Using Policy Gradient method') 21 | trainer = PolicyGradientTrainer(config) 22 | else: 23 | raise NotImplementedError 24 | 25 | trainer.train() -------------------------------------------------------------------------------- /trainers/reward_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | from torch.distributed import destroy_process_group 5 | import wandb 6 | import time, os 7 | from model import RLHF 8 | from trainers.trainer import Trainer 9 | 10 | # This one for reward models similar to InstructGPT paper (rewards based on comparisons) 11 | class RewardModelTrainer(Trainer): 12 | def __init__(self, config, train_data, val_data, collate_fn): 13 | super().__init__(config) 14 | import tiktoken 15 | self.enc = tiktoken.get_encoding("gpt2") 16 | self.mode = 'reward' 17 | from torch.utils.data import DataLoader 18 | train_dataloader = DataLoader(train_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn) 19 | val_dataloader = DataLoader(val_data, batch_size=self.batch_size, shuffle=True, collate_fn=collate_fn) 20 | self.train_dataloader = train_dataloader 21 | self.val_dataloader = val_dataloader 22 | 23 | 24 | def get_batch(self, split): 25 | dataloader = self.train_dataloader if split == 'train' else self.val_dataloader 26 | batch = next(iter(dataloader)) 27 | x, y = batch['chosen_ids'], batch['rejected_ids'] 28 | if self.device_type == 'cuda': 29 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 30 | x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True) 31 | else: 32 | x, y = x.to(self.device), y.to(self.device) 33 | return x, y 34 | 35 | # helps estimate an arbitrarily accurate loss over either split using many batches 36 | @torch.no_grad() 37 | def estimate_loss(self, model, ctx): 38 | out = {} 39 | model.eval() 40 | for split in ['train', 'val']: 41 | losses = torch.zeros(self.eval_iters) 42 | for k in range(self.eval_iters): 43 | chosen, rejected = self.get_batch(split) 44 | with ctx: 45 | reward_chosen = model(chosen) 46 | reward_rejected = model(rejected) 47 | loss = -torch.log(torch.sigmoid(reward_chosen - reward_rejected)).mean() 48 | losses[k] = loss.item() 49 | out[split] = losses.mean() 50 | model.train() 51 | return out 52 | 53 | def evaluate(self, model, ctx): 54 | losses = self.estimate_loss(model, ctx) 55 | print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 56 | 57 | if self.wandb_log: 58 | wandb.log({ 59 | "iter": self.iter_num, 60 | "train/loss": losses['train'], 61 | "val/loss": losses['val'], 62 | "lr": self.lr, 63 | "mfu": self.running_mfu*100, # convert to percentage 64 | }) 65 | if losses['val'] < self.best_val_loss or self.always_save_checkpoint: 66 | self.best_val_loss = losses['val'] 67 | raw_model = model.module if self.ddp else model 68 | if self.iter_num > 0: 69 | checkpoint = { 70 | 'model': raw_model.state_dict(), 71 | 'optimizer': self.optimizer.state_dict(), 72 | 'model_args': self.model_args, 73 | 'iter_num': self.iter_num, 74 | 'best_val_loss': self.best_val_loss, 75 | 'config': self.config, 76 | } 77 | print(f"saving checkpoint to {self.config['out_dir_multihead']}") 78 | torch.save(checkpoint, os.path.join(self.config['out_dir_multihead'], 'ckpt.pt')) 79 | 80 | def train(self): 81 | # set up distributed training 82 | self.setup_ddp() 83 | 84 | ctx, meta_vocab_size = self.setup() 85 | 86 | # model init 87 | 88 | 89 | model = self.init_model() 90 | model = RLHF(model, self.mode) 91 | print('Config of model: ', model.config) 92 | 93 | if self.config['init_multihead_from'] == 'scratch': 94 | print("initializing multihead from scratch") 95 | else: 96 | if self.config['init_multihead_from'] == 'resume': 97 | print(f"Resuming training from {self.config['out_dir_multihead']}") 98 | # resume training from a checkpoint. 99 | ckpt_path = os.path.join(self.config['out_dir_multihead'], 'ckpt.pt') 100 | checkpoint = torch.load(ckpt_path, map_location=self.device) 101 | state_dict = checkpoint['model'] 102 | # fix the keys of the state dictionary :( 103 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 104 | unwanted_prefix = '_orig_mod.' 105 | for k,v in list(state_dict.items()): 106 | if k.startswith(unwanted_prefix): 107 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 108 | model.load_state_dict(state_dict) 109 | 110 | 111 | model.to(self.device) 112 | 113 | # self.optimizer = torch.optim.AdamW(model.model.reward_head.parameters(), lr=1e-3) 114 | self.optimizer = torch.optim.AdamW(model.model.parameters(), lr=1e-4) 115 | 116 | model = self.setup_model(model) 117 | 118 | # logging 119 | if self.wandb_log and self.master_process: 120 | wandb.init(project=self.wandb_project, name=self.wandb_run_name, config=self.config) 121 | 122 | # training loop 123 | chosen, rejected = self.get_batch('train') # fetch the very first batch 124 | t0 = time.time() 125 | local_iter_num = 0 # number of iterations in the lifetime of this process 126 | self.running_mfu = -1.0 127 | loss = None 128 | while True: 129 | 130 | # determine and set the learning rate for this iteration 131 | lr = self.get_lr(self.iter_num) if self.decay_lr else self.learning_rate 132 | for param_group in self.optimizer.param_groups: 133 | param_group['lr'] = lr 134 | 135 | # # every once in a while evaluate the loss on train and val sets 136 | if self.iter_num % self.eval_interval == 0 and self.master_process: 137 | self.evaluate(model, ctx) 138 | 139 | if self.iter_num == 0 and self.eval_only: 140 | break 141 | 142 | # sample a batch of data 143 | chosen, rejected = self.get_batch('train') 144 | 145 | # evaluate the loss 146 | reward_chosen = model(chosen) 147 | reward_rejected = model(rejected) 148 | loss = -torch.log(torch.sigmoid(reward_chosen - reward_rejected)).mean() 149 | 150 | 151 | 152 | self.optimizer.zero_grad(set_to_none=True) 153 | loss.backward() 154 | self.optimizer.step() 155 | 156 | # timing and logging 157 | t1 = time.time() 158 | # dt = t1 - t0 159 | t0 = t1 160 | self.iter_num += 1 161 | local_iter_num += 1 162 | 163 | # termination conditions 164 | if self.iter_num > self.max_iters: 165 | break 166 | 167 | if self.ddp: 168 | destroy_process_group() 169 | 170 | # This one is for reward models which output a probability of reward directly from a given text (no comparison) 171 | class ProbRewardModelTrainer(Trainer): 172 | def __init__(self, config, discrete_reward=False): 173 | super().__init__(config) 174 | import tiktoken 175 | self.enc = tiktoken.get_encoding("gpt2") 176 | self.mode = 'reward' 177 | self.discrete_reward = discrete_reward 178 | 179 | def get_batch(self, split): 180 | # generate a small batch of data of inputs x and targets y 181 | data = self.train_data if split == 'train' else self.val_data 182 | ix = torch.randint(len(data) - self.block_size, (self.batch_size,)) 183 | x = torch.stack([torch.from_numpy((data[i:i+self.block_size]).astype(np.int64)) for i in ix]) 184 | y = torch.stack([self.reward(torch.from_numpy((data[i+1:i+1+self.block_size]).astype(np.int64))) for i in ix]) 185 | 186 | 187 | if self.device_type == 'cuda': 188 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 189 | x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True) 190 | else: 191 | x, y = x.to(self.device), y.to(self.device) 192 | return x, y 193 | 194 | def reward(self, sequence, t='and'): 195 | if t in self.enc.decode(sequence.tolist()): 196 | # print('hello') 197 | return torch.tensor([0.0,1.0]) 198 | else: 199 | return torch.tensor([1.0, 0.0]) 200 | 201 | def evaluate(self, model, ctx, X, lr): 202 | losses = self.estimate_loss(model, ctx) 203 | print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 204 | 205 | 206 | text = self.enc.decode(X[self.iter_num % self.eval_interval].tolist()) 207 | 208 | try: 209 | reward_probs, _ = model(X[self.iter_num % self.eval_interval].unsqueeze(0)) 210 | actual_reward_probs = self.reward(X[self.iter_num % self.eval_interval])[1] 211 | 212 | print("input: ", text[:30], f"expect {actual_reward_probs}, reward: {reward_probs[0][-1]} \n") 213 | except: 214 | pass 215 | 216 | # test_text = text[:4] + 'z' + text[4 + 1:-1] 217 | test_text = list(text) 218 | test_text[3] = ' ' 219 | test_text[4] = 'a' 220 | test_text[5] = 'n' 221 | test_text[6] = 'd' 222 | test_text[7] = ' ' 223 | test_text = ''.join(test_text) 224 | try: 225 | test_text_enc = torch.tensor(self.enc.encode(test_text)[:self.block_size]).unsqueeze(0) 226 | test_reward_probs, _ = model(test_text_enc.to(self.device)) 227 | actual_reward_probs = self.reward(test_text_enc[0].to(self.device))[1] 228 | 229 | print("input: ", test_text[:30], f"expect {actual_reward_probs}, reward: {test_reward_probs[0][-1]} \n") 230 | except: 231 | pass 232 | 233 | if self.wandb_log: 234 | wandb.log({ 235 | "iter": self.iter_num, 236 | "train/loss": losses['train'], 237 | "val/loss": losses['val'], 238 | "lr": lr, 239 | # "mfu": self.running_mfu*100, # convert to percentage 240 | }) 241 | if losses['val'] < self.best_val_loss or self.always_save_checkpoint: 242 | self.best_val_loss = losses['val'] 243 | raw_model = model.module if self.ddp else model 244 | if self.iter_num > 0: 245 | checkpoint = { 246 | 'model': raw_model.state_dict(), 247 | 'optimizer': self.optimizer.state_dict(), 248 | 'model_args': self.model_args, 249 | 'iter_num': self.iter_num, 250 | 'best_val_loss': self.best_val_loss, 251 | 'config': self.config, 252 | } 253 | print(f"saving checkpoint to {self.config['out_dir_multihead']}") 254 | torch.save(checkpoint, os.path.join(self.config['out_dir_multihead'], 'ckpt.pt')) 255 | 256 | def train(self): 257 | # set up distributed training 258 | self.setup_ddp() 259 | 260 | ctx, meta_vocab_size = self.setup() 261 | 262 | # model init 263 | 264 | if self.master_process: 265 | os.makedirs(self.config['out_dir_multihead'], exist_ok=True) 266 | 267 | model = self.init_model() 268 | model = RLHF(model, self.mode, discrete_reward=self.discrete_reward) 269 | 270 | if self.config['init_multihead_from'] == 'scratch': 271 | print("initializing multihead from scratch") 272 | else: 273 | if self.config['init_multihead_from'] == 'resume': 274 | print(f"Resuming training from {self.config['out_dir_multihead']}") 275 | # resume training from a checkpoint. 276 | ckpt_path = os.path.join(self.config['out_dir_multihead'], 'ckpt.pt') 277 | checkpoint = torch.load(ckpt_path, map_location=self.device) 278 | state_dict = checkpoint['model'] 279 | # fix the keys of the state dictionary :( 280 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 281 | unwanted_prefix = '_orig_mod.' 282 | for k,v in list(state_dict.items()): 283 | if k.startswith(unwanted_prefix): 284 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 285 | model.load_state_dict(state_dict) 286 | 287 | 288 | model.to(self.device) 289 | 290 | # self.optimizer = torch.optim.AdamW(model.model.reward_head.parameters(), lr=1e-3) 291 | self.optimizer = torch.optim.AdamW(model.model.parameters(), lr=1e-3) 292 | 293 | model = self.setup_model(model) 294 | 295 | # logging 296 | if self.wandb_log and self.master_process: 297 | wandb.init(project=self.wandb_project, name=self.wandb_run_name, config=self.config) 298 | 299 | # training loop 300 | X, Y = self.get_batch('train') # fetch the very first batch 301 | t0 = time.time() 302 | local_iter_num = 0 # number of iterations in the lifetime of this process 303 | self.running_mfu = -1.0 304 | while True: 305 | 306 | # determine and set the learning rate for this iteration 307 | lr = self.get_lr(self.iter_num) if self.decay_lr else self.learning_rate 308 | for param_group in self.optimizer.param_groups: 309 | param_group['lr'] = lr 310 | 311 | # every once in a while evaluate the loss on train and val sets 312 | if self.iter_num % self.eval_interval == 0 and self.master_process: 313 | self.evaluate(model, ctx, X, lr) 314 | 315 | if self.iter_num == 0 and self.eval_only: 316 | break 317 | 318 | # sample a batch of data 319 | X, Y = self.get_batch('train') 320 | 321 | # evaluate the loss 322 | logits, loss = model(X, Y) 323 | self.optimizer.zero_grad(set_to_none=True) 324 | loss.backward() 325 | self.optimizer.step() 326 | 327 | # timing and logging 328 | t1 = time.time() 329 | # dt = t1 - t0 330 | t0 = t1 331 | self.iter_num += 1 332 | local_iter_num += 1 333 | 334 | # termination conditions 335 | if self.iter_num > self.max_iters: 336 | break 337 | 338 | if self.ddp: 339 | destroy_process_group() -------------------------------------------------------------------------------- /trainers/rl_trainer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torch.nn.parallel import DistributedDataParallel as DDP 4 | from torch.distributed import destroy_process_group 5 | import time, os 6 | from model import RLHF 7 | from trainers.trainer import Trainer 8 | 9 | # TODO: this works but is currently crude and incomplete, critic implementation plus PPO are obvious next steps 10 | class PolicyGradientTrainer(Trainer): 11 | def __init__(self, config): 12 | super().__init__(config) 13 | import tiktoken 14 | self.enc = tiktoken.get_encoding("gpt2") 15 | self.mode = 'RL' 16 | 17 | def train(self): 18 | 19 | self.setup_ddp() 20 | 21 | ctx, meta_vocab_size = self.setup() 22 | 23 | # model init 24 | model = self.init_model() 25 | 26 | model = RLHF(model, self.mode, discrete_reward=self.config['discrete_reward']) 27 | 28 | if self.config['init_multihead_from'] == 'scratch': 29 | print("initializing multihead from scratch") 30 | else: 31 | if self.config['init_multihead_from'] == 'resume': 32 | print(f"Resuming training from {self.config['out_dir_multihead']}") 33 | # resume training from a checkpoint. 34 | ckpt_path = os.path.join(self.config['out_dir_multihead'], 'ckpt.pt') 35 | checkpoint = torch.load(ckpt_path, map_location=self.device) 36 | state_dict = checkpoint['model'] 37 | # fix the keys of the state dictionary :( 38 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 39 | unwanted_prefix = '_orig_mod.' 40 | for k,v in list(state_dict.items()): 41 | if k.startswith(unwanted_prefix): 42 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 43 | model.load_state_dict(state_dict) 44 | 45 | 46 | if self.config['hard_code_reward']: 47 | reward_model = None 48 | print('Using hard-coded reward') 49 | else: 50 | print('Using learned reward model') 51 | if self.config['separate_reward_model']: 52 | import copy 53 | reward_model = copy.deepcopy(model) 54 | print('Reward model instantiated separately') 55 | else: 56 | reward_model = model 57 | print('Reward model and actor model share backbone') 58 | reward_model.to(self.device) 59 | 60 | model.to(self.device) 61 | 62 | # actor_optimizer = torch.optim.AdamW(model.model.policy_head.parameters(), lr=1e-2) 63 | actor_optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3) 64 | 65 | last_time = time.time() 66 | rews_all = [] 67 | max_iters = 100000 68 | X, Y = self.get_batch('train') # fetch the very first batch 69 | t0 = time.time() 70 | for iter in range(max_iters): 71 | 72 | states, log_probs, log_probs_reference, rewards, advantages = model.generate( 73 | X, self.block_size, self.device, self.block_size, reward_model=reward_model, hard_code_reward=self.config['hard_code_reward']) 74 | 75 | # minus KL divergence 76 | rets = advantages * log_probs.squeeze() #- 1*(log_probs-log_probs_reference) #- 0.05*log_probs 77 | actor_loss = -rets.sum() 78 | actor_optimizer.zero_grad(set_to_none=True) 79 | actor_loss.backward() 80 | actor_optimizer.step() 81 | 82 | torch.mean(rewards) 83 | 84 | rews_all.append(rewards.mean().detach().cpu().numpy()) 85 | 86 | if iter % 1000 == 0: 87 | t1 = time.time() 88 | print(f'iter: {iter}, time: {t1-t0}') 89 | # print(actor_loss, critic_loss) 90 | print(f'Actor loss: {actor_loss}, iter: {iter}') 91 | print(f'rets: {np.mean(rews_all[-1000:])}') 92 | current_time = time.time() 93 | # print(current_time - last_time) 94 | last_time = current_time 95 | text = model.generate(X, self.block_size, self.device, self.block_size, reward_model=reward_model)[0] 96 | for i in range(1): 97 | text_i = text[i,:] 98 | # print(reward(text_i)) 99 | try: 100 | print(self.enc.decode(text_i.tolist())) 101 | except: 102 | continue 103 | 104 | 105 | class GumbelTrainer(Trainer): 106 | def __init__(self, config): 107 | super().__init__(config) 108 | import tiktoken 109 | self.enc = tiktoken.get_encoding("gpt2") 110 | self.mode = 'RL' 111 | 112 | def train(self): 113 | 114 | self.setup_ddp() 115 | 116 | ctx, meta_vocab_size = self.setup() 117 | 118 | # model init 119 | model = self.init_model() 120 | 121 | rl_model = RLHF(model, self.mode, discrete_reward=self.config['discrete_reward']) 122 | 123 | 124 | # The current approach is to use a separate reward model because otherwise optimisation of the reward model changes upstream parameters impacting performance of the multihead 125 | # I therefore load the language model from 'out_dir' and the reward model from 'out_dir_multihead' 126 | 127 | if self.config['init_multihead_from'] == 'scratch': 128 | print("initializing multihead from scratch") 129 | else: 130 | if self.config['init_multihead_from'] == 'resume': 131 | print(f"Resuming training from {self.config['out_dir']}") 132 | # resume training from a checkpoint. 133 | ckpt_path = os.path.join(self.config['out_dir'], 'ckpt.pt') 134 | checkpoint = torch.load(ckpt_path, map_location=self.device) 135 | state_dict = checkpoint['model'] 136 | # fix the keys of the state dictionary :( 137 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 138 | unwanted_prefix = '_orig_mod.' 139 | for k,v in list(state_dict.items()): 140 | if k.startswith(unwanted_prefix): 141 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 142 | model.load_state_dict(state_dict) 143 | 144 | separate_reward_model = True 145 | if separate_reward_model: 146 | print('Reward model instantiated as copy') 147 | import copy 148 | reward_model = copy.deepcopy(model) 149 | 150 | print(f"Resuming reward model from {self.config['out_dir_multihead']}") 151 | 152 | reward_model = RLHF(reward_model, self.mode, discrete_reward=self.config['discrete_reward']) 153 | # resume training from a checkpoint. 154 | ckpt_path = os.path.join(self.config['out_dir_multihead'], 'ckpt.pt') 155 | checkpoint = torch.load(ckpt_path, map_location=self.device) 156 | state_dict = checkpoint['model'] 157 | # fix the keys of the state dictionary :( 158 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 159 | unwanted_prefix = '_orig_mod.' 160 | for k,v in list(state_dict.items()): 161 | if k.startswith(unwanted_prefix): 162 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 163 | reward_model.load_state_dict(state_dict) 164 | else: 165 | reward_model = rl_model 166 | rl_model.to(self.device) 167 | reward_model.to(self.device) 168 | 169 | gumbel_optimizer = torch.optim.AdamW(rl_model.parameters(), lr=1e-3) 170 | 171 | # initialize a GradScaler. If enabled=False scaler is a no-op 172 | scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == 'float16')) 173 | 174 | last_time = time.time() 175 | rews_all = [] 176 | max_iters = 100000 177 | 178 | X, Y = self.get_batch('train') # fetch the very first batch 179 | 180 | X = torch.zeros((X.shape[0], 1), dtype=torch.long).to(self.device) # for now there is no prompt 181 | 182 | t0 = time.time() 183 | for iter in range(max_iters): 184 | 185 | for micro_step in range(self.gradient_accumulation_steps): 186 | if self.ddp: 187 | # in DDP training we only need to sync gradients at the last micro step. 188 | # the official way to do this is with model.no_sync() context manager, but 189 | # I really dislike that this bloats the code and forces us to repeat code 190 | # looking at the source of that context manager, it just toggles this variable 191 | rl_model.require_backward_grad_sync = (micro_step == self.gradient_accumulation_steps - 1) 192 | with ctx: 193 | states, rewards = rl_model.generate_gumbel(X, self.config['episode_length'], self.device, self.block_size, reward_model=reward_model) 194 | mean_reward = rewards.mean() 195 | loss = -mean_reward 196 | # # immediately async prefetch next batch while model is doing the forward pass on the GPU 197 | # X, Y = self.get_batch('train') 198 | # backward pass, with gradient scaling if training in fp16 199 | scaler.scale(loss).backward() 200 | 201 | # clip the gradient 202 | if self.grad_clip != 0.0: 203 | scaler.unscale_(gumbel_optimizer) 204 | torch.nn.utils.clip_grad_norm_(rl_model.parameters(), self.grad_clip) 205 | # step the optimizer and scaler if training in fp16 206 | scaler.step(gumbel_optimizer) 207 | scaler.update() 208 | # flush the gradients as soon as we can, no need for this memory anymore 209 | gumbel_optimizer.zero_grad(set_to_none=True) 210 | 211 | rews_all.append(mean_reward.detach().cpu().numpy()) 212 | eval_interval = self.config['eval_interval'] 213 | if iter % eval_interval == 0: 214 | t1 = time.time() 215 | print(f'iter: {iter}, time: {t1-t0}') 216 | print(f'rets: {np.mean(rews_all[-eval_interval:])}') 217 | current_time = time.time() 218 | # print(current_time - last_time) 219 | last_time = current_time 220 | text = rl_model.generate(X, self.config['episode_length'], self.device, self.block_size, reward_model=reward_model)[0] 221 | for i in range(1): 222 | text_i = text[i,:] 223 | # print(reward(text_i)) 224 | try: 225 | print(self.enc.decode(text_i.tolist())) 226 | except: 227 | continue -------------------------------------------------------------------------------- /trainers/trainer.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import time 4 | import math 5 | import pickle 6 | from contextlib import nullcontext 7 | 8 | import numpy as np 9 | import torch 10 | 11 | from model import GPTConfig, GPT, RLHF 12 | import yaml 13 | from torch.nn.parallel import DistributedDataParallel as DDP 14 | from torch.distributed import init_process_group, destroy_process_group 15 | from utils import dotdict 16 | import wandb 17 | 18 | from torch.nn import functional as F 19 | 20 | class Trainer(): 21 | def __init__(self, config): 22 | self.config = config 23 | self.from_config(config) 24 | 25 | self.model_args = dict(n_layer=self.n_layer, n_head=self.n_head, n_embd=self.n_embd, block_size=self.block_size, 26 | bias=self.bias, vocab_size=None, dropout=self.dropout) # start with model_args from command line 27 | self.meta_vocab_size = None 28 | self.iter_num = 0 29 | self.best_val_loss = float('inf') 30 | 31 | def from_config(self, config): 32 | config = dotdict(config) 33 | 34 | # IO 35 | self.out_dir = config.out_dir 36 | self.eval_interval = config.eval_interval 37 | self.log_interval = config.log_interval 38 | self.eval_iters = config.eval_iters 39 | self.eval_only = config.eval_only 40 | self.always_save_checkpoint = config.always_save_checkpoint 41 | self.init_from = config.init_from 42 | 43 | # wandb 44 | self.wandb_log = config.wandb_log 45 | self.wandb_project = config.wandb_project 46 | self.wandb_run_name = config.wandb_run_name 47 | 48 | # data 49 | self.dataset = config.dataset 50 | self.gradient_accumulation_steps = config.gradient_accumulation_steps 51 | self.batch_size = config.batch_size 52 | self.block_size = config.block_size 53 | 54 | # model 55 | self.n_layer = config.n_layer 56 | self.n_head = config.n_head 57 | self.n_embd = config.n_embd 58 | self.dropout = config.dropout 59 | self.bias = config.bias 60 | 61 | # optimizer 62 | self.learning_rate = config.learning_rate 63 | self.max_iters = config.max_iters 64 | self.weight_decay = config.weight_decay 65 | self.beta1 = config.beta1 66 | self.beta2 = config.beta2 67 | self.grad_clip = config.grad_clip 68 | self.decay_lr = config.decay_lr 69 | self.warmup_iters = config.warmup_iters 70 | self.lr_decay_iters = config.lr_decay_iters 71 | self.min_lr = config.min_lr 72 | 73 | # DDP 74 | self.backend = config.backend 75 | 76 | # system 77 | self.device = config.device 78 | self.dtype = config.dtype 79 | self.compile = config.compile 80 | 81 | print(self.out_dir) 82 | 83 | def setup_ddp(self): 84 | 85 | self.ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 86 | if self.ddp: 87 | init_process_group(backend=self.backend) 88 | self.ddp_rank = int(os.environ['RANK']) 89 | self.ddp_local_rank = int(os.environ['LOCAL_RANK']) 90 | self.world_size = int(os.environ['WORLD_SIZE']) # total number of training processes 91 | self.device = f'cuda:{self.ddp_local_rank}' 92 | torch.cuda.set_device(self.device) 93 | self.master_process = self.ddp_rank == 0 # this process will do logging, checkpointing etc. 94 | self.seed_offset = self.ddp_rank # each process gets a different seed 95 | else: 96 | # if not ddp, we are running on a single gpu, and one process 97 | self.world_size = 1 98 | self.master_process = True 99 | self.seed_offset = 0 100 | self.ddp_local_rank = None 101 | 102 | def get_batch(self, split): 103 | data = self.train_data if split == 'train' else self.val_data 104 | ix = torch.randint(len(data) - self.block_size, (self.batch_size,)) 105 | x = torch.stack([torch.from_numpy((data[i:i+self.block_size]).astype(np.int64)) for i in ix]) 106 | y = torch.stack([torch.from_numpy((data[i+1:i+1+self.block_size]).astype(np.int64)) for i in ix]) 107 | if self.device_type == 'cuda': 108 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 109 | x, y = x.pin_memory().to(self.device, non_blocking=True), y.pin_memory().to(self.device, non_blocking=True) 110 | else: 111 | x, y = x.to(self.device), y.to(self.device) 112 | return x, y 113 | 114 | def get_lr(self, it): 115 | # learning rate decay scheduler (cosine with warmup) 116 | # 1) linear warmup for warmup_iters steps 117 | if it < self.warmup_iters: 118 | return self.learning_rate * it / self.warmup_iters 119 | # 2) if it > lr_decay_iters, return min learning rate 120 | if it > self.lr_decay_iters: 121 | return self.min_lr 122 | # 3) in between, use cosine decay down to min learning rate 123 | decay_ratio = (it - self.warmup_iters) / (self.lr_decay_iters - self.warmup_iters) 124 | assert 0 <= decay_ratio <= 1 125 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 126 | return self.min_lr + coeff * (self.learning_rate - self.min_lr) 127 | 128 | def init_model(self): 129 | if self.init_from == 'scratch': 130 | # init a new model from scratch 131 | print("Initializing a new model from scratch") 132 | # determine the vocab size we'll use for from-scratch training 133 | if self.meta_vocab_size is None: 134 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 135 | self.model_args['vocab_size'] = self.meta_vocab_size if self.meta_vocab_size is not None else 50304 136 | gptconf = GPTConfig(**self.model_args) 137 | model = GPT(gptconf) 138 | elif self.init_from == 'resume': 139 | print(f"Resuming training from {self.out_dir}") 140 | # resume training from a checkpoint. 141 | ckpt_path = os.path.join(self.out_dir, 'ckpt.pt') 142 | checkpoint = torch.load(ckpt_path, map_location=self.device) 143 | checkpoint_model_args = checkpoint['model_args'] 144 | # force these config attributes to be equal otherwise we can't even resume training 145 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 146 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 147 | self.model_args[k] = checkpoint_model_args[k] 148 | # create the model 149 | gptconf = GPTConfig(**self.model_args) 150 | model = GPT(gptconf) 151 | state_dict = checkpoint['model'] 152 | # fix the keys of the state dictionary :( 153 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 154 | unwanted_prefix = '_orig_mod.' 155 | for k,v in list(state_dict.items()): 156 | if k.startswith(unwanted_prefix): 157 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 158 | model.load_state_dict(state_dict) 159 | self.iter_num = checkpoint['iter_num'] 160 | self.best_val_loss = checkpoint['best_val_loss'] 161 | self.checkpoint = checkpoint 162 | elif self.init_from.startswith('gpt2'): 163 | print(f"Initializing from OpenAI GPT-2 weights: {self.init_from}") 164 | # initialize from OpenAI GPT-2 weights 165 | override_args = dict(dropout=self.dropout) 166 | model = GPT.from_pretrained(init_from, override_args) 167 | # read off the created config params, so we can store them into checkpoint correctly 168 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 169 | self.model_args[k] = getattr(model.config, k) 170 | # crop down the model block size if desired, using model surgery 171 | if self.block_size < model.config.block_size: 172 | model.crop_block_size(self.block_size) 173 | self.model_args['block_size'] = self.block_size # so that the checkpoint will have the right value 174 | 175 | return model 176 | def setup(self): 177 | if self.master_process: 178 | os.makedirs(self.out_dir, exist_ok=True) 179 | 180 | torch.manual_seed(1337 + self.seed_offset) 181 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 182 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 183 | self.device_type = 'cuda' if 'cuda' in self.device else 'cpu' # for later use in torch.autocast 184 | # note: float16 data type will automatically use a GradScaler 185 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[self.dtype] 186 | ctx = nullcontext() if self.device_type == 'cpu' else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype) 187 | 188 | # poor man's data loader 189 | data_dir = os.path.join('data', self.dataset) 190 | self.train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 191 | self.val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 192 | 193 | 194 | # attempt to derive vocab_size from the dataset 195 | meta_path = os.path.join(data_dir, 'meta.pkl') 196 | meta_vocab_size = None 197 | if os.path.exists(meta_path): 198 | with open(meta_path, 'rb') as f: 199 | meta = pickle.load(f) 200 | meta_vocab_size = meta['vocab_size'] 201 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 202 | 203 | return ctx, meta_vocab_size 204 | 205 | def setup_model(self, model): 206 | # compile the model 207 | if self.compile: 208 | print("compiling the model... (takes a ~minute)") 209 | unoptimized_model = model 210 | model = torch.compile(model) # requires PyTorch 2.0 211 | 212 | # wrap model into DDP container 213 | if self.ddp: 214 | model = DDP(model, device_ids=[self.ddp_local_rank]) 215 | 216 | return model 217 | 218 | def evaluate(self, model, ctx, lr): 219 | losses = self.estimate_loss(model, ctx) 220 | print(f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 221 | if self.wandb_log: 222 | wandb.log({ 223 | "iter": self.iter_num, 224 | "train/loss": losses['train'], 225 | "val/loss": losses['val'], 226 | "lr": lr, 227 | # "mfu": running_mfu*100, # convert to percentage 228 | }) 229 | if losses['val'] < self.best_val_loss or self.always_save_checkpoint: 230 | self.best_val_loss = losses['val'] 231 | raw_model = model.module if self.ddp else model 232 | if self.iter_num > 0: 233 | checkpoint = { 234 | 'model': raw_model.state_dict(), 235 | 'optimizer': self.optimizer.state_dict(), 236 | 'model_args': self.model_args, 237 | 'iter_num': self.iter_num, 238 | 'best_val_loss': self.best_val_loss, 239 | 'config': self.config, 240 | } 241 | print(f"saving checkpoint to {self.out_dir}") 242 | torch.save(checkpoint, os.path.join(self.out_dir, 'ckpt.pt')) 243 | 244 | def train(self): 245 | # set up distributed training 246 | self.setup_ddp() 247 | 248 | ctx, meta_vocab_size = self.setup() 249 | 250 | # model init 251 | model = self.init_model() 252 | 253 | model.to(self.device) 254 | 255 | # initialize a GradScaler. If enabled=False scaler is a no-op 256 | scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == 'float16')) 257 | 258 | # optimizer 259 | self.optimizer = model.configure_optimizers(self.weight_decay, self.learning_rate, \ 260 | (self.beta1, self.beta2), self.device_type) 261 | if self.init_from == 'resume': 262 | self.optimizer.load_state_dict(self.checkpoint['optimizer']) 263 | 264 | model = self.setup_model(model) 265 | 266 | # logging 267 | if self.wandb_log and self.master_process: 268 | wandb.init(project=self.wandb_project, name=self.wandb_run_name, config=self.config) 269 | 270 | # training loop 271 | X, Y = self.get_batch('train') # fetch the very first batch 272 | t0 = time.time() 273 | local_iter_num = 0 # number of iterations in the lifetime of this process 274 | running_mfu = -1.0 275 | while True: 276 | 277 | # determine and set the learning rate for this iteration 278 | lr = self.get_lr(self.iter_num) if self.decay_lr else self.learning_rate 279 | for param_group in self.optimizer.param_groups: 280 | param_group['lr'] = lr 281 | 282 | # evaluate the loss on train/val sets and write checkpoints 283 | if self.iter_num % self.eval_interval == 0 and self.master_process: 284 | self.evaluate(model, ctx, lr) 285 | 286 | if self.iter_num == 0 and self.eval_only: 287 | break 288 | 289 | # forward backward update, with optional gradient accumulation to simulate larger batch size 290 | # and using the GradScaler if data type is float16 291 | for micro_step in range(self.gradient_accumulation_steps): 292 | if self.ddp: 293 | # in DDP training we only need to sync gradients at the last micro step. 294 | # the official way to do this is with model.no_sync() context manager, but 295 | # I really dislike that this bloats the code and forces us to repeat code 296 | # looking at the source of that context manager, it just toggles this variable 297 | model.require_backward_grad_sync = (micro_step == self.gradient_accumulation_steps - 1) 298 | with ctx: 299 | logits, loss = model(X, Y) 300 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 301 | X, Y = self.get_batch('train') 302 | # backward pass, with gradient scaling if training in fp16 303 | scaler.scale(loss).backward() 304 | # clip the gradient 305 | if self.grad_clip != 0.0: 306 | scaler.unscale_(self.optimizer) 307 | torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_clip) 308 | # step the optimizer and scaler if training in fp16 309 | scaler.step(self.optimizer) 310 | scaler.update() 311 | # flush the gradients as soon as we can, no need for this memory anymore 312 | self.optimizer.zero_grad(set_to_none=True) 313 | 314 | # timing and logging 315 | t1 = time.time() 316 | dt = t1 - t0 317 | t0 = t1 318 | if self.iter_num % self.log_interval == 0 and self.master_process: 319 | lossf = loss.item() # loss as float. note: this is a CPU-GPU sync point 320 | if local_iter_num >= 5: # let the training loop settle a bit 321 | mfu = model.estimate_mfu(self.batch_size * self.world_size * self.gradient_accumulation_steps, dt) 322 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 323 | print(f"iter {self.iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 324 | self.iter_num += 1 325 | local_iter_num += 1 326 | 327 | # termination conditions 328 | if self.iter_num > self.max_iters: 329 | break 330 | 331 | if self.ddp: 332 | destroy_process_group() 333 | 334 | # helps estimate an arbitrarily accurate loss over either split using many batches 335 | @torch.no_grad() 336 | def estimate_loss(self, model, ctx): 337 | out = {} 338 | model.eval() 339 | for split in ['train', 'val']: 340 | losses = torch.zeros(self.eval_iters) 341 | for k in range(self.eval_iters): 342 | X, Y = self.get_batch(split) 343 | with ctx: 344 | logits, loss = model(X, Y) 345 | losses[k] = loss.item() 346 | out[split] = losses.mean() 347 | model.train() 348 | return out -------------------------------------------------------------------------------- /transformer_sizing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### Transformer Theoretical Model\n", 9 | "\n", 10 | "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from collections import OrderedDict" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# config_args = {\n", 29 | "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", 30 | "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", 31 | "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", 32 | "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", 33 | "# }[model_type]\n", 34 | "\n", 35 | "block_size = 1024\n", 36 | "vocab_size = 50257\n", 37 | "n_layer = 12\n", 38 | "n_head = 12\n", 39 | "n_embd = 768\n", 40 | "bias = False\n", 41 | "assert not bias, \"this notebook assumes bias=False just for simplicity\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "we see: 124337664, expected: 124337664, match: True\n", 54 | "name params ratio (%) \n", 55 | "emebedding/position 786432 0.6325\n", 56 | "embedding/token 38597376 31.0424\n", 57 | "embedding 39383808 31.6749\n", 58 | "attention/ln 768 0.0006\n", 59 | "attention/kqv 1769472 1.4231\n", 60 | "attention/proj 589824 0.4744\n", 61 | "attention 2360064 1.8981\n", 62 | "mlp/ln 768 0.0006\n", 63 | "mlp/ffw 2359296 1.8975\n", 64 | "mlp/proj 2359296 1.8975\n", 65 | "mlp 4719360 3.7956\n", 66 | "block 7079424 5.6937\n", 67 | "transformer 84953088 68.3245\n", 68 | "ln_f 768 0.0006\n", 69 | "dense 0 0.0000\n", 70 | "total 124337664 100.0000\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "def params():\n", 76 | " \"\"\" estimates the number of parameters in the model\"\"\"\n", 77 | " out = OrderedDict()\n", 78 | "\n", 79 | " # token and position embeddings\n", 80 | " out['emebedding/position'] = n_embd * block_size\n", 81 | " out['embedding/token'] = n_embd * vocab_size\n", 82 | " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n", 83 | "\n", 84 | " # attention blocks\n", 85 | " out['attention/ln'] = n_embd # note, bias=False in our LN\n", 86 | " out['attention/kqv'] = n_embd * 3*n_embd\n", 87 | " out['attention/proj'] = n_embd**2\n", 88 | " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n", 89 | "\n", 90 | " # MLP blocks\n", 91 | " ffw_size = 4*n_embd # feed forward size\n", 92 | " out['mlp/ln'] = n_embd\n", 93 | " out['mlp/ffw'] = n_embd * ffw_size\n", 94 | " out['mlp/proj'] = ffw_size * n_embd\n", 95 | " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n", 96 | " \n", 97 | " # the transformer and the rest of it\n", 98 | " out['block'] = out['attention'] + out['mlp']\n", 99 | " out['transformer'] = n_layer * out['block']\n", 100 | " out['ln_f'] = n_embd # final layernorm\n", 101 | " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n", 102 | "\n", 103 | " # total\n", 104 | " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n", 105 | "\n", 106 | " return out\n", 107 | "\n", 108 | "# compare our param count to that reported by PyTorch\n", 109 | "p = params()\n", 110 | "params_total = p['total']\n", 111 | "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n", 112 | "# create a header\n", 113 | "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n", 114 | "for k,v in p.items():\n", 115 | " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "est checkpoint size: 1.49 GB\n", 129 | "measured with wc -c ckpt.pt: 1542470366\n", 130 | "fluff ratio: 103.38%\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# we can now calculate the size of each checkpoint\n", 136 | "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n", 137 | "params_bytes = params_total*4\n", 138 | "params_and_buffers_bytes = params_bytes + 2*params_bytes\n", 139 | "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n", 140 | "measured_bytes = 1542470366 # from wc -c ckpt.pt\n", 141 | "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n", 142 | "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "memory ratio taken up just for parameters: 3.73%\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n", 168 | "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")" 169 | ] 170 | }, 171 | { 172 | "attachments": {}, 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models." 177 | ] 178 | }, 179 | { 180 | "attachments": {}, 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Let's estimate FLOPs for a single forward pass." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "name flops ratio (%) \n", 197 | "attention/kqv 3623878656 1.2426\n", 198 | "attention/scores 1610612736 0.5522\n", 199 | "attention/reduce 1610612736 0.5522\n", 200 | "attention/proj 1207959552 0.4142\n", 201 | "attention 8053063680 2.7612\n", 202 | "mlp/ffw1 4831838208 1.6567\n", 203 | "mlp/ffw2 4831838208 1.6567\n", 204 | "mlp 9663676416 3.3135\n", 205 | "block 17716740096 6.0747\n", 206 | "transformer 212600881152 72.8963\n", 207 | "dense 79047426048 27.1037\n", 208 | "forward_total 291648307200 100.0000\n", 209 | "backward_total 583296614400 200.0000\n", 210 | "total 874944921600 300.0000\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "def flops():\n", 216 | " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n", 217 | " # we count actual FLOPs, not MACs. Hence 2* all over the place\n", 218 | " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n", 219 | "\n", 220 | " out = OrderedDict()\n", 221 | " head_size = n_embd // n_head\n", 222 | "\n", 223 | " # attention blocks\n", 224 | " # 1) the projection to key, query, values\n", 225 | " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n", 226 | " # 2) calculating the attention scores\n", 227 | " out['attention/scores'] = 2 * block_size * block_size * n_embd\n", 228 | " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", 229 | " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n", 230 | " # 4) the final linear projection\n", 231 | " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n", 232 | " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n", 233 | "\n", 234 | " # MLP blocks\n", 235 | " ffw_size = 4*n_embd # feed forward size\n", 236 | " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n", 237 | " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n", 238 | " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n", 239 | "\n", 240 | " # the transformer and the rest of it\n", 241 | " out['block'] = out['attention'] + out['mlp']\n", 242 | " out['transformer'] = n_layer * out['block']\n", 243 | " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n", 244 | "\n", 245 | " # forward,backward,total\n", 246 | " out['forward_total'] = out['transformer'] + out['dense']\n", 247 | " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n", 248 | " out['total'] = out['forward_total'] + out['backward_total']\n", 249 | "\n", 250 | " return out\n", 251 | " \n", 252 | "# compare our param count to that reported by PyTorch\n", 253 | "f = flops()\n", 254 | "flops_total = f['forward_total']\n", 255 | "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n", 256 | "for k,v in f.items():\n", 257 | " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n", 258 | " " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 7, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "# now here is an estimate copy pasted from the PaLM paper\n", 276 | "# this formula is often used to calculate MFU (model flops utilization)\n", 277 | "def palm_flops():\n", 278 | " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n", 279 | " # non-embedding model parameters. note that we do not subtract the\n", 280 | " # embedding/token params because those are tied and get used in the last layer.\n", 281 | " N = params()['total'] - params()['emebedding/position']\n", 282 | " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n", 283 | " mf_per_token = 6*N + 12*L*H*Q*T\n", 284 | " mf = mf_per_token * block_size\n", 285 | " return mf\n", 286 | "\n", 287 | "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")" 288 | ] 289 | }, 290 | { 291 | "attachments": {}, 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "fraction of A100 used: 37.14%\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# here is what we currently roughly measure\n", 313 | "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n", 314 | "measured_time = 0.755 # in seconds per iteration\n", 315 | "measured_throughput = batch_size / measured_time\n", 316 | "flops_achieved = f['total'] * measured_throughput\n", 317 | "\n", 318 | "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n", 319 | "a100_flops_promised = 312e12\n", 320 | "\n", 321 | "# the fraction of the A100 that we are using:\n", 322 | "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")" 323 | ] 324 | }, 325 | { 326 | "attachments": {}, 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 9, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "time needed to train the model: 3.46 days\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n", 348 | "model_size = params()['total'] # this is number of parameters, N\n", 349 | "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n", 350 | "a100_flops = 312e12 # 312 TFLOPS\n", 351 | "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n", 352 | "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n", 353 | "flops_needed = 6 * model_size * tokens_num # 6ND\n", 354 | "time_needed_s = flops_needed / flops_throughput # in seconds\n", 355 | "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")" 356 | ] 357 | }, 358 | { 359 | "attachments": {}, 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)." 364 | ] 365 | }, 366 | { 367 | "attachments": {}, 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later." 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "pytorch2", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.10.8" 392 | }, 393 | "orig_nbformat": 4, 394 | "vscode": { 395 | "interpreter": { 396 | "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222" 397 | } 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 2 402 | } 403 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class dotdict(dict): 4 | """dot.notation access to dictionary attributes""" 5 | __getattr__ = dict.get 6 | __setattr__ = dict.__setitem__ 7 | __delattr__ = dict.__delitem__ 8 | 9 | --------------------------------------------------------------------------------