├── transformer.png ├── README.md ├── LICENSE └── Transformer_walkthrough.ipynb /transformer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/markriedl/transformer-walkthrough/main/transformer.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # A walkthrough of transformer architecture code 2 | 3 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1PXZMr0KrOUYsWHI7Pae6iNkLPoWEkI4n?usp=sharing) 4 | 5 | The notebook walks through a single forward pass of the Transformer architecture in pytorch. It is meant for illustration and educational purposes only. The walkthrough explains every stage of the architecture accompanied by a detailed computation graph. 6 | 7 | ![Transformer Computation Graph](https://github.com/markriedl/transformer-walkthrough/blob/main/transformer.png) 8 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 markriedl 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Transformer_walkthrough.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Transformer-walkthrough.ipynb", 7 | "provenance": [], 8 | "collapsed_sections": [], 9 | "authorship_tag": "ABX9TyN3m5R0G4c8CGAWO6LxY3nV", 10 | "include_colab_link": true 11 | }, 12 | "kernelspec": { 13 | "name": "python3", 14 | "display_name": "Python 3" 15 | }, 16 | "language_info": { 17 | "name": "python" 18 | } 19 | }, 20 | "cells": [ 21 | { 22 | "cell_type": "markdown", 23 | "metadata": { 24 | "id": "view-in-github", 25 | "colab_type": "text" 26 | }, 27 | "source": [ 28 | "\"Open" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "# Transformer Code Walkthrough\n", 35 | "\n", 36 | "[Mark Riedl](http://eilab.gatech.edu/mark-riedl.html)\n", 37 | "\n", 38 | "This notebook walks through a single forward pass of the Transformer architecture in pytorch. It is meant for illustration and educational purposes only. \n", 39 | "\n", 40 | "The Transformer was introduced by Vaswani et al. (2017) in their paper, titled [Attention Is All You Need](https://arxiv.org/abs/1706.03762)." 41 | ], 42 | "metadata": { 43 | "id": "CoRK-2VVfbMp" 44 | } 45 | }, 46 | { 47 | "cell_type": "markdown", 48 | "source": [ 49 | "# Computation Graph" 50 | ], 51 | "metadata": { 52 | "id": "9riNZp-Jf-HG" 53 | } 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "source": [ 58 | "This is the computation graph, an illustrated diagram of the mathematical operations, their inputs and their outputs. The inputs at the bottom are fed upward into an encode and a decoder (depicted side by side like in a sequence-to-sequence network). At every stage, it shows the matrix and their shapes (excluding the batching dimension, which makes the tensors more complicated looking without adding much information). The bubbles show what part of the code below is responsible for each part of the diagram.\n", 59 | "\n", 60 | "![Computation Graph](https://www.dropbox.com/s/bjwdq06zvq703b4/transformer.png?dl=1)" 61 | ], 62 | "metadata": { 63 | "id": "w4kYasNjgBKI" 64 | } 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "source": [ 69 | "# Imports" 70 | ], 71 | "metadata": { 72 | "id": "ESzQqTWmfDJN" 73 | } 74 | }, 75 | { 76 | "cell_type": "code", 77 | "source": [ 78 | "import torch\n", 79 | "import torch.nn as nn\n", 80 | "import torch.nn.functional as F\n", 81 | "import math\n", 82 | "import numpy as np\n", 83 | "import matplotlib.pyplot as plt" 84 | ], 85 | "metadata": { 86 | "id": "S-cmUU-75qcq" 87 | }, 88 | "execution_count": 1, 89 | "outputs": [] 90 | }, 91 | { 92 | "cell_type": "markdown", 93 | "source": [ 94 | "# Hyper parameters" 95 | ], 96 | "metadata": { 97 | "id": "fl8-t60g85Uz" 98 | } 99 | }, 100 | { 101 | "cell_type": "code", 102 | "source": [ 103 | "d_embed = 512 # embedding size for the attention modules\n", 104 | "num_heads = 8 # Number of attention heads\n", 105 | "num_batches = 1 # number of batches (1 makes it easier to see what is going on)\n", 106 | "vocab = 50000 # vocab size\n", 107 | "max_len = 5000 # Max length of TODO what exactly?\n", 108 | "n_layers = 1 # number of attention layers (not used but would be an expected hyper-parameter)\n", 109 | "d_ff = 2048 # hidden state size in the feed forward layers\n", 110 | "epsilon = 1e-6 # epsilon to use when we need a small non-zero number\n" 111 | ], 112 | "metadata": { 113 | "id": "99lu6YH3kxwF" 114 | }, 115 | "execution_count": 2, 116 | "outputs": [] 117 | }, 118 | { 119 | "cell_type": "markdown", 120 | "source": [ 121 | "#Make dummy data" 122 | ], 123 | "metadata": { 124 | "id": "xt8SMUhu5bSk" 125 | } 126 | }, 127 | { 128 | "cell_type": "markdown", 129 | "source": [ 130 | "Here we create some dummy input data, consisting of three tokens. The 2nd token will be the masked token. Initially we have an input `x` of size `batch_size x sequence_length`. Throughout, we will use `x` to denote the tensor that originated from the input sequence, and `y` to denote the tensor that originated from the target sequence. " 131 | ], 132 | "metadata": { 133 | "id": "KCzWWhuvfs9h" 134 | } 135 | }, 136 | { 137 | "cell_type": "code", 138 | "source": [ 139 | "x = torch.tensor([[1, 2, 3]]) # input will be 3 tokens\n", 140 | "y = torch.tensor([[1, 2, 3]]) # target will be same as the input for many applications\n", 141 | "x_mask = torch.tensor([[1, 0, 1]]) # Mask the 2nd input token\n", 142 | "y_mask = torch.tensor([[1, 0, 1]]) # Mask the 2nd target token\n", 143 | "print(\"x\", x.size())\n", 144 | "print(\"y\", y.size())" 145 | ], 146 | "metadata": { 147 | "colab": { 148 | "base_uri": "https://localhost:8080/" 149 | }, 150 | "id": "AVjmrgvysDd6", 151 | "outputId": "5ed2b8d3-2fff-4d5e-b165-30a3c5eed3a3" 152 | }, 153 | "execution_count": 3, 154 | "outputs": [ 155 | { 156 | "output_type": "stream", 157 | "name": "stdout", 158 | "text": [ 159 | "x torch.Size([1, 3])\n", 160 | "y torch.Size([1, 3])\n" 161 | ] 162 | } 163 | ] 164 | }, 165 | { 166 | "cell_type": "markdown", 167 | "source": [ 168 | "# 1. Encoder" 169 | ], 170 | "metadata": { 171 | "id": "QG8oYLnl8_mX" 172 | } 173 | }, 174 | { 175 | "cell_type": "markdown", 176 | "source": [ 177 | "This section shows a walk-through of one attention layer in the encoder. The purpose of the encoder is to create a *hidden state*, an encoded representation of the input sequence. The hidden state is then passed to the decoder." 178 | ], 179 | "metadata": { 180 | "id": "n4NB3UzKgUAH" 181 | } 182 | }, 183 | { 184 | "cell_type": "markdown", 185 | "source": [ 186 | "## 1.1 Encoder Embeddings" 187 | ], 188 | "metadata": { 189 | "id": "yvhxS-vm9QMb" 190 | } 191 | }, 192 | { 193 | "cell_type": "markdown", 194 | "source": [ 195 | "Use a conventional embedding layer to convert the tokens into embeddings of size `d_embed`. The embedding activations are then scaled by `sqrt(d_model)` in order to make them bigger. This will be important when positional embedding information is added (next). We want this embedding information to have more importance. The result is a tensor of size `batch_size x sequence_length x embedding_size`." 196 | ], 197 | "metadata": { 198 | "id": "KgtVhtKegehe" 199 | } 200 | }, 201 | { 202 | "cell_type": "code", 203 | "source": [ 204 | "# Make the embedding module. It understands that each token should result in a separate embedding.\n", 205 | "emb = nn.Embedding(vocab, d_embed)\n", 206 | "x = emb(x)\n", 207 | "# Scale the embedding\n", 208 | "x = x * math.sqrt(d_embed)\n", 209 | "print(x.size())" 210 | ], 211 | "metadata": { 212 | "colab": { 213 | "base_uri": "https://localhost:8080/" 214 | }, 215 | "id": "VKckLst4k4_N", 216 | "outputId": "fe0ac2c1-ed76-4652-d15b-8f34708d4b87" 217 | }, 218 | "execution_count": 4, 219 | "outputs": [ 220 | { 221 | "output_type": "stream", 222 | "name": "stdout", 223 | "text": [ 224 | "torch.Size([1, 3, 512])\n" 225 | ] 226 | } 227 | ] 228 | }, 229 | { 230 | "cell_type": "markdown", 231 | "source": [ 232 | "Next we add positional embedding information. The code below creates a pattern of overlapping sine and cosine waves that are added to the embedding. This differentiates embedded tokens based on where they are in the sequence. That is, if an input sequence has two of the same token, their embeddings will end up looking a little bit different based on their position in the sequence." 233 | ], 234 | "metadata": { 235 | "id": "CYsqRpXesZtV" 236 | } 237 | }, 238 | { 239 | "cell_type": "code", 240 | "source": [ 241 | "# Start with an empty tensor\n", 242 | "pe = torch.zeros(max_len, d_embed, requires_grad=False)\n", 243 | "# array containing index values 0...max_len\n", 244 | "position = torch.arange(0, max_len).unsqueeze(1)\n", 245 | "divisor = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))\n", 246 | "# Make overlapping sine and cosine wave inside positional embedding tensor\n", 247 | "pe[:, 0::2] = torch.sin(position * divisor)\n", 248 | "pe[:, 1::2] = torch.cos(position * divisor)\n", 249 | "pe = pe.unsqueeze(0)\n", 250 | "# Add the position embedding to the main embedding\n", 251 | "x = x + pe[:, :x.size(1)]\n", 252 | "print(x.size())" 253 | ], 254 | "metadata": { 255 | "colab": { 256 | "base_uri": "https://localhost:8080/" 257 | }, 258 | "id": "WTaufnQksbjC", 259 | "outputId": "1ea218c2-08fe-4dc7-d7bd-5cbe50ae7d66" 260 | }, 261 | "execution_count": 5, 262 | "outputs": [ 263 | { 264 | "output_type": "stream", 265 | "name": "stdout", 266 | "text": [ 267 | "torch.Size([1, 3, 512])\n" 268 | ] 269 | } 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "source": [ 275 | "To see how positional embeddings work, we can visualize the values that get added to each embedding in each dimension of the embedding (we only visualize the first 8 dimensions)." 276 | ], 277 | "metadata": { 278 | "id": "xTHKykBG4z6F" 279 | } 280 | }, 281 | { 282 | "cell_type": "code", 283 | "source": [ 284 | "plt.figure(figsize=(15, 5)) # Make a plot\n", 285 | "d_embed_plot = 16 # for illustration purposes, set embedding dimensions = 16 \n", 286 | "pe_plot = torch.zeros(max_len, d_embed_plot, requires_grad=False) # positional embedding tensor\n", 287 | "position_plot = torch.arange(0, max_len).unsqueeze(1)\n", 288 | "divisor_plot = torch.exp(torch.arange(0, d_embed_plot, 2) * -(math.log(10000.0) / d_embed_plot))\n", 289 | "pe_plot[:, 0::2] = torch.sin(position_plot * divisor_plot)\n", 290 | "pe_plot[:, 1::2] = torch.cos(position_plot * divisor_plot)\n", 291 | "pe_plot = pe_plot.unsqueeze(0)\n", 292 | "# plot it\n", 293 | "y_plot = torch.zeros(1, 50, d_embed_plot)\n", 294 | "y_plot = pe_plot[:, :y_plot.size(1)]\n", 295 | "plt.plot(np.arange(50), y_plot[0, :, 0:4].data.numpy())\n", 296 | "plt.legend([\"dim %d\"%p for p in range(8)])" 297 | ], 298 | "metadata": { 299 | "colab": { 300 | "base_uri": "https://localhost:8080/", 301 | "height": 338 302 | }, 303 | "id": "Ka5VtFT5iLzc", 304 | "outputId": "1f510f12-23d8-4fa8-f3a5-b60cc886ea66" 305 | }, 306 | "execution_count": 6, 307 | "outputs": [ 308 | { 309 | "output_type": "execute_result", 310 | "data": { 311 | "text/plain": [ 312 | "" 313 | ] 314 | }, 315 | "metadata": {}, 316 | "execution_count": 6 317 | }, 318 | { 319 | "output_type": "display_data", 320 | "data": { 321 | "text/plain": [ 322 | "
" 323 | ], 324 | "image/png": "\n" 325 | }, 326 | "metadata": { 327 | "needs_background": "light" 328 | } 329 | } 330 | ] 331 | }, 332 | { 333 | "cell_type": "markdown", 334 | "source": [ 335 | "## 1.2 Encoder Attention Layers" 336 | ], 337 | "metadata": { 338 | "id": "0m4932dp9TOQ" 339 | } 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "source": [ 344 | "The sub-layers in this section will be repeated N times. This code walkthrough will only take us through one. The Encoder Attention Layers consist of a **self-attention** module followed by a **feed forward** module. \n", 345 | "\n", 346 | "The self-attention and the feed forward are wrapped with residuals. A residual connection adds the input of a block to the output of the block. Thus one can think of the block as trying to learn how to add or subtract from the input. This provides stability to the training because the block is not entirely responsible for everything that happens in the forward and backward passes. Taking a look at the encoder for the transformer, one can see the residual connections bypassing the self-attention providing a direct linkage to the hidden state. That is, the embedding at the bottom has the option of doing a lot of the heavy-lifting in terms of the final hidden state encoding. Self-attention and the other sub-layers may add a little bit to that final hidden state or a lot if it helps with loss. Another way of thinking about residuals is like sub-routines in conventional computer program that compute some side-effect. One sub-routine computes the final hidden state. Another sub-routine branches off and computes self-attention. But because every module must be on a gradient path, the side-routines must contribute something to the final loss." 347 | ], 348 | "metadata": { 349 | "id": "63Fk4B5U8Ht5" 350 | } 351 | }, 352 | { 353 | "cell_type": "markdown", 354 | "source": [ 355 | "### 1.2.1 Self-Attention Module" 356 | ], 357 | "metadata": { 358 | "id": "RAwy0gSVAtAw" 359 | } 360 | }, 361 | { 362 | "cell_type": "markdown", 363 | "source": [ 364 | "#### 1.2.1.1 Set aside residual" 365 | ], 366 | "metadata": { 367 | "id": "HijbI99O8LNO" 368 | } 369 | }, 370 | { 371 | "cell_type": "markdown", 372 | "source": [ 373 | "A residual adds the inputs back into the outputs so that what happens in between can be thought of as computing a delta to the original. \n", 374 | "\n", 375 | "Typically we don't need to perform a `clone()` to create a residual, but we are using the same `x` variable in every step so the clone makes sure we don't overwrite." 376 | ], 377 | "metadata": { 378 | "id": "Wj4Tw3Ec9GGv" 379 | } 380 | }, 381 | { 382 | "cell_type": "code", 383 | "source": [ 384 | "x_residual = x.clone() \n", 385 | "print(x.size())" 386 | ], 387 | "metadata": { 388 | "colab": { 389 | "base_uri": "https://localhost:8080/" 390 | }, 391 | "id": "n5MS2WtY46Kr", 392 | "outputId": "3b65980d-dff5-40af-d3e8-c20a86ea3a40" 393 | }, 394 | "execution_count": 7, 395 | "outputs": [ 396 | { 397 | "output_type": "stream", 398 | "name": "stdout", 399 | "text": [ 400 | "torch.Size([1, 3, 512])\n" 401 | ] 402 | } 403 | ] 404 | }, 405 | { 406 | "cell_type": "markdown", 407 | "source": [ 408 | "#### 1.2.1.2 Pre-Self-Attention Layer Normalization" 409 | ], 410 | "metadata": { 411 | "id": "XmMC5eFW7bF8" 412 | } 413 | }, 414 | { 415 | "cell_type": "markdown", 416 | "source": [ 417 | "Before we compute self-attention, we perform layer normalization. Layer normalization stabilizes the training by decreasing the chances that values start to go to extremes. This is accomplished by centering all the values relative to the mean." 418 | ], 419 | "metadata": { 420 | "id": "deaGpdDo0ezK" 421 | } 422 | }, 423 | { 424 | "cell_type": "code", 425 | "source": [ 426 | "mean = x.mean(-1, keepdim=True)\n", 427 | "std = x.std(-1, keepdim=True)\n", 428 | "W1 = nn.Parameter(torch.ones(d_embed))\n", 429 | "b1 = nn.Parameter(torch.zeros(d_embed))\n", 430 | "x = W1 * (x - mean) / (std + epsilon) + b1\n", 431 | "print(x.size())" 432 | ], 433 | "metadata": { 434 | "colab": { 435 | "base_uri": "https://localhost:8080/" 436 | }, 437 | "id": "h0Bmhybj7cia", 438 | "outputId": "ac9eaf1c-ece3-4f5a-f55c-1a7a192cb1b4" 439 | }, 440 | "execution_count": 8, 441 | "outputs": [ 442 | { 443 | "output_type": "stream", 444 | "name": "stdout", 445 | "text": [ 446 | "torch.Size([1, 3, 512])\n" 447 | ] 448 | } 449 | ] 450 | }, 451 | { 452 | "cell_type": "markdown", 453 | "source": [ 454 | "#### 1.2.1.3 Self-Attention" 455 | ], 456 | "metadata": { 457 | "id": "PhNCyBcLlNhI" 458 | } 459 | }, 460 | { 461 | "cell_type": "markdown", 462 | "source": [ 463 | "Self-attention is a process of generating scores that indicate how each token is to every other token. Thus we would expect a `seq_length x seq_length` matrix of values between 0 and 1, each indicating the importance of the i-th token to the j-th token. What does it mean to be \"relevant\"? Whatever reduces loss. The model must learn how to produce the scores.\n", 464 | "\n", 465 | "A metaphor for understanding self-attention is a hash table. In a hash table, there is a list of keys, each of which is associated with a value. A query is sent to the hash table, and the hash table has to find the matching key and return the associated value. Except imagine that this hash table is a fuzzy hash table in the sense that it the query doesn't have to match any keys and the hash table will return whatever seems closest to the query.\n", 466 | "\n", 467 | "The input to self-attention is a `batch_size x sequence_length x embedding_size` matrix. Ignoring the batching dimension, what we have is a sequence of embedded tokens. Self-attention copies this input, `x`, three times and calls them the \"query\" (`q`), \"keys\" (`k`), and \"values\" (`v`). Each of those matrices go through a linear layer. This linear layer is where the network learns to make scores. It makes each matrix different, and if it comes up with the right, different, matrices, it will get good attention scores. If it gets good attention scores and if it gets good attention scores, the loss will be reduced.\n", 468 | "\n", 469 | "Attention-scores are generated as follows. First, we split the `q` and `k` matrices into multiple parts (called \"heads\"). This is called multi-headed attention. The reason we do this is so that each head/part can independently produce different attention scores. This allows each token to have several \"best\" other tokens. In implementation, we just designate chunks of each token embedding to different heads.\n", 470 | "\n", 471 | "The `q` and `k` tensors are multiplied together. This creates a `batch_size x num_heads x sequence_length x sequence_length` matrix. Ignoring batching and heads, one can interpret this matrix as containing the raw scores where each cell computes how related the i-th token is to the j-th token (i is the row and j is the column). \n", 472 | "\n", 473 | "Next we pass this matrix through a softmax layer. The secret to softmax is that it can act like an argmax---it can pick the best match. Softmax squishes all values along a particular dimenion into 0...1. But what it is really doing is trying to force one particular cell to have a number close to 1 and all the rest close to 0. If we multiply this softmaxed score matrix to the `v` matrix, we are in essence asking (for each head), which column is best for each row. Recall that rows and columns correspond to tokens. So we are asking, which token goes best with every other token. Again, if the earlier linear layers get their parameters right, this multiplication will make good choices and loss will improve.\n", 474 | "\n", 475 | "At this point we can think of the softmaxed scores multiplied against `v` as tryinng to zero out everything but the most relevant token embedding (several because of multiple heads). The result, which we will store back in `x` for consistency is mainly the most-attended token embedding (several because of multiple heads) plus a little bit of every other embedded token sprinkled in because we can't do an actual argmax---the best we can do is get everything irrelevant to be close to zero so it doesn't impact anything else.\n", 476 | "\n", 477 | "This multiplication of the scores against the `v` matrix is what we refer to as *self-attention*. It is essentially a dot-product with an underlying learned scoring function. It basically tells us where we should look for good information. The Decoder will use this later." 478 | ], 479 | "metadata": { 480 | "id": "IiXzh1P3_7rC" 481 | } 482 | }, 483 | { 484 | "cell_type": "code", 485 | "source": [ 486 | "# Make three versions of x, for the query, key, and value\n", 487 | "# We don't need to clone because these will immediately go through linear layers, making new tensors\n", 488 | "k = x # key\n", 489 | "q = x # query\n", 490 | "v = x # value\n", 491 | "# Make three linear layers\n", 492 | "# This is where the network learns to make scores\n", 493 | "linear_k = nn.Linear(d_embed, d_embed)\n", 494 | "linear_q = nn.Linear(d_embed, d_embed)\n", 495 | "linear_v = nn.Linear(d_embed, d_embed)\n", 496 | "# We are going to fold the embedding dimensions and treat each fold as an attention head\n", 497 | "d_k = d_embed // num_heads\n", 498 | "# Pass q, k, v through their linear layers\n", 499 | "q = linear_q(q)\n", 500 | "k = linear_k(k)\n", 501 | "v = linear_v(v)\n", 502 | "# Do the fold, treating each h dimenssions as a head\n", 503 | "# Put the head in the second position\n", 504 | "q = q.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 505 | "k = k.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 506 | "v = v.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 507 | "print(\"q\", q.size())\n", 508 | "print(\"x\", k.size())\n", 509 | "print(\"v\", v.size())" 510 | ], 511 | "metadata": { 512 | "colab": { 513 | "base_uri": "https://localhost:8080/" 514 | }, 515 | "id": "kpExj3wl5uRq", 516 | "outputId": "34f2f847-6628-428d-b3b2-b9631dd41b65" 517 | }, 518 | "execution_count": 9, 519 | "outputs": [ 520 | { 521 | "output_type": "stream", 522 | "name": "stdout", 523 | "text": [ 524 | "q torch.Size([1, 8, 3, 64])\n", 525 | "x torch.Size([1, 8, 3, 64])\n", 526 | "v torch.Size([1, 8, 3, 64])\n" 527 | ] 528 | } 529 | ] 530 | }, 531 | { 532 | "cell_type": "markdown", 533 | "source": [ 534 | "To produce the attention scores we multiply `q` and `k` (and normalize). We need to apply the mask so masked tokens don't attend to themselves. Apply softmax to emulate argmax (good stuff close to 1 irrelevant stuff close to 0). You won't see this happen if you look at `attn` because the linear layers aren't trained yet. The attention scores are finally applied to `v`." 535 | ], 536 | "metadata": { 537 | "id": "dkJ2lxguABuE" 538 | } 539 | }, 540 | { 541 | "cell_type": "code", 542 | "source": [ 543 | "d_k = q.size(-1)\n", 544 | "# Compute the raw scores by multiplying k and q (and normalize)\n", 545 | "scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)\n", 546 | "# Mask out the scores\n", 547 | "scores = scores.masked_fill(x_mask == 0, -epsilon)\n", 548 | "# Softmax the scores, ideally creating one score close to 1 and the rest close to 0\n", 549 | "# (Note: this won't happen if you look at the numbers because the linear layers haven't \n", 550 | "# learned anything yet.)\n", 551 | "attn = F.softmax(scores, dim = -1)\n", 552 | "print(\"attention\", attn.size())\n", 553 | "# Apply the scores to v\n", 554 | "x = torch.matmul(attn, v)\n", 555 | "print(\"x\", x.size())" 556 | ], 557 | "metadata": { 558 | "colab": { 559 | "base_uri": "https://localhost:8080/" 560 | }, 561 | "id": "CsYkRUoIDDfU", 562 | "outputId": "954521e1-e00c-4141-8e59-31151a466216" 563 | }, 564 | "execution_count": 10, 565 | "outputs": [ 566 | { 567 | "output_type": "stream", 568 | "name": "stdout", 569 | "text": [ 570 | "attention torch.Size([1, 8, 3, 3])\n", 571 | "x torch.Size([1, 8, 3, 64])\n" 572 | ] 573 | } 574 | ] 575 | }, 576 | { 577 | "cell_type": "markdown", 578 | "source": [ 579 | "The following is an illustration of what self-attention is doing. In the `attn` matrix below each row and column represents a different positionn in the input sequence, such that `attn[i][j]` is how much affinity the i-th position has for the j-th position. In a perfect world, the softmax pushes one element in each row close to 1 and everything else close to 0. Multiplying `attn` against `v` we are picking an embedding (hidden state) for each position (if we have multi-headed attention then we are picking several and adding combining them but the cell below doesn't show that). " 580 | ], 581 | "metadata": { 582 | "id": "K2tJuasVgbzx" 583 | } 584 | }, 585 | { 586 | "cell_type": "code", 587 | "source": [ 588 | "# Make fake attention scores with extreme values\n", 589 | "attn = torch.zeros(3, 3)\n", 590 | "attn[0,1] = 1\n", 591 | "attn[1,2] = 1\n", 592 | "attn[2,0] = 1\n", 593 | "print(\"attn:\")\n", 594 | "print(attn)\n", 595 | "# Make a fake v embedding\n", 596 | "v = torch.tensor(list(map(lambda x:list(range(x*10,(x*10)+10)), list(range(3))))).float()\n", 597 | "print(\"v:\") \n", 598 | "print(v)\n", 599 | "print(\"Matmul result:\")\n", 600 | "print(torch.matmul(attn, v))" 601 | ], 602 | "metadata": { 603 | "colab": { 604 | "base_uri": "https://localhost:8080/" 605 | }, 606 | "id": "z0W5rFyfeSWz", 607 | "outputId": "774a05bd-1ab6-41f2-aee8-59d0a741613a" 608 | }, 609 | "execution_count": 30, 610 | "outputs": [ 611 | { 612 | "output_type": "stream", 613 | "name": "stdout", 614 | "text": [ 615 | "attn:\n", 616 | "tensor([[0., 1., 0.],\n", 617 | " [0., 0., 1.],\n", 618 | " [1., 0., 0.]])\n", 619 | "v:\n", 620 | "tensor([[ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.],\n", 621 | " [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n", 622 | " [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]])\n", 623 | "Matmul result:\n", 624 | "tensor([[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],\n", 625 | " [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],\n", 626 | " [ 0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])\n" 627 | ] 628 | } 629 | ] 630 | }, 631 | { 632 | "cell_type": "markdown", 633 | "source": [ 634 | "But now your embeddings are all swapped around and a little bit of many positions can be mixed together. This is why the residual is going to be important because you can't lose the original embeddings in their original positions." 635 | ], 636 | "metadata": { 637 | "id": "mGAk--KSiSED" 638 | } 639 | }, 640 | { 641 | "cell_type": "markdown", 642 | "source": [ 643 | "Recombine the multiple attention heads (unfold)." 644 | ], 645 | "metadata": { 646 | "id": "6TFWgbA8lRl0" 647 | } 648 | }, 649 | { 650 | "cell_type": "code", 651 | "source": [ 652 | "x = x.transpose(1, 2).contiguous().view(num_batches, -1, num_heads * (d_embed // num_heads))\n", 653 | "print(x.size())" 654 | ], 655 | "metadata": { 656 | "colab": { 657 | "base_uri": "https://localhost:8080/" 658 | }, 659 | "id": "mdlAJY9UJeUe", 660 | "outputId": "f7fa80fd-4236-4a0f-8263-a359d67420e6" 661 | }, 662 | "execution_count": null, 663 | "outputs": [ 664 | { 665 | "output_type": "stream", 666 | "name": "stdout", 667 | "text": [ 668 | "torch.Size([1, 3, 512])\n" 669 | ] 670 | } 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "source": [ 676 | "#### 1.2.1.4 Post-Self-Attention Feed forward" 677 | ], 678 | "metadata": { 679 | "id": "aLze8Ea5Beiw" 680 | } 681 | }, 682 | { 683 | "cell_type": "markdown", 684 | "source": [ 685 | "From this point, we have some token embeddings pushed toward 1 and some token embeddings pushed toward 0. We need to prepare this matrix to be added back into the residual. That is, whatever comes out of this transformation has be a set of values that change the original embedding values for each token by some delta up or down." 686 | ], 687 | "metadata": { 688 | "id": "ZEt2a7DD6jss" 689 | } 690 | }, 691 | { 692 | "cell_type": "code", 693 | "source": [ 694 | "ff = nn.Linear(d_embed, d_embed)\n", 695 | "x = ff(x)\n", 696 | "print(x.size())" 697 | ], 698 | "metadata": { 699 | "colab": { 700 | "base_uri": "https://localhost:8080/" 701 | }, 702 | "id": "_EGA4t_myfIA", 703 | "outputId": "b154d19d-b74d-4fe1-a8e4-a032b276125c" 704 | }, 705 | "execution_count": null, 706 | "outputs": [ 707 | { 708 | "output_type": "stream", 709 | "name": "stdout", 710 | "text": [ 711 | "torch.Size([1, 3, 512])\n" 712 | ] 713 | } 714 | ] 715 | }, 716 | { 717 | "cell_type": "markdown", 718 | "source": [ 719 | "#### 1.2.1.5 Add residual back in" 720 | ], 721 | "metadata": { 722 | "id": "BX-aPZGO7-bf" 723 | } 724 | }, 725 | { 726 | "cell_type": "code", 727 | "source": [ 728 | "x = x_residual + x\n", 729 | "print(x.size())" 730 | ], 731 | "metadata": { 732 | "colab": { 733 | "base_uri": "https://localhost:8080/" 734 | }, 735 | "id": "1ECgQ-hr7_QB", 736 | "outputId": "3a5cd86d-3598-4cbc-b940-848fcc5afc52" 737 | }, 738 | "execution_count": null, 739 | "outputs": [ 740 | { 741 | "output_type": "stream", 742 | "name": "stdout", 743 | "text": [ 744 | "torch.Size([1, 3, 512])\n" 745 | ] 746 | } 747 | ] 748 | }, 749 | { 750 | "cell_type": "markdown", 751 | "source": [ 752 | "### 1.2.2 Feed Forward Module" 753 | ], 754 | "metadata": { 755 | "id": "MwVW4AGcA_jU" 756 | } 757 | }, 758 | { 759 | "cell_type": "markdown", 760 | "source": [ 761 | "This is a straight-forward decoding and re-encoding of the embedding plus self-attention. What we want by the end of the encoding stage is a hidden state. Like in a sequence-to-sequence network we want a *stack* of hidden states, one for each token. That way, the decoder will be able to look back and attend to the hidden state that will be most useful for decoding by looking just at this stack instead of iterating over all the input tokens. So whatever is in each token position has to be representative of what is going on in the input. To move the matrix toward a hidden state we expand the embeddings, giving the network some capacity, and then collapse it down again to force it to make trade-offs." 762 | ], 763 | "metadata": { 764 | "id": "xUg-4uJW7JH5" 765 | } 766 | }, 767 | { 768 | "cell_type": "markdown", 769 | "source": [ 770 | "#### 1.2.2.1 Set aside residual" 771 | ], 772 | "metadata": { 773 | "id": "_9B8fYjZ8PMX" 774 | } 775 | }, 776 | { 777 | "cell_type": "code", 778 | "source": [ 779 | "x_residual = x.clone() \n", 780 | "print(x.size())" 781 | ], 782 | "metadata": { 783 | "colab": { 784 | "base_uri": "https://localhost:8080/" 785 | }, 786 | "id": "2msqv_Br6fHu", 787 | "outputId": "4ccaa96e-85ac-42b9-c4b1-17fd502d2aeb" 788 | }, 789 | "execution_count": null, 790 | "outputs": [ 791 | { 792 | "output_type": "stream", 793 | "name": "stdout", 794 | "text": [ 795 | "torch.Size([1, 3, 512])\n" 796 | ] 797 | } 798 | ] 799 | }, 800 | { 801 | "cell_type": "markdown", 802 | "source": [ 803 | "#### 1.2.2.2 Pre-Feed-Forward Layer Normalization" 804 | ], 805 | "metadata": { 806 | "id": "APWv9beO1rki" 807 | } 808 | }, 809 | { 810 | "cell_type": "code", 811 | "source": [ 812 | "mean = x.mean(-1, keepdim=True)\n", 813 | "std = x.std(-1, keepdim=True)\n", 814 | "W2 = nn.Parameter(torch.ones(d_embed))\n", 815 | "b2 = nn.Parameter(torch.zeros(d_embed))\n", 816 | "x = W2 * (x - mean) / (std + epsilon) + b2\n", 817 | "print(x.size())" 818 | ], 819 | "metadata": { 820 | "colab": { 821 | "base_uri": "https://localhost:8080/" 822 | }, 823 | "id": "i2riRAUL8RFu", 824 | "outputId": "bbfd95fd-0768-4be7-bdd2-2bb5d1cf476a" 825 | }, 826 | "execution_count": null, 827 | "outputs": [ 828 | { 829 | "output_type": "stream", 830 | "name": "stdout", 831 | "text": [ 832 | "torch.Size([1, 3, 512])\n" 833 | ] 834 | } 835 | ] 836 | }, 837 | { 838 | "cell_type": "markdown", 839 | "source": [ 840 | "#### 1.2.2.3 Feed Forward" 841 | ], 842 | "metadata": { 843 | "id": "AwE944o96SDh" 844 | } 845 | }, 846 | { 847 | "cell_type": "markdown", 848 | "source": [ 849 | "This feed forward module grows the embeddings and then compresses it again. This is part of process of transforming the outputs of the self-attention module into a hidden state encoding." 850 | ], 851 | "metadata": { 852 | "id": "zIwN6aUQk2dw" 853 | } 854 | }, 855 | { 856 | "cell_type": "code", 857 | "source": [ 858 | "linear_expand = nn.Linear(d_embed, d_ff)\n", 859 | "linear_compress = nn.Linear(d_ff, d_embed)\n", 860 | "x = linear_compress(F.relu(linear_expand(x)))\n", 861 | "print(x.size())" 862 | ], 863 | "metadata": { 864 | "colab": { 865 | "base_uri": "https://localhost:8080/" 866 | }, 867 | "id": "7aNAnARK6CRp", 868 | "outputId": "58685c56-50f5-4551-e2f9-c6db59a6e926" 869 | }, 870 | "execution_count": null, 871 | "outputs": [ 872 | { 873 | "output_type": "stream", 874 | "name": "stdout", 875 | "text": [ 876 | "torch.Size([1, 3, 512])\n" 877 | ] 878 | } 879 | ] 880 | }, 881 | { 882 | "cell_type": "markdown", 883 | "source": [ 884 | "#### 1.2.2.4 Add residual back in" 885 | ], 886 | "metadata": { 887 | "id": "3ksl3ra36qqc" 888 | } 889 | }, 890 | { 891 | "cell_type": "code", 892 | "source": [ 893 | "x = x_residual + x\n", 894 | "print(x.size())" 895 | ], 896 | "metadata": { 897 | "colab": { 898 | "base_uri": "https://localhost:8080/" 899 | }, 900 | "id": "VhWKoJsi6sec", 901 | "outputId": "e7094c72-7132-432f-f037-c6bb8ea1de01" 902 | }, 903 | "execution_count": null, 904 | "outputs": [ 905 | { 906 | "output_type": "stream", 907 | "name": "stdout", 908 | "text": [ 909 | "torch.Size([1, 3, 512])\n" 910 | ] 911 | } 912 | ] 913 | }, 914 | { 915 | "cell_type": "markdown", 916 | "source": [ 917 | "## 1.3 Final Encoder Layer Normalization" 918 | ], 919 | "metadata": { 920 | "id": "qbm-9J7V8z7w" 921 | } 922 | }, 923 | { 924 | "cell_type": "markdown", 925 | "source": [ 926 | "After repeating the self-attention and feed forward sub-layers N times, we apply one last layer normalization." 927 | ], 928 | "metadata": { 929 | "id": "vrl8f-TEBxMk" 930 | } 931 | }, 932 | { 933 | "cell_type": "code", 934 | "source": [ 935 | "mean = x.mean(-1, keepdim=True)\n", 936 | "std = x.std(-1, keepdim=True)\n", 937 | "Wn = nn.Parameter(torch.ones(d_embed))\n", 938 | "bn = nn.Parameter(torch.zeros(d_embed))\n", 939 | "x = Wn * (x - mean) / (std + epsilon) + bn\n", 940 | "print(x.size())" 941 | ], 942 | "metadata": { 943 | "colab": { 944 | "base_uri": "https://localhost:8080/" 945 | }, 946 | "id": "Lz9Wf7cO6t8I", 947 | "outputId": "1ae80ca2-0fa9-4d1b-b359-83b7b9ec8922" 948 | }, 949 | "execution_count": null, 950 | "outputs": [ 951 | { 952 | "output_type": "stream", 953 | "name": "stdout", 954 | "text": [ 955 | "torch.Size([1, 3, 512])\n" 956 | ] 957 | } 958 | ] 959 | }, 960 | { 961 | "cell_type": "markdown", 962 | "source": [ 963 | "At this point, we should have a matrix, stored in `x` that we can interpret as a stack of hidden states. The Decoder will attempt to attend to this stack and pick out (via softmax emulating argmax) the hidden state that is most helpful in guessing the work that goes in the masked position." 964 | ], 965 | "metadata": { 966 | "id": "V-Gbr1_UzQUD" 967 | } 968 | }, 969 | { 970 | "cell_type": "code", 971 | "source": [ 972 | "# Signify that the output is the hidden state\n", 973 | "hidden = x\n", 974 | "print(hidden.size())" 975 | ], 976 | "metadata": { 977 | "colab": { 978 | "base_uri": "https://localhost:8080/" 979 | }, 980 | "id": "atpOCyod0Xkw", 981 | "outputId": "68928ee4-8272-484a-8502-a1f05803c1c4" 982 | }, 983 | "execution_count": null, 984 | "outputs": [ 985 | { 986 | "output_type": "stream", 987 | "name": "stdout", 988 | "text": [ 989 | "torch.Size([1, 3, 512])\n" 990 | ] 991 | } 992 | ] 993 | }, 994 | { 995 | "cell_type": "markdown", 996 | "source": [ 997 | "# 2. Decoder" 998 | ], 999 | "metadata": { 1000 | "id": "EBZ52DJo-Lry" 1001 | } 1002 | }, 1003 | { 1004 | "cell_type": "markdown", 1005 | "source": [ 1006 | "The Decoder works a lot like the Encoder except for one major change. In addition to self-attention and a feed-forward modules, the Decoder will also include a *source-attention* module wherein it attends to the hidden state output of the encoder. \n", 1007 | "\n", 1008 | "We will be operating on `y`, which is the sequence of target tokens instead of `x`. It seems weird to be treating the target the same as an input. The closest analog is the sequence-to-sequence network, which would generate a sequence of output tokens one at a time to compare to the target sequence to compute loss. But here we don't need to generate the output sequence because there is no recurrence. So we just take the target output and treat it as if it was generated by the transformer. The exception is the masked output token (which is normally the same position as the masked input). For computing loss, we only care if we get a good prediction for the masked target tokens." 1009 | ], 1010 | "metadata": { 1011 | "id": "7Y0CisCQzmtW" 1012 | } 1013 | }, 1014 | { 1015 | "cell_type": "markdown", 1016 | "source": [ 1017 | "## 2.1 Decoder Embeddings" 1018 | ], 1019 | "metadata": { 1020 | "id": "gD3IQX2A-S9r" 1021 | } 1022 | }, 1023 | { 1024 | "cell_type": "code", 1025 | "source": [ 1026 | "emb_d = nn.Embedding(vocab, d_embed)\n", 1027 | "y = emb_d(y) * math.sqrt(d_embed)\n", 1028 | "print(y.size())" 1029 | ], 1030 | "metadata": { 1031 | "colab": { 1032 | "base_uri": "https://localhost:8080/" 1033 | }, 1034 | "id": "9gzaYeb8-NJB", 1035 | "outputId": "fbebfe93-bd6a-487f-8891-4d0b026cdde3" 1036 | }, 1037 | "execution_count": null, 1038 | "outputs": [ 1039 | { 1040 | "output_type": "stream", 1041 | "name": "stdout", 1042 | "text": [ 1043 | "torch.Size([1, 3, 512])\n" 1044 | ] 1045 | } 1046 | ] 1047 | }, 1048 | { 1049 | "cell_type": "markdown", 1050 | "source": [ 1051 | "Add positional embeddings." 1052 | ], 1053 | "metadata": { 1054 | "id": "qIJG3QC2-VWo" 1055 | } 1056 | }, 1057 | { 1058 | "cell_type": "code", 1059 | "source": [ 1060 | "pe = torch.zeros(max_len, d_embed, requires_grad=False)\n", 1061 | "position = torch.arange(0, max_len).unsqueeze(1)\n", 1062 | "divisor = torch.exp(torch.arange(0, d_embed, 2) * -(math.log(10000.0) / d_embed))\n", 1063 | "pe[:, 0::2] = torch.sin(position * divisor)\n", 1064 | "pe[:, 1::2] = torch.cos(position * divisor)\n", 1065 | "pe = pe.unsqueeze(0)\n", 1066 | "y = y + pe[:, :y.size(1)]\n", 1067 | "print(y.size())" 1068 | ], 1069 | "metadata": { 1070 | "colab": { 1071 | "base_uri": "https://localhost:8080/" 1072 | }, 1073 | "id": "gOXTa8yJ-VC1", 1074 | "outputId": "8b4b9923-0252-44f6-83b2-5e212a58faa0" 1075 | }, 1076 | "execution_count": null, 1077 | "outputs": [ 1078 | { 1079 | "output_type": "stream", 1080 | "name": "stdout", 1081 | "text": [ 1082 | "torch.Size([1, 3, 512])\n" 1083 | ] 1084 | } 1085 | ] 1086 | }, 1087 | { 1088 | "cell_type": "markdown", 1089 | "source": [ 1090 | "## 2.2 Decoder Attention Layers" 1091 | ], 1092 | "metadata": { 1093 | "id": "YkI23w9Z_LzP" 1094 | } 1095 | }, 1096 | { 1097 | "cell_type": "markdown", 1098 | "source": [ 1099 | "The decoder layers will be repeated N times. This code walkthrough will only take us through one. The Decoder Attention Layer consists of self-attention followed by a source-attention, followed by a feed forward. Each of these are wrapped with residuals. " 1100 | ], 1101 | "metadata": { 1102 | "id": "0UJrTA06DNoj" 1103 | } 1104 | }, 1105 | { 1106 | "cell_type": "markdown", 1107 | "source": [ 1108 | "### 2.2.1. Self-Attention Sub-Layer" 1109 | ], 1110 | "metadata": { 1111 | "id": "nYKLH9tgEnVz" 1112 | } 1113 | }, 1114 | { 1115 | "cell_type": "markdown", 1116 | "source": [ 1117 | "#### 2.2.1.1 Set aside residual" 1118 | ], 1119 | "metadata": { 1120 | "id": "0mgRzWgkAfFQ" 1121 | } 1122 | }, 1123 | { 1124 | "cell_type": "code", 1125 | "source": [ 1126 | "y_residual = y.clone() \n", 1127 | "print(y.size())" 1128 | ], 1129 | "metadata": { 1130 | "colab": { 1131 | "base_uri": "https://localhost:8080/" 1132 | }, 1133 | "id": "0qZZBDBr_NRh", 1134 | "outputId": "c3dcd539-ffdb-4ff9-bef0-7a861d2ce8ad" 1135 | }, 1136 | "execution_count": null, 1137 | "outputs": [ 1138 | { 1139 | "output_type": "stream", 1140 | "name": "stdout", 1141 | "text": [ 1142 | "torch.Size([1, 3, 512])\n" 1143 | ] 1144 | } 1145 | ] 1146 | }, 1147 | { 1148 | "cell_type": "markdown", 1149 | "source": [ 1150 | "#### 2.2.1.2 Pre-Self-Attention Layer Normalization" 1151 | ], 1152 | "metadata": { 1153 | "id": "oX38ceiTAqp7" 1154 | } 1155 | }, 1156 | { 1157 | "cell_type": "code", 1158 | "source": [ 1159 | "mean = y.mean(-1, keepdim=True)\n", 1160 | "std = y.std(-1, keepdim=True)\n", 1161 | "W1_d = nn.Parameter(torch.ones(d_embed))\n", 1162 | "b1_d = nn.Parameter(torch.zeros(d_embed))\n", 1163 | "y = W1_d * (y - mean) / (std + epsilon) + b1_d\n", 1164 | "print(y.size())" 1165 | ], 1166 | "metadata": { 1167 | "colab": { 1168 | "base_uri": "https://localhost:8080/" 1169 | }, 1170 | "id": "qfK0B9FEArrw", 1171 | "outputId": "980ee0fc-cb6e-41af-a19b-9bada6a49e42" 1172 | }, 1173 | "execution_count": null, 1174 | "outputs": [ 1175 | { 1176 | "output_type": "stream", 1177 | "name": "stdout", 1178 | "text": [ 1179 | "torch.Size([1, 3, 512])\n" 1180 | ] 1181 | } 1182 | ] 1183 | }, 1184 | { 1185 | "cell_type": "markdown", 1186 | "source": [ 1187 | "#### 2.2.1.3 Self-Attention" 1188 | ], 1189 | "metadata": { 1190 | "id": "1ZOJN-Y6AkgT" 1191 | } 1192 | }, 1193 | { 1194 | "cell_type": "code", 1195 | "source": [ 1196 | "k = y\n", 1197 | "q = y\n", 1198 | "v = y\n", 1199 | "linear_q_self = nn.Linear(d_embed, d_embed)\n", 1200 | "linear_k_self = nn.Linear(d_embed, d_embed)\n", 1201 | "linear_v_self = nn.Linear(d_embed, d_embed)\n", 1202 | "d_k = d_embed // num_heads\n", 1203 | "q = linear_q_self(q)\n", 1204 | "k = linear_k_self(k)\n", 1205 | "v = linear_v_self(v)\n", 1206 | "q = q.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1207 | "k = k.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1208 | "v = v.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1209 | "print(\"q\", q.size())\n", 1210 | "print(\"k\", k.size())\n", 1211 | "print(\"v\", v.size())" 1212 | ], 1213 | "metadata": { 1214 | "colab": { 1215 | "base_uri": "https://localhost:8080/" 1216 | }, 1217 | "id": "OZF8LTbiBBNp", 1218 | "outputId": "a2cb6354-784a-4fe4-fb4b-0f4cf8e880db" 1219 | }, 1220 | "execution_count": null, 1221 | "outputs": [ 1222 | { 1223 | "output_type": "stream", 1224 | "name": "stdout", 1225 | "text": [ 1226 | "q torch.Size([1, 8, 3, 64])\n", 1227 | "k torch.Size([1, 8, 3, 64])\n", 1228 | "v torch.Size([1, 8, 3, 64])\n" 1229 | ] 1230 | } 1231 | ] 1232 | }, 1233 | { 1234 | "cell_type": "code", 1235 | "source": [ 1236 | "d_k = q.size(-1)\n", 1237 | "scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)\n", 1238 | "scores = scores.masked_fill(y_mask == 0, -epsilon)\n", 1239 | "attn = F.softmax(scores, dim = -1)\n", 1240 | "print(\"attention\", attn.size())\n", 1241 | "y = torch.matmul(attn, v)\n", 1242 | "print(\"y\", y.size())" 1243 | ], 1244 | "metadata": { 1245 | "colab": { 1246 | "base_uri": "https://localhost:8080/" 1247 | }, 1248 | "id": "IWIeDIOpBJ10", 1249 | "outputId": "f0479edd-e5cc-48b8-de7f-b6a3dd9b1010" 1250 | }, 1251 | "execution_count": null, 1252 | "outputs": [ 1253 | { 1254 | "output_type": "stream", 1255 | "name": "stdout", 1256 | "text": [ 1257 | "attention torch.Size([1, 8, 3, 3])\n", 1258 | "y torch.Size([1, 8, 3, 64])\n" 1259 | ] 1260 | } 1261 | ] 1262 | }, 1263 | { 1264 | "cell_type": "markdown", 1265 | "source": [ 1266 | "Assemble heads" 1267 | ], 1268 | "metadata": { 1269 | "id": "6ato2UicBYqj" 1270 | } 1271 | }, 1272 | { 1273 | "cell_type": "code", 1274 | "source": [ 1275 | "y = y.transpose(1, 2).contiguous().view(num_batches, -1, num_heads * (d_embed // num_heads))\n", 1276 | "print(y.size())" 1277 | ], 1278 | "metadata": { 1279 | "colab": { 1280 | "base_uri": "https://localhost:8080/" 1281 | }, 1282 | "id": "4BeJtRBjBX0c", 1283 | "outputId": "b8b486d8-3e24-4e88-a7f9-6a8f4f22ce97" 1284 | }, 1285 | "execution_count": null, 1286 | "outputs": [ 1287 | { 1288 | "output_type": "stream", 1289 | "name": "stdout", 1290 | "text": [ 1291 | "torch.Size([1, 3, 512])\n" 1292 | ] 1293 | } 1294 | ] 1295 | }, 1296 | { 1297 | "cell_type": "markdown", 1298 | "source": [ 1299 | "#### 2.2.1.4 Post-Self-Attention Feed Forward" 1300 | ], 1301 | "metadata": { 1302 | "id": "2bierEXaBjR9" 1303 | } 1304 | }, 1305 | { 1306 | "cell_type": "code", 1307 | "source": [ 1308 | "ff_d1 = nn.Linear(d_embed, d_embed)\n", 1309 | "y = ff_d1(y)\n", 1310 | "print(y.size())" 1311 | ], 1312 | "metadata": { 1313 | "colab": { 1314 | "base_uri": "https://localhost:8080/" 1315 | }, 1316 | "id": "FCnS-owhBioq", 1317 | "outputId": "c9402c1a-41f8-4312-9a95-6f24befe648f" 1318 | }, 1319 | "execution_count": null, 1320 | "outputs": [ 1321 | { 1322 | "output_type": "stream", 1323 | "name": "stdout", 1324 | "text": [ 1325 | "torch.Size([1, 3, 512])\n" 1326 | ] 1327 | } 1328 | ] 1329 | }, 1330 | { 1331 | "cell_type": "markdown", 1332 | "source": [ 1333 | "#### 2.2.1.5 Add residual back in" 1334 | ], 1335 | "metadata": { 1336 | "id": "m1vjSeu7Bp7p" 1337 | } 1338 | }, 1339 | { 1340 | "cell_type": "code", 1341 | "source": [ 1342 | "y = y_residual + y\n", 1343 | "print(y.size())" 1344 | ], 1345 | "metadata": { 1346 | "colab": { 1347 | "base_uri": "https://localhost:8080/" 1348 | }, 1349 | "id": "yLGImc73BpS4", 1350 | "outputId": "b0d2b24c-ed51-46b2-8a99-400ea1fe54a8" 1351 | }, 1352 | "execution_count": null, 1353 | "outputs": [ 1354 | { 1355 | "output_type": "stream", 1356 | "name": "stdout", 1357 | "text": [ 1358 | "torch.Size([1, 3, 512])\n" 1359 | ] 1360 | } 1361 | ] 1362 | }, 1363 | { 1364 | "cell_type": "markdown", 1365 | "source": [ 1366 | "### 2.2.2 Source-Attention Sub-Layer" 1367 | ], 1368 | "metadata": { 1369 | "id": "lmgiijtbFiXs" 1370 | } 1371 | }, 1372 | { 1373 | "cell_type": "markdown", 1374 | "source": [ 1375 | "#### 2.2.2.1 Set residual aside" 1376 | ], 1377 | "metadata": { 1378 | "id": "dMk0234JBvJg" 1379 | } 1380 | }, 1381 | { 1382 | "cell_type": "code", 1383 | "source": [ 1384 | "y_residual = y.clone() \n", 1385 | "print(y.size())" 1386 | ], 1387 | "metadata": { 1388 | "colab": { 1389 | "base_uri": "https://localhost:8080/" 1390 | }, 1391 | "id": "afyYY7fmBzB4", 1392 | "outputId": "7c39d2d8-84c9-4f78-e857-f3765ec0e14c" 1393 | }, 1394 | "execution_count": null, 1395 | "outputs": [ 1396 | { 1397 | "output_type": "stream", 1398 | "name": "stdout", 1399 | "text": [ 1400 | "torch.Size([1, 3, 512])\n" 1401 | ] 1402 | } 1403 | ] 1404 | }, 1405 | { 1406 | "cell_type": "markdown", 1407 | "source": [ 1408 | "#### 2.2.2.2 Pre-Source-Attention Layer Normalization" 1409 | ], 1410 | "metadata": { 1411 | "id": "2BxZbrnWB9aW" 1412 | } 1413 | }, 1414 | { 1415 | "cell_type": "code", 1416 | "source": [ 1417 | "mean = y.mean(-1, keepdim=True)\n", 1418 | "std = y.std(-1, keepdim=True)\n", 1419 | "W2_d = nn.Parameter(torch.ones(d_embed))\n", 1420 | "b2_d = nn.Parameter(torch.zeros(d_embed))\n", 1421 | "y = W2_d * (y - mean) / (std + epsilon) + b2_d\n", 1422 | "print(y.size())" 1423 | ], 1424 | "metadata": { 1425 | "colab": { 1426 | "base_uri": "https://localhost:8080/" 1427 | }, 1428 | "id": "Z4fC-0REB8h8", 1429 | "outputId": "ed1b9ee6-211e-4a57-9045-5c4bcc0493a0" 1430 | }, 1431 | "execution_count": null, 1432 | "outputs": [ 1433 | { 1434 | "output_type": "stream", 1435 | "name": "stdout", 1436 | "text": [ 1437 | "torch.Size([1, 3, 512])\n" 1438 | ] 1439 | } 1440 | ] 1441 | }, 1442 | { 1443 | "cell_type": "markdown", 1444 | "source": [ 1445 | "#### 2.2.2.3 Source Attention" 1446 | ], 1447 | "metadata": { 1448 | "id": "88dSgLpqB4GV" 1449 | } 1450 | }, 1451 | { 1452 | "cell_type": "markdown", 1453 | "source": [ 1454 | "Source attention works just like self-attention, except we compute the scores using keys and values from the encoder and apply it to the query from the decoder. That is, based on what the encoder thinks we should attend to, what part of the decoder sequence should we actually attend to." 1455 | ], 1456 | "metadata": { 1457 | "id": "9gww-JtEFxx_" 1458 | } 1459 | }, 1460 | { 1461 | "cell_type": "code", 1462 | "source": [ 1463 | "q = y\n", 1464 | "k = x # notice we are using x\n", 1465 | "v = x # notice we are using x\n", 1466 | "linear_q_source = nn.Linear(d_embed, d_embed)\n", 1467 | "linear_k_source = nn.Linear(d_embed, d_embed)\n", 1468 | "linear_v_source = nn.Linear(d_embed, d_embed)\n", 1469 | "d_k = d_embed // num_heads\n", 1470 | "q = linear_q(q)\n", 1471 | "k = linear_k(k)\n", 1472 | "v = linear_v(v)\n", 1473 | "q = q.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1474 | "k = k.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1475 | "v = v.view(num_batches, -1, num_heads, d_k).transpose(1, 2)\n", 1476 | "print(\"q\", q.size())\n", 1477 | "print(\"k\", k.size())\n", 1478 | "print(\"v\", v.size())" 1479 | ], 1480 | "metadata": { 1481 | "colab": { 1482 | "base_uri": "https://localhost:8080/" 1483 | }, 1484 | "id": "JTz6K1Z-B3sA", 1485 | "outputId": "714c0487-2571-40ef-a34a-f0dc53d71222" 1486 | }, 1487 | "execution_count": null, 1488 | "outputs": [ 1489 | { 1490 | "output_type": "stream", 1491 | "name": "stdout", 1492 | "text": [ 1493 | "q torch.Size([1, 8, 3, 64])\n", 1494 | "k torch.Size([1, 8, 3, 64])\n", 1495 | "v torch.Size([1, 8, 3, 64])\n" 1496 | ] 1497 | } 1498 | ] 1499 | }, 1500 | { 1501 | "cell_type": "code", 1502 | "source": [ 1503 | "d_k = q.size(-1)\n", 1504 | "scores = torch.matmul(k, q.transpose(-2, -1)) / math.sqrt(d_k)\n", 1505 | "scores = scores.masked_fill(x_mask == 0, -epsilon) # note source mask\n", 1506 | "attn = F.softmax(scores, dim = -1)\n", 1507 | "y = torch.matmul(attn, v)\n", 1508 | "print(y.size())" 1509 | ], 1510 | "metadata": { 1511 | "colab": { 1512 | "base_uri": "https://localhost:8080/" 1513 | }, 1514 | "id": "HcRRIaAGC02k", 1515 | "outputId": "ebc2928f-cfde-4f5d-839e-9770b54dfe9f" 1516 | }, 1517 | "execution_count": null, 1518 | "outputs": [ 1519 | { 1520 | "output_type": "stream", 1521 | "name": "stdout", 1522 | "text": [ 1523 | "torch.Size([1, 8, 3, 64])\n" 1524 | ] 1525 | } 1526 | ] 1527 | }, 1528 | { 1529 | "cell_type": "markdown", 1530 | "source": [ 1531 | "Assemble heads" 1532 | ], 1533 | "metadata": { 1534 | "id": "cefcyfWDC4JX" 1535 | } 1536 | }, 1537 | { 1538 | "cell_type": "code", 1539 | "source": [ 1540 | "y = y.transpose(1, 2).contiguous().view(num_batches, -1, num_heads * (d_embed // num_heads))\n", 1541 | "print(y.size())" 1542 | ], 1543 | "metadata": { 1544 | "colab": { 1545 | "base_uri": "https://localhost:8080/" 1546 | }, 1547 | "id": "Rrj9aoLXC5Uo", 1548 | "outputId": "fd602406-c63e-4de8-a7e7-d75acb089b49" 1549 | }, 1550 | "execution_count": null, 1551 | "outputs": [ 1552 | { 1553 | "output_type": "stream", 1554 | "name": "stdout", 1555 | "text": [ 1556 | "torch.Size([1, 3, 512])\n" 1557 | ] 1558 | } 1559 | ] 1560 | }, 1561 | { 1562 | "cell_type": "markdown", 1563 | "source": [ 1564 | "#### 2.2.2.4 Post-Source-Attention Feed forward" 1565 | ], 1566 | "metadata": { 1567 | "id": "3FfLZyREC832" 1568 | } 1569 | }, 1570 | { 1571 | "cell_type": "code", 1572 | "source": [ 1573 | "ff_d2 = nn.Linear(d_embed, d_embed)\n", 1574 | "y = ff_d2(y)\n", 1575 | "print(y.size())" 1576 | ], 1577 | "metadata": { 1578 | "colab": { 1579 | "base_uri": "https://localhost:8080/" 1580 | }, 1581 | "id": "bUfjx2KhC9_w", 1582 | "outputId": "b520e61f-a21b-44fe-8f03-8792138af711" 1583 | }, 1584 | "execution_count": null, 1585 | "outputs": [ 1586 | { 1587 | "output_type": "stream", 1588 | "name": "stdout", 1589 | "text": [ 1590 | "torch.Size([1, 3, 512])\n" 1591 | ] 1592 | } 1593 | ] 1594 | }, 1595 | { 1596 | "cell_type": "markdown", 1597 | "source": [ 1598 | "#### 2.2.2.5 Add residual back in" 1599 | ], 1600 | "metadata": { 1601 | "id": "BcZiKusQDBjc" 1602 | } 1603 | }, 1604 | { 1605 | "cell_type": "code", 1606 | "source": [ 1607 | "y = y_residual + y\n", 1608 | "print(y.size())" 1609 | ], 1610 | "metadata": { 1611 | "colab": { 1612 | "base_uri": "https://localhost:8080/" 1613 | }, 1614 | "id": "Dih1yg_kDDg5", 1615 | "outputId": "5a1cd47c-256e-4906-b182-30b2f91730c5" 1616 | }, 1617 | "execution_count": null, 1618 | "outputs": [ 1619 | { 1620 | "output_type": "stream", 1621 | "name": "stdout", 1622 | "text": [ 1623 | "torch.Size([1, 3, 512])\n" 1624 | ] 1625 | } 1626 | ] 1627 | }, 1628 | { 1629 | "cell_type": "markdown", 1630 | "source": [ 1631 | "### 2.2.3 Feed Forward Sub-Layer" 1632 | ], 1633 | "metadata": { 1634 | "id": "6FRfeSRBGbgy" 1635 | } 1636 | }, 1637 | { 1638 | "cell_type": "markdown", 1639 | "source": [ 1640 | "#### 2.2.3.1 Set aside residual" 1641 | ], 1642 | "metadata": { 1643 | "id": "Af-n0PTFDQ9U" 1644 | } 1645 | }, 1646 | { 1647 | "cell_type": "code", 1648 | "source": [ 1649 | "y_residual = y.clone()\n", 1650 | "print(y.size())" 1651 | ], 1652 | "metadata": { 1653 | "colab": { 1654 | "base_uri": "https://localhost:8080/" 1655 | }, 1656 | "id": "Bno9hwYxDSag", 1657 | "outputId": "4386075e-dd95-430e-bd5e-42742ac016f2" 1658 | }, 1659 | "execution_count": null, 1660 | "outputs": [ 1661 | { 1662 | "output_type": "stream", 1663 | "name": "stdout", 1664 | "text": [ 1665 | "torch.Size([1, 3, 512])\n" 1666 | ] 1667 | } 1668 | ] 1669 | }, 1670 | { 1671 | "cell_type": "markdown", 1672 | "source": [ 1673 | "#### 2.2.3.2 Pre-Feed-Forward Layer Normalization" 1674 | ], 1675 | "metadata": { 1676 | "id": "xcz5yTDsDX8O" 1677 | } 1678 | }, 1679 | { 1680 | "cell_type": "code", 1681 | "source": [ 1682 | "mean = y.mean(-1, keepdim=True)\n", 1683 | "std = y.std(-1, keepdim=True)\n", 1684 | "W3_d = nn.Parameter(torch.ones(d_embed))\n", 1685 | "b3_d = nn.Parameter(torch.zeros(d_embed))\n", 1686 | "y = W3_d * (y - mean) / (std + epsilon) + b3_d\n", 1687 | "print(y.size())" 1688 | ], 1689 | "metadata": { 1690 | "colab": { 1691 | "base_uri": "https://localhost:8080/" 1692 | }, 1693 | "id": "Ncx0PcG-DYt0", 1694 | "outputId": "e7c7a1c3-8329-485f-8d26-2a8317b7c9c8" 1695 | }, 1696 | "execution_count": null, 1697 | "outputs": [ 1698 | { 1699 | "output_type": "stream", 1700 | "name": "stdout", 1701 | "text": [ 1702 | "torch.Size([1, 3, 512])\n" 1703 | ] 1704 | } 1705 | ] 1706 | }, 1707 | { 1708 | "cell_type": "markdown", 1709 | "source": [ 1710 | "#### 2.2.3.3 Feed Forward" 1711 | ], 1712 | "metadata": { 1713 | "id": "or2ZZdXfDegg" 1714 | } 1715 | }, 1716 | { 1717 | "cell_type": "code", 1718 | "source": [ 1719 | "linear_expand_d = nn.Linear(d_embed, d_ff)\n", 1720 | "linear_compress_d = nn.Linear(d_ff, d_embed)\n", 1721 | "y = linear_compress_d(F.relu(linear_expand_d(y)))\n", 1722 | "print(y.size())" 1723 | ], 1724 | "metadata": { 1725 | "colab": { 1726 | "base_uri": "https://localhost:8080/" 1727 | }, 1728 | "id": "YHrTcy-xDgaq", 1729 | "outputId": "ecc8e139-1fc4-4478-f484-1355cf3c9895" 1730 | }, 1731 | "execution_count": null, 1732 | "outputs": [ 1733 | { 1734 | "output_type": "stream", 1735 | "name": "stdout", 1736 | "text": [ 1737 | "torch.Size([1, 3, 512])\n" 1738 | ] 1739 | } 1740 | ] 1741 | }, 1742 | { 1743 | "cell_type": "markdown", 1744 | "source": [ 1745 | "#### 2.2.3.4 Add residual back in" 1746 | ], 1747 | "metadata": { 1748 | "id": "41h94ULPDn52" 1749 | } 1750 | }, 1751 | { 1752 | "cell_type": "code", 1753 | "source": [ 1754 | "y = y_residual + y\n", 1755 | "print(y.size())" 1756 | ], 1757 | "metadata": { 1758 | "colab": { 1759 | "base_uri": "https://localhost:8080/" 1760 | }, 1761 | "id": "iJ7CK78HDpQv", 1762 | "outputId": "bf122ca5-96b4-4cb3-fdac-0d261a268366" 1763 | }, 1764 | "execution_count": null, 1765 | "outputs": [ 1766 | { 1767 | "output_type": "stream", 1768 | "name": "stdout", 1769 | "text": [ 1770 | "torch.Size([1, 3, 512])\n" 1771 | ] 1772 | } 1773 | ] 1774 | }, 1775 | { 1776 | "cell_type": "markdown", 1777 | "source": [ 1778 | "## 2.3 Final Decoder Layer Normalization" 1779 | ], 1780 | "metadata": { 1781 | "id": "lyxSbVuxDseS" 1782 | } 1783 | }, 1784 | { 1785 | "cell_type": "code", 1786 | "source": [ 1787 | "mean = y.mean(-1, keepdim=True)\n", 1788 | "std = y.std(-1, keepdim=True)\n", 1789 | "Wn_d = nn.Parameter(torch.ones(d_embed))\n", 1790 | "bn_d = nn.Parameter(torch.zeros(d_embed))\n", 1791 | "y = Wn_d * (y - mean) / (std + epsilon) + bn_d\n", 1792 | "print(y.size())" 1793 | ], 1794 | "metadata": { 1795 | "colab": { 1796 | "base_uri": "https://localhost:8080/" 1797 | }, 1798 | "id": "ld4nj8XJD0DO", 1799 | "outputId": "601a0cc4-e84f-4335-a045-7d83f70a4713" 1800 | }, 1801 | "execution_count": null, 1802 | "outputs": [ 1803 | { 1804 | "output_type": "stream", 1805 | "name": "stdout", 1806 | "text": [ 1807 | "torch.Size([1, 3, 512])\n" 1808 | ] 1809 | } 1810 | ] 1811 | }, 1812 | { 1813 | "cell_type": "markdown", 1814 | "source": [ 1815 | "# 3. Generate Probability Distribution" 1816 | ], 1817 | "metadata": { 1818 | "id": "HMYkSaeIEL33" 1819 | } 1820 | }, 1821 | { 1822 | "cell_type": "markdown", 1823 | "source": [ 1824 | "This next module sits on top of the decoder and expands the decoder output into a log probability distribution over the vocabulary for each token position. This is done for all tokens, though the only ones that will matter for loss computation are the ones that are masked. The loss calculation is not done here." 1825 | ], 1826 | "metadata": { 1827 | "id": "UDYHcYJiHZWI" 1828 | } 1829 | }, 1830 | { 1831 | "cell_type": "code", 1832 | "source": [ 1833 | "linear_scores = nn.Linear(d_embed, vocab)\n", 1834 | "probs = F.log_softmax(linear_scores(y), dim=-1)\n", 1835 | "print(probs.size())" 1836 | ], 1837 | "metadata": { 1838 | "colab": { 1839 | "base_uri": "https://localhost:8080/" 1840 | }, 1841 | "id": "BvdmMdREEOtF", 1842 | "outputId": "9e8e0c72-84a9-4264-de3a-cb7b73c91048" 1843 | }, 1844 | "execution_count": null, 1845 | "outputs": [ 1846 | { 1847 | "output_type": "stream", 1848 | "name": "stdout", 1849 | "text": [ 1850 | "torch.Size([1, 3, 50000])\n" 1851 | ] 1852 | } 1853 | ] 1854 | }, 1855 | { 1856 | "cell_type": "markdown", 1857 | "source": [ 1858 | "# Loss and Training" 1859 | ], 1860 | "metadata": { 1861 | "id": "OBOyeEY76IKk" 1862 | } 1863 | }, 1864 | { 1865 | "cell_type": "markdown", 1866 | "source": [ 1867 | "This notebook does not go through loss computation and training at this time. Loss is computed by looking at the masked probabilities and measuring the KL divergence from the actual target tokens. The code above will work with at `.backward()` once, but it recreates the linear layers etc. from scratch each time so it won't really learn anything. See [The Annotated Transformer](https://nlp.seas.harvard.edu/2018/04/03/attention.html) for a version that is closer to implementation as well as more in-depth description of the loss computation and training loop." 1868 | ], 1869 | "metadata": { 1870 | "id": "louGBtN_6MVf" 1871 | } 1872 | }, 1873 | { 1874 | "cell_type": "markdown", 1875 | "source": [ 1876 | "# Etc." 1877 | ], 1878 | "metadata": { 1879 | "id": "2VD-nZuHyOJE" 1880 | } 1881 | }, 1882 | { 1883 | "cell_type": "markdown", 1884 | "source": [ 1885 | "Careful observers will notice that I have left out a few details, such as Dropout layers, which appear in various places in the actual implementation. These details improving learning but do not significantly alter the understanding of how Transformers work." 1886 | ], 1887 | "metadata": { 1888 | "id": "-N_xmdHnyQkh" 1889 | } 1890 | } 1891 | ] 1892 | } 1893 | --------------------------------------------------------------------------------