├── .gitignore ├── Cargo.lock ├── Cargo.toml ├── README.md ├── build.rs ├── src ├── bindings.rs ├── lib.rs └── main.rs └── vendor └── llama.cpp /.gitignore: -------------------------------------------------------------------------------- 1 | /target 2 | -------------------------------------------------------------------------------- /Cargo.lock: -------------------------------------------------------------------------------- 1 | # This file is automatically @generated by Cargo. 2 | # It is not intended for manual editing. 3 | version = 3 4 | 5 | [[package]] 6 | name = "aho-corasick" 7 | version = "1.0.4" 8 | source = "registry+https://github.com/rust-lang/crates.io-index" 9 | checksum = "6748e8def348ed4d14996fa801f4122cd763fff530258cdc03f64b25f89d3a5a" 10 | dependencies = [ 11 | "memchr", 12 | ] 13 | 14 | [[package]] 15 | name = "anyhow" 16 | version = "1.0.75" 17 | source = "registry+https://github.com/rust-lang/crates.io-index" 18 | checksum = "a4668cab20f66d8d020e1fbc0ebe47217433c1b6c8f2040faf858554e394ace6" 19 | 20 | [[package]] 21 | name = "bindgen" 22 | version = "0.66.1" 23 | source = "registry+https://github.com/rust-lang/crates.io-index" 24 | checksum = "f2b84e06fc203107bfbad243f4aba2af864eb7db3b1cf46ea0a023b0b433d2a7" 25 | dependencies = [ 26 | "bitflags", 27 | "cexpr", 28 | "clang-sys", 29 | "lazy_static", 30 | "lazycell", 31 | "log", 32 | "peeking_take_while", 33 | "prettyplease", 34 | "proc-macro2", 35 | "quote", 36 | "regex", 37 | "rustc-hash", 38 | "shlex", 39 | "syn", 40 | "which", 41 | ] 42 | 43 | [[package]] 44 | name = "bitflags" 45 | version = "2.4.0" 46 | source = "registry+https://github.com/rust-lang/crates.io-index" 47 | checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" 48 | 49 | [[package]] 50 | name = "cc" 51 | version = "1.0.83" 52 | source = "registry+https://github.com/rust-lang/crates.io-index" 53 | checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" 54 | dependencies = [ 55 | "libc", 56 | ] 57 | 58 | [[package]] 59 | name = "cexpr" 60 | version = "0.6.0" 61 | source = "registry+https://github.com/rust-lang/crates.io-index" 62 | checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" 63 | dependencies = [ 64 | "nom", 65 | ] 66 | 67 | [[package]] 68 | name = "cfg-if" 69 | version = "1.0.0" 70 | source = "registry+https://github.com/rust-lang/crates.io-index" 71 | checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" 72 | 73 | [[package]] 74 | name = "clang-sys" 75 | version = "1.6.1" 76 | source = "registry+https://github.com/rust-lang/crates.io-index" 77 | checksum = "c688fc74432808e3eb684cae8830a86be1d66a2bd58e1f248ed0960a590baf6f" 78 | dependencies = [ 79 | "glob", 80 | "libc", 81 | "libloading", 82 | ] 83 | 84 | [[package]] 85 | name = "either" 86 | version = "1.9.0" 87 | source = "registry+https://github.com/rust-lang/crates.io-index" 88 | checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" 89 | 90 | [[package]] 91 | name = "glob" 92 | version = "0.3.1" 93 | source = "registry+https://github.com/rust-lang/crates.io-index" 94 | checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" 95 | 96 | [[package]] 97 | name = "lazy_static" 98 | version = "1.4.0" 99 | source = "registry+https://github.com/rust-lang/crates.io-index" 100 | checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" 101 | 102 | [[package]] 103 | name = "lazycell" 104 | version = "1.3.0" 105 | source = "registry+https://github.com/rust-lang/crates.io-index" 106 | checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" 107 | 108 | [[package]] 109 | name = "libc" 110 | version = "0.2.147" 111 | source = "registry+https://github.com/rust-lang/crates.io-index" 112 | checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" 113 | 114 | [[package]] 115 | name = "libloading" 116 | version = "0.7.4" 117 | source = "registry+https://github.com/rust-lang/crates.io-index" 118 | checksum = "b67380fd3b2fbe7527a606e18729d21c6f3951633d0500574c4dc22d2d638b9f" 119 | dependencies = [ 120 | "cfg-if", 121 | "winapi", 122 | ] 123 | 124 | [[package]] 125 | name = "llama-rs" 126 | version = "0.1.0" 127 | dependencies = [ 128 | "anyhow", 129 | "bindgen", 130 | "cc", 131 | ] 132 | 133 | [[package]] 134 | name = "log" 135 | version = "0.4.20" 136 | source = "registry+https://github.com/rust-lang/crates.io-index" 137 | checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" 138 | 139 | [[package]] 140 | name = "memchr" 141 | version = "2.5.0" 142 | source = "registry+https://github.com/rust-lang/crates.io-index" 143 | checksum = "2dffe52ecf27772e601905b7522cb4ef790d2cc203488bbd0e2fe85fcb74566d" 144 | 145 | [[package]] 146 | name = "minimal-lexical" 147 | version = "0.2.1" 148 | source = "registry+https://github.com/rust-lang/crates.io-index" 149 | checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" 150 | 151 | [[package]] 152 | name = "nom" 153 | version = "7.1.3" 154 | source = "registry+https://github.com/rust-lang/crates.io-index" 155 | checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" 156 | dependencies = [ 157 | "memchr", 158 | "minimal-lexical", 159 | ] 160 | 161 | [[package]] 162 | name = "once_cell" 163 | version = "1.18.0" 164 | source = "registry+https://github.com/rust-lang/crates.io-index" 165 | checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" 166 | 167 | [[package]] 168 | name = "peeking_take_while" 169 | version = "0.1.2" 170 | source = "registry+https://github.com/rust-lang/crates.io-index" 171 | checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" 172 | 173 | [[package]] 174 | name = "prettyplease" 175 | version = "0.2.12" 176 | source = "registry+https://github.com/rust-lang/crates.io-index" 177 | checksum = "6c64d9ba0963cdcea2e1b2230fbae2bab30eb25a174be395c41e764bfb65dd62" 178 | dependencies = [ 179 | "proc-macro2", 180 | "syn", 181 | ] 182 | 183 | [[package]] 184 | name = "proc-macro2" 185 | version = "1.0.66" 186 | source = "registry+https://github.com/rust-lang/crates.io-index" 187 | checksum = "18fb31db3f9bddb2ea821cde30a9f70117e3f119938b5ee630b7403aa6e2ead9" 188 | dependencies = [ 189 | "unicode-ident", 190 | ] 191 | 192 | [[package]] 193 | name = "quote" 194 | version = "1.0.33" 195 | source = "registry+https://github.com/rust-lang/crates.io-index" 196 | checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" 197 | dependencies = [ 198 | "proc-macro2", 199 | ] 200 | 201 | [[package]] 202 | name = "regex" 203 | version = "1.9.3" 204 | source = "registry+https://github.com/rust-lang/crates.io-index" 205 | checksum = "81bc1d4caf89fac26a70747fe603c130093b53c773888797a6329091246d651a" 206 | dependencies = [ 207 | "aho-corasick", 208 | "memchr", 209 | "regex-automata", 210 | "regex-syntax", 211 | ] 212 | 213 | [[package]] 214 | name = "regex-automata" 215 | version = "0.3.6" 216 | source = "registry+https://github.com/rust-lang/crates.io-index" 217 | checksum = "fed1ceff11a1dddaee50c9dc8e4938bd106e9d89ae372f192311e7da498e3b69" 218 | dependencies = [ 219 | "aho-corasick", 220 | "memchr", 221 | "regex-syntax", 222 | ] 223 | 224 | [[package]] 225 | name = "regex-syntax" 226 | version = "0.7.4" 227 | source = "registry+https://github.com/rust-lang/crates.io-index" 228 | checksum = "e5ea92a5b6195c6ef2a0295ea818b312502c6fc94dde986c5553242e18fd4ce2" 229 | 230 | [[package]] 231 | name = "rustc-hash" 232 | version = "1.1.0" 233 | source = "registry+https://github.com/rust-lang/crates.io-index" 234 | checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" 235 | 236 | [[package]] 237 | name = "shlex" 238 | version = "1.1.0" 239 | source = "registry+https://github.com/rust-lang/crates.io-index" 240 | checksum = "43b2853a4d09f215c24cc5489c992ce46052d359b5109343cbafbf26bc62f8a3" 241 | 242 | [[package]] 243 | name = "syn" 244 | version = "2.0.29" 245 | source = "registry+https://github.com/rust-lang/crates.io-index" 246 | checksum = "c324c494eba9d92503e6f1ef2e6df781e78f6a7705a0202d9801b198807d518a" 247 | dependencies = [ 248 | "proc-macro2", 249 | "quote", 250 | "unicode-ident", 251 | ] 252 | 253 | [[package]] 254 | name = "unicode-ident" 255 | version = "1.0.11" 256 | source = "registry+https://github.com/rust-lang/crates.io-index" 257 | checksum = "301abaae475aa91687eb82514b328ab47a211a533026cb25fc3e519b86adfc3c" 258 | 259 | [[package]] 260 | name = "which" 261 | version = "4.4.0" 262 | source = "registry+https://github.com/rust-lang/crates.io-index" 263 | checksum = "2441c784c52b289a054b7201fc93253e288f094e2f4be9058343127c4226a269" 264 | dependencies = [ 265 | "either", 266 | "libc", 267 | "once_cell", 268 | ] 269 | 270 | [[package]] 271 | name = "winapi" 272 | version = "0.3.9" 273 | source = "registry+https://github.com/rust-lang/crates.io-index" 274 | checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" 275 | dependencies = [ 276 | "winapi-i686-pc-windows-gnu", 277 | "winapi-x86_64-pc-windows-gnu", 278 | ] 279 | 280 | [[package]] 281 | name = "winapi-i686-pc-windows-gnu" 282 | version = "0.4.0" 283 | source = "registry+https://github.com/rust-lang/crates.io-index" 284 | checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" 285 | 286 | [[package]] 287 | name = "winapi-x86_64-pc-windows-gnu" 288 | version = "0.4.0" 289 | source = "registry+https://github.com/rust-lang/crates.io-index" 290 | checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" 291 | -------------------------------------------------------------------------------- /Cargo.toml: -------------------------------------------------------------------------------- 1 | [package] 2 | name = "llama-rs" 3 | version = "0.1.0" 4 | edition = "2021" 5 | 6 | # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html 7 | 8 | [dependencies] 9 | anyhow = "1" 10 | 11 | [build-dependencies] 12 | bindgen = "0.66" 13 | cc = "1.0" 14 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # llama-rs 2 | 3 | Rust bindings to llama.cpp, for macOS, with metal support, for testing and 4 | evaluating whether it would be worthwhile to run an Llama model locally in 5 | a Rust app. 6 | 7 | ### Setup 8 | 9 | 1. Clone llama.cpp into `vendor/llama.cpp` 10 | 2. Build llama.cpp: `LLAMA_METAL=1 make` 11 | 3. Download a llama2 model: https://huggingface.co/TheBloke/Llama-2-7B-GGML/tree/main 12 | 4. Convert the model to llama.cpp's GGUF format using the script in the `llama.cpp` repo. 13 | -------------------------------------------------------------------------------- /build.rs: -------------------------------------------------------------------------------- 1 | use bindgen::Builder; 2 | use std::{env, fs, path::PathBuf}; 3 | 4 | fn main() { 5 | let out_path = PathBuf::from(env::var("OUT_DIR").unwrap()); 6 | Builder::default() 7 | .clang_args(["-D", "LLAMA_METAL=1", "-xc++"]) 8 | .header("./vendor/llama.cpp/llama.h") 9 | .layout_tests(false) 10 | .generate() 11 | .expect("failed to generate bindings") 12 | .write_to_file(out_path.join("bindings.rs")) 13 | // .write_to_file("src/bindings.rs") 14 | .expect("failed to write bindings"); 15 | 16 | fs::create_dir_all("target/debug").unwrap(); 17 | fs::copy( 18 | "./vendor/llama.cpp/ggml-metal.metal", 19 | "./target/debug/ggml-metal.metal", 20 | ) 21 | .unwrap(); 22 | 23 | println!("cargo:rustc-link-lib=framework=System"); 24 | println!("cargo:rustc-link-lib=framework=Metal"); 25 | println!("cargo:rustc-link-lib=framework=MetalKit"); 26 | println!("cargo:rustc-link-lib=framework=Accelerate"); 27 | println!("cargo:rustc-link-lib=objc"); 28 | println!("cargo:rustc-link-lib=framework=Foundation"); 29 | 30 | cc::Build::new() 31 | .cpp(true) 32 | .object("./vendor/llama.cpp/common.o") 33 | .object("./vendor/llama.cpp/ggml-alloc.o") 34 | .object("./vendor/llama.cpp/ggml-metal.o") 35 | .object("./vendor/llama.cpp/ggml.o") 36 | .object("./vendor/llama.cpp/k_quants.o") 37 | .object("./vendor/llama.cpp/llama.o") 38 | .compile("binding"); 39 | } 40 | -------------------------------------------------------------------------------- /src/bindings.rs: -------------------------------------------------------------------------------- 1 | #![allow(non_upper_case_globals)] 2 | #![allow(non_camel_case_types)] 3 | #![allow(non_snake_case)] 4 | 5 | include!(concat!(env!("OUT_DIR"), "/bindings.rs")); 6 | -------------------------------------------------------------------------------- /src/lib.rs: -------------------------------------------------------------------------------- 1 | use anyhow::{anyhow, Result}; 2 | use std::{ 3 | ffi::{c_char, CString}, 4 | path::Path, 5 | }; 6 | 7 | mod bindings; 8 | use bindings as c; 9 | 10 | static INIT: std::sync::Once = std::sync::Once::new(); 11 | 12 | pub struct Model(*mut c::llama_model); 13 | pub struct Context(*mut c::llama_context); 14 | pub type Token = c::llama_token; 15 | 16 | #[derive(Copy, Clone)] 17 | pub struct Params(c::llama_context_params); 18 | 19 | impl Params { 20 | pub fn new() -> Self { 21 | Self(unsafe { c::llama_context_default_params() }) 22 | } 23 | 24 | pub fn gpu_layers(mut self, count: usize) -> Self { 25 | self.0.n_gpu_layers = count as i32; 26 | self 27 | } 28 | 29 | pub fn batch_size(mut self, size: usize) -> Self { 30 | self.0.n_batch = size as i32; 31 | self 32 | } 33 | 34 | pub fn embedding_only(mut self) -> Self { 35 | self.0.embedding = true; 36 | self 37 | } 38 | 39 | pub fn context_size(mut self, size: usize) -> Self { 40 | self.0.n_ctx = size as i32; 41 | self 42 | } 43 | } 44 | 45 | impl Default for Params { 46 | fn default() -> Self { 47 | Self::new() 48 | } 49 | } 50 | 51 | impl Model { 52 | pub fn new(path: &Path, params: Params) -> Result { 53 | unsafe { 54 | INIT.call_once(|| c::llama_backend_init(true)); 55 | 56 | let path = CString::new(path.to_str().ok_or_else(|| anyhow!("invalid path"))?)?; 57 | let ptr = c::llama_load_model_from_file(path.as_ptr(), params.0); 58 | if ptr.is_null() { 59 | Err(anyhow::anyhow!("Failed to load model")) 60 | } else { 61 | Ok(Self(ptr)) 62 | } 63 | } 64 | } 65 | 66 | pub fn model_type(&self) -> String { 67 | unsafe { 68 | let mut result = vec![0u8; 256]; 69 | let len = c::llama_model_type(self.0, result.as_ptr() as *mut c_char, result.len()); 70 | result.truncate(len.max(0) as usize); 71 | String::from_utf8(result).unwrap_or(String::new()) 72 | } 73 | } 74 | } 75 | 76 | impl Drop for Model { 77 | fn drop(&mut self) { 78 | unsafe { 79 | c::llama_free_model(self.0); 80 | } 81 | } 82 | } 83 | 84 | impl Context { 85 | pub fn new(model: &Model, params: Params) -> Result { 86 | unsafe { 87 | let ptr = c::llama_new_context_with_model(model.0, params.0); 88 | if ptr.is_null() { 89 | Err(anyhow!("failed to create context")) 90 | } else { 91 | Ok(Self(ptr)) 92 | } 93 | } 94 | } 95 | 96 | pub fn tokenize(&self, text: &str, output: &mut [Token], is_start: bool) -> Result { 97 | let text = CString::new(text)?; 98 | unsafe { 99 | let len = c::llama_tokenize( 100 | self.0, 101 | text.as_ptr(), 102 | output.as_mut_ptr(), 103 | output.len() as i32, 104 | is_start, 105 | ); 106 | if len > 0 { 107 | Ok(len as usize) 108 | } else { 109 | Err(anyhow!("failed to tokenize")) 110 | } 111 | } 112 | } 113 | 114 | pub fn eval(&mut self, tokens: &[Token], n_past: usize, n_threads: usize) -> Result<()> { 115 | unsafe { 116 | let code = c::llama_eval( 117 | self.0, 118 | tokens.as_ptr(), 119 | tokens.len() as i32, 120 | n_past as i32, 121 | n_threads as i32, 122 | ); 123 | if code == 0 { 124 | Ok(()) 125 | } else { 126 | Err(anyhow!("failed to eval")) 127 | } 128 | } 129 | } 130 | 131 | pub fn embeddings(&self) -> &[f32] { 132 | unsafe { 133 | let len = c::llama_n_embd(self.0); 134 | let ptr = c::llama_get_embeddings(self.0); 135 | std::slice::from_raw_parts(ptr, len as usize) 136 | } 137 | } 138 | } 139 | 140 | impl Drop for Context { 141 | fn drop(&mut self) { 142 | unsafe { 143 | c::llama_free(self.0); 144 | } 145 | } 146 | } 147 | -------------------------------------------------------------------------------- /src/main.rs: -------------------------------------------------------------------------------- 1 | use llama_rs::{Context, Model, Params, Token}; 2 | use std::{fs, time::Instant}; 3 | 4 | fn main() { 5 | let params = Params::new() 6 | .embedding_only() 7 | .gpu_layers(16) 8 | .batch_size(1024) 9 | .context_size(2048); 10 | 11 | let model = Model::new("../llama.cpp/models/llama-2-7b.gguf".as_ref(), params).unwrap(); 12 | let mut llama = Context::new(&model, params).unwrap(); 13 | let mut tokens = vec![0; 2048]; 14 | 15 | let filenames = std::env::args().skip(1); 16 | for filename in filenames { 17 | let text = fs::read_to_string(&filename).expect("failed to read file"); 18 | 19 | for _ in 0..10 { 20 | let t0 = Instant::now(); 21 | let embeddings = &get_embeddings(&mut llama, &mut tokens, &text)[0..5]; 22 | println!("{filename:?}: time: {:?}, {embeddings:?}", t0.elapsed()); 23 | } 24 | } 25 | } 26 | 27 | fn get_embeddings<'a>( 28 | llama: &'a mut Context, 29 | token_buffer: &mut Vec, 30 | text: &str, 31 | ) -> &'a [f32] { 32 | let len = llama.tokenize(text, token_buffer, true).unwrap(); 33 | token_buffer.truncate(len); 34 | llama.eval(&token_buffer, 0, 1).unwrap(); 35 | llama.embeddings() 36 | } 37 | -------------------------------------------------------------------------------- /vendor/llama.cpp: -------------------------------------------------------------------------------- 1 | ../../llama.cpp --------------------------------------------------------------------------------