├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── gpt2_124M_loss.png └── nanogpt.jpg ├── bench.py ├── config ├── eval_gpt2.py ├── eval_gpt2_large.py ├── eval_gpt2_medium.py ├── eval_gpt2_xl.py ├── finetune_shakespeare.py ├── train_gpt2.py └── train_shakespeare_char.py ├── configurator.py ├── data ├── openwebtext │ ├── prepare.py │ └── readme.md ├── shakespeare │ ├── prepare.py │ └── readme.md └── shakespeare_char │ ├── prepare.py │ └── readme.md ├── model.py ├── sample.py ├── scaling_laws.ipynb ├── train.py └── transformer_sizing.ipynb /.gitattributes: -------------------------------------------------------------------------------- 1 | # Override jupyter in Github language stats for more accurate estimate of repo code languages 2 | # reference: https://github.com/github/linguist/blob/master/docs/overrides.md#generated-code 3 | *.ipynb linguist-generated 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | .idea 3 | .ipynb_checkpoints/ 4 | .vscode 5 | __pycache__/ 6 | *.bin 7 | *.pkl 8 | *.pt 9 | *.pyc 10 | input.txt 11 | env/ 12 | venv/ -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Andrej Karpathy 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # nanoGPT 3 | 4 | ![nanoGPT](assets/nanogpt.jpg) 5 | 6 | The simplest, fastest repository for training/finetuning medium-sized GPTs. It is a rewrite of [minGPT](https://github.com/karpathy/minGPT) that prioritizes teeth over education. Still under active development, but currently the file `train.py` reproduces GPT-2 (124M) on OpenWebText, running on a single 8XA100 40GB node in about 4 days of training. The code itself is plain and readable: `train.py` is a ~300-line boilerplate training loop and `model.py` a ~300-line GPT model definition, which can optionally load the GPT-2 weights from OpenAI. That's it. 7 | 8 | ![repro124m](assets/gpt2_124M_loss.png) 9 | 10 | Because the code is so simple, it is very easy to hack to your needs, train new models from scratch, or finetune pretrained checkpoints (e.g. biggest one currently available as a starting point would be the GPT-2 1.3B model from OpenAI). 11 | 12 | ## install 13 | 14 | ``` 15 | pip install torch numpy transformers datasets tiktoken wandb tqdm 16 | ``` 17 | 18 | Dependencies: 19 | 20 | - [pytorch](https://pytorch.org) <3 21 | - [numpy](https://numpy.org/install/) <3 22 | - `transformers` for huggingface transformers <3 (to load GPT-2 checkpoints) 23 | - `datasets` for huggingface datasets <3 (if you want to download + preprocess OpenWebText) 24 | - `tiktoken` for OpenAI's fast BPE code <3 25 | - `wandb` for optional logging <3 26 | - `tqdm` for progress bars <3 27 | 28 | ## quick start 29 | 30 | If you are not a deep learning professional and you just want to feel the magic and get your feet wet, the fastest way to get started is to train a character-level GPT on the works of Shakespeare. First, we download it as a single (1MB) file and turn it from raw text into one large stream of integers: 31 | 32 | ```sh 33 | python data/shakespeare_char/prepare.py 34 | ``` 35 | 36 | This creates a `train.bin` and `val.bin` in that data directory. Now it is time to train your GPT. The size of it very much depends on the computational resources of your system: 37 | 38 | **I have a GPU**. Great, we can quickly train a baby GPT with the settings provided in the [config/train_shakespeare_char.py](config/train_shakespeare_char.py) config file: 39 | 40 | ```sh 41 | python train.py config/train_shakespeare_char.py 42 | ``` 43 | 44 | If you peek inside it, you'll see that we're training a GPT with a context size of up to 256 characters, 384 feature channels, and it is a 6-layer Transformer with 6 heads in each layer. On one A100 GPU this training run takes about 3 minutes and the best validation loss is 1.4697. Based on the configuration, the model checkpoints are being written into the `--out_dir` directory `out-shakespeare-char`. So once the training finishes we can sample from the best model by pointing the sampling script at this directory: 45 | 46 | ```sh 47 | python sample.py --out_dir=out-shakespeare-char 48 | ``` 49 | 50 | This generates a few samples, for example: 51 | 52 | ``` 53 | ANGELO: 54 | And cowards it be strawn to my bed, 55 | And thrust the gates of my threats, 56 | Because he that ale away, and hang'd 57 | An one with him. 58 | 59 | DUKE VINCENTIO: 60 | I thank your eyes against it. 61 | 62 | DUKE VINCENTIO: 63 | Then will answer him to save the malm: 64 | And what have you tyrannous shall do this? 65 | 66 | DUKE VINCENTIO: 67 | If you have done evils of all disposition 68 | To end his power, the day of thrust for a common men 69 | That I leave, to fight with over-liking 70 | Hasting in a roseman. 71 | ``` 72 | 73 | lol `¯\_(ツ)_/¯`. Not bad for a character-level model after 3 minutes of training on a GPU. Better results are quite likely obtainable by instead finetuning a pretrained GPT-2 model on this dataset (see finetuning section later). 74 | 75 | **I only have a macbook** (or other cheap computer). No worries, we can still train a GPT but we want to dial things down a notch. I recommend getting the bleeding edge PyTorch nightly ([select it here](https://pytorch.org/get-started/locally/) when installing) as it is currently quite likely to make your code more efficient. But even without it, a simple train run could look as follows: 76 | 77 | ```sh 78 | python train.py config/train_shakespeare_char.py --device=cpu --compile=False --eval_iters=20 --log_interval=1 --block_size=64 --batch_size=12 --n_layer=4 --n_head=4 --n_embd=128 --max_iters=2000 --lr_decay_iters=2000 --dropout=0.0 79 | ``` 80 | 81 | Here, since we are running on CPU instead of GPU we must set both `--device=cpu` and also turn off PyTorch 2.0 compile with `--compile=False`. Then when we evaluate we get a bit more noisy but faster estimate (`--eval_iters=20`, down from 200), our context size is only 64 characters instead of 256, and the batch size only 12 examples per iteration, not 64. We'll also use a much smaller Transformer (4 layers, 4 heads, 128 embedding size), and decrease the number of iterations to 2000 (and correspondingly usually decay the learning rate to around max_iters with `--lr_decay_iters`). Because our network is so small we also ease down on regularization (`--dropout=0.0`). This still runs in about ~3 minutes, but gets us a loss of only 1.88 and therefore also worse samples, but it's still good fun: 82 | 83 | ```sh 84 | python sample.py --out_dir=out-shakespeare-char --device=cpu 85 | ``` 86 | Generates samples like this: 87 | 88 | ``` 89 | GLEORKEN VINGHARD III: 90 | Whell's the couse, the came light gacks, 91 | And the for mought you in Aut fries the not high shee 92 | bot thou the sought bechive in that to doth groan you, 93 | No relving thee post mose the wear 94 | ``` 95 | 96 | Not bad for ~3 minutes on a CPU, for a hint of the right character gestalt. If you're willing to wait longer, feel free to tune the hyperparameters, increase the size of the network, the context length (`--block_size`), the length of training, etc. 97 | 98 | Finally, on Apple Silicon Macbooks and with a recent PyTorch version make sure to add `--device=mps` (short for "Metal Performance Shaders"); PyTorch then uses the on-chip GPU that can *significantly* accelerate training (2-3X) and allow you to use larger networks. See [Issue 28](https://github.com/karpathy/nanoGPT/issues/28) for more. 99 | 100 | ## reproducing GPT-2 101 | 102 | A more serious deep learning professional may be more interested in reproducing GPT-2 results. So here we go - we first tokenize the dataset, in this case the [OpenWebText](https://openwebtext2.readthedocs.io/en/latest/), an open reproduction of OpenAI's (private) WebText: 103 | 104 | ```sh 105 | python data/openwebtext/prepare.py 106 | ``` 107 | 108 | This downloads and tokenizes the [OpenWebText](https://huggingface.co/datasets/openwebtext) dataset. It will create a `train.bin` and `val.bin` which holds the GPT2 BPE token ids in one sequence, stored as raw uint16 bytes. Then we're ready to kick off training. To reproduce GPT-2 (124M) you'll want at least an 8X A100 40GB node and run: 109 | 110 | ```sh 111 | torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 112 | ``` 113 | 114 | This will run for about 4 days using PyTorch Distributed Data Parallel (DDP) and go down to loss of ~2.85. Now, a GPT-2 model just evaluated on OWT gets a val loss of about 3.11, but if you finetune it it will come down to ~2.85 territory (due to an apparent domain gap), making the two models ~match. 115 | 116 | If you're in a cluster environment and you are blessed with multiple GPU nodes you can make GPU go brrrr e.g. across 2 nodes like: 117 | 118 | ```sh 119 | # Run on the first (master) node with example IP 123.456.123.456: 120 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 121 | # Run on the worker node: 122 | torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 123 | ``` 124 | 125 | It is a good idea to benchmark your interconnect (e.g. iperf3). In particular, if you don't have Infiniband then also prepend `NCCL_IB_DISABLE=1` to the above launches. Your multinode training will work, but most likely _crawl_. By default checkpoints are periodically written to the `--out_dir`. We can sample from the model by simply `python sample.py`. 126 | 127 | Finally, to train on a single GPU simply run the `python train.py` script. Have a look at all of its args, the script tries to be very readable, hackable and transparent. You'll most likely want to tune a number of those variables depending on your needs. 128 | 129 | ## baselines 130 | 131 | OpenAI GPT-2 checkpoints allow us to get some baselines in place for openwebtext. We can get the numbers as follows: 132 | 133 | ```sh 134 | $ python train.py config/eval_gpt2.py 135 | $ python train.py config/eval_gpt2_medium.py 136 | $ python train.py config/eval_gpt2_large.py 137 | $ python train.py config/eval_gpt2_xl.py 138 | ``` 139 | 140 | and observe the following losses on train and val: 141 | 142 | | model | params | train loss | val loss | 143 | | ------| ------ | ---------- | -------- | 144 | | gpt2 | 124M | 3.11 | 3.12 | 145 | | gpt2-medium | 350M | 2.85 | 2.84 | 146 | | gpt2-large | 774M | 2.66 | 2.67 | 147 | | gpt2-xl | 1558M | 2.56 | 2.54 | 148 | 149 | However, we have to note that GPT-2 was trained on (closed, never released) WebText, while OpenWebText is just a best-effort open reproduction of this dataset. This means there is a dataset domain gap. Indeed, taking the GPT-2 (124M) checkpoint and finetuning on OWT directly for a while reaches loss down to ~2.85. This then becomes the more appropriate baseline w.r.t. reproduction. 150 | 151 | ## finetuning 152 | 153 | Finetuning is no different than training, we just make sure to initialize from a pretrained model and train with a smaller learning rate. For an example of how to finetune a GPT on new text go to `data/shakespeare` and run `prepare.py` to download the tiny shakespeare dataset and render it into a `train.bin` and `val.bin`, using the OpenAI BPE tokenizer from GPT-2. Unlike OpenWebText this will run in seconds. Finetuning can take very little time, e.g. on a single GPU just a few minutes. Run an example finetuning like: 154 | 155 | ```sh 156 | python train.py config/finetune_shakespeare.py 157 | ``` 158 | 159 | This will load the config parameter overrides in `config/finetune_shakespeare.py` (I didn't tune them much though). Basically, we initialize from a GPT2 checkpoint with `init_from` and train as normal, except shorter and with a small learning rate. If you're running out of memory try decreasing the model size (they are `{'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'}`) or possibly decreasing the `block_size` (context length). The best checkpoint (lowest validation loss) will be in the `out_dir` directory, e.g. in `out-shakespeare` by default, per the config file. You can then run the code in `sample.py --out_dir=out-shakespeare`: 160 | 161 | ``` 162 | THEODORE: 163 | Thou shalt sell me to the highest bidder: if I die, 164 | I sell thee to the first; if I go mad, 165 | I sell thee to the second; if I 166 | lie, I sell thee to the third; if I slay, 167 | I sell thee to the fourth: so buy or sell, 168 | I tell thee again, thou shalt not sell my 169 | possession. 170 | 171 | JULIET: 172 | And if thou steal, thou shalt not sell thyself. 173 | 174 | THEODORE: 175 | I do not steal; I sell the stolen goods. 176 | 177 | THEODORE: 178 | Thou know'st not what thou sell'st; thou, a woman, 179 | Thou art ever a victim, a thing of no worth: 180 | Thou hast no right, no right, but to be sold. 181 | ``` 182 | 183 | Whoa there, GPT, entering some dark place over there. I didn't really tune the hyperparameters in the config too much, feel free to try! 184 | 185 | ## sampling / inference 186 | 187 | Use the script `sample.py` to sample either from pre-trained GPT-2 models released by OpenAI, or from a model you trained yourself. For example, here is a way to sample from the largest available `gpt2-xl` model: 188 | 189 | ```sh 190 | python sample.py \ 191 | --init_from=gpt2-xl \ 192 | --start="What is the answer to life, the universe, and everything?" \ 193 | --num_samples=5 --max_new_tokens=100 194 | ``` 195 | 196 | If you'd like to sample from a model you trained, use the `--out_dir` to point the code appropriately. You can also prompt the model with some text from a file, e.g. ```python sample.py --start=FILE:prompt.txt```. 197 | 198 | ## efficiency notes 199 | 200 | For simple model benchmarking and profiling, `bench.py` might be useful. It's identical to what happens in the meat of the training loop of `train.py`, but omits much of the other complexities. 201 | 202 | Note that the code by default uses [PyTorch 2.0](https://pytorch.org/get-started/pytorch-2.0/). At the time of writing (Dec 29, 2022) this makes `torch.compile()` available in the nightly release. The improvement from the one line of code is noticeable, e.g. cutting down iteration time from ~250ms / iter to 135ms / iter. Nice work PyTorch team! 203 | 204 | ## todos 205 | 206 | - Investigate and add FSDP instead of DDP 207 | - Eval zero-shot perplexities on standard evals (e.g. LAMBADA? HELM? etc.) 208 | - Finetune the finetuning script, I think the hyperparams are not great 209 | - Schedule for linear batch size increase during training 210 | - Incorporate other embeddings (rotary, alibi) 211 | - Separate out the optim buffers from model params in checkpoints I think 212 | - Additional logging around network health (e.g. gradient clip events, magnitudes) 213 | - Few more investigations around better init etc. 214 | 215 | ## troubleshooting 216 | 217 | Note that by default this repo uses PyTorch 2.0 (i.e. `torch.compile`). This is fairly new and experimental, and not yet available on all platforms (e.g. Windows). If you're running into related error messages try to disable this by adding `--compile=False` flag. This will slow down the code but at least it will run. 218 | 219 | For some context on this repository, GPT, and language modeling it might be helpful to watch my [Zero To Hero series](https://karpathy.ai/zero-to-hero.html). Specifically, the [GPT video](https://www.youtube.com/watch?v=kCc8FmEb1nY) is popular if you have some prior language modeling context. 220 | 221 | For more questions/discussions feel free to stop by **#nanoGPT** on Discord: 222 | 223 | [![](https://dcbadge.vercel.app/api/server/3zy8kqD9Cp?compact=true&style=flat)](https://discord.gg/3zy8kqD9Cp) 224 | 225 | ## acknowledgements 226 | 227 | All nanoGPT experiments are powered by GPUs on [Lambda labs](https://lambdalabs.com), my favorite Cloud GPU provider. Thank you Lambda labs for sponsoring nanoGPT! 228 | -------------------------------------------------------------------------------- /assets/gpt2_124M_loss.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karpathy/nanoGPT/93a43d9a5c22450bbf06e78da2cb6eeef084b717/assets/gpt2_124M_loss.png -------------------------------------------------------------------------------- /assets/nanogpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/karpathy/nanoGPT/93a43d9a5c22450bbf06e78da2cb6eeef084b717/assets/nanogpt.jpg -------------------------------------------------------------------------------- /bench.py: -------------------------------------------------------------------------------- 1 | """ 2 | A much shorter version of train.py for benchmarking 3 | """ 4 | import os 5 | from contextlib import nullcontext 6 | import numpy as np 7 | import time 8 | import torch 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | batch_size = 12 13 | block_size = 1024 14 | bias = False 15 | real_data = True 16 | seed = 1337 17 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 18 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 19 | compile = True # use PyTorch 2.0 to compile the model to be faster 20 | profile = False # use pytorch profiler, or just simple benchmarking? 21 | exec(open('configurator.py').read()) # overrides from command line or config file 22 | # ----------------------------------------------------------------------------- 23 | 24 | torch.manual_seed(seed) 25 | torch.cuda.manual_seed(seed) 26 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 27 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 28 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 29 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 30 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 31 | 32 | # data loading init 33 | if real_data: 34 | dataset = 'openwebtext' 35 | data_dir = os.path.join('data', dataset) 36 | train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 37 | def get_batch(split): 38 | data = train_data # note ignore split in benchmarking script 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 41 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 42 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 43 | return x, y 44 | else: 45 | # alternatively, if fixed data is desired to not care about data loading 46 | x = torch.randint(50304, (batch_size, block_size), device=device) 47 | y = torch.randint(50304, (batch_size, block_size), device=device) 48 | get_batch = lambda split: (x, y) 49 | 50 | # model init 51 | gptconf = GPTConfig( 52 | block_size = block_size, # how far back does the model look? i.e. context size 53 | n_layer = 12, n_head = 12, n_embd = 768, # size of the model 54 | dropout = 0, # for determinism 55 | bias = bias, 56 | ) 57 | model = GPT(gptconf) 58 | model.to(device) 59 | 60 | optimizer = model.configure_optimizers(weight_decay=1e-2, learning_rate=1e-4, betas=(0.9, 0.95), device_type=device_type) 61 | 62 | if compile: 63 | print("Compiling model...") 64 | model = torch.compile(model) # pytorch 2.0 65 | 66 | if profile: 67 | # useful docs on pytorch profiler: 68 | # - tutorial https://pytorch.org/tutorials/intermediate/tensorboard_profiler_tutorial.html 69 | # - api https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile 70 | wait, warmup, active = 5, 5, 5 71 | num_steps = wait + warmup + active 72 | with torch.profiler.profile( 73 | activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], 74 | schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active, repeat=1), 75 | on_trace_ready=torch.profiler.tensorboard_trace_handler('./bench_log'), 76 | record_shapes=False, 77 | profile_memory=False, 78 | with_stack=False, # incurs an additional overhead, disable if not needed 79 | with_flops=True, 80 | with_modules=False, # only for torchscript models atm 81 | ) as prof: 82 | 83 | X, Y = get_batch('train') 84 | for k in range(num_steps): 85 | with ctx: 86 | logits, loss = model(X, Y) 87 | X, Y = get_batch('train') 88 | optimizer.zero_grad(set_to_none=True) 89 | loss.backward() 90 | optimizer.step() 91 | lossf = loss.item() 92 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 93 | 94 | prof.step() # notify the profiler at end of each step 95 | 96 | else: 97 | 98 | # simple benchmarking 99 | torch.cuda.synchronize() 100 | for stage, num_steps in enumerate([10, 20]): # burnin, then benchmark 101 | t0 = time.time() 102 | X, Y = get_batch('train') 103 | for k in range(num_steps): 104 | with ctx: 105 | logits, loss = model(X, Y) 106 | X, Y = get_batch('train') 107 | optimizer.zero_grad(set_to_none=True) 108 | loss.backward() 109 | optimizer.step() 110 | lossf = loss.item() 111 | print(f"{k}/{num_steps} loss: {lossf:.4f}") 112 | torch.cuda.synchronize() 113 | t1 = time.time() 114 | dt = t1-t0 115 | mfu = model.estimate_mfu(batch_size * 1 * num_steps, dt) 116 | if stage == 1: 117 | print(f"time per iteration: {dt/num_steps*1000:.4f}ms, MFU: {mfu*100:.2f}%") 118 | -------------------------------------------------------------------------------- /config/eval_gpt2.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=12, n_head=12, n_embd=768 3 | # 124M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_large.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=36, n_head=20, n_embd=1280 3 | # 774M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-large' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_medium.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=24, n_head=16, n_embd=1024 3 | # 350M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-medium' 9 | -------------------------------------------------------------------------------- /config/eval_gpt2_xl.py: -------------------------------------------------------------------------------- 1 | # evaluate the base gpt2 2 | # n_layer=48, n_head=25, n_embd=1600 3 | # 1558M parameters 4 | batch_size = 8 5 | eval_iters = 500 # use more iterations to get good estimate 6 | eval_only = True 7 | wandb_log = False 8 | init_from = 'gpt2-xl' 9 | -------------------------------------------------------------------------------- /config/finetune_shakespeare.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | out_dir = 'out-shakespeare' 4 | eval_interval = 5 5 | eval_iters = 40 6 | wandb_log = False # feel free to turn on 7 | wandb_project = 'shakespeare' 8 | wandb_run_name = 'ft-' + str(time.time()) 9 | 10 | dataset = 'shakespeare' 11 | init_from = 'gpt2-xl' # this is the largest GPT-2 model 12 | 13 | # only save checkpoints if the validation loss improves 14 | always_save_checkpoint = False 15 | 16 | # the number of examples per iter: 17 | # 1 batch_size * 32 grad_accum * 1024 tokens = 32,768 tokens/iter 18 | # shakespeare has 301,966 tokens, so 1 epoch ~= 9.2 iters 19 | batch_size = 1 20 | gradient_accumulation_steps = 32 21 | max_iters = 20 22 | 23 | # finetune at constant LR 24 | learning_rate = 3e-5 25 | decay_lr = False 26 | -------------------------------------------------------------------------------- /config/train_gpt2.py: -------------------------------------------------------------------------------- 1 | # config for training GPT-2 (124M) down to very nice loss of ~2.85 on 1 node of 8X A100 40GB 2 | # launch as the following (e.g. in a screen session) and wait ~5 days: 3 | # $ torchrun --standalone --nproc_per_node=8 train.py config/train_gpt2.py 4 | 5 | wandb_log = True 6 | wandb_project = 'owt' 7 | wandb_run_name='gpt2-124M' 8 | 9 | # these make the total batch size be ~0.5M 10 | # 12 batch size * 1024 block size * 5 gradaccum * 8 GPUs = 491,520 11 | batch_size = 12 12 | block_size = 1024 13 | gradient_accumulation_steps = 5 * 8 14 | 15 | # this makes total number of tokens be 300B 16 | max_iters = 600000 17 | lr_decay_iters = 600000 18 | 19 | # eval stuff 20 | eval_interval = 1000 21 | eval_iters = 200 22 | log_interval = 10 23 | 24 | # weight decay 25 | weight_decay = 1e-1 26 | -------------------------------------------------------------------------------- /config/train_shakespeare_char.py: -------------------------------------------------------------------------------- 1 | # train a miniature character-level shakespeare model 2 | # good for debugging and playing on macbooks and such 3 | 4 | out_dir = 'out-shakespeare-char' 5 | eval_interval = 250 # keep frequent because we'll overfit 6 | eval_iters = 200 7 | log_interval = 10 # don't print too too often 8 | 9 | # we expect to overfit on this small dataset, so only save when val improves 10 | always_save_checkpoint = False 11 | 12 | wandb_log = False # override via command line if you like 13 | wandb_project = 'shakespeare-char' 14 | wandb_run_name = 'mini-gpt' 15 | 16 | dataset = 'shakespeare_char' 17 | gradient_accumulation_steps = 1 18 | batch_size = 64 19 | block_size = 256 # context of up to 256 previous characters 20 | 21 | # baby GPT model :) 22 | n_layer = 6 23 | n_head = 6 24 | n_embd = 384 25 | dropout = 0.2 26 | 27 | learning_rate = 1e-3 # with baby networks can afford to go a bit higher 28 | max_iters = 5000 29 | lr_decay_iters = 5000 # make equal to max_iters usually 30 | min_lr = 1e-4 # learning_rate / 10 usually 31 | beta2 = 0.99 # make a bit bigger because number of tokens per iter is small 32 | 33 | warmup_iters = 100 # not super necessary potentially 34 | 35 | # on macbook also add 36 | # device = 'cpu' # run on cpu only 37 | # compile = False # do not torch compile the model 38 | -------------------------------------------------------------------------------- /configurator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Poor Man's Configurator. Probably a terrible idea. Example usage: 3 | $ python train.py config/override_file.py --batch_size=32 4 | this will first run config/override_file.py, then override batch_size to 32 5 | 6 | The code in this file will be run as follows from e.g. train.py: 7 | >>> exec(open('configurator.py').read()) 8 | 9 | So it's not a Python module, it's just shuttling this code away from train.py 10 | The code in this script then overrides the globals() 11 | 12 | I know people are not going to love this, I just really dislike configuration 13 | complexity and having to prepend config. to every single variable. If someone 14 | comes up with a better simple Python solution I am all ears. 15 | """ 16 | 17 | import sys 18 | from ast import literal_eval 19 | 20 | for arg in sys.argv[1:]: 21 | if '=' not in arg: 22 | # assume it's the name of a config file 23 | assert not arg.startswith('--') 24 | config_file = arg 25 | print(f"Overriding config with {config_file}:") 26 | with open(config_file) as f: 27 | print(f.read()) 28 | exec(open(config_file).read()) 29 | else: 30 | # assume it's a --key=value argument 31 | assert arg.startswith('--') 32 | key, val = arg.split('=') 33 | key = key[2:] 34 | if key in globals(): 35 | try: 36 | # attempt to eval it it (e.g. if bool, number, or etc) 37 | attempt = literal_eval(val) 38 | except (SyntaxError, ValueError): 39 | # if that goes wrong, just use the string 40 | attempt = val 41 | # ensure the types match ok 42 | assert type(attempt) == type(globals()[key]) 43 | # cross fingers 44 | print(f"Overriding: {key} = {attempt}") 45 | globals()[key] = attempt 46 | else: 47 | raise ValueError(f"Unknown config key: {key}") 48 | -------------------------------------------------------------------------------- /data/openwebtext/prepare.py: -------------------------------------------------------------------------------- 1 | # saves the openwebtext dataset to a binary file for training. following was helpful: 2 | # https://github.com/HazyResearch/flash-attention/blob/main/training/src/datamodules/language_modeling_hf.py 3 | 4 | import os 5 | from tqdm import tqdm 6 | import numpy as np 7 | import tiktoken 8 | from datasets import load_dataset # huggingface datasets 9 | 10 | # number of workers in .map() call 11 | # good number to use is ~order number of cpu cores // 2 12 | num_proc = 8 13 | 14 | # number of workers in load_dataset() call 15 | # best number might be different from num_proc above as it also depends on NW speed. 16 | # it is better than 1 usually though 17 | num_proc_load_dataset = num_proc 18 | 19 | enc = tiktoken.get_encoding("gpt2") 20 | 21 | if __name__ == '__main__': 22 | # takes 54GB in huggingface .cache dir, about 8M documents (8,013,769) 23 | dataset = load_dataset("openwebtext", num_proc=num_proc_load_dataset) 24 | 25 | # owt by default only contains the 'train' split, so create a test split 26 | split_dataset = dataset["train"].train_test_split(test_size=0.0005, seed=2357, shuffle=True) 27 | split_dataset['val'] = split_dataset.pop('test') # rename the test split to val 28 | 29 | # this results in: 30 | # >>> split_dataset 31 | # DatasetDict({ 32 | # train: Dataset({ 33 | # features: ['text'], 34 | # num_rows: 8009762 35 | # }) 36 | # val: Dataset({ 37 | # features: ['text'], 38 | # num_rows: 4007 39 | # }) 40 | # }) 41 | 42 | # we now want to tokenize the dataset. first define the encoding function (gpt2 bpe) 43 | def process(example): 44 | ids = enc.encode_ordinary(example['text']) # encode_ordinary ignores any special tokens 45 | ids.append(enc.eot_token) # add the end of text token, e.g. 50256 for gpt2 bpe 46 | # note: I think eot should be prepended not appended... hmm. it's called "eot" though... 47 | out = {'ids': ids, 'len': len(ids)} 48 | return out 49 | 50 | # tokenize the dataset 51 | tokenized = split_dataset.map( 52 | process, 53 | remove_columns=['text'], 54 | desc="tokenizing the splits", 55 | num_proc=num_proc, 56 | ) 57 | 58 | # concatenate all the ids in each dataset into one large file we can use for training 59 | for split, dset in tokenized.items(): 60 | arr_len = np.sum(dset['len'], dtype=np.uint64) 61 | filename = os.path.join(os.path.dirname(__file__), f'{split}.bin') 62 | dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) 63 | arr = np.memmap(filename, dtype=dtype, mode='w+', shape=(arr_len,)) 64 | total_batches = 1024 65 | 66 | idx = 0 67 | for batch_idx in tqdm(range(total_batches), desc=f'writing {filename}'): 68 | # Batch together samples for faster write 69 | batch = dset.shard(num_shards=total_batches, index=batch_idx, contiguous=True).with_format('numpy') 70 | arr_batch = np.concatenate(batch['ids']) 71 | # Write into mmap 72 | arr[idx : idx + len(arr_batch)] = arr_batch 73 | idx += len(arr_batch) 74 | arr.flush() 75 | 76 | # train.bin is ~17GB, val.bin ~8.5MB 77 | # train has ~9B tokens (9,035,582,198) 78 | # val has ~4M tokens (4,434,897) 79 | 80 | # to read the bin files later, e.g. with numpy: 81 | # m = np.memmap('train.bin', dtype=np.uint16, mode='r') 82 | -------------------------------------------------------------------------------- /data/openwebtext/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ## openwebtext dataset 3 | 4 | after running `prepare.py` (preprocess) we get: 5 | 6 | - train.bin is ~17GB, val.bin ~8.5MB 7 | - train has ~9B tokens (9,035,582,198) 8 | - val has ~4M tokens (4,434,897) 9 | 10 | this came from 8,013,769 documents in total. 11 | 12 | references: 13 | 14 | - OpenAI's WebText dataset is discussed in [GPT-2 paper](https://d4mucfpksywv.cloudfront.net/better-language-models/language_models_are_unsupervised_multitask_learners.pdf) 15 | - [OpenWebText](https://skylion007.github.io/OpenWebTextCorpus/) dataset 16 | -------------------------------------------------------------------------------- /data/shakespeare/prepare.py: -------------------------------------------------------------------------------- 1 | import os 2 | import requests 3 | import tiktoken 4 | import numpy as np 5 | 6 | # download the tiny shakespeare dataset 7 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 8 | if not os.path.exists(input_file_path): 9 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 10 | with open(input_file_path, 'w', encoding='utf-8') as f: 11 | f.write(requests.get(data_url).text) 12 | 13 | with open(input_file_path, 'r', encoding='utf-8') as f: 14 | data = f.read() 15 | n = len(data) 16 | train_data = data[:int(n*0.9)] 17 | val_data = data[int(n*0.9):] 18 | 19 | # encode with tiktoken gpt2 bpe 20 | enc = tiktoken.get_encoding("gpt2") 21 | train_ids = enc.encode_ordinary(train_data) 22 | val_ids = enc.encode_ordinary(val_data) 23 | print(f"train has {len(train_ids):,} tokens") 24 | print(f"val has {len(val_ids):,} tokens") 25 | 26 | # export to bin files 27 | train_ids = np.array(train_ids, dtype=np.uint16) 28 | val_ids = np.array(val_ids, dtype=np.uint16) 29 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 30 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 31 | 32 | # train.bin has 301,966 tokens 33 | # val.bin has 36,059 tokens 34 | -------------------------------------------------------------------------------- /data/shakespeare/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 301,966 tokens 9 | - val.bin has 36,059 tokens 10 | -------------------------------------------------------------------------------- /data/shakespeare_char/prepare.py: -------------------------------------------------------------------------------- 1 | """ 2 | Prepare the Shakespeare dataset for character-level language modeling. 3 | So instead of encoding with GPT-2 BPE tokens, we just map characters to ints. 4 | Will save train.bin, val.bin containing the ids, and meta.pkl containing the 5 | encoder and decoder and some other related info. 6 | """ 7 | import os 8 | import pickle 9 | import requests 10 | import numpy as np 11 | 12 | # download the tiny shakespeare dataset 13 | input_file_path = os.path.join(os.path.dirname(__file__), 'input.txt') 14 | if not os.path.exists(input_file_path): 15 | data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt' 16 | with open(input_file_path, 'w') as f: 17 | f.write(requests.get(data_url).text) 18 | 19 | with open(input_file_path, 'r') as f: 20 | data = f.read() 21 | print(f"length of dataset in characters: {len(data):,}") 22 | 23 | # get all the unique characters that occur in this text 24 | chars = sorted(list(set(data))) 25 | vocab_size = len(chars) 26 | print("all the unique characters:", ''.join(chars)) 27 | print(f"vocab size: {vocab_size:,}") 28 | 29 | # create a mapping from characters to integers 30 | stoi = { ch:i for i,ch in enumerate(chars) } 31 | itos = { i:ch for i,ch in enumerate(chars) } 32 | def encode(s): 33 | return [stoi[c] for c in s] # encoder: take a string, output a list of integers 34 | def decode(l): 35 | return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 36 | 37 | # create the train and test splits 38 | n = len(data) 39 | train_data = data[:int(n*0.9)] 40 | val_data = data[int(n*0.9):] 41 | 42 | # encode both to integers 43 | train_ids = encode(train_data) 44 | val_ids = encode(val_data) 45 | print(f"train has {len(train_ids):,} tokens") 46 | print(f"val has {len(val_ids):,} tokens") 47 | 48 | # export to bin files 49 | train_ids = np.array(train_ids, dtype=np.uint16) 50 | val_ids = np.array(val_ids, dtype=np.uint16) 51 | train_ids.tofile(os.path.join(os.path.dirname(__file__), 'train.bin')) 52 | val_ids.tofile(os.path.join(os.path.dirname(__file__), 'val.bin')) 53 | 54 | # save the meta information as well, to help us encode/decode later 55 | meta = { 56 | 'vocab_size': vocab_size, 57 | 'itos': itos, 58 | 'stoi': stoi, 59 | } 60 | with open(os.path.join(os.path.dirname(__file__), 'meta.pkl'), 'wb') as f: 61 | pickle.dump(meta, f) 62 | 63 | # length of dataset in characters: 1115394 64 | # all the unique characters: 65 | # !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz 66 | # vocab size: 65 67 | # train has 1003854 tokens 68 | # val has 111540 tokens 69 | -------------------------------------------------------------------------------- /data/shakespeare_char/readme.md: -------------------------------------------------------------------------------- 1 | 2 | # tiny shakespeare, character-level 3 | 4 | Tiny shakespeare, of the good old char-rnn fame :) Treated on character-level. 5 | 6 | After running `prepare.py`: 7 | 8 | - train.bin has 1,003,854 tokens 9 | - val.bin has 111,540 tokens 10 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | References: 4 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 5 | https://github.com/openai/gpt-2/blob/master/src/model.py 6 | 2) huggingface/transformers PyTorch implementation: 7 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 8 | """ 9 | 10 | import math 11 | import inspect 12 | from dataclasses import dataclass 13 | 14 | import torch 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | 18 | class LayerNorm(nn.Module): 19 | """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """ 20 | 21 | def __init__(self, ndim, bias): 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(ndim)) 24 | self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None 25 | 26 | def forward(self, input): 27 | return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | 31 | def __init__(self, config): 32 | super().__init__() 33 | assert config.n_embd % config.n_head == 0 34 | # key, query, value projections for all heads, but in a batch 35 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) 36 | # output projection 37 | self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) 38 | # regularization 39 | self.attn_dropout = nn.Dropout(config.dropout) 40 | self.resid_dropout = nn.Dropout(config.dropout) 41 | self.n_head = config.n_head 42 | self.n_embd = config.n_embd 43 | self.dropout = config.dropout 44 | # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0 45 | self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') 46 | if not self.flash: 47 | print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") 48 | # causal mask to ensure that attention is only applied to the left in the input sequence 49 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 50 | .view(1, 1, config.block_size, config.block_size)) 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | q, k, v = self.c_attn(x).split(self.n_embd, dim=2) 57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | 61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 | if self.flash: 63 | # efficient attention using Flash Attention CUDA kernels 64 | y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True) 65 | else: 66 | # manual implementation of attention 67 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 68 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 69 | att = F.softmax(att, dim=-1) 70 | att = self.attn_dropout(att) 71 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 72 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 73 | 74 | # output projection 75 | y = self.resid_dropout(self.c_proj(y)) 76 | return y 77 | 78 | class MLP(nn.Module): 79 | 80 | def __init__(self, config): 81 | super().__init__() 82 | self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias) 83 | self.gelu = nn.GELU() 84 | self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias) 85 | self.dropout = nn.Dropout(config.dropout) 86 | 87 | def forward(self, x): 88 | x = self.c_fc(x) 89 | x = self.gelu(x) 90 | x = self.c_proj(x) 91 | x = self.dropout(x) 92 | return x 93 | 94 | class Block(nn.Module): 95 | 96 | def __init__(self, config): 97 | super().__init__() 98 | self.ln_1 = LayerNorm(config.n_embd, bias=config.bias) 99 | self.attn = CausalSelfAttention(config) 100 | self.ln_2 = LayerNorm(config.n_embd, bias=config.bias) 101 | self.mlp = MLP(config) 102 | 103 | def forward(self, x): 104 | x = x + self.attn(self.ln_1(x)) 105 | x = x + self.mlp(self.ln_2(x)) 106 | return x 107 | 108 | @dataclass 109 | class GPTConfig: 110 | block_size: int = 1024 111 | vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency 112 | n_layer: int = 12 113 | n_head: int = 12 114 | n_embd: int = 768 115 | dropout: float = 0.0 116 | bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster 117 | 118 | class GPT(nn.Module): 119 | 120 | def __init__(self, config): 121 | super().__init__() 122 | assert config.vocab_size is not None 123 | assert config.block_size is not None 124 | self.config = config 125 | 126 | self.transformer = nn.ModuleDict(dict( 127 | wte = nn.Embedding(config.vocab_size, config.n_embd), 128 | wpe = nn.Embedding(config.block_size, config.n_embd), 129 | drop = nn.Dropout(config.dropout), 130 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 131 | ln_f = LayerNorm(config.n_embd, bias=config.bias), 132 | )) 133 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 134 | # with weight tying when using torch.compile() some warnings get generated: 135 | # "UserWarning: functional_call was passed multiple values for tied weights. 136 | # This behavior is deprecated and will be an error in future versions" 137 | # not 100% sure what this is, so far seems to be harmless. TODO investigate 138 | self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying 139 | 140 | # init all weights 141 | self.apply(self._init_weights) 142 | # apply special scaled init to the residual projections, per GPT-2 paper 143 | for pn, p in self.named_parameters(): 144 | if pn.endswith('c_proj.weight'): 145 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 146 | 147 | # report number of parameters 148 | print("number of parameters: %.2fM" % (self.get_num_params()/1e6,)) 149 | 150 | def get_num_params(self, non_embedding=True): 151 | """ 152 | Return the number of parameters in the model. 153 | For non-embedding count (default), the position embeddings get subtracted. 154 | The token embeddings would too, except due to the parameter sharing these 155 | params are actually used as weights in the final layer, so we include them. 156 | """ 157 | n_params = sum(p.numel() for p in self.parameters()) 158 | if non_embedding: 159 | n_params -= self.transformer.wpe.weight.numel() 160 | return n_params 161 | 162 | def _init_weights(self, module): 163 | if isinstance(module, nn.Linear): 164 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 165 | if module.bias is not None: 166 | torch.nn.init.zeros_(module.bias) 167 | elif isinstance(module, nn.Embedding): 168 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 169 | 170 | def forward(self, idx, targets=None): 171 | device = idx.device 172 | b, t = idx.size() 173 | assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}" 174 | pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) 175 | 176 | # forward the GPT model itself 177 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 178 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) 179 | x = self.transformer.drop(tok_emb + pos_emb) 180 | for block in self.transformer.h: 181 | x = block(x) 182 | x = self.transformer.ln_f(x) 183 | 184 | if targets is not None: 185 | # if we are given some desired targets also calculate the loss 186 | logits = self.lm_head(x) 187 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 188 | else: 189 | # inference-time mini-optimization: only forward the lm_head on the very last position 190 | logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim 191 | loss = None 192 | 193 | return logits, loss 194 | 195 | def crop_block_size(self, block_size): 196 | # model surgery to decrease the block size if necessary 197 | # e.g. we may load the GPT2 pretrained model checkpoint (block size 1024) 198 | # but want to use a smaller block size for some smaller, simpler model 199 | assert block_size <= self.config.block_size 200 | self.config.block_size = block_size 201 | self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:block_size]) 202 | for block in self.transformer.h: 203 | if hasattr(block.attn, 'bias'): 204 | block.attn.bias = block.attn.bias[:,:,:block_size,:block_size] 205 | 206 | @classmethod 207 | def from_pretrained(cls, model_type, override_args=None): 208 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 209 | override_args = override_args or {} # default to empty dict 210 | # only dropout can be overridden see more notes below 211 | assert all(k == 'dropout' for k in override_args) 212 | from transformers import GPT2LMHeadModel 213 | print("loading weights from pretrained gpt: %s" % model_type) 214 | 215 | # n_layer, n_head and n_embd are determined from model_type 216 | config_args = { 217 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 218 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 219 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 220 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 221 | }[model_type] 222 | print("forcing vocab_size=50257, block_size=1024, bias=True") 223 | config_args['vocab_size'] = 50257 # always 50257 for GPT model checkpoints 224 | config_args['block_size'] = 1024 # always 1024 for GPT model checkpoints 225 | config_args['bias'] = True # always True for GPT model checkpoints 226 | # we can override the dropout rate, if desired 227 | if 'dropout' in override_args: 228 | print(f"overriding dropout rate to {override_args['dropout']}") 229 | config_args['dropout'] = override_args['dropout'] 230 | # create a from-scratch initialized minGPT model 231 | config = GPTConfig(**config_args) 232 | model = GPT(config) 233 | sd = model.state_dict() 234 | sd_keys = sd.keys() 235 | sd_keys = [k for k in sd_keys if not k.endswith('.attn.bias')] # discard this mask / buffer, not a param 236 | 237 | # init a huggingface/transformers model 238 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 239 | sd_hf = model_hf.state_dict() 240 | 241 | # copy while ensuring all of the parameters are aligned and match in names and shapes 242 | sd_keys_hf = sd_hf.keys() 243 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer 244 | sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer) 245 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 246 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla Linear 247 | # this means that we have to transpose these weights when we import them 248 | assert len(sd_keys_hf) == len(sd_keys), f"mismatched keys: {len(sd_keys_hf)} != {len(sd_keys)}" 249 | for k in sd_keys_hf: 250 | if any(k.endswith(w) for w in transposed): 251 | # special treatment for the Conv1D weights we need to transpose 252 | assert sd_hf[k].shape[::-1] == sd[k].shape 253 | with torch.no_grad(): 254 | sd[k].copy_(sd_hf[k].t()) 255 | else: 256 | # vanilla copy over the other parameters 257 | assert sd_hf[k].shape == sd[k].shape 258 | with torch.no_grad(): 259 | sd[k].copy_(sd_hf[k]) 260 | 261 | return model 262 | 263 | def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): 264 | # start with all of the candidate parameters 265 | param_dict = {pn: p for pn, p in self.named_parameters()} 266 | # filter out those that do not require grad 267 | param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} 268 | # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. 269 | # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. 270 | decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] 271 | nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] 272 | optim_groups = [ 273 | {'params': decay_params, 'weight_decay': weight_decay}, 274 | {'params': nodecay_params, 'weight_decay': 0.0} 275 | ] 276 | num_decay_params = sum(p.numel() for p in decay_params) 277 | num_nodecay_params = sum(p.numel() for p in nodecay_params) 278 | print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") 279 | print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") 280 | # Create AdamW optimizer and use the fused version if it is available 281 | fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters 282 | use_fused = fused_available and device_type == 'cuda' 283 | extra_args = dict(fused=True) if use_fused else dict() 284 | optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) 285 | print(f"using fused AdamW: {use_fused}") 286 | 287 | return optimizer 288 | 289 | def estimate_mfu(self, fwdbwd_per_iter, dt): 290 | """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ 291 | # first estimate the number of flops we do per iteration. 292 | # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 293 | N = self.get_num_params() 294 | cfg = self.config 295 | L, H, Q, T = cfg.n_layer, cfg.n_head, cfg.n_embd//cfg.n_head, cfg.block_size 296 | flops_per_token = 6*N + 12*L*H*Q*T 297 | flops_per_fwdbwd = flops_per_token * T 298 | flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter 299 | # express our flops throughput as ratio of A100 bfloat16 peak flops 300 | flops_achieved = flops_per_iter * (1.0/dt) # per second 301 | flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS 302 | mfu = flops_achieved / flops_promised 303 | return mfu 304 | 305 | @torch.no_grad() 306 | def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): 307 | """ 308 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 309 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 310 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 311 | """ 312 | for _ in range(max_new_tokens): 313 | # if the sequence context is growing too long we must crop it at block_size 314 | idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:] 315 | # forward the model to get the logits for the index in the sequence 316 | logits, _ = self(idx_cond) 317 | # pluck the logits at the final step and scale by desired temperature 318 | logits = logits[:, -1, :] / temperature 319 | # optionally crop the logits to only the top k options 320 | if top_k is not None: 321 | v, _ = torch.topk(logits, min(top_k, logits.size(-1))) 322 | logits[logits < v[:, [-1]]] = -float('Inf') 323 | # apply softmax to convert logits to (normalized) probabilities 324 | probs = F.softmax(logits, dim=-1) 325 | # sample from the distribution 326 | idx_next = torch.multinomial(probs, num_samples=1) 327 | # append sampled index to the running sequence and continue 328 | idx = torch.cat((idx, idx_next), dim=1) 329 | 330 | return idx 331 | -------------------------------------------------------------------------------- /sample.py: -------------------------------------------------------------------------------- 1 | """ 2 | Sample from a trained model 3 | """ 4 | import os 5 | import pickle 6 | from contextlib import nullcontext 7 | import torch 8 | import tiktoken 9 | from model import GPTConfig, GPT 10 | 11 | # ----------------------------------------------------------------------------- 12 | init_from = 'resume' # either 'resume' (from an out_dir) or a gpt2 variant (e.g. 'gpt2-xl') 13 | out_dir = 'out' # ignored if init_from is not 'resume' 14 | start = "\n" # or "<|endoftext|>" or etc. Can also specify a file, use as: "FILE:prompt.txt" 15 | num_samples = 10 # number of samples to draw 16 | max_new_tokens = 500 # number of tokens generated in each sample 17 | temperature = 0.8 # 1.0 = no change, < 1.0 = less random, > 1.0 = more random, in predictions 18 | top_k = 200 # retain only the top_k most likely tokens, clamp others to have 0 probability 19 | seed = 1337 20 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1', etc. 21 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32' or 'bfloat16' or 'float16' 22 | compile = False # use PyTorch 2.0 to compile the model to be faster 23 | exec(open('configurator.py').read()) # overrides from command line or config file 24 | # ----------------------------------------------------------------------------- 25 | 26 | torch.manual_seed(seed) 27 | torch.cuda.manual_seed(seed) 28 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 29 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 30 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 31 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 32 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 33 | 34 | # model 35 | if init_from == 'resume': 36 | # init from a model saved in a specific directory 37 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 38 | checkpoint = torch.load(ckpt_path, map_location=device) 39 | gptconf = GPTConfig(**checkpoint['model_args']) 40 | model = GPT(gptconf) 41 | state_dict = checkpoint['model'] 42 | unwanted_prefix = '_orig_mod.' 43 | for k,v in list(state_dict.items()): 44 | if k.startswith(unwanted_prefix): 45 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 46 | model.load_state_dict(state_dict) 47 | elif init_from.startswith('gpt2'): 48 | # init from a given GPT-2 model 49 | model = GPT.from_pretrained(init_from, dict(dropout=0.0)) 50 | 51 | model.eval() 52 | model.to(device) 53 | if compile: 54 | model = torch.compile(model) # requires PyTorch 2.0 (optional) 55 | 56 | # look for the meta pickle in case it is available in the dataset folder 57 | load_meta = False 58 | if init_from == 'resume' and 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... 59 | meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') 60 | load_meta = os.path.exists(meta_path) 61 | if load_meta: 62 | print(f"Loading meta from {meta_path}...") 63 | with open(meta_path, 'rb') as f: 64 | meta = pickle.load(f) 65 | # TODO want to make this more general to arbitrary encoder/decoder schemes 66 | stoi, itos = meta['stoi'], meta['itos'] 67 | encode = lambda s: [stoi[c] for c in s] 68 | decode = lambda l: ''.join([itos[i] for i in l]) 69 | else: 70 | # ok let's assume gpt-2 encodings by default 71 | print("No meta.pkl found, assuming GPT-2 encodings...") 72 | enc = tiktoken.get_encoding("gpt2") 73 | encode = lambda s: enc.encode(s, allowed_special={"<|endoftext|>"}) 74 | decode = lambda l: enc.decode(l) 75 | 76 | # encode the beginning of the prompt 77 | if start.startswith('FILE:'): 78 | with open(start[5:], 'r', encoding='utf-8') as f: 79 | start = f.read() 80 | start_ids = encode(start) 81 | x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) 82 | 83 | # run generation 84 | with torch.no_grad(): 85 | with ctx: 86 | for k in range(num_samples): 87 | y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) 88 | print(decode(y[0].tolist())) 89 | print('---------------') 90 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | This training script can be run both on a single gpu in debug mode, 3 | and also in a larger training run with distributed data parallel (ddp). 4 | 5 | To run on a single GPU, example: 6 | $ python train.py --batch_size=32 --compile=False 7 | 8 | To run with DDP on 4 gpus on 1 node, example: 9 | $ torchrun --standalone --nproc_per_node=4 train.py 10 | 11 | To run with DDP on 4 gpus across 2 nodes, example: 12 | - Run on the first (master) node with example IP 123.456.123.456: 13 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr=123.456.123.456 --master_port=1234 train.py 14 | - Run on the worker node: 15 | $ torchrun --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr=123.456.123.456 --master_port=1234 train.py 16 | (If your cluster does not have Infiniband interconnect prepend NCCL_IB_DISABLE=1) 17 | """ 18 | 19 | import os 20 | import time 21 | import math 22 | import pickle 23 | from contextlib import nullcontext 24 | 25 | import numpy as np 26 | import torch 27 | from torch.nn.parallel import DistributedDataParallel as DDP 28 | from torch.distributed import init_process_group, destroy_process_group 29 | 30 | from model import GPTConfig, GPT 31 | 32 | # ----------------------------------------------------------------------------- 33 | # default config values designed to train a gpt2 (124M) on OpenWebText 34 | # I/O 35 | out_dir = 'out' 36 | eval_interval = 2000 37 | log_interval = 1 38 | eval_iters = 200 39 | eval_only = False # if True, script exits right after the first eval 40 | always_save_checkpoint = True # if True, always save a checkpoint after each eval 41 | init_from = 'scratch' # 'scratch' or 'resume' or 'gpt2*' 42 | # wandb logging 43 | wandb_log = False # disabled by default 44 | wandb_project = 'owt' 45 | wandb_run_name = 'gpt2' # 'run' + str(time.time()) 46 | # data 47 | dataset = 'openwebtext' 48 | gradient_accumulation_steps = 5 * 8 # used to simulate larger batch sizes 49 | batch_size = 12 # if gradient_accumulation_steps > 1, this is the micro-batch size 50 | block_size = 1024 51 | # model 52 | n_layer = 12 53 | n_head = 12 54 | n_embd = 768 55 | dropout = 0.0 # for pretraining 0 is good, for finetuning try 0.1+ 56 | bias = False # do we use bias inside LayerNorm and Linear layers? 57 | # adamw optimizer 58 | learning_rate = 6e-4 # max learning rate 59 | max_iters = 600000 # total number of training iterations 60 | weight_decay = 1e-1 61 | beta1 = 0.9 62 | beta2 = 0.95 63 | grad_clip = 1.0 # clip gradients at this value, or disable if == 0.0 64 | # learning rate decay settings 65 | decay_lr = True # whether to decay the learning rate 66 | warmup_iters = 2000 # how many steps to warm up for 67 | lr_decay_iters = 600000 # should be ~= max_iters per Chinchilla 68 | min_lr = 6e-5 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla 69 | # DDP settings 70 | backend = 'nccl' # 'nccl', 'gloo', etc. 71 | # system 72 | device = 'cuda' # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks 73 | dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler 74 | compile = True # use PyTorch 2.0 to compile the model to be faster 75 | # ----------------------------------------------------------------------------- 76 | config_keys = [k for k,v in globals().items() if not k.startswith('_') and isinstance(v, (int, float, bool, str))] 77 | exec(open('configurator.py').read()) # overrides from command line or config file 78 | config = {k: globals()[k] for k in config_keys} # will be useful for logging 79 | # ----------------------------------------------------------------------------- 80 | 81 | # various inits, derived attributes, I/O setup 82 | ddp = int(os.environ.get('RANK', -1)) != -1 # is this a ddp run? 83 | if ddp: 84 | init_process_group(backend=backend) 85 | ddp_rank = int(os.environ['RANK']) 86 | ddp_local_rank = int(os.environ['LOCAL_RANK']) 87 | ddp_world_size = int(os.environ['WORLD_SIZE']) 88 | device = f'cuda:{ddp_local_rank}' 89 | torch.cuda.set_device(device) 90 | master_process = ddp_rank == 0 # this process will do logging, checkpointing etc. 91 | seed_offset = ddp_rank # each process gets a different seed 92 | # world_size number of processes will be training simultaneously, so we can scale 93 | # down the desired gradient accumulation iterations per process proportionally 94 | assert gradient_accumulation_steps % ddp_world_size == 0 95 | gradient_accumulation_steps //= ddp_world_size 96 | else: 97 | # if not ddp, we are running on a single gpu, and one process 98 | master_process = True 99 | seed_offset = 0 100 | ddp_world_size = 1 101 | tokens_per_iter = gradient_accumulation_steps * ddp_world_size * batch_size * block_size 102 | print(f"tokens per iteration will be: {tokens_per_iter:,}") 103 | 104 | if master_process: 105 | os.makedirs(out_dir, exist_ok=True) 106 | torch.manual_seed(1337 + seed_offset) 107 | torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul 108 | torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn 109 | device_type = 'cuda' if 'cuda' in device else 'cpu' # for later use in torch.autocast 110 | # note: float16 data type will automatically use a GradScaler 111 | ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype] 112 | ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) 113 | 114 | # poor man's data loader 115 | data_dir = os.path.join('data', dataset) 116 | def get_batch(split): 117 | # We recreate np.memmap every batch to avoid a memory leak, as per 118 | # https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122 119 | if split == 'train': 120 | data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r') 121 | else: 122 | data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r') 123 | ix = torch.randint(len(data) - block_size, (batch_size,)) 124 | x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix]) 125 | y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix]) 126 | if device_type == 'cuda': 127 | # pin arrays x,y, which allows us to move them to GPU asynchronously (non_blocking=True) 128 | x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True) 129 | else: 130 | x, y = x.to(device), y.to(device) 131 | return x, y 132 | 133 | # init these up here, can override if init_from='resume' (i.e. from a checkpoint) 134 | iter_num = 0 135 | best_val_loss = 1e9 136 | 137 | # attempt to derive vocab_size from the dataset 138 | meta_path = os.path.join(data_dir, 'meta.pkl') 139 | meta_vocab_size = None 140 | if os.path.exists(meta_path): 141 | with open(meta_path, 'rb') as f: 142 | meta = pickle.load(f) 143 | meta_vocab_size = meta['vocab_size'] 144 | print(f"found vocab_size = {meta_vocab_size} (inside {meta_path})") 145 | 146 | # model init 147 | model_args = dict(n_layer=n_layer, n_head=n_head, n_embd=n_embd, block_size=block_size, 148 | bias=bias, vocab_size=None, dropout=dropout) # start with model_args from command line 149 | if init_from == 'scratch': 150 | # init a new model from scratch 151 | print("Initializing a new model from scratch") 152 | # determine the vocab size we'll use for from-scratch training 153 | if meta_vocab_size is None: 154 | print("defaulting to vocab_size of GPT-2 to 50304 (50257 rounded up for efficiency)") 155 | model_args['vocab_size'] = meta_vocab_size if meta_vocab_size is not None else 50304 156 | gptconf = GPTConfig(**model_args) 157 | model = GPT(gptconf) 158 | elif init_from == 'resume': 159 | print(f"Resuming training from {out_dir}") 160 | # resume training from a checkpoint. 161 | ckpt_path = os.path.join(out_dir, 'ckpt.pt') 162 | checkpoint = torch.load(ckpt_path, map_location=device) 163 | checkpoint_model_args = checkpoint['model_args'] 164 | # force these config attributes to be equal otherwise we can't even resume training 165 | # the rest of the attributes (e.g. dropout) can stay as desired from command line 166 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 167 | model_args[k] = checkpoint_model_args[k] 168 | # create the model 169 | gptconf = GPTConfig(**model_args) 170 | model = GPT(gptconf) 171 | state_dict = checkpoint['model'] 172 | # fix the keys of the state dictionary :( 173 | # honestly no idea how checkpoints sometimes get this prefix, have to debug more 174 | unwanted_prefix = '_orig_mod.' 175 | for k,v in list(state_dict.items()): 176 | if k.startswith(unwanted_prefix): 177 | state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) 178 | model.load_state_dict(state_dict) 179 | iter_num = checkpoint['iter_num'] 180 | best_val_loss = checkpoint['best_val_loss'] 181 | elif init_from.startswith('gpt2'): 182 | print(f"Initializing from OpenAI GPT-2 weights: {init_from}") 183 | # initialize from OpenAI GPT-2 weights 184 | override_args = dict(dropout=dropout) 185 | model = GPT.from_pretrained(init_from, override_args) 186 | # read off the created config params, so we can store them into checkpoint correctly 187 | for k in ['n_layer', 'n_head', 'n_embd', 'block_size', 'bias', 'vocab_size']: 188 | model_args[k] = getattr(model.config, k) 189 | # crop down the model block size if desired, using model surgery 190 | if block_size < model.config.block_size: 191 | model.crop_block_size(block_size) 192 | model_args['block_size'] = block_size # so that the checkpoint will have the right value 193 | model.to(device) 194 | 195 | # initialize a GradScaler. If enabled=False scaler is a no-op 196 | scaler = torch.cuda.amp.GradScaler(enabled=(dtype == 'float16')) 197 | 198 | # optimizer 199 | optimizer = model.configure_optimizers(weight_decay, learning_rate, (beta1, beta2), device_type) 200 | if init_from == 'resume': 201 | optimizer.load_state_dict(checkpoint['optimizer']) 202 | checkpoint = None # free up memory 203 | 204 | # compile the model 205 | if compile: 206 | print("compiling the model... (takes a ~minute)") 207 | unoptimized_model = model 208 | model = torch.compile(model) # requires PyTorch 2.0 209 | 210 | # wrap model into DDP container 211 | if ddp: 212 | model = DDP(model, device_ids=[ddp_local_rank]) 213 | 214 | # helps estimate an arbitrarily accurate loss over either split using many batches 215 | @torch.no_grad() 216 | def estimate_loss(): 217 | out = {} 218 | model.eval() 219 | for split in ['train', 'val']: 220 | losses = torch.zeros(eval_iters) 221 | for k in range(eval_iters): 222 | X, Y = get_batch(split) 223 | with ctx: 224 | logits, loss = model(X, Y) 225 | losses[k] = loss.item() 226 | out[split] = losses.mean() 227 | model.train() 228 | return out 229 | 230 | # learning rate decay scheduler (cosine with warmup) 231 | def get_lr(it): 232 | # 1) linear warmup for warmup_iters steps 233 | if it < warmup_iters: 234 | return learning_rate * (it + 1) / (warmup_iters + 1) 235 | # 2) if it > lr_decay_iters, return min learning rate 236 | if it > lr_decay_iters: 237 | return min_lr 238 | # 3) in between, use cosine decay down to min learning rate 239 | decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) 240 | assert 0 <= decay_ratio <= 1 241 | coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1 242 | return min_lr + coeff * (learning_rate - min_lr) 243 | 244 | # logging 245 | if wandb_log and master_process: 246 | import wandb 247 | wandb.init(project=wandb_project, name=wandb_run_name, config=config) 248 | 249 | # training loop 250 | X, Y = get_batch('train') # fetch the very first batch 251 | t0 = time.time() 252 | local_iter_num = 0 # number of iterations in the lifetime of this process 253 | raw_model = model.module if ddp else model # unwrap DDP container if needed 254 | running_mfu = -1.0 255 | while True: 256 | 257 | # determine and set the learning rate for this iteration 258 | lr = get_lr(iter_num) if decay_lr else learning_rate 259 | for param_group in optimizer.param_groups: 260 | param_group['lr'] = lr 261 | 262 | # evaluate the loss on train/val sets and write checkpoints 263 | if iter_num % eval_interval == 0 and master_process: 264 | losses = estimate_loss() 265 | print(f"step {iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 266 | if wandb_log: 267 | wandb.log({ 268 | "iter": iter_num, 269 | "train/loss": losses['train'], 270 | "val/loss": losses['val'], 271 | "lr": lr, 272 | "mfu": running_mfu*100, # convert to percentage 273 | }) 274 | if losses['val'] < best_val_loss or always_save_checkpoint: 275 | best_val_loss = losses['val'] 276 | if iter_num > 0: 277 | checkpoint = { 278 | 'model': raw_model.state_dict(), 279 | 'optimizer': optimizer.state_dict(), 280 | 'model_args': model_args, 281 | 'iter_num': iter_num, 282 | 'best_val_loss': best_val_loss, 283 | 'config': config, 284 | } 285 | print(f"saving checkpoint to {out_dir}") 286 | torch.save(checkpoint, os.path.join(out_dir, 'ckpt.pt')) 287 | if iter_num == 0 and eval_only: 288 | break 289 | 290 | # forward backward update, with optional gradient accumulation to simulate larger batch size 291 | # and using the GradScaler if data type is float16 292 | for micro_step in range(gradient_accumulation_steps): 293 | if ddp: 294 | # in DDP training we only need to sync gradients at the last micro step. 295 | # the official way to do this is with model.no_sync() context manager, but 296 | # I really dislike that this bloats the code and forces us to repeat code 297 | # looking at the source of that context manager, it just toggles this variable 298 | model.require_backward_grad_sync = (micro_step == gradient_accumulation_steps - 1) 299 | with ctx: 300 | logits, loss = model(X, Y) 301 | loss = loss / gradient_accumulation_steps # scale the loss to account for gradient accumulation 302 | # immediately async prefetch next batch while model is doing the forward pass on the GPU 303 | X, Y = get_batch('train') 304 | # backward pass, with gradient scaling if training in fp16 305 | scaler.scale(loss).backward() 306 | # clip the gradient 307 | if grad_clip != 0.0: 308 | scaler.unscale_(optimizer) 309 | torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) 310 | # step the optimizer and scaler if training in fp16 311 | scaler.step(optimizer) 312 | scaler.update() 313 | # flush the gradients as soon as we can, no need for this memory anymore 314 | optimizer.zero_grad(set_to_none=True) 315 | 316 | # timing and logging 317 | t1 = time.time() 318 | dt = t1 - t0 319 | t0 = t1 320 | if iter_num % log_interval == 0 and master_process: 321 | # get loss as float. note: this is a CPU-GPU sync point 322 | # scale up to undo the division above, approximating the true total loss (exact would have been a sum) 323 | lossf = loss.item() * gradient_accumulation_steps 324 | if local_iter_num >= 5: # let the training loop settle a bit 325 | mfu = raw_model.estimate_mfu(batch_size * gradient_accumulation_steps, dt) 326 | running_mfu = mfu if running_mfu == -1.0 else 0.9*running_mfu + 0.1*mfu 327 | print(f"iter {iter_num}: loss {lossf:.4f}, time {dt*1000:.2f}ms, mfu {running_mfu*100:.2f}%") 328 | iter_num += 1 329 | local_iter_num += 1 330 | 331 | # termination conditions 332 | if iter_num > max_iters: 333 | break 334 | 335 | if ddp: 336 | destroy_process_group() 337 | -------------------------------------------------------------------------------- /transformer_sizing.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "attachments": {}, 5 | "cell_type": "markdown", 6 | "metadata": {}, 7 | "source": [ 8 | "### Transformer Theoretical Model\n", 9 | "\n", 10 | "This notebook stores a bunch of analysis about a Transformer, e.g. estimates the number of FLOPs, parameters, peak memory footprint, checkpoint size, etc." 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "from collections import OrderedDict" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "metadata": {}, 26 | "outputs": [], 27 | "source": [ 28 | "# config_args = {\n", 29 | "# 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params\n", 30 | "# 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params\n", 31 | "# 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params\n", 32 | "# 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params\n", 33 | "# }[model_type]\n", 34 | "\n", 35 | "block_size = 1024\n", 36 | "vocab_size = 50257\n", 37 | "n_layer = 12\n", 38 | "n_head = 12\n", 39 | "n_embd = 768\n", 40 | "bias = False\n", 41 | "assert not bias, \"this notebook assumes bias=False just for simplicity\"" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 3, 47 | "metadata": {}, 48 | "outputs": [ 49 | { 50 | "name": "stdout", 51 | "output_type": "stream", 52 | "text": [ 53 | "we see: 124337664, expected: 124337664, match: True\n", 54 | "name params ratio (%) \n", 55 | "emebedding/position 786432 0.6325\n", 56 | "embedding/token 38597376 31.0424\n", 57 | "embedding 39383808 31.6749\n", 58 | "attention/ln 768 0.0006\n", 59 | "attention/kqv 1769472 1.4231\n", 60 | "attention/proj 589824 0.4744\n", 61 | "attention 2360064 1.8981\n", 62 | "mlp/ln 768 0.0006\n", 63 | "mlp/ffw 2359296 1.8975\n", 64 | "mlp/proj 2359296 1.8975\n", 65 | "mlp 4719360 3.7956\n", 66 | "block 7079424 5.6937\n", 67 | "transformer 84953088 68.3245\n", 68 | "ln_f 768 0.0006\n", 69 | "dense 0 0.0000\n", 70 | "total 124337664 100.0000\n" 71 | ] 72 | } 73 | ], 74 | "source": [ 75 | "def params():\n", 76 | " \"\"\" estimates the number of parameters in the model\"\"\"\n", 77 | " out = OrderedDict()\n", 78 | "\n", 79 | " # token and position embeddings\n", 80 | " out['emebedding/position'] = n_embd * block_size\n", 81 | " out['embedding/token'] = n_embd * vocab_size\n", 82 | " out['embedding'] = out['emebedding/position'] + out['embedding/token']\n", 83 | "\n", 84 | " # attention blocks\n", 85 | " out['attention/ln'] = n_embd # note, bias=False in our LN\n", 86 | " out['attention/kqv'] = n_embd * 3*n_embd\n", 87 | " out['attention/proj'] = n_embd**2\n", 88 | " out['attention'] = out['attention/ln'] + out['attention/kqv'] + out['attention/proj']\n", 89 | "\n", 90 | " # MLP blocks\n", 91 | " ffw_size = 4*n_embd # feed forward size\n", 92 | " out['mlp/ln'] = n_embd\n", 93 | " out['mlp/ffw'] = n_embd * ffw_size\n", 94 | " out['mlp/proj'] = ffw_size * n_embd\n", 95 | " out['mlp'] = out['mlp/ln'] + out['mlp/ffw'] + out['mlp/proj']\n", 96 | " \n", 97 | " # the transformer and the rest of it\n", 98 | " out['block'] = out['attention'] + out['mlp']\n", 99 | " out['transformer'] = n_layer * out['block']\n", 100 | " out['ln_f'] = n_embd # final layernorm\n", 101 | " out['dense'] = 0 # 0 because of parameter sharing. This layer uses the weights from the embedding layer\n", 102 | "\n", 103 | " # total\n", 104 | " out['total'] = out['embedding'] + out['transformer'] + out['ln_f'] + out['dense']\n", 105 | "\n", 106 | " return out\n", 107 | "\n", 108 | "# compare our param count to that reported by PyTorch\n", 109 | "p = params()\n", 110 | "params_total = p['total']\n", 111 | "print(f\"we see: {params_total}, expected: {124337664}, match: {params_total == 124337664}\")\n", 112 | "# create a header\n", 113 | "print(f\"{'name':20s} {'params':10s} {'ratio (%)':10s}\")\n", 114 | "for k,v in p.items():\n", 115 | " print(f\"{k:20s} {v:10d} {v/params_total*100:10.4f}\")\n", 116 | " " 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 4, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "name": "stdout", 126 | "output_type": "stream", 127 | "text": [ 128 | "est checkpoint size: 1.49 GB\n", 129 | "measured with wc -c ckpt.pt: 1542470366\n", 130 | "fluff ratio: 103.38%\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "# we can now calculate the size of each checkpoint\n", 136 | "# params are stored in fp32, and the AdamW optimizer has 2 additional buffers per param for statistics\n", 137 | "params_bytes = params_total*4\n", 138 | "params_and_buffers_bytes = params_bytes + 2*params_bytes\n", 139 | "print(f\"est checkpoint size: {params_and_buffers_bytes/1e9:.2f} GB\")\n", 140 | "measured_bytes = 1542470366 # from wc -c ckpt.pt\n", 141 | "print(f\"measured with wc -c ckpt.pt: {measured_bytes}\")\n", 142 | "print(f\"fluff ratio: {measured_bytes/params_and_buffers_bytes*100:.2f}%\")" 143 | ] 144 | }, 145 | { 146 | "attachments": {}, 147 | "cell_type": "markdown", 148 | "metadata": {}, 149 | "source": [ 150 | "We can also estimate the ratio of our GPU memory that will be taken up just by the weights and the buffers inside the AdamW optimizer" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 5, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "memory ratio taken up just for parameters: 3.73%\n" 163 | ] 164 | } 165 | ], 166 | "source": [ 167 | "gpu_memory = 40e9 # 40 GB A100 GPU, roughly\n", 168 | "print(f\"memory ratio taken up just for parameters: {params_and_buffers_bytes / gpu_memory * 100:.2f}%\")" 169 | ] 170 | }, 171 | { 172 | "attachments": {}, 173 | "cell_type": "markdown", 174 | "metadata": {}, 175 | "source": [ 176 | "i.e. not that much of the memory for this tiny model, most of the memory is activations (forward and backward). This of course changes dramatically for larger and larger models." 177 | ] 178 | }, 179 | { 180 | "attachments": {}, 181 | "cell_type": "markdown", 182 | "metadata": {}, 183 | "source": [ 184 | "Let's estimate FLOPs for a single forward pass." 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 6, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "name flops ratio (%) \n", 197 | "attention/kqv 3623878656 1.2426\n", 198 | "attention/scores 1610612736 0.5522\n", 199 | "attention/reduce 1610612736 0.5522\n", 200 | "attention/proj 1207959552 0.4142\n", 201 | "attention 8053063680 2.7612\n", 202 | "mlp/ffw1 4831838208 1.6567\n", 203 | "mlp/ffw2 4831838208 1.6567\n", 204 | "mlp 9663676416 3.3135\n", 205 | "block 17716740096 6.0747\n", 206 | "transformer 212600881152 72.8963\n", 207 | "dense 79047426048 27.1037\n", 208 | "forward_total 291648307200 100.0000\n", 209 | "backward_total 583296614400 200.0000\n", 210 | "total 874944921600 300.0000\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "def flops():\n", 216 | " # we only count Weight FLOPs, all other layers (LayerNorm, Softmax, etc) are effectively irrelevant\n", 217 | " # we count actual FLOPs, not MACs. Hence 2* all over the place\n", 218 | " # basically for any matrix multiply A (BxC) @ B (CxD) -> (BxD) flops are 2*B*C*D\n", 219 | "\n", 220 | " out = OrderedDict()\n", 221 | " head_size = n_embd // n_head\n", 222 | "\n", 223 | " # attention blocks\n", 224 | " # 1) the projection to key, query, values\n", 225 | " out['attention/kqv'] = 2 * block_size * (n_embd * 3*n_embd)\n", 226 | " # 2) calculating the attention scores\n", 227 | " out['attention/scores'] = 2 * block_size * block_size * n_embd\n", 228 | " # 3) the reduction of the values (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)\n", 229 | " out['attention/reduce'] = 2 * n_head * (block_size * block_size * head_size)\n", 230 | " # 4) the final linear projection\n", 231 | " out['attention/proj'] = 2 * block_size * (n_embd * n_embd)\n", 232 | " out['attention'] = sum(out['attention/'+k] for k in ['kqv', 'scores', 'reduce', 'proj'])\n", 233 | "\n", 234 | " # MLP blocks\n", 235 | " ffw_size = 4*n_embd # feed forward size\n", 236 | " out['mlp/ffw1'] = 2 * block_size * (n_embd * ffw_size)\n", 237 | " out['mlp/ffw2'] = 2 * block_size * (ffw_size * n_embd)\n", 238 | " out['mlp'] = out['mlp/ffw1'] + out['mlp/ffw2']\n", 239 | "\n", 240 | " # the transformer and the rest of it\n", 241 | " out['block'] = out['attention'] + out['mlp']\n", 242 | " out['transformer'] = n_layer * out['block']\n", 243 | " out['dense'] = 2 * block_size * (n_embd * vocab_size)\n", 244 | "\n", 245 | " # forward,backward,total\n", 246 | " out['forward_total'] = out['transformer'] + out['dense']\n", 247 | " out['backward_total'] = 2 * out['forward_total'] # use common estimate of bwd = 2*fwd\n", 248 | " out['total'] = out['forward_total'] + out['backward_total']\n", 249 | "\n", 250 | " return out\n", 251 | " \n", 252 | "# compare our param count to that reported by PyTorch\n", 253 | "f = flops()\n", 254 | "flops_total = f['forward_total']\n", 255 | "print(f\"{'name':20s} {'flops':14s} {'ratio (%)':10s}\")\n", 256 | "for k,v in f.items():\n", 257 | " print(f\"{k:20s} {v:14d} {v/flops_total*100:10.4f}\")\n", 258 | " " 259 | ] 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": 7, 264 | "metadata": {}, 265 | "outputs": [ 266 | { 267 | "name": "stdout", 268 | "output_type": "stream", 269 | "text": [ 270 | "palm_flops: 875062886400, flops: 874944921600, ratio: 1.0001\n" 271 | ] 272 | } 273 | ], 274 | "source": [ 275 | "# now here is an estimate copy pasted from the PaLM paper\n", 276 | "# this formula is often used to calculate MFU (model flops utilization)\n", 277 | "def palm_flops():\n", 278 | " \"\"\"estimate of the model flops following PaLM paper formula\"\"\"\n", 279 | " # non-embedding model parameters. note that we do not subtract the\n", 280 | " # embedding/token params because those are tied and get used in the last layer.\n", 281 | " N = params()['total'] - params()['emebedding/position']\n", 282 | " L, H, Q, T = n_layer, n_head, n_embd//n_head, block_size\n", 283 | " mf_per_token = 6*N + 12*L*H*Q*T\n", 284 | " mf = mf_per_token * block_size\n", 285 | " return mf\n", 286 | "\n", 287 | "print(f\"palm_flops: {palm_flops():d}, flops: {flops()['total']:d}, ratio: {palm_flops()/flops()['total']:.4f}\")" 288 | ] 289 | }, 290 | { 291 | "attachments": {}, 292 | "cell_type": "markdown", 293 | "metadata": {}, 294 | "source": [ 295 | "Ok they are quite similar, giving some confidence that my math in flops() function was ~ok. Now, A100 is cited at 312TFLOPS bfloat16 on tensor cores. So what is our model flops utilization (MFU)? I trained the model above with a batch_size of 20 and grad_accum of 5, which runs in about 755ms on a single A100 GPU. We get:" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 8, 301 | "metadata": {}, 302 | "outputs": [ 303 | { 304 | "name": "stdout", 305 | "output_type": "stream", 306 | "text": [ 307 | "fraction of A100 used: 37.14%\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "# here is what we currently roughly measure\n", 313 | "batch_size = 20 * 5 # 5 is grad_accum, so total batch size is 100\n", 314 | "measured_time = 0.755 # in seconds per iteration\n", 315 | "measured_throughput = batch_size / measured_time\n", 316 | "flops_achieved = f['total'] * measured_throughput\n", 317 | "\n", 318 | "# A100 is cited to be 312 TFLOPS of bloat16 running on tensor cores\n", 319 | "a100_flops_promised = 312e12\n", 320 | "\n", 321 | "# the fraction of the A100 that we are using:\n", 322 | "print(f\"fraction of A100 used: {flops_achieved / a100_flops_promised * 100:.2f}%\")" 323 | ] 324 | }, 325 | { 326 | "attachments": {}, 327 | "cell_type": "markdown", 328 | "metadata": {}, 329 | "source": [ 330 | "For reference, we'd prefer to be somewhere around 50%+, and not just for a single GPU but for an entire DDP run. So we still have some work to do, but at least we're within a factor of ~2X of what is achievable with this GPU." 331 | ] 332 | }, 333 | { 334 | "cell_type": "code", 335 | "execution_count": 9, 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "time needed to train the model: 3.46 days\n" 343 | ] 344 | } 345 | ], 346 | "source": [ 347 | "# Finally let's check out the 6ND approximation as total cost of training in FLOPs\n", 348 | "model_size = params()['total'] # this is number of parameters, N\n", 349 | "tokens_num = 300e9 # 300B tokens, this is dataset size in tokens, D\n", 350 | "a100_flops = 312e12 # 312 TFLOPS\n", 351 | "assumed_mfu = 0.3 # assume this model flops utilization (take the current 37% from above and add some DDP overhead)\n", 352 | "flops_throughput = a100_flops * 8 * assumed_mfu # assume an 8XA100 node at 30% utilization\n", 353 | "flops_needed = 6 * model_size * tokens_num # 6ND\n", 354 | "time_needed_s = flops_needed / flops_throughput # in seconds\n", 355 | "print(f\"time needed to train the model: {time_needed_s/3600/24:.2f} days\")" 356 | ] 357 | }, 358 | { 359 | "attachments": {}, 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "This is not a bad estimate at all. I trained this model and it converged in roughly 4 days. Btw as a good reference for where 6ND comes from and some intuition around it I recommend [Dzmitry's post](https://medium.com/@dzmitrybahdanau/the-flops-calculus-of-language-model-training-3b19c1f025e4)." 364 | ] 365 | }, 366 | { 367 | "attachments": {}, 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "Now, FLOPs are just one constraint, the other that we have to keep a close track of is the memory bandwidth. TODO estimate LOAD/STORE costs of our model later." 372 | ] 373 | } 374 | ], 375 | "metadata": { 376 | "kernelspec": { 377 | "display_name": "pytorch2", 378 | "language": "python", 379 | "name": "python3" 380 | }, 381 | "language_info": { 382 | "codemirror_mode": { 383 | "name": "ipython", 384 | "version": 3 385 | }, 386 | "file_extension": ".py", 387 | "mimetype": "text/x-python", 388 | "name": "python", 389 | "nbconvert_exporter": "python", 390 | "pygments_lexer": "ipython3", 391 | "version": "3.10.8" 392 | }, 393 | "orig_nbformat": 4, 394 | "vscode": { 395 | "interpreter": { 396 | "hash": "7f5833218766b48e6e35e4452ee875aac0e2188d05bbe5298f2c62b79f08b222" 397 | } 398 | } 399 | }, 400 | "nbformat": 4, 401 | "nbformat_minor": 2 402 | } 403 | --------------------------------------------------------------------------------