├── .gitignore ├── LICENSE.txt ├── README.md ├── Slides.pdf ├── download.sh ├── inference.py ├── model.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | share/python-wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | MANIFEST 28 | 29 | # PyInstaller 30 | # Usually these files are written by a python script from a template 31 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 32 | *.manifest 33 | *.spec 34 | 35 | # Installer logs 36 | pip-log.txt 37 | pip-delete-this-directory.txt 38 | 39 | # Unit test / coverage reports 40 | htmlcov/ 41 | .tox/ 42 | .nox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | *.py,cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | cover/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | .pybuilder/ 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | # For a library or package, you might want to ignore these files since the code is 87 | # intended to run in multiple environments; otherwise, check them in: 88 | # .python-version 89 | 90 | # pipenv 91 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 92 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 93 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 94 | # install all needed dependencies. 95 | #Pipfile.lock 96 | 97 | # poetry 98 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 99 | # This is especially recommended for binary packages to ensure reproducibility, and is more 100 | # commonly ignored for libraries. 101 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 102 | #poetry.lock 103 | 104 | # pdm 105 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 106 | #pdm.lock 107 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 108 | # in version control. 109 | # https://pdm.fming.dev/#use-with-ide 110 | .pdm.toml 111 | 112 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 113 | __pypackages__/ 114 | 115 | # Celery stuff 116 | celerybeat-schedule 117 | celerybeat.pid 118 | 119 | # SageMath parsed files 120 | *.sage.py 121 | 122 | # Environments 123 | .env 124 | .venv 125 | env/ 126 | venv/ 127 | ENV/ 128 | env.bak/ 129 | venv.bak/ 130 | 131 | # Spyder project settings 132 | .spyderproject 133 | .spyproject 134 | 135 | # Rope project settings 136 | .ropeproject 137 | 138 | # mkdocs documentation 139 | /site 140 | 141 | # mypy 142 | .mypy_cache/ 143 | .dmypy.json 144 | dmypy.json 145 | 146 | # Pyre type checker 147 | .pyre/ 148 | 149 | # pytype static type analyzer 150 | .pytype/ 151 | 152 | # Cython debug symbols 153 | cython_debug/ 154 | 155 | # PyCharm 156 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 157 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 158 | # and can be added to the global gitignore or merged into this file. For a more nuclear 159 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 160 | #.idea/ 161 | 162 | # Ignore weights 163 | llama-2-* 164 | tokenizer.model 165 | sampling_tests.ipynb -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) [year] [fullname] 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-llama 2 | LLaMA 2 implemented from scratch in PyTorch 3 | -------------------------------------------------------------------------------- /Slides.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/pytorch-llama/067f8a37fe36ac8b52dca9cc6f2a2e8d6aa372d6/Slides.pdf -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Copyright (c) Meta Platforms, Inc. and affiliates. 4 | # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement. 5 | 6 | read -p "Enter the URL from email: " PRESIGNED_URL 7 | echo "" 8 | read -p "Enter the list of models to download without spaces (7B,13B,70B,7B-chat,13B-chat,70B-chat), or press Enter for all: " MODEL_SIZE 9 | TARGET_FOLDER="." # where all files should end up 10 | mkdir -p ${TARGET_FOLDER} 11 | 12 | if [[ $MODEL_SIZE == "" ]]; then 13 | MODEL_SIZE="7B,13B,70B,7B-chat,13B-chat,70B-chat" 14 | fi 15 | 16 | echo "Downloading LICENSE and Acceptable Usage Policy" 17 | wget ${PRESIGNED_URL/'*'/"LICENSE"} -O ${TARGET_FOLDER}"/LICENSE" 18 | wget ${PRESIGNED_URL/'*'/"USE_POLICY.md"} -O ${TARGET_FOLDER}"/USE_POLICY.md" 19 | 20 | echo "Downloading tokenizer" 21 | wget ${PRESIGNED_URL/'*'/"tokenizer.model"} -O ${TARGET_FOLDER}"/tokenizer.model" 22 | wget ${PRESIGNED_URL/'*'/"tokenizer_checklist.chk"} -O ${TARGET_FOLDER}"/tokenizer_checklist.chk" 23 | (cd ${TARGET_FOLDER} && md5sum -c tokenizer_checklist.chk) 24 | 25 | for m in ${MODEL_SIZE//,/ } 26 | do 27 | if [[ $m == "7B" ]]; then 28 | SHARD=0 29 | MODEL_PATH="llama-2-7b" 30 | elif [[ $m == "7B-chat" ]]; then 31 | SHARD=0 32 | MODEL_PATH="llama-2-7b-chat" 33 | elif [[ $m == "13B" ]]; then 34 | SHARD=1 35 | MODEL_PATH="llama-2-13b" 36 | elif [[ $m == "13B-chat" ]]; then 37 | SHARD=1 38 | MODEL_PATH="llama-2-13b-chat" 39 | elif [[ $m == "70B" ]]; then 40 | SHARD=7 41 | MODEL_PATH="llama-2-70b" 42 | elif [[ $m == "70B-chat" ]]; then 43 | SHARD=7 44 | MODEL_PATH="llama-2-70b-chat" 45 | fi 46 | 47 | echo "Downloading ${MODEL_PATH}" 48 | mkdir -p ${TARGET_FOLDER}"/${MODEL_PATH}" 49 | 50 | for s in $(seq -f "0%g" 0 ${SHARD}) 51 | do 52 | wget --retry-connrefused --waitretry=1 --read-timeout=20 --timeout=15 -t 0 --continue ${PRESIGNED_URL/'*'/"${MODEL_PATH}/consolidated.${s}.pth"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/consolidated.${s}.pth" 53 | done 54 | 55 | wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/params.json"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/params.json" 56 | wget ${PRESIGNED_URL/'*'/"${MODEL_PATH}/checklist.chk"} -O ${TARGET_FOLDER}"/${MODEL_PATH}/checklist.chk" 57 | echo "Checking checksums" 58 | (cd ${TARGET_FOLDER}"/${MODEL_PATH}" && md5sum -c checklist.chk) 59 | done 60 | 61 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import time 4 | from pathlib import Path 5 | import json 6 | from sentencepiece import SentencePieceProcessor 7 | from tqdm import tqdm 8 | 9 | from model import ModelArgs, Transformer 10 | 11 | class LLaMA: 12 | 13 | def __init__(self, model: Transformer, tokenizer: SentencePieceProcessor, model_args: ModelArgs): 14 | self.model = model 15 | self.tokenizer = tokenizer 16 | self.args = model_args 17 | 18 | @staticmethod 19 | def build(checkpoints_dir: str, tokenizer_path: str, load_model: bool, max_seq_len: int, max_batch_size: int, device: str): 20 | prev_time = time.time() 21 | if load_model: 22 | checkpoints = sorted(Path(checkpoints_dir).glob("*.pth")) 23 | assert len(checkpoints) > 0, f"no checkpoint files found in {checkpoints_dir}" 24 | ckpt_path = checkpoints[0] 25 | print(f'Loading checkpoint "{ckpt_path}"') 26 | checkpoint = torch.load(ckpt_path, map_location="cpu") 27 | print(f"Loaded checkpoint in {time.time() - prev_time:.2f}s") 28 | prev_time = time.time() 29 | with open(Path(checkpoints_dir) / "params.json", "r") as f: 30 | params = json.loads(f.read()) 31 | 32 | model_args: ModelArgs = ModelArgs( 33 | max_seq_len=max_seq_len, 34 | max_batch_size=max_batch_size, 35 | device=device, 36 | **params 37 | ) 38 | 39 | tokenizer = SentencePieceProcessor() 40 | tokenizer.load(tokenizer_path) 41 | model_args.vocab_size = tokenizer.vocab_size() 42 | 43 | if device == "cuda": 44 | torch.set_default_tensor_type(torch.cuda.HalfTensor) 45 | else: 46 | torch.set_default_tensor_type(torch.BFloat16Tensor) 47 | 48 | model = Transformer(model_args).to(device) 49 | 50 | if load_model: 51 | # The only unmatched key in the checkpoint is rope.freqs. Remove it 52 | del checkpoint['rope.freqs'] 53 | model.load_state_dict(checkpoint, strict=True) 54 | print(f"Loaded state dict in {time.time() - prev_time:.2f}s") 55 | 56 | return LLaMA(model, tokenizer, model_args) 57 | 58 | def text_completion(self, prompts: list[str], temperature: float = 0.6, top_p: float = 0.9, max_gen_len: Optional[int] = None): 59 | if max_gen_len is None: 60 | max_gen_len = self.args.max_seq_len - 1 61 | # Convert each prompt into tokens 62 | prompt_tokens = [self.tokenizer.encode(prompt, out_type=int, add_bos=True, add_eos=False) for prompt in prompts] 63 | # Make sure the batch size is not too large 64 | batch_size = len(prompt_tokens) 65 | assert batch_size <= self.args.max_batch_size, f"batch size must be less than or equal to {self.args.max_batch_size}" 66 | max_prompt_len = max(len(prompt) for prompt in prompt_tokens) 67 | # Make sure the prompt length is not larger than the maximum sequence length 68 | assert max_prompt_len <= self.args.max_seq_len, f"prompt length must be less than or equal to {self.args.max_seq_len}" 69 | total_len = min(self.args.max_seq_len, max_gen_len + max_prompt_len) 70 | 71 | # Create the list that will contain the generated tokens, along with the initial prompt tokens 72 | pad_id = self.tokenizer.pad_id() 73 | tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.long, device=device) 74 | for k, t in enumerate(prompt_tokens): 75 | # Populate the initial tokens with the prompt tokens 76 | tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device=device) 77 | 78 | eos_reached = torch.tensor([False] * batch_size, device=device) 79 | prompt_tokens_mask = tokens != pad_id # True if the token is a prompt token, False otherwise 80 | cur_iterator = tqdm(range(1, total_len), desc="Generating tokens") 81 | for cur_pos in cur_iterator: 82 | with torch.no_grad(): 83 | logits = self.model.forward(tokens[:, cur_pos-1:cur_pos], cur_pos) 84 | if temperature > 0: 85 | # The temperature is applied before the softmax 86 | probs = torch.softmax(logits[:, -1] / temperature, dim=-1) 87 | next_token = self._sample_top_p(probs, top_p) 88 | else: 89 | # Greedily select the token with the max probability 90 | next_token = torch.argmax(logits[:, -1], dim=-1) 91 | 92 | next_token = next_token.reshape(-1) 93 | # Only replace token if it is a padding token 94 | next_token = torch.where(prompt_tokens_mask[:, cur_pos], tokens[:, cur_pos], next_token) 95 | tokens[:, cur_pos] = next_token 96 | # EOS is reached only if we found an EOS token for a padding position 97 | eos_reached |= (~prompt_tokens_mask[:, cur_pos]) & (next_token == self.tokenizer.eos_id) 98 | if all(eos_reached): 99 | break 100 | 101 | out_tokens = [] 102 | out_text = [] 103 | for prompt_index, current_prompt_tokens in enumerate(tokens.tolist()): 104 | # Cut to the EOS token, if present 105 | if self.tokenizer.eos_id in current_prompt_tokens: 106 | eos_idx = current_prompt_tokens.index(self.tokenizer.eos_id) 107 | current_prompt_tokens = current_prompt_tokens[:eos_idx] 108 | out_tokens.append(current_prompt_tokens) 109 | out_text.append(self.tokenizer.decode(current_prompt_tokens)) 110 | return (out_tokens, out_text) 111 | 112 | def _sample_top_p(self, probs, p): 113 | # (B, vocab_size) 114 | probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) 115 | # (B, vocab_size) 116 | probs_sum = torch.cumsum(probs_sort, dim=-1) 117 | # (B, vocab_size) 118 | # (Substracting "probs_sort" shifts the cumulative sum by 1 position to the right before masking) 119 | mask = probs_sum - probs_sort > p 120 | # Zero out all the probabilities of tokens that are not selected by the Top P 121 | probs_sort[mask] = 0.0 122 | # Redistribute the probabilities so that they sum up to 1. 123 | probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) 124 | # Sample a token (its index) from the top p distribution 125 | next_token = torch.multinomial(probs_sort, num_samples=1) 126 | # Get the token position in the vocabulary corresponding to the sampled index 127 | next_token = torch.gather(probs_idx, -1, next_token) 128 | return next_token 129 | 130 | 131 | 132 | if __name__ == '__main__': 133 | torch.manual_seed(0) 134 | 135 | allow_cuda = False 136 | device = 'cuda' if torch.cuda.is_available() and allow_cuda else 'cpu' 137 | 138 | prompts = [ 139 | "Simply put, the theory of relativity states that ", 140 | "If Google was an Italian company founded in Milan, it would", 141 | # Few shot promt 142 | """Translate English to French: 143 | 144 | sea otter => loutre de mer 145 | peppermint => menthe poivrée 146 | plush girafe => girafe peluche 147 | cheese =>""", 148 | # Zero shot prompt 149 | """Tell me if the following person is actually Doraemon disguised as human: 150 | Name: Umar Jamil 151 | Decision: 152 | """ 153 | ] 154 | 155 | model = LLaMA.build( 156 | checkpoints_dir='llama-2-7b/', 157 | tokenizer_path='tokenizer.model', 158 | load_model=True, 159 | max_seq_len=1024, 160 | max_batch_size=len(prompts), 161 | device=device 162 | ) 163 | 164 | out_tokens, out_texts = (model.text_completion(prompts, max_gen_len=64)) 165 | assert len(out_texts) == len(prompts) 166 | for i in range(len(out_texts)): 167 | print(f'{out_texts[i]}') 168 | print('-' * 50) 169 | 170 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | from typing import Optional 3 | import math 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | @dataclass 10 | class ModelArgs: 11 | dim: int = 4096 12 | n_layers: int = 32 13 | n_heads: int = 32 14 | n_kv_heads: Optional[int] = None 15 | vocab_size: int = -1 # Later set in the build method 16 | multiple_of: int = 256 17 | ffn_dim_multiplier: Optional[float] = None 18 | norm_eps: float = 1e-5 19 | 20 | # Needed for KV cache 21 | max_batch_size: int = 32 22 | max_seq_len: int = 2048 23 | 24 | device: str = None 25 | 26 | 27 | class RMSNorm(nn.Module): 28 | def __init__(self, dim: int, eps: float = 1e-6): 29 | super().__init__() 30 | self.eps = eps 31 | # The gamma parameter 32 | self.weight = nn.Parameter(torch.ones(dim)) 33 | 34 | def _norm(self, x: torch.Tensor): 35 | # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim) 36 | # rsqrt: 1 / sqrt(x) 37 | return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) 38 | 39 | def forward(self, x: torch.Tensor): 40 | # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim) 41 | return self.weight * self._norm(x.float()).type_as(x) 42 | 43 | 44 | def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str, theta: float = 10000.0): 45 | # As written in the paragraph 3.2.2 of the paper 46 | # >> In order to generalize our results in 2D to any xi ∈ Rd where **d is even**, [...] 47 | assert head_dim % 2 == 0, "Dimension must be divisible by 2" 48 | # Build the theta parameter 49 | # According to the formula theta_i = 10000^(-2(i-1)/dim) for i = [1, 2, ... dim/2] 50 | # Shape: (Head_Dim / 2) 51 | theta_numerator = torch.arange(0, head_dim, 2).float() 52 | # Shape: (Head_Dim / 2) 53 | theta = 1.0 / (theta ** (theta_numerator / head_dim)).to(device) # (Dim / 2) 54 | # Construct the positions (the "m" parameter) 55 | # Shape: (Seq_Len) 56 | m = torch.arange(seq_len, device=device) 57 | # Multiply each theta by each position using the outer product. 58 | # Shape: (Seq_Len) outer_product* (Head_Dim / 2) -> (Seq_Len, Head_Dim / 2) 59 | freqs = torch.outer(m, theta).float() 60 | # We can compute complex numbers in the polar form c = R * exp(m * theta), where R = 1 as follows: 61 | # (Seq_Len, Head_Dim / 2) -> (Seq_Len, Head_Dim / 2) 62 | freqs_complex = torch.polar(torch.ones_like(freqs), freqs) 63 | return freqs_complex 64 | 65 | def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str): 66 | # Separate the last dimension pairs of two values, representing the real and imaginary parts of the complex number 67 | # Two consecutive values will become a single complex number 68 | # (B, Seq_Len, H, Head_Dim) -> (B, Seq_Len, H, Head_Dim/2) 69 | x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) 70 | # Reshape the freqs_complex tensor to match the shape of the x_complex tensor. So we need to add the batch dimension and the head dimension 71 | # (Seq_Len, Head_Dim/2) --> (1, Seq_Len, 1, Head_Dim/2) 72 | freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2) 73 | # Multiply each complex number in the x_complex tensor by the corresponding complex number in the freqs_complex tensor 74 | # Which results in the rotation of the complex number as shown in the Figure 1 of the paper 75 | # (B, Seq_Len, H, Head_Dim/2) * (1, Seq_Len, 1, Head_Dim/2) = (B, Seq_Len, H, Head_Dim/2) 76 | x_rotated = x_complex * freqs_complex 77 | # Convert the complex number back to the real number 78 | # (B, Seq_Len, H, Head_Dim/2) -> (B, Seq_Len, H, Head_Dim/2, 2) 79 | x_out = torch.view_as_real(x_rotated) 80 | # (B, Seq_Len, H, Head_Dim/2, 2) -> (B, Seq_Len, H, Head_Dim) 81 | x_out = x_out.reshape(*x.shape) 82 | return x_out.type_as(x).to(device) 83 | 84 | 85 | def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: 86 | batch_size, seq_len, n_kv_heads, head_dim = x.shape 87 | if n_rep == 1: 88 | return x 89 | return ( 90 | # (B, Seq_Len, N_KV_Heads, 1, Head_Dim) 91 | x[:, :, :, None, :] 92 | # (B, Seq_Len, N_KV_Heads, N_Rep, Head_Dim) 93 | .expand(batch_size, seq_len, n_kv_heads, n_rep, head_dim) 94 | # (B, Seq_Len, N_KV_Heads * N_Rep, Head_Dim) 95 | .reshape(batch_size, seq_len, n_kv_heads * n_rep, head_dim) 96 | ) 97 | 98 | 99 | class SelfAttention(nn.Module): 100 | def __init__(self, args: ModelArgs): 101 | super().__init__() 102 | 103 | # Indicates the number of heads for the Keys and Values 104 | self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads 105 | # Indicates the number of heads for the Queries 106 | self.n_heads_q = args.n_heads 107 | # Indicates how many times the Keys and Values should be repeated 108 | self.n_rep = self.n_heads_q // self.n_kv_heads 109 | # Indicates the dimension of each head, that is, the part of the embedding that each head will be responsible for 110 | self.head_dim = args.dim // args.n_heads 111 | 112 | self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) 113 | self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 114 | self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) 115 | self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) 116 | 117 | self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)) 118 | self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim)) 119 | 120 | def forward( 121 | self, 122 | x: torch.Tensor, 123 | start_pos: int, 124 | freqs_complex: torch.Tensor 125 | ): 126 | batch_size, seq_len, _ = x.shape # (B, 1, Dim) 127 | 128 | # (B, 1, Dim) -> (B, 1, H_Q * Head_Dim) 129 | xq = self.wq(x) 130 | # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim) 131 | xk = self.wk(x) 132 | # (B, 1, Dim) -> (B, 1, H_KV * Head_Dim) 133 | xv = self.wv(x) 134 | 135 | # (B, 1, H_Q * Head_Dim) -> (B, 1, H_Q, Head_Dim) 136 | xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim) 137 | # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim) 138 | xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 139 | # (B, 1, H_KV * Head_Dim) -> (B, 1, H_KV, Head_Dim) 140 | xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim) 141 | 142 | # (B, 1, H_Q, Head_Dim) --> (B, 1, H_Q, Head_Dim) 143 | xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device) 144 | # (B, 1, H_KV, Head_Dim) --> (B, 1, H_KV, Head_Dim) 145 | xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device) 146 | 147 | # Replace the entry in the cache 148 | self.cache_k[:batch_size, start_pos : start_pos + seq_len] = xk 149 | self.cache_v[:batch_size, start_pos : start_pos + seq_len] = xv 150 | 151 | # (B, Seq_Len_KV, H_KV, Head_Dim) 152 | keys = self.cache_k[:batch_size, : start_pos + seq_len] 153 | # (B, Seq_Len_KV, H_KV, Head_Dim) 154 | values = self.cache_v[:batch_size, : start_pos + seq_len] 155 | 156 | # Since every group of Q shares the same K and V heads, just repeat the K and V heads for every Q in the same group. 157 | 158 | # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim) 159 | keys = repeat_kv(keys, self.n_rep) 160 | # (B, Seq_Len_KV, H_KV, Head_Dim) --> (B, Seq_Len_KV, H_Q, Head_Dim) 161 | values = repeat_kv(values, self.n_rep) 162 | 163 | # (B, 1, H_Q, Head_Dim) -> (B, H_Q, 1, Head_Dim) 164 | xq = xq.transpose(1, 2) 165 | # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim) 166 | keys = keys.transpose(1, 2) 167 | # (B, Seq_Len_KV, H_Q, Head_Dim) -> (B, H_Q, Seq_Len_KV, Head_Dim) 168 | values = values.transpose(1, 2) 169 | 170 | # (B, H_Q, 1, Head_Dim) @ (B, H_Q, Head_Dim, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV) 171 | scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim) 172 | # (B, H_Q, 1, Seq_Len_KV) -> (B, H_Q, 1, Seq_Len_KV) 173 | scores = F.softmax(scores.float(), dim=-1).type_as(xq) 174 | 175 | # (B, H_Q, 1, Seq_Len) @ (B, H_Q, Seq_Len_KV, Head_Dim) -> (B, H_Q, 1, Head_Dim) 176 | output = torch.matmul(scores, values) 177 | # (B, H_Q, 1, Head_Dim) -> (B, 1, H_Q, Head_Dim) -> (B, 1, Dim) 178 | output = (output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)) 179 | return self.wo(output) # (B, 1, Dim) -> (B, 1, Dim) 180 | 181 | 182 | class FeedForward(nn.Module): 183 | def __init__( 184 | self, 185 | args: ModelArgs 186 | ): 187 | super().__init__() 188 | 189 | hidden_dim = 4 * args.dim 190 | hidden_dim = int(2 * hidden_dim / 3) 191 | if args.ffn_dim_multiplier is not None: 192 | hidden_dim = int(args.ffn_dim_multiplier * hidden_dim) 193 | # Round the hidden_dim to the nearest multiple of the multiple_of parameter 194 | hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of) 195 | 196 | self.w1 = nn.Linear(args.dim, hidden_dim, bias=False) 197 | self.w2 = nn.Linear(hidden_dim, args.dim, bias=False) 198 | self.w3 = nn.Linear(args.dim, hidden_dim, bias=False) 199 | 200 | def forward(self, x: torch.Tensor): 201 | # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim) 202 | swish = F.silu(self.w1(x)) 203 | # (B, Seq_Len, Dim) --> (B, Seq_Len, Hidden_Dim) 204 | x_V = self.w3(x) 205 | # (B, Seq_Len, Hidden_Dim) * (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Hidden_Dim) 206 | x = swish * x_V 207 | # (B, Seq_Len, Hidden_Dim) --> (B, Seq_Len, Dim) 208 | x = self.w2(x) 209 | return x 210 | 211 | 212 | class EncoderBlock(nn.Module): 213 | 214 | def __init__(self, args: ModelArgs): 215 | super().__init__() 216 | 217 | self.n_heads = args.n_heads 218 | self.dim = args.dim 219 | self.head_dim = args.dim // args.n_heads 220 | 221 | self.attention = SelfAttention(args) 222 | self.feed_forward = FeedForward(args) 223 | 224 | # Normalization BEFORE the attention block 225 | self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) 226 | # Normalization BEFORE the feed forward block 227 | self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) 228 | 229 | def forward(self, x: torch.Tensor, start_pos: int, freqs_complex: torch.Tensor): 230 | # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim) 231 | h = x + self.attention.forward( 232 | self.attention_norm(x), start_pos, freqs_complex 233 | ) 234 | # (B, Seq_Len, Dim) + (B, Seq_Len, Dim) --> (B, Seq_Len, Dim) 235 | out = h + self.feed_forward.forward(self.ffn_norm(h)) 236 | return out 237 | 238 | class Transformer(nn.Module): 239 | 240 | def __init__(self, args: ModelArgs): 241 | super().__init__() 242 | 243 | assert args.vocab_size != -1, "Vocab size must be set" 244 | 245 | self.args = args 246 | self.vocab_size = args.vocab_size 247 | self.n_layers = args.n_layers 248 | self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim) 249 | 250 | self.layers = nn.ModuleList() 251 | for layer_id in range(args.n_layers): 252 | self.layers.append(EncoderBlock(args)) 253 | 254 | self.norm = RMSNorm(args.dim, eps=args.norm_eps) 255 | self.output = nn.Linear(args.dim, self.vocab_size, bias=False) 256 | 257 | self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device) 258 | 259 | def forward(self, tokens: torch.Tensor, start_pos: int): 260 | # (B, Seq_Len) 261 | batch_size, seq_len = tokens.shape 262 | assert seq_len == 1, "Only one token at a time can be processed" 263 | 264 | # (B, Seq_Len) -> (B, Seq_Len, Dim) 265 | h = self.tok_embeddings(tokens) 266 | 267 | # Retrieve the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len] 268 | freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len] 269 | 270 | # Consecutively apply all the encoder layers 271 | for layer in self.layers: 272 | h = layer(h, start_pos, freqs_complex) 273 | h = self.norm(h) 274 | output = self.output(h).float() 275 | return output -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | sentencepiece 3 | tqdm --------------------------------------------------------------------------------