├── assets ├── the-primer.png └── pagedattention.png ├── README.md ├── llama3-naive.py ├── tokenizer.py ├── LICENSE └── llama3-paged.py /assets/the-primer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tspeterkim/paged-attention-minimal/HEAD/assets/the-primer.png -------------------------------------------------------------------------------- /assets/pagedattention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tspeterkim/paged-attention-minimal/HEAD/assets/pagedattention.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # paged-attention-minimal 2 | 3 | A minimal PagedAttention cache manager. `llama3-paged.py` is a <300 line implementation of: 4 | * LLama3 batch inference using ShareGPT prompts (`sharegpt-filtered.json`). 5 | * A KV cache manager for PagedAttention. 6 | 7 | This repo aims to show, minimally, 8 | how PagedAttention achieves larger batch sizes and higher request throughput. 9 | 10 | To be clear, this is not a from-scratch implementation of PagedAttention. We'll use Flash Attention's 11 | PagedAttention kernel, but write our own KV cache manager as 12 | Tri Dao [suggests](https://github.com/Dao-AILab/flash-attention/issues/660): 13 | 14 | ![the-primer](assets/the-primer.png) 15 | 16 | ## Prereqs 17 | 18 | ### Llama3 Weights 19 | 20 | [Download](https://github.com/meta-llama/llama3?tab=readme-ov-file#download) 21 | the pretrained weights. Here's one way using the command line, after you `pip install huggingface-hub`: 22 | 23 | ```bash 24 | huggingface-cli download meta-llama/Meta-Llama-3-8B-Instruct --include "original/*" --local-dir Meta-Llama-3-8B-Instruct --token {YOUR_HF_TOKEN} 25 | ``` 26 | 27 | ### Dependencies 28 | 29 | ```bash 30 | pip install torch # cuda required 31 | pip install tiktoken # for the tokenizer 32 | pip install flash-attn --no-build-isolation # for its PagedAttention implementation 33 | ``` 34 | 35 | ## Quick Start 36 | 37 | ### Naive 38 | 39 | Generate responses for 4 requests using the naive method 40 | (pre-allocating the full max sequence length of the KV cache for each request): 41 | ```bash 42 | python llama3-naive.py 4 43 | ``` 44 | This will return 4 responses generated using Llama3. 45 | 46 | Try increasing the number to see at which batch size OOM occurs. For my setup using a 4090 with memory size 24GB, 47 | the maximum batch size is 7 (I get an OOM with 8): 48 | ```bash 49 | $ python llama3-naive.py 7 50 | ... 51 | -------------------------------------------------- 52 | Fragmented Memory: 7.23 GB (28.46%) 53 | ``` 54 | Note how **~30% of the entire GPU memory becomes fragmented and unusable.** 55 | Let's see how using PagedAttention improves this. 56 | 57 | ### PagedAttention 58 | 59 | With PagedAttention, we allocate memory only when we need to when generating tokens. 60 | **This decreases fragmentation to <1%, and increases maximum batch size by 7X for me:** 61 | ```bash 62 | $ python llama3-paged.py 49 63 | ... 64 | -------------------------------------------------- 65 | Fragmented Memory: 0.14 GB (0.57%) 66 | ``` 67 | 68 | Note that these batch sizes are specific to my setup. If you have more GPU memory available, 69 | you will be able to use a larger batch size before you OOM. 70 | Regardless, the fact that PagedAttention will allow you to dramatically increase your batch size 71 | by decreasing memory fragmentation does not change. 72 | The benefit of PagedAttention will be apparent on any GPU device. 73 | 74 | ## Fun Details 75 | 76 | ### PagedAttention 77 | 78 | Traditionally, a request's KV cache is 1) stored in contiguous memory space, and 2) pre-allocated with the maximum 79 | context length (8192 for Llama3). This results in severe internal memory fragmentation e.g. if a request's actual length 80 | was generated to be 792 tokens, ~90% (=7400/8192) of the pre-allocated memory is fragmented i.e. unable to be used by 81 | any other requests. 82 | 83 | To reduce memory fragmentation and increase request throughput (batch size), PagedAttention offers a non-contiguous 84 | KV cache memory management scheme, loosely following [OS paging](https://en.wikipedia.org/wiki/Memory_paging). 85 | This ensures that memory fragmentation only occurs at the last assigned block per request: in the diagram below, 86 | outlined in red, 3 tokens in Physical Block 3 for request A, and 2 tokens in Physical Block 2 for request B. 87 | 88 | ![paged-attention](/assets/pagedattention.png) 89 | 90 | I also found it helpful to think about it in code. Instead of this: 91 | ```python 92 | y = attn(k_cache=k_cache, v_cache=v_cache, ...) 93 | ``` 94 | PagedAttention does this: 95 | ```python 96 | y = paged_attn(k_cache=k_cache_paged, v_cache=v_cache_paged, block_table=block_table, ...) 97 | ``` 98 | Unlike `k_cache`, `k_cache_paged` is non-contiguous, and is shared by all requests. Physical blocks 0~8 can be 99 | assigned to any request, and this is why we pass in `block_table`, which contains the per-request assignments of the 100 | logical blocks to physical blocks e.g. in the diagram above, `block_table` will look something like 101 | `{0: [7,1,3], 1: [5,2]}` (0,1 being the indices for request A and B, respectively) 102 | 103 | So who makes these assignments? 104 | 105 | ### KV cache manager 106 | 107 | This is what I [reimplement](https://github.com/tspeterkim/paged-attention-minimal/blob/main/llama3-paged.py#L134-L224) 108 | in this repo. I also added a very basic optimization: freeing blocks of finished requests such that the blocks can be 109 | used by other unfinished requests. Overall, my cache manager focuses on simplicity at the cost of performance. 110 | There are design choices that I made that makes the inference latency scale linearly with the number of requests. 111 | Shame on me. Please feel free to suggest improvements. In principle, the inference latency should be constant given any 112 | batch size. 113 | 114 | To be fair, I made these design choices because it was enough to show the increase in request throughput 115 | using PagedAttention in <300 lines. If I made the cache manager more performant, I would probably sacrifice the 116 | minimality. However, it is important to note that performance does matter in the end, and is the reason why 8.5K LOC 117 | systems like [vLLM](https://github.com/vllm-project/vllm) exist. 118 | 119 | ## Acknowledgements 120 | 121 | Thanks to: 122 | * Meta for the Llama3 [code](https://github.com/meta-llama/llama3) and weights 123 | * @naklecha for the minimal (and entertaining) Llama3 inference [code](https://github.com/naklecha/llama3-from-scratch) 124 | * The authors of the PagedAttention [paper](https://arxiv.org/pdf/2309.06180). 125 | * Tri Dao for the Flash Attention Repo and its PagedAttention [implementation](https://github.com/Dao-AILab/flash-attention). 126 | -------------------------------------------------------------------------------- /llama3-naive.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | 4 | import torch 5 | from flash_attn import flash_attn_with_kvcache 6 | 7 | from tokenizer import ChatFormat, Tokenizer 8 | 9 | # Housekeeping to load pretrained llama. 10 | device = 'cuda' 11 | model_name = 'Meta-Llama-3-8B-Instruct' 12 | tokenizer_path = f'{model_name}/original/tokenizer.model' 13 | tokenizer = Tokenizer(model_path=tokenizer_path) 14 | 15 | model = torch.load(f'{model_name}/original/consolidated.00.pth', map_location=device, mmap=False) 16 | 17 | with open(f'{model_name}/original/params.json', 'r') as f: 18 | config = json.load(f) 19 | 20 | dim = config['dim'] 21 | n_layers = config['n_layers'] 22 | n_heads = config['n_heads'] 23 | n_kv_heads = config['n_kv_heads'] 24 | vocab_size = config['vocab_size'] 25 | multiple_of = config['multiple_of'] 26 | ffn_dim_multiplier = config['ffn_dim_multiplier'] 27 | norm_eps = config['norm_eps'] 28 | rope_theta = torch.tensor(config['rope_theta'], device=device) 29 | head_dim = dim // n_heads # 4096 // 32 = 128 30 | max_seq_len = 8192 31 | 32 | stop_tokens = torch.tensor(list(tokenizer.stop_tokens), device=device) 33 | 34 | # Set Embedding 35 | embedding_layer = torch.nn.Embedding(vocab_size, dim, device=device, _weight=model['tok_embeddings.weight']) 36 | 37 | # Precompute freqs cis for rope 38 | zero_to_one_split_into_64_parts = torch.tensor(range(head_dim//2), device=device)/(head_dim//2) 39 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 40 | freqs_for_each_token = torch.outer(torch.arange(max_seq_len, device=device), freqs) 41 | freqs_cis_max = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 42 | 43 | # Utility funcs for rope 44 | def reshape_for_broadcast(freqs_cis, x): 45 | shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)] 46 | return freqs_cis.view(*shape) 47 | 48 | def apply_rotary_emb(xq, xk, freqs_cis): 49 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 50 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 51 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 52 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 53 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 54 | return xq_out.type_as(xq), xk_out.type_as(xk) 55 | 56 | def rms_norm(tensor, norm_weights): 57 | return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights 58 | 59 | # Generate next token i.e. do one forward pass of llama 60 | def forward(tokens, start_pos): 61 | bsz, T = tokens.shape 62 | final_embedding = embedding_layer(tokens) 63 | freqs_cis = freqs_cis_max[start_pos:start_pos+T, :] 64 | 65 | for layer in range(n_layers): 66 | q_layer = model[f'layers.{layer}.attention.wq.weight'] 67 | k_layer = model[f'layers.{layer}.attention.wk.weight'] 68 | v_layer = model[f'layers.{layer}.attention.wv.weight'] 69 | w_layer = model[f'layers.{layer}.attention.wo.weight'] 70 | 71 | layer_embedding_norm = rms_norm(final_embedding, model[f'layers.{layer}.attention_norm.weight']) 72 | 73 | q = layer_embedding_norm @ q_layer.T 74 | k = layer_embedding_norm @ k_layer.T 75 | v = layer_embedding_norm @ v_layer.T 76 | 77 | q = q.view(bsz, T, n_heads, head_dim) 78 | k = k.view(bsz, T, n_kv_heads, head_dim) 79 | v = v.view(bsz, T, n_kv_heads, head_dim) 80 | 81 | q, k = apply_rotary_emb(q, k, freqs_cis) 82 | 83 | # Use flash attention with kv-cache support. 84 | k_cache, v_cache = kv_cache[layer] 85 | y = flash_attn_with_kvcache(q, k_cache, v_cache, k, v, cache_seqlens=start_pos, causal=True) 86 | 87 | stacked_qkv_attention = y.view(bsz, T, dim) 88 | 89 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 90 | embedding_after_edit = final_embedding + embedding_delta 91 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f'layers.{layer}.ffn_norm.weight']) 92 | w1 = model[f'layers.{layer}.feed_forward.w1.weight'] 93 | w2 = model[f'layers.{layer}.feed_forward.w2.weight'] 94 | w3 = model[f'layers.{layer}.feed_forward.w3.weight'] 95 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 96 | final_embedding = embedding_after_edit + output_after_feedforward 97 | 98 | final_embedding = rms_norm(final_embedding, model['norm.weight']) 99 | logits = torch.matmul(final_embedding[:,-1,:], model['output.weight'].T) 100 | tokens = torch.argmax(logits, dim=-1) 101 | return tokens 102 | 103 | # Load ShareGPT prompts 104 | with open('sharegpt-filtered.json') as f: 105 | sharegpt = json.load(f) 106 | 107 | requests = [] 108 | for i in range(len(sharegpt)): 109 | conversations = sharegpt[i]['conversations'] 110 | if len(conversations) > 0: 111 | requests.append([{'role': 'user', 'content': sharegpt[i]['conversations'][0]['value']}]) 112 | 113 | # Use given amount of requests 114 | num_requests = int(sys.argv[1]) 115 | dialogs = requests[:num_requests] 116 | 117 | # Tokenize 118 | prompt_tokens = [ChatFormat(tokenizer).encode_dialog_prompt(d) for d in dialogs] 119 | bsz = len(prompt_tokens) 120 | min_prompt_len = min(len(t) for t in prompt_tokens) 121 | 122 | tokens = torch.full((bsz, max_seq_len), tokenizer.pad_id, dtype=torch.long, device=device) 123 | for k, t in enumerate(prompt_tokens): 124 | tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=device) 125 | 126 | prev_pos = 0 127 | eos_reached = torch.tensor([False] * bsz, device=device) 128 | input_text_mask = tokens != tokenizer.pad_id 129 | 130 | # Pre-allocate KV Cache. 131 | # Notice how we reserve `max_seq_len` length of tokens per request. 132 | # Other requests cannot use this space, leading to internal fragmentation. 133 | kv_cache = [(torch.randn((bsz, max_seq_len, n_kv_heads, head_dim), dtype=torch.bfloat16, device=device), 134 | torch.randn((bsz, max_seq_len, n_kv_heads, head_dim), dtype=torch.bfloat16, device=device)) for _ in range(n_layers)] 135 | 136 | # Do inference 137 | for cur_pos in range(min_prompt_len, max_seq_len): 138 | next_token = forward(tokens[:, prev_pos:cur_pos], prev_pos) 139 | next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) 140 | tokens[:, cur_pos] = next_token 141 | 142 | eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) 143 | prev_pos = cur_pos 144 | 145 | if all(eos_reached): 146 | break 147 | 148 | # Print generated answers / calculate fragmented memory size 149 | fragmented_memory_size = 0 150 | for i, toks in enumerate(tokens.tolist()): 151 | start = 0 if False else len(prompt_tokens[i]) 152 | toks = toks[start: len(prompt_tokens[i]) + max_seq_len] 153 | for stop_token in tokenizer.stop_tokens: 154 | try: 155 | eos_idx = toks.index(stop_token) 156 | toks = toks[:eos_idx] 157 | fragmented_memory_size += (max_seq_len - eos_idx) * n_kv_heads * head_dim * 2 * 2 * n_layers 158 | except ValueError: 159 | pass 160 | print(tokenizer.decode(toks)) 161 | print('-'*50) 162 | 163 | # Print fragmented memory size and percentage 164 | fragmented_ratio = fragmented_memory_size / torch.cuda.get_device_properties(0).total_memory 165 | print(f'Fragmented Memory: {fragmented_memory_size / 1e9:.2f} GB ({fragmented_ratio * 100:.2f}%)') 166 | -------------------------------------------------------------------------------- /tokenizer.py: -------------------------------------------------------------------------------- 1 | # Taken from official llama3 repo. 2 | # Copyright (c) Meta Platforms, Inc. and affiliates. 3 | # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. 4 | 5 | import os 6 | from logging import getLogger 7 | from pathlib import Path 8 | from typing import ( 9 | AbstractSet, 10 | cast, 11 | Collection, 12 | Dict, 13 | Iterator, 14 | List, 15 | Literal, 16 | Sequence, 17 | TypedDict, 18 | Union, 19 | ) 20 | 21 | import tiktoken 22 | from tiktoken.load import load_tiktoken_bpe 23 | 24 | 25 | logger = getLogger(__name__) 26 | 27 | 28 | Role = Literal["system", "user", "assistant"] 29 | 30 | 31 | class Message(TypedDict): 32 | role: Role 33 | content: str 34 | 35 | 36 | Dialog = Sequence[Message] 37 | 38 | 39 | class Tokenizer: 40 | """ 41 | Tokenizing and encoding/decoding text using the Tiktoken tokenizer. 42 | """ 43 | 44 | special_tokens: Dict[str, int] 45 | 46 | num_reserved_special_tokens = 256 47 | 48 | pat_str = r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+" # noqa: E501 49 | 50 | def __init__(self, model_path: str): 51 | """ 52 | Initializes the Tokenizer with a Tiktoken model. 53 | 54 | Args: 55 | model_path (str): The path to the Tiktoken model file. 56 | """ 57 | assert os.path.isfile(model_path), model_path 58 | 59 | mergeable_ranks = load_tiktoken_bpe(model_path) 60 | num_base_tokens = len(mergeable_ranks) 61 | special_tokens = [ 62 | "<|begin_of_text|>", 63 | "<|end_of_text|>", 64 | "<|reserved_special_token_0|>", 65 | "<|reserved_special_token_1|>", 66 | "<|reserved_special_token_2|>", 67 | "<|reserved_special_token_3|>", 68 | "<|start_header_id|>", 69 | "<|end_header_id|>", 70 | "<|reserved_special_token_4|>", 71 | "<|eot_id|>", # end of turn 72 | ] + [ 73 | f"<|reserved_special_token_{i}|>" 74 | for i in range(5, self.num_reserved_special_tokens - 5) 75 | ] 76 | self.special_tokens = { 77 | token: num_base_tokens + i for i, token in enumerate(special_tokens) 78 | } 79 | self.model = tiktoken.Encoding( 80 | name=Path(model_path).name, 81 | pat_str=self.pat_str, 82 | mergeable_ranks=mergeable_ranks, 83 | special_tokens=self.special_tokens, 84 | ) 85 | logger.info(f"Reloaded tiktoken model from {model_path}") 86 | 87 | self.n_words: int = self.model.n_vocab 88 | # BOS / EOS token IDs 89 | self.bos_id: int = self.special_tokens["<|begin_of_text|>"] 90 | self.eos_id: int = self.special_tokens["<|end_of_text|>"] 91 | self.pad_id: int = -1 92 | self.stop_tokens = { 93 | self.special_tokens["<|end_of_text|>"], 94 | self.special_tokens["<|eot_id|>"], 95 | } 96 | logger.info( 97 | f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}" 98 | ) 99 | 100 | def encode( 101 | self, 102 | s: str, 103 | *, 104 | bos: bool, 105 | eos: bool, 106 | allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), 107 | disallowed_special: Union[Literal["all"], Collection[str]] = (), 108 | ) -> List[int]: 109 | """ 110 | Encodes a string into a list of token IDs. 111 | 112 | Args: 113 | s (str): The input string to be encoded. 114 | bos (bool): Whether to prepend the beginning-of-sequence token. 115 | eos (bool): Whether to append the end-of-sequence token. 116 | allowed_tokens ("all"|set[str]): allowed special tokens in string 117 | disallowed_tokens ("all"|set[str]): special tokens that raise an error when in string 118 | 119 | Returns: 120 | list[int]: A list of token IDs. 121 | 122 | By default, setting disallowed_special=() encodes a string by ignoring 123 | special tokens. Specifically: 124 | - Setting `disallowed_special` to () will cause all text corresponding 125 | to special tokens to be encoded as natural text (insteading of raising 126 | an error). 127 | - Setting `allowed_special` to "all" will treat all text corresponding 128 | to special tokens to be encoded as special tokens. 129 | """ 130 | assert type(s) is str 131 | 132 | # The tiktoken tokenizer can handle <=400k chars without 133 | # pyo3_runtime.PanicException. 134 | TIKTOKEN_MAX_ENCODE_CHARS = 400_000 135 | 136 | # https://github.com/openai/tiktoken/issues/195 137 | # Here we iterate over subsequences and split if we exceed the limit 138 | # of max consecutive non-whitespace or whitespace characters. 139 | MAX_NO_WHITESPACES_CHARS = 25_000 140 | 141 | substrs = ( 142 | substr 143 | for i in range(0, len(s), TIKTOKEN_MAX_ENCODE_CHARS) 144 | for substr in self._split_whitespaces_or_nonwhitespaces( 145 | s[i : i + TIKTOKEN_MAX_ENCODE_CHARS], MAX_NO_WHITESPACES_CHARS 146 | ) 147 | ) 148 | t: List[int] = [] 149 | for substr in substrs: 150 | t.extend( 151 | self.model.encode( 152 | substr, 153 | allowed_special=allowed_special, 154 | disallowed_special=disallowed_special, 155 | ) 156 | ) 157 | if bos: 158 | t.insert(0, self.bos_id) 159 | if eos: 160 | t.append(self.eos_id) 161 | return t 162 | 163 | def decode(self, t: Sequence[int]) -> str: 164 | """ 165 | Decodes a list of token IDs into a string. 166 | 167 | Args: 168 | t (List[int]): The list of token IDs to be decoded. 169 | 170 | Returns: 171 | str: The decoded string. 172 | """ 173 | # Typecast is safe here. Tiktoken doesn't do anything list-related with the sequence. 174 | return self.model.decode(cast(List[int], t)) 175 | 176 | @staticmethod 177 | def _split_whitespaces_or_nonwhitespaces( 178 | s: str, max_consecutive_slice_len: int 179 | ) -> Iterator[str]: 180 | """ 181 | Splits the string `s` so that each substring contains no more than `max_consecutive_slice_len` 182 | consecutive whitespaces or consecutive non-whitespaces. 183 | """ 184 | current_slice_len = 0 185 | current_slice_is_space = s[0].isspace() if len(s) > 0 else False 186 | slice_start = 0 187 | 188 | for i in range(len(s)): 189 | is_now_space = s[i].isspace() 190 | 191 | if current_slice_is_space ^ is_now_space: 192 | current_slice_len = 1 193 | current_slice_is_space = is_now_space 194 | else: 195 | current_slice_len += 1 196 | if current_slice_len > max_consecutive_slice_len: 197 | yield s[slice_start:i] 198 | slice_start = i 199 | current_slice_len = 1 200 | yield s[slice_start:] 201 | 202 | 203 | class ChatFormat: 204 | def __init__(self, tokenizer: Tokenizer): 205 | self.tokenizer = tokenizer 206 | 207 | def encode_header(self, message: Message) -> List[int]: 208 | tokens = [] 209 | tokens.append(self.tokenizer.special_tokens["<|start_header_id|>"]) 210 | tokens.extend(self.tokenizer.encode(message["role"], bos=False, eos=False)) 211 | tokens.append(self.tokenizer.special_tokens["<|end_header_id|>"]) 212 | tokens.extend(self.tokenizer.encode("\n\n", bos=False, eos=False)) 213 | return tokens 214 | 215 | def encode_message(self, message: Message) -> List[int]: 216 | tokens = self.encode_header(message) 217 | tokens.extend( 218 | self.tokenizer.encode(message["content"].strip(), bos=False, eos=False) 219 | ) 220 | tokens.append(self.tokenizer.special_tokens["<|eot_id|>"]) 221 | return tokens 222 | 223 | def encode_dialog_prompt(self, dialog: Dialog) -> List[int]: 224 | tokens = [] 225 | tokens.append(self.tokenizer.special_tokens["<|begin_of_text|>"]) 226 | for message in dialog: 227 | tokens.extend(self.encode_message(message)) 228 | # Add the start of an assistant message for the model to complete. 229 | tokens.extend(self.encode_header({"role": "assistant", "content": ""})) 230 | return tokens -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /llama3-paged.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import json 3 | import math 4 | import random 5 | 6 | import torch 7 | from flash_attn import flash_attn_with_kvcache 8 | 9 | from tokenizer import ChatFormat, Tokenizer 10 | 11 | # Housekeeping to load pretrained llama. 12 | device = 'cuda' 13 | model_name = 'Meta-Llama-3-8B-Instruct' 14 | tokenizer_path = f'{model_name}/original/tokenizer.model' 15 | tokenizer = Tokenizer(model_path=tokenizer_path) 16 | 17 | model = torch.load(f'{model_name}/original/consolidated.00.pth', map_location=device, mmap=False) 18 | 19 | with open(f'{model_name}/original/params.json', 'r') as f: 20 | config = json.load(f) 21 | 22 | dim = config['dim'] 23 | n_layers = config['n_layers'] 24 | n_heads = config['n_heads'] 25 | n_kv_heads = config['n_kv_heads'] 26 | vocab_size = config['vocab_size'] 27 | multiple_of = config['multiple_of'] 28 | ffn_dim_multiplier = config['ffn_dim_multiplier'] 29 | norm_eps = config['norm_eps'] 30 | rope_theta = torch.tensor(config['rope_theta'], device=device) 31 | head_dim = dim // n_heads # 4096 // 32 = 128 32 | max_seq_len = 8192 33 | 34 | stop_tokens = torch.tensor(list(tokenizer.stop_tokens), device=device) 35 | 36 | # Set Embedding 37 | embedding_layer = torch.nn.Embedding(vocab_size, dim, device=device, _weight=model['tok_embeddings.weight']) 38 | 39 | # Precompute freqs cis for rope 40 | zero_to_one_split_into_64_parts = torch.tensor(range(head_dim//2), device=device)/(head_dim//2) 41 | freqs = 1.0 / (rope_theta ** zero_to_one_split_into_64_parts) 42 | freqs_for_each_token = torch.outer(torch.arange(max_seq_len, device=device), freqs) 43 | freqs_cis_max = torch.polar(torch.ones_like(freqs_for_each_token), freqs_for_each_token) 44 | 45 | # Utility funcs for rope 46 | def reshape_for_broadcast(freqs_cis, x): 47 | shape = [d if i == 1 or i == x.ndim - 1 else 1 for i, d in enumerate(x.shape)] 48 | return freqs_cis.view(*shape) 49 | 50 | def apply_rotary_emb(xq, xk, freqs_cis): 51 | xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) 52 | xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) 53 | freqs_cis = reshape_for_broadcast(freqs_cis, xq_) 54 | xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) 55 | xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) 56 | return xq_out.type_as(xq), xk_out.type_as(xk) 57 | 58 | def rms_norm(tensor, norm_weights): 59 | return (tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + norm_eps)) * norm_weights 60 | 61 | # Generate next token i.e. do one forward pass of llama 62 | def forward(tokens, start_pos): 63 | bsz, T = tokens.shape 64 | final_embedding = embedding_layer(tokens) 65 | freqs_cis = freqs_cis_max[start_pos:start_pos+T, :] 66 | 67 | for layer in range(n_layers): 68 | q_layer = model[f'layers.{layer}.attention.wq.weight'] 69 | k_layer = model[f'layers.{layer}.attention.wk.weight'] 70 | v_layer = model[f'layers.{layer}.attention.wv.weight'] 71 | w_layer = model[f'layers.{layer}.attention.wo.weight'] 72 | 73 | layer_embedding_norm = rms_norm(final_embedding, model[f'layers.{layer}.attention_norm.weight']) 74 | 75 | q = layer_embedding_norm @ q_layer.T 76 | k = layer_embedding_norm @ k_layer.T 77 | v = layer_embedding_norm @ v_layer.T 78 | 79 | q = q.view(bsz, T, n_heads, head_dim) 80 | k = k.view(bsz, T, n_kv_heads, head_dim) 81 | v = v.view(bsz, T, n_kv_heads, head_dim) 82 | 83 | q, k = apply_rotary_emb(q, k, freqs_cis) 84 | 85 | # Use KV Cache Manager to get paged_kv_cache and (logical) block table 86 | block_table = cms[layer].get_block_table() 87 | k_cache_paged, v_cache_paged = cms[layer].get_kv_cache() 88 | cache_seqlens = torch.where(eos_reached, cms[layer].get_last_pos(), torch.tensor([start_pos]*bsz, dtype=torch.int32, device=device)) 89 | y = flash_attn_with_kvcache(q, k_cache_paged, v_cache_paged, k, v, cache_seqlens=cache_seqlens, block_table=block_table, causal=True) 90 | 91 | stacked_qkv_attention = y.view(bsz, T, dim) 92 | 93 | embedding_delta = torch.matmul(stacked_qkv_attention, w_layer.T) 94 | embedding_after_edit = final_embedding + embedding_delta 95 | embedding_after_edit_normalized = rms_norm(embedding_after_edit, model[f'layers.{layer}.ffn_norm.weight']) 96 | w1 = model[f'layers.{layer}.feed_forward.w1.weight'] 97 | w2 = model[f'layers.{layer}.feed_forward.w2.weight'] 98 | w3 = model[f'layers.{layer}.feed_forward.w3.weight'] 99 | output_after_feedforward = torch.matmul(torch.functional.F.silu(torch.matmul(embedding_after_edit_normalized, w1.T)) * torch.matmul(embedding_after_edit_normalized, w3.T), w2.T) 100 | final_embedding = embedding_after_edit + output_after_feedforward 101 | 102 | final_embedding = rms_norm(final_embedding, model['norm.weight']) 103 | logits = torch.matmul(final_embedding[:,-1,:], model['output.weight'].T) 104 | tokens = torch.argmax(logits, dim=-1) 105 | return tokens 106 | 107 | # Load ShareGPT prompts 108 | with open('sharegpt-filtered.json') as f: 109 | sharegpt = json.load(f) 110 | 111 | requests = [] 112 | for i in range(len(sharegpt)): 113 | conversations = sharegpt[i]['conversations'] 114 | if len(conversations) > 0: 115 | requests.append([{'role': 'user', 'content': sharegpt[i]['conversations'][0]['value']}]) 116 | 117 | # Use given amount of requests 118 | num_requests = int(sys.argv[1]) 119 | dialogs = requests[:num_requests] 120 | 121 | # Tokenize 122 | prompt_tokens = [ChatFormat(tokenizer).encode_dialog_prompt(d) for d in dialogs] 123 | bsz = len(prompt_tokens) 124 | min_prompt_len = min(len(t) for t in prompt_tokens) 125 | 126 | tokens = torch.full((bsz, max_seq_len), tokenizer.pad_id, dtype=torch.long, device=device) 127 | for k, t in enumerate(prompt_tokens): 128 | tokens[k, :len(t)] = torch.tensor(t, dtype=torch.long, device=device) 129 | 130 | prev_pos = 0 131 | eos_reached = torch.tensor([False] * bsz, device=device) 132 | input_text_mask = tokens != tokenizer.pad_id 133 | 134 | # KV Cache Manager for PagedAttention 135 | # Flash Attention currently only supports block sizes of multiples of 256. (https://github.com/Dao-AILab/flash-attention/pull/824) 136 | block_size = 256 137 | class CacheManager: 138 | def __init__(self, tokens, block_size=block_size, batch_size=bsz, n_kv_heads=n_kv_heads, head_dim=head_dim): 139 | self.block_size = block_size 140 | self.batch_size = bsz 141 | self.n_kv_heads = n_kv_heads 142 | self.head_dim = head_dim 143 | self.num_blocks = (max_seq_len // block_size) * 5 # TODO: make this dynamic 144 | self.block_table = {i: [] for i in range(batch_size)} 145 | self.free_blocks = set(range(self.num_blocks)) 146 | self.k_cache_paged = torch.randn(self.num_blocks, block_size, n_kv_heads, head_dim, device=device, dtype=torch.bfloat16) 147 | self.v_cache_paged = torch.randn(self.num_blocks, block_size, n_kv_heads, head_dim, device=device, dtype=torch.bfloat16) 148 | 149 | seq_lens = (tokens != -1).sum(1) 150 | for i, t in enumerate(seq_lens.tolist()): 151 | num_blocks_to_reserve = math.ceil(t / block_size) 152 | num_filled_positions = t % block_size 153 | for b in range(num_blocks_to_reserve): 154 | index = self.get_free_block() 155 | if b == num_blocks_to_reserve-1: 156 | self.block_table[i].append((index, num_filled_positions)) 157 | else: 158 | self.block_table[i].append((index, block_size)) 159 | 160 | # Returns a free block to allocate more tokens to. 161 | # For simplicity, I raise an error when we run out of free blocks. 162 | # In the actual implementation, it solves this through scheduling and preemption (see paper) 163 | def get_free_block(self): 164 | if len(self.free_blocks) == 0: 165 | raise Exception('No more free blocks. Implement scheduling and preemption.') 166 | index = random.choice(list(self.free_blocks)) 167 | self.free_blocks.remove(index) 168 | return index 169 | 170 | # Gets the logical block table that PagedAttention uses 171 | # TODO: Serial computation makes it slow. Is there a faster way? 172 | def get_block_table(self): 173 | max_len = max(len(b) for b in self.block_table.values()) 174 | block_table = [[-1] * max_len for _ in range(self.batch_size)] 175 | for i, b in self.block_table.items(): 176 | for j, (index, _) in enumerate(b): 177 | block_table[i][j] = index 178 | return torch.tensor(block_table, dtype=torch.int32, device=device) 179 | 180 | def get_kv_cache(self): 181 | return self.k_cache_paged, self.v_cache_paged 182 | 183 | # Specific to my KV implementation. Returns the last sequence position given the block table. 184 | def get_last_pos(self): 185 | last_pos = [(len(b)-1)*self.block_size + b[len(b)-1][1]-1 for b in self.block_table.values()] 186 | return torch.tensor(last_pos, dtype=torch.int32, device=device) 187 | 188 | # Frees request's blocks. 189 | # Here, I leave one block, and free the rest. This is a limitation imposed by my kv cache implementation. 190 | # TODO: Avoid this limitation. 191 | def free_memory(self, index): 192 | blocks = self.block_table[index] 193 | if len(blocks) == 1: 194 | return 195 | for i, _ in blocks[1:]: 196 | self.free_blocks.add(i) 197 | self.block_table[index] = blocks[:1] 198 | 199 | # Updates block table and filled positions. 200 | # TODO: Again, pretty slow. Faster parallel way? 201 | def update(self, eos_reached, input_text_mask): 202 | for i, (eos, is_prompt) in enumerate(zip(eos_reached, input_text_mask)): 203 | if is_prompt: # if the token is part of the original prompt, we skip 204 | continue 205 | if eos: # free the request's blocks since we have generated the complete answer 206 | self.free_memory(i) 207 | continue 208 | 209 | old_index, n = self.block_table[i][-1] 210 | if n == self.block_size: # allocate new block if necessary 211 | new_index = self.get_free_block() 212 | self.block_table[i].append((new_index, 1)) 213 | else: # otherwise, just use the next available slot in the block 214 | self.block_table[i][-1] = (old_index, n+1) 215 | 216 | def get_fragmented_memory_size(self): 217 | size = 0 218 | for b in self.block_table.values(): 219 | _, filled = b[-1] # only the last block has fragmentation 220 | size += (self.block_size - filled) * n_kv_heads * head_dim * 2 * 2 221 | return size 222 | 223 | # Create CacheManagers for each layer 224 | cms = [CacheManager(tokens) for _ in range(n_layers)] 225 | 226 | # Do inference 227 | for cur_pos in range(min_prompt_len, max_seq_len): 228 | next_token = forward(tokens[:,prev_pos:cur_pos], prev_pos) 229 | next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) 230 | tokens[:, cur_pos] = next_token 231 | 232 | # Update CacheManagers. Increment filled positions + allocate new block if required. 233 | for layer in range(n_layers): 234 | cms[layer].update(eos_reached.tolist(), input_text_mask[:, cur_pos].tolist()) 235 | 236 | eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens)) 237 | prev_pos = cur_pos 238 | 239 | if all(eos_reached): 240 | break 241 | 242 | # Print generated answers 243 | for i, toks in enumerate(tokens.tolist()): 244 | start = 0 if False else len(prompt_tokens[i]) 245 | toks = toks[start: len(prompt_tokens[i]) + max_seq_len] 246 | for stop_token in tokenizer.stop_tokens: 247 | try: 248 | eos_idx = toks.index(stop_token) 249 | toks = toks[:eos_idx] 250 | except ValueError: 251 | pass 252 | print(tokenizer.decode(toks)) 253 | print('-'*50) 254 | 255 | # Print fragmented memory size and percentage 256 | fragmented_memory_size = sum(cms[layer].get_fragmented_memory_size() for layer in range(n_layers)) 257 | fragmented_ratio = fragmented_memory_size / torch.cuda.get_device_properties(0).total_memory 258 | print(f'Fragmented Memory: {fragmented_memory_size / 1e9:.2f} GB ({fragmented_ratio * 100:.2f}%)') --------------------------------------------------------------------------------