├── README.md
├── decoder_transformers_with_pytorch_and_lightning_v2.ipynb
└── images
├── Brandmark_FullColor_Black.png
├── attention_compute_k.png
├── attention_compute_q.png
├── attention_compute_v.png
├── attention_equations.png
├── attention_final_scores.png
├── attention_final_scores_masked.png
├── attention_masking.png
├── attention_q_times_kt.png
├── attention_scaling_scores.png
├── attention_softmax.png
├── attention_softmax_masked.png
├── dec_transformer.png
├── decoder_diagram.png
├── enc_dec_attention_1.png
├── enc_dec_transformer.png
├── encoder_diagram.png
├── expected_input_output_1.png
├── expected_input_output_2.png
├── masked_attention_1.png
├── pos_encoding_1.png
├── pos_encoding_2.png
├── self_attention_1.png
├── self_attention_2.png
└── squatch_eats_pizza.png
/README.md:
--------------------------------------------------------------------------------
1 | This is code is from the book **The StatQuest Illustrated Guide to Neural Networks and AI** by Josh Starmer (me!).
2 |
3 | You can look at it here, but it's more fun to play with it, which you can do two easy ways:
4 |
5 | - Google colab:
6 |
7 |
8 |
9 | - Lightning Studio:
10 |
11 |
12 |
--------------------------------------------------------------------------------
/decoder_transformers_with_pytorch_and_lightning_v2.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "id": "9f40c232-df6e-49df-9016-6459e4af2e1e",
6 | "metadata": {
7 | "tags": []
8 | },
9 | "source": [
10 | "# StatQuest: Coding Transformers from Scratch!!!\n",
11 | "## Part 1: Decoder-Only Transformers\n",
12 | "\n",
13 | "Copyright 2024, Joshua Starmer"
14 | ]
15 | },
16 | {
17 | "cell_type": "markdown",
18 | "id": "f4226d63-8d76-40bc-a8e6-0f290a159418",
19 | "metadata": {},
20 | "source": [
21 | "---- \n",
22 | "\n",
23 | "In this tutorial, we will use **[PyTorch](https://pytorch.org/) + [Lightning](https://www.lightning.ai/)** to create and optimize a **Decoder-Only Transformer**, like the one shown in the picture below. Decoder-Only Transformers are taking over AI right now, and quite possibly their most famous use is in ChatGPT.\n",
24 | "\n",
25 | "
\n",
26 | "\n",
27 | "Although Decoder-Only Transformers look complicated and can do really cool things, the good news is that they don't actually require a lot of code. A lot of their power comes from simply making multiple copies of each component. So, with that said...\n",
28 | "\n",
29 | "In this tutorial, you will...\n",
30 | "\n",
31 | "- **[Code a Position Encoder Class From Scratch!!!](#position)** The position encoder gives a transformer a way to keep track of the order of the input tokens.\n",
32 | "\n",
33 | "- **[Code an Attention Class From Scratch!!!](#attention)** The attention class allows the transformer to keep track of the relationships among words in the input and the output.\n",
34 | "\n",
35 | "- **[Code a Decoder-Only Transformer Class From Scratch!!!](#decoder)** The Decoder-Only Transformer will combine the position encoder and attention classes that we wrote with built-in pytorch classes to process the user input and generate the output.\n",
36 | "\n",
37 | "- **[Train the Transformer!!!](#train)** We'll train the transformer to answer simple questions.\n",
38 | "\n",
39 | "- **[Use the Trained Transformer!!!](#use)** Finally, we'll use the transformer to answer simple questions.\n",
40 | "\n",
41 | "#### NOTE:\n",
42 | "This tutorial assumes that you already know the basics of coding in **Python** and are familiar with the theory behind **[Decoder-Only Transformers](https://youtu.be/bQ5BoolX9Ag)** and **[Backpropagation](https://youtu.be/IN2XmBhILt4)**. It also assumes that you are familiar with the **[Essential Matrix Algebra for Neural Networks](https://youtu.be/ZTt9gsGcdDo)** and how it applies to **[Transformers](https://youtu.be/KphmOJnLAdI)**. If not, check out the **StatQuests** by clicking on the links for each topic.\n",
43 | "\n",
44 | "#### ALSO NOTE:\n",
45 | "I strongly encourage you to play around with the code. Playing with the code is the best way to learn from it."
46 | ]
47 | },
48 | {
49 | "cell_type": "markdown",
50 | "id": "855d3a52-a43e-44cf-b8b6-b2ce88bea382",
51 | "metadata": {},
52 | "source": [
53 | "----"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "id": "63b86036-1369-441c-a14d-3ba7bdaa103b",
59 | "metadata": {
60 | "tags": []
61 | },
62 | "source": [
63 | "# Import the modules that will do all the work\n",
64 | "\n",
65 | "The very first thing we need to do is load a bunch of Python modules. Python itself is just a basic programming language. These modules give us extra functionality to create and train a Neural Network.\n",
66 | "\n",
67 | "**NOTE:** The code below will check and see if **Lightning** is installed, and if not, it will install it for you. However, if you also need to install PyTorch, check out there install page **[here.](https://pytorch.org/get-started/locally/)**"
68 | ]
69 | },
70 | {
71 | "cell_type": "code",
72 | "execution_count": null,
73 | "id": "5c520e0b-c6e4-43ce-93f5-c0f2b5e75438",
74 | "metadata": {},
75 | "outputs": [],
76 | "source": [
77 | "## First, check to see if lightning is installed, if not, install it.\n",
78 | "import pip\n",
79 | "try:\n",
80 | " __import__(\"lightning\")\n",
81 | "except ImportError:\n",
82 | " pip.main(['install', \"lightning\"]) \n",
83 | "\n",
84 | "import torch ## torch let's us create tensors and also provides helper functions\n",
85 | "import torch.nn as nn ## torch.nn gives us nn.Module(), nn.Embedding() and nn.Linear()\n",
86 | "import torch.nn.functional as F # This gives us the softmax() and argmax()\n",
87 | "from torch.optim import Adam ## We will use the Adam optimizer, which is, essentially, \n",
88 | " ## a slightly less stochastic version of stochastic gradient descent.\n",
89 | "from torch.utils.data import TensorDataset, DataLoader ## We'll store our data in DataLoaders\n",
90 | "\n",
91 | "import lightning as L ## Lightning makes it easier to write, optimize and scale our code"
92 | ]
93 | },
94 | {
95 | "cell_type": "markdown",
96 | "id": "6e58ebb2-1798-41c3-9ff8-1b4f50605964",
97 | "metadata": {},
98 | "source": [
99 | "----"
100 | ]
101 | },
102 | {
103 | "cell_type": "markdown",
104 | "id": "fa9ddca3-53b6-451b-8fe1-6116e1f7f473",
105 | "metadata": {},
106 | "source": [
107 | "# Create the input and output and data"
108 | ]
109 | },
110 | {
111 | "cell_type": "markdown",
112 | "id": "67d17fd2-ac78-4664-93bb-4dd64c448e1e",
113 | "metadata": {},
114 | "source": [
115 | "In this tutorial we will build a simple Decoder-Only Transformer that can answer two super simple questions, **What is StatQuest?** and **StatQuest is what?**, and give them both the same answer, **Awesome!!!**\n",
116 | "\n",
117 | "In order to keep track of our simple dataset,\n",
118 | "we'll create a dictionary that maps the words and tokens to ID numbers. This is because the class we will use to do word embedding for us, `nn.Embedding()`, only accepts ID numbers as input, rather than words or tokens. Then we will use the dictionary to create a **Dataloader** that contains the questions and the desired answers encoded as ID numbers. Ultimately we'll use the **Dataloader** to train the transformer. **NOTE:** Dataloaders are designed to scale to very large datasets, so this simple example should be useful even when you have a terabyte of text.\n",
119 | "\n",
120 | "**ALSO NOTE:** The **inputs** and **labels** for the training data used with a Decoder-Only Transformer can seem a little strange at first. This is because a Decoder-Only Transformer generates a lot of the user input in addition to the response. To get a sense of what this means, let's pretend we want to train our Decoder-Only Transformer to answer the question **What is StatQuest?** with the response **Awesome**. In the figure below, on the left side, we see that the first token in the input **What** generates the output **is**. During training, we can compare this output to the known second token in the input, and if it is different, use that difference to modify the weights and biases in the model. Thus, even though **is** is part of the **input**, it is also part of the **label** that we use to evaluate how well the Decoder-Only Transformer is performing and whether or not the weights and biases should be changed. Likewise, **StatQuest**, **<\\EOS>**, and **awesome** can also be in both the **input** and in the **label** because we know we want the Decoder-Only Transformer to use them as inputs and generate them as outputs."
121 | ]
122 | },
123 | {
124 | "cell_type": "markdown",
125 | "id": "752cda95-7e7e-430a-8c1a-4afc9670379b",
126 | "metadata": {},
127 | "source": [
128 | "
"
129 | ]
130 | },
131 | {
132 | "cell_type": "code",
133 | "execution_count": null,
134 | "id": "63631347-8db6-4812-8198-9169b2df0b24",
135 | "metadata": {},
136 | "outputs": [],
137 | "source": [
138 | "## first, we create a dictionary that maps vocabulary tokens to id numbers...\n",
139 | "token_to_id = {'what' : 0,\n",
140 | " 'is' : 1,\n",
141 | " 'statquest' : 2,\n",
142 | " 'awesome': 3,\n",
143 | " '' : 4, ## = end of sequence\n",
144 | " }\n",
145 | "## ...then we create a dictionary that maps the ids to tokens. This will help us interpret the output.\n",
146 | "## We use the \"map()\" function to apply the \"reversed()\" function to each tuple (i.e. ('what', 0)) stored\n",
147 | "## in the token_to_id dictionary. We then use dict() to make a new dictionary from the\n",
148 | "## reversed tuples.\n",
149 | "id_to_token = dict(map(reversed, token_to_id.items()))\n",
150 | "\n",
151 | "## NOTE: Because we are using a Decoder-Only Transformer, the inputs contain\n",
152 | "## the questions (\"what is statquest?\" and \"statquest is what?\") followed\n",
153 | "## by an token followed by the response, \"awesome\".\n",
154 | "## This is because all of those tokens will be used as inputs to the Decoder-Only\n",
155 | "## Transformer during Training. (See the illustration above for more details) \n",
156 | "## ALSO NOTE: When we train this way, it's called \"teacher forcing\".\n",
157 | "## Teacher forcing helps us train the neural network faster.\n",
158 | "inputs = torch.tensor([[token_to_id[\"what\"], ## input #1: what is statquest awesome\n",
159 | " token_to_id[\"is\"], \n",
160 | " token_to_id[\"statquest\"], \n",
161 | " token_to_id[\"\"],\n",
162 | " token_to_id[\"awesome\"]], \n",
163 | " \n",
164 | " [token_to_id[\"statquest\"], # input #2: statquest is what awesome\n",
165 | " token_to_id[\"is\"], \n",
166 | " token_to_id[\"what\"], \n",
167 | " token_to_id[\"\"], \n",
168 | " token_to_id[\"awesome\"]]])\n",
169 | "\n",
170 | "## NOTE: Because we are using a Decoder-Only Transformer the outputs, or\n",
171 | "## the predictions, are the input questions (minus the first word) followed by \n",
172 | "## awesome . The first means we're done processing the input question\n",
173 | "## and the second means we are done generating the output.\n",
174 | "## See the illustration above for more details.\n",
175 | "labels = torch.tensor([[token_to_id[\"is\"], \n",
176 | " token_to_id[\"statquest\"], \n",
177 | " token_to_id[\"\"], \n",
178 | " token_to_id[\"awesome\"], \n",
179 | " token_to_id[\"\"]], \n",
180 | " \n",
181 | " [token_to_id[\"is\"], \n",
182 | " token_to_id[\"what\"], \n",
183 | " token_to_id[\"\"], \n",
184 | " token_to_id[\"awesome\"], \n",
185 | " token_to_id[\"\"]]])\n",
186 | "\n",
187 | "## Now let's package everything up into a DataLoader...\n",
188 | "dataset = TensorDataset(inputs, labels) \n",
189 | "dataloader = DataLoader(dataset)"
190 | ]
191 | },
192 | {
193 | "cell_type": "markdown",
194 | "id": "3de72068-c081-4868-9e60-1f5cc6cf45df",
195 | "metadata": {},
196 | "source": [
197 | "Now that we have created the input and output datasets and the **Dataloader** to train the model, let's start building it."
198 | ]
199 | },
200 | {
201 | "cell_type": "markdown",
202 | "id": "3cb1335c-7bd0-4c60-80c1-6043ee11accd",
203 | "metadata": {},
204 | "source": [
205 | "----"
206 | ]
207 | },
208 | {
209 | "cell_type": "markdown",
210 | "id": "42efbdf7-b3de-4b12-b2ec-a54651b9df79",
211 | "metadata": {},
212 | "source": [
213 | "\n",
214 | "# Position Encoding\n",
215 | "\n",
216 | "Position Encoding helps the transformer keep track of the order of the words in the input and the output. For example, in the picture below, we see that the two phrases **Squatch eats pizza** and **Pizza eats Squatch** both have the exact same words, but, due to differences in the word order, have very different meanings. Thus, keeping track of word order is very important.\n",
217 | "\n",
218 | "
\n",
219 | "\n",
220 | "There are a bunch of ways for a transformer to keep track of word order, but one popular method is to use a series of alternating sine and cosine curves (seen below). The number of sine and cosine curves depends on how many numbers, or word embedding values, we use to represent each token. In the context of Transformers, the number of numbers, or word embedding values, we use to represent each token is the **dimension** of the transformer. So, if the transformer's dimension is 2, meaning that it uses 2 numbers to represent each token, then we only need one sine and one cosine to generate two position encoding values. \n",
221 | "\n",
222 | "
\n",
223 | "\n",
224 | "In contrast, as we see in the illustration below, if the transformer's dimension is 4, then we'll need 2 sine curves alternating with 2 cosine curves, for a total of 4 curves.\n",
225 | "\n",
226 | "
\n",
227 | "\n",
228 | "As we see in the illustration above, the additional pair of sine and cosine curves have a wider period (they repeat less frequently) than the first pair. Increasing the period for each additional pair of curves ensures that each position is represented by a unique combination of values.\n",
229 | "\n",
230 | "**NOTE:** The reason why we are bothering to create a class to do positional encoding, instead of just adding this code directly to the transformer, is that we can easily re-use it in an Encoder-Only Transformer or an Encoder-Decoder Transformer. So, by creating a class that does positional encoding, we can code it once, and then just create instances when and where we need it.\n",
231 | "\n",
232 | "**ALSO NOTE:** Since the position encoding values never change, meaning that the first token always uses the same position encoding values regardless of what that token is, we precompute them and save them in a lookup table. This makes adding position encoding values super fast.\n",
233 | "\n",
234 | "Now that we understand the ideas that we want to implement in the Position Encoding class, let's code it!"
235 | ]
236 | },
237 | {
238 | "cell_type": "code",
239 | "execution_count": null,
240 | "id": "b8b62789-ee84-49bf-b736-12c1b47b34ae",
241 | "metadata": {},
242 | "outputs": [],
243 | "source": [
244 | "class PositionEncoding(nn.Module):\n",
245 | " \n",
246 | " def __init__(self, d_model=2, max_len=6):\n",
247 | " ## d_model = The dimension of the transformer, which is also the number of embedding values per token.\n",
248 | " ## In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!\n",
249 | " ## d_model=2, so that's what we'll use as a default for now.\n",
250 | " ## However, in \"Attention Is All You Need\" d_model=512\n",
251 | " ## max_len = maximum number of tokens we allow as input.\n",
252 | " ## Since we are precomputing the position encoding values and storing them in a lookup table\n",
253 | " ## we can use d_model and max_len to determine the number of rows and columns in that\n",
254 | " ## lookup table.\n",
255 | " ##\n",
256 | " ## In this simple example, we are only using short phrases, so we are using\n",
257 | " ## max_len=6 as the default setting.\n",
258 | " ## However, in The Annotated Transformer, they set the default value for max_len to 5000\n",
259 | " \n",
260 | " super().__init__()\n",
261 | " ## We call the super's init because by creating our own __init__() method, we overwrite the one\n",
262 | " ## we inherited from nn.Module. So we have to explicity call nn.Module's __init__(), otherwise it\n",
263 | " ## won't get initialized. NOTE: If we didn't write our own __init__(), then we would not have\n",
264 | " ## to call super().__init__(). Alternatively, if we didn't want to access any of nn.Module's methods, \n",
265 | " ## we wouldn't have to call it then either.\n",
266 | "\n",
267 | " ## Now we create a lookup table, pe, of position encoding values and initialize all of them to 0.\n",
268 | " ## To do this, we will make a matrix of 0s that has max_len rows and d_model columns.\n",
269 | " ## for example...\n",
270 | " ## torch.zeros(3, 2)\n",
271 | " ## ...returns a matrix of 0s with 3 rows and 2 columns...\n",
272 | " ## tensor([[0., 0.],\n",
273 | " ## [0., 0.],\n",
274 | " ## [0., 0.]])\n",
275 | " pe = torch.zeros(max_len, d_model)\n",
276 | "\n",
277 | " ## Now we create a sequence of numbers for each position that a token can have in the input (or output).\n",
278 | " ## For example, if the input tokens where \"I'm happy today!\", then \"I'm\" would get the first\n",
279 | " ## position, 0, \"happy\" would get the second position, 1, and \"today!\" would get the third position, 2.\n",
280 | " ## NOTE: Since we are going to be doing math with these position indices to create the \n",
281 | " ## positional encoding for each one, we need them to be floats rather than ints.\n",
282 | " ## \n",
283 | " ## NOTE: Two ways to create floats are...\n",
284 | " ##\n",
285 | " ## torch.arange(start=0, end=3, step=1, dtype=torch.float)\n",
286 | " ##\n",
287 | " ## ...and...\n",
288 | " ##\n",
289 | " ## torch.arange(start=0, end=3, step=1).float()\n",
290 | " ##\n",
291 | " ## ...but the latter is just as clear and requires less typing.\n",
292 | " ##\n",
293 | " ## Lastly, .unsqueeze(1) converts the single list of numbers that torch.arange creates into a matrix with\n",
294 | " ## one row for each index, and all of the indices in a single column. So if \"max_len\" = 3, then we\n",
295 | " ## would create a matrix with 3 rows and 1 column like this...\n",
296 | " ##\n",
297 | " ## torch.arange(start=0, end=3, step=1, dtype=torch.float).unsqueeze(1)\n",
298 | " ##\n",
299 | " ## ...returns...\n",
300 | " ##\n",
301 | " ## tensor([[0.],\n",
302 | " ## [1.],\n",
303 | " ## [2.]]) \n",
304 | " position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)\n",
305 | "\n",
306 | "\n",
307 | " ## Here is where we start doing the math to determine the y-axis coordinates on the\n",
308 | " ## sine and cosine curves.\n",
309 | " ##\n",
310 | " ## The positional encoding equations used in \"Attention is all you need\" are...\n",
311 | " ##\n",
312 | " ## PE(pos, 2i) = sin(pos / 10000^(2i/d_model))\n",
313 | " ## PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))\n",
314 | " ##\n",
315 | " ## ...and we see, within the sin() and cos() functions, we divide \"pos\" by some number that depends\n",
316 | " ## on the index (i) and total number of PE values we want per token (d_model). \n",
317 | " ##\n",
318 | " ## NOTE: When the index, i, is 0 then we are calculating the y-axis coordinates on the **first pair** \n",
319 | " ## of sine and cosine curves. When i=1, then we are calculating the y-axis coordiantes on the \n",
320 | " ## **second pair** of sine and cosine curves. etc. etc.\n",
321 | " ##\n",
322 | " ## Now, pretty much everyone calculates the term we use to divide \"pos\" by first, and they do it with\n",
323 | " ## code that looks like this...\n",
324 | " ##\n",
325 | " ## div_term = torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(math.log(10000.0) / d_model))\n",
326 | " ##\n",
327 | " ## Now, at least to me, it's not obvious that div_term = 1/(10000^(2i/d_model)) for a few reasons:\n",
328 | " ##\n",
329 | " ## 1) div_term wraps everything in a call to torch.exp() \n",
330 | " ## 2) It uses log()\n",
331 | " ## 2) The order of the terms is different \n",
332 | " ##\n",
333 | " ## The reason for these differences is, presumably, trying to prevent underflow (getting too close to 0).\n",
334 | " ## So, to show that div_term = 1/(10000^(2i/d_model))...\n",
335 | " ##\n",
336 | " ## 1) Swap out math.log() for torch.log() (doing this requires converting 10000.0 to a tensor, which is my\n",
337 | " ## guess for why they used math.log() instead of torch.log())...\n",
338 | " ##\n",
339 | " ## torch.exp(torch.arange(start=0, end=d_model, step=2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))\n",
340 | " ##\n",
341 | " ## 2) Rearrange the terms...\n",
342 | " ##\n",
343 | " ## torch.exp(-1 * (torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model))\n",
344 | " ##\n",
345 | " ## 3) Pull out the -1 with exp(-1 * x) = 1/exp(x)\n",
346 | " ##\n",
347 | " ## 1/torch.exp(torch.log(torch.tensor(10000.0)) * torch.arange(start=0, end=d_model, step=2).float() / d_model)\n",
348 | " ##\n",
349 | " ## 4) Use exp(a * b) = exp(a)^b to pull out the 2i/d_model term...\n",
350 | " ##\n",
351 | " ## 1/torch.exp(torch.log(torch.tensor(10000.0)))^(torch.arange(start=0, end=d_model, step=2).float() / d_model)\n",
352 | " ##\n",
353 | " ## 5) Use exp(log(x)) = x to get the original form of the denominator...\n",
354 | " ##\n",
355 | " ## 1/(torch.tensor(10000.0)^(torch.arange(start=0, end=d_model, step=2).float() / d_model))\n",
356 | " ##\n",
357 | " ## 6) Bam.\n",
358 | " ## \n",
359 | " ## So, that being said, I don't think underflow is actually that big an issue. In fact, some coder at Hugging Face\n",
360 | " ## also doesn't think so, and their code for positional encoding in DistilBERT (a streamlined version of BERT, which\n",
361 | " ## is a transformer model)\n",
362 | " ## calculates the values directly - using the form of the equation found in original Attention is all you need\n",
363 | " ## manuscript. See...\n",
364 | " ## https://github.com/huggingface/transformers/blob/455c6390938a5c737fa63e78396cedae41e4e87e/src/transformers/modeling_distilbert.py#L53\n",
365 | " ## So I think we can simplify the code, but I'm also writing all these comments to show that it is equivalent to what\n",
366 | " ## you'll see in the wild...\n",
367 | " ##\n",
368 | " ## Now let's create an index for the embedding positions to simplify the code a little more...\n",
369 | " embedding_index = torch.arange(start=0, end=d_model, step=2).float()\n",
370 | " ## NOTE: Setting step=2 results in the same sequence numbers that we would get if we multiplied i by 2.\n",
371 | " ## So we can save ourselves a little math by just setting step=2.\n",
372 | "\n",
373 | " ## And now, finally, let's create div_term...\n",
374 | " div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)\n",
375 | " \n",
376 | " ## Now we calculate the actual positional encoding values. Remember 'pe' was initialized as a matrix of 0s\n",
377 | " ## with max_len (max number of input tokens) rows and d_model (number of embedding values per token) columns.\n",
378 | " pe[:, 0::2] = torch.sin(position * div_term) ## every other column, starting with the 1st, has sin() values\n",
379 | " pe[:, 1::2] = torch.cos(position * div_term) ## every other column, starting with the 2nd, has cos() values\n",
380 | " ## NOTE: If the notation for indexing 'pe[]' looks cryptic to you, read on...\n",
381 | " ##\n",
382 | " ## First, let's look at the general indexing notation:\n",
383 | " ##\n",
384 | " ## For each row or column in matrix we can select elements in that\n",
385 | " ## row or column with the following indexs...\n",
386 | " ##\n",
387 | " ## i:j:k = select elements between i and j with stepsize = k.\n",
388 | " ##\n",
389 | " ## ...where...\n",
390 | " ##\n",
391 | " ## i defaults to 0\n",
392 | " ## j defaults to the number of elements in the row, column or whatever.\n",
393 | " ## k defaults to 1\n",
394 | " ##\n",
395 | " ## Now that we have looked at the general notation, let's look at specific\n",
396 | " ## examples so that we can understand it.\n",
397 | " ##\n",
398 | " ## We'll start with: pe[:, 0::2]\n",
399 | " ##\n",
400 | " ## The stuff that comes before the comma (in this case ':') refers to the rows we want to select.\n",
401 | " ## The ':' before the comma means \"select all rows\" because we are not providing specific \n",
402 | " ## values for i, j and k and, instead, just using the default values.\n",
403 | " ##\n",
404 | " ## The stuff after the comma refers to the columns we want to select.\n",
405 | " ## In this case, we have '0::2', and that means we start with\n",
406 | " ## the first column (column = 0) and go to the end (using the default value for j)\n",
407 | " ## and we set the stepsize to 2, which means we skip every other column.\n",
408 | " ##\n",
409 | " ## Now to understand pe[:, 1::2]\n",
410 | " ##\n",
411 | " ## Again, the stuff before the comma refers to the rows, and, just like before\n",
412 | " ## we use default values for i,j and k, so we select all rows.\n",
413 | " ##\n",
414 | " ## The stuff that comes after the comma refers to the columns.\n",
415 | " ## In this case, we start with the 2nd column (column = 1), and go to the end\n",
416 | " ## (using the default value for 'j') and we set the stepsize to 2, which\n",
417 | " ## means we skip every other column.\n",
418 | " ##\n",
419 | " ## NOTE: using this ':' based notation is called \"indexing\" and also called \"slicing\"\n",
420 | " \n",
421 | " ## Now we \"register 'pe'.\n",
422 | " self.register_buffer('pe', pe) ## \"register_buffer()\" ensures that\n",
423 | " ## 'pe' will be moved to wherever the model gets\n",
424 | " ## moved to. So if the model is moved to a GPU, then,\n",
425 | " ## even though we don't need to optimize 'pe', it will \n",
426 | " ## also be moved to that GPU. This, in turn, means\n",
427 | " ## that accessing 'pe' will be relatively fast copared\n",
428 | " ## to having a GPU have to get the data from a CPU.\n",
429 | "\n",
430 | " ## Because this class, PositionEncoding, inherits from nn.Module, the forward() method \n",
431 | " ## is called by default when we use a PositionEncoding() object.\n",
432 | " ## In other words, after we create a PositionEncoding() object, pe = PositionEncoding(),\n",
433 | " ## then pe(word_embeddings) will call forward() and so this is where \n",
434 | " ## we will add the position encoding values to the word embedding values\n",
435 | " def forward(self, word_embeddings):\n",
436 | " \n",
437 | " return word_embeddings + self.pe[:word_embeddings.size(0), :] ## word_embeddings.size(0) = number of embeddings\n",
438 | " ## NOTE: That second ':' is optional and \n",
439 | " ## we could re-write it like this: \n",
440 | " ## self.pe[:word_embeddings.size(0)]"
441 | ]
442 | },
443 | {
444 | "cell_type": "markdown",
445 | "id": "f4fe202d-6995-4322-ad5c-119ef7dc4b28",
446 | "metadata": {},
447 | "source": [
448 | "----"
449 | ]
450 | },
451 | {
452 | "cell_type": "markdown",
453 | "id": "c20eccdc-3188-4c6d-94da-d28db057dc24",
454 | "metadata": {},
455 | "source": [
456 | "\n",
457 | "# Attention\n",
458 | "We're going to code an `Attention` class to do all of the types of attention that a transformer might need: **Self-Attention**, **Masked Self-Attention** (which is used by the Decoder during training), and **Encoder-Decoder Attention**.\n",
459 | "\n",
460 | "**Self-Attention** is a type of attention used in Encoder-Decoder and Encoder-Only transformers. It allows every word in a phrase to define a relationship with any other word in the phrase, regardless of the order of the words. In other words, if the phrase is **The pizza came out of the oven and it tasted good!**, then the word **it** can define it's relationship with every word in that phrase, including words that came after it, like **tasted** and **good**, as illustrated by the blue arrows in the figure below.\n",
461 | "\n",
462 | "
\n",
463 | "\n",
464 | "**Masked Self-Attention** is used by Encoder-Decoder and Decoder-Only transformers and it allows each word in a phrase to define a relationship with itself and the words that came before it. In other words, **Masked Self-Attention** prevents the transformer from \"looking ahead\". This is illustrated below where the word **it** can define relationships with itself and everything that came earlier in the input. In Encoder-Decoder transformers, **Masked Self-Attention** is used during training, when we know what the output should be, but we still force the decoder to generate it one token at a time, thus, limiting attention to only output words that came earlier. In contrast, Decoder-Only transformers use **Masked Self-Attention** all the time, on the input and the output, during training and inference. Thus, even though the Decoder-Only transformer can see all of the input during training and inference, it still only allows the attention values for each word to depend on words that came before it.\n",
465 | "\n",
466 | "
\n",
467 | "\n",
468 | "**Encoder-Decoder Attention** is only used in Encoder-Decoder transformers, where there is a distinct separation of the part of the transformer that processes the input (the encoder) from the part that generates the output (the decoder). **Encoder-Decoder Attention** lets each word in the output (in the decoder) define relationships with all the words in the input (in the encoder), as illustrated in the figure below.\n",
469 | "\n",
470 | "
\n",
471 | "\n",
472 | "Now that we have a general sense of the three types of attention used in transformers, we can talk about how it's calculated. "
473 | ]
474 | },
475 | {
476 | "cell_type": "markdown",
477 | "id": "749d3c97-616d-4af1-a463-960fc36fdcf6",
478 | "metadata": {},
479 | "source": [
480 | "First, the general equations for the different types of attention are almost identical as seen in the figure below. In the equations, **Q** is for the **Query** matrix, **K** is for the **Key** matrix and **V** is for the **Value** matrix. On the left, we have the equation for Self-Attention and Encoder-Decoder Attention. The differences in these types of attention are not from the equation we use, but from how the **Q**, **K**, and **V** matrices are computed. On the right, we see the equation for Masked Self-Attention and the only difference it has from the equation on the left is the addition of a **Mask** matrix, **M**, that prevents words that come after a specific **Query** from being included in the final attention scores. \n",
481 | "\n",
482 | "
\n",
483 | "\n",
484 | "**NOTE:** Since both equations are very similar, we'll go through one example and point out the key differences when we get to them.\n",
485 | "\n",
486 | "First, given word embedding values for each word/token in the input phrase **\\ let's go** in matrix form, we multiply them by matrices of weights to create **Queries**, **Keys**, and **Values** \n",
487 | "\n",
488 | "
\n",
489 | "\n",
490 | "
\n",
491 | "\n",
492 | "
\n",
493 | "\n",
494 | "We then multiply the **Queries** by the transpose of the **Keys** so that the query for each word calculates a similarity value with the keys for all of the words. **NOTE:** As seen in the illustration below, Masked Self-Attention calculates the values for all **Query/Key** pairs, but, ultimately, ignores values for when a token's **Query** comes before other token's **Keys**. For example, if the **Query** is for the first token **\\**, then Masked Self-Attention will ignore the values calculated with **Keys** for **Let's** and **go** because those tokens come after **\\**.\n",
495 | "\n",
496 | "
\n",
497 | "\n",
498 | "The next step is to scale the similarity scores by the square root of the number of columns in the **Key** matrix, which represents the number of values used to represent each token. In this case, we scale by the square root of 2.\n",
499 | "\n",
500 | "
\n",
501 | "\n",
502 | "Now, if we were doing Masked Self-Attention, we would mask out the values we want to ignore by adding -infinity to them, as seen below. This step is the only difference between Self-Attention and Masked Self-Attention. \n",
503 | "\n",
504 | "
\n",
505 | "\n",
506 | "The next step is to apply the **SoftMax()** function to each row in the scaled similarities. We'll do this first for the Self-Attention without a mask (below)...\n",
507 | "\n",
508 | "
\n",
509 | "\n",
510 | "...and we'll also do it for Masked Self-Attention (below).\n",
511 | "\n",
512 | "
\n",
513 | "\n",
514 | "The `SoftMax()` function gives us percentages that the **Values** for each token should contribute to the attention score for a specific token. Thus, we can get the final attention scores by multiplying the percentages with the **Values** in matrix **V**. First, we'll do this with the unmasked percentages...\n",
515 | "\n",
516 | "
\n",
517 | "\n",
518 | "...and then we'll calculate the final Masked Self-Attention scores.\n",
519 | "\n",
520 | "
\n",
521 | "\n",
522 | "# BAM!\n",
523 | "\n",
524 | "Now that we know how to calculate the different types of attention, let's code the `Attention()` class."
525 | ]
526 | },
527 | {
528 | "cell_type": "code",
529 | "execution_count": null,
530 | "id": "e3392130-cd25-4000-97bb-9612764c83a8",
531 | "metadata": {},
532 | "outputs": [],
533 | "source": [
534 | "class Attention(nn.Module): \n",
535 | " \n",
536 | " def __init__(self, d_model=2):\n",
537 | " ## d_model = the number of embedding values per token.\n",
538 | " ## In the transformer I used in the StatQuest: Transformer Neural Networks Clearly Explained!!!\n",
539 | " ## d_model=2, so that's what we'll use as a default for now.\n",
540 | " ## However, in \"Attention Is All You Need\" d_model=512\n",
541 | "\n",
542 | " \n",
543 | " super().__init__()\n",
544 | " \n",
545 | " self.d_model=d_model\n",
546 | " \n",
547 | " ## Initialize the Weights (W) that we'll use to create the\n",
548 | " ## query (q), key (k) and value (v) numbers for each token\n",
549 | " ## NOTE: Most implementations that I looked at include the bias terms\n",
550 | " ## but I didn't use them in my video (since they are not in the \n",
551 | " ## original Attention is All You Need paper).\n",
552 | " self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
553 | " self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
554 | " self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
555 | " \n",
556 | " ## NOTE: In this simple example, we are not training on the data in \"batches\"\n",
557 | " ## However, by defining variables for row_dim and col_dim, we could\n",
558 | " ## allow for batches by setting row_dim to 1 and col_com to 2.\n",
559 | " self.row_dim = 0\n",
560 | " self.col_dim = 1\n",
561 | "\n",
562 | " \n",
563 | " def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):\n",
564 | " ## Create the query, key and values using the encodings\n",
565 | " ## associated with each token (token encodings)\n",
566 | " ##\n",
567 | " ## NOTE: For Encoder-Decoder Attention, the encodings for q come from\n",
568 | " ## the decoder and the encodings for k and v come from the output\n",
569 | " ## from the encoder.\n",
570 | " ## In all of the other types of attention, the encodings all\n",
571 | " ## come from the same source.\n",
572 | " q = self.W_q(encodings_for_q)\n",
573 | " k = self.W_k(encodings_for_k)\n",
574 | " v = self.W_v(encodings_for_v)\n",
575 | "\n",
576 | " ## Compute attention scores\n",
577 | " ## the equation is (q * k^T)/sqrt(d_model)\n",
578 | " ## NOTE: It seems most people use \"reverse indexing\" for the dimensions when transposing k\n",
579 | " ## k.transpose(dim0, dim1) will transpose k by swapping dim0 and dim1\n",
580 | " ## In standard matrix notation, we would want to swap rows (dim=0) with columns (dim=1)\n",
581 | " ## If we have 3 dimensions, because of batching, and the batch was the first dimension\n",
582 | " ## And thus dims are defined batch = 0, rows = 1, columns = 2\n",
583 | " ## then dim0=-2 = 3 - 2 = 1. dim1=-1 = 3 - 1 = 2.\n",
584 | " ## Alternatively, we could put the batches in dim 3, and thus, dim 0 would still be rows\n",
585 | " ## and dim 1 would still be columns. I'm not sure why batches are put in dim 0...\n",
586 | " ##\n",
587 | " ## Likewise, the q.size(-1) uses negative indexing to rever to the number of columns in the query\n",
588 | " ## which tells us d_model. Alternatively, we could ust q.size(2) if we have batches in the first\n",
589 | " ## dimension or q.size(1) if we have batches in the 3rd dimension.\n",
590 | " ##\n",
591 | " ## Since there are a bunch of ways to index things, I think the best thing to do is use\n",
592 | " ## variables \"row_dim\" and \"col_dim\" instead of numbers...\n",
593 | " sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))\n",
594 | "\n",
595 | " scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)\n",
596 | "\n",
597 | " if mask is not None:\n",
598 | " ## Here we are masking out things we don't want to pay attention to,\n",
599 | " ## like tokens that come after the current token.\n",
600 | " ## We can also use masking to block out the token,\n",
601 | " ## which is used when we have a batch of inputs sequences\n",
602 | " ## and they are not all the exact same length. Because the batch is passed\n",
603 | " ## in as a matrix, each input sequence has to have the same length, so we\n",
604 | " ## add to the shorter sequences so that they are all as long ast the\n",
605 | " ## longest sequence.\n",
606 | " ##\n",
607 | " ## We replace , or tokens that come after the current token\n",
608 | " ## with a very large negative number so that the SoftMax() function\n",
609 | " ## will give all masked elements an output value (or \"probability\") of 0.\n",
610 | " scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9) # I've also seen -1e20 and -9e15 used in masking\n",
611 | " \n",
612 | " ## Apply softmax to determine what percent of each token's value to\n",
613 | " ## use in the final attention values.\n",
614 | " attention_percents = F.softmax(scaled_sims, dim=self.col_dim)\n",
615 | "\n",
616 | " ## Scale the values by their associated percentages and add them up.\n",
617 | " attention_scores = torch.matmul(attention_percents, v)\n",
618 | " \n",
619 | " return attention_scores"
620 | ]
621 | },
622 | {
623 | "cell_type": "markdown",
624 | "id": "d92c2f59-b64e-4dbc-9a33-74391ff84683",
625 | "metadata": {},
626 | "source": [
627 | "# BAM!\n",
628 | "\n",
629 | "Now that we have coded the `Attention()` class, we can build a Decoder-Only Transformer."
630 | ]
631 | },
632 | {
633 | "cell_type": "markdown",
634 | "id": "291359b2-d509-43d3-aacc-648721690246",
635 | "metadata": {},
636 | "source": [
637 | "----"
638 | ]
639 | },
640 | {
641 | "cell_type": "markdown",
642 | "id": "791ffd7e-6671-4153-8d47-4c3b74e61341",
643 | "metadata": {},
644 | "source": [
645 | "\n",
646 | "# The Decoder-Only Transformer"
647 | ]
648 | },
649 | {
650 | "cell_type": "markdown",
651 | "id": "abf2d3f7-41cd-4645-a105-6b9a924df22d",
652 | "metadata": {},
653 | "source": [
654 | "\n",
655 | "\n",
656 | "
\n",
657 | "\n",
658 | "A Decoder-Only Transformer simply brings together...\n",
659 | "\n",
660 | "- Word Embedding\n",
661 | "- Position Encoding\n",
662 | "- Masked Self-Attention\n",
663 | "- Residual Connections\n",
664 | "- A fully connected layer\n",
665 | "- SoftMax - However, the loss function we are using `nn.CrossEntropyLoss()`, applies the SoftMax for us."
666 | ]
667 | },
668 | {
669 | "cell_type": "code",
670 | "execution_count": null,
671 | "id": "1f286589-e933-46d2-aeda-600c74799357",
672 | "metadata": {},
673 | "outputs": [],
674 | "source": [
675 | "class DecoderOnlyTransformer(L.LightningModule):\n",
676 | " \n",
677 | " def __init__(self, num_tokens=4, d_model=2, max_len=6):\n",
678 | " \n",
679 | " super().__init__()\n",
680 | " \n",
681 | " ## We are set the seed so that you can get the same results as me.\n",
682 | " L.seed_everything(seed=42)\n",
683 | " \n",
684 | " \n",
685 | " ## NOTE: In this simple example, we are just using a \"single layer\" decoder.\n",
686 | " ## If we wanted to have multiple layers of decoder, then we would\n",
687 | " ## take the output of one decoder module and use it as input to\n",
688 | " ## the next module.\n",
689 | " \n",
690 | " self.we = nn.Embedding(num_embeddings=num_tokens, \n",
691 | " embedding_dim=d_model) \n",
692 | " \n",
693 | " self.pe = PositionEncoding(d_model=d_model, \n",
694 | " max_len=max_len)\n",
695 | "\n",
696 | " self.self_attention = Attention(d_model=d_model)\n",
697 | " ## NOTE: In this simple example, we are not doing multi-head attention\n",
698 | " ## If we wanted to do multi-head attention, we could\n",
699 | " ## initailize more Attention objects like this...\n",
700 | " ##\n",
701 | " ## self.self_attention_2 = Attention(d_model=d_model)\n",
702 | " ## self.self_attention_3 = Attention(d_model=d_model)\n",
703 | " ##\n",
704 | " ## If d_model=2, then using 3 self_attention objects would \n",
705 | " ## result in d_model*3 = 6 self-attention values per token, \n",
706 | " ## so we would need to initialize\n",
707 | " ## a fully connected layer to reduce the dimension of the \n",
708 | " ## self attention values back down to d_model like this:\n",
709 | " ## \n",
710 | " ## self.reduce_attention_dim = nn.Linear(in_features=(num_attention_heads*d_model), out_features=d_model)\n",
711 | "\n",
712 | " self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)\n",
713 | " \n",
714 | " self.loss = nn.CrossEntropyLoss()\n",
715 | " \n",
716 | " \n",
717 | " def forward(self, token_ids):\n",
718 | " \n",
719 | " word_embeddings = self.we(token_ids) \n",
720 | " position_encoded = self.pe(word_embeddings)\n",
721 | " \n",
722 | " ## For the decoder-only transformer, we need to use \"masked self-attention\" so that \n",
723 | " ## when we are training we can't cheat and look ahead at\n",
724 | " ## what words come after the current word.\n",
725 | " ## To create the mask we are creating a matrix where the lower triangle\n",
726 | " ## is filled with 0, and everything above the diagonal is filled with 0s.\n",
727 | " mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0)), device=self.device))\n",
728 | " ## NOTE: The device=self.device is needed because we are creating a new\n",
729 | " ## tensor, mask, in the forward() method, which, by default, goes\n",
730 | " ## to the CPU. If all we have is a CPU, then we don't need it, but\n",
731 | " ## if we want to train on a GPU, we need to make sure mask goes\n",
732 | " ## there too. Using self.device allows us tyo not worry about whether\n",
733 | " ## or not we are using a GPU or CPU or whatever, it will make sure\n",
734 | " ## mask is where it needs to go.\n",
735 | "\n",
736 | " ## We then replace the 0s above the digaonal, which represent the words\n",
737 | " ## we want to be masked out, with \"True\", and replace the 1s in the lower\n",
738 | " ## triangle, which represent the words we want to include when we calcualte\n",
739 | " ## self-attention for a specific word in the output, with \"False\".\n",
740 | " mask = mask == 0\n",
741 | " \n",
742 | " self_attention_values = self.self_attention(position_encoded, \n",
743 | " position_encoded, \n",
744 | " position_encoded, \n",
745 | " mask=mask)\n",
746 | " ## NOTE: If we were doing multi-head attention, we would\n",
747 | " ## calculate the self-attention values with the other attention objects\n",
748 | " ## like this...\n",
749 | " ##\n",
750 | " ## self_attention_values_2 = self.self_attention_2(...)\n",
751 | " ## self_attention_values 3 = self.self_attention_3(...)\n",
752 | " ## \n",
753 | " ## ...then we would concatenate all the self attention values...\n",
754 | " ##\n",
755 | " ## all_self_attention_values = torch.cat(self_attention_values_1, ...)\n",
756 | " ##\n",
757 | " ## ...and then run them through reduce_dim to get back to d_model values per token\n",
758 | " ##\n",
759 | " ## final_self_attention_values = self.reduce_attention_dim(all_self_attention_values)\n",
760 | " \n",
761 | " residual_connection_values = position_encoded + self_attention_values\n",
762 | " \n",
763 | " fc_layer_output = self.fc_layer(residual_connection_values)\n",
764 | " \n",
765 | " return fc_layer_output\n",
766 | " \n",
767 | " \n",
768 | " def configure_optimizers(self): \n",
769 | " ## configure_optimizers() simply passes the parameters we want to\n",
770 | " ## optimize to the optimzes and sets the learning rate\n",
771 | " return Adam(self.parameters(), lr=0.1)\n",
772 | " \n",
773 | " \n",
774 | " def training_step(self, batch, batch_idx): \n",
775 | " ## training_step() is called by Lightning trainer when \n",
776 | " ## we want to train the model.\n",
777 | " input_tokens, labels = batch # collect input\n",
778 | " output = self.forward(input_tokens[0])\n",
779 | " loss = self.loss(output, labels[0])\n",
780 | " \n",
781 | " return loss"
782 | ]
783 | },
784 | {
785 | "cell_type": "markdown",
786 | "id": "dae5288f-80dc-4ae9-96d5-55b951feeec7",
787 | "metadata": {},
788 | "source": [
789 | "# BAM!\n",
790 | "\n",
791 | "Now that we have coded up the `DecoderOnlyTransformer()` class, let's see if it works correctly without training. You never know, it might just work! \n",
792 | "\n",
793 | "To use the transformer, we run an input phrase, either **what is statquest \\** or **statquest is what \\**, through the transformer to get the next predicted token. If the next predicted token is not **\\**, then we add the predicted token to the input tokens and run that through the transformer and repeat until we get the **\\** token or reach the maximum sequence length."
794 | ]
795 | },
796 | {
797 | "cell_type": "code",
798 | "execution_count": null,
799 | "id": "0c7d3c9f-17c0-41e7-b654-c03540713b90",
800 | "metadata": {},
801 | "outputs": [],
802 | "source": [
803 | "## First, create a model from DecoderOnlyTransformer()\n",
804 | "model = DecoderOnlyTransformer(num_tokens=len(token_to_id), d_model=2, max_len=6)\n",
805 | "\n",
806 | "## Now create the input for the transformer...\n",
807 | "model_input = torch.tensor([token_to_id[\"what\"], \n",
808 | " token_to_id[\"is\"], \n",
809 | " token_to_id[\"statquest\"], \n",
810 | " token_to_id[\"\"]])\n",
811 | "input_length = model_input.size(dim=0)\n",
812 | "\n",
813 | "## Now get get predictions from the model\n",
814 | "predictions = model(model_input) \n",
815 | "## NOTE: \"predictions\" is the output from the fully connected layer,\n",
816 | "## not a softmax() function. We could, if we wanted to,\n",
817 | "## Run \"predictions\" through a softmax() function, but \n",
818 | "## since we're going to select the item with the largest value\n",
819 | "## we can just use argmax instead...\n",
820 | "## ALSO NOTE: \"predictions\" is a matrix, with one row of predicted values\n",
821 | "## per input token. Since we only want the prediction from the\n",
822 | "## last row (the most recent prediction) we use reverse index for the\n",
823 | "## row, -1.\n",
824 | "predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
825 | "## We'll store predicted_id in an array, predicted_ids, that\n",
826 | "## we'll add to each time we predict a new output token.\n",
827 | "predicted_ids = predicted_id\n",
828 | "\n",
829 | "## Now use a loop to predict output tokens until we get an \n",
830 | "## token.\n",
831 | "max_length = 6\n",
832 | "for i in range(input_length, max_length):\n",
833 | " if (predicted_id == token_to_id[\"\"]): # if the prediction is , then we are done\n",
834 | " break\n",
835 | " \n",
836 | " model_input = torch.cat((model_input, predicted_id))\n",
837 | " \n",
838 | " predictions = model(model_input) \n",
839 | " predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
840 | " predicted_ids = torch.cat((predicted_ids, predicted_id))\n",
841 | " \n",
842 | "## Now printout the predicted output phrase.\n",
843 | "print(\"Predicted Tokens:\\n\") \n",
844 | "for id in predicted_ids: \n",
845 | " print(\"\\t\", id_to_token[id.item()])"
846 | ]
847 | },
848 | {
849 | "cell_type": "markdown",
850 | "id": "5c451f90-3382-4931-92fc-5cd3670be7e3",
851 | "metadata": {},
852 | "source": [
853 | "And, without training, the transformer predicts **\\**, but we wanted it to predict **awesome \\** So, since the transformer didn't correctly respond to the prompt, we'll have to train it."
854 | ]
855 | },
856 | {
857 | "cell_type": "markdown",
858 | "id": "39d51f3f-02c7-4d30-8e18-43be37996454",
859 | "metadata": {},
860 | "source": [
861 | "----"
862 | ]
863 | },
864 | {
865 | "cell_type": "markdown",
866 | "id": "982ca23f-0c46-4aa9-aebf-32e5bcedece6",
867 | "metadata": {},
868 | "source": [
869 | "\n",
870 | "# Train the Decoder-Only Transformer!!!"
871 | ]
872 | },
873 | {
874 | "cell_type": "markdown",
875 | "id": "ddc78dd9-ca22-4c6a-bd0f-26540716620e",
876 | "metadata": {},
877 | "source": [
878 | "To train a decoder-only transformer, we simply create a Lightning `Trainer()` and train the transformer with the `dataloader` that we created earlier."
879 | ]
880 | },
881 | {
882 | "cell_type": "code",
883 | "execution_count": null,
884 | "id": "a96bd93c-31f4-4ef8-8c4f-fd33c3c20344",
885 | "metadata": {
886 | "tags": []
887 | },
888 | "outputs": [],
889 | "source": [
890 | "trainer = L.Trainer(max_epochs=30)\n",
891 | "trainer.fit(model, train_dataloaders=dataloader)"
892 | ]
893 | },
894 | {
895 | "cell_type": "markdown",
896 | "id": "0acf9d1f-ac29-4316-88bb-5c6d0f3e04b6",
897 | "metadata": {},
898 | "source": [
899 | "# Double BAM!!!\n",
900 | "\n",
901 | "Now that we've trained the transformer, let's use it!"
902 | ]
903 | },
904 | {
905 | "cell_type": "markdown",
906 | "id": "1b80a32f-ded6-4299-bf89-12c6bbd8ced2",
907 | "metadata": {},
908 | "source": [
909 | "----"
910 | ]
911 | },
912 | {
913 | "cell_type": "markdown",
914 | "id": "c040285c-c6de-4aac-b5b9-925e84f5e83f",
915 | "metadata": {},
916 | "source": [
917 | "\n",
918 | "# Use the Trained Transformer!!!"
919 | ]
920 | },
921 | {
922 | "cell_type": "markdown",
923 | "id": "0666f428-eb0d-4c43-91d5-48c67cdcc5f8",
924 | "metadata": {},
925 | "source": [
926 | "To use the transformer that we just trained, we just repeat what we did earlier, only this time we use the trained transformer instead of an untrained transformer. First, we'll see if it correctly responds to the prompt **What is StatQuest?**"
927 | ]
928 | },
929 | {
930 | "cell_type": "code",
931 | "execution_count": null,
932 | "id": "717eaa4e-2249-46a1-b01f-9ed76beed0de",
933 | "metadata": {},
934 | "outputs": [],
935 | "source": [
936 | "model_input = torch.tensor([token_to_id[\"what\"], \n",
937 | " token_to_id[\"is\"], \n",
938 | " token_to_id[\"statquest\"], \n",
939 | " token_to_id[\"\"]])\n",
940 | "input_length = model_input.size(dim=0)\n",
941 | "\n",
942 | "predictions = model(model_input) \n",
943 | "predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
944 | "predicted_ids = predicted_id\n",
945 | "\n",
946 | "for i in range(input_length, max_length):\n",
947 | " if (predicted_id == token_to_id[\"\"]): # if the prediction is , then we are done\n",
948 | " break\n",
949 | " \n",
950 | " model_input = torch.cat((model_input, predicted_id))\n",
951 | " \n",
952 | " predictions = model(model_input) \n",
953 | " predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
954 | " predicted_ids = torch.cat((predicted_ids, predicted_id))\n",
955 | " \n",
956 | "print(\"Predicted Tokens:\\n\") \n",
957 | "for id in predicted_ids: \n",
958 | " print(\"\\t\", id_to_token[id.item()])"
959 | ]
960 | },
961 | {
962 | "cell_type": "markdown",
963 | "id": "40d52e9f-81dc-4f51-a699-5a19127b4a19",
964 | "metadata": {},
965 | "source": [
966 | "Hooray!!! We got the correct output! Now let's see if it correctly responds to the prompt **StatQuest is what?**"
967 | ]
968 | },
969 | {
970 | "cell_type": "code",
971 | "execution_count": null,
972 | "id": "89881ada-c49d-4805-a989-21020d9ca4c9",
973 | "metadata": {},
974 | "outputs": [],
975 | "source": [
976 | "## Now let's ask the other question...\n",
977 | "model_input = torch.tensor([token_to_id[\"statquest\"], \n",
978 | " token_to_id[\"is\"], \n",
979 | " token_to_id[\"what\"], \n",
980 | " token_to_id[\"\"]])\n",
981 | "input_length = model_input.size(dim=0)\n",
982 | "\n",
983 | "predictions = model(model_input) \n",
984 | "predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
985 | "predicted_ids = predicted_id\n",
986 | "\n",
987 | "for i in range(input_length, max_length):\n",
988 | " if (predicted_id == token_to_id[\"\"]): # if the prediction is , then we are done\n",
989 | " break\n",
990 | " \n",
991 | " model_input = torch.cat((model_input, predicted_id))\n",
992 | " \n",
993 | " predictions = model(model_input) \n",
994 | " predicted_id = torch.tensor([torch.argmax(predictions[-1,:])])\n",
995 | " predicted_ids = torch.cat((predicted_ids, predicted_id))\n",
996 | " \n",
997 | "print(\"Predicted Tokens:\\n\") \n",
998 | "for id in predicted_ids: \n",
999 | " print(\"\\t\", id_to_token[id.item()])"
1000 | ]
1001 | },
1002 | {
1003 | "cell_type": "markdown",
1004 | "id": "e3a6dda5-f868-410e-8038-b1bd8d4024df",
1005 | "metadata": {},
1006 | "source": [
1007 | "And the output for both questions is **awesome \\**, which is exactly what we want.\n",
1008 | "\n",
1009 | "# TRIPLE BAM!!!"
1010 | ]
1011 | },
1012 | {
1013 | "cell_type": "markdown",
1014 | "id": "c5e297c9-9735-44fd-83ef-fd267cba327c",
1015 | "metadata": {},
1016 | "source": [
1017 | "**NOTE:** With all the comments in the `PositionEncoding()`, `Attention()`, and `DecoderOnlyTransformer` classes, it may seem like we had to write a lot of code to create a Decoder-Only Transformer. Not so. Below we see that all three classes are only a handful of lines of code. Isn't that bonkers? The most state-of-the-art model isn't that big of a deal!\n",
1018 | "\n",
1019 | "```python\n",
1020 | "class PositionEncoding(nn.Module):\n",
1021 | " \n",
1022 | " def __init__(self, d_model=2, max_len=6):\n",
1023 | " \n",
1024 | " super().__init__()\n",
1025 | " \n",
1026 | " pe = torch.zeros(max_len, d_model)\n",
1027 | " \n",
1028 | " position = torch.arange(start=0, end=max_len, step=1).float().unsqueeze(1)\n",
1029 | " embedding_index = torch.arange(start=0, end=d_model, step=2).float()\n",
1030 | " \n",
1031 | " div_term = 1/torch.tensor(10000.0)**(embedding_index / d_model)\n",
1032 | " \n",
1033 | "\n",
1034 | " pe[:, 0::2] = torch.sin(position * div_term) \n",
1035 | " pe[:, 1::2] = torch.cos(position * div_term) \n",
1036 | " \n",
1037 | " self.register_buffer('pe', pe) \n",
1038 | "\n",
1039 | " \n",
1040 | " def forward(self, word_embeddings):\n",
1041 | " \n",
1042 | " return word_embeddings + self.pe[:word_embeddings.size(0), :] \n",
1043 | "\n",
1044 | "class Attention(nn.Module): \n",
1045 | " \n",
1046 | " def __init__(self, d_model=2):\n",
1047 | " \n",
1048 | " super().__init__()\n",
1049 | " \n",
1050 | " self.W_q = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
1051 | " self.W_k = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
1052 | " self.W_v = nn.Linear(in_features=d_model, out_features=d_model, bias=False)\n",
1053 | " \n",
1054 | " self.row_dim = 0\n",
1055 | " self.col_dim = 1\n",
1056 | "\n",
1057 | " \n",
1058 | " def forward(self, encodings_for_q, encodings_for_k, encodings_for_v, mask=None):\n",
1059 | "\n",
1060 | " q = self.W_q(encodings_for_q)\n",
1061 | " k = self.W_k(encodings_for_k)\n",
1062 | " v = self.W_v(encodings_for_v)\n",
1063 | "\n",
1064 | " sims = torch.matmul(q, k.transpose(dim0=self.row_dim, dim1=self.col_dim))\n",
1065 | "\n",
1066 | " scaled_sims = sims / torch.tensor(k.size(self.col_dim)**0.5)\n",
1067 | "\n",
1068 | " if mask is not None:\n",
1069 | " scaled_sims = scaled_sims.masked_fill(mask=mask, value=-1e9)\n",
1070 | " \n",
1071 | " attention_percents = F.softmax(scaled_sims, dim=self.col_dim)\n",
1072 | " attention_scores = torch.matmul(attention_percents, v)\n",
1073 | " \n",
1074 | " return attention_scores\n",
1075 | "\n",
1076 | "class DecoderOnlyTransformer(L.LightningModule):\n",
1077 | " \n",
1078 | " def __init__(self, num_tokens=4, d_model=2, max_len=6):\n",
1079 | " \n",
1080 | " super().__init__()\n",
1081 | " \n",
1082 | " L.seed_everything(seed=42)\n",
1083 | " \n",
1084 | " self.we = nn.Embedding(num_embeddings=num_tokens, \n",
1085 | " embedding_dim=d_model) \n",
1086 | " self.pe = PositionEncoding(d_model=d_model, \n",
1087 | " max_len=max_len)\n",
1088 | " self.self_attention = Attention(d_model=d_model)\n",
1089 | " self.fc_layer = nn.Linear(in_features=d_model, out_features=num_tokens)\n",
1090 | " \n",
1091 | " self.loss = nn.CrossEntropyLoss()\n",
1092 | " \n",
1093 | " \n",
1094 | " def forward(self, token_ids):\n",
1095 | " \n",
1096 | " word_embeddings = self.we(token_ids) \n",
1097 | " position_encoded = self.pe(word_embeddings)\n",
1098 | " \n",
1099 | " mask = torch.tril(torch.ones((token_ids.size(dim=0), token_ids.size(dim=0))))\n",
1100 | " mask = mask == 0\n",
1101 | " \n",
1102 | " self_attention_values = self.self_attention(position_encoded, \n",
1103 | " position_encoded, \n",
1104 | " position_encoded, \n",
1105 | " mask=mask)\n",
1106 | " \n",
1107 | " residual_connection_values = position_encoded + self_attention_values \n",
1108 | " fc_layer_output = self.fc_layer(residual_connection_values)\n",
1109 | " \n",
1110 | " return fc_layer_output\n",
1111 | " \n",
1112 | " \n",
1113 | " def configure_optimizers(self): \n",
1114 | " return Adam(self.parameters(), lr=0.1)\n",
1115 | " \n",
1116 | " \n",
1117 | " def training_step(self, batch, batch_idx): \n",
1118 | " input_tokens, labels = batch # collect input\n",
1119 | " output = self.forward(input_tokens[0])\n",
1120 | " loss = self.loss(output, labels[0])\n",
1121 | " \n",
1122 | " return loss\n",
1123 | "```"
1124 | ]
1125 | },
1126 | {
1127 | "cell_type": "markdown",
1128 | "id": "3c21bd5a-7a4a-4ec6-adb2-d3fa38718fe2",
1129 | "metadata": {},
1130 | "source": [
1131 | "# BONUS BAM!!!\n",
1132 | "\n",
1133 | "Now that we can code our own transformer, let's do what people actually do in practice with transformers \n",
1134 | "\n",
1135 | "- **[Fine tune a pre-trained LLM](https://lightning.ai/lightning-ai/studios/instruction-finetuning-tinyllama-1-1b-llm)**\n",
1136 | "\n",
1137 | "- **[Introdcution to RAG (retrieval augmented generation)](https://lightning.ai/lightning-ai/studios/document-search-and-retrieval-using-rag)**\n",
1138 | "\n",
1139 | "- **[Use an pre-trained LLM chatbot paired with RAG (retrieval augmented generation)](https://lightning.ai/lightning-ai/studios/document-chat-assistant-using-rag)**"
1140 | ]
1141 | }
1142 | ],
1143 | "metadata": {
1144 | "kernelspec": {
1145 | "display_name": "Python 3 (ipykernel)",
1146 | "language": "python",
1147 | "name": "python3"
1148 | },
1149 | "language_info": {
1150 | "codemirror_mode": {
1151 | "name": "ipython",
1152 | "version": 3
1153 | },
1154 | "file_extension": ".py",
1155 | "mimetype": "text/x-python",
1156 | "name": "python",
1157 | "nbconvert_exporter": "python",
1158 | "pygments_lexer": "ipython3",
1159 | "version": "3.9.6"
1160 | }
1161 | },
1162 | "nbformat": 4,
1163 | "nbformat_minor": 5
1164 | }
1165 |
--------------------------------------------------------------------------------
/images/Brandmark_FullColor_Black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/Brandmark_FullColor_Black.png
--------------------------------------------------------------------------------
/images/attention_compute_k.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_compute_k.png
--------------------------------------------------------------------------------
/images/attention_compute_q.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_compute_q.png
--------------------------------------------------------------------------------
/images/attention_compute_v.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_compute_v.png
--------------------------------------------------------------------------------
/images/attention_equations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_equations.png
--------------------------------------------------------------------------------
/images/attention_final_scores.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_final_scores.png
--------------------------------------------------------------------------------
/images/attention_final_scores_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_final_scores_masked.png
--------------------------------------------------------------------------------
/images/attention_masking.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_masking.png
--------------------------------------------------------------------------------
/images/attention_q_times_kt.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_q_times_kt.png
--------------------------------------------------------------------------------
/images/attention_scaling_scores.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_scaling_scores.png
--------------------------------------------------------------------------------
/images/attention_softmax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_softmax.png
--------------------------------------------------------------------------------
/images/attention_softmax_masked.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/attention_softmax_masked.png
--------------------------------------------------------------------------------
/images/dec_transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/dec_transformer.png
--------------------------------------------------------------------------------
/images/decoder_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/decoder_diagram.png
--------------------------------------------------------------------------------
/images/enc_dec_attention_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/enc_dec_attention_1.png
--------------------------------------------------------------------------------
/images/enc_dec_transformer.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/enc_dec_transformer.png
--------------------------------------------------------------------------------
/images/encoder_diagram.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/encoder_diagram.png
--------------------------------------------------------------------------------
/images/expected_input_output_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/expected_input_output_1.png
--------------------------------------------------------------------------------
/images/expected_input_output_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/expected_input_output_2.png
--------------------------------------------------------------------------------
/images/masked_attention_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/masked_attention_1.png
--------------------------------------------------------------------------------
/images/pos_encoding_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/pos_encoding_1.png
--------------------------------------------------------------------------------
/images/pos_encoding_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/pos_encoding_2.png
--------------------------------------------------------------------------------
/images/self_attention_1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/self_attention_1.png
--------------------------------------------------------------------------------
/images/self_attention_2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/self_attention_2.png
--------------------------------------------------------------------------------
/images/squatch_eats_pizza.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/StatQuest/decoder_transformer_from_scratch/4faf475d765a7f6ff55bdf482676f1644b570e25/images/squatch_eats_pizza.png
--------------------------------------------------------------------------------