├── mingpt ├── __init__.py ├── trainer.py ├── utils.py ├── bpe.py └── model.py ├── mingpt.jpg ├── .gitignore ├── projects ├── adder │ ├── readme.md │ └── adder.py ├── readme.md └── chargpt │ ├── readme.md │ └── chargpt.py ├── LICENSE ├── tests └── test_huggingface_import.py ├── generate.ipynb ├── README.md └── demo.ipynb /mingpt/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /mingpt.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/nat/minGPT/master/mingpt.jpg -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | __pycache__/ 3 | *.swp 4 | .env 5 | .pylintrc 6 | -------------------------------------------------------------------------------- /projects/adder/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ### adder 3 | 4 | Train a GPT model to add n-digit numbers 5 | -------------------------------------------------------------------------------- /projects/readme.md: -------------------------------------------------------------------------------- 1 | 2 | ### minGPT projects 3 | 4 | Various projects that use the minGPT library to achieve great things. 5 | -------------------------------------------------------------------------------- /projects/chargpt/readme.md: -------------------------------------------------------------------------------- 1 | # chargpt 2 | 3 | chargpt trains a character-level language model. 4 | 5 | We support three settings: 1 convenience setting and 2 "benchmark" settings that have acedemic literature results: 6 | 7 | - a user specified `input.txt` file that we train an LM on (e.g. get tiny-shakespear (1.1MB of data) [here](https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt)) 8 | - TODO [text8](http://mattmahoney.net/dc/textdata.html): also derived from Wikipedia text but all XML is removed and is lowercased to only 26 characters of 9 | - TODO [enwik8](http://prize.hutter1.net) benchmark ("Hutter Prize"), first 100M bytes of a Wikipedia XML dump, with 205 unique tokensEnglish plus spaces 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | The MIT License (MIT) Copyright (c) 2020 Andrej Karpathy 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 4 | 5 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 6 | 7 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 8 | -------------------------------------------------------------------------------- /tests/test_huggingface_import.py: -------------------------------------------------------------------------------- 1 | """ 2 | Ensure that we can load huggingface/transformer GPTs into minGPT 3 | """ 4 | 5 | import unittest 6 | import torch 7 | from transformers import GPT2Tokenizer, GPT2LMHeadModel 8 | from mingpt.model import GPT 9 | from mingpt.bpe import BPETokenizer 10 | # ----------------------------------------------------------------------------- 11 | 12 | class TestHuggingFaceImport(unittest.TestCase): 13 | 14 | def test_gpt2(self): 15 | model_type = 'gpt2' 16 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 17 | prompt = "Hello!!!!!!!!!? 🤗, my dog is a little" 18 | 19 | # create a minGPT and a huggingface/transformers model 20 | model = GPT.from_pretrained(model_type) 21 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) # init a HF model too 22 | 23 | # ship both to device 24 | model.to(device) 25 | model_hf.to(device) 26 | 27 | # set both to eval mode 28 | model.eval() 29 | model_hf.eval() 30 | 31 | # tokenize input prompt 32 | # ... with mingpt 33 | tokenizer = BPETokenizer() 34 | x1 = tokenizer(prompt).to(device) 35 | # ... with huggingface/transformers 36 | tokenizer_hf = GPT2Tokenizer.from_pretrained(model_type) 37 | model_hf.config.pad_token_id = model_hf.config.eos_token_id # suppress a warning 38 | encoded_input = tokenizer_hf(prompt, return_tensors='pt').to(device) 39 | x2 = encoded_input['input_ids'] 40 | 41 | # ensure the logits match exactly 42 | logits1, loss = model(x1) 43 | logits2 = model_hf(x2).logits 44 | self.assertTrue(torch.allclose(logits1, logits2)) 45 | 46 | # now draw the argmax samples from each 47 | y1 = model.generate(x1, max_new_tokens=20, do_sample=False)[0] 48 | y2 = model_hf.generate(x2, max_new_tokens=20, do_sample=False)[0] 49 | self.assertTrue(torch.equal(y1, y2)) # compare the raw sampled indices 50 | 51 | # convert indices to strings 52 | out1 = tokenizer.decode(y1.cpu().squeeze()) 53 | out2 = tokenizer_hf.decode(y2.cpu().squeeze()) 54 | self.assertTrue(out1 == out2) # compare the exact output strings too 55 | 56 | if __name__ == '__main__': 57 | unittest.main() 58 | -------------------------------------------------------------------------------- /mingpt/trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple training loop; Boilerplate that could apply to any arbitrary neural network, 3 | so nothing in this file really has anything to do with GPT specifically. 4 | """ 5 | 6 | import time 7 | from collections import defaultdict 8 | 9 | import torch 10 | from torch.utils.data.dataloader import DataLoader 11 | from mingpt.utils import CfgNode as CN 12 | 13 | class Trainer: 14 | 15 | @staticmethod 16 | def get_default_config(): 17 | C = CN() 18 | # device to train on 19 | C.device = 'auto' 20 | # dataloder parameters 21 | C.num_workers = 4 22 | # optimizer parameters 23 | C.max_iters = None 24 | C.batch_size = 64 25 | C.learning_rate = 3e-4 26 | C.betas = (0.9, 0.95) 27 | C.weight_decay = 0.1 # only applied on matmul weights 28 | C.grad_norm_clip = 1.0 29 | return C 30 | 31 | def __init__(self, config, model, train_dataset): 32 | self.config = config 33 | self.model = model 34 | self.train_dataset = train_dataset 35 | self.callbacks = defaultdict(list) 36 | 37 | # determine the device we'll train on 38 | if config.device == 'auto': 39 | self.device = 'cuda' if torch.cuda.is_available() else 'cpu' 40 | else: 41 | self.device = config.device 42 | self.model = self.model.to(self.device) 43 | print("running on device", self.device) 44 | 45 | # variables that will be assigned to trainer class later for logging and etc 46 | self.iter_num = 0 47 | self.iter_time = 0.0 48 | self.iter_dt = 0.0 49 | 50 | def add_callback(self, onevent: str, callback): 51 | self.callbacks[onevent].append(callback) 52 | 53 | def set_callback(self, onevent: str, callback): 54 | self.callbacks[onevent] = [callback] 55 | 56 | def trigger_callbacks(self, onevent: str): 57 | for callback in self.callbacks.get(onevent, []): 58 | callback(self) 59 | 60 | def run(self): 61 | model, config = self.model, self.config 62 | 63 | # setup the optimizer 64 | optimizer = model.configure_optimizers(config) 65 | 66 | # setup the dataloader 67 | train_loader = DataLoader( 68 | self.train_dataset, 69 | sampler=torch.utils.data.RandomSampler(self.train_dataset, replacement=True, num_samples=int(1e10)), 70 | shuffle=False, 71 | pin_memory=True, 72 | batch_size=config.batch_size, 73 | num_workers=config.num_workers, 74 | ) 75 | 76 | model.train() 77 | self.iter_num = 0 78 | self.iter_time = time.time() 79 | data_iter = iter(train_loader) 80 | while True: 81 | 82 | # fetch the next batch (x, y) and re-init iterator if needed 83 | try: 84 | batch = next(data_iter) 85 | except StopIteration: 86 | data_iter = iter(train_loader) 87 | batch = next(data_iter) 88 | batch = [t.to(self.device) for t in batch] 89 | x, y = batch 90 | 91 | # forward the model 92 | logits, self.loss = model(x, y) 93 | 94 | # backprop and update the parameters 95 | model.zero_grad(set_to_none=True) 96 | self.loss.backward() 97 | torch.nn.utils.clip_grad_norm_(model.parameters(), config.grad_norm_clip) 98 | optimizer.step() 99 | 100 | self.trigger_callbacks('on_batch_end') 101 | self.iter_num += 1 102 | tnow = time.time() 103 | self.iter_dt = tnow - self.iter_time 104 | self.iter_time = tnow 105 | 106 | # termination conditions 107 | if config.max_iters is not None and self.iter_num >= config.max_iters: 108 | break 109 | -------------------------------------------------------------------------------- /mingpt/utils.py: -------------------------------------------------------------------------------- 1 | 2 | import os 3 | import sys 4 | import json 5 | import random 6 | from ast import literal_eval 7 | 8 | import numpy as np 9 | import torch 10 | 11 | # ----------------------------------------------------------------------------- 12 | 13 | def set_seed(seed): 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed_all(seed) 18 | 19 | def setup_logging(config): 20 | """ monotonous bookkeeping """ 21 | work_dir = config.system.work_dir 22 | # create the work directory if it doesn't already exist 23 | os.makedirs(work_dir, exist_ok=True) 24 | # log the args (if any) 25 | with open(os.path.join(work_dir, 'args.txt'), 'w') as f: 26 | f.write(' '.join(sys.argv)) 27 | # log the config itself 28 | with open(os.path.join(work_dir, 'config.json'), 'w') as f: 29 | f.write(json.dumps(config.to_dict(), indent=4)) 30 | 31 | class CfgNode: 32 | """ a lightweight configuration class inspired by yacs """ 33 | # TODO: convert to subclass from a dict like in yacs? 34 | # TODO: implement freezing to prevent shooting of own foot 35 | # TODO: additional existence/override checks when reading/writing params? 36 | 37 | def __init__(self, **kwargs): 38 | self.__dict__.update(kwargs) 39 | 40 | def __str__(self): 41 | return self._str_helper(0) 42 | 43 | def _str_helper(self, indent): 44 | """ need to have a helper to support nested indentation for pretty printing """ 45 | parts = [] 46 | for k, v in self.__dict__.items(): 47 | if isinstance(v, CfgNode): 48 | parts.append("%s:\n" % k) 49 | parts.append(v._str_helper(indent + 1)) 50 | else: 51 | parts.append("%s: %s\n" % (k, v)) 52 | parts = [' ' * (indent * 4) + p for p in parts] 53 | return "".join(parts) 54 | 55 | def to_dict(self): 56 | """ return a dict representation of the config """ 57 | return { k: v.to_dict() if isinstance(v, CfgNode) else v for k, v in self.__dict__.items() } 58 | 59 | def merge_from_dict(self, d): 60 | self.__dict__.update(d) 61 | 62 | def merge_from_args(self, args): 63 | """ 64 | update the configuration from a list of strings that is expected 65 | to come from the command line, i.e. sys.argv[1:]. 66 | 67 | The arguments are expected to be in the form of `--arg=value`, and 68 | the arg can use . to denote nested sub-attributes. Example: 69 | 70 | --model.n_layer=10 --trainer.batch_size=32 71 | """ 72 | for arg in args: 73 | 74 | keyval = arg.split('=') 75 | assert len(keyval) == 2, "expecting each override arg to be of form --arg=value, got %s" % arg 76 | key, val = keyval # unpack 77 | 78 | # first translate val into a python object 79 | try: 80 | val = literal_eval(val) 81 | """ 82 | need some explanation here. 83 | - if val is simply a string, literal_eval will throw a ValueError 84 | - if val represents a thing (like an 3, 3.14, [1,2,3], False, None, etc.) it will get created 85 | """ 86 | except ValueError: 87 | pass 88 | 89 | # find the appropriate object to insert the attribute into 90 | assert key[:2] == '--' 91 | key = key[2:] # strip the '--' 92 | keys = key.split('.') 93 | obj = self 94 | for k in keys[:-1]: 95 | obj = getattr(obj, k) 96 | leaf_key = keys[-1] 97 | 98 | # ensure that this attribute exists 99 | assert hasattr(obj, leaf_key), f"{key} is not an attribute that exists in the config" 100 | 101 | # overwrite the attribute 102 | print("command line overwriting config attribute %s with %s" % (key, val)) 103 | setattr(obj, leaf_key, val) 104 | -------------------------------------------------------------------------------- /projects/chargpt/chargpt.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a character-level language model. 3 | """ 4 | 5 | import os 6 | import sys 7 | 8 | import torch 9 | from torch.utils.data import Dataset 10 | from torch.utils.data.dataloader import DataLoader 11 | 12 | from mingpt.model import GPT 13 | from mingpt.trainer import Trainer 14 | from mingpt.utils import set_seed, setup_logging, CfgNode as CN 15 | 16 | # ----------------------------------------------------------------------------- 17 | 18 | def get_config(): 19 | 20 | C = CN() 21 | 22 | # system 23 | C.system = CN() 24 | C.system.seed = 3407 25 | C.system.work_dir = './out/chargpt' 26 | 27 | # data 28 | C.data = CharDataset.get_default_config() 29 | 30 | # model 31 | C.model = GPT.get_default_config() 32 | C.model.model_type = 'gpt-mini' 33 | 34 | # trainer 35 | C.trainer = Trainer.get_default_config() 36 | C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster 37 | 38 | return C 39 | 40 | # ----------------------------------------------------------------------------- 41 | 42 | class CharDataset(Dataset): 43 | """ 44 | Emits batches of characters 45 | """ 46 | 47 | @staticmethod 48 | def get_default_config(): 49 | C = CN() 50 | C.block_size = 128 51 | return C 52 | 53 | def __init__(self, config, data): 54 | self.config = config 55 | 56 | chars = sorted(list(set(data))) 57 | data_size, vocab_size = len(data), len(chars) 58 | print('data has %d characters, %d unique.' % (data_size, vocab_size)) 59 | 60 | self.stoi = { ch:i for i,ch in enumerate(chars) } 61 | self.itos = { i:ch for i,ch in enumerate(chars) } 62 | self.vocab_size = vocab_size 63 | self.data = data 64 | 65 | def get_vocab_size(self): 66 | return self.vocab_size 67 | 68 | def get_block_size(self): 69 | return self.config.block_size 70 | 71 | def __len__(self): 72 | return len(self.data) - self.config.block_size 73 | 74 | def __getitem__(self, idx): 75 | # grab a chunk of (block_size + 1) characters from the data 76 | chunk = self.data[idx:idx + self.config.block_size + 1] 77 | # encode every character to an integer 78 | dix = [self.stoi[s] for s in chunk] 79 | # return as tensors 80 | x = torch.tensor(dix[:-1], dtype=torch.long) 81 | y = torch.tensor(dix[1:], dtype=torch.long) 82 | return x, y 83 | 84 | # ----------------------------------------------------------------------------- 85 | 86 | if __name__ == '__main__': 87 | 88 | # get default config and overrides from the command line, if any 89 | config = get_config() 90 | config.merge_from_args(sys.argv[1:]) 91 | print(config) 92 | setup_logging(config) 93 | set_seed(config.system.seed) 94 | 95 | # construct the training dataset 96 | text = open('input.txt', 'r').read() # don't worry we won't run out of file handles 97 | train_dataset = CharDataset(config.data, text) 98 | 99 | # construct the model 100 | config.model.vocab_size = train_dataset.get_vocab_size() 101 | config.model.block_size = train_dataset.get_block_size() 102 | model = GPT(config.model) 103 | 104 | # construct the trainer object 105 | trainer = Trainer(config.trainer, model, train_dataset) 106 | 107 | # iteration callback 108 | def batch_end_callback(trainer): 109 | 110 | if trainer.iter_num % 10 == 0: 111 | print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") 112 | 113 | if trainer.iter_num % 500 == 0: 114 | # evaluate both the train and test score 115 | model.eval() 116 | with torch.no_grad(): 117 | # sample from the model... 118 | context = "O God, O God!" 119 | x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device) 120 | y = model.generate(x, 500, temperature=1.0, do_sample=True, top_k=10)[0] 121 | completion = ''.join([train_dataset.itos[int(i)] for i in y]) 122 | print(completion) 123 | # save the latest model 124 | print("saving model") 125 | ckpt_path = os.path.join(config.system.work_dir, "model.pt") 126 | torch.save(model.state_dict(), ckpt_path) 127 | # revert model to training mode 128 | model.train() 129 | 130 | trainer.set_callback('on_batch_end', batch_end_callback) 131 | 132 | # run the optimization 133 | trainer.run() 134 | -------------------------------------------------------------------------------- /generate.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "Shows how one can generate text given a prompt and some hyperparameters, using either minGPT or huggingface/transformers" 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", 18 | "from mingpt.model import GPT\n", 19 | "from mingpt.utils import set_seed\n", 20 | "from mingpt.bpe import BPETokenizer\n", 21 | "set_seed(3407)" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 2, 27 | "metadata": {}, 28 | "outputs": [], 29 | "source": [ 30 | "use_mingpt = True # use minGPT or huggingface/transformers model?\n", 31 | "model_type = 'gpt2-xl'\n", 32 | "device = 'cuda'" 33 | ] 34 | }, 35 | { 36 | "cell_type": "code", 37 | "execution_count": 3, 38 | "metadata": {}, 39 | "outputs": [ 40 | { 41 | "name": "stdout", 42 | "output_type": "stream", 43 | "text": [ 44 | "number of parameters: 1557.61M\n" 45 | ] 46 | } 47 | ], 48 | "source": [ 49 | "if use_mingpt:\n", 50 | " model = GPT.from_pretrained(model_type)\n", 51 | "else:\n", 52 | " model = GPT2LMHeadModel.from_pretrained(model_type)\n", 53 | " model.config.pad_token_id = model.config.eos_token_id # suppress a warning\n", 54 | "\n", 55 | "# ship model to device and set to eval mode\n", 56 | "model.to(device)\n", 57 | "model.eval();" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [], 65 | "source": [ 66 | "\n", 67 | "def generate(prompt='', num_samples=10, steps=20, do_sample=True):\n", 68 | " \n", 69 | " # tokenize the input prompt into integer input sequence\n", 70 | " if use_mingpt:\n", 71 | " tokenizer = BPETokenizer()\n", 72 | " if prompt == '':\n", 73 | " # to create unconditional samples...\n", 74 | " # manually create a tensor with only the special <|endoftext|> token\n", 75 | " # similar to what openai's code does here https://github.com/openai/gpt-2/blob/master/src/generate_unconditional_samples.py\n", 76 | " x = torch.tensor([[tokenizer.encoder.encoder['<|endoftext|>']]], dtype=torch.long)\n", 77 | " else:\n", 78 | " x = tokenizer(prompt).to(device)\n", 79 | " else:\n", 80 | " tokenizer = GPT2Tokenizer.from_pretrained(model_type)\n", 81 | " if prompt == '': \n", 82 | " # to create unconditional samples...\n", 83 | " # huggingface/transformers tokenizer special cases these strings\n", 84 | " prompt = '<|endoftext|>'\n", 85 | " encoded_input = tokenizer(prompt, return_tensors='pt').to(device)\n", 86 | " x = encoded_input['input_ids']\n", 87 | " \n", 88 | " # we'll process all desired num_samples in a batch, so expand out the batch dim\n", 89 | " x = x.expand(num_samples, -1)\n", 90 | "\n", 91 | " # forward the model `steps` times to get samples, in a batch\n", 92 | " y = model.generate(x, max_new_tokens=steps, do_sample=do_sample, top_k=40)\n", 93 | " \n", 94 | " for i in range(num_samples):\n", 95 | " out = tokenizer.decode(y[i].cpu().squeeze())\n", 96 | " print('-'*80)\n", 97 | " print(out)\n", 98 | " " 99 | ] 100 | }, 101 | { 102 | "cell_type": "code", 103 | "execution_count": 5, 104 | "metadata": {}, 105 | "outputs": [ 106 | { 107 | "name": "stdout", 108 | "output_type": "stream", 109 | "text": [ 110 | "--------------------------------------------------------------------------------\n", 111 | "Andrej Karpathy, the chief of the criminal investigation department, said during a news conference, \"We still have a lot of\n", 112 | "--------------------------------------------------------------------------------\n", 113 | "Andrej Karpathy, the man whom most of America believes is the architect of the current financial crisis. He runs the National Council\n", 114 | "--------------------------------------------------------------------------------\n", 115 | "Andrej Karpathy, the head of the Department for Regional Reform of Bulgaria and an MP in the centre-right GERB party\n", 116 | "--------------------------------------------------------------------------------\n", 117 | "Andrej Karpathy, the former head of the World Bank's IMF department, who worked closely with the IMF. The IMF had\n", 118 | "--------------------------------------------------------------------------------\n", 119 | "Andrej Karpathy, the vice president for innovation and research at Citi who oversaw the team's work to make sense of the\n", 120 | "--------------------------------------------------------------------------------\n", 121 | "Andrej Karpathy, the CEO of OOAK Research, said that the latest poll indicates that it won't take much to\n", 122 | "--------------------------------------------------------------------------------\n", 123 | "Andrej Karpathy, the former prime minister of Estonia was at the helm of a three-party coalition when parliament met earlier this\n", 124 | "--------------------------------------------------------------------------------\n", 125 | "Andrej Karpathy, the director of the Institute of Economic and Social Research, said if the rate of return is only 5 per\n", 126 | "--------------------------------------------------------------------------------\n", 127 | "Andrej Karpathy, the minister of commerce for Latvia's western neighbour: \"The deal means that our two countries have reached more\n", 128 | "--------------------------------------------------------------------------------\n", 129 | "Andrej Karpathy, the state's environmental protection commissioner. \"That's why we have to keep these systems in place.\"\n", 130 | "\n" 131 | ] 132 | } 133 | ], 134 | "source": [ 135 | "generate(prompt='Andrej Karpathy, the', num_samples=10, steps=20)" 136 | ] 137 | } 138 | ], 139 | "metadata": { 140 | "kernelspec": { 141 | "display_name": "Python 3.10.4 64-bit", 142 | "language": "python", 143 | "name": "python3" 144 | }, 145 | "language_info": { 146 | "codemirror_mode": { 147 | "name": "ipython", 148 | "version": 3 149 | }, 150 | "file_extension": ".py", 151 | "mimetype": "text/x-python", 152 | "name": "python", 153 | "nbconvert_exporter": "python", 154 | "pygments_lexer": "ipython3", 155 | "version": "3.10.4" 156 | }, 157 | "orig_nbformat": 4, 158 | "vscode": { 159 | "interpreter": { 160 | "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" 161 | } 162 | } 163 | }, 164 | "nbformat": 4, 165 | "nbformat_minor": 2 166 | } 167 | -------------------------------------------------------------------------------- /projects/adder/adder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Trains a GPT to add n-digit numbers. 3 | """ 4 | 5 | import os 6 | import sys 7 | import json 8 | 9 | import torch 10 | from torch.utils.data import Dataset 11 | from torch.utils.data.dataloader import DataLoader 12 | 13 | from mingpt.model import GPT 14 | from mingpt.trainer import Trainer 15 | from mingpt.utils import set_seed, setup_logging, CfgNode as CN 16 | 17 | # ----------------------------------------------------------------------------- 18 | 19 | def get_config(): 20 | 21 | C = CN() 22 | 23 | # system 24 | C.system = CN() 25 | C.system.seed = 3407 26 | C.system.work_dir = './out/adder' 27 | 28 | # data 29 | C.data = AdditionDataset.get_default_config() 30 | 31 | # model 32 | C.model = GPT.get_default_config() 33 | C.model.model_type = 'gpt-nano' 34 | 35 | # trainer 36 | C.trainer = Trainer.get_default_config() 37 | C.trainer.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster 38 | 39 | return C 40 | 41 | # ----------------------------------------------------------------------------- 42 | 43 | class AdditionDataset(Dataset): 44 | """ 45 | Creates n-digit addition problems. For example, if n=2, then an example 46 | addition problem would be to add 85 + 50 = 135. This problem would be 47 | represented as the following string for the GPT: 48 | 49 | "8550531" 50 | 51 | This is because: 52 | - we are discarding the + and =, which are not necessary. We just encode the digits 53 | of the input numbers concatenated together. 54 | - the result 135 is encoded backwards to make the addition easier to learn for the 55 | GPT model, because of how the addition algorithm works. 56 | 57 | As one more example, the problem 6 + 39 = 45 would be encoded as: 58 | 59 | "0639054" 60 | 61 | where you will notice that we are padding with zeros to make sure that we always 62 | produce strings of the exact same size: n + n + (n + 1). When n=2, this is 7. 63 | At test time, we will feed in an addition problem by giving the first 2n digits, 64 | and hoping that the GPT model completes the sequence with the next (n+1) digits 65 | correctly. 66 | """ 67 | 68 | @staticmethod 69 | def get_default_config(): 70 | C = CN() 71 | C.ndigit = 2 72 | return C 73 | 74 | def __init__(self, config, split): 75 | self.config = config 76 | self.split = split # train/test 77 | 78 | # split up all addition problems into either training data or test data 79 | ndigit = self.config.ndigit 80 | assert ndigit <= 3, "the lines below would be very memory inefficient, in future maybe refactor to support" 81 | num = (10**ndigit)**2 # total number of possible addition problems with ndigit numbers 82 | rng = torch.Generator() 83 | rng.manual_seed(1337) 84 | perm = torch.randperm(num, generator=rng) 85 | num_test = min(int(num*0.2), 500) # 20% of the whole dataset, or only up to 500 86 | self.ixes = perm[:num_test] if split == 'test' else perm[num_test:] 87 | 88 | def get_vocab_size(self): 89 | return 10 # digits 0..9 90 | 91 | def get_block_size(self): 92 | # a,b,a+b, and +1 due to potential carry overflow, 93 | # but then also -1 because very last digit doesn't ever plug back 94 | # as there is no explicit token to predict, it is implied 95 | return 3*self.config.ndigit + 1 - 1 96 | 97 | def __len__(self): 98 | return self.ixes.nelement() 99 | 100 | def __getitem__(self, idx): 101 | ndigit = self.config.ndigit 102 | # given a problem index idx, first recover the associated a + b 103 | idx = self.ixes[idx].item() 104 | nd = 10**ndigit 105 | a = idx // nd 106 | b = idx % nd 107 | # calculate the "label" of the addition problem a + b 108 | c = a + b 109 | # encode the digits of a, b, c into strings 110 | astr = f'%0{ndigit}d' % a 111 | bstr = f'%0{ndigit}d' % b 112 | cstr = (f'%0{ndigit+1}d' % c)[::-1] # reverse c to make addition easier 113 | render = astr + bstr + cstr 114 | dix = [int(s) for s in render] # convert each character to its token index 115 | # x will be input to GPT and y will be the associated expected outputs 116 | x = torch.tensor(dix[:-1], dtype=torch.long) 117 | y = torch.tensor(dix[1:], dtype=torch.long) # predict the next token in the sequence 118 | y[:ndigit*2-1] = -1 # we will only train in the output locations. -1 will mask loss to zero 119 | return x, y 120 | 121 | # ----------------------------------------------------------------------------- 122 | 123 | if __name__ == '__main__': 124 | 125 | # get default config and overrides from the command line, if any 126 | config = get_config() 127 | config.merge_from_args(sys.argv[1:]) 128 | print(config) 129 | setup_logging(config) 130 | set_seed(config.system.seed) 131 | 132 | # construct train and test datasets 133 | train_dataset = AdditionDataset(config.data, split='train') 134 | test_dataset = AdditionDataset(config.data, split='test') 135 | 136 | # construct the model 137 | config.model.vocab_size = train_dataset.get_vocab_size() 138 | config.model.block_size = train_dataset.get_block_size() 139 | model = GPT(config.model) 140 | 141 | # construct the trainer object 142 | trainer = Trainer(config.trainer, model, train_dataset) 143 | 144 | # helper function for the evaluation of a model 145 | def eval_split(trainer, split, max_batches=None): 146 | dataset = {'train':train_dataset, 'test':test_dataset}[split] 147 | ndigit = config.data.ndigit 148 | results = [] 149 | mistakes_printed_already = 0 150 | factors = torch.tensor([[10**i for i in range(ndigit+1)][::-1]]).to(trainer.device) 151 | loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False) 152 | for b, (x, y) in enumerate(loader): 153 | x = x.to(trainer.device) 154 | # isolate the first two digits of the input sequence alone 155 | d1d2 = x[:, :ndigit*2] 156 | # let the model sample the rest of the sequence 157 | d1d2d3 = model.generate(d1d2, ndigit+1, do_sample=False) # using greedy argmax, not sampling 158 | # isolate the last digit of the sampled sequence 159 | d3 = d1d2d3[:, -(ndigit+1):] 160 | d3 = d3.flip(1) # reverse the digits to their "normal" order 161 | # decode the integers from individual digits 162 | d1i = (d1d2[:,:ndigit] * factors[:,1:]).sum(1) 163 | d2i = (d1d2[:,ndigit:ndigit*2] * factors[:,1:]).sum(1) 164 | d3i_pred = (d3 * factors).sum(1) 165 | d3i_gt = d1i + d2i # manually calculate the ground truth 166 | # evaluate the correctness of the results in this batch 167 | correct = (d3i_pred == d3i_gt).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha 168 | for i in range(x.size(0)): 169 | results.append(int(correct[i])) 170 | if not correct[i] and mistakes_printed_already < 5: # only print up to 5 mistakes to get a sense 171 | mistakes_printed_already += 1 172 | print("GPT claims that %d + %d = %d but gt is %d" % (d1i[i], d2i[i], d3i_pred[i], d3i_gt[i])) 173 | if max_batches is not None and b+1 >= max_batches: 174 | break 175 | rt = torch.tensor(results, dtype=torch.float) 176 | print("%s final score: %d/%d = %.2f%% correct" % (split, rt.sum(), len(results), 100*rt.mean())) 177 | return rt.sum() 178 | 179 | # iteration callback 180 | top_score = 0 181 | def batch_end_callback(trainer): 182 | global top_score 183 | 184 | if trainer.iter_num % 10 == 0: 185 | print(f"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}") 186 | 187 | if trainer.iter_num % 500 == 0: 188 | # evaluate both the train and test score 189 | train_max_batches = {1: None, 2: None, 3: 5}[config.data.ndigit] # if ndigit=2 we can afford the whole train set, ow no 190 | model.eval() 191 | with torch.no_grad(): 192 | train_score = eval_split(trainer, 'train', max_batches=train_max_batches) 193 | test_score = eval_split(trainer, 'test', max_batches=None) 194 | score = train_score + test_score 195 | # save the model if this is the best score we've seen so far 196 | if score > top_score: 197 | top_score = score 198 | print(f"saving model with new top score of {score}") 199 | ckpt_path = os.path.join(config.system.work_dir, "model.pt") 200 | torch.save(model.state_dict(), ckpt_path) 201 | # revert model to training mode 202 | model.train() 203 | 204 | trainer.set_callback('on_batch_end', batch_end_callback) 205 | 206 | # run the optimization 207 | trainer.run() 208 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | # minGPT 3 | 4 | ![mingpt](mingpt.jpg) 5 | 6 | A PyTorch re-implementation of [GPT](https://github.com/openai/gpt-2), both training and inference. minGPT tries to be small, clean, interpretable and educational, as most of the currently available GPT model implementations can a bit sprawling. GPT is not a complicated model and this implementation is appropriately about 300 lines of code (see [mingpt/model.py](mingpt/model.py)). All that's going on is that a sequence of indices feeds into a [Transformer](https://arxiv.org/abs/1706.03762), and a probability distribution over the next index in the sequence comes out. The majority of the complexity is just being clever with batching (both across examples and over sequence length) for efficiency. 7 | 8 | The minGPT library is three files: [mingpt/model.py](mingpt/model.py) contains the actual Transformer model definition, [mingpt/bpe.py](mingpt/bpe.py) contains a mildly refactored Byte Pair Encoder that translates between text and sequences of integers exactly like OpenAI did in GPT, [mingpt/trainer.py](mingpt/trainer.py) is (GPT-independent) PyTorch boilerplate code that trains the model. Then there are a number of demos and projects that use the library in the `projects` folder: 9 | 10 | - `projects/adder` trains a GPT from scratch to add numbers (inspired by the addition section in the GPT-3 paper) 11 | - `projects/chargpt` trains a GPT to be a character-level language model on some input text file 12 | - `demo.ipynb` shows a minimal usage of the `GPT` and `Trainer` in a notebook format on a simple sorting example 13 | - `generate.ipynb` shows how one can load a pretrained GPT2 and generate text given some prompt 14 | 15 | ### Usage 16 | 17 | Here's how you'd instantiate a GPT-2 (124M param version): 18 | 19 | ```python 20 | from mingpt.model import GPT 21 | model_config = GPT.get_default_config() 22 | model_config.model_type = 'gpt2' 23 | model_config.vocab_size = 50257 # openai's model vocabulary 24 | model_config.block_size = 1024 # openai's model block_size (i.e. input context length) 25 | model = GPT(model_config) 26 | ``` 27 | 28 | And here's how you'd train it: 29 | 30 | ```python 31 | # your subclass of torch.utils.data.Dataset that emits example 32 | # torch LongTensor of lengths up to 1024, with integers from [0,50257) 33 | train_dataset = YourDataset() 34 | 35 | from mingpt.trainer import Trainer 36 | train_config = Trainer.get_default_config() 37 | train_config.learning_rate = 5e-4 # many possible options, see the file 38 | train_config.max_iters = 1000 39 | train_config.batch_size = 32 40 | trainer = Trainer(train_config, model, train_dataset) 41 | trainer.run() 42 | ``` 43 | 44 | See `demo.ipynb` for a more concrete example. 45 | 46 | ### Unit tests 47 | 48 | Coverage is not super amazing just yet but: 49 | 50 | ``` 51 | python -m unittest discover tests 52 | ``` 53 | 54 | ### todos 55 | 56 | - add gpt-2 finetuning demo on arbitrary given text file 57 | - add dialog agent demo 58 | - better docs of outcomes for existing projects (adder, chargpt) 59 | - add mixed precision and related training scaling goodies 60 | - distributed training support 61 | - reproduce some benchmarks in projects/, e.g. text8 or other language modeling 62 | - proper logging instead of print statement amateur hour haha 63 | - i probably should have a requirements.txt file... 64 | - it should be possible to load in many other model weights other than just gpt2-\* 65 | 66 | ### References 67 | 68 | Code: 69 | 70 | - [openai/gpt-2](https://github.com/openai/gpt-2) has the model definition in TensorFlow, but not the training code 71 | - [openai/image-gpt](https://github.com/openai/image-gpt) has some more modern gpt-3 like modification in its code, good reference as well 72 | - [huggingface/transformers](https://github.com/huggingface/transformers) has a [language-modeling example](https://github.com/huggingface/transformers/tree/master/examples/pytorch/language-modeling). It is full-featured but as a result also somewhat challenging to trace. E.g. some large functions have as much as 90% unused code behind various branching statements that is unused in the default setting of simple language modeling 73 | 74 | Papers + some implementation notes: 75 | 76 | #### Improving Language Understanding by Generative Pre-Training (GPT-1) 77 | 78 | - Our model largely follows the original transformer work 79 | - We trained a 12-layer decoder-only transformer with masked self-attention heads (768 dimensional states and 12 attention heads). For the position-wise feed-forward networks, we used 3072 dimensional inner states. 80 | - Adam max learning rate of 2.5e-4. (later GPT-3 for this model size uses 6e-4) 81 | - LR decay: increased linearly from zero over the first 2000 updates and annealed to 0 using a cosine schedule 82 | - We train for 100 epochs on minibatches of 64 randomly sampled, contiguous sequences of 512 tokens. 83 | - Since layernorm is used extensively throughout the model, a simple weight initialization of N(0, 0.02) was sufficient 84 | - bytepair encoding (BPE) vocabulary with 40,000 merges 85 | - residual, embedding, and attention dropouts with a rate of 0.1 for regularization. 86 | - modified version of L2 regularization proposed in (37), with w = 0.01 on all non bias or gain weights 87 | - For the activation function, we used the Gaussian Error Linear Unit (GELU). 88 | - We used learned position embeddings instead of the sinusoidal version proposed in the original work 89 | - For finetuning: We add dropout to the classifier with a rate of 0.1. learning rate of 6.25e-5 and a batchsize of 32. 3 epochs. We use a linear learning rate decay schedule with warmup over 0.2% of training. λ was set to 0.5. 90 | - GPT-1 model is 12 layers and d_model 768, ~117M params 91 | 92 | #### Language Models are Unsupervised Multitask Learners (GPT-2) 93 | 94 | - LayerNorm was moved to the input of each sub-block, similar to a pre-activation residual network 95 | - an additional layer normalization was added after the final self-attention block. 96 | - modified initialization which accounts for the accumulation on the residual path with model depth is used. We scale the weights of residual layers at initialization by a factor of 1/√N where N is the number of residual layers. (weird because in their released code i can only find a simple use of the old 0.02... in their release of image-gpt I found it used for c_proj, and even then only for attn, not for mlp. huh. https://github.com/openai/image-gpt/blob/master/src/model.py) 97 | - the vocabulary is expanded to 50,257 98 | - increase the context size from 512 to 1024 tokens 99 | - larger batchsize of 512 is used 100 | - GPT-2 used 48 layers and d_model 1600 (vs. original 12 layers and d_model 768). ~1.542B params 101 | 102 | #### Language Models are Few-Shot Learners (GPT-3) 103 | 104 | - GPT-3: 96 layers, 96 heads, with d_model of 12,288 (175B parameters). 105 | - GPT-1-like: 12 layers, 12 heads, d_model 768 (125M) 106 | - We use the same model and architecture as GPT-2, including the modified initialization, pre-normalization, and reversible tokenization described therein 107 | - we use alternating dense and locally banded sparse attention patterns in the layers of the transformer, similar to the Sparse Transformer 108 | - we always have the feedforward layer four times the size of the bottleneck layer, dff = 4 ∗ dmodel 109 | - all models use a context window of nctx = 2048 tokens. 110 | - Adam with β1 = 0.9, β2 = 0.95, and eps = 10−8 111 | - All models use weight decay of 0.1 to provide a small amount of regularization. (NOTE: GPT-1 used 0.01 I believe, see above) 112 | - clip the global norm of the gradient at 1.0 113 | - Linear LR warmup over the first 375 million tokens. Then use cosine decay for learning rate down to 10% of its value, over 260 billion tokens. 114 | - gradually increase the batch size linearly from a small value (32k tokens) to the full value over the first 4-12 billion tokens of training, depending on the model size. 115 | - full 2048-sized time context window is always used, with a special END OF DOCUMENT token delimiter 116 | 117 | #### Generative Pretraining from Pixels (Image GPT) 118 | 119 | - When working with images, we pick the identity permutation πi = i for 1 ≤ i ≤ n, also known as raster order. 120 | - we create our own 9-bit color palette by clustering (R, G, B) pixel values using k-means with k = 512. 121 | - Our largest model, iGPT-XL, contains L = 60 layers and uses an embedding size of d = 3072 for a total of 6.8B parameters. 122 | - Our next largest model, iGPT-L, is essentially identical to GPT-2 with L = 48 layers, but contains a slightly smaller embedding size of d = 1536 (vs 1600) for a total of 1.4M parameters. 123 | - We use the same model code as GPT-2, except that we initialize weights in the layerdependent fashion as in Sparse Transformer (Child et al., 2019) and zero-initialize all projections producing logits. 124 | - We also train iGPT-M, a 455M parameter model with L = 36 and d = 1024 125 | - iGPT-S, a 76M parameter model with L = 24 and d = 512 (okay, and how many heads? looks like the Github code claims 8) 126 | - When pre-training iGPT-XL, we use a batch size of 64 and train for 2M iterations, and for all other models we use a batch size of 128 and train for 1M iterations. 127 | - Adam with β1 = 0.9 and β2 = 0.95 128 | - The learning rate is warmed up for one epoch, and then decays to 0 129 | - We did not use weight decay because applying a small weight decay of 0.01 did not change representation quality. 130 | - iGPT-S lr 0.003 131 | - No dropout is used. 132 | 133 | ### License 134 | 135 | MIT 136 | -------------------------------------------------------------------------------- /demo.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "A cute little demo showing the simplest usage of minGPT. Configured to run fine on Macbook Air in like a minute." 8 | ] 9 | }, 10 | { 11 | "cell_type": "code", 12 | "execution_count": 1, 13 | "metadata": {}, 14 | "outputs": [], 15 | "source": [ 16 | "import torch\n", 17 | "from torch.utils.data import Dataset\n", 18 | "from torch.utils.data.dataloader import DataLoader\n", 19 | "from mingpt.utils import set_seed\n", 20 | "set_seed(3407)" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": 2, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import pickle\n", 30 | "\n", 31 | "class SortDataset(Dataset):\n", 32 | " \"\"\" \n", 33 | " Dataset for the Sort problem. E.g. for problem length 6:\n", 34 | " Input: 0 0 2 1 0 1 -> Output: 0 0 0 1 1 2\n", 35 | " Which will feed into the transformer concatenated as:\n", 36 | " input: 0 0 2 1 0 1 0 0 0 1 1\n", 37 | " output: I I I I I 0 0 0 1 1 2\n", 38 | " where I is \"ignore\", as the transformer is reading the input sequence\n", 39 | " \"\"\"\n", 40 | "\n", 41 | " def __init__(self, split, length=6, num_digits=3):\n", 42 | " assert split in {'train', 'test'}\n", 43 | " self.split = split\n", 44 | " self.length = length\n", 45 | " self.num_digits = num_digits\n", 46 | " \n", 47 | " def __len__(self):\n", 48 | " return 10000 # ...\n", 49 | " \n", 50 | " def get_vocab_size(self):\n", 51 | " return self.num_digits\n", 52 | " \n", 53 | " def get_block_size(self):\n", 54 | " # the length of the sequence that will feed into transformer, \n", 55 | " # containing concatenated input and the output, but -1 because\n", 56 | " # the transformer starts making predictions at the last input element\n", 57 | " return self.length * 2 - 1\n", 58 | "\n", 59 | " def __getitem__(self, idx):\n", 60 | " \n", 61 | " # use rejection sampling to generate an input example from the desired split\n", 62 | " while True:\n", 63 | " # generate some random integers\n", 64 | " inp = torch.randint(self.num_digits, size=(self.length,), dtype=torch.long)\n", 65 | " # half of the time let's try to boost the number of examples that \n", 66 | " # have a large number of repeats, as this is what the model seems to struggle\n", 67 | " # with later in training, and they are kind of rate\n", 68 | " if torch.rand(1).item() < 0.5:\n", 69 | " if inp.unique().nelement() > self.length // 2:\n", 70 | " # too many unqiue digits, re-sample\n", 71 | " continue\n", 72 | " # figure out if this generated example is train or test based on its hash\n", 73 | " h = hash(pickle.dumps(inp.tolist()))\n", 74 | " inp_split = 'test' if h % 4 == 0 else 'train' # designate 25% of examples as test\n", 75 | " if inp_split == self.split:\n", 76 | " break # ok\n", 77 | " \n", 78 | " # solve the task: i.e. sort\n", 79 | " sol = torch.sort(inp)[0]\n", 80 | "\n", 81 | " # concatenate the problem specification and the solution\n", 82 | " cat = torch.cat((inp, sol), dim=0)\n", 83 | "\n", 84 | " # the inputs to the transformer will be the offset sequence\n", 85 | " x = cat[:-1].clone()\n", 86 | " y = cat[1:].clone()\n", 87 | " # we only want to predict at output locations, mask out the loss at the input locations\n", 88 | " y[:self.length-1] = -1\n", 89 | " return x, y\n" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 3, 95 | "metadata": {}, 96 | "outputs": [ 97 | { 98 | "name": "stdout", 99 | "output_type": "stream", 100 | "text": [ 101 | "1 -1\n", 102 | "0 -1\n", 103 | "1 -1\n", 104 | "0 -1\n", 105 | "0 -1\n", 106 | "0 0\n", 107 | "0 0\n", 108 | "0 0\n", 109 | "0 0\n", 110 | "0 1\n", 111 | "1 1\n" 112 | ] 113 | } 114 | ], 115 | "source": [ 116 | "# print an example instance of the dataset\n", 117 | "train_dataset = SortDataset('train')\n", 118 | "test_dataset = SortDataset('test')\n", 119 | "x, y = train_dataset[0]\n", 120 | "for a, b in zip(x,y):\n", 121 | " print(int(a),int(b))" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 4, 127 | "metadata": {}, 128 | "outputs": [ 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "number of parameters: 0.09M\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "# create a GPT instance\n", 139 | "from mingpt.model import GPT\n", 140 | "\n", 141 | "model_config = GPT.get_default_config()\n", 142 | "model_config.model_type = 'gpt-nano'\n", 143 | "model_config.vocab_size = train_dataset.get_vocab_size()\n", 144 | "model_config.block_size = train_dataset.get_block_size()\n", 145 | "model = GPT(model_config)" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 5, 151 | "metadata": {}, 152 | "outputs": [ 153 | { 154 | "name": "stdout", 155 | "output_type": "stream", 156 | "text": [ 157 | "running on device cuda\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "# create a Trainer object\n", 163 | "from mingpt.trainer import Trainer\n", 164 | "\n", 165 | "train_config = Trainer.get_default_config()\n", 166 | "train_config.learning_rate = 5e-4 # the model we're using is so small that we can go a bit faster\n", 167 | "train_config.max_iters = 2000\n", 168 | "train_config.num_workers = 0\n", 169 | "trainer = Trainer(train_config, model, train_dataset)" 170 | ] 171 | }, 172 | { 173 | "cell_type": "code", 174 | "execution_count": 6, 175 | "metadata": {}, 176 | "outputs": [ 177 | { 178 | "name": "stdout", 179 | "output_type": "stream", 180 | "text": [ 181 | "iter_dt 0.00ms; iter 0: train loss 1.06407\n", 182 | "iter_dt 18.17ms; iter 100: train loss 0.14712\n", 183 | "iter_dt 18.70ms; iter 200: train loss 0.05315\n", 184 | "iter_dt 19.65ms; iter 300: train loss 0.04404\n", 185 | "iter_dt 31.64ms; iter 400: train loss 0.04724\n", 186 | "iter_dt 18.43ms; iter 500: train loss 0.02521\n", 187 | "iter_dt 19.83ms; iter 600: train loss 0.03352\n", 188 | "iter_dt 19.58ms; iter 700: train loss 0.00539\n", 189 | "iter_dt 18.72ms; iter 800: train loss 0.02057\n", 190 | "iter_dt 18.26ms; iter 900: train loss 0.00360\n", 191 | "iter_dt 18.50ms; iter 1000: train loss 0.00788\n", 192 | "iter_dt 20.64ms; iter 1100: train loss 0.01162\n", 193 | "iter_dt 18.63ms; iter 1200: train loss 0.00963\n", 194 | "iter_dt 18.32ms; iter 1300: train loss 0.02066\n", 195 | "iter_dt 18.40ms; iter 1400: train loss 0.01739\n", 196 | "iter_dt 18.37ms; iter 1500: train loss 0.00376\n", 197 | "iter_dt 18.67ms; iter 1600: train loss 0.00133\n", 198 | "iter_dt 18.38ms; iter 1700: train loss 0.00179\n", 199 | "iter_dt 18.66ms; iter 1800: train loss 0.00079\n", 200 | "iter_dt 18.48ms; iter 1900: train loss 0.00042\n" 201 | ] 202 | } 203 | ], 204 | "source": [ 205 | "def batch_end_callback(trainer):\n", 206 | " if trainer.iter_num % 100 == 0:\n", 207 | " print(f\"iter_dt {trainer.iter_dt * 1000:.2f}ms; iter {trainer.iter_num}: train loss {trainer.loss.item():.5f}\")\n", 208 | "trainer.set_callback('on_batch_end', batch_end_callback)\n", 209 | "\n", 210 | "trainer.run()" 211 | ] 212 | }, 213 | { 214 | "cell_type": "code", 215 | "execution_count": 7, 216 | "metadata": {}, 217 | "outputs": [], 218 | "source": [ 219 | "# now let's perform some evaluation\n", 220 | "model.eval();" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "metadata": {}, 227 | "outputs": [ 228 | { 229 | "name": "stdout", 230 | "output_type": "stream", 231 | "text": [ 232 | "train final score: 5000/5000 = 100.00% correct\n", 233 | "test final score: 5000/5000 = 100.00% correct\n" 234 | ] 235 | } 236 | ], 237 | "source": [ 238 | "def eval_split(trainer, split, max_batches):\n", 239 | " dataset = {'train':train_dataset, 'test':test_dataset}[split]\n", 240 | " n = train_dataset.length # naugy direct access shrug\n", 241 | " results = []\n", 242 | " mistakes_printed_already = 0\n", 243 | " loader = DataLoader(dataset, batch_size=100, num_workers=0, drop_last=False)\n", 244 | " for b, (x, y) in enumerate(loader):\n", 245 | " x = x.to(trainer.device)\n", 246 | " y = y.to(trainer.device)\n", 247 | " # isolate the input pattern alone\n", 248 | " inp = x[:, :n]\n", 249 | " sol = y[:, -n:]\n", 250 | " # let the model sample the rest of the sequence\n", 251 | " cat = model.generate(inp, n, do_sample=False) # using greedy argmax, not sampling\n", 252 | " sol_candidate = cat[:, n:] # isolate the filled in sequence\n", 253 | " # compare the predicted sequence to the true sequence\n", 254 | " correct = (sol == sol_candidate).all(1).cpu() # Software 1.0 vs. Software 2.0 fight RIGHT on this line haha\n", 255 | " for i in range(x.size(0)):\n", 256 | " results.append(int(correct[i]))\n", 257 | " if not correct[i] and mistakes_printed_already < 3: # only print up to 5 mistakes to get a sense\n", 258 | " mistakes_printed_already += 1\n", 259 | " print(\"GPT claims that %s sorted is %s but gt is %s\" % (inp[i].tolist(), sol_candidate[i].tolist(), sol[i].tolist()))\n", 260 | " if max_batches is not None and b+1 >= max_batches:\n", 261 | " break\n", 262 | " rt = torch.tensor(results, dtype=torch.float)\n", 263 | " print(\"%s final score: %d/%d = %.2f%% correct\" % (split, rt.sum(), len(results), 100*rt.mean()))\n", 264 | " return rt.sum()\n", 265 | "\n", 266 | "# run a lot of examples from both train and test through the model and verify the output correctness\n", 267 | "with torch.no_grad():\n", 268 | " train_score = eval_split(trainer, 'train', max_batches=50)\n", 269 | " test_score = eval_split(trainer, 'test', max_batches=50)" 270 | ] 271 | }, 272 | { 273 | "cell_type": "code", 274 | "execution_count": 9, 275 | "metadata": {}, 276 | "outputs": [ 277 | { 278 | "name": "stdout", 279 | "output_type": "stream", 280 | "text": [ 281 | "input sequence : [[0, 0, 2, 1, 0, 1]]\n", 282 | "predicted sorted: [[0, 0, 0, 1, 1, 2]]\n", 283 | "gt sort : [0, 0, 0, 1, 1, 2]\n", 284 | "matches : True\n" 285 | ] 286 | } 287 | ], 288 | "source": [ 289 | "# let's run a random given sequence through the model as well\n", 290 | "n = train_dataset.length # naugy direct access shrug\n", 291 | "inp = torch.tensor([[0, 0, 2, 1, 0, 1]], dtype=torch.long).to(trainer.device)\n", 292 | "assert inp[0].nelement() == n\n", 293 | "with torch.no_grad():\n", 294 | " cat = model.generate(inp, n, do_sample=False)\n", 295 | "sol = torch.sort(inp[0])[0]\n", 296 | "sol_candidate = cat[:, n:]\n", 297 | "print('input sequence :', inp.tolist())\n", 298 | "print('predicted sorted:', sol_candidate.tolist())\n", 299 | "print('gt sort :', sol.tolist())\n", 300 | "print('matches :', bool((sol == sol_candidate).all()))" 301 | ] 302 | } 303 | ], 304 | "metadata": { 305 | "kernelspec": { 306 | "display_name": "Python 3.10.4 64-bit", 307 | "language": "python", 308 | "name": "python3" 309 | }, 310 | "language_info": { 311 | "codemirror_mode": { 312 | "name": "ipython", 313 | "version": 3 314 | }, 315 | "file_extension": ".py", 316 | "mimetype": "text/x-python", 317 | "name": "python", 318 | "nbconvert_exporter": "python", 319 | "pygments_lexer": "ipython3", 320 | "version": "3.10.4" 321 | }, 322 | "orig_nbformat": 4, 323 | "vscode": { 324 | "interpreter": { 325 | "hash": "3ad933181bd8a04b432d3370b9dc3b0662ad032c4dfaa4e4f1596c548f763858" 326 | } 327 | } 328 | }, 329 | "nbformat": 4, 330 | "nbformat_minor": 2 331 | } 332 | -------------------------------------------------------------------------------- /mingpt/bpe.py: -------------------------------------------------------------------------------- 1 | """ 2 | bpe is short for Byte Pair Encoder. It translates arbitrary utf-8 strings into 3 | sequences of integers, where each integer represents small chunks of commonly 4 | occuring characters. This implementation is based on openai's gpt2 encoder.py: 5 | https://github.com/openai/gpt-2/blob/master/src/encoder.py 6 | but was midly modified because the original implementation is a bit confusing. 7 | I also tried to add as many comments as possible, my own understanding of what's 8 | going on. 9 | """ 10 | 11 | import os 12 | import json 13 | import regex as re 14 | import requests 15 | 16 | import torch 17 | 18 | # ----------------------------------------------------------------------------- 19 | 20 | def bytes_to_unicode(): 21 | """ 22 | Everu possible byte (really an integer 0..255) gets mapped by OpenAI to a unicode 23 | character that represents it visually. Some bytes have their appearance preserved 24 | because they don't cause any trouble. These are defined in list bs. For example: 25 | chr(33) returns "!", so in the returned dictionary we simply have d[33] -> "!". 26 | However, chr(0), for example, is '\x00', which looks ugly. So OpenAI maps these 27 | bytes, into new characters in a range where chr() returns a single nice character. 28 | So in the final dictionary we have d[0] -> 'Ā' instead, which is just chr(0 + 2**8). 29 | In particular, the space character is 32, which we can see by ord(' '). Instead, 30 | this function will shift space (32) by 256 to 288, so d[32] -> 'Ġ'. 31 | So this is just a simple one-to-one mapping of bytes 0..255 into unicode characters 32 | that "look nice", either in their original form, or a funny shifted character 33 | like 'Ā', or 'Ġ', etc. 34 | """ 35 | # the 188 integers that render fine in their original form and need no shifting 36 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 37 | cs = bs[:] # all integers b in bs will simply map to chr(b) in the output dict 38 | # now get the representations of the other 68 integers that do need shifting 39 | # each will get mapped chr(256 + n), where n will grow from 0...67 in the loop 40 | n = 0 41 | for b in range(2**8): 42 | if b not in bs: 43 | # if this byte is "ugly" then map it to the next available "nice" character 44 | bs.append(b) 45 | cs.append(2**8+n) 46 | n += 1 47 | cs = [chr(n) for n in cs] 48 | d = dict(zip(bs, cs)) 49 | return d 50 | 51 | def get_pairs(word): 52 | """ 53 | Return all bigrams as a set of tuples, of consecutive elements in the iterable word. 54 | """ 55 | pairs = set() 56 | prev_char = word[0] 57 | for char in word[1:]: 58 | pairs.add((prev_char, char)) 59 | prev_char = char 60 | return pairs 61 | 62 | class Encoder: 63 | 64 | def __init__(self, encoder, bpe_merges): 65 | # byte encoder/decoder 66 | self.byte_encoder = bytes_to_unicode() 67 | self.byte_decoder = {v:k for k, v in self.byte_encoder.items()} 68 | # bpe token encoder/decoder 69 | self.encoder = encoder 70 | self.decoder = {v:k for k,v in self.encoder.items()} 71 | # bpe merge list that defines the bpe "tree", of tuples (a,b) that are to merge to token ab 72 | self.bpe_ranks = dict(zip(bpe_merges, range(len(bpe_merges)))) 73 | # the splitting pattern used for pre-tokenization 74 | # Should haved added re.IGNORECASE so BPE merges can happen for capitalized versions of contractions <-- original openai comment 75 | """ 76 | ok so what is this regex looking for, exactly? 77 | python re reference: https://docs.python.org/3/library/re.html 78 | - the vertical bars | is OR, so re.findall will chunkate text as the pieces match, from left to right 79 | - '\'s' would split up things like Andrej's -> (Andrej, 's) 80 | - ' ?\p{L}': optional space followed by 1+ unicode code points in the category "letter" 81 | - ' ?\p{N}': optional space followed by 1+ unicode code points in the category "number" 82 | - ' ?[^\s\p{L}\p{N}]+': optional space, then 1+ things that are NOT a whitespace, letter or number 83 | - '\s+(?!\S)': 1+ whitespace characters (e.g. space or tab or etc) UNLESS they are followed by non-whitespace 84 | so this will consume whitespace characters in a sequence but exclude the last whitespace in 85 | that sequence. that last whitespace has the opportunity to then match the optional ' ?' in 86 | earlier patterns. 87 | - '\s+': 1+ whitespace characters, intended probably to catch a full trailing sequence of whitespaces at end of string 88 | So TLDR: 89 | - we are special casing a few common apostrophe constructs ('s, 't, 're, ...) and making those into separate tokens 90 | - we then separate out strings into consecutive chunks of 1) letters, 2) numbers, 3) non-letter-numbers, 4) whitespaces 91 | """ 92 | self.pat = re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""") 93 | self.cache = {} 94 | 95 | def bpe(self, token): 96 | """ 97 | this function uses self.bpe_ranks to iterative merge all the possible bpe tokens 98 | up the tree. token is a string of one individual 'word' (after regex tokenization) 99 | and after byte encoding, e.g. 'Ġthere'. 100 | """ 101 | # token is a string of one individual 'word', after byte encoding, e.g. 'Ġthere' 102 | 103 | # memoization, for efficiency 104 | if token in self.cache: 105 | return self.cache[token] 106 | 107 | word = tuple(token) # individual characters that make up the token, in a tuple 108 | pairs = get_pairs(word) # get all bigrams 109 | 110 | if not pairs: 111 | return token 112 | 113 | while True: 114 | 115 | # find the next lowest rank bigram that can be merged 116 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 117 | if bigram not in self.bpe_ranks: 118 | break # no more bigrams are eligible to be merged 119 | first, second = bigram 120 | 121 | # we will now replace all occurences of (first, second) in the list of current 122 | # words into one merged token first_second, in the output list new_words 123 | new_word = [] 124 | i = 0 125 | while i < len(word): 126 | 127 | # find the next occurence of first in the sequence of current words 128 | try: 129 | j = word.index(first, i) 130 | new_word.extend(word[i:j]) 131 | i = j 132 | except: 133 | new_word.extend(word[i:]) 134 | break 135 | 136 | # if this occurence is also followed by second, then merge them into one 137 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 138 | new_word.append(first+second) 139 | i += 2 140 | else: 141 | new_word.append(word[i]) 142 | i += 1 143 | 144 | # all occurences of (first, second) have been merged to first_second 145 | new_word = tuple(new_word) 146 | word = new_word 147 | if len(word) == 1: 148 | break 149 | else: 150 | pairs = get_pairs(word) 151 | 152 | # concat all words into a string, and use ' ' as the separator. Note that 153 | # by now all characters have been byte encoded, guaranteeing that ' ' is 154 | # not used in the actual data and is a 'special' delimiter character 155 | word = ' '.join(word) 156 | 157 | # cache the result and return 158 | self.cache[token] = word 159 | return word 160 | 161 | def encode(self, text): 162 | """ string goes in, list of integers comes out """ 163 | bpe_idx = [] 164 | # pre-tokenize the input text into string tokens (words, roughly speaking) 165 | tokens = re.findall(self.pat, text) 166 | # process each token into BPE integers 167 | for token in tokens: 168 | # encode the token as a bytes (b'') object 169 | token_bytes = token.encode('utf-8') 170 | # translate all bytes to their unicode string representation and flatten 171 | token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) 172 | # perform all the applicable bpe merges according to self.bpe_ranks 173 | token_merged = self.bpe(token_translated).split(' ') 174 | # translate all bpe tokens to integers 175 | token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] 176 | # extend our running list of all output integers 177 | bpe_idx.extend(token_ix) 178 | return bpe_idx 179 | 180 | def encode_and_show_work(self, text): 181 | """ debugging function, same as encode but returns all intermediate work """ 182 | bpe_idx = [] 183 | parts = [] 184 | tokens = re.findall(self.pat, text) 185 | for token in tokens: 186 | token_bytes = token.encode('utf-8') 187 | token_translated = ''.join(self.byte_encoder[b] for b in token_bytes) 188 | token_merged = self.bpe(token_translated).split(' ') 189 | token_ix = [self.encoder[bpe_token] for bpe_token in token_merged] 190 | bpe_idx.extend(token_ix) 191 | parts.append({ 192 | 'token': token, 193 | 'token_bytes': token_bytes, 194 | 'token_translated': token_translated, 195 | 'token_merged': token_merged, 196 | 'token_ix': token_ix, 197 | }) 198 | out = { 199 | 'bpe_idx': bpe_idx, # the actual output sequence 200 | 'tokens': tokens, # result of pre-tokenization 201 | 'parts': parts, # intermediates for each token part 202 | } 203 | return out 204 | 205 | def decode(self, bpe_idx): 206 | """ list of integers comes in, string comes out """ 207 | # inverse map the integers to get the tokens 208 | tokens_merged = [self.decoder[token] for token in bpe_idx] 209 | # inverse the byte encoder, e.g. recovering 'Ġ' -> ' ', and get the bytes 210 | tokens_flat = ''.join(tokens_merged) 211 | tokens_bytes = bytearray([self.byte_decoder[c] for c in tokens_flat]) 212 | # recover the full utf-8 string 213 | text = tokens_bytes.decode('utf-8', errors='replace') 214 | return text 215 | 216 | def get_file(local_file, remote_file): 217 | """ downloads remote_file to local_file if necessary """ 218 | if not os.path.isfile(local_file): 219 | print(f"downloading {remote_file} to {local_file}") 220 | response = requests.get(remote_file) 221 | open(local_file, "wb").write(response.content) 222 | 223 | def get_encoder(): 224 | """ 225 | Returns an instance of the GPT BPE Encoder/Decoder 226 | and handles caching of "database" files. 227 | """ 228 | home_dir = os.path.expanduser('~') 229 | cache_dir = os.path.join(home_dir, '.cache', 'mingpt') 230 | os.makedirs(cache_dir, exist_ok=True) 231 | 232 | # load encoder.json that has the raw mappings from token -> bpe index 233 | encoder_local_file = os.path.join(cache_dir, 'encoder.json') 234 | encoder_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/encoder.json' 235 | get_file(encoder_local_file, encoder_remote_file) 236 | with open(encoder_local_file, 'r') as f: 237 | encoder = json.load(f) 238 | assert len(encoder) == 50257 # 256 individual byte tokens, 50,000 merged tokens, and 1 special <|endoftext|> token 239 | 240 | # load vocab.bpe that contains the bpe merges, i.e. the bpe tree structure 241 | # in the form tuples (a, b), that indicate that (a, b) is to be merged to one token ab 242 | vocab_local_file = os.path.join(cache_dir, 'vocab.bpe') 243 | vocab_remote_file = 'https://openaipublic.blob.core.windows.net/gpt-2/models/124M/vocab.bpe' 244 | get_file(vocab_local_file, vocab_remote_file) 245 | with open(vocab_local_file, 'r', encoding="utf-8") as f: 246 | bpe_data = f.read() 247 | # light postprocessing: strip the version on first line and the last line is a blank 248 | bpe_merges = [tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]] 249 | assert len(bpe_merges) == 50000 # 50,000 merged tokens 250 | 251 | # construct the Encoder object and return 252 | enc = Encoder(encoder, bpe_merges) 253 | return enc 254 | 255 | # ----------------------------------------------------------------------------- 256 | 257 | class BPETokenizer: 258 | """ PyTorch-aware class that wraps the Encoder above """ 259 | 260 | def __init__(self): 261 | self.encoder = get_encoder() 262 | 263 | def __call__(self, text, return_tensors='pt'): 264 | # PyTorch only; here because we want to match huggingface/transformers interface 265 | assert return_tensors == 'pt' 266 | # single string input for now, in the future potentially a list of strings 267 | assert isinstance(text, str) 268 | # encode and create a "batch dimension" of 1 269 | idx = [self.encoder.encode(text)] 270 | # wrap into PyTorch tensor 271 | out = torch.tensor(idx, dtype=torch.long) 272 | return out 273 | 274 | def decode(self, idx): 275 | # ensure a simple 1D tensor for now 276 | assert idx.ndim == 1 277 | # decode indices to text 278 | text = self.encoder.decode(idx.tolist()) 279 | return text 280 | -------------------------------------------------------------------------------- /mingpt/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Full definition of a GPT Language Model, all of it in this single file. 3 | 4 | References: 5 | 1) the official GPT-2 TensorFlow implementation released by OpenAI: 6 | https://github.com/openai/gpt-2/blob/master/src/model.py 7 | 2) huggingface/transformers PyTorch implementation: 8 | https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py 9 | """ 10 | 11 | import math 12 | 13 | import torch 14 | import torch.nn as nn 15 | from torch.nn import functional as F 16 | 17 | from mingpt.utils import CfgNode as CN 18 | 19 | # ----------------------------------------------------------------------------- 20 | 21 | class NewGELU(nn.Module): 22 | """ 23 | Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). 24 | Reference: Gaussian Error Linear Units (GELU) paper: https://arxiv.org/abs/1606.08415 25 | """ 26 | def forward(self, x): 27 | return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0)))) 28 | 29 | class CausalSelfAttention(nn.Module): 30 | """ 31 | A vanilla multi-head masked self-attention layer with a projection at the end. 32 | It is possible to use torch.nn.MultiheadAttention here but I am including an 33 | explicit implementation here to show that there is nothing too scary here. 34 | """ 35 | 36 | def __init__(self, config): 37 | super().__init__() 38 | assert config.n_embd % config.n_head == 0 39 | # key, query, value projections for all heads, but in a batch 40 | self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd) 41 | # output projection 42 | self.c_proj = nn.Linear(config.n_embd, config.n_embd) 43 | # regularization 44 | self.attn_dropout = nn.Dropout(config.attn_pdrop) 45 | self.resid_dropout = nn.Dropout(config.resid_pdrop) 46 | # causal mask to ensure that attention is only applied to the left in the input sequence 47 | self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)) 48 | .view(1, 1, config.block_size, config.block_size)) 49 | self.n_head = config.n_head 50 | self.n_embd = config.n_embd 51 | 52 | def forward(self, x): 53 | B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) 54 | 55 | # calculate query, key, values for all heads in batch and move head forward to be the batch dim 56 | q, k ,v = self.c_attn(x).split(self.n_embd, dim=2) 57 | k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 58 | q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 59 | v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs) 60 | 61 | # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T) 62 | att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) 63 | att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf')) 64 | att = F.softmax(att, dim=-1) 65 | att = self.attn_dropout(att) 66 | y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs) 67 | y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side 68 | 69 | # output projection 70 | y = self.resid_dropout(self.c_proj(y)) 71 | return y 72 | 73 | class Block(nn.Module): 74 | """ an unassuming Transformer block """ 75 | 76 | def __init__(self, config): 77 | super().__init__() 78 | self.ln_1 = nn.LayerNorm(config.n_embd) 79 | self.attn = CausalSelfAttention(config) 80 | self.ln_2 = nn.LayerNorm(config.n_embd) 81 | self.mlp = nn.ModuleDict(dict( 82 | c_fc = nn.Linear(config.n_embd, 4 * config.n_embd), 83 | c_proj = nn.Linear(4 * config.n_embd, config.n_embd), 84 | act = NewGELU(), 85 | dropout = nn.Dropout(config.resid_pdrop), 86 | )) 87 | m = self.mlp 88 | self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward 89 | 90 | def forward(self, x): 91 | x = x + self.attn(self.ln_1(x)) 92 | x = x + self.mlpf(self.ln_2(x)) 93 | return x 94 | 95 | class GPT(nn.Module): 96 | """ GPT Language Model """ 97 | 98 | @staticmethod 99 | def get_default_config(): 100 | C = CN() 101 | # either model_type or (n_layer, n_head, n_embd) must be given in the config 102 | C.model_type = 'gpt' 103 | C.n_layer = None 104 | C.n_head = None 105 | C.n_embd = None 106 | # these options must be filled in externally 107 | C.vocab_size = None 108 | C.block_size = None 109 | # dropout hyperparameters 110 | C.embd_pdrop = 0.1 111 | C.resid_pdrop = 0.1 112 | C.attn_pdrop = 0.1 113 | return C 114 | 115 | def __init__(self, config): 116 | super().__init__() 117 | assert config.vocab_size is not None 118 | assert config.block_size is not None 119 | self.block_size = config.block_size 120 | 121 | type_given = config.model_type is not None 122 | params_given = all([config.n_layer is not None, config.n_head is not None, config.n_embd is not None]) 123 | assert (type_given and not params_given) or (not type_given and params_given) # exactly one of these 124 | if type_given: 125 | # translate from model_type to detailed configuration 126 | config.merge_from_dict({ 127 | # names follow the huggingface naming conventions 128 | # GPT-1 129 | 'openai-gpt': dict(n_layer=12, n_head=12, n_embd=768), # 117M params 130 | # GPT-2 configs 131 | 'gpt2': dict(n_layer=12, n_head=12, n_embd=768), # 124M params 132 | 'gpt2-medium': dict(n_layer=24, n_head=16, n_embd=1024), # 350M params 133 | 'gpt2-large': dict(n_layer=36, n_head=20, n_embd=1280), # 774M params 134 | 'gpt2-xl': dict(n_layer=48, n_head=25, n_embd=1600), # 1558M params 135 | # Gophers 136 | 'gopher-44m': dict(n_layer=8, n_head=16, n_embd=512), 137 | # (there are a number more...) 138 | # I made these tiny models up 139 | 'gpt-mini': dict(n_layer=6, n_head=6, n_embd=192), 140 | 'gpt-micro': dict(n_layer=4, n_head=4, n_embd=128), 141 | 'gpt-nano': dict(n_layer=3, n_head=3, n_embd=48), 142 | }[config.model_type]) 143 | 144 | self.transformer = nn.ModuleDict(dict( 145 | wte = nn.Embedding(config.vocab_size, config.n_embd), 146 | wpe = nn.Embedding(config.block_size, config.n_embd), 147 | drop = nn.Dropout(config.embd_pdrop), 148 | h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]), 149 | ln_f = nn.LayerNorm(config.n_embd), 150 | )) 151 | self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) 152 | 153 | # init all weights, and apply a special scaled init to the residual projections, per GPT-2 paper 154 | self.apply(self._init_weights) 155 | for pn, p in self.named_parameters(): 156 | if pn.endswith('c_proj.weight'): 157 | torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer)) 158 | 159 | # report number of parameters (note we don't count the decoder parameters in lm_head) 160 | n_params = sum(p.numel() for p in self.transformer.parameters()) 161 | print("number of parameters: %.2fM" % (n_params/1e6,)) 162 | 163 | def _init_weights(self, module): 164 | if isinstance(module, nn.Linear): 165 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 166 | if module.bias is not None: 167 | torch.nn.init.zeros_(module.bias) 168 | elif isinstance(module, nn.Embedding): 169 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 170 | elif isinstance(module, nn.LayerNorm): 171 | torch.nn.init.zeros_(module.bias) 172 | torch.nn.init.ones_(module.weight) 173 | 174 | @classmethod 175 | def from_pretrained(cls, model_type): 176 | """ 177 | Initialize a pretrained GPT model by copying over the weights 178 | from a huggingface/transformers checkpoint. 179 | """ 180 | assert model_type in {'gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl'} 181 | from transformers import GPT2LMHeadModel 182 | 183 | # create a from-scratch initialized minGPT model 184 | config = cls.get_default_config() 185 | config.model_type = model_type 186 | config.vocab_size = 50257 # openai's model vocabulary 187 | config.block_size = 1024 # openai's model block_size 188 | model = GPT(config) 189 | sd = model.state_dict() 190 | 191 | # init a huggingface/transformers model 192 | model_hf = GPT2LMHeadModel.from_pretrained(model_type) 193 | sd_hf = model_hf.state_dict() 194 | 195 | # copy while ensuring all of the parameters are aligned and match in names and shapes 196 | keys = [k for k in sd_hf if not k.endswith('attn.masked_bias')] # ignore these 197 | transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight'] 198 | # basically the openai checkpoints use a "Conv1D" module, but we only want to use a vanilla nn.Linear. 199 | # this means that we have to transpose these weights when we import them 200 | assert len(keys) == len(sd) 201 | for k in keys: 202 | if any(k.endswith(w) for w in transposed): 203 | # special treatment for the Conv1D weights we need to transpose 204 | assert sd_hf[k].shape[::-1] == sd[k].shape 205 | with torch.no_grad(): 206 | sd[k].copy_(sd_hf[k].t()) 207 | else: 208 | # vanilla copy over the other parameters 209 | assert sd_hf[k].shape == sd[k].shape 210 | with torch.no_grad(): 211 | sd[k].copy_(sd_hf[k]) 212 | 213 | return model 214 | 215 | def configure_optimizers(self, train_config): 216 | """ 217 | This long function is unfortunately doing something very simple and is being very defensive: 218 | We are separating out all parameters of the model into two buckets: those that will experience 219 | weight decay for regularization and those that won't (biases, and layernorm/embedding weights). 220 | We are then returning the PyTorch optimizer object. 221 | """ 222 | 223 | # separate out all parameters to those that will and won't experience regularizing weight decay 224 | decay = set() 225 | no_decay = set() 226 | whitelist_weight_modules = (torch.nn.Linear, ) 227 | blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding) 228 | for mn, m in self.named_modules(): 229 | for pn, p in m.named_parameters(): 230 | fpn = '%s.%s' % (mn, pn) if mn else pn # full param name 231 | # random note: because named_modules and named_parameters are recursive 232 | # we will see the same tensors p many many times. but doing it this way 233 | # allows us to know which parent module any tensor p belongs to... 234 | if pn.endswith('bias'): 235 | # all biases will not be decayed 236 | no_decay.add(fpn) 237 | elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules): 238 | # weights of whitelist modules will be weight decayed 239 | decay.add(fpn) 240 | elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules): 241 | # weights of blacklist modules will NOT be weight decayed 242 | no_decay.add(fpn) 243 | 244 | # validate that we considered every parameter 245 | param_dict = {pn: p for pn, p in self.named_parameters()} 246 | inter_params = decay & no_decay 247 | union_params = decay | no_decay 248 | assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), ) 249 | assert len(param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \ 250 | % (str(param_dict.keys() - union_params), ) 251 | 252 | # create the pytorch optimizer object 253 | optim_groups = [ 254 | {"params": [param_dict[pn] for pn in sorted(list(decay))], "weight_decay": train_config.weight_decay}, 255 | {"params": [param_dict[pn] for pn in sorted(list(no_decay))], "weight_decay": 0.0}, 256 | ] 257 | optimizer = torch.optim.AdamW(optim_groups, lr=train_config.learning_rate, betas=train_config.betas) 258 | return optimizer 259 | 260 | def forward(self, idx, targets=None): 261 | device = idx.device 262 | b, t = idx.size() 263 | assert t <= self.block_size, f"Cannot forward sequence of length {t}, block size is only {self.block_size}" 264 | pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t) 265 | 266 | # forward the GPT model itself 267 | tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd) 268 | pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd) 269 | x = self.transformer.drop(tok_emb + pos_emb) 270 | for block in self.transformer.h: 271 | x = block(x) 272 | x = self.transformer.ln_f(x) 273 | logits = self.lm_head(x) 274 | 275 | # if we are given some desired targets also calculate the loss 276 | loss = None 277 | if targets is not None: 278 | loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) 279 | 280 | return logits, loss 281 | 282 | @torch.no_grad() 283 | def generate(self, idx, max_new_tokens, temperature=1.0, do_sample=False, top_k=None): 284 | """ 285 | Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete 286 | the sequence max_new_tokens times, feeding the predictions back into the model each time. 287 | Most likely you'll want to make sure to be in model.eval() mode of operation for this. 288 | """ 289 | for _ in range(max_new_tokens): 290 | # if the sequence context is growing too long we must crop it at block_size 291 | idx_cond = idx if idx.size(1) <= self.block_size else idx[:, -self.block_size:] 292 | # forward the model to get the logits for the index in the sequence 293 | logits, _ = self(idx_cond) 294 | # pluck the logits at the final step and scale by desired temperature 295 | logits = logits[:, -1, :] / temperature 296 | # optionally crop the logits to only the top k options 297 | if top_k is not None: 298 | v, _ = torch.topk(logits, top_k) 299 | logits[logits < v[:, [-1]]] = -float('Inf') 300 | # apply softmax to convert logits to (normalized) probabilities 301 | probs = F.softmax(logits, dim=-1) 302 | # either sample from the distribution or take the most likely element 303 | if do_sample: 304 | idx_next = torch.multinomial(probs, num_samples=1) 305 | else: 306 | _, idx_next = torch.topk(probs, k=1, dim=-1) 307 | # append sampled index to the running sequence and continue 308 | idx = torch.cat((idx, idx_next), dim=1) 309 | 310 | return idx 311 | --------------------------------------------------------------------------------