├── .gitignore ├── .gitmodules ├── .vscode └── launch.json ├── Cargo.lock ├── Cargo.toml ├── README.md ├── download_weights.sh ├── src └── main.rs └── tokenizer.bin /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | /weights -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "llama2.c"] 2 | path = llama2.c 3 | url = git@github.com:karpathy/llama2.c.git 4 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Verwendet IntelliSense zum Ermitteln möglicher Attribute. 3 | // Zeigen Sie auf vorhandene Attribute, um die zugehörigen Beschreibungen anzuzeigen. 4 | // Weitere Informationen finden Sie unter https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "type": "lldb", 9 | "request": "launch", 10 | "name": "Debug executable 'llama2-rs'", 11 | "cargo": { 12 | "args": [ 13 | "build", 14 | "--bin=llama2-rs", 15 | "--package=llama2-rs" 16 | ], 17 | "filter": { 18 | "name": "llama2-rs", 19 | "kind": "bin" 20 | } 21 | }, 22 | "args": [], 23 | "cwd": "${workspaceFolder}" 24 | }, 25 | { 26 | "type": "lldb", 27 | "request": "launch", 28 | "name": "Debug unit tests in executable 'llama2-rs'", 29 | "cargo": { 30 | "args": [ 31 | "test", 32 | "--no-run", 33 | "--bin=llama2-rs", 34 | "--package=llama2-rs" 35 | ], 36 | "filter": { 37 | "name": "llama2-rs", 38 | "kind": "bin" 39 | } 40 | }, 41 | "args": [], 42 | "cwd": "${workspaceFolder}" 43 | } 44 | ] 45 | } -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "autocfg" 7 | version = "1.1.0" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" 10 | 11 | [[package]] 12 | name = "byteorder" 13 | version = "1.4.3" 14 | source = "registry+https://github.com/rust-lang/crates.io-index" 15 | checksum = "14c189c53d098945499cdfa7ecc63567cf3886b3332b312a5b4585d8d3a6a610" 16 | 17 | [[package]] 18 | name = "cfg-if" 19 | version = "1.0.0" 20 | source = "registry+https://github.com/rust-lang/crates.io-index" 21 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 22 | 23 | [[package]] 24 | name = "crossbeam-channel" 25 | version = "0.5.8" 26 | source = "registry+https://github.com/rust-lang/crates.io-index" 27 | checksum = "a33c2bf77f2df06183c3aa30d1e96c0695a313d4f9c453cc3762a6db39f99200" 28 | dependencies = [ 29 | "cfg-if", 30 | "crossbeam-utils", 31 | ] 32 | 33 | [[package]] 34 | name = "crossbeam-deque" 35 | version = "0.8.3" 36 | source = "registry+https://github.com/rust-lang/crates.io-index" 37 | checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" 38 | dependencies = [ 39 | "cfg-if", 40 | "crossbeam-epoch", 41 | "crossbeam-utils", 42 | ] 43 | 44 | [[package]] 45 | name = "crossbeam-epoch" 46 | version = "0.9.15" 47 | source = "registry+https://github.com/rust-lang/crates.io-index" 48 | checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" 49 | dependencies = [ 50 | "autocfg", 51 | "cfg-if", 52 | "crossbeam-utils", 53 | "memoffset", 54 | "scopeguard", 55 | ] 56 | 57 | [[package]] 58 | name = "crossbeam-utils" 59 | version = "0.8.16" 60 | source = "registry+https://github.com/rust-lang/crates.io-index" 61 | checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" 62 | dependencies = [ 63 | "cfg-if", 64 | ] 65 | 66 | [[package]] 67 | name = "either" 68 | version = "1.9.0" 69 | source = "registry+https://github.com/rust-lang/crates.io-index" 70 | checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" 71 | 72 | [[package]] 73 | name = "getrandom" 74 | version = "0.2.10" 75 | source = "registry+https://github.com/rust-lang/crates.io-index" 76 | checksum = "be4136b2a15dd319360be1c07d9933517ccf0be8f16bf62a3bee4f0d618df427" 77 | dependencies = [ 78 | "cfg-if", 79 | "libc", 80 | "wasi", 81 | ] 82 | 83 | [[package]] 84 | name = "hermit-abi" 85 | version = "0.3.2" 86 | source = "registry+https://github.com/rust-lang/crates.io-index" 87 | checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" 88 | 89 | [[package]] 90 | name = "libc" 91 | version = "0.2.147" 92 | source = "registry+https://github.com/rust-lang/crates.io-index" 93 | checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" 94 | 95 | [[package]] 96 | name = "llama2-rs" 97 | version = "0.1.0" 98 | dependencies = [ 99 | "byteorder", 100 | "memmap2", 101 | "num_cpus", 102 | "rand", 103 | "rayon", 104 | ] 105 | 106 | [[package]] 107 | name = "memmap2" 108 | version = "0.7.1" 109 | source = "registry+https://github.com/rust-lang/crates.io-index" 110 | checksum = "f49388d20533534cd19360ad3d6a7dadc885944aa802ba3995040c5ec11288c6" 111 | dependencies = [ 112 | "libc", 113 | ] 114 | 115 | [[package]] 116 | name = "memoffset" 117 | version = "0.9.0" 118 | source = "registry+https://github.com/rust-lang/crates.io-index" 119 | checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" 120 | dependencies = [ 121 | "autocfg", 122 | ] 123 | 124 | [[package]] 125 | name = "num_cpus" 126 | version = "1.16.0" 127 | source = "registry+https://github.com/rust-lang/crates.io-index" 128 | checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" 129 | dependencies = [ 130 | "hermit-abi", 131 | "libc", 132 | ] 133 | 134 | [[package]] 135 | name = "ppv-lite86" 136 | version = "0.2.17" 137 | source = "registry+https://github.com/rust-lang/crates.io-index" 138 | checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" 139 | 140 | [[package]] 141 | name = "rand" 142 | version = "0.8.5" 143 | source = "registry+https://github.com/rust-lang/crates.io-index" 144 | checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" 145 | dependencies = [ 146 | "libc", 147 | "rand_chacha", 148 | "rand_core", 149 | ] 150 | 151 | [[package]] 152 | name = "rand_chacha" 153 | version = "0.3.1" 154 | source = "registry+https://github.com/rust-lang/crates.io-index" 155 | checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" 156 | dependencies = [ 157 | "ppv-lite86", 158 | "rand_core", 159 | ] 160 | 161 | [[package]] 162 | name = "rand_core" 163 | version = "0.6.4" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" 166 | dependencies = [ 167 | "getrandom", 168 | ] 169 | 170 | [[package]] 171 | name = "rayon" 172 | version = "1.7.0" 173 | source = "registry+https://github.com/rust-lang/crates.io-index" 174 | checksum = "1d2df5196e37bcc87abebc0053e20787d73847bb33134a69841207dd0a47f03b" 175 | dependencies = [ 176 | "either", 177 | "rayon-core", 178 | ] 179 | 180 | [[package]] 181 | name = "rayon-core" 182 | version = "1.11.0" 183 | source = "registry+https://github.com/rust-lang/crates.io-index" 184 | checksum = "4b8f95bd6966f5c87776639160a66bd8ab9895d9d4ab01ddba9fc60661aebe8d" 185 | dependencies = [ 186 | "crossbeam-channel", 187 | "crossbeam-deque", 188 | "crossbeam-utils", 189 | "num_cpus", 190 | ] 191 | 192 | [[package]] 193 | name = "scopeguard" 194 | version = "1.2.0" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" 197 | 198 | [[package]] 199 | name = "wasi" 200 | version = "0.11.0+wasi-snapshot-preview1" 201 | source = "registry+https://github.com/rust-lang/crates.io-index" 202 | checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" 203 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llama2-rs" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | [dependencies] 7 | byteorder = "1.4.3" 8 | memmap2 = "0.7.1" 9 | num_cpus = "1.16.0" 10 | rand = "0.8.5" 11 | rayon = "1.7.0" 12 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llama2-rs 2 | 3 | LLaMA2 port for Rust inspired by `llama2.c`. 4 | 5 | TODOs: 6 | 7 | - [X] Implement loading of the model 8 | - [X] Implement forward pass 9 | - [X] Implement generation 10 | - [X] Implement tokens/sec 11 | - [X] Support prompting and tokenization 12 | - [ ] Command line args 13 | - [X] Parallelize implementation 14 | - [ ] Support Quantization 15 | - [ ] Optimize performance (SIMD/vectorization, fuse loops etc.) 16 | 17 | 18 | Current Performance on my M1 Pro: 19 | 20 | ``` 21 | tokens / seconds = 145.67 22 | 23 | 24 | Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, red ball in the sky. It was the sun! She thought it was so pretty. 25 | Lily wanted to play with the ball, but it was too high up in the sky. She tried to jump and reach it, but she couldn't. Then, she had an idea. She would use a stick to knock the ball down. 26 | Lily found a stick and tried to hit the ball. But the stick was too short. She tried again and again, but she couldn't reach it. She felt sad. 27 | Suddenly, a kind man came by and saw Lily. He asked her what was wrong. Lily told him about the ball. The man smiled and said, "I have a useful idea!" He took out a long stick and used it to knock the ball down. Lily was so happy! She thanked the man and they played together in the sunshine. 28 | 29 | Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big,% 30 | ``` -------------------------------------------------------------------------------- /download_weights.sh: -------------------------------------------------------------------------------- 1 | WEIGHT_DIR="weights" 2 | 3 | if [ ! -d "$DIRECTORY" ]; then 4 | mkdir weights 5 | fi 6 | cd weights 7 | 8 | # Download 15M tinystories model from Andrej Karpathy 9 | wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin 10 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use byteorder::ByteOrder; 2 | use memmap2::Mmap; 3 | use rand::Rng; 4 | use rayon::iter::ParallelIterator; 5 | use rayon::prelude::{IndexedParallelIterator, IntoParallelRefMutIterator}; 6 | use rayon::slice::ParallelSliceMut; 7 | use std::collections::HashMap; 8 | use std::fs::File; 9 | use std::io::{BufReader, Read, Result}; 10 | use std::iter::zip; 11 | use std::time::Instant; 12 | 13 | const BOS_TOKEN: usize = 1; 14 | 15 | type RawConfigI32 = [i32; 7]; 16 | 17 | #[derive(Debug, Default)] 18 | struct Config { 19 | dim: usize, // transformer dimension 20 | hidden_dim: usize, // for ffn layers 21 | n_layers: usize, // number of layers 22 | n_heads: usize, // number of query heads 23 | head_size: usize, // size of each head (dim / n_heads) 24 | n_kv_heads: usize, // number of key/value heads 25 | shared_weights: bool, 26 | vocab_size: usize, // vocabulary size 27 | seq_len: usize, // max. sequence length 28 | } 29 | 30 | impl Config { 31 | fn from_file(weights_file_path: &str) -> Result { 32 | // mmap binary weights file 33 | let file = File::open(weights_file_path)?; 34 | let mmap = unsafe { Mmap::map(&file)? }; 35 | 36 | // Read config from weights file 37 | let config_byte_size = std::mem::size_of::(); 38 | if mmap.len() < config_byte_size { 39 | return Err(std::io::Error::new( 40 | std::io::ErrorKind::InvalidData, 41 | "Something is wrong with the weights file. Are you sure you are using the correct file?", 42 | )); 43 | } 44 | 45 | let raw_config = unsafe { 46 | std::mem::transmute::<[u8; std::mem::size_of::()], RawConfigI32>( 47 | mmap[..config_byte_size].try_into().unwrap(), 48 | ) 49 | }; 50 | 51 | Ok(Self { 52 | dim: raw_config[0] as usize, 53 | hidden_dim: raw_config[1] as usize, 54 | n_layers: raw_config[2] as usize, 55 | n_heads: raw_config[3] as usize, 56 | head_size: (raw_config[0] as usize) / (raw_config[3] as usize), 57 | n_kv_heads: raw_config[4] as usize, 58 | shared_weights: raw_config[5] > 0, // weird hack from Andrej 59 | vocab_size: raw_config[5].abs() as usize, 60 | seq_len: raw_config[6] as usize, 61 | }) 62 | } 63 | } 64 | 65 | fn read_n(reader: R, bytes_to_read: usize) -> Result> 66 | where 67 | R: Read, 68 | { 69 | let mut buf = vec![]; 70 | let mut chunk = reader.take(bytes_to_read as u64); 71 | let n = chunk.read_to_end(&mut buf)?; 72 | assert_eq!(bytes_to_read, n); 73 | Ok(buf) 74 | } 75 | 76 | #[derive(Debug, Default)] 77 | struct Tokenizer { 78 | vocab_scores: Vec, 79 | vocab: Vec, 80 | word_to_token_id: HashMap, 81 | max_token_length: usize, 82 | } 83 | 84 | impl Tokenizer { 85 | fn from_file(tokenizer_file_path: &str, vocab_size: usize) -> Result { 86 | let mut vocab = Tokenizer::default(); 87 | vocab.vocab_scores.reserve(vocab_size); 88 | vocab.vocab.reserve(vocab_size); 89 | 90 | let file = File::open(tokenizer_file_path)?; 91 | let mut reader = BufReader::new(file); 92 | 93 | // Read max_token_length 94 | let max_token_length_buffer = read_n(&mut reader, std::mem::size_of::())?; 95 | vocab.max_token_length = byteorder::LittleEndian::read_u32(&max_token_length_buffer) as usize; 96 | 97 | for _ in 0..vocab_size { 98 | // Read vocab score 99 | let vocab_score_buffer = read_n(&mut reader, std::mem::size_of::())?; 100 | let score = byteorder::LittleEndian::read_f32(&vocab_score_buffer); 101 | vocab.vocab_scores.push(score); 102 | 103 | // Read length from file stream 104 | let length_buffer = read_n(&mut reader, std::mem::size_of::())?; 105 | let string_length = byteorder::LittleEndian::read_i32(&length_buffer); 106 | 107 | // Read string from file stream 108 | let string_buffer = read_n(&mut reader, string_length as usize)?; 109 | let string = String::from_utf8(string_buffer).expect("could not read word"); 110 | vocab.vocab.push(string); 111 | } 112 | 113 | vocab.word_to_token_id.reserve(vocab_size); 114 | vocab.vocab.iter().enumerate().for_each(|(token_id, word)| { 115 | vocab.word_to_token_id.insert(word.to_string(), token_id); 116 | }); 117 | 118 | Ok(vocab) 119 | } 120 | 121 | fn decode(&self, token_id: usize) -> &str { 122 | &self.vocab[token_id] 123 | } 124 | 125 | fn lookup_word(&self, word: &str) -> Option { 126 | match self.word_to_token_id.get(word) { 127 | Some(token_id) => Some(*token_id), 128 | None => None 129 | } 130 | } 131 | 132 | fn bpe_encode(&self, s: &str) -> Vec { 133 | let mut tokens = Vec::new(); 134 | tokens.reserve(s.len()); 135 | 136 | // encode every individual byte 137 | for i in 0..s.len() { 138 | let token_id = self.lookup_word(&s[i..i+1]).unwrap(); 139 | tokens.push(token_id); 140 | } 141 | 142 | let mut str_buffer = String::with_capacity(2 * self.max_token_length); 143 | 144 | // merge the best consecutive pair each iteration, according the scores in vocab_scores 145 | loop { 146 | let mut best_score = -1e10; 147 | let mut best_token_id = usize::MAX; 148 | let mut best_idx = usize::MAX; 149 | 150 | for i in 0..tokens.len() - 1 { 151 | // Copy the two consecutive tokens into a single string 152 | str_buffer.clear(); 153 | str_buffer.push_str(&self.vocab[tokens[i]]); 154 | str_buffer.push_str(&self.vocab[tokens[i + 1]]); 155 | 156 | if let Some(token_id) = self.lookup_word(&str_buffer) { 157 | if self.vocab_scores[token_id] > best_score { 158 | best_score = self.vocab_scores[token_id]; 159 | best_token_id = token_id; 160 | best_idx = i; 161 | } 162 | } 163 | } 164 | 165 | if best_idx == usize::MAX { 166 | break; 167 | } 168 | 169 | // Merge the best pair and delete the second token 170 | tokens[best_idx] = best_token_id; 171 | tokens.remove(best_idx + 1); 172 | } 173 | 174 | tokens 175 | } 176 | } 177 | 178 | #[derive(Debug, Default)] 179 | struct TransformerWeights { 180 | // Token Embedding Table 181 | token_embedding_table: Vec, // (vocab_size, dim) 182 | // Weights for RMSNorm 183 | rms_att_weight: Vec, // (layer, dim) 184 | rms_ffn_weight: Vec, // (layer, dim) 185 | // Weights for matmuls in attn 186 | wq: Vec, // (layer, dim, dim) 187 | wk: Vec, // (layer, dim, dim) 188 | wv: Vec, // (layer, dim, dim) 189 | wo: Vec, // (layer, dim, dim) 190 | // Weights for ffn 191 | w1: Vec, // (layer, hidden_dim, dim) 192 | w2: Vec, // (layer, dim, hidden_dim) 193 | w3: Vec, // (layer, hidden_dim, dim) 194 | // final RMSNorm 195 | rms_final_weights: Vec, // (dim) 196 | // freq_cis for RoPE relatively positional embeddings 197 | freq_cis_real: Vec, // (seq_len, head_size/2) 198 | freq_cis_imag: Vec, // (seq_len, head_size/2) 199 | // (optional) classifier weights for the logits, on the last layer 200 | wcls: Vec, // (vocab_size, dim) 201 | } 202 | 203 | fn byte_chunk_to_vec(byte_chunk: &[u8], number_elements: usize) -> Vec 204 | where 205 | T: Clone, 206 | { 207 | unsafe { 208 | let data = byte_chunk.as_ptr() as *const T; 209 | let slice_data = std::slice::from_raw_parts(data, number_elements); 210 | slice_data.to_vec() 211 | } 212 | } 213 | 214 | impl TransformerWeights { 215 | fn from_file(weights_file_path: &str, config: &Config) -> Result { 216 | // mmap binary weights file 217 | let file = File::open(weights_file_path)?; 218 | let mmap = unsafe { Mmap::map(&file)? }; 219 | 220 | let mut offset = std::mem::size_of::(); 221 | 222 | // Read the weights 223 | let token_embedding_table_size = config.vocab_size * config.dim; 224 | let token_embedding_table: Vec = 225 | byte_chunk_to_vec(&mmap[offset..], token_embedding_table_size); 226 | offset += token_embedding_table_size * std::mem::size_of::(); 227 | 228 | // Read the RMSNorm weights for attention 229 | let rms_att_weight_size = config.n_layers * config.dim; 230 | let rms_att_weight: Vec = byte_chunk_to_vec(&mmap[offset..], rms_att_weight_size); 231 | offset += rms_att_weight_size * std::mem::size_of::(); 232 | 233 | // Read the attention weights 234 | let wq_size = config.n_layers * config.dim * config.dim; 235 | let wq: Vec = byte_chunk_to_vec(&mmap[offset..], wq_size); 236 | offset += wq_size * std::mem::size_of::(); 237 | 238 | let wk_size = config.n_layers * config.dim * config.dim; 239 | let wk: Vec = byte_chunk_to_vec(&mmap[offset..], wk_size); 240 | offset += wk_size * std::mem::size_of::(); 241 | 242 | let wv_size = config.n_layers * config.dim * config.dim; 243 | let wv: Vec = byte_chunk_to_vec(&mmap[offset..], wv_size); 244 | offset += wv_size * std::mem::size_of::(); 245 | 246 | let wo_size = config.n_layers * config.dim * config.dim; 247 | let wo: Vec = byte_chunk_to_vec(&mmap[offset..], wo_size); 248 | offset += wo_size * std::mem::size_of::(); 249 | 250 | // Read the RMSNorm weights for ffn 251 | let rms_ffn_weight_size = config.n_layers * config.dim; 252 | let rms_ffn_weight: Vec = byte_chunk_to_vec(&mmap[offset..], rms_ffn_weight_size); 253 | offset += rms_ffn_weight_size * std::mem::size_of::(); 254 | 255 | // Read the ffn weights 256 | let w1_size = config.n_layers * config.hidden_dim * config.dim; 257 | let w1: Vec = byte_chunk_to_vec(&mmap[offset..], w1_size); 258 | offset += w1_size * std::mem::size_of::(); 259 | 260 | let w2_size = config.n_layers * config.dim * config.hidden_dim; 261 | let w2: Vec = byte_chunk_to_vec(&mmap[offset..], w2_size); 262 | offset += w2_size * std::mem::size_of::(); 263 | 264 | let w3_size = config.n_layers * config.hidden_dim * config.dim; 265 | let w3: Vec = byte_chunk_to_vec(&mmap[offset..], w3_size); 266 | offset += w3_size * std::mem::size_of::(); 267 | 268 | // Read the final RMSNorm weights 269 | let rms_final_weights_size = config.dim; 270 | let rms_final_weights: Vec = 271 | byte_chunk_to_vec(&mmap[offset..], rms_final_weights_size); 272 | offset += rms_final_weights_size * std::mem::size_of::(); 273 | 274 | // Read the freq_cis for RoPE relatively positional embeddings 275 | let freq_cis_real_size = config.seq_len * config.head_size / 2; 276 | let freq_cis_real: Vec = byte_chunk_to_vec(&mmap[offset..], freq_cis_real_size); 277 | offset += freq_cis_real_size * std::mem::size_of::(); 278 | 279 | let freq_cis_imag_size = config.seq_len * config.head_size / 2; 280 | let freq_cis_imag: Vec = byte_chunk_to_vec(&mmap[offset..], freq_cis_imag_size); 281 | offset += freq_cis_imag_size * std::mem::size_of::(); 282 | 283 | // Read the classifier weights 284 | let wcls_size = config.vocab_size * config.dim; 285 | let wcls: Vec = if config.shared_weights { 286 | token_embedding_table.clone() 287 | } else { 288 | byte_chunk_to_vec(&mmap[offset..], wcls_size) 289 | }; 290 | 291 | Ok(TransformerWeights { 292 | token_embedding_table, 293 | rms_att_weight, 294 | rms_ffn_weight, 295 | wq, 296 | wk, 297 | wv, 298 | wo, 299 | w1, 300 | w2, 301 | w3, 302 | rms_final_weights, 303 | freq_cis_real, 304 | freq_cis_imag, 305 | wcls, 306 | }) 307 | } 308 | 309 | // Note: does not include the token embedding table 310 | fn num_parameters(&self) -> usize { 311 | let mut n = 0; 312 | n += self.rms_att_weight.len(); 313 | n += self.wq.len(); 314 | n += self.wk.len(); 315 | n += self.wv.len(); 316 | n += self.wo.len(); 317 | n += self.rms_ffn_weight.len(); 318 | n += self.w1.len(); 319 | n += self.w2.len(); 320 | n += self.w3.len(); 321 | n += self.rms_final_weights.len(); 322 | n += self.freq_cis_real.len(); 323 | n += self.freq_cis_imag.len(); 324 | n += self.wcls.len(); 325 | n 326 | } 327 | 328 | fn memory_usage_in_bytes(&self) -> usize { 329 | (self.num_parameters() + self.token_embedding_table.len()) * std::mem::size_of::() 330 | } 331 | } 332 | 333 | // F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid 334 | fn silu(x: f32) -> f32 { 335 | x / (1.0 + (-x).exp()) 336 | } 337 | 338 | fn add_vectors(target: &mut [f32], source: &[f32]) { 339 | target 340 | .iter_mut() 341 | .zip(source.iter()) 342 | .for_each(|(t, s)| *t += s); 343 | } 344 | 345 | fn rmsnorm(x: &mut [f32], weight: &[f32]) { 346 | let size = x.len(); 347 | 348 | let squared_sum = x.iter().fold(0.0, |acc, x| acc + x * x); 349 | let rms = 1. / (squared_sum / size as f32).sqrt(); 350 | 351 | x.iter_mut() 352 | .zip(weight.iter()) 353 | .for_each(|(x, w)| *x *= rms * w); 354 | } 355 | 356 | fn rmsnorm_with_dest(dest: &mut [f32], x: &[f32], weight: &[f32]) { 357 | let size = x.len(); 358 | 359 | let squared_sum = x.iter().fold(0.0, |acc, x| acc + x * x); 360 | let rms = 1. / (squared_sum / size as f32).sqrt(); 361 | 362 | dest.iter_mut() 363 | .zip(x.iter()) 364 | .zip(weight.iter()) 365 | .for_each(|((d, x), w)| { 366 | *d = x * rms * w; 367 | }); 368 | } 369 | 370 | fn softmax(logits: &mut [f32]) { 371 | let n = logits.len(); 372 | 373 | // Find max. for fixing stability 374 | let mut max_logit = logits[0]; 375 | for i in 1..n { 376 | max_logit = max_logit.max(logits[i]); 377 | } 378 | 379 | // Exponentiate and sum logits 380 | let mut sum = 0.0; 381 | for i in 0..n { 382 | logits[i] = (logits[i] - max_logit).exp(); 383 | sum += logits[i]; 384 | } 385 | 386 | // Normalize 387 | for i in 0..n { 388 | logits[i] /= sum; 389 | } 390 | } 391 | 392 | // (out_dim, in_dim) @ (d,) -> (out_dim,) 393 | // w @ x -> target 394 | fn matmul(target: &mut [f32], w: &[f32], x: &[f32]) { 395 | let in_dim = x.len(); 396 | target.par_iter_mut().enumerate().for_each(|(i, t)| { 397 | let row_offset = i * in_dim; 398 | *t = x 399 | .iter() 400 | .zip(w[row_offset..].iter()) 401 | .fold(0.0, |result, (x, w)| result + x * w); 402 | }); 403 | } 404 | 405 | fn inner_product(x: &[f32], y: &[f32]) -> f32 { 406 | zip(x, y).fold(0.0, |acc, (a, b)| acc + a * b) 407 | } 408 | 409 | fn argmax(x: &[f32]) -> usize { 410 | let mut max = std::f32::MIN; 411 | let mut argmax = 0; 412 | for (i, v) in x.iter().enumerate() { 413 | if *v > max { 414 | max = *v; 415 | argmax = i; 416 | } 417 | } 418 | argmax 419 | } 420 | 421 | fn sample(probs: &[f32]) -> usize { 422 | let mut rng = rand::thread_rng(); 423 | let mut cdf = 0.0; 424 | let r = rng.gen_range(0.0..1.0); 425 | for (i, p) in probs.iter().enumerate() { 426 | cdf += p; 427 | if cdf > r { 428 | return i; 429 | } 430 | } 431 | probs.len() - 1 432 | } 433 | 434 | #[derive(Debug)] 435 | struct LLaMA2<'a> { 436 | // buffers for current activations 437 | x: Vec, // activation at current timestep (dim,) 438 | xb: Vec, // same, but inside a residual branch (dim,) 439 | xb2: Vec, // additional buffer (dim,) 440 | hb: Vec, // buffer for hidden dimension in the ffn (hidden_dim,) 441 | hb2: Vec, // buffer for hidden dimension in the ffn (hidden_dim,) 442 | q: Vec, // query (dim,) 443 | k: Vec, // key (dim,) 444 | v: Vec, // value (dim,) 445 | att: Vec, // attention scores (n_heads, seq_len) 446 | logits: Vec, // output logits (vocab_size,) 447 | // kv cache 448 | key_cache: Vec, // (layer, seq_len, dim) 449 | value_cache: Vec, // (layer, seq_len, dim) 450 | // weights & config 451 | transformer: &'a TransformerWeights, 452 | config: &'a Config, 453 | } 454 | 455 | impl<'a> LLaMA2<'a> { 456 | fn new(transformer: &'a TransformerWeights, config: &'a Config) -> LLaMA2<'a> { 457 | Self { 458 | x: vec![0.0; config.dim], 459 | xb: vec![0.0; config.dim], 460 | xb2: vec![0.0; config.dim], 461 | hb: vec![0.0; config.hidden_dim], 462 | hb2: vec![0.0; config.hidden_dim], 463 | q: vec![0.0; config.dim], 464 | k: vec![0.0; config.dim], 465 | v: vec![0.0; config.dim], 466 | att: vec![0.0; config.n_heads * config.seq_len], 467 | logits: vec![0.0; config.vocab_size], 468 | key_cache: vec![0.0; config.n_kv_heads * config.seq_len * config.dim], 469 | value_cache: vec![0.0; config.n_kv_heads * config.seq_len * config.dim], 470 | transformer, 471 | config, 472 | } 473 | } 474 | 475 | // PyTorch: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 476 | fn attn_qkv_matmuls(&mut self, layer: usize) { 477 | let weight_from = layer * self.config.dim * self.config.dim; 478 | let weight_to = (layer + 1) * self.config.dim * self.config.dim; 479 | 480 | matmul( 481 | self.q.as_mut_slice(), // out: (dim,) 482 | &self.transformer.wq[weight_from..weight_to], // W: (dim, dim) 483 | self.xb.as_slice(), // x: (dim,) 484 | ); 485 | 486 | matmul( 487 | self.k.as_mut_slice(), // out: (dim,) 488 | &self.transformer.wk[weight_from..weight_to], // W: (dim, dim) 489 | self.xb.as_slice(), // x: (dim,) 490 | ); 491 | 492 | matmul( 493 | self.v.as_mut_slice(), // out: (dim,) 494 | &self.transformer.wv[weight_from..weight_to], // W: (dim, dim) 495 | self.xb.as_slice(), // x: (dim,) 496 | ); 497 | } 498 | 499 | fn attn_rope(&mut self, layer: usize, pos: usize) { 500 | // apply RoPE rotation to the q and k vectors for each head 501 | 502 | let freq_cis_real_offset = pos * self.config.head_size / 2; 503 | let freq_cis_imag_offset = pos * self.config.head_size / 2; 504 | 505 | // rotate q and k by the freq_cis_real and freq_cis_imag 506 | // For more information checkout the Roformer paper, 507 | // section 3.4.2: https://arxiv.org/pdf/2104.09864.pdf 508 | for i in (0..self.config.dim).step_by(2) { 509 | let q0 = self.q[i]; 510 | let q1 = self.q[i + 1]; 511 | 512 | let k0 = self.k[i]; 513 | let k1 = self.k[i + 1]; 514 | 515 | let cos = self.transformer.freq_cis_real 516 | [freq_cis_real_offset + (i % self.config.head_size) / 2]; 517 | let sin = self.transformer.freq_cis_imag 518 | [freq_cis_imag_offset + (i % self.config.head_size) / 2]; 519 | 520 | self.q[i] = q0 * cos - q1 * sin; 521 | self.q[i + 1] = q1 * cos + q0 * sin; 522 | 523 | self.k[i] = k0 * cos - k1 * sin; 524 | self.k[i + 1] = k1 * cos + k0 * sin; 525 | } 526 | } 527 | 528 | fn cache_kv(&mut self, layer: usize, pos: usize) { 529 | // cache the key, value for the current timestep (pos) 530 | let layer_offset = layer * self.config.seq_len * self.config.dim; // offset to get to the cache of the current layer 531 | let cache_from = layer_offset + pos * self.config.dim; 532 | let cache_to = layer_offset + (pos + 1) * self.config.dim; 533 | 534 | self.key_cache[cache_from..cache_to].copy_from_slice(&self.k.as_slice()); 535 | self.value_cache[cache_from..cache_to].copy_from_slice(&self.v.as_slice()); 536 | } 537 | 538 | fn multihead_attn(&mut self, layer: usize, pos: usize) { 539 | let layer_offset_for_cache = layer * self.config.seq_len * self.config.dim; // offset to get to the cache of the current layer 540 | 541 | let sqrt_d = (self.config.head_size as f32).sqrt(); 542 | 543 | self.att 544 | .par_chunks_exact_mut(self.config.seq_len) 545 | .zip(self.xb.par_chunks_exact_mut(self.config.head_size)) 546 | .enumerate() 547 | .for_each(|(h, (attn_scores, xb))| { 548 | assert_eq!(attn_scores.len(), self.config.seq_len); 549 | assert_eq!(xb.len(), self.config.head_size); 550 | 551 | // get query vector of the timestep pos for the current head 552 | let q_from = h * self.config.head_size; 553 | let q_to = (h + 1) * self.config.head_size; 554 | let q = &self.q[q_from..q_to]; 555 | 556 | // Compute temp = (K * q_pos) / sqrt(dim) 557 | for t in 0..=pos { 558 | let timestep_and_layer_offset = layer_offset_for_cache + t * self.config.dim; // key_cache[l, t] 559 | // for the current key, we need to select the correct range which corresponds to the current head 560 | let key_vector_from = timestep_and_layer_offset + h * self.config.head_size; 561 | let key_vector_to = timestep_and_layer_offset + (h + 1) * self.config.head_size; 562 | let key_vector = &self.key_cache[key_vector_from..key_vector_to]; 563 | 564 | attn_scores[t] = inner_product(q, key_vector) / sqrt_d; 565 | } 566 | 567 | // softmax the scores to get attention weights, from 0..pos inclusively 568 | // Compute temp2 = softmax(temp) 569 | softmax(&mut attn_scores[..(pos + 1)]); 570 | 571 | // Compute temp2^T * V 572 | xb.fill(0.0); 573 | 574 | for t in 0..=pos { 575 | let timestep_and_layer_offset = layer_offset_for_cache + t * self.config.dim; // value_cache[l, t] 576 | // for the current value, we need to select the correct range which corresponds to the current head 577 | let value_vector_from = timestep_and_layer_offset + h * self.config.head_size; 578 | let value_vector_to = 579 | timestep_and_layer_offset + (h + 1) * self.config.head_size; 580 | let value_vector = &self.value_cache[value_vector_from..value_vector_to]; 581 | 582 | // weighted sum with attention scores as weights 583 | let attention_weight = attn_scores[t]; 584 | for i in 0..self.config.head_size { 585 | xb[i] += attention_weight * value_vector[i]; 586 | } 587 | } 588 | }); 589 | } 590 | 591 | // multi-head attention with RoPE 592 | fn attn(&mut self, layer: usize, pos: usize) { 593 | // qkv matmuls 594 | // PyTorch: xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) 595 | self.attn_qkv_matmuls(layer); 596 | 597 | // apply RoPE rotation to the q and k vectors for each head 598 | self.attn_rope(layer, pos); 599 | 600 | // Multi-head attention with caching 601 | // Idea: 602 | // 603 | // Let the current sequence length until the current timestep pos be n. 604 | // The idea is to only compute the attention score for the token at timestep pos. 605 | // Therefore, we compute for each head: 606 | // 607 | // attn_pos = softmax((K * q_pos) / sqrt(dim))^T * V 608 | // 609 | // where 610 | // - attn_pos: attention score for timestep pos. dim(head_size,) 611 | // - q_pos: query vector for timestep pos. dim(head_size,) 612 | // - K/V: key/value vectors for all timesteps up to n. dim(n,head_size) 613 | // ==> this is also the reason why we need the caching, to store all the previous key/value vectors 614 | self.cache_kv(layer, pos); 615 | self.multihead_attn(layer, pos); 616 | 617 | // Map attention scores to logits 618 | // PyTorch: x = self.wo(x) 619 | let weight_from = layer * self.config.dim * self.config.dim; 620 | let weight_to = (layer + 1) * self.config.dim * self.config.dim; 621 | matmul( 622 | self.xb2.as_mut_slice(), // out: (dim,) 623 | &self.transformer.wo[weight_from..weight_to], // W: (dim, dim) 624 | self.xb.as_slice(), // x: (dim,) 625 | ); 626 | } 627 | 628 | // PyTorch: self.w2(F.silu(self.w1(x)) * self.w3(x)) 629 | fn ffn(&mut self, layer: usize) { 630 | let weight_from = layer * self.config.hidden_dim * self.config.dim; 631 | let weight_to = (layer + 1) * self.config.hidden_dim * self.config.dim; 632 | 633 | // PyTorch: self.w1(x) 634 | matmul( 635 | self.hb.as_mut_slice(), // out: (hidden_dim,) 636 | &self.transformer.w1[weight_from..weight_to], // W: (hidden_dim, dim) 637 | self.xb.as_slice(), // x: (dim,) 638 | ); 639 | 640 | // PyTorch: self.w3(x) 641 | matmul( 642 | self.hb2.as_mut_slice(), // out: (hidden_dim,) 643 | &self.transformer.w3[weight_from..weight_to], // W: (hidden_dim, dim) 644 | self.xb.as_slice(), // x: (dim,) 645 | ); 646 | 647 | // PyTorch: x = F.silu(self.w1(x)) * self.w3(x) 648 | // Note: Fused the activation and elementwise multiplication loop 649 | for i in 0..self.config.hidden_dim { 650 | self.hb[i] = silu(self.hb[i]) * self.hb2[i]; 651 | } 652 | 653 | // PyTorch: self.w2(x) 654 | matmul( 655 | self.xb.as_mut_slice(), // out: (hidden_dim,) 656 | &self.transformer.w2[weight_from..weight_to], // W: (hidden_dim, dim) 657 | self.hb.as_slice(), // x: (dim,) 658 | ); 659 | } 660 | 661 | fn layer(&mut self, layer: usize, pos: usize) { 662 | // PyTorch: h = x + self.attention.forward(self.attention_norm(x), start_pos, freqs_cis, mask) 663 | // Note: we leave the buffer x as it is because we need it for the residual connection 664 | rmsnorm_with_dest( 665 | self.xb.as_mut_slice(), 666 | self.x.as_slice(), 667 | &self.transformer.rms_att_weight 668 | [layer * self.config.dim..(layer + 1) * self.config.dim], 669 | ); 670 | self.attn(layer, pos); 671 | // residual connection 672 | add_vectors(self.x.as_mut_slice(), self.xb2.as_slice()); 673 | 674 | // PyTorch: out = h + self.feed_forward.forward(self.ffn_norm(h)) 675 | // Note: we leave the buffer x as it is because we need it for the residual connection 676 | rmsnorm_with_dest( 677 | self.xb.as_mut_slice(), 678 | self.x.as_slice(), 679 | &self.transformer.rms_ffn_weight 680 | [layer * self.config.dim..(layer + 1) * self.config.dim], 681 | ); 682 | self.ffn(layer); 683 | // residual connection 684 | add_vectors(self.x.as_mut_slice(), self.xb.as_slice()); 685 | } 686 | 687 | fn forward(&mut self, token: usize, pos: usize) { 688 | // fetch the token embedding 689 | // PyTorch: h = self.tok_embeddings(tokens) 690 | self.x.copy_from_slice( 691 | &self.transformer.token_embedding_table 692 | [(token * self.config.dim)..((token + 1) * self.config.dim)], 693 | ); 694 | 695 | // Note: here it always holds that seqlen == 1 in comparison to the PyTorch implementation 696 | 697 | // forward through the layers 698 | // PyTorch: 699 | // for layer in self.layers: 700 | // h = layer(h, start_pos, freqs_cis, mask) 701 | for l in 0..self.config.n_layers { 702 | self.layer(l, pos); 703 | } 704 | 705 | // final RMSNorm 706 | // PyTorch: h = self.norm(h) 707 | rmsnorm( 708 | self.x.as_mut_slice(), 709 | self.transformer.rms_final_weights.as_slice(), 710 | ); 711 | 712 | // generate logits, i.e., map activations from dim to vocab_size 713 | // PyTorch: output = self.output(h).float() 714 | matmul( 715 | self.logits.as_mut_slice(), // out: (vocab_size,) 716 | self.transformer.wcls.as_slice(), // W: (vocab_size, dim) 717 | self.x.as_slice(), // x: (dim,) 718 | ); 719 | } 720 | 721 | fn generate(&mut self, prompt_tokens: &Vec, n_tokens: usize, temperature: f32) -> Vec { 722 | let mut tokens = vec![]; 723 | tokens.reserve(n_tokens); 724 | 725 | let mut token = BOS_TOKEN; 726 | tokens.push(token); 727 | 728 | // forward through the prompt to fill up the KV-cache! 729 | for (pos, prompt_token) in prompt_tokens.iter().enumerate() { 730 | self.forward(token, pos); 731 | token = *prompt_token; 732 | tokens.push(token); 733 | } 734 | 735 | // complete the prompt 736 | for pos in prompt_tokens.len()..(n_tokens - 1) { 737 | self.forward(token, pos); 738 | 739 | if temperature == 0.0 { 740 | token = argmax(self.logits.as_slice()); 741 | } else { 742 | // Apply temperature and then sample. 743 | // If temperature < 1.0 then the distribution becomes more peaked ==> lower variance in sampling 744 | // If temperature > 1.0 then the distribution becomes more flat ==> higher variance in sampling 745 | self.logits.iter_mut().for_each(|p| *p = *p / temperature); 746 | softmax(&mut self.logits.as_mut_slice()); 747 | token = sample(self.logits.as_slice()); 748 | } 749 | 750 | tokens.push(token); 751 | } 752 | 753 | tokens 754 | } 755 | 756 | fn memory_usage_in_bytes(&self) -> usize { 757 | let mut memory_usage = 0; 758 | 759 | memory_usage += self.x.capacity() * std::mem::size_of::(); 760 | memory_usage += self.xb.capacity() * std::mem::size_of::(); 761 | memory_usage += self.xb2.capacity() * std::mem::size_of::(); 762 | memory_usage += self.hb.capacity() * std::mem::size_of::(); 763 | memory_usage += self.hb2.capacity() * std::mem::size_of::(); 764 | memory_usage += self.q.capacity() * std::mem::size_of::(); 765 | memory_usage += self.k.capacity() * std::mem::size_of::(); 766 | memory_usage += self.v.capacity() * std::mem::size_of::(); 767 | memory_usage += self.att.capacity() * std::mem::size_of::(); 768 | memory_usage += self.logits.capacity() * std::mem::size_of::(); 769 | memory_usage += self.key_cache.capacity() * std::mem::size_of::(); 770 | memory_usage += self.value_cache.capacity() * std::mem::size_of::(); 771 | 772 | memory_usage += self.transformer.memory_usage_in_bytes(); 773 | 774 | memory_usage 775 | } 776 | } 777 | 778 | fn main() -> Result<()> { 779 | let file_path = "weights/stories15M.bin"; 780 | let prompt = "One day, Lily met a bear"; 781 | let temperature = 0.0; 782 | let steps = 256; 783 | 784 | // Setup 785 | 786 | println!("Loading config..."); 787 | let config = Config::from_file(file_path)?; 788 | println!("Loaded config: {:?}", config); 789 | 790 | println!("Loading vocab..."); 791 | let tokenizer = Tokenizer::from_file("tokenizer.bin", config.vocab_size as usize)?; 792 | 793 | println!("Loading weights..."); 794 | let transformer_weights = TransformerWeights::from_file(file_path, &config)?; 795 | 796 | println!("Done."); 797 | 798 | println!( 799 | "Number of parameters: {}", 800 | transformer_weights.num_parameters() 801 | ); 802 | 803 | // Configure rayon 804 | 805 | let cpus = num_cpus::get(); 806 | let active_cpus = (cpus).max(1).min(config.n_heads); 807 | println!("Using {} threads", active_cpus); 808 | rayon::ThreadPoolBuilder::new() 809 | .num_threads(active_cpus) 810 | .build_global() 811 | .unwrap(); 812 | 813 | // Inference 814 | 815 | let start = Instant::now(); 816 | let mut llama2 = LLaMA2::new(&transformer_weights, &config); 817 | 818 | let llama_memory_mib = llama2.memory_usage_in_bytes() as f32 / ((1 as usize) << 20) as f32; 819 | println!("Memory usage in MiB: {llama_memory_mib}"); 820 | 821 | let prompt_tokens = tokenizer.bpe_encode(&prompt); 822 | let generated_tokens = llama2.generate(&prompt_tokens, steps, temperature); 823 | 824 | let time_elapsed = start.elapsed().as_secs_f32(); 825 | let tokens_per_sec = (steps as f32) / time_elapsed; 826 | println!("tokens / seconds = {:.2?}", tokens_per_sec); 827 | 828 | print!("{}", prompt); 829 | for token in generated_tokens { 830 | if token == 1 && tokenizer.decode(token).starts_with(' ') { 831 | print!("{}", &tokenizer.decode(token)[1..]); 832 | } else { 833 | print!("{}", tokenizer.decode(token)); 834 | }; 835 | } 836 | 837 | Ok(()) 838 | } 839 | -------------------------------------------------------------------------------- /tokenizer.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/danielgrittner/llama2-rs/a6c406d831989e2908f5969a325c8b2d051de4a6/tokenizer.bin --------------------------------------------------------------------------------