├── LLaMA_Final.pdf ├── README.md └── mqa_comparison.py /LLaMA_Final.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/hkproj/pytorch-llama-notes/af0c07671373d29ff3279e78b90d6b9a83249a1d/LLaMA_Final.pdf -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-llama-notes 2 | 3 | Notes of the video https://youtu.be/Mn_9W1nCFLo 4 | -------------------------------------------------------------------------------- /mqa_comparison.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | 5 | # Algorithms described in the paper "Fast Transformer Decoding: One Write-Head is All You Need", Noam Shazeer, 2019 6 | 7 | def MultiheadAttentionBatched(): 8 | d_model, seq_len_kv, seq_len, b, h, d_k, d_v = 512, 10, 10, 32, 8, (512 // 8), (512 // 8) 9 | 10 | X = torch.rand(b, seq_len, d_model) # Query 11 | M = torch.rand(b, seq_len_kv, d_model) # Key and Value 12 | mask = torch.rand(b, h, seq_len, seq_len_kv) 13 | P_q = torch.rand(h, d_model, d_k) # W_q 14 | P_k = torch.rand(h, d_model, d_k) # W_k 15 | P_v = torch.rand(h, d_model, d_v) # W_v 16 | P_o = torch.rand(h, d_model, d_v) # W_o 17 | 18 | Q = torch.einsum("bnd,hdk->bhnk ", X, P_q) 19 | K = torch.einsum("bmd,hdk->bhmk", M, P_k) 20 | V = torch.einsum("bmd,hdv->bhmv", M, P_v) 21 | 22 | logits = torch.einsum("bhnk,bhmk->bhnm", Q, K) 23 | weights = torch.softmax(logits + mask, dim=-1) 24 | 25 | O = torch.einsum("bhnm,bhmv->bhnv ", weights, V) 26 | Y = torch.einsum("bhnv,hdv->bnd ", O, P_o) 27 | return Y 28 | 29 | 30 | def MultiheadSelfAttentionIncremental(): 31 | d_model, b, h, d_k, d_v = 512, 32, 8, (512 // 8), (512 // 8) 32 | 33 | m = 5 # Suppose we have already cached "m" tokens 34 | prev_K = torch.rand(b, h, m, d_k) 35 | prev_V = torch.rand(b, h, m, d_v) 36 | 37 | X = torch.rand(b, d_model) # Query 38 | M = torch.rand(b, d_model) # Key and Value 39 | P_q = torch.rand(h, d_model, d_k) # W_q 40 | P_k = torch.rand(h, d_model, d_k) # W_k 41 | P_v = torch.rand(h, d_model, d_v) # W_v 42 | P_o = torch.rand(h, d_model, d_v) # W_o 43 | 44 | q = torch.einsum("bd,hdk->bhk", X, P_q) 45 | new_K = torch.concat( 46 | [prev_K, torch.einsum("bd,hdk->bhk", M, P_k).unsqueeze(2)], axis=2 47 | ) 48 | new_V = torch.concat( 49 | [prev_V, torch.einsum("bd,hdv->bhv", M, P_v).unsqueeze(2)], axis=2 50 | ) 51 | logits = torch.einsum("bhk,bhmk->bhm", q, new_K) 52 | weights = torch.softmax(logits, dim=-1) 53 | O = torch.einsum("bhm,bhmv->bhv", weights, new_V) 54 | y = torch.einsum("bhv,hdv->bd", O, P_o) 55 | return y, new_K, new_V 56 | 57 | 58 | def MultiquerySelfAttentionIncremental(): 59 | d, b, h, k, v = 512, 32, 8, (512 // 8), (512 // 8) 60 | 61 | m = 5 # Suppose we have already cached "m" tokens 62 | prev_K = torch.rand(b, m, k) 63 | prev_V = torch.rand(b, m, v) 64 | 65 | X = torch.rand(b, d) # Query 66 | M = torch.rand(b, d) # Key and Value 67 | P_q = torch.rand(h, d, k) # W_q 68 | P_k = torch.rand(d, k) # W_k 69 | P_v = torch.rand(d, v) # W_v 70 | P_o = torch.rand(h, d, v) # W_o 71 | 72 | q = torch.einsum("bd,hdk->bhk", X, P_q) 73 | K = torch.concat([prev_K, torch.einsum("bd,dk->bk", M, P_k).unsqueeze(1)], axis=1) 74 | V = torch.concat([prev_V, torch.einsum("bd,dv->bv", M, P_v).unsqueeze(1)], axis=1) 75 | logits = torch.einsum("bhk,bmk->bhm", q, K) 76 | weights = torch.softmax(logits, dim=-1) 77 | O = torch.einsum("bhm,bmv->bhv", weights, V) 78 | y = torch.einsum("bhv,hdv->bd", O, P_o) 79 | return y, K, V 80 | 81 | if __name__ == "__main__": 82 | MultiheadAttentionBatched() 83 | MultiheadSelfAttentionIncremental() 84 | MultiquerySelfAttentionIncremental() 85 | --------------------------------------------------------------------------------