├── .gitignore ├── README.md ├── mingpt.jpg ├── mingpt ├── __init__.py ├── block.py ├── callback.py ├── lr_decay.py └── utils.py ├── requirements.txt └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # minGPT with Lightning & DeepSpeed 2 | 3 | **Lightning now has their own Lightning GPT Example! Highly recommend using their repo [here](https://github.com/Lightning-AI/lightning-GPT).** 4 | 5 | ![mingpt](mingpt.jpg) 6 | 7 | [![Open in Streamlit](https://static.streamlit.io/badges/streamlit_badge_black_white.svg)](https://share.streamlit.io/seannaren/mingpt/streamlit/app.py) 8 | 9 | Modified [Andrej's](https://github.com/karpathy/minGPT) and [William's](https://github.com/williamFalcon/minGPT) awesome code to provide a minimal example of how to pair Lightning and DeepSpeed with a minimal GPT model. 10 | 11 | *Note: this minimal example won't be as efficient/optimized as other specialized repos due to keeping it minimal and readable, but large model training is still achievable.* 12 | 13 | ## Usage 14 | 15 | ``` 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Training Billion+ Parameter GPT Models 20 | 21 | A lot of information has been taken from the very helpful [Lightning Model Parallel Documentation](https://pytorch-lightning.readthedocs.io/en/latest/advanced/model_parallel.html#fully-sharded-training). 22 | 23 | In the below examples batch size is set to 1 to try reduce VRAM as much as possible, but you can scale that with your compute. In the below case we could scale the batch size significantly to fill the left over GPU memory. 24 | 25 | For 20B/45B parameter models, you'll need a reasonable amount of CPU RAM as we offload partitions to the CPU. For the 45B parameter model, you'll need around 1TB of CPU memory which is the default for the p4d.24xlarge instance in AWS (roughly 9 dollars an hour for a spot instance). 26 | 27 | Note that we enable CPU offloading. Offloading has a huge impact on throughput and in most cases when training from scratch should be turned off. You should consider scaling the number of GPUs rather than enabling offloading at these model sizes. 28 | 29 | ##### 1.7B (Requires around 2GiB per 8 GPUs, 5.1GiB for 1 GPU) 30 | ```bash 31 | python train.py --n_layer 15 --n_head 16 --n_embd 3072 --gpus 8 --precision 16 --batch_size 1 --strategy deepspeed_stage_3_offload 32 | ``` 33 | 34 | ##### ~10B (Requires around 6GiB per 8 GPUs, 26GiB for 1 GPU) 35 | ```bash 36 | python train.py --n_layer 13 --n_head 16 --n_embd 8192 --gpus 8 --precision 16 --batch_size 1 --strategy deepspeed_stage_3_offload 37 | ``` 38 | 39 | ##### ~20B (Requires around 8GiB per 8 GPUs, OOM for 1 GPU, offloading onto ~500GB of CPU RAM) 40 | ```bash 41 | python train.py --n_layer 25 --n_head 16 --n_embd 8192 --gpus 8 --precision 16 --batch_size 1 --strategy deepspeed_stage_3_offload 42 | ``` 43 | 44 | ##### ~45B (Requires around 14GiB per 8 GPUs, OOM for 1 GPU, offloading onto ~950GB of CPU RAM) 45 | ```bash 46 | python train.py --n_layer 56 --n_head 16 --n_embd 8192 --gpus 8 --precision 16 --batch_size 1 --strategy deepspeed_stage_3_offload 47 | ``` 48 | 49 | ### Model Loading and Evaluation Example 50 | The best model is checkpointed during the training process and stored by default in the "checkpoints" directory. With DeepSpeed, model checkpoints are saved as directories, which can cause some issues when trying to load model/trainers from Pytorch Lightning checkpoints. To properly restore the model and run test, call the evaluate.py file with similar arguments to the train script: 51 | 52 | ```bash 53 | python evaluate.py --gpus 1 --precision 16 --batch_size 1 --strategy deepspeed_stage_2 54 | ``` 55 | 56 | This will first convert the model checkpoint directory into a single model .pt file, then load the trainer using deepspeed_stage_2, and run the test set. For simplicity of this example, the test set is identical to the training set. 57 | 58 | ### License 59 | 60 | MIT 61 | -------------------------------------------------------------------------------- /mingpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/minGPT/2d185a773d3a6457bdcbaacfb30a532abc3d2052/mingpt.jpg -------------------------------------------------------------------------------- /mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SeanNaren/minGPT/2d185a773d3a6457bdcbaacfb30a532abc3d2052/mingpt/__init__.py -------------------------------------------------------------------------------- /mingpt/block.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | 8 | class CausalSelfAttention(nn.Module): 9 | """ 10 | A vanilla multi-head masked self-attention layer with a projection at the end. 11 | I believe I could have just used torch.nn.MultiheadAttention but their documentation 12 | is all but absent and code ugly so I don't trust it, rolling my own here. 13 | """ 14 | 15 | def __init__(self, config): 16 | super().__init__() 17 | assert config.n_embd % config.n_head == 0 18 | # key, query, value projections for all heads 19 | self.key = nn.Linear(config.n_embd, config.n_embd) 20 | self.query = nn.Linear(config.n_embd, config.n_embd) 21 | self.value = nn.Linear(config.n_embd, config.n_embd) 22 | # regularization 23 | self.attn_drop = nn.Dropout(config.attn_pdrop) 24 | self.resid_drop = nn.Dropout(config.resid_pdrop) 25 | # output projection 26 | self.proj = nn.Linear(config.n_embd, config.n_embd) 27 | # causal mask to ensure that attention is only applied to the left in the input sequence 28 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 29 | .view(1, 1, config.block_size, config.block_size)) 30 | self.n_head = config.n_head 31 | 32 | def forward(self, x, layer_past=None): 33 | B, T, C = x.size() 34 | 35 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 36 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 37 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 38 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 39 | 40 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 41 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 42 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 43 | att = F.softmax(att, dim=-1) 44 | att = self.attn_drop(att) 45 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 46 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 47 | 48 | # output projection 49 | y = self.resid_drop(self.proj(y)) 50 | return y 51 | 52 | 53 | class Block(nn.Module): 54 | """ an unassuming Transformer block """ 55 | 56 | def __init__(self, config): 57 | super().__init__() 58 | self.ln1 = nn.LayerNorm(config.n_embd) 59 | self.ln2 = nn.LayerNorm(config.n_embd) 60 | self.attn = CausalSelfAttention(config) 61 | self.mlp = nn.Sequential( 62 | nn.Linear(config.n_embd, 4 * config.n_embd), 63 | nn.GELU(), 64 | nn.Linear(4 * config.n_embd, config.n_embd), 65 | nn.Dropout(config.resid_pdrop), 66 | ) 67 | 68 | def forward(self, x): 69 | x = x + self.attn(self.ln1(x)) 70 | x = x + self.mlp(self.ln2(x)) 71 | return x 72 | -------------------------------------------------------------------------------- /mingpt/callback.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | import torch 4 | from pytorch_lightning import Callback 5 | from pytorch_lightning.utilities import rank_zero_info 6 | 7 | 8 | class CUDACallback(Callback): 9 | 10 | def on_train_epoch_start(self, trainer, pl_module): 11 | # Reset the memory use counter 12 | torch.cuda.reset_peak_memory_stats(self.root_gpu(trainer)) 13 | torch.cuda.synchronize(self.root_gpu(trainer)) 14 | self.start_time = time.time() 15 | 16 | def on_train_epoch_end(self, trainer, pl_module): 17 | torch.cuda.synchronize(self.root_gpu(trainer)) 18 | max_memory = torch.cuda.max_memory_allocated(self.root_gpu(trainer)) / 2 ** 20 19 | epoch_time = time.time() - self.start_time 20 | 21 | max_memory = trainer.training_type_plugin.reduce(max_memory) 22 | epoch_time = trainer.training_type_plugin.reduce(epoch_time) 23 | 24 | rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds") 25 | rank_zero_info(f"Average Peak memory {max_memory:.2f}MiB") 26 | 27 | def root_gpu(self, trainer): 28 | return trainer.strategy.root_device.index 29 | -------------------------------------------------------------------------------- /mingpt/lr_decay.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pytorch_lightning as pl 3 | 4 | 5 | class LearningRateDecayCallback(pl.Callback): 6 | 7 | def __init__(self, learning_rate, warmup_tokens=375e6, final_tokens=260e9, lr_decay=True): 8 | super().__init__() 9 | self.learning_rate = learning_rate 10 | self.tokens = 0 11 | self.final_tokens = final_tokens 12 | self.lr_decay = lr_decay 13 | self.warmup_tokens = warmup_tokens 14 | 15 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 16 | optimizer = trainer.optimizers[0] 17 | _, y = batch 18 | 19 | if self.lr_decay: 20 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 21 | if self.tokens < self.warmup_tokens: 22 | # linear warmup 23 | lr_mult = float(self.tokens) / float(max(1, self.warmup_tokens)) 24 | else: 25 | # cosine learning rate decay 26 | progress = float(self.tokens - self.warmup_tokens) / float( 27 | max(1, self.final_tokens - self.warmup_tokens)) 28 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 29 | lr = self.learning_rate * lr_mult 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = lr -------------------------------------------------------------------------------- /mingpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def set_seed(seed): 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | 13 | def top_k_logits(logits, k): 14 | v, ix = torch.topk(logits, k) 15 | out = logits.clone() 16 | out[out < v[:, [-1]]] = -float('Inf') 17 | return out 18 | 19 | @torch.no_grad() 20 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 21 | """ 22 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 23 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 24 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 25 | of block_size, unlike an RNN that has an infinite context window. 26 | """ 27 | block_size = model.get_block_size() 28 | model.eval() 29 | for k in range(steps): 30 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 31 | logits = model(x_cond) 32 | # pluck the logits at the final step and scale by temperature 33 | logits = logits[:, -1, :] / temperature 34 | # optionally crop probabilities to only the top k options 35 | if top_k is not None: 36 | logits = top_k_logits(logits, top_k) 37 | # apply softmax to convert to probabilities 38 | probs = F.softmax(logits, dim=-1) 39 | # sample from the distribution or take the most likely 40 | if sample: 41 | ix = torch.multinomial(probs, num_samples=1) 42 | else: 43 | _, ix = torch.topk(probs, k=1, dim=-1) 44 | # append to the sequence and continue 45 | x = torch.cat((x, ix), dim=1) 46 | 47 | return x 48 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning>=1.6.0 2 | torch>=1.10.0 3 | deepspeed 4 | numpy 5 | psutil 6 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentParser 3 | 4 | import numpy as np 5 | from pytorch_lightning import Trainer 6 | from pytorch_lightning import seed_everything 7 | from pytorch_lightning.strategies import DeepSpeedStrategy 8 | from pytorch_lightning.utilities import rank_zero_info 9 | from pytorch_lightning.utilities.meta import init_meta_context 10 | from torch.utils.data import Dataset, DataLoader 11 | import math 12 | 13 | import deepspeed 14 | import pytorch_lightning as pl 15 | import torch 16 | import torch.nn as nn 17 | from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam 18 | from torch.nn import functional as F 19 | from mingpt.callback import CUDACallback 20 | from mingpt.lr_decay import LearningRateDecayCallback 21 | from mingpt.block import Block 22 | 23 | 24 | class CharDataset(Dataset): 25 | 26 | def __init__(self, data, block_size): 27 | chars = list(set(data)) 28 | data_size, vocab_size = len(data), len(chars) 29 | rank_zero_info('data has %d characters, %d unique.' % (data_size, vocab_size)) 30 | 31 | self.stoi = {ch: i for i, ch in enumerate(chars)} 32 | self.itos = {i: ch for i, ch in enumerate(chars)} 33 | self.block_size = block_size 34 | self.vocab_size = vocab_size 35 | self.data = data 36 | 37 | def __len__(self): 38 | return math.ceil(len(self.data) / (self.block_size + 1)) 39 | 40 | def __getitem__(self, idx): 41 | # we're actually going to "cheat" and pick a spot in the dataset at random 42 | i = np.random.randint(0, len(self.data) - (self.block_size + 1)) 43 | chunk = self.data[i:i + self.block_size + 1] 44 | dix = [self.stoi[s] for s in chunk] 45 | x = torch.tensor(dix[:-1], dtype=torch.long) 46 | y = torch.tensor(dix[1:], dtype=torch.long) 47 | return x, y 48 | 49 | 50 | class GPT(pl.LightningModule): 51 | def __init__(self, 52 | vocab_size, 53 | weight_decay=0.1, 54 | betas=(0.9, 0.95), 55 | learning_rate=3e-4, 56 | n_embd=768, 57 | block_size=128, 58 | embd_pdrop=0.1, 59 | n_layer=12, 60 | n_head=4, 61 | resid_pdrop=0.1, 62 | attn_pdrop=0.1 63 | ): 64 | super().__init__() 65 | self.save_hyperparameters() 66 | # input embedding stem 67 | self.tok_emb = nn.Embedding(vocab_size, n_embd) 68 | self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd)) 69 | self.drop = nn.Dropout(embd_pdrop) 70 | 71 | # decoder head 72 | self.ln_f = nn.LayerNorm(self.hparams.n_embd) 73 | self.head = nn.Linear(self.hparams.n_embd, self.hparams.vocab_size, bias=False) 74 | 75 | self.block_size = self.hparams.block_size 76 | 77 | self.blocks = nn.ModuleList([Block(self.hparams) for _ in range(self.hparams.n_layer)]) 78 | 79 | def configure_optimizers(self): 80 | no_decay = ["bias", "LayerNorm.weight"] 81 | params_decay = [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)] 82 | params_nodecay = [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)] 83 | optim_groups = [ 84 | {"params": params_decay, "weight_decay": self.hparams.weight_decay}, 85 | {"params": params_nodecay, "weight_decay": 0.0}, 86 | ] 87 | # todo: need to enable deepspeed cpu adam only if offloading 88 | 89 | if self.deepspeed_offload: 90 | return DeepSpeedCPUAdam(optim_groups, lr=self.hparams.learning_rate, betas=self.hparams.betas) 91 | return FusedAdam(optim_groups, lr=self.hparams.learning_rate, betas=self.hparams.betas) 92 | 93 | @property 94 | def deepspeed_offload(self) -> bool: 95 | strategy = self.trainer.strategy 96 | if isinstance(strategy, DeepSpeedStrategy): 97 | config = strategy.config['zero_optimization'] 98 | return config.get('offload_optimizer') or config.get('offload_param') 99 | return False 100 | 101 | def forward(self, idx): 102 | b, t = idx.size() 103 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 104 | 105 | # forward the GPT model 106 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 107 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 108 | x = self.drop(token_embeddings + position_embeddings) 109 | for block in self.blocks: 110 | x = deepspeed.checkpointing.checkpoint(block, x) 111 | x = self.ln_f(x) 112 | logits = self.head(x) 113 | return logits 114 | 115 | def training_step(self, batch, batch_idx): 116 | idx, targets = batch 117 | logits = self(idx) 118 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 119 | self.log('train_loss', loss) 120 | return loss 121 | 122 | 123 | if __name__ == '__main__': 124 | seed_everything(42) 125 | 126 | parser = ArgumentParser() 127 | parser = Trainer.add_argparse_args(parser) 128 | parser.add_argument('--n_layer', default=22, type=int) 129 | parser.add_argument('--n_head', default=16, type=int) 130 | parser.add_argument('--n_embd', default=3072, type=int) 131 | parser.add_argument('--learning_rate', default=6e-4, type=float) 132 | parser.add_argument('--block_size', default=128, type=int) 133 | parser.add_argument('--batch_size', default=1, type=int) 134 | parser.add_argument('--num_workers', default=0, type=int) 135 | args = parser.parse_args() 136 | 137 | if not os.path.exists("input.txt"): 138 | os.system("wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt") 139 | 140 | text = open('input.txt', 'r').read() # don't worry we won't run out of file handles 141 | train_dataset = CharDataset(text, args.block_size) # one line of poem is roughly 50 characters 142 | train_loader = DataLoader(train_dataset, batch_size=args.batch_size, num_workers=args.num_workers) 143 | 144 | with init_meta_context(): 145 | model = GPT( 146 | vocab_size=train_dataset.vocab_size, 147 | block_size=train_dataset.block_size, 148 | n_layer=args.n_layer, 149 | n_head=args.n_head, 150 | n_embd=args.n_embd, 151 | learning_rate=args.learning_rate 152 | ) 153 | 154 | lr_decay = LearningRateDecayCallback( 155 | learning_rate=6e-4, 156 | warmup_tokens=512 * 20, 157 | final_tokens=2 * len(train_dataset) * args.block_size 158 | ) 159 | 160 | trainer = Trainer.from_argparse_args( 161 | args, 162 | max_epochs=10, 163 | gradient_clip_val=1.0, 164 | callbacks=[lr_decay, CUDACallback()], 165 | precision=16, 166 | ) 167 | trainer.fit(model, train_loader) 168 | --------------------------------------------------------------------------------