├── README.md ├── bigram.ipynb ├── bigram.py ├── gpt.ipynb ├── gpt.py ├── gpt1.ipynb ├── input.txt └── more.txt /README.md: -------------------------------------------------------------------------------- 1 | # GPT-Model 2 | This Is A Large Language Model Like GPT . 3 | -------------------------------------------------------------------------------- /bigram.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.nn import functional as F" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# hyperparameters\n", 21 | "batch_size = 32 # how many independent sequences will we process in parallel?\n", 22 | "block_size = 8 # what is the maximum context length for predictions?\n", 23 | "max_iters = 3000\n", 24 | "eval_interval = 300\n", 25 | "learning_rate = 1e-2\n", 26 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 27 | "eval_iters = 200\n", 28 | "# ------------" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "metadata": {}, 35 | "outputs": [ 36 | { 37 | "data": { 38 | "text/plain": [ 39 | "" 40 | ] 41 | }, 42 | "execution_count": 3, 43 | "metadata": {}, 44 | "output_type": "execute_result" 45 | } 46 | ], 47 | "source": [ 48 | "torch.manual_seed(1337)" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 4, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f:\n", 58 | " text = f.read()" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 5, 64 | "metadata": {}, 65 | "outputs": [], 66 | "source": [ 67 | "# here are all the unique characters that occur in this text\n", 68 | "chars = sorted(list(set(text)))\n", 69 | "vocab_size = len(chars)\n", 70 | "# create a mapping from characters to integers\n", 71 | "stoi = { ch:i for i,ch in enumerate(chars) }\n", 72 | "itos = { i:ch for i,ch in enumerate(chars) }\n", 73 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", 74 | "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": 6, 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "# Train and test splits\n", 84 | "data = torch.tensor(encode(text), dtype=torch.long)\n", 85 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n", 86 | "train_data = data[:n]\n", 87 | "val_data = data[n:]" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 7, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "# data loading\n", 97 | "def get_batch(split):\n", 98 | " # generate a small batch of data of inputs x and targets y\n", 99 | " data = train_data if split == 'train' else val_data\n", 100 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 101 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 102 | " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", 103 | " x, y = x.to(device), y.to(device)\n", 104 | " return x, y" 105 | ] 106 | }, 107 | { 108 | "cell_type": "code", 109 | "execution_count": 8, 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "@torch.no_grad()\n", 114 | "def estimate_loss():\n", 115 | " out = {}\n", 116 | " model.eval()\n", 117 | " for split in ['train', 'val']:\n", 118 | " losses = torch.zeros(eval_iters)\n", 119 | " for k in range(eval_iters):\n", 120 | " X, Y = get_batch(split)\n", 121 | " logits, loss = model(X, Y)\n", 122 | " losses[k] = loss.item()\n", 123 | " out[split] = losses.mean()\n", 124 | " model.train()\n", 125 | " return out" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 9, 131 | "metadata": {}, 132 | "outputs": [], 133 | "source": [ 134 | "# super simple bigram model\n", 135 | "class BigramLanguageModel(nn.Module):\n", 136 | "\n", 137 | " def __init__(self, vocab_size):\n", 138 | " super().__init__()\n", 139 | " # each token directly reads off the logits for the next token from a lookup table\n", 140 | " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", 141 | "\n", 142 | " def forward(self, idx, targets=None):\n", 143 | "\n", 144 | " # idx and targets are both (B,T) tensor of integers\n", 145 | " logits = self.token_embedding_table(idx) # (B,T,C)\n", 146 | "\n", 147 | " if targets is None:\n", 148 | " loss = None\n", 149 | " else:\n", 150 | " B, T, C = logits.shape\n", 151 | " logits = logits.view(B*T, C)\n", 152 | " targets = targets.view(B*T)\n", 153 | " loss = F.cross_entropy(logits, targets)\n", 154 | "\n", 155 | " return logits, loss\n", 156 | "\n", 157 | " def generate(self, idx, max_new_tokens):\n", 158 | " # idx is (B, T) array of indices in the current context\n", 159 | " for _ in range(max_new_tokens):\n", 160 | " # get the predictions\n", 161 | " logits, loss = self(idx)\n", 162 | " # focus only on the last time step\n", 163 | " logits = logits[:, -1, :] # becomes (B, C)\n", 164 | " # apply softmax to get probabilities\n", 165 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 166 | " # sample from the distribution\n", 167 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 168 | " # append sampled index to the running sequence\n", 169 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 170 | " return idx" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 10, 176 | "metadata": {}, 177 | "outputs": [], 178 | "source": [ 179 | "model = BigramLanguageModel(vocab_size)\n", 180 | "m = model.to(device)\n", 181 | "\n", 182 | "# create a PyTorch optimizer\n", 183 | "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)" 184 | ] 185 | }, 186 | { 187 | "cell_type": "code", 188 | "execution_count": 11, 189 | "metadata": {}, 190 | "outputs": [ 191 | { 192 | "name": "stdout", 193 | "output_type": "stream", 194 | "text": [ 195 | "step 0: train loss 4.7305, val loss 4.7241\n", 196 | "step 300: train loss 2.8110, val loss 2.8249\n", 197 | "step 600: train loss 2.5434, val loss 2.5682\n", 198 | "step 900: train loss 2.4932, val loss 2.5088\n", 199 | "step 1200: train loss 2.4863, val loss 2.5035\n", 200 | "step 1500: train loss 2.4665, val loss 2.4921\n", 201 | "step 1800: train loss 2.4683, val loss 2.4936\n", 202 | "step 2100: train loss 2.4696, val loss 2.4846\n", 203 | "step 2400: train loss 2.4638, val loss 2.4879\n", 204 | "step 2700: train loss 2.4738, val loss 2.4911\n", 205 | "\n", 206 | "od nos CAy go ghanoray t, co haringoudrou clethe k,LARof fr werar,\n", 207 | "Is fa!\n", 208 | "\n", 209 | "\n", 210 | "Thilemel cia h hmboomyorarifrcitheviPO, tle dst f qur'dig t cof boddo y t o ar pileas h mo wierl t,\n", 211 | "S:\n", 212 | "STENENEat I athe thounomy tinrent distesisanimald 3I: eliento ald, avaviconofrisist me Busarend un'soto vat s k,\n", 213 | "SBRI he the f wendleindd t acoe ts ansu, thy ppr h.QULY:\n", 214 | "KIIsqu pr odEd ch,\n", 215 | "APrnes ouse bll owhored miner t ooon'stoume bupromo! fifoveghind hiarnge s.\n", 216 | "MI aswimy or m, wardd tw'To tee abifewoetsphin sed The a\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "for iter in range(max_iters):\n", 222 | "\n", 223 | " # every once in a while evaluate the loss on train and val sets\n", 224 | " if iter % eval_interval == 0:\n", 225 | " losses = estimate_loss()\n", 226 | " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", 227 | "\n", 228 | " # sample a batch of data\n", 229 | " xb, yb = get_batch('train')\n", 230 | "\n", 231 | " # evaluate the loss\n", 232 | " logits, loss = model(xb, yb)\n", 233 | " optimizer.zero_grad(set_to_none=True)\n", 234 | " loss.backward()\n", 235 | " optimizer.step()\n", 236 | "\n", 237 | "# generate from the model\n", 238 | "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", 239 | "print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))" 240 | ] 241 | } 242 | ], 243 | "metadata": { 244 | "kernelspec": { 245 | "display_name": "Python 3", 246 | "language": "python", 247 | "name": "python3" 248 | }, 249 | "language_info": { 250 | "codemirror_mode": { 251 | "name": "ipython", 252 | "version": 3 253 | }, 254 | "file_extension": ".py", 255 | "mimetype": "text/x-python", 256 | "name": "python", 257 | "nbconvert_exporter": "python", 258 | "pygments_lexer": "ipython3", 259 | "version": "3.11.2" 260 | }, 261 | "orig_nbformat": 4 262 | }, 263 | "nbformat": 4, 264 | "nbformat_minor": 2 265 | } 266 | -------------------------------------------------------------------------------- /bigram.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | # hyperparameters 6 | batch_size = 32 # how many independent sequences will we process in parallel? 7 | block_size = 8 # what is the maximum context length for predictions? 8 | max_iters = 3000 9 | eval_interval = 300 10 | learning_rate = 1e-2 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | eval_iters = 200 13 | # ------------ 14 | 15 | torch.manual_seed(1337) 16 | 17 | with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f: 18 | text = f.read() 19 | 20 | # here are all the unique characters that occur in this text 21 | chars = sorted(list(set(text))) 22 | vocab_size = len(chars) 23 | # create a mapping from characters to integers 24 | stoi = { ch:i for i,ch in enumerate(chars) } 25 | itos = { i:ch for i,ch in enumerate(chars) } 26 | encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers 27 | decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 28 | 29 | # Train and test splits 30 | data = torch.tensor(encode(text), dtype=torch.long) 31 | n = int(0.9*len(data)) # first 90% will be train, rest val 32 | train_data = data[:n] 33 | val_data = data[n:] 34 | 35 | # data loading 36 | def get_batch(split): 37 | # generate a small batch of data of inputs x and targets y 38 | data = train_data if split == 'train' else val_data 39 | ix = torch.randint(len(data) - block_size, (batch_size,)) 40 | x = torch.stack([data[i:i+block_size] for i in ix]) 41 | y = torch.stack([data[i+1:i+block_size+1] for i in ix]) 42 | x, y = x.to(device), y.to(device) 43 | return x, y 44 | 45 | @torch.no_grad() 46 | def estimate_loss(): 47 | out = {} 48 | model.eval() 49 | for split in ['train', 'val']: 50 | losses = torch.zeros(eval_iters) 51 | for k in range(eval_iters): 52 | X, Y = get_batch(split) 53 | logits, loss = model(X, Y) 54 | losses[k] = loss.item() 55 | out[split] = losses.mean() 56 | model.train() 57 | return out 58 | 59 | # super simple bigram model 60 | class BigramLanguageModel(nn.Module): 61 | 62 | def __init__(self, vocab_size): 63 | super().__init__() 64 | # each token directly reads off the logits for the next token from a lookup table 65 | self.token_embedding_table = nn.Embedding(vocab_size, vocab_size) 66 | 67 | def forward(self, idx, targets=None): 68 | 69 | # idx and targets are both (B,T) tensor of integers 70 | logits = self.token_embedding_table(idx) # (B,T,C) 71 | 72 | if targets is None: 73 | loss = None 74 | else: 75 | B, T, C = logits.shape 76 | logits = logits.view(B*T, C) 77 | targets = targets.view(B*T) 78 | loss = F.cross_entropy(logits, targets) 79 | 80 | return logits, loss 81 | 82 | def generate(self, idx, max_new_tokens): 83 | # idx is (B, T) array of indices in the current context 84 | for _ in range(max_new_tokens): 85 | # get the predictions 86 | logits, loss = self(idx) 87 | # focus only on the last time step 88 | logits = logits[:, -1, :] # becomes (B, C) 89 | # apply softmax to get probabilities 90 | probs = F.softmax(logits, dim=-1) # (B, C) 91 | # sample from the distribution 92 | idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) 93 | # append sampled index to the running sequence 94 | idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) 95 | return idx 96 | 97 | model = BigramLanguageModel(vocab_size) 98 | m = model.to(device) 99 | 100 | # create a PyTorch optimizer 101 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 102 | 103 | for iter in range(max_iters): 104 | 105 | # every once in a while evaluate the loss on train and val sets 106 | if iter % eval_interval == 0: 107 | losses = estimate_loss() 108 | print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 109 | 110 | # sample a batch of data 111 | xb, yb = get_batch('train') 112 | 113 | # evaluate the loss 114 | logits, loss = model(xb, yb) 115 | optimizer.zero_grad(set_to_none=True) 116 | loss.backward() 117 | optimizer.step() 118 | 119 | # generate from the model 120 | context = torch.zeros((1, 1), dtype=torch.long, device=device) 121 | print(decode(m.generate(context, max_new_tokens=500)[0].tolist())) 122 | -------------------------------------------------------------------------------- /gpt.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "import torch\n", 10 | "import torch.nn as nn\n", 11 | "from torch.nn import functional as F" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "metadata": {}, 18 | "outputs": [], 19 | "source": [ 20 | "# hyperparameters\n", 21 | "batch_size = 64 # how many independent sequences will we process in parallel?\n", 22 | "block_size = 256 # what is the maximum context length for predictions?\n", 23 | "max_iters = 5000\n", 24 | "eval_interval = 500\n", 25 | "learning_rate = 3e-4\n", 26 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 27 | "eval_iters = 200\n", 28 | "n_embd = 384\n", 29 | "n_head = 6\n", 30 | "n_layer = 6\n", 31 | "dropout = 0.2" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": 3, 37 | "metadata": {}, 38 | "outputs": [ 39 | { 40 | "data": { 41 | "text/plain": [ 42 | "" 43 | ] 44 | }, 45 | "execution_count": 3, 46 | "metadata": {}, 47 | "output_type": "execute_result" 48 | } 49 | ], 50 | "source": [ 51 | "torch.manual_seed(1337)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 4, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f:\n", 61 | " text = f.read()" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 5, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "# here are all the unique characters that occur in this text\n", 71 | "chars = sorted(list(set(text)))\n", 72 | "vocab_size = len(chars)\n", 73 | "# create a mapping from characters to integers\n", 74 | "stoi = { ch:i for i,ch in enumerate(chars) }\n", 75 | "itos = { i:ch for i,ch in enumerate(chars) }\n", 76 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", 77 | "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string" 78 | ] 79 | }, 80 | { 81 | "cell_type": "code", 82 | "execution_count": 6, 83 | "metadata": {}, 84 | "outputs": [], 85 | "source": [ 86 | "# Train and test splits\n", 87 | "data = torch.tensor(encode(text), dtype=torch.long)\n", 88 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n", 89 | "train_data = data[:n]\n", 90 | "val_data = data[n:]" 91 | ] 92 | }, 93 | { 94 | "cell_type": "code", 95 | "execution_count": 7, 96 | "metadata": {}, 97 | "outputs": [], 98 | "source": [ 99 | "# data loading\n", 100 | "def get_batch(split):\n", 101 | " # generate a small batch of data of inputs x and targets y\n", 102 | " data = train_data if split == 'train' else val_data\n", 103 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 104 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 105 | " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", 106 | " x, y = x.to(device), y.to(device)\n", 107 | " return x, y" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 8, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "@torch.no_grad()\n", 117 | "def estimate_loss():\n", 118 | " out = {}\n", 119 | " model.eval()\n", 120 | " for split in ['train', 'val']:\n", 121 | " losses = torch.zeros(eval_iters)\n", 122 | " for k in range(eval_iters):\n", 123 | " X, Y = get_batch(split)\n", 124 | " logits, loss = model(X, Y)\n", 125 | " losses[k] = loss.item()\n", 126 | " out[split] = losses.mean()\n", 127 | " model.train()\n", 128 | " return out" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": 9, 134 | "metadata": {}, 135 | "outputs": [], 136 | "source": [ 137 | "class Head(nn.Module):\n", 138 | " \"\"\" one head of self-attention \"\"\"\n", 139 | "\n", 140 | " def __init__(self, head_size):\n", 141 | " super().__init__()\n", 142 | " self.key = nn.Linear(n_embd, head_size, bias=False)\n", 143 | " self.query = nn.Linear(n_embd, head_size, bias=False)\n", 144 | " self.value = nn.Linear(n_embd, head_size, bias=False)\n", 145 | " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", 146 | "\n", 147 | " self.dropout = nn.Dropout(dropout)\n", 148 | "\n", 149 | " def forward(self, x):\n", 150 | " # input of size (batch, time-step, channels)\n", 151 | " # output of size (batch, time-step, head size)\n", 152 | " B,T,C = x.shape\n", 153 | " k = self.key(x) # (B,T,hs)\n", 154 | " q = self.query(x) # (B,T,hs)\n", 155 | " # compute attention scores (\"affinities\")\n", 156 | " wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T)\n", 157 | " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", 158 | " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", 159 | " wei = self.dropout(wei)\n", 160 | " # perform the weighted aggregation of the values\n", 161 | " v = self.value(x) # (B,T,hs)\n", 162 | " out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs)\n", 163 | " return out" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 10, 169 | "metadata": {}, 170 | "outputs": [], 171 | "source": [ 172 | "class MultiHeadAttention(nn.Module):\n", 173 | " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", 174 | "\n", 175 | " def __init__(self, num_heads, head_size):\n", 176 | " super().__init__()\n", 177 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", 178 | " self.proj = nn.Linear(head_size * num_heads, n_embd)\n", 179 | " self.dropout = nn.Dropout(dropout)\n", 180 | "\n", 181 | " def forward(self, x):\n", 182 | " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", 183 | " out = self.dropout(self.proj(out))\n", 184 | " return out" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": 11, 190 | "metadata": {}, 191 | "outputs": [], 192 | "source": [ 193 | "class FeedFoward(nn.Module):\n", 194 | " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", 195 | "\n", 196 | " def __init__(self, n_embd):\n", 197 | " super().__init__()\n", 198 | " self.net = nn.Sequential(\n", 199 | " nn.Linear(n_embd, 4 * n_embd),\n", 200 | " nn.ReLU(),\n", 201 | " nn.Linear(4 * n_embd, n_embd),\n", 202 | " nn.Dropout(dropout),\n", 203 | " )\n", 204 | "\n", 205 | " def forward(self, x):\n", 206 | " return self.net(x)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": 12, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "class Block(nn.Module):\n", 216 | " \"\"\" Transformer block: communication followed by computation \"\"\"\n", 217 | "\n", 218 | " def __init__(self, n_embd, n_head):\n", 219 | " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", 220 | " super().__init__()\n", 221 | " head_size = n_embd // n_head\n", 222 | " self.sa = MultiHeadAttention(n_head, head_size)\n", 223 | " self.ffwd = FeedFoward(n_embd)\n", 224 | " self.ln1 = nn.LayerNorm(n_embd)\n", 225 | " self.ln2 = nn.LayerNorm(n_embd)\n", 226 | "\n", 227 | " def forward(self, x):\n", 228 | " x = x + self.sa(self.ln1(x))\n", 229 | " x = x + self.ffwd(self.ln2(x))\n", 230 | " return x" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": 13, 236 | "metadata": {}, 237 | "outputs": [], 238 | "source": [ 239 | "class GPTLanguageModel(nn.Module):\n", 240 | "\n", 241 | " def __init__(self):\n", 242 | " super().__init__()\n", 243 | " # each token directly reads off the logits for the next token from a lookup table\n", 244 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", 245 | " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", 246 | " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", 247 | " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", 248 | " self.lm_head = nn.Linear(n_embd, vocab_size)\n", 249 | "\n", 250 | " # better init, not covered in the original GPT video, but important, will cover in followup video\n", 251 | " self.apply(self._init_weights)\n", 252 | "\n", 253 | " def _init_weights(self, module):\n", 254 | " if isinstance(module, nn.Linear):\n", 255 | " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", 256 | " if module.bias is not None:\n", 257 | " torch.nn.init.zeros_(module.bias)\n", 258 | " elif isinstance(module, nn.Embedding):\n", 259 | " torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)\n", 260 | "\n", 261 | " def forward(self, idx, targets=None):\n", 262 | " B, T = idx.shape\n", 263 | "\n", 264 | " # idx and targets are both (B,T) tensor of integers\n", 265 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", 266 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", 267 | " x = tok_emb + pos_emb # (B,T,C)\n", 268 | " x = self.blocks(x) # (B,T,C)\n", 269 | " x = self.ln_f(x) # (B,T,C)\n", 270 | " logits = self.lm_head(x) # (B,T,vocab_size)\n", 271 | "\n", 272 | " if targets is None:\n", 273 | " loss = None\n", 274 | " else:\n", 275 | " B, T, C = logits.shape\n", 276 | " logits = logits.view(B*T, C)\n", 277 | " targets = targets.view(B*T)\n", 278 | " loss = F.cross_entropy(logits, targets)\n", 279 | "\n", 280 | " return logits, loss\n", 281 | "\n", 282 | " def generate(self, idx, max_new_tokens):\n", 283 | " # idx is (B, T) array of indices in the current context\n", 284 | " for _ in range(max_new_tokens):\n", 285 | " # crop idx to the last block_size tokens\n", 286 | " idx_cond = idx[:, -block_size:]\n", 287 | " # get the predictions\n", 288 | " logits, loss = self(idx_cond)\n", 289 | " # focus only on the last time step\n", 290 | " logits = logits[:, -1, :] # becomes (B, C)\n", 291 | " # apply softmax to get probabilities\n", 292 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 293 | " # sample from the distribution\n", 294 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 295 | " # append sampled index to the running sequence\n", 296 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 297 | " return idx" 298 | ] 299 | }, 300 | { 301 | "cell_type": "code", 302 | "execution_count": 14, 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "10.788929 M parameters\n" 310 | ] 311 | } 312 | ], 313 | "source": [ 314 | "model = GPTLanguageModel()\n", 315 | "m = model.to(device)\n", 316 | "# print the number of parameters in the model\n", 317 | "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n", 318 | "\n", 319 | "# create a PyTorch optimizer\n", 320 | "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)" 321 | ] 322 | }, 323 | { 324 | "cell_type": "code", 325 | "execution_count": 15, 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "step 0: train loss 4.2221, val loss 4.2306\n" 333 | ] 334 | } 335 | ], 336 | "source": [ 337 | "for iter in range(max_iters):\n", 338 | "\n", 339 | " # every once in a while evaluate the loss on train and val sets\n", 340 | " if iter % eval_interval == 0 or iter == max_iters - 1:\n", 341 | " losses = estimate_loss()\n", 342 | " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", 343 | "\n", 344 | " # sample a batch of data\n", 345 | " xb, yb = get_batch('train')\n", 346 | "\n", 347 | " # evaluate the loss\n", 348 | " logits, loss = model(xb, yb)\n", 349 | " optimizer.zero_grad(set_to_none=True)\n", 350 | " loss.backward()\n", 351 | " optimizer.step()\n", 352 | "\n", 353 | "# generate from the model\n", 354 | "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", 355 | "print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))\n", 356 | "#open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist()))" 357 | ] 358 | } 359 | ], 360 | "metadata": { 361 | "kernelspec": { 362 | "display_name": "Python 3", 363 | "language": "python", 364 | "name": "python3" 365 | }, 366 | "language_info": { 367 | "codemirror_mode": { 368 | "name": "ipython", 369 | "version": 3 370 | }, 371 | "file_extension": ".py", 372 | "mimetype": "text/x-python", 373 | "name": "python", 374 | "nbconvert_exporter": "python", 375 | "pygments_lexer": "ipython3", 376 | "version": "3.11.2" 377 | }, 378 | "orig_nbformat": 4 379 | }, 380 | "nbformat": 4, 381 | "nbformat_minor": 2 382 | } 383 | -------------------------------------------------------------------------------- /gpt.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | 5 | # hyperparameters 6 | batch_size = 64 # how many independent sequences will we process in parallel? 7 | block_size = 256 # what is the maximum context length for predictions? 8 | max_iters = 5000 9 | eval_interval = 500 10 | learning_rate = 3e-4 11 | device = 'cuda' if torch.cuda.is_available() else 'cpu' 12 | eval_iters = 200 13 | n_embd = 384 14 | n_head = 6 15 | n_layer = 6 16 | dropout = 0.2 17 | # ------------ 18 | 19 | torch.manual_seed(1337) 20 | 21 | with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f: 22 | text = f.read() 23 | 24 | # here are all the unique characters that occur in this text 25 | chars = sorted(list(set(text))) 26 | vocab_size = len(chars) 27 | # create a mapping from characters to integers 28 | stoi = { ch:i for i,ch in enumerate(chars) } 29 | itos = { i:ch for i,ch in enumerate(chars) } 30 | encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers 31 | decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string 32 | 33 | # Train and test splits 34 | data = torch.tensor(encode(text), dtype=torch.long) 35 | n = int(0.9*len(data)) # first 90% will be train, rest val 36 | train_data = data[:n] 37 | val_data = data[n:] 38 | 39 | # data loading 40 | def get_batch(split): 41 | # generate a small batch of data of inputs x and targets y 42 | data = train_data if split == 'train' else val_data 43 | ix = torch.randint(len(data) - block_size, (batch_size,)) 44 | x = torch.stack([data[i:i+block_size] for i in ix]) 45 | y = torch.stack([data[i+1:i+block_size+1] for i in ix]) 46 | x, y = x.to(device), y.to(device) 47 | return x, y 48 | 49 | @torch.no_grad() 50 | def estimate_loss(): 51 | out = {} 52 | model.eval() 53 | for split in ['train', 'val']: 54 | losses = torch.zeros(eval_iters) 55 | for k in range(eval_iters): 56 | X, Y = get_batch(split) 57 | logits, loss = model(X, Y) 58 | losses[k] = loss.item() 59 | out[split] = losses.mean() 60 | model.train() 61 | return out 62 | 63 | class Head(nn.Module): 64 | """ one head of self-attention """ 65 | 66 | def __init__(self, head_size): 67 | super().__init__() 68 | self.key = nn.Linear(n_embd, head_size, bias=False) 69 | self.query = nn.Linear(n_embd, head_size, bias=False) 70 | self.value = nn.Linear(n_embd, head_size, bias=False) 71 | self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size))) 72 | 73 | self.dropout = nn.Dropout(dropout) 74 | 75 | def forward(self, x): 76 | # input of size (batch, time-step, channels) 77 | # output of size (batch, time-step, head size) 78 | B,T,C = x.shape 79 | k = self.key(x) # (B,T,hs) 80 | q = self.query(x) # (B,T,hs) 81 | # compute attention scores ("affinities") 82 | wei = q @ k.transpose(-2,-1) * k.shape[-1]**-0.5 # (B, T, hs) @ (B, hs, T) -> (B, T, T) 83 | wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T) 84 | wei = F.softmax(wei, dim=-1) # (B, T, T) 85 | wei = self.dropout(wei) 86 | # perform the weighted aggregation of the values 87 | v = self.value(x) # (B,T,hs) 88 | out = wei @ v # (B, T, T) @ (B, T, hs) -> (B, T, hs) 89 | return out 90 | 91 | class MultiHeadAttention(nn.Module): 92 | """ multiple heads of self-attention in parallel """ 93 | 94 | def __init__(self, num_heads, head_size): 95 | super().__init__() 96 | self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)]) 97 | self.proj = nn.Linear(head_size * num_heads, n_embd) 98 | self.dropout = nn.Dropout(dropout) 99 | 100 | def forward(self, x): 101 | out = torch.cat([h(x) for h in self.heads], dim=-1) 102 | out = self.dropout(self.proj(out)) 103 | return out 104 | 105 | class FeedFoward(nn.Module): 106 | """ a simple linear layer followed by a non-linearity """ 107 | 108 | def __init__(self, n_embd): 109 | super().__init__() 110 | self.net = nn.Sequential( 111 | nn.Linear(n_embd, 4 * n_embd), 112 | nn.ReLU(), 113 | nn.Linear(4 * n_embd, n_embd), 114 | nn.Dropout(dropout), 115 | ) 116 | 117 | def forward(self, x): 118 | return self.net(x) 119 | 120 | class Block(nn.Module): 121 | """ Transformer block: communication followed by computation """ 122 | 123 | def __init__(self, n_embd, n_head): 124 | # n_embd: embedding dimension, n_head: the number of heads we'd like 125 | super().__init__() 126 | head_size = n_embd // n_head 127 | self.sa = MultiHeadAttention(n_head, head_size) 128 | self.ffwd = FeedFoward(n_embd) 129 | self.ln1 = nn.LayerNorm(n_embd) 130 | self.ln2 = nn.LayerNorm(n_embd) 131 | 132 | def forward(self, x): 133 | x = x + self.sa(self.ln1(x)) 134 | x = x + self.ffwd(self.ln2(x)) 135 | return x 136 | 137 | class GPTLanguageModel(nn.Module): 138 | 139 | def __init__(self): 140 | super().__init__() 141 | # each token directly reads off the logits for the next token from a lookup table 142 | self.token_embedding_table = nn.Embedding(vocab_size, n_embd) 143 | self.position_embedding_table = nn.Embedding(block_size, n_embd) 144 | self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)]) 145 | self.ln_f = nn.LayerNorm(n_embd) # final layer norm 146 | self.lm_head = nn.Linear(n_embd, vocab_size) 147 | 148 | # better init, not covered in the original GPT video, but important, will cover in followup video 149 | self.apply(self._init_weights) 150 | 151 | def _init_weights(self, module): 152 | if isinstance(module, nn.Linear): 153 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 154 | if module.bias is not None: 155 | torch.nn.init.zeros_(module.bias) 156 | elif isinstance(module, nn.Embedding): 157 | torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) 158 | 159 | def forward(self, idx, targets=None): 160 | B, T = idx.shape 161 | 162 | # idx and targets are both (B,T) tensor of integers 163 | tok_emb = self.token_embedding_table(idx) # (B,T,C) 164 | pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C) 165 | x = tok_emb + pos_emb # (B,T,C) 166 | x = self.blocks(x) # (B,T,C) 167 | x = self.ln_f(x) # (B,T,C) 168 | logits = self.lm_head(x) # (B,T,vocab_size) 169 | 170 | if targets is None: 171 | loss = None 172 | else: 173 | B, T, C = logits.shape 174 | logits = logits.view(B*T, C) 175 | targets = targets.view(B*T) 176 | loss = F.cross_entropy(logits, targets) 177 | 178 | return logits, loss 179 | 180 | def generate(self, idx, max_new_tokens): 181 | # idx is (B, T) array of indices in the current context 182 | for _ in range(max_new_tokens): 183 | # crop idx to the last block_size tokens 184 | idx_cond = idx[:, -block_size:] 185 | # get the predictions 186 | logits, loss = self(idx_cond) 187 | # focus only on the last time step 188 | logits = logits[:, -1, :] # becomes (B, C) 189 | # apply softmax to get probabilities 190 | probs = F.softmax(logits, dim=-1) # (B, C) 191 | # sample from the distribution 192 | idx_next = torch.multinomial(probs, num_samples=1) # (B, 1) 193 | # append sampled index to the running sequence 194 | idx = torch.cat((idx, idx_next), dim=1) # (B, T+1) 195 | return idx 196 | 197 | model = GPTLanguageModel() 198 | m = model.to(device) 199 | # print the number of parameters in the model 200 | print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters') 201 | 202 | # create a PyTorch optimizer 203 | optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate) 204 | 205 | for iter in range(max_iters): 206 | 207 | # every once in a while evaluate the loss on train and val sets 208 | if iter % eval_interval == 0 or iter == max_iters - 1: 209 | losses = estimate_loss() 210 | print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}") 211 | 212 | # sample a batch of data 213 | xb, yb = get_batch('train') 214 | 215 | # evaluate the loss 216 | logits, loss = model(xb, yb) 217 | optimizer.zero_grad(set_to_none=True) 218 | loss.backward() 219 | optimizer.step() 220 | 221 | # generate from the model 222 | context = torch.zeros((1, 1), dtype=torch.long, device=device) 223 | print(decode(m.generate(context, max_new_tokens=500)[0].tolist())) 224 | #open('more.txt', 'w').write(decode(m.generate(context, max_new_tokens=10000)[0].tolist())) 225 | -------------------------------------------------------------------------------- /gpt1.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "data": { 10 | "text/plain": [ 11 | "" 12 | ] 13 | }, 14 | "execution_count": 4, 15 | "metadata": {}, 16 | "output_type": "execute_result" 17 | } 18 | ], 19 | "source": [ 20 | "import torch\n", 21 | "import torch.nn as nn\n", 22 | "from torch.nn import functional as F\n", 23 | "\n", 24 | "# hyperparameters\n", 25 | "batch_size = 32 # how many independent sequences will we process in parallel?\n", 26 | "block_size = 8 # what is the maximum context length for predictions?\n", 27 | "max_iters = 3000\n", 28 | "eval_interval = 300\n", 29 | "learning_rate = 1e-2\n", 30 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 31 | "eval_iters = 200\n", 32 | "# ------------\n", 33 | "\n", 34 | "torch.manual_seed(1337)" 35 | ] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "execution_count": 5, 40 | "metadata": {}, 41 | "outputs": [], 42 | "source": [ 43 | "with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f:\n", 44 | " text = f.read()" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": 6, 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "length of dataset in characters: 1115394\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "print(\"length of dataset in characters: \", len(text))" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 7, 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "First Citizen:\n", 74 | "Before we proceed any further, hear me speak.\n", 75 | "\n", 76 | "All:\n", 77 | "Speak, speak.\n", 78 | "\n", 79 | "First Citizen:\n", 80 | "You are all resolved rather to die than to famish?\n", 81 | "\n", 82 | "All:\n", 83 | "Resolved. resolved.\n", 84 | "\n", 85 | "First Citizen:\n", 86 | "First, you know Caius Marcius is chief enemy to the people.\n", 87 | "\n", 88 | "All:\n", 89 | "We know't, we know't.\n", 90 | "\n", 91 | "First Citizen:\n", 92 | "Let us kill him, and we'll have corn at our own price.\n", 93 | "Is't a verdict?\n", 94 | "\n", 95 | "All:\n", 96 | "No more talking on't; let it be done: away, away!\n", 97 | "\n", 98 | "Second Citizen:\n", 99 | "One word, good citizens.\n", 100 | "\n", 101 | "First Citizen:\n", 102 | "We are accounted poor citizens, the patricians good.\n", 103 | "What authority surfeits on would relieve us: if they\n", 104 | "would yield us but the superfluity, while it were\n", 105 | "wholesome, we might guess they relieved us humanely;\n", 106 | "but they think we are too dear: the leanness that\n", 107 | "afflicts us, the object of our misery, is as an\n", 108 | "inventory to particularise their abundance; our\n", 109 | "sufferance is a gain to them Let us revenge this with\n", 110 | "our pikes, ere we become rakes: for the gods know I\n", 111 | "speak this in hunger for bread, not in thirst for revenge.\n", 112 | "\n", 113 | "\n" 114 | ] 115 | } 116 | ], 117 | "source": [ 118 | "# let's look at the first 1000 characters\n", 119 | "print(text[:1000])" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": 8, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "name": "stdout", 129 | "output_type": "stream", 130 | "text": [ 131 | "\n", 132 | " !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz\n", 133 | "65\n" 134 | ] 135 | } 136 | ], 137 | "source": [ 138 | "# here are all the unique characters that occur in this text\n", 139 | "chars = sorted(list(set(text)))\n", 140 | "vocab_size = len(chars)\n", 141 | "print(''.join(chars))\n", 142 | "print(vocab_size)" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 9, 148 | "metadata": {}, 149 | "outputs": [ 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "[46, 47, 47, 1, 58, 46, 43, 56, 43]\n", 155 | "hii there\n" 156 | ] 157 | } 158 | ], 159 | "source": [ 160 | "# create a mapping from characters to integers\n", 161 | "stoi = { ch:i for i,ch in enumerate(chars) }\n", 162 | "itos = { i:ch for i,ch in enumerate(chars) }\n", 163 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", 164 | "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", 165 | "\n", 166 | "print(encode(\"hii there\"))\n", 167 | "print(decode(encode(\"hii there\")))" 168 | ] 169 | }, 170 | { 171 | "cell_type": "code", 172 | "execution_count": 10, 173 | "metadata": {}, 174 | "outputs": [ 175 | { 176 | "name": "stdout", 177 | "output_type": "stream", 178 | "text": [ 179 | "torch.Size([1115394]) torch.int64\n", 180 | "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 14, 43, 44,\n", 181 | " 53, 56, 43, 1, 61, 43, 1, 54, 56, 53, 41, 43, 43, 42, 1, 39, 52, 63,\n", 182 | " 1, 44, 59, 56, 58, 46, 43, 56, 6, 1, 46, 43, 39, 56, 1, 51, 43, 1,\n", 183 | " 57, 54, 43, 39, 49, 8, 0, 0, 13, 50, 50, 10, 0, 31, 54, 43, 39, 49,\n", 184 | " 6, 1, 57, 54, 43, 39, 49, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47,\n", 185 | " 58, 47, 64, 43, 52, 10, 0, 37, 53, 59, 1, 39, 56, 43, 1, 39, 50, 50,\n", 186 | " 1, 56, 43, 57, 53, 50, 60, 43, 42, 1, 56, 39, 58, 46, 43, 56, 1, 58,\n", 187 | " 53, 1, 42, 47, 43, 1, 58, 46, 39, 52, 1, 58, 53, 1, 44, 39, 51, 47,\n", 188 | " 57, 46, 12, 0, 0, 13, 50, 50, 10, 0, 30, 43, 57, 53, 50, 60, 43, 42,\n", 189 | " 8, 1, 56, 43, 57, 53, 50, 60, 43, 42, 8, 0, 0, 18, 47, 56, 57, 58,\n", 190 | " 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 18, 47, 56, 57, 58, 6, 1, 63,\n", 191 | " 53, 59, 1, 49, 52, 53, 61, 1, 15, 39, 47, 59, 57, 1, 25, 39, 56, 41,\n", 192 | " 47, 59, 57, 1, 47, 57, 1, 41, 46, 47, 43, 44, 1, 43, 52, 43, 51, 63,\n", 193 | " 1, 58, 53, 1, 58, 46, 43, 1, 54, 43, 53, 54, 50, 43, 8, 0, 0, 13,\n", 194 | " 50, 50, 10, 0, 35, 43, 1, 49, 52, 53, 61, 5, 58, 6, 1, 61, 43, 1,\n", 195 | " 49, 52, 53, 61, 5, 58, 8, 0, 0, 18, 47, 56, 57, 58, 1, 15, 47, 58,\n", 196 | " 47, 64, 43, 52, 10, 0, 24, 43, 58, 1, 59, 57, 1, 49, 47, 50, 50, 1,\n", 197 | " 46, 47, 51, 6, 1, 39, 52, 42, 1, 61, 43, 5, 50, 50, 1, 46, 39, 60,\n", 198 | " 43, 1, 41, 53, 56, 52, 1, 39, 58, 1, 53, 59, 56, 1, 53, 61, 52, 1,\n", 199 | " 54, 56, 47, 41, 43, 8, 0, 21, 57, 5, 58, 1, 39, 1, 60, 43, 56, 42,\n", 200 | " 47, 41, 58, 12, 0, 0, 13, 50, 50, 10, 0, 26, 53, 1, 51, 53, 56, 43,\n", 201 | " 1, 58, 39, 50, 49, 47, 52, 45, 1, 53, 52, 5, 58, 11, 1, 50, 43, 58,\n", 202 | " 1, 47, 58, 1, 40, 43, 1, 42, 53, 52, 43, 10, 1, 39, 61, 39, 63, 6,\n", 203 | " 1, 39, 61, 39, 63, 2, 0, 0, 31, 43, 41, 53, 52, 42, 1, 15, 47, 58,\n", 204 | " 47, 64, 43, 52, 10, 0, 27, 52, 43, 1, 61, 53, 56, 42, 6, 1, 45, 53,\n", 205 | " 53, 42, 1, 41, 47, 58, 47, 64, 43, 52, 57, 8, 0, 0, 18, 47, 56, 57,\n", 206 | " 58, 1, 15, 47, 58, 47, 64, 43, 52, 10, 0, 35, 43, 1, 39, 56, 43, 1,\n", 207 | " 39, 41, 41, 53, 59, 52, 58, 43, 42, 1, 54, 53, 53, 56, 1, 41, 47, 58,\n", 208 | " 47, 64, 43, 52, 57, 6, 1, 58, 46, 43, 1, 54, 39, 58, 56, 47, 41, 47,\n", 209 | " 39, 52, 57, 1, 45, 53, 53, 42, 8, 0, 35, 46, 39, 58, 1, 39, 59, 58,\n", 210 | " 46, 53, 56, 47, 58, 63, 1, 57, 59, 56, 44, 43, 47, 58, 57, 1, 53, 52,\n", 211 | " 1, 61, 53, 59, 50, 42, 1, 56, 43, 50, 47, 43, 60, 43, 1, 59, 57, 10,\n", 212 | " 1, 47, 44, 1, 58, 46, 43, 63, 0, 61, 53, 59, 50, 42, 1, 63, 47, 43,\n", 213 | " 50, 42, 1, 59, 57, 1, 40, 59, 58, 1, 58, 46, 43, 1, 57, 59, 54, 43,\n", 214 | " 56, 44, 50, 59, 47, 58, 63, 6, 1, 61, 46, 47, 50, 43, 1, 47, 58, 1,\n", 215 | " 61, 43, 56, 43, 0, 61, 46, 53, 50, 43, 57, 53, 51, 43, 6, 1, 61, 43,\n", 216 | " 1, 51, 47, 45, 46, 58, 1, 45, 59, 43, 57, 57, 1, 58, 46, 43, 63, 1,\n", 217 | " 56, 43, 50, 47, 43, 60, 43, 42, 1, 59, 57, 1, 46, 59, 51, 39, 52, 43,\n", 218 | " 50, 63, 11, 0, 40, 59, 58, 1, 58, 46, 43, 63, 1, 58, 46, 47, 52, 49,\n", 219 | " 1, 61, 43, 1, 39, 56, 43, 1, 58, 53, 53, 1, 42, 43, 39, 56, 10, 1,\n", 220 | " 58, 46, 43, 1, 50, 43, 39, 52, 52, 43, 57, 57, 1, 58, 46, 39, 58, 0,\n", 221 | " 39, 44, 44, 50, 47, 41, 58, 57, 1, 59, 57, 6, 1, 58, 46, 43, 1, 53,\n", 222 | " 40, 48, 43, 41, 58, 1, 53, 44, 1, 53, 59, 56, 1, 51, 47, 57, 43, 56,\n", 223 | " 63, 6, 1, 47, 57, 1, 39, 57, 1, 39, 52, 0, 47, 52, 60, 43, 52, 58,\n", 224 | " 53, 56, 63, 1, 58, 53, 1, 54, 39, 56, 58, 47, 41, 59, 50, 39, 56, 47,\n", 225 | " 57, 43, 1, 58, 46, 43, 47, 56, 1, 39, 40, 59, 52, 42, 39, 52, 41, 43,\n", 226 | " 11, 1, 53, 59, 56, 0, 57, 59, 44, 44, 43, 56, 39, 52, 41, 43, 1, 47,\n", 227 | " 57, 1, 39, 1, 45, 39, 47, 52, 1, 58, 53, 1, 58, 46, 43, 51, 1, 24,\n", 228 | " 43, 58, 1, 59, 57, 1, 56, 43, 60, 43, 52, 45, 43, 1, 58, 46, 47, 57,\n", 229 | " 1, 61, 47, 58, 46, 0, 53, 59, 56, 1, 54, 47, 49, 43, 57, 6, 1, 43,\n", 230 | " 56, 43, 1, 61, 43, 1, 40, 43, 41, 53, 51, 43, 1, 56, 39, 49, 43, 57,\n", 231 | " 10, 1, 44, 53, 56, 1, 58, 46, 43, 1, 45, 53, 42, 57, 1, 49, 52, 53,\n", 232 | " 61, 1, 21, 0, 57, 54, 43, 39, 49, 1, 58, 46, 47, 57, 1, 47, 52, 1,\n", 233 | " 46, 59, 52, 45, 43, 56, 1, 44, 53, 56, 1, 40, 56, 43, 39, 42, 6, 1,\n", 234 | " 52, 53, 58, 1, 47, 52, 1, 58, 46, 47, 56, 57, 58, 1, 44, 53, 56, 1,\n", 235 | " 56, 43, 60, 43, 52, 45, 43, 8, 0, 0])\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "# let's now encode the entire text dataset and store it into a torch.Tensor\n", 241 | "import torch # we use PyTorch: https://pytorch.org\n", 242 | "data = torch.tensor(encode(text), dtype=torch.long)\n", 243 | "print(data.shape, data.dtype)\n", 244 | "print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this" 245 | ] 246 | }, 247 | { 248 | "cell_type": "code", 249 | "execution_count": 11, 250 | "metadata": {}, 251 | "outputs": [], 252 | "source": [ 253 | "# Let's now split up the data into train and validation sets\n", 254 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n", 255 | "train_data = data[:n]\n", 256 | "val_data = data[n:]" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": 12, 262 | "metadata": {}, 263 | "outputs": [ 264 | { 265 | "data": { 266 | "text/plain": [ 267 | "tensor([18, 47, 56, 57, 58, 1, 15, 47, 58])" 268 | ] 269 | }, 270 | "execution_count": 12, 271 | "metadata": {}, 272 | "output_type": "execute_result" 273 | } 274 | ], 275 | "source": [ 276 | "block_size = 8\n", 277 | "train_data[:block_size+1]" 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 13, 283 | "metadata": {}, 284 | "outputs": [ 285 | { 286 | "name": "stdout", 287 | "output_type": "stream", 288 | "text": [ 289 | "when input is tensor([18]) the target: 47\n", 290 | "when input is tensor([18, 47]) the target: 56\n", 291 | "when input is tensor([18, 47, 56]) the target: 57\n", 292 | "when input is tensor([18, 47, 56, 57]) the target: 58\n", 293 | "when input is tensor([18, 47, 56, 57, 58]) the target: 1\n", 294 | "when input is tensor([18, 47, 56, 57, 58, 1]) the target: 15\n", 295 | "when input is tensor([18, 47, 56, 57, 58, 1, 15]) the target: 47\n", 296 | "when input is tensor([18, 47, 56, 57, 58, 1, 15, 47]) the target: 58\n" 297 | ] 298 | } 299 | ], 300 | "source": [ 301 | "x = train_data[:block_size]\n", 302 | "y = train_data[1:block_size+1]\n", 303 | "for t in range(block_size):\n", 304 | " context = x[:t+1]\n", 305 | " target = y[t]\n", 306 | " print(f\"when input is {context} the target: {target}\")" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 14, 312 | "metadata": {}, 313 | "outputs": [ 314 | { 315 | "name": "stdout", 316 | "output_type": "stream", 317 | "text": [ 318 | "inputs:\n", 319 | "torch.Size([4, 8])\n", 320 | "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n", 321 | " [44, 53, 56, 1, 58, 46, 39, 58],\n", 322 | " [52, 58, 1, 58, 46, 39, 58, 1],\n", 323 | " [25, 17, 27, 10, 0, 21, 1, 54]])\n", 324 | "targets:\n", 325 | "torch.Size([4, 8])\n", 326 | "tensor([[43, 58, 5, 57, 1, 46, 43, 39],\n", 327 | " [53, 56, 1, 58, 46, 39, 58, 1],\n", 328 | " [58, 1, 58, 46, 39, 58, 1, 46],\n", 329 | " [17, 27, 10, 0, 21, 1, 54, 39]])\n", 330 | "----\n", 331 | "when input is [24] the target: 43\n", 332 | "when input is [24, 43] the target: 58\n", 333 | "when input is [24, 43, 58] the target: 5\n", 334 | "when input is [24, 43, 58, 5] the target: 57\n", 335 | "when input is [24, 43, 58, 5, 57] the target: 1\n", 336 | "when input is [24, 43, 58, 5, 57, 1] the target: 46\n", 337 | "when input is [24, 43, 58, 5, 57, 1, 46] the target: 43\n", 338 | "when input is [24, 43, 58, 5, 57, 1, 46, 43] the target: 39\n", 339 | "when input is [44] the target: 53\n", 340 | "when input is [44, 53] the target: 56\n", 341 | "when input is [44, 53, 56] the target: 1\n", 342 | "when input is [44, 53, 56, 1] the target: 58\n", 343 | "when input is [44, 53, 56, 1, 58] the target: 46\n", 344 | "when input is [44, 53, 56, 1, 58, 46] the target: 39\n", 345 | "when input is [44, 53, 56, 1, 58, 46, 39] the target: 58\n", 346 | "when input is [44, 53, 56, 1, 58, 46, 39, 58] the target: 1\n", 347 | "when input is [52] the target: 58\n", 348 | "when input is [52, 58] the target: 1\n", 349 | "when input is [52, 58, 1] the target: 58\n", 350 | "when input is [52, 58, 1, 58] the target: 46\n", 351 | "when input is [52, 58, 1, 58, 46] the target: 39\n", 352 | "when input is [52, 58, 1, 58, 46, 39] the target: 58\n", 353 | "when input is [52, 58, 1, 58, 46, 39, 58] the target: 1\n", 354 | "when input is [52, 58, 1, 58, 46, 39, 58, 1] the target: 46\n", 355 | "when input is [25] the target: 17\n", 356 | "when input is [25, 17] the target: 27\n", 357 | "when input is [25, 17, 27] the target: 10\n", 358 | "when input is [25, 17, 27, 10] the target: 0\n", 359 | "when input is [25, 17, 27, 10, 0] the target: 21\n", 360 | "when input is [25, 17, 27, 10, 0, 21] the target: 1\n", 361 | "when input is [25, 17, 27, 10, 0, 21, 1] the target: 54\n", 362 | "when input is [25, 17, 27, 10, 0, 21, 1, 54] the target: 39\n" 363 | ] 364 | } 365 | ], 366 | "source": [ 367 | "torch.manual_seed(1337)\n", 368 | "batch_size = 4 # how many independent sequences will we process in parallel?\n", 369 | "block_size = 8 # what is the maximum context length for predictions?\n", 370 | "\n", 371 | "def get_batch(split):\n", 372 | " # generate a small batch of data of inputs x and targets y\n", 373 | " data = train_data if split == 'train' else val_data\n", 374 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 375 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 376 | " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", 377 | " return x, y\n", 378 | "\n", 379 | "xb, yb = get_batch('train')\n", 380 | "print('inputs:')\n", 381 | "print(xb.shape)\n", 382 | "print(xb)\n", 383 | "print('targets:')\n", 384 | "print(yb.shape)\n", 385 | "print(yb)\n", 386 | "\n", 387 | "print('----')\n", 388 | "\n", 389 | "for b in range(batch_size): # batch dimension\n", 390 | " for t in range(block_size): # time dimension\n", 391 | " context = xb[b, :t+1]\n", 392 | " target = yb[b,t]\n", 393 | " print(f\"when input is {context.tolist()} the target: {target}\")" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": 15, 399 | "metadata": {}, 400 | "outputs": [ 401 | { 402 | "name": "stdout", 403 | "output_type": "stream", 404 | "text": [ 405 | "tensor([[24, 43, 58, 5, 57, 1, 46, 43],\n", 406 | " [44, 53, 56, 1, 58, 46, 39, 58],\n", 407 | " [52, 58, 1, 58, 46, 39, 58, 1],\n", 408 | " [25, 17, 27, 10, 0, 21, 1, 54]])\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "print(xb) # our input to the transformer" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 16, 419 | "metadata": {}, 420 | "outputs": [ 421 | { 422 | "name": "stdout", 423 | "output_type": "stream", 424 | "text": [ 425 | "torch.Size([32, 65])\n", 426 | "tensor(4.8786, grad_fn=)\n", 427 | "\n", 428 | "Sr?qP-QWktXoL&jLDJgOLVz'RIoDqHdhsV&vLLxatjscMpwLERSPyao.qfzs$Ys$zF-w,;eEkzxjgCKFChs!iWW.ObzDnxA Ms$3\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "import torch\n", 434 | "import torch.nn as nn\n", 435 | "from torch.nn import functional as F\n", 436 | "torch.manual_seed(1337)\n", 437 | "\n", 438 | "class BigramLanguageModel(nn.Module):\n", 439 | "\n", 440 | " def __init__(self, vocab_size):\n", 441 | " super().__init__()\n", 442 | " # each token directly reads off the logits for the next token from a lookup table\n", 443 | " self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)\n", 444 | "\n", 445 | " def forward(self, idx, targets=None):\n", 446 | "\n", 447 | " # idx and targets are both (B,T) tensor of integers\n", 448 | " logits = self.token_embedding_table(idx) # (B,T,C)\n", 449 | " \n", 450 | " if targets is None:\n", 451 | " loss = None\n", 452 | " else:\n", 453 | " B, T, C = logits.shape\n", 454 | " logits = logits.view(B*T, C)\n", 455 | " targets = targets.view(B*T)\n", 456 | " loss = F.cross_entropy(logits, targets)\n", 457 | "\n", 458 | " return logits, loss\n", 459 | " \n", 460 | " def generate(self, idx, max_new_tokens):\n", 461 | " # idx is (B, T) array of indices in the current context\n", 462 | " for _ in range(max_new_tokens):\n", 463 | " # get the predictions\n", 464 | " logits, loss = self(idx)\n", 465 | " # focus only on the last time step\n", 466 | " logits = logits[:, -1, :] # becomes (B, C)\n", 467 | " # apply softmax to get probabilities\n", 468 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 469 | " # sample from the distribution\n", 470 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 471 | " # append sampled index to the running sequence\n", 472 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 473 | " return idx\n", 474 | "\n", 475 | "m = BigramLanguageModel(vocab_size)\n", 476 | "logits, loss = m(xb, yb)\n", 477 | "print(logits.shape)\n", 478 | "print(loss)\n", 479 | "\n", 480 | "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=100)[0].tolist()))" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 17, 486 | "metadata": {}, 487 | "outputs": [], 488 | "source": [ 489 | "# create a PyTorch optimizer\n", 490 | "optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": 18, 496 | "metadata": {}, 497 | "outputs": [ 498 | { 499 | "name": "stdout", 500 | "output_type": "stream", 501 | "text": [ 502 | "4.587916374206543\n" 503 | ] 504 | } 505 | ], 506 | "source": [ 507 | "batch_size = 32\n", 508 | "for steps in range(100): # increase number of steps for good results... \n", 509 | " \n", 510 | " # sample a batch of data\n", 511 | " xb, yb = get_batch('train')\n", 512 | "\n", 513 | " # evaluate the loss\n", 514 | " logits, loss = m(xb, yb)\n", 515 | " optimizer.zero_grad(set_to_none=True)\n", 516 | " loss.backward()\n", 517 | " optimizer.step()\n", 518 | "\n", 519 | "print(loss.item())" 520 | ] 521 | }, 522 | { 523 | "cell_type": "code", 524 | "execution_count": 21, 525 | "metadata": {}, 526 | "outputs": [ 527 | { 528 | "name": "stdout", 529 | "output_type": "stream", 530 | "text": [ 531 | "\n", 532 | "xhhmZVK ZKbtgnLT C?uE-Ru3$-trd?PxzrVX'q-bQ3!!eDbAF-Wd&urdTlk!agFM?qmbHq?!YCD mzLys:zKRj$.ysTt tTgO'bot$po!z,pDmx;i. sCXqCs -ttR.eq-bnkc,:3nA.-too.muQvvxTEeaYCdx-t3:qfkzkH\n", 533 | "FZXrcLbVbHTENvrLFzXdVju$'P-wapG,R cOtXegSPy3Sto;ivv'nZ3QFKpDllY:JM?ujgCJiG,i\n", 534 | "D.Srzo,m?3T VBQBMpx?KDGt\n", 535 | "RqfcW!lVOIirgUqfBD3:CEKM:CYzDWI3E3QHU C-t&DnnQ3nWNa!oCo-o?Y:loKBL!I&K.\n", 536 | "oQ.NmrL:CHZAIF$fkQVzDCJMXhssaH?q,iKDnYIBT:C -jxfqttOHdw'AfgYQE&y3QyxD..yoBH\n", 537 | "V&yZW:x;aFXgddZR&N:GF$'z,BNCoRSP -&vO fw$wJJrqeQ?EZbGz;DnnlyVJkzy!U?a:CR' qHE3f?\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "print(decode(m.generate(idx = torch.zeros((1, 1), dtype=torch.long), max_new_tokens=500)[0].tolist()))" 543 | ] 544 | }, 545 | { 546 | "attachments": {}, 547 | "cell_type": "markdown", 548 | "metadata": {}, 549 | "source": [ 550 | "## The mathematical trick in self-attention" 551 | ] 552 | }, 553 | { 554 | "cell_type": "code", 555 | "execution_count": 22, 556 | "metadata": {}, 557 | "outputs": [ 558 | { 559 | "name": "stdout", 560 | "output_type": "stream", 561 | "text": [ 562 | "a=\n", 563 | "tensor([[1.0000, 0.0000, 0.0000],\n", 564 | " [0.5000, 0.5000, 0.0000],\n", 565 | " [0.3333, 0.3333, 0.3333]])\n", 566 | "--\n", 567 | "b=\n", 568 | "tensor([[2., 7.],\n", 569 | " [6., 4.],\n", 570 | " [6., 5.]])\n", 571 | "--\n", 572 | "c=\n", 573 | "tensor([[2.0000, 7.0000],\n", 574 | " [4.0000, 5.5000],\n", 575 | " [4.6667, 5.3333]])\n" 576 | ] 577 | } 578 | ], 579 | "source": [ 580 | "# toy example illustrating how matrix multiplication can be used for a \"weighted aggregation\"\n", 581 | "torch.manual_seed(42)\n", 582 | "a = torch.tril(torch.ones(3, 3))\n", 583 | "a = a / torch.sum(a, 1, keepdim=True)\n", 584 | "b = torch.randint(0,10,(3,2)).float()\n", 585 | "c = a @ b\n", 586 | "print('a=')\n", 587 | "print(a)\n", 588 | "print('--')\n", 589 | "print('b=')\n", 590 | "print(b)\n", 591 | "print('--')\n", 592 | "print('c=')\n", 593 | "print(c)" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": 23, 599 | "metadata": {}, 600 | "outputs": [ 601 | { 602 | "data": { 603 | "text/plain": [ 604 | "torch.Size([4, 8, 2])" 605 | ] 606 | }, 607 | "execution_count": 23, 608 | "metadata": {}, 609 | "output_type": "execute_result" 610 | } 611 | ], 612 | "source": [ 613 | "# consider the following toy example:\n", 614 | "\n", 615 | "torch.manual_seed(1337)\n", 616 | "B,T,C = 4,8,2 # batch, time, channels\n", 617 | "x = torch.randn(B,T,C)\n", 618 | "x.shape" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 24, 624 | "metadata": {}, 625 | "outputs": [], 626 | "source": [ 627 | "# We want x[b,t] = mean_{i<=t} x[b,i]\n", 628 | "xbow = torch.zeros((B,T,C))\n", 629 | "for b in range(B):\n", 630 | " for t in range(T):\n", 631 | " xprev = x[b,:t+1] # (t,C)\n", 632 | " xbow[b,t] = torch.mean(xprev, 0)" 633 | ] 634 | }, 635 | { 636 | "cell_type": "code", 637 | "execution_count": 26, 638 | "metadata": {}, 639 | "outputs": [ 640 | { 641 | "data": { 642 | "text/plain": [ 643 | "False" 644 | ] 645 | }, 646 | "execution_count": 26, 647 | "metadata": {}, 648 | "output_type": "execute_result" 649 | } 650 | ], 651 | "source": [ 652 | "# version 2: using matrix multiply for a weighted aggregation\n", 653 | "wei = torch.tril(torch.ones(T, T))\n", 654 | "wei = wei / wei.sum(1, keepdim=True)\n", 655 | "xbow2 = wei @ x # (B, T, T) @ (B, T, C) ----> (B, T, C)\n", 656 | "torch.allclose(xbow, xbow2)" 657 | ] 658 | }, 659 | { 660 | "cell_type": "code", 661 | "execution_count": 27, 662 | "metadata": {}, 663 | "outputs": [ 664 | { 665 | "data": { 666 | "text/plain": [ 667 | "False" 668 | ] 669 | }, 670 | "execution_count": 27, 671 | "metadata": {}, 672 | "output_type": "execute_result" 673 | } 674 | ], 675 | "source": [ 676 | "# version 3: use Softmax\n", 677 | "tril = torch.tril(torch.ones(T, T))\n", 678 | "wei = torch.zeros((T,T))\n", 679 | "wei = wei.masked_fill(tril == 0, float('-inf'))\n", 680 | "wei = F.softmax(wei, dim=-1)\n", 681 | "xbow3 = wei @ x\n", 682 | "torch.allclose(xbow, xbow3)" 683 | ] 684 | }, 685 | { 686 | "cell_type": "code", 687 | "execution_count": 28, 688 | "metadata": {}, 689 | "outputs": [ 690 | { 691 | "data": { 692 | "text/plain": [ 693 | "torch.Size([4, 8, 16])" 694 | ] 695 | }, 696 | "execution_count": 28, 697 | "metadata": {}, 698 | "output_type": "execute_result" 699 | } 700 | ], 701 | "source": [ 702 | "# version 4: self-attention!\n", 703 | "torch.manual_seed(1337)\n", 704 | "B,T,C = 4,8,32 # batch, time, channels\n", 705 | "x = torch.randn(B,T,C)\n", 706 | "\n", 707 | "# let's see a single Head perform self-attention\n", 708 | "head_size = 16\n", 709 | "key = nn.Linear(C, head_size, bias=False)\n", 710 | "query = nn.Linear(C, head_size, bias=False)\n", 711 | "value = nn.Linear(C, head_size, bias=False)\n", 712 | "k = key(x) # (B, T, 16)\n", 713 | "q = query(x) # (B, T, 16)\n", 714 | "wei = q @ k.transpose(-2, -1) # (B, T, 16) @ (B, 16, T) ---> (B, T, T)\n", 715 | "\n", 716 | "tril = torch.tril(torch.ones(T, T))\n", 717 | "#wei = torch.zeros((T,T))\n", 718 | "wei = wei.masked_fill(tril == 0, float('-inf'))\n", 719 | "wei = F.softmax(wei, dim=-1)\n", 720 | "\n", 721 | "v = value(x)\n", 722 | "out = wei @ v\n", 723 | "#out = wei @ x\n", 724 | "\n", 725 | "out.shape" 726 | ] 727 | }, 728 | { 729 | "cell_type": "code", 730 | "execution_count": 29, 731 | "metadata": {}, 732 | "outputs": [ 733 | { 734 | "data": { 735 | "text/plain": [ 736 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 737 | " [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 738 | " [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 739 | " [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],\n", 740 | " [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],\n", 741 | " [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],\n", 742 | " [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],\n", 743 | " [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],\n", 744 | " grad_fn=)" 745 | ] 746 | }, 747 | "execution_count": 29, 748 | "metadata": {}, 749 | "output_type": "execute_result" 750 | } 751 | ], 752 | "source": [ 753 | "wei[0]" 754 | ] 755 | }, 756 | { 757 | "attachments": {}, 758 | "cell_type": "markdown", 759 | "metadata": {}, 760 | "source": [ 761 | "# Notes:\n", 762 | "\n", 763 | "- Attention is a communication mechanism. Can be seen as nodes in a directed graph looking at each other and aggregating information with a weighted sum from all nodes that point to them, with data-dependent weights.\n", 764 | "- There is no notion of space. Attention simply acts over a set of vectors. This is why we need to positionally encode tokens.\n", 765 | "- Each example across batch dimension is of course processed completely independently and never \"talk\" to each other\n", 766 | "- In an \"encoder\" attention block just delete the single line that does masking with tril, allowing all tokens to communicate. This block here is called a \"decoder\" attention block because it has triangular masking, and is usually used in autoregressive settings, like language modeling.\n", 767 | "- \"self-attention\" just means that the keys and values are produced from the same source as queries. In \"cross-attention\", the queries still get produced from x, but the keys and values come from some other, external source (e.g. an encoder module)\n", 768 | "- \"Scaled\" attention additional divides wei by 1/sqrt(head_size). This makes it so when input Q,K are unit variance, wei will be unit variance too and Softmax will stay diffuse and not saturate too much. Illustration below." 769 | ] 770 | }, 771 | { 772 | "cell_type": "code", 773 | "execution_count": 30, 774 | "metadata": {}, 775 | "outputs": [], 776 | "source": [ 777 | "k = torch.randn(B,T,head_size)\n", 778 | "q = torch.randn(B,T,head_size)\n", 779 | "wei = q @ k.transpose(-2, -1) * head_size**-0.5" 780 | ] 781 | }, 782 | { 783 | "cell_type": "code", 784 | "execution_count": 31, 785 | "metadata": {}, 786 | "outputs": [ 787 | { 788 | "data": { 789 | "text/plain": [ 790 | "tensor(1.0449)" 791 | ] 792 | }, 793 | "execution_count": 31, 794 | "metadata": {}, 795 | "output_type": "execute_result" 796 | } 797 | ], 798 | "source": [ 799 | "k.var()" 800 | ] 801 | }, 802 | { 803 | "cell_type": "code", 804 | "execution_count": 32, 805 | "metadata": {}, 806 | "outputs": [ 807 | { 808 | "data": { 809 | "text/plain": [ 810 | "tensor(1.0700)" 811 | ] 812 | }, 813 | "execution_count": 32, 814 | "metadata": {}, 815 | "output_type": "execute_result" 816 | } 817 | ], 818 | "source": [ 819 | "q.var()" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": 33, 825 | "metadata": {}, 826 | "outputs": [ 827 | { 828 | "data": { 829 | "text/plain": [ 830 | "tensor(1.0918)" 831 | ] 832 | }, 833 | "execution_count": 33, 834 | "metadata": {}, 835 | "output_type": "execute_result" 836 | } 837 | ], 838 | "source": [ 839 | "wei.var()" 840 | ] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "execution_count": 34, 845 | "metadata": {}, 846 | "outputs": [ 847 | { 848 | "data": { 849 | "text/plain": [ 850 | "tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])" 851 | ] 852 | }, 853 | "execution_count": 34, 854 | "metadata": {}, 855 | "output_type": "execute_result" 856 | } 857 | ], 858 | "source": [ 859 | "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5]), dim=-1)" 860 | ] 861 | }, 862 | { 863 | "cell_type": "code", 864 | "execution_count": 35, 865 | "metadata": {}, 866 | "outputs": [ 867 | { 868 | "data": { 869 | "text/plain": [ 870 | "tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])" 871 | ] 872 | }, 873 | "execution_count": 35, 874 | "metadata": {}, 875 | "output_type": "execute_result" 876 | } 877 | ], 878 | "source": [ 879 | "torch.softmax(torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])*8, dim=-1) # gets too peaky, converges to one-hot" 880 | ] 881 | }, 882 | { 883 | "cell_type": "code", 884 | "execution_count": 36, 885 | "metadata": {}, 886 | "outputs": [ 887 | { 888 | "data": { 889 | "text/plain": [ 890 | "torch.Size([32, 100])" 891 | ] 892 | }, 893 | "execution_count": 36, 894 | "metadata": {}, 895 | "output_type": "execute_result" 896 | } 897 | ], 898 | "source": [ 899 | "class LayerNorm1d: # (used to be BatchNorm1d)\n", 900 | " \n", 901 | " def __init__(self, dim, eps=1e-5, momentum=0.1):\n", 902 | " self.eps = eps\n", 903 | " self.gamma = torch.ones(dim)\n", 904 | " self.beta = torch.zeros(dim)\n", 905 | " \n", 906 | " def __call__(self, x):\n", 907 | " # calculate the forward pass\n", 908 | " xmean = x.mean(1, keepdim=True) # batch mean\n", 909 | " xvar = x.var(1, keepdim=True) # batch variance\n", 910 | " xhat = (x - xmean) / torch.sqrt(xvar + self.eps) # normalize to unit variance\n", 911 | " self.out = self.gamma * xhat + self.beta\n", 912 | " return self.out\n", 913 | " \n", 914 | " def parameters(self):\n", 915 | " return [self.gamma, self.beta]\n", 916 | "\n", 917 | "torch.manual_seed(1337)\n", 918 | "module = LayerNorm1d(100)\n", 919 | "x = torch.randn(32, 100) # batch size 32 of 100-dimensional vectors\n", 920 | "x = module(x)\n", 921 | "x.shape" 922 | ] 923 | }, 924 | { 925 | "cell_type": "code", 926 | "execution_count": 37, 927 | "metadata": {}, 928 | "outputs": [ 929 | { 930 | "data": { 931 | "text/plain": [ 932 | "(tensor(0.1469), tensor(0.8803))" 933 | ] 934 | }, 935 | "execution_count": 37, 936 | "metadata": {}, 937 | "output_type": "execute_result" 938 | } 939 | ], 940 | "source": [ 941 | "x[:,0].mean(), x[:,0].std() # mean,std of one feature across all batch inputs" 942 | ] 943 | }, 944 | { 945 | "cell_type": "code", 946 | "execution_count": 38, 947 | "metadata": {}, 948 | "outputs": [ 949 | { 950 | "data": { 951 | "text/plain": [ 952 | "(tensor(-9.5367e-09), tensor(1.0000))" 953 | ] 954 | }, 955 | "execution_count": 38, 956 | "metadata": {}, 957 | "output_type": "execute_result" 958 | } 959 | ], 960 | "source": [ 961 | "x[0,:].mean(), x[0,:].std() # mean,std of a single input from the batch, of its features" 962 | ] 963 | }, 964 | { 965 | "cell_type": "code", 966 | "execution_count": 39, 967 | "metadata": {}, 968 | "outputs": [], 969 | "source": [ 970 | "# French to English translation example:\n", 971 | "\n", 972 | "# <--------- ENCODE ------------------><--------------- DECODE ----------------->\n", 973 | "# les réseaux de neurones sont géniaux! neural networks are awesome!\n", 974 | "\n" 975 | ] 976 | }, 977 | { 978 | "attachments": {}, 979 | "cell_type": "markdown", 980 | "metadata": {}, 981 | "source": [ 982 | "## Full finished code, for reference" 983 | ] 984 | }, 985 | { 986 | "cell_type": "code", 987 | "execution_count": 40, 988 | "metadata": {}, 989 | "outputs": [ 990 | { 991 | "name": "stdout", 992 | "output_type": "stream", 993 | "text": [ 994 | "0.209729 M parameters\n", 995 | "step 0: train loss 4.4116, val loss 4.4022\n", 996 | "step 100: train loss 2.6568, val loss 2.6670\n", 997 | "step 200: train loss 2.5091, val loss 2.5059\n", 998 | "step 300: train loss 2.4194, val loss 2.4334\n", 999 | "step 400: train loss 2.3503, val loss 2.3564\n", 1000 | "step 500: train loss 2.2966, val loss 2.3131\n", 1001 | "step 600: train loss 2.2408, val loss 2.2500\n", 1002 | "step 700: train loss 2.2052, val loss 2.2195\n", 1003 | "step 800: train loss 2.1638, val loss 2.1870\n", 1004 | "step 900: train loss 2.1247, val loss 2.1510\n", 1005 | "step 1000: train loss 2.1027, val loss 2.1292\n", 1006 | "step 1100: train loss 2.0701, val loss 2.1194\n", 1007 | "step 1200: train loss 2.0387, val loss 2.0796\n", 1008 | "step 1300: train loss 2.0246, val loss 2.0637\n", 1009 | "step 1400: train loss 1.9928, val loss 2.0376\n", 1010 | "step 1500: train loss 1.9693, val loss 2.0292\n", 1011 | "step 1600: train loss 1.9625, val loss 2.0460\n", 1012 | "step 1700: train loss 1.9417, val loss 2.0143\n", 1013 | "step 1800: train loss 1.9099, val loss 1.9980\n", 1014 | "step 1900: train loss 1.9105, val loss 1.9894\n", 1015 | "step 2000: train loss 1.8846, val loss 1.9956\n", 1016 | "step 2100: train loss 1.8725, val loss 1.9751\n", 1017 | "step 2200: train loss 1.8593, val loss 1.9636\n", 1018 | "step 2300: train loss 1.8558, val loss 1.9530\n", 1019 | "step 2400: train loss 1.8410, val loss 1.9449\n", 1020 | "step 2500: train loss 1.8145, val loss 1.9451\n", 1021 | "step 2600: train loss 1.8268, val loss 1.9424\n", 1022 | "step 2700: train loss 1.8105, val loss 1.9341\n", 1023 | "step 2800: train loss 1.8027, val loss 1.9238\n", 1024 | "step 2900: train loss 1.8047, val loss 1.9310\n", 1025 | "step 3000: train loss 1.7924, val loss 1.9198\n", 1026 | "step 3100: train loss 1.7671, val loss 1.9202\n", 1027 | "step 3200: train loss 1.7504, val loss 1.9115\n", 1028 | "step 3300: train loss 1.7594, val loss 1.9104\n", 1029 | "step 3400: train loss 1.7581, val loss 1.9020\n", 1030 | "step 3500: train loss 1.7395, val loss 1.9000\n", 1031 | "step 3600: train loss 1.7243, val loss 1.8912\n", 1032 | "step 3700: train loss 1.7286, val loss 1.8852\n", 1033 | "step 3800: train loss 1.7202, val loss 1.8941\n", 1034 | "step 3900: train loss 1.7196, val loss 1.8786\n", 1035 | "step 4000: train loss 1.7138, val loss 1.8650\n", 1036 | "step 4100: train loss 1.7089, val loss 1.8755\n", 1037 | "step 4200: train loss 1.7114, val loss 1.8695\n", 1038 | "step 4300: train loss 1.7011, val loss 1.8512\n", 1039 | "step 4400: train loss 1.7056, val loss 1.8719\n", 1040 | "step 4500: train loss 1.6865, val loss 1.8489\n", 1041 | "step 4600: train loss 1.6893, val loss 1.8387\n", 1042 | "step 4700: train loss 1.6835, val loss 1.8530\n", 1043 | "step 4800: train loss 1.6672, val loss 1.8503\n", 1044 | "step 4900: train loss 1.6722, val loss 1.8459\n", 1045 | "step 4999: train loss 1.6648, val loss 1.8298\n", 1046 | "\n", 1047 | "ROMEO:\n", 1048 | "But you freign'd I wish his migute:\n", 1049 | "Not duty I usly call fittle of where\n", 1050 | "whilend with their of that drungt upon upon of Frienct gliman;\n", 1051 | "This Iell threws you\n", 1052 | "Than prant: he begave: not to by betters,\n", 1053 | "Ot, tow, go: bitt, Ditire spres shall; there some not.\n", 1054 | "\n", 1055 | "LUSIYI his his love.\n", 1056 | "\n", 1057 | "HES OF GAUNTES:\n", 1058 | "And my frocend,\n", 1059 | "And by thou sovoure, I sidfut bace pillade but hith at hithger\n", 1060 | "Ban, brote in that let mit?\n", 1061 | "\n", 1062 | "PUMPEREY:\n", 1063 | "By fond, I newlinds of Henry; why whosess is them,\n", 1064 | "And not wis wan mad him to is forgivins:\n", 1065 | "Egguardsly come,\n", 1066 | "How my best besom staid us, of may,\n", 1067 | "But be us.\n", 1068 | "\n", 1069 | "BUCKINGHABHOM:\n", 1070 | "Away, from them kinds, take head I lay:-\n", 1071 | "Weat he'singforth, even that confer that you\n", 1072 | "Why deaths!\n", 1073 | "And, and the fraginst. I kno-mond Hers!\n", 1074 | "There, fornurnce allank: rompechried Afforry and\n", 1075 | "gate: come I mote it? would grave thou getray,\n", 1076 | "And with a peassed turn of to him.---\n", 1077 | "\n", 1078 | "HENRY My Pert, Romgout:\n", 1079 | "Onfull you, argive it which, you, profrince answeraw in cousan bulists.\n", 1080 | "\n", 1081 | "MENENIUS:\n", 1082 | "We beavehs exeny with.\n", 1083 | "\n", 1084 | "GLOUCENT:\n", 1085 | "I'll me, what tower suil on on that the would patch stile that than war\n", 1086 | "on way as plasusions, what broth tempind countent.\n", 1087 | "\n", 1088 | "PAUULENT:\n", 1089 | "\n", 1090 | "PROMEEY:\n", 1091 | "Vhall herew but they world that not my connor undolier. But, bramest him\n", 1092 | "Shall bety with my now Murst: heart or ants his morst.\n", 1093 | "Your bout to\n", 1094 | "not a rathing if anys a king that woulds no the said how\n", 1095 | "onderding, lords my behed like to be am\n", 1096 | "in That and I him\n", 1097 | "Augnmon's me, or play's, I and by\n", 1098 | "Wonteeth, niery will, we e slord:\n", 1099 | "That get because at that his say\n", 1100 | "a doth hourts goods a montems, and nobe.\n", 1101 | "\n", 1102 | "BUCKINGHUMBY:\n", 1103 | "Which forget to made yoour never.\n", 1104 | "\n", 1105 | "KING RICHARD II:\n", 1106 | "Who too nead?\n", 1107 | "\n", 1108 | "LORD MARCULK:\n", 1109 | "Down,\n", 1110 | "Then thou say mover a desbrick with.\n", 1111 | "Your hurscure. Citt\n", 1112 | "\n", 1113 | "FLORIZEL:\n", 1114 | "Rome'st: wears you have crownmy bewut I make\n", 1115 | "Here figh unfeed that wife thou behuse; thou behod us my pothing,\n", 1116 | "You will this morne twas be the flass that high on unportal nmerping that badreful in them,\n", 1117 | "At nocly from it to repot himss,\n", 1118 | "And what you guect in Con to sure;\n", 1119 | "For In before\n" 1120 | ] 1121 | } 1122 | ], 1123 | "source": [ 1124 | "import torch\n", 1125 | "import torch.nn as nn\n", 1126 | "from torch.nn import functional as F\n", 1127 | "\n", 1128 | "# hyperparameters\n", 1129 | "batch_size = 16 # how many independent sequences will we process in parallel?\n", 1130 | "block_size = 32 # what is the maximum context length for predictions?\n", 1131 | "max_iters = 5000\n", 1132 | "eval_interval = 100\n", 1133 | "learning_rate = 1e-3\n", 1134 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", 1135 | "eval_iters = 200\n", 1136 | "n_embd = 64\n", 1137 | "n_head = 4\n", 1138 | "n_layer = 4\n", 1139 | "dropout = 0.0\n", 1140 | "# ------------\n", 1141 | "\n", 1142 | "torch.manual_seed(1337)\n", 1143 | "\n", 1144 | "with open('C:/Users/rocks/OneDrive/Desktop/Projects/My-GPT/input.txt', 'r', encoding='utf-8') as f:\n", 1145 | " text = f.read()\n", 1146 | "\n", 1147 | "# here are all the unique characters that occur in this text\n", 1148 | "chars = sorted(list(set(text)))\n", 1149 | "vocab_size = len(chars)\n", 1150 | "# create a mapping from characters to integers\n", 1151 | "stoi = { ch:i for i,ch in enumerate(chars) }\n", 1152 | "itos = { i:ch for i,ch in enumerate(chars) }\n", 1153 | "encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers\n", 1154 | "decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string\n", 1155 | "\n", 1156 | "# Train and test splits\n", 1157 | "data = torch.tensor(encode(text), dtype=torch.long)\n", 1158 | "n = int(0.9*len(data)) # first 90% will be train, rest val\n", 1159 | "train_data = data[:n]\n", 1160 | "val_data = data[n:]\n", 1161 | "\n", 1162 | "# data loading\n", 1163 | "def get_batch(split):\n", 1164 | " # generate a small batch of data of inputs x and targets y\n", 1165 | " data = train_data if split == 'train' else val_data\n", 1166 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 1167 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 1168 | " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", 1169 | " x, y = x.to(device), y.to(device)\n", 1170 | " return x, y\n", 1171 | "\n", 1172 | "@torch.no_grad()\n", 1173 | "def estimate_loss():\n", 1174 | " out = {}\n", 1175 | " model.eval()\n", 1176 | " for split in ['train', 'val']:\n", 1177 | " losses = torch.zeros(eval_iters)\n", 1178 | " for k in range(eval_iters):\n", 1179 | " X, Y = get_batch(split)\n", 1180 | " logits, loss = model(X, Y)\n", 1181 | " losses[k] = loss.item()\n", 1182 | " out[split] = losses.mean()\n", 1183 | " model.train()\n", 1184 | " return out\n", 1185 | "\n", 1186 | "class Head(nn.Module):\n", 1187 | " \"\"\" one head of self-attention \"\"\"\n", 1188 | "\n", 1189 | " def __init__(self, head_size):\n", 1190 | " super().__init__()\n", 1191 | " self.key = nn.Linear(n_embd, head_size, bias=False)\n", 1192 | " self.query = nn.Linear(n_embd, head_size, bias=False)\n", 1193 | " self.value = nn.Linear(n_embd, head_size, bias=False)\n", 1194 | " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", 1195 | "\n", 1196 | " self.dropout = nn.Dropout(dropout)\n", 1197 | "\n", 1198 | " def forward(self, x):\n", 1199 | " B,T,C = x.shape\n", 1200 | " k = self.key(x) # (B,T,C)\n", 1201 | " q = self.query(x) # (B,T,C)\n", 1202 | " # compute attention scores (\"affinities\")\n", 1203 | " wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)\n", 1204 | " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)\n", 1205 | " wei = F.softmax(wei, dim=-1) # (B, T, T)\n", 1206 | " wei = self.dropout(wei)\n", 1207 | " # perform the weighted aggregation of the values\n", 1208 | " v = self.value(x) # (B,T,C)\n", 1209 | " out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)\n", 1210 | " return out\n", 1211 | "\n", 1212 | "class MultiHeadAttention(nn.Module):\n", 1213 | " \"\"\" multiple heads of self-attention in parallel \"\"\"\n", 1214 | "\n", 1215 | " def __init__(self, num_heads, head_size):\n", 1216 | " super().__init__()\n", 1217 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", 1218 | " self.proj = nn.Linear(n_embd, n_embd)\n", 1219 | " self.dropout = nn.Dropout(dropout)\n", 1220 | "\n", 1221 | " def forward(self, x):\n", 1222 | " out = torch.cat([h(x) for h in self.heads], dim=-1)\n", 1223 | " out = self.dropout(self.proj(out))\n", 1224 | " return out\n", 1225 | "\n", 1226 | "class FeedFoward(nn.Module):\n", 1227 | " \"\"\" a simple linear layer followed by a non-linearity \"\"\"\n", 1228 | "\n", 1229 | " def __init__(self, n_embd):\n", 1230 | " super().__init__()\n", 1231 | " self.net = nn.Sequential(\n", 1232 | " nn.Linear(n_embd, 4 * n_embd),\n", 1233 | " nn.ReLU(),\n", 1234 | " nn.Linear(4 * n_embd, n_embd),\n", 1235 | " nn.Dropout(dropout),\n", 1236 | " )\n", 1237 | "\n", 1238 | " def forward(self, x):\n", 1239 | " return self.net(x)\n", 1240 | "\n", 1241 | "class Block(nn.Module):\n", 1242 | " \"\"\" Transformer block: communication followed by computation \"\"\"\n", 1243 | "\n", 1244 | " def __init__(self, n_embd, n_head):\n", 1245 | " # n_embd: embedding dimension, n_head: the number of heads we'd like\n", 1246 | " super().__init__()\n", 1247 | " head_size = n_embd // n_head\n", 1248 | " self.sa = MultiHeadAttention(n_head, head_size)\n", 1249 | " self.ffwd = FeedFoward(n_embd)\n", 1250 | " self.ln1 = nn.LayerNorm(n_embd)\n", 1251 | " self.ln2 = nn.LayerNorm(n_embd)\n", 1252 | "\n", 1253 | " def forward(self, x):\n", 1254 | " x = x + self.sa(self.ln1(x))\n", 1255 | " x = x + self.ffwd(self.ln2(x))\n", 1256 | " return x\n", 1257 | "\n", 1258 | "# super simple bigram model\n", 1259 | "class BigramLanguageModel(nn.Module):\n", 1260 | "\n", 1261 | " def __init__(self):\n", 1262 | " super().__init__()\n", 1263 | " # each token directly reads off the logits for the next token from a lookup table\n", 1264 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embd)\n", 1265 | " self.position_embedding_table = nn.Embedding(block_size, n_embd)\n", 1266 | " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", 1267 | " self.ln_f = nn.LayerNorm(n_embd) # final layer norm\n", 1268 | " self.lm_head = nn.Linear(n_embd, vocab_size)\n", 1269 | "\n", 1270 | " def forward(self, idx, targets=None):\n", 1271 | " B, T = idx.shape\n", 1272 | "\n", 1273 | " # idx and targets are both (B,T) tensor of integers\n", 1274 | " tok_emb = self.token_embedding_table(idx) # (B,T,C)\n", 1275 | " pos_emb = self.position_embedding_table(torch.arange(T, device=device)) # (T,C)\n", 1276 | " x = tok_emb + pos_emb # (B,T,C)\n", 1277 | " x = self.blocks(x) # (B,T,C)\n", 1278 | " x = self.ln_f(x) # (B,T,C)\n", 1279 | " logits = self.lm_head(x) # (B,T,vocab_size)\n", 1280 | "\n", 1281 | " if targets is None:\n", 1282 | " loss = None\n", 1283 | " else:\n", 1284 | " B, T, C = logits.shape\n", 1285 | " logits = logits.view(B*T, C)\n", 1286 | " targets = targets.view(B*T)\n", 1287 | " loss = F.cross_entropy(logits, targets)\n", 1288 | "\n", 1289 | " return logits, loss\n", 1290 | "\n", 1291 | " def generate(self, idx, max_new_tokens):\n", 1292 | " # idx is (B, T) array of indices in the current context\n", 1293 | " for _ in range(max_new_tokens):\n", 1294 | " # crop idx to the last block_size tokens\n", 1295 | " idx_cond = idx[:, -block_size:]\n", 1296 | " # get the predictions\n", 1297 | " logits, loss = self(idx_cond)\n", 1298 | " # focus only on the last time step\n", 1299 | " logits = logits[:, -1, :] # becomes (B, C)\n", 1300 | " # apply softmax to get probabilities\n", 1301 | " probs = F.softmax(logits, dim=-1) # (B, C)\n", 1302 | " # sample from the distribution\n", 1303 | " idx_next = torch.multinomial(probs, num_samples=1) # (B, 1)\n", 1304 | " # append sampled index to the running sequence\n", 1305 | " idx = torch.cat((idx, idx_next), dim=1) # (B, T+1)\n", 1306 | " return idx\n", 1307 | "\n", 1308 | "model = BigramLanguageModel()\n", 1309 | "m = model.to(device)\n", 1310 | "# print the number of parameters in the model\n", 1311 | "print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')\n", 1312 | "\n", 1313 | "# create a PyTorch optimizer\n", 1314 | "optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)\n", 1315 | "\n", 1316 | "for iter in range(max_iters):\n", 1317 | "\n", 1318 | " # every once in a while evaluate the loss on train and val sets\n", 1319 | " if iter % eval_interval == 0 or iter == max_iters - 1:\n", 1320 | " losses = estimate_loss()\n", 1321 | " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", 1322 | "\n", 1323 | " # sample a batch of data\n", 1324 | " xb, yb = get_batch('train')\n", 1325 | "\n", 1326 | " # evaluate the loss\n", 1327 | " logits, loss = model(xb, yb)\n", 1328 | " optimizer.zero_grad(set_to_none=True)\n", 1329 | " loss.backward()\n", 1330 | " optimizer.step()\n", 1331 | "\n", 1332 | "# generate from the model\n", 1333 | "context = torch.zeros((1, 1), dtype=torch.long, device=device)\n", 1334 | "print(decode(m.generate(context, max_new_tokens=2000)[0].tolist()))" 1335 | ] 1336 | } 1337 | ], 1338 | "metadata": { 1339 | "kernelspec": { 1340 | "display_name": "Python 3", 1341 | "language": "python", 1342 | "name": "python3" 1343 | }, 1344 | "language_info": { 1345 | "codemirror_mode": { 1346 | "name": "ipython", 1347 | "version": 3 1348 | }, 1349 | "file_extension": ".py", 1350 | "mimetype": "text/x-python", 1351 | "name": "python", 1352 | "nbconvert_exporter": "python", 1353 | "pygments_lexer": "ipython3", 1354 | "version": "3.11.2" 1355 | }, 1356 | "orig_nbformat": 4 1357 | }, 1358 | "nbformat": 4, 1359 | "nbformat_minor": 2 1360 | } 1361 | -------------------------------------------------------------------------------- /more.txt: -------------------------------------------------------------------------------- 1 | 2 | The top in a world by susphoring grace. 3 | 4 | LUCIO: 5 | We muse hath resistes him so sovere: son't his other wrough 6 | stands of coverent sh'd: he has here, and stand it 7 | and poor exceeder or a Henry's last, stay 8 | not in faith, forewell's base of graves, thanks, happy comparel, 9 | warmentfully: may as face by the courst, that strangth 10 | errise hath breathed. Hastings come to Valenting. 11 | 12 | HERMIONE: 13 | Well have been bolly poor late 14 | Is the lords. 15 | 16 | ABELLA: 17 | Let's found: I will kind him; 18 | I do braw'sy him business wherein far his face. 19 | 20 | LUCENTIO: 21 | He is last afford: make him diseably to London, 22 | Take him great Hastings, boldness in his natic keeps, 23 | To oftragn lost me ready glust through the house. 24 | Why chose that I dares it be a Montague. 25 | 26 | MONTAGUE: 27 | Woe's Claudly Haste of his own at last the Volscient, 28 | And seen'd helpit: bearn to do it be, and most hop, 29 | Miscause's more conterar than without this lambs 30 | Shall down appla fortune flight flowers. 31 | 32 | FRIAR LAUAURENCE: 33 | His son, do your morself, that leaven your honours 34 | Sufferable in more and suffer five. 35 | A horse! High-graced York rights. And bother Montague 36 | That the caapter, that I soughd him; such a chooson 37 | Woes, that they have splight that care 38 | Fades the respect to her spult: betfore him, 39 | Un tell him up hine, or hope, that throw'st thou carry 40 | apied sing with wear over the plenting long stamper 41 | That doth butcherity. For love, what arful was an soldier 42 | That last twain of all and Romeo runly Froth. 43 | 44 | VALHASINA: 45 | Nobleman; go, then both groans to us. 46 | 47 | AUFIDIUS: 48 | O those prepation! 49 | 50 | AUFIDIUS: 51 | It is: ever crimty be a house. 52 | 53 | Second Citizen: 54 | We give heed. 55 | 56 | All Clarence, that makes not know work. The may say speak way. 57 | How is my sorrow to strange on the fares 58 | That which to play some called Margaret 59 | The state town outward's wife, as the foul sleep; 60 | Trickly of from thy blod'sty day blows here, 61 | And pratess that chrospiles stalk falls up, 62 | The world's hollow princhment, which should a bankind, 63 | At till naKaina-daughter tae truth, 64 | Craged from lares an oar that rems' stol-eat with blass. 65 | Those is sometimes well call the Tale, the rod, 66 | Submished his truth; Right states; but for ourselves, 67 | Claud not thy hand, addingness. 68 | No, there all conslent here pue the fault that yokUCHastisful 69 | From servant and folling 'em how that: be drunk, 70 | Set was halt be else, I will betwixt thee three with Tewar: 71 | I am their man before a vile bad amiss'd 72 | And thought have shorn'd the back-flowed of mine, 73 | And ne'er than this, they leave spectiff. 74 | I am to sure, 75 | To maintain on what rash thy dam of suddise! 76 | Thyself thee pays wither edge. 77 | God I am speak to-morrow's like, to me speak, 78 | Am 79 | Dash your deliverance, nitted tongue to study. 80 | But if you were could not love, if you such commands, 81 | Your ignoration lightnifies 82 | Sufficed hath granted a sacret 83 | Divine: minute hath too should be assured, 84 | Unless, heaven to themselvish, as I am, 85 | Hance my father bend to them speak; 86 | His the business' hath themselves; 87 | For his ordance: bow his hand, hell pluck my pet! 88 | What it brace there of his oath? 89 | Rather? Where, whilst thou garling feet Bark? aim? stay; 90 | So if He and him come, and make his mostake 91 | You forbid you had stoopp'd your grace. 92 | 93 | Servant: 94 | He may once it indeces: 95 | See do it between. 96 | 97 | Provost: 98 | Ah, sir! shall it stay the heavy nights. 99 | 100 | PeRDINA: 101 | Behind-foot, sir; three manner did he remiss 102 | no slain up is disconful: slave you breast-wish. 103 | 104 | HUMIO: 105 | Why more lose on than ime well so fofter townd 106 | you. 107 | 108 | LUCIO: 109 | How find, I must by our son? 110 | 111 | PRINCE EDWARD: 112 | Not I, sir? 113 | 114 | PETHUMIO: 115 | Base my fa-lor; I have ports to guilty: 116 | It string is remorse: seldiers, thou retirest that Titus; 117 | And I will have my close father- place: 118 | I have well kings your husband; he will flow. 119 | 120 | FRIAR: 121 | I am nother for your highness' remain. 122 | 123 | GREMIO: 124 | O, ho! sent me you, mighty lord. 125 | 126 | LEONTES: 127 | Woe's condemn! 128 | 129 | HERMIONE: 130 | It light bo continued. 131 | 132 | LEONTES: 133 | How? most need: 134 | Affections he hath before a knife stay: 135 | Since I can such add not that heard him was? 136 | 137 | LEONTES: 138 | Shall hest this lives. 139 | Not been dead, lord, Hortensio, Catesby thy nature, 140 | Stay beg the myrripg-neck continuiagement? 141 | If the rest, may be save it die. 142 | 143 | LEONTES: 144 | First? 145 | 146 | LADY CALEY: 147 | Peace! give me low. 148 | 149 | LUCIO: 150 | Now you have. 151 | What will the ways? 152 | 153 | TANTITOPHES: 154 | For thee are spects into the actple. 155 | 156 | LEONTES: 157 | O, compassion of city! How say you? 158 | 159 | LADY CAPULET: 160 | Softly I cape her, ahave; 'tis is boldly better 161 | That will speak of if death. There my sweety fault; 162 | adoption in all the reasons; maintise I banish o'er 163 | hath speak too at the object of his noble cousin: 164 | I of all I flatter with a harmy coward confess them. 165 | 166 | DUCHESS OF YORK: 167 | There shades Ourself God, faith, Somerset, 168 | How shall he furl at the sister ward with sight 169 | Which he elseman hath avided in my pale 170 | The spits of minist him of the ixty service; 171 | He beauty with respesses, and though for his rose-houses, 172 | He mpawring in the bench of farther closer, 173 | From to the hearts enstraved to prison! 174 | My purse in his sorrow witchch harded taunts, 175 | And not such sin a pagman's simple and ch 176 | His prisoner fled with sucknes. 177 | These have is dead executions, 178 | That I met to part; do praise him sick, as it 179 | I cannot ded. this you by not good with with-- 180 | Histonions slughte reward of foreEd tide 181 | Ennointments with the whose sallows I would have dream. 182 | 183 | LUCENTIO: 184 | Amen, my tears poor wooding Kence o'er me. 185 | Let me hear you better here. 186 | 187 | LUCENTIO: 188 | But speak best should dissolve that: 189 | Is as four, vantage we must hear it, look go, 190 | Before I then return no strait for the peoplicy, 191 | Before I did court with being to court his life; 192 | And therefore I must revented with Lord Copitol: 193 | Where pass's and loves him stem token afters, 194 | Whom thou lovest him cand lepther than suffer 195 | Woes year with himself ancient, 196 | Even such knees from your hands remembrance, 197 | And like in Paer laid sworn pray shate these comfort 198 | Than spproves I such better fire feg; 199 | And troaks hath the othem told them and Laurence 200 | That hade brankling Henry himself praging. 201 | For no hence, that say that Warwick say you, 202 | Tile unfrom no think; let without drown him oak: 203 | Being no sun; let him in ourself 204 | And say the prison. O, God almoster hath the tape, 205 | Trateful hearts, to your kissness shall the king. 206 | If that he, since his in the heap to 207 | Ord your garden pain to your issue with Roman patience; 208 | Urge for highwas fortings. 209 | 210 | MARCIUS: 211 | Spake you would men? 212 | 213 | First Citizen: 214 | This shame's light: it may keThat his dead, 215 | Elight dangly have with his painting 216 | Remain'd out of less, bit his tongue you of their 217 | , better drops, give for's purross, lords, 218 | 'daughted in his king in preputation; 219 | 'Tis not authority, he is coming town. A thou; 220 | For we have suck'd him show a creature to dear 221 | A blawford. 222 | 223 | MENENIUS: 224 | Marry, I am that. 225 | 226 | MARIUS: 227 | Agreepon't with your honours of lawfully 228 | Is planted. The vile field mont, is only now. 229 | 230 | COMINIUS: 231 | Be borning to you, free father. 232 | Please you mistress your will; say with me 233 | But, come, my nobleman: you'll have me all profanted 234 | What I have heard: mark'd not appointing of unhancy. 235 | There is the no: if I all the time and prepared, 236 | To hear you, repair'd all the narged of this 237 | That so fashen'd by you. A Vicutio's wife, 238 | For I am not; Here's wreck'd-for him the fire; 239 | For thou art at Margaret. 240 | 241 | MARCIUS: 242 | A a guard as he which, Englant stitterp's eye. 243 | 244 | First Senator: 245 | Our woman is not my sit much a little; 246 | Without duke's are for sucrelity blies, 247 | As cipaused him, I have earth 248 | Socing of it object, sthat upon our wretcher, 249 | Yet Vaughan, good nigh that wantingal before him. 250 | Come, an is act my grief, and my letter hence, 251 | I am proce are from my set the sad wars; 252 | I doubt not, and consul's day. 253 | 254 | VOMERDIA: 255 | Nay, 256 | It should, my lord, I did, sit be satisfiance 257 | With swinper, I choose you a snot that the man: 258 | I, beside, I would know, rail yours, whensing; 259 | I gave lands, I warried, and there's a man 260 | Than young king's chambers; and that he beats, a 261 | maskit, have I spoke a man passay. 262 | 263 | LEONTES: 264 | His: 265 | son father: I hofver I have have an said 'A' quarrelishmench, 266 | Preacers no power! 267 | 268 | CAPULET: 269 | These men. 270 | 271 | Shepherd: 272 | Meantime hate, holy great off my youngest adverse. 273 | A thum, give me supper, Within; as are not call'd in 274 | the repless thie: and they are warronged. 275 | 276 | SIR STEPHEN SCROOP: 277 | Either's defend the stain; 278 | It is no purpose off that Calanus. 279 | Come to Thurt, to heaven from that bless them. 280 | 281 | LORD BERLAND: 282 | Not passite; yet thou art quic the gols 283 | Hath more upon them. 284 | 285 | GRUMIO: 286 | Though hoart'st far us it. Go you to unfolve, 287 | You make against usurely please on ourselves, 288 | That with make misude isshape: shall perform, 289 | By God's nufiencence love to love with their morn. 290 | 291 | LORD WILLOUGHBY: 292 | The better weeps not strange; neforth mine at I 293 | Can with all: if they should us small as found 294 | The darting sight opposed in to in the war 295 | Unless tenors affection to hisson's infact. 296 | Swear hied our heart ribbors's blood? 297 | 298 | QUEEN: 299 | I do so, and late yourself 300 | So wast I have: bow I have 301 | My way, this fawn brackless the haste, 302 | Commandementance up from his face and war. 303 | 304 | Second Gentleman: 305 | Be bless be dead;'-- 306 | 'Thank'd you not soon to be have given all thomasks. 307 | 308 | GLOUCESTER: 309 | We'll, not the number; what will you are graved 310 | Till use it was withal proft to your counselsion. 311 | 312 | KING EDWARD IV: 313 | 'Tis near worthy children, for your lordship as flay. 314 | Best, with what made her on this face? 315 | Better leave me in fine ears and heir for; 316 | A brief and old may be and givish thy grave: 317 | To be my life, 318 | And fear thee the woman to these foul duty! 319 | 320 | DUKE OF YORK: 321 | Pardon mayor yourself be longed: 322 | Alague their laws, and given your cut offinger, 323 | Stand all asleep you can ear faults; 324 | Crown thou dear of thee against and up! 325 | This moonty some prepare merry Poor, 326 | Even so tell in them a paragment. If he, it had 327 | The drink easemt sore gentlemen, and 328 | Caius eseements of courtain: the heir sea would 329 | safests those own would so sever. 330 | 331 | HENRY' 332 | 333 | BOLINGBROKE: 334 | This winter's wisdover-- 335 | Here no man despain, their hands: say hath they given 336 | Had none he enforpore himself? or he is , 337 | His hand lived, as they in his deposech'd; 338 | And we make a sistor selmit on once, 339 | The gear of less and royal night but arms 340 | Against away; and why no --------------------------------------------------------------------------------