├── README.md ├── architectures_comparison.pdf └── building-llama-3-from-scratch.ipynb /README.md: -------------------------------------------------------------------------------- 1 | ## Building LLaMA 3 LLM From Scratch using Python 2 | 3 | LLaMA 3 is one of the most promising open-source model after Mistral, solving a wide range of tasks. I previously wrote a blog on Medium about creating an LLM with over 2.3 million parameters from scratch using the LLaMA architecture. Now that LLaMA-3 is released, we will recreate it in a simpler manner. 4 | 5 | We won't be using a GPU for this blog, but you'll need at least 17 GB of RAM because we are going to load some files that are more than 15 GB in size. If this is an issue for you, you can use Kaggle as a solution. Since we don't need a GPU, Kaggle offers 30 GB of RAM while using only CPU cores as an accelerator. 6 | 7 | Here is the blog link which guides you on how to create a 2.3+ million parameter LLM from scratch: 8 | [2.3+ Million Parameter LLM From Scratch](https://levelup.gitconnected.com/building-a-million-parameter-llm-from-scratch-using-python-f612398f06c2) 9 | 10 | ## Table of Contents 11 | 12 | - [Prerequisites](#prerequisites) 13 | - [Difference between LLaMA 2 and LLaMA 3](#difference-between-llama-2-and-llama3) 14 | - [Understanding the Transformer Architecture of LLaMA 3](#understanding-the-transformer-architecture-of-llama3) 15 | - [Pre-normalization Using RMSNorm](#1-pre-normalization-usingrmsnorm) 16 | - [SwiGLU Activation Function](#2-swiglu-activation-function) 17 | - [Rotary Embeddings (RoPE)](#3-rotary-embeddings-rope) 18 | - [Byte Pair Encoding (BPE) Algorithm](#4-byte-pair-encoding-bpe-algorithm) 19 | - [Setting the Stage](#setting-thestage) 20 | - [Understanding the File Structure](#understanding-the-file-structure) 21 | - [Tokenizing our input data](#tokenizing-our-inputdata) 22 | - [Creating Embedding for each Token](#creating-embedding-for-eachtoken) 23 | - [Normalization Using RMSNorm](#normalization-usingrmsnorm) 24 | - [Attention Heads (Query, Key, Values)](#attention-heads-query-keyvalues) 25 | - [Implementing RoPE](#implementing-rope) 26 | - [Implementing Self Attention](#implementing-self-attention) 27 | - [Implementing Multi-Head Attention](#implementing-multi-head-attention) 28 | - [Implementing SwiGLU Activation Function](#implementing-swiglu-activation-function) 29 | - [Merging everything](#merging-everything) 30 | - [Generating the Output](#generating-theoutput) 31 | 32 | 33 | ## Prerequisites 34 | The good part is we won't be using object-oriented programming (OOP) coding, just plain Python programming. However, you should have a basic understanding of neural networks and Transformer architecture. These are the only two prerequisites needed to follow along with the blog. 35 | 36 | | Topic | Link | 37 | | ---- | ---- | 38 | | Transformer Theory | [Video Link](https://www.youtube.com/watch?v=zxQyTK8quyY) | 39 | | Neural Networks Theory | [Video Link](https://www.youtube.com/watch?v=Jy4wM2X21u0) | 40 | | Python basics | [Video Link](https://www.youtube.com/watch?v=eWRfhZUzrAc) | 41 | 42 | ## Difference between LLaMA 2 and LLaMA 3 43 | Before looking into the technical details, the first thing you must know is that the entire architecture of LLaMA 3 is the same as LLaMA 2. So, if you haven't gone through the technical details of LLaMA 3 yet, it won't be a problem for you to follow this blog. Even if you don't have an understanding of LLaMA 2 architecture, don't worry, we will also look at a high-level overview of its technical details. This blog is designed for you either way. 44 | 45 | Here are some key points about LLaMA 2 and LLaMA 3. If you are already familiar with their architecture: 46 | 47 | | FEATURE | Llama 3 | Llama 2 | 48 | |----------------------------------------|----------------------------|------------------------------------| 49 | | Tokenizer | Tiktoken (developed by OpenAI) | SentencePiece | 50 | | Number of Parameters | 8B, 70B | 70B, 13B, 7B | 51 | | Training Data | 15T tokens | 2.2T tokens | 52 | | Context Length | 8192 tokens | 4096 tokens | 53 | | Attention Mechanism | Grouped-query attention | Grouped-query attention | 54 | | Fine-Tuned Models | Yes | Yes | 55 | | Performance | Better than Llama 2 on all benchmarks | Better than Llama 1 on most benchmarks | 56 | | Computational Requirements | Very high (70B model) | Very high (70B model) | 57 | | Availability | Open source | Open source | 58 | | Reinforcement learning from human feedback | Yes | Yes | 59 | | Number of languages supported | 30 languages | 20 languages | 60 | | Suitable for | Best for more demanding tasks, such as reasoning, coding, and proficiency tests | Good for more demanding tasks, such as reasoning, coding, and proficiency tests | 61 | 62 | ## Understanding the Transformer Architecture of LLaMA 3 63 | Understanding the architecture of LLaMA 3 is important before diving into coding it. For a better visual understanding, here's a comparison diagram between the vanilla Transformer, LLaMA 2/3, and Mistral. 64 | 65 | ![Comparison](https://i.ibb.co/kKN4Cks/Group-1-1.png) 66 | 67 | Let's look into the most important components of LLaMA 3 with a bit more detail: 68 | ### 1. Pre-normalization Using RMSNorm: 69 | In the LLaMA 3 approach which is the same as LLaMA 2, a technique called RMSNorm is used for normalizing the input of each transformer sub-layer. 70 | 71 | Imagine you're studying for a big exam, and you have a massive textbook full of chapters. Each chapter represents a different topic, but some chapters are more crucial for understanding the subject than others. 72 | Now, before diving into the entire textbook, you decide to evaluate the importance of each chapter. You don't want to spend the same amount of time on every chapter; you want to focus more on the critical ones. 73 | This is where Pre-normalization using RMSNorm comes into play for large language models (LLMs) like ChatGPT. It's like assigning a weight to each chapter based on its significance. Chapters that are fundamental to the subject get higher weights, while less important ones get lower weights. 74 | 75 | So, before going deeply into studying, you adjust your study plan based on the weighted importance of each chapter. You allocate more time and effort to the chapters with higher weights, ensuring you grasp the core concepts thoroughly. 76 | 77 | ![N](https://cdn-images-1.medium.com/v2/resize:fit:1000/0*GIr8bvByN_iAGQBW.png) 78 | 79 | Similarly, Pre-normalization using RMSNorm helps LLMs prioritize which parts of the text are more critical for understanding the context and meaning. It assigns higher weights to essential elements and lower weights to less crucial ones, ensuring the model focuses its attention where it's most needed for accurate comprehension. Interested readers can explore the detailed implementation of RMSNorm here. 80 | 81 | ### 2. SwiGLU Activation Function: 82 | LLaMA introduces the SwiGLU activation function, drawing inspiration from PaLM. 83 | 84 | Imagine you're a teacher trying to explain a complex topic to your students. You have a big whiteboard where you write down key points and draw diagrams to make things clearer. But sometimes, your handwriting might not be very neat, or your diagrams might not be perfectly drawn. This can make it harder for your students to understand the material. 85 | 86 | Now, imagine if you had a magic pen that automatically adjusted the size and style of your handwriting based on how important each point is. If something is really crucial, the pen writes it bigger and clearer, making it stand out. If it's less important, the pen writes it smaller, but still legible. 87 | SwiGLU is like that magic pen for large language models (LLMs) like ChatGPT. Before generating text, SwiGLU adjusts the importance of each word or phrase based on its relevance to the context. Just like the magic pen adjusts the size and style of your writing, SwiGLU adjusts the emphasis of each word or phrase. 88 | 89 | ![SwigLU](https://cdn-images-1.medium.com/max/1000/0*NtNn2CFuDNEH6jFC.png) 90 | 91 | So, when the LLM generates text, it can give more prominence to the important parts, making them more noticeable and ensuring they contribute more to the overall understanding of the text. This way, SwiGLU helps LLMs produce text that's clearer and easier to understand, much like how the magic pen helps you create clearer explanations for your students on the whiteboard. Further details on SwiGLU can be found in the associated paper. 92 | 93 | ### 3. Rotary Embeddings (RoPE): 94 | Rotary Embeddings, or RoPE, is a type of position embedding used in LLaMA 3. 95 | 96 | Imagine you're in a classroom, and you want to assign seats to students for group discussions. Typically, you might arrange the seats in rows and columns, with each student having a fixed position. However, in some cases, you want to create a more dynamic seating arrangement where students can move around and interact more freely. 97 | 98 | ROPE is like a special seating arrangement that allows students to rotate and change positions while still maintaining their relative positions to each other. Instead of being fixed in one place, students can now move around in a circular motion, allowing for more fluid interactions. 99 | 100 | In this scenario, each student represents a word or token in a text sequence, and their position corresponds to their position in the sequence. Just like how ROPE allows students to rotate and change positions, ROPE allows the positional embeddings of words in a text sequence to dynamically change based on their relative positions to each other. 101 | So, when processing text, instead of treating positional embeddings as fixed and static, ROPE introduces a rotational aspect, allowing for more flexible representations that capture the dynamic relationships between words in the sequence. This flexibility helps models like ChatGPT better understand and generate text that flows naturally and maintains coherence, similar to how a dynamic seating arrangement fosters more interactive discussions in a classroom. Those interested in the mathematical details can refer to the RoPE paper. 102 | 103 | ### 4. Byte Pair Encoding (BPE) Algorithm 104 | 105 | LLaMA 3 uses Byte Pair Encoding (BPE) from the tiktoken library introduced by OpenAI, whereas the LLaMA 2 tokenizer BPE is based on the sentencepiece library. There is a slight difference between them, but 106 | 107 | first, let's learn what BPE actually is. 108 | 109 | Let's start with a simple example. Suppose we have a text corpus with the words: "ab", "bc", "bcd", and "cde". We begin by initializing our vocabulary with all the individual characters in the text corpus, so our initial vocabulary is {"a", "b", "c", "d", "e"}. 110 | 111 | Next, we calculate the frequency of each character in the text corpus. For our example, the frequencies are: {"a": 1, "b": 3, "c": 3, "d": 2, "e": 1}. 112 | 113 | Now, we start the merging process. We repeat the following steps until our vocabulary reaches the desired size: 114 | 115 | 1. First, we find the most frequent pair of consecutive characters. In this case, the most frequent pair is "bc" with a frequency of 2. We then merge this pair to create a new subword unit "bc". After merging, we update the frequency counts to reflect the new subword unit. The updated frequency is {"a": 1, "b": 2, "c": 2, "d": 2, "e": 1, "bc": 2}. We add the new subword unit "bc" to our vocabulary, which now becomes {"a", "b", "c", "d", "e", "bc"}. 116 | 117 | 2. We repeat the process. The next most frequent pair is "cd". We merge "cd" to form a new subword unit "cd" and update the frequency counts. The updated frequency is {"a": 1, "b": 2, "c": 1, "d": 1, "e": 1, "bc": 2, "cd": 2}. We add "cd" to the vocabulary, resulting in {"a", "b", "c", "d", "e", "bc", "cd"}. 118 | 119 | 3. Continuing the process, the next frequent pair is "de". We merge "de" to form the subword unit "de" and update the frequency counts to {"a": 1, "b": 2, "c": 1, "d": 1, "e": 0, "bc": 2, "cd": 1, "de": 1}. We add "de" to the vocabulary, making it {"a", "b", "c", "d", "e", "bc", "cd", "de"}. 120 | 121 | 4. Next, we find "ab" as the most frequent pair. We merge "ab" to form the subword unit "ab" and update the frequency counts to {"a": 0, "b": 1, "c": 1, "d": 1, "e": 0, "bc": 2, "cd": 1, "de": 1, "ab": 1}. 122 | 123 | 5. We add "ab" to the vocabulary, which becomes {"a", "b", "c", "d", "e", "bc", "cd", "de", "ab"}. 124 | 125 | 6. Then, the next frequent pair is "bcd". We merge "bcd" to form the subword unit "bcd" and update the frequency counts to {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "bc": 1, "cd": 0, "de": 1, "ab": 1, "bcd": 1}. We add "bcd" to the vocabulary, resulting in {"a", "b", "c", "d", "e", "bc", "cd", "de", "ab", "bcd"}. 126 | 127 | 7. Finally, the most frequent pair is "cde". We merge "cde" to form the subword unit "cde" and update the frequency counts to {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "bc": 1, "cd": 0, "de": 0, "ab": 1, "bcd": 1, "cde": 1}. We add "cde" to the vocabulary, making it {"a", "b", "c", "d", "e", "bc", "cd", "de", "ab", "bcd", "cde"}. 128 | 129 | This technique can improve the performance of LLMs and handle rare and out-of-vocabulary words. The big difference between TikToken BPE and sentencepiece BPE is that TikToken BPE doesn't always split words into smaller parts if the whole word is already known. For example, if "hugging" is in the vocabulary, it stays as one token instead of splitting into ["hug","ging"]. 130 | 131 | ## Setting the Stage 132 | We will be working with a small range of Python libraries, but it's better to install them to avoid encountering "no module found" errors. 133 | 134 | 135 | ```python 136 | !pip install sentencepiece tiktoken torch blobfile matplotlib huggingface_hub 137 | ``` 138 | 139 | Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.10/site-packages (0.2.0) 140 | Requirement already satisfied: tiktoken in /opt/conda/lib/python3.10/site-packages (0.7.0) 141 | Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (2.1.2+cpu) 142 | Requirement already satisfied: blobfile in /opt/conda/lib/python3.10/site-packages (2.1.1) 143 | Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (3.7.5) 144 | Requirement already satisfied: huggingface_hub in /opt/conda/lib/python3.10/site-packages (0.22.2) 145 | Requirement already satisfied: regex>=2022.1.18 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2023.12.25) 146 | Requirement already satisfied: requests>=2.26.0 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2.31.0) 147 | Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.13.1) 148 | Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch) (4.9.0) 149 | Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.12) 150 | Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (3.2.1) 151 | Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2) 152 | Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch) (2024.2.0) 153 | Requirement already satisfied: pycryptodomex~=3.8 in /opt/conda/lib/python3.10/site-packages (from blobfile) (3.20.0) 154 | Requirement already satisfied: urllib3<3,>=1.25.3 in /opt/conda/lib/python3.10/site-packages (from blobfile) (1.26.18) 155 | Requirement already satisfied: lxml~=4.9 in /opt/conda/lib/python3.10/site-packages (from blobfile) (4.9.4) 156 | Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.2.0) 157 | Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (0.12.1) 158 | Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (4.47.0) 159 | Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.4.5) 160 | Requirement already satisfied: numpy<2,>=1.20 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.26.4) 161 | Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (21.3) 162 | Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (9.5.0) 163 | Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (3.1.1) 164 | Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0) 165 | Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (6.0.1) 166 | Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.66.1) 167 | Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) 168 | Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.3.2) 169 | Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.6) 170 | Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2024.2.2) 171 | Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.3) 172 | Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.3.0) 173 | 174 | 175 | After installing the required libraries, we need to download some files. Since we're going to replicate the architecture of llama-3–8B, you must have an account on HuggingFace. Additionally, since llama-3 is a gated model, you have to accept their terms and conditions to access model content. 176 | 177 | Here are the steps: 178 | 1. Create an HuggingFace account from this [link](https://huggingface.co/join?next=%2Fmeta-llama%2FMeta-Llama-3-8B) 179 | 2. Accept the terms and conditions of llama-3–8B from this [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B) 180 | 181 | Once you've completed both of these steps, Now we have to download some files. There are two options to do that: 182 | 183 | (Option 1: Manual) Go to the llama-3–8B HF directory from this [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main/original) and manually download each of these three files. 184 | 185 | ![](https://cdn-images-1.medium.com/max/1000/1*QpaH8EzAEEZsLv_EJ1OsFg.png) 186 | 187 | (options 2: Coding) We can use the hugging_face library, which we installed earlier, to download all of these files. However, first, we need to log in to HuggingFace Hub within our working notebook using our HF Token. You can create a new token or access it from this [link](https://huggingface.co/settings/tokens). 188 | 189 | 190 | ```python 191 | # Import the `notebook_login` function from the `huggingface_hub` module. 192 | from huggingface_hub import notebook_login 193 | 194 | # Execute the `notebook_login` function to log in to the Hugging Face Hub. 195 | notebook_login() 196 | ``` 197 | 198 | 199 | VBox(children=(HTML(value='
", # Marks the beginning of a text sequence. 412 | "<|end_of_text|>", # Marks the end of a text sequence. 413 | "<|reserved_special_token_0|>", # Reserved for future use. 414 | "<|reserved_special_token_1|>", # Reserved for future use. 415 | "<|reserved_special_token_2|>", # Reserved for future use. 416 | "<|reserved_special_token_3|>", # Reserved for future use. 417 | "<|start_header_id|>", # Indicates the start of a header ID. 418 | "<|end_header_id|>", # Indicates the end of a header ID. 419 | "<|reserved_special_token_4|>", # Reserved for future use. 420 | "<|eot_id|>", # Marks the end of a turn (in a conversational context). 421 | ] + [f"<|reserved_special_token_{i}|>" for i in range(5, 256 - 5)] # A large set of tokens reserved for future use. 422 | ``` 423 | 424 | Next we define the rules for splitting text into tokens by specifying different patterns to match various types of substrings in the input text. Here's how we can do that. 425 | 426 | 427 | ```python 428 | # patterns based on which text will be break into tokens 429 | tokenize_breaker = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" 430 | ``` 431 | 432 | It can extracts words, contractions, numbers (up to three digits), and sequences of non-whitespace characters from the input text, you can customize it based on your requirements. 433 | We need to code a simple tokenizer function using the TikToken BPE, which takes three inputs: tokenizer_model, tokenize_breaker, and special_tokens. This function will encode/decode our input text accordingly. 434 | 435 | 436 | ```python 437 | # Initialize tokenizer with specified parameters 438 | tokenizer = tiktoken.Encoding( 439 | 440 | # make sure to set path to tokenizer.model file 441 | name = "/kaggle/working/llama-3-8B/original/tokenizer.model", 442 | 443 | # Define tokenization pattern string 444 | pat_str = tokenize_breaker, 445 | 446 | # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3 447 | mergeable_ranks = tokenizer_model, 448 | 449 | # Set special tokens with indices 450 | special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)}, 451 | ) 452 | 453 | # Encode "hello world!" and decode tokens to string 454 | tokenizer.decode(tokenizer.encode("hello world!")) 455 | ``` 456 | 457 | 458 | 459 | 460 | 'hello world!' 461 | 462 | 463 | 464 | To verify that our encoder function methods work correctly, we pass "Hello World" into it. First, it encodes the text, transforming it into numerical values. Then, it decodes it back to text, resulting in "hello world!". This confirms that the function is working correctly. Let's tokenize our input. 465 | 466 | 467 | ```python 468 | # input prompt 469 | prompt = "the answer to the ultimate question of life, the universe, and everything is " 470 | 471 | # Encode the prompt using the tokenizer and prepend a special token (128000) 472 | tokens = [128000] + tokenizer.encode(prompt) 473 | 474 | print(tokens) # Print the encoded tokens 475 | 476 | # Convert the list of tokens into a PyTorch tensor 477 | tokens = torch.tensor(tokens) 478 | 479 | # Decode each token back into its corresponding string 480 | prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens] 481 | 482 | print(prompt_split_as_tokens) # Print the decoded tokens 483 | ``` 484 | 485 | [128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220] 486 | ['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' '] 487 | 488 | 489 | We encoded our input text "the answer to the ultimate question of life, the universe, and everything is " starting with a special token. 490 | 491 | ## Creating Embedding for each Token 492 | 493 | If we check the length of our input vector, it would be: 494 | 495 | 496 | ```python 497 | # checking dimension of input vector and embedding vector from llama-3 architecture 498 | print(dim, len(tokens)) 499 | ``` 500 | 501 | 4096 17 502 | 503 | 504 | Our input vectors, which are currently of dimension (17x1), need to be transformed into embeddings for each tokenized word. This means our (17x1) tokens will become (17x4096), where each token has a corresponding embedding of length 4096. 505 | 506 | 507 | ```python 508 | # Define embedding layer with vocab size and embedding dimension 509 | embedding_layer = torch.nn.Embedding(vocab_size, dim) 510 | 511 | # Copy pre-trained token embeddings to the embedding layer 512 | embedding_layer.weight.data.copy_(model["tok_embeddings.weight"]) 513 | 514 | # Get token embeddings for given tokens, converting to torch.bfloat16 format 515 | token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16) 516 | 517 | # Print shape of resulting token embeddings 518 | token_embeddings_unnormalized.shape 519 | ``` 520 | 521 | 522 | 523 | 524 | torch.Size([17, 4096]) 525 | 526 | 527 | 528 | These embeddings are not normalized, and it will have a serious effect if we don't normalize them. In the next section, we will perform normalization on our input vectors. 529 | 530 | ## Normalization Using RMSNorm 531 | We will normalize the input vectors using the same formula we have seen earlier for RMSNorm to ensure our inputs are normalized. 532 | 533 | ![](https://cdn-images-1.medium.com/max/1000/0*GIr8bvByN_iAGQBW.png) 534 | 535 | 536 | ```python 537 | # Calculating RMSNorm 538 | def rms_norm(tensor, norm_weights): 539 | 540 | # Calculate the mean of the square of tensor values along the last dimension 541 | squared_mean = tensor.pow(2).mean(-1, keepdim=True) 542 | 543 | # Add a small value to avoid division by zero 544 | normalized = torch.rsqrt(squared_mean + norm_eps) 545 | 546 | # Multiply normalized tensor by the provided normalization weights 547 | return (tensor * normalized) * norm_weights 548 | ``` 549 | 550 | We will use the attention weights from layers_0 to normalize our unnormalized embeddings. The reason for using layer_0 is that we are now creating the first layer of our LLaMA-3 transformer architecture. 551 | 552 | 553 | ```python 554 | # using RMS normalization and provided normalization weights 555 | token_embeddings = rms_norm(token_embeddings_unnormalized, 556 | model["layers.0.attention_norm.weight"]) 557 | 558 | # Print the shape of the resulting token embeddings 559 | token_embeddings.shape 560 | ``` 561 | 562 | 563 | 564 | 565 | torch.Size([17, 4096]) 566 | 567 | 568 | 569 | You may already know that the dimension won't change because we are only normalizing the vectors and nothing else. 570 | 571 | ## Attention Heads (Query, Key, Values) 572 | first, let's load the query, key, value and output vectors from the model. 573 | 574 | 575 | ```python 576 | # Print the shapes of different weights 577 | print( 578 | # Query weight shape 579 | model["layers.0.attention.wq.weight"].shape, 580 | 581 | # Key weight shape 582 | model["layers.0.attention.wk.weight"].shape, 583 | 584 | # Value weight shape 585 | model["layers.0.attention.wv.weight"].shape, 586 | 587 | # Output weight shape 588 | model["layers.0.attention.wo.weight"].shape 589 | ) 590 | ``` 591 | 592 | torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096]) 593 | 594 | 595 | The dimensions indicate that the model weights we downloaded are not for each head individually but for multiple attention heads due to implementing a parallel approach/training. However, we can unwrap these matrices to make them available for a single head only. 596 | 597 | 598 | ```python 599 | # Retrieve query weight for the first layer of attention 600 | q_layer0 = model["layers.0.attention.wq.weight"] 601 | 602 | # Calculate dimension per head 603 | head_dim = q_layer0.shape[0] // n_heads 604 | 605 | # Reshape query weight to separate heads 606 | q_layer0 = q_layer0.view(n_heads, head_dim, dim) 607 | 608 | # Print the shape of the reshaped query weight tensor 609 | q_layer0.shape 610 | ``` 611 | 612 | 613 | 614 | 615 | torch.Size([32, 128, 4096]) 616 | 617 | 618 | 619 | Here, 32 is the number of attention heads in Llama-3, 128 is the size of the query vector, and 4096 is the size of the token embedding. 620 | We can access the query weight matrix of the first head of the first layer using: 621 | 622 | 623 | ```python 624 | # Extract the query weight for the first head of the first layer of attention 625 | q_layer0_head0 = q_layer0[0] 626 | 627 | # Print the shape of the extracted query weight tensor for the first head 628 | q_layer0_head0.shape 629 | ``` 630 | 631 | 632 | 633 | 634 | torch.Size([128, 4096]) 635 | 636 | 637 | 638 | To find the query vector for each token, we multiply the query weights with the token embedding. 639 | 640 | 641 | ```python 642 | # Matrix multiplication: token embeddings with transpose of query weight for first head 643 | q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T) 644 | 645 | # Shape of resulting tensor: queries per token 646 | q_per_token.shape 647 | ``` 648 | 649 | 650 | 651 | 652 | torch.Size([17, 128]) 653 | 654 | 655 | 656 | The query vectors don't inherently know their position in the prompt, so we'll use RoPE to make them aware of it. 657 | 658 | ## Implementing RoPE 659 | 660 | We split the query vectors into pairs and then apply a rotational angle shift to each pair. 661 | 662 | 663 | ```python 664 | # Convert queries per token to float and split into pairs 665 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 666 | 667 | # Print the shape of the resulting tensor after splitting into pairs 668 | q_per_token_split_into_pairs.shape 669 | ``` 670 | 671 | 672 | 673 | 674 | torch.Size([17, 64, 2]) 675 | 676 | 677 | 678 | We have a vector of size [17x64x2], which represents the 128-length queries split into 64 pairs for each token in the prompt. Each pair will be rotated by m*theta, where m is the position of the token for which we are rotating the query. 679 | We'll use the dot product of complex numbers to rotate a vector. 680 | 681 | 682 | ```python 683 | # Generate values from 0 to 1 split into 64 parts 684 | zero_to_one_split_into_64_parts = torch.tensor(range(64))/64 685 | 686 | # Print the resulting tensor 687 | zero_to_one_split_into_64_parts 688 | ``` 689 | 690 | 691 | 692 | 693 | tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250, 694 | 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656, 695 | 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062, 696 | 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469, 697 | 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875, 698 | 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281, 699 | 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688, 700 | 0.9844]) 701 | 702 | 703 | 704 | After the splitting step, we are going to calculate the frequency of it. 705 | 706 | 707 | ```python 708 | # Calculate frequencies using a power operation 709 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 710 | 711 | # Display the resulting frequencies 712 | freqs 713 | ``` 714 | 715 | 716 | 717 | 718 | tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01, 719 | 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01, 720 | 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02, 721 | 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03, 722 | 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03, 723 | 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04, 724 | 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04, 725 | 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05, 726 | 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05, 727 | 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06, 728 | 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06]) 729 | 730 | 731 | 732 | Now, with a complex number for each token's query element, we convert our queries into complex numbers and then rotate them based on their position using dot product. 733 | 734 | 735 | ```python 736 | # Convert queries per token to complex numbers 737 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 738 | 739 | q_per_token_as_complex_numbers.shape 740 | # Output: torch.Size([17, 64]) 741 | 742 | # Calculate frequencies for each token using outer product of arange(17) and freqs 743 | freqs_for_each_token = torch.outer(torch.arange(17), freqs) 744 | 745 | # Calculate complex numbers from frequencies_for_each_token using polar coordinates 746 | freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 747 | 748 | # Rotate complex numbers by frequencies 749 | q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis 750 | 751 | q_per_token_as_complex_numbers_rotated.shape 752 | # Output: torch.Size([17, 64]) 753 | ``` 754 | 755 | 756 | 757 | 758 | torch.Size([17, 64]) 759 | 760 | 761 | 762 | After obtaining the rotated vector, we can revert back to our original queries as pairs by viewing the complex numbers as real numbers again. 763 | 764 | 765 | ```python 766 | # Convert rotated complex numbers back to real numbers 767 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated) 768 | 769 | # Print the shape of the resulting tensor 770 | q_per_token_split_into_pairs_rotated.shape 771 | ``` 772 | 773 | 774 | 775 | 776 | torch.Size([17, 64, 2]) 777 | 778 | 779 | 780 | The rotated pairs are now merged, resulting in a new query vector (rotated query vector) that has the shape [17x128], where 17 is the number of tokens and 128 is the dimension of the query vector. 781 | 782 | 783 | ```python 784 | # Reshape rotated token queries to match the original shape 785 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 786 | 787 | # Print the shape of the resulting tensor 788 | q_per_token_rotated.shape 789 | ``` 790 | 791 | 792 | 793 | 794 | torch.Size([17, 128]) 795 | 796 | 797 | 798 | For keys, the process is similar, but keep in mind that key vectors are also 128-dimensional. Keys have only 1/4th the number of weights as queries because they are shared across 4 heads at a time to minimize computations. Keys are also rotated to include positional information, similar to queries. 799 | 800 | 801 | ```python 802 | # Extract the weight tensor for the attention mechanism's key in the first layer of the model 803 | k_layer0 = model["layers.0.attention.wk.weight"] 804 | 805 | # Reshape key weight for the first layer of attention to separate heads 806 | k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim) 807 | 808 | # Print the shape of the reshaped key weight tensor 809 | k_layer0.shape # Output: torch.Size([8, 128, 4096]) 810 | 811 | # Extract the key weight for the first head of the first layer of attention 812 | k_layer0_head0 = k_layer0[0] 813 | 814 | # Print the shape of the extracted key weight tensor for the first head 815 | k_layer0_head0.shape # Output: torch.Size([128, 4096]) 816 | 817 | # Calculate key per token by matrix multiplication 818 | k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T) 819 | 820 | # Print the shape of the resulting tensor representing keys per token 821 | k_per_token.shape # Output: torch.Size([17, 128]) 822 | 823 | # Split key per token into pairs and convert to float 824 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 825 | 826 | # Print the shape of the resulting tensor after splitting into pairs 827 | k_per_token_split_into_pairs.shape # Output: torch.Size([17, 64, 2]) 828 | 829 | # Convert key per token to complex numbers 830 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 831 | 832 | # Print the shape of the resulting tensor representing key per token as complex numbers 833 | k_per_token_as_complex_numbers.shape # Output: torch.Size([17, 64]) 834 | 835 | # Rotate complex key per token by frequencies 836 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 837 | 838 | # Print the shape of the rotated complex key per token 839 | k_per_token_split_into_pairs_rotated.shape # Output: torch.Size([17, 64, 2]) 840 | 841 | # Reshape rotated key per token to match the original shape 842 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 843 | 844 | # Print the shape of the rotated key per token 845 | k_per_token_rotated.shape # Output: torch.Size([17, 128]) 846 | ``` 847 | 848 | 849 | 850 | 851 | torch.Size([17, 128]) 852 | 853 | 854 | 855 | We now have the rotated queries and keys for each token, with each being of size [17x128]. 856 | 857 | ## Implementing Self Attention 858 | Multiplying the query and key matrices will give us a score that maps each token to another. This score indicates the relationship between each token's query and key. 859 | 860 | 861 | ```python 862 | # Calculate query-key dot products per token 863 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim) ** 0.5 864 | 865 | # Print the shape of the resulting tensor representing query-key dot products per token 866 | qk_per_token.shape 867 | ``` 868 | 869 | 870 | 871 | 872 | torch.Size([17, 17]) 873 | 874 | 875 | 876 | [17x17] Shape represents attention score (qk_per_token) where 17 is the number of tokens in the prompt. 877 | We need to mask the query-key scores. During training, future token query-key scores are masked because we only learn to predict tokens using past tokens. As a result, during inference, we set the future tokens to zero. 878 | 879 | 880 | ```python 881 | # Create a mask tensor filled with negative infinity values 882 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 883 | 884 | # Set upper triangular part of the mask tensor to negative infinity 885 | mask = torch.triu(mask, diagonal=1) 886 | 887 | # Print the resulting mask tensor 888 | mask 889 | ``` 890 | 891 | 892 | 893 | 894 | tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 895 | [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 896 | [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 897 | [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 898 | [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 899 | [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 900 | [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 901 | [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 902 | [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf], 903 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf], 904 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf], 905 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf], 906 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf], 907 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf], 908 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf], 909 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf], 910 | [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]) 911 | 912 | 913 | 914 | Now, we have to apply a mask to the query-key per token vector. Additionally, we want to apply softmax on top of it to convert the output scores into probabilities. This helps in selecting the most likely token or sequence of tokens from the model's vocabulary, making the model's predictions more interpretable and suitable for tasks like language generation and classification. 915 | 916 | 917 | ```python 918 | # Add the mask to the query-key dot products per token 919 | qk_per_token_after_masking = qk_per_token + mask 920 | 921 | # Apply softmax along the second dimension after masking 922 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 923 | ``` 924 | 925 | For the value matrix, which marks the end of the self-attention part, similar to keys, value weights are also shared across every 4 attention heads to save computation. As a result, the shape of the value weight matrix is [8x128x4096]. 926 | 927 | 928 | ```python 929 | # Retrieve the value weight for the first layer of attention 930 | v_layer0 = model["layers.0.attention.wv.weight"] 931 | 932 | # Reshape value weight for the first layer of attention to separate heads 933 | v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim) 934 | 935 | # Print the shape of the reshaped value weight tensor 936 | v_layer0.shape 937 | ``` 938 | 939 | 940 | 941 | 942 | torch.Size([8, 128, 4096]) 943 | 944 | 945 | 946 | Similar to the query and key matrices, the value matrix for the first layer and first head can be obtained using: 947 | 948 | 949 | ```python 950 | # Extract the value weight for the first head of the first layer of attention 951 | v_layer0_head0 = v_layer0[0] 952 | 953 | # Print the shape of the extracted value weight tensor for the first head 954 | v_layer0_head0.shape 955 | ``` 956 | 957 | 958 | 959 | 960 | torch.Size([128, 4096]) 961 | 962 | 963 | 964 | Using the value weights, we compute the attention values for each token, resulting in a matrix of size [17x128]. Here, 17 denotes the number of tokens in the prompt, and 128 indicates the dimension of the value vector for each token. 965 | 966 | 967 | ```python 968 | # Calculate value per token by matrix multiplication 969 | v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T) 970 | 971 | # Print the shape of the resulting tensor representing values per token 972 | v_per_token.shape 973 | ``` 974 | 975 | 976 | 977 | 978 | torch.Size([17, 128]) 979 | 980 | 981 | 982 | To obtain the resulting attention matrix, we can perform the following multiplication: 983 | 984 | 985 | ```python 986 | # Calculate QKV attention by matrix multiplication 987 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 988 | 989 | # Print the shape of the resulting tensor 990 | qkv_attention.shape 991 | ``` 992 | 993 | 994 | 995 | 996 | torch.Size([17, 128]) 997 | 998 | 999 | 1000 | We now have the attention values for the first layer and first head or in other words self attention. 1001 | 1002 | ## Implementing Multi-Head Attention 1003 | 1004 | A loop will be executed to perform the same calculations as above, but for every head in the first layer. 1005 | 1006 | 1007 | ```python 1008 | # Store QKV attention for each head in a list 1009 | qkv_attention_store = [] 1010 | 1011 | # Iterate through each head 1012 | for head in range(n_heads): 1013 | # Extract query, key, and value weights for the current head 1014 | q_layer0_head = q_layer0[head] 1015 | k_layer0_head = k_layer0[head//4] # Key weights are shared across 4 heads 1016 | v_layer0_head = v_layer0[head//4] # Value weights are shared across 4 heads 1017 | 1018 | # Calculate query per token by matrix multiplication 1019 | q_per_token = torch.matmul(token_embeddings, q_layer0_head.T) 1020 | 1021 | # Calculate key per token by matrix multiplication 1022 | k_per_token = torch.matmul(token_embeddings, k_layer0_head.T) 1023 | 1024 | # Calculate value per token by matrix multiplication 1025 | v_per_token = torch.matmul(token_embeddings, v_layer0_head.T) 1026 | 1027 | # Split query per token into pairs and rotate them 1028 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 1029 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 1030 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 1031 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 1032 | 1033 | # Split key per token into pairs and rotate them 1034 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 1035 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 1036 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)]) 1037 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 1038 | 1039 | # Calculate query-key dot products per token 1040 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5 1041 | 1042 | # Create a mask tensor filled with negative infinity values 1043 | mask = torch.full((len(tokens), len(tokens)), float("-inf"), device=tokens.device) 1044 | # Set upper triangular part of the mask tensor to negative infinity 1045 | mask = torch.triu(mask, diagonal=1) 1046 | # Add the mask to the query-key dot products per token 1047 | qk_per_token_after_masking = qk_per_token + mask 1048 | 1049 | # Apply softmax along the second dimension after masking 1050 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 1051 | 1052 | # Calculate QKV attention by matrix multiplication 1053 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 1054 | 1055 | # Store QKV attention for the current head 1056 | qkv_attention_store.append(qkv_attention) 1057 | 1058 | # Print the number of QKV attentions stored 1059 | len(qkv_attention_store) 1060 | ``` 1061 | 1062 | 1063 | 1064 | 1065 | 32 1066 | 1067 | 1068 | 1069 | Now that the QKV attention matrix for all 32 heads in the first layer is obtained, all attention scores will be merged into one large matrix of size [17x4096]. 1070 | 1071 | 1072 | ```python 1073 | # Concatenate QKV attentions from all heads along the last dimension 1074 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 1075 | 1076 | # Print the shape of the resulting tensor 1077 | stacked_qkv_attention.shape 1078 | ``` 1079 | 1080 | 1081 | 1082 | 1083 | torch.Size([17, 4096]) 1084 | 1085 | 1086 | 1087 | One of the last steps for layer 0 attention is to multiply the weight matrix with the stacked QKV matrix. 1088 | 1089 | 1090 | ```python 1091 | # Calculate the embedding delta by matrix multiplication with the output weight 1092 | embedding_delta = torch.matmul(stacked_qkv_attention, model["layers.0.attention.wo.weight"].T) 1093 | 1094 | # Print the shape of the resulting tensor 1095 | embedding_delta.shape 1096 | ``` 1097 | 1098 | 1099 | 1100 | 1101 | torch.Size([17, 4096]) 1102 | 1103 | 1104 | 1105 | We now have the change in the embedding values after attention, which should be added to the original token embeddings. 1106 | 1107 | 1108 | ```python 1109 | # Add the embedding delta to the unnormalized token embeddings to get the final embeddings 1110 | embedding_after_edit = token_embeddings_unnormalized + embedding_delta 1111 | 1112 | # Print the shape of the resulting tensor 1113 | embedding_after_edit.shape 1114 | ``` 1115 | 1116 | 1117 | 1118 | 1119 | torch.Size([17, 4096]) 1120 | 1121 | 1122 | 1123 | The change in embeddings is normalized, followed by running it through a feedforward neural network. 1124 | 1125 | 1126 | ```python 1127 | # Normalize edited embeddings using root mean square normalization and provided weights 1128 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model["layers.0.ffn_norm.weight"]) 1129 | 1130 | # Print the shape of resulting normalized embeddings 1131 | embedding_after_edit_normalized.shape 1132 | ``` 1133 | 1134 | 1135 | 1136 | 1137 | torch.Size([17, 4096]) 1138 | 1139 | 1140 | 1141 | ## Implementing SwiGLU Activation Function 1142 | Given our familiarity with the SwiGLU activation function from the previous section, we will apply the equation we studied earlier here. 1143 | 1144 | ![](https://cdn-images-1.medium.com/max/1000/1*q5FbOgDpo6H-86AefVzdNQ.png) 1145 | 1146 | 1147 | ```python 1148 | # Retrieve weights for feedforward layer 1149 | w1 = model["layers.0.feed_forward.w1.weight"] 1150 | w2 = model["layers.0.feed_forward.w2.weight"] 1151 | w3 = model["layers.0.feed_forward.w3.weight"] 1152 | 1153 | # Perform operations for feedforward layer 1154 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 1155 | 1156 | # Print the shape of the resulting tensor after feedforward 1157 | output_after_feedforward.shape 1158 | ``` 1159 | 1160 | 1161 | 1162 | 1163 | torch.Size([17, 4096]) 1164 | 1165 | 1166 | 1167 | ## Merging everything 1168 | Now that everything is ready, we need to merge our code to generate 31 more layers. 1169 | 1170 | 1171 | ```python 1172 | # Initialize final embedding with unnormalized token embeddings 1173 | final_embedding = token_embeddings_unnormalized 1174 | 1175 | # Iterate through each layer 1176 | for layer in range(n_layers): 1177 | # Initialize list to store QKV attentions for each head 1178 | qkv_attention_store = [] 1179 | 1180 | # Normalize the final embedding using root mean square normalization and weights from the current layer 1181 | layer_embedding_norm = rms_norm(final_embedding, model[f"layers.{layer}.attention_norm.weight"]) 1182 | 1183 | # Retrieve query, key, value, and output weights for the attention mechanism of the current layer 1184 | q_layer = model[f"layers.{layer}.attention.wq.weight"] 1185 | q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim) 1186 | k_layer = model[f"layers.{layer}.attention.wk.weight"] 1187 | k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim) 1188 | v_layer = model[f"layers.{layer}.attention.wv.weight"] 1189 | v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim) 1190 | w_layer = model[f"layers.{layer}.attention.wo.weight"] 1191 | 1192 | # Iterate through each head 1193 | for head in range(n_heads): 1194 | # Extract query, key, and value weights for the current head 1195 | q_layer_head = q_layer[head] 1196 | k_layer_head = k_layer[head//4] # Key weights are shared across 4 heads 1197 | v_layer_head = v_layer[head//4] # Value weights are shared across 4 heads 1198 | 1199 | # Calculate query per token by matrix multiplication 1200 | q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T) 1201 | 1202 | # Calculate key per token by matrix multiplication 1203 | k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T) 1204 | 1205 | # Calculate value per token by matrix multiplication 1206 | v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T) 1207 | 1208 | # Split query per token into pairs and rotate them 1209 | q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2) 1210 | q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs) 1211 | q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis) 1212 | q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape) 1213 | 1214 | # Split key per token into pairs and rotate them 1215 | k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2) 1216 | k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs) 1217 | k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis) 1218 | k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape) 1219 | 1220 | # Calculate query-key dot products per token 1221 | qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5 1222 | 1223 | # Create a mask tensor filled with negative infinity values 1224 | mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float("-inf")) 1225 | # Set upper triangular part of the mask tensor to negative infinity 1226 | mask = torch.triu(mask, diagonal=1) 1227 | # Add the mask to the query-key dot products per token 1228 | qk_per_token_after_masking = qk_per_token + mask 1229 | 1230 | # Apply softmax along the second dimension after masking 1231 | qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16) 1232 | 1233 | # Calculate QKV attention by matrix multiplication 1234 | qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token) 1235 | 1236 | # Store QKV attention for the current head 1237 | qkv_attention_store.append(qkv_attention) 1238 | 1239 | # Concatenate QKV attentions from all heads along the last dimension 1240 | stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1) 1241 | 1242 | # Calculate embedding delta by matrix multiplication with the output weight 1243 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 1244 | 1245 | # Add the embedding delta to the current embedding to get the edited embedding 1246 | embedding_after_edit = final_embedding + embedding_delta 1247 | 1248 | # Normalize the edited embedding using root mean square normalization and weights from the current layer 1249 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f"layers.{layer}.ffn_norm.weight"]) 1250 | 1251 | # Retrieve weights for the feedforward layer 1252 | w1 = model[f"layers.{layer}.feed_forward.w1.weight"] 1253 | w2 = model[f"layers.{layer}.feed_forward.w2.weight"] 1254 | w3 = model[f"layers.{layer}.feed_forward.w3.weight"] 1255 | 1256 | # Perform operations for the feedforward layer 1257 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 1258 | 1259 | # Update the final embedding with the edited embedding plus the output from the feedforward layer 1260 | final_embedding = embedding_after_edit + output_after_feedforward 1261 | ``` 1262 | 1263 | ## Generating the Output 1264 | We now have the final embedding, representing the model's guess for the next token. Its shape is the same as regular token embeddings, [17x4096], with 17 tokens and an embedding dimension of 4096. 1265 | 1266 | 1267 | ```python 1268 | # Normalize the final embedding using root mean square normalization and provided weights 1269 | final_embedding = rms_norm(final_embedding, model["norm.weight"]) 1270 | 1271 | # Print the shape of the resulting normalized final embedding 1272 | final_embedding.shape 1273 | ``` 1274 | 1275 | 1276 | 1277 | 1278 | torch.Size([17, 4096]) 1279 | 1280 | 1281 | 1282 | Now we can decode the embedding into the token value. 1283 | 1284 | 1285 | 1286 | ```python 1287 | # Print the shape of the output weight tensor 1288 | model["output.weight"].shape 1289 | ``` 1290 | 1291 | 1292 | 1293 | 1294 | torch.Size([128256, 4096]) 1295 | 1296 | 1297 | 1298 | To predict the next value, we utilize the embedding of the last token. 1299 | 1300 | 1301 | ```python 1302 | # Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor 1303 | logits = torch.matmul(final_embedding[-1], model["output.weight"].T) 1304 | 1305 | # Find the index of the maximum value along the last dimension to determine the next token 1306 | next_token = torch.argmax(logits, dim=-1) 1307 | 1308 | # Decode the index of the next token using the tokenizer 1309 | tokenizer.decode([next_token.item()]) 1310 | ``` 1311 | 1312 | 1313 | 1314 | 1315 | '42' 1316 | 1317 | 1318 | 1319 | So, our input was "the answer to the ultimate question of life, the universe, and everything is ", and the output for it is "42", which is the correct answer. 1320 | You can experiment with different input texts by simply changing these two lines throughout the entire code, Rest of the code remains same! 1321 | 1322 | ```python 1323 | # input prompt 1324 | prompt = "Your Input" 1325 | 1326 | # Replacing 17 number with total number of tokens in your input 1327 | # You can check total number of tokens using len(tokens) 1328 | freqs_for_each_token = torch.outer(torch.arange(17), freqs) 1329 | ``` 1330 | 1331 | ### Hope you have enjoyed and learned new things from this blog! 1332 | -------------------------------------------------------------------------------- /architectures_comparison.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/FareedKhan-dev/Building-llama3-from-scratch/506878635e6ce2ca30b2b31c6b6b37480353a3bb/architectures_comparison.pdf -------------------------------------------------------------------------------- /building-llama-3-from-scratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "## Building LLaMA 3 LLM From Scratch using Python\n", 8 | "\n", 9 | "LLaMA 3 is one of the most promising open-source model after Mistral, solving a wide range of tasks. I previously wrote a blog on Medium about creating an LLM with over 2.3 million parameters from scratch using the LLaMA architecture. Now that LLaMA-3 is released, we will recreate it in a simpler manner." 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "We won't be using a GPU for this blog, but you'll need at least 17 GB of RAM because we are going to load some files that are more than 15 GB in size. If this is an issue for you, you can use Kaggle as a solution. Since we don't need a GPU, Kaggle offers 30 GB of RAM while using only CPU cores as an accelerator.\n", 17 | "\n", 18 | "Here is the blog link which guides you on how to create a 2.3+ million parameter LLM from scratch:\n", 19 | "[2.3+ Million Parameter LLM From Scratch](https://levelup.gitconnected.com/building-a-million-parameter-llm-from-scratch-using-python-f612398f06c2)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": { 25 | "execution": { 26 | "iopub.execute_input": "2024-05-26T18:58:30.350370Z", 27 | "iopub.status.busy": "2024-05-26T18:58:30.349842Z", 28 | "iopub.status.idle": "2024-05-26T18:58:30.361066Z", 29 | "shell.execute_reply": "2024-05-26T18:58:30.359662Z", 30 | "shell.execute_reply.started": "2024-05-26T18:58:30.350337Z" 31 | } 32 | }, 33 | "source": [ 34 | "## Table of Contents\n", 35 | "\n", 36 | "1. [Prerequisites](#prerequisites)\n", 37 | "2. [Difference between LLaMA 2 and LLaMA 3](#difference-between-llama-2-and-llama-3)\n", 38 | "3. [Understanding the Transformer Architecture of LLaMA 3](#understanding-the-transformer-architecture-of-llama-3)\n", 39 | " - Pre-normalization Using RMSNorm\n", 40 | " - SwiGLU Activation Function\n", 41 | " - Rotary Embeddings (RoPE)\n", 42 | " - Byte Pair Encoding (BPE) Algorithm\n", 43 | "4. [Setting the Stage](#setting-the-stage)\n", 44 | "5. [Understanding the File Structure](#understanding-the-file-structure)\n", 45 | "6. [Tokenizing our Input Data](#tokenizing-our-input-data)\n", 46 | "7. [Creating Embedding for each Token](#creating-embedding-for-each-token)\n", 47 | "8. [Normalization Using RMSNorm](#normalization-using-rmsnorm)\n", 48 | "9. [Attention Heads (Query, Key, Values)](#attention-heads-query-key-values)\n", 49 | "10. [Implementing RoPE](#implementing-rope)\n", 50 | "11. [Implementing Self Attention](#implementing-self-attention)\n", 51 | "12. [Implementing Multi-Head Attention](#implementing-multi-head-attention)\n", 52 | "13. [Implementing SwiGLU Activation Function](#implementing-swiglu-activation-function)\n", 53 | "14. [Merging Everything](#merging-everything)\n", 54 | "15. [Generating the Output](#generating-the-output)\n" 55 | ] 56 | }, 57 | { 58 | "cell_type": "markdown", 59 | "metadata": {}, 60 | "source": [ 61 | "## Prerequisites\n", 62 | "The good part is we won't be using object-oriented programming (OOP) coding, just plain Python programming. However, you should have a basic understanding of neural networks and Transformer architecture. These are the only two prerequisites needed to follow along with the blog.\n", 63 | "\n", 64 | "| Topic | Link |\n", 65 | "| ---- | ---- |\n", 66 | "| Transformer Theory | [Video Link](https://www.youtube.com/watch?v=zxQyTK8quyY) |\n", 67 | "| Neural Networks Theory | [Video Link](https://www.youtube.com/watch?v=Jy4wM2X21u0) |\n", 68 | "| Python basics | [Video Link](https://www.youtube.com/watch?v=eWRfhZUzrAc) |" 69 | ] 70 | }, 71 | { 72 | "cell_type": "markdown", 73 | "metadata": {}, 74 | "source": [ 75 | "## Difference between LLaMA 2 and LLaMA 3\n", 76 | "Before looking into the technical details, the first thing you must know is that the entire architecture of LLaMA 3 is the same as LLaMA 2. So, if you haven't gone through the technical details of LLaMA 3 yet, it won't be a problem for you to follow this blog. Even if you don't have an understanding of LLaMA 2 architecture, don't worry, we will also look at a high-level overview of its technical details. This blog is designed for you either way.\n", 77 | "\n", 78 | "Here are some key points about LLaMA 2 and LLaMA 3. If you are already familiar with their architecture:\n", 79 | "\n", 80 | "| FEATURE | Llama 3 | Llama 2 |\n", 81 | "|----------------------------------------|----------------------------|------------------------------------|\n", 82 | "| Tokenizer | Tiktoken (developed by OpenAI) | SentencePiece |\n", 83 | "| Number of Parameters | 8B, 70B | 70B, 13B, 7B |\n", 84 | "| Training Data | 15T tokens | 2.2T tokens |\n", 85 | "| Context Length | 8192 tokens | 4096 tokens |\n", 86 | "| Attention Mechanism | Grouped-query attention | Grouped-query attention |\n", 87 | "| Fine-Tuned Models | Yes | Yes |\n", 88 | "| Performance | Better than Llama 2 on all benchmarks | Better than Llama 1 on most benchmarks |\n", 89 | "| Computational Requirements | Very high (70B model) | Very high (70B model) |\n", 90 | "| Availability | Open source | Open source |\n", 91 | "| Reinforcement learning from human feedback | Yes | Yes |\n", 92 | "| Number of languages supported | 30 languages | 20 languages |\n", 93 | "| Suitable for | Best for more demanding tasks, such as reasoning, coding, and proficiency tests | Good for more demanding tasks, such as reasoning, coding, and proficiency tests |" 94 | ] 95 | }, 96 | { 97 | "cell_type": "markdown", 98 | "metadata": {}, 99 | "source": [ 100 | "## Understanding the Transformer Architecture of LLaMA 3\n", 101 | "Understanding the architecture of LLaMA 3 is important before diving into coding it. For a better visual understanding, here's a comparison diagram between the vanilla Transformer, LLaMA 2/3, and Mistral.\n", 102 | "\n", 103 | "![Comparison](https://i.ibb.co/kKN4Cks/Group-1-1.png)" 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "Let's look into the most important components of LLaMA 3 with a bit more detail:\n", 111 | "### 1. Pre-normalization Using RMSNorm:\n", 112 | "In the LLaMA 3 approach which is the same as LLaMA 2, a technique called RMSNorm is used for normalizing the input of each transformer sub-layer.\n", 113 | "\n", 114 | "Imagine you're studying for a big exam, and you have a massive textbook full of chapters. Each chapter represents a different topic, but some chapters are more crucial for understanding the subject than others.\n", 115 | "Now, before diving into the entire textbook, you decide to evaluate the importance of each chapter. You don't want to spend the same amount of time on every chapter; you want to focus more on the critical ones.\n", 116 | "This is where Pre-normalization using RMSNorm comes into play for large language models (LLMs) like ChatGPT. It's like assigning a weight to each chapter based on its significance. Chapters that are fundamental to the subject get higher weights, while less important ones get lower weights.\n", 117 | "\n", 118 | "So, before going deeply into studying, you adjust your study plan based on the weighted importance of each chapter. You allocate more time and effort to the chapters with higher weights, ensuring you grasp the core concepts thoroughly.\n", 119 | "\n", 120 | "![N](https://cdn-images-1.medium.com/v2/resize:fit:1000/0*GIr8bvByN_iAGQBW.png)\n", 121 | "\n", 122 | "Similarly, Pre-normalization using RMSNorm helps LLMs prioritize which parts of the text are more critical for understanding the context and meaning. It assigns higher weights to essential elements and lower weights to less crucial ones, ensuring the model focuses its attention where it's most needed for accurate comprehension. Interested readers can explore the detailed implementation of RMSNorm here." 123 | ] 124 | }, 125 | { 126 | "cell_type": "markdown", 127 | "metadata": {}, 128 | "source": [ 129 | "### 2. SwiGLU Activation Function:\n", 130 | "LLaMA introduces the SwiGLU activation function, drawing inspiration from PaLM.\n", 131 | "\n", 132 | "Imagine you're a teacher trying to explain a complex topic to your students. You have a big whiteboard where you write down key points and draw diagrams to make things clearer. But sometimes, your handwriting might not be very neat, or your diagrams might not be perfectly drawn. This can make it harder for your students to understand the material.\n", 133 | "\n", 134 | "Now, imagine if you had a magic pen that automatically adjusted the size and style of your handwriting based on how important each point is. If something is really crucial, the pen writes it bigger and clearer, making it stand out. If it's less important, the pen writes it smaller, but still legible.\n", 135 | "SwiGLU is like that magic pen for large language models (LLMs) like ChatGPT. Before generating text, SwiGLU adjusts the importance of each word or phrase based on its relevance to the context. Just like the magic pen adjusts the size and style of your writing, SwiGLU adjusts the emphasis of each word or phrase.\n", 136 | "\n", 137 | "![SwigLU](https://cdn-images-1.medium.com/max/1000/0*NtNn2CFuDNEH6jFC.png)\n", 138 | "\n", 139 | "So, when the LLM generates text, it can give more prominence to the important parts, making them more noticeable and ensuring they contribute more to the overall understanding of the text. This way, SwiGLU helps LLMs produce text that's clearer and easier to understand, much like how the magic pen helps you create clearer explanations for your students on the whiteboard. Further details on SwiGLU can be found in the associated paper." 140 | ] 141 | }, 142 | { 143 | "cell_type": "markdown", 144 | "metadata": {}, 145 | "source": [ 146 | "### 3. Rotary Embeddings (RoPE):\n", 147 | "Rotary Embeddings, or RoPE, is a type of position embedding used in LLaMA 3.\n", 148 | "\n", 149 | "Imagine you're in a classroom, and you want to assign seats to students for group discussions. Typically, you might arrange the seats in rows and columns, with each student having a fixed position. However, in some cases, you want to create a more dynamic seating arrangement where students can move around and interact more freely.\n", 150 | "\n", 151 | "ROPE is like a special seating arrangement that allows students to rotate and change positions while still maintaining their relative positions to each other. Instead of being fixed in one place, students can now move around in a circular motion, allowing for more fluid interactions.\n", 152 | "\n", 153 | "In this scenario, each student represents a word or token in a text sequence, and their position corresponds to their position in the sequence. Just like how ROPE allows students to rotate and change positions, ROPE allows the positional embeddings of words in a text sequence to dynamically change based on their relative positions to each other.\n", 154 | "So, when processing text, instead of treating positional embeddings as fixed and static, ROPE introduces a rotational aspect, allowing for more flexible representations that capture the dynamic relationships between words in the sequence. This flexibility helps models like ChatGPT better understand and generate text that flows naturally and maintains coherence, similar to how a dynamic seating arrangement fosters more interactive discussions in a classroom. Those interested in the mathematical details can refer to the RoPE paper." 155 | ] 156 | }, 157 | { 158 | "cell_type": "markdown", 159 | "metadata": {}, 160 | "source": [ 161 | "### 4. Byte Pair Encoding (BPE) Algorithm\n", 162 | "\n", 163 | "LLaMA 3 uses Byte Pair Encoding (BPE) from the tiktoken library introduced by OpenAI, whereas the LLaMA 2 tokenizer BPE is based on the sentencepiece library. There is a slight difference between them, but \n", 164 | "\n", 165 | "first, let's learn what BPE actually is.\n", 166 | "\n", 167 | "Let's start with a simple example. Suppose we have a text corpus with the words: \"ab\", \"bc\", \"bcd\", and \"cde\". We begin by initializing our vocabulary with all the individual characters in the text corpus, so our initial vocabulary is {\"a\", \"b\", \"c\", \"d\", \"e\"}.\n", 168 | "\n", 169 | "Next, we calculate the frequency of each character in the text corpus. For our example, the frequencies are: {\"a\": 1, \"b\": 3, \"c\": 3, \"d\": 2, \"e\": 1}.\n", 170 | "\n", 171 | "Now, we start the merging process. We repeat the following steps until our vocabulary reaches the desired size:\n", 172 | "\n", 173 | "1. First, we find the most frequent pair of consecutive characters. In this case, the most frequent pair is \"bc\" with a frequency of 2. We then merge this pair to create a new subword unit \"bc\". After merging, we update the frequency counts to reflect the new subword unit. The updated frequency is {\"a\": 1, \"b\": 2, \"c\": 2, \"d\": 2, \"e\": 1, \"bc\": 2}. We add the new subword unit \"bc\" to our vocabulary, which now becomes {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\"}.\n", 174 | "\n", 175 | "2. We repeat the process. The next most frequent pair is \"cd\". We merge \"cd\" to form a new subword unit \"cd\" and update the frequency counts. The updated frequency is {\"a\": 1, \"b\": 2, \"c\": 1, \"d\": 1, \"e\": 1, \"bc\": 2, \"cd\": 2}. We add \"cd\" to the vocabulary, resulting in {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\", \"cd\"}.\n", 176 | "\n", 177 | "3. Continuing the process, the next frequent pair is \"de\". We merge \"de\" to form the subword unit \"de\" and update the frequency counts to {\"a\": 1, \"b\": 2, \"c\": 1, \"d\": 1, \"e\": 0, \"bc\": 2, \"cd\": 1, \"de\": 1}. We add \"de\" to the vocabulary, making it {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\", \"cd\", \"de\"}.\n", 178 | "\n", 179 | "4. Next, we find \"ab\" as the most frequent pair. We merge \"ab\" to form the subword unit \"ab\" and update the frequency counts to {\"a\": 0, \"b\": 1, \"c\": 1, \"d\": 1, \"e\": 0, \"bc\": 2, \"cd\": 1, \"de\": 1, \"ab\": 1}. \n", 180 | "\n", 181 | "5. We add \"ab\" to the vocabulary, which becomes {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\", \"cd\", \"de\", \"ab\"}.\n", 182 | "\n", 183 | "6. Then, the next frequent pair is \"bcd\". We merge \"bcd\" to form the subword unit \"bcd\" and update the frequency counts to {\"a\": 0, \"b\": 0, \"c\": 0, \"d\": 0, \"e\": 0, \"bc\": 1, \"cd\": 0, \"de\": 1, \"ab\": 1, \"bcd\": 1}. We add \"bcd\" to the vocabulary, resulting in {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\", \"cd\", \"de\", \"ab\", \"bcd\"}.\n", 184 | "\n", 185 | "7. Finally, the most frequent pair is \"cde\". We merge \"cde\" to form the subword unit \"cde\" and update the frequency counts to {\"a\": 0, \"b\": 0, \"c\": 0, \"d\": 0, \"e\": 0, \"bc\": 1, \"cd\": 0, \"de\": 0, \"ab\": 1, \"bcd\": 1, \"cde\": 1}. We add \"cde\" to the vocabulary, making it {\"a\", \"b\", \"c\", \"d\", \"e\", \"bc\", \"cd\", \"de\", \"ab\", \"bcd\", \"cde\"}.\n", 186 | "\n", 187 | "This technique can improve the performance of LLMs and handle rare and out-of-vocabulary words. The big difference between TikToken BPE and sentencepiece BPE is that TikToken BPE doesn't always split words into smaller parts if the whole word is already known. For example, if \"hugging\" is in the vocabulary, it stays as one token instead of splitting into [\"hug\",\"ging\"]." 188 | ] 189 | }, 190 | { 191 | "cell_type": "markdown", 192 | "metadata": {}, 193 | "source": [ 194 | "## Setting the Stage\n", 195 | "We will be working with a small range of Python libraries, but it's better to install them to avoid encountering \"no module found\" errors." 196 | ] 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 3, 201 | "metadata": { 202 | "execution": { 203 | "iopub.execute_input": "2024-05-27T02:56:41.674943Z", 204 | "iopub.status.busy": "2024-05-27T02:56:41.674503Z", 205 | "iopub.status.idle": "2024-05-27T02:56:56.698775Z", 206 | "shell.execute_reply": "2024-05-27T02:56:56.697490Z", 207 | "shell.execute_reply.started": "2024-05-27T02:56:41.674901Z" 208 | } 209 | }, 210 | "outputs": [ 211 | { 212 | "name": "stdout", 213 | "output_type": "stream", 214 | "text": [ 215 | "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.10/site-packages (0.2.0)\n", 216 | "Requirement already satisfied: tiktoken in /opt/conda/lib/python3.10/site-packages (0.7.0)\n", 217 | "Requirement already satisfied: torch in /opt/conda/lib/python3.10/site-packages (2.1.2+cpu)\n", 218 | "Requirement already satisfied: blobfile in /opt/conda/lib/python3.10/site-packages (2.1.1)\n", 219 | "Requirement already satisfied: matplotlib in /opt/conda/lib/python3.10/site-packages (3.7.5)\n", 220 | "Requirement already satisfied: huggingface_hub in /opt/conda/lib/python3.10/site-packages (0.22.2)\n", 221 | "Requirement already satisfied: regex>=2022.1.18 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2023.12.25)\n", 222 | "Requirement already satisfied: requests>=2.26.0 in /opt/conda/lib/python3.10/site-packages (from tiktoken) (2.31.0)\n", 223 | "Requirement already satisfied: filelock in /opt/conda/lib/python3.10/site-packages (from torch) (3.13.1)\n", 224 | "Requirement already satisfied: typing-extensions in /opt/conda/lib/python3.10/site-packages (from torch) (4.9.0)\n", 225 | "Requirement already satisfied: sympy in /opt/conda/lib/python3.10/site-packages (from torch) (1.12)\n", 226 | "Requirement already satisfied: networkx in /opt/conda/lib/python3.10/site-packages (from torch) (3.2.1)\n", 227 | "Requirement already satisfied: jinja2 in /opt/conda/lib/python3.10/site-packages (from torch) (3.1.2)\n", 228 | "Requirement already satisfied: fsspec in /opt/conda/lib/python3.10/site-packages (from torch) (2024.2.0)\n", 229 | "Requirement already satisfied: pycryptodomex~=3.8 in /opt/conda/lib/python3.10/site-packages (from blobfile) (3.20.0)\n", 230 | "Requirement already satisfied: urllib3<3,>=1.25.3 in /opt/conda/lib/python3.10/site-packages (from blobfile) (1.26.18)\n", 231 | "Requirement already satisfied: lxml~=4.9 in /opt/conda/lib/python3.10/site-packages (from blobfile) (4.9.4)\n", 232 | "Requirement already satisfied: contourpy>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.2.0)\n", 233 | "Requirement already satisfied: cycler>=0.10 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (0.12.1)\n", 234 | "Requirement already satisfied: fonttools>=4.22.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (4.47.0)\n", 235 | "Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.4.5)\n", 236 | "Requirement already satisfied: numpy<2,>=1.20 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (1.26.4)\n", 237 | "Requirement already satisfied: packaging>=20.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (21.3)\n", 238 | "Requirement already satisfied: pillow>=6.2.0 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (9.5.0)\n", 239 | "Requirement already satisfied: pyparsing>=2.3.1 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (3.1.1)\n", 240 | "Requirement already satisfied: python-dateutil>=2.7 in /opt/conda/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)\n", 241 | "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (6.0.1)\n", 242 | "Requirement already satisfied: tqdm>=4.42.1 in /opt/conda/lib/python3.10/site-packages (from huggingface_hub) (4.66.1)\n", 243 | "Requirement already satisfied: six>=1.5 in /opt/conda/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0)\n", 244 | "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.3.2)\n", 245 | "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (3.6)\n", 246 | "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.10/site-packages (from requests>=2.26.0->tiktoken) (2024.2.2)\n", 247 | "Requirement already satisfied: MarkupSafe>=2.0 in /opt/conda/lib/python3.10/site-packages (from jinja2->torch) (2.1.3)\n", 248 | "Requirement already satisfied: mpmath>=0.19 in /opt/conda/lib/python3.10/site-packages (from sympy->torch) (1.3.0)\n" 249 | ] 250 | } 251 | ], 252 | "source": [ 253 | "!pip install sentencepiece tiktoken torch blobfile matplotlib huggingface_hub" 254 | ] 255 | }, 256 | { 257 | "cell_type": "markdown", 258 | "metadata": {}, 259 | "source": [ 260 | "After installing the required libraries, we need to download some files. Since we're going to replicate the architecture of llama-3–8B, you must have an account on HuggingFace. Additionally, since llama-3 is a gated model, you have to accept their terms and conditions to access model content.\n", 261 | "\n", 262 | "Here are the steps:\n", 263 | "1. Create an HuggingFace account from this [link](https://huggingface.co/join?next=%2Fmeta-llama%2FMeta-Llama-3-8B)\n", 264 | "2. Accept the terms and conditions of llama-3–8B from this [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B)\n", 265 | "\n", 266 | "Once you've completed both of these steps, Now we have to download some files. There are two options to do that:\n", 267 | "\n", 268 | "(Option 1: Manual) Go to the llama-3–8B HF directory from this [link](https://huggingface.co/meta-llama/Meta-Llama-3-8B/tree/main/original) and manually download each of these three files.\n", 269 | "\n", 270 | "![](https://cdn-images-1.medium.com/max/1000/1*QpaH8EzAEEZsLv_EJ1OsFg.png)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "(options 2: Coding) We can use the hugging_face library, which we installed earlier, to download all of these files. However, first, we need to log in to HuggingFace Hub within our working notebook using our HF Token. You can create a new token or access it from this [link](https://huggingface.co/settings/tokens)." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": 4, 283 | "metadata": { 284 | "execution": { 285 | "iopub.execute_input": "2024-05-27T02:56:56.700767Z", 286 | "iopub.status.busy": "2024-05-27T02:56:56.700441Z", 287 | "iopub.status.idle": "2024-05-27T02:56:56.729869Z", 288 | "shell.execute_reply": "2024-05-27T02:56:56.728713Z", 289 | "shell.execute_reply.started": "2024-05-27T02:56:56.700735Z" 290 | } 291 | }, 292 | "outputs": [ 293 | { 294 | "data": { 295 | "application/vnd.jupyter.widget-view+json": { 296 | "model_id": "0d8605bb48d848de9547fd6a4944dca0", 297 | "version_major": 2, 298 | "version_minor": 0 299 | }, 300 | "text/plain": [ 301 | "VBox(children=(HTML(value='
\", # Marks the beginning of a text sequence.\n", 725 | " \"<|end_of_text|>\", # Marks the end of a text sequence.\n", 726 | " \"<|reserved_special_token_0|>\", # Reserved for future use.\n", 727 | " \"<|reserved_special_token_1|>\", # Reserved for future use.\n", 728 | " \"<|reserved_special_token_2|>\", # Reserved for future use.\n", 729 | " \"<|reserved_special_token_3|>\", # Reserved for future use.\n", 730 | " \"<|start_header_id|>\", # Indicates the start of a header ID.\n", 731 | " \"<|end_header_id|>\", # Indicates the end of a header ID.\n", 732 | " \"<|reserved_special_token_4|>\", # Reserved for future use.\n", 733 | " \"<|eot_id|>\", # Marks the end of a turn (in a conversational context).\n", 734 | "] + [f\"<|reserved_special_token_{i}|>\" for i in range(5, 256 - 5)] # A large set of tokens reserved for future use." 735 | ] 736 | }, 737 | { 738 | "cell_type": "markdown", 739 | "metadata": {}, 740 | "source": [ 741 | "Next we define the rules for splitting text into tokens by specifying different patterns to match various types of substrings in the input text. Here's how we can do that." 742 | ] 743 | }, 744 | { 745 | "cell_type": "code", 746 | "execution_count": 13, 747 | "metadata": { 748 | "execution": { 749 | "iopub.execute_input": "2024-05-27T03:00:04.757642Z", 750 | "iopub.status.busy": "2024-05-27T03:00:04.757203Z", 751 | "iopub.status.idle": "2024-05-27T03:00:04.767797Z", 752 | "shell.execute_reply": "2024-05-27T03:00:04.766574Z", 753 | "shell.execute_reply.started": "2024-05-27T03:00:04.757607Z" 754 | } 755 | }, 756 | "outputs": [], 757 | "source": [ 758 | "# patterns based on which text will be break into tokens\n", 759 | "tokenize_breaker = r\"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+\"" 760 | ] 761 | }, 762 | { 763 | "cell_type": "markdown", 764 | "metadata": {}, 765 | "source": [ 766 | "It can extracts words, contractions, numbers (up to three digits), and sequences of non-whitespace characters from the input text, you can customize it based on your requirements.\n", 767 | "We need to code a simple tokenizer function using the TikToken BPE, which takes three inputs: tokenizer_model, tokenize_breaker, and special_tokens. This function will encode/decode our input text accordingly." 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": 14, 773 | "metadata": { 774 | "execution": { 775 | "iopub.execute_input": "2024-05-27T03:00:04.773732Z", 776 | "iopub.status.busy": "2024-05-27T03:00:04.773352Z", 777 | "iopub.status.idle": "2024-05-27T03:00:05.272888Z", 778 | "shell.execute_reply": "2024-05-27T03:00:05.271628Z", 779 | "shell.execute_reply.started": "2024-05-27T03:00:04.773701Z" 780 | } 781 | }, 782 | "outputs": [ 783 | { 784 | "data": { 785 | "text/plain": [ 786 | "'hello world!'" 787 | ] 788 | }, 789 | "execution_count": 14, 790 | "metadata": {}, 791 | "output_type": "execute_result" 792 | } 793 | ], 794 | "source": [ 795 | "# Initialize tokenizer with specified parameters\n", 796 | "tokenizer = tiktoken.Encoding(\n", 797 | "\n", 798 | " # make sure to set path to tokenizer.model file\n", 799 | " name = \"/kaggle/working/llama-3-8B/original/tokenizer.model\",\n", 800 | "\n", 801 | " # Define tokenization pattern string\n", 802 | " pat_str = tokenize_breaker,\n", 803 | "\n", 804 | " # Assign BPE mergeable ranks from tokenizer_model of LLaMA-3\n", 805 | " mergeable_ranks = tokenizer_model,\n", 806 | "\n", 807 | " # Set special tokens with indices\n", 808 | " special_tokens={token: len(tokenizer_model) + i for i, token in enumerate(special_tokens)},\n", 809 | ")\n", 810 | "\n", 811 | "# Encode \"hello world!\" and decode tokens to string\n", 812 | "tokenizer.decode(tokenizer.encode(\"hello world!\"))" 813 | ] 814 | }, 815 | { 816 | "cell_type": "markdown", 817 | "metadata": {}, 818 | "source": [ 819 | "To verify that our encoder function methods work correctly, we pass \"Hello World\" into it. First, it encodes the text, transforming it into numerical values. Then, it decodes it back to text, resulting in \"hello world!\". This confirms that the function is working correctly. Let's tokenize our input." 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": 15, 825 | "metadata": { 826 | "execution": { 827 | "iopub.execute_input": "2024-05-27T03:00:05.274591Z", 828 | "iopub.status.busy": "2024-05-27T03:00:05.274236Z", 829 | "iopub.status.idle": "2024-05-27T03:00:05.303138Z", 830 | "shell.execute_reply": "2024-05-27T03:00:05.302025Z", 831 | "shell.execute_reply.started": "2024-05-27T03:00:05.274562Z" 832 | } 833 | }, 834 | "outputs": [ 835 | { 836 | "name": "stdout", 837 | "output_type": "stream", 838 | "text": [ 839 | "[128000, 1820, 4320, 311, 279, 17139, 3488, 315, 2324, 11, 279, 15861, 11, 323, 4395, 374, 220]\n", 840 | "['<|begin_of_text|>', 'the', ' answer', ' to', ' the', ' ultimate', ' question', ' of', ' life', ',', ' the', ' universe', ',', ' and', ' everything', ' is', ' ']\n" 841 | ] 842 | } 843 | ], 844 | "source": [ 845 | "# input prompt\n", 846 | "prompt = \"the answer to the ultimate question of life, the universe, and everything is \"\n", 847 | "\n", 848 | "# Encode the prompt using the tokenizer and prepend a special token (128000)\n", 849 | "tokens = [128000] + tokenizer.encode(prompt)\n", 850 | "\n", 851 | "print(tokens) # Print the encoded tokens\n", 852 | "\n", 853 | "# Convert the list of tokens into a PyTorch tensor\n", 854 | "tokens = torch.tensor(tokens)\n", 855 | "\n", 856 | "# Decode each token back into its corresponding string\n", 857 | "prompt_split_as_tokens = [tokenizer.decode([token.item()]) for token in tokens]\n", 858 | "\n", 859 | "print(prompt_split_as_tokens) # Print the decoded tokens" 860 | ] 861 | }, 862 | { 863 | "cell_type": "markdown", 864 | "metadata": {}, 865 | "source": [ 866 | "We encoded our input text \"the answer to the ultimate question of life, the universe, and everything is \" starting with a special token." 867 | ] 868 | }, 869 | { 870 | "cell_type": "markdown", 871 | "metadata": {}, 872 | "source": [ 873 | "## Creating Embedding for each Token\n", 874 | "\n", 875 | "If we check the length of our input vector, it would be:" 876 | ] 877 | }, 878 | { 879 | "cell_type": "code", 880 | "execution_count": 16, 881 | "metadata": { 882 | "execution": { 883 | "iopub.execute_input": "2024-05-27T03:00:05.305300Z", 884 | "iopub.status.busy": "2024-05-27T03:00:05.304831Z", 885 | "iopub.status.idle": "2024-05-27T03:00:05.311936Z", 886 | "shell.execute_reply": "2024-05-27T03:00:05.310770Z", 887 | "shell.execute_reply.started": "2024-05-27T03:00:05.305260Z" 888 | } 889 | }, 890 | "outputs": [ 891 | { 892 | "name": "stdout", 893 | "output_type": "stream", 894 | "text": [ 895 | "4096 17\n" 896 | ] 897 | } 898 | ], 899 | "source": [ 900 | "# checking dimension of input vector and embedding vector from llama-3 architecture\n", 901 | "print(dim, len(tokens))" 902 | ] 903 | }, 904 | { 905 | "cell_type": "markdown", 906 | "metadata": {}, 907 | "source": [ 908 | "Our input vectors, which are currently of dimension (17x1), need to be transformed into embeddings for each tokenized word. This means our (17x1) tokens will become (17x4096), where each token has a corresponding embedding of length 4096." 909 | ] 910 | }, 911 | { 912 | "cell_type": "code", 913 | "execution_count": 17, 914 | "metadata": { 915 | "execution": { 916 | "iopub.execute_input": "2024-05-27T03:00:05.314096Z", 917 | "iopub.status.busy": "2024-05-27T03:00:05.313743Z", 918 | "iopub.status.idle": "2024-05-27T03:00:11.088133Z", 919 | "shell.execute_reply": "2024-05-27T03:00:11.086993Z", 920 | "shell.execute_reply.started": "2024-05-27T03:00:05.314067Z" 921 | } 922 | }, 923 | "outputs": [ 924 | { 925 | "data": { 926 | "text/plain": [ 927 | "torch.Size([17, 4096])" 928 | ] 929 | }, 930 | "execution_count": 17, 931 | "metadata": {}, 932 | "output_type": "execute_result" 933 | } 934 | ], 935 | "source": [ 936 | "# Define embedding layer with vocab size and embedding dimension\n", 937 | "embedding_layer = torch.nn.Embedding(vocab_size, dim)\n", 938 | "\n", 939 | "# Copy pre-trained token embeddings to the embedding layer\n", 940 | "embedding_layer.weight.data.copy_(model[\"tok_embeddings.weight\"])\n", 941 | "\n", 942 | "# Get token embeddings for given tokens, converting to torch.bfloat16 format\n", 943 | "token_embeddings_unnormalized = embedding_layer(tokens).to(torch.bfloat16)\n", 944 | "\n", 945 | "# Print shape of resulting token embeddings\n", 946 | "token_embeddings_unnormalized.shape" 947 | ] 948 | }, 949 | { 950 | "cell_type": "markdown", 951 | "metadata": {}, 952 | "source": [ 953 | "These embeddings are not normalized, and it will have a serious effect if we don't normalize them. In the next section, we will perform normalization on our input vectors." 954 | ] 955 | }, 956 | { 957 | "cell_type": "markdown", 958 | "metadata": { 959 | "execution": { 960 | "iopub.status.busy": "2024-05-26T17:50:03.647697Z", 961 | "iopub.status.idle": "2024-05-26T17:50:03.648075Z", 962 | "shell.execute_reply": "2024-05-26T17:50:03.647918Z", 963 | "shell.execute_reply.started": "2024-05-26T17:50:03.647903Z" 964 | } 965 | }, 966 | "source": [ 967 | "## Normalization Using RMSNorm\n", 968 | "We will normalize the input vectors using the same formula we have seen earlier for RMSNorm to ensure our inputs are normalized.\n", 969 | "\n", 970 | "![](https://cdn-images-1.medium.com/max/1000/0*GIr8bvByN_iAGQBW.png)" 971 | ] 972 | }, 973 | { 974 | "cell_type": "code", 975 | "execution_count": 18, 976 | "metadata": { 977 | "execution": { 978 | "iopub.execute_input": "2024-05-27T03:00:11.090735Z", 979 | "iopub.status.busy": "2024-05-27T03:00:11.089634Z", 980 | "iopub.status.idle": "2024-05-27T03:00:11.098194Z", 981 | "shell.execute_reply": "2024-05-27T03:00:11.096941Z", 982 | "shell.execute_reply.started": "2024-05-27T03:00:11.090683Z" 983 | } 984 | }, 985 | "outputs": [], 986 | "source": [ 987 | "# Calculating RMSNorm\n", 988 | "def rms_norm(tensor, norm_weights):\n", 989 | "\n", 990 | " # Calculate the mean of the square of tensor values along the last dimension\n", 991 | " squared_mean = tensor.pow(2).mean(-1, keepdim=True)\n", 992 | " \n", 993 | " # Add a small value to avoid division by zero\n", 994 | " normalized = torch.rsqrt(squared_mean + norm_eps)\n", 995 | " \n", 996 | " # Multiply normalized tensor by the provided normalization weights\n", 997 | " return (tensor * normalized) * norm_weights" 998 | ] 999 | }, 1000 | { 1001 | "cell_type": "markdown", 1002 | "metadata": {}, 1003 | "source": [ 1004 | "We will use the attention weights from layers_0 to normalize our unnormalized embeddings. The reason for using layer_0 is that we are now creating the first layer of our LLaMA-3 transformer architecture." 1005 | ] 1006 | }, 1007 | { 1008 | "cell_type": "code", 1009 | "execution_count": 19, 1010 | "metadata": { 1011 | "execution": { 1012 | "iopub.execute_input": "2024-05-27T03:00:11.101006Z", 1013 | "iopub.status.busy": "2024-05-27T03:00:11.100196Z", 1014 | "iopub.status.idle": "2024-05-27T03:00:11.163803Z", 1015 | "shell.execute_reply": "2024-05-27T03:00:11.162310Z", 1016 | "shell.execute_reply.started": "2024-05-27T03:00:11.100974Z" 1017 | } 1018 | }, 1019 | "outputs": [ 1020 | { 1021 | "data": { 1022 | "text/plain": [ 1023 | "torch.Size([17, 4096])" 1024 | ] 1025 | }, 1026 | "execution_count": 19, 1027 | "metadata": {}, 1028 | "output_type": "execute_result" 1029 | } 1030 | ], 1031 | "source": [ 1032 | "# using RMS normalization and provided normalization weights\n", 1033 | "token_embeddings = rms_norm(token_embeddings_unnormalized, \n", 1034 | " model[\"layers.0.attention_norm.weight\"])\n", 1035 | "\n", 1036 | "# Print the shape of the resulting token embeddings\n", 1037 | "token_embeddings.shape" 1038 | ] 1039 | }, 1040 | { 1041 | "cell_type": "markdown", 1042 | "metadata": {}, 1043 | "source": [ 1044 | "You may already know that the dimension won't change because we are only normalizing the vectors and nothing else." 1045 | ] 1046 | }, 1047 | { 1048 | "cell_type": "markdown", 1049 | "metadata": {}, 1050 | "source": [ 1051 | "## Attention Heads (Query, Key, Values)\n", 1052 | "first, let's load the query, key, value and output vectors from the model." 1053 | ] 1054 | }, 1055 | { 1056 | "cell_type": "code", 1057 | "execution_count": 20, 1058 | "metadata": { 1059 | "execution": { 1060 | "iopub.execute_input": "2024-05-27T03:00:11.166431Z", 1061 | "iopub.status.busy": "2024-05-27T03:00:11.165746Z", 1062 | "iopub.status.idle": "2024-05-27T03:00:11.172693Z", 1063 | "shell.execute_reply": "2024-05-27T03:00:11.171756Z", 1064 | "shell.execute_reply.started": "2024-05-27T03:00:11.166398Z" 1065 | } 1066 | }, 1067 | "outputs": [ 1068 | { 1069 | "name": "stdout", 1070 | "output_type": "stream", 1071 | "text": [ 1072 | "torch.Size([4096, 4096]) torch.Size([1024, 4096]) torch.Size([1024, 4096]) torch.Size([4096, 4096])\n" 1073 | ] 1074 | } 1075 | ], 1076 | "source": [ 1077 | "# Print the shapes of different weights\n", 1078 | "print(\n", 1079 | " # Query weight shape\n", 1080 | " model[\"layers.0.attention.wq.weight\"].shape,\n", 1081 | " \n", 1082 | " # Key weight shape\n", 1083 | " model[\"layers.0.attention.wk.weight\"].shape,\n", 1084 | " \n", 1085 | " # Value weight shape\n", 1086 | " model[\"layers.0.attention.wv.weight\"].shape,\n", 1087 | " \n", 1088 | " # Output weight shape\n", 1089 | " model[\"layers.0.attention.wo.weight\"].shape\n", 1090 | ")" 1091 | ] 1092 | }, 1093 | { 1094 | "cell_type": "markdown", 1095 | "metadata": {}, 1096 | "source": [ 1097 | "The dimensions indicate that the model weights we downloaded are not for each head individually but for multiple attention heads due to implementing a parallel approach/training. However, we can unwrap these matrices to make them available for a single head only." 1098 | ] 1099 | }, 1100 | { 1101 | "cell_type": "code", 1102 | "execution_count": 21, 1103 | "metadata": { 1104 | "execution": { 1105 | "iopub.execute_input": "2024-05-27T03:00:11.175268Z", 1106 | "iopub.status.busy": "2024-05-27T03:00:11.174335Z", 1107 | "iopub.status.idle": "2024-05-27T03:00:11.188207Z", 1108 | "shell.execute_reply": "2024-05-27T03:00:11.186907Z", 1109 | "shell.execute_reply.started": "2024-05-27T03:00:11.175208Z" 1110 | } 1111 | }, 1112 | "outputs": [ 1113 | { 1114 | "data": { 1115 | "text/plain": [ 1116 | "torch.Size([32, 128, 4096])" 1117 | ] 1118 | }, 1119 | "execution_count": 21, 1120 | "metadata": {}, 1121 | "output_type": "execute_result" 1122 | } 1123 | ], 1124 | "source": [ 1125 | "# Retrieve query weight for the first layer of attention\n", 1126 | "q_layer0 = model[\"layers.0.attention.wq.weight\"]\n", 1127 | "\n", 1128 | "# Calculate dimension per head\n", 1129 | "head_dim = q_layer0.shape[0] // n_heads\n", 1130 | "\n", 1131 | "# Reshape query weight to separate heads\n", 1132 | "q_layer0 = q_layer0.view(n_heads, head_dim, dim)\n", 1133 | "\n", 1134 | "# Print the shape of the reshaped query weight tensor\n", 1135 | "q_layer0.shape" 1136 | ] 1137 | }, 1138 | { 1139 | "cell_type": "markdown", 1140 | "metadata": {}, 1141 | "source": [ 1142 | "Here, 32 is the number of attention heads in Llama-3, 128 is the size of the query vector, and 4096 is the size of the token embedding.\n", 1143 | "We can access the query weight matrix of the first head of the first layer using:" 1144 | ] 1145 | }, 1146 | { 1147 | "cell_type": "code", 1148 | "execution_count": 22, 1149 | "metadata": { 1150 | "execution": { 1151 | "iopub.execute_input": "2024-05-27T03:00:11.190921Z", 1152 | "iopub.status.busy": "2024-05-27T03:00:11.189926Z", 1153 | "iopub.status.idle": "2024-05-27T03:00:11.201264Z", 1154 | "shell.execute_reply": "2024-05-27T03:00:11.199956Z", 1155 | "shell.execute_reply.started": "2024-05-27T03:00:11.190869Z" 1156 | } 1157 | }, 1158 | "outputs": [ 1159 | { 1160 | "data": { 1161 | "text/plain": [ 1162 | "torch.Size([128, 4096])" 1163 | ] 1164 | }, 1165 | "execution_count": 22, 1166 | "metadata": {}, 1167 | "output_type": "execute_result" 1168 | } 1169 | ], 1170 | "source": [ 1171 | "# Extract the query weight for the first head of the first layer of attention\n", 1172 | "q_layer0_head0 = q_layer0[0]\n", 1173 | "\n", 1174 | "# Print the shape of the extracted query weight tensor for the first head\n", 1175 | "q_layer0_head0.shape" 1176 | ] 1177 | }, 1178 | { 1179 | "cell_type": "markdown", 1180 | "metadata": {}, 1181 | "source": [ 1182 | "To find the query vector for each token, we multiply the query weights with the token embedding." 1183 | ] 1184 | }, 1185 | { 1186 | "cell_type": "code", 1187 | "execution_count": 23, 1188 | "metadata": { 1189 | "execution": { 1190 | "iopub.execute_input": "2024-05-27T03:00:11.203858Z", 1191 | "iopub.status.busy": "2024-05-27T03:00:11.202973Z", 1192 | "iopub.status.idle": "2024-05-27T03:00:11.238942Z", 1193 | "shell.execute_reply": "2024-05-27T03:00:11.237663Z", 1194 | "shell.execute_reply.started": "2024-05-27T03:00:11.203821Z" 1195 | } 1196 | }, 1197 | "outputs": [ 1198 | { 1199 | "data": { 1200 | "text/plain": [ 1201 | "torch.Size([17, 128])" 1202 | ] 1203 | }, 1204 | "execution_count": 23, 1205 | "metadata": {}, 1206 | "output_type": "execute_result" 1207 | } 1208 | ], 1209 | "source": [ 1210 | "# Matrix multiplication: token embeddings with transpose of query weight for first head\n", 1211 | "q_per_token = torch.matmul(token_embeddings, q_layer0_head0.T)\n", 1212 | "\n", 1213 | "# Shape of resulting tensor: queries per token\n", 1214 | "q_per_token.shape" 1215 | ] 1216 | }, 1217 | { 1218 | "cell_type": "markdown", 1219 | "metadata": {}, 1220 | "source": [ 1221 | "The query vectors don't inherently know their position in the prompt, so we'll use RoPE to make them aware of it." 1222 | ] 1223 | }, 1224 | { 1225 | "cell_type": "markdown", 1226 | "metadata": { 1227 | "execution": { 1228 | "iopub.execute_input": "2024-05-26T18:02:35.200388Z", 1229 | "iopub.status.busy": "2024-05-26T18:02:35.199990Z", 1230 | "iopub.status.idle": "2024-05-26T18:02:35.207049Z", 1231 | "shell.execute_reply": "2024-05-26T18:02:35.205839Z", 1232 | "shell.execute_reply.started": "2024-05-26T18:02:35.200358Z" 1233 | } 1234 | }, 1235 | "source": [ 1236 | "## Implementing RoPE\n", 1237 | "\n", 1238 | "We split the query vectors into pairs and then apply a rotational angle shift to each pair." 1239 | ] 1240 | }, 1241 | { 1242 | "cell_type": "code", 1243 | "execution_count": 24, 1244 | "metadata": { 1245 | "execution": { 1246 | "iopub.execute_input": "2024-05-27T03:00:11.241095Z", 1247 | "iopub.status.busy": "2024-05-27T03:00:11.240650Z", 1248 | "iopub.status.idle": "2024-05-27T03:00:11.251282Z", 1249 | "shell.execute_reply": "2024-05-27T03:00:11.249919Z", 1250 | "shell.execute_reply.started": "2024-05-27T03:00:11.241053Z" 1251 | } 1252 | }, 1253 | "outputs": [ 1254 | { 1255 | "data": { 1256 | "text/plain": [ 1257 | "torch.Size([17, 64, 2])" 1258 | ] 1259 | }, 1260 | "execution_count": 24, 1261 | "metadata": {}, 1262 | "output_type": "execute_result" 1263 | } 1264 | ], 1265 | "source": [ 1266 | "# Convert queries per token to float and split into pairs\n", 1267 | "q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n", 1268 | "\n", 1269 | "# Print the shape of the resulting tensor after splitting into pairs\n", 1270 | "q_per_token_split_into_pairs.shape" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "markdown", 1275 | "metadata": {}, 1276 | "source": [ 1277 | "We have a vector of size [17x64x2], which represents the 128-length queries split into 64 pairs for each token in the prompt. Each pair will be rotated by m*theta, where m is the position of the token for which we are rotating the query.\n", 1278 | "We'll use the dot product of complex numbers to rotate a vector." 1279 | ] 1280 | }, 1281 | { 1282 | "cell_type": "code", 1283 | "execution_count": 25, 1284 | "metadata": { 1285 | "execution": { 1286 | "iopub.execute_input": "2024-05-27T03:00:11.253499Z", 1287 | "iopub.status.busy": "2024-05-27T03:00:11.252995Z", 1288 | "iopub.status.idle": "2024-05-27T03:00:11.296362Z", 1289 | "shell.execute_reply": "2024-05-27T03:00:11.294937Z", 1290 | "shell.execute_reply.started": "2024-05-27T03:00:11.253449Z" 1291 | } 1292 | }, 1293 | "outputs": [ 1294 | { 1295 | "data": { 1296 | "text/plain": [ 1297 | "tensor([0.0000, 0.0156, 0.0312, 0.0469, 0.0625, 0.0781, 0.0938, 0.1094, 0.1250,\n", 1298 | " 0.1406, 0.1562, 0.1719, 0.1875, 0.2031, 0.2188, 0.2344, 0.2500, 0.2656,\n", 1299 | " 0.2812, 0.2969, 0.3125, 0.3281, 0.3438, 0.3594, 0.3750, 0.3906, 0.4062,\n", 1300 | " 0.4219, 0.4375, 0.4531, 0.4688, 0.4844, 0.5000, 0.5156, 0.5312, 0.5469,\n", 1301 | " 0.5625, 0.5781, 0.5938, 0.6094, 0.6250, 0.6406, 0.6562, 0.6719, 0.6875,\n", 1302 | " 0.7031, 0.7188, 0.7344, 0.7500, 0.7656, 0.7812, 0.7969, 0.8125, 0.8281,\n", 1303 | " 0.8438, 0.8594, 0.8750, 0.8906, 0.9062, 0.9219, 0.9375, 0.9531, 0.9688,\n", 1304 | " 0.9844])" 1305 | ] 1306 | }, 1307 | "execution_count": 25, 1308 | "metadata": {}, 1309 | "output_type": "execute_result" 1310 | } 1311 | ], 1312 | "source": [ 1313 | "# Generate values from 0 to 1 split into 64 parts\n", 1314 | "zero_to_one_split_into_64_parts = torch.tensor(range(64))/64\n", 1315 | "\n", 1316 | "# Print the resulting tensor\n", 1317 | "zero_to_one_split_into_64_parts" 1318 | ] 1319 | }, 1320 | { 1321 | "cell_type": "markdown", 1322 | "metadata": {}, 1323 | "source": [ 1324 | "After the splitting step, we are going to calculate the frequency of it." 1325 | ] 1326 | }, 1327 | { 1328 | "cell_type": "code", 1329 | "execution_count": 26, 1330 | "metadata": { 1331 | "execution": { 1332 | "iopub.execute_input": "2024-05-27T03:00:11.307426Z", 1333 | "iopub.status.busy": "2024-05-27T03:00:11.306558Z", 1334 | "iopub.status.idle": "2024-05-27T03:00:11.319484Z", 1335 | "shell.execute_reply": "2024-05-27T03:00:11.318071Z", 1336 | "shell.execute_reply.started": "2024-05-27T03:00:11.307379Z" 1337 | } 1338 | }, 1339 | "outputs": [ 1340 | { 1341 | "data": { 1342 | "text/plain": [ 1343 | "tensor([1.0000e+00, 8.1462e-01, 6.6360e-01, 5.4058e-01, 4.4037e-01, 3.5873e-01,\n", 1344 | " 2.9223e-01, 2.3805e-01, 1.9392e-01, 1.5797e-01, 1.2869e-01, 1.0483e-01,\n", 1345 | " 8.5397e-02, 6.9566e-02, 5.6670e-02, 4.6164e-02, 3.7606e-02, 3.0635e-02,\n", 1346 | " 2.4955e-02, 2.0329e-02, 1.6560e-02, 1.3490e-02, 1.0990e-02, 8.9523e-03,\n", 1347 | " 7.2927e-03, 5.9407e-03, 4.8394e-03, 3.9423e-03, 3.2114e-03, 2.6161e-03,\n", 1348 | " 2.1311e-03, 1.7360e-03, 1.4142e-03, 1.1520e-03, 9.3847e-04, 7.6450e-04,\n", 1349 | " 6.2277e-04, 5.0732e-04, 4.1327e-04, 3.3666e-04, 2.7425e-04, 2.2341e-04,\n", 1350 | " 1.8199e-04, 1.4825e-04, 1.2077e-04, 9.8381e-05, 8.0143e-05, 6.5286e-05,\n", 1351 | " 5.3183e-05, 4.3324e-05, 3.5292e-05, 2.8750e-05, 2.3420e-05, 1.9078e-05,\n", 1352 | " 1.5542e-05, 1.2660e-05, 1.0313e-05, 8.4015e-06, 6.8440e-06, 5.5752e-06,\n", 1353 | " 4.5417e-06, 3.6997e-06, 3.0139e-06, 2.4551e-06])" 1354 | ] 1355 | }, 1356 | "execution_count": 26, 1357 | "metadata": {}, 1358 | "output_type": "execute_result" 1359 | } 1360 | ], 1361 | "source": [ 1362 | "# Calculate frequencies using a power operation\n", 1363 | "freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts)\n", 1364 | "\n", 1365 | "# Display the resulting frequencies\n", 1366 | "freqs" 1367 | ] 1368 | }, 1369 | { 1370 | "cell_type": "markdown", 1371 | "metadata": {}, 1372 | "source": [ 1373 | "Now, with a complex number for each token's query element, we convert our queries into complex numbers and then rotate them based on their position using dot product." 1374 | ] 1375 | }, 1376 | { 1377 | "cell_type": "code", 1378 | "execution_count": 27, 1379 | "metadata": { 1380 | "execution": { 1381 | "iopub.execute_input": "2024-05-27T03:00:11.321968Z", 1382 | "iopub.status.busy": "2024-05-27T03:00:11.321278Z", 1383 | "iopub.status.idle": "2024-05-27T03:00:11.344481Z", 1384 | "shell.execute_reply": "2024-05-27T03:00:11.343316Z", 1385 | "shell.execute_reply.started": "2024-05-27T03:00:11.321926Z" 1386 | } 1387 | }, 1388 | "outputs": [ 1389 | { 1390 | "data": { 1391 | "text/plain": [ 1392 | "torch.Size([17, 64])" 1393 | ] 1394 | }, 1395 | "execution_count": 27, 1396 | "metadata": {}, 1397 | "output_type": "execute_result" 1398 | } 1399 | ], 1400 | "source": [ 1401 | "# Convert queries per token to complex numbers\n", 1402 | "q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n", 1403 | "\n", 1404 | "q_per_token_as_complex_numbers.shape\n", 1405 | "# Output: torch.Size([17, 64])\n", 1406 | "\n", 1407 | "# Calculate frequencies for each token using outer product of arange(17) and freqs\n", 1408 | "freqs_for_each_token = torch.outer(torch.arange(17), freqs)\n", 1409 | "\n", 1410 | "# Calculate complex numbers from frequencies_for_each_token using polar coordinates\n", 1411 | "freqs_cis = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token)\n", 1412 | "\n", 1413 | "# Rotate complex numbers by frequencies\n", 1414 | "q_per_token_as_complex_numbers_rotated = q_per_token_as_complex_numbers * freqs_cis\n", 1415 | "\n", 1416 | "q_per_token_as_complex_numbers_rotated.shape\n", 1417 | "# Output: torch.Size([17, 64])" 1418 | ] 1419 | }, 1420 | { 1421 | "cell_type": "markdown", 1422 | "metadata": {}, 1423 | "source": [ 1424 | "After obtaining the rotated vector, we can revert back to our original queries as pairs by viewing the complex numbers as real numbers again." 1425 | ] 1426 | }, 1427 | { 1428 | "cell_type": "code", 1429 | "execution_count": 28, 1430 | "metadata": { 1431 | "execution": { 1432 | "iopub.execute_input": "2024-05-27T03:00:11.346345Z", 1433 | "iopub.status.busy": "2024-05-27T03:00:11.345950Z", 1434 | "iopub.status.idle": "2024-05-27T03:00:11.355632Z", 1435 | "shell.execute_reply": "2024-05-27T03:00:11.354135Z", 1436 | "shell.execute_reply.started": "2024-05-27T03:00:11.346316Z" 1437 | } 1438 | }, 1439 | "outputs": [ 1440 | { 1441 | "data": { 1442 | "text/plain": [ 1443 | "torch.Size([17, 64, 2])" 1444 | ] 1445 | }, 1446 | "execution_count": 28, 1447 | "metadata": {}, 1448 | "output_type": "execute_result" 1449 | } 1450 | ], 1451 | "source": [ 1452 | "# Convert rotated complex numbers back to real numbers\n", 1453 | "q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers_rotated)\n", 1454 | "\n", 1455 | "# Print the shape of the resulting tensor\n", 1456 | "q_per_token_split_into_pairs_rotated.shape" 1457 | ] 1458 | }, 1459 | { 1460 | "cell_type": "markdown", 1461 | "metadata": {}, 1462 | "source": [ 1463 | "The rotated pairs are now merged, resulting in a new query vector (rotated query vector) that has the shape [17x128], where 17 is the number of tokens and 128 is the dimension of the query vector." 1464 | ] 1465 | }, 1466 | { 1467 | "cell_type": "code", 1468 | "execution_count": 29, 1469 | "metadata": { 1470 | "execution": { 1471 | "iopub.execute_input": "2024-05-27T03:00:11.357857Z", 1472 | "iopub.status.busy": "2024-05-27T03:00:11.357408Z", 1473 | "iopub.status.idle": "2024-05-27T03:00:11.367948Z", 1474 | "shell.execute_reply": "2024-05-27T03:00:11.366558Z", 1475 | "shell.execute_reply.started": "2024-05-27T03:00:11.357811Z" 1476 | } 1477 | }, 1478 | "outputs": [ 1479 | { 1480 | "data": { 1481 | "text/plain": [ 1482 | "torch.Size([17, 128])" 1483 | ] 1484 | }, 1485 | "execution_count": 29, 1486 | "metadata": {}, 1487 | "output_type": "execute_result" 1488 | } 1489 | ], 1490 | "source": [ 1491 | "# Reshape rotated token queries to match the original shape\n", 1492 | "q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n", 1493 | "\n", 1494 | "# Print the shape of the resulting tensor\n", 1495 | "q_per_token_rotated.shape" 1496 | ] 1497 | }, 1498 | { 1499 | "cell_type": "markdown", 1500 | "metadata": {}, 1501 | "source": [ 1502 | "For keys, the process is similar, but keep in mind that key vectors are also 128-dimensional. Keys have only 1/4th the number of weights as queries because they are shared across 4 heads at a time to minimize computations. Keys are also rotated to include positional information, similar to queries." 1503 | ] 1504 | }, 1505 | { 1506 | "cell_type": "code", 1507 | "execution_count": 30, 1508 | "metadata": { 1509 | "execution": { 1510 | "iopub.execute_input": "2024-05-27T03:00:11.370253Z", 1511 | "iopub.status.busy": "2024-05-27T03:00:11.369660Z", 1512 | "iopub.status.idle": "2024-05-27T03:00:11.395550Z", 1513 | "shell.execute_reply": "2024-05-27T03:00:11.394310Z", 1514 | "shell.execute_reply.started": "2024-05-27T03:00:11.370190Z" 1515 | } 1516 | }, 1517 | "outputs": [ 1518 | { 1519 | "data": { 1520 | "text/plain": [ 1521 | "torch.Size([17, 128])" 1522 | ] 1523 | }, 1524 | "execution_count": 30, 1525 | "metadata": {}, 1526 | "output_type": "execute_result" 1527 | } 1528 | ], 1529 | "source": [ 1530 | "# Extract the weight tensor for the attention mechanism's key in the first layer of the model\n", 1531 | "k_layer0 = model[\"layers.0.attention.wk.weight\"]\n", 1532 | "\n", 1533 | "# Reshape key weight for the first layer of attention to separate heads\n", 1534 | "k_layer0 = k_layer0.view(n_kv_heads, k_layer0.shape[0] // n_kv_heads, dim)\n", 1535 | "\n", 1536 | "# Print the shape of the reshaped key weight tensor\n", 1537 | "k_layer0.shape # Output: torch.Size([8, 128, 4096])\n", 1538 | "\n", 1539 | "# Extract the key weight for the first head of the first layer of attention\n", 1540 | "k_layer0_head0 = k_layer0[0]\n", 1541 | "\n", 1542 | "# Print the shape of the extracted key weight tensor for the first head\n", 1543 | "k_layer0_head0.shape # Output: torch.Size([128, 4096])\n", 1544 | "\n", 1545 | "# Calculate key per token by matrix multiplication\n", 1546 | "k_per_token = torch.matmul(token_embeddings, k_layer0_head0.T)\n", 1547 | "\n", 1548 | "# Print the shape of the resulting tensor representing keys per token\n", 1549 | "k_per_token.shape # Output: torch.Size([17, 128])\n", 1550 | "\n", 1551 | "# Split key per token into pairs and convert to float\n", 1552 | "k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n", 1553 | "\n", 1554 | "# Print the shape of the resulting tensor after splitting into pairs\n", 1555 | "k_per_token_split_into_pairs.shape # Output: torch.Size([17, 64, 2])\n", 1556 | "\n", 1557 | "# Convert key per token to complex numbers\n", 1558 | "k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n", 1559 | "\n", 1560 | "# Print the shape of the resulting tensor representing key per token as complex numbers\n", 1561 | "k_per_token_as_complex_numbers.shape # Output: torch.Size([17, 64])\n", 1562 | "\n", 1563 | "# Rotate complex key per token by frequencies\n", 1564 | "k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)\n", 1565 | "\n", 1566 | "# Print the shape of the rotated complex key per token\n", 1567 | "k_per_token_split_into_pairs_rotated.shape # Output: torch.Size([17, 64, 2])\n", 1568 | "\n", 1569 | "# Reshape rotated key per token to match the original shape\n", 1570 | "k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n", 1571 | "\n", 1572 | "# Print the shape of the rotated key per token\n", 1573 | "k_per_token_rotated.shape # Output: torch.Size([17, 128])" 1574 | ] 1575 | }, 1576 | { 1577 | "cell_type": "markdown", 1578 | "metadata": {}, 1579 | "source": [ 1580 | "We now have the rotated queries and keys for each token, with each being of size [17x128]." 1581 | ] 1582 | }, 1583 | { 1584 | "cell_type": "markdown", 1585 | "metadata": {}, 1586 | "source": [ 1587 | "## Implementing Self Attention\n", 1588 | "Multiplying the query and key matrices will give us a score that maps each token to another. This score indicates the relationship between each token's query and key." 1589 | ] 1590 | }, 1591 | { 1592 | "cell_type": "code", 1593 | "execution_count": 31, 1594 | "metadata": { 1595 | "execution": { 1596 | "iopub.execute_input": "2024-05-27T03:00:11.397911Z", 1597 | "iopub.status.busy": "2024-05-27T03:00:11.397119Z", 1598 | "iopub.status.idle": "2024-05-27T03:00:11.409650Z", 1599 | "shell.execute_reply": "2024-05-27T03:00:11.408282Z", 1600 | "shell.execute_reply.started": "2024-05-27T03:00:11.397867Z" 1601 | } 1602 | }, 1603 | "outputs": [ 1604 | { 1605 | "data": { 1606 | "text/plain": [ 1607 | "torch.Size([17, 17])" 1608 | ] 1609 | }, 1610 | "execution_count": 31, 1611 | "metadata": {}, 1612 | "output_type": "execute_result" 1613 | } 1614 | ], 1615 | "source": [ 1616 | "# Calculate query-key dot products per token\n", 1617 | "qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (head_dim) ** 0.5\n", 1618 | "\n", 1619 | "# Print the shape of the resulting tensor representing query-key dot products per token\n", 1620 | "qk_per_token.shape" 1621 | ] 1622 | }, 1623 | { 1624 | "cell_type": "markdown", 1625 | "metadata": {}, 1626 | "source": [ 1627 | "[17x17] Shape represents attention score (qk_per_token) where 17 is the number of tokens in the prompt.\n", 1628 | "We need to mask the query-key scores. During training, future token query-key scores are masked because we only learn to predict tokens using past tokens. As a result, during inference, we set the future tokens to zero." 1629 | ] 1630 | }, 1631 | { 1632 | "cell_type": "code", 1633 | "execution_count": 32, 1634 | "metadata": { 1635 | "execution": { 1636 | "iopub.execute_input": "2024-05-27T03:00:11.418619Z", 1637 | "iopub.status.busy": "2024-05-27T03:00:11.418065Z", 1638 | "iopub.status.idle": "2024-05-27T03:00:11.432397Z", 1639 | "shell.execute_reply": "2024-05-27T03:00:11.431065Z", 1640 | "shell.execute_reply.started": "2024-05-27T03:00:11.418586Z" 1641 | } 1642 | }, 1643 | "outputs": [ 1644 | { 1645 | "data": { 1646 | "text/plain": [ 1647 | "tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1648 | " [0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1649 | " [0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1650 | " [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1651 | " [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1652 | " [0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1653 | " [0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1654 | " [0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1655 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1656 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf],\n", 1657 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],\n", 1658 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],\n", 1659 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf, -inf],\n", 1660 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf, -inf],\n", 1661 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf, -inf],\n", 1662 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., -inf],\n", 1663 | " [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])" 1664 | ] 1665 | }, 1666 | "execution_count": 32, 1667 | "metadata": {}, 1668 | "output_type": "execute_result" 1669 | } 1670 | ], 1671 | "source": [ 1672 | "# Create a mask tensor filled with negative infinity values\n", 1673 | "mask = torch.full((len(tokens), len(tokens)), float(\"-inf\"), device=tokens.device)\n", 1674 | "\n", 1675 | "# Set upper triangular part of the mask tensor to negative infinity\n", 1676 | "mask = torch.triu(mask, diagonal=1)\n", 1677 | "\n", 1678 | "# Print the resulting mask tensor\n", 1679 | "mask" 1680 | ] 1681 | }, 1682 | { 1683 | "cell_type": "markdown", 1684 | "metadata": {}, 1685 | "source": [ 1686 | "Now, we have to apply a mask to the query-key per token vector. Additionally, we want to apply softmax on top of it to convert the output scores into probabilities. This helps in selecting the most likely token or sequence of tokens from the model's vocabulary, making the model's predictions more interpretable and suitable for tasks like language generation and classification." 1687 | ] 1688 | }, 1689 | { 1690 | "cell_type": "code", 1691 | "execution_count": 33, 1692 | "metadata": { 1693 | "execution": { 1694 | "iopub.execute_input": "2024-05-27T03:00:11.436035Z", 1695 | "iopub.status.busy": "2024-05-27T03:00:11.434670Z", 1696 | "iopub.status.idle": "2024-05-27T03:00:11.445875Z", 1697 | "shell.execute_reply": "2024-05-27T03:00:11.444453Z", 1698 | "shell.execute_reply.started": "2024-05-27T03:00:11.435949Z" 1699 | } 1700 | }, 1701 | "outputs": [], 1702 | "source": [ 1703 | "# Add the mask to the query-key dot products per token\n", 1704 | "qk_per_token_after_masking = qk_per_token + mask\n", 1705 | "\n", 1706 | "# Apply softmax along the second dimension after masking\n", 1707 | "qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)" 1708 | ] 1709 | }, 1710 | { 1711 | "cell_type": "markdown", 1712 | "metadata": {}, 1713 | "source": [ 1714 | "For the value matrix, which marks the end of the self-attention part, similar to keys, value weights are also shared across every 4 attention heads to save computation. As a result, the shape of the value weight matrix is [8x128x4096]." 1715 | ] 1716 | }, 1717 | { 1718 | "cell_type": "code", 1719 | "execution_count": 34, 1720 | "metadata": { 1721 | "execution": { 1722 | "iopub.execute_input": "2024-05-27T03:00:11.448024Z", 1723 | "iopub.status.busy": "2024-05-27T03:00:11.447547Z", 1724 | "iopub.status.idle": "2024-05-27T03:00:11.458656Z", 1725 | "shell.execute_reply": "2024-05-27T03:00:11.457431Z", 1726 | "shell.execute_reply.started": "2024-05-27T03:00:11.447985Z" 1727 | } 1728 | }, 1729 | "outputs": [ 1730 | { 1731 | "data": { 1732 | "text/plain": [ 1733 | "torch.Size([8, 128, 4096])" 1734 | ] 1735 | }, 1736 | "execution_count": 34, 1737 | "metadata": {}, 1738 | "output_type": "execute_result" 1739 | } 1740 | ], 1741 | "source": [ 1742 | "# Retrieve the value weight for the first layer of attention\n", 1743 | "v_layer0 = model[\"layers.0.attention.wv.weight\"]\n", 1744 | "\n", 1745 | "# Reshape value weight for the first layer of attention to separate heads\n", 1746 | "v_layer0 = v_layer0.view(n_kv_heads, v_layer0.shape[0] // n_kv_heads, dim)\n", 1747 | "\n", 1748 | "# Print the shape of the reshaped value weight tensor\n", 1749 | "v_layer0.shape" 1750 | ] 1751 | }, 1752 | { 1753 | "cell_type": "markdown", 1754 | "metadata": {}, 1755 | "source": [ 1756 | "Similar to the query and key matrices, the value matrix for the first layer and first head can be obtained using:" 1757 | ] 1758 | }, 1759 | { 1760 | "cell_type": "code", 1761 | "execution_count": 35, 1762 | "metadata": { 1763 | "execution": { 1764 | "iopub.execute_input": "2024-05-27T03:00:11.461191Z", 1765 | "iopub.status.busy": "2024-05-27T03:00:11.460270Z", 1766 | "iopub.status.idle": "2024-05-27T03:00:11.471921Z", 1767 | "shell.execute_reply": "2024-05-27T03:00:11.470436Z", 1768 | "shell.execute_reply.started": "2024-05-27T03:00:11.461146Z" 1769 | } 1770 | }, 1771 | "outputs": [ 1772 | { 1773 | "data": { 1774 | "text/plain": [ 1775 | "torch.Size([128, 4096])" 1776 | ] 1777 | }, 1778 | "execution_count": 35, 1779 | "metadata": {}, 1780 | "output_type": "execute_result" 1781 | } 1782 | ], 1783 | "source": [ 1784 | "# Extract the value weight for the first head of the first layer of attention\n", 1785 | "v_layer0_head0 = v_layer0[0]\n", 1786 | "\n", 1787 | "# Print the shape of the extracted value weight tensor for the first head\n", 1788 | "v_layer0_head0.shape" 1789 | ] 1790 | }, 1791 | { 1792 | "cell_type": "markdown", 1793 | "metadata": {}, 1794 | "source": [ 1795 | "Using the value weights, we compute the attention values for each token, resulting in a matrix of size [17x128]. Here, 17 denotes the number of tokens in the prompt, and 128 indicates the dimension of the value vector for each token." 1796 | ] 1797 | }, 1798 | { 1799 | "cell_type": "code", 1800 | "execution_count": 36, 1801 | "metadata": { 1802 | "execution": { 1803 | "iopub.execute_input": "2024-05-27T03:00:11.475738Z", 1804 | "iopub.status.busy": "2024-05-27T03:00:11.473685Z", 1805 | "iopub.status.idle": "2024-05-27T03:00:11.493795Z", 1806 | "shell.execute_reply": "2024-05-27T03:00:11.492635Z", 1807 | "shell.execute_reply.started": "2024-05-27T03:00:11.475692Z" 1808 | } 1809 | }, 1810 | "outputs": [ 1811 | { 1812 | "data": { 1813 | "text/plain": [ 1814 | "torch.Size([17, 128])" 1815 | ] 1816 | }, 1817 | "execution_count": 36, 1818 | "metadata": {}, 1819 | "output_type": "execute_result" 1820 | } 1821 | ], 1822 | "source": [ 1823 | "# Calculate value per token by matrix multiplication\n", 1824 | "v_per_token = torch.matmul(token_embeddings, v_layer0_head0.T)\n", 1825 | "\n", 1826 | "# Print the shape of the resulting tensor representing values per token\n", 1827 | "v_per_token.shape" 1828 | ] 1829 | }, 1830 | { 1831 | "cell_type": "markdown", 1832 | "metadata": {}, 1833 | "source": [ 1834 | "To obtain the resulting attention matrix, we can perform the following multiplication:" 1835 | ] 1836 | }, 1837 | { 1838 | "cell_type": "code", 1839 | "execution_count": 37, 1840 | "metadata": { 1841 | "execution": { 1842 | "iopub.execute_input": "2024-05-27T03:00:11.496509Z", 1843 | "iopub.status.busy": "2024-05-27T03:00:11.495676Z", 1844 | "iopub.status.idle": "2024-05-27T03:00:11.505015Z", 1845 | "shell.execute_reply": "2024-05-27T03:00:11.503901Z", 1846 | "shell.execute_reply.started": "2024-05-27T03:00:11.496460Z" 1847 | } 1848 | }, 1849 | "outputs": [ 1850 | { 1851 | "data": { 1852 | "text/plain": [ 1853 | "torch.Size([17, 128])" 1854 | ] 1855 | }, 1856 | "execution_count": 37, 1857 | "metadata": {}, 1858 | "output_type": "execute_result" 1859 | } 1860 | ], 1861 | "source": [ 1862 | "# Calculate QKV attention by matrix multiplication\n", 1863 | "qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n", 1864 | "\n", 1865 | "# Print the shape of the resulting tensor\n", 1866 | "qkv_attention.shape" 1867 | ] 1868 | }, 1869 | { 1870 | "cell_type": "markdown", 1871 | "metadata": {}, 1872 | "source": [ 1873 | "We now have the attention values for the first layer and first head or in other words self attention." 1874 | ] 1875 | }, 1876 | { 1877 | "cell_type": "markdown", 1878 | "metadata": {}, 1879 | "source": [ 1880 | "## Implementing Multi-Head Attention\n", 1881 | "\n", 1882 | "A loop will be executed to perform the same calculations as above, but for every head in the first layer." 1883 | ] 1884 | }, 1885 | { 1886 | "cell_type": "code", 1887 | "execution_count": 38, 1888 | "metadata": { 1889 | "execution": { 1890 | "iopub.execute_input": "2024-05-27T03:00:11.506930Z", 1891 | "iopub.status.busy": "2024-05-27T03:00:11.506558Z", 1892 | "iopub.status.idle": "2024-05-27T03:00:12.386253Z", 1893 | "shell.execute_reply": "2024-05-27T03:00:12.385092Z", 1894 | "shell.execute_reply.started": "2024-05-27T03:00:11.506901Z" 1895 | } 1896 | }, 1897 | "outputs": [ 1898 | { 1899 | "data": { 1900 | "text/plain": [ 1901 | "32" 1902 | ] 1903 | }, 1904 | "execution_count": 38, 1905 | "metadata": {}, 1906 | "output_type": "execute_result" 1907 | } 1908 | ], 1909 | "source": [ 1910 | "# Store QKV attention for each head in a list\n", 1911 | "qkv_attention_store = []\n", 1912 | "\n", 1913 | "# Iterate through each head\n", 1914 | "for head in range(n_heads):\n", 1915 | " # Extract query, key, and value weights for the current head\n", 1916 | " q_layer0_head = q_layer0[head]\n", 1917 | " k_layer0_head = k_layer0[head//4] # Key weights are shared across 4 heads\n", 1918 | " v_layer0_head = v_layer0[head//4] # Value weights are shared across 4 heads\n", 1919 | " \n", 1920 | " # Calculate query per token by matrix multiplication\n", 1921 | " q_per_token = torch.matmul(token_embeddings, q_layer0_head.T)\n", 1922 | " \n", 1923 | " # Calculate key per token by matrix multiplication\n", 1924 | " k_per_token = torch.matmul(token_embeddings, k_layer0_head.T)\n", 1925 | " \n", 1926 | " # Calculate value per token by matrix multiplication\n", 1927 | " v_per_token = torch.matmul(token_embeddings, v_layer0_head.T)\n", 1928 | " \n", 1929 | " # Split query per token into pairs and rotate them\n", 1930 | " q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n", 1931 | " q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n", 1932 | " q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis[:len(tokens)])\n", 1933 | " q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n", 1934 | " \n", 1935 | " # Split key per token into pairs and rotate them\n", 1936 | " k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n", 1937 | " k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n", 1938 | " k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis[:len(tokens)])\n", 1939 | " k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n", 1940 | " \n", 1941 | " # Calculate query-key dot products per token\n", 1942 | " qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5\n", 1943 | " \n", 1944 | " # Create a mask tensor filled with negative infinity values\n", 1945 | " mask = torch.full((len(tokens), len(tokens)), float(\"-inf\"), device=tokens.device)\n", 1946 | " # Set upper triangular part of the mask tensor to negative infinity\n", 1947 | " mask = torch.triu(mask, diagonal=1)\n", 1948 | " # Add the mask to the query-key dot products per token\n", 1949 | " qk_per_token_after_masking = qk_per_token + mask\n", 1950 | " \n", 1951 | " # Apply softmax along the second dimension after masking\n", 1952 | " qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)\n", 1953 | " \n", 1954 | " # Calculate QKV attention by matrix multiplication\n", 1955 | " qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n", 1956 | " \n", 1957 | " # Store QKV attention for the current head\n", 1958 | " qkv_attention_store.append(qkv_attention)\n", 1959 | "\n", 1960 | "# Print the number of QKV attentions stored\n", 1961 | "len(qkv_attention_store)" 1962 | ] 1963 | }, 1964 | { 1965 | "cell_type": "markdown", 1966 | "metadata": {}, 1967 | "source": [ 1968 | "Now that the QKV attention matrix for all 32 heads in the first layer is obtained, all attention scores will be merged into one large matrix of size [17x4096]." 1969 | ] 1970 | }, 1971 | { 1972 | "cell_type": "code", 1973 | "execution_count": 39, 1974 | "metadata": { 1975 | "execution": { 1976 | "iopub.execute_input": "2024-05-27T03:00:12.388623Z", 1977 | "iopub.status.busy": "2024-05-27T03:00:12.388044Z", 1978 | "iopub.status.idle": "2024-05-27T03:00:12.400295Z", 1979 | "shell.execute_reply": "2024-05-27T03:00:12.399179Z", 1980 | "shell.execute_reply.started": "2024-05-27T03:00:12.388591Z" 1981 | } 1982 | }, 1983 | "outputs": [ 1984 | { 1985 | "data": { 1986 | "text/plain": [ 1987 | "torch.Size([17, 4096])" 1988 | ] 1989 | }, 1990 | "execution_count": 39, 1991 | "metadata": {}, 1992 | "output_type": "execute_result" 1993 | } 1994 | ], 1995 | "source": [ 1996 | "# Concatenate QKV attentions from all heads along the last dimension\n", 1997 | "stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)\n", 1998 | "\n", 1999 | "# Print the shape of the resulting tensor\n", 2000 | "stacked_qkv_attention.shape" 2001 | ] 2002 | }, 2003 | { 2004 | "cell_type": "markdown", 2005 | "metadata": {}, 2006 | "source": [ 2007 | "One of the last steps for layer 0 attention is to multiply the weight matrix with the stacked QKV matrix." 2008 | ] 2009 | }, 2010 | { 2011 | "cell_type": "code", 2012 | "execution_count": 40, 2013 | "metadata": { 2014 | "execution": { 2015 | "iopub.execute_input": "2024-05-27T03:00:12.402289Z", 2016 | "iopub.status.busy": "2024-05-27T03:00:12.401818Z", 2017 | "iopub.status.idle": "2024-05-27T03:00:12.647349Z", 2018 | "shell.execute_reply": "2024-05-27T03:00:12.646269Z", 2019 | "shell.execute_reply.started": "2024-05-27T03:00:12.402248Z" 2020 | } 2021 | }, 2022 | "outputs": [ 2023 | { 2024 | "data": { 2025 | "text/plain": [ 2026 | "torch.Size([17, 4096])" 2027 | ] 2028 | }, 2029 | "execution_count": 40, 2030 | "metadata": {}, 2031 | "output_type": "execute_result" 2032 | } 2033 | ], 2034 | "source": [ 2035 | "# Calculate the embedding delta by matrix multiplication with the output weight\n", 2036 | "embedding_delta = torch.matmul(stacked_qkv_attention, model[\"layers.0.attention.wo.weight\"].T)\n", 2037 | "\n", 2038 | "# Print the shape of the resulting tensor\n", 2039 | "embedding_delta.shape" 2040 | ] 2041 | }, 2042 | { 2043 | "cell_type": "markdown", 2044 | "metadata": {}, 2045 | "source": [ 2046 | "We now have the change in the embedding values after attention, which should be added to the original token embeddings." 2047 | ] 2048 | }, 2049 | { 2050 | "cell_type": "code", 2051 | "execution_count": 41, 2052 | "metadata": { 2053 | "execution": { 2054 | "iopub.execute_input": "2024-05-27T03:00:12.649046Z", 2055 | "iopub.status.busy": "2024-05-27T03:00:12.648711Z", 2056 | "iopub.status.idle": "2024-05-27T03:00:12.656370Z", 2057 | "shell.execute_reply": "2024-05-27T03:00:12.655063Z", 2058 | "shell.execute_reply.started": "2024-05-27T03:00:12.649018Z" 2059 | } 2060 | }, 2061 | "outputs": [ 2062 | { 2063 | "data": { 2064 | "text/plain": [ 2065 | "torch.Size([17, 4096])" 2066 | ] 2067 | }, 2068 | "execution_count": 41, 2069 | "metadata": {}, 2070 | "output_type": "execute_result" 2071 | } 2072 | ], 2073 | "source": [ 2074 | "# Add the embedding delta to the unnormalized token embeddings to get the final embeddings\n", 2075 | "embedding_after_edit = token_embeddings_unnormalized + embedding_delta\n", 2076 | "\n", 2077 | "# Print the shape of the resulting tensor\n", 2078 | "embedding_after_edit.shape" 2079 | ] 2080 | }, 2081 | { 2082 | "cell_type": "markdown", 2083 | "metadata": {}, 2084 | "source": [ 2085 | "The change in embeddings is normalized, followed by running it through a feedforward neural network." 2086 | ] 2087 | }, 2088 | { 2089 | "cell_type": "code", 2090 | "execution_count": 42, 2091 | "metadata": { 2092 | "execution": { 2093 | "iopub.execute_input": "2024-05-27T03:00:12.658349Z", 2094 | "iopub.status.busy": "2024-05-27T03:00:12.657987Z", 2095 | "iopub.status.idle": "2024-05-27T03:00:12.671973Z", 2096 | "shell.execute_reply": "2024-05-27T03:00:12.670683Z", 2097 | "shell.execute_reply.started": "2024-05-27T03:00:12.658319Z" 2098 | } 2099 | }, 2100 | "outputs": [ 2101 | { 2102 | "data": { 2103 | "text/plain": [ 2104 | "torch.Size([17, 4096])" 2105 | ] 2106 | }, 2107 | "execution_count": 42, 2108 | "metadata": {}, 2109 | "output_type": "execute_result" 2110 | } 2111 | ], 2112 | "source": [ 2113 | "# Normalize edited embeddings using root mean square normalization and provided weights\n", 2114 | "embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[\"layers.0.ffn_norm.weight\"])\n", 2115 | "\n", 2116 | "# Print the shape of resulting normalized embeddings\n", 2117 | "embedding_after_edit_normalized.shape" 2118 | ] 2119 | }, 2120 | { 2121 | "cell_type": "markdown", 2122 | "metadata": {}, 2123 | "source": [ 2124 | "## Implementing SwiGLU Activation Function\n", 2125 | "Given our familiarity with the SwiGLU activation function from the previous section, we will apply the equation we studied earlier here.\n", 2126 | "\n", 2127 | "![](https://cdn-images-1.medium.com/max/1000/1*q5FbOgDpo6H-86AefVzdNQ.png)" 2128 | ] 2129 | }, 2130 | { 2131 | "cell_type": "code", 2132 | "execution_count": 43, 2133 | "metadata": { 2134 | "execution": { 2135 | "iopub.execute_input": "2024-05-27T03:00:12.673958Z", 2136 | "iopub.status.busy": "2024-05-27T03:00:12.673440Z", 2137 | "iopub.status.idle": "2024-05-27T03:00:15.193930Z", 2138 | "shell.execute_reply": "2024-05-27T03:00:15.192559Z", 2139 | "shell.execute_reply.started": "2024-05-27T03:00:12.673917Z" 2140 | } 2141 | }, 2142 | "outputs": [ 2143 | { 2144 | "data": { 2145 | "text/plain": [ 2146 | "torch.Size([17, 4096])" 2147 | ] 2148 | }, 2149 | "execution_count": 43, 2150 | "metadata": {}, 2151 | "output_type": "execute_result" 2152 | } 2153 | ], 2154 | "source": [ 2155 | "# Retrieve weights for feedforward layer\n", 2156 | "w1 = model[\"layers.0.feed_forward.w1.weight\"]\n", 2157 | "w2 = model[\"layers.0.feed_forward.w2.weight\"]\n", 2158 | "w3 = model[\"layers.0.feed_forward.w3.weight\"]\n", 2159 | "\n", 2160 | "# Perform operations for feedforward layer\n", 2161 | "output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)\n", 2162 | "\n", 2163 | "# Print the shape of the resulting tensor after feedforward\n", 2164 | "output_after_feedforward.shape" 2165 | ] 2166 | }, 2167 | { 2168 | "cell_type": "markdown", 2169 | "metadata": {}, 2170 | "source": [ 2171 | "## Merging everything\n", 2172 | "Now that everything is ready, we need to merge our code to generate 31 more layers." 2173 | ] 2174 | }, 2175 | { 2176 | "cell_type": "code", 2177 | "execution_count": 44, 2178 | "metadata": { 2179 | "execution": { 2180 | "iopub.execute_input": "2024-05-27T03:00:15.196517Z", 2181 | "iopub.status.busy": "2024-05-27T03:00:15.196114Z", 2182 | "iopub.status.idle": "2024-05-27T03:02:08.179910Z", 2183 | "shell.execute_reply": "2024-05-27T03:02:08.178647Z", 2184 | "shell.execute_reply.started": "2024-05-27T03:00:15.196487Z" 2185 | } 2186 | }, 2187 | "outputs": [], 2188 | "source": [ 2189 | "# Initialize final embedding with unnormalized token embeddings\n", 2190 | "final_embedding = token_embeddings_unnormalized\n", 2191 | "\n", 2192 | "# Iterate through each layer\n", 2193 | "for layer in range(n_layers):\n", 2194 | " # Initialize list to store QKV attentions for each head\n", 2195 | " qkv_attention_store = []\n", 2196 | " \n", 2197 | " # Normalize the final embedding using root mean square normalization and weights from the current layer\n", 2198 | " layer_embedding_norm = rms_norm(final_embedding, model[f\"layers.{layer}.attention_norm.weight\"])\n", 2199 | " \n", 2200 | " # Retrieve query, key, value, and output weights for the attention mechanism of the current layer\n", 2201 | " q_layer = model[f\"layers.{layer}.attention.wq.weight\"]\n", 2202 | " q_layer = q_layer.view(n_heads, q_layer.shape[0] // n_heads, dim)\n", 2203 | " k_layer = model[f\"layers.{layer}.attention.wk.weight\"]\n", 2204 | " k_layer = k_layer.view(n_kv_heads, k_layer.shape[0] // n_kv_heads, dim)\n", 2205 | " v_layer = model[f\"layers.{layer}.attention.wv.weight\"]\n", 2206 | " v_layer = v_layer.view(n_kv_heads, v_layer.shape[0] // n_kv_heads, dim)\n", 2207 | " w_layer = model[f\"layers.{layer}.attention.wo.weight\"]\n", 2208 | " \n", 2209 | " # Iterate through each head\n", 2210 | " for head in range(n_heads):\n", 2211 | " # Extract query, key, and value weights for the current head\n", 2212 | " q_layer_head = q_layer[head]\n", 2213 | " k_layer_head = k_layer[head//4] # Key weights are shared across 4 heads\n", 2214 | " v_layer_head = v_layer[head//4] # Value weights are shared across 4 heads\n", 2215 | " \n", 2216 | " # Calculate query per token by matrix multiplication\n", 2217 | " q_per_token = torch.matmul(layer_embedding_norm, q_layer_head.T)\n", 2218 | " \n", 2219 | " # Calculate key per token by matrix multiplication\n", 2220 | " k_per_token = torch.matmul(layer_embedding_norm, k_layer_head.T)\n", 2221 | " \n", 2222 | " # Calculate value per token by matrix multiplication\n", 2223 | " v_per_token = torch.matmul(layer_embedding_norm, v_layer_head.T)\n", 2224 | " \n", 2225 | " # Split query per token into pairs and rotate them\n", 2226 | " q_per_token_split_into_pairs = q_per_token.float().view(q_per_token.shape[0], -1, 2)\n", 2227 | " q_per_token_as_complex_numbers = torch.view_as_complex(q_per_token_split_into_pairs)\n", 2228 | " q_per_token_split_into_pairs_rotated = torch.view_as_real(q_per_token_as_complex_numbers * freqs_cis)\n", 2229 | " q_per_token_rotated = q_per_token_split_into_pairs_rotated.view(q_per_token.shape)\n", 2230 | " \n", 2231 | " # Split key per token into pairs and rotate them\n", 2232 | " k_per_token_split_into_pairs = k_per_token.float().view(k_per_token.shape[0], -1, 2)\n", 2233 | " k_per_token_as_complex_numbers = torch.view_as_complex(k_per_token_split_into_pairs)\n", 2234 | " k_per_token_split_into_pairs_rotated = torch.view_as_real(k_per_token_as_complex_numbers * freqs_cis)\n", 2235 | " k_per_token_rotated = k_per_token_split_into_pairs_rotated.view(k_per_token.shape)\n", 2236 | " \n", 2237 | " # Calculate query-key dot products per token\n", 2238 | " qk_per_token = torch.matmul(q_per_token_rotated, k_per_token_rotated.T) / (128) ** 0.5\n", 2239 | " \n", 2240 | " # Create a mask tensor filled with negative infinity values\n", 2241 | " mask = torch.full((len(token_embeddings_unnormalized), len(token_embeddings_unnormalized)), float(\"-inf\"))\n", 2242 | " # Set upper triangular part of the mask tensor to negative infinity\n", 2243 | " mask = torch.triu(mask, diagonal=1)\n", 2244 | " # Add the mask to the query-key dot products per token\n", 2245 | " qk_per_token_after_masking = qk_per_token + mask\n", 2246 | " \n", 2247 | " # Apply softmax along the second dimension after masking\n", 2248 | " qk_per_token_after_masking_after_softmax = torch.nn.functional.softmax(qk_per_token_after_masking, dim=1).to(torch.bfloat16)\n", 2249 | " \n", 2250 | " # Calculate QKV attention by matrix multiplication\n", 2251 | " qkv_attention = torch.matmul(qk_per_token_after_masking_after_softmax, v_per_token)\n", 2252 | " \n", 2253 | " # Store QKV attention for the current head\n", 2254 | " qkv_attention_store.append(qkv_attention)\n", 2255 | " \n", 2256 | " # Concatenate QKV attentions from all heads along the last dimension\n", 2257 | " stacked_qkv_attention = torch.cat(qkv_attention_store, dim=-1)\n", 2258 | " \n", 2259 | " # Calculate embedding delta by matrix multiplication with the output weight\n", 2260 | " embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T)\n", 2261 | " \n", 2262 | " # Add the embedding delta to the current embedding to get the edited embedding\n", 2263 | " embedding_after_edit = final_embedding + embedding_delta\n", 2264 | " \n", 2265 | " # Normalize the edited embedding using root mean square normalization and weights from the current layer\n", 2266 | " embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f\"layers.{layer}.ffn_norm.weight\"])\n", 2267 | " \n", 2268 | " # Retrieve weights for the feedforward layer\n", 2269 | " w1 = model[f\"layers.{layer}.feed_forward.w1.weight\"]\n", 2270 | " w2 = model[f\"layers.{layer}.feed_forward.w2.weight\"]\n", 2271 | " w3 = model[f\"layers.{layer}.feed_forward.w3.weight\"]\n", 2272 | " \n", 2273 | " # Perform operations for the feedforward layer\n", 2274 | " output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T)\n", 2275 | " \n", 2276 | " # Update the final embedding with the edited embedding plus the output from the feedforward layer\n", 2277 | " final_embedding = embedding_after_edit + output_after_feedforward" 2278 | ] 2279 | }, 2280 | { 2281 | "cell_type": "markdown", 2282 | "metadata": {}, 2283 | "source": [ 2284 | "## Generating the Output\n", 2285 | "We now have the final embedding, representing the model's guess for the next token. Its shape is the same as regular token embeddings, [17x4096], with 17 tokens and an embedding dimension of 4096." 2286 | ] 2287 | }, 2288 | { 2289 | "cell_type": "code", 2290 | "execution_count": 45, 2291 | "metadata": { 2292 | "execution": { 2293 | "iopub.execute_input": "2024-05-27T03:02:08.181761Z", 2294 | "iopub.status.busy": "2024-05-27T03:02:08.181381Z", 2295 | "iopub.status.idle": "2024-05-27T03:02:08.190110Z", 2296 | "shell.execute_reply": "2024-05-27T03:02:08.189023Z", 2297 | "shell.execute_reply.started": "2024-05-27T03:02:08.181730Z" 2298 | } 2299 | }, 2300 | "outputs": [ 2301 | { 2302 | "data": { 2303 | "text/plain": [ 2304 | "torch.Size([17, 4096])" 2305 | ] 2306 | }, 2307 | "execution_count": 45, 2308 | "metadata": {}, 2309 | "output_type": "execute_result" 2310 | } 2311 | ], 2312 | "source": [ 2313 | "# Normalize the final embedding using root mean square normalization and provided weights\n", 2314 | "final_embedding = rms_norm(final_embedding, model[\"norm.weight\"])\n", 2315 | "\n", 2316 | "# Print the shape of the resulting normalized final embedding\n", 2317 | "final_embedding.shape" 2318 | ] 2319 | }, 2320 | { 2321 | "cell_type": "markdown", 2322 | "metadata": {}, 2323 | "source": [ 2324 | "Now we can decode the embedding into the token value.\n" 2325 | ] 2326 | }, 2327 | { 2328 | "cell_type": "code", 2329 | "execution_count": 46, 2330 | "metadata": { 2331 | "execution": { 2332 | "iopub.execute_input": "2024-05-27T03:02:08.192506Z", 2333 | "iopub.status.busy": "2024-05-27T03:02:08.191868Z", 2334 | "iopub.status.idle": "2024-05-27T03:02:08.201531Z", 2335 | "shell.execute_reply": "2024-05-27T03:02:08.200371Z", 2336 | "shell.execute_reply.started": "2024-05-27T03:02:08.192469Z" 2337 | } 2338 | }, 2339 | "outputs": [ 2340 | { 2341 | "data": { 2342 | "text/plain": [ 2343 | "torch.Size([128256, 4096])" 2344 | ] 2345 | }, 2346 | "execution_count": 46, 2347 | "metadata": {}, 2348 | "output_type": "execute_result" 2349 | } 2350 | ], 2351 | "source": [ 2352 | "# Print the shape of the output weight tensor\n", 2353 | "model[\"output.weight\"].shape" 2354 | ] 2355 | }, 2356 | { 2357 | "cell_type": "markdown", 2358 | "metadata": {}, 2359 | "source": [ 2360 | "To predict the next value, we utilize the embedding of the last token." 2361 | ] 2362 | }, 2363 | { 2364 | "cell_type": "code", 2365 | "execution_count": 47, 2366 | "metadata": { 2367 | "execution": { 2368 | "iopub.execute_input": "2024-05-27T03:02:08.203485Z", 2369 | "iopub.status.busy": "2024-05-27T03:02:08.203038Z", 2370 | "iopub.status.idle": "2024-05-27T03:02:08.725469Z", 2371 | "shell.execute_reply": "2024-05-27T03:02:08.724301Z", 2372 | "shell.execute_reply.started": "2024-05-27T03:02:08.203448Z" 2373 | } 2374 | }, 2375 | "outputs": [ 2376 | { 2377 | "data": { 2378 | "text/plain": [ 2379 | "'42'" 2380 | ] 2381 | }, 2382 | "execution_count": 47, 2383 | "metadata": {}, 2384 | "output_type": "execute_result" 2385 | } 2386 | ], 2387 | "source": [ 2388 | "# Calculate logits by matrix multiplication between the final embedding and the transpose of the output weight tensor\n", 2389 | "logits = torch.matmul(final_embedding[-1], model[\"output.weight\"].T)\n", 2390 | "\n", 2391 | "# Find the index of the maximum value along the last dimension to determine the next token\n", 2392 | "next_token = torch.argmax(logits, dim=-1)\n", 2393 | "\n", 2394 | "# Decode the index of the next token using the tokenizer\n", 2395 | "tokenizer.decode([next_token.item()])" 2396 | ] 2397 | }, 2398 | { 2399 | "cell_type": "markdown", 2400 | "metadata": {}, 2401 | "source": [ 2402 | "So, our input was \"the answer to the ultimate question of life, the universe, and everything is \", and the output for it is \"42\", which is the correct answer.\n", 2403 | "You can experiment with different input texts by simply changing these two lines throughout the entire code, Rest of the code remains same!" 2404 | ] 2405 | }, 2406 | { 2407 | "cell_type": "markdown", 2408 | "metadata": {}, 2409 | "source": [ 2410 | "```python\n", 2411 | "# input prompt\n", 2412 | "prompt = \"Your Input\"\n", 2413 | "\n", 2414 | "# Replacing 17 number with total number of tokens in your input\n", 2415 | "# You can check total number of tokens using len(tokens)\n", 2416 | "freqs_for_each_token = torch.outer(torch.arange(17), freqs)\n", 2417 | "```" 2418 | ] 2419 | }, 2420 | { 2421 | "cell_type": "markdown", 2422 | "metadata": {}, 2423 | "source": [ 2424 | "### Hope you have enjoyed and learned new things from this blog!" 2425 | ] 2426 | } 2427 | ], 2428 | "metadata": { 2429 | "kaggle": { 2430 | "accelerator": "none", 2431 | "dataSources": [], 2432 | "dockerImageVersionId": 30698, 2433 | "isGpuEnabled": false, 2434 | "isInternetEnabled": true, 2435 | "language": "python", 2436 | "sourceType": "notebook" 2437 | }, 2438 | "kernelspec": { 2439 | "display_name": "Python 3 (ipykernel)", 2440 | "language": "python", 2441 | "name": "python3" 2442 | }, 2443 | "language_info": { 2444 | "codemirror_mode": { 2445 | "name": "ipython", 2446 | "version": 3 2447 | }, 2448 | "file_extension": ".py", 2449 | "mimetype": "text/x-python", 2450 | "name": "python", 2451 | "nbconvert_exporter": "python", 2452 | "pygments_lexer": "ipython3", 2453 | "version": "3.9.0" 2454 | } 2455 | }, 2456 | "nbformat": 4, 2457 | "nbformat_minor": 4 2458 | } 2459 | --------------------------------------------------------------------------------