├── .gitignore ├── 2.2_tokenizing_text.ipynb ├── 2.5_byte_pair_encoding.ipynb ├── 2.6_data_sampling.ipynb ├── 2.7_token_embeddings.ipynb ├── 3.2_attention_mechanisms.ipynb ├── 3.5_casual_attention.ipynb ├── 3.6_multihead_attention.ipynb ├── 4.1_GPT_architecture.ipynb ├── 4.6_GPT_model.ipynb ├── 5.1_evaluation.ipynb ├── 5.2_training_llms.ipynb ├── 5.3_decoding_strategies.ipynb ├── 5.5_Loading_pretrained_model_from_OpenAI.ipynb ├── 5.5_convert-gpt2-media-from-OpenAI.ipynb ├── 6.GPT_classification_finetuning.ipynb ├── 7.2_Preparing_supervised_instruction_finetuning.ipynb ├── 7.GPT_instruction_finetuning.ipynb ├── README.md ├── ch02 └── the-verdict.txt ├── ch05 ├── download_model.py ├── gpt-to-llama.pdf ├── gpt-to-llama2.ipynb ├── gpt_download.py └── train_plot.png ├── ch06 ├── accuracy-plot.pdf ├── loss-plot.pdf ├── test.csv ├── train.csv └── valid.csv ├── ch07 └── instruction-data.json ├── codes ├── __init__.py ├── configs.py ├── data.py ├── gpt_model.py ├── losses.py ├── model_convert.py ├── plots.py ├── solver.py └── utils.py └── train_llms_from_scratch.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | *.pyc 3 | *.pth 4 | .ipynb_checkpoints/ 5 | #*.csv 6 | gpt2/ 7 | *.zip 8 | sms_spam_collection/ 9 | __pycache__/ -------------------------------------------------------------------------------- /2.2_tokenizing_text.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "83c52fce", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "data": { 11 | "text/plain": [ 12 | "('ch02/the-verdict.txt', )" 13 | ] 14 | }, 15 | "execution_count": 1, 16 | "metadata": {}, 17 | "output_type": "execute_result" 18 | } 19 | ], 20 | "source": [ 21 | "import urllib.request as request\n", 22 | "\n", 23 | "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch02/01_main-chapter-code/the-verdict.txt\"\n", 24 | "file_path = \"ch02/the-verdict.txt\"\n", 25 | "\n", 26 | "request.urlretrieve(url, file_path)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "code", 31 | "execution_count": 3, 32 | "id": "a1e2d7cd", 33 | "metadata": {}, 34 | "outputs": [ 35 | { 36 | "name": "stdout", 37 | "output_type": "stream", 38 | "text": [ 39 | "Total number of character: 20479\n", 40 | "I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no \n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "with open(file_path, \"r\", encoding=\"utf-8\") as f:\n", 46 | " raw_data = f.read()\n", 47 | "\n", 48 | "print(\"Total number of character:\", len(raw_data))\n", 49 | "print(raw_data[:99])" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 7, 55 | "id": "bf77377c", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "['Hello,', ' ', 'world.', ' ', 'This', ' ', 'is', ' ', 'a', ' ', 'test.']\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "import re\n", 68 | "text = \"Hello, world. This is a test.\"\n", 69 | "result = re.split(r'(\\s)', text)\n", 70 | "print(result)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "code", 75 | "execution_count": 9, 76 | "id": "4c9dff78", 77 | "metadata": {}, 78 | "outputs": [ 79 | { 80 | "name": "stdout", 81 | "output_type": "stream", 82 | "text": [ 83 | "['Hello', ',', '', ' ', 'world', '.', '', ' ', 'This', ' ', 'is', ' ', 'a', ' ', 'test', '.', '']\n" 84 | ] 85 | } 86 | ], 87 | "source": [ 88 | "result = re.split(r\"([,.]|\\s)\", text) # splits on whitespace, commas, and periods\n", 89 | "print(result)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "code", 94 | "execution_count": 10, 95 | "id": "9f0238c7", 96 | "metadata": {}, 97 | "outputs": [ 98 | { 99 | "name": "stdout", 100 | "output_type": "stream", 101 | "text": [ 102 | "['Hello', ',', 'world', '.', 'This', 'is', 'a', 'test', '.']\n" 103 | ] 104 | } 105 | ], 106 | "source": [ 107 | "result = [it for it in result if it.strip()] # remove whitespace charachters\n", 108 | "print(result)" 109 | ] 110 | }, 111 | { 112 | "cell_type": "code", 113 | "execution_count": 11, 114 | "id": "6a1d85ab", 115 | "metadata": {}, 116 | "outputs": [ 117 | { 118 | "name": "stdout", 119 | "output_type": "stream", 120 | "text": [ 121 | "['Hello', ',', 'world', '.', 'Is', 'this', '--', 'a', 'test', '?']\n" 122 | ] 123 | } 124 | ], 125 | "source": [ 126 | "text = \"Hello, world. Is this-- a test?\"\n", 127 | "result = re.split(r'([,.:;?_!\"()\\']|--|\\s)', text)\n", 128 | "result = [it.strip() for it in result if it.strip()]\n", 129 | "print(result)" 130 | ] 131 | }, 132 | { 133 | "cell_type": "code", 134 | "execution_count": 12, 135 | "id": "a32ae220", 136 | "metadata": {}, 137 | "outputs": [], 138 | "source": [ 139 | "def tokenizer(text):\n", 140 | " result = re.split(r'([,.:;?_!\"()\\']|--|\\s)', text)\n", 141 | " result = [it.strip() for it in result if it.strip()]\n", 142 | " return result" 143 | ] 144 | }, 145 | { 146 | "cell_type": "code", 147 | "execution_count": 13, 148 | "id": "e4a1d51f", 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "preprocessed = tokenizer(raw_data) # preprocess the whole text data" 153 | ] 154 | }, 155 | { 156 | "cell_type": "code", 157 | "execution_count": 14, 158 | "id": "2f0bad79", 159 | "metadata": {}, 160 | "outputs": [ 161 | { 162 | "name": "stdout", 163 | "output_type": "stream", 164 | "text": [ 165 | "4690 ['I', 'HAD', 'always', 'thought', 'Jack', 'Gisburn', 'rather', 'a', 'cheap', 'genius', '--', 'though', 'a', 'good', 'fellow', 'enough', '--', 'so', 'it', 'was', 'no', 'great', 'surprise', 'to', 'me', 'to', 'hear', 'that', ',', 'in']\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "print(len(preprocessed), preprocessed[:30])" 171 | ] 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": 15, 176 | "id": "8e2d8f0b", 177 | "metadata": {}, 178 | "outputs": [ 179 | { 180 | "name": "stdout", 181 | "output_type": "stream", 182 | "text": [ 183 | "1130\n" 184 | ] 185 | } 186 | ], 187 | "source": [ 188 | "all_words = sorted(set(preprocessed)) # sorting the unique tokens\n", 189 | "vocab_size = len(all_words)\n", 190 | "print(vocab_size)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": 17, 196 | "id": "3e04aa62", 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "! 0\n", 204 | "\" 1\n", 205 | "' 2\n", 206 | "( 3\n", 207 | ") 4\n", 208 | ", 5\n", 209 | "-- 6\n", 210 | ". 7\n", 211 | ": 8\n", 212 | "; 9\n", 213 | "? 10\n", 214 | "A 11\n", 215 | "Ah 12\n", 216 | "Among 13\n", 217 | "And 14\n", 218 | "Are 15\n", 219 | "Arrt 16\n", 220 | "As 17\n", 221 | "At 18\n", 222 | "Be 19\n", 223 | "Begin 20\n", 224 | "Burlington 21\n", 225 | "But 22\n", 226 | "By 23\n", 227 | "Carlo 24\n", 228 | "Chicago 25\n", 229 | "Claude 26\n", 230 | "Come 27\n", 231 | "Croft 28\n", 232 | "Destroyed 29\n", 233 | "Devonshire 30\n", 234 | "Don 31\n", 235 | "Dubarry 32\n", 236 | "Emperors 33\n", 237 | "Florence 34\n", 238 | "For 35\n", 239 | "Gallery 36\n", 240 | "Gideon 37\n", 241 | "Gisburn 38\n", 242 | "Gisburns 39\n", 243 | "Grafton 40\n", 244 | "Greek 41\n", 245 | "Grindle 42\n", 246 | "Grindles 43\n", 247 | "HAD 44\n", 248 | "Had 45\n", 249 | "Hang 46\n", 250 | "Has 47\n", 251 | "He 48\n", 252 | "Her 49\n", 253 | "Hermia 50\n" 254 | ] 255 | } 256 | ], 257 | "source": [ 258 | "vocab = {token:index for index,token in enumerate(all_words)}\n", 259 | "for i, item in vocab.items():\n", 260 | " print(i, item)\n", 261 | " if item >= 50:\n", 262 | " break" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 24, 268 | "id": "6e494d0e", 269 | "metadata": {}, 270 | "outputs": [], 271 | "source": [ 272 | "class SimpleTokenizerV1(object):\n", 273 | " def __init__(self, vocab):\n", 274 | " self.str2int = vocab\n", 275 | " self.int2str = {i:s for s,i in vocab.items()}\n", 276 | " \n", 277 | " def encode(self, text): # string to token ids\n", 278 | " preprocessed = tokenizer(text)\n", 279 | " ids = [self.str2int[s] for s in preprocessed]\n", 280 | " return ids\n", 281 | " \n", 282 | " def decode(self, ids): # token ids to string\n", 283 | " text = \" \".join([self.int2str[i] for i in ids])\n", 284 | " text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text) # remove spaces before the specified punctuation\n", 285 | " return text" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": 25, 291 | "id": "26744966", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "my_tokenizer = SimpleTokenizerV1(vocab)" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 26, 301 | "id": "ca1cc1c7", 302 | "metadata": {}, 303 | "outputs": [ 304 | { 305 | "name": "stdout", 306 | "output_type": "stream", 307 | "text": [ 308 | "[1, 56, 2, 850, 988, 602, 533, 746, 5, 1126, 596, 5, 1, 67, 7, 38, 851, 1108, 754, 793, 7]\n" 309 | ] 310 | } 311 | ], 312 | "source": [ 313 | "text = \"\"\"\"It's the last he painted, you know, \"\n", 314 | " Mrs. Gisburn said with pardonable pride.\"\"\"\n", 315 | "ids = my_tokenizer.encode(text)\n", 316 | "print(ids)" 317 | ] 318 | }, 319 | { 320 | "cell_type": "code", 321 | "execution_count": 27, 322 | "id": "f9b41b32", 323 | "metadata": {}, 324 | "outputs": [ 325 | { 326 | "name": "stdout", 327 | "output_type": "stream", 328 | "text": [ 329 | "\" It' s the last he painted, you know,\" Mrs. Gisburn said with pardonable pride.\n" 330 | ] 331 | } 332 | ], 333 | "source": [ 334 | "print(my_tokenizer.decode(ids))" 335 | ] 336 | }, 337 | { 338 | "cell_type": "code", 339 | "execution_count": 32, 340 | "id": "a3ff82a5", 341 | "metadata": {}, 342 | "outputs": [ 343 | { 344 | "name": "stdout", 345 | "output_type": "stream", 346 | "text": [ 347 | "1132\n" 348 | ] 349 | } 350 | ], 351 | "source": [ 352 | "all_words = sorted(set(preprocessed)) # sorting the unique tokens\n", 353 | "all_words.extend([\"<|unk|>\", \"<|endoftext|>\"])\n", 354 | "vocab_size = len(all_words)\n", 355 | "print(vocab_size)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 33, 361 | "id": "c493a171", 362 | "metadata": {}, 363 | "outputs": [ 364 | { 365 | "name": "stdout", 366 | "output_type": "stream", 367 | "text": [ 368 | "('younger', 1127)\n", 369 | "('your', 1128)\n", 370 | "('yourself', 1129)\n", 371 | "('<|unk|>', 1130)\n", 372 | "('<|endoftext|>', 1131)\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "vocab = {token:index for index,token in enumerate(all_words)}\n", 378 | "for i, item in list(enumerate(vocab.items()))[-5:]:\n", 379 | " print(item)" 380 | ] 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 42, 385 | "id": "48ec21fa", 386 | "metadata": {}, 387 | "outputs": [], 388 | "source": [ 389 | "class SimpleTokenizerV2(object):\n", 390 | " def __init__(self, vocab):\n", 391 | " self.str2int = vocab\n", 392 | " self.int2str = {i:s for s,i in vocab.items()}\n", 393 | " self.unk = \"<|unk|>\"\n", 394 | " self.eof = \"<|endoftext|>\"\n", 395 | " \n", 396 | " def encode(self, text): # string to token ids\n", 397 | " preprocessed = tokenizer(text)\n", 398 | " ids = [self.str2int[s] if s in self.str2int else self.str2int[self.unk] for s in preprocessed]\n", 399 | " return ids\n", 400 | " \n", 401 | " def decode(self, ids): # token ids to string\n", 402 | " text = \" \".join([self.int2str[i] for i in ids])\n", 403 | " text = re.sub(r'\\s+([,.?!\"()\\'])', r'\\1', text) # remove spaces before the specified punctuation\n", 404 | " return text\n" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": 36, 410 | "id": "e8f4ef96", 411 | "metadata": {}, 412 | "outputs": [ 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | "Hello, do you like tea? <|endoftext|> In the sunlit terraces of the palace.\n" 418 | ] 419 | } 420 | ], 421 | "source": [ 422 | "text1 = \"Hello, do you like tea?\"\n", 423 | "text2 = \"In the sunlit terraces of the palace.\"\n", 424 | "text = \" <|endoftext|> \".join([text1, text2])\n", 425 | "print(text)" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 43, 431 | "id": "45efcf8f", 432 | "metadata": {}, 433 | "outputs": [], 434 | "source": [ 435 | "my_tokenizer = SimpleTokenizerV2(vocab)" 436 | ] 437 | }, 438 | { 439 | "cell_type": "code", 440 | "execution_count": 44, 441 | "id": "2b6d7f60", 442 | "metadata": {}, 443 | "outputs": [ 444 | { 445 | "name": "stdout", 446 | "output_type": "stream", 447 | "text": [ 448 | "[1130, 5, 355, 1126, 628, 975, 10, 1131, 55, 988, 956, 984, 722, 988, 1130, 7]\n" 449 | ] 450 | } 451 | ], 452 | "source": [ 453 | "ids = my_tokenizer.encode(text)\n", 454 | "print(ids)" 455 | ] 456 | }, 457 | { 458 | "cell_type": "code", 459 | "execution_count": 45, 460 | "id": "6017a670", 461 | "metadata": {}, 462 | "outputs": [ 463 | { 464 | "name": "stdout", 465 | "output_type": "stream", 466 | "text": [ 467 | "<|unk|>, do you like tea? <|endoftext|> In the sunlit terraces of the <|unk|>.\n" 468 | ] 469 | } 470 | ], 471 | "source": [ 472 | "print(my_tokenizer.decode(ids))" 473 | ] 474 | }, 475 | { 476 | "cell_type": "code", 477 | "execution_count": null, 478 | "id": "5b6bfff8", 479 | "metadata": {}, 480 | "outputs": [], 481 | "source": [] 482 | } 483 | ], 484 | "metadata": { 485 | "kernelspec": { 486 | "display_name": "Python 3 (ipykernel)", 487 | "language": "python", 488 | "name": "python3" 489 | }, 490 | "language_info": { 491 | "codemirror_mode": { 492 | "name": "ipython", 493 | "version": 3 494 | }, 495 | "file_extension": ".py", 496 | "mimetype": "text/x-python", 497 | "name": "python", 498 | "nbconvert_exporter": "python", 499 | "pygments_lexer": "ipython3", 500 | "version": "3.11.4" 501 | } 502 | }, 503 | "nbformat": 4, 504 | "nbformat_minor": 5 505 | } 506 | -------------------------------------------------------------------------------- /2.5_byte_pair_encoding.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "887ca684", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "0.8.0\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import tiktoken\n", 19 | "print(tiktoken.__version__)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "c69cc34c", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tokenizer = tiktoken.get_encoding(\"gpt2\")" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 6, 35 | "id": "261b6182", 36 | "metadata": {}, 37 | "outputs": [], 38 | "source": [ 39 | "text1 = \"Hello, do you like tea?\"\n", 40 | "text2 = \"In the sunlit terraces of the someunkPalace.\"\n", 41 | "text = \" <|endoftext|> \".join([text1, text2])" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": 7, 47 | "id": "72e53baa", 48 | "metadata": {}, 49 | "outputs": [ 50 | { 51 | "name": "stdout", 52 | "output_type": "stream", 53 | "text": [ 54 | "[15496, 11, 466, 345, 588, 8887, 30, 220, 50256, 554, 262, 4252, 18250, 8812, 2114, 286, 262, 617, 2954, 11531, 558, 13]\n" 55 | ] 56 | } 57 | ], 58 | "source": [ 59 | "integers = tokenizer.encode(text, allowed_special={\"<|endoftext|>\"})\n", 60 | "print(integers)" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": 8, 66 | "id": "696fe9d1", 67 | "metadata": {}, 68 | "outputs": [ 69 | { 70 | "name": "stdout", 71 | "output_type": "stream", 72 | "text": [ 73 | "Hello, do you like tea? <|endoftext|> In the sunlit terraces of the someunkPalace.\n" 74 | ] 75 | } 76 | ], 77 | "source": [ 78 | "print(tokenizer.decode(integers))" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "25de20a5", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [] 88 | } 89 | ], 90 | "metadata": { 91 | "kernelspec": { 92 | "display_name": "Python 3 (ipykernel)", 93 | "language": "python", 94 | "name": "python3" 95 | }, 96 | "language_info": { 97 | "codemirror_mode": { 98 | "name": "ipython", 99 | "version": 3 100 | }, 101 | "file_extension": ".py", 102 | "mimetype": "text/x-python", 103 | "name": "python", 104 | "nbconvert_exporter": "python", 105 | "pygments_lexer": "ipython3", 106 | "version": "3.9.5" 107 | } 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 5 111 | } 112 | -------------------------------------------------------------------------------- /2.6_data_sampling.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "f566abe9", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "0.8.0\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "import tiktoken\n", 19 | "print(tiktoken.__version__)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": 2, 25 | "id": "04b5333e", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "tokenizer = tiktoken.get_encoding(\"gpt2\")" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": 3, 35 | "id": "909081a4", 36 | "metadata": {}, 37 | "outputs": [ 38 | { 39 | "name": "stdout", 40 | "output_type": "stream", 41 | "text": [ 42 | "5145\n" 43 | ] 44 | } 45 | ], 46 | "source": [ 47 | "with open(\"ch02/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n", 48 | " raw_data = f.read()\n", 49 | " \n", 50 | "enc_text = tokenizer.encode(raw_data)\n", 51 | "print(len(enc_text))" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 5, 57 | "id": "a3e8c483", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "enc_sample = enc_text[50:]" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": 7, 67 | "id": "78e7d26f", 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stdout", 72 | "output_type": "stream", 73 | "text": [ 74 | "x: [290, 4920, 2241, 287]\n", 75 | "y: [4920, 2241, 287, 257]\n" 76 | ] 77 | } 78 | ], 79 | "source": [ 80 | "context_size = 4\n", 81 | "x = enc_sample[:context_size]\n", 82 | "y = enc_sample[1:context_size+1]\n", 83 | "\n", 84 | "print(f\"x: {x}\")\n", 85 | "print(f\"y: {y}\")" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 9, 91 | "id": "2bab2e77", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "[290] ---> 4920\n", 99 | "[290, 4920] ---> 2241\n", 100 | "[290, 4920, 2241] ---> 287\n", 101 | "[290, 4920, 2241, 287] ---> 257\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "for i in range(1, context_size + 1):\n", 107 | " context = enc_sample[:i]\n", 108 | " target = enc_sample[i]\n", 109 | " \n", 110 | " print(context, \"--->\", target)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 10, 116 | "id": "d22703f4", 117 | "metadata": {}, 118 | "outputs": [ 119 | { 120 | "name": "stdout", 121 | "output_type": "stream", 122 | "text": [ 123 | " and ---> established\n", 124 | " and established ---> himself\n", 125 | " and established himself ---> in\n", 126 | " and established himself in ---> a\n" 127 | ] 128 | } 129 | ], 130 | "source": [ 131 | "for i in range(1, context_size+1):\n", 132 | " context = enc_sample[:i]\n", 133 | " target = enc_sample[i]\n", 134 | " print(tokenizer.decode(context), \"--->\", tokenizer.decode([target]))" 135 | ] 136 | }, 137 | { 138 | "cell_type": "code", 139 | "execution_count": 2, 140 | "id": "ead4ff4b", 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "import torch\n", 145 | "from torch.utils.data import Dataset, DataLoader" 146 | ] 147 | }, 148 | { 149 | "cell_type": "code", 150 | "execution_count": 3, 151 | "id": "cb982dfc", 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "class GPTDatasetV2(Dataset):\n", 156 | " def __init__(self, text, tokenizer, max_length, stride):\n", 157 | " self.input_ids = []\n", 158 | " self.target_ids = []\n", 159 | "\n", 160 | " token_ids = tokenizer.encode(text)\n", 161 | " for i in range(0, len(token_ids) - max_length, stride):\n", 162 | " input_chunk = token_ids[i:i+max_length]\n", 163 | " target_chunk = token_ids[i+1:i+max_length+1]\n", 164 | "\n", 165 | " self.input_ids.append(torch.tensor(input_chunk))\n", 166 | " self.target_ids.append(torch.tensor(target_chunk))\n", 167 | "\n", 168 | " def __len__(self):\n", 169 | " return len(self.input_ids)\n", 170 | "\n", 171 | " def __getitem__(self, index):\n", 172 | " return self.input_ids[index], self.target_ids[index]\n" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": 4, 178 | "id": "943a9bd3-6694-473a-8005-343228017f80", 179 | "metadata": {}, 180 | "outputs": [], 181 | "source": [ 182 | "def create_dataloader_v1(text, batch_size=4, max_length=256, stride=128, shuffle=True, drop_last=True, num_workers=0):\n", 183 | " tokenizer = tiktoken.get_encoding(\"gpt2\")\n", 184 | "\n", 185 | " dataset = GPTDatasetV2(text, tokenizer, max_length, stride)\n", 186 | " dataloader = DataLoader(\n", 187 | " dataset, \n", 188 | " batch_size=batch_size, \n", 189 | " shuffle=shuffle, \n", 190 | " drop_last=drop_last,\n", 191 | " num_workers=num_workers\n", 192 | " )\n", 193 | "\n", 194 | " return dataloader" 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 7, 200 | "id": "d1162e2f-9ed0-4903-868f-33d9ef5f0565", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "[tensor([[ 40, 367, 2885, 1464]]), tensor([[ 367, 2885, 1464, 1807]])]\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "with open(\"the-verdict.txt\", \"r\", encoding=\"utf-8\") as fin:\n", 213 | " raw_data = fin.read()\n", 214 | "\n", 215 | "dataloader = create_dataloader_v1(raw_data, batch_size=1, max_length=4, stride=1, shuffle=False)\n", 216 | "data_iter = iter(dataloader)\n", 217 | "\n", 218 | "first_batch = next(data_iter)\n", 219 | "\n", 220 | "print(first_batch)" 221 | ] 222 | }, 223 | { 224 | "cell_type": "code", 225 | "execution_count": 8, 226 | "id": "fd631af9-f81e-4031-89db-c0fdfdaf489f", 227 | "metadata": {}, 228 | "outputs": [ 229 | { 230 | "name": "stdout", 231 | "output_type": "stream", 232 | "text": [ 233 | "[tensor([[ 40, 367, 2885, 1464],\n", 234 | " [ 1807, 3619, 402, 271],\n", 235 | " [10899, 2138, 257, 7026],\n", 236 | " [15632, 438, 2016, 257],\n", 237 | " [ 922, 5891, 1576, 438],\n", 238 | " [ 568, 340, 373, 645],\n", 239 | " [ 1049, 5975, 284, 502],\n", 240 | " [ 284, 3285, 326, 11]]), tensor([[ 367, 2885, 1464, 1807],\n", 241 | " [ 3619, 402, 271, 10899],\n", 242 | " [ 2138, 257, 7026, 15632],\n", 243 | " [ 438, 2016, 257, 922],\n", 244 | " [ 5891, 1576, 438, 568],\n", 245 | " [ 340, 373, 645, 1049],\n", 246 | " [ 5975, 284, 502, 284],\n", 247 | " [ 3285, 326, 11, 287]])]\n" 248 | ] 249 | } 250 | ], 251 | "source": [ 252 | "dataloader = create_dataloader_v1(raw_data, batch_size=8, max_length=4, stride=4, shuffle=False)\n", 253 | "data_iter = iter(dataloader)\n", 254 | "\n", 255 | "first_batch = next(data_iter)\n", 256 | "\n", 257 | "print(first_batch)" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": null, 263 | "id": "ba355b65-2774-4632-8afa-3120c1c5e647", 264 | "metadata": {}, 265 | "outputs": [], 266 | "source": [] 267 | }, 268 | { 269 | "cell_type": "code", 270 | "execution_count": 10, 271 | "id": "76418266-97da-43b4-8689-0ce9fca6f7a0", 272 | "metadata": {}, 273 | "outputs": [], 274 | "source": [ 275 | "import torch.nn as nn\n", 276 | "\n", 277 | "vocab_size = 50257\n", 278 | "embed_dim = 256\n", 279 | "\n", 280 | "vocab = nn.Embedding(vocab_size, embed_dim)" 281 | ] 282 | }, 283 | { 284 | "cell_type": "code", 285 | "execution_count": 11, 286 | "id": "d9604719-8598-49b6-a72c-24ae183f8399", 287 | "metadata": {}, 288 | "outputs": [ 289 | { 290 | "name": "stdout", 291 | "output_type": "stream", 292 | "text": [ 293 | "[tensor([[ 40, 367, 2885, 1464],\n", 294 | " [ 1807, 3619, 402, 271],\n", 295 | " [10899, 2138, 257, 7026],\n", 296 | " [15632, 438, 2016, 257],\n", 297 | " [ 922, 5891, 1576, 438],\n", 298 | " [ 568, 340, 373, 645],\n", 299 | " [ 1049, 5975, 284, 502],\n", 300 | " [ 284, 3285, 326, 11]]), tensor([[ 367, 2885, 1464, 1807],\n", 301 | " [ 3619, 402, 271, 10899],\n", 302 | " [ 2138, 257, 7026, 15632],\n", 303 | " [ 438, 2016, 257, 922],\n", 304 | " [ 5891, 1576, 438, 568],\n", 305 | " [ 340, 373, 645, 1049],\n", 306 | " [ 5975, 284, 502, 284],\n", 307 | " [ 3285, 326, 11, 287]])]\n" 308 | ] 309 | } 310 | ], 311 | "source": [ 312 | "data_loader = create_dataloader_v1(raw_data, batch_size=8, max_length=4, stride=4, shuffle=False)\n", 313 | "data_iter = iter(dataloader)\n", 314 | "\n", 315 | "first_batch = next(data_iter)\n", 316 | "\n", 317 | "print(first_batch)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 12, 323 | "id": "ba30a406-992f-41b7-bdd0-442b3c58d6e4", 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "input shape torch.Size([8, 4])\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "inputs, target = first_batch\n", 336 | "print(\"input shape\", inputs.shape)" 337 | ] 338 | }, 339 | { 340 | "cell_type": "code", 341 | "execution_count": 13, 342 | "id": "b4057ab7-04e8-4c87-a1a8-9c35ca2637cc", 343 | "metadata": {}, 344 | "outputs": [ 345 | { 346 | "name": "stdout", 347 | "output_type": "stream", 348 | "text": [ 349 | "torch.Size([8, 4, 256])\n" 350 | ] 351 | } 352 | ], 353 | "source": [ 354 | "token_embedding = vocab(inputs)\n", 355 | "print(token_embedding.shape)" 356 | ] 357 | }, 358 | { 359 | "cell_type": "code", 360 | "execution_count": 14, 361 | "id": "cd9a032e-cf19-411a-9801-0fd147555786", 362 | "metadata": {}, 363 | "outputs": [], 364 | "source": [ 365 | "context_length = 4\n", 366 | "position_embedding_layer = nn.Embedding(context_length, embed_dim)\n", 367 | "position_embeddings = position_embedding_layer(torch.arange(context_length))" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": 15, 373 | "id": "fac8eb41-c6c2-4784-b192-03ad6436d794", 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "torch.Size([4, 256])\n" 381 | ] 382 | } 383 | ], 384 | "source": [ 385 | "print(position_embeddings.shape)" 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 16, 391 | "id": "701e1941-63ee-4459-b23f-0add245110bd", 392 | "metadata": {}, 393 | "outputs": [], 394 | "source": [ 395 | "input_embeddings = token_embedding + position_embeddings" 396 | ] 397 | }, 398 | { 399 | "cell_type": "code", 400 | "execution_count": 17, 401 | "id": "d274bb4d-aa78-4373-9c33-81254f70a00d", 402 | "metadata": {}, 403 | "outputs": [ 404 | { 405 | "name": "stdout", 406 | "output_type": "stream", 407 | "text": [ 408 | "torch.Size([8, 4, 256])\n" 409 | ] 410 | } 411 | ], 412 | "source": [ 413 | "print(input_embeddings.shape)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "6ff73744-c9fc-4cf4-809b-a921b649218b", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | } 424 | ], 425 | "metadata": { 426 | "kernelspec": { 427 | "display_name": "Python 3 (ipykernel)", 428 | "language": "python", 429 | "name": "python3" 430 | }, 431 | "language_info": { 432 | "codemirror_mode": { 433 | "name": "ipython", 434 | "version": 3 435 | }, 436 | "file_extension": ".py", 437 | "mimetype": "text/x-python", 438 | "name": "python", 439 | "nbconvert_exporter": "python", 440 | "pygments_lexer": "ipython3", 441 | "version": "3.11.4" 442 | } 443 | }, 444 | "nbformat": 4, 445 | "nbformat_minor": 5 446 | } 447 | -------------------------------------------------------------------------------- /2.7_token_embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "id": "a373b0c4-7a6e-4252-a20d-6f5a0e989113", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 5, 17 | "id": "e3d6d72b-1f68-4d5c-831c-477a0198b495", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "input_ids = torch.tensor([2,3,5,1])" 22 | ] 23 | }, 24 | { 25 | "cell_type": "code", 26 | "execution_count": 6, 27 | "id": "29f47ee8-e02d-4eaa-8ecf-9bfe3d63d5f4", 28 | "metadata": {}, 29 | "outputs": [ 30 | { 31 | "name": "stdout", 32 | "output_type": "stream", 33 | "text": [ 34 | "Parameter containing:\n", 35 | "tensor([[ 0.3374, -0.1778, -0.1690],\n", 36 | " [ 0.9178, 1.5810, 1.3010],\n", 37 | " [ 1.2753, -0.2010, -0.1606],\n", 38 | " [-0.4015, 0.9666, -1.1481],\n", 39 | " [-1.1589, 0.3255, -0.6315],\n", 40 | " [-2.8400, -0.7849, -1.4096]], requires_grad=True)\n" 41 | ] 42 | } 43 | ], 44 | "source": [ 45 | "vocab_size = 6\n", 46 | "embed_dim = 3\n", 47 | "\n", 48 | "torch.manual_seed(123)\n", 49 | "embed_layer = nn.Embedding(vocab_size, embed_dim)\n", 50 | "print(embed_layer.weight)" 51 | ] 52 | }, 53 | { 54 | "cell_type": "code", 55 | "execution_count": 7, 56 | "id": "12326b6c-5f78-4972-923d-4efd54c91592", 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "vocab_size = 50257\n", 61 | "embed_dim = 256\n", 62 | "\n", 63 | "vocab = nn.Embedding(vocab_size, embed_dim)" 64 | ] 65 | }, 66 | { 67 | "cell_type": "code", 68 | "execution_count": 8, 69 | "id": "5f757090-b686-46f5-9f25-30662dc8230f", 70 | "metadata": {}, 71 | "outputs": [], 72 | "source": [ 73 | "with open(\"ch02/the-verdict.txt\", \"r\", encoding=\"utf-8\") as f:\n", 74 | " raw_data = f.read()" 75 | ] 76 | }, 77 | { 78 | "cell_type": "code", 79 | "execution_count": null, 80 | "id": "0e43ec1e-e918-4753-8731-97467dadb36c", 81 | "metadata": {}, 82 | "outputs": [], 83 | "source": [] 84 | } 85 | ], 86 | "metadata": { 87 | "kernelspec": { 88 | "display_name": "Python 3 (ipykernel)", 89 | "language": "python", 90 | "name": "python3" 91 | }, 92 | "language_info": { 93 | "codemirror_mode": { 94 | "name": "ipython", 95 | "version": 3 96 | }, 97 | "file_extension": ".py", 98 | "mimetype": "text/x-python", 99 | "name": "python", 100 | "nbconvert_exporter": "python", 101 | "pygments_lexer": "ipython3", 102 | "version": "3.11.4" 103 | } 104 | }, 105 | "nbformat": 4, 106 | "nbformat_minor": 5 107 | } 108 | -------------------------------------------------------------------------------- /3.2_attention_mechanisms.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "af5bdd9e-71a3-4171-89f1-ad8a6d723ac3", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "c8f95b74-2df0-4e44-a96e-9bda03785695", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "inputs = torch.tensor(\n", 22 | " [[0.43, 0.15, 0.89], # Your\n", 23 | " [0.55, 0.87, 0.66], # journey\n", 24 | " [0.57, 0.85, 0.64], # starts\n", 25 | " [0.22, 0.58, 0.33], # with\n", 26 | " [0.77, 0.25, 0.10], # one\n", 27 | " [0.05, 0.80, 0.55] # step\n", 28 | " ])" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 7, 34 | "id": "486dc404-9a68-44a8-bd82-c5cd7f47a620", 35 | "metadata": {}, 36 | "outputs": [ 37 | { 38 | "name": "stdout", 39 | "output_type": "stream", 40 | "text": [ 41 | "tensor([0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605])\n" 42 | ] 43 | } 44 | ], 45 | "source": [ 46 | "query = inputs[2]\n", 47 | "attn_scores = torch.zeros(inputs.shape[0])\n", 48 | "\n", 49 | "for i, x in enumerate(inputs):\n", 50 | " attn_scores[i] = torch.dot(x, query)\n", 51 | "print(attn_scores)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": 9, 57 | "id": "c5b873ea-40d7-40ae-870f-e2e4d1f5e9b0", 58 | "metadata": {}, 59 | "outputs": [ 60 | { 61 | "name": "stdout", 62 | "output_type": "stream", 63 | "text": [ 64 | "tensor([0.1454, 0.2277, 0.2248, 0.1280, 0.1104, 0.1637])\n", 65 | "tensor(1.)\n" 66 | ] 67 | } 68 | ], 69 | "source": [ 70 | "norm_attn_scores = attn_scores / attn_scores.sum()\n", 71 | "print(norm_attn_scores)\n", 72 | "print(norm_attn_scores.sum())" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": 10, 78 | "id": "f61d8dc1-0ae0-4a0e-a559-f63554e2c398", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def softmax_naive(x):\n", 83 | " exp_x = torch.exp(x)\n", 84 | " return exp_x / exp_x.sum(dim=0)" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 11, 90 | "id": "6076b901-4891-4270-8e92-89ebd0427f21", 91 | "metadata": {}, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "tensor([0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565])\n", 98 | "tensor(1.)\n" 99 | ] 100 | } 101 | ], 102 | "source": [ 103 | "softmax_attn_score = softmax_naive(attn_scores)\n", 104 | "print(softmax_attn_score)\n", 105 | "print(softmax_attn_score.sum())" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 12, 111 | "id": "a94fbed8-9f5e-41fa-a586-0a0a8221217c", 112 | "metadata": {}, 113 | "outputs": [ 114 | { 115 | "name": "stdout", 116 | "output_type": "stream", 117 | "text": [ 118 | "tensor([0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565])\n", 119 | "tensor(1.0000)\n" 120 | ] 121 | } 122 | ], 123 | "source": [ 124 | "softmax_attn_scores = torch.softmax(attn_scores, dim=0)\n", 125 | "print(softmax_attn_scores)\n", 126 | "print(softmax_attn_scores.sum())" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 13, 132 | "id": "5649fad3-6fb2-4713-92c5-052ef12b13a4", 133 | "metadata": {}, 134 | "outputs": [ 135 | { 136 | "name": "stdout", 137 | "output_type": "stream", 138 | "text": [ 139 | "tensor([0.4431, 0.6496, 0.5671])\n" 140 | ] 141 | } 142 | ], 143 | "source": [ 144 | "context_vec = torch.zeros(inputs[0].shape[0])\n", 145 | "for i, x in enumerate(inputs):\n", 146 | " context_vec += softmax_attn_scores[i] * x\n", 147 | "print(context_vec)" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 14, 153 | "id": "33f01f00-f27e-420e-9c28-22e53a8c71fc", 154 | "metadata": {}, 155 | "outputs": [ 156 | { 157 | "name": "stdout", 158 | "output_type": "stream", 159 | "text": [ 160 | "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n", 161 | " [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n", 162 | " [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n", 163 | " [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n", 164 | " [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n", 165 | " [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n" 166 | ] 167 | } 168 | ], 169 | "source": [ 170 | "attn_scores = torch.empty(6, 6)\n", 171 | "for i, x in enumerate(inputs):\n", 172 | " for j, y in enumerate(inputs):\n", 173 | " attn_scores[i,j] = torch.dot(x, y)\n", 174 | "print(attn_scores)" 175 | ] 176 | }, 177 | { 178 | "cell_type": "code", 179 | "execution_count": 15, 180 | "id": "4680b5d0-c952-4e1b-8b78-102e3a78f51f", 181 | "metadata": {}, 182 | "outputs": [ 183 | { 184 | "name": "stdout", 185 | "output_type": "stream", 186 | "text": [ 187 | "tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],\n", 188 | " [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],\n", 189 | " [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],\n", 190 | " [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],\n", 191 | " [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],\n", 192 | " [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "attn_scores = inputs @ inputs.T\n", 198 | "print(attn_scores)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 16, 204 | "id": "8f3e4dc8-8c29-4ec5-9b7c-a61dadf59fb0", 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],\n", 212 | " [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],\n", 213 | " [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],\n", 214 | " [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],\n", 215 | " [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],\n", 216 | " [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "attn_scores = torch.softmax(attn_scores, dim=-1)\n", 222 | "print(attn_scores)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 17, 228 | "id": "1af3e47e-cf94-4e27-a271-61a3775cb589", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])\n" 236 | ] 237 | } 238 | ], 239 | "source": [ 240 | "print(attn_scores.sum(dim=-1))" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": 18, 246 | "id": "e11334f6-552b-45ad-9455-aea45f7bbd44", 247 | "metadata": {}, 248 | "outputs": [ 249 | { 250 | "name": "stdout", 251 | "output_type": "stream", 252 | "text": [ 253 | "tensor([[0.4421, 0.5931, 0.5790],\n", 254 | " [0.4419, 0.6515, 0.5683],\n", 255 | " [0.4431, 0.6496, 0.5671],\n", 256 | " [0.4304, 0.6298, 0.5510],\n", 257 | " [0.4671, 0.5910, 0.5266],\n", 258 | " [0.4177, 0.6503, 0.5645]])\n" 259 | ] 260 | } 261 | ], 262 | "source": [ 263 | "context_vecs = attn_scores @ inputs\n", 264 | "print(context_vecs)" 265 | ] 266 | }, 267 | { 268 | "cell_type": "code", 269 | "execution_count": 20, 270 | "id": "a5b86436-dfa3-4ba1-a097-fe7b98632c3c", 271 | "metadata": {}, 272 | "outputs": [], 273 | "source": [ 274 | "x_2 = inputs[1]\n", 275 | "\n", 276 | "dim_in, dim_out = inputs.shape[1], 2" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 22, 282 | "id": "0c651280-763b-4bdd-8859-34b52bc32dbf", 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "torch.manual_seed(123)\n", 287 | "W_query = nn.Parameter(torch.rand(dim_in, dim_out), requires_grad=False)\n", 288 | "W_key = nn.Parameter(torch.rand(dim_in, dim_out), requires_grad=False)\n", 289 | "W_value = nn.Parameter(torch.rand(dim_in, dim_out), requires_grad=False)" 290 | ] 291 | }, 292 | { 293 | "cell_type": "code", 294 | "execution_count": 23, 295 | "id": "a83fb5a6-bcbd-48b5-97ae-2d8b01a5547d", 296 | "metadata": {}, 297 | "outputs": [], 298 | "source": [ 299 | "query_2 = x_2 @ W_query\n", 300 | "key_2 = x_2 @ W_key\n", 301 | "value_2 = x_2 @ W_value" 302 | ] 303 | }, 304 | { 305 | "cell_type": "code", 306 | "execution_count": 24, 307 | "id": "ffb077b9-caa0-4d3c-b66d-0a062e32a35e", 308 | "metadata": {}, 309 | "outputs": [ 310 | { 311 | "name": "stdout", 312 | "output_type": "stream", 313 | "text": [ 314 | "tensor([0.4306, 1.4551])\n" 315 | ] 316 | } 317 | ], 318 | "source": [ 319 | "print(query_2)" 320 | ] 321 | }, 322 | { 323 | "cell_type": "code", 324 | "execution_count": 26, 325 | "id": "c3f10d7b-bc74-4a3f-acec-d9aa9c720748", 326 | "metadata": {}, 327 | "outputs": [ 328 | { 329 | "name": "stdout", 330 | "output_type": "stream", 331 | "text": [ 332 | "key.shape: torch.Size([6, 2])\n", 333 | "values.shape: torch.Size([6, 2])\n" 334 | ] 335 | } 336 | ], 337 | "source": [ 338 | "keys = inputs @ W_key\n", 339 | "values = inputs @ W_value\n", 340 | "\n", 341 | "print(\"key.shape: \", keys.shape)\n", 342 | "print(\"values.shape: \", values.shape)" 343 | ] 344 | }, 345 | { 346 | "cell_type": "code", 347 | "execution_count": 27, 348 | "id": "e12a1c6d-e715-4bbc-9292-0008d8c15ca6", 349 | "metadata": {}, 350 | "outputs": [ 351 | { 352 | "name": "stdout", 353 | "output_type": "stream", 354 | "text": [ 355 | "tensor(1.8524)\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "key_2 = keys[1]\n", 361 | "attn_score_22 = query_2.dot(key_2)\n", 362 | "print(attn_score_22)" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 28, 368 | "id": "f5719dff-a708-4dd7-9f6e-8b62e45c75e5", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])\n" 376 | ] 377 | } 378 | ], 379 | "source": [ 380 | "attn_scores_2 = query_2 @ keys.T\n", 381 | "print(attn_scores_2)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 29, 387 | "id": "b3d6b6af-f832-4b09-9335-7a02a9767d2b", 388 | "metadata": {}, 389 | "outputs": [ 390 | { 391 | "name": "stdout", 392 | "output_type": "stream", 393 | "text": [ 394 | "tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])\n" 395 | ] 396 | } 397 | ], 398 | "source": [ 399 | "d_k = keys.shape[-1]\n", 400 | "attn_weights_2 = torch.softmax(attn_scores_2 / d_k ** 0.5, dim=-1) # scaled-dot product attention\n", 401 | "print(attn_weights_2)" 402 | ] 403 | }, 404 | { 405 | "cell_type": "code", 406 | "execution_count": 30, 407 | "id": "a0cd2c06-257f-486e-99df-2a2e6a5e5683", 408 | "metadata": {}, 409 | "outputs": [ 410 | { 411 | "name": "stdout", 412 | "output_type": "stream", 413 | "text": [ 414 | "tensor([0.3061, 0.8210])\n" 415 | ] 416 | } 417 | ], 418 | "source": [ 419 | "context_vec_2 = attn_weights_2 @ values\n", 420 | "print(context_vec_2)" 421 | ] 422 | }, 423 | { 424 | "cell_type": "code", 425 | "execution_count": 33, 426 | "id": "ded43f86-feff-4afc-a91e-f39c7f26d1a2", 427 | "metadata": {}, 428 | "outputs": [], 429 | "source": [ 430 | "class SelfAttentionV1(nn.Module):\n", 431 | " def __init__(self, dim_in, dim_out):\n", 432 | " super().__init__()\n", 433 | "\n", 434 | " self.W_query = nn.Parameter(torch.rand(dim_in, dim_out))\n", 435 | " self.W_key = nn.Parameter(torch.rand(dim_in, dim_out))\n", 436 | " self.W_value = nn.Parameter(torch.rand(dim_in, dim_out))\n", 437 | "\n", 438 | " def forward(self, x):\n", 439 | " keys = x @ self.W_key\n", 440 | " queries = x @ self.W_query\n", 441 | " values = x @ self.W_value\n", 442 | "\n", 443 | " attn_scores = queries @ keys.T\n", 444 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim = -1)\n", 445 | " context_vec = attn_weights @ values\n", 446 | " return context_vec" 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 34, 452 | "id": "4444f682-5d80-4861-8fdd-d799fb09ab05", 453 | "metadata": {}, 454 | "outputs": [ 455 | { 456 | "name": "stdout", 457 | "output_type": "stream", 458 | "text": [ 459 | "tensor([[0.2996, 0.8053],\n", 460 | " [0.3061, 0.8210],\n", 461 | " [0.3058, 0.8203],\n", 462 | " [0.2948, 0.7939],\n", 463 | " [0.2927, 0.7891],\n", 464 | " [0.2990, 0.8040]], grad_fn=)\n" 465 | ] 466 | } 467 | ], 468 | "source": [ 469 | "torch.manual_seed(123)\n", 470 | "self_attn = SelfAttentionV1(dim_in, dim_out)\n", 471 | "print(self_attn(inputs))" 472 | ] 473 | }, 474 | { 475 | "cell_type": "code", 476 | "execution_count": 37, 477 | "id": "db956ecb-7d4b-4274-b454-ce275196c5fd", 478 | "metadata": {}, 479 | "outputs": [], 480 | "source": [ 481 | "class SelfAttentionV2(nn.Module):\n", 482 | " def __init__(self, dim_in, dim_out, qkv_bias=False):\n", 483 | " super().__init__()\n", 484 | "\n", 485 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 486 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 487 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 488 | "\n", 489 | " def forward(self, x):\n", 490 | " queries = self.W_query(x)\n", 491 | " keys = self.W_key(x)\n", 492 | " values = self.W_value(x)\n", 493 | "\n", 494 | " attn_scores = queries @ keys.T\n", 495 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 496 | " context_vec = attn_weights @ values\n", 497 | " return context_vec" 498 | ] 499 | }, 500 | { 501 | "cell_type": "code", 502 | "execution_count": 39, 503 | "id": "bd85902b-6a99-4581-a819-d0796e7a534f", 504 | "metadata": {}, 505 | "outputs": [ 506 | { 507 | "name": "stdout", 508 | "output_type": "stream", 509 | "text": [ 510 | "tensor([[-0.0739, 0.0713],\n", 511 | " [-0.0748, 0.0703],\n", 512 | " [-0.0749, 0.0702],\n", 513 | " [-0.0760, 0.0685],\n", 514 | " [-0.0763, 0.0679],\n", 515 | " [-0.0754, 0.0693]], grad_fn=)\n" 516 | ] 517 | } 518 | ], 519 | "source": [ 520 | "torch.manual_seed(789)\n", 521 | "self_attn_2 = SelfAttentionV2(dim_in, dim_out)\n", 522 | "print(self_attn_2(inputs))" 523 | ] 524 | }, 525 | { 526 | "cell_type": "code", 527 | "execution_count": 47, 528 | "id": "96f2ba54-acf5-43c2-b0ff-9f023bdd0ce3", 529 | "metadata": {}, 530 | "outputs": [ 531 | { 532 | "name": "stdout", 533 | "output_type": "stream", 534 | "text": [ 535 | "Parameter containing:\n", 536 | "tensor([[ 0.3161, 0.4568, 0.5118],\n", 537 | " [-0.1683, -0.3379, -0.0918]], requires_grad=True)\n" 538 | ] 539 | } 540 | ], 541 | "source": [ 542 | "print(self_attn_2.W_query.weight)" 543 | ] 544 | }, 545 | { 546 | "cell_type": "code", 547 | "execution_count": 54, 548 | "id": "339cf31f-ef9c-45bc-90c8-076b3f6383d9", 549 | "metadata": {}, 550 | "outputs": [], 551 | "source": [ 552 | "self_attn.W_query.data = self_attn_2.W_query.weight.T\n", 553 | "self_attn.W_key.data = self_attn_2.W_key.weight.T\n", 554 | "self_attn.W_value.data = self_attn_2.W_value.weight.T" 555 | ] 556 | }, 557 | { 558 | "cell_type": "code", 559 | "execution_count": 55, 560 | "id": "9ffb764b-ac12-4ed3-846d-6c4c11328cb2", 561 | "metadata": {}, 562 | "outputs": [ 563 | { 564 | "name": "stdout", 565 | "output_type": "stream", 566 | "text": [ 567 | "tensor([[-0.0739, 0.0713],\n", 568 | " [-0.0748, 0.0703],\n", 569 | " [-0.0749, 0.0702],\n", 570 | " [-0.0760, 0.0685],\n", 571 | " [-0.0763, 0.0679],\n", 572 | " [-0.0754, 0.0693]], grad_fn=)\n", 573 | "tensor([[-0.0739, 0.0713],\n", 574 | " [-0.0748, 0.0703],\n", 575 | " [-0.0749, 0.0702],\n", 576 | " [-0.0760, 0.0685],\n", 577 | " [-0.0763, 0.0679],\n", 578 | " [-0.0754, 0.0693]], grad_fn=)\n" 579 | ] 580 | } 581 | ], 582 | "source": [ 583 | "print(self_attn(inputs))\n", 584 | "print(self_attn_2(inputs))" 585 | ] 586 | }, 587 | { 588 | "cell_type": "code", 589 | "execution_count": 56, 590 | "id": "68e13a42-66c6-4712-90aa-91ea2900eed1", 591 | "metadata": {}, 592 | "outputs": [ 593 | { 594 | "name": "stdout", 595 | "output_type": "stream", 596 | "text": [ 597 | "tensor([[ 0.3161, -0.1683],\n", 598 | " [ 0.4568, -0.3379],\n", 599 | " [ 0.5118, -0.0918]]) tensor([[ 0.3161, -0.1683],\n", 600 | " [ 0.4568, -0.3379],\n", 601 | " [ 0.5118, -0.0918]], grad_fn=)\n", 602 | "tensor([[ 0.4058, 0.2134],\n", 603 | " [-0.4704, -0.2601],\n", 604 | " [ 0.2368, -0.5105]]) tensor([[ 0.4058, 0.2134],\n", 605 | " [-0.4704, -0.2601],\n", 606 | " [ 0.2368, -0.5105]], grad_fn=)\n", 607 | "tensor([[ 0.2526, 0.5191],\n", 608 | " [-0.1415, -0.0852],\n", 609 | " [-0.1962, -0.2043]]) tensor([[ 0.2526, 0.5191],\n", 610 | " [-0.1415, -0.0852],\n", 611 | " [-0.1962, -0.2043]], grad_fn=)\n" 612 | ] 613 | } 614 | ], 615 | "source": [ 616 | "print(self_attn.W_query.data, self_attn_2.W_query.weight.T)\n", 617 | "print(self_attn.W_key.data, self_attn_2.W_key.weight.T)\n", 618 | "print(self_attn.W_value.data, self_attn_2.W_value.weight.T)" 619 | ] 620 | }, 621 | { 622 | "cell_type": "code", 623 | "execution_count": 57, 624 | "id": "991cd48f-80fb-420d-bf6f-a14c32bece2e", 625 | "metadata": {}, 626 | "outputs": [ 627 | { 628 | "name": "stdout", 629 | "output_type": "stream", 630 | "text": [ 631 | "tensor([[-0.0739, 0.0713],\n", 632 | " [-0.0748, 0.0703],\n", 633 | " [-0.0749, 0.0702],\n", 634 | " [-0.0760, 0.0685],\n", 635 | " [-0.0763, 0.0679],\n", 636 | " [-0.0754, 0.0693]], grad_fn=)\n", 637 | "tensor([[-0.0739, 0.0713],\n", 638 | " [-0.0748, 0.0703],\n", 639 | " [-0.0749, 0.0702],\n", 640 | " [-0.0760, 0.0685],\n", 641 | " [-0.0763, 0.0679],\n", 642 | " [-0.0754, 0.0693]], grad_fn=)\n" 643 | ] 644 | } 645 | ], 646 | "source": [ 647 | "print(self_attn(inputs))\n", 648 | "print(self_attn_2(inputs))" 649 | ] 650 | }, 651 | { 652 | "cell_type": "code", 653 | "execution_count": null, 654 | "id": "905f4917-259a-4952-8ba1-1a1eb42fcc35", 655 | "metadata": {}, 656 | "outputs": [], 657 | "source": [] 658 | } 659 | ], 660 | "metadata": { 661 | "kernelspec": { 662 | "display_name": "Python 3 (ipykernel)", 663 | "language": "python", 664 | "name": "python3" 665 | }, 666 | "language_info": { 667 | "codemirror_mode": { 668 | "name": "ipython", 669 | "version": 3 670 | }, 671 | "file_extension": ".py", 672 | "mimetype": "text/x-python", 673 | "name": "python", 674 | "nbconvert_exporter": "python", 675 | "pygments_lexer": "ipython3", 676 | "version": "3.11.4" 677 | } 678 | }, 679 | "nbformat": 4, 680 | "nbformat_minor": 5 681 | } 682 | -------------------------------------------------------------------------------- /3.5_casual_attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "dccf39be-a8a3-4908-9d2e-efb228b1df96", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "a7c133a8-3fbb-4f11-b69e-257d7a6d6e2b", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "inputs = torch.tensor(\n", 22 | " [[0.43, 0.15, 0.89], # Your\n", 23 | " [0.55, 0.87, 0.66], # journey\n", 24 | " [0.57, 0.85, 0.64], # starts\n", 25 | " [0.22, 0.58, 0.33], # with\n", 26 | " [0.77, 0.25, 0.10], # one\n", 27 | " [0.05, 0.80, 0.55] # step\n", 28 | " ])" 29 | ] 30 | }, 31 | { 32 | "cell_type": "code", 33 | "execution_count": 3, 34 | "id": "f8c8ef91-bd07-4ccf-a0bd-67e556705061", 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "class SelfAttentionV2(nn.Module):\n", 39 | " def __init__(self, dim_in, dim_out, qkv_bias=False):\n", 40 | " super().__init__()\n", 41 | "\n", 42 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 43 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 44 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 45 | "\n", 46 | " def forward(self, x):\n", 47 | " queries = self.W_query(x)\n", 48 | " keys = self.W_key(x)\n", 49 | " values = self.W_value(x)\n", 50 | "\n", 51 | " attn_scores = queries @ keys.T\n", 52 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 53 | " context_vec = attn_weights @ values\n", 54 | " return context_vec" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": 4, 60 | "id": "74d3fe20-3cc9-40c2-8c26-d7b3ff66be28", 61 | "metadata": {}, 62 | "outputs": [], 63 | "source": [ 64 | "dim_in, dim_out = inputs.shape[-1], 2" 65 | ] 66 | }, 67 | { 68 | "cell_type": "code", 69 | "execution_count": 8, 70 | "id": "56b278d9-c5f7-42d4-9ea6-a8889613713d", 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "torch.manual_seed(789)\n", 75 | "sa_v2 = SelfAttentionV2(dim_in, dim_out)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 9, 81 | "id": "c756ae74-6918-44f1-a267-8fb2a728212e", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "queries = sa_v2.W_query(inputs)\n", 86 | "keys = sa_v2.W_key(inputs)\n", 87 | "values = sa_v2.W_value(inputs)" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 10, 93 | "id": "07285229-a029-4d3c-83bf-702a0416c20e", 94 | "metadata": {}, 95 | "outputs": [ 96 | { 97 | "name": "stdout", 98 | "output_type": "stream", 99 | "text": [ 100 | "tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],\n", 101 | " [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],\n", 102 | " [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],\n", 103 | " [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],\n", 104 | " [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],\n", 105 | " [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n", 106 | " grad_fn=)\n" 107 | ] 108 | } 109 | ], 110 | "source": [ 111 | "attn_scores = queries @ keys.T\n", 112 | "attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 113 | "print(attn_weights)" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": 11, 119 | "id": "b4d4bfe7-aa74-44e7-afaf-93d6f7ce69e1", 120 | "metadata": {}, 121 | "outputs": [ 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "tensor([[1., 0., 0., 0., 0., 0.],\n", 127 | " [1., 1., 0., 0., 0., 0.],\n", 128 | " [1., 1., 1., 0., 0., 0.],\n", 129 | " [1., 1., 1., 1., 0., 0.],\n", 130 | " [1., 1., 1., 1., 1., 0.],\n", 131 | " [1., 1., 1., 1., 1., 1.]])\n" 132 | ] 133 | } 134 | ], 135 | "source": [ 136 | "context_length = attn_scores.shape[0]\n", 137 | "mask_simple = torch.tril(torch.ones(context_length, context_length))\n", 138 | "print(mask_simple)" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 12, 144 | "id": "6819a092-3c9a-41a0-b0bf-13ace5b5d63f", 145 | "metadata": {}, 146 | "outputs": [ 147 | { 148 | "name": "stdout", 149 | "output_type": "stream", 150 | "text": [ 151 | "tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 152 | " [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],\n", 153 | " [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],\n", 154 | " [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],\n", 155 | " [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],\n", 156 | " [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n", 157 | " grad_fn=)\n" 158 | ] 159 | } 160 | ], 161 | "source": [ 162 | "masked_attn_weights = attn_weights * mask_simple\n", 163 | "print(masked_attn_weights)" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 13, 169 | "id": "f8779ccf-408c-4e6f-b8d1-4b96349efc9a", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "row_sum = masked_attn_weights.sum(dim=-1, keepdim=True)" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": 14, 179 | "id": "62639f99-a3c0-4578-ae3c-506b0aecd7ef", 180 | "metadata": {}, 181 | "outputs": [ 182 | { 183 | "name": "stdout", 184 | "output_type": "stream", 185 | "text": [ 186 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 187 | " [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n", 188 | " [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n", 189 | " [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n", 190 | " [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n", 191 | " [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n", 192 | " grad_fn=)\n" 193 | ] 194 | } 195 | ], 196 | "source": [ 197 | "masked_attn_weights_norm = masked_attn_weights / row_sum\n", 198 | "print(masked_attn_weights_norm)" 199 | ] 200 | }, 201 | { 202 | "cell_type": "code", 203 | "execution_count": 16, 204 | "id": "a3f34a02-6a9b-4e9e-b849-76e77f4e0865", 205 | "metadata": {}, 206 | "outputs": [ 207 | { 208 | "name": "stdout", 209 | "output_type": "stream", 210 | "text": [ 211 | "tensor([[0., 1., 1., 1., 1., 1.],\n", 212 | " [0., 0., 1., 1., 1., 1.],\n", 213 | " [0., 0., 0., 1., 1., 1.],\n", 214 | " [0., 0., 0., 0., 1., 1.],\n", 215 | " [0., 0., 0., 0., 0., 1.],\n", 216 | " [0., 0., 0., 0., 0., 0.]])\n" 217 | ] 218 | } 219 | ], 220 | "source": [ 221 | "mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", 222 | "print(mask)" 223 | ] 224 | }, 225 | { 226 | "cell_type": "code", 227 | "execution_count": 18, 228 | "id": "832cf176-5164-48fc-bfb2-aa56323473d0", 229 | "metadata": {}, 230 | "outputs": [ 231 | { 232 | "name": "stdout", 233 | "output_type": "stream", 234 | "text": [ 235 | "tensor([[0.2899, -inf, -inf, -inf, -inf, -inf],\n", 236 | " [0.4656, 0.1723, -inf, -inf, -inf, -inf],\n", 237 | " [0.4594, 0.1703, 0.1731, -inf, -inf, -inf],\n", 238 | " [0.2642, 0.1024, 0.1036, 0.0186, -inf, -inf],\n", 239 | " [0.2183, 0.0874, 0.0882, 0.0177, 0.0786, -inf],\n", 240 | " [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],\n", 241 | " grad_fn=)\n" 242 | ] 243 | } 244 | ], 245 | "source": [ 246 | "masked = attn_scores.masked_fill(mask.bool(), -torch.inf)\n", 247 | "print(masked)" 248 | ] 249 | }, 250 | { 251 | "cell_type": "code", 252 | "execution_count": 19, 253 | "id": "969b3520-7274-4fd8-8a03-251a6420ec8f", 254 | "metadata": {}, 255 | "outputs": [ 256 | { 257 | "name": "stdout", 258 | "output_type": "stream", 259 | "text": [ 260 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 261 | " [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n", 262 | " [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n", 263 | " [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n", 264 | " [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n", 265 | " [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n", 266 | " grad_fn=)\n" 267 | ] 268 | } 269 | ], 270 | "source": [ 271 | "masked_attn_weights = torch.softmax(masked / keys.shape[-1] ** 0.5, dim=-1)\n", 272 | "print(masked_attn_weights)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": 20, 278 | "id": "619fbcb1-6960-4a71-85ed-620e590df905", 279 | "metadata": {}, 280 | "outputs": [ 281 | { 282 | "name": "stdout", 283 | "output_type": "stream", 284 | "text": [ 285 | "tensor([[-0.0872, 0.0286],\n", 286 | " [-0.0991, 0.0501],\n", 287 | " [-0.0999, 0.0633],\n", 288 | " [-0.0983, 0.0489],\n", 289 | " [-0.0514, 0.1098],\n", 290 | " [-0.0754, 0.0693]], grad_fn=)\n" 291 | ] 292 | } 293 | ], 294 | "source": [ 295 | "context_vec = masked_attn_weights @ values\n", 296 | "print(context_vec)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 24, 302 | "id": "1e9389f6-706a-45fa-92bb-f1b64c2e6a9d", 303 | "metadata": {}, 304 | "outputs": [ 305 | { 306 | "name": "stdout", 307 | "output_type": "stream", 308 | "text": [ 309 | "tensor([[1., 1., 1., 1., 1., 1.],\n", 310 | " [1., 1., 1., 1., 1., 1.],\n", 311 | " [1., 1., 1., 1., 1., 1.],\n", 312 | " [1., 1., 1., 1., 1., 1.],\n", 313 | " [1., 1., 1., 1., 1., 1.],\n", 314 | " [1., 1., 1., 1., 1., 1.]])\n", 315 | "tensor([[2., 2., 0., 2., 2., 0.],\n", 316 | " [0., 0., 0., 2., 0., 2.],\n", 317 | " [2., 2., 2., 2., 0., 2.],\n", 318 | " [0., 2., 2., 0., 0., 2.],\n", 319 | " [0., 2., 0., 2., 0., 2.],\n", 320 | " [0., 2., 2., 2., 2., 0.]])\n" 321 | ] 322 | } 323 | ], 324 | "source": [ 325 | "torch.manual_seed(123)\n", 326 | "dropout = nn.Dropout(0.5)\n", 327 | "example = torch.ones(6, 6)\n", 328 | "print(example)\n", 329 | "print(dropout(example))" 330 | ] 331 | }, 332 | { 333 | "cell_type": "code", 334 | "execution_count": 27, 335 | "id": "d7fde114-0313-4f92-a6cc-3c379e90510c", 336 | "metadata": {}, 337 | "outputs": [ 338 | { 339 | "name": "stdout", 340 | "output_type": "stream", 341 | "text": [ 342 | "tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 343 | " [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],\n", 344 | " [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],\n", 345 | " [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],\n", 346 | " [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],\n", 347 | " [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],\n", 348 | " grad_fn=)\n", 349 | "tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 350 | " [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],\n", 351 | " [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],\n", 352 | " [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],\n", 353 | " [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],\n", 354 | " [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],\n", 355 | " grad_fn=)\n" 356 | ] 357 | } 358 | ], 359 | "source": [ 360 | "torch.manual_seed(123)\n", 361 | "print(masked_attn_weights)\n", 362 | "print(dropout(masked_attn_weights))" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 28, 368 | "id": "7d11bb23-d9f6-4fe3-a1e1-b17b80cd9af1", 369 | "metadata": {}, 370 | "outputs": [ 371 | { 372 | "name": "stdout", 373 | "output_type": "stream", 374 | "text": [ 375 | "torch.Size([2, 6, 3])\n" 376 | ] 377 | } 378 | ], 379 | "source": [ 380 | "batch = torch.stack((inputs, inputs), dim=0)\n", 381 | "print(batch.shape)" 382 | ] 383 | }, 384 | { 385 | "cell_type": "code", 386 | "execution_count": 29, 387 | "id": "8c2647f0-866a-4fd7-b837-c7bc1e9962e5", 388 | "metadata": {}, 389 | "outputs": [], 390 | "source": [ 391 | "class CasualAttention(nn.Module):\n", 392 | " def __init__(self, \n", 393 | " dim_in, \n", 394 | " dim_out, \n", 395 | " context_len, \n", 396 | " dropout, \n", 397 | " qkv_bias=False):\n", 398 | " super().__init__()\n", 399 | " \n", 400 | " self.d_out = dim_out\n", 401 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 402 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 403 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 404 | " self.dropout = nn.Dropout(dropout)\n", 405 | "\n", 406 | " self.register_buffer(\n", 407 | " \"mask\",\n", 408 | " torch.triu(torch.ones(context_len, context_len), diagonal=1)\n", 409 | " )\n", 410 | "\n", 411 | " def forward(self, x):\n", 412 | " b, num_tokens, dim_in = x.size()\n", 413 | " \n", 414 | " keys = self.W_key(x)\n", 415 | " queries = self.W_query(x)\n", 416 | " values = self.W_value(x)\n", 417 | "\n", 418 | " attn_scores = queries @ keys.transpose(1,2)\n", 419 | " attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n", 420 | "\n", 421 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 422 | " attn_weights = self.dropout(attn_weights)\n", 423 | "\n", 424 | " context_vecs = attn_weights @ values\n", 425 | " return context_vecs" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 30, 431 | "id": "0780a04e-822b-4618-84b2-df793f738231", 432 | "metadata": {}, 433 | "outputs": [ 434 | { 435 | "name": "stdout", 436 | "output_type": "stream", 437 | "text": [ 438 | "torch.Size([2, 6, 2])\n", 439 | "tensor([[[-0.5337, -0.1051],\n", 440 | " [-0.5323, -0.1080],\n", 441 | " [-0.5323, -0.1079],\n", 442 | " [-0.5297, -0.1076],\n", 443 | " [-0.5311, -0.1066],\n", 444 | " [-0.5299, -0.1081]],\n", 445 | "\n", 446 | " [[-0.5337, -0.1051],\n", 447 | " [-0.5323, -0.1080],\n", 448 | " [-0.5323, -0.1079],\n", 449 | " [-0.5297, -0.1076],\n", 450 | " [-0.5311, -0.1066],\n", 451 | " [-0.5299, -0.1081]]], grad_fn=)\n" 452 | ] 453 | } 454 | ], 455 | "source": [ 456 | "torch.manual_seed(123)\n", 457 | "context_length = batch.shape[1]\n", 458 | "ca = CasualAttention(\n", 459 | " dim_in, \n", 460 | " dim_out,\n", 461 | " context_length,\n", 462 | " 0.0\n", 463 | ")\n", 464 | "\n", 465 | "context_vecs = ca(batch)\n", 466 | "print(context_vecs.shape)\n", 467 | "print(context_vecs)" 468 | ] 469 | }, 470 | { 471 | "cell_type": "code", 472 | "execution_count": null, 473 | "id": "2e06f64e-69e8-4e89-807a-f5e3721173c2", 474 | "metadata": {}, 475 | "outputs": [], 476 | "source": [] 477 | } 478 | ], 479 | "metadata": { 480 | "kernelspec": { 481 | "display_name": "Python 3 (ipykernel)", 482 | "language": "python", 483 | "name": "python3" 484 | }, 485 | "language_info": { 486 | "codemirror_mode": { 487 | "name": "ipython", 488 | "version": 3 489 | }, 490 | "file_extension": ".py", 491 | "mimetype": "text/x-python", 492 | "name": "python", 493 | "nbconvert_exporter": "python", 494 | "pygments_lexer": "ipython3", 495 | "version": "3.11.4" 496 | } 497 | }, 498 | "nbformat": 4, 499 | "nbformat_minor": 5 500 | } 501 | -------------------------------------------------------------------------------- /3.6_multihead_attention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "ecf47e86-71c6-4746-ae5c-a72e994ee51a", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 3, 17 | "id": "e6fa2153-3907-41ce-878a-706245448c50", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "class CasualAttention(nn.Module):\n", 22 | " def __init__(self, \n", 23 | " dim_in, \n", 24 | " dim_out, \n", 25 | " context_len, \n", 26 | " dropout, \n", 27 | " qkv_bias=False):\n", 28 | " super().__init__()\n", 29 | " \n", 30 | " self.d_out = dim_out\n", 31 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 32 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 33 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 34 | " self.dropout = nn.Dropout(dropout)\n", 35 | "\n", 36 | " self.register_buffer(\n", 37 | " \"mask\",\n", 38 | " torch.triu(torch.ones(context_len, context_len), diagonal=1)\n", 39 | " )\n", 40 | "\n", 41 | " def forward(self, x):\n", 42 | " b, num_tokens, dim_in = x.size()\n", 43 | " \n", 44 | " keys = self.W_key(x)\n", 45 | " queries = self.W_query(x)\n", 46 | " values = self.W_value(x)\n", 47 | "\n", 48 | " attn_scores = queries @ keys.transpose(1,2)\n", 49 | " attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)\n", 50 | "\n", 51 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 52 | " attn_weights = self.dropout(attn_weights)\n", 53 | "\n", 54 | " context_vecs = attn_weights @ values\n", 55 | " return context_vecs" 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": 7, 61 | "id": "b05fdb37-251b-4838-b182-6935271450f1", 62 | "metadata": {}, 63 | "outputs": [], 64 | "source": [ 65 | "class MultiheadAttention(nn.Module):\n", 66 | " def __init__(self, \n", 67 | " dim_in, \n", 68 | " dim_out, \n", 69 | " context_length, \n", 70 | " dropout, \n", 71 | " num_heads,\n", 72 | " qkv_bias=False):\n", 73 | " super().__init__()\n", 74 | "\n", 75 | " self.heads = nn.ModuleList([\n", 76 | " CasualAttention(dim_in, dim_out, context_length, dropout, qkv_bias) for _ in range(num_heads)\n", 77 | " ])\n", 78 | "\n", 79 | " def forward(self, x):\n", 80 | " return torch.cat([head(x) for head in self.heads], dim=-1)\n", 81 | " " 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 15, 87 | "id": "a83d56ef-3131-45a2-a548-ab31a2de7055", 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "name": "stdout", 92 | "output_type": "stream", 93 | "text": [ 94 | "torch.Size([2, 6, 3])\n", 95 | "tensor([[[-0.5337, -0.1051, 0.5085, 0.3508],\n", 96 | " [-0.5323, -0.1080, 0.5084, 0.3508],\n", 97 | " [-0.5323, -0.1079, 0.5084, 0.3506],\n", 98 | " [-0.5297, -0.1076, 0.5074, 0.3471],\n", 99 | " [-0.5311, -0.1066, 0.5076, 0.3446],\n", 100 | " [-0.5299, -0.1081, 0.5077, 0.3493]],\n", 101 | "\n", 102 | " [[-0.5337, -0.1051, 0.5085, 0.3508],\n", 103 | " [-0.5323, -0.1080, 0.5084, 0.3508],\n", 104 | " [-0.5323, -0.1079, 0.5084, 0.3506],\n", 105 | " [-0.5297, -0.1076, 0.5074, 0.3471],\n", 106 | " [-0.5311, -0.1066, 0.5076, 0.3446],\n", 107 | " [-0.5299, -0.1081, 0.5077, 0.3493]]], grad_fn=)\n", 108 | "torch.Size([2, 6, 4])\n" 109 | ] 110 | } 111 | ], 112 | "source": [ 113 | "torch.manual_seed(123)\n", 114 | "inputs = torch.tensor(\n", 115 | " [[0.43, 0.15, 0.89], # Your\n", 116 | " [0.55, 0.87, 0.66], # journey\n", 117 | " [0.57, 0.85, 0.64], # starts\n", 118 | " [0.22, 0.58, 0.33], # with\n", 119 | " [0.77, 0.25, 0.10], # one\n", 120 | " [0.05, 0.80, 0.55] # step\n", 121 | " ])\n", 122 | "batch = torch.stack((inputs, inputs), dim=0) # [2, 6, 3]\n", 123 | "print(batch.shape)\n", 124 | "\n", 125 | "dim_in, dim_out = batch.shape[-1], 2\n", 126 | "context_length = batch.shape[1]\n", 127 | "\n", 128 | "mha = MultiheadAttention(dim_in, dim_out, context_length, 0.0, num_heads=2)\n", 129 | "context_vecs = mha(batch)\n", 130 | "\n", 131 | "print(context_vecs)\n", 132 | "print(context_vecs.shape)" 133 | ] 134 | }, 135 | { 136 | "cell_type": "code", 137 | "execution_count": 16, 138 | "id": "66e0cdae-1ea4-40d3-888a-8e1f16b9d709", 139 | "metadata": {}, 140 | "outputs": [], 141 | "source": [ 142 | "class MultiheadAttention(nn.Module):\n", 143 | " def __init__(self, dim_in, dim_out,\n", 144 | " context_length, dropout, num_heads, qkv_bias=False):\n", 145 | " super().__init__()\n", 146 | "\n", 147 | " assert dim_out % num_heads == 0, \\\n", 148 | " \"dim_out must be divisible by num_heads\"\n", 149 | "\n", 150 | " self.dim_out = dim_out\n", 151 | " self.num_heads = num_heads\n", 152 | " self.head_dim = dim_out // num_heads\n", 153 | "\n", 154 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 155 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 156 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 157 | "\n", 158 | " self.out_proj = nn.Linear(dim_out, dim_out)\n", 159 | " self.dropout = nn.Dropout(dropout)\n", 160 | "\n", 161 | " self.register_buffer(\n", 162 | " \"mask\",\n", 163 | " torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", 164 | " )\n", 165 | "\n", 166 | " def split_heads(self, x, batch_size, num_tokens):\n", 167 | " x = x.view(batch_size, num_tokens, self.num_heads, self.head_dim)\n", 168 | " x = x.transpose(1, 2)\n", 169 | " return x\n", 170 | "\n", 171 | " def forward(self, x):\n", 172 | " b, num_tokens, dim_in = x.shape\n", 173 | " \n", 174 | " keys = self.W_key(x)\n", 175 | " queries = self.W_query(x)\n", 176 | " values = self.W_value(x)\n", 177 | "\n", 178 | " keys = self.split_heads(keys, b, num_tokens)\n", 179 | " queries = self.split_heads(queries, b, num_tokens)\n", 180 | " values = self.split_heads(values, b, num_tokens)\n", 181 | "\n", 182 | " attn_scores = queries @ keys.transpose(2, 3)\n", 183 | " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", 184 | " attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)\n", 185 | " \n", 186 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 187 | " attn_weights = self.dropout(attn_weights)\n", 188 | "\n", 189 | " context_vecs = (attn_weights @ values).transpose(1, 2)\n", 190 | " context_vecs = context_vecs.contiguous().view(b, num_tokens, self.dim_out)\n", 191 | "\n", 192 | " context_vecs = self.out_proj(context_vecs)\n", 193 | " return context_vecs\n", 194 | " " 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": 18, 200 | "id": "f817617e-4c9b-4e63-8896-e0efaf658178", 201 | "metadata": {}, 202 | "outputs": [ 203 | { 204 | "name": "stdout", 205 | "output_type": "stream", 206 | "text": [ 207 | "tensor([[[0.3190, 0.4858],\n", 208 | " [0.2943, 0.3897],\n", 209 | " [0.2856, 0.3593],\n", 210 | " [0.2693, 0.3873],\n", 211 | " [0.2639, 0.3928],\n", 212 | " [0.2575, 0.4028]],\n", 213 | "\n", 214 | " [[0.3190, 0.4858],\n", 215 | " [0.2943, 0.3897],\n", 216 | " [0.2856, 0.3593],\n", 217 | " [0.2693, 0.3873],\n", 218 | " [0.2639, 0.3928],\n", 219 | " [0.2575, 0.4028]]], grad_fn=)\n", 220 | "torch.Size([2, 6, 2])\n" 221 | ] 222 | } 223 | ], 224 | "source": [ 225 | "torch.manual_seed(123)\n", 226 | "mha = MultiheadAttention(3, 2, 6, 0, 2)\n", 227 | "\n", 228 | "context_vecs = mha(batch)\n", 229 | "print(context_vecs)\n", 230 | "print(context_vecs.shape)" 231 | ] 232 | }, 233 | { 234 | "cell_type": "code", 235 | "execution_count": null, 236 | "id": "5f589666-6879-40d5-b0d7-ad25204162b5", 237 | "metadata": {}, 238 | "outputs": [], 239 | "source": [] 240 | } 241 | ], 242 | "metadata": { 243 | "kernelspec": { 244 | "display_name": "Python 3 (ipykernel)", 245 | "language": "python", 246 | "name": "python3" 247 | }, 248 | "language_info": { 249 | "codemirror_mode": { 250 | "name": "ipython", 251 | "version": 3 252 | }, 253 | "file_extension": ".py", 254 | "mimetype": "text/x-python", 255 | "name": "python", 256 | "nbconvert_exporter": "python", 257 | "pygments_lexer": "ipython3", 258 | "version": "3.11.4" 259 | } 260 | }, 261 | "nbformat": 4, 262 | "nbformat_minor": 5 263 | } 264 | -------------------------------------------------------------------------------- /4.6_GPT_model.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "330e411f-edd1-405c-9997-8d57c1780ae2", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import torch\n", 11 | "import torch.nn as nn" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 2, 17 | "id": "da966431-b9d3-4506-a6c2-c24bbfb51ce6", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "class MultiheadAttention(nn.Module):\n", 22 | " def __init__(self, dim_in, dim_out,\n", 23 | " context_length, dropout, num_heads, qkv_bias=False):\n", 24 | " super().__init__()\n", 25 | "\n", 26 | " assert dim_out % num_heads == 0, \\\n", 27 | " \"dim_out must be divisible by num_heads\"\n", 28 | "\n", 29 | " self.dim_out = dim_out\n", 30 | " self.num_heads = num_heads\n", 31 | " self.head_dim = dim_out // num_heads\n", 32 | "\n", 33 | " self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 34 | " self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 35 | " self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias)\n", 36 | "\n", 37 | " self.out_proj = nn.Linear(dim_out, dim_out)\n", 38 | " self.dropout = nn.Dropout(dropout)\n", 39 | "\n", 40 | " self.register_buffer(\n", 41 | " \"mask\",\n", 42 | " torch.triu(torch.ones(context_length, context_length), diagonal=1)\n", 43 | " )\n", 44 | "\n", 45 | " def split_heads(self, x, batch_size, num_tokens):\n", 46 | " x = x.view(batch_size, num_tokens, self.num_heads, self.head_dim)\n", 47 | " x = x.transpose(1, 2)\n", 48 | " return x\n", 49 | "\n", 50 | " def forward(self, x):\n", 51 | " b, num_tokens, dim_in = x.shape\n", 52 | " \n", 53 | " keys = self.W_key(x)\n", 54 | " queries = self.W_query(x)\n", 55 | " values = self.W_value(x)\n", 56 | "\n", 57 | " keys = self.split_heads(keys, b, num_tokens)\n", 58 | " queries = self.split_heads(queries, b, num_tokens)\n", 59 | " values = self.split_heads(values, b, num_tokens)\n", 60 | "\n", 61 | " attn_scores = queries @ keys.transpose(2, 3)\n", 62 | " mask_bool = self.mask.bool()[:num_tokens, :num_tokens]\n", 63 | " attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf)\n", 64 | " \n", 65 | " attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)\n", 66 | " attn_weights = self.dropout(attn_weights)\n", 67 | "\n", 68 | " context_vecs = (attn_weights @ values).transpose(1, 2)\n", 69 | " context_vecs = context_vecs.contiguous().view(b, num_tokens, self.dim_out)\n", 70 | "\n", 71 | " context_vecs = self.out_proj(context_vecs)\n", 72 | " return context_vecs\n", 73 | " " 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 3, 79 | "id": "51385151-a8fe-4d7a-9187-69a719fa6673", 80 | "metadata": {}, 81 | "outputs": [], 82 | "source": [ 83 | "class LayerNorm(nn.Module):\n", 84 | " def __init__(self, emb_dim):\n", 85 | " super().__init__()\n", 86 | "\n", 87 | " self.eps = 1e-5\n", 88 | " self.scale = nn.Parameter(torch.ones(emb_dim))\n", 89 | " self.shift = nn.Parameter(torch.zeros(emb_dim))\n", 90 | "\n", 91 | " def forward(self, x):\n", 92 | " mean = x.mean(dim=-1, keepdim=True)\n", 93 | " var = x.var(dim=-1, keepdim=True, unbiased=False)\n", 94 | "\n", 95 | " norm_x = (x - mean) / torch.sqrt(var + self.eps)\n", 96 | " return self.scale * norm_x + self.shift" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": 4, 102 | "id": "1e2f7aeb-66d3-4726-8dea-f4d5264e9d79", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "class GELU(nn.Module):\n", 107 | " def __init__(self):\n", 108 | " super().__init__()\n", 109 | "\n", 110 | " def forward(self, x):\n", 111 | " return 0.5 * x * (1 + torch.tanh((\n", 112 | " torch.sqrt(torch.tensor(2 / torch.pi)) * \n", 113 | " (x + 0.044715 * x**3))\n", 114 | " ))" 115 | ] 116 | }, 117 | { 118 | "cell_type": "code", 119 | "execution_count": 5, 120 | "id": "dbdcbaf5-e50f-4c51-8449-e000ab640b93", 121 | "metadata": {}, 122 | "outputs": [], 123 | "source": [ 124 | "class FeedForward(nn.Module):\n", 125 | " def __init__(self, cfg):\n", 126 | " super().__init__()\n", 127 | "\n", 128 | " self.layers = nn.Sequential(\n", 129 | " nn.Linear(cfg[\"emb_dim\"], 4*cfg[\"emb_dim\"]),\n", 130 | " GELU(),\n", 131 | " nn.Linear(4*cfg[\"emb_dim\"], cfg[\"emb_dim\"])\n", 132 | " )\n", 133 | "\n", 134 | " def forward(self, x):\n", 135 | " return self.layers(x)" 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": 12, 141 | "id": "dfbd7c58-9ed5-4655-8c12-a7641ea65708", 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "class TransformerBlock(nn.Module):\n", 146 | " def __init__(self, cfg):\n", 147 | " super().__init__()\n", 148 | "\n", 149 | " self.attn = MultiheadAttention(\n", 150 | " dim_in = cfg[\"emb_dim\"],\n", 151 | " dim_out = cfg[\"emb_dim\"],\n", 152 | " context_length = cfg[\"context_length\"],\n", 153 | " num_heads = cfg[\"num_heads\"],\n", 154 | " dropout = cfg[\"drop_rate\"],\n", 155 | " qkv_bias = cfg[\"qkv_bias\"]\n", 156 | " )\n", 157 | "\n", 158 | " self.ffn = FeedForward(cfg)\n", 159 | " self.norm1 = LayerNorm(cfg[\"emb_dim\"])\n", 160 | " self.norm2 = LayerNorm(cfg[\"emb_dim\"])\n", 161 | " self.drop_shortcut = nn.Dropout(cfg[\"drop_rate\"])\n", 162 | "\n", 163 | " def forward(self, x):\n", 164 | " shortcut = x\n", 165 | " x = self.norm1(x)\n", 166 | " x = self.attn(x)\n", 167 | " x = self.drop_shortcut(x)\n", 168 | " x = x + shortcut\n", 169 | "\n", 170 | " shortcut = x\n", 171 | " x = self.norm2(x)\n", 172 | " x = self.ffn(x)\n", 173 | " x = self.drop_shortcut(x)\n", 174 | " x = x + shortcut\n", 175 | " return x" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 13, 181 | "id": "50583a2d-55a3-4e0f-8e36-21963713affd", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "GPT_CONFIG_124M = {\n", 186 | " \"vocab_size\": 50257,\n", 187 | " \"context_length\": 1024,\n", 188 | " \"emb_dim\": 768,\n", 189 | " \"num_heads\": 12,\n", 190 | " \"n_layers\": 12,\n", 191 | " \"drop_rate\": 0.1,\n", 192 | " \"qkv_bias\": False\n", 193 | "}" 194 | ] 195 | }, 196 | { 197 | "cell_type": "code", 198 | "execution_count": 14, 199 | "id": "3286d9f2-38ae-43ea-8b37-d4ddc4c2eaec", 200 | "metadata": {}, 201 | "outputs": [ 202 | { 203 | "name": "stdout", 204 | "output_type": "stream", 205 | "text": [ 206 | "torch.Size([2, 4, 768])\n", 207 | "torch.Size([2, 4, 768])\n" 208 | ] 209 | } 210 | ], 211 | "source": [ 212 | "torch.manual_seed(123)\n", 213 | "x = torch.rand(2, 4, 768)\n", 214 | "block = TransformerBlock(GPT_CONFIG_124M)\n", 215 | "out = block(x)\n", 216 | "\n", 217 | "print(x.shape)\n", 218 | "print(out.shape)" 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": 19, 224 | "id": "83d75dcd-356c-4cf4-a382-8863f41bfc80", 225 | "metadata": {}, 226 | "outputs": [], 227 | "source": [ 228 | "class GPTModel(nn.Module):\n", 229 | " def __init__(self, cfg):\n", 230 | " super().__init__()\n", 231 | "\n", 232 | " self.tok_emb = nn.Embedding(cfg[\"vocab_size\"], cfg[\"emb_dim\"])\n", 233 | " self.pos_emb = nn.Embedding(cfg[\"context_length\"], cfg[\"emb_dim\"])\n", 234 | " self.drop_emb = nn.Dropout(cfg[\"drop_rate\"])\n", 235 | "\n", 236 | " self.trans_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg[\"n_layers\"])])\n", 237 | "\n", 238 | " self.final_norm = LayerNorm(cfg[\"emb_dim\"])\n", 239 | " self.out_head = nn.Linear(cfg[\"emb_dim\"], cfg[\"vocab_size\"], bias=False)\n", 240 | "\n", 241 | " def forward(self, in_idx):\n", 242 | " batch_size, seq_len = in_idx.shape\n", 243 | " token_embeds = self.tok_emb(in_idx)\n", 244 | " pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device))\n", 245 | "\n", 246 | " x = token_embeds + pos_embeds\n", 247 | " x = self.drop_emb(x)\n", 248 | " x = self.trans_blocks(x)\n", 249 | " x = self.final_norm(x)\n", 250 | " logits = self.out_head(x)\n", 251 | " return logits" 252 | ] 253 | }, 254 | { 255 | "cell_type": "code", 256 | "execution_count": 20, 257 | "id": "b813a642-76be-418e-8fed-b1df129593a0", 258 | "metadata": {}, 259 | "outputs": [], 260 | "source": [ 261 | "torch.manual_seed(123)\n", 262 | "model = GPTModel(GPT_CONFIG_124M)" 263 | ] 264 | }, 265 | { 266 | "cell_type": "code", 267 | "execution_count": 21, 268 | "id": "97c57f1e-2337-4022-81f2-97f0cc196da1", 269 | "metadata": {}, 270 | "outputs": [ 271 | { 272 | "name": "stdout", 273 | "output_type": "stream", 274 | "text": [ 275 | "torch.Size([2, 4, 50257])\n", 276 | "tensor([[[ 0.3613, 0.4222, -0.0711, ..., 0.3483, 0.4661, -0.2838],\n", 277 | " [-0.1792, -0.5660, -0.9485, ..., 0.0477, 0.5181, -0.3168],\n", 278 | " [ 0.7120, 0.0332, 0.1085, ..., 0.1018, -0.4327, -0.2553],\n", 279 | " [-1.0076, 0.3418, -0.1190, ..., 0.7195, 0.4023, 0.0532]],\n", 280 | "\n", 281 | " [[-0.2564, 0.0900, 0.0335, ..., 0.2659, 0.4454, -0.6806],\n", 282 | " [ 0.1230, 0.3653, -0.2074, ..., 0.7705, 0.2710, 0.2246],\n", 283 | " [ 1.0558, 1.0318, -0.2800, ..., 0.6936, 0.3205, -0.3178],\n", 284 | " [-0.1565, 0.3926, 0.3288, ..., 1.2630, -0.1858, 0.0388]]],\n", 285 | " grad_fn=)\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "batch = torch.tensor([\n", 291 | " [6109, 3626, 6100, 345],\n", 292 | " [6109, 1110, 6622, 257]\n", 293 | "])\n", 294 | "\n", 295 | "out = model(batch)\n", 296 | "\n", 297 | "print(out.shape)\n", 298 | "print(out)" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": 22, 304 | "id": "0e336eb0-2506-44ff-a55e-740d0e788d6f", 305 | "metadata": {}, 306 | "outputs": [ 307 | { 308 | "name": "stdout", 309 | "output_type": "stream", 310 | "text": [ 311 | "163009536\n" 312 | ] 313 | } 314 | ], 315 | "source": [ 316 | "total_params = sum(p.numel() for p in model.parameters())\n", 317 | "print(total_params)" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": 23, 323 | "id": "4f5c0505-4ff4-4925-97b1-3845a588fc36", 324 | "metadata": {}, 325 | "outputs": [ 326 | { 327 | "name": "stdout", 328 | "output_type": "stream", 329 | "text": [ 330 | "621.83203125\n" 331 | ] 332 | } 333 | ], 334 | "source": [ 335 | "total_size_type = total_params * 4\n", 336 | "total_size_mb = total_size_type / (1024 * 1024)\n", 337 | "print(total_size_mb)" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 24, 343 | "id": "e9e9fab4-87df-4aee-bb3f-4b9f2147ad48", 344 | "metadata": {}, 345 | "outputs": [ 346 | { 347 | "data": { 348 | "text/plain": [ 349 | "'\\n1. Encodes text input into four token IDs;\\n2. The GPT model returns a matrix consisting of four vectors, where each vector has 50257 dimensions;\\n3. Extract the last vector, which corresponds to the next token that the GPT model is supposed to generate;\\n4. Converts logits into probability distribution using the softmax function;\\n5. Indentifies the index position of the largest value, which also represents the token ID;\\n6. Appends token to the previous inputs for the next round.\\n'" 350 | ] 351 | }, 352 | "execution_count": 24, 353 | "metadata": {}, 354 | "output_type": "execute_result" 355 | } 356 | ], 357 | "source": [ 358 | "\"\"\"\n", 359 | "1. Encodes text input into four token IDs;\n", 360 | "2. The GPT model returns a matrix consisting of four vectors, where each vector has 50257 dimensions;\n", 361 | "3. Extract the last vector, which corresponds to the next token that the GPT model is supposed to generate;\n", 362 | "4. Converts logits into probability distribution using the softmax function;\n", 363 | "5. Indentifies the index position of the largest value, which also represents the token ID;\n", 364 | "6. Appends token to the previous inputs for the next round.\n", 365 | "\"\"\"" 366 | ] 367 | }, 368 | { 369 | "cell_type": "code", 370 | "execution_count": 25, 371 | "id": "4db4ae6c-238e-4159-a223-957bd2343cfd", 372 | "metadata": {}, 373 | "outputs": [], 374 | "source": [ 375 | "def generate_text_simple(model, idx, max_new_tokens, context_size):\n", 376 | " for _ in range(max_new_tokens):\n", 377 | " idx_second = idx[:, -context_size:]\n", 378 | " with torch.no_grad():\n", 379 | " logits = model(idx_second)\n", 380 | "\n", 381 | " logits = logits[:,-1,:]\n", 382 | " probs = torch.softmax(logits, dim=-1)\n", 383 | " idx_next = torch.argmax(probs, dim=-1, keepdim=True)\n", 384 | "\n", 385 | " idx = torch.cat([idx, idx_next], dim=1)\n", 386 | " return idx" 387 | ] 388 | }, 389 | { 390 | "cell_type": "code", 391 | "execution_count": 27, 392 | "id": "e5e9e385-8118-4a3b-a4bf-c5231609c36d", 393 | "metadata": {}, 394 | "outputs": [ 395 | { 396 | "name": "stdout", 397 | "output_type": "stream", 398 | "text": [ 399 | "tensor([[15496, 11, 314, 716]])\n", 400 | "torch.Size([1, 4])\n" 401 | ] 402 | } 403 | ], 404 | "source": [ 405 | "import tiktoken\n", 406 | "\n", 407 | "start_text = \"Hello, I am\"\n", 408 | "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", 409 | "encoded = tokenizer.encode(start_text)\n", 410 | "\n", 411 | "encoded = torch.tensor(encoded).unsqueeze(0)\n", 412 | "print(encoded)\n", 413 | "print(encoded.shape)" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": 29, 419 | "id": "4407c20f-50a1-4d75-9361-ae33203da971", 420 | "metadata": {}, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "tensor([[15496, 11, 314, 716, 27018, 24086, 47843, 30961, 42348, 7267]])\n", 427 | "10\n" 428 | ] 429 | } 430 | ], 431 | "source": [ 432 | "model.eval()\n", 433 | "out = generate_text_simple(model, encoded, 6, GPT_CONFIG_124M[\"context_length\"])\n", 434 | "\n", 435 | "print(out)\n", 436 | "print(len(out[0]))" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 30, 442 | "id": "aac4c550-8ca5-4082-a5e4-e2e0dbadaa4b", 443 | "metadata": {}, 444 | "outputs": [ 445 | { 446 | "name": "stdout", 447 | "output_type": "stream", 448 | "text": [ 449 | "Hello, I am Featureiman Byeswickattribute argue\n" 450 | ] 451 | } 452 | ], 453 | "source": [ 454 | "decoded_txt = tokenizer.decode(out.squeeze(0).tolist())\n", 455 | "print(decoded_txt)" 456 | ] 457 | }, 458 | { 459 | "cell_type": "code", 460 | "execution_count": null, 461 | "id": "95c7a746-03ef-491c-b67a-3f862c959c96", 462 | "metadata": {}, 463 | "outputs": [], 464 | "source": [] 465 | } 466 | ], 467 | "metadata": { 468 | "kernelspec": { 469 | "display_name": "Python 3 (ipykernel)", 470 | "language": "python", 471 | "name": "python3" 472 | }, 473 | "language_info": { 474 | "codemirror_mode": { 475 | "name": "ipython", 476 | "version": 3 477 | }, 478 | "file_extension": ".py", 479 | "mimetype": "text/x-python", 480 | "name": "python", 481 | "nbconvert_exporter": "python", 482 | "pygments_lexer": "ipython3", 483 | "version": "3.11.4" 484 | } 485 | }, 486 | "nbformat": 4, 487 | "nbformat_minor": 5 488 | } 489 | -------------------------------------------------------------------------------- /5.5_Loading_pretrained_model_from_OpenAI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 4, 6 | "id": "f299eb3e-c51c-49b8-a44d-9207adea6208", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import urllib.request\n", 11 | "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch05/01_main-chapter-code/gpt_download.py\"" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": 5, 17 | "id": "c356e121-7426-4f33-9907-2d059dabe120", 18 | "metadata": {}, 19 | "outputs": [ 20 | { 21 | "data": { 22 | "text/plain": [ 23 | "('ch05/gpt_download.py', )" 24 | ] 25 | }, 26 | "execution_count": 5, 27 | "metadata": {}, 28 | "output_type": "execute_result" 29 | } 30 | ], 31 | "source": [ 32 | "filename = \"ch05/\" + url.split(\"/\")[-1]\n", 33 | "urllib.request.urlretrieve(url, filename)" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 6, 39 | "id": "d89e38bf-e35a-4666-b718-903b590bf05a", 40 | "metadata": {}, 41 | "outputs": [ 42 | { 43 | "name": "stdout", 44 | "output_type": "stream", 45 | "text": [ 46 | "File already exists and is up-to-date: ch05/gpt2/124M/checkpoint\n", 47 | "File already exists and is up-to-date: ch05/gpt2/124M/encoder.json\n", 48 | "File already exists and is up-to-date: ch05/gpt2/124M/hparams.json\n", 49 | "File already exists and is up-to-date: ch05/gpt2/124M/model.ckpt.data-00000-of-00001\n", 50 | "File already exists and is up-to-date: ch05/gpt2/124M/model.ckpt.index\n", 51 | "File already exists and is up-to-date: ch05/gpt2/124M/model.ckpt.meta\n", 52 | "File already exists and is up-to-date: ch05/gpt2/124M/vocab.bpe\n" 53 | ] 54 | } 55 | ], 56 | "source": [ 57 | "from ch05.gpt_download import download_and_load_gpt2\n", 58 | "settings, params = download_and_load_gpt2(model_size=\"124M\", models_dir=\"ch05/gpt2\")" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": 7, 64 | "id": "f33e2b02-4712-474d-ae05-c7fdca8ad965", 65 | "metadata": {}, 66 | "outputs": [ 67 | { 68 | "name": "stdout", 69 | "output_type": "stream", 70 | "text": [ 71 | "{'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 768, 'n_head': 12, 'n_layer': 12}\n", 72 | "dict_keys(['blocks', 'b', 'g', 'wpe', 'wte'])\n" 73 | ] 74 | } 75 | ], 76 | "source": [ 77 | "print(settings)\n", 78 | "print(params.keys())" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": 8, 84 | "id": "1d7a7e65-2d10-4950-a700-3f13fcf208b7", 85 | "metadata": {}, 86 | "outputs": [ 87 | { 88 | "name": "stdout", 89 | "output_type": "stream", 90 | "text": [ 91 | "[[-0.11010301 -0.03926672 0.03310751 ... -0.1363697 0.01506208\n", 92 | " 0.04531523]\n", 93 | " [ 0.04034033 -0.04861503 0.04624869 ... 0.08605453 0.00253983\n", 94 | " 0.04318958]\n", 95 | " [-0.12746179 0.04793796 0.18410145 ... 0.08991534 -0.12972379\n", 96 | " -0.08785918]\n", 97 | " ...\n", 98 | " [-0.04453601 -0.05483596 0.01225674 ... 0.10435229 0.09783269\n", 99 | " -0.06952604]\n", 100 | " [ 0.1860082 0.01665728 0.04611587 ... -0.09625227 0.07847701\n", 101 | " -0.02245961]\n", 102 | " [ 0.05135201 -0.02768905 0.0499369 ... 0.00704835 0.15519823\n", 103 | " 0.12067825]]\n", 104 | "(50257, 768)\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "print(params[\"wte\"])\n", 110 | "print(params[\"wte\"].shape)" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": 9, 116 | "id": "39dff3ec-5244-4ece-9b87-0eb196401ace", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "model_configs = {\n", 121 | " \"gpt2-small (124M)\": {\"emb_dim\": 768, \"n_layers\": 12, \"n_heads\": 12},\n", 122 | " \"gpt2-media (355M)\": {\"emb_dim\": 1024, \"n_layers\": 24, \"n_heads\": 16},\n", 123 | " \"gpt2-large (774M)\": {\"emb_dim\": 1280, \"n_layers\": 36, \"n_heads\": 20},\n", 124 | " \"gpt2-xl (1558M)\": {\"emd_dim\": 1600, \"n_layers\": 48, \"n_heads\": 25}\n", 125 | "}" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": 10, 131 | "id": "de281e51-ab20-4402-8641-d7e1ab199bc8", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "from codes.configs import GPT_CONFIG_124M\n", 136 | "\n", 137 | "model_name = \"gpt2-small (124M)\"\n", 138 | "NEW_CONFIG = GPT_CONFIG_124M.copy()\n", 139 | "NEW_CONFIG.update(model_configs[model_name])\n", 140 | "\n", 141 | "NEW_CONFIG.update({\"context_length\": 1024})" 142 | ] 143 | }, 144 | { 145 | "cell_type": "code", 146 | "execution_count": 11, 147 | "id": "6d7de070-2626-4b0a-91be-0a759063031b", 148 | "metadata": {}, 149 | "outputs": [], 150 | "source": [ 151 | "NEW_CONFIG.update({\"qkv_bias\": True})" 152 | ] 153 | }, 154 | { 155 | "cell_type": "code", 156 | "execution_count": 17, 157 | "id": "34ada927-c69f-4b5a-9c0b-0b21c928c31c", 158 | "metadata": {}, 159 | "outputs": [], 160 | "source": [ 161 | "from codes.gpt_model import GPTModel\n", 162 | "gpt = GPTModel(NEW_CONFIG)\n", 163 | "#gpt.eval()" 164 | ] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "execution_count": 18, 169 | "id": "c90bd865-55e3-4b11-b60c-007e338f5784", 170 | "metadata": {}, 171 | "outputs": [], 172 | "source": [ 173 | "import torch\n", 174 | "\n", 175 | "def assign(left, right):\n", 176 | " assert left.shape == right.shape, f\"Shape mismatch. Left: {left.shape}, Right: {right.shape}\"\n", 177 | " return torch.nn.Parameter(torch.tensor(right))" 178 | ] 179 | }, 180 | { 181 | "cell_type": "code", 182 | "execution_count": 19, 183 | "id": "399e21bc-41e9-4617-95d8-140efbf0b803", 184 | "metadata": {}, 185 | "outputs": [], 186 | "source": [ 187 | "import numpy as np\n", 188 | "\n", 189 | "def load_weights_into_gpt(gpt, params):\n", 190 | " gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params[\"wpe\"])\n", 191 | " gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params[\"wte\"])\n", 192 | "\n", 193 | " for b in range(len(params[\"blocks\"])):\n", 194 | " q_w, k_w, v_w = np.split(\n", 195 | " (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"w\"], 3, axis=-1\n", 196 | " )\n", 197 | " gpt.trans_blocks[b].attn.W_query.weight = assign(\n", 198 | " gpt.trans_blocks[b].attn.W_query.weight, q_w.T\n", 199 | " )\n", 200 | " gpt.trans_blocks[b].attn.W_key.weight = assign(\n", 201 | " gpt.trans_blocks[b].attn.W_key.weight, k_w.T\n", 202 | " )\n", 203 | " gpt.trans_blocks[b].attn.W_value.weight = assign(\n", 204 | " gpt.trans_blocks[b].attn.W_value.weight, v_w.T\n", 205 | " )\n", 206 | "\n", 207 | " q_b, k_b, v_b = np.split(\n", 208 | " (params[\"blocks\"][b][\"attn\"][\"c_attn\"])[\"b\"], 3, axis=-1\n", 209 | " )\n", 210 | " gpt.trans_blocks[b].attn.W_query.bias = assign(\n", 211 | " gpt.trans_blocks[b].attn.W_query.bias, q_b\n", 212 | " )\n", 213 | " gpt.trans_blocks[b].attn.W_key.bias = assign(\n", 214 | " gpt.trans_blocks[b].attn.W_key.bias, k_b\n", 215 | " )\n", 216 | " gpt.trans_blocks[b].attn.W_value.bias = assign(\n", 217 | " gpt.trans_blocks[b].attn.W_value.bias, v_b\n", 218 | " )\n", 219 | "\n", 220 | " gpt.trans_blocks[b].attn.out_proj.weight = assign(\n", 221 | " gpt.trans_blocks[b].attn.out_proj.weight,\n", 222 | " params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"w\"].T\n", 223 | " )\n", 224 | " gpt.trans_blocks[b].attn.out_proj.bias = assign(\n", 225 | " gpt.trans_blocks[b].attn.out_proj.bias,\n", 226 | " params[\"blocks\"][b][\"attn\"][\"c_proj\"][\"b\"]\n", 227 | " )\n", 228 | "\n", 229 | " gpt.trans_blocks[b].ffn.layers[0].weight = assign(\n", 230 | " gpt.trans_blocks[b].ffn.layers[0].weight, \n", 231 | " params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"w\"].T\n", 232 | " )\n", 233 | " gpt.trans_blocks[b].ffn.layers[0].bias = assign(\n", 234 | " gpt.trans_blocks[b].ffn.layers[0].bias, \n", 235 | " params[\"blocks\"][b][\"mlp\"][\"c_fc\"][\"b\"]\n", 236 | " )\n", 237 | " gpt.trans_blocks[b].ffn.layers[2].weight = assign(\n", 238 | " gpt.trans_blocks[b].ffn.layers[2].weight, \n", 239 | " params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"w\"].T\n", 240 | " )\n", 241 | " gpt.trans_blocks[b].ffn.layers[2].bias = assign(\n", 242 | " gpt.trans_blocks[b].ffn.layers[2].bias, \n", 243 | " params[\"blocks\"][b][\"mlp\"][\"c_proj\"][\"b\"]\n", 244 | " )\n", 245 | "\n", 246 | " gpt.trans_blocks[b].norm1.scale = assign(\n", 247 | " gpt.trans_blocks[b].norm1.scale,\n", 248 | " params[\"blocks\"][b][\"ln_1\"][\"g\"]\n", 249 | " )\n", 250 | " gpt.trans_blocks[b].norm1.shift = assign(\n", 251 | " gpt.trans_blocks[b].norm1.shift,\n", 252 | " params[\"blocks\"][b][\"ln_1\"][\"b\"]\n", 253 | " )\n", 254 | " gpt.trans_blocks[b].norm2.scale = assign(\n", 255 | " gpt.trans_blocks[b].norm2.scale,\n", 256 | " params[\"blocks\"][b][\"ln_2\"][\"g\"]\n", 257 | " )\n", 258 | " gpt.trans_blocks[b].norm2.shift = assign(\n", 259 | " gpt.trans_blocks[b].norm2.shift,\n", 260 | " params[\"blocks\"][b][\"ln_2\"][\"b\"]\n", 261 | " )\n", 262 | " gpt.final_norm.scale = assign(\n", 263 | " gpt.final_norm.scale, params[\"g\"]\n", 264 | " )\n", 265 | " gpt.final_norm.shift = assign(\n", 266 | " gpt.final_norm.shift, params[\"b\"]\n", 267 | " )\n", 268 | "\n", 269 | " gpt.out_head.weight = assign(\n", 270 | " gpt.out_head.weight, params[\"wte\"]\n", 271 | " )" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": 20, 277 | "id": "1775b290-8f50-41be-920b-e5ecf1df4f1a", 278 | "metadata": {}, 279 | "outputs": [], 280 | "source": [ 281 | "load_weights_into_gpt(gpt, params)" 282 | ] 283 | }, 284 | { 285 | "cell_type": "code", 286 | "execution_count": 21, 287 | "id": "c6c5aa89-2108-4e21-84ed-933f961ed03b", 288 | "metadata": {}, 289 | "outputs": [ 290 | { 291 | "data": { 292 | "text/plain": [ 293 | "GPTModel(\n", 294 | " (tok_emb): Embedding(50257, 768)\n", 295 | " (pos_emb): Embedding(1024, 768)\n", 296 | " (drop_emb): Dropout(p=0.1, inplace=False)\n", 297 | " (trans_blocks): Sequential(\n", 298 | " (0): TransformerBlock(\n", 299 | " (attn): MultiheadAttention(\n", 300 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 301 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 302 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 303 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 304 | " (dropout): Dropout(p=0.1, inplace=False)\n", 305 | " )\n", 306 | " (ffn): FeedForward(\n", 307 | " (layers): Sequential(\n", 308 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 309 | " (1): GELU()\n", 310 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 311 | " )\n", 312 | " )\n", 313 | " (norm1): LayerNorm()\n", 314 | " (norm2): LayerNorm()\n", 315 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 316 | " )\n", 317 | " (1): TransformerBlock(\n", 318 | " (attn): MultiheadAttention(\n", 319 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 320 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 321 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 322 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 323 | " (dropout): Dropout(p=0.1, inplace=False)\n", 324 | " )\n", 325 | " (ffn): FeedForward(\n", 326 | " (layers): Sequential(\n", 327 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 328 | " (1): GELU()\n", 329 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 330 | " )\n", 331 | " )\n", 332 | " (norm1): LayerNorm()\n", 333 | " (norm2): LayerNorm()\n", 334 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 335 | " )\n", 336 | " (2): TransformerBlock(\n", 337 | " (attn): MultiheadAttention(\n", 338 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 339 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 340 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 341 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 342 | " (dropout): Dropout(p=0.1, inplace=False)\n", 343 | " )\n", 344 | " (ffn): FeedForward(\n", 345 | " (layers): Sequential(\n", 346 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 347 | " (1): GELU()\n", 348 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 349 | " )\n", 350 | " )\n", 351 | " (norm1): LayerNorm()\n", 352 | " (norm2): LayerNorm()\n", 353 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 354 | " )\n", 355 | " (3): TransformerBlock(\n", 356 | " (attn): MultiheadAttention(\n", 357 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 358 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 359 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 360 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 361 | " (dropout): Dropout(p=0.1, inplace=False)\n", 362 | " )\n", 363 | " (ffn): FeedForward(\n", 364 | " (layers): Sequential(\n", 365 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 366 | " (1): GELU()\n", 367 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 368 | " )\n", 369 | " )\n", 370 | " (norm1): LayerNorm()\n", 371 | " (norm2): LayerNorm()\n", 372 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 373 | " )\n", 374 | " (4): TransformerBlock(\n", 375 | " (attn): MultiheadAttention(\n", 376 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 377 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 378 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 379 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 380 | " (dropout): Dropout(p=0.1, inplace=False)\n", 381 | " )\n", 382 | " (ffn): FeedForward(\n", 383 | " (layers): Sequential(\n", 384 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 385 | " (1): GELU()\n", 386 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 387 | " )\n", 388 | " )\n", 389 | " (norm1): LayerNorm()\n", 390 | " (norm2): LayerNorm()\n", 391 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 392 | " )\n", 393 | " (5): TransformerBlock(\n", 394 | " (attn): MultiheadAttention(\n", 395 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 396 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 397 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 398 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 399 | " (dropout): Dropout(p=0.1, inplace=False)\n", 400 | " )\n", 401 | " (ffn): FeedForward(\n", 402 | " (layers): Sequential(\n", 403 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 404 | " (1): GELU()\n", 405 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 406 | " )\n", 407 | " )\n", 408 | " (norm1): LayerNorm()\n", 409 | " (norm2): LayerNorm()\n", 410 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 411 | " )\n", 412 | " (6): TransformerBlock(\n", 413 | " (attn): MultiheadAttention(\n", 414 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 415 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 416 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 417 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 418 | " (dropout): Dropout(p=0.1, inplace=False)\n", 419 | " )\n", 420 | " (ffn): FeedForward(\n", 421 | " (layers): Sequential(\n", 422 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 423 | " (1): GELU()\n", 424 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 425 | " )\n", 426 | " )\n", 427 | " (norm1): LayerNorm()\n", 428 | " (norm2): LayerNorm()\n", 429 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 430 | " )\n", 431 | " (7): TransformerBlock(\n", 432 | " (attn): MultiheadAttention(\n", 433 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 434 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 435 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 436 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 437 | " (dropout): Dropout(p=0.1, inplace=False)\n", 438 | " )\n", 439 | " (ffn): FeedForward(\n", 440 | " (layers): Sequential(\n", 441 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 442 | " (1): GELU()\n", 443 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 444 | " )\n", 445 | " )\n", 446 | " (norm1): LayerNorm()\n", 447 | " (norm2): LayerNorm()\n", 448 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 449 | " )\n", 450 | " (8): TransformerBlock(\n", 451 | " (attn): MultiheadAttention(\n", 452 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 453 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 454 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 455 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 456 | " (dropout): Dropout(p=0.1, inplace=False)\n", 457 | " )\n", 458 | " (ffn): FeedForward(\n", 459 | " (layers): Sequential(\n", 460 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 461 | " (1): GELU()\n", 462 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 463 | " )\n", 464 | " )\n", 465 | " (norm1): LayerNorm()\n", 466 | " (norm2): LayerNorm()\n", 467 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 468 | " )\n", 469 | " (9): TransformerBlock(\n", 470 | " (attn): MultiheadAttention(\n", 471 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 472 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 473 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 474 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 475 | " (dropout): Dropout(p=0.1, inplace=False)\n", 476 | " )\n", 477 | " (ffn): FeedForward(\n", 478 | " (layers): Sequential(\n", 479 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 480 | " (1): GELU()\n", 481 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 482 | " )\n", 483 | " )\n", 484 | " (norm1): LayerNorm()\n", 485 | " (norm2): LayerNorm()\n", 486 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 487 | " )\n", 488 | " (10): TransformerBlock(\n", 489 | " (attn): MultiheadAttention(\n", 490 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 491 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 492 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 493 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 494 | " (dropout): Dropout(p=0.1, inplace=False)\n", 495 | " )\n", 496 | " (ffn): FeedForward(\n", 497 | " (layers): Sequential(\n", 498 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 499 | " (1): GELU()\n", 500 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 501 | " )\n", 502 | " )\n", 503 | " (norm1): LayerNorm()\n", 504 | " (norm2): LayerNorm()\n", 505 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 506 | " )\n", 507 | " (11): TransformerBlock(\n", 508 | " (attn): MultiheadAttention(\n", 509 | " (W_query): Linear(in_features=768, out_features=768, bias=True)\n", 510 | " (W_key): Linear(in_features=768, out_features=768, bias=True)\n", 511 | " (W_value): Linear(in_features=768, out_features=768, bias=True)\n", 512 | " (out_proj): Linear(in_features=768, out_features=768, bias=True)\n", 513 | " (dropout): Dropout(p=0.1, inplace=False)\n", 514 | " )\n", 515 | " (ffn): FeedForward(\n", 516 | " (layers): Sequential(\n", 517 | " (0): Linear(in_features=768, out_features=3072, bias=True)\n", 518 | " (1): GELU()\n", 519 | " (2): Linear(in_features=3072, out_features=768, bias=True)\n", 520 | " )\n", 521 | " )\n", 522 | " (norm1): LayerNorm()\n", 523 | " (norm2): LayerNorm()\n", 524 | " (drop_shortcut): Dropout(p=0.1, inplace=False)\n", 525 | " )\n", 526 | " )\n", 527 | " (final_norm): LayerNorm()\n", 528 | " (out_head): Linear(in_features=768, out_features=50257, bias=False)\n", 529 | ")" 530 | ] 531 | }, 532 | "execution_count": 21, 533 | "metadata": {}, 534 | "output_type": "execute_result" 535 | } 536 | ], 537 | "source": [ 538 | "gpt.eval()" 539 | ] 540 | }, 541 | { 542 | "cell_type": "code", 543 | "execution_count": 22, 544 | "id": "a17c873d-e6b8-47a4-a500-7d87a5e135e4", 545 | "metadata": {}, 546 | "outputs": [], 547 | "source": [ 548 | "import tiktoken\n", 549 | "from codes.utils import text_to_token_ids, token_ids_to_text, generate\n", 550 | "\n", 551 | "tokenizer = tiktoken.get_encoding(\"gpt2\")" 552 | ] 553 | }, 554 | { 555 | "cell_type": "code", 556 | "execution_count": 24, 557 | "id": "23f19a41-596a-4709-9f74-59ac18c43bf7", 558 | "metadata": {}, 559 | "outputs": [ 560 | { 561 | "name": "stdout", 562 | "output_type": "stream", 563 | "text": [ 564 | "Every effort moves you toward finding an ideal new way to practice something!\n", 565 | "\n", 566 | "What makes us want to be on top of that?\n", 567 | "\n", 568 | "\n" 569 | ] 570 | } 571 | ], 572 | "source": [ 573 | "torch.manual_seed(123)\n", 574 | "token_ids = generate(\n", 575 | " model=gpt,\n", 576 | " idx=text_to_token_ids(\"Every effort moves you\", tokenizer=tokenizer).to(\"cpu\"),\n", 577 | " max_new_tokens=25,\n", 578 | " context_size=NEW_CONFIG[\"context_length\"],\n", 579 | " top_k=50,\n", 580 | " temperature=1.5\n", 581 | ")\n", 582 | "\n", 583 | "print(token_ids_to_text(token_ids, tokenizer=tokenizer))" 584 | ] 585 | }, 586 | { 587 | "cell_type": "code", 588 | "execution_count": 26, 589 | "id": "b46ea669-6e46-45c2-8d57-86146fc7ece9", 590 | "metadata": {}, 591 | "outputs": [], 592 | "source": [ 593 | "torch.save(gpt.state_dict(), \"./ch05/gpt2-small-124m-pretrained.pth\")" 594 | ] 595 | }, 596 | { 597 | "cell_type": "code", 598 | "execution_count": null, 599 | "id": "446a2e0a-04a2-4fd5-96bd-5a633ecc0d9f", 600 | "metadata": {}, 601 | "outputs": [], 602 | "source": [] 603 | } 604 | ], 605 | "metadata": { 606 | "kernelspec": { 607 | "display_name": "Python 3 (ipykernel)", 608 | "language": "python", 609 | "name": "python3" 610 | }, 611 | "language_info": { 612 | "codemirror_mode": { 613 | "name": "ipython", 614 | "version": 3 615 | }, 616 | "file_extension": ".py", 617 | "mimetype": "text/x-python", 618 | "name": "python", 619 | "nbconvert_exporter": "python", 620 | "pygments_lexer": "ipython3", 621 | "version": "3.11.4" 622 | } 623 | }, 624 | "nbformat": 4, 625 | "nbformat_minor": 5 626 | } 627 | -------------------------------------------------------------------------------- /5.5_convert-gpt2-media-from-OpenAI.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 2, 6 | "id": "a348dd57-a1f2-47af-8b60-f6df52516389", 7 | "metadata": {}, 8 | "outputs": [ 9 | { 10 | "name": "stdout", 11 | "output_type": "stream", 12 | "text": [ 13 | "File already exists and is up-to-date: ch05/gpt2/355M/checkpoint\n", 14 | "File already exists and is up-to-date: ch05/gpt2/355M/encoder.json\n", 15 | "File already exists and is up-to-date: ch05/gpt2/355M/hparams.json\n", 16 | "File already exists and is up-to-date: ch05/gpt2/355M/model.ckpt.data-00000-of-00001\n", 17 | "File already exists and is up-to-date: ch05/gpt2/355M/model.ckpt.index\n", 18 | "File already exists and is up-to-date: ch05/gpt2/355M/model.ckpt.meta\n", 19 | "File already exists and is up-to-date: ch05/gpt2/355M/vocab.bpe\n" 20 | ] 21 | } 22 | ], 23 | "source": [ 24 | "from ch05.gpt_download import download_and_load_gpt2\n", 25 | "settings, params = download_and_load_gpt2(model_size=\"355M\", models_dir=\"ch05/gpt2\")" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": 3, 31 | "id": "aba395a9-b393-4467-a069-358946eb6bd8", 32 | "metadata": {}, 33 | "outputs": [ 34 | { 35 | "name": "stdout", 36 | "output_type": "stream", 37 | "text": [ 38 | "{'n_vocab': 50257, 'n_ctx': 1024, 'n_embd': 1024, 'n_head': 16, 'n_layer': 24} ['blocks', 'b', 'g', 'wpe', 'wte']\n" 39 | ] 40 | } 41 | ], 42 | "source": [ 43 | "print(settings, [k for k in params])" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "id": "002512de-0a68-443e-a653-1381b1f5384c", 50 | "metadata": {}, 51 | "outputs": [ 52 | { 53 | "name": "stdout", 54 | "output_type": "stream", 55 | "text": [ 56 | "{'vocab_size': 50257, 'context_length': 1024, 'emb_dim': 1024, 'num_heads': 16, 'n_layers': 24, 'drop_rate': 0.0, 'qkv_bias': True}\n" 57 | ] 58 | } 59 | ], 60 | "source": [ 61 | "from codes.gpt_model import GPTModel\n", 62 | "from codes.configs import gpt2_media_config\n", 63 | "\n", 64 | "config = gpt2_media_config()\n", 65 | "print(config)" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": 5, 71 | "id": "58d70d62-a6f4-4195-977c-8a48c3b4426a", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "gpt = GPTModel(config)" 76 | ] 77 | }, 78 | { 79 | "cell_type": "code", 80 | "execution_count": 7, 81 | "id": "d3df77e9-46aa-4d4f-b1b2-dc604404f920", 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "from codes.model_convert import load_weights_into_gpt\n", 86 | "load_weights_into_gpt(gpt, params)" 87 | ] 88 | }, 89 | { 90 | "cell_type": "code", 91 | "execution_count": 9, 92 | "id": "37e3d8b6-5f29-4723-be90-361b91bc3e56", 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "import torch\n", 97 | "torch.save(gpt.state_dict(), \"./ch05/gpt2-media-355m-pretrained.pth\")" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "0706314a-4cb5-4161-b574-45a2f615b78c", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [] 107 | } 108 | ], 109 | "metadata": { 110 | "kernelspec": { 111 | "display_name": "Python 3 (ipykernel)", 112 | "language": "python", 113 | "name": "python3" 114 | }, 115 | "language_info": { 116 | "codemirror_mode": { 117 | "name": "ipython", 118 | "version": 3 119 | }, 120 | "file_extension": ".py", 121 | "mimetype": "text/x-python", 122 | "name": "python", 123 | "nbconvert_exporter": "python", 124 | "pygments_lexer": "ipython3", 125 | "version": "3.11.4" 126 | } 127 | }, 128 | "nbformat": 4, 129 | "nbformat_minor": 5 130 | } 131 | -------------------------------------------------------------------------------- /7.2_Preparing_supervised_instruction_finetuning.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "id": "58e54b19-ce2d-4497-90e0-4d954fc110d1", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import json\n", 11 | "import os\n", 12 | "import urllib" 13 | ] 14 | }, 15 | { 16 | "cell_type": "code", 17 | "execution_count": 2, 18 | "id": "be7f0a4d-92be-43a2-8df0-c09b25f7ad8e", 19 | "metadata": {}, 20 | "outputs": [ 21 | { 22 | "name": "stdout", 23 | "output_type": "stream", 24 | "text": [ 25 | "1100\n" 26 | ] 27 | } 28 | ], 29 | "source": [ 30 | "def download_and_load_file(file_path, url):\n", 31 | " if not os.path.exists(file_path):\n", 32 | " with urllib.request.urlopen(url) as response:\n", 33 | " text_data = response.read().decode(\"utf-8\")\n", 34 | "\n", 35 | " with open(file_path, \"w\", encoding=\"utf-8\") as fout:\n", 36 | " fout.write(text_data)\n", 37 | " #else:\n", 38 | " # with open(file_path, \"r\", encoding=\"utf-8\") as fin:\n", 39 | " # text_data = fin.read()\n", 40 | "\n", 41 | " with open(file_path, \"r\") as fin:\n", 42 | " data = json.load(fin)\n", 43 | "\n", 44 | " return data\n", 45 | "\n", 46 | "file_path = \"ch07/instruction-data.json\"\n", 47 | "url = \"https://raw.githubusercontent.com/rasbt/LLMs-from-scratch/main/ch07/01_main-chapter-code/instruction-data.json\"\n", 48 | "data = download_and_load_file(file_path, url)\n", 49 | "print(len(data))" 50 | ] 51 | }, 52 | { 53 | "cell_type": "code", 54 | "execution_count": 3, 55 | "id": "de02c3e8-e7ff-497e-b475-785dc504bea5", 56 | "metadata": {}, 57 | "outputs": [ 58 | { 59 | "name": "stdout", 60 | "output_type": "stream", 61 | "text": [ 62 | "{'instruction': 'Identify the correct spelling of the following word.', 'input': 'Ocassion', 'output': \"The correct spelling is 'Occasion.'\"}\n" 63 | ] 64 | } 65 | ], 66 | "source": [ 67 | "print(data[50])" 68 | ] 69 | }, 70 | { 71 | "cell_type": "code", 72 | "execution_count": 4, 73 | "id": "a60a19ae-c983-45ef-b58f-b2de0c2cccae", 74 | "metadata": {}, 75 | "outputs": [ 76 | { 77 | "name": "stdout", 78 | "output_type": "stream", 79 | "text": [ 80 | "{'instruction': \"What is an antonym of 'complicated'?\", 'input': '', 'output': \"An antonym of 'complicated' is 'simple'.\"}\n" 81 | ] 82 | } 83 | ], 84 | "source": [ 85 | "print(data[999])" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": 5, 91 | "id": "925085d3-5575-464f-959c-f72c7b29071d", 92 | "metadata": {}, 93 | "outputs": [ 94 | { 95 | "name": "stdout", 96 | "output_type": "stream", 97 | "text": [ 98 | "Bellow is an instruction that describe a task. Write a response that appropriately completes the request.\n", 99 | "\n", 100 | "### Instruction:\n", 101 | "Identify the correct spelling of the following word.\n", 102 | "\n", 103 | "### Input: \n", 104 | "Ocassion\n" 105 | ] 106 | } 107 | ], 108 | "source": [ 109 | "def format_input(entry):\n", 110 | " instruction_text = (\n", 111 | " f\"Bellow is an instruction that describe a task. \"\n", 112 | " f\"Write a response that appropriately completes the request.\"\n", 113 | " f\"\\n\\n### Instruction:\\n{entry['instruction']}\"\n", 114 | " )\n", 115 | "\n", 116 | " input_text = f\"\\n\\n### Input: \\n{entry['input']}\" if entry['input'] else \"\"\n", 117 | "\n", 118 | " return instruction_text + input_text\n", 119 | "\n", 120 | "model_input = format_input(data[50])\n", 121 | "print(model_input)" 122 | ] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": 6, 127 | "id": "630e1f7d-6a05-4ed2-aaa7-c9c9a8170f37", 128 | "metadata": {}, 129 | "outputs": [ 130 | { 131 | "name": "stdout", 132 | "output_type": "stream", 133 | "text": [ 134 | "Bellow is an instruction that describe a task. Write a response that appropriately completes the request.\n", 135 | "\n", 136 | "### Instruction:\n", 137 | "Identify the correct spelling of the following word.\n", 138 | "\n", 139 | "### Input: \n", 140 | "Ocassion\n", 141 | "\n", 142 | "### Response: \n", 143 | "The correct spelling is 'Occasion.'\n" 144 | ] 145 | } 146 | ], 147 | "source": [ 148 | "desired_output = f\"\\n\\n### Response: \\n{data[50]['output']}\"\n", 149 | "print(model_input + desired_output)" 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": 7, 155 | "id": "6d3b8861-d298-4b00-9212-a95ea7013e9a", 156 | "metadata": {}, 157 | "outputs": [ 158 | { 159 | "name": "stdout", 160 | "output_type": "stream", 161 | "text": [ 162 | "Bellow is an instruction that describe a task. Write a response that appropriately completes the request.\n", 163 | "\n", 164 | "### Instruction:\n", 165 | "What is an antonym of 'complicated'?\n", 166 | "\n", 167 | "### Response: \n", 168 | "An antonym of 'complicated' is 'simple'.\n" 169 | ] 170 | } 171 | ], 172 | "source": [ 173 | "model_input = format_input(data[999])\n", 174 | "desired_output = f\"\\n\\n### Response: \\n{data[999]['output']}\"\n", 175 | "print(model_input + desired_output)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "code", 180 | "execution_count": 8, 181 | "id": "d74e8190-a1f0-47b3-8d6d-6f8cade5e3c0", 182 | "metadata": {}, 183 | "outputs": [], 184 | "source": [ 185 | "train_portion = int(len(data) * 0.85)\n", 186 | "test_portion = int(len(data) * 0.1)\n", 187 | "val_portion = len(data) - train_portion - test_portion" 188 | ] 189 | }, 190 | { 191 | "cell_type": "code", 192 | "execution_count": 9, 193 | "id": "a988922e-3a5b-412e-92cc-a9e767ac5edc", 194 | "metadata": {}, 195 | "outputs": [ 196 | { 197 | "name": "stdout", 198 | "output_type": "stream", 199 | "text": [ 200 | "Training data: 935\n", 201 | "Test data: 110\n", 202 | "Validation data: 55\n" 203 | ] 204 | } 205 | ], 206 | "source": [ 207 | "train_data = data[:train_portion]\n", 208 | "test_data = data[train_portion:train_portion+test_portion]\n", 209 | "val_data = data[train_portion+test_portion:]\n", 210 | "\n", 211 | "print(f\"Training data: {len(train_data)}\")\n", 212 | "print(f\"Test data: {len(test_data)}\")\n", 213 | "print(f\"Validation data: {len(val_data)}\")" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": null, 219 | "id": "fc44edd5-bb9d-40c0-84d4-870bfa87259e", 220 | "metadata": {}, 221 | "outputs": [], 222 | "source": [] 223 | }, 224 | { 225 | "cell_type": "code", 226 | "execution_count": 10, 227 | "id": "3bf21372-287b-40a5-a4f1-4f74fcf33931", 228 | "metadata": {}, 229 | "outputs": [], 230 | "source": [ 231 | "import torch\n", 232 | "\n", 233 | "def custom_collate_fn(\n", 234 | " batch,\n", 235 | " pad_token_id=50526,\n", 236 | " ignore_index=-100,\n", 237 | " allowed_max_length=None,\n", 238 | " device=\"cpu\"\n", 239 | "):\n", 240 | " batch_max_length = max([len(item)+1 for item in batch])\n", 241 | " inputs_lst, targets_lst = [], []\n", 242 | "\n", 243 | " for item in batch:\n", 244 | " new_item = item.copy()\n", 245 | " new_item += [pad_token_id]\n", 246 | "\n", 247 | " padded = new_item + [pad_token_id] * (batch_max_length - len(new_item))\n", 248 | " inputs = torch.tensor(padded[:-1])\n", 249 | " targets = torch.tensor(padded[1:])\n", 250 | "\n", 251 | " mask = targets == pad_token_id\n", 252 | " indices = torch.nonzero(mask).squeeze()\n", 253 | "\n", 254 | " if indices.numel() > 1:\n", 255 | " targets[indices[1:]] = ignore_index\n", 256 | "\n", 257 | " if allowed_max_length is not None:\n", 258 | " inputs = inputs[:allowed_max_length]\n", 259 | " targets = targets[:allowed_max_length]\n", 260 | "\n", 261 | " inputs_lst.append(inputs)\n", 262 | " targets_lst.append(targets)\n", 263 | " inputs_tensor = torch.stack(inputs_lst).to(device)\n", 264 | " targets_tensor = torch.stack(targets_lst).to(device)\n", 265 | "\n", 266 | " return inputs_tensor, targets_tensor" 267 | ] 268 | }, 269 | { 270 | "cell_type": "code", 271 | "execution_count": 12, 272 | "id": "252da5ff-078d-4376-ba82-ed94e9d698c4", 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "from functools import partial\n", 277 | "device = torch.device(\"cpu\")\n", 278 | "\n", 279 | "customized_collate_fn = partial(\n", 280 | " custom_collate_fn,\n", 281 | " device = device,\n", 282 | " allowed_max_length=1024\n", 283 | ")" 284 | ] 285 | }, 286 | { 287 | "cell_type": "code", 288 | "execution_count": 13, 289 | "id": "e9ba7639-1564-4d01-a25d-308acdc2d883", 290 | "metadata": {}, 291 | "outputs": [], 292 | "source": [ 293 | "inputs_1 = [0, 1, 2, 3, 4]\n", 294 | "inputs_2 = [5, 6]\n", 295 | "inputs_3 = [7, 8, 9]\n", 296 | "batch = (inputs_1, inputs_2, inputs_3)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 14, 302 | "id": "7f4315db-8b41-4ec5-bad7-ab2a474df0e8", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "inputs, targets = custom_collate_fn(batch, device=\"mps\")" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": 22, 312 | "id": "03e364ea-3c98-428b-b96e-6078c4663760", 313 | "metadata": {}, 314 | "outputs": [ 315 | { 316 | "name": "stdout", 317 | "output_type": "stream", 318 | "text": [ 319 | "tensor([[ 0, 1, 2, 3, 4],\n", 320 | " [ 5, 6, 50526, 50526, 50526],\n", 321 | " [ 7, 8, 9, 50526, 50526]], device='mps:0')\n", 322 | "tensor([[ 1, 2, 3, 4, 50526],\n", 323 | " [ 6, 50526, -100, -100, -100],\n", 324 | " [ 8, 9, 50526, -100, -100]], device='mps:0')\n", 325 | "torch.Size([3, 5]) torch.Size([3, 5])\n" 326 | ] 327 | } 328 | ], 329 | "source": [ 330 | "print(inputs)\n", 331 | "print(targets)\n", 332 | "\n", 333 | "print(inputs.shape, targets.shape)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 16, 339 | "id": "2122987f-d99d-4e51-9cfc-6478fdc06207", 340 | "metadata": {}, 341 | "outputs": [], 342 | "source": [ 343 | "from torch.utils.data import Dataset\n", 344 | "\n", 345 | "class InstructionDataset(Dataset):\n", 346 | " def __init__(self, data, tokenizer):\n", 347 | " self.data = data\n", 348 | " self.encoded_texts = []\n", 349 | "\n", 350 | " for entry in data:\n", 351 | " instruction_plus_input = format_input(entry)\n", 352 | " response_text = f\"\\n\\n### Response:\\n{entry['output']}\"\n", 353 | " full_text = instruction_plus_input + response_text\n", 354 | " self.encoded_texts.append(\n", 355 | " tokenizer.encode(full_text)\n", 356 | " )\n", 357 | "\n", 358 | " def __len__(self):\n", 359 | " return len(self.encoded_texts)\n", 360 | "\n", 361 | " def __getitem__(self, index):\n", 362 | " return self.encoded_texts[index]" 363 | ] 364 | }, 365 | { 366 | "cell_type": "code", 367 | "execution_count": 17, 368 | "id": "a396fdd5-3c11-4b25-96ba-872a3b3dbd90", 369 | "metadata": {}, 370 | "outputs": [], 371 | "source": [ 372 | "from torch.utils.data import DataLoader\n", 373 | "import tiktoken\n", 374 | "\n", 375 | "num_workers = 0\n", 376 | "batch_size = 8\n", 377 | "\n", 378 | "torch.manual_seed(123)\n", 379 | "\n", 380 | "tokenizer = tiktoken.get_encoding(\"gpt2\")\n", 381 | "\n", 382 | "train_dataset = InstructionDataset(train_data, tokenizer)\n", 383 | "train_loader = DataLoader(\n", 384 | " train_dataset, \n", 385 | " batch_size=batch_size, \n", 386 | " collate_fn=customized_collate_fn,\n", 387 | " shuffle=True,\n", 388 | " drop_last=True,\n", 389 | " num_workers=num_workers\n", 390 | ")\n", 391 | "\n", 392 | "test_dataset = InstructionDataset(test_data, tokenizer)\n", 393 | "test_loader = DataLoader(\n", 394 | " test_dataset, \n", 395 | " batch_size=batch_size, \n", 396 | " collate_fn=customized_collate_fn,\n", 397 | " shuffle=False,\n", 398 | " drop_last=False,\n", 399 | " num_workers=num_workers\n", 400 | ")\n", 401 | "\n", 402 | "val_dataset = InstructionDataset(val_data, tokenizer)\n", 403 | "val_loader = DataLoader(\n", 404 | " val_dataset, \n", 405 | " batch_size=batch_size, \n", 406 | " collate_fn=customized_collate_fn,\n", 407 | " shuffle=False,\n", 408 | " drop_last=False,\n", 409 | " num_workers=num_workers\n", 410 | ")" 411 | ] 412 | }, 413 | { 414 | "cell_type": "code", 415 | "execution_count": 24, 416 | "id": "4f1ddb5e-012b-4998-8dcc-1d6b27188aa2", 417 | "metadata": {}, 418 | "outputs": [ 419 | { 420 | "name": "stdout", 421 | "output_type": "stream", 422 | "text": [ 423 | "torch.Size([8, 93]) torch.Size([8, 93])\n", 424 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 425 | "torch.Size([8, 63]) torch.Size([8, 63])\n", 426 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 427 | "torch.Size([8, 76]) torch.Size([8, 76])\n", 428 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 429 | "torch.Size([8, 77]) torch.Size([8, 77])\n", 430 | "torch.Size([8, 75]) torch.Size([8, 75])\n", 431 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 432 | "torch.Size([8, 78]) torch.Size([8, 78])\n", 433 | "torch.Size([8, 63]) torch.Size([8, 63])\n", 434 | "torch.Size([8, 75]) torch.Size([8, 75])\n", 435 | "torch.Size([8, 65]) torch.Size([8, 65])\n", 436 | "torch.Size([8, 89]) torch.Size([8, 89])\n", 437 | "torch.Size([8, 74]) torch.Size([8, 74])\n", 438 | "torch.Size([8, 77]) torch.Size([8, 77])\n", 439 | "torch.Size([8, 75]) torch.Size([8, 75])\n", 440 | "torch.Size([8, 90]) torch.Size([8, 90])\n", 441 | "torch.Size([8, 79]) torch.Size([8, 79])\n", 442 | "torch.Size([8, 70]) torch.Size([8, 70])\n", 443 | "torch.Size([8, 84]) torch.Size([8, 84])\n", 444 | "torch.Size([8, 73]) torch.Size([8, 73])\n", 445 | "torch.Size([8, 60]) torch.Size([8, 60])\n", 446 | "torch.Size([8, 65]) torch.Size([8, 65])\n", 447 | "torch.Size([8, 74]) torch.Size([8, 74])\n", 448 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 449 | "torch.Size([8, 62]) torch.Size([8, 62])\n", 450 | "torch.Size([8, 81]) torch.Size([8, 81])\n", 451 | "torch.Size([8, 65]) torch.Size([8, 65])\n", 452 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 453 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 454 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 455 | "torch.Size([8, 70]) torch.Size([8, 70])\n", 456 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 457 | "torch.Size([8, 92]) torch.Size([8, 92])\n", 458 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 459 | "torch.Size([8, 85]) torch.Size([8, 85])\n", 460 | "torch.Size([8, 62]) torch.Size([8, 62])\n", 461 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 462 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 463 | "torch.Size([8, 78]) torch.Size([8, 78])\n", 464 | "torch.Size([8, 77]) torch.Size([8, 77])\n", 465 | "torch.Size([8, 93]) torch.Size([8, 93])\n", 466 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 467 | "torch.Size([8, 64]) torch.Size([8, 64])\n", 468 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 469 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 470 | "torch.Size([8, 69]) torch.Size([8, 69])\n", 471 | "torch.Size([8, 71]) torch.Size([8, 71])\n", 472 | "torch.Size([8, 61]) torch.Size([8, 61])\n", 473 | "torch.Size([8, 63]) torch.Size([8, 63])\n", 474 | "torch.Size([8, 64]) torch.Size([8, 64])\n", 475 | "torch.Size([8, 77]) torch.Size([8, 77])\n", 476 | "torch.Size([8, 64]) torch.Size([8, 64])\n", 477 | "torch.Size([8, 71]) torch.Size([8, 71])\n", 478 | "torch.Size([8, 69]) torch.Size([8, 69])\n", 479 | "torch.Size([8, 71]) torch.Size([8, 71])\n", 480 | "torch.Size([8, 58]) torch.Size([8, 58])\n", 481 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 482 | "torch.Size([8, 81]) torch.Size([8, 81])\n", 483 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 484 | "torch.Size([8, 63]) torch.Size([8, 63])\n", 485 | "torch.Size([8, 64]) torch.Size([8, 64])\n", 486 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 487 | "torch.Size([8, 64]) torch.Size([8, 64])\n", 488 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 489 | "torch.Size([8, 76]) torch.Size([8, 76])\n", 490 | "torch.Size([8, 70]) torch.Size([8, 70])\n", 491 | "torch.Size([8, 63]) torch.Size([8, 63])\n", 492 | "torch.Size([8, 58]) torch.Size([8, 58])\n", 493 | "torch.Size([8, 73]) torch.Size([8, 73])\n", 494 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 495 | "torch.Size([8, 58]) torch.Size([8, 58])\n", 496 | "torch.Size([8, 71]) torch.Size([8, 71])\n", 497 | "torch.Size([8, 55]) torch.Size([8, 55])\n", 498 | "torch.Size([8, 68]) torch.Size([8, 68])\n", 499 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 500 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 501 | "torch.Size([8, 58]) torch.Size([8, 58])\n", 502 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 503 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 504 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 505 | "torch.Size([8, 85]) torch.Size([8, 85])\n", 506 | "torch.Size([8, 69]) torch.Size([8, 69])\n", 507 | "torch.Size([8, 84]) torch.Size([8, 84])\n", 508 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 509 | "torch.Size([8, 76]) torch.Size([8, 76])\n", 510 | "torch.Size([8, 80]) torch.Size([8, 80])\n", 511 | "torch.Size([8, 84]) torch.Size([8, 84])\n", 512 | "torch.Size([8, 73]) torch.Size([8, 73])\n", 513 | "torch.Size([8, 81]) torch.Size([8, 81])\n", 514 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 515 | "torch.Size([8, 88]) torch.Size([8, 88])\n", 516 | "torch.Size([8, 62]) torch.Size([8, 62])\n", 517 | "torch.Size([8, 69]) torch.Size([8, 69])\n", 518 | "torch.Size([8, 58]) torch.Size([8, 58])\n", 519 | "torch.Size([8, 80]) torch.Size([8, 80])\n", 520 | "torch.Size([8, 69]) torch.Size([8, 69])\n", 521 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 522 | "torch.Size([8, 61]) torch.Size([8, 61])\n", 523 | "torch.Size([8, 74]) torch.Size([8, 74])\n", 524 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 525 | "torch.Size([8, 82]) torch.Size([8, 82])\n", 526 | "torch.Size([8, 84]) torch.Size([8, 84])\n", 527 | "torch.Size([8, 75]) torch.Size([8, 75])\n", 528 | "torch.Size([8, 71]) torch.Size([8, 71])\n", 529 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 530 | "torch.Size([8, 78]) torch.Size([8, 78])\n", 531 | "torch.Size([8, 67]) torch.Size([8, 67])\n", 532 | "torch.Size([8, 84]) torch.Size([8, 84])\n", 533 | "torch.Size([8, 72]) torch.Size([8, 72])\n", 534 | "torch.Size([8, 62]) torch.Size([8, 62])\n", 535 | "torch.Size([8, 81]) torch.Size([8, 81])\n", 536 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 537 | "torch.Size([8, 66]) torch.Size([8, 66])\n", 538 | "torch.Size([8, 68]) torch.Size([8, 68])\n" 539 | ] 540 | } 541 | ], 542 | "source": [ 543 | "for inputs, targets in train_loader:\n", 544 | " print(inputs.shape, targets.shape)" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "id": "0ee60fed-2944-4729-84a8-17d8c43e6e64", 551 | "metadata": {}, 552 | "outputs": [], 553 | "source": [] 554 | } 555 | ], 556 | "metadata": { 557 | "kernelspec": { 558 | "display_name": "Python 3 (ipykernel)", 559 | "language": "python", 560 | "name": "python3" 561 | }, 562 | "language_info": { 563 | "codemirror_mode": { 564 | "name": "ipython", 565 | "version": 3 566 | }, 567 | "file_extension": ".py", 568 | "mimetype": "text/x-python", 569 | "name": "python", 570 | "nbconvert_exporter": "python", 571 | "pygments_lexer": "ipython3", 572 | "version": "3.11.4" 573 | } 574 | }, 575 | "nbformat": 4, 576 | "nbformat_minor": 5 577 | } 578 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LLMs_from_scratch 2 | Learning records for building a large language model from scratch 3 | 4 | - Book: [Build a Large Language Model (From Scratch)](https://www.manning.com/books/build-a-large-language-model-from-scratch?utm_source=raschka&utm_medium=affiliate&utm_campaign=book_raschka_build_12_12_23&a_aid=raschka&a_bid=4c2437a0&chan=mm_github) 5 | - Github repo: [rasbt/LLMs-from-scratch](https://github.com/rasbt/LLMs-from-scratch) 6 | 7 | I implemented the codes in the fantastic book. All the codes are run with a Macbook Pro (2020 version). Thus, I can use only two types of devices: `cpu` and `mps`. Perhaps, the Macbook is a little old that is equipped with M1 chip. The training speed of the`mps` mode is even slower than the `cpu` mode. 8 | 9 | 10 | ### Records 11 | 12 | - **Chapter 2**: 13 | 14 | - [x] Understanding word embeddings 15 | 16 | - [x] Tokenizing text 17 | 18 | - [x] Converting tokens into token IDs 19 | 20 | - [x] Adding special context tokens 21 | 22 | - [x] Byte pair encoding 23 | 24 | - [x] Data sampling with a sliding window 25 | 26 | - [x] Creating token embeddings 27 | 28 | - [x] Encoding word positions 29 | 30 | - **Chapter 3**: 31 | 32 | - [x] Capturing data dependencies with attention mechanisms 33 | 34 | - [x] Implementing self-attention with trainable weights 35 | 36 | - [x] Hiding feature words with causal attention 37 | 38 | - [x] Multi-head attention 39 | 40 | - **Chapter 4**: 41 | 42 | - [x] Activations and layer normalization 43 | 44 | - [x] Adding shortcut connections 45 | 46 | - [x] Build a Transformer block 47 | 48 | - [x] Coding the GPT model 49 | 50 | - [x] Generating text 51 | 52 | - **Chapter 5**: 53 | 54 | - [x] Evaluating generative text model 55 | 56 | - [x] Training an LLM 57 | 58 | ``` 59 | $ python train_llms_from_scratch.py 60 | ``` 61 | 62 | ![](ch05/train_plot.png) 63 | 64 | - [x] Greedy search and Top-k sampling 65 | 66 | - [x] Load pretrained weights from OpenAI 67 | 68 | **PS**: When I download the source model data from OpenAI, the downloading procedure is always frequently broken. Therefore, I tried multiple times and finally collect both the `small` and `media` models. These models are uploaded to Baidu Cloud for your convenience. 69 | 70 | 71 | |Model Size|OpenAI Sources|Converted (Pytorch version)| 72 | |:-----:|:-----:|:-----:| 73 | |`small`|[Baidu Cloud](https://pan.baidu.com/s/1BMpqgnkceMsNYGqOzNybxA?pwd=d3wu) (psw: d3wu)| [Baidu Cloud](https://pan.baidu.com/s/1_oL4DSRfWg6wBmSJ6vDISA?pwd=r4hq) (psw: r4hq)| 74 | |`media`|[Baidu Cloud](https://pan.baidu.com/s/1Ih1A0UQPUsAOdwT0eoGmhw?pwd=qqqj) (psw: qqqj) | [Baidu Cloud](https://pan.baidu.com/s/1n_2WndBnEviIhO3X6MShCg?pwd=8whr) (psw: 8whr)| 75 | 76 | - [x] [Convert GPT to Llama2](./ch05/gpt-to-llama2.ipynb) (Finished: RMSNorm, SiLU, SwiGLU; TODO: RoPE and TransformerBlock) 77 | 78 | - [ ] Convert GPT to Llama3 79 | 80 | 81 | - **Chapter 6**: 82 | 83 | - [x] Prepare spam email dataset and dataloader 84 | 85 | - [x] Fine-tune the model on supervised data 86 | 87 | - [x] Use the LLM as a spam classifier 88 | 89 | - **Chapter 7:** 90 | 91 | - [x] Prepare a dataset for supervised instruction fine-tuning 92 | 93 | - [x] Organize data into training batches 94 | 95 | - [x] Finetune the LLM on instruction data 96 | 97 | **PS**: It is challenging for me to train with `gpt2-media (355M)` model. Thefore, I still use the light-weight `gpt2-small (124M)`. So, it is no superise that the predictions of the finetuned model perform bad. 98 | 99 | - [ ] Preference tuning with DPO 100 | 101 | -------------------------------------------------------------------------------- /ch02/the-verdict.txt: -------------------------------------------------------------------------------- 1 | I HAD always thought Jack Gisburn rather a cheap genius--though a good fellow enough--so it was no great surprise to me to hear that, in the height of his glory, he had dropped his painting, married a rich widow, and established himself in a villa on the Riviera. (Though I rather thought it would have been Rome or Florence.) 2 | 3 | "The height of his glory"--that was what the women called it. I can hear Mrs. Gideon Thwing--his last Chicago sitter--deploring his unaccountable abdication. "Of course it's going to send the value of my picture 'way up; but I don't think of that, Mr. Rickham--the loss to Arrt is all I think of." The word, on Mrs. Thwing's lips, multiplied its _rs_ as though they were reflected in an endless vista of mirrors. And it was not only the Mrs. Thwings who mourned. Had not the exquisite Hermia Croft, at the last Grafton Gallery show, stopped me before Gisburn's "Moon-dancers" to say, with tears in her eyes: "We shall not look upon its like again"? 4 | 5 | Well!--even through the prism of Hermia's tears I felt able to face the fact with equanimity. Poor Jack Gisburn! The women had made him--it was fitting that they should mourn him. Among his own sex fewer regrets were heard, and in his own trade hardly a murmur. Professional jealousy? Perhaps. If it were, the honour of the craft was vindicated by little Claude Nutley, who, in all good faith, brought out in the Burlington a very handsome "obituary" on Jack--one of those showy articles stocked with random technicalities that I have heard (I won't say by whom) compared to Gisburn's painting. And so--his resolve being apparently irrevocable--the discussion gradually died out, and, as Mrs. Thwing had predicted, the price of "Gisburns" went up. 6 | 7 | It was not till three years later that, in the course of a few weeks' idling on the Riviera, it suddenly occurred to me to wonder why Gisburn had given up his painting. On reflection, it really was a tempting problem. To accuse his wife would have been too easy--his fair sitters had been denied the solace of saying that Mrs. Gisburn had "dragged him down." For Mrs. Gisburn--as such--had not existed till nearly a year after Jack's resolve had been taken. It might be that he had married her--since he liked his ease--because he didn't want to go on painting; but it would have been hard to prove that he had given up his painting because he had married her. 8 | 9 | Of course, if she had not dragged him down, she had equally, as Miss Croft contended, failed to "lift him up"--she had not led him back to the easel. To put the brush into his hand again--what a vocation for a wife! But Mrs. Gisburn appeared to have disdained it--and I felt it might be interesting to find out why. 10 | 11 | The desultory life of the Riviera lends itself to such purely academic speculations; and having, on my way to Monte Carlo, caught a glimpse of Jack's balustraded terraces between the pines, I had myself borne thither the next day. 12 | 13 | I found the couple at tea beneath their palm-trees; and Mrs. Gisburn's welcome was so genial that, in the ensuing weeks, I claimed it frequently. It was not that my hostess was "interesting": on that point I could have given Miss Croft the fullest reassurance. It was just because she was _not_ interesting--if I may be pardoned the bull--that I found her so. For Jack, all his life, had been surrounded by interesting women: they had fostered his art, it had been reared in the hot-house of their adulation. And it was therefore instructive to note what effect the "deadening atmosphere of mediocrity" (I quote Miss Croft) was having on him. 14 | 15 | I have mentioned that Mrs. Gisburn was rich; and it was immediately perceptible that her husband was extracting from this circumstance a delicate but substantial satisfaction. It is, as a rule, the people who scorn money who get most out of it; and Jack's elegant disdain of his wife's big balance enabled him, with an appearance of perfect good-breeding, to transmute it into objects of art and luxury. To the latter, I must add, he remained relatively indifferent; but he was buying Renaissance bronzes and eighteenth-century pictures with a discrimination that bespoke the amplest resources. 16 | 17 | "Money's only excuse is to put beauty into circulation," was one of the axioms he laid down across the Sevres and silver of an exquisitely appointed luncheon-table, when, on a later day, I had again run over from Monte Carlo; and Mrs. Gisburn, beaming on him, added for my enlightenment: "Jack is so morbidly sensitive to every form of beauty." 18 | 19 | Poor Jack! It had always been his fate to have women say such things of him: the fact should be set down in extenuation. What struck me now was that, for the first time, he resented the tone. I had seen him, so often, basking under similar tributes--was it the conjugal note that robbed them of their savour? No--for, oddly enough, it became apparent that he was fond of Mrs. Gisburn--fond enough not to see her absurdity. It was his own absurdity he seemed to be wincing under--his own attitude as an object for garlands and incense. 20 | 21 | "My dear, since I've chucked painting people don't say that stuff about me--they say it about Victor Grindle," was his only protest, as he rose from the table and strolled out onto the sunlit terrace. 22 | 23 | I glanced after him, struck by his last word. Victor Grindle was, in fact, becoming the man of the moment--as Jack himself, one might put it, had been the man of the hour. The younger artist was said to have formed himself at my friend's feet, and I wondered if a tinge of jealousy underlay the latter's mysterious abdication. But no--for it was not till after that event that the _rose Dubarry_ drawing-rooms had begun to display their "Grindles." 24 | 25 | I turned to Mrs. Gisburn, who had lingered to give a lump of sugar to her spaniel in the dining-room. 26 | 27 | "Why _has_ he chucked painting?" I asked abruptly. 28 | 29 | She raised her eyebrows with a hint of good-humoured surprise. 30 | 31 | "Oh, he doesn't _have_ to now, you know; and I want him to enjoy himself," she said quite simply. 32 | 33 | I looked about the spacious white-panelled room, with its _famille-verte_ vases repeating the tones of the pale damask curtains, and its eighteenth-century pastels in delicate faded frames. 34 | 35 | "Has he chucked his pictures too? I haven't seen a single one in the house." 36 | 37 | A slight shade of constraint crossed Mrs. Gisburn's open countenance. "It's his ridiculous modesty, you know. He says they're not fit to have about; he's sent them all away except one--my portrait--and that I have to keep upstairs." 38 | 39 | His ridiculous modesty--Jack's modesty about his pictures? My curiosity was growing like the bean-stalk. I said persuasively to my hostess: "I must really see your portrait, you know." 40 | 41 | She glanced out almost timorously at the terrace where her husband, lounging in a hooded chair, had lit a cigar and drawn the Russian deerhound's head between his knees. 42 | 43 | "Well, come while he's not looking," she said, with a laugh that tried to hide her nervousness; and I followed her between the marble Emperors of the hall, and up the wide stairs with terra-cotta nymphs poised among flowers at each landing. 44 | 45 | In the dimmest corner of her boudoir, amid a profusion of delicate and distinguished objects, hung one of the familiar oval canvases, in the inevitable garlanded frame. The mere outline of the frame called up all Gisburn's past! 46 | 47 | Mrs. Gisburn drew back the window-curtains, moved aside a _jardiniere_ full of pink azaleas, pushed an arm-chair away, and said: "If you stand here you can just manage to see it. I had it over the mantel-piece, but he wouldn't let it stay." 48 | 49 | Yes--I could just manage to see it--the first portrait of Jack's I had ever had to strain my eyes over! Usually they had the place of honour--say the central panel in a pale yellow or _rose Dubarry_ drawing-room, or a monumental easel placed so that it took the light through curtains of old Venetian point. The more modest place became the picture better; yet, as my eyes grew accustomed to the half-light, all the characteristic qualities came out--all the hesitations disguised as audacities, the tricks of prestidigitation by which, with such consummate skill, he managed to divert attention from the real business of the picture to some pretty irrelevance of detail. Mrs. Gisburn, presenting a neutral surface to work on--forming, as it were, so inevitably the background of her own picture--had lent herself in an unusual degree to the display of this false virtuosity. The picture was one of Jack's "strongest," as his admirers would have put it--it represented, on his part, a swelling of muscles, a congesting of veins, a balancing, straddling and straining, that reminded one of the circus-clown's ironic efforts to lift a feather. It met, in short, at every point the demand of lovely woman to be painted "strongly" because she was tired of being painted "sweetly"--and yet not to lose an atom of the sweetness. 50 | 51 | "It's the last he painted, you know," Mrs. Gisburn said with pardonable pride. "The last but one," she corrected herself--"but the other doesn't count, because he destroyed it." 52 | 53 | "Destroyed it?" I was about to follow up this clue when I heard a footstep and saw Jack himself on the threshold. 54 | 55 | As he stood there, his hands in the pockets of his velveteen coat, the thin brown waves of hair pushed back from his white forehead, his lean sunburnt cheeks furrowed by a smile that lifted the tips of a self-confident moustache, I felt to what a degree he had the same quality as his pictures--the quality of looking cleverer than he was. 56 | 57 | His wife glanced at him deprecatingly, but his eyes travelled past her to the portrait. 58 | 59 | "Mr. Rickham wanted to see it," she began, as if excusing herself. He shrugged his shoulders, still smiling. 60 | 61 | "Oh, Rickham found me out long ago," he said lightly; then, passing his arm through mine: "Come and see the rest of the house." 62 | 63 | He showed it to me with a kind of naive suburban pride: the bath-rooms, the speaking-tubes, the dress-closets, the trouser-presses--all the complex simplifications of the millionaire's domestic economy. And whenever my wonder paid the expected tribute he said, throwing out his chest a little: "Yes, I really don't see how people manage to live without that." 64 | 65 | Well--it was just the end one might have foreseen for him. Only he was, through it all and in spite of it all--as he had been through, and in spite of, his pictures--so handsome, so charming, so disarming, that one longed to cry out: "Be dissatisfied with your leisure!" as once one had longed to say: "Be dissatisfied with your work!" 66 | 67 | But, with the cry on my lips, my diagnosis suffered an unexpected check. 68 | 69 | "This is my own lair," he said, leading me into a dark plain room at the end of the florid vista. It was square and brown and leathery: no "effects"; no bric-a-brac, none of the air of posing for reproduction in a picture weekly--above all, no least sign of ever having been used as a studio. 70 | 71 | The fact brought home to me the absolute finality of Jack's break with his old life. 72 | 73 | "Don't you ever dabble with paint any more?" I asked, still looking about for a trace of such activity. 74 | 75 | "Never," he said briefly. 76 | 77 | "Or water-colour--or etching?" 78 | 79 | His confident eyes grew dim, and his cheeks paled a little under their handsome sunburn. 80 | 81 | "Never think of it, my dear fellow--any more than if I'd never touched a brush." 82 | 83 | And his tone told me in a flash that he never thought of anything else. 84 | 85 | I moved away, instinctively embarrassed by my unexpected discovery; and as I turned, my eye fell on a small picture above the mantel-piece--the only object breaking the plain oak panelling of the room. 86 | 87 | "Oh, by Jove!" I said. 88 | 89 | It was a sketch of a donkey--an old tired donkey, standing in the rain under a wall. 90 | 91 | "By Jove--a Stroud!" I cried. 92 | 93 | He was silent; but I felt him close behind me, breathing a little quickly. 94 | 95 | "What a wonder! Made with a dozen lines--but on everlasting foundations. You lucky chap, where did you get it?" 96 | 97 | He answered slowly: "Mrs. Stroud gave it to me." 98 | 99 | "Ah--I didn't know you even knew the Strouds. He was such an inflexible hermit." 100 | 101 | "I didn't--till after. . . . She sent for me to paint him when he was dead." 102 | 103 | "When he was dead? You?" 104 | 105 | I must have let a little too much amazement escape through my surprise, for he answered with a deprecating laugh: "Yes--she's an awful simpleton, you know, Mrs. Stroud. Her only idea was to have him done by a fashionable painter--ah, poor Stroud! She thought it the surest way of proclaiming his greatness--of forcing it on a purblind public. And at the moment I was _the_ fashionable painter." 106 | 107 | "Ah, poor Stroud--as you say. Was _that_ his history?" 108 | 109 | "That was his history. She believed in him, gloried in him--or thought she did. But she couldn't bear not to have all the drawing-rooms with her. She couldn't bear the fact that, on varnishing days, one could always get near enough to see his pictures. Poor woman! She's just a fragment groping for other fragments. Stroud is the only whole I ever knew." 110 | 111 | "You ever knew? But you just said--" 112 | 113 | Gisburn had a curious smile in his eyes. 114 | 115 | "Oh, I knew him, and he knew me--only it happened after he was dead." 116 | 117 | I dropped my voice instinctively. "When she sent for you?" 118 | 119 | "Yes--quite insensible to the irony. She wanted him vindicated--and by me!" 120 | 121 | He laughed again, and threw back his head to look up at the sketch of the donkey. "There were days when I couldn't look at that thing--couldn't face it. But I forced myself to put it here; and now it's cured me--cured me. That's the reason why I don't dabble any more, my dear Rickham; or rather Stroud himself is the reason." 122 | 123 | For the first time my idle curiosity about my companion turned into a serious desire to understand him better. 124 | 125 | "I wish you'd tell me how it happened," I said. 126 | 127 | He stood looking up at the sketch, and twirling between his fingers a cigarette he had forgotten to light. Suddenly he turned toward me. 128 | 129 | "I'd rather like to tell you--because I've always suspected you of loathing my work." 130 | 131 | I made a deprecating gesture, which he negatived with a good-humoured shrug. 132 | 133 | "Oh, I didn't care a straw when I believed in myself--and now it's an added tie between us!" 134 | 135 | He laughed slightly, without bitterness, and pushed one of the deep arm-chairs forward. "There: make yourself comfortable--and here are the cigars you like." 136 | 137 | He placed them at my elbow and continued to wander up and down the room, stopping now and then beneath the picture. 138 | 139 | "How it happened? I can tell you in five minutes--and it didn't take much longer to happen. . . . I can remember now how surprised and pleased I was when I got Mrs. Stroud's note. Of course, deep down, I had always _felt_ there was no one like him--only I had gone with the stream, echoed the usual platitudes about him, till I half got to think he was a failure, one of the kind that are left behind. By Jove, and he _was_ left behind--because he had come to stay! The rest of us had to let ourselves be swept along or go under, but he was high above the current--on everlasting foundations, as you say. 140 | 141 | "Well, I went off to the house in my most egregious mood--rather moved, Lord forgive me, at the pathos of poor Stroud's career of failure being crowned by the glory of my painting him! Of course I meant to do the picture for nothing--I told Mrs. Stroud so when she began to stammer something about her poverty. I remember getting off a prodigious phrase about the honour being _mine_--oh, I was princely, my dear Rickham! I was posing to myself like one of my own sitters. 142 | 143 | "Then I was taken up and left alone with him. I had sent all my traps in advance, and I had only to set up the easel and get to work. He had been dead only twenty-four hours, and he died suddenly, of heart disease, so that there had been no preliminary work of destruction--his face was clear and untouched. I had met him once or twice, years before, and thought him insignificant and dingy. Now I saw that he was superb. 144 | 145 | "I was glad at first, with a merely aesthetic satisfaction: glad to have my hand on such a 'subject.' Then his strange life-likeness began to affect me queerly--as I blocked the head in I felt as if he were watching me do it. The sensation was followed by the thought: if he _were_ watching me, what would he say to my way of working? My strokes began to go a little wild--I felt nervous and uncertain. 146 | 147 | "Once, when I looked up, I seemed to see a smile behind his close grayish beard--as if he had the secret, and were amusing himself by holding it back from me. That exasperated me still more. The secret? Why, I had a secret worth twenty of his! I dashed at the canvas furiously, and tried some of my bravura tricks. But they failed me, they crumbled. I saw that he wasn't watching the showy bits--I couldn't distract his attention; he just kept his eyes on the hard passages between. Those were the ones I had always shirked, or covered up with some lying paint. And how he saw through my lies! 148 | 149 | "I looked up again, and caught sight of that sketch of the donkey hanging on the wall near his bed. His wife told me afterward it was the last thing he had done--just a note taken with a shaking hand, when he was down in Devonshire recovering from a previous heart attack. Just a note! But it tells his whole history. There are years of patient scornful persistence in every line. A man who had swum with the current could never have learned that mighty up-stream stroke. . . . 150 | 151 | "I turned back to my work, and went on groping and muddling; then I looked at the donkey again. I saw that, when Stroud laid in the first stroke, he knew just what the end would be. He had possessed his subject, absorbed it, recreated it. When had I done that with any of my things? They hadn't been born of me--I had just adopted them. . . . 152 | 153 | "Hang it, Rickham, with that face watching me I couldn't do another stroke. The plain truth was, I didn't know where to put it--_I had never known_. Only, with my sitters and my public, a showy splash of colour covered up the fact--I just threw paint into their faces. . . . Well, paint was the one medium those dead eyes could see through--see straight to the tottering foundations underneath. Don't you know how, in talking a foreign language, even fluently, one says half the time not what one wants to but what one can? Well--that was the way I painted; and as he lay there and watched me, the thing they called my 'technique' collapsed like a house of cards. He didn't sneer, you understand, poor Stroud--he just lay there quietly watching, and on his lips, through the gray beard, I seemed to hear the question: 'Are you sure you know where you're coming out?' 154 | 155 | "If I could have painted that face, with that question on it, I should have done a great thing. The next greatest thing was to see that I couldn't--and that grace was given me. But, oh, at that minute, Rickham, was there anything on earth I wouldn't have given to have Stroud alive before me, and to hear him say: 'It's not too late--I'll show you how'? 156 | 157 | "It _was_ too late--it would have been, even if he'd been alive. I packed up my traps, and went down and told Mrs. Stroud. Of course I didn't tell her _that_--it would have been Greek to her. I simply said I couldn't paint him, that I was too moved. She rather liked the idea--she's so romantic! It was that that made her give me the donkey. But she was terribly upset at not getting the portrait--she did so want him 'done' by some one showy! At first I was afraid she wouldn't let me off--and at my wits' end I suggested Grindle. Yes, it was I who started Grindle: I told Mrs. Stroud he was the 'coming' man, and she told somebody else, and so it got to be true. . . . And he painted Stroud without wincing; and she hung the picture among her husband's things. . . ." 158 | 159 | He flung himself down in the arm-chair near mine, laid back his head, and clasping his arms beneath it, looked up at the picture above the chimney-piece. 160 | 161 | "I like to fancy that Stroud himself would have given it to me, if he'd been able to say what he thought that day." 162 | 163 | And, in answer to a question I put half-mechanically--"Begin again?" he flashed out. "When the one thing that brings me anywhere near him is that I knew enough to leave off?" 164 | 165 | He stood up and laid his hand on my shoulder with a laugh. "Only the irony of it is that I _am_ still painting--since Grindle's doing it for me! The Strouds stand alone, and happen once--but there's no exterminating our kind of art." -------------------------------------------------------------------------------- /ch05/download_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import requests 4 | from tqdm import tqdm 5 | 6 | if len(sys.argv) != 2: 7 | print('You must enter the model name as a parameter, e.g.: download_model.py 124M') 8 | sys.exit(1) 9 | 10 | model = sys.argv[1] 11 | 12 | subdir = os.path.join('models', model) 13 | if not os.path.exists(subdir): 14 | os.makedirs(subdir) 15 | subdir = subdir.replace('\\','/') # needed for Windows 16 | 17 | for filename in ['checkpoint','encoder.json','hparams.json','model.ckpt.data-00000-of-00001', 'model.ckpt.index', 'model.ckpt.meta', 'vocab.bpe']: 18 | 19 | r = requests.get("https://openaipublic.blob.core.windows.net/gpt-2/" + subdir + "/" + filename, stream=True) 20 | 21 | with open(os.path.join(subdir, filename), 'wb') as f: 22 | file_size = int(r.headers["content-length"]) 23 | chunk_size = 1000 24 | with tqdm(ncols=100, desc="Fetching " + filename, total=file_size, unit_scale=True) as pbar: 25 | # 1k for chunk_size, since Ethernet packet size is around 1500 bytes 26 | for chunk in r.iter_content(chunk_size=chunk_size): 27 | f.write(chunk) 28 | pbar.update(chunk_size) -------------------------------------------------------------------------------- /ch05/gpt-to-llama.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhlleo/LLMs_from_scratch/14ddb8817640e9437d1cd2f1a7cf18cf0f52f799/ch05/gpt-to-llama.pdf -------------------------------------------------------------------------------- /ch05/gpt_download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt). 2 | # Source for "Build a Large Language Model From Scratch" 3 | # - https://www.manning.com/books/build-a-large-language-model-from-scratch 4 | # Code: https://github.com/rasbt/LLMs-from-scratch 5 | 6 | 7 | import os 8 | import urllib.request 9 | 10 | # import requests 11 | import json 12 | import numpy as np 13 | import tensorflow as tf 14 | from tqdm import tqdm 15 | 16 | 17 | def download_and_load_gpt2(model_size, models_dir): 18 | # Validate model size 19 | allowed_sizes = ("124M", "355M", "774M", "1558M") 20 | if model_size not in allowed_sizes: 21 | raise ValueError(f"Model size not in {allowed_sizes}") 22 | 23 | # Define paths 24 | model_dir = os.path.join(models_dir, model_size) 25 | base_url = "https://openaipublic.blob.core.windows.net/gpt-2/models" 26 | filenames = [ 27 | "checkpoint", "encoder.json", "hparams.json", 28 | "model.ckpt.data-00000-of-00001", "model.ckpt.index", 29 | "model.ckpt.meta", "vocab.bpe" 30 | ] 31 | 32 | # Download files 33 | os.makedirs(model_dir, exist_ok=True) 34 | for filename in filenames: 35 | file_url = os.path.join(base_url, model_size, filename) 36 | file_path = os.path.join(model_dir, filename) 37 | download_file(file_url, file_path) 38 | 39 | # Load settings and params 40 | tf_ckpt_path = tf.train.latest_checkpoint(model_dir) 41 | settings = json.load(open(os.path.join(model_dir, "hparams.json"))) 42 | params = load_gpt2_params_from_tf_ckpt(tf_ckpt_path, settings) 43 | 44 | return settings, params 45 | 46 | 47 | def download_file(url, destination): 48 | # Send a GET request to download the file 49 | 50 | try: 51 | with urllib.request.urlopen(url) as response: 52 | # Get the total file size from headers, defaulting to 0 if not present 53 | file_size = int(response.headers.get("Content-Length", 0)) 54 | 55 | # Check if file exists and has the same size 56 | if os.path.exists(destination): 57 | file_size_local = os.path.getsize(destination) 58 | if file_size == file_size_local: 59 | print(f"File already exists and is up-to-date: {destination}") 60 | return 61 | 62 | # Define the block size for reading the file 63 | block_size = 1024 # 1 Kilobyte 64 | 65 | # Initialize the progress bar with total file size 66 | progress_bar_description = os.path.basename(url) # Extract filename from URL 67 | with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar: 68 | # Open the destination file in binary write mode 69 | with open(destination, "wb") as file: 70 | # Read the file in chunks and write to destination 71 | while True: 72 | chunk = response.read(block_size) 73 | if not chunk: 74 | break 75 | file.write(chunk) 76 | progress_bar.update(len(chunk)) # Update progress bar 77 | except urllib.error.HTTPError: 78 | s = ( 79 | f"The specified URL ({url}) is incorrect, the internet connection cannot be established," 80 | "\nor the requested file is temporarily unavailable.\nPlease visit the following website" 81 | " for help: https://github.com/rasbt/LLMs-from-scratch/discussions/273") 82 | print(s) 83 | 84 | 85 | # Alternative way using `requests` 86 | """ 87 | def download_file(url, destination): 88 | # Send a GET request to download the file in streaming mode 89 | response = requests.get(url, stream=True) 90 | 91 | # Get the total file size from headers, defaulting to 0 if not present 92 | file_size = int(response.headers.get("content-length", 0)) 93 | 94 | # Check if file exists and has the same size 95 | if os.path.exists(destination): 96 | file_size_local = os.path.getsize(destination) 97 | if file_size == file_size_local: 98 | print(f"File already exists and is up-to-date: {destination}") 99 | return 100 | 101 | # Define the block size for reading the file 102 | block_size = 1024 # 1 Kilobyte 103 | 104 | # Initialize the progress bar with total file size 105 | progress_bar_description = url.split("/")[-1] # Extract filename from URL 106 | with tqdm(total=file_size, unit="iB", unit_scale=True, desc=progress_bar_description) as progress_bar: 107 | # Open the destination file in binary write mode 108 | with open(destination, "wb") as file: 109 | # Iterate over the file data in chunks 110 | for chunk in response.iter_content(block_size): 111 | progress_bar.update(len(chunk)) # Update progress bar 112 | file.write(chunk) # Write the chunk to the file 113 | """ 114 | 115 | 116 | def load_gpt2_params_from_tf_ckpt(ckpt_path, settings): 117 | # Initialize parameters dictionary with empty blocks for each layer 118 | params = {"blocks": [{} for _ in range(settings["n_layer"])]} 119 | 120 | # Iterate over each variable in the checkpoint 121 | for name, _ in tf.train.list_variables(ckpt_path): 122 | # Load the variable and remove singleton dimensions 123 | variable_array = np.squeeze(tf.train.load_variable(ckpt_path, name)) 124 | 125 | # Process the variable name to extract relevant parts 126 | variable_name_parts = name.split("/")[1:] # Skip the 'model/' prefix 127 | 128 | # Identify the target dictionary for the variable 129 | target_dict = params 130 | if variable_name_parts[0].startswith("h"): 131 | layer_number = int(variable_name_parts[0][1:]) 132 | target_dict = params["blocks"][layer_number] 133 | 134 | # Recursively access or create nested dictionaries 135 | for key in variable_name_parts[1:-1]: 136 | target_dict = target_dict.setdefault(key, {}) 137 | 138 | # Assign the variable array to the last key 139 | last_key = variable_name_parts[-1] 140 | target_dict[last_key] = variable_array 141 | 142 | return params 143 | -------------------------------------------------------------------------------- /ch05/train_plot.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhlleo/LLMs_from_scratch/14ddb8817640e9437d1cd2f1a7cf18cf0f52f799/ch05/train_plot.png -------------------------------------------------------------------------------- /ch06/accuracy-plot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhlleo/LLMs_from_scratch/14ddb8817640e9437d1cd2f1a7cf18cf0f52f799/ch06/accuracy-plot.pdf -------------------------------------------------------------------------------- /ch06/loss-plot.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhlleo/LLMs_from_scratch/14ddb8817640e9437d1cd2f1a7cf18cf0f52f799/ch06/loss-plot.pdf -------------------------------------------------------------------------------- /ch06/valid.csv: -------------------------------------------------------------------------------- 1 | Label,Text 2 | 1,"Mila, age23, blonde, new in UK. I look sex with UK guys. if u like fun with me. Text MTALK to 69866.18 . 30pp/txt 1st 5free. £1.50 increments. Help08718728876" 3 | 1,"Hungry gay guys feeling hungry and up 4 it, now. Call 08718730555 just 10p/min. To stop texts call 08712460324 (10p/min)" 4 | 0,Ugh. Gotta drive back to sd from la. My butt is sore. 5 | 0,Please leave this topic..sorry for telling that.. 6 | 1,We tried to contact you re our offer of New Video Phone 750 anytime any network mins HALF PRICE Rental camcorder call 08000930705 or reply for delivery Wed 7 | 0,do u think that any girl will propose u today by seing ur bloody funky shit fucking face...............asssssholeeee................ 8 | 1,You are a winner U have been specially selected 2 receive £1000 or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810910p/min (18+) 9 | 0,"Sorry, I'll call later" 10 | 1,"For ur chance to win a £250 wkly shopping spree TXT: SHOP to 80878. T's&C's www.txt-2-shop.com custcare 08715705022, 1x150p/wk" 11 | 0,Aiyah u did ok already lar. E nydc at wheellock? 12 | 0,The whole car appreciated the last two! Dad and are having a map reading semi argument but apart from that things are going ok. P. 13 | 1,"YOUR CHANCE TO BE ON A REALITY FANTASY SHOW call now = 08707509020 Just 20p per min NTT Ltd, PO Box 1327 Croydon CR9 5WB 0870 is a national = rate call." 14 | 0,Old Orchard near univ. How about you? 15 | 1,"You are a £1000 winner or Guaranteed Caller Prize, this is our Final attempt to contact you! To Claim Call 09071517866 Now! 150ppmPOBox10183BhamB64XE" 16 | 0,Haiyoh... Maybe your hamster was jealous of million 17 | 1,Sorry! U can not unsubscribe yet. THE MOB offer package has a min term of 54 weeks> pls resubmit request after expiry. Reply THEMOB HELP 4 more info 18 | 0,Hey you gave them your photo when you registered for driving ah? Tmr wanna meet at yck? 19 | 1,You have WON a guaranteed £1000 cash or a £2000 prize.To claim yr prize call our customer service representative on 20 | 0,"A swt thought: ""Nver get tired of doing little things 4 lovable persons.."" Coz..somtimes those little things occupy d biggest part in their Hearts.. Gud ni8" 21 | 0,Hi.:)technical support.providing assistance to us customer through call and email:) 22 | 1,"Text82228>> Get more ringtones, logos and games from www.txt82228.com. Questions: info@txt82228.co.uk" 23 | 0,What happened in interview? 24 | 1,Dont forget you can place as many FREE Requests with 1stchoice.co.uk as you wish. For more Information call 08707808226. 25 | 1,Don't b floppy... b snappy & happy! Only gay chat service with photo upload call 08718730666 (10p/min). 2 stop our texts call 08712460324 26 | 0,Lol boo I was hoping for a laugh 27 | 1,Jamster! To get your free wallpaper text HEART to 88888 now! T&C apply. 16 only. Need Help? Call 08701213186. 28 | 1,CDs 4u: Congratulations ur awarded £500 of CD gift vouchers or £125 gift guaranteed & Freeentry 2 £100 wkly draw xt MUSIC to 87066 TnCs www.ldew.com1win150ppmx3age16 29 | 1,FreeMsg: Hey - I'm Buffy. 25 and love to satisfy men. Home alone feeling randy. Reply 2 C my PIX! QlynnBV Help08700621170150p a msg Send stop to stop txts 30 | 1,You have WON a guaranteed £1000 cash or a £2000 prize. To claim yr prize call our customer service representative on 08714712379 between 10am-7pm Cost 10p 31 | 0,Your opinion about me? 1. Over 2. Jada 3. Kusruthi 4. Lovable 5. Silent 6. Spl character 7. Not matured 8. Stylish 9. Simple Pls reply.. 32 | 1,Reply to win £100 weekly! Where will the 2006 FIFA World Cup be held? Send STOP to 87239 to end service 33 | 1,"Loan for any purpose £500 - £75,000. Homeowners + Tenants welcome. Have you been previously refused? We can still help. Call Free 0800 1956669 or text back 'help'" 34 | 0,"Its ok., i just askd did u knw tht no?" 35 | 0,Oh yeah I forgot. U can only take 2 out shopping at once. 36 | 0,Let me know how it changes in the next 6hrs. It can even be appendix but you are out of that age range. However its not impossible. So just chill and let me know in 6hrs 37 | 1,Dear 0776xxxxxxx U've been invited to XCHAT. This is our final attempt to contact u! Txt CHAT to 86688 150p/MsgrcvdHG/Suite342/2Lands/Row/W1J6HL LDN 18yrs 38 | 0,May b approve panalam...but it should have more posts.. 39 | 0,"Christmas is An occasion that is Celebrated as a Reflection of UR... Values..., Desires..., Affections...& Traditions.... Have an ideal Christmas..." 40 | 1,"Auction round 4. The highest bid is now £54. Next maximum bid is £71. To bid, send BIDS e. g. 10 (to bid £10) to 83383. Good luck." 41 | 0,"Yeah go on then, bored and depressed sittin waitin for phone to ring... Hope the wind drops though, scary" 42 | 0,I will be gentle baby! Soon you will be taking all <#> inches deep inside your tight pussy... 43 | 0,Waiting 4 my tv show 2 start lor... U leh still busy doing ur report? 44 | 0,Good. Good job. I like entrepreneurs 45 | 1,You have been specially selected to receive a 2000 pound award! Call 08712402050 BEFORE the lines close. Cost 10ppm. 16+. T&Cs apply. AG Promo 46 | 1,This message is free. Welcome to the new & improved Sex & Dogging club! To unsubscribe from this service reply STOP. msgs@150p 18+only 47 | 0,Hi good mornin.. Thanku wish u d same.. 48 | 1,"Today's Offer! Claim ur £150 worth of discount vouchers! Text YES to 85023 now! SavaMob, member offers mobile! T Cs 08717898035. £3.00 Sub. 16 . Unsub reply X" 49 | 1,"This is the 2nd attempt to contract U, you have won this weeks top prize of either £1000 cash or £200 prize. Just call 09066361921" 50 | 0,Maybe?! Say hi to and find out if got his card. Great escape or wetherspoons? 51 | 0,Huh but i got lesson at 4 lei n i was thinkin of going to sch earlier n i tot of parkin at kent vale... 52 | 1,"Congrats 2 mobile 3G Videophones R yours. call 09063458130 now! videochat wid ur mates, play java games, Dload polypH music, noline rentl. bx420. ip4. 5we. 150p" 53 | 0,Even my brother is not like to speak with me. They treat me like aids patent. 54 | 0,Also fuck you and your family for going to rhode island or wherever the fuck and leaving me all alone the week I have a new bong >:( 55 | 1,December only! Had your mobile 11mths+? You are entitled to update to the latest colour camera mobile for Free! Call The Mobile Update Co FREE on 08002986906 56 | 0,"2 celebrate my b’day, y else?" 57 | 0,"chile, please! It's only a <DECIMAL> hour drive for me. I come down all the time and will be subletting feb-april for audition season." 58 | 1,Want explicit SEX in 30 secs? Ring 02073162414 now! Costs 20p/min 59 | 0,I had askd u a question some hours before. Its answer 60 | 1,"FreeMSG You have been awarded a FREE mini DIGITAL CAMERA, just reply SNAP to collect your prize! (quizclub Opt out? Stop 80122300p/wk SP:RWM Ph:08704050406)" 61 | 0,I want to sent <#> mesages today. Thats y. Sorry if i hurts 62 | 0,Gam gone after outstanding innings. 63 | 1,"Thank you, winner notified by sms. Good Luck! No future marketing reply STOP to 84122 customer services 08450542832" 64 | 0,:-) yeah! Lol. Luckily i didn't have a starring role like you! 65 | 0,Babes I think I got ur brolly I left it in English wil bring it in 2mrw 4 u luv Franxx 66 | 0,Then u better go sleep.. Dun disturb u liao.. U wake up then msg me lor.. 67 | 1,Double your mins & txts on Orange or 1/2 price linerental - Motorola and SonyEricsson with B/Tooth FREE-Nokia FREE Call MobileUpd8 on 08000839402 or2optout/HV9D 68 | 1,Customer Loyalty Offer:The NEW Nokia6650 Mobile from ONLY £10 at TXTAUCTION! Txt word: START to No: 81151 & get yours Now! 4T&Ctxt TC 150p/MTmsg 69 | 0,"Hey there! Glad u r better now. I hear u treated urself to a digi cam, is it good? We r off at 9pm. Have a fab new year, c u in coupla wks!" 70 | 1,"449050000301 You have won a £2,000 price! To claim, call 09050000301." 71 | 0,Thats cool. i am a gentleman and will treat you with dignity and respect. 72 | 0,"Idk. You keep saying that you're not, but since he moved, we keep butting heads over freedom vs. responsibility. And i'm tired. I have so much other shit to deal with that i'm barely keeping myself together once this gets added to it." 73 | 0,Tell me they're female :V how're you throwing in? We're deciding what all to get now 74 | 1,"SMSSERVICES. for yourinclusive text credits, pls goto www.comuk.net login= 3qxj9 unsubscribe with STOP, no extra charge. help 08702840625.COMUK. 220-CM2 9AE" 75 | 1,"Welcome to UK-mobile-date this msg is FREE giving you free calling to 08719839835. Future mgs billed at 150p daily. To cancel send ""go stop"" to 89123" 76 | 1,"Free entry to the gr8prizes wkly comp 4 a chance to win the latest Nokia 8800, PSP or £250 cash every wk.TXT GREAT to 80878 http//www.gr8prizes.com 08715705022" 77 | 0,Wot u up 2 u weirdo? 78 | 0,Havent stuck at orchard in my dad's car. Going 4 dinner now. U leh? So r they free tonight? 79 | 0,K..k..i'm also fine:)when will you complete the course? 80 | 1,"Urgent! Please call 0906346330. Your ABTA complimentary 4* Spanish Holiday or £10,000 cash await collection SAE T&Cs BOX 47 PO19 2EZ 150ppm 18+" 81 | 0,Nothing will ever be easy. But don't be looking for a reason not to take a risk on life and love 82 | 1,Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's 83 | 0,You know what hook up means right? 84 | 1,u r a winner U ave been specially selected 2 receive £1000 cash or a 4* holiday (flights inc) speak to a live operator 2 claim 0871277810710p/min (18 ) 85 | 1,"U can WIN £100 of Music Gift Vouchers every week starting NOW Txt the word DRAW to 87066 TsCs www.Idew.com SkillGame, 1Winaweek, age16. 150ppermessSubscription" 86 | 1,Send a logo 2 ur lover - 2 names joined by a heart. Txt LOVE NAME1 NAME2 MOBNO eg LOVE ADAM EVE 07123456789 to 87077 Yahoo! POBox36504W45WQ TxtNO 4 no ads 150p 87 | 1,"URGENT! Your mobile No *********** WON a £2,000 Bonus Caller Prize on 02/06/03! This is the 2nd attempt to reach YOU! Call 09066362220 ASAP! BOX97N7QP, 150ppm" 88 | 0,"Wat time liao, where still got." 89 | 0,I'm still looking for a car to buy. And have not gone 4the driving test yet. 90 | 1,Valentines Day Special! Win over £1000 in our quiz and take your partner on the trip of a lifetime! Send GO to 83600 now. 150p/msg rcvd. CustCare:08718720201. 91 | 1,"YOUR CHANCE TO BE ON A REALITY FANTASY SHOW call now = 08707509020 Just 20p per min NTT Ltd, PO Box 1327 Croydon CR9 5WB 0870 is a national = rate call" 92 | 0,You said not now. No problem. When you can. Let me know. 93 | 1,Marvel Mobile Play the official Ultimate Spider-man game (£4.50) on ur mobile right now. Text SPIDER to 83338 for the game & we ll send u a FREE 8Ball wallpaper 94 | 1,Call Germany for only 1 pence per minute! Call from a fixed line via access number 0844 861 85 85. No prepayment. Direct access! www.telediscount.co.uk 95 | 0,"I know you are thinkin malaria. But relax, children cant handle malaria. She would have been worse and its gastroenteritis. If she takes enough to replace her loss her temp will reduce. And if you give her malaria meds now she will just vomit. Its a self limiting illness she has which means in a few days it will completely stop" 96 | 1,"Your free ringtone is waiting to be collected. Simply text the password ""MIX"" to 85069 to verify. Get Usher and Britney. FML, PO Box 5249, MK17 92H. 450Ppw 16" 97 | 0,May i call You later Pls 98 | 1,"HOT LIVE FANTASIES call now 08707509020 Just 20p per min NTT Ltd, PO Box 1327 Croydon CR9 5WB 0870 is a national rate call" 99 | 1,URGENT! We are trying to contact U. Todays draw shows that you have won a £800 prize GUARANTEED. Call 09050001808 from land line. Claim M95. Valid12hrs only 100 | 1,URGENT! We are trying to contact U. Todays draw shows that you have won a £800 prize GUARANTEED. Call 09050001808 from land line. Claim M95. Valid12hrs only 101 | 1,URGENT! Your Mobile number has been awarded with a £2000 prize GUARANTEED. Call 09061790121 from land line. Claim 3030. Valid 12hrs only 150ppm 102 | 1,FREE camera phones with linerental from 4.49/month with 750 cross ntwk mins. 1/2 price txt bundle deals also avble. Call 08001950382 or call2optout/J MF 103 | 0,Don know..he is watching film in computer.. 104 | 1,Camera - You are awarded a SiPix Digital Camera! call 09061221066 fromm landline. Delivery within 28 days. 105 | 0,"It's not that you make me cry. It's just that when all our stuff happens on top of everything else, it pushes me over the edge. You don't underdtand how often i cry over my sorry, sorry life." 106 | 1,Ur cash-balance is currently 500 pounds - to maximize ur cash-in now send CASH to 86688 only 150p/msg. CC: 08708800282 HG/Suite342/2Lands Row/W1J6HL 107 | 1,Email AlertFrom: Jeri StewartSize: 2KBSubject: Low-cost prescripiton drvgsTo listen to email call 123 108 | 0,Dear where you. Call me 109 | 1,December only! Had your mobile 11mths+? You are entitled to update to the latest colour camera mobile for Free! Call The Mobile Update VCo FREE on 08002986906 110 | 1,Txt: CALL to No: 86888 & claim your reward of 3 hours talk time to use from your phone now! Subscribe6GBP/mnth inc 3hrs 16 stop?txtStop www.gamb.tv 111 | 1,December only! Had your mobile 11mths+? You are entitled to update to the latest colour camera mobile for Free! Call The Mobile Update Co FREE on 08002986906 112 | 1,PRIVATE! Your 2003 Account Statement for 07808 XXXXXX shows 800 un-redeemed S. I. M. points. Call 08719899217 Identifier Code: 41685 Expires 07/11/04 113 | 1,You'll not rcv any more msgs from the chat svc. For FREE Hardcore services text GO to: 69988 If u get nothing u must Age Verify with yr network & try again 114 | 0,Better. Made up for Friday and stuffed myself like a pig yesterday. Now I feel bleh. But at least its not writhing pain kind of bleh. 115 | 1,Congratulations ur awarded either £500 of CD gift vouchers & Free entry 2 our £100 weekly draw txt MUSIC to 87066 TnCs www.Ldew.com 1 win150ppmx3age16 116 | 0,He's an adult and would learn from the experience. There's no real danger. I just dont like peeps using drugs they dont need. But no comment 117 | 0,Wa... U so efficient... Gee... Thanx... 118 | 1,todays vodafone numbers ending with 0089(my last four digits) are selected to received a £350 award. If your number matches please call 09063442151 to claim your £350 award 119 | 1,UR awarded a City Break and could WIN a £200 Summer Shopping spree every WK. Txt STORE to 88039 . SkilGme. TsCs087147403231Winawk!Age16 £1.50perWKsub 120 | 0,Message:some text missing* Sender:Name Missing* *Number Missing *Sent:Date missing *Missing U a lot thats y everything is missing sent via fullonsms.com 121 | 1,You have WON a guaranteed £1000 cash or a £2000 prize. To claim yr prize call our customer service representative on 08714712394 between 10am-7pm 122 | 1,"FREE MESSAGE Activate your 500 FREE Text Messages by replying to this message with the word FREE For terms & conditions, visit www.07781482378.com" 123 | 0,Dont think so. It turns off like randomlly within 5min of opening 124 | 0,Will do. Was exhausted on train this morning. Too much wine and pie. You sleep well too 125 | 1,2/2 146tf150p 126 | 1,U’ve Bin Awarded £50 to Play 4 Instant Cash. Call 08715203028 To Claim. EVERY 9th Player Wins Min £50-£500. OptOut 08718727870 127 | 1,"Xmas Offer! Latest Motorola, SonyEricsson & Nokia & FREE Bluetooth or DVD! Double Mins & 1000 Txt on Orange. Call MobileUpd8 on 08000839402 or call2optout/4QF2" 128 | 0,"Party's at my place at usf, no charge (but if you can contribute in any way it is greatly appreciated) and yeah, we got room for one more" 129 | 0,How are you with moneY...as in to you...money aint a thing....how are you sha! 130 | 1,Hi I'm sue. I am 20 years old and work as a lapdancer. I love sex. Text me live - I'm i my bedroom now. text SUE to 89555. By TextOperator G2 1DA 150ppmsg 18+ 131 | 1,4mths half price Orange line rental & latest camera phones 4 FREE. Had your phone 11mths ? Call MobilesDirect free on 08000938767 to update now! or2stoptxt 132 | 1,Do you want a new Video phone? 600 anytime any network mins 400 Inclusive Video calls AND downloads 5 per week Free delTOMORROW call 08002888812 or reply NOW 133 | 0,Neshanth..tel me who r u? 134 | 0,"Sure thing big man. i have hockey elections at 6, shouldn‘t go on longer than an hour though" 135 | 1,Get a FREE mobile video player FREE movie. To collect text GO to 89105. Its free! Extra films can be ordered t's and c's apply. 18 yrs only 136 | 1,Save money on wedding lingerie at www.bridal.petticoatdreams.co.uk Choose from a superb selection with national delivery. Brought to you by WeddingFriend 137 | 1,"Dear Matthew please call 09063440451 from a landline, your complimentary 4*Lux Tenerife holiday or £1000 CASH await collection. ppm150 SAE T&Cs Box334 SK38XH." 138 | 1,Valentines Day Special! Win over £1000 in our quiz and take your partner on the trip of a lifetime! Send GO to 83600 now. 150p/msg rcvd. CustCare:08718720201 139 | 0,No drama Pls.i have had enough from you and family while i am struggling in the hot sun in a strange place.No reason why there should be an ego of not going 'IF NOT INVITED' when actually its necessity to go.wait for very serious reppurcussions. 140 | 0,Cant think of anyone with * spare room off * top of my head 141 | 1,Guess who am I?This is the first time I created a web page WWW.ASJESUS.COM read all I wrote. I'm waiting for your opinions. I want to be your friend 1/1 142 | 1,"FREE2DAY sexy St George's Day pic of Jordan!Txt PIC to 89080 dont miss out, then every wk a saucy celeb!4 more pics c PocketBabe.co.uk 0870241182716 £3/wk" 143 | 1,FREE entry into our £250 weekly comp just send the word ENTER to 88877 NOW. 18 T&C www.textcomp.com 144 | 1,For the most sparkling shopping breaks from 45 per person; call 0121 2025050 or visit www.shortbreaks.org.uk 145 | 0,"My painful personal thought- ""I always try to keep everybody happy all the time. But nobody recognises me when i am alone""" 146 | 0,Hey gorgeous man. My work mobile number is. Have a good one babe. Squishy Mwahs. 147 | 0,Just sleeping..and surfing 148 | 0,"I'm in solihull, | do you want anything?" 149 | 0,"Jay told me already, will do" 150 | 1,"U can WIN £100 of Music Gift Vouchers every week starting NOW Txt the word DRAW to 87066 TsCs www.Idew.com SkillGame, 1Winaweek, age16. 150ppermessSubscription" 151 | -------------------------------------------------------------------------------- /codes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yhlleo/LLMs_from_scratch/14ddb8817640e9437d1cd2f1a7cf18cf0f52f799/codes/__init__.py -------------------------------------------------------------------------------- /codes/configs.py: -------------------------------------------------------------------------------- 1 | 2 | GPT_CONFIG_124M = { 3 | "vocab_size": 50257, 4 | "context_length": 256, 5 | "emb_dim": 768, 6 | "num_heads": 12, 7 | "n_layers": 12, 8 | "drop_rate": 0.1, 9 | "qkv_bias": False 10 | } 11 | 12 | gpt2_model_configs = { 13 | "gpt2-small (124M)": {"emb_dim": 768, "n_layers": 12, "num_heads": 12}, 14 | "gpt2-media (355M)": {"emb_dim": 1024, "n_layers": 24, "num_heads": 16}, 15 | "gpt2-large (774M)": {"emb_dim": 1280, "n_layers": 36, "num_heads": 20}, 16 | "gpt2-xl (1558M)": {"emd_dim": 1600, "n_layers": 48, "num_heads": 25} 17 | } 18 | 19 | def _build_gpt2_configs(model_type, update_cfg): 20 | assert model_type in gpt2_model_configs, "Unknown model type!" 21 | 22 | new_config = GPT_CONFIG_124M.copy() 23 | new_config.update(gpt2_model_configs[model_type]) 24 | new_config.update(update_cfg) 25 | return new_config 26 | 27 | def gpt2_small_config( 28 | model_type = "gpt2-small (124M)", 29 | update_cfg = { 30 | "context_length": 1024, 31 | "qkv_bias": True, 32 | "drop_rate": 0.0} 33 | ): 34 | return _build_gpt2_configs(model_type, update_cfg) 35 | 36 | def gpt2_media_config( 37 | model_type = "gpt2-media (355M)", 38 | update_cfg = { 39 | "context_length": 1024, 40 | "qkv_bias": True, 41 | "drop_rate": 0.0} 42 | ): 43 | return _build_gpt2_configs(model_type, update_cfg) 44 | -------------------------------------------------------------------------------- /codes/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset, DataLoader 3 | 4 | import tiktoken 5 | 6 | from codes.configs import GPT_CONFIG_124M 7 | 8 | class GPTDatasetV2(Dataset): 9 | def __init__(self, text, tokenizer, max_length, stride): 10 | self.input_ids = [] 11 | self.target_ids = [] 12 | 13 | token_ids = tokenizer.encode(text) 14 | for i in range(0, len(token_ids) - max_length, stride): 15 | input_chunk = token_ids[i:i+max_length] 16 | target_chunk = token_ids[i+1:i+max_length+1] 17 | 18 | self.input_ids.append(torch.tensor(input_chunk)) 19 | self.target_ids.append(torch.tensor(target_chunk)) 20 | 21 | def __len__(self): 22 | return len(self.input_ids) 23 | 24 | def __getitem__(self, index): 25 | return self.input_ids[index], self.target_ids[index] 26 | 27 | 28 | def create_dataloader_v1( 29 | text, 30 | batch_size=4, 31 | max_length=256, 32 | stride=128, 33 | shuffle=True, 34 | drop_last=True, 35 | num_workers=0): 36 | tokenizer = tiktoken.get_encoding("gpt2") 37 | 38 | dataset = GPTDatasetV2(text, tokenizer, max_length, stride) 39 | dataloader = DataLoader( 40 | dataset, 41 | batch_size=batch_size, 42 | shuffle=shuffle, 43 | drop_last=drop_last, 44 | num_workers=num_workers 45 | ) 46 | 47 | return dataloader 48 | 49 | 50 | def build_dataloader(): 51 | with open("./ch02/the-verdict.txt", "r", encoding="utf-8") as fin: 52 | raw_data = fin.read() 53 | 54 | train_ratio = 0.9 55 | split_idx = int(train_ratio * len(raw_data)) 56 | train_data = raw_data[:split_idx] 57 | val_data = raw_data[split_idx:] 58 | 59 | train_loader = create_dataloader_v1( 60 | train_data, 61 | batch_size=2, 62 | max_length=GPT_CONFIG_124M['context_length'], 63 | stride=GPT_CONFIG_124M['context_length'], 64 | drop_last=True 65 | ) 66 | 67 | val_loader = create_dataloader_v1( 68 | val_data, 69 | batch_size=2, 70 | max_length=GPT_CONFIG_124M['context_length'], 71 | stride=GPT_CONFIG_124M['context_length'], 72 | drop_last=False, 73 | shuffle=False 74 | ) 75 | return train_loader, val_loader -------------------------------------------------------------------------------- /codes/gpt_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class MultiheadAttention(nn.Module): 5 | def __init__(self, dim_in, dim_out, 6 | context_length, dropout, num_heads, qkv_bias=False): 7 | super().__init__() 8 | 9 | assert dim_out % num_heads == 0, \ 10 | "dim_out must be divisible by num_heads" 11 | 12 | self.dim_out = dim_out 13 | self.num_heads = num_heads 14 | self.head_dim = dim_out // num_heads 15 | 16 | self.W_query = nn.Linear(dim_in, dim_out, bias=qkv_bias) 17 | self.W_key = nn.Linear(dim_in, dim_out, bias=qkv_bias) 18 | self.W_value = nn.Linear(dim_in, dim_out, bias=qkv_bias) 19 | 20 | self.out_proj = nn.Linear(dim_out, dim_out) 21 | self.dropout = nn.Dropout(dropout) 22 | 23 | self.register_buffer( 24 | "mask", 25 | torch.triu(torch.ones(context_length, context_length), diagonal=1) 26 | ) 27 | 28 | def split_heads(self, x, batch_size, num_tokens): 29 | x = x.view(batch_size, num_tokens, self.num_heads, self.head_dim) 30 | x = x.transpose(1, 2) 31 | return x 32 | 33 | def forward(self, x): 34 | b, num_tokens, dim_in = x.shape 35 | 36 | keys = self.W_key(x) 37 | queries = self.W_query(x) 38 | values = self.W_value(x) 39 | 40 | keys = self.split_heads(keys, b, num_tokens) 41 | queries = self.split_heads(queries, b, num_tokens) 42 | values = self.split_heads(values, b, num_tokens) 43 | 44 | attn_scores = queries @ keys.transpose(2, 3) 45 | mask_bool = self.mask.bool()[:num_tokens, :num_tokens] 46 | attn_scores = attn_scores.masked_fill(mask_bool, -torch.inf) 47 | 48 | attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1) 49 | attn_weights = self.dropout(attn_weights) 50 | 51 | context_vecs = (attn_weights @ values).transpose(1, 2) 52 | context_vecs = context_vecs.contiguous().view(b, num_tokens, self.dim_out) 53 | 54 | context_vecs = self.out_proj(context_vecs) 55 | return context_vecs 56 | 57 | class LayerNorm(nn.Module): 58 | def __init__(self, emb_dim): 59 | super().__init__() 60 | 61 | self.eps = 1e-5 62 | self.scale = nn.Parameter(torch.ones(emb_dim)) 63 | self.shift = nn.Parameter(torch.zeros(emb_dim)) 64 | 65 | def forward(self, x): 66 | mean = x.mean(dim=-1, keepdim=True) 67 | var = x.var(dim=-1, keepdim=True, unbiased=False) 68 | 69 | norm_x = (x - mean) / torch.sqrt(var + self.eps) 70 | return self.scale * norm_x + self.shift 71 | 72 | class GELU(nn.Module): 73 | def __init__(self): 74 | super().__init__() 75 | 76 | def forward(self, x): 77 | return 0.5 * x * (1 + torch.tanh(( 78 | torch.sqrt(torch.tensor(2 / torch.pi)) * 79 | (x + 0.044715 * x**3)) 80 | )) 81 | 82 | class FeedForward(nn.Module): 83 | def __init__(self, cfg): 84 | super().__init__() 85 | 86 | self.layers = nn.Sequential( 87 | nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"]), 88 | GELU(), 89 | nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"]) 90 | ) 91 | 92 | def forward(self, x): 93 | return self.layers(x) 94 | 95 | 96 | class TransformerBlock(nn.Module): 97 | def __init__(self, cfg): 98 | super().__init__() 99 | 100 | self.attn = MultiheadAttention( 101 | dim_in = cfg["emb_dim"], 102 | dim_out = cfg["emb_dim"], 103 | context_length = cfg["context_length"], 104 | num_heads = cfg["num_heads"], 105 | dropout = cfg["drop_rate"], 106 | qkv_bias = cfg["qkv_bias"] 107 | ) 108 | 109 | self.ffn = FeedForward(cfg) 110 | self.norm1 = LayerNorm(cfg["emb_dim"]) 111 | self.norm2 = LayerNorm(cfg["emb_dim"]) 112 | self.drop_shortcut = nn.Dropout(cfg["drop_rate"]) 113 | 114 | def forward(self, x): 115 | shortcut = x 116 | x = self.norm1(x) 117 | x = self.attn(x) 118 | x = self.drop_shortcut(x) 119 | x = x + shortcut 120 | 121 | shortcut = x 122 | x = self.norm2(x) 123 | x = self.ffn(x) 124 | x = self.drop_shortcut(x) 125 | x = x + shortcut 126 | return x 127 | 128 | class GPTModel(nn.Module): 129 | def __init__(self, cfg): 130 | super().__init__() 131 | 132 | self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"]) 133 | self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"]) 134 | self.drop_emb = nn.Dropout(cfg["drop_rate"]) 135 | 136 | self.trans_blocks = nn.Sequential(*[TransformerBlock(cfg) for _ in range(cfg["n_layers"])]) 137 | 138 | self.final_norm = LayerNorm(cfg["emb_dim"]) 139 | self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias=False) 140 | 141 | def forward(self, in_idx): 142 | batch_size, seq_len = in_idx.shape 143 | token_embeds = self.tok_emb(in_idx) 144 | pos_embeds = self.pos_emb(torch.arange(seq_len, device=in_idx.device)) 145 | 146 | x = token_embeds + pos_embeds 147 | x = self.drop_emb(x) 148 | x = self.trans_blocks(x) 149 | x = self.final_norm(x) 150 | logits = self.out_head(x) 151 | return logits -------------------------------------------------------------------------------- /codes/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | def calc_loss_batch(input_batch, target_batch, model, device=None): 5 | input_batch = input_batch.to(device) 6 | target_batch = target_batch.to(device) 7 | 8 | logits = model(input_batch) 9 | loss = F.cross_entropy(logits.flatten(0,1), target_batch.flatten()) 10 | return loss 11 | 12 | 13 | def calc_loss_loader(data_loader, model, device=None, num_batches=None): 14 | total_loss = 0 15 | if len(data_loader) == 0: 16 | return float("nan") 17 | elif num_batches is None: 18 | num_batches = len(data_loader) 19 | else: 20 | num_batches = min(num_batches, len(data_loader)) 21 | 22 | for i, (input_batch, target_batch) in enumerate(data_loader): 23 | if i < num_batches: 24 | loss = calc_loss_batch(input_batch, target_batch, model, device) 25 | total_loss += loss.item() 26 | else: 27 | break 28 | 29 | return total_loss / num_batches 30 | 31 | def evaluate_model(model, train_loader, val_loader, device, eval_iter): 32 | model.eval() 33 | with torch.no_grad(): 34 | train_loss = calc_loss_loader(train_loader, model, device, num_batches=eval_iter) 35 | val_loss = calc_loss_loader(val_loader, model, device, num_batches=eval_iter) 36 | model.train() 37 | return train_loss, val_loss -------------------------------------------------------------------------------- /codes/model_convert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def assign(left, right): 5 | assert left.shape == right.shape, f"Shape mismatch. Left: {left.shape}, Right: {right.shape}" 6 | return torch.nn.Parameter(torch.tensor(right)) 7 | 8 | 9 | def load_weights_into_gpt(gpt, params): 10 | gpt.pos_emb.weight = assign(gpt.pos_emb.weight, params["wpe"]) 11 | gpt.tok_emb.weight = assign(gpt.tok_emb.weight, params["wte"]) 12 | 13 | for b in range(len(params["blocks"])): 14 | q_w, k_w, v_w = np.split( 15 | (params["blocks"][b]["attn"]["c_attn"])["w"], 3, axis=-1 16 | ) 17 | gpt.trans_blocks[b].attn.W_query.weight = assign( 18 | gpt.trans_blocks[b].attn.W_query.weight, q_w.T 19 | ) 20 | gpt.trans_blocks[b].attn.W_key.weight = assign( 21 | gpt.trans_blocks[b].attn.W_key.weight, k_w.T 22 | ) 23 | gpt.trans_blocks[b].attn.W_value.weight = assign( 24 | gpt.trans_blocks[b].attn.W_value.weight, v_w.T 25 | ) 26 | 27 | q_b, k_b, v_b = np.split( 28 | (params["blocks"][b]["attn"]["c_attn"])["b"], 3, axis=-1 29 | ) 30 | gpt.trans_blocks[b].attn.W_query.bias = assign( 31 | gpt.trans_blocks[b].attn.W_query.bias, q_b 32 | ) 33 | gpt.trans_blocks[b].attn.W_key.bias = assign( 34 | gpt.trans_blocks[b].attn.W_key.bias, k_b 35 | ) 36 | gpt.trans_blocks[b].attn.W_value.bias = assign( 37 | gpt.trans_blocks[b].attn.W_value.bias, v_b 38 | ) 39 | 40 | gpt.trans_blocks[b].attn.out_proj.weight = assign( 41 | gpt.trans_blocks[b].attn.out_proj.weight, 42 | params["blocks"][b]["attn"]["c_proj"]["w"].T 43 | ) 44 | gpt.trans_blocks[b].attn.out_proj.bias = assign( 45 | gpt.trans_blocks[b].attn.out_proj.bias, 46 | params["blocks"][b]["attn"]["c_proj"]["b"] 47 | ) 48 | 49 | gpt.trans_blocks[b].ffn.layers[0].weight = assign( 50 | gpt.trans_blocks[b].ffn.layers[0].weight, 51 | params["blocks"][b]["mlp"]["c_fc"]["w"].T 52 | ) 53 | gpt.trans_blocks[b].ffn.layers[0].bias = assign( 54 | gpt.trans_blocks[b].ffn.layers[0].bias, 55 | params["blocks"][b]["mlp"]["c_fc"]["b"] 56 | ) 57 | gpt.trans_blocks[b].ffn.layers[2].weight = assign( 58 | gpt.trans_blocks[b].ffn.layers[2].weight, 59 | params["blocks"][b]["mlp"]["c_proj"]["w"].T 60 | ) 61 | gpt.trans_blocks[b].ffn.layers[2].bias = assign( 62 | gpt.trans_blocks[b].ffn.layers[2].bias, 63 | params["blocks"][b]["mlp"]["c_proj"]["b"] 64 | ) 65 | 66 | gpt.trans_blocks[b].norm1.scale = assign( 67 | gpt.trans_blocks[b].norm1.scale, 68 | params["blocks"][b]["ln_1"]["g"] 69 | ) 70 | gpt.trans_blocks[b].norm1.shift = assign( 71 | gpt.trans_blocks[b].norm1.shift, 72 | params["blocks"][b]["ln_1"]["b"] 73 | ) 74 | gpt.trans_blocks[b].norm2.scale = assign( 75 | gpt.trans_blocks[b].norm2.scale, 76 | params["blocks"][b]["ln_2"]["g"] 77 | ) 78 | gpt.trans_blocks[b].norm2.shift = assign( 79 | gpt.trans_blocks[b].norm2.shift, 80 | params["blocks"][b]["ln_2"]["b"] 81 | ) 82 | gpt.final_norm.scale = assign( 83 | gpt.final_norm.scale, params["g"] 84 | ) 85 | gpt.final_norm.shift = assign( 86 | gpt.final_norm.shift, params["b"] 87 | ) 88 | 89 | gpt.out_head.weight = assign( 90 | gpt.out_head.weight, params["wte"] 91 | ) -------------------------------------------------------------------------------- /codes/plots.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | from matplotlib.ticker import MaxNLocator 3 | 4 | def plot_losses(epochs_seen, token_seen, train_losses, val_losses): 5 | fig, ax1 = plt.subplots(figsize=(5,3)) 6 | ax1.plot(epochs_seen, train_losses, label="Training loss") 7 | ax1.plot(epochs_seen, val_losses, linestyle="-.", label="Validation loss") 8 | 9 | ax1.set_xlabel("Epochs") 10 | ax1.set_ylabel("Loss") 11 | ax1.legend(loc="upper right") 12 | ax1.xaxis.set_major_locator(MaxNLocator(integer=True)) 13 | ax2 = ax1.twiny() 14 | ax2.plot(token_seen, train_losses, alpha=0) 15 | ax2.set_xlabel("Tokens seen") 16 | fig.tight_layout() 17 | plt.show() -------------------------------------------------------------------------------- /codes/solver.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from .losses import calc_loss_batch, evaluate_model 4 | from .utils import generate_and_print_sample 5 | 6 | def train_model_simple( 7 | model, 8 | train_loader, 9 | val_loader, 10 | optimizer, 11 | device, 12 | num_epochs, 13 | eval_freq, 14 | eval_iter, 15 | start_context, 16 | tokenizer 17 | ): 18 | train_losses, val_losses, track_tokens_seen = [], [], [] 19 | token_seen, global_step = 0, -1 20 | 21 | for epoch in range(num_epochs): 22 | model.train() 23 | 24 | for input_batch, target_batch in train_loader: 25 | optimizer.zero_grad() 26 | loss = calc_loss_batch( 27 | input_batch, target_batch, model, device 28 | ) 29 | loss.backward() 30 | optimizer.step() 31 | 32 | token_seen += input_batch.numel() 33 | global_step += 1 34 | 35 | if global_step % eval_freq == 0: 36 | train_loss, val_loss = evaluate_model( 37 | model, train_loader, val_loader, device, eval_iter 38 | ) 39 | train_losses.append(train_loss) 40 | val_losses.append(val_loss) 41 | 42 | track_tokens_seen.append(token_seen) 43 | 44 | print(f"Ep {epoch+1} (Step {global_step:06d}): " 45 | f"Train loss {train_loss:.3f} " 46 | f"Val loss {val_loss:.3f}" 47 | ) 48 | generate_and_print_sample( 49 | model, tokenizer, device, start_context 50 | ) 51 | 52 | return train_losses, val_losses, track_tokens_seen 53 | -------------------------------------------------------------------------------- /codes/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # greedy search 4 | def generate_text_simple(model, idx, max_new_tokens, context_size): 5 | for _ in range(max_new_tokens): 6 | idx_second = idx[:, -context_size:] 7 | with torch.no_grad(): 8 | logits = model(idx_second) 9 | 10 | logits = logits[:,-1,:] 11 | probs = torch.softmax(logits, dim=-1) 12 | idx_next = torch.argmax(probs, dim=-1, keepdim=True) 13 | 14 | idx = torch.cat([idx, idx_next], dim=1) 15 | return idx 16 | 17 | # probabilistic sampling + temperature scaling -> top-k sampling 18 | def generate(model, idx, max_new_tokens, context_size, 19 | temperature=0.0, top_k=None, eos_id=None): 20 | for _ in range(max_new_tokens): 21 | idx_cond = idx[:,-context_size:] 22 | with torch.no_grad(): 23 | logits = model(idx_cond) 24 | logits = logits[:, -1, :] 25 | 26 | if top_k is not None: 27 | top_logits, _ = torch.topk(logits, top_k) 28 | min_val = top_logits[:, -1] 29 | logits = torch.where(logits < min_val, 30 | torch.tensor(-torch.inf).to(logits.device), 31 | logits) 32 | 33 | if temperature > 0.0: 34 | logits /= temperature 35 | probs = torch.softmax(logits, dim=-1) 36 | idx_next = torch.multinomial(probs, num_samples=1) 37 | else: 38 | idx_next = torch.argmax(logits, dim=-1, keepdim=True) 39 | 40 | idx = torch.cat([idx, idx_next], dim=-1) 41 | return idx 42 | 43 | 44 | def text_to_token_ids(text, tokenizer): 45 | encoded = tokenizer.encode(text, allowed_special={"<|endoftext|>"}) 46 | encoded_tensor = torch.tensor(encoded).unsqueeze(0) 47 | return encoded_tensor 48 | 49 | 50 | def token_ids_to_text(token_ids, tokenizer): 51 | flat = token_ids.squeeze(0) 52 | return tokenizer.decode(flat.tolist()) 53 | 54 | def generate_and_print_sample(model, tokenizer, device, start_context): 55 | model.eval() 56 | context_size = model.pos_emb.weight.shape[0] 57 | encoded = text_to_token_ids(start_context, tokenizer).to(device) 58 | with torch.no_grad(): 59 | token_ids = generate_text_simple( 60 | model=model, idx=encoded, 61 | max_new_tokens=50, context_size=context_size 62 | ) 63 | decoded_text = token_ids_to_text(token_ids, tokenizer) 64 | print(decoded_text.replace("\n", " ")) 65 | model.train() 66 | 67 | -------------------------------------------------------------------------------- /train_llms_from_scratch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import tiktoken 3 | 4 | from codes.configs import GPT_CONFIG_124M 5 | from codes.gpt_model import GPTModel 6 | from codes.data import build_dataloader 7 | from codes.solver import train_model_simple 8 | from codes.plots import plot_losses 9 | 10 | import matplotlib.pyplot as plt 11 | from matplotlib.ticker import MaxNLocator 12 | 13 | if __name__ == '__main__': 14 | torch.manual_seed(123) 15 | 16 | device = torch.device("cpu") 17 | model = GPTModel(GPT_CONFIG_124M) 18 | model.to(device) 19 | 20 | train_loader, val_loader = build_dataloader() 21 | 22 | optimizer = torch.optim.AdamW(model.parameters(), 23 | lr=4e-4, weight_decay=0.1) 24 | 25 | tokenizer = tiktoken.get_encoding("gpt2") 26 | num_epochs = 10 27 | train_losses, val_losses, token_seen = train_model_simple( 28 | model, train_loader, val_loader, optimizer, device, 29 | num_epochs=num_epochs, eval_freq=5, eval_iter=5, 30 | start_context="Every effort moves you", 31 | tokenizer=tokenizer 32 | ) 33 | 34 | epochs_tensor = torch.linspace(0, num_epochs, len(train_losses)) 35 | plot_losses(epochs_tensor, token_seen, train_losses, val_losses) 36 | --------------------------------------------------------------------------------