├── MoE.png ├── Transf.png ├── README.md └── moe.ipynb /MoE.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonio-f/mixture-of-experts-from-scratch/HEAD/MoE.png -------------------------------------------------------------------------------- /Transf.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/antonio-f/mixture-of-experts-from-scratch/HEAD/Transf.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixture of Experts from scratch 2 | 3 | Simple implementation of Mixture of Experts technique. 4 | 5 | Also available on [Substack](https://monads.substack.com/p/mixture-of-experts-from-scratch). 6 | -------------------------------------------------------------------------------- /moe.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "id": "b772b407", 6 | "metadata": { 7 | "id": "b772b407" 8 | }, 9 | "source": [ 10 | "# Mixture of Experts from scratch" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "id": "89776a6a", 16 | "metadata": { 17 | "id": "89776a6a" 18 | }, 19 | "source": [ 20 | "\n", 21 | " \"Open\n", 22 | "" 23 | ] 24 | }, 25 | { 26 | "cell_type": "markdown", 27 | "id": "e704e55d", 28 | "metadata": { 29 | "id": "e704e55d" 30 | }, 31 | "source": [ 32 | "This is a simple implementation of **Mixture of Experts** (**MoE**) technique applied to language modeling tasks.\n", 33 | "\n", 34 | "Evaluation and training of deep models can be computationally expensive and time-consuming. The Conditional Computation approach has been proposed to tackle this problem.\n", 35 | "**Conditional Computation** refers to a class of algorithms in which each input sample uses a different part of the model such that (on average) the compute, latency or power (depending on our objective) is reduced. It operates by selectively activating only parts of the network at a time." 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "id": "dea0516a", 41 | "metadata": { 42 | "id": "dea0516a" 43 | }, 44 | "source": [ 45 | "### Loading data" 46 | ] 47 | }, 48 | { 49 | "cell_type": "markdown", 50 | "id": "6def97d6", 51 | "metadata": { 52 | "id": "6def97d6" 53 | }, 54 | "source": [ 55 | "We will use the TinyStories dataset ([info](https://huggingface.co/datasets/roneneldan/TinyStories)), it is is suitable and not overly large." 56 | ] 57 | }, 58 | { 59 | "cell_type": "code", 60 | "execution_count": null, 61 | "id": "tV2MvYTdVwe2", 62 | "metadata": { 63 | "colab": { 64 | "base_uri": "https://localhost:8080/" 65 | }, 66 | "id": "tV2MvYTdVwe2", 67 | "outputId": "1841330a-0f23-4c9d-a5b3-2d83b6951cf0" 68 | }, 69 | "outputs": [], 70 | "source": [ 71 | "!wget https://huggingface.co/datasets/roneneldan/TinyStories/resolve/main/TinyStories_all_data.tar.gz" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "id": "4015770f", 77 | "metadata": { 78 | "id": "4015770f" 79 | }, 80 | "source": [ 81 | "We import some modules providing operating system dependent functionality like operations on files, paths etc." 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": 2, 87 | "id": "ZKuW2rbWW7G-", 88 | "metadata": { 89 | "id": "ZKuW2rbWW7G-" 90 | }, 91 | "outputs": [], 92 | "source": [ 93 | "import os\n", 94 | "import glob\n", 95 | "import json" 96 | ] 97 | }, 98 | { 99 | "cell_type": "markdown", 100 | "id": "8a6b0fb8", 101 | "metadata": { 102 | "id": "8a6b0fb8" 103 | }, 104 | "source": [ 105 | "Now we create `TinyStories` folder and extract data inside it." 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": 3, 111 | "id": "7dfd7fc0", 112 | "metadata": { 113 | "id": "7dfd7fc0" 114 | }, 115 | "outputs": [], 116 | "source": [ 117 | "if not os.path.exists(\"./TinyStories\"):\n", 118 | " os.makedirs(\"./TinyStories\")\n", 119 | "\n", 120 | "!tar -xzf TinyStories_all_data.tar.gz -C TinyStories" 121 | ] 122 | }, 123 | { 124 | "cell_type": "markdown", 125 | "id": "d505f287", 126 | "metadata": { 127 | "id": "d505f287" 128 | }, 129 | "source": [ 130 | "The following command returns a list of paths like\n", 131 | "\n", 132 | "`'TinyStories/data00.json'`\n", 133 | "\n", 134 | "`'TinyStories/data01.json'`\n", 135 | "\n", 136 | "`'TinyStories/data02.json'`\n", 137 | "\n", 138 | ". . .\n", 139 | "\n", 140 | "and so on." 141 | ] 142 | }, 143 | { 144 | "cell_type": "code", 145 | "execution_count": 4, 146 | "id": "-N5m4Rk4Wp72", 147 | "metadata": { 148 | "id": "-N5m4Rk4Wp72" 149 | }, 150 | "outputs": [], 151 | "source": [ 152 | "shard_filenames = sorted(glob.glob(os.path.join('TinyStories', \"*.json\")))" 153 | ] 154 | }, 155 | { 156 | "cell_type": "markdown", 157 | "id": "6d89f17c", 158 | "metadata": { 159 | "id": "6d89f17c" 160 | }, 161 | "source": [ 162 | "We load each json file into `data`." 163 | ] 164 | }, 165 | { 166 | "cell_type": "code", 167 | "execution_count": 5, 168 | "id": "1m-He99IXCV4", 169 | "metadata": { 170 | "id": "1m-He99IXCV4" 171 | }, 172 | "outputs": [], 173 | "source": [ 174 | "with open(shard_filenames[0], \"r\") as f:\n", 175 | " data = json.load(f)" 176 | ] 177 | }, 178 | { 179 | "cell_type": "markdown", 180 | "id": "0bda234c", 181 | "metadata": { 182 | "id": "0bda234c" 183 | }, 184 | "source": [ 185 | "Let us check the first element of `data`." 186 | ] 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": 6, 191 | "id": "5000d128", 192 | "metadata": { 193 | "colab": { 194 | "base_uri": "https://localhost:8080/" 195 | }, 196 | "id": "5000d128", 197 | "outputId": "1b2c74ef-930e-4a26-8602-af0ada9a6f0a" 198 | }, 199 | "outputs": [ 200 | { 201 | "data": { 202 | "text/plain": [ 203 | "{'story': '\\n\\nLily and Ben are friends. They like to play in the park. One day, they see a big tree with a swing. Lily wants to try the swing. She runs to the tree and climbs on the swing.\\n\"Push me, Ben!\" she says. Ben pushes her gently. Lily feels happy. She swings higher and higher. She laughs and shouts.\\nBen watches Lily. He thinks she is cute. He wants to swing too. He waits for Lily to stop. But Lily does not stop. She swings faster and faster. She is having too much fun.\\n\"Can I swing too, Lily?\" Ben asks. Lily does not hear him. She is too busy swinging. Ben feels sad. He walks away.\\nLily swings so high that she loses her grip. She falls off the swing. She lands on the ground. She hurts her foot. She cries.\\n\"Ow, ow, ow!\" she says. She looks for Ben. She wants him to help her. But Ben is not there. He is gone.\\nLily feels sorry. She wishes she had shared the swing with Ben. She wishes he was there to hug her. She limps to the tree. She sees something hanging from a branch. It is Ben\\'s hat. He left it for her.\\nLily smiles. She thinks Ben is nice. She puts on his hat. She hopes he will come back. She wants to say sorry. She wants to be friends again.',\n", 204 | " 'instruction': {'prompt:': 'Write a short story (3-5 paragraphs) which only uses very simple words that a 3 year old child would understand. The story should use the verb \"hang\", the noun \"foot\" and the adjective \"cute\". The story has the following features: the story should contain at least one dialogue. Remember to only use simple words!\\n\\nPossible story:',\n", 205 | " 'words': ['hang', 'foot', 'cute'],\n", 206 | " 'features': ['Dialogue']},\n", 207 | " 'summary': 'Lily and Ben play in the park and Lily gets too caught up in swinging, causing Ben to leave. Lily falls off the swing and hurts herself, but Ben leaves his hat for her as a kind gesture.',\n", 208 | " 'source': 'GPT-4'}" 209 | ] 210 | }, 211 | "execution_count": 6, 212 | "metadata": {}, 213 | "output_type": "execute_result" 214 | } 215 | ], 216 | "source": [ 217 | "data[0]" 218 | ] 219 | }, 220 | { 221 | "cell_type": "markdown", 222 | "id": "5dac7b43", 223 | "metadata": { 224 | "id": "5dac7b43" 225 | }, 226 | "source": [ 227 | "We collect all stories in the `stories` list." 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": 7, 233 | "id": "wmsSEsysXFXF", 234 | "metadata": { 235 | "id": "wmsSEsysXFXF" 236 | }, 237 | "outputs": [], 238 | "source": [ 239 | "stories = [x['story'] for x in data]" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "id": "f7f9b21c", 245 | "metadata": { 246 | "id": "f7f9b21c" 247 | }, 248 | "source": [ 249 | "A sample from `stories` is the following." 250 | ] 251 | }, 252 | { 253 | "cell_type": "code", 254 | "execution_count": 8, 255 | "id": "yigoqX6fXYMl", 256 | "metadata": { 257 | "colab": { 258 | "base_uri": "https://localhost:8080/", 259 | "height": 103 260 | }, 261 | "id": "yigoqX6fXYMl", 262 | "outputId": "5f440e13-2d34-40c0-9f7f-4d4a973b9288" 263 | }, 264 | "outputs": [ 265 | { 266 | "data": { 267 | "application/vnd.google.colaboratory.intrinsic+json": { 268 | "type": "string" 269 | }, 270 | "text/plain": [ 271 | "\"Once upon a time, there was a little girl named Lily. Lily loved to play in the park with her friends. One day, Lily and her friends were playing hide and seek. Lily found a good hiding spot behind a big tree. As she was hiding, she started to yawn because she was very tired.\\nSuddenly, Lily saw an enormous shadow coming towards her. She got scared and started to cry. It turned out that the shadow was just her friend, Timmy. Timmy had found her hiding spot and was trying to surprise her. \\nLily learned that sometimes things that seem scary are not really scary at all. She also learned that it's important to get enough sleep so you don't yawn during the day. From that day on, Lily made sure to get plenty of rest before playing with her friends.\"" 272 | ] 273 | }, 274 | "execution_count": 8, 275 | "metadata": {}, 276 | "output_type": "execute_result" 277 | } 278 | ], 279 | "source": [ 280 | "stories[42]" 281 | ] 282 | }, 283 | { 284 | "cell_type": "markdown", 285 | "id": "7000983c", 286 | "metadata": { 287 | "id": "7000983c" 288 | }, 289 | "source": [ 290 | "All the stories are joined together into the string called `text`. At the end of each story there is a new line `\\n` escape sequence." 291 | ] 292 | }, 293 | { 294 | "cell_type": "code", 295 | "execution_count": 9, 296 | "id": "1f30df8b", 297 | "metadata": { 298 | "id": "1f30df8b" 299 | }, 300 | "outputs": [], 301 | "source": [ 302 | "text = \"\\n\".join(stories)" 303 | ] 304 | }, 305 | { 306 | "cell_type": "markdown", 307 | "id": "1340c982", 308 | "metadata": { 309 | "id": "1340c982" 310 | }, 311 | "source": [ 312 | "`text` is a very long string." 313 | ] 314 | }, 315 | { 316 | "cell_type": "code", 317 | "execution_count": 10, 318 | "id": "a9e34655", 319 | "metadata": { 320 | "colab": { 321 | "base_uri": "https://localhost:8080/" 322 | }, 323 | "id": "a9e34655", 324 | "outputId": "0d79d626-8599-4878-bed5-52cf0acfbdf8" 325 | }, 326 | "outputs": [ 327 | { 328 | "data": { 329 | "text/plain": [ 330 | "77586884" 331 | ] 332 | }, 333 | "execution_count": 10, 334 | "metadata": {}, 335 | "output_type": "execute_result" 336 | } 337 | ], 338 | "source": [ 339 | "len(text)" 340 | ] 341 | }, 342 | { 343 | "cell_type": "code", 344 | "execution_count": 11, 345 | "id": "c596e794", 346 | "metadata": { 347 | "colab": { 348 | "base_uri": "https://localhost:8080/" 349 | }, 350 | "id": "c596e794", 351 | "outputId": "c7179d56-dfbb-4c3b-dcb9-a22fd2103f99" 352 | }, 353 | "outputs": [ 354 | { 355 | "name": "stdout", 356 | "output_type": "stream", 357 | "text": [ 358 | "\n", 359 | "\n", 360 | "Lily and Ben are friends. They like to play in the park. One day, they see a big tree with a swing\n" 361 | ] 362 | } 363 | ], 364 | "source": [ 365 | "print(text[:100])" 366 | ] 367 | }, 368 | { 369 | "cell_type": "markdown", 370 | "id": "28829cf1", 371 | "metadata": { 372 | "id": "28829cf1" 373 | }, 374 | "source": [ 375 | "### Character encoding" 376 | ] 377 | }, 378 | { 379 | "cell_type": "markdown", 380 | "id": "f3885408", 381 | "metadata": { 382 | "id": "f3885408" 383 | }, 384 | "source": [ 385 | "We are going to use PyTorch tensors to store data." 386 | ] 387 | }, 388 | { 389 | "cell_type": "code", 390 | "execution_count": 12, 391 | "id": "14e943d7", 392 | "metadata": { 393 | "id": "14e943d7" 394 | }, 395 | "outputs": [], 396 | "source": [ 397 | "import torch" 398 | ] 399 | }, 400 | { 401 | "cell_type": "markdown", 402 | "id": "e14f124e", 403 | "metadata": { 404 | "id": "e14f124e" 405 | }, 406 | "source": [ 407 | "`chars` contains all the characters found in the text (joined stories). Its size is 97." 408 | ] 409 | }, 410 | { 411 | "cell_type": "code", 412 | "execution_count": 13, 413 | "id": "f8InhsLGbFCj", 414 | "metadata": { 415 | "colab": { 416 | "base_uri": "https://localhost:8080/" 417 | }, 418 | "id": "f8InhsLGbFCj", 419 | "outputId": "2214e646-9a72-4dd6-b38a-b97c19dc7c4c" 420 | }, 421 | "outputs": [ 422 | { 423 | "name": "stdout", 424 | "output_type": "stream", 425 | "text": [ 426 | "\t\n", 427 | " !\"$%&'()*+,-./0123456789:;?ABCDEFGHIJKLMNOPQRSTUVWXYZ[]`abcdefghijklmnopqrstuvwxyz|~ éñ–—‘’“”…\n", 428 | "97\n" 429 | ] 430 | } 431 | ], 432 | "source": [ 433 | "chars = sorted(list(set(text)))\n", 434 | "vocab_size = len(chars)\n", 435 | "print(''.join(chars))\n", 436 | "print(vocab_size)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "markdown", 441 | "id": "54de4004", 442 | "metadata": { 443 | "id": "54de4004" 444 | }, 445 | "source": [ 446 | "Below, two dictionaries: the first binds characters to integers and the second does the reverse." 447 | ] 448 | }, 449 | { 450 | "cell_type": "code", 451 | "execution_count": 14, 452 | "id": "UHOqlwJtbKu2", 453 | "metadata": { 454 | "id": "UHOqlwJtbKu2" 455 | }, 456 | "outputs": [], 457 | "source": [ 458 | "ctoi = {ch:i for i, ch in enumerate(chars)}\n", 459 | "itoc = {i:ch for i,ch in enumerate(chars)}" 460 | ] 461 | }, 462 | { 463 | "cell_type": "code", 464 | "execution_count": 15, 465 | "id": "cec90814", 466 | "metadata": { 467 | "colab": { 468 | "base_uri": "https://localhost:8080/" 469 | }, 470 | "id": "cec90814", 471 | "outputId": "f0c5f30c-f7df-40ec-99f5-3e2e11c4d9ae" 472 | }, 473 | "outputs": [ 474 | { 475 | "data": { 476 | "text/plain": [ 477 | "{'\\t': 0,\n", 478 | " '\\n': 1,\n", 479 | " ' ': 2,\n", 480 | " '!': 3,\n", 481 | " '\"': 4,\n", 482 | " '$': 5,\n", 483 | " '%': 6,\n", 484 | " '&': 7,\n", 485 | " \"'\": 8,\n", 486 | " '(': 9,\n", 487 | " ')': 10,\n", 488 | " '*': 11,\n", 489 | " '+': 12,\n", 490 | " ',': 13,\n", 491 | " '-': 14,\n", 492 | " '.': 15,\n", 493 | " '/': 16,\n", 494 | " '0': 17,\n", 495 | " '1': 18,\n", 496 | " '2': 19,\n", 497 | " '3': 20,\n", 498 | " '4': 21,\n", 499 | " '5': 22,\n", 500 | " '6': 23,\n", 501 | " '7': 24,\n", 502 | " '8': 25,\n", 503 | " '9': 26,\n", 504 | " ':': 27,\n", 505 | " ';': 28,\n", 506 | " '?': 29,\n", 507 | " 'A': 30,\n", 508 | " 'B': 31,\n", 509 | " 'C': 32,\n", 510 | " 'D': 33,\n", 511 | " 'E': 34,\n", 512 | " 'F': 35,\n", 513 | " 'G': 36,\n", 514 | " 'H': 37,\n", 515 | " 'I': 38,\n", 516 | " 'J': 39,\n", 517 | " 'K': 40,\n", 518 | " 'L': 41,\n", 519 | " 'M': 42,\n", 520 | " 'N': 43,\n", 521 | " 'O': 44,\n", 522 | " 'P': 45,\n", 523 | " 'Q': 46,\n", 524 | " 'R': 47,\n", 525 | " 'S': 48,\n", 526 | " 'T': 49,\n", 527 | " 'U': 50,\n", 528 | " 'V': 51,\n", 529 | " 'W': 52,\n", 530 | " 'X': 53,\n", 531 | " 'Y': 54,\n", 532 | " 'Z': 55,\n", 533 | " '[': 56,\n", 534 | " ']': 57,\n", 535 | " '`': 58,\n", 536 | " 'a': 59,\n", 537 | " 'b': 60,\n", 538 | " 'c': 61,\n", 539 | " 'd': 62,\n", 540 | " 'e': 63,\n", 541 | " 'f': 64,\n", 542 | " 'g': 65,\n", 543 | " 'h': 66,\n", 544 | " 'i': 67,\n", 545 | " 'j': 68,\n", 546 | " 'k': 69,\n", 547 | " 'l': 70,\n", 548 | " 'm': 71,\n", 549 | " 'n': 72,\n", 550 | " 'o': 73,\n", 551 | " 'p': 74,\n", 552 | " 'q': 75,\n", 553 | " 'r': 76,\n", 554 | " 's': 77,\n", 555 | " 't': 78,\n", 556 | " 'u': 79,\n", 557 | " 'v': 80,\n", 558 | " 'w': 81,\n", 559 | " 'x': 82,\n", 560 | " 'y': 83,\n", 561 | " 'z': 84,\n", 562 | " '|': 85,\n", 563 | " '~': 86,\n", 564 | " '\\xa0': 87,\n", 565 | " 'é': 88,\n", 566 | " 'ñ': 89,\n", 567 | " '–': 90,\n", 568 | " '—': 91,\n", 569 | " '‘': 92,\n", 570 | " '’': 93,\n", 571 | " '“': 94,\n", 572 | " '”': 95,\n", 573 | " '…': 96}" 574 | ] 575 | }, 576 | "execution_count": 15, 577 | "metadata": {}, 578 | "output_type": "execute_result" 579 | } 580 | ], 581 | "source": [ 582 | "ctoi" 583 | ] 584 | }, 585 | { 586 | "cell_type": "markdown", 587 | "id": "92bbcd2f", 588 | "metadata": { 589 | "id": "92bbcd2f" 590 | }, 591 | "source": [ 592 | "The encoding function transforms a text `s` into a list of integer (one for each character). Decode works exactly in the reverse order: it takes a list of integers and returns the text composed of the characters obtained decoding these integers. For example\n", 593 | "\n", 594 | "`encode(\"Hello, world!\")`\n", 595 | "\n", 596 | "returns the list\n", 597 | "\n", 598 | "`[37, 63, 70, 70, 73, 13, 2, 81, 73, 76, 70, 62, 3]`.\n", 599 | "\n", 600 | "Likewise,\n", 601 | "\n", 602 | "`decode([37, 63, 70, 70, 73, 13, 2, 81, 73, 76, 70, 62, 3])`\n", 603 | "\n", 604 | "returns the string\n", 605 | "\n", 606 | "`'Hello, world!'`." 607 | ] 608 | }, 609 | { 610 | "cell_type": "code", 611 | "execution_count": 16, 612 | "id": "muglCKTAbNQi", 613 | "metadata": { 614 | "id": "muglCKTAbNQi" 615 | }, 616 | "outputs": [], 617 | "source": [ 618 | "encode = lambda s: [ctoi[c] for c in s]\n", 619 | "decode = lambda l: \"\".join([itoc[x] for x in l])" 620 | ] 621 | }, 622 | { 623 | "cell_type": "markdown", 624 | "id": "8837c47e", 625 | "metadata": { 626 | "id": "8837c47e" 627 | }, 628 | "source": [ 629 | "We store the encoded text into a tensor named `data` (that is not the variable encountered before)." 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 17, 635 | "id": "9dc05388", 636 | "metadata": { 637 | "id": "9dc05388" 638 | }, 639 | "outputs": [], 640 | "source": [ 641 | "data = torch.tensor(encode(text), dtype = torch.long)" 642 | ] 643 | }, 644 | { 645 | "cell_type": "code", 646 | "execution_count": 18, 647 | "id": "a4f27fc3", 648 | "metadata": { 649 | "colab": { 650 | "base_uri": "https://localhost:8080/" 651 | }, 652 | "id": "a4f27fc3", 653 | "outputId": "f99f3f01-4737-4e0f-8f90-6675d5b8cf8f" 654 | }, 655 | "outputs": [ 656 | { 657 | "data": { 658 | "text/plain": [ 659 | "(torch.Size([77586884]), torch.Tensor)" 660 | ] 661 | }, 662 | "execution_count": 18, 663 | "metadata": {}, 664 | "output_type": "execute_result" 665 | } 666 | ], 667 | "source": [ 668 | "data.shape, type(data)" 669 | ] 670 | }, 671 | { 672 | "cell_type": "code", 673 | "execution_count": 19, 674 | "id": "f67b1902", 675 | "metadata": { 676 | "colab": { 677 | "base_uri": "https://localhost:8080/" 678 | }, 679 | "id": "f67b1902", 680 | "outputId": "d9a12499-b006-4d79-f1b9-7114443debe3" 681 | }, 682 | "outputs": [ 683 | { 684 | "data": { 685 | "text/plain": [ 686 | "tensor([ 1, 1, 41, 67, 70, 83, 2, 59, 72, 62, 2, 31, 63, 72, 2, 59, 76, 63,\n", 687 | " 2, 64, 76, 67, 63, 72, 62, 77, 15, 2, 49, 66, 63, 83, 2, 70, 67, 69,\n", 688 | " 63, 2, 78, 73, 2, 74, 70, 59, 83, 2, 67, 72, 2, 78, 66, 63, 2, 74,\n", 689 | " 59, 76, 69, 15, 2, 44, 72, 63, 2, 62, 59, 83, 13, 2, 78, 66, 63, 83,\n", 690 | " 2, 77, 63, 63, 2, 59, 2, 60, 67, 65, 2, 78, 76, 63, 63, 2, 81, 67,\n", 691 | " 78, 66, 2, 59, 2, 77, 81, 67, 72, 65])" 692 | ] 693 | }, 694 | "execution_count": 19, 695 | "metadata": {}, 696 | "output_type": "execute_result" 697 | } 698 | ], 699 | "source": [ 700 | "data[:100]" 701 | ] 702 | }, 703 | { 704 | "cell_type": "markdown", 705 | "id": "aa47c3ca", 706 | "metadata": { 707 | "id": "aa47c3ca" 708 | }, 709 | "source": [ 710 | "### Data splitting" 711 | ] 712 | }, 713 | { 714 | "cell_type": "markdown", 715 | "id": "e7a86910", 716 | "metadata": { 717 | "id": "e7a86910" 718 | }, 719 | "source": [ 720 | "Now it's time to create training and validation datasets.\n", 721 | "Training data amounts to 90% of all data, the rest is validation data." 722 | ] 723 | }, 724 | { 725 | "cell_type": "code", 726 | "execution_count": 20, 727 | "id": "66922688", 728 | "metadata": { 729 | "id": "66922688" 730 | }, 731 | "outputs": [], 732 | "source": [ 733 | "n = int(0.9*len(data))\n", 734 | "train_data = data[:n]\n", 735 | "val_data = data[n:]" 736 | ] 737 | }, 738 | { 739 | "cell_type": "markdown", 740 | "id": "7dc4b801", 741 | "metadata": { 742 | "id": "7dc4b801" 743 | }, 744 | "source": [ 745 | "Let's define a temporary block size, setting it equal to 8 for testing purposes only. Subsequently this parameter will be set to 256 because it represents the length of the context - it is the set of data that will be provided to the MoE model from time to time." 746 | ] 747 | }, 748 | { 749 | "cell_type": "code", 750 | "execution_count": 21, 751 | "id": "62038547", 752 | "metadata": { 753 | "id": "62038547" 754 | }, 755 | "outputs": [], 756 | "source": [ 757 | "block_size = 8" 758 | ] 759 | }, 760 | { 761 | "cell_type": "code", 762 | "execution_count": 22, 763 | "id": "78a76ae5", 764 | "metadata": { 765 | "colab": { 766 | "base_uri": "https://localhost:8080/" 767 | }, 768 | "id": "78a76ae5", 769 | "outputId": "ead80c08-d88e-47f7-a180-5d4896e2045f" 770 | }, 771 | "outputs": [ 772 | { 773 | "data": { 774 | "text/plain": [ 775 | "tensor([ 1, 1, 41, 67, 70, 83, 2, 59, 72])" 776 | ] 777 | }, 778 | "execution_count": 22, 779 | "metadata": {}, 780 | "output_type": "execute_result" 781 | } 782 | ], 783 | "source": [ 784 | "# Training data block example\n", 785 | "train_data[:block_size+1]" 786 | ] 787 | }, 788 | { 789 | "cell_type": "markdown", 790 | "id": "57669c35", 791 | "metadata": { 792 | "id": "57669c35" 793 | }, 794 | "source": [ 795 | "Basically, these language models are trained to guess, given n elements of text - words, parts of words, or like in this character-level case, just characters - the next text element. We are going to train a character-level model so, for example, if the first 8 characters (the context) are `your nam`, the next (the 9th) should be `e` (the target). So we need integers `x` for the training data and integers `y` representing all the targets." 796 | ] 797 | }, 798 | { 799 | "cell_type": "code", 800 | "execution_count": 23, 801 | "id": "bfd62ded", 802 | "metadata": { 803 | "id": "bfd62ded" 804 | }, 805 | "outputs": [], 806 | "source": [ 807 | "x = train_data[:block_size]" 808 | ] 809 | }, 810 | { 811 | "cell_type": "code", 812 | "execution_count": 24, 813 | "id": "ed9341b9", 814 | "metadata": { 815 | "id": "ed9341b9" 816 | }, 817 | "outputs": [], 818 | "source": [ 819 | "y = train_data[1:block_size+1]" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": 25, 825 | "id": "cc4bde3a", 826 | "metadata": { 827 | "colab": { 828 | "base_uri": "https://localhost:8080/" 829 | }, 830 | "id": "cc4bde3a", 831 | "outputId": "1a4ac09a-267a-477a-d18e-7ee15e04ea03" 832 | }, 833 | "outputs": [ 834 | { 835 | "data": { 836 | "text/plain": [ 837 | "(tensor([ 1, 1, 41, 67, 70, 83, 2, 59]),\n", 838 | " tensor([ 1, 41, 67, 70, 83, 2, 59, 72]))" 839 | ] 840 | }, 841 | "execution_count": 25, 842 | "metadata": {}, 843 | "output_type": "execute_result" 844 | } 845 | ], 846 | "source": [ 847 | "x,y" 848 | ] 849 | }, 850 | { 851 | "cell_type": "markdown", 852 | "id": "91359d43", 853 | "metadata": { 854 | "id": "91359d43" 855 | }, 856 | "source": [ 857 | "Here are some examples of contexts-targets, as `t` varies, based on the two tensors `x` and `y` above." 858 | ] 859 | }, 860 | { 861 | "cell_type": "code", 862 | "execution_count": 26, 863 | "id": "e2f98640", 864 | "metadata": { 865 | "colab": { 866 | "base_uri": "https://localhost:8080/" 867 | }, 868 | "id": "e2f98640", 869 | "outputId": "dad178ad-95cc-4a13-bcb0-e12d3df81a31" 870 | }, 871 | "outputs": [ 872 | { 873 | "name": "stdout", 874 | "output_type": "stream", 875 | "text": [ 876 | "context tensor([1]) target tensor(1)\n", 877 | "context tensor([1, 1]) target tensor(41)\n", 878 | "context tensor([ 1, 1, 41]) target tensor(67)\n", 879 | "context tensor([ 1, 1, 41, 67]) target tensor(70)\n", 880 | "context tensor([ 1, 1, 41, 67, 70]) target tensor(83)\n", 881 | "context tensor([ 1, 1, 41, 67, 70, 83]) target tensor(2)\n", 882 | "context tensor([ 1, 1, 41, 67, 70, 83, 2]) target tensor(59)\n", 883 | "context tensor([ 1, 1, 41, 67, 70, 83, 2, 59]) target tensor(72)\n" 884 | ] 885 | } 886 | ], 887 | "source": [ 888 | "for t in range(block_size):\n", 889 | " context = x[:t+1]\n", 890 | " target = y[t]\n", 891 | " print(\"context\", context, \"target\", target)" 892 | ] 893 | }, 894 | { 895 | "cell_type": "markdown", 896 | "id": "bbc5d33e", 897 | "metadata": { 898 | "id": "bbc5d33e" 899 | }, 900 | "source": [ 901 | "For reproducibility, we set a seed for PyTorch. Reproducibility is about limiting the number of sources of nondeterministic behavior for a specific platform, device, and PyTorch release. Often, it is possible to control sources of randomness that can cause multiple executions of your application to behave differently." 902 | ] 903 | }, 904 | { 905 | "cell_type": "code", 906 | "execution_count": 27, 907 | "id": "7d8ab5b0", 908 | "metadata": { 909 | "colab": { 910 | "base_uri": "https://localhost:8080/" 911 | }, 912 | "id": "7d8ab5b0", 913 | "outputId": "90edfdf2-224b-4480-c320-de4c5cd140eb" 914 | }, 915 | "outputs": [ 916 | { 917 | "data": { 918 | "text/plain": [ 919 | "" 920 | ] 921 | }, 922 | "execution_count": 27, 923 | "metadata": {}, 924 | "output_type": "execute_result" 925 | } 926 | ], 927 | "source": [ 928 | "torch.manual_seed(0)" 929 | ] 930 | }, 931 | { 932 | "cell_type": "markdown", 933 | "id": "2bf0bb31", 934 | "metadata": { 935 | "id": "2bf0bb31" 936 | }, 937 | "source": [ 938 | "### Creating batches" 939 | ] 940 | }, 941 | { 942 | "cell_type": "markdown", 943 | "id": "6e9f8fa1", 944 | "metadata": { 945 | "id": "6e9f8fa1" 946 | }, 947 | "source": [ 948 | "We set the batch size to 4 for testing (will be changed later). Batch size is how many independent sequences are going to be processed in parallel." 949 | ] 950 | }, 951 | { 952 | "cell_type": "code", 953 | "execution_count": 28, 954 | "id": "4f400257", 955 | "metadata": { 956 | "id": "4f400257" 957 | }, 958 | "outputs": [], 959 | "source": [ 960 | "batch_size = 4" 961 | ] 962 | }, 963 | { 964 | "cell_type": "markdown", 965 | "id": "349b2197", 966 | "metadata": { 967 | "id": "349b2197" 968 | }, 969 | "source": [ 970 | "The following function splits the data into batches." 971 | ] 972 | }, 973 | { 974 | "cell_type": "code", 975 | "execution_count": 29, 976 | "id": "6ef9ed1f", 977 | "metadata": { 978 | "id": "6ef9ed1f" 979 | }, 980 | "outputs": [], 981 | "source": [ 982 | "def get_batch(split):\n", 983 | " # generate a small bunch of data of inputs x and targets y\n", 984 | " data = train_data if split == 'train' else val_data\n", 985 | " ix = torch.randint(len(data) - block_size, (batch_size,))\n", 986 | " x = torch.stack([data[i:i+block_size] for i in ix])\n", 987 | " y = torch.stack([data[i+1:i+block_size+1] for i in ix])\n", 988 | " return x, y" 989 | ] 990 | }, 991 | { 992 | "cell_type": "code", 993 | "execution_count": 30, 994 | "id": "1ed4a1c6", 995 | "metadata": { 996 | "id": "1ed4a1c6" 997 | }, 998 | "outputs": [], 999 | "source": [ 1000 | "xb, yb = get_batch('train')" 1001 | ] 1002 | }, 1003 | { 1004 | "cell_type": "code", 1005 | "execution_count": 31, 1006 | "id": "abafad30", 1007 | "metadata": { 1008 | "colab": { 1009 | "base_uri": "https://localhost:8080/" 1010 | }, 1011 | "id": "abafad30", 1012 | "outputId": "7e28e5aa-3a9f-4907-9a84-c707f34e3515" 1013 | }, 1014 | "outputs": [ 1015 | { 1016 | "data": { 1017 | "text/plain": [ 1018 | "tensor([[71, 71, 83, 2, 64, 73, 76, 2],\n", 1019 | " [67, 77, 2, 64, 59, 80, 73, 76],\n", 1020 | " [59, 72, 65, 2, 59, 72, 62, 2],\n", 1021 | " [ 2, 81, 73, 79, 70, 62, 2, 78]])" 1022 | ] 1023 | }, 1024 | "execution_count": 31, 1025 | "metadata": {}, 1026 | "output_type": "execute_result" 1027 | } 1028 | ], 1029 | "source": [ 1030 | "yb" 1031 | ] 1032 | }, 1033 | { 1034 | "cell_type": "markdown", 1035 | "id": "61c5f0ac", 1036 | "metadata": { 1037 | "id": "61c5f0ac" 1038 | }, 1039 | "source": [ 1040 | "Below, examples of context-target sequences on 4 batches." 1041 | ] 1042 | }, 1043 | { 1044 | "cell_type": "code", 1045 | "execution_count": 32, 1046 | "id": "f083be50", 1047 | "metadata": { 1048 | "colab": { 1049 | "base_uri": "https://localhost:8080/" 1050 | }, 1051 | "id": "f083be50", 1052 | "outputId": "a3d6f6c5-af1e-4f59-8753-e4a9e10f11c7" 1053 | }, 1054 | "outputs": [ 1055 | { 1056 | "name": "stdout", 1057 | "output_type": "stream", 1058 | "text": [ 1059 | "tensor([73]) tensor(71)\n", 1060 | "tensor([73, 71]) tensor(71)\n", 1061 | "tensor([73, 71, 71]) tensor(83)\n", 1062 | "tensor([73, 71, 71, 83]) tensor(2)\n", 1063 | "tensor([73, 71, 71, 83, 2]) tensor(64)\n", 1064 | "tensor([73, 71, 71, 83, 2, 64]) tensor(73)\n", 1065 | "tensor([73, 71, 71, 83, 2, 64, 73]) tensor(76)\n", 1066 | "tensor([73, 71, 71, 83, 2, 64, 73, 76]) tensor(2)\n", 1067 | "\n", 1068 | "tensor([66]) tensor(67)\n", 1069 | "tensor([66, 67]) tensor(77)\n", 1070 | "tensor([66, 67, 77]) tensor(2)\n", 1071 | "tensor([66, 67, 77, 2]) tensor(64)\n", 1072 | "tensor([66, 67, 77, 2, 64]) tensor(59)\n", 1073 | "tensor([66, 67, 77, 2, 64, 59]) tensor(80)\n", 1074 | "tensor([66, 67, 77, 2, 64, 59, 80]) tensor(73)\n", 1075 | "tensor([66, 67, 77, 2, 64, 59, 80, 73]) tensor(76)\n", 1076 | "\n", 1077 | "tensor([77]) tensor(59)\n", 1078 | "tensor([77, 59]) tensor(72)\n", 1079 | "tensor([77, 59, 72]) tensor(65)\n", 1080 | "tensor([77, 59, 72, 65]) tensor(2)\n", 1081 | "tensor([77, 59, 72, 65, 2]) tensor(59)\n", 1082 | "tensor([77, 59, 72, 65, 2, 59]) tensor(72)\n", 1083 | "tensor([77, 59, 72, 65, 2, 59, 72]) tensor(62)\n", 1084 | "tensor([77, 59, 72, 65, 2, 59, 72, 62]) tensor(2)\n", 1085 | "\n", 1086 | "tensor([63]) tensor(2)\n", 1087 | "tensor([63, 2]) tensor(81)\n", 1088 | "tensor([63, 2, 81]) tensor(73)\n", 1089 | "tensor([63, 2, 81, 73]) tensor(79)\n", 1090 | "tensor([63, 2, 81, 73, 79]) tensor(70)\n", 1091 | "tensor([63, 2, 81, 73, 79, 70]) tensor(62)\n", 1092 | "tensor([63, 2, 81, 73, 79, 70, 62]) tensor(2)\n", 1093 | "tensor([63, 2, 81, 73, 79, 70, 62, 2]) tensor(78)\n", 1094 | "\n" 1095 | ] 1096 | } 1097 | ], 1098 | "source": [ 1099 | "for b in range(batch_size):\n", 1100 | " for t in range(block_size):\n", 1101 | " context = xb[b][:t+1]\n", 1102 | " target = yb[b][t]\n", 1103 | " print(context, \" \", target)\n", 1104 | " print()" 1105 | ] 1106 | }, 1107 | { 1108 | "cell_type": "markdown", 1109 | "id": "759eccde", 1110 | "metadata": { 1111 | "id": "759eccde" 1112 | }, 1113 | "source": [ 1114 | "### Models" 1115 | ] 1116 | }, 1117 | { 1118 | "cell_type": "markdown", 1119 | "id": "df33320b", 1120 | "metadata": { 1121 | "id": "df33320b" 1122 | }, 1123 | "source": [ 1124 | "Let's import some PyTorch neural networks modules." 1125 | ] 1126 | }, 1127 | { 1128 | "cell_type": "code", 1129 | "execution_count": 33, 1130 | "id": "90997e9e", 1131 | "metadata": { 1132 | "id": "90997e9e" 1133 | }, 1134 | "outputs": [], 1135 | "source": [ 1136 | "import torch.nn as nn\n", 1137 | "from torch.nn import functional as F" 1138 | ] 1139 | }, 1140 | { 1141 | "cell_type": "markdown", 1142 | "id": "9cd1f110", 1143 | "metadata": { 1144 | "id": "9cd1f110" 1145 | }, 1146 | "source": [ 1147 | "The core of MoE technique is provided by the following code. The MoE layer is a type of neural network layer that combines the predictions of multiple expert networks based a gating mechanism. The gating mechanism is learned.\n", 1148 | "\n", 1149 | "The `__init__` method initializes the `MoeLayer` class with a list of expert modules (`experts`), a gate module (`gate`), and a parameter `k` (default value 1). The experts are the individual neural networks that form the \"experts\" in the mixture, they are feed-forward neural networks. The gate is another neural network (a linear layer) responsible for producing gate logits, which are used to weight the contributions of the experts. The parameter `k` determines how many experts to select based on the gate logits (gate logits are the values that emerge from the application of gate module operations).\n", 1150 | "\n", 1151 | "Let's move on to discussing the mechanics of the `forward` method. At the beginning, the input tensor `inputs` is flattened (squashed) and passed through the gate module to obtain gate logits. The top-k experts with the highest gate logits are selected using `torch.topk`.\n", 1152 | "\n", 1153 | "The gate logits are then normalized using the softmax function along the second dimension. This results in a probability distribution over the selected experts.\n", 1154 | "\n", 1155 | "The selected experts and their corresponding weights are used to compute the weighted sum of the expert outputs. The final result is a tensor representing the output of the mixture of experts layer.\n", 1156 | "\n", 1157 | "The output tensor is reshaped to match the shape of the input tensor and returned." 1158 | ] 1159 | }, 1160 | { 1161 | "cell_type": "code", 1162 | "execution_count": 34, 1163 | "id": "cc99dca5", 1164 | "metadata": { 1165 | "id": "cc99dca5" 1166 | }, 1167 | "outputs": [], 1168 | "source": [ 1169 | "class MoeLayer(nn.Module):\n", 1170 | " def __init__(self, experts, gate, k=1):\n", 1171 | " super().__init__()\n", 1172 | " assert len(experts) > 0\n", 1173 | " self.experts = nn.ModuleList(experts)\n", 1174 | " self.gate = gate\n", 1175 | " self.k = k\n", 1176 | "\n", 1177 | " def forward(self, inputs: torch.Tensor):\n", 1178 | " inputs_squashed = inputs.view(-1, inputs.shape[-1])\n", 1179 | " gate_logits = self.gate(inputs_squashed)\n", 1180 | " weights, selected_experts = torch.topk(\n", 1181 | " gate_logits, self.k\n", 1182 | " )\n", 1183 | " weights = nn.functional.softmax(\n", 1184 | " weights,\n", 1185 | " dim=1,\n", 1186 | " dtype=torch.float,\n", 1187 | " ).type_as(inputs)\n", 1188 | " results = torch.zeros_like(inputs_squashed)\n", 1189 | " for i, expert in enumerate(self.experts):\n", 1190 | " batch_idx, nth_expert = torch.where(selected_experts == i)\n", 1191 | " results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(\n", 1192 | " inputs_squashed[batch_idx]\n", 1193 | " )\n", 1194 | " return results.view_as(inputs)" 1195 | ] 1196 | }, 1197 | { 1198 | "cell_type": "markdown", 1199 | "id": "dbe9bb49", 1200 | "metadata": { 1201 | "id": "dbe9bb49" 1202 | }, 1203 | "source": [ 1204 | "The picture below shows the plain Transformer encoder architecture (left) and its MoE modified version (right). Block module is implemented by the `Block` class, which we will see shortly (actually there are n Block modules, n is coded as `n_layer`)." 1205 | ] 1206 | }, 1207 | { 1208 | "cell_type": "markdown", 1209 | "id": "cb065da2", 1210 | "metadata": { 1211 | "id": "cb065da2" 1212 | }, 1213 | "source": [ 1214 | "![Transformer](https://github.com/antonio-f/mixture-of-experts-from-scratch/blob/main/Transf.png?raw=1)" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "markdown", 1219 | "id": "4f671d79", 1220 | "metadata": { 1221 | "id": "4f671d79" 1222 | }, 1223 | "source": [ 1224 | "Below, a more detailed picture highlighting MoE layer (taken from https://arxiv.org/pdf/2101.03961.pdf). \"Router\" represents the gating module, experts are Feed Forward Networks (FFN 1, 2, 3 and 4)." 1225 | ] 1226 | }, 1227 | { 1228 | "cell_type": "markdown", 1229 | "id": "285a9fd6", 1230 | "metadata": { 1231 | "id": "285a9fd6" 1232 | }, 1233 | "source": [ 1234 | "![MoE](https://github.com/antonio-f/mixture-of-experts-from-scratch/blob/main/MoE.png?raw=1)" 1235 | ] 1236 | }, 1237 | { 1238 | "cell_type": "markdown", 1239 | "id": "39f4ccfa", 1240 | "metadata": { 1241 | "id": "39f4ccfa" 1242 | }, 1243 | "source": [ 1244 | "Below, the code for the Transformer model (modified to include MoE layer). The Transformer consists of several blocks. So, to implement `Transformer` class, we need to implement the `Block` class first. In turn, to implement the `Block` class, we need `MulitHeadAttention` and `FeedForward` classes (other than `MoeLayer`, already defined). To define `MulitHeadAttention` we need the class `Head`." 1245 | ] 1246 | }, 1247 | { 1248 | "cell_type": "code", 1249 | "execution_count": 35, 1250 | "id": "b2a585da", 1251 | "metadata": { 1252 | "id": "b2a585da" 1253 | }, 1254 | "outputs": [], 1255 | "source": [ 1256 | "class Head(nn.Module):\n", 1257 | " def __init__(self, head_size):\n", 1258 | " super().__init__()\n", 1259 | " self.key = nn.Linear(n_embed, head_size, bias = False)\n", 1260 | " self.query = nn.Linear(n_embed, head_size, bias = False)\n", 1261 | " self.value = nn.Linear(n_embed, head_size, bias = False)\n", 1262 | " self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))\n", 1263 | " self.dropout = nn.Dropout(dropout)\n", 1264 | "\n", 1265 | " def forward(self, x):\n", 1266 | " B, T, C = x.shape\n", 1267 | " k = self.key(x)\n", 1268 | " q = self.query(x)\n", 1269 | " wei = q @ k.transpose(-2, -1) * C**-0.5\n", 1270 | " wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))\n", 1271 | " wei = F.softmax(wei, dim=-1)\n", 1272 | " wei = self.dropout(wei)\n", 1273 | " v = self.value(x)\n", 1274 | " out = wei @ v\n", 1275 | " return out\n", 1276 | "\n", 1277 | "class MulitHeadAttention(nn.Module):\n", 1278 | " def __init__(self, num_heads, head_size):\n", 1279 | " super().__init__()\n", 1280 | " self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])\n", 1281 | " self.proj = nn.Linear(n_embed, n_embed)\n", 1282 | " self.dropout = nn.Dropout(dropout)\n", 1283 | "\n", 1284 | " def forward(self, x):\n", 1285 | " x = torch.cat([head(x) for head in self.heads], dim=-1)\n", 1286 | " out = self.dropout(self.proj(x))\n", 1287 | " return out\n", 1288 | "\n", 1289 | "\n", 1290 | "class FeedForward(nn.Module):\n", 1291 | " def __init__(self, n_embed):\n", 1292 | " super().__init__()\n", 1293 | " self.net = nn.Sequential(\n", 1294 | " nn.Linear(n_embed, 4* n_embed),\n", 1295 | " nn.ReLU(),\n", 1296 | " nn.Linear(4 * n_embed, n_embed),\n", 1297 | " nn.Dropout(dropout))\n", 1298 | "\n", 1299 | " def forward(self, x):\n", 1300 | " return self.net(x)\n", 1301 | "\n", 1302 | "class Block(nn.Module):\n", 1303 | " def __init__(self, n_embed, n_head, num_experts=4):\n", 1304 | " super().__init__()\n", 1305 | " self.sa_head= MulitHeadAttention(n_head, n_embed//n_head)\n", 1306 | " self.ffw = MoeLayer(\n", 1307 | " experts=[FeedForward(n_embed) for _ in range(num_experts)],\n", 1308 | " gate=nn.Linear(n_embed, num_experts, bias=False),\n", 1309 | " )\n", 1310 | "\n", 1311 | " self.ln1 = nn.LayerNorm(n_embed)\n", 1312 | " self.ln2 = nn.LayerNorm(n_embed)\n", 1313 | "\n", 1314 | " def forward(self, x):\n", 1315 | " x = x + self.sa_head(self.ln1(x))\n", 1316 | " x = x + self.ffw(self.ln2(x))\n", 1317 | " return x\n", 1318 | "\n", 1319 | "\n", 1320 | "class Transformer(nn.Module):\n", 1321 | " def __init__(self):\n", 1322 | " super().__init__()\n", 1323 | "\n", 1324 | " self.token_embedding_table = nn.Embedding(vocab_size, n_embed, device=device)\n", 1325 | " self.position_embedding_table = nn.Embedding(block_size, n_embed, device=device)\n", 1326 | " self.blocks = nn.Sequential(*[Block(n_embd, n_head=n_head) for _ in range(n_layer)])\n", 1327 | " self.lm_head = nn.Linear(n_embed, vocab_size)\n", 1328 | "\n", 1329 | "\n", 1330 | " def forward(self, idx, targets=None):\n", 1331 | " B, T = idx.shape\n", 1332 | "\n", 1333 | " token_emb = self.token_embedding_table(idx)\n", 1334 | " pos_emb = self.position_embedding_table(torch.arange(T).to(device))\n", 1335 | " x = token_emb + pos_emb\n", 1336 | " x = self.blocks(x)\n", 1337 | " logits = self.lm_head(x)\n", 1338 | " if targets == None:\n", 1339 | " loss = None\n", 1340 | " else:\n", 1341 | " B, T, C = logits.shape\n", 1342 | " logits = logits.view(B*T, C)\n", 1343 | " targets = targets.view(B*T)\n", 1344 | " loss = F.cross_entropy(logits, targets)\n", 1345 | " return logits, loss\n", 1346 | "\n", 1347 | " def generate(self, idx, max_new_tokes):\n", 1348 | " for _ in range(max_new_tokes):\n", 1349 | " idx_cond = idx[:, -block_size:]\n", 1350 | " logits, loss = self(idx_cond)\n", 1351 | " logits = logits[:, -1, :]\n", 1352 | " probs = F.softmax(logits, dim = -1)\n", 1353 | " idx_next = torch.multinomial(probs, num_samples = 1)\n", 1354 | " idx = torch.cat((idx, idx_next), dim = 1)\n", 1355 | " return idx" 1356 | ] 1357 | }, 1358 | { 1359 | "cell_type": "markdown", 1360 | "id": "e08f7997", 1361 | "metadata": { 1362 | "id": "e08f7997" 1363 | }, 1364 | "source": [ 1365 | "Here are all the necessary hyperparameters. `max_iters` is set to 3000 for testing (it will take some time to train). Probably things start to become significant for values larger than 5000..." 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "code", 1370 | "execution_count": 36, 1371 | "id": "35bc917b", 1372 | "metadata": { 1373 | "id": "35bc917b" 1374 | }, 1375 | "outputs": [], 1376 | "source": [ 1377 | "# hyperparameters\n", 1378 | "batch_size = 64 # independent sequences processed in parallel\n", 1379 | "block_size = 256 # max context length\n", 1380 | "max_iters = 3000 \n", 1381 | "eval_interval = 100\n", 1382 | "learning_rate = 1e-3\n", 1383 | "eval_iters = 200\n", 1384 | "n_embd = 384\n", 1385 | "n_embed = 384\n", 1386 | "n_head = 6\n", 1387 | "n_layer = 6\n", 1388 | "dropout = 0.0\n", 1389 | "\n", 1390 | "# set device\n", 1391 | "device = 'cuda' if torch.cuda.is_available() else 'cpu'" 1392 | ] 1393 | }, 1394 | { 1395 | "cell_type": "markdown", 1396 | "id": "1063a544", 1397 | "metadata": { 1398 | "id": "1063a544" 1399 | }, 1400 | "source": [ 1401 | "### Model training" 1402 | ] 1403 | }, 1404 | { 1405 | "cell_type": "markdown", 1406 | "id": "740507d5", 1407 | "metadata": {}, 1408 | "source": [ 1409 | "Our model is the previously defined Transformer." 1410 | ] 1411 | }, 1412 | { 1413 | "cell_type": "code", 1414 | "execution_count": 37, 1415 | "id": "24122aa8", 1416 | "metadata": { 1417 | "id": "24122aa8" 1418 | }, 1419 | "outputs": [], 1420 | "source": [ 1421 | "model = Transformer()" 1422 | ] 1423 | }, 1424 | { 1425 | "cell_type": "markdown", 1426 | "id": "e59af998", 1427 | "metadata": { 1428 | "id": "e59af998" 1429 | }, 1430 | "source": [ 1431 | "The function below evaluates loss for training and validation data." 1432 | ] 1433 | }, 1434 | { 1435 | "cell_type": "code", 1436 | "execution_count": 38, 1437 | "id": "18a31597", 1438 | "metadata": { 1439 | "id": "18a31597" 1440 | }, 1441 | "outputs": [], 1442 | "source": [ 1443 | "@torch.no_grad()\n", 1444 | "def estimate_loss():\n", 1445 | " out = {}\n", 1446 | " model.eval()\n", 1447 | " for split in ['train', 'val']:\n", 1448 | " losses = torch.zeros(eval_iters)\n", 1449 | " for k in range(eval_iters):\n", 1450 | " X, Y = get_batch(split)\n", 1451 | " X = X.to(device)\n", 1452 | " Y = Y.to(device)\n", 1453 | " logits, loss = model(X, Y)\n", 1454 | " losses[k] = loss.item()\n", 1455 | " out[split] = losses.mean()\n", 1456 | " model.train()\n", 1457 | " return out" 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "execution_count": 39, 1463 | "id": "QtLU4WDSPoIF", 1464 | "metadata": { 1465 | "colab": { 1466 | "base_uri": "https://localhost:8080/", 1467 | "height": 35 1468 | }, 1469 | "id": "QtLU4WDSPoIF", 1470 | "outputId": "804a9d4d-2f88-462b-e1e7-62d9f498a664" 1471 | }, 1472 | "outputs": [ 1473 | { 1474 | "data": { 1475 | "application/vnd.google.colaboratory.intrinsic+json": { 1476 | "type": "string" 1477 | }, 1478 | "text/plain": [ 1479 | "'cuda'" 1480 | ] 1481 | }, 1482 | "execution_count": 39, 1483 | "metadata": {}, 1484 | "output_type": "execute_result" 1485 | } 1486 | ], 1487 | "source": [ 1488 | "device" 1489 | ] 1490 | }, 1491 | { 1492 | "cell_type": "markdown", 1493 | "id": "2bc8643a", 1494 | "metadata": { 1495 | "id": "2bc8643a" 1496 | }, 1497 | "source": [ 1498 | "Move the model to the device and adopt [AdamW](https://www.fast.ai/posts/2018-07-02-adam-weight-decay.html) optimizer." 1499 | ] 1500 | }, 1501 | { 1502 | "cell_type": "code", 1503 | "execution_count": 40, 1504 | "id": "ba7349df", 1505 | "metadata": { 1506 | "id": "ba7349df" 1507 | }, 1508 | "outputs": [], 1509 | "source": [ 1510 | "model = model.to(device)\n", 1511 | "optimizer = torch.optim.AdamW(model.parameters(),lr=1e-4)" 1512 | ] 1513 | }, 1514 | { 1515 | "cell_type": "markdown", 1516 | "id": "f7aca4de", 1517 | "metadata": { 1518 | "id": "f7aca4de" 1519 | }, 1520 | "source": [ 1521 | "The training loop. If `max_iters` is large, it may take some time to complete." 1522 | ] 1523 | }, 1524 | { 1525 | "cell_type": "code", 1526 | "execution_count": 41, 1527 | "id": "e43b298a", 1528 | "metadata": { 1529 | "colab": { 1530 | "base_uri": "https://localhost:8080/" 1531 | }, 1532 | "id": "e43b298a", 1533 | "outputId": "68f6185f-13a3-437b-c9a5-0dc1456a5dfb" 1534 | }, 1535 | "outputs": [ 1536 | { 1537 | "name": "stdout", 1538 | "output_type": "stream", 1539 | "text": [ 1540 | "step 0: train loss 4.9073, val loss 4.9073\n", 1541 | "step 100: train loss 2.3431, val loss 2.3454\n", 1542 | "step 200: train loss 2.3039, val loss 2.3042\n", 1543 | "step 300: train loss 2.2779, val loss 2.2779\n", 1544 | "step 400: train loss 2.2433, val loss 2.2438\n", 1545 | "step 500: train loss 2.1811, val loss 2.1828\n", 1546 | "step 600: train loss 2.0586, val loss 2.0600\n", 1547 | "step 700: train loss 1.8800, val loss 1.8853\n", 1548 | "step 800: train loss 1.7369, val loss 1.7424\n", 1549 | "step 900: train loss 1.6339, val loss 1.6397\n", 1550 | "step 1000: train loss 1.5603, val loss 1.5576\n", 1551 | "step 1100: train loss 1.4920, val loss 1.4932\n", 1552 | "step 1200: train loss 1.4438, val loss 1.4467\n", 1553 | "step 1300: train loss 1.3997, val loss 1.4049\n", 1554 | "step 1400: train loss 1.3656, val loss 1.3669\n", 1555 | "step 1500: train loss 1.3264, val loss 1.3289\n", 1556 | "step 1600: train loss 1.3024, val loss 1.2976\n", 1557 | "step 1700: train loss 1.2736, val loss 1.2743\n", 1558 | "step 1800: train loss 1.2499, val loss 1.2537\n", 1559 | "step 1900: train loss 1.2261, val loss 1.2253\n", 1560 | "step 2000: train loss 1.2046, val loss 1.2061\n", 1561 | "step 2100: train loss 1.1865, val loss 1.1890\n", 1562 | "step 2200: train loss 1.1698, val loss 1.1704\n", 1563 | "step 2300: train loss 1.1549, val loss 1.1545\n", 1564 | "step 2400: train loss 1.1383, val loss 1.1397\n", 1565 | "step 2500: train loss 1.1250, val loss 1.1214\n", 1566 | "step 2600: train loss 1.1100, val loss 1.1127\n", 1567 | "step 2700: train loss 1.0963, val loss 1.0971\n", 1568 | "step 2800: train loss 1.0880, val loss 1.0880\n", 1569 | "step 2900: train loss 1.0735, val loss 1.0768\n", 1570 | "step 2999: train loss 1.0622, val loss 1.0644\n" 1571 | ] 1572 | } 1573 | ], 1574 | "source": [ 1575 | "\n", 1576 | "for iter in range(max_iters):\n", 1577 | "\n", 1578 | " # print the loss on train and val datasets\n", 1579 | " if iter % 100 == 0 or iter == max_iters - 1:\n", 1580 | " losses = estimate_loss()\n", 1581 | " print(f\"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}\")\n", 1582 | "\n", 1583 | " # sample a batch of data\n", 1584 | " xb, yb = get_batch('train')\n", 1585 | " xb = xb.to(device)\n", 1586 | " yb = yb.to(device)\n", 1587 | "\n", 1588 | " # evaluate the loss\n", 1589 | " logits, loss = model(xb, yb)\n", 1590 | " optimizer.zero_grad(set_to_none=True)\n", 1591 | " loss.backward()\n", 1592 | " optimizer.step()" 1593 | ] 1594 | }, 1595 | { 1596 | "cell_type": "markdown", 1597 | "id": "8d8447c3", 1598 | "metadata": { 1599 | "id": "8d8447c3" 1600 | }, 1601 | "source": [ 1602 | "### Model evaluation" 1603 | ] 1604 | }, 1605 | { 1606 | "cell_type": "markdown", 1607 | "id": "b1da22a0", 1608 | "metadata": { 1609 | "id": "b1da22a0" 1610 | }, 1611 | "source": [ 1612 | "We test our model first encoding some small sequence `d` to get started." 1613 | ] 1614 | }, 1615 | { 1616 | "cell_type": "code", 1617 | "execution_count": 42, 1618 | "id": "SGZjghiVcR7q", 1619 | "metadata": { 1620 | "colab": { 1621 | "base_uri": "https://localhost:8080/" 1622 | }, 1623 | "id": "SGZjghiVcR7q", 1624 | "outputId": "e7837b8f-6e2b-4ac4-b5ee-5498312d2b0f" 1625 | }, 1626 | "outputs": [ 1627 | { 1628 | "name": "stdout", 1629 | "output_type": "stream", 1630 | "text": [ 1631 | "a long time ago, there was a she what orn it was drawaying.\n", 1632 | "Lily said on the tress and went fast, what let so deep. So, he said, \"From you, Max! I have full new get special?\" But so atcher amaze her paint and hellped swing that mudre that he every day.\n", 1633 | "One Bunny day, a ball abloove make turn very thought animals alun. Lily asked the field mortor the ground, another of get theree were so aftul, scareful deond again.\n", 1634 | "One day, a mexe, something more sak yurng afr he could the make slove locks? \n", 1635 | "Lily asked her for to man stook \n" 1636 | ] 1637 | } 1638 | ], 1639 | "source": [ 1640 | "d = 'a long time ago, there was a '\n", 1641 | "x = torch.tensor(encode(d), dtype = torch.long,device=device).unsqueeze(0)\n", 1642 | "print(decode(model.generate(x, max_new_tokes=500)[0].tolist()))" 1643 | ] 1644 | }, 1645 | { 1646 | "cell_type": "markdown", 1647 | "id": "92898cb0", 1648 | "metadata": {}, 1649 | "source": [ 1650 | "### Useful links" 1651 | ] 1652 | }, 1653 | { 1654 | "cell_type": "markdown", 1655 | "id": "1c637c0a", 1656 | "metadata": {}, 1657 | "source": [ 1658 | "Mixture of Experts from scratch ([Substack](https://monads.substack.com/p/mixture-of-experts-from-scratch)) ([Wordpress](https://m0nads.wordpress.com/2024/02/08/mixture-of-experts-from-scratch/))\n", 1659 | "\n", 1660 | "Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
\n", 1661 | "W. Fedus, B. Zoph, N. Shazeer
\n", 1662 | "[arXiv:2101.03961v3](https://arxiv.org/abs/2101.03961) [cs.LG](2021, rev. 2022)\n", 1663 | "\n", 1664 | "GShard: Scaling Giant Models with Conditional Computation and Automatic Sharding
\n", 1665 | "D. Lepikhin, H. Lee, Y. Xu, D. Chen, O. Firat, Y. Huang, M. Krikun, N. Shazeer, Z. Chen
\n", 1666 | "[arXiv:2006.16668v1](https://arxiv.org/abs/2006.16668) [cs.CL] (2020)\n", 1667 | "\n", 1668 | "TinyStories dataset ([link](https://huggingface.co/datasets/roneneldan/TinyStories))\n", 1669 | "\n", 1670 | "Mixture of Experts Explained ([link](https://huggingface.co/blog/moe))" 1671 | ] 1672 | } 1673 | ], 1674 | "metadata": { 1675 | "accelerator": "GPU", 1676 | "colab": { 1677 | "gpuType": "T4", 1678 | "provenance": [] 1679 | }, 1680 | "kernelspec": { 1681 | "display_name": "Python 3", 1682 | "name": "python3" 1683 | }, 1684 | "language_info": { 1685 | "codemirror_mode": { 1686 | "name": "ipython", 1687 | "version": 3 1688 | }, 1689 | "file_extension": ".py", 1690 | "mimetype": "text/x-python", 1691 | "name": "python", 1692 | "nbconvert_exporter": "python", 1693 | "pygments_lexer": "ipython3", 1694 | "version": "3.11.7" 1695 | } 1696 | }, 1697 | "nbformat": 4, 1698 | "nbformat_minor": 5 1699 | } 1700 | --------------------------------------------------------------------------------