├── .gitignore ├── README.md ├── mingpt.jpg ├── mingpt ├── __init__.py ├── lr_decay.py ├── model.py └── utils.py ├── play_char.ipynb ├── play_math.ipynb └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # minGPT 3 | 4 | ![mingpt](mingpt.jpg) 5 | 6 | A PyTorch re-implementation of [GPT](https://github.com/openai/gpt-3) training. minGPT tries to be small, clean, interpretable and educational, as most of the currently available ones are a bit sprawling. GPT is not a complicated model and this implementation is appropriately about 300 lines of code, including boilerplate and a totally unnecessary custom causal self-attention module. Anyway, all that's going on is that a sequence of indices goes into a sequence of transformer blocks, and a probability distribution of the next index comes out. The rest of the complexity is just being clever with batching (both across examples and over sequence length) so that training is efficient. 7 | 8 | The core minGPT "library" (hah) is two files: `mingpt/model.py` contains the actual Transformer model definition and `mingpt/trainer.py` is (GPT-independent) PyTorch boilerplate that trains the model. The attached Jupyter notebooks then show how the "library" (hah) can be used to train sequence models: 9 | 10 | - `play_math.ipynb` trains a GPT focused on addition (inspired by the addition section in the GPT-3 paper) 11 | - `play_char.ipynb` trains a GPT to be a character-level language model on arbitrary text, similar to my older char-rnn but with a transformer instead of an RNN 12 | - `play_words.ipynb` a BPE version that does not yet exist 13 | 14 | With a bpe encoder, distributed training and maybe fp16 this implementation may be able to reproduce GPT-1/GPT-2 results, though I haven't tried $$$. GPT-3 is likely out of reach as my understanding is that it does not fit into GPU memory and requires a more careful model-parallel treatment. 15 | 16 | ### Example usage 17 | 18 | This code is simple enough to just hack inline, not "used", but current API looks something like: 19 | 20 | ```python 21 | 22 | # you're on your own to define a class that returns individual examples as PyTorch LongTensors 23 | from torch.utils.data import Dataset, DataLoader 24 | train_dataset = MyDataset(...) 25 | val_dataset = MyDataset(...) 26 | train_loader = DataLoader(train_dataset) 27 | val_loader = DataLoader(val_dataset) 28 | 29 | # construct a GPT model 30 | from mingpt.model import GPT 31 | model = GPT(vocab_size=train_dataset.vocab_size, 32 | block_size=train_dataset.block_size, 33 | n_layer=8, 34 | n_head=8, 35 | n_embd=512, 36 | learning_rate=6e-4) 37 | 38 | # construct a trainer 39 | from pytorch_lightning import Trainer 40 | from mingpt.lr_decay import LearningRateDecayCallback 41 | 42 | # scheduler 43 | lr_decay = LearningRateDecayCallback(learning_rate=6e-4, warmup_tokens=512*20, 44 | final_tokens=00*len(train_dataset)*block_size) 45 | 46 | trainer = Trainer(gpus=1, precision=16, max_epochs=500, 47 | gradient_clip_val=1.0, 48 | callbacks=[lr_decay], 49 | progress_bar_refresh_rate=1, 50 | row_log_interval=1) 51 | trainer.fit(model, train_loader, val_loader) 52 | # (... enjoy the show for a while... ) 53 | 54 | # sample from the model (the [None, ...] and [0] are to push/pop a needed dummy batch dimension) 55 | from mingpt.utils import sample 56 | x = torch.tensor([1, 2, 3], dtype=torch.long)[None, ...] # context conditioning 57 | y = sample(model, x, steps=30, temperature=1.0, sample=True, top_k=5)[0] 58 | print(y) # our model filled in the integer sequence with 30 additional likely integers 59 | ``` 60 | 61 | ### References 62 | 63 | Code: 64 | 65 | - [openai/gpt-2](https://github.com/openai/gpt-2) has the model but not the training code, and in TensorFlow 66 | - [openai/image-gpt](https://github.com/openai/image-gpt) has some more modern gpt-3 like modification in its code, good reference as well 67 | - huggingface/transformers has a [language-modeling example](https://github.com/huggingface/transformers/tree/master/examples/language-modeling). It is full-featured but as a result also somewhat challenging to trace. E.g. some large functions have as much as 90% unused code behind various branching statements that is unused in the default setting of simple language modeling. 68 | - [Teddy Koker/image-gpu in PyTorch Lightning](https://github.com/teddykoker/image-gpt) 69 | 70 | Papers + some implementation notes: 71 | 72 | #### Improving Language Understanding by Generative Pre-Training (GPT-1) 73 | 74 | - Our model largely follows the original transformer work 75 | - We trained a 12-layer decoder-only transformer with masked self-attention heads (768 dimensional states and 12 attention heads). For the position-wise feed-forward networks, we used 3072 dimensional inner states. 76 | - Adam max learning rate of 2.5e-4. (later GPT-3 for this model size uses 6e-4) 77 | - LR decay: increased linearly from zero over the first 2000 updates and annealed to 0 using a cosine schedule 78 | - We train for 100 epochs on minibatches of 64 randomly sampled, contiguous sequences of 512 tokens. 79 | - Since layernorm is used extensively throughout the model, a simple weight initialization of N(0, 0.02) was sufficient 80 | - bytepair encoding (BPE) vocabulary with 40,000 merges 81 | - residual, embedding, and attention dropouts with a rate of 0.1 for regularization. 82 | - modified version of L2 regularization proposed in (37), with w = 0.01 on all non bias or gain weights 83 | - For the activation function, we used the Gaussian Error Linear Unit (GELU). 84 | - We used learned position embeddings instead of the sinusoidal version proposed in the original work 85 | - For finetuning: We add dropout to the classifier with a rate of 0.1. learning rate of 6.25e-5 and a batchsize of 32. 3 epochs. We use a linear learning rate decay schedule with warmup over 0.2% of training. λ was set to 0.5. 86 | - GPT-1 model is 12 layers and d_model 768, ~117M params 87 | 88 | #### Language Models are Unsupervised Multitask Learners (GPT-2) 89 | 90 | - LayerNorm was moved to the input of each sub-block, similar to a pre-activation residual network 91 | - an additional layer normalization was added after the final self-attention block. 92 | - modified initialization which accounts for the accumulation on the residual path with model depth is used. We scale the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. (weird because in their released code i can only find a simple use of the old 0.02... in their release of image-gpt I found it used for c_proj, and even then only for attn, not for mlp. huh. https://github.com/openai/image-gpt/blob/master/src/model.py) 93 | - the vocabulary is expanded to 50,257 94 | - increase the context size from 512 to 1024 tokens 95 | - larger batchsize of 512 is used 96 | - GPT-2 used 48 layers and d_model 1600 (vs. original 12 layers and d_model 768). ~1.542B params 97 | 98 | #### Language Models are Few-Shot Learners (GPT-3) 99 | 100 | - GPT-3: 96 layers, 96 heads, with d_model of 12,288 (175B parameters). 101 | - GPT-1-like: 12 layers, 12 heads, d_model 768 (125M) 102 | - We use the same model and architecture as GPT-2, including the modified initialization, pre-normalization, and reversible tokenization described therein 103 | - we use alternating dense and locally banded sparse attention patterns in the layers of the transformer, similar to the Sparse Transformer 104 | - we always have the feedforward layer four times the size of the bottleneck layer, dff = 4 ∗ dmodel 105 | - all models use a context window of nctx = 2048 tokens. 106 | - Adam with β1 = 0.9, β2 = 0.95, and eps = 10−8 107 | - All models use weight decay of 0.1 to provide a small amount of regularization. (NOTE: GPT-1 used 0.01 I believe, see above) 108 | - clip the global norm of the gradient at 1.0 109 | - Linear LR warmup over the first 375 million tokens. Then use cosine decay for learning rate down to 10% of its value, over 260 billion tokens. 110 | - gradually increase the batch size linearly from a small value (32k tokens) to the full value over the first 4-12 billion tokens of training, depending on the model size. 111 | - full 2048-sized time context window is always used, with a special END OF DOCUMENT token delimiter 112 | 113 | ### License 114 | 115 | MIT 116 | -------------------------------------------------------------------------------- /mingpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamFalcon/minGPT/ad77167036e87b72d6db117678020741005da6c6/mingpt.jpg -------------------------------------------------------------------------------- /mingpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/williamFalcon/minGPT/ad77167036e87b72d6db117678020741005da6c6/mingpt/__init__.py -------------------------------------------------------------------------------- /mingpt/lr_decay.py: -------------------------------------------------------------------------------- 1 | import math 2 | import pytorch_lightning as pl 3 | 4 | 5 | class LearningRateDecayCallback(pl.Callback): 6 | 7 | def __init__(self, learning_rate, warmup_tokens=375e6, final_tokens=260e9, lr_decay=True): 8 | super().__init__() 9 | self.learning_rate = learning_rate 10 | self.tokens = 0 11 | self.final_tokens = final_tokens 12 | self.lr_decay = lr_decay 13 | self.warmup_tokens = warmup_tokens 14 | 15 | def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): 16 | optimizer = trainer.optimizers[0] 17 | _, y = batch 18 | 19 | if self.lr_decay: 20 | self.tokens += (y >= 0).sum() # number of tokens processed this step (i.e. label is not -100) 21 | if self.tokens < self.warmup_tokens: 22 | # linear warmup 23 | lr_mult = float(self.tokens) / float(max(1, self.warmup_tokens)) 24 | else: 25 | # cosine learning rate decay 26 | progress = float(self.tokens - self.warmup_tokens) / float( 27 | max(1, self.final_tokens - self.warmup_tokens)) 28 | lr_mult = max(0.1, 0.5 * (1.0 + math.cos(math.pi * progress))) 29 | lr = self.learning_rate * lr_mult 30 | for param_group in optimizer.param_groups: 31 | param_group['lr'] = lr -------------------------------------------------------------------------------- /mingpt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | GPT model: 3 | - the initial stem consists of a combination of token encoding and a positional encoding 4 | - the meat of it is a uniform sequence of Transformer blocks 5 | - each Transformer is a sequential combination of a 1-hidden-layer MLP block and a self-attention block 6 | - all blocks feed into a central residual pathway similar to resnets 7 | - the final decoder is a linear projection into a vanilla Softmax classifier 8 | """ 9 | 10 | import math 11 | import logging 12 | 13 | import torch 14 | 15 | import torch.nn as nn 16 | from torch.nn import functional as F 17 | import pytorch_lightning as pl 18 | 19 | logger = logging.getLogger(__name__) 20 | 21 | 22 | class GPTConfig: 23 | """ base GPT config, params common to all GPT versions """ 24 | embd_pdrop = 0.1 25 | resid_pdrop = 0.1 26 | attn_pdrop = 0.1 27 | 28 | def __init__(self, vocab_size, block_size, **kwargs): 29 | self.vocab_size = vocab_size 30 | self.block_size = block_size 31 | for k,v in kwargs.items(): 32 | setattr(self, k, v) 33 | 34 | class GPT1Config(GPTConfig): 35 | """ GPT-1 like network roughly 125M params """ 36 | n_layer = 12 37 | n_head = 12 38 | n_embd = 768 39 | 40 | class CausalSelfAttention(nn.Module): 41 | """ 42 | A vanilla multi-head masked self-attention layer with a projection at the end. 43 | I believe I could have just used torch.nn.MultiheadAttention but their documentation 44 | is all but absent and code ugly so I don't trust it, rolling my own here. 45 | """ 46 | 47 | def __init__(self, config): 48 | super().__init__() 49 | assert config.n_embd % config.n_head == 0 50 | # key, query, value projections for all heads 51 | self.key = nn.Linear(config.n_embd, config.n_embd) 52 | self.query = nn.Linear(config.n_embd, config.n_embd) 53 | self.value = nn.Linear(config.n_embd, config.n_embd) 54 | # regularization 55 | self.attn_drop = nn.Dropout(config.attn_pdrop) 56 | self.resid_drop = nn.Dropout(config.resid_pdrop) 57 | # output projection 58 | self.proj = nn.Linear(config.n_embd, config.n_embd) 59 | # causal mask to ensure that attention is only applied to the left in the input sequence 60 | self.register_buffer("mask", torch.tril(torch.ones(config.block_size, config.block_size)) 61 | .view(1, 1, config.block_size, config.block_size)) 62 | self.n_head = config.n_head 63 | 64 | def forward(self, x, layer_past=None): 65 | B, T, C = x.size() 66 | 67 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 68 | k = self.key(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 69 | q = self.query(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 70 | v = self.value(x).view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 71 | 72 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 73 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 74 | att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf')) 75 | att = F.softmax(att, dim=-1) 76 | att = self.attn_drop(att) 77 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 78 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 79 | 80 | # output projection 81 | y = self.resid_drop(self.proj(y)) 82 | return y 83 | 84 | class Block(nn.Module): 85 | """ an unassuming Transformer block """ 86 | 87 | def __init__(self, config): 88 | super().__init__() 89 | self.ln1 = nn.LayerNorm(config.n_embd) 90 | self.ln2 = nn.LayerNorm(config.n_embd) 91 | self.attn = CausalSelfAttention(config) 92 | self.mlp = nn.Sequential( 93 | nn.Linear(config.n_embd, 4 * config.n_embd), 94 | nn.GELU(), 95 | nn.Linear(4 * config.n_embd, config.n_embd), 96 | nn.Dropout(config.resid_pdrop), 97 | ) 98 | 99 | def forward(self, x): 100 | x = x + self.attn(self.ln1(x)) 101 | x = x + self.mlp(self.ln2(x)) 102 | return x 103 | 104 | 105 | class GPT(pl.LightningModule): 106 | """ the full GPT language model, with a context size of block_size """ 107 | def __init__(self, 108 | vocab_size, 109 | weight_decay=0.1, 110 | betas=(0.9, 0.95), 111 | learning_rate=3e-4, 112 | n_embd=768, 113 | block_size=128, 114 | embd_pdrop=0.1, 115 | n_layer=12, 116 | n_head=4, 117 | resid_pdrop=0.1, 118 | attn_pdrop=0.1 119 | ): 120 | super().__init__() 121 | # auto creates self.hparams from the method signature 122 | self.save_hyperparameters() 123 | 124 | # in lightning the "config" is hparams (for hyperparameters) 125 | self.config = self.hparams 126 | 127 | # input embedding stem 128 | self.tok_emb = nn.Embedding(vocab_size, n_embd) 129 | self.pos_emb = nn.Parameter(torch.zeros(1, block_size, n_embd)) 130 | self.drop = nn.Dropout(embd_pdrop) 131 | # transformer 132 | self.blocks = nn.Sequential(*[Block(self.config) for _ in range(self.config.n_layer)]) 133 | # decoder head 134 | self.ln_f = nn.LayerNorm(self.config.n_embd) 135 | self.head = nn.Linear(self.config.n_embd, self.config.vocab_size, bias=False) 136 | 137 | self.block_size = self.config.block_size 138 | self.apply(self._init_weights) 139 | 140 | logger.info("number of parameters: %e", sum(p.numel() for p in self.parameters())) 141 | 142 | def _init_weights(self, module): 143 | if isinstance(module, (nn.Linear, nn.Embedding)): 144 | module.weight.data.normal_(mean=0.0, std=0.02) 145 | if isinstance(module, nn.Linear) and module.bias is not None: 146 | module.bias.data.zero_() 147 | elif isinstance(module, nn.LayerNorm): 148 | module.bias.data.zero_() 149 | module.weight.data.fill_(1.0) 150 | 151 | def get_block_size(self): 152 | return self.block_size 153 | 154 | def configure_optimizers(self): 155 | # create the optimizer 156 | no_decay = ["bias", "LayerNorm.weight"] 157 | params_decay = [p for n, p in self.named_parameters() if not any(nd in n for nd in no_decay)] 158 | params_nodecay = [p for n, p in self.named_parameters() if any(nd in n for nd in no_decay)] 159 | optim_groups = [ 160 | {"params": params_decay, "weight_decay": self.hparams.weight_decay}, 161 | {"params": params_nodecay, "weight_decay": 0.0}, 162 | ] 163 | optimizer = torch.optim.AdamW(optim_groups, lr=self.hparams.learning_rate, betas=self.hparams.betas) 164 | return optimizer 165 | 166 | def forward(self, idx): 167 | b, t = idx.size() 168 | assert t <= self.block_size, "Cannot forward, model block size is exhausted." 169 | 170 | # forward the GPT model 171 | token_embeddings = self.tok_emb(idx) # each index maps to a (learnable) vector 172 | position_embeddings = self.pos_emb[:, :t, :] # each position maps to a (learnable) vector 173 | x = self.drop(token_embeddings + position_embeddings) 174 | x = self.blocks(x) 175 | x = self.ln_f(x) 176 | logits = self.head(x) 177 | return logits 178 | 179 | def training_step(self, batch, batch_idx): 180 | idx, targets = batch 181 | # same action as inference 182 | logits = self(idx) 183 | 184 | # if we are given some desired targets also calculate the loss 185 | loss = None 186 | if targets is not None: 187 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1)) 188 | 189 | result = pl.TrainResult(minimize=loss, checkpoint_on=loss) 190 | result.log('train_loss', loss) 191 | return result 192 | 193 | -------------------------------------------------------------------------------- /mingpt/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import functional as F 6 | 7 | def set_seed(seed): 8 | random.seed(seed) 9 | np.random.seed(seed) 10 | torch.manual_seed(seed) 11 | torch.cuda.manual_seed_all(seed) 12 | 13 | def top_k_logits(logits, k): 14 | v, ix = torch.topk(logits, k) 15 | out = logits.clone() 16 | out[out < v[:, [-1]]] = -float('Inf') 17 | return out 18 | 19 | @torch.no_grad() 20 | def sample(model, x, steps, temperature=1.0, sample=False, top_k=None): 21 | """ 22 | take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in 23 | the sequence, feeding the predictions back into the model each time. Clearly the sampling 24 | has quadratic complexity unlike an RNN that is only linear, and has a finite context window 25 | of block_size, unlike an RNN that has an infinite context window. 26 | """ 27 | block_size = model.get_block_size() 28 | model.eval() 29 | for k in range(steps): 30 | x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed 31 | logits = model(x_cond) 32 | # pluck the logits at the final step and scale by temperature 33 | logits = logits[:, -1, :] / temperature 34 | # optionally crop probabilities to only the top k options 35 | if top_k is not None: 36 | logits = top_k_logits(logits, top_k) 37 | # apply softmax to convert to probabilities 38 | probs = F.softmax(logits, dim=-1) 39 | # sample from the distribution or take the most likely 40 | if sample: 41 | ix = torch.multinomial(probs, num_samples=1) 42 | else: 43 | _, ix = torch.topk(probs, k=1, dim=-1) 44 | # append to the sequence and continue 45 | x = torch.cat((x, ix), dim=1) 46 | 47 | return x 48 | -------------------------------------------------------------------------------- /play_char.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Train a character-level GPT on some text data\n", 8 | "\n", 9 | "The inputs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue as well. In this example we will feed it some shakespear, which we'll get it to predict character-level." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "data": { 19 | "text/plain": [ 20 | "42" 21 | ] 22 | }, 23 | "execution_count": 1, 24 | "metadata": {}, 25 | "output_type": "execute_result" 26 | } 27 | ], 28 | "source": [ 29 | "# make deterministic\n", 30 | "from pytorch_lightning import seed_everything\n", 31 | "seed_everything(42)" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 2, 37 | "metadata": {}, 38 | "outputs": [], 39 | "source": [ 40 | "import numpy as np\n", 41 | "import torch\n", 42 | "import torch.nn as nn\n", 43 | "from torch.nn import functional as F" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 3, 49 | "metadata": {}, 50 | "outputs": [], 51 | "source": [ 52 | "import math\n", 53 | "from torch.utils.data import Dataset, DataLoader\n", 54 | "\n", 55 | "class CharDataset(Dataset):\n", 56 | "\n", 57 | " def __init__(self, data, block_size):\n", 58 | " chars = list(set(data))\n", 59 | " data_size, vocab_size = len(data), len(chars)\n", 60 | " print('data has %d characters, %d unique.' % (data_size, vocab_size))\n", 61 | "\n", 62 | " self.stoi = { ch:i for i,ch in enumerate(chars) }\n", 63 | " self.itos = { i:ch for i,ch in enumerate(chars) }\n", 64 | " self.block_size = block_size\n", 65 | " self.vocab_size = vocab_size\n", 66 | " self.data = data\n", 67 | "\n", 68 | " def __len__(self):\n", 69 | " return math.ceil(len(self.data) / (self.block_size + 1))\n", 70 | "\n", 71 | " def __getitem__(self, idx):\n", 72 | " # we're actually going to \"cheat\" and pick a spot in the dataset at random\n", 73 | " i = np.random.randint(0, len(self.data) - (self.block_size + 1))\n", 74 | " chunk = self.data[i:i+self.block_size+1]\n", 75 | " dix = [self.stoi[s] for s in chunk]\n", 76 | " x = torch.tensor(dix[:-1], dtype=torch.long)\n", 77 | " y = torch.tensor(dix[1:], dtype=torch.long)\n", 78 | " return x, y\n" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 4, 84 | "metadata": {}, 85 | "outputs": [], 86 | "source": [ 87 | "block_size = 128 # spatial extent of the model for its context" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 9, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "name": "stdout", 97 | "output_type": "stream", 98 | "text": [ 99 | "--2020-08-19 16:03:55-- https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt\n", 100 | "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 199.232.64.133\n", 101 | "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|199.232.64.133|:443... connected.\n", 102 | "HTTP request sent, awaiting response... 200 OK\n", 103 | "Length: 1115394 (1.1M) [text/plain]\n", 104 | "Saving to: ‘input.txt’\n", 105 | "\n", 106 | "input.txt 100%[===================>] 1.06M --.-KB/s in 0.03s \n", 107 | "\n", 108 | "2020-08-19 16:03:55 (42.3 MB/s) - ‘input.txt’ saved [1115394/1115394]\n", 109 | "\n" 110 | ] 111 | } 112 | ], 113 | "source": [ 114 | "# download text from \n", 115 | "! wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt" 116 | ] 117 | }, 118 | { 119 | "cell_type": "code", 120 | "execution_count": 10, 121 | "metadata": {}, 122 | "outputs": [ 123 | { 124 | "name": "stdout", 125 | "output_type": "stream", 126 | "text": [ 127 | "data has 1115394 characters, 65 unique.\n" 128 | ] 129 | } 130 | ], 131 | "source": [ 132 | "# you can download this file at https://github.com/karpathy/char-rnn/blob/master/data/tinyshakespeare/input.txt\n", 133 | "text = open('input.txt', 'r').read() # don't worry we won't run out of file handles\n", 134 | "train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters\n", 135 | "train_loader = DataLoader(train_dataset, batch_size=256, num_workers=4)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 11, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "from mingpt.model import GPT\n", 145 | "model = GPT(vocab_size=train_dataset.vocab_size, \n", 146 | " block_size=train_dataset.block_size,\n", 147 | " n_layer=8, \n", 148 | " n_head=8, \n", 149 | " n_embd=512, \n", 150 | " learning_rate=6e-4)" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": 12, 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stderr", 160 | "output_type": "stream", 161 | "text": [ 162 | "GPU available: True, used: True\n", 163 | "TPU available: False, using: 0 TPU cores\n", 164 | "CUDA_VISIBLE_DEVICES: [0]\n", 165 | "Using native 16bit precision.\n", 166 | "\n", 167 | " | Name | Type | Params\n", 168 | "---------------------------------------\n", 169 | "0 | tok_emb | Embedding | 33 K \n", 170 | "1 | drop | Dropout | 0 \n", 171 | "2 | blocks | Sequential | 25 M \n", 172 | "3 | ln_f | LayerNorm | 1 K \n", 173 | "4 | head | Linear | 33 K \n" 174 | ] 175 | }, 176 | { 177 | "data": { 178 | "application/vnd.jupyter.widget-view+json": { 179 | "model_id": "ab86a3945eb54225b221c2a07f00509f", 180 | "version_major": 2, 181 | "version_minor": 0 182 | }, 183 | "text/plain": [ 184 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 185 | ] 186 | }, 187 | "metadata": {}, 188 | "output_type": "display_data" 189 | }, 190 | { 191 | "name": "stderr", 192 | "output_type": "stream", 193 | "text": [ 194 | "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:23: UserWarning: \n", 195 | " When using EvalResult(early_stop_on=X) or TrainResult(early_stop_on=X) the\n", 196 | " 'monitor' key of ModelCheckpoint has no effect.\n", 197 | " Remove ModelCheckpoint(monitor='loss) to fix')\n", 198 | " \n", 199 | " warnings.warn(*args, **kwargs)\n" 200 | ] 201 | }, 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "\n" 207 | ] 208 | }, 209 | { 210 | "name": "stderr", 211 | "output_type": "stream", 212 | "text": [ 213 | "/opt/conda/lib/python3.7/site-packages/pytorch_lightning/utilities/distributed.py:23: UserWarning: Detected KeyboardInterrupt, attempting graceful shutdown...\n", 214 | " warnings.warn(*args, **kwargs)\n" 215 | ] 216 | }, 217 | { 218 | "data": { 219 | "text/plain": [ 220 | "1" 221 | ] 222 | }, 223 | "execution_count": 12, 224 | "metadata": {}, 225 | "output_type": "execute_result" 226 | } 227 | ], 228 | "source": [ 229 | "from pytorch_lightning import Trainer\n", 230 | "from mingpt.lr_decay import LearningRateDecayCallback\n", 231 | "\n", 232 | "# scheduler\n", 233 | "lr_decay = LearningRateDecayCallback(learning_rate=6e-4, warmup_tokens=512*20,\n", 234 | " final_tokens=00*len(train_dataset)*block_size)\n", 235 | "\n", 236 | "trainer = Trainer(gpus=1, precision=16, max_epochs=500,\n", 237 | " gradient_clip_val=1.0, \n", 238 | " callbacks=[lr_decay], \n", 239 | " progress_bar_refresh_rate=1, \n", 240 | " row_log_interval=1)\n", 241 | "trainer.fit(model, train_loader)" 242 | ] 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": 18, 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "O God, I code but their friends.\n", 254 | "\n", 255 | "KING EDWARD IV:\n", 256 | "Thou hast thronge indable thy father friar,\n", 257 | "Stand up and desperite, should virtuous advit.\n", 258 | "\n", 259 | "SICINIUS:\n", 260 | "Sir, the king of words this land him.\n", 261 | "\n", 262 | "BIONDA:\n", 263 | "You marry; my lord.\n", 264 | "\n", 265 | "SICINIUS:\n", 266 | " faith, know, you say, My company.\n", 267 | "\n", 268 | "MENENIUS:\n", 269 | "You passion, this name:\n", 270 | "If she do seat your sight, and no more,\n", 271 | "So save man than still, what says 'tis more commongt\n", 272 | "To sling hell bit will be bastanded of your deliver,\n", 273 | "Remither than shall still, his land hand;\n", 274 | "More im thou not, and the subject more,\n", 275 | "Stime at eample, and saffe his corder--feath, this\n", 276 | "manify stiff his life, and what may live, and\n", 277 | "Nor what shorn compassion to my sover; but I do,\n", 278 | "I'll commplainly to still, be born him: I am thought\n", 279 | "In shhe yould still, and say 'anoth;\n", 280 | "For though here do selfs and consul,\n", 281 | "With leave more brings and ours, catisfied,\n", 282 | "Shaill and yourself to your most to think,\n", 283 | "Where believes their dince and thou ne'er to kithfull;\n", 284 | "With hom they have do your high earth to thing,\n", 285 | "Which shall hm lear him ca\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "# alright, let's sample some character-level shakespear\n", 291 | "from mingpt.utils import sample\n", 292 | "\n", 293 | "context = \"O God, I code but\"\n", 294 | "x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(model.device)\n", 295 | "y = sample(model, x, 1000, temperature=0.9, sample=True, top_k=5)[0]\n", 296 | "completion = ''.join([train_dataset.itos[int(i)] for i in y])\n", 297 | "print(completion)" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": null, 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "# well that was fun..." 307 | ] 308 | } 309 | ], 310 | "metadata": { 311 | "kernelspec": { 312 | "display_name": "Python 3", 313 | "language": "python", 314 | "name": "python3" 315 | }, 316 | "language_info": { 317 | "codemirror_mode": { 318 | "name": "ipython", 319 | "version": 3 320 | }, 321 | "file_extension": ".py", 322 | "mimetype": "text/x-python", 323 | "name": "python", 324 | "nbconvert_exporter": "python", 325 | "pygments_lexer": "ipython3", 326 | "version": "3.7.8" 327 | } 328 | }, 329 | "nbformat": 4, 330 | "nbformat_minor": 4 331 | } 332 | -------------------------------------------------------------------------------- /play_math.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Train GPT on addition\n", 8 | "\n", 9 | "Train a GPT model on a dedicated addition dataset to see if a Transformer can learn to add." 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 12, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "# set up logging\n", 19 | "import logging\n", 20 | "logging.basicConfig(\n", 21 | " format=\"%(asctime)s - %(levelname)s - %(name)s - %(message)s\",\n", 22 | " datefmt=\"%m/%d/%Y %H:%M:%S\",\n", 23 | " level=logging.INFO,\n", 24 | ")" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": 13, 30 | "metadata": {}, 31 | "outputs": [], 32 | "source": [ 33 | "# make deterministic\n", 34 | "from mingpt.utils import set_seed\n", 35 | "set_seed(42)" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": 14, 41 | "metadata": {}, 42 | "outputs": [], 43 | "source": [ 44 | "import numpy as np\n", 45 | "import torch" 46 | ] 47 | }, 48 | { 49 | "cell_type": "code", 50 | "execution_count": 15, 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "from torch.utils.data import Dataset, DataLoader\n", 55 | "\n", 56 | "class AdditionDataset(Dataset):\n", 57 | " \"\"\"\n", 58 | " Returns addition problems of up to some number of digits in the inputs. Recall\n", 59 | " that all GPT cares about are sequences of integers, and completing them according to\n", 60 | " patterns in the data. Therefore, we have to somehow encode addition problems\n", 61 | " as a sequence of integers.\n", 62 | " \n", 63 | " The sum of two n-digit numbers gives a third up to (n+1)-digit number. So our\n", 64 | " encoding will simply be the n-digit first number, n-digit second number, \n", 65 | " and (n+1)-digit result, all simply concatenated together. Because each addition\n", 66 | " problem is so structured, there is no need to bother the model with encoding\n", 67 | " +, =, or other tokens. Each possible sequence has the same length, and simply\n", 68 | " contains the raw digits of the addition problem.\n", 69 | " \n", 70 | " As a few examples, the 2-digit problems:\n", 71 | " - 85 + 50 = 135 becomes the sequence [8, 5, 5, 0, 1, 3, 5]\n", 72 | " - 6 + 39 = 45 becomes the sequence [0, 6, 3, 9, 0, 4, 5]\n", 73 | " etc.\n", 74 | " \n", 75 | " We will also only train GPT on the final (n+1)-digits because the first\n", 76 | " two n-digits are always assumed to be given. So when we give GPT an exam later,\n", 77 | " we will e.g. feed it the sequence [0, 6, 3, 9], which encodes that we'd like\n", 78 | " to add 6 + 39, and hope that the model completes the integer sequence with [0, 4, 5]\n", 79 | " in 3 sequential steps.\n", 80 | " \n", 81 | " fun exercise: does it help if the result is asked to be produced in reverse order?\n", 82 | " \"\"\"\n", 83 | "\n", 84 | " def __init__(self, ndigit, split):\n", 85 | " self.split = split # train/test\n", 86 | " self.ndigit = ndigit\n", 87 | " self.vocab_size = 10 # 10 possible digits 0..9\n", 88 | " # +1 due to potential carry overflow, but then -1 because very last digit doesn't plug back\n", 89 | " self.block_size = ndigit + ndigit + ndigit + 1 - 1\n", 90 | " \n", 91 | " # split up all addition problems into either training data or test data\n", 92 | " num = (10**self.ndigit)**2 # total number of possible combinations\n", 93 | " r = np.random.RandomState(1337) # make deterministic\n", 94 | " perm = r.permutation(num)\n", 95 | " num_test = min(int(num*0.2), 1000) # 20% of the whole dataset, or only up to 1000\n", 96 | " self.ixes = perm[:num_test] if split == 'test' else perm[num_test:]\n", 97 | "\n", 98 | " def __len__(self):\n", 99 | " return self.ixes.size\n", 100 | "\n", 101 | " def __getitem__(self, idx):\n", 102 | " # given a problem index idx, first recover the associated a + b\n", 103 | " idx = self.ixes[idx]\n", 104 | " nd = 10**self.ndigit\n", 105 | " a = idx // nd\n", 106 | " b = idx % nd\n", 107 | " c = a + b\n", 108 | " render = f'%0{self.ndigit}d%0{self.ndigit}d%0{self.ndigit+1}d' % (a,b,c) # e.g. 03+25=28 becomes \"0325028\" \n", 109 | " dix = [int(s) for s in render] # convert each character to its token index\n", 110 | " # x will be input to GPT and y will be the associated expected outputs\n", 111 | " x = torch.tensor(dix[:-1], dtype=torch.long)\n", 112 | " y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence\n", 113 | " y[:self.ndigit*2-1] = -100 # we will only train in the output locations. -100 will mask loss to zero\n", 114 | " return x, y\n" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 16, 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [ 123 | "# create a dataset for e.g. 2-digit addition\n", 124 | "ndigit = 2\n", 125 | "train_dataset = AdditionDataset(ndigit=ndigit, split='train')\n", 126 | "test_dataset = AdditionDataset(ndigit=ndigit, split='test')\n", 127 | "train_dataloader = DataLoader(train_dataset, batch_size=512, num_workers=4)\n", 128 | "val_dataloader = DataLoader(test_dataset, batch_size=512, num_workers=4)" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 17, 134 | "metadata": {}, 135 | "outputs": [ 136 | { 137 | "data": { 138 | "text/plain": [ 139 | "(tensor([4, 7, 1, 7, 0, 6]), tensor([-100, -100, -100, 0, 6, 4]))" 140 | ] 141 | }, 142 | "execution_count": 17, 143 | "metadata": {}, 144 | "output_type": "execute_result" 145 | } 146 | ], 147 | "source": [ 148 | "train_dataset[0] # sample a training instance just to see what one raw example looks like" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": 19, 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stderr", 158 | "output_type": "stream", 159 | "text": [ 160 | "08/19/2020 16:20:44 - INFO - mingpt.model - number of parameters: 4.001280e+05\n" 161 | ] 162 | } 163 | ], 164 | "source": [ 165 | "from mingpt.model import GPT\n", 166 | "\n", 167 | "# initialize a baby GPT model\n", 168 | "model = GPT(vocab_size=train_dataset.vocab_size,\n", 169 | " block_size=train_dataset.block_size,\n", 170 | " n_layer=2,\n", 171 | " n_head=4,\n", 172 | " n_embd=128,\n", 173 | " learning_rate=6e-4)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 26, 179 | "metadata": {}, 180 | "outputs": [ 181 | { 182 | "name": "stderr", 183 | "output_type": "stream", 184 | "text": [ 185 | "GPU available: True, used: True\n", 186 | "08/19/2020 16:23:14 - INFO - lightning - GPU available: True, used: True\n", 187 | "TPU available: False, using: 0 TPU cores\n", 188 | "08/19/2020 16:23:14 - INFO - lightning - TPU available: False, using: 0 TPU cores\n", 189 | "CUDA_VISIBLE_DEVICES: [0]\n", 190 | "08/19/2020 16:23:14 - INFO - lightning - CUDA_VISIBLE_DEVICES: [0]\n", 191 | "Using native 16bit precision.\n", 192 | "08/19/2020 16:23:14 - INFO - lightning - Using native 16bit precision.\n", 193 | "\n", 194 | " | Name | Type | Params\n", 195 | "---------------------------------------\n", 196 | "0 | tok_emb | Embedding | 1 K \n", 197 | "1 | drop | Dropout | 0 \n", 198 | "2 | blocks | Sequential | 396 K \n", 199 | "3 | ln_f | LayerNorm | 256 \n", 200 | "4 | head | Linear | 1 K \n", 201 | "08/19/2020 16:23:14 - INFO - lightning - \n", 202 | " | Name | Type | Params\n", 203 | "---------------------------------------\n", 204 | "0 | tok_emb | Embedding | 1 K \n", 205 | "1 | drop | Dropout | 0 \n", 206 | "2 | blocks | Sequential | 396 K \n", 207 | "3 | ln_f | LayerNorm | 256 \n", 208 | "4 | head | Linear | 1 K \n" 209 | ] 210 | }, 211 | { 212 | "data": { 213 | "application/vnd.jupyter.widget-view+json": { 214 | "model_id": "be94b974523446f6846e6291b6e6608b", 215 | "version_major": 2, 216 | "version_minor": 0 217 | }, 218 | "text/plain": [ 219 | "HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…" 220 | ] 221 | }, 222 | "metadata": {}, 223 | "output_type": "display_data" 224 | }, 225 | { 226 | "name": "stdout", 227 | "output_type": "stream", 228 | "text": [ 229 | "\n" 230 | ] 231 | }, 232 | { 233 | "data": { 234 | "text/plain": [ 235 | "1" 236 | ] 237 | }, 238 | "execution_count": 26, 239 | "metadata": {}, 240 | "output_type": "execute_result" 241 | } 242 | ], 243 | "source": [ 244 | "from pytorch_lightning import Trainer\n", 245 | "from mingpt.lr_decay import LearningRateDecayCallback\n", 246 | "\n", 247 | "lr_decay = LearningRateDecayCallback(learning_rate=6e-4, warmup_tokens=1024,\n", 248 | " final_tokens=50*len(train_dataset)*(ndigit+1))\n", 249 | "\n", 250 | "trainer = Trainer(gpus=1, precision=16, max_epochs=50, callbacks=[lr_decay], row_log_interval=3)\n", 251 | "trainer.fit(model, train_dataloader, val_dataloader)" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": null, 257 | "metadata": {}, 258 | "outputs": [], 259 | "source": [ 260 | "# now let's give the trained model an addition exam\n", 261 | "from torch.utils.data.dataloader import DataLoader\n", 262 | "from mingpt.utils import sample\n", 263 | "\n", 264 | "def give_exam(dataset, batch_size=32, max_batches=-1):\n", 265 | " \n", 266 | " results = []\n", 267 | " loader = DataLoader(dataset, batch_size=batch_size)\n", 268 | " for b, (x, y) in enumerate(loader):\n", 269 | " x = x.to(model.device)\n", 270 | " d1d2 = x[:, :ndigit*2]\n", 271 | " d1d2d3 = sample(model, d1d2, ndigit+1)\n", 272 | " d3 = d1d2d3[:, -(ndigit+1):]\n", 273 | " factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(model.device)\n", 274 | " # decode the integers from individual digits\n", 275 | " d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1)\n", 276 | " d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1)\n", 277 | " d3i_pred = (d3 * factors).sum(1)\n", 278 | " d3i_gt = d1i + d2i\n", 279 | " correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line, lol\n", 280 | " for i in range(x.size(0)):\n", 281 | " results.append(int(correct[i]))\n", 282 | " judge = 'YEP!!!' if correct[i] else 'NOPE'\n", 283 | " if not correct[i]:\n", 284 | " print(\"GPT claims that %03d + %03d = %03d (gt is %03d; %s)\" \n", 285 | " % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i], judge))\n", 286 | " \n", 287 | " if max_batches >= 0 and b+1 >= max_batches:\n", 288 | " break\n", 289 | "\n", 290 | " print(\"final score: %d/%d = %.2f%% correct\" % (np.sum(results), len(results), 100*np.mean(results)))" 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": null, 296 | "metadata": { 297 | "scrolled": true 298 | }, 299 | "outputs": [], 300 | "source": [ 301 | "# training set: how well did we memorize?\n", 302 | "give_exam(train_dataset, batch_size=1024, max_batches=10)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": 25, 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "final score: 1000/1000 = 100.00% correct\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "# test set: how well did we generalize?\n", 320 | "give_exam(test_dataset, batch_size=1024, max_batches=-1)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": null, 326 | "metadata": {}, 327 | "outputs": [], 328 | "source": [ 329 | "# model got all correct!" 330 | ] 331 | } 332 | ], 333 | "metadata": { 334 | "kernelspec": { 335 | "display_name": "Python 3", 336 | "language": "python", 337 | "name": "python3" 338 | }, 339 | "language_info": { 340 | "codemirror_mode": { 341 | "name": "ipython", 342 | "version": 3 343 | }, 344 | "file_extension": ".py", 345 | "mimetype": "text/x-python", 346 | "name": "python", 347 | "nbconvert_exporter": "python", 348 | "pygments_lexer": "ipython3", 349 | "version": "3.7.8" 350 | } 351 | }, 352 | "nbformat": 4, 353 | "nbformat_minor": 4 354 | } 355 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==0.9.0 2 | --------------------------------------------------------------------------------