├── LICENSE ├── README.md ├── images ├── 42.png ├── a10.png ├── afterattention.png ├── archi.png ├── attention.png ├── embeddings.png ├── finallayer.png ├── freq_cis.png ├── god.png ├── heads.png ├── implllama3_30_0.png ├── implllama3_39_0.png ├── implllama3_41_0.png ├── implllama3_42_0.png ├── implllama3_50_0.png ├── implllama3_52_0.png ├── implllama3_54_0.png ├── karpathyminbpe.png ├── keys.png ├── keys0.png ├── last_norm.png ├── mask.png ├── model.png ├── norm.png ├── norm_after.png ├── q_per_token.png ├── qkmatmul.png ├── qkv.png ├── qsplit.png ├── rms.png ├── rope.png ├── ropesplit.png ├── softmax.png ├── stacked.png ├── swiglu.png ├── tokens.png ├── v0.png ├── value.png └── weightmatrix.png ├── llama3-from-scratch.ipynb └── requirements.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Nishant Aklecha 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llama3 implemented from scratch 2 | in this file, i implemented llama3 from scratch, one tensor and matrix multiplication at a time. 3 |
4 | also, im going to load tensors directly from the model file that meta provided for llama3, you need to download the weights before running this file. 5 | here is the offical link to download the weights: https://llama.meta.com/llama-downloads/ 6 | 7 |
8 | 9 |
10 | 11 | ## tokenizer 12 | im not going to implement a bpe tokenizer (but andrej karpathy has a really clean implementation) 13 |
14 | link to his implementation: https://github.com/karpathy/minbpe 15 | 16 |
17 | 18 |
19 | 20 | 21 | 22 | ```python 23 | from pathlib import Path 24 | import tiktoken 25 | from tiktoken.load import load_tiktoken_bpe 26 | import torch 27 | import json 28 | import matplotlib.pyplot as plt 29 | 30 | tokenizer_path = "Meta-Llama-3-8B/tokenizer.model" 31 | special_tokens = [ 32 | "<|begin_of_text|>", 33 | "<|end_of_text|>", 34 | "<|reserved_special_token_0|>", 35 | "<|reserved_special_token_1|>", 36 | "<|reserved_special_token_2|>", 37 | "<|reserved_special_token_3|>", 38 | "<|start_header_id|>", 39 | "<|end_header_id|>", 40 | "<|reserved_special_token_4|>", 41 | "<|eot_id|>", # end of turn 42 | ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] 43 | mergeable_ranks = load_tiktoken_bpe(tokenizer_path) 44 | tokenizer = tiktoken.Encoding( 45 | name=Path(tokenizer_path).name, 46 | pat_str=r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+", 47 | mergeable_ranks=mergeable_ranks, 48 | special_tokens={token: len(mergeable_ranks) + i for i, token in enumerate(special_tokens)}, 49 | ) 50 | 51 | tokenizer.decode(tokenizer.encode("hello world!")) 52 | ``` 53 | 54 | 55 | 56 | 57 | 'hello world!' 58 | 59 | 60 | 61 | ## reading the model file 62 | normally, reading this depends on how the model classes are written and the variable names inside them. 63 |
64 | but since we are implementing llama3 from scratch we will read the file one tensor at a time. 65 |
66 | 67 |
68 | 69 | 70 | ```python 71 | model = torch.load("Meta-Llama-3-8B/consolidated.00.pth") 72 | print(json.dumps(list(model.keys())[:20], indent=4)) 73 | ``` 74 | 75 | [ 76 | "tok_embeddings.weight", 77 | "layers.0.attention.wq.weight", 78 | "layers.0.attention.wk.weight", 79 | "layers.0.attention.wv.weight", 80 | "layers.0.attention.wo.weight", 81 | "layers.0.feed_forward.w1.weight", 82 | "layers.0.feed_forward.w3.weight", 83 | "layers.0.feed_forward.w2.weight", 84 | "layers.0.attention_norm.weight", 85 | "layers.0.ffn_norm.weight", 86 | "layers.1.attention.wq.weight", 87 | "layers.1.attention.wk.weight", 88 | "layers.1.attention.wv.weight", 89 | "layers.1.attention.wo.weight", 90 | "layers.1.feed_forward.w1.weight", 91 | "layers.1.feed_forward.w3.weight", 92 | "layers.1.feed_forward.w2.weight", 93 | "layers.1.attention_norm.weight", 94 | "layers.1.ffn_norm.weight", 95 | "layers.2.attention.wq.weight" 96 | ] 97 | 98 | 99 | 100 | ```python 101 | with open("Meta-Llama-3-8B/params.json", "r") as f: 102 | config = json.load(f) 103 | config 104 | ``` 105 | 106 | 107 | 108 | 109 | {'dim': 4096, 110 | 'n_layers': 32, 111 | 'n_heads': 32, 112 | 'n_kv_heads': 8, 113 | 'vocab_size': 128256, 114 | 'multiple_of': 1024, 115 | 'ffn_dim_multiplier': 1.3, 116 | 'norm_eps': 1e-05, 117 | 'rope_theta': 500000.0} 118 | 119 | 120 | 121 | ## we use this config to infer details about the model like 122 | 1. the model has 32 transformer layers 123 | 2. each multi-head attention block has 32 heads 124 | 3. the vocab size and so on 125 | 126 | 127 | ```python 128 | dim = config["dim"] 129 | n_layers = config["n_layers"] 130 | n_heads = config["n_heads"] 131 | n_kv_heads = config["n_kv_heads"] 132 | vocab_size = config["vocab_size"] 133 | multiple_of = config["multiple_of"] 134 | ffn_dim_multiplier = config["ffn_dim_multiplier"] 135 | norm_eps = config["norm_eps"] 136 | rope_theta = torch.tensor(config["rope_theta"]) 137 | ``` 138 | 139 | ## converting text to tokens 140 | here we use tiktoken (i think an openai library) as the tokenizer 141 |
142 | 143 |
144 | 145 | 146 | ```python 147 | prompt = "the answer to the ultimate question of life, the universe, and everything is " 148 | tokens = [128000] + tokenizer.encode(prompt) 149 | print(tokens) 150 | tokens = torch.tensor(tokens) 151 | prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens] 152 | print(prompt_split_as_tokens) 153 | ``` 154 | 155 | [128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220] 156 | ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' '] 157 | 158 | 159 | ## converting tokens to their embedding 160 | IM SORRY but this is the only part of the codebase where i use an inbuilt neural network module 161 |
162 | anyway, so our [17x1] tokens are now [17x4096], i.e. 17 embeddings (one for each token) of length 4096 163 |
164 |
165 | note: keep track of the shapes, it makes it much easier to understand everything 166 | 167 |
168 | 169 |
170 | 171 | 172 | ```python 173 | embedding_layer = torch.nn.Embedding(vocab_size, dim) 174 | embedding_layer.weight.data.copy_(model["tok_embeddings.weight"]) 175 | token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) 176 | token_embeddings_unnormalized.shape 177 | ``` 178 | 179 | 180 | 181 | 182 | torch.Size([17, 4096]) 183 | 184 | 185 | 186 | ## we then normalize the embedding using rms normalization 187 | please, note after this step the shapes dont change, the values are just normalized 188 |
189 | things to keep in mind, we need a norm_eps (from config) because we dont want to accidently set rms to 0 and divide by 0 190 |
191 | here is the formula: 192 |
193 | 194 |
195 | 196 | 197 | ```python 198 | # def rms_norm(tensor, norm_weights): 199 | # rms = (tensor.pow(2).mean(-1, keepdim=True) + norm_eps)**0.5 200 | # return tensor * (norm_weights / rms) 201 | def rms_norm(tensor, norm_weights): 202 | return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights 203 | ``` 204 | 205 | # building the first first layer of the transformer 206 | 207 | ### normalization 208 | you will see me accessing layer.0 from the model dict (this is the first layer) 209 |
210 | anyway, so after normalizing our shapes are still [17x4096] same as embedding but normalized 211 | 212 |
213 | 214 |
215 | 216 | 217 | ```python 218 | token_embeddings = rms_norm(token_embeddings_unnormalized, model["layers.0.attention_norm.weight"]) 219 | token_embeddings.shape 220 | ``` 221 | 222 | 223 | 224 | 225 | torch.Size([17, 4096]) 226 | 227 | 228 | 229 | ### attention implemented from scratch 230 | let's load the attention heads of the first layer of the transformer 231 |
232 | 233 |
234 | 235 |
236 | 237 | > when we load the query, key, value and output vectors from the model we notice the shapes to be [4096x4096], [1024x4096], [1024x4096], [4096x4096] 238 |
239 | > at first glance this is weird because ideally we want each q,k,v and o for each head individually 240 |
241 | > the authors of the code bundled them togeather because its easy it helps parallize attention head multiplication. 242 |
243 | > im going to unwrap everything... 244 | 245 | 246 | ```python 247 | print( 248 | model["layers.0.attention.wq.weight"].shape, 249 | model["layers.0.attention.wk.weight"].shape, 250 | model["layers.0.attention.wv.weight"].shape, 251 | model["layers.0.attention.wo.weight"].shape 252 | ) 253 | ``` 254 | 255 | torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096]) 256 | 257 | 258 | ### unwrapping query 259 | in the next section we will unwrap the queries from multiple attention heads, the resulting shape is [32x128x4096] 260 |

261 | here, 32 is the number of attention heads in llama3, 128 is the size of the query vector and 4096 is the size of the token embedding 262 | 263 | 264 | ```python 265 | q_layer0 = model["layers.0.attention.wq.weight"] 266 | head_dim = q_layer0.shape[0] // n_heads 267 | q_layer0 = q_layer0.view(n_heads, head_dim, dim) 268 | q_layer0.shape 269 | ``` 270 | 271 | 272 | 273 | 274 | torch.Size([32, 128, 4096]) 275 | 276 | 277 | 278 | ### im going to implement the first head of the first layer 279 | here i access the query weight matrix first head of the first layer, the size of this query weight matrix is [128x4096] 280 | 281 | 282 | ```python 283 | q_layer0_head0 = q_layer0[0] 284 | q_layer0_head0.shape 285 | ``` 286 | 287 | 288 | 289 | 290 | torch.Size([128, 4096]) 291 | 292 | 293 | 294 | ### we now multiply the query weights with the token embedding, to recive a query for the token 295 | here you can see the resulting shape is [17x128], this is because we have 17 tokens and for each token there is a 128 length query. 296 |
297 | 298 |
299 | 300 | 301 | ```python 302 | q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) 303 | q_per_token.shape 304 | ``` 305 | 306 | 307 | 308 | 309 | torch.Size([17, 128]) 310 | 311 | 312 | 313 | ## positioning encoding 314 | we are now at a stage where we have a query vector for each token in our prompt, but if you think about it -- the indivitually query vector has no idea about the position in the prompt. 315 |

316 | query: "the answer to the ultimate question of life, the universe, and everything is " 317 |

318 | in our prompt we have used "the" three times, we need the query vectors of all 3 "the" tokens to have different query vectors (each of size [1x128]) based on their positions in the query. we perform these rotations using RoPE (rotory positional embedding). 319 |

320 | ### RoPE 321 | watch this video (this is what i watched) to understand the math. 322 | https://www.youtube.com/watch?v=o29P0Kpobz0&t=530s 323 | 324 | 325 |
326 | 327 |
328 | 329 | 330 | ```python 331 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 332 | q_per_token_split_into_pairs.shape 333 | ``` 334 | 335 | 336 | 337 | 338 | torch.Size([17, 64, 2]) 339 | 340 | 341 | 342 | in the above step, we split the query vectors into pairs, we apply a rotational angle shift to each pair! 343 |

344 | we now have a vector of size [17x64x2], this is the 128 length queries split into 64 pairs for each token in the prompt! each of those 64 pairs will be rotated by m*(theta) where m is the position of the token for which we are rotating the query! 345 | 346 | 347 |
348 | 349 |
350 | 351 | ## using dot product of complex numbers to rotate a vector 352 |
353 | 354 |
355 | 356 | 357 | ```python 358 | zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 359 | zero_to_one_split_into_64_parts 360 | ``` 361 | 362 | 363 | 364 | 365 | tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250, 366 | 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656, 367 | 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062, 368 | 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469, 369 | 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875, 370 | 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281, 371 | 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688, 372 | 0.9844]) 373 | 374 | 375 | 376 | 377 | ```python 378 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 379 | freqs 380 | ``` 381 | 382 | 383 | 384 | 385 | tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01, 386 | 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01, 387 | 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02, 388 | 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03, 389 | 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03, 390 | 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04, 391 | 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04, 392 | 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05, 393 | 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05, 394 | 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06, 395 | 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06]) 396 | 397 | 398 | 399 | 400 | ```python 401 | freqs_for_each_token = torch.outer(torch.arange(17), freqs) 402 | freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 403 | freqs_cis.shape 404 | 405 | # viewing tjhe third row of freqs_cis 406 | value = freqs_cis[3] 407 | plt.figure() 408 | for i, element in enumerate(value[:17]): 409 | plt.plot([0, element.real], [0, element.imag], color='blue', linewidth=1, label=f"Index: {i}") 410 | plt.annotate(f"{i}", xy=(element.real, element.imag), color='red') 411 | plt.xlabel('Real') 412 | plt.ylabel('Imaginary') 413 | plt.title('Plot of one row of freqs_cis') 414 | plt.show() 415 | ``` 416 | 417 | 418 | 419 | ![png](images/implllama3_30_0.png) 420 | 421 | 422 | 423 | ### now that we have a complex number (the angle change vector) for every token's query element 424 | we can convert our queries (the one we split into pairs) as complex numbers and then dot product to rotate the query based on the position 425 |
426 | honeslty this is beautiful to think about :) 427 | 428 | 429 | ```python 430 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 431 | q_per_token_as_complex_numbers.shape 432 | ``` 433 | 434 | 435 | 436 | 437 | torch.Size([17, 64]) 438 | 439 | 440 | 441 | 442 | ```python 443 | q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis 444 | q_per_token_as_complex_numbers_rotated.shape 445 | ``` 446 | 447 | 448 | 449 | 450 | torch.Size([17, 64]) 451 | 452 | 453 | 454 | ### after rotated vector is obtained 455 | we can get back our the queries as pairs by viewing the complex numbers as real numbers again 456 | 457 | 458 | ```python 459 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) 460 | q_per_token_split_into_pairs_rotated.shape 461 | ``` 462 | 463 | 464 | 465 | 466 | torch.Size([17, 64, 2]) 467 | 468 | 469 | 470 | the rotated pairs are now merged, we now have a new query vector (rotated query vector) that is of the shape [17x128] where 17 is the number of tokens and the 128 is the dim of the query vector 471 | 472 | 473 | ```python 474 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 475 | q_per_token_rotated.shape 476 | ``` 477 | 478 | 479 | 480 | 481 | torch.Size([17, 128]) 482 | 483 | 484 | 485 | # keys (almost the same as queries) 486 |
487 | 488 |
489 | im lazy as fuck, so im not going to go through the math for keys, the only things you need to keep in mind are: 490 |
491 | > keys generate key vectors also of dimention 128 492 |
493 | > keys have only 1/4th the number of the weights as queries, this is because the weights for keys are shared across 4 heads at a time, to reduce the number of computations need 494 |
495 | > keys are also rotated to add positional info, just like queries because of the same reasons 496 | 497 | 498 | ```python 499 | k_layer0 = model["layers.0.attention.wk.weight"] 500 | k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) 501 | k_layer0.shape 502 | ``` 503 | 504 | 505 | 506 | 507 | torch.Size([8, 128, 4096]) 508 | 509 | 510 | 511 | 512 | ```python 513 | k_layer0_head0 = k_layer0[0] 514 | k_layer0_head0.shape 515 | ``` 516 | 517 | 518 | 519 | 520 | torch.Size([128, 4096]) 521 | 522 | 523 | 524 | 525 | ```python 526 | k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) 527 | k_per_token.shape 528 | ``` 529 | 530 | 531 | 532 | 533 | torch.Size([17, 128]) 534 | 535 | 536 | 537 | 538 | ```python 539 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 540 | k_per_token_split_into_pairs.shape 541 | ``` 542 | 543 | 544 | 545 | 546 | torch.Size([17, 64, 2]) 547 | 548 | 549 | 550 | 551 | ```python 552 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 553 | k_per_token_as_complex_numbers.shape 554 | ``` 555 | 556 | 557 | 558 | 559 | torch.Size([17, 64]) 560 | 561 | 562 | 563 | 564 | ```python 565 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 566 | k_per_token_split_into_pairs_rotated.shape 567 | ``` 568 | 569 | 570 | 571 | 572 | torch.Size([17, 64, 2]) 573 | 574 | 575 | 576 | 577 | ```python 578 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 579 | k_per_token_rotated.shape 580 | ``` 581 | 582 | 583 | 584 | 585 | torch.Size([17, 128]) 586 | 587 | 588 | 589 | ## at this stage now have both the rotated values of queries and keys, for each token. 590 |
591 | 592 |
593 | each of the queries and keys are now of shape [17x128]. 594 | 595 | ## in the next step we will multiply the queries and key matrices 596 | doing this will give us a score mapping each token with one another 597 |
598 | this score describes how well each token's query relates to the each tokens's key. 599 | THIS IS SELF ATTENTION :) 600 |
601 | the shape of the attention score matrix (qk_per_token) is [17x17] where 17 is the number of tokens in the prompt 602 | 603 |
604 | 605 |
606 | 607 | 608 | ```python 609 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(head_dim)**0.5 610 | qk_per_token.shape 611 | ``` 612 | 613 | 614 | 615 | 616 | torch.Size([17, 17]) 617 | 618 | 619 | 620 | # we now have to mask query key scores 621 | during the training process of llama3, the future token qk scores are masked. 622 |
623 | why? because during training we only learn to predict tokens using past tokens. 624 |
625 | as a result, during inference we set the future tokens to zero. 626 |
627 | 628 |
629 | 630 | 631 | ```python 632 | def display_qk_heatmap(qk_per_token): 633 | _, ax = plt.subplots() 634 | im = ax.imshow(qk_per_token.to(float).detach(), cmap='viridis') 635 | ax.set_xticks(range(len(prompt_split_as_tokens))) 636 | ax.set_yticks(range(len(prompt_split_as_tokens))) 637 | ax.set_xticklabels(prompt_split_as_tokens) 638 | ax.set_yticklabels(prompt_split_as_tokens) 639 | ax.figure.colorbar(im, ax=ax) 640 | 641 | display_qk_heatmap(qk_per_token) 642 | ``` 643 | 644 | 645 | 646 | ![png](images/implllama3_50_0.png) 647 | 648 | 649 | 650 | 651 | ```python 652 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 653 | mask = torch.triu(mask, diagonal=1) 654 | mask 655 | ``` 656 | 657 | 658 | 659 | 660 | tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 661 | [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 662 | [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 663 | [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 664 | [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 665 | [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 666 | [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 667 | [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 668 | [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 669 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf], 670 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf], 671 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], 672 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf], 673 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf], 674 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf], 675 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf], 676 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) 677 | 678 | 679 | 680 | 681 | ```python 682 | qk_per_token_after_masking = qk_per_token + mask 683 | display_qk_heatmap(qk_per_token_after_masking) 684 | ``` 685 | 686 | 687 | 688 | ![png](images/implllama3_52_0.png) 689 | 690 | 691 | 692 |
693 | 694 |
695 | 696 | 697 | ```python 698 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 699 | display_qk_heatmap(qk_per_token_after_masking_after_softmax) 700 | ``` 701 | 702 | 703 | 704 | ![png](images/implllama3_54_0.png) 705 | 706 | 707 | 708 | ## values (almost the end of attention) 709 | 710 |
711 | 712 |
713 | these scores (0-1) are used to determine how much of value matrix is used per token 714 |
715 | > just like keys, value weights are also shared acorss every 4 attention heads (to save computation) 716 |
717 | > as a result, the shape of the value weight matrix below is [8x128x4096] 718 | 719 | 720 | 721 | ```python 722 | v_layer0 = model["layers.0.attention.wv.weight"] 723 | v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) 724 | v_layer0.shape 725 | ``` 726 | 727 | 728 | 729 | 730 | torch.Size([8, 128, 4096]) 731 | 732 | 733 | 734 | the first layer, first head value weight matrix is given below 735 | 736 | 737 | ```python 738 | v_layer0_head0 = v_layer0[0] 739 | v_layer0_head0.shape 740 | ``` 741 | 742 | 743 | 744 | 745 | torch.Size([128, 4096]) 746 | 747 | 748 | 749 | ## value vectors 750 |
751 | 752 |
753 | we now use the value weghts to get the attention values per token, this is of size [17x128] where 17 is the number of tokens in the prompt and 128 is the dim of the value vector per token 754 | 755 | 756 | ```python 757 | v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) 758 | v_per_token.shape 759 | ``` 760 | 761 | 762 | 763 | 764 | torch.Size([17, 128]) 765 | 766 | 767 | 768 | ## attention 769 |
770 | 771 |
772 | the resultant attention vector after multipying with the values per token is of shape [17*128] 773 | 774 | 775 | ```python 776 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 777 | qkv_attention.shape 778 | ``` 779 | 780 | 781 | 782 | 783 | torch.Size([17, 128]) 784 | 785 | 786 | 787 | # multi head attention 788 |
789 | 790 |
791 | WE NOW HAVE THE ATTENTION VALUE OF THE FIRST LAYER AND FIRST HEAD 792 |
793 | now im going to run a loop and perform the exact same math as the cells above but for every head in the first layer 794 | 795 | 796 | ```python 797 | qkv_attention_store = [] 798 | 799 | for head in range(n_heads): 800 | q_layer0_head = q_layer0[head] 801 | k_layer0_head = k_layer0[head//4] # key weights are shared across 4 heads 802 | v_layer0_head = v_layer0[head//4] # value weights are shared across 4 heads 803 | q_per_token = torch.matmul(token_embeddings, q_layer0_head.T) 804 | k_per_token = torch.matmul(token_embeddings, k_layer0_head.T) 805 | v_per_token = torch.matmul(token_embeddings, v_layer0_head.T) 806 | 807 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 808 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 809 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 810 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 811 | 812 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 813 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 814 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 815 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 816 | 817 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 818 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 819 | mask = torch.triu(mask, diagonal=1) 820 | qk_per_token_after_masking = qk_per_token + mask 821 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 822 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 823 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 824 | qkv_attention_store.append(qkv_attention) 825 | 826 | len(qkv_attention_store) 827 | ``` 828 | 829 | 830 | 831 | 832 | 32 833 | 834 | 835 | 836 |
837 | 838 |
839 | we now have a the qkv_attention matrix for all 32 heads on the first layer, next im going to merge all attention scores into one large matrix of size [17x4096] 840 |
841 | we are almost at the end :) 842 | 843 | 844 | ```python 845 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 846 | stacked_qkv_attention.shape 847 | ``` 848 | 849 | 850 | 851 | 852 | torch.Size([17, 4096]) 853 | 854 | 855 | 856 | # weight matrix, one of the final steps 857 |
858 | 859 |
860 | one of the last things to do for a layer 0 attention is, is to multiply the weight matrix of the 861 | 862 | 863 | ```python 864 | w_layer0 = model["layers.0.attention.wo.weight"] 865 | w_layer0.shape 866 | ``` 867 | 868 | 869 | 870 | 871 | torch.Size([4096, 4096]) 872 | 873 | 874 | 875 | ### this is a simple linear layer, so we just matmul 876 | 877 | 878 | ```python 879 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer0.T) 880 | embedding_delta.shape 881 | ``` 882 | 883 | 884 | 885 | 886 | torch.Size([17, 4096]) 887 | 888 | 889 | 890 |
891 | 892 |
893 | we now have the change in the embedding value after attention, that should be adding to the original token embeddings 894 | 895 | 896 | ```python 897 | embedding_after_edit = token_embeddings_unnormalized + embedding_delta 898 | embedding_after_edit.shape 899 | ``` 900 | 901 | 902 | 903 | 904 | torch.Size([17, 4096]) 905 | 906 | 907 | 908 | ## we normalize and then run a feed forward neural network through the embedding delta 909 |
910 | 911 |
912 | 913 | 914 | ```python 915 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) 916 | embedding_after_edit_normalized.shape 917 | ``` 918 | 919 | 920 | 921 | 922 | torch.Size([17, 4096]) 923 | 924 | 925 | 926 | ## loading the ff weights and implementing the feed forward network 927 |
928 | 929 |
930 | in llama3, they used a SwiGLU feedforward network, this network architecture is really good at adding non linearity when needed by the model. 931 |
932 | its pretty standard to use this feed forward network architecture in llms these days 933 | 934 | 935 | ```python 936 | w1 = model["layers.0.feed_forward.w1.weight"] 937 | w2 = model["layers.0.feed_forward.w2.weight"] 938 | w3 = model["layers.0.feed_forward.w3.weight"] 939 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 940 | output_after_feedforward.shape 941 | ``` 942 | 943 | 944 | 945 | 946 | torch.Size([17, 4096]) 947 | 948 | 949 | 950 | # WE FINALLY HAVE NEW EDITED EMBEDDINGS FOR EACH TOKEN AFTER THE FIRST LAYER 951 | just 31 more layers to go before we are done (one for loop away) 952 |
953 | you can imagine this edited embedding as having information about all queries asked on the first layer 954 |
955 | now each layer will encode more and more complex queries on the quesions asked, until we have an embedding that knows everything about the next token that we need. 956 | 957 | 958 | ```python 959 | layer_0_embedding = embedding_after_edit+output_after_feedforward 960 | layer_0_embedding.shape 961 | ``` 962 | 963 | 964 | 965 | 966 | torch.Size([17, 4096]) 967 | 968 | 969 | 970 | # god, everything all at once 971 |
972 | 973 |
974 | yep, this is it. everything we did before, all at once, for every single layer. 975 |
976 | 977 | # have fun reading :) 978 | 979 | 980 | ```python 981 | final_embedding = token_embeddings_unnormalized 982 | for layer in range(n_layers): 983 | qkv_attention_store = [] 984 | layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) 985 | q_layer = model[f"layers.{layer}.attention.wq.weight"] 986 | q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) 987 | k_layer = model[f"layers.{layer}.attention.wk.weight"] 988 | k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) 989 | v_layer = model[f"layers.{layer}.attention.wv.weight"] 990 | v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) 991 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 992 | for head in range(n_heads): 993 | q_layer_head = q_layer[head] 994 | k_layer_head = k_layer[head//4] 995 | v_layer_head = v_layer[head//4] 996 | q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T) 997 | k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T) 998 | v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T) 999 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 1000 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 1001 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis) 1002 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 1003 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 1004 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 1005 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 1006 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 1007 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T)/(128)**0.5 1008 | mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf")) 1009 | mask = torch.triu(mask, diagonal=1) 1010 | qk_per_token_after_masking = qk_per_token + mask 1011 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 1012 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 1013 | qkv_attention_store.append(qkv_attention) 1014 | 1015 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 1016 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 1017 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 1018 | embedding_after_edit = final_embedding + embedding_delta 1019 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) 1020 | w1 = model[f"layers.{layer}.feed_forward.w1.weight"] 1021 | w2 = model[f"layers.{layer}.feed_forward.w2.weight"] 1022 | w3 = model[f"layers.{layer}.feed_forward.w3.weight"] 1023 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 1024 | final_embedding = embedding_after_edit+output_after_feedforward 1025 | ``` 1026 | 1027 | # we now have the final embedding, the best guess the model could make about the next token 1028 | the shape of the embedding is the same as regular token embeddings [17x4096] where 17 is the number of tokens and 4096 is the embedding dim 1029 |
1030 | 1031 |
1032 | 1033 | 1034 | ```python 1035 | final_embedding = rms_norm(final_embedding, model["norm.weight"]) 1036 | final_embedding.shape 1037 | ``` 1038 | 1039 | 1040 | 1041 | 1042 | torch.Size([17, 4096]) 1043 | 1044 | 1045 | 1046 | # finally, lets decode the embedding into the token value 1047 |
1048 | 1049 |
1050 | we will use the output decoder to convert the final embedding into a token 1051 | 1052 | 1053 | ```python 1054 | model["output.weight"].shape 1055 | ``` 1056 | 1057 | 1058 | 1059 | 1060 | torch.Size([128256, 4096]) 1061 | 1062 | 1063 | 1064 | # we use the embedding of the last token to predict the next value 1065 | hopefully in our case, 42 :) 1066 | note: 42 is the answer to "the answer to the ultimate question of life, the universe, and everything is ", according to the book "hitchhiker's guide to the galaxy", most mordern llms would answer with 42 here, which should validate our entire code! wish me luck :) 1067 | 1068 | 1069 | ```python 1070 | logits = torch.matmul(final_embedding[-1], model["output.weight"].T) 1071 | logits.shape 1072 | ``` 1073 | 1074 | 1075 | 1076 | 1077 | torch.Size([128256]) 1078 | 1079 | 1080 | 1081 | ### the model predicted token number 2983 as the next token, is this the token number for 42? 1082 | IM HYPING YOU UP, this is the last cell of code, hopefully you had fun :) 1083 | 1084 | 1085 | ```python 1086 | next_token = torch.argmax(logits, dim=-1) 1087 | next_token 1088 | ``` 1089 | 1090 | 1091 | 1092 | 1093 | tensor(2983) 1094 | 1095 | 1096 | 1097 | # lets fucking go 1098 |
1099 | 1100 |
1101 | 1102 | 1103 | ```python 1104 | tokenizer.decode([next_token.item()]) 1105 | ``` 1106 | 1107 | 1108 | 1109 | 1110 | '42' 1111 | 1112 | 1113 | 1114 | # thank you, i love you :) 1115 | 1116 | This is the end. Hopefully you enjoyed reading it! 1117 | 1118 | If you want to support my work 1119 | 1120 | 1. follow me on twitter https://twitter.com/naklecha 1121 | 2. or, buy me a coffee [https://www.buymeacoffee.com/naklecha](https://www.buymeacoffee.com/naklecha) 1122 | 1123 | Honestly, if you made it this far you already made my day :) 1124 | 1125 | ## what motivates me? 1126 | 1127 | My friends and I are on a mission - to make research more accessible! 1128 | We created a research lab called A10 - [AAAAAAAAAA.org](http://aaaaaaaaaa.org/) 1129 | 1130 | A10 twitter - https://twitter.com/aaaaaaaaaaorg 1131 | 1132 | our thesis: 1133 |
1134 | 1135 |
1136 | -------------------------------------------------------------------------------- /images/42.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/42.png -------------------------------------------------------------------------------- /images/a10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/a10.png -------------------------------------------------------------------------------- /images/afterattention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/afterattention.png -------------------------------------------------------------------------------- /images/archi.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/archi.png -------------------------------------------------------------------------------- /images/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/attention.png -------------------------------------------------------------------------------- /images/embeddings.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/embeddings.png -------------------------------------------------------------------------------- /images/finallayer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/finallayer.png -------------------------------------------------------------------------------- /images/freq_cis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/freq_cis.png -------------------------------------------------------------------------------- /images/god.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/god.png -------------------------------------------------------------------------------- /images/heads.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/heads.png -------------------------------------------------------------------------------- /images/implllama3_30_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_30_0.png -------------------------------------------------------------------------------- /images/implllama3_39_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_39_0.png -------------------------------------------------------------------------------- /images/implllama3_41_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_41_0.png -------------------------------------------------------------------------------- /images/implllama3_42_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_42_0.png -------------------------------------------------------------------------------- /images/implllama3_50_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_50_0.png -------------------------------------------------------------------------------- /images/implllama3_52_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_52_0.png -------------------------------------------------------------------------------- /images/implllama3_54_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/implllama3_54_0.png -------------------------------------------------------------------------------- /images/karpathyminbpe.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/karpathyminbpe.png -------------------------------------------------------------------------------- /images/keys.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/keys.png -------------------------------------------------------------------------------- /images/keys0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/keys0.png -------------------------------------------------------------------------------- /images/last_norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/last_norm.png -------------------------------------------------------------------------------- /images/mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/mask.png -------------------------------------------------------------------------------- /images/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/model.png -------------------------------------------------------------------------------- /images/norm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/norm.png -------------------------------------------------------------------------------- /images/norm_after.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/norm_after.png -------------------------------------------------------------------------------- /images/q_per_token.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/q_per_token.png -------------------------------------------------------------------------------- /images/qkmatmul.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/qkmatmul.png -------------------------------------------------------------------------------- /images/qkv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/qkv.png -------------------------------------------------------------------------------- /images/qsplit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/qsplit.png -------------------------------------------------------------------------------- /images/rms.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/rms.png -------------------------------------------------------------------------------- /images/rope.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/rope.png -------------------------------------------------------------------------------- /images/ropesplit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/ropesplit.png -------------------------------------------------------------------------------- /images/softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/softmax.png -------------------------------------------------------------------------------- /images/stacked.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/stacked.png -------------------------------------------------------------------------------- /images/swiglu.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/swiglu.png -------------------------------------------------------------------------------- /images/tokens.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/tokens.png -------------------------------------------------------------------------------- /images/v0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/v0.png -------------------------------------------------------------------------------- /images/value.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/value.png -------------------------------------------------------------------------------- /images/weightmatrix.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/naklecha/llama3-from-scratch/1b866ac638dceb667b2050692d5366844f81bc37/images/weightmatrix.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | sentencepiece 2 | tiktoken 3 | torch 4 | blobfile 5 | matplotlib --------------------------------------------------------------------------------